ethcore: minor optimization of modexp by using LR exponentiation (#9697)

This commit is contained in:
André Silva 2018-10-04 12:29:53 +01:00 committed by Marek Kotewicz
parent 5a2f3e700b
commit 726884afcb
2 changed files with 45 additions and 27 deletions

View File

@ -25,12 +25,10 @@ extern crate ethereum_types;
extern crate parity_bytes as bytes; extern crate parity_bytes as bytes;
extern crate rustc_hex; extern crate rustc_hex;
use std::collections::BTreeMap;
use bytes::BytesRef; use bytes::BytesRef;
use ethcore::builtin::Builtin; use ethcore::builtin::Builtin;
use ethcore::machine::EthereumMachine; use ethcore::machine::EthereumMachine;
use ethereum_types::{Address, U256}; use ethereum_types::U256;
use ethcore::ethereum::new_byzantium_test_machine; use ethcore::ethereum::new_byzantium_test_machine;
use rustc_hex::FromHex; use rustc_hex::FromHex;
use self::test::Bencher; use self::test::Bencher;

View File

@ -311,35 +311,51 @@ impl Impl for Ripemd160 {
} }
} }
// calculate modexp: exponentiation by squaring. the `num` crate has pow, but not modular. // calculate modexp: left-to-right binary exponentiation to keep multiplicands lower
fn modexp(mut base: BigUint, mut exp: BigUint, modulus: BigUint) -> BigUint { fn modexp(mut base: BigUint, exp: Vec<u8>, modulus: BigUint) -> BigUint {
use num::Integer; const BITS_PER_DIGIT: usize = 8;
if modulus <= BigUint::one() { // n^m % 0 || n^m % 1 // n^m % 0 || n^m % 1
if modulus <= BigUint::one() {
return BigUint::zero(); return BigUint::zero();
} }
if exp.is_zero() { // n^0 % m // normalize exponent
let mut exp = exp.into_iter().skip_while(|d| *d == 0).peekable();
// n^0 % m
if let None = exp.peek() {
return BigUint::one(); return BigUint::one();
} }
if base.is_zero() { // 0^n % m, n>0 // 0^n % m, n > 0
if base.is_zero() {
return BigUint::zero(); return BigUint::zero();
} }
let mut result = BigUint::one();
base = base % &modulus; base = base % &modulus;
// fast path for base divisible by modulus. // Fast path for base divisible by modulus.
if base.is_zero() { return BigUint::zero() } if base.is_zero() { return BigUint::zero() }
while !exp.is_zero() {
if exp.is_odd() {
result = (result * &base) % &modulus;
}
exp = exp >> 1; // Left-to-right binary exponentiation (Handbook of Applied Cryptography - Algorithm 14.79).
base = (base.clone() * base) % &modulus; // http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
let mut result = BigUint::one();
for digit in exp {
let mut mask = 1 << (BITS_PER_DIGIT - 1);
for _ in 0..BITS_PER_DIGIT {
result = &result * &result % &modulus;
if digit & mask > 0 {
result = result * &base % &modulus;
}
mask >>= 1;
}
} }
result result
} }
@ -366,15 +382,19 @@ impl Impl for ModexpImpl {
} else { } else {
// read the numbers themselves. // read the numbers themselves.
let mut buf = vec![0; max(mod_len, max(base_len, exp_len))]; let mut buf = vec![0; max(mod_len, max(base_len, exp_len))];
let mut read_num = |len| { let mut read_num = |reader: &mut io::Chain<&[u8], io::Repeat>, len: usize| {
reader.read_exact(&mut buf[..len]).expect("reading from zero-extended memory cannot fail; qed"); reader.read_exact(&mut buf[..len]).expect("reading from zero-extended memory cannot fail; qed");
BigUint::from_bytes_be(&buf[..len]) BigUint::from_bytes_be(&buf[..len])
}; };
let base = read_num(base_len); let base = read_num(&mut reader, base_len);
let exp = read_num(exp_len);
let modulus = read_num(mod_len); let mut exp_buf = vec![0; exp_len];
modexp(base, exp, modulus) reader.read_exact(&mut exp_buf[..exp_len]).expect("reading from zero-extended memory cannot fail; qed");
let modulus = read_num(&mut reader, mod_len);
modexp(base, exp_buf, modulus)
}; };
// write output to given memory, left padded and same length as the modulus. // write output to given memory, left padded and same length as the modulus.
@ -551,31 +571,31 @@ mod tests {
let mut base = BigUint::parse_bytes(b"12345", 10).unwrap(); let mut base = BigUint::parse_bytes(b"12345", 10).unwrap();
let mut exp = BigUint::zero(); let mut exp = BigUint::zero();
let mut modulus = BigUint::parse_bytes(b"789", 10).unwrap(); let mut modulus = BigUint::parse_bytes(b"789", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::one()); assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::one());
// 0^n % m == 0 // 0^n % m == 0
base = BigUint::zero(); base = BigUint::zero();
exp = BigUint::parse_bytes(b"12345", 10).unwrap(); exp = BigUint::parse_bytes(b"12345", 10).unwrap();
modulus = BigUint::parse_bytes(b"789", 10).unwrap(); modulus = BigUint::parse_bytes(b"789", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::zero()); assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::zero());
// n^m % 1 == 0 // n^m % 1 == 0
base = BigUint::parse_bytes(b"12345", 10).unwrap(); base = BigUint::parse_bytes(b"12345", 10).unwrap();
exp = BigUint::parse_bytes(b"789", 10).unwrap(); exp = BigUint::parse_bytes(b"789", 10).unwrap();
modulus = BigUint::one(); modulus = BigUint::one();
assert_eq!(me(base, exp, modulus), BigUint::zero()); assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::zero());
// if n % d == 0, then n^m % d == 0 // if n % d == 0, then n^m % d == 0
base = BigUint::parse_bytes(b"12345", 10).unwrap(); base = BigUint::parse_bytes(b"12345", 10).unwrap();
exp = BigUint::parse_bytes(b"789", 10).unwrap(); exp = BigUint::parse_bytes(b"789", 10).unwrap();
modulus = BigUint::parse_bytes(b"15", 10).unwrap(); modulus = BigUint::parse_bytes(b"15", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::zero()); assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::zero());
// others // others
base = BigUint::parse_bytes(b"12345", 10).unwrap(); base = BigUint::parse_bytes(b"12345", 10).unwrap();
exp = BigUint::parse_bytes(b"789", 10).unwrap(); exp = BigUint::parse_bytes(b"789", 10).unwrap();
modulus = BigUint::parse_bytes(b"97", 10).unwrap(); modulus = BigUint::parse_bytes(b"97", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::parse_bytes(b"55", 10).unwrap()); assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::parse_bytes(b"55", 10).unwrap());
} }
#[test] #[test]