Implementing mul and full_mul

This commit is contained in:
Tomasz Drwięga 2016-03-08 01:13:00 +01:00
parent 76865694ce
commit 17b2d2a2d7

View File

@ -390,27 +390,47 @@ macro_rules! uint_overflowing_mul {
macro_rules! uint_overflowing_mul_reg {
($name:ident, $n_words: expr, $self_expr: expr, $other: expr) => ({
let mut res = $name::from(0u64);
let mut overflow = false;
let $name(ref me) = $self_expr;
let $name(ref you) = $other;
let mut ret = [0u64; 2*$n_words];
let mut current = $other;
let mut current_shift = 0;
let mut current_u32;
let mut i = 0;
for i in 0..$n_words {
let mut carry2 = 0u64;
let (b_u, b_l) = (you[i] >> 32, you[i] & 0xFFFFFFFF);
while i < 2*$n_words {
current_u32 = current.low_u32();
for j in 0..$n_words {
let a = me[j];
let v = overflowing!($self_expr.overflowing_mul_u32(current_u32), overflow);
let v_shifted = overflowing!(v.overflowing_shl(current_shift), overflow);
res = overflowing!(res.overflowing_add(v_shifted), overflow);
// multiply parts
let (c_l, overflow_l) = mul_u32(a, b_l as u32, ret[j + i]);
let (c_u, overflow_u) = mul_u32(a, b_u as u32, c_l >> 32);
current = current >> 32;
current_shift += 32;
i += 1;
// This won't overflow
ret[j + i] = (c_l & 0xFFFFFFFF) + (c_u << 32);
// carry1 = overflow_l + (c_u >> 32) + (overflow_u << 32) + carry2 + c0;
let (ca1, c1) = overflow_l.overflowing_add((c_u >> 32) + (overflow_u << 32));
let (ca1, c2) = ca1.overflowing_add(ret[j + i + 1]);
let (ca1, c3) = ca1.overflowing_add(carry2);
ret[j + i + 1] = ca1;
// Will never overflow
carry2 = (overflow_u >> 32) + c1 as u64 + c2 as u64 + c3 as u64;
}
}
(res, overflow)
let mut res = [0u64; $n_words];
let mut overflow = false;
for i in 0..$n_words {
res[i] = ret[i];
}
for i in $n_words..2*$n_words {
overflow |= ret[i] != 0;
}
($name(res), overflow)
})
}
@ -438,6 +458,19 @@ macro_rules! panic_on_overflow {
}
}
#[inline(always)]
fn mul_u32(a: u64, b: u32, carry: u64) -> (u64, u64) {
let b = b as u64;
let upper = b * (a >> 32);
let lower = b * (a & 0xFFFFFFFF);
let (res1, overflow1) = lower.overflowing_add(upper << 32);
let (res2, overflow2) = res1.overflowing_add(carry);
let carry = (upper >> 32) + overflow1 as u64 + overflow2 as u64;
(res2, carry)
}
/// Large, fixed-length unsigned integer type.
pub trait Uint: Sized + Default + FromStr + From<u64> + fmt::Debug + fmt::Display + PartialOrd + Ord + PartialEq + Eq + Hash {
@ -496,9 +529,6 @@ pub trait Uint: Sized + Default + FromStr + From<u64> + fmt::Debug + fmt::Displa
/// Returns negation of this `Uint` and overflow (always true)
fn overflowing_neg(self) -> (Self, bool);
/// Shifts this `Uint` and returns overflow
fn overflowing_shl(self, shift: u32) -> (Self, bool);
}
macro_rules! construct_uint {
@ -687,13 +717,6 @@ macro_rules! construct_uint {
fn overflowing_neg(self) -> ($name, bool) {
(!self, true)
}
fn overflowing_shl(self, shift32: u32) -> ($name, bool) {
let shift = shift32 as usize;
let res = self << shift;
(res, self != (res >> shift))
}
}
impl $name {
@ -709,19 +732,13 @@ macro_rules! construct_uint {
/// Overflowing multiplication by u32
fn overflowing_mul_u32(self, other: u32) -> (Self, bool) {
let $name(ref arr) = self;
let o = other as u64;
let mut ret = [0u64; $n_words];
let mut carry = 0;
for i in 0..$n_words {
let upper = o * (arr[i] >> 32);
let lower = o * (arr[i] & 0xFFFFFFFF);
let (res1, overflow1) = lower.overflowing_add(upper << 32);
let (res2, overflow2) = res1.overflowing_add(carry);
ret[i] = res2;
carry = (upper >> 32) + overflow1 as u64 + overflow2 as u64;
let (res, carry2) = mul_u32(arr[i], other, carry);
ret[i] = res;
carry = carry2;
}
($name(ret), carry > 0)
@ -1233,10 +1250,37 @@ impl U256 {
/// No overflow possible
#[cfg(not(all(asm_available, target_arch="x86_64")))]
pub fn full_mul(self, other: U256) -> U512 {
let self_512 = U512::from(self);
let other_512 = U512::from(other);
let (result, _) = self_512.overflowing_mul(other_512);
result
let U256(ref me) = self;
let U256(ref you) = other;
let mut ret = [0u64; 8];
for i in 0..4 {
let mut carry2 = 0u64;
let (b_u, b_l) = (you[i] >> 32, you[i] & 0xFFFFFFFF);
for j in 0..4 {
let a = me[j];
// multiply parts
let (c_l, overflow_l) = mul_u32(a, b_l as u32, ret[j + i]);
let (c_u, overflow_u) = mul_u32(a, b_u as u32, c_l >> 32);
// This won't overflow
ret[j + i] = (c_l & 0xFFFFFFFF) + (c_u << 32);
// carry1 = overflow_l + (c_u >> 32) + (overflow_u << 32) + carry2 + c0;
let (ca1, c1) = overflow_l.overflowing_add((c_u >> 32) + (overflow_u << 32));
let (ca1, c2) = ca1.overflowing_add(ret[j + i + 1]);
let (ca1, c3) = ca1.overflowing_add(carry2);
ret[j + i + 1] = ca1;
// Will never overflow
carry2 = (overflow_u >> 32) + c1 as u64 + c2 as u64 + c3 as u64;
}
}
U512(ret)
}
}
@ -1502,6 +1546,18 @@ mod tests {
//// TODO: bit inversion
}
#[test]
pub fn uint256_simple_mul() {
let a = U256::from_str("10000000000000000").unwrap();
let b = U256::from_str("10000000000000000").unwrap();
let c = U256::from_str("100000000000000000000000000000000").unwrap();
println!("Multiplying");
let result = a.overflowing_mul(b);
println!("Got result");
assert_eq!(result, (c, false))
}
#[test]
pub fn uint256_extreme_bitshift_test() {
//// Shifting a u64 by 64 bits gives an undefined value, so make sure that
@ -1664,21 +1720,16 @@ mod tests {
}
#[test]
pub fn uint256_shl_overflow() {
pub fn uint256_shl() {
assert_eq!(
U256::from_str("7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").unwrap()
<< 4,
U256::from_str("fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0").unwrap()
);
assert_eq!(
U256::from_str("7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").unwrap()
.overflowing_shl(4),
(U256::from_str("fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0").unwrap(), true)
);
}
#[test]
pub fn uint256_shl_overflow_words() {
pub fn uint256_shl_words() {
assert_eq!(
U256::from_str("0000000000000001ffffffffffffffffffffffffffffffffffffffffffffffff").unwrap()
<< 64,
@ -1689,45 +1740,6 @@ mod tests {
<< 64,
U256::from_str("ffffffffffffffffffffffffffffffffffffffffffffffff0000000000000000").unwrap()
);
assert_eq!(
U256::from_str("0000000000000001ffffffffffffffffffffffffffffffffffffffffffffffff").unwrap()
.overflowing_shl(64),
(U256::from_str("ffffffffffffffffffffffffffffffffffffffffffffffff0000000000000000").unwrap(), true)
);
assert_eq!(
U256::from_str("0000000000000000ffffffffffffffffffffffffffffffffffffffffffffffff").unwrap()
.overflowing_shl(64),
(U256::from_str("ffffffffffffffffffffffffffffffffffffffffffffffff0000000000000000").unwrap(), false)
);
}
#[test]
pub fn uint256_shl_overflow_words2() {
assert_eq!(
U256::from_str("00000000000000000000000000000001ffffffffffffffffffffffffffffffff").unwrap()
.overflowing_shl(128),
(U256::from_str("ffffffffffffffffffffffffffffffff00000000000000000000000000000000").unwrap(), true)
);
assert_eq!(
U256::from_str("00000000000000000000000000000000ffffffffffffffffffffffffffffffff").unwrap()
.overflowing_shl(128),
(U256::from_str("ffffffffffffffffffffffffffffffff00000000000000000000000000000000").unwrap(), false)
);
assert_eq!(
U256::from_str("00000000000000000000000000000000ffffffffffffffffffffffffffffffff").unwrap()
.overflowing_shl(129),
(U256::from_str("fffffffffffffffffffffffffffffffe00000000000000000000000000000000").unwrap(), true)
);
}
#[test]
pub fn uint256_shl_overflow2() {
assert_eq!(
U256::from_str("0fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").unwrap()
.overflowing_shl(4),
(U256::from_str("fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0").unwrap(), false)
);
}
#[test]