Merge branch 'master' into bigint-opt

Conflicts:
	util/src/uint.rs
This commit is contained in:
Nikolay Volf
2016-02-27 19:17:51 +03:00
72 changed files with 4994 additions and 2049 deletions

View File

@@ -1,469 +0,0 @@
// Copyright 2015, 2016 Ethcore (UK) Ltd.
// This file is part of Parity.
// Parity is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// Parity is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with Parity. If not, see <http://www.gnu.org/licenses/>.
//! Multilevel blockchain bloom filter.
//!
//! ```
//! extern crate ethcore_util as util;
//! use std::str::FromStr;
//! use util::chainfilter::*;
//! use util::sha3::*;
//! use util::hash::*;
//!
//! fn main() {
//! let (index_size, bloom_levels) = (16, 3);
//! let mut cache = MemoryCache::new();
//!
//! let address = Address::from_str("ef2d6d194084c2de36e0dabfce45d046b37d1106").unwrap();
//!
//! // borrow cache for reading inside the scope
//! let modified_blooms = {
//! let filter = ChainFilter::new(&cache, index_size, bloom_levels);
//! let block_number = 39;
//! let mut bloom = H2048::new();
//! bloom.shift_bloomed(&address.sha3());
//! filter.add_bloom(&bloom, block_number)
//! };
//!
//! // number of updated blooms is equal number of levels
//! assert_eq!(modified_blooms.len(), bloom_levels as usize);
//!
//! // lets inserts modified blooms into the cache
//! cache.insert_blooms(modified_blooms);
//!
//! // borrow cache for another reading operations
//! {
//! let filter = ChainFilter::new(&cache, index_size, bloom_levels);
//! let blocks = filter.blocks_with_address(&address, 10, 40);
//! assert_eq!(blocks.len(), 1);
//! assert_eq!(blocks[0], 39);
//! }
//! }
//! ```
//!
use std::collections::{HashMap};
use hash::*;
use sha3::*;
/// Represents bloom index in cache
///
/// On cache level 0, every block bloom is represented by different index.
/// On higher cache levels, multiple block blooms are represented by one
/// index. Their `BloomIndex` can be created from block number and given level.
#[derive(Eq, PartialEq, Hash, Clone, Debug)]
pub struct BloomIndex {
/// Bloom level
pub level: u8,
/// Filter Index
pub index: usize,
}
impl BloomIndex {
/// Default constructor for `BloomIndex`
pub fn new(level: u8, index: usize) -> BloomIndex {
BloomIndex {
level: level,
index: index,
}
}
}
/// Types implementing this trait should provide read access for bloom filters database.
pub trait FilterDataSource {
/// returns reference to log at given position if it exists
fn bloom_at_index(&self, index: &BloomIndex) -> Option<&H2048>;
}
/// In memory cache for blooms.
///
/// Stores all blooms in HashMap, which indexes them by `BloomIndex`.
pub struct MemoryCache {
blooms: HashMap<BloomIndex, H2048>,
}
impl MemoryCache {
/// Default constructor for MemoryCache
pub fn new() -> MemoryCache {
MemoryCache { blooms: HashMap::new() }
}
/// inserts all blooms into cache
///
/// if bloom at given index already exists, overwrites it
pub fn insert_blooms(&mut self, blooms: HashMap<BloomIndex, H2048>) {
self.blooms.extend(blooms);
}
}
impl FilterDataSource for MemoryCache {
fn bloom_at_index(&self, index: &BloomIndex) -> Option<&H2048> {
self.blooms.get(index)
}
}
/// Should be used for search operations on blockchain.
pub struct ChainFilter<'a, D>
where D: FilterDataSource + 'a
{
data_source: &'a D,
index_size: usize,
level_sizes: Vec<usize>,
}
impl<'a, D> ChainFilter<'a, D> where D: FilterDataSource
{
/// Creates new filter instance.
///
/// Borrows `FilterDataSource` for reading.
pub fn new(data_source: &'a D, index_size: usize, levels: u8) -> Self {
if levels == 0 {
panic!("ChainFilter requires at least 1 level");
}
let mut filter = ChainFilter {
data_source: data_source,
index_size: index_size,
// 0 level has always a size of 1
level_sizes: vec![1]
};
// cache level sizes, so we do not have to calculate them all the time
// eg. if levels == 3, index_size = 16
// level_sizes = [1, 16, 256]
let additional: Vec<usize> = (1..).into_iter()
.scan(1, |acc, _| {
*acc = *acc * index_size;
Some(*acc)
})
.take(levels as usize - 1)
.collect();
filter.level_sizes.extend(additional);
filter
}
/// unsafely get level size
fn level_size(&self, level: u8) -> usize {
self.level_sizes[level as usize]
}
/// converts block number and level to `BloomIndex`
fn bloom_index(&self, block_number: usize, level: u8) -> BloomIndex {
BloomIndex {
level: level,
index: block_number / self.level_size(level),
}
}
/// return bloom which are dependencies for given index
///
/// bloom indexes are ordered from lowest to highest
fn lower_level_bloom_indexes(&self, index: &BloomIndex) -> Vec<BloomIndex> {
// this is the lowest level
if index.level == 0 {
return vec![];
}
let new_level = index.level - 1;
let offset = self.index_size * index.index;
(0..self.index_size).map(|i| BloomIndex::new(new_level, offset + i)).collect()
}
/// return number of levels
fn levels(&self) -> u8 {
self.level_sizes.len() as u8
}
/// returns max filter level
fn max_level(&self) -> u8 {
self.level_sizes.len() as u8 - 1
}
/// internal function which does bloom search recursively
fn blocks(&self, bloom: &H2048, from_block: usize, to_block: usize, level: u8, offset: usize) -> Option<Vec<usize>> {
let index = self.bloom_index(offset, level);
match self.data_source.bloom_at_index(&index) {
None => return None,
Some(level_bloom) => match level {
// if we are on the lowest level
// take the value, exclude to_block
0 if offset < to_block => return Some(vec![offset]),
// return None if it is is equal to to_block
0 => return None,
// return None if current level doesnt contain given bloom
_ if !level_bloom.contains(bloom) => return None,
// continue processing && go down
_ => ()
}
};
let level_size = self.level_size(level - 1);
let from_index = self.bloom_index(from_block, level - 1);
let to_index = self.bloom_index(to_block, level - 1);
let res: Vec<usize> = self.lower_level_bloom_indexes(&index).into_iter()
// chose only blooms in range
.filter(|li| li.index >= from_index.index && li.index <= to_index.index)
// map them to offsets
.map(|li| li.index * level_size)
// get all blocks that may contain our bloom
.map(|off| self.blocks(bloom, from_block, to_block, level - 1, off))
// filter existing ones
.filter_map(|x| x)
// flatten nested structures
.flat_map(|v| v)
.collect();
Some(res)
}
/// Adds new bloom to all filter levels
pub fn add_bloom(&self, bloom: &H2048, block_number: usize) -> HashMap<BloomIndex, H2048> {
let mut result: HashMap<BloomIndex, H2048> = HashMap::new();
for level in 0..self.levels() {
let bloom_index = self.bloom_index(block_number, level);
let new_bloom = match self.data_source.bloom_at_index(&bloom_index) {
Some(old_bloom) => old_bloom | bloom,
None => bloom.clone(),
};
result.insert(bloom_index, new_bloom);
}
result
}
/// Adds new blooms starting from block number.
pub fn add_blooms(&self, blooms: &[H2048], block_number: usize) -> HashMap<BloomIndex, H2048> {
let mut result: HashMap<BloomIndex, H2048> = HashMap::new();
for level in 0..self.levels() {
for i in 0..blooms.len() {
let bloom_index = self.bloom_index(block_number + i, level);
let is_new_bloom = match result.get_mut(&bloom_index) {
// it was already modified
Some(to_shift) => {
*to_shift = &blooms[i] | to_shift;
false
}
None => true,
};
// it hasn't been modified yet
if is_new_bloom {
let new_bloom = match self.data_source.bloom_at_index(&bloom_index) {
Some(old_bloom) => old_bloom | &blooms[i],
None => blooms[i].clone(),
};
result.insert(bloom_index, new_bloom);
}
}
}
result
}
/// Resets bloom at level 0 and forces rebuild on higher levels.
pub fn reset_bloom(&self, bloom: &H2048, block_number: usize) -> HashMap<BloomIndex, H2048> {
let mut result: HashMap<BloomIndex, H2048> = HashMap::new();
let mut reset_index = self.bloom_index(block_number, 0);
result.insert(reset_index.clone(), bloom.clone());
for level in 1..self.levels() {
let index = self.bloom_index(block_number, level);
// get all bloom indexes that were used to construct this bloom
let lower_indexes = self.lower_level_bloom_indexes(&index);
let new_bloom = lower_indexes.into_iter()
// skip reseted one
.filter(|li| li != &reset_index)
// get blooms for these indexes
.map(|li| self.data_source.bloom_at_index(&li))
// filter existing ones
.filter_map(|b| b)
// BitOr all of them
.fold(H2048::new(), |acc, bloom| &acc | bloom);
reset_index = index.clone();
result.insert(index, &new_bloom | bloom);
}
result
}
/// Sets lowest level bloom to 0 and forces rebuild on higher levels.
pub fn clear_bloom(&self, block_number: usize) -> HashMap<BloomIndex, H2048> {
self.reset_bloom(&H2048::new(), block_number)
}
/// Returns numbers of blocks that may contain Address.
pub fn blocks_with_address(&self, address: &Address, from_block: usize, to_block: usize) -> Vec<usize> {
let mut bloom = H2048::new();
bloom.shift_bloomed(&address.sha3());
self.blocks_with_bloom(&bloom, from_block, to_block)
}
/// Returns numbers of blocks that may contain Topic.
pub fn blocks_with_topic(&self, topic: &H256, from_block: usize, to_block: usize) -> Vec<usize> {
let mut bloom = H2048::new();
bloom.shift_bloomed(&topic.sha3());
self.blocks_with_bloom(&bloom, from_block, to_block)
}
/// Returns numbers of blocks that may log bloom.
pub fn blocks_with_bloom(&self, bloom: &H2048, from_block: usize, to_block: usize) -> Vec<usize> {
let mut result = vec![];
// lets start from highest level
let max_level = self.max_level();
let level_size = self.level_size(max_level);
let from_index = self.bloom_index(from_block, max_level);
let to_index = self.bloom_index(to_block, max_level);
for index in from_index.index..to_index.index + 1 {
// offset will be used to calculate where we are right now
let offset = level_size * index;
// go doooown!
if let Some(blocks) = self.blocks(bloom, from_block, to_block, max_level, offset) {
result.extend(blocks);
}
}
result
}
}
#[cfg(test)]
mod tests {
use hash::*;
use chainfilter::*;
use sha3::*;
use std::str::FromStr;
#[test]
fn test_level_size() {
let cache = MemoryCache::new();
let filter = ChainFilter::new(&cache, 16, 3);
assert_eq!(filter.level_size(0), 1);
assert_eq!(filter.level_size(1), 16);
assert_eq!(filter.level_size(2), 256);
}
#[test]
fn test_bloom_index() {
let cache = MemoryCache::new();
let filter = ChainFilter::new(&cache, 16, 3);
let bi0 = filter.bloom_index(0, 0);
assert_eq!(bi0.level, 0);
assert_eq!(bi0.index, 0);
let bi1 = filter.bloom_index(1, 0);
assert_eq!(bi1.level, 0);
assert_eq!(bi1.index, 1);
let bi2 = filter.bloom_index(2, 0);
assert_eq!(bi2.level, 0);
assert_eq!(bi2.index, 2);
let bi3 = filter.bloom_index(3, 1);
assert_eq!(bi3.level, 1);
assert_eq!(bi3.index, 0);
let bi4 = filter.bloom_index(15, 1);
assert_eq!(bi4.level, 1);
assert_eq!(bi4.index, 0);
let bi5 = filter.bloom_index(16, 1);
assert_eq!(bi5.level, 1);
assert_eq!(bi5.index, 1);
let bi6 = filter.bloom_index(255, 2);
assert_eq!(bi6.level, 2);
assert_eq!(bi6.index, 0);
let bi7 = filter.bloom_index(256, 2);
assert_eq!(bi7.level, 2);
assert_eq!(bi7.index, 1);
}
#[test]
fn test_lower_level_bloom_indexes() {
let cache = MemoryCache::new();
let filter = ChainFilter::new(&cache, 16, 3);
let bi = filter.bloom_index(256, 2);
assert_eq!(bi.level, 2);
assert_eq!(bi.index, 1);
let mut ebis = vec![];
for i in 16..32 {
ebis.push(BloomIndex::new(1, i));
}
let bis = filter.lower_level_bloom_indexes(&bi);
assert_eq!(ebis, bis);
}
#[test]
fn test_topic_basic_search() {
let index_size = 16;
let bloom_levels = 3;
let mut cache = MemoryCache::new();
let topic = H256::from_str("8d936b1bd3fc635710969ccfba471fb17d598d9d1971b538dd712e1e4b4f4dba").unwrap();
let modified_blooms = {
let filter = ChainFilter::new(&cache, index_size, bloom_levels);
let block_number = 23;
let mut bloom = H2048::new();
bloom.shift_bloomed(&topic.sha3());
filter.add_bloom(&bloom, block_number)
};
// number of modified blooms should always be equal number of levels
assert_eq!(modified_blooms.len(), bloom_levels as usize);
cache.insert_blooms(modified_blooms);
{
let filter = ChainFilter::new(&cache, index_size, bloom_levels);
let blocks = filter.blocks_with_topic(&topic, 0, 100);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0], 23);
}
{
let filter = ChainFilter::new(&cache, index_size, bloom_levels);
let blocks = filter.blocks_with_topic(&topic, 0, 23);
assert_eq!(blocks.len(), 0);
}
{
let filter = ChainFilter::new(&cache, index_size, bloom_levels);
let blocks = filter.blocks_with_topic(&topic, 23, 24);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0], 23);
}
{
let filter = ChainFilter::new(&cache, index_size, bloom_levels);
let blocks = filter.blocks_with_topic(&topic, 24, 100);
assert_eq!(blocks.len(), 0);
}
}
}

View File

@@ -235,11 +235,11 @@ macro_rules! impl_hash {
}
impl serde::Serialize for $from {
fn serialize<S>(&self, serializer: &mut S) -> Result<(), S::Error>
fn serialize<S>(&self, serializer: &mut S) -> Result<(), S::Error>
where S: serde::Serializer {
let mut hex = "0x".to_owned();
hex.push_str(self.to_hex().as_ref());
serializer.visit_str(hex.as_ref())
serializer.serialize_str(hex.as_ref())
}
}
@@ -250,14 +250,14 @@ macro_rules! impl_hash {
impl serde::de::Visitor for HashVisitor {
type Value = $from;
fn visit_str<E>(&mut self, value: &str) -> Result<Self::Value, E> where E: serde::Error {
// 0x + len
if value.len() != 2 + $size * 2 {
return Err(serde::Error::syntax("Invalid length."));
return Err(serde::Error::custom("Invalid length."));
}
value[2..].from_hex().map(|ref v| $from::from_slice(v)).map_err(|_| serde::Error::syntax("Invalid valid hex."))
value[2..].from_hex().map(|ref v| $from::from_slice(v)).map_err(|_| serde::Error::custom("Invalid valid hex."))
}
fn visit_string<E>(&mut self, value: String) -> Result<Self::Value, E> where E: serde::Error {
@@ -265,7 +265,7 @@ macro_rules! impl_hash {
}
}
deserializer.visit(HashVisitor)
deserializer.deserialize(HashVisitor)
}
}
@@ -719,4 +719,3 @@ mod tests {
assert_eq!(r, u);
}
}

View File

@@ -20,7 +20,7 @@ use common::*;
use rlp::*;
use hashdb::*;
use memorydb::*;
use rocksdb::{DB, Writable, WriteBatch, IteratorMode};
use kvdb::{Database, DBTransaction, DatabaseConfig};
#[cfg(test)]
use std::env;
@@ -33,7 +33,7 @@ use std::env;
/// the removals actually take effect.
pub struct JournalDB {
overlay: MemoryDB,
backing: Arc<DB>,
backing: Arc<Database>,
counters: Arc<RwLock<HashMap<H256, i32>>>,
}
@@ -47,21 +47,25 @@ impl Clone for JournalDB {
}
}
const LATEST_ERA_KEY : [u8; 4] = [ b'l', b'a', b's', b't' ];
const VERSION_KEY : [u8; 4] = [ b'j', b'v', b'e', b'r' ];
// all keys must be at least 12 bytes
const LATEST_ERA_KEY : [u8; 12] = [ b'l', b'a', b's', b't', 0, 0, 0, 0, 0, 0, 0, 0 ];
const VERSION_KEY : [u8; 12] = [ b'j', b'v', b'e', b'r', 0, 0, 0, 0, 0, 0, 0, 0 ];
const DB_VERSION: u32 = 2;
const DB_VERSION: u32 = 3;
const PADDING : [u8; 10] = [ 0u8; 10 ];
impl JournalDB {
/// Create a new instance given a `backing` database.
pub fn new(backing: DB) -> JournalDB {
let db = Arc::new(backing);
JournalDB::new_with_arc(db)
}
/// Create a new instance given a shared `backing` database.
pub fn new_with_arc(backing: Arc<DB>) -> JournalDB {
if backing.iterator(IteratorMode::Start).next().is_some() {
/// Create a new instance from file
pub fn new(path: &str) -> JournalDB {
let opts = DatabaseConfig {
prefix_size: Some(12) //use 12 bytes as prefix, this must match account_db prefix
};
let backing = Database::open(&opts, path).unwrap_or_else(|e| {
panic!("Error opening state db: {}", e);
});
if !backing.is_empty() {
match backing.get(&VERSION_KEY).map(|d| d.map(|v| decode::<u32>(&v))) {
Ok(Some(DB_VERSION)) => {},
v => panic!("Incompatible DB version, expected {}, got {:?}", DB_VERSION, v)
@@ -72,7 +76,7 @@ impl JournalDB {
let counters = JournalDB::read_counters(&backing);
JournalDB {
overlay: MemoryDB::new(),
backing: backing,
backing: Arc::new(backing),
counters: Arc::new(RwLock::new(counters)),
}
}
@@ -82,7 +86,7 @@ impl JournalDB {
pub fn new_temp() -> JournalDB {
let mut dir = env::temp_dir();
dir.push(H32::random().hex());
Self::new(DB::open_default(dir.to_str().unwrap()).unwrap())
Self::new(dir.to_str().unwrap())
}
/// Check if this database has any commits
@@ -117,16 +121,17 @@ impl JournalDB {
// and the key is safe to delete.
// record new commit's details.
let batch = WriteBatch::new();
let batch = DBTransaction::new();
let mut counters = self.counters.write().unwrap();
{
let mut index = 0usize;
let mut last;
while try!(self.backing.get({
let mut r = RlpStream::new_list(2);
let mut r = RlpStream::new_list(3);
r.append(&now);
r.append(&index);
r.append(&&PADDING[..]);
last = r.drain();
&last
})).is_some() {
@@ -154,9 +159,10 @@ impl JournalDB {
let mut to_remove: Vec<H256> = Vec::new();
let mut canon_inserts: Vec<H256> = Vec::new();
while let Some(rlp_data) = try!(self.backing.get({
let mut r = RlpStream::new_list(2);
let mut r = RlpStream::new_list(3);
r.append(&end_era);
r.append(&index);
r.append(&&PADDING[..]);
last = r.drain();
&last
})) {
@@ -226,16 +232,17 @@ impl JournalDB {
self.backing.get(&key.bytes()).expect("Low-level database error. Some issue with your hard disk?").map(|v| v.to_vec())
}
fn read_counters(db: &DB) -> HashMap<H256, i32> {
fn read_counters(db: &Database) -> HashMap<H256, i32> {
let mut res = HashMap::new();
if let Some(val) = db.get(&LATEST_ERA_KEY).expect("Low-level database error.") {
let mut era = decode::<u64>(&val);
loop {
let mut index = 0usize;
while let Some(rlp_data) = db.get({
let mut r = RlpStream::new_list(2);
let mut r = RlpStream::new_list(3);
r.append(&era);
r.append(&index);
r.append(&&PADDING[..]);
&r.drain()
}).expect("Low-level database error.") {
let rlp = Rlp::new(&rlp_data);
@@ -259,7 +266,7 @@ impl JournalDB {
impl HashDB for JournalDB {
fn keys(&self) -> HashMap<H256, i32> {
let mut ret: HashMap<H256, i32> = HashMap::new();
for (key, _) in self.backing.iterator(IteratorMode::Start) {
for (key, _) in self.backing.iter() {
let h = H256::from_slice(key.deref());
ret.insert(h, 1);
}
@@ -429,12 +436,11 @@ mod tests {
#[test]
fn reopen() {
use rocksdb::DB;
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let foo = {
let mut jdb = JournalDB::new(DB::open_default(dir.to_str().unwrap()).unwrap());
let mut jdb = JournalDB::new(dir.to_str().unwrap());
// history is 1
let foo = jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
@@ -442,13 +448,13 @@ mod tests {
};
{
let mut jdb = JournalDB::new(DB::open_default(dir.to_str().unwrap()).unwrap());
let mut jdb = JournalDB::new(dir.to_str().unwrap());
jdb.remove(&foo);
jdb.commit(1, &b"1".sha3(), Some((0, b"0".sha3()))).unwrap();
}
{
let mut jdb = JournalDB::new(DB::open_default(dir.to_str().unwrap()).unwrap());
let mut jdb = JournalDB::new(dir.to_str().unwrap());
assert!(jdb.exists(&foo));
jdb.commit(2, &b"2".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(!jdb.exists(&foo));

View File

@@ -333,7 +333,9 @@ pub struct KeyFileContent {
/// Holds cypher and decrypt function settings.
pub crypto: KeyFileCrypto,
/// The identifier.
pub id: Uuid
pub id: Uuid,
/// Account (if present)
pub account: Option<Address>,
}
#[derive(Debug)]
@@ -374,7 +376,19 @@ impl KeyFileContent {
KeyFileContent {
id: new_uuid(),
version: KeyFileVersion::V3(3),
crypto: crypto
crypto: crypto,
account: None
}
}
/// Loads key from valid json, returns error and records warning if key is mallformed
pub fn load(json: &Json) -> Result<KeyFileContent, ()> {
match Self::from_json(json) {
Ok(key_file) => Ok(key_file),
Err(e) => {
warn!(target: "sstore", "Error parsing json for key: {:?}", e);
Err(())
}
}
}
@@ -407,6 +421,9 @@ impl KeyFileContent {
Ok(id) => id
};
let account = as_object.get("address").and_then(|json| json.as_string()).and_then(
|account_text| match Address::from_str(account_text) { Ok(account) => Some(account), Err(_) => None });
let crypto = match as_object.get("crypto") {
None => { return Err(KeyFileParseError::NoCryptoSection); }
Some(crypto_json) => match KeyFileCrypto::from_json(crypto_json) {
@@ -418,7 +435,8 @@ impl KeyFileContent {
Ok(KeyFileContent {
version: version,
id: id.clone(),
crypto: crypto
crypto: crypto,
account: account
})
}
@@ -427,6 +445,7 @@ impl KeyFileContent {
map.insert("id".to_owned(), Json::String(uuid_to_string(&self.id)));
map.insert("version".to_owned(), Json::U64(CURRENT_DECLARED_VERSION));
map.insert("crypto".to_owned(), self.crypto.to_json());
if let Some(ref address) = self.account { map.insert("address".to_owned(), Json::String(format!("{:?}", address))); }
Json::Object(map)
}
}
@@ -599,6 +618,8 @@ impl KeyDirectory {
Err(_) => Err(KeyFileLoadError::ParseError(KeyFileParseError::InvalidJson))
}
}
}
@@ -653,7 +674,7 @@ mod file_tests {
}
#[test]
fn can_read_scrypt_krf() {
fn can_read_scrypt_kdf() {
let json = Json::from_str(
r#"
{
@@ -689,6 +710,47 @@ mod file_tests {
}
}
#[test]
fn can_read_scrypt_kdf_params() {
let json = Json::from_str(
r#"
{
"crypto" : {
"cipher" : "aes-128-ctr",
"cipherparams" : {
"iv" : "83dbcc02d8ccb40e466191a123791e0e"
},
"ciphertext" : "d172bf743a674da9cdad04534d56926ef8358534d458fffccd4e6ad2fbde479c",
"kdf" : "scrypt",
"kdfparams" : {
"dklen" : 32,
"n" : 262144,
"r" : 1,
"p" : 8,
"salt" : "ab0c7876052600dd703518d6fc3fe8984592145b591fc8fb5c6d43190334ba19"
},
"mac" : "2103ac29920d71da29f15d75b4a16dbe95cfd7ff8faea1056c33131d846e3097"
},
"id" : "3198bc9c-6672-5ab3-d995-4942343ae5b6",
"version" : 3
}
"#).unwrap();
match KeyFileContent::from_json(&json) {
Ok(key_file) => {
match key_file.crypto.kdf {
KeyFileKdf::Scrypt(scrypt_params) => {
assert_eq!(262144, scrypt_params.n);
assert_eq!(1, scrypt_params.r);
assert_eq!(8, scrypt_params.p);
},
_ => { panic!("expected kdf params of crypto to be of scrypt type" ); }
}
},
Err(e) => panic!("Error parsing valid file: {:?}", e)
}
}
#[test]
fn can_return_error_no_id() {
let json = Json::from_str(
@@ -844,7 +906,7 @@ mod file_tests {
panic!("Should be error of no identifier, got ok");
},
Err(KeyFileParseError::Crypto(CryptoParseError::Scrypt(_))) => { },
Err(other_error) => { panic!("should be error of no identifier, got {:?}", other_error); }
Err(other_error) => { panic!("should be scrypt parse error, got {:?}", other_error); }
}
}

View File

@@ -0,0 +1,165 @@
// Copyright 2015, 2016 Ethcore (UK) Ltd.
// This file is part of Parity.
// Parity is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// Parity is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with Parity. If not, see <http://www.gnu.org/licenses/>.
//! Geth keys import/export tool
use common::*;
use keys::store::SecretStore;
use keys::directory::KeyFileContent;
/// Enumerates all geth keys in the directory and returns collection of tuples `(accountId, filename)`
pub fn enumerate_geth_keys(path: &Path) -> Result<Vec<(Address, String)>, io::Error> {
let mut entries = Vec::new();
for entry in try!(fs::read_dir(path)) {
let entry = try!(entry);
if !try!(fs::metadata(entry.path())).is_dir() {
match entry.file_name().to_str() {
Some(name) => {
let parts: Vec<&str> = name.split("--").collect();
if parts.len() != 3 { continue; }
match Address::from_str(parts[2]) {
Ok(account_id) => { entries.push((account_id, name.to_owned())); }
Err(e) => { panic!("error: {:?}", e); }
}
},
None => { continue; }
};
}
}
Ok(entries)
}
/// Geth import error
#[derive(Debug)]
pub enum ImportError {
/// Io error reading geth file
IoError(io::Error),
/// format error
FormatError,
}
impl From<io::Error> for ImportError {
fn from (err: io::Error) -> ImportError {
ImportError::IoError(err)
}
}
/// Imports one geth key to the store
pub fn import_geth_key(secret_store: &mut SecretStore, geth_keyfile_path: &Path) -> Result<(), ImportError> {
let mut file = try!(fs::File::open(geth_keyfile_path));
let mut buf = String::new();
try!(file.read_to_string(&mut buf));
let mut json_result = Json::from_str(&buf);
let mut json = match json_result {
Ok(ref mut parsed_json) => try!(parsed_json.as_object_mut().ok_or(ImportError::FormatError)),
Err(_) => { return Err(ImportError::FormatError); }
};
let crypto_object = try!(json.get("Crypto").and_then(|crypto| crypto.as_object()).ok_or(ImportError::FormatError)).clone();
json.insert("crypto".to_owned(), Json::Object(crypto_object));
json.remove("Crypto");
match KeyFileContent::load(&Json::Object(json.clone())) {
Ok(key_file) => try!(secret_store.import_key(key_file)),
Err(_) => { return Err(ImportError::FormatError); }
};
Ok(())
}
/// Imports all geth keys in the directory
pub fn import_geth_keys(secret_store: &mut SecretStore, geth_keyfiles_directory: &Path) -> Result<(), ImportError> {
use std::path::PathBuf;
let geth_files = try!(enumerate_geth_keys(geth_keyfiles_directory));
for &(ref address, ref file_path) in geth_files.iter() {
let mut path = PathBuf::new();
path.push(geth_keyfiles_directory);
path.push(file_path);
if let Err(e) = import_geth_key(secret_store, Path::new(&path)) {
warn!("Skipped geth address {}, error importing: {:?}", address, e)
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use common::*;
use keys::store::SecretStore;
#[test]
fn can_enumerate() {
let keys = enumerate_geth_keys(Path::new("res/geth_keystore")).unwrap();
assert_eq!(2, keys.len());
}
#[test]
fn can_import() {
let temp = ::devtools::RandomTempPath::create_dir();
let mut secret_store = SecretStore::new_in(temp.as_path());
import_geth_key(&mut secret_store, Path::new("res/geth_keystore/UTC--2016-02-17T09-20-45.721400158Z--3f49624084b67849c7b4e805c5988c21a430f9d9")).unwrap();
let key = secret_store.account(&Address::from_str("3f49624084b67849c7b4e805c5988c21a430f9d9").unwrap());
assert!(key.is_some());
}
#[test]
fn can_import_directory() {
let temp = ::devtools::RandomTempPath::create_dir();
let mut secret_store = SecretStore::new_in(temp.as_path());
import_geth_keys(&mut secret_store, Path::new("res/geth_keystore")).unwrap();
let key = secret_store.account(&Address::from_str("3f49624084b67849c7b4e805c5988c21a430f9d9").unwrap());
assert!(key.is_some());
let key = secret_store.account(&Address::from_str("5ba4dcf897e97c2bdf8315b9ef26c13c085988cf").unwrap());
assert!(key.is_some());
}
#[test]
fn imports_as_scrypt_keys() {
use keys::directory::{KeyDirectory, KeyFileKdf};
let temp = ::devtools::RandomTempPath::create_dir();
{
let mut secret_store = SecretStore::new_in(temp.as_path());
import_geth_keys(&mut secret_store, Path::new("res/geth_keystore")).unwrap();
}
let key_directory = KeyDirectory::new(&temp.as_path());
let key_file = key_directory.get(&H128::from_str("62a0ad73556d496a8e1c0783d30d3ace").unwrap()).unwrap();
match key_file.crypto.kdf {
KeyFileKdf::Scrypt(scrypt_params) => {
assert_eq!(262144, scrypt_params.n);
assert_eq!(8, scrypt_params.r);
assert_eq!(1, scrypt_params.p);
},
_ => { panic!("expected kdf params of crypto to be of scrypt type" ); }
}
}
#[test]
fn can_decrypt_with_imported() {
use keys::store::EncryptedHashMap;
let temp = ::devtools::RandomTempPath::create_dir();
let mut secret_store = SecretStore::new_in(temp.as_path());
import_geth_keys(&mut secret_store, Path::new("res/geth_keystore")).unwrap();
let val = secret_store.get::<Bytes>(&H128::from_str("62a0ad73556d496a8e1c0783d30d3ace").unwrap(), "123");
assert!(val.is_ok());
assert_eq!(32, val.unwrap().len());
}
}

View File

@@ -18,3 +18,4 @@
pub mod directory;
pub mod store;
mod geth_import;

View File

@@ -19,11 +19,12 @@
use keys::directory::*;
use common::*;
use rcrypto::pbkdf2::*;
use rcrypto::scrypt::*;
use rcrypto::hmac::*;
use crypto;
const KEY_LENGTH: u32 = 32;
const KEY_ITERATIONS: u32 = 4096;
const KEY_ITERATIONS: u32 = 10240;
const KEY_LENGTH_AES: u32 = KEY_LENGTH/2;
const KEY_LENGTH_USIZE: usize = KEY_LENGTH as usize;
@@ -60,15 +61,62 @@ pub struct SecretStore {
}
impl SecretStore {
/// new instance of Secret Store
/// new instance of Secret Store in default home directory
pub fn new() -> SecretStore {
let mut path = ::std::env::home_dir().expect("Failed to get home dir");
path.push(".keys");
path.push(".parity");
path.push("keys");
Self::new_in(&path)
}
/// new instance of Secret Store in specific directory
pub fn new_in(path: &Path) -> SecretStore {
SecretStore {
directory: KeyDirectory::new(&path)
directory: KeyDirectory::new(path)
}
}
/// trys to import keys in the known locations
pub fn try_import_existing(&mut self) {
use std::path::PathBuf;
use keys::geth_import;
let mut import_path = PathBuf::new();
import_path.push(::std::env::home_dir().expect("Failed to get home dir"));
import_path.push(".ethereum");
import_path.push("keystore");
if let Err(e) = geth_import::import_geth_keys(self, &import_path) {
warn!(target: "sstore", "Error retrieving geth keys: {:?}", e)
}
}
/// Lists all accounts and corresponding key ids
pub fn accounts(&self) -> Result<Vec<(Address, H128)>, ::std::io::Error> {
let accounts = try!(self.directory.list()).iter().map(|key_id| self.directory.get(key_id))
.filter(|key| key.is_some())
.map(|key| { let some_key = key.unwrap(); (some_key.account, some_key.id) })
.filter(|&(ref account, _)| account.is_some())
.map(|(account, id)| (account.unwrap(), id))
.collect::<Vec<(Address, H128)>>();
Ok(accounts)
}
/// Resolves key_id by account address
pub fn account(&self, account: &Address) -> Option<H128> {
let mut accounts = match self.accounts() {
Ok(accounts) => accounts,
Err(e) => { warn!(target: "sstore", "Failed to load accounts: {}", e); return None; }
};
accounts.retain(|&(ref store_account, _)| account == store_account);
accounts.first().and_then(|&(_, ref key_id)| Some(key_id.clone()))
}
/// Imports pregenerated key, returns error if not saved correctly
pub fn import_key(&mut self, key_file: KeyFileContent) -> Result<(), ::std::io::Error> {
try!(self.directory.save(key_file));
Ok(())
}
#[cfg(test)]
fn new_test(path: &::devtools::RandomTempPath) -> SecretStore {
SecretStore {
@@ -90,6 +138,15 @@ fn derive_key(password: &str, salt: &H256) -> (Bytes, Bytes) {
derive_key_iterations(password, salt, KEY_ITERATIONS)
}
fn derive_key_scrypt(password: &str, salt: &H256, n: u32, p: u32, r: u32) -> (Bytes, Bytes) {
let mut derived_key = vec![0u8; KEY_LENGTH_USIZE];
let scrypt_params = ScryptParams::new(n.trailing_zeros() as u8, r, p);
scrypt(password.as_bytes(), &salt.as_slice(), &scrypt_params, &mut derived_key);
let derived_right_bits = &derived_key[0..KEY_LENGTH_AES_USIZE];
let derived_left_bits = &derived_key[KEY_LENGTH_AES_USIZE..KEY_LENGTH_USIZE];
(derived_right_bits.to_vec(), derived_left_bits.to_vec())
}
fn derive_mac(derived_left_bits: &[u8], cipher_text: &[u8]) -> Bytes {
let mut mac = vec![0u8; KEY_LENGTH_AES_USIZE + cipher_text.len()];
mac[0..KEY_LENGTH_AES_USIZE].clone_from_slice(derived_left_bits);
@@ -101,24 +158,22 @@ impl EncryptedHashMap<H128> for SecretStore {
fn get<Value: FromRawBytes + BytesConvertable>(&self, key: &H128, password: &str) -> Result<Value, EncryptedHashMapError> {
match self.directory.get(key) {
Some(key_file) => {
let decrypted_bytes = match key_file.crypto.kdf {
KeyFileKdf::Pbkdf2(ref params) => {
let (derived_left_bits, derived_right_bits) = derive_key_iterations(password, &params.salt, params.c);
if derive_mac(&derived_right_bits, &key_file.crypto.cipher_text)
.sha3() != key_file.crypto.mac { return Err(EncryptedHashMapError::InvalidPassword); }
let mut val = vec![0u8; key_file.crypto.cipher_text.len()];
match key_file.crypto.cipher_type {
CryptoCipherType::Aes128Ctr(ref iv) => {
crypto::aes::decrypt(&derived_left_bits, &iv.as_slice(), &key_file.crypto.cipher_text, &mut val);
}
}
val
}
_ => { unimplemented!(); }
let (derived_left_bits, derived_right_bits) = match key_file.crypto.kdf {
KeyFileKdf::Pbkdf2(ref params) => derive_key_iterations(password, &params.salt, params.c),
KeyFileKdf::Scrypt(ref params) => derive_key_scrypt(password, &params.salt, params.n, params.p, params.r)
};
match Value::from_bytes(&decrypted_bytes) {
if derive_mac(&derived_right_bits, &key_file.crypto.cipher_text)
.sha3() != key_file.crypto.mac { return Err(EncryptedHashMapError::InvalidPassword); }
let mut val = vec![0u8; key_file.crypto.cipher_text.len()];
match key_file.crypto.cipher_type {
CryptoCipherType::Aes128Ctr(ref iv) => {
crypto::aes::decrypt(&derived_left_bits, &iv.as_slice(), &key_file.crypto.cipher_text, &mut val);
}
};
match Value::from_bytes(&val) {
Ok(value) => Ok(value),
Err(bytes_error) => Err(EncryptedHashMapError::InvalidValueFormat(bytes_error))
}
@@ -259,6 +314,27 @@ mod tests {
result
}
fn pregenerate_accounts(temp: &RandomTempPath, count: usize) -> Vec<H128> {
use keys::directory::{KeyFileContent, KeyFileCrypto};
let mut write_sstore = SecretStore::new_test(&temp);
let mut result = Vec::new();
for i in 0..count {
let mut key_file =
KeyFileContent::new(
KeyFileCrypto::new_pbkdf2(
FromHex::from_hex("5318b4d5bcd28de64ee5559e671353e16f075ecae9f99c7a79a38af5f869aa46").unwrap(),
H128::from_str("6087dab2f9fdbbfaddc31a909735c1e6").unwrap(),
H256::from_str("ae3cd4e7013836a3df6bd7241b12db061dbe2c6785853cce422d148a624ce0bd").unwrap(),
H256::from_str("517ead924a9d0dc3124507e3393d175ce3ff7c1e96529c6c555ce9e51205e9b2").unwrap(),
262144,
32));
key_file.account = Some(x!(i as u64));
result.push(key_file.id.clone());
write_sstore.import_key(key_file).unwrap();
}
result
}
#[test]
fn can_get() {
let temp = RandomTempPath::create_dir();
@@ -293,5 +369,35 @@ mod tests {
assert_eq!(4, sstore.directory.list().unwrap().len())
}
#[test]
fn can_import_account() {
use keys::directory::{KeyFileContent, KeyFileCrypto};
let temp = RandomTempPath::create_dir();
let mut key_file =
KeyFileContent::new(
KeyFileCrypto::new_pbkdf2(
FromHex::from_hex("5318b4d5bcd28de64ee5559e671353e16f075ecae9f99c7a79a38af5f869aa46").unwrap(),
H128::from_str("6087dab2f9fdbbfaddc31a909735c1e6").unwrap(),
H256::from_str("ae3cd4e7013836a3df6bd7241b12db061dbe2c6785853cce422d148a624ce0bd").unwrap(),
H256::from_str("517ead924a9d0dc3124507e3393d175ce3ff7c1e96529c6c555ce9e51205e9b2").unwrap(),
262144,
32));
key_file.account = Some(Address::from_str("3f49624084b67849c7b4e805c5988c21a430f9d9").unwrap());
let mut sstore = SecretStore::new_test(&temp);
sstore.import_key(key_file).unwrap();
assert_eq!(1, sstore.accounts().unwrap().len());
assert!(sstore.account(&Address::from_str("3f49624084b67849c7b4e805c5988c21a430f9d9").unwrap()).is_some());
}
#[test]
fn can_list_accounts() {
let temp = RandomTempPath::create_dir();
pregenerate_accounts(&temp, 30);
let sstore = SecretStore::new_test(&temp);
let accounts = sstore.accounts().unwrap();
assert_eq!(30, accounts.len());
}
}

206
util/src/kvdb.rs Normal file
View File

@@ -0,0 +1,206 @@
// Copyright 2015, 2016 Ethcore (UK) Ltd.
// This file is part of Parity.
// Parity is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// Parity is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with Parity. If not, see <http://www.gnu.org/licenses/>.
//! Key-Value store abstraction with RocksDB backend.
use rocksdb::{DB, Writable, WriteBatch, IteratorMode, DBVector, DBIterator,
IndexType, Options, DBCompactionStyle, BlockBasedOptions, Direction};
/// Write transaction. Batches a sequence of put/delete operations for efficiency.
pub struct DBTransaction {
batch: WriteBatch,
}
impl DBTransaction {
/// Create new transaction.
pub fn new() -> DBTransaction {
DBTransaction { batch: WriteBatch::new() }
}
/// Insert a key-value pair in the transaction. Any existing value value will be overwritten upon write.
pub fn put(&self, key: &[u8], value: &[u8]) -> Result<(), String> {
self.batch.put(key, value)
}
/// Delete value by key.
pub fn delete(&self, key: &[u8]) -> Result<(), String> {
self.batch.delete(key)
}
}
/// Database configuration
pub struct DatabaseConfig {
/// Optional prefix size in bytes. Allows lookup by partial key.
pub prefix_size: Option<usize>
}
/// Database iterator
pub struct DatabaseIterator<'a> {
iter: DBIterator<'a>,
}
impl<'a> Iterator for DatabaseIterator<'a> {
type Item = (Box<[u8]>, Box<[u8]>);
#[cfg_attr(feature="dev", allow(type_complexity))]
fn next(&mut self) -> Option<(Box<[u8]>, Box<[u8]>)> {
self.iter.next()
}
}
/// Key-Value database.
pub struct Database {
db: DB,
}
impl Database {
/// Open database with default settings.
pub fn open_default(path: &str) -> Result<Database, String> {
Database::open(&DatabaseConfig { prefix_size: None }, path)
}
/// Open database file. Creates if it does not exist.
pub fn open(config: &DatabaseConfig, path: &str) -> Result<Database, String> {
let mut opts = Options::new();
opts.set_max_open_files(256);
opts.create_if_missing(true);
opts.set_use_fsync(false);
opts.set_compaction_style(DBCompactionStyle::DBUniversalCompaction);
/*
opts.set_bytes_per_sync(8388608);
opts.set_disable_data_sync(false);
opts.set_block_cache_size_mb(1024);
opts.set_table_cache_num_shard_bits(6);
opts.set_max_write_buffer_number(32);
opts.set_write_buffer_size(536870912);
opts.set_target_file_size_base(1073741824);
opts.set_min_write_buffer_number_to_merge(4);
opts.set_level_zero_stop_writes_trigger(2000);
opts.set_level_zero_slowdown_writes_trigger(0);
opts.set_compaction_style(DBUniversalCompaction);
opts.set_max_background_compactions(4);
opts.set_max_background_flushes(4);
opts.set_filter_deletes(false);
opts.set_disable_auto_compactions(false);
*/
if let Some(size) = config.prefix_size {
let mut block_opts = BlockBasedOptions::new();
block_opts.set_index_type(IndexType::HashSearch);
opts.set_block_based_table_factory(&block_opts);
opts.set_prefix_extractor_fixed_size(size);
}
let db = try!(DB::open(&opts, path));
Ok(Database { db: db })
}
/// Insert a key-value pair in the transaction. Any existing value value will be overwritten.
pub fn put(&self, key: &[u8], value: &[u8]) -> Result<(), String> {
self.db.put(key, value)
}
/// Delete value by key.
pub fn delete(&self, key: &[u8]) -> Result<(), String> {
self.db.delete(key)
}
/// Commit transaction to database.
pub fn write(&self, tr: DBTransaction) -> Result<(), String> {
self.db.write(tr.batch)
}
/// Get value by key.
pub fn get(&self, key: &[u8]) -> Result<Option<DBVector>, String> {
self.db.get(key)
}
/// Get value by partial key. Prefix size should match configured prefix size.
pub fn get_by_prefix(&self, prefix: &[u8]) -> Option<Box<[u8]>> {
let mut iter = self.db.iterator(IteratorMode::From(prefix, Direction::forward));
match iter.next() {
// TODO: use prefix_same_as_start read option (not availabele in C API currently)
Some((k, v)) => if k[0 .. prefix.len()] == prefix[..] { Some(v) } else { None },
_ => None
}
}
/// Check if there is anything in the database.
pub fn is_empty(&self) -> bool {
self.db.iterator(IteratorMode::Start).next().is_none()
}
/// Check if there is anything in the database.
pub fn iter(&self) -> DatabaseIterator {
DatabaseIterator { iter: self.db.iterator(IteratorMode::Start) }
}
}
#[cfg(test)]
mod tests {
use hash::*;
use super::*;
use devtools::*;
use std::str::FromStr;
use std::ops::Deref;
fn test_db(config: &DatabaseConfig) {
let path = RandomTempPath::create_dir();
let db = Database::open(config, path.as_path().to_str().unwrap()).unwrap();
let key1 = H256::from_str("02c69be41d0b7e40352fc85be1cd65eb03d40ef8427a0ca4596b1ead9a00e9fc").unwrap();
let key2 = H256::from_str("03c69be41d0b7e40352fc85be1cd65eb03d40ef8427a0ca4596b1ead9a00e9fc").unwrap();
let key3 = H256::from_str("01c69be41d0b7e40352fc85be1cd65eb03d40ef8427a0ca4596b1ead9a00e9fc").unwrap();
db.put(&key1, b"cat").unwrap();
db.put(&key2, b"dog").unwrap();
assert_eq!(db.get(&key1).unwrap().unwrap().deref(), b"cat");
let contents: Vec<_> = db.iter().collect();
assert_eq!(contents.len(), 2);
assert_eq!(&*contents[0].0, key1.deref());
assert_eq!(&*contents[0].1, b"cat");
assert_eq!(&*contents[1].0, key2.deref());
assert_eq!(&*contents[1].1, b"dog");
db.delete(&key1).unwrap();
assert!(db.get(&key1).unwrap().is_none());
db.put(&key1, b"cat").unwrap();
let transaction = DBTransaction::new();
transaction.put(&key3, b"elephant").unwrap();
transaction.delete(&key1).unwrap();
db.write(transaction).unwrap();
assert!(db.get(&key1).unwrap().is_none());
assert_eq!(db.get(&key3).unwrap().unwrap().deref(), b"elephant");
if config.prefix_size.is_some() {
assert_eq!(db.get_by_prefix(&key3).unwrap().deref(), b"elephant");
assert_eq!(db.get_by_prefix(&key2).unwrap().deref(), b"dog");
}
}
#[test]
fn kvdb() {
let path = RandomTempPath::create_dir();
let smoke = Database::open_default(path.as_path().to_str().unwrap()).unwrap();
assert!(smoke.is_empty());
test_db(&DatabaseConfig { prefix_size: None });
test_db(&DatabaseConfig { prefix_size: Some(1) });
test_db(&DatabaseConfig { prefix_size: Some(8) });
test_db(&DatabaseConfig { prefix_size: Some(32) });
}
}

View File

@@ -130,8 +130,8 @@ pub mod hashdb;
pub mod memorydb;
pub mod overlaydb;
pub mod journaldb;
pub mod kvdb;
mod math;
pub mod chainfilter;
pub mod crypto;
pub mod triehash;
pub mod trie;
@@ -154,7 +154,6 @@ pub use memorydb::*;
pub use overlaydb::*;
pub use journaldb::*;
pub use math::*;
pub use chainfilter::*;
pub use crypto::*;
pub use triehash::*;
pub use trie::*;
@@ -164,4 +163,5 @@ pub use semantic_version::*;
pub use network::*;
pub use io::*;
pub use log::*;
pub use kvdb::*;

View File

@@ -106,6 +106,7 @@ const IDLE: usize = LAST_HANDSHAKE + 2;
const DISCOVERY: usize = LAST_HANDSHAKE + 3;
const DISCOVERY_REFRESH: usize = LAST_HANDSHAKE + 4;
const DISCOVERY_ROUND: usize = LAST_HANDSHAKE + 5;
const INIT_PUBLIC: usize = LAST_HANDSHAKE + 6;
const FIRST_SESSION: usize = 0;
const LAST_SESSION: usize = FIRST_SESSION + MAX_SESSIONS - 1;
const FIRST_HANDSHAKE: usize = LAST_SESSION + 1;
@@ -261,7 +262,9 @@ pub struct HostInfo {
/// TCP connection port.
pub listen_port: u16,
/// Registered capabilities (handlers)
pub capabilities: Vec<CapabilityInfo>
pub capabilities: Vec<CapabilityInfo>,
/// Public address + discovery port
public_endpoint: NodeEndpoint,
}
impl HostInfo {
@@ -294,16 +297,15 @@ struct ProtocolTimer {
/// Root IO handler. Manages protocol handlers, IO timers and network connections.
pub struct Host<Message> where Message: Send + Sync + Clone {
pub info: RwLock<HostInfo>,
tcp_listener: Mutex<TcpListener>,
tcp_listener: Mutex<Option<TcpListener>>,
handshakes: Arc<RwLock<Slab<SharedHandshake>>>,
sessions: Arc<RwLock<Slab<SharedSession>>>,
discovery: Option<Mutex<Discovery>>,
discovery: Mutex<Option<Discovery>>,
nodes: RwLock<NodeTable>,
handlers: RwLock<HashMap<ProtocolId, Arc<NetworkProtocolHandler<Message>>>>,
timers: RwLock<HashMap<TimerToken, ProtocolTimer>>,
timer_counter: RwLock<usize>,
stats: Arc<NetworkStats>,
public_endpoint: NodeEndpoint,
pinned_nodes: Vec<NodeId>,
}
@@ -316,27 +318,6 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
};
let udp_port = config.udp_port.unwrap_or(listen_address.port());
let public_endpoint = match config.public_address {
None => {
let public_address = select_public_address(listen_address.port());
let local_endpoint = NodeEndpoint { address: public_address, udp_port: udp_port };
if config.nat_enabled {
match map_external_address(&local_endpoint) {
Some(endpoint) => {
info!("NAT Mappped to external address {}", endpoint.address);
endpoint
},
None => local_endpoint
}
} else {
local_endpoint
}
}
Some(addr) => NodeEndpoint { address: addr, udp_port: udp_port }
};
// Setup the server socket
let tcp_listener = TcpListener::bind(&listen_address).unwrap();
let keys = if let Some(ref secret) = config.use_secret {
KeyPair::from_secret(secret.clone()).unwrap()
} else {
@@ -350,10 +331,8 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
},
|s| KeyPair::from_secret(s).expect("Error creating node secret key"))
};
let discovery = if config.discovery_enabled && !config.pin {
Some(Discovery::new(&keys, listen_address.clone(), public_endpoint.clone(), DISCOVERY))
} else { None };
let path = config.config_path.clone();
let local_endpoint = NodeEndpoint { address: listen_address, udp_port: udp_port };
let mut host = Host::<Message> {
info: RwLock::new(HostInfo {
keys: keys,
@@ -363,9 +342,10 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
client_version: version(),
listen_port: 0,
capabilities: Vec::new(),
public_endpoint: local_endpoint, // will be replaced by public once it is resolved
}),
discovery: discovery.map(Mutex::new),
tcp_listener: Mutex::new(tcp_listener),
discovery: Mutex::new(None),
tcp_listener: Mutex::new(None),
handshakes: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_HANDSHAKE, MAX_HANDSHAKES))),
sessions: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_SESSION, MAX_SESSIONS))),
nodes: RwLock::new(NodeTable::new(path)),
@@ -373,16 +353,12 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
timers: RwLock::new(HashMap::new()),
timer_counter: RwLock::new(USER_TIMER),
stats: Arc::new(NetworkStats::default()),
public_endpoint: public_endpoint,
pinned_nodes: Vec::new(),
};
let port = listen_address.port();
host.info.write().unwrap().deref_mut().listen_port = port;
let boot_nodes = host.info.read().unwrap().config.boot_nodes.clone();
if let Some(ref mut discovery) = host.discovery {
discovery.lock().unwrap().init_node_list(host.nodes.read().unwrap().unordered_entries());
}
for n in boot_nodes {
host.add_node(&n);
}
@@ -400,8 +376,8 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
let entry = NodeEntry { endpoint: n.endpoint.clone(), id: n.id.clone() };
self.pinned_nodes.push(n.id.clone());
self.nodes.write().unwrap().add_node(n);
if let Some(ref mut discovery) = self.discovery {
discovery.lock().unwrap().add_node(entry);
if let &mut Some(ref mut discovery) = self.discovery.lock().unwrap().deref_mut() {
discovery.add_node(entry);
}
}
}
@@ -412,7 +388,61 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
}
pub fn client_url(&self) -> String {
format!("{}", Node::new(self.info.read().unwrap().id().clone(), self.public_endpoint.clone()))
format!("{}", Node::new(self.info.read().unwrap().id().clone(), self.info.read().unwrap().public_endpoint.clone()))
}
fn init_public_interface(&self, io: &IoContext<NetworkIoMessage<Message>>) {
io.clear_timer(INIT_PUBLIC).unwrap();
let mut tcp_listener = self.tcp_listener.lock().unwrap();
if tcp_listener.is_some() {
return;
}
// public_endpoint in host info contains local adderss at this point
let listen_address = self.info.read().unwrap().public_endpoint.address.clone();
let udp_port = self.info.read().unwrap().config.udp_port.unwrap_or(listen_address.port());
let public_endpoint = match self.info.read().unwrap().config.public_address {
None => {
let public_address = select_public_address(listen_address.port());
let local_endpoint = NodeEndpoint { address: public_address, udp_port: udp_port };
if self.info.read().unwrap().config.nat_enabled {
match map_external_address(&local_endpoint) {
Some(endpoint) => {
info!("NAT mappped to external address {}", endpoint.address);
endpoint
},
None => local_endpoint
}
} else {
local_endpoint
}
}
Some(addr) => NodeEndpoint { address: addr, udp_port: udp_port }
};
// Setup the server socket
*tcp_listener = Some(TcpListener::bind(&listen_address).unwrap());
self.info.write().unwrap().public_endpoint = public_endpoint.clone();
io.register_stream(TCP_ACCEPT).expect("Error registering TCP listener");
info!("Public node URL: {}", self.client_url());
// Initialize discovery.
let discovery = {
let info = self.info.read().unwrap();
if info.config.discovery_enabled && !info.config.pin {
Some(Discovery::new(&info.keys, listen_address.clone(), public_endpoint, DISCOVERY))
} else { None }
};
if let Some(mut discovery) = discovery {
discovery.init_node_list(self.nodes.read().unwrap().unordered_entries());
for n in self.nodes.read().unwrap().unordered_entries() {
discovery.add_node(n.clone());
}
io.register_stream(DISCOVERY).expect("Error registering UDP listener");
io.register_timer(DISCOVERY_REFRESH, 7200).expect("Error registering discovery timer");
io.register_timer(DISCOVERY_ROUND, 300).expect("Error registering discovery timer");
*self.discovery.lock().unwrap().deref_mut() = Some(discovery);
}
}
fn maintain_network(&self, io: &IoContext<NetworkIoMessage<Message>>) {
@@ -526,7 +556,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
fn accept(&self, io: &IoContext<NetworkIoMessage<Message>>) {
trace!(target: "network", "Accepting incoming connection");
loop {
let socket = match self.tcp_listener.lock().unwrap().accept() {
let socket = match self.tcp_listener.lock().unwrap().as_ref().unwrap().accept() {
Ok(None) => break,
Ok(Some((sock, _addr))) => sock,
Err(e) => {
@@ -579,7 +609,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
}
}
if kill {
self.kill_connection(token, io, true); //TODO: mark connection as dead an check in kill_connection
self.kill_connection(token, io, true);
return;
} else if create_session {
self.start_session(token, io);
@@ -621,7 +651,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
}
}
if kill {
self.kill_connection(token, io, true); //TODO: mark connection as dead an check in kill_connection
self.kill_connection(token, io, true);
}
for p in ready_data {
let h = self.handlers.read().unwrap().get(p).unwrap().clone();
@@ -666,8 +696,9 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
if let Ok(address) = session.remote_addr() {
let entry = NodeEntry { id: session.id().clone(), endpoint: NodeEndpoint { address: address, udp_port: address.port() } };
self.nodes.write().unwrap().add_node(Node::new(entry.id.clone(), entry.endpoint.clone()));
if let Some(ref discovery) = self.discovery {
discovery.lock().unwrap().add_node(entry);
let mut discovery = self.discovery.lock().unwrap();
if let &mut Some(ref mut discovery) = discovery.deref_mut() {
discovery.add_node(entry);
}
}
}
@@ -685,6 +716,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
fn kill_connection(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>, remote: bool) {
let mut to_disconnect: Vec<ProtocolId> = Vec::new();
let mut failure_id = None;
let mut deregister = false;
match token {
FIRST_HANDSHAKE ... LAST_HANDSHAKE => {
let handshakes = self.handshakes.write().unwrap();
@@ -693,7 +725,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
if !handshake.expired() {
handshake.set_expired();
failure_id = Some(handshake.id().clone());
io.deregister_stream(token).expect("Error deregistering stream");
deregister = true;
}
}
},
@@ -711,7 +743,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
}
s.set_expired();
failure_id = Some(s.id().clone());
io.deregister_stream(token).expect("Error deregistering stream");
deregister = true;
}
}
},
@@ -726,6 +758,9 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
let h = self.handlers.read().unwrap().get(p).unwrap().clone();
h.disconnected(&NetworkContext::new(io, p, Some(token), self.sessions.clone()), &token);
}
if deregister {
io.deregister_stream(token).expect("Error deregistering stream");
}
}
fn update_nodes(&self, io: &IoContext<NetworkIoMessage<Message>>, node_changes: TableUpdates) {
@@ -760,13 +795,8 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Message: Send + Sync + Clone + 'static {
/// Initialize networking
fn initialize(&self, io: &IoContext<NetworkIoMessage<Message>>) {
io.register_stream(TCP_ACCEPT).expect("Error registering TCP listener");
io.register_timer(IDLE, MAINTENANCE_TIMEOUT).expect("Error registering Network idle timer");
if self.discovery.is_some() {
io.register_stream(DISCOVERY).expect("Error registering UDP listener");
io.register_timer(DISCOVERY_REFRESH, 7200).expect("Error registering discovery timer");
io.register_timer(DISCOVERY_ROUND, 300).expect("Error registering discovery timer");
}
io.register_timer(INIT_PUBLIC, 0).expect("Error registering initialization timer");
self.maintain_network(io)
}
@@ -784,7 +814,7 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
FIRST_SESSION ... LAST_SESSION => self.session_readable(stream, io),
FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_readable(stream, io),
DISCOVERY => {
let node_changes = { self.discovery.as_ref().unwrap().lock().unwrap().readable() };
let node_changes = { self.discovery.lock().unwrap().as_mut().unwrap().readable() };
if let Some(node_changes) = node_changes {
self.update_nodes(io, node_changes);
}
@@ -800,7 +830,7 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
FIRST_SESSION ... LAST_SESSION => self.session_writable(stream, io),
FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_writable(stream, io),
DISCOVERY => {
self.discovery.as_ref().unwrap().lock().unwrap().writable();
self.discovery.lock().unwrap().as_mut().unwrap().writable();
io.update_registration(DISCOVERY).expect("Error updating discovery registration");
}
_ => panic!("Received unknown writable token"),
@@ -810,14 +840,15 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
fn timeout(&self, io: &IoContext<NetworkIoMessage<Message>>, token: TimerToken) {
match token {
IDLE => self.maintain_network(io),
INIT_PUBLIC => self.init_public_interface(io),
FIRST_SESSION ... LAST_SESSION => self.connection_timeout(token, io),
FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_timeout(token, io),
DISCOVERY_REFRESH => {
self.discovery.as_ref().unwrap().lock().unwrap().refresh();
self.discovery.lock().unwrap().as_mut().unwrap().refresh();
io.update_registration(DISCOVERY).expect("Error updating discovery registration");
},
DISCOVERY_ROUND => {
let node_changes = { self.discovery.as_ref().unwrap().lock().unwrap().round() };
let node_changes = { self.discovery.lock().unwrap().as_mut().unwrap().round() };
if let Some(node_changes) = node_changes {
self.update_nodes(io, node_changes);
}
@@ -892,8 +923,8 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
connection.lock().unwrap().register_socket(reg, event_loop).expect("Error registering socket");
}
}
DISCOVERY => self.discovery.as_ref().unwrap().lock().unwrap().register_socket(event_loop).expect("Error registering discovery socket"),
TCP_ACCEPT => event_loop.register(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error registering stream"),
DISCOVERY => self.discovery.lock().unwrap().as_ref().unwrap().register_socket(event_loop).expect("Error registering discovery socket"),
TCP_ACCEPT => event_loop.register(self.tcp_listener.lock().unwrap().as_ref().unwrap(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error registering stream"),
_ => warn!("Unexpected stream registration")
}
}
@@ -915,7 +946,6 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
}
}
DISCOVERY => (),
TCP_ACCEPT => event_loop.deregister(self.tcp_listener.lock().unwrap().deref()).unwrap(),
_ => warn!("Unexpected stream deregistration")
}
}
@@ -934,8 +964,8 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
connection.lock().unwrap().update_socket(reg, event_loop).expect("Error updating socket");
}
}
DISCOVERY => self.discovery.as_ref().unwrap().lock().unwrap().update_registration(event_loop).expect("Error reregistering discovery socket"),
TCP_ACCEPT => event_loop.reregister(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error reregistering stream"),
DISCOVERY => self.discovery.lock().unwrap().as_ref().unwrap().update_registration(event_loop).expect("Error reregistering discovery socket"),
TCP_ACCEPT => event_loop.reregister(self.tcp_listener.lock().unwrap().as_ref().unwrap(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error reregistering stream"),
_ => warn!("Unexpected stream update")
}
}

View File

@@ -42,7 +42,6 @@ impl<Message> NetworkService<Message> where Message: Send + Sync + Clone + 'stat
let host = Arc::new(Host::new(config));
let stats = host.stats().clone();
let host_info = host.client_version();
info!("Node URL: {}", host.client_url());
try!(io_service.register_handler(host));
Ok(NetworkService {
io_service: io_service,

View File

@@ -26,7 +26,7 @@ use std::ops::*;
use std::sync::*;
use std::env;
use std::collections::HashMap;
use rocksdb::{DB, Writable, IteratorMode};
use kvdb::{Database};
/// Implementation of the HashDB trait for a disk-backed database with a memory overlay.
///
@@ -38,15 +38,15 @@ use rocksdb::{DB, Writable, IteratorMode};
/// queries have an immediate effect in terms of these functions.
pub struct OverlayDB {
overlay: MemoryDB,
backing: Arc<DB>,
backing: Arc<Database>,
}
impl OverlayDB {
/// Create a new instance of OverlayDB given a `backing` database.
pub fn new(backing: DB) -> OverlayDB { Self::new_with_arc(Arc::new(backing)) }
pub fn new(backing: Database) -> OverlayDB { Self::new_with_arc(Arc::new(backing)) }
/// Create a new instance of OverlayDB given a `backing` database.
pub fn new_with_arc(backing: Arc<DB>) -> OverlayDB {
pub fn new_with_arc(backing: Arc<Database>) -> OverlayDB {
OverlayDB{ overlay: MemoryDB::new(), backing: backing }
}
@@ -54,7 +54,7 @@ impl OverlayDB {
pub fn new_temp() -> OverlayDB {
let mut dir = env::temp_dir();
dir.push(H32::random().hex());
Self::new(DB::open_default(dir.to_str().unwrap()).unwrap())
Self::new(Database::open_default(dir.to_str().unwrap()).unwrap())
}
/// Commit all memory operations to the backing database.
@@ -164,7 +164,7 @@ impl OverlayDB {
impl HashDB for OverlayDB {
fn keys(&self) -> HashMap<H256, i32> {
let mut ret: HashMap<H256, i32> = HashMap::new();
for (key, _) in self.backing.iterator(IteratorMode::Start) {
for (key, _) in self.backing.iter() {
let h = H256::from_slice(key.deref());
let r = self.payload(&h).unwrap().1;
ret.insert(h, r as i32);
@@ -318,7 +318,7 @@ fn overlaydb_complex() {
fn playpen() {
use std::fs;
{
let db: DB = DB::open_default("/tmp/test").unwrap();
let db: Database = Database::open_default("/tmp/test").unwrap();
db.put(b"test", b"test2").unwrap();
match db.get(b"test") {
Ok(Some(value)) => println!("Got value {:?}", value.deref()),

View File

@@ -20,6 +20,7 @@ extern crate rand;
use bytes::*;
use sha3::*;
use hash::*;
use rlp::encode;
/// Alphabet to use when creating words for insertion into tries.
pub enum Alphabet {
@@ -39,6 +40,8 @@ pub enum ValueMode {
Mirror,
/// Randomly (50:50) 1 or 32 byte randomly string.
Random,
/// RLP-encoded index.
Index,
}
/// Standard test map for profiling tries.
@@ -89,19 +92,27 @@ impl StandardMap {
/// Create the standard map (set of keys and values) for the object's fields.
pub fn make(&self) -> Vec<(Bytes, Bytes)> {
self.make_with(&mut H256::new())
}
/// Create the standard map (set of keys and values) for the object's fields, using the given seed.
pub fn make_with(&self, seed: &mut H256) -> Vec<(Bytes, Bytes)> {
let low = b"abcdef";
let mid = b"@QWERTYUIOPASDFGHJKLZXCVBNM[/]^_";
let mut d: Vec<(Bytes, Bytes)> = Vec::new();
let mut seed = H256::new();
for _ in 0..self.count {
for index in 0..self.count {
let k = match self.alphabet {
Alphabet::All => Self::random_bytes(self.min_key, self.journal_key, &mut seed),
Alphabet::Low => Self::random_word(low, self.min_key, self.journal_key, &mut seed),
Alphabet::Mid => Self::random_word(mid, self.min_key, self.journal_key, &mut seed),
Alphabet::Custom(ref a) => Self::random_word(&a, self.min_key, self.journal_key, &mut seed),
Alphabet::All => Self::random_bytes(self.min_key, self.journal_key, seed),
Alphabet::Low => Self::random_word(low, self.min_key, self.journal_key, seed),
Alphabet::Mid => Self::random_word(mid, self.min_key, self.journal_key, seed),
Alphabet::Custom(ref a) => Self::random_word(&a, self.min_key, self.journal_key, seed),
};
let v = match self.value_mode {
ValueMode::Mirror => k.clone(),
ValueMode::Random => Self::random_value(seed),
ValueMode::Index => encode(&index).to_vec(),
};
let v = match self.value_mode { ValueMode::Mirror => k.clone(), ValueMode::Random => Self::random_value(&mut seed) };
d.push((k, v))
}
d

View File

@@ -687,31 +687,10 @@ mod tests {
use super::*;
use nibbleslice::*;
use rlp::*;
use rand::random;
use std::collections::HashSet;
use bytes::{ToPretty,Bytes,Populatable};
use bytes::ToPretty;
use super::super::node::*;
use super::super::trietraits::*;
fn random_key(alphabet: &[u8], min_count: usize, journal_count: usize) -> Vec<u8> {
let mut ret: Vec<u8> = Vec::new();
let r = min_count + if journal_count > 0 {random::<usize>() % journal_count} else {0};
for _ in 0..r {
ret.push(alphabet[random::<usize>() % alphabet.len()]);
}
ret
}
fn random_value_indexed(j: usize) -> Bytes {
match random::<usize>() % 2 {
0 => encode(&j).to_vec(),
_ => {
let mut h = H256::new();
h.as_slice_mut()[31] = j as u8;
encode(&h).to_vec()
},
}
}
use super::super::standardmap::*;
fn populate_trie<'db>(db: &'db mut HashDB, root: &'db mut H256, v: &[(Vec<u8>, Vec<u8>)]) -> TrieDBMut<'db> {
let mut t = TrieDBMut::new(db, root);
@@ -756,20 +735,18 @@ mod tests {
};*/
// panic!();
let mut seed = H256::new();
for test_i in 0..1 {
if test_i % 50 == 0 {
debug!("{:?} of 10000 stress tests done", test_i);
}
let mut x: Vec<(Vec<u8>, Vec<u8>)> = Vec::new();
let mut got: HashSet<Vec<u8>> = HashSet::new();
let alphabet = b"@QWERTYUIOPASDFGHJKLZXCVBNM[/]^_";
for j in 0..100usize {
let key = random_key(alphabet, 5, 0);
if !got.contains(&key) {
x.push((key.clone(), random_value_indexed(j)));
got.insert(key);
}
}
let x = StandardMap {
alphabet: Alphabet::Custom(b"@QWERTYUIOPASDFGHJKLZXCVBNM[/]^_".to_vec()),
min_key: 5,
journal_key: 0,
value_mode: ValueMode::Index,
count: 100,
}.make_with(&mut seed);
let real = trie_root(x.clone());
let mut memdb = MemoryDB::new();
@@ -1049,13 +1026,16 @@ mod tests {
#[test]
fn stress() {
let mut seed = H256::new();
for _ in 0..50 {
let mut x: Vec<(Vec<u8>, Vec<u8>)> = Vec::new();
let alphabet = b"@QWERTYUIOPASDFGHJKLZXCVBNM[/]^_";
for j in 0..4u32 {
let key = random_key(alphabet, 5, 1);
x.push((key, encode(&j).to_vec()));
}
let x = StandardMap {
alphabet: Alphabet::Custom(b"@QWERTYUIOPASDFGHJKLZXCVBNM[/]^_".to_vec()),
min_key: 5,
journal_key: 0,
value_mode: ValueMode::Index,
count: 4,
}.make_with(&mut seed);
let real = trie_root(x.clone());
let mut memdb = MemoryDB::new();
let mut root = H256::new();

View File

@@ -97,23 +97,69 @@ macro_rules! uint_overflowing_add {
let other_t: &[u64; 4] = unsafe { &mem::transmute($other) };
let overflow: u8;
unsafe {
asm!("
adc $9, $0
adc $10, $1
adc $11, $2
adc $12, $3
setc %al
"
: "=r"(result[0]), "=r"(result[1]), "=r"(result[2]), "=r"(result[3]), "={al}"(overflow)
: "0"(self_t[0]), "1"(self_t[1]), "2"(self_t[2]), "3"(self_t[3]),
unsafe {
asm!("
add $9, $0
adc $10, $1
adc $11, $2
adc $12, $3
setc %al
"
: "=r"(result[0]), "=r"(result[1]), "=r"(result[2]), "=r"(result[3]), "={al}"(overflow)
: "0"(self_t[0]), "1"(self_t[1]), "2"(self_t[2]), "3"(self_t[3]),
"mr"(other_t[0]), "mr"(other_t[1]), "mr"(other_t[2]), "mr"(other_t[3])
:
:
:
:
);
}
(U256(result), overflow != 0)
});
(U512, $n_words: expr, $self_expr: expr, $other: expr) => ({
let mut result: [u64; 8] = unsafe { mem::uninitialized() };
let self_t: &[u64; 8] = unsafe { &mem::transmute($self_expr) };
let other_t: &[u64; 8] = unsafe { &mem::transmute($other) };
let overflow: u8;
unsafe {
asm!("
add $15, $0
adc $16, $1
adc $17, $2
adc $18, $3
lodsq
adc $11, %rax
stosq
lodsq
adc $12, %rax
stosq
lodsq
adc $13, %rax
stosq
lodsq
adc $14, %rax
stosq
setc %al
": "=r"(result[0]), "=r"(result[1]), "=r"(result[2]), "=r"(result[3]),
"={al}"(overflow) /* $0 - $4 */
: "{rdi}"(&result[4] as *const u64) /* $5 */
"{rsi}"(&other_t[4] as *const u64) /* $6 */
"0"(self_t[0]), "1"(self_t[1]), "2"(self_t[2]), "3"(self_t[3]),
"m"(self_t[4]), "m"(self_t[5]), "m"(self_t[6]), "m"(self_t[7]),
/* $7 - $14 */
"mr"(other_t[0]), "mr"(other_t[1]), "mr"(other_t[2]), "mr"(other_t[3]),
"m"(other_t[4]), "m"(other_t[5]), "m"(other_t[6]), "m"(other_t[7]) /* $15 - $22 */
: "rdi", "rsi"
:
);
}
(U512(result), overflow != 0)
});
($name:ident, $n_words:expr, $self_expr: expr, $other: expr) => (
uint_overflowing_add_reg!($name, $n_words, $self_expr, $other)
)
@@ -138,12 +184,13 @@ macro_rules! uint_overflowing_sub {
let overflow: u8;
unsafe {
asm!("
sbb $9, $0
sbb $10, $1
sbb $11, $2
sbb $12, $3
setb %al"
: "=r"(result[0]), "=r"(result[1]), "=r"(result[2]), "=r"(result[3]), "={al}"(overflow)
sub $9, $0
sbb $10, $1
sbb $11, $2
sbb $12, $3
setb %al
"
: "=r"(result[0]), "=r"(result[1]), "=r"(result[2]), "=r"(result[3]), "={al}"(overflow)
: "0"(self_t[0]), "1"(self_t[1]), "2"(self_t[2]), "3"(self_t[3]), "mr"(other_t[0]), "mr"(other_t[1]), "mr"(other_t[2]), "mr"(other_t[3])
:
:
@@ -151,6 +198,51 @@ macro_rules! uint_overflowing_sub {
}
(U256(result), overflow != 0)
});
(U512, $n_words: expr, $self_expr: expr, $other: expr) => ({
let mut result: [u64; 8] = unsafe { mem::uninitialized() };
let self_t: &[u64; 8] = unsafe { &mem::transmute($self_expr) };
let other_t: &[u64; 8] = unsafe { &mem::transmute($other) };
let overflow: u8;
unsafe {
asm!("
sub $15, $0
sbb $16, $1
sbb $17, $2
sbb $18, $3
lodsq
sbb $19, %rax
stosq
lodsq
sbb $20, %rax
stosq
lodsq
sbb $21, %rax
stosq
lodsq
sbb $22, %rax
stosq
setb %al
"
: "=r"(result[0]), "=r"(result[1]), "=r"(result[2]), "=r"(result[3]),
"={al}"(overflow) /* $0 - $4 */
: "{rdi}"(&result[4] as *const u64) /* $5 */
"{rsi}"(&self_t[4] as *const u64) /* $6 */
"0"(self_t[0]), "1"(self_t[1]), "2"(self_t[2]), "3"(self_t[3]),
"m"(self_t[4]), "m"(self_t[5]), "m"(self_t[6]), "m"(self_t[7]),
/* $7 - $14 */
"m"(other_t[0]), "m"(other_t[1]), "m"(other_t[2]), "m"(other_t[3]),
"m"(other_t[4]), "m"(other_t[5]), "m"(other_t[6]), "m"(other_t[7]) /* $15 - $22 */
: "rdi", "rsi"
:
);
}
(U512(result), overflow != 0)
});
($name:ident, $n_words: expr, $self_expr: expr, $other: expr) => ({
let res = overflowing!((!$other).overflowing_add(From::from(1u64)));
let res = overflowing!($self_expr.overflowing_add(res));
@@ -675,7 +767,7 @@ macro_rules! construct_uint {
self.to_bytes(&mut bytes);
let len = cmp::max((self.bits() + 7) / 8, 1);
hex.push_str(bytes[bytes.len() - len..].to_hex().as_ref());
serializer.visit_str(hex.as_ref())
serializer.serialize_str(hex.as_ref())
}
}
@@ -1490,6 +1582,38 @@ mod tests {
assert_eq!(format!("{}", U256::from(0)), "0");
}
#[test]
fn u512_multi_adds() {
let (result, _) = U512([0, 0, 0, 0, 0, 0, 0, 0]).overflowing_add(U512([0, 0, 0, 0, 0, 0, 0, 0]));
assert_eq!(result, U512([0, 0, 0, 0, 0, 0, 0, 0]));
let (result, _) = U512([1, 0, 0, 0, 0, 0, 0, 1]).overflowing_add(U512([1, 0, 0, 0, 0, 0, 0, 1]));
assert_eq!(result, U512([2, 0, 0, 0, 0, 0, 0, 2]));
let (result, _) = U512([0, 0, 0, 0, 0, 0, 0, 1]).overflowing_add(U512([0, 0, 0, 0, 0, 0, 0, 1]));
assert_eq!(result, U512([0, 0, 0, 0, 0, 0, 0, 2]));
let (result, _) = U512([0, 0, 0, 0, 0, 0, 2, 1]).overflowing_add(U512([0, 0, 0, 0, 0, 0, 3, 1]));
assert_eq!(result, U512([0, 0, 0, 0, 0, 0, 5, 2]));
let (result, _) = U512([1, 2, 3, 4, 5, 6, 7, 8]).overflowing_add(U512([9, 10, 11, 12, 13, 14, 15, 16]));
assert_eq!(result, U512([10, 12, 14, 16, 18, 20, 22, 24]));
let (_, overflow) = U512([0, 0, 0, 0, 0, 0, 2, 1]).overflowing_add(U512([0, 0, 0, 0, 0, 0, 3, 1]));
assert!(!overflow);
let (_, overflow) = U512([::std::u64::MAX, ::std::u64::MAX, ::std::u64::MAX, ::std::u64::MAX, ::std::u64::MAX, ::std::u64::MAX, ::std::u64::MAX, ::std::u64::MAX])
.overflowing_add(U512([::std::u64::MAX, ::std::u64::MAX, ::std::u64::MAX, ::std::u64::MAX, ::std::u64::MAX, ::std::u64::MAX, ::std::u64::MAX, ::std::u64::MAX]));
assert!(overflow);
let (_, overflow) = U512([0, 0, 0, 0, 0, 0, 0, ::std::u64::MAX])
.overflowing_add(U512([0, 0, 0, 0, 0, 0, 0, ::std::u64::MAX]));
assert!(overflow);
let (_, overflow) = U512([0, 0, 0, 0, 0, 0, 0, ::std::u64::MAX])
.overflowing_add(U512([0, 0, 0, 0, 0, 0, 0, 0]));
assert!(!overflow);
}
#[test]
fn u256_multi_adds() {
@@ -1537,6 +1661,21 @@ mod tests {
assert_eq!(U256([::std::u64::MAX, ::std::u64::MAX, ::std::u64::MAX, 0]), result);
}
#[test]
fn u512_multi_subs() {
let (result, _) = U512([0, 0, 0, 0, 0, 0, 0, 0]).overflowing_sub(U512([0, 0, 0, 0, 0, 0, 0, 0]));
assert_eq!(result, U512([0, 0, 0, 0, 0, 0, 0, 0]));
let (result, _) = U512([10, 9, 8, 7, 6, 5, 4, 3]).overflowing_sub(U512([9, 8, 7, 6, 5, 4, 3, 2]));
assert_eq!(result, U512([1, 1, 1, 1, 1, 1, 1, 1]));
let (_, overflow) = U512([10, 9, 8, 7, 6, 5, 4, 3]).overflowing_sub(U512([9, 8, 7, 6, 5, 4, 3, 2]));
assert!(!overflow);
let (_, overflow) = U512([9, 8, 7, 6, 5, 4, 3, 2]).overflowing_sub(U512([10, 9, 8, 7, 6, 5, 4, 3]));
assert!(overflow);
}
#[test]
fn u256_multi_carry_all() {
let (result, _) = U256([::std::u64::MAX, 0, 0, 0]).overflowing_mul(U256([::std::u64::MAX, 0, 0, 0]));