fix modexp bug: return 0 if base=0 (#6424)
This commit is contained in:
parent
2faa28ce9b
commit
1d95fe481f
@ -267,6 +267,34 @@ impl Impl for Ripemd160 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// calculate modexp: exponentiation by squaring. the `num` crate has pow, but not modular.
|
||||||
|
fn modexp(mut base: BigUint, mut exp: BigUint, modulus: BigUint) -> BigUint {
|
||||||
|
use num::Integer;
|
||||||
|
|
||||||
|
match (base.is_zero(), exp.is_zero()) {
|
||||||
|
(_, true) => return BigUint::one(), // n^0 % m
|
||||||
|
(true, false) => return BigUint::zero(), // 0^n % m, n>0
|
||||||
|
(false, false) if modulus <= BigUint::one() => return BigUint::zero(), // a^b % 1 = 0.
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = BigUint::one();
|
||||||
|
base = base % &modulus;
|
||||||
|
|
||||||
|
// fast path for base divisible by modulus.
|
||||||
|
if base.is_zero() { return BigUint::zero() }
|
||||||
|
while !exp.is_zero() {
|
||||||
|
if exp.is_odd() {
|
||||||
|
result = (result * &base) % &modulus;
|
||||||
|
}
|
||||||
|
|
||||||
|
exp = exp >> 1;
|
||||||
|
base = (base.clone() * base) % &modulus;
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
impl Impl for ModexpImpl {
|
impl Impl for ModexpImpl {
|
||||||
fn execute(&self, input: &[u8], output: &mut BytesRef) -> Result<(), Error> {
|
fn execute(&self, input: &[u8], output: &mut BytesRef) -> Result<(), Error> {
|
||||||
let mut reader = input.chain(io::repeat(0));
|
let mut reader = input.chain(io::repeat(0));
|
||||||
@ -295,34 +323,6 @@ impl Impl for ModexpImpl {
|
|||||||
let exp = read_num(exp_len);
|
let exp = read_num(exp_len);
|
||||||
let modulus = read_num(mod_len);
|
let modulus = read_num(mod_len);
|
||||||
|
|
||||||
// calculate modexp: exponentiation by squaring. the `num` crate has pow, but not modular.
|
|
||||||
fn modexp(mut base: BigUint, mut exp: BigUint, modulus: BigUint) -> BigUint {
|
|
||||||
use num::Integer;
|
|
||||||
|
|
||||||
match (base.is_zero(), exp.is_zero()) {
|
|
||||||
(_, true) => return BigUint::one(), // n^0 % m
|
|
||||||
(true, false) => return BigUint::zero(), // 0^n % m, n>0
|
|
||||||
(false, false) if modulus <= BigUint::one() => return BigUint::zero(), // a^b % 1 = 0.
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut result = BigUint::one();
|
|
||||||
base = base % &modulus;
|
|
||||||
|
|
||||||
// fast path for base divisible by modulus.
|
|
||||||
if base.is_zero() { return result }
|
|
||||||
while !exp.is_zero() {
|
|
||||||
if exp.is_odd() {
|
|
||||||
result = (result * &base) % &modulus;
|
|
||||||
}
|
|
||||||
|
|
||||||
exp = exp >> 1;
|
|
||||||
base = (base.clone() * base) % &modulus;
|
|
||||||
}
|
|
||||||
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
// write output to given memory, left padded and same length as the modulus.
|
// write output to given memory, left padded and same length as the modulus.
|
||||||
let bytes = modexp(base, exp, modulus).to_bytes_be();
|
let bytes = modexp(base, exp, modulus).to_bytes_be();
|
||||||
|
|
||||||
@ -504,10 +504,44 @@ impl Impl for Bn128PairingImpl {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::{Builtin, Linear, ethereum_builtin, Pricer, Modexp};
|
use super::{Builtin, Linear, ethereum_builtin, Pricer, Modexp, modexp as me};
|
||||||
use ethjson;
|
use ethjson;
|
||||||
use util::{U256, BytesRef};
|
use util::{U256, BytesRef};
|
||||||
use rustc_hex::FromHex;
|
use rustc_hex::FromHex;
|
||||||
|
use num::{BigUint, Zero, One};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn modexp_func() {
|
||||||
|
// n^0 % m == 1
|
||||||
|
let mut base = BigUint::parse_bytes(b"12345", 10).unwrap();
|
||||||
|
let mut exp = BigUint::zero();
|
||||||
|
let mut modulus = BigUint::parse_bytes(b"789", 10).unwrap();
|
||||||
|
assert_eq!(me(base, exp, modulus), BigUint::one());
|
||||||
|
|
||||||
|
// 0^n % m == 0
|
||||||
|
base = BigUint::zero();
|
||||||
|
exp = BigUint::parse_bytes(b"12345", 10).unwrap();
|
||||||
|
modulus = BigUint::parse_bytes(b"789", 10).unwrap();
|
||||||
|
assert_eq!(me(base, exp, modulus), BigUint::zero());
|
||||||
|
|
||||||
|
// n^m % 1 == 0
|
||||||
|
base = BigUint::parse_bytes(b"12345", 10).unwrap();
|
||||||
|
exp = BigUint::parse_bytes(b"789", 10).unwrap();
|
||||||
|
modulus = BigUint::one();
|
||||||
|
assert_eq!(me(base, exp, modulus), BigUint::zero());
|
||||||
|
|
||||||
|
// if n % d == 0, then n^m % d == 0
|
||||||
|
base = BigUint::parse_bytes(b"12345", 10).unwrap();
|
||||||
|
exp = BigUint::parse_bytes(b"789", 10).unwrap();
|
||||||
|
modulus = BigUint::parse_bytes(b"15", 10).unwrap();
|
||||||
|
assert_eq!(me(base, exp, modulus), BigUint::zero());
|
||||||
|
|
||||||
|
// others
|
||||||
|
base = BigUint::parse_bytes(b"12345", 10).unwrap();
|
||||||
|
exp = BigUint::parse_bytes(b"789", 10).unwrap();
|
||||||
|
modulus = BigUint::parse_bytes(b"97", 10).unwrap();
|
||||||
|
assert_eq!(me(base, exp, modulus), BigUint::parse_bytes(b"55", 10).unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn identity() {
|
fn identity() {
|
||||||
|
Loading…
Reference in New Issue
Block a user