diff --git a/util/bigint/src/uint.rs b/util/bigint/src/uint.rs index ad4f0a99c..47a975d5f 100644 --- a/util/bigint/src/uint.rs +++ b/util/bigint/src/uint.rs @@ -71,29 +71,19 @@ macro_rules! uint_overflowing_add_reg { ($name:ident, $n_words:expr, $self_expr: expr, $other: expr) => ({ let $name(ref me) = $self_expr; let $name(ref you) = $other; + let mut ret = [0u64; $n_words]; - let mut carry = [0u64; $n_words]; - let mut b_carry = false; - let mut overflow = false; + let mut carry = [0u64; $n_words + 1]; for i in 0..$n_words { - ret[i] = me[i].wrapping_add(you[i]); + let (res1, overflow1) = me[i].overflowing_add(you[i]); + let (res2, overflow2) = res1.overflowing_add(carry[i]); - if ret[i] < me[i] { - if i < $n_words - 1 { - carry[i + 1] = 1; - b_carry = true; - } else { - overflow = true; - } - } - } - if b_carry { - let ret = overflowing!($name(ret).overflowing_add($name(carry)), overflow); - (ret, overflow) - } else { - ($name(ret), overflow) + ret[i] = res2; + carry[i+1] = overflow1 as u64 + overflow2 as u64; } + + ($name(ret), carry[$n_words] > 0) }) } @@ -673,37 +663,10 @@ macro_rules! construct_uint { } fn overflowing_shl(self, shift32: u32) -> ($name, bool) { - let $name(ref original) = self; - let mut ret = [0u64; $n_words]; let shift = shift32 as usize; - let word_shift = shift / 64; - let bit_shift = shift % 64; - for i in 0..$n_words { - // Shift - if i + word_shift < $n_words { - ret[i + word_shift] += original[i] << bit_shift; - } - // Carry - if bit_shift > 0 && i + word_shift + 1 < $n_words { - ret[i + word_shift + 1] += original[i] >> (64 - bit_shift); - } - } - // Detecting overflow - let last = $n_words - word_shift - if bit_shift > 0 { 1 } else { 0 }; - let overflow = if bit_shift > 0 { - (original[last] >> (64 - bit_shift)) > 0 - } else if word_shift > 0 { - original[last] > 0 - } else { - false - }; - for i in last+1..$n_words-1 { - if original[i] > 0 { - return ($name(ret), true); - } - } - ($name(ret), overflow) + let res = self << shift; + (res, self != (res >> shift)) } } @@ -987,14 +950,15 @@ macro_rules! construct_uint { let mut ret = [0u64; $n_words]; let word_shift = shift / 64; let bit_shift = shift % 64; - for i in 0..$n_words { - // Shift - if i + word_shift < $n_words { - ret[i + word_shift] += original[i] << bit_shift; - } - // Carry - if bit_shift > 0 && i + word_shift + 1 < $n_words { - ret[i + word_shift + 1] += original[i] >> (64 - bit_shift); + + // shift + for i in word_shift..$n_words { + ret[i] += original[i - word_shift] << bit_shift; + } + // carry + if bit_shift > 0 { + for i in word_shift+1..$n_words { + ret[i] += original[i - 1 - word_shift] >> (64 - bit_shift); } } $name(ret) @@ -1672,6 +1636,11 @@ mod tests { #[test] pub fn uint256_shl_overflow() { + assert_eq!( + U256::from_str("7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").unwrap() + << 4, + U256::from_str("fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0").unwrap() + ); assert_eq!( U256::from_str("7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").unwrap() .overflowing_shl(4), @@ -1681,6 +1650,16 @@ mod tests { #[test] pub fn uint256_shl_overflow_words() { + assert_eq!( + U256::from_str("0000000000000001ffffffffffffffffffffffffffffffffffffffffffffffff").unwrap() + << 64, + U256::from_str("ffffffffffffffffffffffffffffffffffffffffffffffff0000000000000000").unwrap() + ); + assert_eq!( + U256::from_str("0000000000000000ffffffffffffffffffffffffffffffffffffffffffffffff").unwrap() + << 64, + U256::from_str("ffffffffffffffffffffffffffffffffffffffffffffffff0000000000000000").unwrap() + ); assert_eq!( U256::from_str("0000000000000001ffffffffffffffffffffffffffffffffffffffffffffffff").unwrap() .overflowing_shl(64),