Merge with master

This commit is contained in:
Robert Habermeier
2016-06-29 14:46:29 +02:00
parent 3850ee64bb
commit 49024a4f28
235 changed files with 36456 additions and 4814 deletions

View File

@@ -132,15 +132,10 @@ macro_rules! impl_hash {
$size
}
// TODO: remove once slice::clone_from_slice is stable
#[inline]
fn clone_from_slice(&mut self, src: &[u8]) -> usize {
let min = ::std::cmp::min($size, src.len());
let dst = &mut self.deref_mut()[.. min];
let src = &src[.. min];
for i in 0..min {
dst[i] = src[i];
}
let min = cmp::min($size, src.len());
self.0[..min].copy_from_slice(&src[..min]);
min
}
@@ -151,7 +146,7 @@ macro_rules! impl_hash {
}
fn copy_to(&self, dest: &mut[u8]) {
let min = ::std::cmp::min($size, dest.len());
let min = cmp::min($size, dest.len());
dest[..min].copy_from_slice(&self.0[..min]);
}

View File

@@ -20,12 +20,10 @@ use bytes::*;
use std::collections::HashMap;
/// Trait modelling datastore keyed by a 32-byte Keccak hash.
pub trait HashDB : AsHashDB {
pub trait HashDB: AsHashDB {
/// Get the keys in the database together with number of underlying references.
fn keys(&self) -> HashMap<H256, i32>;
/// Deprecated. use `get`.
fn lookup(&self, key: &H256) -> Option<&[u8]>; // TODO: rename to get.
/// Look up a given hash into the bytes that hash to it, returning None if the
/// hash is not known.
///
@@ -41,10 +39,8 @@ pub trait HashDB : AsHashDB {
/// assert_eq!(m.get(&hash).unwrap(), hello_bytes);
/// }
/// ```
fn get(&self, key: &H256) -> Option<&[u8]> { self.lookup(key) }
fn get(&self, key: &H256) -> Option<&[u8]>;
/// Deprecated. Use `contains`.
fn exists(&self, key: &H256) -> bool; // TODO: rename to contains.
/// Check for the existance of a hash-key.
///
/// # Examples
@@ -63,10 +59,10 @@ pub trait HashDB : AsHashDB {
/// assert!(!m.contains(&key));
/// }
/// ```
fn contains(&self, key: &H256) -> bool { self.exists(key) }
fn contains(&self, key: &H256) -> bool;
/// Insert a datum item into the DB and return the datum's hash for a later lookup. Insertions
/// are counted and the equivalent number of `kill()`s must be performed before the data
/// are counted and the equivalent number of `remove()`s must be performed before the data
/// is considered dead.
///
/// # Examples
@@ -86,8 +82,6 @@ pub trait HashDB : AsHashDB {
/// Like `insert()` , except you provide the key and the data is all moved.
fn emplace(&mut self, key: H256, value: Bytes);
/// Deprecated - use `remove`.
fn kill(&mut self, key: &H256); // TODO: rename to remove.
/// Remove a datum previously inserted. Insertions can be "owed" such that the same number of `insert()`s may
/// happen without the data being eventually being inserted into the DB.
///
@@ -109,7 +103,7 @@ pub trait HashDB : AsHashDB {
/// assert_eq!(m.get(key).unwrap(), d);
/// }
/// ```
fn remove(&mut self, key: &H256) { self.kill(key) }
fn remove(&mut self, key: &H256);
}
/// Upcast trait.
@@ -121,6 +115,10 @@ pub trait AsHashDB {
}
impl<T: HashDB> AsHashDB for T {
fn as_hashdb(&self) -> &HashDB { self }
fn as_hashdb_mut(&mut self) -> &mut HashDB { self }
fn as_hashdb(&self) -> &HashDB {
self
}
fn as_hashdb_mut(&mut self) -> &mut HashDB {
self
}
}

View File

@@ -142,7 +142,7 @@ mod tests {
#[test]
fn test_service_register_handler () {
let mut service = IoService::<MyMessage>::start().expect("Error creating network service");
let service = IoService::<MyMessage>::start().expect("Error creating network service");
service.register_handler(Arc::new(MyHandler)).unwrap();
}

View File

@@ -18,9 +18,10 @@ use std::sync::*;
use std::thread::{self, JoinHandle};
use std::collections::HashMap;
use mio::*;
use crossbeam::sync::chase_lev;
use slab::Slab;
use error::*;
use io::{IoError, IoHandler};
use crossbeam::sync::chase_lev;
use io::worker::{Worker, Work, WorkType};
use panics::*;
@@ -33,6 +34,7 @@ pub type HandlerId = usize;
/// Maximum number of tokens a handler can use
pub const TOKENS_PER_HANDLER: usize = 16384;
const MAX_HANDLERS: usize = 8;
/// Messages used to communicate with the event loop from other threads.
#[derive(Clone)]
@@ -43,6 +45,9 @@ pub enum IoMessage<Message> where Message: Send + Clone + Sized {
AddHandler {
handler: Arc<IoHandler<Message>+Send>,
},
RemoveHandler {
handler_id: HandlerId,
},
AddTimer {
handler_id: HandlerId,
token: TimerToken,
@@ -130,14 +135,24 @@ impl<Message> IoContext<Message> where Message: Send + Clone + 'static {
}
/// Broadcast a message to other IO clients
pub fn message(&self, message: Message) {
self.channel.send(message).expect("Error seding message");
pub fn message(&self, message: Message) -> Result<(), UtilError> {
try!(self.channel.send(message));
Ok(())
}
/// Get message channel
pub fn channel(&self) -> IoChannel<Message> {
self.channel.clone()
}
/// Unregister current IO handler.
pub fn unregister_handler(&self) -> Result<(), IoError> {
try!(self.channel.send_io(IoMessage::RemoveHandler {
handler_id: self.handler,
}));
Ok(())
}
}
#[derive(Clone)]
@@ -149,7 +164,7 @@ struct UserTimer {
/// Root IO handler. Manages user handlers, messages and IO timers.
pub struct IoManager<Message> where Message: Send + Sync {
timers: Arc<RwLock<HashMap<HandlerId, UserTimer>>>,
handlers: Vec<Arc<IoHandler<Message>>>,
handlers: Slab<Arc<IoHandler<Message>>, HandlerId>,
workers: Vec<Worker>,
worker_channel: chase_lev::Worker<Work<Message>>,
work_ready: Arc<Condvar>,
@@ -175,7 +190,7 @@ impl<Message> IoManager<Message> where Message: Send + Sync + Clone + 'static {
let mut io = IoManager {
timers: Arc::new(RwLock::new(HashMap::new())),
handlers: Vec::new(),
handlers: Slab::new(MAX_HANDLERS),
worker_channel: worker,
workers: workers,
work_ready: work_ready,
@@ -192,36 +207,31 @@ impl<Message> Handler for IoManager<Message> where Message: Send + Clone + Sync
fn ready(&mut self, _event_loop: &mut EventLoop<Self>, token: Token, events: EventSet) {
let handler_index = token.as_usize() / TOKENS_PER_HANDLER;
let token_id = token.as_usize() % TOKENS_PER_HANDLER;
if handler_index >= self.handlers.len() {
panic!("Unexpected stream token: {}", token.as_usize());
}
let handler = self.handlers[handler_index].clone();
if events.is_hup() {
self.worker_channel.push(Work { work_type: WorkType::Hup, token: token_id, handler: handler.clone(), handler_id: handler_index });
}
else {
if events.is_readable() {
self.worker_channel.push(Work { work_type: WorkType::Readable, token: token_id, handler: handler.clone(), handler_id: handler_index });
if let Some(handler) = self.handlers.get(handler_index) {
if events.is_hup() {
self.worker_channel.push(Work { work_type: WorkType::Hup, token: token_id, handler: handler.clone(), handler_id: handler_index });
}
if events.is_writable() {
self.worker_channel.push(Work { work_type: WorkType::Writable, token: token_id, handler: handler.clone(), handler_id: handler_index });
else {
if events.is_readable() {
self.worker_channel.push(Work { work_type: WorkType::Readable, token: token_id, handler: handler.clone(), handler_id: handler_index });
}
if events.is_writable() {
self.worker_channel.push(Work { work_type: WorkType::Writable, token: token_id, handler: handler.clone(), handler_id: handler_index });
}
}
self.work_ready.notify_all();
}
self.work_ready.notify_all();
}
fn timeout(&mut self, event_loop: &mut EventLoop<Self>, token: Token) {
let handler_index = token.as_usize() / TOKENS_PER_HANDLER;
let token_id = token.as_usize() % TOKENS_PER_HANDLER;
if handler_index >= self.handlers.len() {
panic!("Unexpected timer token: {}", token.as_usize());
}
if let Some(timer) = self.timers.read().unwrap().get(&token.as_usize()) {
event_loop.timeout_ms(token, timer.delay).expect("Error re-registering user timer");
let handler = self.handlers[handler_index].clone();
self.worker_channel.push(Work { work_type: WorkType::Timeout, token: token_id, handler: handler, handler_id: handler_index });
self.work_ready.notify_all();
if let Some(handler) = self.handlers.get(handler_index) {
if let Some(timer) = self.timers.read().unwrap().get(&token.as_usize()) {
event_loop.timeout_ms(token, timer.delay).expect("Error re-registering user timer");
self.worker_channel.push(Work { work_type: WorkType::Timeout, token: token_id, handler: handler.clone(), handler_id: handler_index });
self.work_ready.notify_all();
}
}
}
@@ -232,12 +242,13 @@ impl<Message> Handler for IoManager<Message> where Message: Send + Clone + Sync
event_loop.shutdown();
},
IoMessage::AddHandler { handler } => {
let handler_id = {
self.handlers.push(handler.clone());
self.handlers.len() - 1
};
let handler_id = self.handlers.insert(handler.clone()).unwrap_or_else(|_| panic!("Too many handlers registered"));
handler.initialize(&IoContext::new(IoChannel::new(event_loop.channel()), handler_id));
},
IoMessage::RemoveHandler { handler_id } => {
// TODO: flush event loop
self.handlers.remove(handler_id);
},
IoMessage::AddTimer { handler_id, token, delay } => {
let timer_id = token + handler_id * TOKENS_PER_HANDLER;
let timeout = event_loop.timeout_ms(Token(timer_id), delay).expect("Error registering user timer");
@@ -250,26 +261,32 @@ impl<Message> Handler for IoManager<Message> where Message: Send + Clone + Sync
}
},
IoMessage::RegisterStream { handler_id, token } => {
let handler = self.handlers.get(handler_id).expect("Unknown handler id").clone();
handler.register_stream(token, Token(token + handler_id * TOKENS_PER_HANDLER), event_loop);
if let Some(handler) = self.handlers.get(handler_id) {
handler.register_stream(token, Token(token + handler_id * TOKENS_PER_HANDLER), event_loop);
}
},
IoMessage::DeregisterStream { handler_id, token } => {
let handler = self.handlers.get(handler_id).expect("Unknown handler id").clone();
handler.deregister_stream(token, event_loop);
// unregister a timer associated with the token (if any)
let timer_id = token + handler_id * TOKENS_PER_HANDLER;
if let Some(timer) = self.timers.write().unwrap().remove(&timer_id) {
event_loop.clear_timeout(timer.timeout);
if let Some(handler) = self.handlers.get(handler_id) {
handler.deregister_stream(token, event_loop);
// unregister a timer associated with the token (if any)
let timer_id = token + handler_id * TOKENS_PER_HANDLER;
if let Some(timer) = self.timers.write().unwrap().remove(&timer_id) {
event_loop.clear_timeout(timer.timeout);
}
}
},
IoMessage::UpdateStreamRegistration { handler_id, token } => {
let handler = self.handlers.get(handler_id).expect("Unknown handler id").clone();
handler.update_stream(token, Token(token + handler_id * TOKENS_PER_HANDLER), event_loop);
if let Some(handler) = self.handlers.get(handler_id) {
handler.update_stream(token, Token(token + handler_id * TOKENS_PER_HANDLER), event_loop);
}
},
IoMessage::UserMessage(data) => {
for n in 0 .. self.handlers.len() {
let handler = self.handlers[n].clone();
self.worker_channel.push(Work { work_type: WorkType::Message(data.clone()), token: 0, handler: handler, handler_id: n });
//TODO: better way to iterate the slab
for id in 0 .. MAX_HANDLERS {
if let Some(h) = self.handlers.get(id) {
let handler = h.clone();
self.worker_channel.push(Work { work_type: WorkType::Message(data.clone()), token: 0, handler: handler, handler_id: id });
}
}
self.work_ready.notify_all();
}
@@ -335,7 +352,9 @@ impl<Message> IoService<Message> where Message: Send + Sync + Clone + 'static {
/// Starts IO event loop
pub fn start() -> Result<IoService<Message>, UtilError> {
let panic_handler = PanicHandler::new_in_arc();
let mut event_loop = EventLoop::new().unwrap();
let mut config = EventLoopConfig::new();
config.messages_per_tick(1024);
let mut event_loop = EventLoop::configured(config).expect("Error creating event loop");
let channel = event_loop.channel();
let panic = panic_handler.clone();
let thread = thread::spawn(move || {
@@ -351,8 +370,8 @@ impl<Message> IoService<Message> where Message: Send + Sync + Clone + 'static {
})
}
/// Regiter a IO hadnler with the event loop.
pub fn register_handler(&mut self, handler: Arc<IoHandler<Message>+Send>) -> Result<(), IoError> {
/// Regiter an IO handler with the event loop.
pub fn register_handler(&self, handler: Arc<IoHandler<Message>+Send>) -> Result<(), IoError> {
try!(self.host_channel.send(IoMessage::AddHandler {
handler: handler,
}));
@@ -360,13 +379,13 @@ impl<Message> IoService<Message> where Message: Send + Sync + Clone + 'static {
}
/// Send a message over the network. Normaly `HostIo::send` should be used. This can be used from non-io threads.
pub fn send_message(&mut self, message: Message) -> Result<(), IoError> {
pub fn send_message(&self, message: Message) -> Result<(), IoError> {
try!(self.host_channel.send(IoMessage::UserMessage(message)));
Ok(())
}
/// Create a new message channel
pub fn channel(&mut self) -> IoChannel<Message> {
pub fn channel(&self) -> IoChannel<Message> {
IoChannel { channel: Some(self.host_channel.clone()) }
}
}
@@ -374,7 +393,7 @@ impl<Message> IoService<Message> where Message: Send + Sync + Clone + 'static {
impl<Message> Drop for IoService<Message> where Message: Send + Sync + Clone {
fn drop(&mut self) {
trace!(target: "shutdown", "[IoService] Closing...");
self.host_channel.send(IoMessage::Shutdown).unwrap();
self.host_channel.send(IoMessage::Shutdown).unwrap_or_else(|e| warn!("Error on IO service shutdown: {:?}", e));
self.thread.take().unwrap().join().ok();
trace!(target: "shutdown", "[IoService] Closed.");
}

View File

@@ -20,6 +20,7 @@ use common::*;
use rlp::*;
use hashdb::*;
use memorydb::*;
use super::{DB_PREFIX_LEN, LATEST_ERA_KEY, VERSION_KEY};
use super::traits::JournalDB;
use kvdb::{Database, DBTransaction, DatabaseConfig};
#[cfg(test)]
@@ -38,19 +39,12 @@ pub struct ArchiveDB {
latest_era: Option<u64>,
}
// all keys must be at least 12 bytes
const LATEST_ERA_KEY : [u8; 12] = [ b'l', b'a', b's', b't', 0, 0, 0, 0, 0, 0, 0, 0 ];
const VERSION_KEY : [u8; 12] = [ b'j', b'v', b'e', b'r', 0, 0, 0, 0, 0, 0, 0, 0 ];
const DB_VERSION : u32 = 0x103;
impl ArchiveDB {
/// Create a new instance from file
pub fn new(path: &str) -> ArchiveDB {
let opts = DatabaseConfig {
//use 12 bytes as prefix, this must match account_db prefix
prefix_size: Some(12),
max_open_files: 256,
};
pub fn new(path: &str, config: DatabaseConfig) -> ArchiveDB {
let opts = config.prefix(DB_PREFIX_LEN);
let backing = Database::open(&opts, path).unwrap_or_else(|e| {
panic!("Error opening state db: {}", e);
});
@@ -76,7 +70,7 @@ impl ArchiveDB {
fn new_temp() -> ArchiveDB {
let mut dir = env::temp_dir();
dir.push(H32::random().hex());
Self::new(dir.to_str().unwrap())
Self::new(dir.to_str().unwrap(), DatabaseConfig::default())
}
fn payload(&self, key: &H256) -> Option<Bytes> {
@@ -99,7 +93,7 @@ impl HashDB for ArchiveDB {
ret
}
fn lookup(&self, key: &H256) -> Option<&[u8]> {
fn get(&self, key: &H256) -> Option<&[u8]> {
let k = self.overlay.raw(key);
match k {
Some(&(ref d, rc)) if rc > 0 => Some(d),
@@ -114,8 +108,8 @@ impl HashDB for ArchiveDB {
}
}
fn exists(&self, key: &H256) -> bool {
self.lookup(key).is_some()
fn contains(&self, key: &H256) -> bool {
self.get(key).is_some()
}
fn insert(&mut self, value: &[u8]) -> H256 {
@@ -124,8 +118,8 @@ impl HashDB for ArchiveDB {
fn emplace(&mut self, key: H256, value: Bytes) {
self.overlay.emplace(key, value);
}
fn kill(&mut self, key: &H256) {
self.overlay.kill(key);
fn remove(&mut self, key: &H256) {
self.overlay.remove(key);
}
}
@@ -188,6 +182,7 @@ mod tests {
use super::*;
use hashdb::*;
use journaldb::traits::JournalDB;
use kvdb::DatabaseConfig;
#[test]
fn insert_same_in_fork() {
@@ -208,7 +203,7 @@ mod tests {
jdb.commit(5, &b"1004a".sha3(), Some((3, b"1002a".sha3()))).unwrap();
jdb.commit(6, &b"1005a".sha3(), Some((4, b"1003a".sha3()))).unwrap();
assert!(jdb.exists(&x));
assert!(jdb.contains(&x));
}
#[test]
@@ -217,14 +212,14 @@ mod tests {
let mut jdb = ArchiveDB::new_temp();
let h = jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.remove(&h);
jdb.commit(1, &b"1".sha3(), None).unwrap();
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.commit(2, &b"2".sha3(), None).unwrap();
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.commit(3, &b"3".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.commit(4, &b"4".sha3(), Some((1, b"1".sha3()))).unwrap();
}
@@ -236,26 +231,26 @@ mod tests {
let foo = jdb.insert(b"foo");
let bar = jdb.insert(b"bar");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
jdb.remove(&foo);
jdb.remove(&bar);
let baz = jdb.insert(b"baz");
jdb.commit(1, &b"1".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
assert!(jdb.contains(&baz));
let foo = jdb.insert(b"foo");
jdb.remove(&baz);
jdb.commit(2, &b"2".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&baz));
jdb.remove(&foo);
jdb.commit(3, &b"3".sha3(), Some((2, b"2".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.commit(4, &b"4".sha3(), Some((3, b"3".sha3()))).unwrap();
}
@@ -268,8 +263,8 @@ mod tests {
let foo = jdb.insert(b"foo");
let bar = jdb.insert(b"bar");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
jdb.remove(&foo);
let baz = jdb.insert(b"baz");
@@ -278,12 +273,12 @@ mod tests {
jdb.remove(&bar);
jdb.commit(1, &b"1b".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
assert!(jdb.contains(&baz));
jdb.commit(2, &b"2b".sha3(), Some((1, b"1b".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
}
#[test]
@@ -293,16 +288,16 @@ mod tests {
let foo = jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.remove(&foo);
jdb.commit(1, &b"1".sha3(), Some((0, b"0".sha3()))).unwrap();
jdb.insert(b"foo");
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.commit(2, &b"2".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.commit(3, &b"2".sha3(), Some((0, b"2".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
}
#[test]
@@ -316,10 +311,10 @@ mod tests {
jdb.insert(b"foo");
jdb.commit(1, &b"1b".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.commit(2, &b"2a".sha3(), Some((1, b"1a".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
}
#[test]
@@ -329,7 +324,7 @@ mod tests {
let bar = H256::random();
let foo = {
let mut jdb = ArchiveDB::new(dir.to_str().unwrap());
let mut jdb = ArchiveDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 1
let foo = jdb.insert(b"foo");
jdb.emplace(bar.clone(), b"bar".to_vec());
@@ -338,15 +333,15 @@ mod tests {
};
{
let mut jdb = ArchiveDB::new(dir.to_str().unwrap());
let mut jdb = ArchiveDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.remove(&foo);
jdb.commit(1, &b"1".sha3(), Some((0, b"0".sha3()))).unwrap();
}
{
let mut jdb = ArchiveDB::new(dir.to_str().unwrap());
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
let mut jdb = ArchiveDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
jdb.commit(2, &b"2".sha3(), Some((1, b"1".sha3()))).unwrap();
}
}
@@ -357,7 +352,7 @@ mod tests {
dir.push(H32::random().hex());
let foo = {
let mut jdb = ArchiveDB::new(dir.to_str().unwrap());
let mut jdb = ArchiveDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 1
let foo = jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
@@ -371,10 +366,10 @@ mod tests {
};
{
let mut jdb = ArchiveDB::new(dir.to_str().unwrap());
let mut jdb = ArchiveDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.remove(&foo);
jdb.commit(3, &b"3".sha3(), Some((2, b"2".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.remove(&foo);
jdb.commit(4, &b"4".sha3(), Some((3, b"3".sha3()))).unwrap();
jdb.commit(5, &b"5".sha3(), Some((4, b"4".sha3()))).unwrap();
@@ -386,7 +381,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let (foo, _, _) = {
let mut jdb = ArchiveDB::new(dir.to_str().unwrap());
let mut jdb = ArchiveDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 1
let foo = jdb.insert(b"foo");
let bar = jdb.insert(b"bar");
@@ -401,9 +396,9 @@ mod tests {
};
{
let mut jdb = ArchiveDB::new(dir.to_str().unwrap());
let mut jdb = ArchiveDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.commit(2, &b"2b".sha3(), Some((1, b"1b".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
}
}
@@ -412,14 +407,14 @@ mod tests {
let temp = ::devtools::RandomTempPath::new();
let key = {
let mut jdb = ArchiveDB::new(temp.as_str());
let mut jdb = ArchiveDB::new(temp.as_str(), DatabaseConfig::default());
let key = jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
key
};
{
let jdb = ArchiveDB::new(temp.as_str());
let jdb = ArchiveDB::new(temp.as_str(), DatabaseConfig::default());
let state = jdb.state(&key);
assert!(state.is_some());
}

View File

@@ -20,6 +20,7 @@ use common::*;
use rlp::*;
use hashdb::*;
use memorydb::*;
use super::{DB_PREFIX_LEN, LATEST_ERA_KEY, VERSION_KEY};
use super::traits::JournalDB;
use kvdb::{Database, DBTransaction, DatabaseConfig};
#[cfg(test)]
@@ -67,20 +68,13 @@ pub struct EarlyMergeDB {
latest_era: Option<u64>,
}
// all keys must be at least 12 bytes
const LATEST_ERA_KEY : [u8; 12] = [ b'l', b'a', b's', b't', 0, 0, 0, 0, 0, 0, 0, 0 ];
const VERSION_KEY : [u8; 12] = [ b'j', b'v', b'e', b'r', 0, 0, 0, 0, 0, 0, 0, 0 ];
const DB_VERSION : u32 = 0x003;
const PADDING : [u8; 10] = [ 0u8; 10 ];
impl EarlyMergeDB {
/// Create a new instance from file
pub fn new(path: &str) -> EarlyMergeDB {
let opts = DatabaseConfig {
//use 12 bytes as prefix, this must match account_db prefix
prefix_size: Some(12),
max_open_files: 256,
};
pub fn new(path: &str, config: DatabaseConfig) -> EarlyMergeDB {
let opts = config.prefix(DB_PREFIX_LEN);
let backing = Database::open(&opts, path).unwrap_or_else(|e| {
panic!("Error opening state db: {}", e);
});
@@ -108,7 +102,7 @@ impl EarlyMergeDB {
fn new_temp() -> EarlyMergeDB {
let mut dir = env::temp_dir();
dir.push(H32::random().hex());
Self::new(dir.to_str().unwrap())
Self::new(dir.to_str().unwrap(), DatabaseConfig::default())
}
fn morph_key(key: &H256, index: u8) -> Bytes {
@@ -173,8 +167,8 @@ impl EarlyMergeDB {
trace!(target: "jdb.fine", "replay_keys: (end) refs={:?}", refs);
}
fn kill_keys(deletes: &[H256], refs: &mut HashMap<H256, RefInfo>, batch: &DBTransaction, from: RemoveFrom, trace: bool) {
// with a kill on {queue_refs: 1, in_archive: true}, we have two options:
fn remove_keys(deletes: &[H256], refs: &mut HashMap<H256, RefInfo>, batch: &DBTransaction, from: RemoveFrom, trace: bool) {
// with a remove on {queue_refs: 1, in_archive: true}, we have two options:
// - convert to {queue_refs: 1, in_archive: false} (i.e. remove it from the conceptual archive)
// - convert to {queue_refs: 0, in_archive: true} (i.e. remove it from the conceptual queue)
// (the latter option would then mean removing the RefInfo, since it would no longer be counted in the queue.)
@@ -187,13 +181,13 @@ impl EarlyMergeDB {
c.in_archive = false;
Self::reset_already_in(batch, h);
if trace {
trace!(target: "jdb.fine", " kill({}): In archive, 1 in queue: Reducing to queue only and recording", h);
trace!(target: "jdb.fine", " remove({}): In archive, 1 in queue: Reducing to queue only and recording", h);
}
continue;
} else if c.queue_refs > 1 {
c.queue_refs -= 1;
if trace {
trace!(target: "jdb.fine", " kill({}): In queue > 1 refs: Decrementing ref count to {}", h, c.queue_refs);
trace!(target: "jdb.fine", " remove({}): In queue > 1 refs: Decrementing ref count to {}", h, c.queue_refs);
}
continue;
} else {
@@ -205,14 +199,14 @@ impl EarlyMergeDB {
refs.remove(h);
Self::reset_already_in(batch, h);
if trace {
trace!(target: "jdb.fine", " kill({}): In archive, 1 in queue: Removing from queue and leaving in archive", h);
trace!(target: "jdb.fine", " remove({}): In archive, 1 in queue: Removing from queue and leaving in archive", h);
}
}
Some(RefInfo{queue_refs: 1, in_archive: false}) => {
refs.remove(h);
batch.delete(&h.bytes()).expect("Low-level database error. Some issue with your hard disk?");
if trace {
trace!(target: "jdb.fine", " kill({}): Not in archive, only 1 ref in queue: Removing from queue and DB", h);
trace!(target: "jdb.fine", " remove({}): Not in archive, only 1 ref in queue: Removing from queue and DB", h);
}
}
None => {
@@ -220,7 +214,7 @@ impl EarlyMergeDB {
//assert!(!Self::is_already_in(db, &h));
batch.delete(&h.bytes()).expect("Low-level database error. Some issue with your hard disk?");
if trace {
trace!(target: "jdb.fine", " kill({}): Not in queue - MUST BE IN ARCHIVE: Removing from DB", h);
trace!(target: "jdb.fine", " remove({}): Not in queue - MUST BE IN ARCHIVE: Removing from DB", h);
}
}
_ => panic!("Invalid value in refs: {:?}", n),
@@ -291,7 +285,7 @@ impl HashDB for EarlyMergeDB {
ret
}
fn lookup(&self, key: &H256) -> Option<&[u8]> {
fn get(&self, key: &H256) -> Option<&[u8]> {
let k = self.overlay.raw(key);
match k {
Some(&(ref d, rc)) if rc > 0 => Some(d),
@@ -306,8 +300,8 @@ impl HashDB for EarlyMergeDB {
}
}
fn exists(&self, key: &H256) -> bool {
self.lookup(key).is_some()
fn contains(&self, key: &H256) -> bool {
self.get(key).is_some()
}
fn insert(&mut self, value: &[u8]) -> H256 {
@@ -316,8 +310,8 @@ impl HashDB for EarlyMergeDB {
fn emplace(&mut self, key: H256, value: Bytes) {
self.overlay.emplace(key, value);
}
fn kill(&mut self, key: &H256) {
self.overlay.kill(key);
fn remove(&mut self, key: &H256) {
self.overlay.remove(key);
}
}
@@ -473,7 +467,7 @@ impl JournalDB for EarlyMergeDB {
if trace {
trace!(target: "jdb.ops", " Expunging: {:?}", deletes);
}
Self::kill_keys(&deletes, &mut refs, &batch, RemoveFrom::Archive, trace);
Self::remove_keys(&deletes, &mut refs, &batch, RemoveFrom::Archive, trace);
if trace {
trace!(target: "jdb.ops", " Finalising: {:?}", inserts);
@@ -505,7 +499,7 @@ impl JournalDB for EarlyMergeDB {
if trace {
trace!(target: "jdb.ops", " Reverting: {:?}", inserts);
}
Self::kill_keys(&inserts, &mut refs, &batch, RemoveFrom::Queue, trace);
Self::remove_keys(&inserts, &mut refs, &batch, RemoveFrom::Queue, trace);
}
try!(batch.delete(&last));
@@ -538,6 +532,7 @@ mod tests {
use super::super::traits::JournalDB;
use hashdb::*;
use log::init_log;
use kvdb::DatabaseConfig;
#[test]
fn insert_same_in_fork() {
@@ -566,7 +561,7 @@ mod tests {
jdb.commit(6, &b"1005a".sha3(), Some((4, b"1003a".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&x));
assert!(jdb.contains(&x));
}
#[test]
@@ -585,8 +580,8 @@ mod tests {
assert!(jdb.can_reconstruct_refs());
jdb.commit(2, &b"2".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
}
#[test]
@@ -596,20 +591,20 @@ mod tests {
let h = jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.remove(&h);
jdb.commit(1, &b"1".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.commit(2, &b"2".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.commit(3, &b"3".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.commit(4, &b"4".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(!jdb.exists(&h));
assert!(!jdb.contains(&h));
}
#[test]
@@ -621,38 +616,38 @@ mod tests {
let bar = jdb.insert(b"bar");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
jdb.remove(&foo);
jdb.remove(&bar);
let baz = jdb.insert(b"baz");
jdb.commit(1, &b"1".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
assert!(jdb.contains(&baz));
let foo = jdb.insert(b"foo");
jdb.remove(&baz);
jdb.commit(2, &b"2".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(!jdb.exists(&bar));
assert!(jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(!jdb.contains(&bar));
assert!(jdb.contains(&baz));
jdb.remove(&foo);
jdb.commit(3, &b"3".sha3(), Some((2, b"2".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(!jdb.exists(&bar));
assert!(!jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(!jdb.contains(&bar));
assert!(!jdb.contains(&baz));
jdb.commit(4, &b"4".sha3(), Some((3, b"3".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(!jdb.exists(&foo));
assert!(!jdb.exists(&bar));
assert!(!jdb.exists(&baz));
assert!(!jdb.contains(&foo));
assert!(!jdb.contains(&bar));
assert!(!jdb.contains(&baz));
}
#[test]
@@ -664,8 +659,8 @@ mod tests {
let bar = jdb.insert(b"bar");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
jdb.remove(&foo);
let baz = jdb.insert(b"baz");
@@ -676,15 +671,15 @@ mod tests {
jdb.commit(1, &b"1b".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
assert!(jdb.contains(&baz));
jdb.commit(2, &b"2b".sha3(), Some((1, b"1b".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(!jdb.exists(&baz));
assert!(!jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(!jdb.contains(&baz));
assert!(!jdb.contains(&bar));
}
#[test]
@@ -695,19 +690,19 @@ mod tests {
let foo = jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.remove(&foo);
jdb.commit(1, &b"1".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
jdb.insert(b"foo");
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.commit(2, &b"2".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.commit(3, &b"2".sha3(), Some((0, b"2".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
}
#[test]
@@ -715,7 +710,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
@@ -731,11 +726,11 @@ mod tests {
jdb.commit(1, &b"1c".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.commit(2, &b"2a".sha3(), Some((1, b"1a".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
}
#[test]
@@ -743,7 +738,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
@@ -759,11 +754,11 @@ mod tests {
jdb.commit(1, &b"1c".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.commit(2, &b"2b".sha3(), Some((1, b"1b".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
}
#[test]
@@ -771,7 +766,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
@@ -809,7 +804,7 @@ mod tests {
let bar = H256::random();
let foo = {
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 1
let foo = jdb.insert(b"foo");
jdb.emplace(bar.clone(), b"bar".to_vec());
@@ -819,19 +814,19 @@ mod tests {
};
{
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.remove(&foo);
jdb.commit(1, &b"1".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
}
{
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
jdb.commit(2, &b"2".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(!jdb.exists(&foo));
assert!(!jdb.contains(&foo));
}
}
@@ -841,7 +836,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 4
let foo = jdb.insert(b"foo");
@@ -870,7 +865,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 4
let foo = jdb.insert(b"foo");
@@ -919,7 +914,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 1
let foo = jdb.insert(b"foo");
jdb.commit(1, &b"1".sha3(), Some((0, b"0".sha3()))).unwrap();
@@ -934,7 +929,7 @@ mod tests {
jdb.insert(b"foo");
jdb.commit(3, &b"3".sha3(), Some((2, b"2".sha3()))).unwrap(); // BROKEN
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.remove(&foo);
jdb.commit(4, &b"4".sha3(), Some((3, b"3".sha3()))).unwrap();
@@ -942,7 +937,7 @@ mod tests {
jdb.commit(5, &b"5".sha3(), Some((4, b"4".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(!jdb.exists(&foo));
assert!(!jdb.contains(&foo));
}
#[test]
@@ -950,7 +945,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 4
let foo = jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
@@ -990,7 +985,7 @@ mod tests {
let foo = b"foo".sha3();
{
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 1
jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
@@ -1003,34 +998,34 @@ mod tests {
jdb.remove(&foo);
jdb.commit(2, &b"2".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.insert(b"foo");
jdb.commit(3, &b"3".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
// incantation to reopen the db
}; { let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
}; { let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.remove(&foo);
jdb.commit(4, &b"4".sha3(), Some((2, b"2".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
// incantation to reopen the db
}; { let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
}; { let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.commit(5, &b"5".sha3(), Some((3, b"3".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
// incantation to reopen the db
}; { let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
}; { let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.commit(6, &b"6".sha3(), Some((4, b"4".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(!jdb.exists(&foo));
assert!(!jdb.contains(&foo));
}
}
@@ -1039,7 +1034,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let (foo, bar, baz) = {
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 1
let foo = jdb.insert(b"foo");
let bar = jdb.insert(b"bar");
@@ -1057,12 +1052,12 @@ mod tests {
};
{
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap());
let mut jdb = EarlyMergeDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.commit(2, &b"2b".sha3(), Some((1, b"1b".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(!jdb.exists(&baz));
assert!(!jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(!jdb.contains(&baz));
assert!(!jdb.contains(&bar));
}
}
}

View File

@@ -17,6 +17,7 @@
//! `JournalDB` interface and implementation.
use common::*;
use kvdb::DatabaseConfig;
/// Export the journaldb module.
pub mod traits;
@@ -71,11 +72,16 @@ impl fmt::Display for Algorithm {
}
/// Create a new `JournalDB` trait object.
pub fn new(path: &str, algorithm: Algorithm) -> Box<JournalDB> {
pub fn new(path: &str, algorithm: Algorithm, config: DatabaseConfig) -> Box<JournalDB> {
match algorithm {
Algorithm::Archive => Box::new(archivedb::ArchiveDB::new(path)),
Algorithm::EarlyMerge => Box::new(earlymergedb::EarlyMergeDB::new(path)),
Algorithm::OverlayRecent => Box::new(overlayrecentdb::OverlayRecentDB::new(path)),
Algorithm::RefCounted => Box::new(refcounteddb::RefCountedDB::new(path)),
Algorithm::Archive => Box::new(archivedb::ArchiveDB::new(path, config)),
Algorithm::EarlyMerge => Box::new(earlymergedb::EarlyMergeDB::new(path, config)),
Algorithm::OverlayRecent => Box::new(overlayrecentdb::OverlayRecentDB::new(path, config)),
Algorithm::RefCounted => Box::new(refcounteddb::RefCountedDB::new(path, config)),
}
}
// all keys must be at least 12 bytes
const DB_PREFIX_LEN : usize = 12;
const LATEST_ERA_KEY : [u8; DB_PREFIX_LEN] = [ b'l', b'a', b's', b't', 0, 0, 0, 0, 0, 0, 0, 0 ];
const VERSION_KEY : [u8; DB_PREFIX_LEN] = [ b'j', b'v', b'e', b'r', 0, 0, 0, 0, 0, 0, 0, 0 ];

View File

@@ -20,6 +20,7 @@ use common::*;
use rlp::*;
use hashdb::*;
use memorydb::*;
use super::{DB_PREFIX_LEN, LATEST_ERA_KEY, VERSION_KEY};
use kvdb::{Database, DBTransaction, DatabaseConfig};
#[cfg(test)]
use std::env;
@@ -92,25 +93,18 @@ impl Clone for OverlayRecentDB {
}
}
// all keys must be at least 12 bytes
const LATEST_ERA_KEY : [u8; 12] = [ b'l', b'a', b's', b't', 0, 0, 0, 0, 0, 0, 0, 0 ];
const VERSION_KEY : [u8; 12] = [ b'j', b'v', b'e', b'r', 0, 0, 0, 0, 0, 0, 0, 0 ];
const DB_VERSION : u32 = 0x203;
const PADDING : [u8; 10] = [ 0u8; 10 ];
impl OverlayRecentDB {
/// Create a new instance from file
pub fn new(path: &str) -> OverlayRecentDB {
Self::from_prefs(path)
pub fn new(path: &str, config: DatabaseConfig) -> OverlayRecentDB {
Self::from_prefs(path, config)
}
/// Create a new instance from file
pub fn from_prefs(path: &str) -> OverlayRecentDB {
let opts = DatabaseConfig {
//use 12 bytes as prefix, this must match account_db prefix
prefix_size: Some(12),
max_open_files: 256,
};
pub fn from_prefs(path: &str, config: DatabaseConfig) -> OverlayRecentDB {
let opts = config.prefix(DB_PREFIX_LEN);
let backing = Database::open(&opts, path).unwrap_or_else(|e| {
panic!("Error opening state db: {}", e);
});
@@ -136,7 +130,7 @@ impl OverlayRecentDB {
pub fn new_temp() -> OverlayRecentDB {
let mut dir = env::temp_dir();
dir.push(H32::random().hex());
Self::new(dir.to_str().unwrap())
Self::new(dir.to_str().unwrap(), DatabaseConfig::default())
}
#[cfg(test)]
@@ -289,11 +283,11 @@ impl JournalDB for OverlayRecentDB {
}
// update the overlay
for k in overlay_deletions {
journal_overlay.backing_overlay.kill(&k);
journal_overlay.backing_overlay.remove(&k);
}
// apply canon deletions
for k in canon_deletions {
if !journal_overlay.backing_overlay.exists(&k) {
if !journal_overlay.backing_overlay.contains(&k) {
try!(batch.delete(&k));
}
}
@@ -322,12 +316,12 @@ impl HashDB for OverlayRecentDB {
ret
}
fn lookup(&self, key: &H256) -> Option<&[u8]> {
fn get(&self, key: &H256) -> Option<&[u8]> {
let k = self.transaction_overlay.raw(key);
match k {
Some(&(ref d, rc)) if rc > 0 => Some(d),
_ => {
let v = self.journal_overlay.read().unwrap().backing_overlay.lookup(key).map(|v| v.to_vec());
let v = self.journal_overlay.read().unwrap().backing_overlay.get(key).map(|v| v.to_vec());
match v {
Some(x) => {
Some(&self.transaction_overlay.denote(key, x).0)
@@ -345,8 +339,8 @@ impl HashDB for OverlayRecentDB {
}
}
fn exists(&self, key: &H256) -> bool {
self.lookup(key).is_some()
fn contains(&self, key: &H256) -> bool {
self.get(key).is_some()
}
fn insert(&mut self, value: &[u8]) -> H256 {
@@ -355,8 +349,8 @@ impl HashDB for OverlayRecentDB {
fn emplace(&mut self, key: H256, value: Bytes) {
self.transaction_overlay.emplace(key, value);
}
fn kill(&mut self, key: &H256) {
self.transaction_overlay.kill(key);
fn remove(&mut self, key: &H256) {
self.transaction_overlay.remove(key);
}
}
@@ -370,6 +364,7 @@ mod tests {
use hashdb::*;
use log::init_log;
use journaldb::JournalDB;
use kvdb::DatabaseConfig;
#[test]
fn insert_same_in_fork() {
@@ -398,7 +393,7 @@ mod tests {
jdb.commit(6, &b"1005a".sha3(), Some((4, b"1003a".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&x));
assert!(jdb.contains(&x));
}
#[test]
@@ -408,20 +403,20 @@ mod tests {
let h = jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.remove(&h);
jdb.commit(1, &b"1".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.commit(2, &b"2".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.commit(3, &b"3".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.commit(4, &b"4".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(!jdb.exists(&h));
assert!(!jdb.contains(&h));
}
#[test]
@@ -433,38 +428,38 @@ mod tests {
let bar = jdb.insert(b"bar");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
jdb.remove(&foo);
jdb.remove(&bar);
let baz = jdb.insert(b"baz");
jdb.commit(1, &b"1".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
assert!(jdb.contains(&baz));
let foo = jdb.insert(b"foo");
jdb.remove(&baz);
jdb.commit(2, &b"2".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(!jdb.exists(&bar));
assert!(jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(!jdb.contains(&bar));
assert!(jdb.contains(&baz));
jdb.remove(&foo);
jdb.commit(3, &b"3".sha3(), Some((2, b"2".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(!jdb.exists(&bar));
assert!(!jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(!jdb.contains(&bar));
assert!(!jdb.contains(&baz));
jdb.commit(4, &b"4".sha3(), Some((3, b"3".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(!jdb.exists(&foo));
assert!(!jdb.exists(&bar));
assert!(!jdb.exists(&baz));
assert!(!jdb.contains(&foo));
assert!(!jdb.contains(&bar));
assert!(!jdb.contains(&baz));
}
#[test]
@@ -476,8 +471,8 @@ mod tests {
let bar = jdb.insert(b"bar");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
jdb.remove(&foo);
let baz = jdb.insert(b"baz");
@@ -488,15 +483,15 @@ mod tests {
jdb.commit(1, &b"1b".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
assert!(jdb.contains(&baz));
jdb.commit(2, &b"2b".sha3(), Some((1, b"1b".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(!jdb.exists(&baz));
assert!(!jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(!jdb.contains(&baz));
assert!(!jdb.contains(&bar));
}
#[test]
@@ -507,19 +502,19 @@ mod tests {
let foo = jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.remove(&foo);
jdb.commit(1, &b"1".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
jdb.insert(b"foo");
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.commit(2, &b"2".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.commit(3, &b"2".sha3(), Some((0, b"2".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
}
#[test]
@@ -527,7 +522,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
@@ -543,11 +538,11 @@ mod tests {
jdb.commit(1, &b"1c".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.commit(2, &b"2a".sha3(), Some((1, b"1a".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
}
#[test]
@@ -555,7 +550,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
@@ -571,11 +566,11 @@ mod tests {
jdb.commit(1, &b"1c".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.commit(2, &b"2b".sha3(), Some((1, b"1b".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
}
#[test]
@@ -583,7 +578,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.can_reconstruct_refs());
@@ -621,7 +616,7 @@ mod tests {
let bar = H256::random();
let foo = {
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 1
let foo = jdb.insert(b"foo");
jdb.emplace(bar.clone(), b"bar".to_vec());
@@ -631,19 +626,19 @@ mod tests {
};
{
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.remove(&foo);
jdb.commit(1, &b"1".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
}
{
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
jdb.commit(2, &b"2".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(!jdb.exists(&foo));
assert!(!jdb.contains(&foo));
}
}
@@ -653,7 +648,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 4
let foo = jdb.insert(b"foo");
@@ -682,7 +677,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 4
let foo = jdb.insert(b"foo");
@@ -731,7 +726,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 1
let foo = jdb.insert(b"foo");
jdb.commit(1, &b"1".sha3(), Some((0, b"0".sha3()))).unwrap();
@@ -746,7 +741,7 @@ mod tests {
jdb.insert(b"foo");
jdb.commit(3, &b"3".sha3(), Some((2, b"2".sha3()))).unwrap(); // BROKEN
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.remove(&foo);
jdb.commit(4, &b"4".sha3(), Some((3, b"3".sha3()))).unwrap();
@@ -754,7 +749,7 @@ mod tests {
jdb.commit(5, &b"5".sha3(), Some((4, b"4".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(!jdb.exists(&foo));
assert!(!jdb.contains(&foo));
}
#[test]
@@ -762,7 +757,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 4
let foo = jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
@@ -802,7 +797,7 @@ mod tests {
let foo = b"foo".sha3();
{
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 1
jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
@@ -815,34 +810,34 @@ mod tests {
jdb.remove(&foo);
jdb.commit(2, &b"2".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
jdb.insert(b"foo");
jdb.commit(3, &b"3".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
// incantation to reopen the db
}; { let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
}; { let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.remove(&foo);
jdb.commit(4, &b"4".sha3(), Some((2, b"2".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
// incantation to reopen the db
}; { let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
}; { let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.commit(5, &b"5".sha3(), Some((3, b"3".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(jdb.contains(&foo));
// incantation to reopen the db
}; { let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
}; { let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.commit(6, &b"6".sha3(), Some((4, b"4".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(!jdb.exists(&foo));
assert!(!jdb.contains(&foo));
}
}
@@ -851,7 +846,7 @@ mod tests {
let mut dir = ::std::env::temp_dir();
dir.push(H32::random().hex());
let (foo, bar, baz) = {
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
// history is 1
let foo = jdb.insert(b"foo");
let bar = jdb.insert(b"bar");
@@ -869,12 +864,12 @@ mod tests {
};
{
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap());
let mut jdb = OverlayRecentDB::new(dir.to_str().unwrap(), DatabaseConfig::default());
jdb.commit(2, &b"2b".sha3(), Some((1, b"1b".sha3()))).unwrap();
assert!(jdb.can_reconstruct_refs());
assert!(jdb.exists(&foo));
assert!(!jdb.exists(&baz));
assert!(!jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(!jdb.contains(&baz));
assert!(!jdb.contains(&bar));
}
}
@@ -894,7 +889,7 @@ mod tests {
assert!(jdb.can_reconstruct_refs());
jdb.commit(2, &b"2".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
}
}

View File

@@ -20,6 +20,7 @@ use common::*;
use rlp::*;
use hashdb::*;
use overlaydb::*;
use super::{DB_PREFIX_LEN, LATEST_ERA_KEY, VERSION_KEY};
use super::traits::JournalDB;
use kvdb::{Database, DBTransaction, DatabaseConfig};
#[cfg(test)]
@@ -40,19 +41,13 @@ pub struct RefCountedDB {
removes: Vec<H256>,
}
const LATEST_ERA_KEY : [u8; 12] = [ b'l', b'a', b's', b't', 0, 0, 0, 0, 0, 0, 0, 0 ];
const VERSION_KEY : [u8; 12] = [ b'j', b'v', b'e', b'r', 0, 0, 0, 0, 0, 0, 0, 0 ];
const DB_VERSION : u32 = 0x200;
const PADDING : [u8; 10] = [ 0u8; 10 ];
impl RefCountedDB {
/// Create a new instance given a `backing` database.
pub fn new(path: &str) -> RefCountedDB {
let opts = DatabaseConfig {
//use 12 bytes as prefix, this must match account_db prefix
prefix_size: Some(12),
max_open_files: 256,
};
pub fn new(path: &str, config: DatabaseConfig) -> RefCountedDB {
let opts = config.prefix(DB_PREFIX_LEN);
let backing = Database::open(&opts, path).unwrap_or_else(|e| {
panic!("Error opening state db: {}", e);
});
@@ -82,17 +77,17 @@ impl RefCountedDB {
fn new_temp() -> RefCountedDB {
let mut dir = env::temp_dir();
dir.push(H32::random().hex());
Self::new(dir.to_str().unwrap())
Self::new(dir.to_str().unwrap(), DatabaseConfig::default())
}
}
impl HashDB for RefCountedDB {
fn keys(&self) -> HashMap<H256, i32> { self.forward.keys() }
fn lookup(&self, key: &H256) -> Option<&[u8]> { self.forward.lookup(key) }
fn exists(&self, key: &H256) -> bool { self.forward.exists(key) }
fn get(&self, key: &H256) -> Option<&[u8]> { self.forward.get(key) }
fn contains(&self, key: &H256) -> bool { self.forward.contains(key) }
fn insert(&mut self, value: &[u8]) -> H256 { let r = self.forward.insert(value); self.inserts.push(r.clone()); r }
fn emplace(&mut self, key: H256, value: Bytes) { self.inserts.push(key.clone()); self.forward.emplace(key, value); }
fn kill(&mut self, key: &H256) { self.removes.push(key.clone()); }
fn remove(&mut self, key: &H256) { self.removes.push(key.clone()); }
}
impl JournalDB for RefCountedDB {
@@ -212,16 +207,16 @@ mod tests {
let mut jdb = RefCountedDB::new_temp();
let h = jdb.insert(b"foo");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.remove(&h);
jdb.commit(1, &b"1".sha3(), None).unwrap();
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.commit(2, &b"2".sha3(), None).unwrap();
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.commit(3, &b"3".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.exists(&h));
assert!(jdb.contains(&h));
jdb.commit(4, &b"4".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(!jdb.exists(&h));
assert!(!jdb.contains(&h));
}
#[test]
@@ -251,34 +246,34 @@ mod tests {
let foo = jdb.insert(b"foo");
let bar = jdb.insert(b"bar");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
jdb.remove(&foo);
jdb.remove(&bar);
let baz = jdb.insert(b"baz");
jdb.commit(1, &b"1".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
assert!(jdb.contains(&baz));
let foo = jdb.insert(b"foo");
jdb.remove(&baz);
jdb.commit(2, &b"2".sha3(), Some((1, b"1".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(!jdb.exists(&bar));
assert!(jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(!jdb.contains(&bar));
assert!(jdb.contains(&baz));
jdb.remove(&foo);
jdb.commit(3, &b"3".sha3(), Some((2, b"2".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(!jdb.exists(&bar));
assert!(!jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(!jdb.contains(&bar));
assert!(!jdb.contains(&baz));
jdb.commit(4, &b"4".sha3(), Some((3, b"3".sha3()))).unwrap();
assert!(!jdb.exists(&foo));
assert!(!jdb.exists(&bar));
assert!(!jdb.exists(&baz));
assert!(!jdb.contains(&foo));
assert!(!jdb.contains(&bar));
assert!(!jdb.contains(&baz));
}
#[test]
@@ -289,8 +284,8 @@ mod tests {
let foo = jdb.insert(b"foo");
let bar = jdb.insert(b"bar");
jdb.commit(0, &b"0".sha3(), None).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
jdb.remove(&foo);
let baz = jdb.insert(b"baz");
@@ -299,13 +294,13 @@ mod tests {
jdb.remove(&bar);
jdb.commit(1, &b"1b".sha3(), Some((0, b"0".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(jdb.exists(&bar));
assert!(jdb.exists(&baz));
assert!(jdb.contains(&foo));
assert!(jdb.contains(&bar));
assert!(jdb.contains(&baz));
jdb.commit(2, &b"2b".sha3(), Some((1, b"1b".sha3()))).unwrap();
assert!(jdb.exists(&foo));
assert!(!jdb.exists(&baz));
assert!(!jdb.exists(&bar));
assert!(jdb.contains(&foo));
assert!(!jdb.contains(&baz));
assert!(!jdb.contains(&bar));
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,261 +0,0 @@
// Copyright 2015, 2016 Ethcore (UK) Ltd.
// This file is part of Parity.
// Parity is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// Parity is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with Parity. If not, see <http://www.gnu.org/licenses/>.
//! Geth keys import/export tool
use common::*;
use keys::store::SecretStore;
use keys::directory::KeyFileContent;
use std::path::PathBuf;
use path;
/// Enumerates all geth keys in the directory and returns collection of tuples `(accountId, filename)`
pub fn enumerate_geth_keys(path: &Path) -> Result<Vec<(Address, String)>, ImportError> {
let mut entries = Vec::new();
for entry in try!(fs::read_dir(path)) {
let entry = try!(entry);
if !try!(fs::metadata(entry.path())).is_dir() {
match entry.file_name().to_str() {
Some(name) => {
let parts: Vec<&str> = name.split("--").collect();
if parts.len() != 3 { continue; }
let account_id = try!(Address::from_str(parts[2]).map_err(|_| ImportError::Format));
entries.push((account_id, name.to_owned()));
},
None => { continue; }
};
}
}
Ok(entries)
}
/// Geth import error
#[derive(Debug)]
pub enum ImportError {
/// Io error reading geth file
Io(io::Error),
/// format error
Format,
}
impl From<io::Error> for ImportError {
fn from (err: io::Error) -> ImportError {
ImportError::Io(err)
}
}
/// Imports one geth key to the store
pub fn import_geth_key(secret_store: &mut SecretStore, geth_keyfile_path: &Path) -> Result<(), ImportError> {
let mut file = try!(fs::File::open(geth_keyfile_path));
let mut buf = String::new();
try!(file.read_to_string(&mut buf));
let mut json_result = Json::from_str(&buf);
let mut json = match json_result {
Ok(ref mut parsed_json) => try!(parsed_json.as_object_mut().ok_or(ImportError::Format)),
Err(_) => { return Err(ImportError::Format); }
};
if let Some(crypto_object) = json.get("Crypto").and_then(|crypto| crypto.as_object()).cloned() {
json.insert("crypto".to_owned(), Json::Object(crypto_object));
json.remove("Crypto");
}
match KeyFileContent::load(&Json::Object(json.clone())) {
Ok(key_file) => try!(secret_store.import_key(key_file)),
Err(_) => { return Err(ImportError::Format); }
};
Ok(())
}
/// Imports all geth keys in the directory
pub fn import_geth_keys(secret_store: &mut SecretStore, geth_keyfiles_directory: &Path) -> Result<usize, ImportError> {
use std::path::PathBuf;
let geth_files = try!(enumerate_geth_keys(geth_keyfiles_directory));
let mut total = 0;
for &(ref address, ref file_path) in &geth_files {
let mut path = PathBuf::new();
path.push(geth_keyfiles_directory);
path.push(file_path);
if let Err(e) = import_geth_key(secret_store, Path::new(&path)) {
warn!("Skipped geth address {}, error importing: {:?}", address, e)
}
else { total = total + 1}
}
Ok(total)
}
/// Gets the default geth keystore directory.
pub fn keystore_dir(is_testnet: bool) -> PathBuf {
path::ethereum::with_default(if is_testnet {"testnet/keystore"} else {"keystore"})
}
/// Imports key(s) from provided file/directory
pub fn import_keys_path(secret_store: &mut SecretStore, path: &str) -> Result<usize, ImportError> {
// check if it is just one file or directory
if let Ok(meta) = fs::metadata(path) {
if meta.is_file() {
try!(import_geth_key(secret_store, Path::new(path)));
return Ok(1);
}
else if meta.is_dir() {
return Ok(try!(fs::read_dir(path)).fold(
0,
|total, p|
total +
match p {
Ok(dir_entry) => import_keys_path(secret_store, dir_entry.path().to_str().unwrap()).unwrap_or_else(|_| 0),
Err(e) => { warn!("Error importing dir entry: {:?}", e); 0 },
}
))
}
}
Ok(0)
}
/// Imports all keys from list of provided files/directories
pub fn import_keys_paths(secret_store: &mut SecretStore, path: &[String]) -> Result<usize, ImportError> {
Ok(path.iter().fold(0, |total, ref p| total + import_keys_path(secret_store, &p).unwrap_or_else(|_| 0)))
}
#[cfg(test)]
mod tests {
use super::*;
use common::*;
use keys::store::SecretStore;
fn test_path() -> &'static str {
match ::std::fs::metadata("res") {
Ok(_) => "res/geth_keystore",
Err(_) => "util/res/geth_keystore"
}
}
fn pat_path() -> &'static str {
match ::std::fs::metadata("res") {
Ok(_) => "res/pat",
Err(_) => "util/res/pat"
}
}
fn test_path_param(param_val: &'static str) -> String {
test_path().to_owned() + param_val
}
fn pat_path_param(param_val: &'static str) -> String {
pat_path().to_owned() + param_val
}
#[test]
fn can_enumerate() {
let keys = enumerate_geth_keys(Path::new(test_path())).unwrap();
assert_eq!(3, keys.len());
}
#[test]
fn can_import_geth_old() {
let temp = ::devtools::RandomTempPath::create_dir();
let mut secret_store = SecretStore::new_in(temp.as_path());
import_geth_key(&mut secret_store, Path::new(&test_path_param("/UTC--2016-02-17T09-20-45.721400158Z--3f49624084b67849c7b4e805c5988c21a430f9d9"))).unwrap();
let key = secret_store.account(&Address::from_str("3f49624084b67849c7b4e805c5988c21a430f9d9").unwrap());
assert!(key.is_some());
}
#[test]
fn can_import_geth140() {
let temp = ::devtools::RandomTempPath::create_dir();
let mut secret_store = SecretStore::new_in(temp.as_path());
import_geth_key(&mut secret_store, Path::new(&test_path_param("/UTC--2016-04-03T08-58-49.834202900Z--63121b431a52f8043c16fcf0d1df9cb7b5f66649"))).unwrap();
let key = secret_store.account(&Address::from_str("63121b431a52f8043c16fcf0d1df9cb7b5f66649").unwrap());
assert!(key.is_some());
}
#[test]
fn can_import_directory() {
let temp = ::devtools::RandomTempPath::create_dir();
let mut secret_store = SecretStore::new_in(temp.as_path());
import_geth_keys(&mut secret_store, Path::new(test_path())).unwrap();
let key = secret_store.account(&Address::from_str("3f49624084b67849c7b4e805c5988c21a430f9d9").unwrap());
assert!(key.is_some());
let key = secret_store.account(&Address::from_str("5ba4dcf897e97c2bdf8315b9ef26c13c085988cf").unwrap());
assert!(key.is_some());
}
#[test]
fn imports_as_scrypt_keys() {
use keys::directory::{KeyDirectory, KeyFileKdf};
let temp = ::devtools::RandomTempPath::create_dir();
{
let mut secret_store = SecretStore::new_in(temp.as_path());
import_geth_keys(&mut secret_store, Path::new(test_path())).unwrap();
}
let key_directory = KeyDirectory::new(&temp.as_path());
let key_file = key_directory.get(&H128::from_str("62a0ad73556d496a8e1c0783d30d3ace").unwrap()).unwrap();
match key_file.crypto.kdf {
KeyFileKdf::Scrypt(scrypt_params) => {
assert_eq!(262144, scrypt_params.n);
assert_eq!(8, scrypt_params.r);
assert_eq!(1, scrypt_params.p);
},
_ => { panic!("expected kdf params of crypto to be of scrypt type" ); }
}
}
#[test]
#[cfg(feature="heavy-tests")]
fn can_decrypt_with_imported() {
use keys::store::EncryptedHashMap;
let temp = ::devtools::RandomTempPath::create_dir();
let mut secret_store = SecretStore::new_in(temp.as_path());
import_geth_keys(&mut secret_store, Path::new(test_path())).unwrap();
let val = secret_store.get::<Bytes>(&H128::from_str("62a0ad73556d496a8e1c0783d30d3ace").unwrap(), "123");
assert!(val.is_ok());
assert_eq!(32, val.unwrap().len());
}
#[test]
fn can_import_by_filename() {
let temp = ::devtools::RandomTempPath::create_dir();
let mut secret_store = SecretStore::new_in(temp.as_path());
let amount = import_keys_path(&mut secret_store, &pat_path_param("/p1.json")).unwrap();
assert_eq!(1, amount);
}
#[test]
fn can_import_by_dir() {
let temp = ::devtools::RandomTempPath::create_dir();
let mut secret_store = SecretStore::new_in(temp.as_path());
let amount = import_keys_path(&mut secret_store, pat_path()).unwrap();
assert_eq!(2, amount);
}
#[test]
fn can_import_mulitple() {
let temp = ::devtools::RandomTempPath::create_dir();
let mut secret_store = SecretStore::new_in(temp.as_path());
let amount = import_keys_paths(&mut secret_store, &[pat_path_param("/p1.json"), pat_path_param("/p2.json")]).unwrap();
assert_eq!(2, amount);
}
}

View File

@@ -1,26 +0,0 @@
// Copyright 2015, 2016 Ethcore (UK) Ltd.
// This file is part of Parity.
// Parity is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// Parity is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with Parity. If not, see <http://www.gnu.org/licenses/>.
//! Key management module
pub mod directory;
pub mod store;
mod geth_import;
mod test_account_provider;
pub use self::store::AccountProvider;
pub use self::test_account_provider::{TestAccount, TestAccountProvider};
pub use self::geth_import::import_keys_paths;

View File

@@ -1,750 +0,0 @@
// Copyright 2015, 2016 Ethcore (UK) Ltd.
// This file is part of Parity.
// Parity is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// Parity is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with Parity. If not, see <http://www.gnu.org/licenses/>.
//! Secret Store
use keys::directory::*;
use common::*;
use rcrypto::pbkdf2::*;
use rcrypto::scrypt::*;
use rcrypto::hmac::*;
use crypto;
use chrono::*;
const KEY_LENGTH: u32 = 32;
const KEY_ITERATIONS: u32 = 10240;
const KEY_LENGTH_AES: u32 = KEY_LENGTH/2;
const KEY_LENGTH_USIZE: usize = KEY_LENGTH as usize;
const KEY_LENGTH_AES_USIZE: usize = KEY_LENGTH_AES as usize;
// TODO: this file needs repotting into several separate files.
/// Encrypted hash-map, each request should contain password
pub trait EncryptedHashMap<Key: Hash + Eq> {
/// Returns existing value for the key, if any
fn get<Value: FromRawBytesVariable + BytesConvertable>(&self, key: &Key, password: &str) -> Result<Value, EncryptedHashMapError>;
/// Insert new encrypted key-value and returns previous if there was any
fn insert<Value: FromRawBytesVariable + BytesConvertable>(&mut self, key: Key, value: Value, password: &str) -> Option<Value>;
/// Removes key-value by key and returns the removed one, if any exists and password was provided
fn remove<Value: FromRawBytesVariable + BytesConvertable> (&mut self, key: &Key, password: Option<&str>) -> Option<Value>;
/// Deletes key-value by key and returns if the key-value existed
fn delete(&mut self, key: &Key) -> bool {
self.remove::<Bytes>(key, None).is_some()
}
}
/// Error retrieving value from encrypted hashmap
#[derive(Debug)]
pub enum EncryptedHashMapError {
/// Encryption failed
InvalidPassword,
/// No key in the hashmap
UnknownIdentifier,
/// Stored value is not well formed for the requested type
InvalidValueFormat(FromBytesError),
}
/// Error while signing a message
#[derive(Debug)]
pub enum SigningError {
/// Account passed does not exist
NoAccount,
/// Account passed is not unlocked
AccountNotUnlocked,
/// Invalid passphrase
InvalidPassword,
/// Invalid secret in store
InvalidSecret
}
/// Represent service for storing encrypted arbitrary data
pub struct SecretStore {
directory: KeyDirectory,
unlocks: RwLock<HashMap<Address, AccountUnlock>>,
key_iterations: u32,
}
struct AccountUnlock {
secret: H256,
/// expiration datetime (None - never)
expires: Option<DateTime<UTC>>,
/// Sccount should be relocked after first use.
relock_on_use: bool,
}
/// Basic account management trait
pub trait AccountProvider: Send + Sync {
/// Lists all accounts
fn accounts(&self) -> Result<Vec<Address>, ::std::io::Error>;
/// Unlocks account with the password provided
fn unlock_account(&self, account: &Address, pass: &str) -> Result<(), EncryptedHashMapError>;
/// Unlocks account with the password provided; relocks it on the next call to `account_secret` or `sign`.
fn unlock_account_temp(&self, account: &Address, pass: &str) -> Result<(), EncryptedHashMapError>;
/// Creates account
fn new_account(&self, pass: &str) -> Result<Address, ::std::io::Error>;
/// Returns secret for unlocked `account`.
fn account_secret(&self, account: &Address) -> Result<crypto::Secret, SigningError>;
/// Returns secret for locked account given passphrase.
fn locked_account_secret(&self, account: &Address, pass: &str) -> Result<crypto::Secret, SigningError>;
/// Returns signature when unlocked `account` signs `message`.
fn sign(&self, account: &Address, message: &H256) -> Result<crypto::Signature, SigningError> {
self.account_secret(account).and_then(|s| crypto::ec::sign(&s, message).map_err(|_| SigningError::InvalidSecret))
}
}
/// Thread-safe accounts management
pub struct AccountService {
secret_store: RwLock<SecretStore>,
}
impl AccountProvider for AccountService {
/// Lists all accounts
fn accounts(&self) -> Result<Vec<Address>, ::std::io::Error> {
Ok(try!(self.secret_store.read().unwrap().accounts()).iter().map(|&(addr, _)| addr).collect::<Vec<Address>>())
}
/// Unlocks account with the password provided
fn unlock_account(&self, account: &Address, pass: &str) -> Result<(), EncryptedHashMapError> {
self.secret_store.read().unwrap().unlock_account(account, pass)
}
fn unlock_account_temp(&self, account: &Address, pass: &str) -> Result<(), EncryptedHashMapError> {
self.secret_store.read().unwrap().unlock_account_temp(account, pass)
}
/// Creates account
fn new_account(&self, pass: &str) -> Result<Address, ::std::io::Error> {
self.secret_store.write().unwrap().new_account(pass)
}
/// Returns secret for unlocked account
fn account_secret(&self, account: &Address) -> Result<crypto::Secret, SigningError> {
self.secret_store.read().unwrap().account_secret(account)
}
/// Returns secret for locked account given passphrase.
fn locked_account_secret(&self, account: &Address, pass: &str) -> Result<crypto::Secret, SigningError> {
self.secret_store.read().unwrap().locked_account_secret(account, pass)
}
/// Signs a message using key of given unlocked account address.
fn sign(&self, account: &Address, message: &H256) -> Result<crypto::Signature, SigningError> {
self.secret_store.read().unwrap().sign(account, message)
}
}
/// Which set of keys to import.
#[derive(PartialEq)]
pub enum ImportKeySet {
/// Empty set.
None,
/// Import legacy client's general keys.
Legacy,
/// Import legacy client's testnet keys.
LegacyTestnet,
}
impl AccountService {
/// New account service with the keys store in specific location and configured security parameters.
pub fn with_security(path: &Path, key_iterations: u32, import_keys: ImportKeySet) -> Self {
let secret_store = RwLock::new(SecretStore::with_security(path, key_iterations));
match import_keys {
ImportKeySet::None => {}
_ => { secret_store.write().unwrap().try_import_existing(import_keys == ImportKeySet::LegacyTestnet); }
}
AccountService {
secret_store: secret_store,
}
}
#[cfg(test)]
fn new_test(temp: &::devtools::RandomTempPath) -> Self {
let secret_store = RwLock::new(SecretStore::new_test(temp));
AccountService {
secret_store: secret_store
}
}
/// Ticks the account service
pub fn tick(&self) {
self.secret_store.write().unwrap().collect_garbage();
}
/// Unlocks account for use (no expiration of unlock)
pub fn unlock_account_no_expire(&self, account: &Address, pass: &str) -> Result<(), EncryptedHashMapError> {
self.secret_store.write().unwrap().unlock_account_with_expiration(account, pass, None, false)
}
}
impl SecretStore {
/// new instance of Secret Store in specific directory
pub fn new_in(path: &Path) -> Self {
SecretStore::with_security(path, KEY_ITERATIONS)
}
/// new instance of Secret Store in specific directory and configured security parameters
pub fn with_security(path: &Path, key_iterations: u32) -> Self {
::std::fs::create_dir_all(&path).expect("Cannot access requested key directory - critical");
SecretStore {
directory: KeyDirectory::new(path),
unlocks: RwLock::new(HashMap::new()),
key_iterations: key_iterations,
}
}
/// trys to import keys in the known locations
pub fn try_import_existing(&mut self, is_testnet: bool) {
use keys::geth_import;
let import_path = geth_import::keystore_dir(is_testnet);
if let Err(e) = geth_import::import_geth_keys(self, &import_path) {
trace!(target: "sstore", "Geth key not imported: {:?}", e);
}
}
/// Lists all accounts and corresponding key ids
pub fn accounts(&self) -> Result<Vec<(Address, H128)>, ::std::io::Error> {
let accounts = try!(self.directory.list()).iter().map(|key_id| self.directory.get(key_id))
.filter(|key| key.is_some())
.map(|key| { let some_key = key.unwrap(); (some_key.account, some_key.id) })
.filter(|&(ref account, _)| account.is_some())
.map(|(account, id)| (account.unwrap(), id))
.collect::<Vec<(Address, H128)>>();
Ok(accounts)
}
/// Resolves key_id by account address
pub fn account(&self, account: &Address) -> Option<H128> {
let mut accounts = match self.accounts() {
Ok(accounts) => accounts,
Err(e) => { warn!(target: "sstore", "Failed to load accounts: {}", e); return None; }
};
accounts.retain(|&(ref store_account, _)| account == store_account);
accounts.first().and_then(|&(_, ref key_id)| Some(key_id.clone()))
}
/// Imports pregenerated key, returns error if not saved correctly
pub fn import_key(&mut self, key_file: KeyFileContent) -> Result<(), ::std::io::Error> {
try!(self.directory.save(key_file));
Ok(())
}
#[cfg(test)]
fn new_test(path: &::devtools::RandomTempPath) -> SecretStore {
SecretStore {
directory: KeyDirectory::new(path.as_path()),
unlocks: RwLock::new(HashMap::new()),
key_iterations: KEY_ITERATIONS,
}
}
/// Unlocks account for use
pub fn unlock_account(&self, account: &Address, pass: &str) -> Result<(), EncryptedHashMapError> {
self.unlock_account_with_expiration(account, pass, Some(UTC::now() + Duration::minutes(20)), false)
}
/// Unlocks account for use (no expiration of unlock)
pub fn unlock_account_no_expire(&self, account: &Address, pass: &str) -> Result<(), EncryptedHashMapError> {
self.unlock_account_with_expiration(account, pass, None, false)
}
/// Unlocks account for use (no expiration of unlock)
pub fn unlock_account_temp(&self, account: &Address, pass: &str) -> Result<(), EncryptedHashMapError> {
self.unlock_account_with_expiration(account, pass, None, true)
}
fn unlock_account_with_expiration(&self, account: &Address, pass: &str, expiration: Option<DateTime<UTC>>, relock_on_use: bool) -> Result<(), EncryptedHashMapError> {
let secret_id = try!(self.account(&account).ok_or(EncryptedHashMapError::UnknownIdentifier));
let secret = try!(self.get(&secret_id, pass));
{
let mut write_lock = self.unlocks.write().unwrap();
let mut unlock = write_lock.entry(*account)
.or_insert_with(|| AccountUnlock { secret: secret, expires: Some(UTC::now()), relock_on_use: relock_on_use });
unlock.secret = secret;
unlock.expires = expiration;
}
Ok(())
}
/// Creates new account
pub fn new_account(&mut self, pass: &str) -> Result<Address, ::std::io::Error> {
let key_pair = crypto::KeyPair::create().expect("Error creating key-pair. Something wrong with crypto libraries?");
let address = Address::from(key_pair.public().sha3());
let key_id = H128::random();
self.insert(key_id.clone(), key_pair.secret().clone(), pass);
let mut key_file = self.directory.get(&key_id).expect("the key was just inserted");
key_file.account = Some(address);
try!(self.directory.save(key_file));
Ok(address)
}
/// Signs message with unlocked account.
pub fn sign(&self, account: &Address, message: &H256) -> Result<crypto::Signature, SigningError> {
let (relock, ret) = {
let read_lock = self.unlocks.read().unwrap();
if let Some(unlock) = read_lock.get(account) {
(unlock.relock_on_use, match crypto::KeyPair::from_secret(unlock.secret) {
Ok(pair) => match pair.sign(message) {
Ok(signature) => Ok(signature),
Err(_) => Err(SigningError::InvalidSecret)
},
Err(_) => Err(SigningError::InvalidSecret)
})
} else {
(false, Err(SigningError::AccountNotUnlocked))
}
};
if relock {
self.unlocks.write().unwrap().remove(account);
}
ret
}
/// Returns secret for unlocked account.
pub fn account_secret(&self, account: &Address) -> Result<crypto::Secret, SigningError> {
let (relock, ret) = {
let read_lock = self.unlocks.read().unwrap();
if let Some(unlock) = read_lock.get(account) {
(unlock.relock_on_use, Ok(unlock.secret as crypto::Secret))
} else {
(false, Err(SigningError::AccountNotUnlocked))
}
};
if relock {
self.unlocks.write().unwrap().remove(account);
}
ret
}
/// Returns secret for locked account.
pub fn locked_account_secret(&self, account: &Address, pass: &str) -> Result<crypto::Secret, SigningError> {
let secret_id = try!(self.account(&account).ok_or(SigningError::NoAccount));
self.get(&secret_id, pass).or_else(|e| Err(match e {
EncryptedHashMapError::InvalidPassword => SigningError::InvalidPassword,
EncryptedHashMapError::UnknownIdentifier => SigningError::NoAccount,
EncryptedHashMapError::InvalidValueFormat(_) => SigningError::InvalidSecret,
}))
}
/// Makes account unlocks expire and removes unused key files from memory
pub fn collect_garbage(&mut self) {
let mut garbage_lock = self.unlocks.write().unwrap();
self.directory.collect_garbage();
let utc = UTC::now();
let expired_addresses = garbage_lock.iter()
.filter(|&(_, unlock)| match unlock.expires { Some(ref expire_val) => expire_val < &utc, _ => false })
.map(|(address, _)| address.clone()).collect::<Vec<Address>>();
for expired in expired_addresses { garbage_lock.remove(&expired); }
garbage_lock.shrink_to_fit();
}
}
fn derive_key_iterations(password: &str, salt: &H256, c: u32) -> (Bytes, Bytes) {
let mut h_mac = Hmac::new(::rcrypto::sha2::Sha256::new(), password.as_bytes());
let mut derived_key = vec![0u8; KEY_LENGTH_USIZE];
pbkdf2(&mut h_mac, &salt.as_slice(), c, &mut derived_key);
let derived_right_bits = &derived_key[0..KEY_LENGTH_AES_USIZE];
let derived_left_bits = &derived_key[KEY_LENGTH_AES_USIZE..KEY_LENGTH_USIZE];
(derived_right_bits.to_vec(), derived_left_bits.to_vec())
}
fn derive_key(password: &str, salt: &H256, iterations: u32) -> (Bytes, Bytes) {
derive_key_iterations(password, salt, iterations)
}
fn derive_key_scrypt(password: &str, salt: &H256, n: u32, p: u32, r: u32) -> (Bytes, Bytes) {
let mut derived_key = vec![0u8; KEY_LENGTH_USIZE];
let scrypt_params = ScryptParams::new(n.trailing_zeros() as u8, r, p);
scrypt(password.as_bytes(), &salt.as_slice(), &scrypt_params, &mut derived_key);
let derived_right_bits = &derived_key[0..KEY_LENGTH_AES_USIZE];
let derived_left_bits = &derived_key[KEY_LENGTH_AES_USIZE..KEY_LENGTH_USIZE];
(derived_right_bits.to_vec(), derived_left_bits.to_vec())
}
fn derive_mac(derived_left_bits: &[u8], cipher_text: &[u8]) -> Bytes {
let mut mac = vec![0u8; KEY_LENGTH_AES_USIZE + cipher_text.len()];
mac[0..KEY_LENGTH_AES_USIZE].clone_from_slice(derived_left_bits);
mac[KEY_LENGTH_AES_USIZE..cipher_text.len()+KEY_LENGTH_AES_USIZE].clone_from_slice(cipher_text);
mac
}
impl EncryptedHashMap<H128> for SecretStore {
fn get<Value: FromRawBytesVariable + BytesConvertable>(&self, key: &H128, password: &str) -> Result<Value, EncryptedHashMapError> {
match self.directory.get(key) {
Some(key_file) => {
let (derived_left_bits, derived_right_bits) = match key_file.crypto.kdf {
KeyFileKdf::Pbkdf2(ref params) => derive_key_iterations(password, &params.salt, params.c),
KeyFileKdf::Scrypt(ref params) => derive_key_scrypt(password, &params.salt, params.n, params.p, params.r)
};
if derive_mac(&derived_right_bits, &key_file.crypto.cipher_text)
.sha3() != key_file.crypto.mac { return Err(EncryptedHashMapError::InvalidPassword); }
let mut val = vec![0u8; key_file.crypto.cipher_text.len()];
match key_file.crypto.cipher_type {
CryptoCipherType::Aes128Ctr(ref iv) => {
crypto::aes::decrypt(&derived_left_bits, &iv.as_slice(), &key_file.crypto.cipher_text, &mut val);
}
};
match Value::from_bytes_variable(&val) {
Ok(value) => Ok(value),
Err(bytes_error) => Err(EncryptedHashMapError::InvalidValueFormat(bytes_error))
}
},
None => Err(EncryptedHashMapError::UnknownIdentifier)
}
}
fn insert<Value: FromRawBytesVariable + BytesConvertable>(&mut self, key: H128, value: Value, password: &str) -> Option<Value> {
let previous = if let Ok(previous_value) = self.get(&key, password) { Some(previous_value) } else { None };
// crypto random initiators
let salt = H256::random();
let iv = H128::random();
// two parts of derived key
// DK = [ DK[0..15] DK[16..31] ] = [derived_left_bits, derived_right_bits]
let (derived_left_bits, derived_right_bits) = derive_key(password, &salt, self.key_iterations);
let mut cipher_text = vec![0u8; value.as_slice().len()];
// aes-128-ctr with initial vector of iv
crypto::aes::encrypt(&derived_left_bits, &iv.clone(), &value.as_slice(), &mut cipher_text);
// KECCAK(DK[16..31] ++ <ciphertext>), where DK[16..31] - derived_right_bits
let mac = derive_mac(&derived_right_bits, &cipher_text.clone()).sha3();
let mut key_file = KeyFileContent::new(
KeyFileCrypto::new_pbkdf2(
cipher_text,
iv,
salt,
mac,
self.key_iterations,
KEY_LENGTH));
key_file.id = key;
if let Err(io_error) = self.directory.save(key_file) {
warn!("Error saving key file: {:?}", io_error);
}
previous
}
fn remove<Value: FromRawBytesVariable + BytesConvertable>(&mut self, key: &H128, password: Option<&str>) -> Option<Value> {
let previous = if let Some(pass) = password {
if let Ok(previous_value) = self.get(&key, pass) { Some(previous_value) } else { None }
}
else { None };
if let Err(io_error) = self.directory.delete(key) {
warn!("Error saving key file: {:?}", io_error);
}
previous
}
}
#[cfg(all(test, feature="heavy-tests"))]
mod vector_tests {
use super::{derive_mac,derive_key_iterations};
use common::*;
#[test]
fn mac_vector() {
let password = "testpassword";
let salt = H256::from_str("ae3cd4e7013836a3df6bd7241b12db061dbe2c6785853cce422d148a624ce0bd").unwrap();
let cipher_text = FromHex::from_hex("5318b4d5bcd28de64ee5559e671353e16f075ecae9f99c7a79a38af5f869aa46").unwrap();
let iterations = 262144u32;
let (derived_left_bits, derived_right_bits) = derive_key_iterations(password, &salt, iterations);
assert_eq!("f06d69cdc7da0faffb1008270bca38f5", derived_left_bits.to_hex());
assert_eq!("e31891a3a773950e6d0fea48a7188551", derived_right_bits.to_hex());
let mac_body = derive_mac(&derived_right_bits, &cipher_text);
assert_eq!("e31891a3a773950e6d0fea48a71885515318b4d5bcd28de64ee5559e671353e16f075ecae9f99c7a79a38af5f869aa46", mac_body.to_hex());
let mac = mac_body.sha3();
assert_eq!("517ead924a9d0dc3124507e3393d175ce3ff7c1e96529c6c555ce9e51205e9b2", format!("{:?}", mac));
}
}
#[cfg(test)]
mod tests {
use super::*;
use devtools::*;
use common::*;
use crypto::KeyPair;
use chrono::*;
#[test]
fn can_insert() {
let temp = RandomTempPath::create_dir();
let mut sstore = SecretStore::new_test(&temp);
let id = H128::random();
sstore.insert(id.clone(), "Cat".to_owned(), "pass");
assert!(sstore.get::<String>(&id, "pass").is_ok());
}
#[test]
fn can_get_fail() {
let temp = RandomTempPath::create_dir();
{
use keys::directory::{KeyFileContent, KeyFileCrypto};
let mut write_sstore = SecretStore::new_test(&temp);
write_sstore.directory.save(
KeyFileContent::new(
KeyFileCrypto::new_pbkdf2(
FromHex::from_hex("5318b4d5bcd28de64ee5559e671353e16f075ecae9f99c7a79a38af5f869aa46").unwrap(),
H128::from_str("6087dab2f9fdbbfaddc31a909735c1e6").unwrap(),
H256::from_str("ae3cd4e7013836a3df6bd7241b12db061dbe2c6785853cce422d148a624ce0bd").unwrap(),
H256::from_str("517ead924a9d0dc3124507e3393d175ce3ff7c1e96529c6c555ce9e51205e9b2").unwrap(),
262144,
32)))
.unwrap();
}
let sstore = SecretStore::new_test(&temp);
if let Ok(_) = sstore.get::<Bytes>(&H128::from_str("3198bc9c66725ab3d9954942343ae5b6").unwrap(), "testpassword") {
panic!("should be error loading key, we requested the wrong key");
}
}
fn pregenerate_keys(temp: &RandomTempPath, count: usize) -> Vec<H128> {
use keys::directory::{KeyFileContent, KeyFileCrypto};
let mut write_sstore = SecretStore::new_test(&temp);
let mut result = Vec::new();
for _ in 0..count {
result.push(write_sstore.directory.save(
KeyFileContent::new(
KeyFileCrypto::new_pbkdf2(
FromHex::from_hex("5318b4d5bcd28de64ee5559e671353e16f075ecae9f99c7a79a38af5f869aa46").unwrap(),
H128::from_str("6087dab2f9fdbbfaddc31a909735c1e6").unwrap(),
H256::from_str("ae3cd4e7013836a3df6bd7241b12db061dbe2c6785853cce422d148a624ce0bd").unwrap(),
H256::from_str("517ead924a9d0dc3124507e3393d175ce3ff7c1e96529c6c555ce9e51205e9b2").unwrap(),
262144,
32)))
.unwrap());
}
result
}
fn pregenerate_accounts(temp: &RandomTempPath, count: usize) -> Vec<H128> {
use keys::directory::{KeyFileContent, KeyFileCrypto};
let mut write_sstore = SecretStore::new_test(&temp);
let mut result = Vec::new();
for i in 0..count {
let mut key_file =
KeyFileContent::new(
KeyFileCrypto::new_pbkdf2(
FromHex::from_hex("5318b4d5bcd28de64ee5559e671353e16f075ecae9f99c7a79a38af5f869aa46").unwrap(),
H128::from_str("6087dab2f9fdbbfaddc31a909735c1e6").unwrap(),
H256::from_str("ae3cd4e7013836a3df6bd7241b12db061dbe2c6785853cce422d148a624ce0bd").unwrap(),
H256::from_str("517ead924a9d0dc3124507e3393d175ce3ff7c1e96529c6c555ce9e51205e9b2").unwrap(),
262144,
32));
key_file.account = Some((i as u64).into());
result.push(key_file.id.clone());
write_sstore.import_key(key_file).unwrap();
}
result
}
#[test]
#[cfg(feature="heavy-tests")]
fn can_get() {
let temp = RandomTempPath::create_dir();
let key_id = {
use keys::directory::{KeyFileContent, KeyFileCrypto};
let mut write_sstore = SecretStore::new_test(&temp);
write_sstore.directory.save(
KeyFileContent::new(
KeyFileCrypto::new_pbkdf2(
FromHex::from_hex("5318b4d5bcd28de64ee5559e671353e16f075ecae9f99c7a79a38af5f869aa46").unwrap(),
H128::from_str("6087dab2f9fdbbfaddc31a909735c1e6").unwrap(),
H256::from_str("ae3cd4e7013836a3df6bd7241b12db061dbe2c6785853cce422d148a624ce0bd").unwrap(),
H256::from_str("517ead924a9d0dc3124507e3393d175ce3ff7c1e96529c6c555ce9e51205e9b2").unwrap(),
262144,
32)))
.unwrap()
};
let sstore = SecretStore::new_test(&temp);
if let Err(e) = sstore.get::<Bytes>(&key_id, "testpassword") {
panic!("got no key: {:?}", e);
}
}
#[test]
fn can_delete() {
let temp = RandomTempPath::create_dir();
let keys = pregenerate_keys(&temp, 5);
let mut sstore = SecretStore::new_test(&temp);
sstore.delete(&keys[2]);
assert_eq!(4, sstore.directory.list().unwrap().len())
}
#[test]
fn can_create_account() {
let temp = RandomTempPath::create_dir();
let mut sstore = SecretStore::new_test(&temp);
sstore.new_account("123").unwrap();
assert_eq!(1, sstore.accounts().unwrap().len());
}
#[test]
fn can_unlock_account() {
let temp = RandomTempPath::create_dir();
let mut sstore = SecretStore::new_test(&temp);
let address = sstore.new_account("123").unwrap();
let secret = sstore.unlock_account(&address, "123");
assert!(secret.is_ok());
}
#[test]
fn can_sign_data() {
let temp = RandomTempPath::create_dir();
let address = {
let mut sstore = SecretStore::new_test(&temp);
sstore.new_account("334").unwrap()
};
let signature = {
let sstore = SecretStore::new_test(&temp);
sstore.unlock_account(&address, "334").unwrap();
sstore.sign(&address, &H256::random()).unwrap()
};
assert!(signature != 0.into());
}
#[test]
fn can_relock_temp_account() {
let temp = RandomTempPath::create_dir();
let address = {
let mut sstore = SecretStore::new_test(&temp);
sstore.new_account("334").unwrap()
};
let signature = {
let sstore = SecretStore::new_test(&temp);
sstore.unlock_account_temp(&address, "334").unwrap();
sstore.sign(&address, &H256::random()).unwrap();
sstore.sign(&address, &H256::random())
};
assert!(signature.is_err());
let secret = {
let sstore = SecretStore::new_test(&temp);
sstore.unlock_account_temp(&address, "334").unwrap();
sstore.account_secret(&address).unwrap();
sstore.account_secret(&address)
};
assert!(secret.is_err());
}
#[test]
fn can_import_account() {
use keys::directory::{KeyFileContent, KeyFileCrypto};
let temp = RandomTempPath::create_dir();
let mut key_file =
KeyFileContent::new(
KeyFileCrypto::new_pbkdf2(
FromHex::from_hex("5318b4d5bcd28de64ee5559e671353e16f075ecae9f99c7a79a38af5f869aa46").unwrap(),
H128::from_str("6087dab2f9fdbbfaddc31a909735c1e6").unwrap(),
H256::from_str("ae3cd4e7013836a3df6bd7241b12db061dbe2c6785853cce422d148a624ce0bd").unwrap(),
H256::from_str("517ead924a9d0dc3124507e3393d175ce3ff7c1e96529c6c555ce9e51205e9b2").unwrap(),
262144,
32));
key_file.account = Some(Address::from_str("3f49624084b67849c7b4e805c5988c21a430f9d9").unwrap());
let mut sstore = SecretStore::new_test(&temp);
sstore.import_key(key_file).unwrap();
assert_eq!(1, sstore.accounts().unwrap().len());
assert!(sstore.account(&Address::from_str("3f49624084b67849c7b4e805c5988c21a430f9d9").unwrap()).is_some());
}
#[test]
fn can_list_accounts() {
let temp = RandomTempPath::create_dir();
pregenerate_accounts(&temp, 30);
let sstore = SecretStore::new_test(&temp);
let accounts = sstore.accounts().unwrap();
assert_eq!(30, accounts.len());
}
#[test]
fn validate_generated_addresses() {
let temp = RandomTempPath::create_dir();
let mut sstore = SecretStore::new_test(&temp);
let addr = sstore.new_account("test").unwrap();
sstore.unlock_account(&addr, "test").unwrap();
let secret = sstore.account_secret(&addr).unwrap();
let kp = KeyPair::from_secret(secret).unwrap();
assert_eq!(Address::from(kp.public().sha3()), addr);
}
#[test]
fn secret_for_locked_account() {
// given
let temp = RandomTempPath::create_dir();
let mut sstore = SecretStore::new_test(&temp);
let addr = sstore.new_account("test-pass").unwrap();
// when
// Invalid pass
let secret1 = sstore.locked_account_secret(&addr, "test-pass123");
// Valid pass
let secret2 = sstore.locked_account_secret(&addr, "test-pass");
// Account not unlocked
let secret3 = sstore.account_secret(&addr);
assert!(secret1.is_err(), "Invalid password should not return secret.");
assert!(secret2.is_ok(), "Should return secret provided valid passphrase.");
assert!(secret3.is_err(), "Account should still be locked.");
}
#[test]
fn can_create_service() {
let temp = RandomTempPath::create_dir();
let svc = AccountService::new_test(&temp);
assert!(svc.accounts().unwrap().is_empty());
}
#[test]
fn accounts_expire() {
use std::collections::hash_map::*;
let temp = RandomTempPath::create_dir();
let svc = AccountService::new_test(&temp);
let address = svc.new_account("pass").unwrap();
svc.unlock_account(&address, "pass").unwrap();
assert!(svc.account_secret(&address).is_ok());
{
let ss_rw = svc.secret_store.write().unwrap();
let mut ua_rw = ss_rw.unlocks.write().unwrap();
let entry = ua_rw.entry(address);
if let Entry::Occupied(mut occupied) = entry { occupied.get_mut().expires = Some(UTC::now() - Duration::minutes(1)) }
}
svc.tick();
assert!(svc.account_secret(&address).is_err());
}
}

View File

@@ -1,118 +0,0 @@
// Copyright 2015, 2016 Ethcore (UK) Ltd.
// This file is part of Parity.
// Parity is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// Parity is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with Parity. If not, see <http://www.gnu.org/licenses/>.
//! Test implementation of account provider.
use std::sync::RwLock;
use std::collections::HashMap;
use std::io;
use hash::Address;
use crypto::{Secret, KeyPair};
use super::store::{AccountProvider, SigningError, EncryptedHashMapError};
/// Account mock.
#[derive(Clone)]
pub struct TestAccount {
/// True if account is unlocked.
pub unlocked: bool,
/// Account's password.
pub password: String,
/// Account's secret.
pub secret: Secret,
}
impl TestAccount {
/// Creates new test account.
pub fn new(password: &str) -> Self {
let pair = KeyPair::create().unwrap();
TestAccount {
unlocked: false,
password: password.to_owned(),
secret: pair.secret().clone()
}
}
/// Returns account address.
pub fn address(&self) -> Address {
KeyPair::from_secret(self.secret.clone()).unwrap().address()
}
}
/// Test account provider.
pub struct TestAccountProvider {
/// Test provider accounts.
pub accounts: RwLock<HashMap<Address, TestAccount>>,
}
impl TestAccountProvider {
/// Basic constructor.
pub fn new(accounts: HashMap<Address, TestAccount>) -> Self {
TestAccountProvider {
accounts: RwLock::new(accounts),
}
}
}
impl AccountProvider for TestAccountProvider {
fn accounts(&self) -> Result<Vec<Address>, io::Error> {
Ok(self.accounts.read().unwrap().keys().cloned().collect())
}
fn unlock_account(&self, account: &Address, pass: &str) -> Result<(), EncryptedHashMapError> {
match self.accounts.write().unwrap().get_mut(account) {
Some(ref mut acc) if acc.password == pass => {
acc.unlocked = true;
Ok(())
},
Some(_) => Err(EncryptedHashMapError::InvalidPassword),
None => Err(EncryptedHashMapError::UnknownIdentifier),
}
}
fn unlock_account_temp(&self, account: &Address, pass: &str) -> Result<(), EncryptedHashMapError> {
// TODO; actually make it relock on use
self.unlock_account(account, pass)
}
fn new_account(&self, pass: &str) -> Result<Address, io::Error> {
let account = TestAccount::new(pass);
let address = KeyPair::from_secret(account.secret.clone()).unwrap().address();
self.accounts.write().unwrap().insert(address.clone(), account);
Ok(address)
}
fn account_secret(&self, address: &Address) -> Result<Secret, SigningError> {
// todo: consider checking if account is unlock. some test may need alteration then.
self.accounts
.read()
.unwrap()
.get(address)
.ok_or(SigningError::NoAccount)
.map(|acc| acc.secret.clone())
}
fn locked_account_secret(&self, address: &Address, pass: &str) -> Result<Secret, SigningError> {
let accounts = self.accounts.read().unwrap();
match accounts.get(address) {
Some(ref acc) if acc.password == pass => {
Ok(acc.secret.clone())
},
Some(ref _acc) => Err(SigningError::InvalidPassword),
_ => Err(SigningError::NoAccount),
}
}
}

View File

@@ -20,6 +20,9 @@ use std::default::Default;
use rocksdb::{DB, Writable, WriteBatch, IteratorMode, DBVector, DBIterator,
IndexType, Options, DBCompactionStyle, BlockBasedOptions, Direction};
const DB_BACKGROUND_FLUSHES: i32 = 2;
const DB_BACKGROUND_COMPACTIONS: i32 = 2;
/// Write transaction. Batches a sequence of put/delete operations for efficiency.
pub struct DBTransaction {
batch: WriteBatch,
@@ -48,12 +51,81 @@ impl DBTransaction {
}
}
/// Compaction profile for the database settings
pub struct CompactionProfile {
/// L0-L1 target file size
pub initial_file_size: u64,
/// L2-LN target file size multiplier
pub file_size_multiplier: i32,
/// rate limiter for background flushes and compactions, bytes/sec, if any
pub write_rate_limit: Option<u64>,
}
impl CompactionProfile {
/// Default profile suitable for most storage
pub fn default() -> CompactionProfile {
CompactionProfile {
initial_file_size: 32 * 1024 * 1024,
file_size_multiplier: 2,
write_rate_limit: None,
}
}
/// Slow hdd compaction profile
pub fn hdd() -> CompactionProfile {
CompactionProfile {
initial_file_size: 192 * 1024 * 1024,
file_size_multiplier: 1,
write_rate_limit: Some(8 * 1024 * 1024),
}
}
}
/// Database configuration
pub struct DatabaseConfig {
/// Optional prefix size in bytes. Allows lookup by partial key.
pub prefix_size: Option<usize>,
/// Max number of open files.
pub max_open_files: i32,
/// Cache-size
pub cache_size: Option<usize>,
/// Compaction profile
pub compaction: CompactionProfile,
}
impl DatabaseConfig {
/// Database with default settings and specified cache size
pub fn with_cache(cache_size: usize) -> DatabaseConfig {
DatabaseConfig {
cache_size: Some(cache_size),
prefix_size: None,
max_open_files: -1,
compaction: CompactionProfile::default(),
}
}
/// Modify the compaction profile
pub fn compaction(mut self, profile: CompactionProfile) -> Self {
self.compaction = profile;
self
}
/// Modify the prefix of the db
pub fn prefix(mut self, prefix_size: usize) -> Self {
self.prefix_size = Some(prefix_size);
self
}
}
impl Default for DatabaseConfig {
fn default() -> DatabaseConfig {
DatabaseConfig {
cache_size: None,
prefix_size: None,
max_open_files: -1,
compaction: CompactionProfile::default(),
}
}
}
/// Database iterator
@@ -77,16 +149,32 @@ pub struct Database {
impl Database {
/// Open database with default settings.
pub fn open_default(path: &str) -> Result<Database, String> {
Database::open(&DatabaseConfig { prefix_size: None, max_open_files: 256 }, path)
Database::open(&DatabaseConfig::default(), path)
}
/// Open database file. Creates if it does not exist.
pub fn open(config: &DatabaseConfig, path: &str) -> Result<Database, String> {
let mut opts = Options::new();
if let Some(rate_limit) = config.compaction.write_rate_limit {
try!(opts.set_parsed_options(&format!("rate_limiter_bytes_per_sec={}", rate_limit)));
}
opts.set_max_open_files(config.max_open_files);
opts.create_if_missing(true);
opts.set_use_fsync(false);
// compaction settings
opts.set_compaction_style(DBCompactionStyle::DBUniversalCompaction);
opts.set_target_file_size_base(config.compaction.initial_file_size);
opts.set_target_file_size_multiplier(config.compaction.file_size_multiplier);
opts.set_max_background_flushes(DB_BACKGROUND_FLUSHES);
opts.set_max_background_compactions(DB_BACKGROUND_COMPACTIONS);
if let Some(cache_size) = config.cache_size {
// half goes to read cache
opts.set_block_cache_size_mb(cache_size as u64 / 2);
// quarter goes to each of the two write buffers
opts.set_write_buffer_size(cache_size * 1024 * 256);
}
/*
opts.set_bytes_per_sync(8388608);
opts.set_disable_data_sync(false);
@@ -111,7 +199,16 @@ impl Database {
opts.set_block_based_table_factory(&block_opts);
opts.set_prefix_extractor_fixed_size(size);
}
let db = try!(DB::open(&opts, path));
let db = match DB::open(&opts, path) {
Ok(db) => db,
Err(ref s) if s.starts_with("Corruption:") => {
info!("{}", s);
info!("Attempting DB repair for {}", path);
try!(DB::repair(&opts, path));
try!(DB::open(&opts, path))
},
Err(s) => { return Err(s); }
};
Ok(Database { db: db })
}
@@ -205,10 +302,10 @@ mod tests {
let path = RandomTempPath::create_dir();
let smoke = Database::open_default(path.as_path().to_str().unwrap()).unwrap();
assert!(smoke.is_empty());
test_db(&DatabaseConfig { prefix_size: None, max_open_files: 256 });
test_db(&DatabaseConfig { prefix_size: Some(1), max_open_files: 256 });
test_db(&DatabaseConfig { prefix_size: Some(8), max_open_files: 256 });
test_db(&DatabaseConfig { prefix_size: Some(32), max_open_files: 256 });
test_db(&DatabaseConfig::default());
test_db(&DatabaseConfig::default().prefix(12));
test_db(&DatabaseConfig::default().prefix(22));
test_db(&DatabaseConfig::default().prefix(8));
}
}

View File

@@ -151,11 +151,11 @@ pub mod io;
pub mod network;
pub mod log;
pub mod panics;
pub mod keys;
pub mod table;
pub mod network_settings;
pub mod path;
pub mod snappy;
mod timer;
pub use common::*;
pub use misc::*;
@@ -177,6 +177,7 @@ pub use network::*;
pub use io::*;
pub use log::*;
pub use kvdb::*;
pub use timer::*;
#[cfg(test)]
mod tests {

View File

@@ -162,7 +162,7 @@ impl MemoryDB {
static NULL_RLP_STATIC: [u8; 1] = [0x80; 1];
impl HashDB for MemoryDB {
fn lookup(&self, key: &H256) -> Option<&[u8]> {
fn get(&self, key: &H256) -> Option<&[u8]> {
if key == &SHA3_NULL_RLP {
return Some(&NULL_RLP_STATIC);
}
@@ -176,7 +176,7 @@ impl HashDB for MemoryDB {
self.data.iter().filter_map(|(k, v)| if v.1 != 0 {Some((k.clone(), v.1))} else {None}).collect()
}
fn exists(&self, key: &H256) -> bool {
fn contains(&self, key: &H256) -> bool {
if key == &SHA3_NULL_RLP {
return true;
}
@@ -222,7 +222,7 @@ impl HashDB for MemoryDB {
self.data.insert(key, (value, 1));
}
fn kill(&mut self, key: &H256) {
fn remove(&mut self, key: &H256) {
if key == &SHA3_NULL_RLP {
return;
}

View File

@@ -17,6 +17,7 @@
use std::sync::Arc;
use std::collections::VecDeque;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
use mio::{Handler, Token, EventSet, EventLoop, PollOpt, TryRead, TryWrite};
use mio::tcp::*;
use hash::*;
@@ -60,45 +61,57 @@ pub struct GenericConnection<Socket: GenericSocket> {
interest: EventSet,
/// Shared network statistics
stats: Arc<NetworkStats>,
/// Registered flag
registered: AtomicBool,
}
impl<Socket: GenericSocket> GenericConnection<Socket> {
pub fn expect(&mut self, size: usize) {
trace!(target:"network", "Expect to read {} bytes", size);
if self.rec_size != self.rec_buf.len() {
warn!(target:"net", "Unexpected connection read start");
warn!(target:"network", "Unexpected connection read start");
}
unsafe { self.rec_buf.set_len(0) }
self.rec_size = size;
}
/// Readable IO handler. Called when there is some data to be read.
pub fn readable(&mut self) -> io::Result<Option<Bytes>> {
if self.rec_size == 0 || self.rec_buf.len() >= self.rec_size {
warn!(target:"net", "Unexpected connection read");
warn!(target:"network", "Unexpected connection read");
}
let max = self.rec_size - self.rec_buf.len();
// resolve "multiple applicable items in scope [E0034]" error
let sock_ref = <Socket as Read>::by_ref(&mut self.socket);
match sock_ref.take(max as u64).try_read_buf(&mut self.rec_buf) {
Ok(Some(size)) if size != 0 => {
self.stats.inc_recv(size);
if self.rec_size != 0 && self.rec_buf.len() == self.rec_size {
self.rec_size = 0;
Ok(Some(::std::mem::replace(&mut self.rec_buf, Bytes::new())))
} else { Ok(None) }
},
Ok(_) => Ok(None),
Err(e) => Err(e),
}
}
loop {
let max = self.rec_size - self.rec_buf.len();
match sock_ref.take(max as u64).try_read_buf(&mut self.rec_buf) {
Ok(Some(size)) if size != 0 => {
self.stats.inc_recv(size);
trace!(target:"network", "{}: Read {} of {} bytes", self.token, self.rec_buf.len(), self.rec_size);
if self.rec_size != 0 && self.rec_buf.len() == self.rec_size {
self.rec_size = 0;
return Ok(Some(::std::mem::replace(&mut self.rec_buf, Bytes::new())))
}
else if self.rec_buf.len() > self.rec_size {
warn!(target:"network", "Read past buffer {} bytes", self.rec_buf.len() - self.rec_size);
return Ok(Some(::std::mem::replace(&mut self.rec_buf, Bytes::new())))
}
},
Ok(_) => return Ok(None),
Err(e) => {
debug!(target:"network", "Read error {} ({})", self.token, e);
return Err(e)
}
}
}
}
/// Add a packet to send queue.
pub fn send(&mut self, data: Bytes) {
pub fn send<Message>(&mut self, io: &IoContext<Message>, data: Bytes) where Message: Send + Clone {
if !data.is_empty() {
self.send_queue.push_back(Cursor::new(data));
}
if !self.interest.is_writable() {
self.interest.insert(EventSet::writable());
io.update_registration(self.token).ok();
}
}
@@ -108,7 +121,7 @@ impl<Socket: GenericSocket> GenericConnection<Socket> {
}
/// Writable IO handler. Called when the socket is ready to send.
pub fn writable(&mut self) -> io::Result<WriteStatus> {
pub fn writable<Message>(&mut self, io: &IoContext<Message>) -> Result<WriteStatus, UtilError> where Message: Send + Clone {
if self.send_queue.is_empty() {
return Ok(WriteStatus::Complete)
}
@@ -121,17 +134,17 @@ impl<Socket: GenericSocket> GenericConnection<Socket> {
}
match self.socket.try_write_buf(buf) {
Ok(Some(size)) if (buf.position() as usize) < send_size => {
self.interest.insert(EventSet::writable());
self.stats.inc_send(size);
Ok(WriteStatus::Ongoing)
},
Ok(Some(size)) if (buf.position() as usize) == send_size => {
self.stats.inc_send(size);
trace!(target:"network", "{}: Wrote {} bytes", self.token, send_size);
Ok(WriteStatus::Complete)
},
Ok(Some(_)) => { panic!("Wrote past buffer");},
Ok(None) => Ok(WriteStatus::Ongoing),
Err(e) => Err(e)
Err(e) => try!(Err(e))
}
}.and_then(|r| {
if r == WriteStatus::Complete {
@@ -139,9 +152,7 @@ impl<Socket: GenericSocket> GenericConnection<Socket> {
}
if self.send_queue.is_empty() {
self.interest.remove(EventSet::writable());
}
else {
self.interest.insert(EventSet::writable());
try!(io.update_registration(self.token));
}
Ok(r)
})
@@ -162,6 +173,7 @@ impl Connection {
rec_size: 0,
interest: EventSet::hup() | EventSet::readable(),
stats: stats,
registered: AtomicBool::new(false),
}
}
@@ -188,27 +200,36 @@ impl Connection {
rec_buf: Vec::new(),
rec_size: 0,
send_queue: self.send_queue.clone(),
interest: EventSet::hup() | EventSet::readable(),
interest: EventSet::hup(),
stats: self.stats.clone(),
registered: AtomicBool::new(false),
})
}
/// Register this connection with the IO event loop.
pub fn register_socket<Host: Handler>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> io::Result<()> {
if self.registered.load(AtomicOrdering::SeqCst) {
return Ok(());
}
trace!(target: "network", "connection register; token={:?}", reg);
if let Err(e) = event_loop.register(&self.socket, reg, self.interest, PollOpt::edge() /* | PollOpt::oneshot() */) { // TODO: oneshot is broken on windows
trace!(target: "network", "Failed to register {:?}, {:?}", reg, e);
}
self.registered.store(true, AtomicOrdering::SeqCst);
Ok(())
}
/// Update connection registration. Should be called at the end of the IO handler.
pub fn update_socket<Host: Handler>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> io::Result<()> {
trace!(target: "network", "connection reregister; token={:?}", reg);
event_loop.reregister( &self.socket, reg, self.interest, PollOpt::edge() /* | PollOpt::oneshot() */ ).or_else(|e| { // TODO: oneshot is broken on windows
trace!(target: "network", "Failed to reregister {:?}, {:?}", reg, e);
if !self.registered.load(AtomicOrdering::SeqCst) {
self.register_socket(reg, event_loop)
} else {
event_loop.reregister(&self.socket, reg, self.interest, PollOpt::edge() /* | PollOpt::oneshot() */ ).unwrap_or_else(|e| { // TODO: oneshot is broken on windows
trace!(target: "network", "Failed to reregister {:?}, {:?}", reg, e);
});
Ok(())
})
}
}
/// Delete connection registration. Should be called at the end of the IO handler.
@@ -266,7 +287,7 @@ pub struct EncryptedConnection {
}
impl EncryptedConnection {
/// Create an encrypted connection out of the handshake. Consumes a handshake object.
/// Create an encrypted connection out of the handshake.
pub fn new(handshake: &mut Handshake) -> Result<EncryptedConnection, UtilError> {
let shared = try!(crypto::ecdh::agree(handshake.ecdhe.secret(), &handshake.remote_ephemeral));
let mut nonce_material = H512::new();
@@ -320,7 +341,7 @@ impl EncryptedConnection {
}
/// Send a packet
pub fn send_packet(&mut self, payload: &[u8]) -> Result<(), UtilError> {
pub fn send_packet<Message>(&mut self, io: &IoContext<Message>, payload: &[u8]) -> Result<(), UtilError> where Message: Send + Clone {
let mut header = RlpStream::new();
let len = payload.len() as usize;
header.append_raw(&[(len >> 16) as u8, (len >> 8) as u8, len as u8], 1);
@@ -342,7 +363,7 @@ impl EncryptedConnection {
self.egress_mac.update(&packet[32..(32 + len + padding)]);
EncryptedConnection::update_mac(&mut self.egress_mac, &mut self.mac_encoder, &[0u8; 0]);
self.egress_mac.clone().finalize(&mut packet[(32 + len + padding)..]);
self.connection.send(packet);
self.connection.send(io, packet);
Ok(())
}
@@ -416,32 +437,30 @@ impl EncryptedConnection {
/// Readable IO handler. Tracker receive status and returns decoded packet if avaialable.
pub fn readable<Message>(&mut self, io: &IoContext<Message>) -> Result<Option<Packet>, UtilError> where Message: Send + Clone{
io.clear_timer(self.connection.token).unwrap();
match self.read_state {
EncryptedConnectionState::Header => {
if let Some(data) = try!(self.connection.readable()) {
try!(self.read_header(&data));
try!(io.register_timer(self.connection.token, RECIEVE_PAYLOAD_TIMEOUT));
}
Ok(None)
},
EncryptedConnectionState::Payload => {
match try!(self.connection.readable()) {
Some(data) => {
self.read_state = EncryptedConnectionState::Header;
self.connection.expect(ENCRYPTED_HEADER_LEN);
Ok(Some(try!(self.read_payload(&data))))
},
None => Ok(None)
}
try!(io.clear_timer(self.connection.token));
if let EncryptedConnectionState::Header = self.read_state {
if let Some(data) = try!(self.connection.readable()) {
try!(self.read_header(&data));
try!(io.register_timer(self.connection.token, RECIEVE_PAYLOAD_TIMEOUT));
}
};
if let EncryptedConnectionState::Payload = self.read_state {
match try!(self.connection.readable()) {
Some(data) => {
self.read_state = EncryptedConnectionState::Header;
self.connection.expect(ENCRYPTED_HEADER_LEN);
Ok(Some(try!(self.read_payload(&data))))
},
None => Ok(None)
}
} else {
Ok(None)
}
}
/// Writable IO handler. Processes send queeue.
pub fn writable<Message>(&mut self, io: &IoContext<Message>) -> Result<(), UtilError> where Message: Send + Clone {
io.clear_timer(self.connection.token).unwrap();
try!(self.connection.writable());
try!(self.connection.writable(io));
Ok(())
}
}
@@ -472,12 +491,14 @@ pub fn test_encryption() {
mod tests {
use super::*;
use std::sync::*;
use std::sync::atomic::AtomicBool;
use super::super::stats::*;
use std::io::{Read, Write, Error, Cursor, ErrorKind};
use mio::{EventSet};
use std::collections::VecDeque;
use bytes::*;
use devtools::*;
use io::*;
impl GenericSocket for TestSocket {}
@@ -521,6 +542,7 @@ mod tests {
rec_size: 0,
interest: EventSet::hup() | EventSet::readable(),
stats: Arc::<NetworkStats>::new(NetworkStats::new()),
registered: AtomicBool::new(false),
}
}
}
@@ -543,10 +565,15 @@ mod tests {
rec_size: 0,
interest: EventSet::hup() | EventSet::readable(),
stats: Arc::<NetworkStats>::new(NetworkStats::new()),
registered: AtomicBool::new(false),
}
}
}
fn test_io() -> IoContext<i32> {
IoContext::new(IoChannel::disconnected(), 0)
}
#[test]
fn connection_expect() {
let mut connection = TestConnection::new();
@@ -557,7 +584,7 @@ mod tests {
#[test]
fn connection_write_empty() {
let mut connection = TestConnection::new();
let status = connection.writable();
let status = connection.writable(&test_io());
assert!(status.is_ok());
assert!(WriteStatus::Complete == status.unwrap());
}
@@ -568,7 +595,7 @@ mod tests {
let data = Cursor::new(vec![0; 10240]);
connection.send_queue.push_back(data);
let status = connection.writable();
let status = connection.writable(&test_io());
assert!(status.is_ok());
assert!(WriteStatus::Complete == status.unwrap());
assert_eq!(10240, connection.socket.write_buffer.len());
@@ -581,7 +608,7 @@ mod tests {
let data = Cursor::new(vec![0; 10240]);
connection.send_queue.push_back(data);
let status = connection.writable();
let status = connection.writable(&test_io());
assert!(status.is_ok());
assert!(WriteStatus::Ongoing == status.unwrap());
@@ -594,7 +621,7 @@ mod tests {
let data = Cursor::new(vec![0; 10240]);
connection.send_queue.push_back(data);
let status = connection.writable();
let status = connection.writable(&test_io());
assert!(!status.is_ok());
assert_eq!(1, connection.send_queue.len());

View File

@@ -28,7 +28,7 @@ use crypto::*;
use rlp::*;
use network::node_table::*;
use network::error::NetworkError;
use io::StreamToken;
use io::{StreamToken, IoContext};
use network::PROTOCOL_VERSION;
@@ -283,7 +283,7 @@ impl Discovery {
ret
}
pub fn writable(&mut self) {
pub fn writable<Message>(&mut self, io: &IoContext<Message>) where Message: Send + Sync + Clone {
while !self.send_queue.is_empty() {
let data = self.send_queue.pop_front().unwrap();
match self.udp_socket.send_to(&data.payload, &data.address) {
@@ -302,15 +302,17 @@ impl Discovery {
}
}
}
io.update_registration(self.token).unwrap_or_else(|e| debug!("Error updating discovery registration: {:?}", e));
}
fn send_to(&mut self, payload: Bytes, address: SocketAddr) {
self.send_queue.push_back(Datagramm { payload: payload, address: address });
}
pub fn readable(&mut self) -> Option<TableUpdates> {
pub fn readable<Message>(&mut self, io: &IoContext<Message>) -> Option<TableUpdates> where Message: Send + Sync + Clone {
let mut buf: [u8; MAX_DATAGRAM_SIZE] = unsafe { mem::uninitialized() };
match self.udp_socket.recv_from(&mut buf) {
let writable = !self.send_queue.is_empty();
let res = match self.udp_socket.recv_from(&mut buf) {
Ok(Some((len, address))) => self.on_packet(&buf[0..len], address).unwrap_or_else(|e| {
debug!("Error processing UDP packet: {:?}", e);
None
@@ -320,7 +322,12 @@ impl Discovery {
debug!("Error reading UPD socket: {:?}", e);
None
}
};
let new_writable = !self.send_queue.is_empty();
if writable != new_writable {
io.update_registration(self.token).unwrap_or_else(|e| debug!("Error updating discovery registration: {:?}", e));
}
res
}
fn on_packet(&mut self, packet: &[u8], from: SocketAddr) -> Result<Option<TableUpdates>, NetworkError> {

View File

@@ -111,7 +111,7 @@ impl Handshake {
self.originated = originated;
io.register_timer(self.connection.token, HANDSHAKE_TIMEOUT).ok();
if originated {
try!(self.write_auth(host.secret(), host.id()));
try!(self.write_auth(io, host.secret(), host.id()));
}
else {
self.state = HandshakeState::ReadingAuth;
@@ -128,17 +128,16 @@ impl Handshake {
/// Readable IO handler. Drives the state change.
pub fn readable<Message>(&mut self, io: &IoContext<Message>, host: &HostInfo) -> Result<(), UtilError> where Message: Send + Clone {
if !self.expired() {
io.clear_timer(self.connection.token).unwrap();
match self.state {
HandshakeState::New => {}
HandshakeState::ReadingAuth => {
if let Some(data) = try!(self.connection.readable()) {
try!(self.read_auth(host.secret(), &data));
try!(self.read_auth(io, host.secret(), &data));
};
},
HandshakeState::ReadingAuthEip8 => {
if let Some(data) = try!(self.connection.readable()) {
try!(self.read_auth_eip8(host.secret(), &data));
try!(self.read_auth_eip8(io, host.secret(), &data));
};
},
HandshakeState::ReadingAck => {
@@ -151,10 +150,9 @@ impl Handshake {
try!(self.read_ack_eip8(host.secret(), &data));
};
},
HandshakeState::StartSession => {},
}
if self.state != HandshakeState::StartSession {
try!(io.update_registration(self.connection.token));
HandshakeState::StartSession => {
io.clear_timer(self.connection.token).ok();
},
}
}
Ok(())
@@ -163,11 +161,7 @@ impl Handshake {
/// Writabe IO handler.
pub fn writable<Message>(&mut self, io: &IoContext<Message>) -> Result<(), UtilError> where Message: Send + Clone {
if !self.expired() {
io.clear_timer(self.connection.token).unwrap();
try!(self.connection.writable());
if self.state != HandshakeState::StartSession {
io.update_registration(self.connection.token).unwrap();
}
try!(self.connection.writable(io));
}
Ok(())
}
@@ -183,7 +177,7 @@ impl Handshake {
}
/// Parse, validate and confirm auth message
fn read_auth(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> {
fn read_auth<Message>(&mut self, io: &IoContext<Message>, secret: &Secret, data: &[u8]) -> Result<(), UtilError> where Message: Send + Clone {
trace!(target:"network", "Received handshake auth from {:?}", self.connection.remote_addr_str());
if data.len() != V4_AUTH_PACKET_SIZE {
debug!(target:"net", "Wrong auth packet size");
@@ -197,7 +191,7 @@ impl Handshake {
let (pubk, rest) = rest.split_at(64);
let (nonce, _) = rest.split_at(32);
try!(self.set_auth(secret, sig, pubk, nonce, PROTOCOL_VERSION));
try!(self.write_ack());
try!(self.write_ack(io));
}
Err(_) => {
// Try to interpret as EIP-8 packet
@@ -214,7 +208,7 @@ impl Handshake {
Ok(())
}
fn read_auth_eip8(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> {
fn read_auth_eip8<Message>(&mut self, io: &IoContext<Message>, secret: &Secret, data: &[u8]) -> Result<(), UtilError> where Message: Send + Clone {
trace!(target:"network", "Received EIP8 handshake auth from {:?}", self.connection.remote_addr_str());
self.auth_cipher.extend_from_slice(data);
let auth = try!(ecies::decrypt(secret, &self.auth_cipher[0..2], &self.auth_cipher[2..]));
@@ -224,13 +218,13 @@ impl Handshake {
let remote_nonce: H256 = try!(rlp.val_at(2));
let remote_version: u64 = try!(rlp.val_at(3));
try!(self.set_auth(secret, &signature, &remote_public, &remote_nonce, remote_version));
try!(self.write_ack_eip8());
try!(self.write_ack_eip8(io));
Ok(())
}
/// Parse and validate ack message
fn read_ack(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> {
trace!(target:"network", "Received handshake auth to {:?}", self.connection.remote_addr_str());
trace!(target:"network", "Received handshake ack from {:?}", self.connection.remote_addr_str());
if data.len() != V4_ACK_PACKET_SIZE {
debug!(target:"net", "Wrong ack packet size");
return Err(From::from(NetworkError::BadProtocol));
@@ -270,7 +264,7 @@ impl Handshake {
}
/// Sends auth message
fn write_auth(&mut self, secret: &Secret, public: &Public) -> Result<(), UtilError> {
fn write_auth<Message>(&mut self, io: &IoContext<Message>, secret: &Secret, public: &Public) -> Result<(), UtilError> where Message: Send + Clone {
trace!(target:"network", "Sending handshake auth to {:?}", self.connection.remote_addr_str());
let mut data = [0u8; /*Signature::SIZE*/ 65 + /*H256::SIZE*/ 32 + /*Public::SIZE*/ 64 + /*H256::SIZE*/ 32 + 1]; //TODO: use associated constants
let len = data.len();
@@ -290,14 +284,14 @@ impl Handshake {
}
let message = try!(crypto::ecies::encrypt(&self.id, &[], &data));
self.auth_cipher = message.clone();
self.connection.send(message);
self.connection.send(io, message);
self.connection.expect(V4_ACK_PACKET_SIZE);
self.state = HandshakeState::ReadingAck;
Ok(())
}
/// Sends ack message
fn write_ack(&mut self) -> Result<(), UtilError> {
fn write_ack<Message>(&mut self, io: &IoContext<Message>) -> Result<(), UtilError> where Message: Send + Clone {
trace!(target:"network", "Sending handshake ack to {:?}", self.connection.remote_addr_str());
let mut data = [0u8; 1 + /*Public::SIZE*/ 64 + /*H256::SIZE*/ 32]; //TODO: use associated constants
let len = data.len();
@@ -310,13 +304,13 @@ impl Handshake {
}
let message = try!(crypto::ecies::encrypt(&self.id, &[], &data));
self.ack_cipher = message.clone();
self.connection.send(message);
self.connection.send(io, message);
self.state = HandshakeState::StartSession;
Ok(())
}
/// Sends EIP8 ack message
fn write_ack_eip8(&mut self) -> Result<(), UtilError> {
fn write_ack_eip8<Message>(&mut self, io: &IoContext<Message>) -> Result<(), UtilError> where Message: Send + Clone {
trace!(target:"network", "Sending EIP8 handshake ack to {:?}", self.connection.remote_addr_str());
let mut rlp = RlpStream::new_list(3);
rlp.append(self.ecdhe.public());
@@ -333,7 +327,7 @@ impl Handshake {
let message = try!(crypto::ecies::encrypt(&self.id, &prefix, &encoded));
self.ack_cipher.extend_from_slice(&prefix);
self.ack_cipher.extend_from_slice(&message);
self.connection.send(self.ack_cipher.clone());
self.connection.send(io, self.ack_cipher.clone());
self.state = HandshakeState::StartSession;
Ok(())
}
@@ -347,6 +341,7 @@ mod test {
use super::*;
use crypto::*;
use hash::*;
use io::*;
use std::net::SocketAddr;
use mio::tcp::TcpStream;
use network::stats::NetworkStats;
@@ -371,6 +366,10 @@ mod test {
Handshake::new(0, to, socket, &nonce, Arc::new(NetworkStats::new())).unwrap()
}
fn test_io() -> IoContext<i32> {
IoContext::new(IoChannel::disconnected(), 0)
}
#[test]
fn test_handshake_auth_plain() {
let mut h = create_handshake(None);
@@ -387,7 +386,7 @@ mod test {
a4592ee77e2bd94d0be3691f3b406f9bba9b591fc63facc016bfa8\
".from_hex().unwrap();
h.read_auth(&secret, &auth).unwrap();
h.read_auth(&test_io(), &secret, &auth).unwrap();
assert_eq!(h.state, super::HandshakeState::StartSession);
check_auth(&h, 4);
}
@@ -411,9 +410,9 @@ mod test {
3bf7678318e2d5b5340c9e488eefea198576344afbdf66db5f51204a6961a63ce072c8926c\
".from_hex().unwrap();
h.read_auth(&secret, &auth[0..super::V4_AUTH_PACKET_SIZE]).unwrap();
h.read_auth(&test_io(), &secret, &auth[0..super::V4_AUTH_PACKET_SIZE]).unwrap();
assert_eq!(h.state, super::HandshakeState::ReadingAuthEip8);
h.read_auth_eip8(&secret, &auth[super::V4_AUTH_PACKET_SIZE..]).unwrap();
h.read_auth_eip8(&test_io(), &secret, &auth[super::V4_AUTH_PACKET_SIZE..]).unwrap();
assert_eq!(h.state, super::HandshakeState::StartSession);
check_auth(&h, 4);
}
@@ -438,9 +437,9 @@ mod test {
d490\
".from_hex().unwrap();
h.read_auth(&secret, &auth[0..super::V4_AUTH_PACKET_SIZE]).unwrap();
h.read_auth(&test_io(), &secret, &auth[0..super::V4_AUTH_PACKET_SIZE]).unwrap();
assert_eq!(h.state, super::HandshakeState::ReadingAuthEip8);
h.read_auth_eip8(&secret, &auth[super::V4_AUTH_PACKET_SIZE..]).unwrap();
h.read_auth_eip8(&test_io(), &secret, &auth[super::V4_AUTH_PACKET_SIZE..]).unwrap();
assert_eq!(h.state, super::HandshakeState::StartSession);
check_auth(&h, 56);
let ack = h.ack_cipher.clone();

View File

@@ -14,11 +14,11 @@
// You should have received a copy of the GNU General Public License
// along with Parity. If not, see <http://www.gnu.org/licenses/>.
use std::net::{SocketAddr};
use std::collections::{HashMap};
use std::str::{FromStr};
use std::net::SocketAddr;
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::sync::*;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::sync::atomic::{AtomicUsize, AtomicBool, Ordering as AtomicOrdering};
use std::ops::*;
use std::cmp::min;
use std::path::{Path, PathBuf};
@@ -35,12 +35,13 @@ use rlp::*;
use network::session::{Session, SessionData};
use error::*;
use io::*;
use network::{NetworkProtocolHandler, PROTOCOL_VERSION};
use network::{NetworkProtocolHandler, NonReservedPeerMode, PROTOCOL_VERSION};
use network::node_table::*;
use network::stats::NetworkStats;
use network::error::{NetworkError, DisconnectReason};
use network::discovery::{Discovery, TableUpdates, NodeEntry};
use network::ip_utils::{map_external_address, select_public_address};
use path::restrict_permissions_owner;
type Slab<T> = ::slab::Slab<T, usize>;
@@ -49,7 +50,7 @@ const MAX_HANDSHAKES: usize = 80;
const MAX_HANDSHAKES_PER_ROUND: usize = 32;
const MAINTENANCE_TIMEOUT: u64 = 1000;
#[derive(Debug)]
#[derive(Debug, Clone)]
/// Network service configuration
pub struct NetworkConfiguration {
/// Directory path to store network configuration. None means nothing will be saved
@@ -64,14 +65,16 @@ pub struct NetworkConfiguration {
pub nat_enabled: bool,
/// Enable discovery
pub discovery_enabled: bool,
/// Pin to boot nodes only
pub pin: bool,
/// List of initial node addresses
pub boot_nodes: Vec<String>,
/// Use provided node key instead of default
pub use_secret: Option<Secret>,
/// Number of connected peers to maintain
pub ideal_peers: u32,
/// List of reserved node addresses.
pub reserved_nodes: Vec<String>,
/// The non-reserved peer mode.
pub non_reserved_mode: NonReservedPeerMode,
}
impl Default for NetworkConfiguration {
@@ -90,10 +93,11 @@ impl NetworkConfiguration {
udp_port: None,
nat_enabled: true,
discovery_enabled: true,
pin: false,
boot_nodes: Vec::new(),
use_secret: None,
ideal_peers: 25,
reserved_nodes: Vec::new(),
non_reserved_mode: NonReservedPeerMode::Accept,
}
}
@@ -187,13 +191,15 @@ pub struct NetworkContext<'s, Message> where Message: Send + Sync + Clone + 'sta
sessions: Arc<RwLock<Slab<SharedSession>>>,
session: Option<SharedSession>,
session_id: Option<StreamToken>,
_reserved_peers: &'s HashSet<NodeId>,
}
impl<'s, Message> NetworkContext<'s, Message> where Message: Send + Sync + Clone + 'static, {
/// Create a new network IO access point. Takes references to all the data that can be updated within the IO handler.
fn new(io: &'s IoContext<NetworkIoMessage<Message>>,
protocol: ProtocolId,
session: Option<SharedSession>, sessions: Arc<RwLock<Slab<SharedSession>>>) -> NetworkContext<'s, Message> {
session: Option<SharedSession>, sessions: Arc<RwLock<Slab<SharedSession>>>,
reserved_peers: &'s HashSet<NodeId>) -> NetworkContext<'s, Message> {
let id = session.as_ref().map(|s| s.lock().unwrap().token());
NetworkContext {
io: io,
@@ -201,6 +207,7 @@ impl<'s, Message> NetworkContext<'s, Message> where Message: Send + Sync + Clone
session_id: id,
session: session,
sessions: sessions,
_reserved_peers: reserved_peers,
}
}
@@ -215,8 +222,7 @@ impl<'s, Message> NetworkContext<'s, Message> where Message: Send + Sync + Clone
pub fn send(&self, peer: PeerId, packet_id: PacketId, data: Vec<u8>) -> Result<(), UtilError> {
let session = self.resolve_session(peer);
if let Some(session) = session {
try!(session.lock().unwrap().deref_mut().send_packet(self.protocol, packet_id as u8, &data));
try!(self.io.update_registration(peer));
try!(session.lock().unwrap().send_packet(self.io, self.protocol, packet_id as u8, &data));
} else {
trace!(target: "network", "Send: Peer no longer exist")
}
@@ -230,19 +236,31 @@ impl<'s, Message> NetworkContext<'s, Message> where Message: Send + Sync + Clone
}
/// Send an IO message
pub fn message(&self, msg: Message) {
self.io.message(NetworkIoMessage::User(msg));
pub fn message(&self, msg: Message) -> Result<(), UtilError> {
self.io.message(NetworkIoMessage::User(msg))
}
/// Get an IoChannel.
pub fn io_channel(&self) -> IoChannel<NetworkIoMessage<Message>> {
self.io.channel()
}
/// Disable current protocol capability for given peer. If no capabilities left peer gets disconnected.
pub fn disable_peer(&self, peer: PeerId) {
//TODO: remove capability, disconnect if no capabilities left
self.io.message(NetworkIoMessage::DisablePeer(peer));
self.io.message(NetworkIoMessage::DisablePeer(peer))
.unwrap_or_else(|e| warn!("Error sending network IO message: {:?}", e));
}
/// Disconnect peer. Reconnect can be attempted later.
pub fn disconnect_peer(&self, peer: PeerId) {
self.io.message(NetworkIoMessage::Disconnect(peer));
self.io.message(NetworkIoMessage::Disconnect(peer))
.unwrap_or_else(|e| warn!("Error sending network IO message: {:?}", e));
}
/// Check if the session is still active.
pub fn is_expired(&self) -> bool {
self.session.as_ref().map_or(false, |s| s.lock().unwrap().expired())
}
/// Register a new IO timer. 'IoHandler::timeout' will be called with the token.
@@ -251,7 +269,7 @@ impl<'s, Message> NetworkContext<'s, Message> where Message: Send + Sync + Clone
token: token,
delay: ms,
protocol: self.protocol,
});
}).unwrap_or_else(|e| warn!("Error sending network IO message: {:?}", e));
Ok(())
}
@@ -322,13 +340,14 @@ pub struct Host<Message> where Message: Send + Sync + Clone {
timers: RwLock<HashMap<TimerToken, ProtocolTimer>>,
timer_counter: RwLock<usize>,
stats: Arc<NetworkStats>,
pinned_nodes: Vec<NodeId>,
reserved_nodes: RwLock<HashSet<NodeId>>,
num_sessions: AtomicUsize,
stopping: AtomicBool,
}
impl<Message> Host<Message> where Message: Send + Sync + Clone {
/// Create a new instance
pub fn new(config: NetworkConfiguration) -> Result<Host<Message>, UtilError> {
pub fn new(config: NetworkConfiguration, stats: Arc<NetworkStats>) -> Result<Host<Message>, UtilError> {
let mut listen_address = match config.listen_address {
None => SocketAddr::from_str("0.0.0.0:30304").unwrap(),
Some(addr) => addr,
@@ -354,6 +373,9 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
let udp_port = config.udp_port.unwrap_or(listen_address.port());
let local_endpoint = NodeEndpoint { address: listen_address, udp_port: udp_port };
let boot_nodes = config.boot_nodes.clone();
let reserved_nodes = config.reserved_nodes.clone();
let mut host = Host::<Message> {
info: RwLock::new(HostInfo {
keys: keys,
@@ -372,20 +394,22 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
handlers: RwLock::new(HashMap::new()),
timers: RwLock::new(HashMap::new()),
timer_counter: RwLock::new(USER_TIMER),
stats: Arc::new(NetworkStats::default()),
pinned_nodes: Vec::new(),
stats: stats,
reserved_nodes: RwLock::new(HashSet::new()),
num_sessions: AtomicUsize::new(0),
stopping: AtomicBool::new(false),
};
let boot_nodes = host.info.read().unwrap().config.boot_nodes.clone();
for n in boot_nodes {
host.add_node(&n);
}
Ok(host)
}
pub fn stats(&self) -> Arc<NetworkStats> {
self.stats.clone()
for n in reserved_nodes {
if let Err(e) = host.add_reserved_node(&n) {
debug!(target: "network", "Error parsing node id: {}: {:?}", n, e);
}
}
Ok(host)
}
pub fn add_node(&mut self, id: &str) {
@@ -393,17 +417,67 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
Err(e) => { debug!(target: "network", "Could not add node {}: {:?}", id, e); },
Ok(n) => {
let entry = NodeEntry { endpoint: n.endpoint.clone(), id: n.id.clone() };
self.pinned_nodes.push(n.id.clone());
self.nodes.write().unwrap().add_node(n);
if let Some(ref mut discovery) = *self.discovery.lock().unwrap().deref_mut() {
if let Some(ref mut discovery) = *self.discovery.lock().unwrap() {
discovery.add_node(entry);
}
}
}
}
pub fn client_version(&self) -> String {
self.info.read().unwrap().client_version.clone()
pub fn add_reserved_node(&self, id: &str) -> Result<(), UtilError> {
let n = try!(Node::from_str(id));
let entry = NodeEntry { endpoint: n.endpoint.clone(), id: n.id.clone() };
self.reserved_nodes.write().unwrap().insert(n.id.clone());
if let Some(ref mut discovery) = *self.discovery.lock().unwrap() {
discovery.add_node(entry);
}
Ok(())
}
pub fn set_non_reserved_mode(&self, mode: NonReservedPeerMode, io: &IoContext<NetworkIoMessage<Message>>) {
let mut info = self.info.write().unwrap();
if info.config.non_reserved_mode != mode {
info.config.non_reserved_mode = mode.clone();
drop(info);
if let NonReservedPeerMode::Deny = mode {
// disconnect all non-reserved peers here.
let reserved: HashSet<NodeId> = self.reserved_nodes.read().unwrap().clone();
let mut to_kill = Vec::new();
for e in self.sessions.write().unwrap().iter_mut() {
let mut s = e.lock().unwrap();
{
let id = s.id();
if id.is_some() && reserved.contains(id.unwrap()) {
continue;
}
}
s.disconnect(io, DisconnectReason::ClientQuit);
to_kill.push(s.token());
}
for p in to_kill {
trace!(target: "network", "Disconnecting on reserved-only mode: {}", p);
self.kill_connection(p, io, false);
}
}
}
}
pub fn remove_reserved_node(&self, id: &str) -> Result<(), UtilError> {
let n = try!(Node::from_str(id));
self.reserved_nodes.write().unwrap().remove(&n.id);
Ok(())
}
pub fn client_version() -> String {
version()
}
pub fn external_url(&self) -> Option<String> {
@@ -416,6 +490,22 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
r
}
pub fn stop(&self, io: &IoContext<NetworkIoMessage<Message>>) -> Result<(), UtilError> {
self.stopping.store(true, AtomicOrdering::Release);
let mut to_kill = Vec::new();
for e in self.sessions.write().unwrap().iter_mut() {
let mut s = e.lock().unwrap();
s.disconnect(io, DisconnectReason::ClientQuit);
to_kill.push(s.token());
}
for p in to_kill {
trace!(target: "network", "Disconnecting on shutdown: {}", p);
self.kill_connection(p, io, true);
}
try!(io.unregister_handler());
Ok(())
}
fn init_public_interface(&self, io: &IoContext<NetworkIoMessage<Message>>) -> Result<(), UtilError> {
io.clear_timer(INIT_PUBLIC).unwrap();
if self.info.read().unwrap().public_endpoint.is_some() {
@@ -448,7 +538,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
// Initialize discovery.
let discovery = {
let info = self.info.read().unwrap();
if info.config.discovery_enabled && !info.config.pin {
if info.config.discovery_enabled && info.config.non_reserved_mode == NonReservedPeerMode::Accept {
Some(Discovery::new(&info.keys, public_endpoint.address.clone(), public_endpoint, DISCOVERY))
} else { None }
};
@@ -458,7 +548,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
for n in self.nodes.read().unwrap().unordered_entries() {
discovery.add_node(n.clone());
}
*self.discovery.lock().unwrap().deref_mut() = Some(discovery);
*self.discovery.lock().unwrap() = Some(discovery);
io.register_stream(DISCOVERY).expect("Error registering UDP listener");
io.register_timer(DISCOVERY_REFRESH, 7200).expect("Error registering discovery timer");
io.register_timer(DISCOVERY_ROUND, 300).expect("Error registering discovery timer");
@@ -494,7 +584,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
for e in self.sessions.write().unwrap().iter_mut() {
let mut s = e.lock().unwrap();
if !s.keep_alive(io) {
s.disconnect(DisconnectReason::PingTimeout);
s.disconnect(io, DisconnectReason::PingTimeout);
to_kill.push(s.token());
}
}
@@ -505,14 +595,26 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
}
fn connect_peers(&self, io: &IoContext<NetworkIoMessage<Message>>) {
if self.info.read().unwrap().deref().capabilities.is_empty() {
return;
}
let ideal_peers = { self.info.read().unwrap().deref().config.ideal_peers };
let pin = { self.info.read().unwrap().deref().config.pin };
let (ideal_peers, mut pin) = {
let info = self.info.read().unwrap();
if info.capabilities.is_empty() {
return;
}
let config = &info.config;
(config.ideal_peers, config.non_reserved_mode == NonReservedPeerMode::Deny)
};
let session_count = self.session_count();
if session_count >= ideal_peers as usize {
return;
let reserved_nodes = self.reserved_nodes.read().unwrap();
if session_count >= ideal_peers as usize + reserved_nodes.len() {
// check if all pinned nodes are connected.
if reserved_nodes.iter().all(|n| self.have_session(n) && self.connecting_to(n)) {
return;
}
// if not, only attempt connect to reserved peers
pin = true;
}
let handshake_count = self.handshake_count();
@@ -522,12 +624,21 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
return;
}
let nodes = if pin { self.pinned_nodes.clone() } else { self.nodes.read().unwrap().nodes() };
for id in nodes.iter().filter(|ref id| !self.have_session(id) && !self.connecting_to(id))
// iterate over all nodes, reserved ones coming first.
// if we are pinned to only reserved nodes, ignore all others.
let nodes = reserved_nodes.iter().cloned().chain(if !pin {
self.nodes.read().unwrap().nodes()
} else {
Vec::new()
});
let mut started: usize = 0;
for id in nodes.filter(|ref id| !self.have_session(id) && !self.connecting_to(id))
.take(min(MAX_HANDSHAKES_PER_ROUND, handshake_limit - handshake_count)) {
self.connect_peer(&id, io);
started += 1;
}
debug!(target: "network", "Connecting peers: {} sessions, {} pending", self.session_count(), self.handshake_count());
debug!(target: "network", "Connecting peers: {} sessions, {} pending, {} started", self.session_count(), self.handshake_count(), started);
}
#[cfg_attr(feature="dev", allow(single_match))]
@@ -605,7 +716,6 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
debug!(target: "network", "Can't accept connection: {:?}", e);
}
}
io.update_registration(TCP_ACCEPT).expect("Error registering TCP listener");
}
fn session_writable(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) {
@@ -616,9 +726,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
trace!(target: "network", "Session write error: {}: {:?}", token, e);
}
if s.done() {
io.deregister_stream(token).expect("Error deregistering stream");
} else {
io.update_registration(token).unwrap_or_else(|e| debug!(target: "network", "Session registration error: {:?}", e));
io.deregister_stream(token).unwrap_or_else(|e| debug!("Error deregistering stream: {:?}", e));
}
}
}
@@ -628,78 +736,90 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
self.kill_connection(token, io, true);
}
#[cfg_attr(feature="dev", allow(collapsible_if))]
fn session_readable(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) {
let mut ready_data: Vec<ProtocolId> = Vec::new();
let mut packet_data: Option<(ProtocolId, PacketId, Vec<u8>)> = None;
let mut packet_data: Vec<(ProtocolId, PacketId, Vec<u8>)> = Vec::new();
let mut kill = false;
let session = { self.sessions.read().unwrap().get(token).cloned() };
if let Some(session) = session.clone() {
let mut s = session.lock().unwrap();
match s.readable(io, &self.info.read().unwrap()) {
Err(e) => {
trace!(target: "network", "Session read error: {}:{:?} ({:?}) {:?}", token, s.id(), s.remote_addr(), e);
match e {
UtilError::Network(NetworkError::Disconnect(DisconnectReason::UselessPeer)) |
UtilError::Network(NetworkError::Disconnect(DisconnectReason::IncompatibleProtocol)) => {
loop {
match s.readable(io, &self.info.read().unwrap()) {
Err(e) => {
trace!(target: "network", "Session read error: {}:{:?} ({:?}) {:?}", token, s.id(), s.remote_addr(), e);
if let UtilError::Network(NetworkError::Disconnect(DisconnectReason::IncompatibleProtocol)) = e {
if let Some(id) = s.id() {
self.nodes.write().unwrap().mark_as_useless(id);
}
}
_ => (),
}
kill = true;
},
Ok(SessionData::Ready) => {
if !s.info.originated {
let session_count = self.session_count();
let ideal_peers = { self.info.read().unwrap().deref().config.ideal_peers };
if session_count >= ideal_peers as usize {
s.disconnect(DisconnectReason::TooManyPeers);
return;
}
// Add it no node table
if let Ok(address) = s.remote_addr() {
let entry = NodeEntry { id: s.id().unwrap().clone(), endpoint: NodeEndpoint { address: address, udp_port: address.port() } };
self.nodes.write().unwrap().add_node(Node::new(entry.id.clone(), entry.endpoint.clone()));
let mut discovery = self.discovery.lock().unwrap();
if let Some(ref mut discovery) = *discovery.deref_mut() {
discovery.add_node(entry);
kill = true;
break;
},
Ok(SessionData::Ready) => {
self.num_sessions.fetch_add(1, AtomicOrdering::SeqCst);
if !s.info.originated {
let session_count = self.session_count();
let reserved_nodes = self.reserved_nodes.read().unwrap();
let (ideal_peers, reserved_only) = {
let info = self.info.read().unwrap();
(info.config.ideal_peers, info.config.non_reserved_mode == NonReservedPeerMode::Deny)
};
if session_count >= ideal_peers as usize || reserved_only {
// only proceed if the connecting peer is reserved.
if !reserved_nodes.contains(s.id().unwrap()) {
s.disconnect(io, DisconnectReason::TooManyPeers);
return;
}
}
// Add it no node table
if let Ok(address) = s.remote_addr() {
let entry = NodeEntry { id: s.id().unwrap().clone(), endpoint: NodeEndpoint { address: address, udp_port: address.port() } };
self.nodes.write().unwrap().add_node(Node::new(entry.id.clone(), entry.endpoint.clone()));
let mut discovery = self.discovery.lock().unwrap();
if let Some(ref mut discovery) = *discovery.deref_mut() {
discovery.add_node(entry);
}
}
}
}
self.num_sessions.fetch_add(1, AtomicOrdering::SeqCst);
for (p, _) in self.handlers.read().unwrap().iter() {
if s.have_capability(p) {
ready_data.push(p);
for (p, _) in self.handlers.read().unwrap().iter() {
if s.have_capability(p) {
ready_data.push(p);
}
}
}
},
Ok(SessionData::Packet {
data,
protocol,
packet_id,
}) => {
match self.handlers.read().unwrap().get(protocol) {
None => { warn!(target: "network", "No handler found for protocol: {:?}", protocol) },
Some(_) => packet_data = Some((protocol, packet_id, data)),
}
},
Ok(SessionData::None) => {},
}
},
Ok(SessionData::Packet {
data,
protocol,
packet_id,
}) => {
match self.handlers.read().unwrap().get(protocol) {
None => { warn!(target: "network", "No handler found for protocol: {:?}", protocol) },
Some(_) => packet_data.push((protocol, packet_id, data)),
}
},
Ok(SessionData::Continue) => (),
Ok(SessionData::None) => break,
}
}
}
if kill {
self.kill_connection(token, io, true);
}
let handlers = self.handlers.read().unwrap();
for p in ready_data {
let h = self.handlers.read().unwrap().get(p).unwrap().clone();
let h = handlers.get(p).unwrap().clone();
self.stats.inc_sessions();
h.connected(&NetworkContext::new(io, p, session.clone(), self.sessions.clone()), &token);
let reserved = self.reserved_nodes.read().unwrap();
h.connected(&NetworkContext::new(io, p, session.clone(), self.sessions.clone(), &reserved), &token);
}
if let Some((p, packet_id, data)) = packet_data {
let h = self.handlers.read().unwrap().get(p).unwrap().clone();
h.read(&NetworkContext::new(io, p, session.clone(), self.sessions.clone()), &token, packet_id, &data[1..]);
for (p, packet_id, data) in packet_data {
let h = handlers.get(p).unwrap().clone();
let reserved = self.reserved_nodes.read().unwrap();
h.read(&NetworkContext::new(io, p, session.clone(), self.sessions.clone(), &reserved), &token, packet_id, &data[1..]);
}
io.update_registration(token).unwrap_or_else(|e| debug!(target: "network", "Token registration error: {:?}", e));
}
fn connection_timeout(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) {
@@ -719,9 +839,9 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
let mut s = session.lock().unwrap();
if !s.expired() {
if s.is_ready() {
self.num_sessions.fetch_sub(1, AtomicOrdering::SeqCst);
for (p, _) in self.handlers.read().unwrap().iter() {
if s.have_capability(p) {
self.num_sessions.fetch_sub(1, AtomicOrdering::SeqCst);
to_disconnect.push(p);
}
}
@@ -739,12 +859,11 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
}
for p in to_disconnect {
let h = self.handlers.read().unwrap().get(p).unwrap().clone();
h.disconnected(&NetworkContext::new(io, p, expired_session.clone(), self.sessions.clone()), &token);
let reserved = self.reserved_nodes.read().unwrap();
h.disconnected(&NetworkContext::new(io, p, expired_session.clone(), self.sessions.clone(), &reserved), &token);
}
if deregister {
io.deregister_stream(token).expect("Error deregistering stream");
} else if expired_session.is_some() {
io.update_registration(token).unwrap_or_else(|e| debug!(target: "network", "Connection registration error: {:?}", e));
io.deregister_stream(token).unwrap_or_else(|e| debug!("Error deregistering stream: {:?}", e));
}
}
@@ -786,14 +905,16 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
}
fn stream_readable(&self, io: &IoContext<NetworkIoMessage<Message>>, stream: StreamToken) {
if self.stopping.load(AtomicOrdering::Acquire) {
return;
}
match stream {
FIRST_SESSION ... LAST_SESSION => self.session_readable(stream, io),
DISCOVERY => {
let node_changes = { self.discovery.lock().unwrap().as_mut().unwrap().readable() };
let node_changes = { self.discovery.lock().unwrap().as_mut().unwrap().readable(io) };
if let Some(node_changes) = node_changes {
self.update_nodes(io, node_changes);
}
io.update_registration(DISCOVERY).expect("Error updating discovery registration");
},
TCP_ACCEPT => self.accept(io),
_ => panic!("Received unknown readable token"),
@@ -801,17 +922,22 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
}
fn stream_writable(&self, io: &IoContext<NetworkIoMessage<Message>>, stream: StreamToken) {
if self.stopping.load(AtomicOrdering::Acquire) {
return;
}
match stream {
FIRST_SESSION ... LAST_SESSION => self.session_writable(stream, io),
DISCOVERY => {
self.discovery.lock().unwrap().as_mut().unwrap().writable();
io.update_registration(DISCOVERY).expect("Error updating discovery registration");
self.discovery.lock().unwrap().as_mut().unwrap().writable(io);
}
_ => panic!("Received unknown writable token"),
}
}
fn timeout(&self, io: &IoContext<NetworkIoMessage<Message>>, token: TimerToken) {
if self.stopping.load(AtomicOrdering::Acquire) {
return;
}
match token {
IDLE => self.maintain_network(io),
INIT_PUBLIC => self.init_public_interface(io).unwrap_or_else(|e|
@@ -819,22 +945,26 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
FIRST_SESSION ... LAST_SESSION => self.connection_timeout(token, io),
DISCOVERY_REFRESH => {
self.discovery.lock().unwrap().as_mut().unwrap().refresh();
io.update_registration(DISCOVERY).expect("Error updating discovery registration");
io.update_registration(DISCOVERY).unwrap_or_else(|e| debug!("Error updating discovery registration: {:?}", e));
},
DISCOVERY_ROUND => {
let node_changes = { self.discovery.lock().unwrap().as_mut().unwrap().round() };
if let Some(node_changes) = node_changes {
self.update_nodes(io, node_changes);
}
io.update_registration(DISCOVERY).expect("Error updating discovery registration");
io.update_registration(DISCOVERY).unwrap_or_else(|e| debug!("Error updating discovery registration: {:?}", e));
},
NODE_TABLE => {
trace!(target: "network", "Refreshing node table");
self.nodes.write().unwrap().clear_useless();
},
_ => match self.timers.read().unwrap().get(&token).cloned() {
Some(timer) => match self.handlers.read().unwrap().get(timer.protocol).cloned() {
None => { warn!(target: "network", "No handler found for protocol: {:?}", timer.protocol) },
Some(h) => { h.timeout(&NetworkContext::new(io, timer.protocol, None, self.sessions.clone()), timer.token); }
None => { warn!(target: "network", "No handler found for protocol: {:?}", timer.protocol) },
Some(h) => {
let reserved = self.reserved_nodes.read().unwrap();
h.timeout(&NetworkContext::new(io, timer.protocol, None, self.sessions.clone(), &reserved), timer.token);
}
},
None => { warn!("Unknown timer token: {}", token); } // timer is not registerd through us
}
@@ -842,6 +972,9 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
}
fn message(&self, io: &IoContext<NetworkIoMessage<Message>>, message: &NetworkIoMessage<Message>) {
if self.stopping.load(AtomicOrdering::Acquire) {
return;
}
match *message {
NetworkIoMessage::AddHandler {
ref handler,
@@ -849,7 +982,8 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
ref versions
} => {
let h = handler.clone();
h.initialize(&NetworkContext::new(io, protocol, None, self.sessions.clone()));
let reserved = self.reserved_nodes.read().unwrap();
h.initialize(&NetworkContext::new(io, protocol, None, self.sessions.clone(), &reserved));
self.handlers.write().unwrap().insert(protocol, h);
let mut info = self.info.write().unwrap();
for v in versions {
@@ -863,18 +997,18 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
} => {
let handler_token = {
let mut timer_counter = self.timer_counter.write().unwrap();
let counter = timer_counter.deref_mut();
let counter = &mut *timer_counter;
let handler_token = *counter;
*counter += 1;
handler_token
};
self.timers.write().unwrap().insert(handler_token, ProtocolTimer { protocol: protocol, token: *token });
io.register_timer(handler_token, *delay).expect("Error registering timer");
io.register_timer(handler_token, *delay).unwrap_or_else(|e| debug!("Error registering timer {}: {:?}", token, e));
},
NetworkIoMessage::Disconnect(ref peer) => {
let session = { self.sessions.read().unwrap().get(*peer).cloned() };
if let Some(session) = session {
session.lock().unwrap().disconnect(DisconnectReason::DisconnectRequested);
session.lock().unwrap().disconnect(io, DisconnectReason::DisconnectRequested);
}
trace!(target: "network", "Disconnect requested {}", peer);
self.kill_connection(*peer, io, false);
@@ -882,7 +1016,7 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
NetworkIoMessage::DisablePeer(ref peer) => {
let session = { self.sessions.read().unwrap().get(*peer).cloned() };
if let Some(session) = session {
session.lock().unwrap().disconnect(DisconnectReason::DisconnectRequested);
session.lock().unwrap().disconnect(io, DisconnectReason::DisconnectRequested);
if let Some(id) = session.lock().unwrap().id() {
self.nodes.write().unwrap().mark_as_useless(id)
}
@@ -891,8 +1025,9 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
self.kill_connection(*peer, io, false);
},
NetworkIoMessage::User(ref message) => {
let reserved = self.reserved_nodes.read().unwrap();
for (p, h) in self.handlers.read().unwrap().iter() {
h.message(&NetworkContext::new(io, p, None, self.sessions.clone()), &message);
h.message(&NetworkContext::new(io, p, None, self.sessions.clone(), &reserved), &message);
}
}
}
@@ -907,7 +1042,7 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
}
}
DISCOVERY => self.discovery.lock().unwrap().as_ref().unwrap().register_socket(event_loop).expect("Error registering discovery socket"),
TCP_ACCEPT => event_loop.register(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error registering stream"),
TCP_ACCEPT => event_loop.register(&*self.tcp_listener.lock().unwrap(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error registering stream"),
_ => warn!("Unexpected stream registration")
}
}
@@ -935,7 +1070,7 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
}
}
DISCOVERY => self.discovery.lock().unwrap().as_ref().unwrap().update_registration(event_loop).expect("Error reregistering discovery socket"),
TCP_ACCEPT => event_loop.reregister(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error reregistering stream"),
TCP_ACCEPT => event_loop.reregister(&*self.tcp_listener.lock().unwrap(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error reregistering stream"),
_ => warn!("Unexpected stream update")
}
}
@@ -948,13 +1083,17 @@ fn save_key(path: &Path, key: &Secret) {
return;
};
path_buf.push("key");
let mut file = match fs::File::create(path_buf.as_path()) {
let path = path_buf.as_path();
let mut file = match fs::File::create(&path) {
Ok(file) => file,
Err(e) => {
warn!("Error creating key file: {:?}", e);
return;
}
};
if let Err(e) = restrict_permissions_owner(&path) {
warn!(target: "network", "Failed to modify permissions of the file (chmod: {})", e);
}
if let Err(e) = file.write(&key.hex().into_bytes()) {
warn!("Error writing key file: {:?}", e);
}
@@ -1003,6 +1142,6 @@ fn host_client_url() {
let mut config = NetworkConfiguration::new();
let key = h256_from_hex("6f7b0d801bc7b5ce7bbd930b84fd0369b3eb25d09be58d64ba811091046f3aa2");
config.use_secret = Some(key);
let host: Host<u32> = Host::new(config).unwrap();
let host: Host<u32> = Host::new(config, Arc::new(NetworkStats::new())).unwrap();
assert!(host.local_url().starts_with("enode://101b3ef5a4ea7a1c7928e24c4c75fd053c235d7b80c22ae5c03d145d0ac7396e2a4ffff9adee3133a7b05044a5cee08115fd65145e5165d646bde371010d803c@"));
}

View File

@@ -14,8 +14,8 @@
// You should have received a copy of the GNU General Public License
// along with Parity. If not, see <http://www.gnu.org/licenses/>.
//! Network and general IO module.
//!
//! Network and general IO module.
//!
//! Example usage for craeting a network service and adding an IO handler:
//!
//! ```rust
@@ -56,8 +56,9 @@
//! }
//!
//! fn main () {
//! let mut service = NetworkService::<MyMessage>::start(NetworkConfiguration::new_local()).expect("Error creating network service");
//! let mut service = NetworkService::<MyMessage>::new(NetworkConfiguration::new_local()).expect("Error creating network service");
//! service.register_protocol(Arc::new(MyHandler), "myproto", &[1u8]);
//! service.start().expect("Error starting service");
//!
//! // Wait for quit condition
//! // ...
@@ -111,3 +112,22 @@ pub trait NetworkProtocolHandler<Message>: Sync + Send where Message: Send + Syn
fn message(&self, _io: &NetworkContext<Message>, _message: &Message) {}
}
/// Non-reserved peer modes.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum NonReservedPeerMode {
/// Accept them. This is the default.
Accept,
/// Deny them.
Deny,
}
impl NonReservedPeerMode {
/// Attempt to parse the peer mode from a string.
pub fn parse(s: &str) -> Option<Self> {
match s {
"accept" => Some(NonReservedPeerMode::Accept),
"deny" => Some(NonReservedPeerMode::Deny),
_ => None,
}
}
}

View File

@@ -18,9 +18,9 @@ use std::sync::*;
use error::*;
use panics::*;
use network::{NetworkProtocolHandler, NetworkConfiguration};
use network::error::{NetworkError};
use network::error::NetworkError;
use network::host::{Host, NetworkIoMessage, ProtocolId};
use network::stats::{NetworkStats};
use network::stats::NetworkStats;
use io::*;
/// IO Service with networking
@@ -28,33 +28,33 @@ use io::*;
pub struct NetworkService<Message> where Message: Send + Sync + Clone + 'static {
io_service: IoService<NetworkIoMessage<Message>>,
host_info: String,
host: Arc<Host<Message>>,
host: RwLock<Option<Arc<Host<Message>>>>,
stats: Arc<NetworkStats>,
panic_handler: Arc<PanicHandler>
panic_handler: Arc<PanicHandler>,
config: NetworkConfiguration,
}
impl<Message> NetworkService<Message> where Message: Send + Sync + Clone + 'static {
/// Starts IO event loop
pub fn start(config: NetworkConfiguration) -> Result<NetworkService<Message>, UtilError> {
pub fn new(config: NetworkConfiguration) -> Result<NetworkService<Message>, UtilError> {
let panic_handler = PanicHandler::new_in_arc();
let mut io_service = try!(IoService::<NetworkIoMessage<Message>>::start());
let io_service = try!(IoService::<NetworkIoMessage<Message>>::start());
panic_handler.forward_from(&io_service);
let host = Arc::new(try!(Host::new(config)));
let stats = host.stats().clone();
let host_info = host.client_version();
try!(io_service.register_handler(host.clone()));
let stats = Arc::new(NetworkStats::new());
let host_info = Host::<Message>::client_version();
Ok(NetworkService {
io_service: io_service,
host_info: host_info,
stats: stats,
panic_handler: panic_handler,
host: host,
host: RwLock::new(None),
config: config,
})
}
/// Regiter a new protocol handler with the event loop.
pub fn register_protocol(&mut self, handler: Arc<NetworkProtocolHandler<Message>+Send + Sync>, protocol: ProtocolId, versions: &[u8]) -> Result<(), NetworkError> {
pub fn register_protocol(&self, handler: Arc<NetworkProtocolHandler<Message>+Send + Sync>, protocol: ProtocolId, versions: &[u8]) -> Result<(), NetworkError> {
try!(self.io_service.send_message(NetworkIoMessage::AddHandler {
handler: handler,
protocol: protocol,
@@ -69,8 +69,8 @@ impl<Message> NetworkService<Message> where Message: Send + Sync + Clone + 'stat
}
/// Returns underlying io service.
pub fn io(&mut self) -> &mut IoService<NetworkIoMessage<Message>> {
&mut self.io_service
pub fn io(&self) -> &IoService<NetworkIoMessage<Message>> {
&self.io_service
}
/// Returns network statistics.
@@ -80,12 +80,65 @@ impl<Message> NetworkService<Message> where Message: Send + Sync + Clone + 'stat
/// Returns external url if available.
pub fn external_url(&self) -> Option<String> {
self.host.external_url()
let host = self.host.read().unwrap();
host.as_ref().and_then(|h| h.external_url())
}
/// Returns external url if available.
pub fn local_url(&self) -> String {
self.host.local_url()
pub fn local_url(&self) -> Option<String> {
let host = self.host.read().unwrap();
host.as_ref().map(|h| h.local_url())
}
/// Start network IO
pub fn start(&self) -> Result<(), UtilError> {
let mut host = self.host.write().unwrap();
if host.is_none() {
let h = Arc::new(try!(Host::new(self.config.clone(), self.stats.clone())));
try!(self.io_service.register_handler(h.clone()));
*host = Some(h);
}
Ok(())
}
/// Stop network IO
pub fn stop(&self) -> Result<(), UtilError> {
let mut host = self.host.write().unwrap();
if let Some(ref host) = *host {
let io = IoContext::new(self.io_service.channel(), 0); //TODO: take token id from host
try!(host.stop(&io));
}
*host = None;
Ok(())
}
/// Try to add a reserved peer.
pub fn add_reserved_peer(&self, peer: &str) -> Result<(), UtilError> {
let host = self.host.read().unwrap();
if let Some(ref host) = *host {
host.add_reserved_node(peer)
} else {
Ok(())
}
}
/// Try to remove a reserved peer.
pub fn remove_reserved_peer(&self, peer: &str) -> Result<(), UtilError> {
let host = self.host.read().unwrap();
if let Some(ref host) = *host {
host.remove_reserved_node(peer)
} else {
Ok(())
}
}
/// Set the non-reserved peer mode.
pub fn set_non_reserved_mode(&self, mode: ::network::NonReservedPeerMode) {
let host = self.host.read().unwrap();
if let Some(ref host) = *host {
let io_ctxt = IoContext::new(self.io_service.channel(), 0);
host.set_non_reserved_mode(mode, &io_ctxt);
}
}
}

View File

@@ -68,6 +68,8 @@ pub enum SessionData {
/// Zero based packet ID
packet_id: u8,
},
/// Session has more data to be read
Continue,
}
/// Shared session information
@@ -145,7 +147,7 @@ impl Session {
})
}
fn complete_handshake(&mut self, host: &HostInfo) -> Result<(), UtilError> {
fn complete_handshake<Message>(&mut self, io: &IoContext<Message>, host: &HostInfo) -> Result<(), UtilError> where Message: Send + Sync + Clone {
let connection = if let State::Handshake(ref mut h) = self.state {
self.info.id = Some(h.id.clone());
try!(EncryptedConnection::new(h))
@@ -153,8 +155,8 @@ impl Session {
panic!("Unexpected state");
};
self.state = State::Session(connection);
try!(self.write_hello(host));
try!(self.send_ping());
try!(self.write_hello(io, host));
try!(self.send_ping(io));
Ok(())
}
@@ -220,10 +222,11 @@ impl Session {
}
}
if let Some(data) = packet_data {
return Ok(try!(self.read_packet(data, host)));
return Ok(try!(self.read_packet(io, data, host)));
}
if create_session {
try!(self.complete_handshake(host));
try!(self.complete_handshake(io, host));
io.update_registration(self.token()).unwrap_or_else(|e| debug!(target: "network", "Token registration error: {:?}", e));
}
Ok(SessionData::None)
}
@@ -263,7 +266,8 @@ impl Session {
}
/// Send a protocol packet to peer.
pub fn send_packet(&mut self, protocol: &str, packet_id: u8, data: &[u8]) -> Result<(), UtilError> {
pub fn send_packet<Message>(&mut self, io: &IoContext<Message>, protocol: &str, packet_id: u8, data: &[u8]) -> Result<(), UtilError>
where Message: Send + Sync + Clone {
if self.info.capabilities.is_empty() || !self.had_hello {
debug!(target: "network", "Sending to unconfirmed session {}, protocol: {}, packet: {}", self.token(), protocol, packet_id);
return Err(From::from(NetworkError::BadProtocol));
@@ -283,7 +287,7 @@ impl Session {
let mut rlp = RlpStream::new();
rlp.append(&(pid as u32));
rlp.append_raw(data, 1);
self.send(rlp)
self.send(io, rlp)
}
/// Keep this session alive. Returns false if ping timeout happened
@@ -298,10 +302,9 @@ impl Session {
};
if !timed_out && time::precise_time_ns() - self.ping_time_ns > PING_INTERVAL_SEC * 1000_000_000 {
if let Err(e) = self.send_ping() {
if let Err(e) = self.send_ping(io) {
debug!("Error sending ping message: {:?}", e);
}
io.update_registration(self.token()).unwrap_or_else(|e| debug!(target: "network", "Session registration error: {:?}", e));
}
!timed_out
}
@@ -310,7 +313,8 @@ impl Session {
self.connection().token()
}
fn read_packet(&mut self, packet: Packet, host: &HostInfo) -> Result<SessionData, UtilError> {
fn read_packet<Message>(&mut self, io: &IoContext<Message>, packet: Packet, host: &HostInfo) -> Result<SessionData, UtilError>
where Message: Send + Sync + Clone {
if packet.data.len() < 2 {
return Err(From::from(NetworkError::BadProtocol));
}
@@ -321,22 +325,25 @@ impl Session {
match packet_id {
PACKET_HELLO => {
let rlp = UntrustedRlp::new(&packet.data[1..]); //TODO: validate rlp expected size
try!(self.read_hello(&rlp, host));
try!(self.read_hello(io, &rlp, host));
Ok(SessionData::Ready)
},
PACKET_DISCONNECT => {
let rlp = UntrustedRlp::new(&packet.data[1..]);
let reason: u8 = try!(rlp.val_at(0));
if self.had_hello {
debug!("Disconnected: {}: {:?}", self.token(), DisconnectReason::from_u8(reason));
}
Err(From::from(NetworkError::Disconnect(DisconnectReason::from_u8(reason))))
}
PACKET_PING => {
try!(self.send_pong());
Ok(SessionData::None)
try!(self.send_pong(io));
Ok(SessionData::Continue)
},
PACKET_PONG => {
self.pong_time_ns = Some(time::precise_time_ns());
self.info.ping_ms = Some((self.pong_time_ns.unwrap() - self.ping_time_ns) / 1000_000);
Ok(SessionData::None)
Ok(SessionData::Continue)
},
PACKET_GET_PEERS => Ok(SessionData::None), //TODO;
PACKET_PEERS => Ok(SessionData::None),
@@ -346,7 +353,7 @@ impl Session {
i += 1;
if i == self.info.capabilities.len() {
debug!(target: "network", "Unknown packet: {:?}", packet_id);
return Ok(SessionData::None)
return Ok(SessionData::Continue)
}
}
@@ -357,12 +364,12 @@ impl Session {
},
_ => {
debug!(target: "network", "Unknown packet: {:?}", packet_id);
Ok(SessionData::None)
Ok(SessionData::Continue)
}
}
}
fn write_hello(&mut self, host: &HostInfo) -> Result<(), UtilError> {
fn write_hello<Message>(&mut self, io: &IoContext<Message>, host: &HostInfo) -> Result<(), UtilError> where Message: Send + Sync + Clone {
let mut rlp = RlpStream::new();
rlp.append_raw(&[PACKET_HELLO as u8], 0);
rlp.begin_list(5)
@@ -371,10 +378,11 @@ impl Session {
.append(&host.capabilities)
.append(&host.local_endpoint.address.port())
.append(host.id());
self.send(rlp)
self.send(io, rlp)
}
fn read_hello(&mut self, rlp: &UntrustedRlp, host: &HostInfo) -> Result<(), UtilError> {
fn read_hello<Message>(&mut self, io: &IoContext<Message>, rlp: &UntrustedRlp, host: &HostInfo) -> Result<(), UtilError>
where Message: Send + Sync + Clone {
let protocol = try!(rlp.val_at::<u32>(0));
let client_version = try!(rlp.val_at::<String>(1));
let peer_caps = try!(rlp.val_at::<Vec<PeerCapabilityInfo>>(2));
@@ -417,36 +425,36 @@ impl Session {
self.info.capabilities = caps;
if self.info.capabilities.is_empty() {
trace!(target: "network", "No common capabilities with peer.");
return Err(From::from(self.disconnect(DisconnectReason::UselessPeer)));
return Err(From::from(self.disconnect(io, DisconnectReason::UselessPeer)));
}
if protocol != host.protocol_version {
trace!(target: "network", "Peer protocol version mismatch: {}", protocol);
return Err(From::from(self.disconnect(DisconnectReason::UselessPeer)));
return Err(From::from(self.disconnect(io, DisconnectReason::UselessPeer)));
}
self.had_hello = true;
Ok(())
}
/// Senf ping packet
pub fn send_ping(&mut self) -> Result<(), UtilError> {
try!(self.send(try!(Session::prepare(PACKET_PING))));
pub fn send_ping<Message>(&mut self, io: &IoContext<Message>) -> Result<(), UtilError> where Message: Send + Sync + Clone {
try!(self.send(io, try!(Session::prepare(PACKET_PING))));
self.ping_time_ns = time::precise_time_ns();
self.pong_time_ns = None;
Ok(())
}
fn send_pong(&mut self) -> Result<(), UtilError> {
self.send(try!(Session::prepare(PACKET_PONG)))
fn send_pong<Message>(&mut self, io: &IoContext<Message>) -> Result<(), UtilError> where Message: Send + Sync + Clone {
self.send(io, try!(Session::prepare(PACKET_PONG)))
}
/// Disconnect this session
pub fn disconnect(&mut self, reason: DisconnectReason) -> NetworkError {
pub fn disconnect<Message>(&mut self, io: &IoContext<Message>, reason: DisconnectReason) -> NetworkError where Message: Send + Sync + Clone {
if let State::Session(_) = self.state {
let mut rlp = RlpStream::new();
rlp.append(&(PACKET_DISCONNECT as u32));
rlp.begin_list(1);
rlp.append(&(reason as u32));
self.send(rlp).ok();
self.send(io, rlp).ok();
}
NetworkError::Disconnect(reason)
}
@@ -458,13 +466,13 @@ impl Session {
Ok(rlp)
}
fn send(&mut self, rlp: RlpStream) -> Result<(), UtilError> {
fn send<Message>(&mut self, io: &IoContext<Message>, rlp: RlpStream) -> Result<(), UtilError> where Message: Send + Sync + Clone {
match self.state {
State::Handshake(_) => {
warn!(target:"network", "Unexpected send request");
},
State::Session(ref mut s) => {
try!(s.send_packet(&rlp.out()))
try!(s.send_packet(io, &rlp.out()))
},
}
Ok(())

View File

@@ -65,7 +65,7 @@ impl NetworkStats {
self.sessions.load(Ordering::Relaxed)
}
#[cfg(test)]
/// Create a new empty instance.
pub fn new() -> NetworkStats {
NetworkStats {
recv: AtomicUsize::new(0),

View File

@@ -97,7 +97,8 @@ impl NetworkProtocolHandler<TestProtocolMessage> for TestProtocol {
#[test]
fn net_service() {
let mut service = NetworkService::<TestProtocolMessage>::start(NetworkConfiguration::new_local()).expect("Error creating network service");
let service = NetworkService::<TestProtocolMessage>::new(NetworkConfiguration::new_local()).expect("Error creating network service");
service.start().unwrap();
service.register_protocol(Arc::new(TestProtocol::new(false)), "myproto", &[1u8]).unwrap();
}
@@ -108,12 +109,14 @@ fn net_connect() {
let mut config1 = NetworkConfiguration::new_local();
config1.use_secret = Some(key1.secret().clone());
config1.boot_nodes = vec![ ];
let mut service1 = NetworkService::<TestProtocolMessage>::start(config1).unwrap();
let mut service1 = NetworkService::<TestProtocolMessage>::new(config1).unwrap();
service1.start().unwrap();
let handler1 = TestProtocol::register(&mut service1, false);
let mut config2 = NetworkConfiguration::new_local();
info!("net_connect: local URL: {}", service1.local_url());
config2.boot_nodes = vec![ service1.local_url() ];
let mut service2 = NetworkService::<TestProtocolMessage>::start(config2).unwrap();
info!("net_connect: local URL: {}", service1.local_url().unwrap());
config2.boot_nodes = vec![ service1.local_url().unwrap() ];
let mut service2 = NetworkService::<TestProtocolMessage>::new(config2).unwrap();
service2.start().unwrap();
let handler2 = TestProtocol::register(&mut service2, false);
while !handler1.got_packet() && !handler2.got_packet() && (service1.stats().sessions() == 0 || service2.stats().sessions() == 0) {
thread::sleep(Duration::from_millis(50));
@@ -122,17 +125,28 @@ fn net_connect() {
assert!(service2.stats().sessions() >= 1);
}
#[test]
fn net_start_stop() {
let config = NetworkConfiguration::new_local();
let service = NetworkService::<TestProtocolMessage>::new(config).unwrap();
service.start().unwrap();
service.stop().unwrap();
service.start().unwrap();
}
#[test]
fn net_disconnect() {
let key1 = KeyPair::create().unwrap();
let mut config1 = NetworkConfiguration::new_local();
config1.use_secret = Some(key1.secret().clone());
config1.boot_nodes = vec![ ];
let mut service1 = NetworkService::<TestProtocolMessage>::start(config1).unwrap();
let mut service1 = NetworkService::<TestProtocolMessage>::new(config1).unwrap();
service1.start().unwrap();
let handler1 = TestProtocol::register(&mut service1, false);
let mut config2 = NetworkConfiguration::new_local();
config2.boot_nodes = vec![ service1.local_url() ];
let mut service2 = NetworkService::<TestProtocolMessage>::start(config2).unwrap();
config2.boot_nodes = vec![ service1.local_url().unwrap() ];
let mut service2 = NetworkService::<TestProtocolMessage>::new(config2).unwrap();
service2.start().unwrap();
let handler2 = TestProtocol::register(&mut service2, true);
while !(handler1.got_disconnect() && handler2.got_disconnect()) {
thread::sleep(Duration::from_millis(50));
@@ -144,7 +158,8 @@ fn net_disconnect() {
#[test]
fn net_timeout() {
let config = NetworkConfiguration::new_local();
let mut service = NetworkService::<TestProtocolMessage>::start(config).unwrap();
let mut service = NetworkService::<TestProtocolMessage>::new(config).unwrap();
service.start().unwrap();
let handler = TestProtocol::register(&mut service, false);
while !handler.got_timeout() {
thread::sleep(Duration::from_millis(50));

View File

@@ -92,7 +92,7 @@ impl OverlayDB {
///
/// Returns either an error or the number of items changed in the backing database.
///
/// Will return an error if the number of `kill()`s ever exceeds the number of
/// Will return an error if the number of `remove()`s ever exceeds the number of
/// `insert()`s for any key. This will leave the database in an undeterminate
/// state. Don't ever let it happen.
///
@@ -104,15 +104,15 @@ impl OverlayDB {
/// fn main() {
/// let mut m = OverlayDB::new_temp();
/// let key = m.insert(b"foo"); // insert item.
/// assert!(m.exists(&key)); // key exists (in memory).
/// assert!(m.contains(&key)); // key exists (in memory).
/// assert_eq!(m.commit().unwrap(), 1); // 1 item changed.
/// assert!(m.exists(&key)); // key still exists (in backing).
/// m.kill(&key); // delete item.
/// assert!(!m.exists(&key)); // key "doesn't exist" (though still does in backing).
/// m.kill(&key); // oh dear... more kills than inserts for the key...
/// assert!(m.contains(&key)); // key still exists (in backing).
/// m.remove(&key); // delete item.
/// assert!(!m.contains(&key)); // key "doesn't exist" (though still does in backing).
/// m.remove(&key); // oh dear... more removes than inserts for the key...
/// //m.commit().unwrap(); // this commit/unwrap would cause a panic.
/// m.revert(); // revert both kills.
/// assert!(m.exists(&key)); // key now still exists.
/// m.revert(); // revert both removes.
/// assert!(m.contains(&key)); // key now still exists.
/// }
/// ```
pub fn commit(&mut self) -> Result<u32, UtilError> {
@@ -224,7 +224,7 @@ impl HashDB for OverlayDB {
}
ret
}
fn lookup(&self, key: &H256) -> Option<&[u8]> {
fn get(&self, key: &H256) -> Option<&[u8]> {
// return ok if positive; if negative, check backing - might be enough references there to make
// it positive again.
let k = self.overlay.raw(key);
@@ -249,7 +249,7 @@ impl HashDB for OverlayDB {
}
}
}
fn exists(&self, key: &H256) -> bool {
fn contains(&self, key: &H256) -> bool {
// return ok if positive; if negative, check backing - might be enough references there to make
// it positive again.
let k = self.overlay.raw(key);
@@ -271,7 +271,7 @@ impl HashDB for OverlayDB {
}
fn insert(&mut self, value: &[u8]) -> H256 { self.overlay.insert(value) }
fn emplace(&mut self, key: H256, value: Bytes) { self.overlay.emplace(key, value); }
fn kill(&mut self, key: &H256) { self.overlay.kill(key); }
fn remove(&mut self, key: &H256) { self.overlay.remove(key); }
}
#[test]

View File

@@ -15,6 +15,7 @@
// along with Parity. If not, see <http://www.gnu.org/licenses/>.
//! Path utilities
use std::path::Path;
/// Default ethereum paths
pub mod ethereum {
@@ -62,3 +63,21 @@ pub mod ethereum {
pth
}
}
/// Restricts the permissions of given path only to the owner.
#[cfg(not(windows))]
pub fn restrict_permissions_owner(file_path: &Path) -> Result<(), i32> {
let cstr = ::std::ffi::CString::new(file_path.to_str().unwrap()).unwrap();
match unsafe { ::libc::chmod(cstr.as_ptr(), ::libc::S_IWUSR | ::libc::S_IRUSR) } {
0 => Ok(()),
x => Err(x),
}
}
/// Restricts the permissions of given path only to the owner.
#[cfg(windows)]
pub fn restrict_permissions_owner(_file_path: &Path) -> Result<(), i32> {
//TODO: implement me
Ok(())
}

View File

@@ -258,14 +258,7 @@ impl <T>FromBytes for T where T: FixedHash {
Ordering::Equal => ()
};
unsafe {
use std::{mem, ptr};
let mut res: T = mem::uninitialized();
ptr::copy(bytes.as_ptr(), res.as_slice_mut().as_mut_ptr(), T::len());
Ok(res)
}
Ok(T::from_slice(bytes))
}
}

View File

@@ -429,3 +429,10 @@ fn test_rlp_nested_empty_list_encode() {
assert_eq!(stream.drain()[..], [0xc2u8, 0xc0u8, 40u8][..]);
}
#[test]
fn test_rlp_list_length_overflow() {
let data: Vec<u8> = vec![0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00];
let rlp = UntrustedRlp::new(&data);
let as_val: Result<String, DecoderError> = rlp.val_at(0);
assert_eq!(Err(DecoderError::RlpIsTooShort), as_val);
}

View File

@@ -334,9 +334,9 @@ impl<'a> BasicDecoder<'a> {
/// Return first item info.
fn payload_info(bytes: &[u8]) -> Result<PayloadInfo, DecoderError> {
let item = try!(PayloadInfo::from(bytes));
match item.header_len + item.value_len <= bytes.len() {
true => Ok(item),
false => Err(DecoderError::RlpIsTooShort),
match item.header_len.checked_add(item.value_len) {
Some(x) if x <= bytes.len() => Ok(item),
_ => Err(DecoderError::RlpIsTooShort),
}
}
}

51
util/src/timer.rs Normal file
View File

@@ -0,0 +1,51 @@
// Copyright 2015, 2016 Ethcore (UK) Ltd.
// This file is part of Parity.
// Parity is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// Parity is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with Parity. If not, see <http://www.gnu.org/licenses/>.
//! Performance timer with logging
use time::precise_time_ns;
/// Performance timer with logging. Starts measuring time in the constructor, prints
/// elapsed time in the destructor or when `stop` is called.
pub struct PerfTimer {
name: &'static str,
start: u64,
stopped: bool,
}
impl PerfTimer {
/// Create an instance with given name.
pub fn new(name: &'static str) -> PerfTimer {
PerfTimer {
name: name,
start: precise_time_ns(),
stopped: false,
}
}
/// Stop the timer and print elapsed time on trace level with `perf` target.
pub fn stop(&mut self) {
if !self.stopped {
trace!(target: "perf", "{}: {:.2}ms", self.name, (precise_time_ns() - self.start) as f32 / 1000_000.0);
self.stopped = true;
}
}
}
impl Drop for PerfTimer {
fn drop(&mut self) {
self.stop()
}
}

View File

@@ -133,7 +133,7 @@ impl<'db> TrieDB<'db> {
/// Get the data of the root node.
fn root_data(&self) -> &[u8] {
self.db.lookup(&self.root).expect("Trie root not found!")
self.db.get(&self.root).expect("Trie root not found!")
}
/// Get the root node as a `Node`.
@@ -184,7 +184,7 @@ impl<'db> TrieDB<'db> {
/// Return optional data for a key given as a `NibbleSlice`. Returns `None` if no data exists.
fn do_lookup<'a, 'key>(&'a self, key: &NibbleSlice<'key>) -> Option<&'a [u8]> where 'a: 'key {
let root_rlp = self.db.lookup(&self.root).expect("Trie root not found!");
let root_rlp = self.db.get(&self.root).expect("Trie root not found!");
self.get_from_node(&root_rlp, key)
}
@@ -213,7 +213,7 @@ impl<'db> TrieDB<'db> {
// check if its sha3 + len
let r = Rlp::new(node);
match r.is_data() && r.size() == 32 {
true => self.db.lookup(&r.as_val::<H256>()).unwrap_or_else(|| panic!("Not found! {:?}", r.as_val::<H256>())),
true => self.db.get(&r.as_val::<H256>()).unwrap_or_else(|| panic!("Not found! {:?}", r.as_val::<H256>())),
false => node
}
}
@@ -349,7 +349,7 @@ impl<'db> Trie for TrieDB<'db> {
impl<'db> fmt::Debug for TrieDB<'db> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
try!(writeln!(f, "c={:?} [", self.hash_count));
let root_rlp = self.db.lookup(&self.root).expect("Trie root not found!");
let root_rlp = self.db.get(&self.root).expect("Trie root not found!");
try!(self.fmt_all(Node::decoded(root_rlp), f, 0));
writeln!(f, "]")
}

View File

@@ -87,7 +87,7 @@ impl<'db> TrieDBMut<'db> {
/// Create a new trie with the backing database `db` and `root`.
/// Returns an error if `root` does not exist.
pub fn from_existing(db: &'db mut HashDB, root: &'db mut H256) -> Result<Self, TrieError> {
if !db.exists(root) {
if !db.contains(root) {
Err(TrieError::InvalidStateRoot)
} else {
Ok(TrieDBMut {
@@ -143,7 +143,7 @@ impl<'db> TrieDBMut<'db> {
/// Set the trie to a new root node's RLP, inserting the new RLP into the backing database
/// and removing the old.
fn set_root_rlp(&mut self, root_data: &[u8]) {
self.db.kill(&self.root);
self.db.remove(&self.root);
*self.root = self.db.insert(root_data);
self.hash_count += 1;
trace!("set_root_rlp {:?} {:?}", root_data.pretty(), self.root);
@@ -174,7 +174,7 @@ impl<'db> TrieDBMut<'db> {
/// Get the root node's RLP.
fn root_node(&self) -> Node {
Node::decoded(self.db.lookup(&self.root).expect("Trie root not found!"))
Node::decoded(self.db.get(&self.root).expect("Trie root not found!"))
}
/// Get the root node as a `Node`.
@@ -225,7 +225,7 @@ impl<'db> TrieDBMut<'db> {
/// Return optional data for a key given as a `NibbleSlice`. Returns `None` if no data exists.
fn do_lookup<'a, 'key>(&'a self, key: &NibbleSlice<'key>) -> Option<&'a [u8]> where 'a: 'key {
let root_rlp = self.db.lookup(&self.root).expect("Trie root not found!");
let root_rlp = self.db.get(&self.root).expect("Trie root not found!");
self.get_from_node(&root_rlp, key)
}
@@ -254,7 +254,7 @@ impl<'db> TrieDBMut<'db> {
// check if its sha3 + len
let r = Rlp::new(node);
match r.is_data() && r.size() == 32 {
true => self.db.lookup(&r.as_val::<H256>()).expect("Not found!"),
true => self.db.get(&r.as_val::<H256>()).expect("Not found!"),
false => node
}
}
@@ -266,7 +266,7 @@ impl<'db> TrieDBMut<'db> {
trace!("ADD: {:?} {:?}", key, value.pretty());
// determine what the new root is, insert new nodes and remove old as necessary.
let mut todo: Journal = Journal::new();
let root_rlp = self.augmented(self.db.lookup(&self.root).expect("Trie root not found!"), key, value, &mut todo);
let root_rlp = self.augmented(self.db.get(&self.root).expect("Trie root not found!"), key, value, &mut todo);
self.apply(todo);
self.set_root_rlp(&root_rlp);
trace!("/");
@@ -279,7 +279,7 @@ impl<'db> TrieDBMut<'db> {
trace!("DELETE: {:?}", key);
// determine what the new root is, insert new nodes and remove old as necessary.
let mut todo: Journal = Journal::new();
match self.cleared_from_slice(self.db.lookup(&self.root).expect("Trie root not found!"), key, &mut todo) {
match self.cleared_from_slice(self.db.get(&self.root).expect("Trie root not found!"), key, &mut todo) {
Some(root_rlp) => {
self.apply(todo);
self.set_root_rlp(&root_rlp);
@@ -335,7 +335,7 @@ impl<'db> TrieDBMut<'db> {
}
else if rlp.is_data() && rlp.size() == 32 {
let h = rlp.as_val();
let r = self.db.lookup(&h).unwrap_or_else(||{
let r = self.db.get(&h).unwrap_or_else(||{
println!("Node not found! rlp={:?}, node_hash={:?}", rlp.as_raw().pretty(), h);
println!("Journal: {:?}", journal);
panic!();
@@ -670,7 +670,7 @@ impl<'db> TrieMut for TrieDBMut<'db> {
impl<'db> fmt::Debug for TrieDBMut<'db> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
try!(writeln!(f, "c={:?} [", self.hash_count));
let root_rlp = self.db.lookup(&self.root).expect("Trie root not found!");
let root_rlp = self.db.get(&self.root).expect("Trie root not found!");
try!(self.fmt_all(Node::decoded(root_rlp), f, 0));
writeln!(f, "]")
}