ethcore: minor optimization of modexp by using LR exponentiation (#9697)

This commit is contained in:
André Silva 2018-10-04 12:29:53 +01:00 committed by Marek Kotewicz
parent 5a2f3e700b
commit 726884afcb
2 changed files with 45 additions and 27 deletions

View File

@ -25,12 +25,10 @@ extern crate ethereum_types;
extern crate parity_bytes as bytes;
extern crate rustc_hex;
use std::collections::BTreeMap;
use bytes::BytesRef;
use ethcore::builtin::Builtin;
use ethcore::machine::EthereumMachine;
use ethereum_types::{Address, U256};
use ethereum_types::U256;
use ethcore::ethereum::new_byzantium_test_machine;
use rustc_hex::FromHex;
use self::test::Bencher;

View File

@ -311,35 +311,51 @@ impl Impl for Ripemd160 {
}
}
// calculate modexp: exponentiation by squaring. the `num` crate has pow, but not modular.
fn modexp(mut base: BigUint, mut exp: BigUint, modulus: BigUint) -> BigUint {
use num::Integer;
// calculate modexp: left-to-right binary exponentiation to keep multiplicands lower
fn modexp(mut base: BigUint, exp: Vec<u8>, modulus: BigUint) -> BigUint {
const BITS_PER_DIGIT: usize = 8;
if modulus <= BigUint::one() { // n^m % 0 || n^m % 1
// n^m % 0 || n^m % 1
if modulus <= BigUint::one() {
return BigUint::zero();
}
if exp.is_zero() { // n^0 % m
// normalize exponent
let mut exp = exp.into_iter().skip_while(|d| *d == 0).peekable();
// n^0 % m
if let None = exp.peek() {
return BigUint::one();
}
if base.is_zero() { // 0^n % m, n>0
// 0^n % m, n > 0
if base.is_zero() {
return BigUint::zero();
}
let mut result = BigUint::one();
base = base % &modulus;
// fast path for base divisible by modulus.
// Fast path for base divisible by modulus.
if base.is_zero() { return BigUint::zero() }
while !exp.is_zero() {
if exp.is_odd() {
result = (result * &base) % &modulus;
// Left-to-right binary exponentiation (Handbook of Applied Cryptography - Algorithm 14.79).
// http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
let mut result = BigUint::one();
for digit in exp {
let mut mask = 1 << (BITS_PER_DIGIT - 1);
for _ in 0..BITS_PER_DIGIT {
result = &result * &result % &modulus;
if digit & mask > 0 {
result = result * &base % &modulus;
}
exp = exp >> 1;
base = (base.clone() * base) % &modulus;
mask >>= 1;
}
}
result
}
@ -366,15 +382,19 @@ impl Impl for ModexpImpl {
} else {
// read the numbers themselves.
let mut buf = vec![0; max(mod_len, max(base_len, exp_len))];
let mut read_num = |len| {
let mut read_num = |reader: &mut io::Chain<&[u8], io::Repeat>, len: usize| {
reader.read_exact(&mut buf[..len]).expect("reading from zero-extended memory cannot fail; qed");
BigUint::from_bytes_be(&buf[..len])
};
let base = read_num(base_len);
let exp = read_num(exp_len);
let modulus = read_num(mod_len);
modexp(base, exp, modulus)
let base = read_num(&mut reader, base_len);
let mut exp_buf = vec![0; exp_len];
reader.read_exact(&mut exp_buf[..exp_len]).expect("reading from zero-extended memory cannot fail; qed");
let modulus = read_num(&mut reader, mod_len);
modexp(base, exp_buf, modulus)
};
// write output to given memory, left padded and same length as the modulus.
@ -551,31 +571,31 @@ mod tests {
let mut base = BigUint::parse_bytes(b"12345", 10).unwrap();
let mut exp = BigUint::zero();
let mut modulus = BigUint::parse_bytes(b"789", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::one());
assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::one());
// 0^n % m == 0
base = BigUint::zero();
exp = BigUint::parse_bytes(b"12345", 10).unwrap();
modulus = BigUint::parse_bytes(b"789", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::zero());
assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::zero());
// n^m % 1 == 0
base = BigUint::parse_bytes(b"12345", 10).unwrap();
exp = BigUint::parse_bytes(b"789", 10).unwrap();
modulus = BigUint::one();
assert_eq!(me(base, exp, modulus), BigUint::zero());
assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::zero());
// if n % d == 0, then n^m % d == 0
base = BigUint::parse_bytes(b"12345", 10).unwrap();
exp = BigUint::parse_bytes(b"789", 10).unwrap();
modulus = BigUint::parse_bytes(b"15", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::zero());
assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::zero());
// others
base = BigUint::parse_bytes(b"12345", 10).unwrap();
exp = BigUint::parse_bytes(b"789", 10).unwrap();
modulus = BigUint::parse_bytes(b"97", 10).unwrap();
assert_eq!(me(base, exp, modulus), BigUint::parse_bytes(b"55", 10).unwrap());
assert_eq!(me(base, exp.to_bytes_be(), modulus), BigUint::parse_bytes(b"55", 10).unwrap());
}
#[test]