diff --git a/ethcore/src/builtin.rs b/ethcore/src/builtin.rs index b6d7064eb..a752064cd 100644 --- a/ethcore/src/builtin.rs +++ b/ethcore/src/builtin.rs @@ -267,6 +267,34 @@ impl Impl for Ripemd160 { } } +// calculate modexp: exponentiation by squaring. the `num` crate has pow, but not modular. +fn modexp(mut base: BigUint, mut exp: BigUint, modulus: BigUint) -> BigUint { + use num::Integer; + + match (base.is_zero(), exp.is_zero()) { + (_, true) => return BigUint::one(), // n^0 % m + (true, false) => return BigUint::zero(), // 0^n % m, n>0 + (false, false) if modulus <= BigUint::one() => return BigUint::zero(), // a^b % 1 = 0. + _ => {} + } + + let mut result = BigUint::one(); + base = base % &modulus; + + // fast path for base divisible by modulus. + if base.is_zero() { return BigUint::zero() } + while !exp.is_zero() { + if exp.is_odd() { + result = (result * &base) % &modulus; + } + + exp = exp >> 1; + base = (base.clone() * base) % &modulus; + } + + result +} + impl Impl for ModexpImpl { fn execute(&self, input: &[u8], output: &mut BytesRef) -> Result<(), Error> { let mut reader = input.chain(io::repeat(0)); @@ -295,34 +323,6 @@ impl Impl for ModexpImpl { let exp = read_num(exp_len); let modulus = read_num(mod_len); - // calculate modexp: exponentiation by squaring. the `num` crate has pow, but not modular. - fn modexp(mut base: BigUint, mut exp: BigUint, modulus: BigUint) -> BigUint { - use num::Integer; - - match (base.is_zero(), exp.is_zero()) { - (_, true) => return BigUint::one(), // n^0 % m - (true, false) => return BigUint::zero(), // 0^n % m, n>0 - (false, false) if modulus <= BigUint::one() => return BigUint::zero(), // a^b % 1 = 0. - _ => {} - } - - let mut result = BigUint::one(); - base = base % &modulus; - - // fast path for base divisible by modulus. - if base.is_zero() { return result } - while !exp.is_zero() { - if exp.is_odd() { - result = (result * &base) % &modulus; - } - - exp = exp >> 1; - base = (base.clone() * base) % &modulus; - } - - result - } - // write output to given memory, left padded and same length as the modulus. let bytes = modexp(base, exp, modulus).to_bytes_be(); @@ -504,10 +504,44 @@ impl Impl for Bn128PairingImpl { #[cfg(test)] mod tests { - use super::{Builtin, Linear, ethereum_builtin, Pricer, Modexp}; + use super::{Builtin, Linear, ethereum_builtin, Pricer, Modexp, modexp as me}; use ethjson; use util::{U256, BytesRef}; use rustc_hex::FromHex; + use num::{BigUint, Zero, One}; + + #[test] + fn modexp_func() { + // n^0 % m == 1 + let mut base = BigUint::parse_bytes(b"12345", 10).unwrap(); + let mut exp = BigUint::zero(); + let mut modulus = BigUint::parse_bytes(b"789", 10).unwrap(); + assert_eq!(me(base, exp, modulus), BigUint::one()); + + // 0^n % m == 0 + base = BigUint::zero(); + exp = BigUint::parse_bytes(b"12345", 10).unwrap(); + modulus = BigUint::parse_bytes(b"789", 10).unwrap(); + assert_eq!(me(base, exp, modulus), BigUint::zero()); + + // n^m % 1 == 0 + base = BigUint::parse_bytes(b"12345", 10).unwrap(); + exp = BigUint::parse_bytes(b"789", 10).unwrap(); + modulus = BigUint::one(); + assert_eq!(me(base, exp, modulus), BigUint::zero()); + + // if n % d == 0, then n^m % d == 0 + base = BigUint::parse_bytes(b"12345", 10).unwrap(); + exp = BigUint::parse_bytes(b"789", 10).unwrap(); + modulus = BigUint::parse_bytes(b"15", 10).unwrap(); + assert_eq!(me(base, exp, modulus), BigUint::zero()); + + // others + base = BigUint::parse_bytes(b"12345", 10).unwrap(); + exp = BigUint::parse_bytes(b"789", 10).unwrap(); + modulus = BigUint::parse_bytes(b"97", 10).unwrap(); + assert_eq!(me(base, exp, modulus), BigUint::parse_bytes(b"55", 10).unwrap()); + } #[test] fn identity() {