diff --git a/.gitignore b/.gitignore index 959045cf9..bc882627b 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,6 @@ Cargo.lock # Vim *.swp + +# GDB +*.gdb_history diff --git a/Cargo.toml b/Cargo.toml index 486cc6a8c..9da1b73c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,13 +11,13 @@ log = "0.3" env_logger = "0.3" rustc-serialize = "0.3" arrayvec = "0.3" -mio = "0.4.4" +mio = "0.5.0" rand = "0.3.12" time = "0.1.34" tiny-keccak = "1.0" rocksdb = "0.2" -lazy_static = "0.1.*" -secp256k1 = "0.5.1" +lazy_static = "0.1" +eth-secp256k1 = { git = "https://github.com/arkpar/rust-secp256k1.git" } rust-crypto = "0.2.34" elastic-array = "0.4" heapsize = "0.2" diff --git a/src/bytes.rs b/src/bytes.rs index 5ffc72c08..8d33a4108 100644 --- a/src/bytes.rs +++ b/src/bytes.rs @@ -1,27 +1,27 @@ //! Unified interfaces for bytes operations on basic types -//! +//! //! # Examples //! ```rust //! extern crate ethcore_util as util; -//! +//! //! fn bytes_convertable() { //! use util::bytes::BytesConvertable; //! //! let arr = [0; 5]; //! let slice: &[u8] = arr.bytes(); //! } -//! +//! //! fn to_bytes() { //! use util::bytes::ToBytes; -//! +//! //! let a: Vec = "hello_world".to_bytes(); //! let b: Vec = 400u32.to_bytes(); //! let c: Vec = 0xffffffffffffffffu64.to_bytes(); //! } -//! +//! //! fn from_bytes() { //! use util::bytes::FromBytes; -//! +//! //! let a = String::from_bytes(&[b'd', b'o', b'g']); //! let b = u16::from_bytes(&[0xfa]); //! let c = u64::from_bytes(&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]); @@ -142,7 +142,7 @@ impl <'a> ToBytes for &'a str { fn to_bytes(&self) -> Vec { From::from(*self) } - + fn to_bytes_len(&self) -> usize { self.len() } } @@ -151,7 +151,7 @@ impl ToBytes for String { let s: &str = self.as_ref(); From::from(s) } - + fn to_bytes_len(&self) -> usize { self.len() } } @@ -170,6 +170,14 @@ impl ToBytes for u64 { fn to_bytes_len(&self) -> usize { 8 - self.leading_zeros() as usize / 8 } } +impl ToBytes for bool { + fn to_bytes(&self) -> Vec { + vec![ if *self { 1u8 } else { 0u8 } ] + } + + fn to_bytes_len(&self) -> usize { 1 } +} + macro_rules! impl_map_to_bytes { ($from: ident, $to: ty) => { impl ToBytes for $from { @@ -186,7 +194,7 @@ impl_map_to_bytes!(u32, u64); macro_rules! impl_uint_to_bytes { ($name: ident) => { impl ToBytes for $name { - fn to_bytes(&self) -> Vec { + fn to_bytes(&self) -> Vec { let mut res= vec![]; let count = self.to_bytes_len(); res.reserve(count); @@ -214,7 +222,7 @@ impl ToBytes for T where T: FixedHash { ptr::copy(self.bytes().as_ptr(), res.as_mut_ptr(), T::size()); res.set_len(T::size()); } - + res } } @@ -268,6 +276,17 @@ impl FromBytes for u64 { } } +impl FromBytes for bool { + fn from_bytes(bytes: &[u8]) -> FromBytesResult { + match bytes.len() { + 0 => Ok(false), + 1 => Ok(bytes[0] != 0), + _ => Err(FromBytesError::DataIsTooLong), + } + } +} + + macro_rules! impl_map_from_bytes { ($from: ident, $to: ident) => { impl FromBytes for $from { diff --git a/src/crypto.rs b/src/crypto.rs index 70b3d0d0b..a80045e55 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -1,6 +1,5 @@ use hash::*; -use secp256k1::Secp256k1; -use secp256k1::key; +use secp256k1::{key, Secp256k1}; use rand::os::OsRng; pub type Secret = H256; @@ -52,8 +51,8 @@ impl From<::std::io::Error> for CryptoError { #[derive(Debug, PartialEq, Eq)] /// secp256k1 Key pair /// -/// Use `create()` to create a new random key pair. -/// +/// Use `create()` to create a new random key pair. +/// /// # Example /// ```rust /// extern crate ethcore_util; @@ -62,10 +61,10 @@ impl From<::std::io::Error> for CryptoError { /// fn main() { /// let pair = KeyPair::create().unwrap(); /// let message = H256::random(); -/// let signature = sign(pair.secret(), &message).unwrap(); +/// let signature = ec::sign(pair.secret(), &message).unwrap(); /// -/// assert!(verify(pair.public(), &signature, &message).unwrap()); -/// assert_eq!(recover(&signature, &message).unwrap(), *pair.public()); +/// assert!(ec::verify(pair.public(), &signature, &message).unwrap()); +/// assert_eq!(ec::recover(&signature, &message).unwrap(), *pair.public()); /// } /// ``` pub struct KeyPair { @@ -109,63 +108,207 @@ impl KeyPair { } /// Sign a message with our secret key. - pub fn sign(&self, message: &H256) -> Result { sign(&self.secret, message) } + pub fn sign(&self, message: &H256) -> Result { ec::sign(&self.secret, message) } } -/// Recovers Public key from signed message hash. -pub fn recover(signature: &Signature, message: &H256) -> Result { - use secp256k1::*; - let context = Secp256k1::new(); - let rsig = try!(RecoverableSignature::from_compact(&context, &signature[0..64], try!(RecoveryId::from_i32(signature[64] as i32)))); - let publ = try!(context.recover(&try!(Message::from_slice(&message)), &rsig)); - let serialized = publ.serialize_vec(&context, false); - let p: Public = Public::from_slice(&serialized[1..65]); - //TODO: check if it's the zero key and fail if so. +pub mod ec { + use hash::*; + use crypto::*; + use crypto::{self}; - Ok(p) -} + /// Recovers Public key from signed message hash. + pub fn recover(signature: &Signature, message: &H256) -> Result { + use secp256k1::*; + let context = Secp256k1::new(); + let rsig = try!(RecoverableSignature::from_compact(&context, &signature[0..64], try!(RecoveryId::from_i32(signature[64] as i32)))); + let publ = try!(context.recover(&try!(Message::from_slice(&message)), &rsig)); + let serialized = publ.serialize_vec(&context, false); + let p: Public = Public::from_slice(&serialized[1..65]); + //TODO: check if it's the zero key and fail if so. + Ok(p) + } + /// Returns siganture of message hash. + pub fn sign(secret: &Secret, message: &H256) -> Result { + use secp256k1::*; + let context = Secp256k1::new(); + let sec: &key::SecretKey = unsafe { ::std::mem::transmute(secret) }; + let s = try!(context.sign_recoverable(&try!(Message::from_slice(&message)), sec)); + let (rec_id, data) = s.serialize_compact(&context); + let mut signature: crypto::Signature = unsafe { ::std::mem::uninitialized() }; + signature.clone_from_slice(&data); + signature[64] = rec_id.to_i32() as u8; + Ok(signature) + } + /// Verify signature. + pub fn verify(public: &Public, signature: &Signature, message: &H256) -> Result { + use secp256k1::*; + let context = Secp256k1::new(); + let rsig = try!(RecoverableSignature::from_compact(&context, &signature[0..64], try!(RecoveryId::from_i32(signature[64] as i32)))); + let sig = rsig.to_standard(&context); -/// Returns siganture of message hash. -pub fn sign(secret: &Secret, message: &H256) -> Result { - use secp256k1::*; - let context = Secp256k1::new(); - let sec: &key::SecretKey = unsafe { ::std::mem::transmute(secret) }; - let s = try!(context.sign_recoverable(&try!(Message::from_slice(&message)), sec)); - let (rec_id, data) = s.serialize_compact(&context); - let mut signature: ::crypto::Signature = unsafe { ::std::mem::uninitialized() }; - signature.clone_from_slice(&data); - signature[64] = rec_id.to_i32() as u8; - Ok(signature) -} + let mut pdata: [u8; 65] = [4u8; 65]; + let ptr = pdata[1..].as_mut_ptr(); + let src = public.as_ptr(); + unsafe { ::std::ptr::copy_nonoverlapping(src, ptr, 64) }; + let publ = try!(key::PublicKey::from_slice(&context, &pdata)); + match context.verify(&try!(Message::from_slice(&message)), &sig, &publ) { + Ok(_) => Ok(true), + Err(Error::IncorrectSignature) => Ok(false), + Err(x) => Err(>::from(x)) + } + } -/// Check if each component of the signature is in range. -pub fn is_valid(sig: &Signature) -> bool { - sig[64] <= 1 && - H256::from_slice(&sig[0..32]) < h256_from_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141") && - H256::from_slice(&sig[32..64]) < h256_from_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141") && - H256::from_slice(&sig[32..64]) >= h256_from_u64(1) && - H256::from_slice(&sig[0..32]) >= h256_from_u64(1) -} - -/// Verify signature. -pub fn verify(public: &Public, signature: &Signature, message: &H256) -> Result { - use secp256k1::*; - let context = Secp256k1::new(); - let rsig = try!(RecoverableSignature::from_compact(&context, &signature[0..64], try!(RecoveryId::from_i32(signature[64] as i32)))); - let sig = rsig.to_standard(&context); - - let mut pdata: [u8; 65] = [4u8; 65]; - let ptr = pdata[1..].as_mut_ptr(); - let src = public.as_ptr(); - unsafe { ::std::ptr::copy_nonoverlapping(src, ptr, 64) }; - let publ = try!(key::PublicKey::from_slice(&context, &pdata)); - match context.verify(&try!(Message::from_slice(&message)), &sig, &publ) { - Ok(_) => Ok(true), - Err(Error::IncorrectSignature) => Ok(false), - Err(x) => Err(>::from(x)) + /// Check if each component of the signature is in range. + pub fn is_valid(sig: &Signature) -> bool { + sig[64] <= 1 && + H256::from_slice(&sig[0..32]) < h256_from_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141") && + H256::from_slice(&sig[32..64]) < h256_from_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141") && + H256::from_slice(&sig[32..64]) >= h256_from_u64(1) && + H256::from_slice(&sig[0..32]) >= h256_from_u64(1) } } +pub mod ecdh { + use crypto::*; + + pub fn agree(secret: &Secret, public: &Public, ) -> Result { + use secp256k1::*; + let context = Secp256k1::new(); + let mut pdata: [u8; 65] = [4u8; 65]; + let ptr = pdata[1..].as_mut_ptr(); + let src = public.as_ptr(); + unsafe { ::std::ptr::copy_nonoverlapping(src, ptr, 64) }; + let publ = try!(key::PublicKey::from_slice(&context, &pdata)); + let sec: &key::SecretKey = unsafe { ::std::mem::transmute(secret) }; + let shared = ecdh::SharedSecret::new_raw(&context, &publ, &sec); + let s: Secret = unsafe { ::std::mem::transmute(shared) }; + Ok(s) + } +} + +pub mod ecies { + use hash::*; + use bytes::*; + use crypto::*; + + pub fn encrypt(public: &Public, plain: &[u8]) -> Result { + use ::rcrypto::digest::Digest; + use ::rcrypto::sha2::Sha256; + use ::rcrypto::hmac::Hmac; + use ::rcrypto::mac::Mac; + let r = try!(KeyPair::create()); + let z = try!(ecdh::agree(r.secret(), public)); + let mut key = [0u8; 32]; + let mut mkey = [0u8; 32]; + kdf(&z, &[0u8; 0], &mut key); + let mut hasher = Sha256::new(); + let mkey_material = &key[16..32]; + hasher.input(mkey_material); + hasher.result(&mut mkey); + let ekey = &key[0..16]; + + let mut msg = vec![0u8; (1 + 64 + 16 + plain.len() + 32)]; + msg[0] = 0x04u8; + { + let msgd = &mut msg[1..]; + r.public().copy_to(&mut msgd[0..64]); + { + let cipher = &mut msgd[(64 + 16)..(64 + 16 + plain.len())]; + aes::encrypt(ekey, &H128::new(), plain, cipher); + } + let mut hmac = Hmac::new(Sha256::new(), &mkey); + { + let cipher_iv = &msgd[64..(64 + 16 + plain.len())]; + hmac.input(cipher_iv); + } + hmac.raw_result(&mut msgd[(64 + 16 + plain.len())..]); + } + Ok(msg) + } + + pub fn decrypt(secret: &Secret, encrypted: &[u8]) -> Result { + use ::rcrypto::digest::Digest; + use ::rcrypto::sha2::Sha256; + use ::rcrypto::hmac::Hmac; + use ::rcrypto::mac::Mac; + + let meta_len = 1 + 64 + 16 + 32; + if encrypted.len() < meta_len || encrypted[0] < 2 || encrypted[0] > 4 { + return Err(CryptoError::InvalidMessage); //invalid message: publickey + } + + let e = &encrypted[1..]; + let p = Public::from_slice(&e[0..64]); + let z = try!(ecdh::agree(secret, &p)); + let mut key = [0u8; 32]; + kdf(&z, &[0u8; 0], &mut key); + let ekey = &key[0..16]; + let mkey_material = &key[16..32]; + let mut hasher = Sha256::new(); + let mut mkey = [0u8; 32]; + hasher.input(mkey_material); + hasher.result(&mut mkey); + + let clen = encrypted.len() - meta_len; + let cipher_with_iv = &e[64..(64+16+clen)]; + let cipher_iv = &cipher_with_iv[0..16]; + let cipher_no_iv = &cipher_with_iv[16..]; + let msg_mac = &e[(64+16+clen)..]; + + // Verify tag + let mut hmac = Hmac::new(Sha256::new(), &mkey); + hmac.input(cipher_with_iv); + let mut mac = H256::new(); + hmac.raw_result(&mut mac); + if &mac[..] != msg_mac { + return Err(CryptoError::InvalidMessage); + } + + let mut msg = vec![0u8; clen]; + aes::decrypt(ekey, cipher_iv, cipher_no_iv, &mut msg[..]); + Ok(msg) + } + + fn kdf(secret: &Secret, s1: &[u8], dest: &mut [u8]) { + use ::rcrypto::digest::Digest; + use ::rcrypto::sha2::Sha256; + let mut hasher = Sha256::new(); + // SEC/ISO/Shoup specify counter size SHOULD be equivalent + // to size of hash output, however, it also notes that + // the 4 bytes is okay. NIST specifies 4 bytes. + let mut ctr = 1u32; + let mut written = 0usize; + while written < dest.len() { + let ctrs = [(ctr >> 24) as u8, (ctr >> 16) as u8, (ctr >> 8) as u8, ctr as u8]; + hasher.input(&ctrs); + hasher.input(secret); + hasher.input(s1); + hasher.result(&mut dest[written..(written + 32)]); + hasher.reset(); + written += 32; + ctr += 1; + } + } +} + +pub mod aes { + use ::rcrypto::blockmodes::*; + use ::rcrypto::aessafe::*; + use ::rcrypto::symmetriccipher::*; + use ::rcrypto::buffer::*; + + pub fn encrypt(k: &[u8], iv: &[u8], plain: &[u8], dest: &mut [u8]) { + let mut encryptor = CtrMode::new(AesSafe128Encryptor::new(k), iv.to_vec()); + encryptor.encrypt(&mut RefReadBuffer::new(plain), &mut RefWriteBuffer::new(dest), true).expect("Invalid length or padding"); + } + + pub fn decrypt(k: &[u8], iv: &[u8], encrypted: &[u8], dest: &mut [u8]) { + let mut encryptor = CtrMode::new(AesSafe128Encryptor::new(k), iv.to_vec()); + encryptor.decrypt(&mut RefReadBuffer::new(encrypted), &mut RefWriteBuffer::new(dest), true).expect("Invalid length or padding"); + } +} + + #[cfg(test)] mod tests { use hash::*; @@ -177,10 +320,10 @@ mod tests { fn test_signature() { let pair = KeyPair::create().unwrap(); let message = H256::random(); - let signature = sign(pair.secret(), &message).unwrap(); + let signature = ec::sign(pair.secret(), &message).unwrap(); - assert!(verify(pair.public(), &signature, &message).unwrap()); - assert_eq!(recover(&signature, &message).unwrap(), *pair.public()); + assert!(ec::verify(pair.public(), &signature, &message).unwrap()); + assert_eq!(ec::recover(&signature, &message).unwrap(), *pair.public()); } #[test] diff --git a/src/hash.rs b/src/hash.rs index 2fb4b71e6..9eafa5dfb 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -4,7 +4,7 @@ use std::str::FromStr; use std::fmt; use std::ops; use std::hash::{Hash, Hasher}; -use std::ops::{Index, IndexMut, Deref, DerefMut, BitOr, BitAnd}; +use std::ops::{Index, IndexMut, Deref, DerefMut, BitOr, BitAnd, BitXor}; use std::cmp::{PartialOrd, Ordering}; use rustc_serialize::hex::*; use error::EthcoreError; @@ -15,7 +15,7 @@ use math::log2; use uint::U256; /// Trait for a fixed-size byte array to be used as the output of hash functions. -/// +/// /// Note: types implementing `FixedHash` must be also `BytesConvertable`. pub trait FixedHash: Sized + BytesConvertable + Populatable { fn new() -> Self; @@ -24,10 +24,12 @@ pub trait FixedHash: Sized + BytesConvertable + Populatable { fn size() -> usize; fn from_slice(src: &[u8]) -> Self; fn clone_from_slice(&mut self, src: &[u8]) -> usize; + fn copy_to(&self, dest: &mut [u8]); fn shift_bloom<'a, T>(&'a mut self, b: &T) -> &'a mut Self where T: FixedHash; fn bloom_part(&self, m: usize) -> T where T: FixedHash; fn contains_bloom(&self, b: &T) -> bool where T: FixedHash; fn contains<'a>(&'a self, b: &'a Self) -> bool; + fn is_zero(&self) -> bool; } macro_rules! impl_hash { @@ -94,6 +96,13 @@ macro_rules! impl_hash { r } + fn copy_to(&self, dest: &mut[u8]) { + unsafe { + let min = ::std::cmp::min($size, dest.len()); + ::std::ptr::copy(self.0.as_ptr(), dest.as_mut_ptr(), min); + } + } + fn shift_bloom<'a, T>(&'a mut self, b: &T) -> &'a mut Self where T: FixedHash { let bp: Self = b.bloom_part($size); let new_self = &bp | self; @@ -152,6 +161,10 @@ macro_rules! impl_hash { fn contains<'a>(&'a self, b: &'a Self) -> bool { &(b & self) == b } + + fn is_zero(&self) -> bool { + self.eq(&Self::new()) + } } impl FromStr for $from { @@ -311,6 +324,30 @@ macro_rules! impl_hash { } } + /// BitXor on references + impl <'a> BitXor for &'a $from { + type Output = $from; + + fn bitxor(self, rhs: Self) -> Self::Output { + unsafe { + use std::mem; + let mut ret: $from = mem::uninitialized(); + for i in 0..$size { + ret.0[i] = self.0[i] ^ rhs.0[i]; + } + ret + } + } + } + + /// Moving BitXor + impl BitXor for $from { + type Output = $from; + + fn bitxor(self, rhs: Self) -> Self::Output { + &self ^ &rhs + } + } impl $from { pub fn hex(&self) -> String { format!("{}", self) diff --git a/src/lib.rs b/src/lib.rs index 7cc4c35eb..4827860c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,22 +3,22 @@ //! ### Rust version: //! - beta //! - nightly -//! +//! //! ### Supported platforms: //! - OSX //! - Linux //! //! ### Dependencies: -//! - RocksDB 3.13 -//! +//! - RocksDB 3.13 +//! //! ### Dependencies Installation: //! //! - OSX: -//! +//! //! ```bash //! brew install rocksdb //! ``` -//! +//! //! - From source: //! //! ```bash @@ -69,6 +69,7 @@ pub mod nibbleslice; pub mod heapsizeof; pub mod squeeze; pub mod semantic_version; +pub mod network; pub use common::*; pub use rlp::*; @@ -84,5 +85,4 @@ pub use nibbleslice::*; pub use heapsizeof::*; pub use squeeze::*; pub use semantic_version::*; - -//pub mod network; +pub use network::*; diff --git a/src/network/connection.rs b/src/network/connection.rs new file mode 100644 index 000000000..2a8025b83 --- /dev/null +++ b/src/network/connection.rs @@ -0,0 +1,364 @@ +use std::collections::VecDeque; +use mio::{Token, EventSet, EventLoop, Timeout, PollOpt, TryRead, TryWrite}; +use mio::tcp::*; +use hash::*; +use sha3::*; +use bytes::*; +use rlp::*; +use std::io::{self, Cursor, Read}; +use network::host::{Host}; +use network::Error; +use network::handshake::Handshake; +use crypto; +use rcrypto::blockmodes::*; +use rcrypto::aessafe::*; +use rcrypto::symmetriccipher::*; +use rcrypto::buffer::*; +use tiny_keccak::Keccak; + +const ENCRYPTED_HEADER_LEN: usize = 32; + +pub struct Connection { + pub token: Token, + pub socket: TcpStream, + rec_buf: Bytes, + rec_size: usize, + send_queue: VecDeque>, + interest: EventSet, +} + +#[derive(PartialEq, Eq)] +pub enum WriteStatus { + Ongoing, + Complete +} + +impl Connection { + pub fn new(token: Token, socket: TcpStream) -> Connection { + Connection { + token: token, + socket: socket, + send_queue: VecDeque::new(), + rec_buf: Bytes::new(), + rec_size: 0, + interest: EventSet::hup(), + } + } + + pub fn expect(&mut self, size: usize) { + if self.rec_size != self.rec_buf.len() { + warn!(target:"net", "Unexpected connection read start"); + } + unsafe { self.rec_buf.set_len(0) } + self.rec_size = size; + } + + //TODO: return a slice + pub fn readable(&mut self) -> io::Result> { + if self.rec_size == 0 || self.rec_buf.len() >= self.rec_size { + warn!(target:"net", "Unexpected connection read"); + } + let max = self.rec_size - self.rec_buf.len(); + // resolve "multiple applicable items in scope [E0034]" error + let sock_ref = ::by_ref(&mut self.socket); + match sock_ref.take(max as u64).try_read_buf(&mut self.rec_buf) { + Ok(Some(_)) if self.rec_buf.len() == self.rec_size => { + self.rec_size = 0; + Ok(Some(::std::mem::replace(&mut self.rec_buf, Bytes::new()))) + }, + Ok(_) => Ok(None), + Err(e) => Err(e), + } + } + + pub fn send(&mut self, data: Bytes) { + if data.len() != 0 { + self.send_queue.push_back(Cursor::new(data)); + } + if !self.interest.is_writable() { + self.interest.insert(EventSet::writable()); + } + } + + pub fn writable(&mut self) -> io::Result { + if self.send_queue.is_empty() { + return Ok(WriteStatus::Complete) + } + { + let buf = self.send_queue.front_mut().unwrap(); + let send_size = buf.get_ref().len(); + if (buf.position() as usize) >= send_size { + warn!(target:"net", "Unexpected connection data"); + return Ok(WriteStatus::Complete) + } + match self.socket.try_write_buf(buf) { + Ok(_) if (buf.position() as usize) < send_size => { + self.interest.insert(EventSet::writable()); + Ok(WriteStatus::Ongoing) + }, + Ok(_) if (buf.position() as usize) == send_size => { + Ok(WriteStatus::Complete) + }, + Ok(_) => { panic!("Wrote past buffer");}, + Err(e) => Err(e) + } + }.and_then(|r| { + if r == WriteStatus::Complete { + self.send_queue.pop_front(); + } + if self.send_queue.is_empty() { + self.interest.remove(EventSet::writable()); + } + else { + self.interest.insert(EventSet::writable()); + } + Ok(r) + }) + } + + pub fn register(&mut self, event_loop: &mut EventLoop) -> io::Result<()> { + trace!(target: "net", "connection register; token={:?}", self.token); + self.interest.insert(EventSet::readable()); + event_loop.register(&self.socket, self.token, self.interest, PollOpt::edge() | PollOpt::oneshot()).or_else(|e| { + error!("Failed to register {:?}, {:?}", self.token, e); + Err(e) + }) + } + + pub fn reregister(&mut self, event_loop: &mut EventLoop) -> io::Result<()> { + trace!(target: "net", "connection reregister; token={:?}", self.token); + event_loop.reregister( &self.socket, self.token, self.interest, PollOpt::edge() | PollOpt::oneshot()).or_else(|e| { + error!("Failed to reregister {:?}, {:?}", self.token, e); + Err(e) + }) + } +} + +pub struct Packet { + pub protocol: u16, + pub data: Bytes, +} + +enum EncryptedConnectionState { + Header, + Payload, +} + +pub struct EncryptedConnection { + connection: Connection, + encoder: CtrMode, + decoder: CtrMode, + mac_encoder: EcbEncryptor>, + egress_mac: Keccak, + ingress_mac: Keccak, + read_state: EncryptedConnectionState, + idle_timeout: Option, + protocol_id: u16, + payload_len: u32, +} + +impl EncryptedConnection { + pub fn new(handshake: Handshake) -> Result { + let shared = try!(crypto::ecdh::agree(handshake.ecdhe.secret(), &handshake.remote_public)); + let mut nonce_material = H512::new(); + if handshake.originated { + handshake.remote_nonce.copy_to(&mut nonce_material[0..32]); + handshake.nonce.copy_to(&mut nonce_material[32..64]); + } + else { + handshake.nonce.copy_to(&mut nonce_material[0..32]); + handshake.remote_nonce.copy_to(&mut nonce_material[32..64]); + } + let mut key_material = H512::new(); + shared.copy_to(&mut key_material[0..32]); + nonce_material.sha3_into(&mut key_material[32..64]); + key_material.sha3().copy_to(&mut key_material[32..64]); + key_material.sha3().copy_to(&mut key_material[32..64]); + + let iv = vec![0u8; 16]; + let encoder = CtrMode::new(AesSafe256Encryptor::new(&key_material[32..64]), iv); + let iv = vec![0u8; 16]; + let decoder = CtrMode::new(AesSafe256Encryptor::new(&key_material[32..64]), iv); + + key_material.sha3().copy_to(&mut key_material[32..64]); + let mac_encoder = EcbEncryptor::new(AesSafe256Encryptor::new(&key_material[32..64]), NoPadding); + + let mut egress_mac = Keccak::new_keccak256(); + let mut mac_material = &H256::from_slice(&key_material[32..64]) ^ &handshake.remote_nonce; + egress_mac.update(&mac_material); + egress_mac.update(if handshake.originated { &handshake.auth_cipher } else { &handshake.ack_cipher }); + + let mut ingress_mac = Keccak::new_keccak256(); + mac_material = &H256::from_slice(&key_material[32..64]) ^ &handshake.nonce; + ingress_mac.update(&mac_material); + ingress_mac.update(if handshake.originated { &handshake.ack_cipher } else { &handshake.auth_cipher }); + + Ok(EncryptedConnection { + connection: handshake.connection, + encoder: encoder, + decoder: decoder, + mac_encoder: mac_encoder, + egress_mac: egress_mac, + ingress_mac: ingress_mac, + read_state: EncryptedConnectionState::Header, + idle_timeout: None, + protocol_id: 0, + payload_len: 0 + }) + } + + pub fn send_packet(&mut self, payload: &[u8]) -> Result<(), Error> { + let mut header = RlpStream::new(); + let len = payload.len() as usize; + header.append_raw(&[(len >> 16) as u8, (len >> 8) as u8, len as u8], 1); + header.append_raw(&[0xc2u8, 0x80u8, 0x80u8], 1); + //TODO: ger rid of vectors here + let mut header = header.out(); + let padding = (16 - (payload.len() % 16)) % 16; + header.resize(16, 0u8); + + let mut packet = vec![0u8; (32 + payload.len() + padding + 16)]; + self.encoder.encrypt(&mut RefReadBuffer::new(&header), &mut RefWriteBuffer::new(&mut packet), false).expect("Invalid length or padding"); + EncryptedConnection::update_mac(&mut self.egress_mac, &mut self.mac_encoder, &packet[0..16]); + self.egress_mac.clone().finalize(&mut packet[16..32]); + self.encoder.encrypt(&mut RefReadBuffer::new(&payload), &mut RefWriteBuffer::new(&mut packet[32..(32 + len)]), padding == 0).expect("Invalid length or padding"); + if padding != 0 { + let pad = [08; 16]; + self.encoder.encrypt(&mut RefReadBuffer::new(&pad[0..padding]), &mut RefWriteBuffer::new(&mut packet[(32 + len)..(32 + len + padding)]), true).expect("Invalid length or padding"); + } + self.egress_mac.update(&packet[32..(32 + len + padding)]); + EncryptedConnection::update_mac(&mut self.egress_mac, &mut self.mac_encoder, &[0u8; 0]); + self.egress_mac.clone().finalize(&mut packet[(32 + len + padding)..]); + self.connection.send(packet); + Ok(()) + } + + fn read_header(&mut self, header: &[u8]) -> Result<(), Error> { + if header.len() != ENCRYPTED_HEADER_LEN { + return Err(Error::Auth); + } + EncryptedConnection::update_mac(&mut self.ingress_mac, &mut self.mac_encoder, &header[0..16]); + let mac = &header[16..]; + let mut expected = H256::new(); + self.ingress_mac.clone().finalize(&mut expected); + if mac != &expected[0..16] { + return Err(Error::Auth); + } + + let mut hdec = H128::new(); + self.decoder.decrypt(&mut RefReadBuffer::new(&header[0..16]), &mut RefWriteBuffer::new(&mut hdec), false).expect("Invalid length or padding"); + + let length = ((((hdec[0] as u32) << 8) + (hdec[1] as u32)) << 8) + (hdec[2] as u32); + let header_rlp = UntrustedRlp::new(&hdec[3..6]); + let protocol_id = try!(header_rlp.val_at::(0)); + + self.payload_len = length; + self.protocol_id = protocol_id; + self.read_state = EncryptedConnectionState::Payload; + + let padding = (16 - (length % 16)) % 16; + let full_length = length + padding + 16; + self.connection.expect(full_length as usize); + Ok(()) + } + + fn read_payload(&mut self, payload: &[u8]) -> Result { + let padding = (16 - (self.payload_len % 16)) % 16; + let full_length = (self.payload_len + padding + 16) as usize; + if payload.len() != full_length { + return Err(Error::Auth); + } + self.ingress_mac.update(&payload[0..payload.len() - 16]); + EncryptedConnection::update_mac(&mut self.ingress_mac, &mut self.mac_encoder, &[0u8; 0]); + let mac = &payload[(payload.len() - 16)..]; + let mut expected = H128::new(); + self.ingress_mac.clone().finalize(&mut expected); + if mac != &expected[..] { + return Err(Error::Auth); + } + + let mut packet = vec![0u8; self.payload_len as usize]; + self.decoder.decrypt(&mut RefReadBuffer::new(&payload[0..(full_length - 16)]), &mut RefWriteBuffer::new(&mut packet), false).expect("Invalid length or padding"); + packet.resize(self.payload_len as usize, 0u8); + Ok(Packet { + protocol: self.protocol_id, + data: packet + }) + } + + fn update_mac(mac: &mut Keccak, mac_encoder: &mut EcbEncryptor>, seed: &[u8]) { + let mut prev = H128::new(); + mac.clone().finalize(&mut prev); + let mut enc = H128::new(); + mac_encoder.encrypt(&mut RefReadBuffer::new(&prev), &mut RefWriteBuffer::new(&mut enc), true).unwrap(); + mac_encoder.reset(); + + enc = enc ^ if seed.is_empty() { prev } else { H128::from_slice(seed) }; + mac.update(&enc); + } + + pub fn readable(&mut self, event_loop: &mut EventLoop) -> Result, Error> { + self.idle_timeout.map(|t| event_loop.clear_timeout(t)); + try!(self.connection.reregister(event_loop)); + match self.read_state { + EncryptedConnectionState::Header => { + match try!(self.connection.readable()) { + Some(data) => { + try!(self.read_header(&data)); + }, + None => {} + }; + Ok(None) + }, + EncryptedConnectionState::Payload => { + match try!(self.connection.readable()) { + Some(data) => { + self.read_state = EncryptedConnectionState::Header; + self.connection.expect(ENCRYPTED_HEADER_LEN); + Ok(Some(try!(self.read_payload(&data)))) + }, + None => Ok(None) + } + } + } + } + + pub fn writable(&mut self, event_loop: &mut EventLoop) -> Result<(), Error> { + self.idle_timeout.map(|t| event_loop.clear_timeout(t)); + try!(self.connection.writable()); + try!(self.connection.reregister(event_loop)); + Ok(()) + } + + pub fn register(&mut self, event_loop: &mut EventLoop) -> Result<(), Error> { + self.connection.expect(ENCRYPTED_HEADER_LEN); + self.idle_timeout.map(|t| event_loop.clear_timeout(t)); + self.idle_timeout = event_loop.timeout_ms(self.connection.token, 1800).ok(); + try!(self.connection.reregister(event_loop)); + Ok(()) + } +} + +#[test] +pub fn test_encryption() { + use hash::*; + use std::str::FromStr; + let key = H256::from_str("2212767d793a7a3d66f869ae324dd11bd17044b82c9f463b8a541a4d089efec5").unwrap(); + let before = H128::from_str("12532abaec065082a3cf1da7d0136f15").unwrap(); + let before2 = H128::from_str("7e99f682356fdfbc6b67a9562787b18a").unwrap(); + let after = H128::from_str("89464c6b04e7c99e555c81d3f7266a05").unwrap(); + let after2 = H128::from_str("85c070030589ef9c7a2879b3a8489316").unwrap(); + + let mut got = H128::new(); + + let mut encoder = EcbEncryptor::new(AesSafe256Encryptor::new(&key), NoPadding); + encoder.encrypt(&mut RefReadBuffer::new(&before), &mut RefWriteBuffer::new(&mut got), true).unwrap(); + encoder.reset(); + assert_eq!(got, after); + got = H128::new(); + encoder.encrypt(&mut RefReadBuffer::new(&before2), &mut RefWriteBuffer::new(&mut got), true).unwrap(); + encoder.reset(); + assert_eq!(got, after2); +} + + diff --git a/src/network/discovery.rs b/src/network/discovery.rs new file mode 100644 index 000000000..f2d139f45 --- /dev/null +++ b/src/network/discovery.rs @@ -0,0 +1,206 @@ +// This module is a work in progress + +#![allow(dead_code)] //TODO: remove this after everything is done + +use std::collections::{HashSet, BTreeMap}; +use std::cell::{RefCell}; +use std::ops::{DerefMut}; +use mio::*; +use mio::udp::*; +use hash::*; +use sha3::Hashable; +use crypto::*; +use network::host::*; + +const ADDRESS_BYTES_SIZE: u32 = 32; ///< Size of address type in bytes. +const ADDRESS_BITS: u32 = 8 * ADDRESS_BYTES_SIZE; ///< Denoted by n in [Kademlia]. +const NODE_BINS: u32 = ADDRESS_BITS - 1; ///< Size of m_state (excludes root, which is us). +const DISCOVERY_MAX_STEPS: u16 = 8; ///< Max iterations of discovery. (discover) +const BUCKET_SIZE: u32 = 16; ///< Denoted by k in [Kademlia]. Number of nodes stored in each bucket. +const ALPHA: usize = 3; ///< Denoted by \alpha in [Kademlia]. Number of concurrent FindNode requests. + +struct NodeBucket { + distance: u32, + nodes: Vec +} + +impl NodeBucket { + fn new(distance: u32) -> NodeBucket { + NodeBucket { + distance: distance, + nodes: Vec::new() + } + } +} + +struct Discovery { + id: NodeId, + discovery_round: u16, + discovery_id: NodeId, + discovery_nodes: HashSet, + node_buckets: Vec, +} + +struct FindNodePacket; + +impl FindNodePacket { + fn new(_endpoint: &NodeEndpoint, _id: &NodeId) -> FindNodePacket { + FindNodePacket + } + + fn sign(&mut self, _secret: &Secret) { + } + + fn send(& self, _socket: &mut UdpSocket) { + } +} + +impl Discovery { + pub fn new(id: &NodeId) -> Discovery { + Discovery { + id: id.clone(), + discovery_round: 0, + discovery_id: NodeId::new(), + discovery_nodes: HashSet::new(), + node_buckets: (0..NODE_BINS).map(|x| NodeBucket::new(x)).collect(), + } + } + + pub fn add_node(&mut self, id: &NodeId) { + self.node_buckets[Discovery::distance(&self.id, &id) as usize].nodes.push(id.clone()); + } + + fn start_node_discovery(&mut self, event_loop: &mut EventLoop) { + self.discovery_round = 0; + self.discovery_id.randomize(); + self.discovery_nodes.clear(); + self.discover(event_loop); + } + + fn discover(&mut self, event_loop: &mut EventLoop) { + if self.discovery_round == DISCOVERY_MAX_STEPS + { + debug!("Restarting discovery"); + self.start_node_discovery(event_loop); + return; + } + let mut tried_count = 0; + { + let nearest = Discovery::nearest_node_entries(&self.id, &self.discovery_id, &self.node_buckets).into_iter(); + let nodes = RefCell::new(&mut self.discovery_nodes); + let nearest = nearest.filter(|x| nodes.borrow().contains(&x)).take(ALPHA); + for r in nearest { + //let mut p = FindNodePacket::new(&r.endpoint, &self.discovery_id); + //p.sign(&self.secret); + //p.send(&mut self.udp_socket); + let mut borrowed = nodes.borrow_mut(); + borrowed.deref_mut().insert(r.clone()); + tried_count += 1; + } + } + + if tried_count == 0 + { + debug!("Restarting discovery"); + self.start_node_discovery(event_loop); + return; + } + self.discovery_round += 1; + //event_loop.timeout_ms(Token(NODETABLE_DISCOVERY), 1200).unwrap(); + } + + fn distance(a: &NodeId, b: &NodeId) -> u32 { + let d = a.sha3() ^ b.sha3(); + let mut ret:u32 = 0; + for i in 0..32 { + let mut v: u8 = d[i]; + while v != 0 { + v >>= 1; + ret += 1; + } + } + ret + } + + fn nearest_node_entries<'b>(source: &NodeId, target: &NodeId, buckets: &'b Vec) -> Vec<&'b NodeId> + { + // send ALPHA FindNode packets to nodes we know, closest to target + const LAST_BIN: u32 = NODE_BINS - 1; + let mut head = Discovery::distance(source, target); + let mut tail = if head == 0 { LAST_BIN } else { (head - 1) % NODE_BINS }; + + let mut found: BTreeMap> = BTreeMap::new(); + let mut count = 0; + + // if d is 0, then we roll look forward, if last, we reverse, else, spread from d + if head > 1 && tail != LAST_BIN { + while head != tail && head < NODE_BINS && count < BUCKET_SIZE + { + for n in buckets[head as usize].nodes.iter() + { + if count < BUCKET_SIZE { + count += 1; + found.entry(Discovery::distance(target, &n)).or_insert(Vec::new()).push(n); + } + else { + break; + } + } + if count < BUCKET_SIZE && tail != 0 { + for n in buckets[tail as usize].nodes.iter() { + if count < BUCKET_SIZE { + count += 1; + found.entry(Discovery::distance(target, &n)).or_insert(Vec::new()).push(n); + } + else { + break; + } + } + } + + head += 1; + if tail > 0 { + tail -= 1; + } + } + } + else if head < 2 { + while head < NODE_BINS && count < BUCKET_SIZE { + for n in buckets[head as usize].nodes.iter() { + if count < BUCKET_SIZE { + count += 1; + found.entry(Discovery::distance(target, &n)).or_insert(Vec::new()).push(n); + } + else { + break; + } + } + head += 1; + } + } + else { + while tail > 0 && count < BUCKET_SIZE { + for n in buckets[tail as usize].nodes.iter() { + if count < BUCKET_SIZE { + count += 1; + found.entry(Discovery::distance(target, &n)).or_insert(Vec::new()).push(n); + } + else { + break; + } + } + tail -= 1; + } + } + + let mut ret:Vec<&NodeId> = Vec::new(); + for (_, nodes) in found { + for n in nodes { + if ret.len() < BUCKET_SIZE as usize /* && n->endpoint && n->endpoint.isAllowed() */ { + ret.push(n); + } + } + } + ret + } +} diff --git a/src/network/handshake.rs b/src/network/handshake.rs new file mode 100644 index 000000000..4df8cbe8b --- /dev/null +++ b/src/network/handshake.rs @@ -0,0 +1,189 @@ +use mio::*; +use mio::tcp::*; +use hash::*; +use sha3::Hashable; +use bytes::Bytes; +use crypto::*; +use crypto; +use network::connection::{Connection}; +use network::host::{NodeId, Host, HostInfo}; +use network::Error; + +#[derive(PartialEq, Eq, Debug)] +enum HandshakeState { + New, + ReadingAuth, + ReadingAck, + StartSession, +} + +pub struct Handshake { + pub id: NodeId, + pub connection: Connection, + state: HandshakeState, + pub originated: bool, + idle_timeout: Option, + pub ecdhe: KeyPair, + pub nonce: H256, + pub remote_public: Public, + pub remote_nonce: H256, + pub auth_cipher: Bytes, + pub ack_cipher: Bytes +} + +const AUTH_PACKET_SIZE: usize = 307; +const ACK_PACKET_SIZE: usize = 210; + +impl Handshake { + pub fn new(token: Token, id: &NodeId, socket: TcpStream, nonce: &H256) -> Result { + Ok(Handshake { + id: id.clone(), + connection: Connection::new(token, socket), + originated: false, + state: HandshakeState::New, + idle_timeout: None, + ecdhe: try!(KeyPair::create()), + nonce: nonce.clone(), + remote_public: Public::new(), + remote_nonce: H256::new(), + auth_cipher: Bytes::new(), + ack_cipher: Bytes::new(), + }) + } + + pub fn start(&mut self, host: &HostInfo, originated: bool) -> Result<(), Error> { + self.originated = originated; + if originated { + try!(self.write_auth(host)); + } + else { + self.state = HandshakeState::ReadingAuth; + self.connection.expect(AUTH_PACKET_SIZE); + }; + Ok(()) + } + + pub fn done(&self) -> bool { + self.state == HandshakeState::StartSession + } + + pub fn readable(&mut self, event_loop: &mut EventLoop, host: &HostInfo) -> Result<(), Error> { + self.idle_timeout.map(|t| event_loop.clear_timeout(t)); + match self.state { + HandshakeState::ReadingAuth => { + match try!(self.connection.readable()) { + Some(data) => { + try!(self.read_auth(host, &data)); + try!(self.write_ack()); + }, + None => {} + }; + }, + HandshakeState::ReadingAck => { + match try!(self.connection.readable()) { + Some(data) => { + try!(self.read_ack(host, &data)); + self.state = HandshakeState::StartSession; + }, + None => {} + }; + }, + _ => { panic!("Unexpected state"); } + } + if self.state != HandshakeState::StartSession { + try!(self.connection.reregister(event_loop)); + } + Ok(()) + } + + pub fn writable(&mut self, event_loop: &mut EventLoop, _host: &HostInfo) -> Result<(), Error> { + self.idle_timeout.map(|t| event_loop.clear_timeout(t)); + try!(self.connection.writable()); + if self.state != HandshakeState::StartSession { + try!(self.connection.reregister(event_loop)); + } + Ok(()) + } + + pub fn register(&mut self, event_loop: &mut EventLoop) -> Result<(), Error> { + self.idle_timeout.map(|t| event_loop.clear_timeout(t)); + self.idle_timeout = event_loop.timeout_ms(self.connection.token, 1800).ok(); + try!(self.connection.register(event_loop)); + Ok(()) + } + + fn read_auth(&mut self, host: &HostInfo, data: &[u8]) -> Result<(), Error> { + trace!(target:"net", "Received handshake auth to {:?}", self.connection.socket.peer_addr()); + assert!(data.len() == AUTH_PACKET_SIZE); + self.auth_cipher = data.to_vec(); + let auth = try!(ecies::decrypt(host.secret(), data)); + let (sig, rest) = auth.split_at(65); + let (hepubk, rest) = rest.split_at(32); + let (pubk, rest) = rest.split_at(64); + let (nonce, _) = rest.split_at(32); + self.remote_public.clone_from_slice(pubk); + self.remote_nonce.clone_from_slice(nonce); + let shared = try!(ecdh::agree(host.secret(), &self.remote_public)); + let signature = Signature::from_slice(sig); + let spub = try!(ec::recover(&signature, &(&shared ^ &self.remote_nonce))); + if &spub.sha3()[..] != hepubk { + trace!(target:"net", "Handshake hash mismath with {:?}", self.connection.socket.peer_addr()); + return Err(Error::Auth); + }; + self.write_ack() + } + + fn read_ack(&mut self, host: &HostInfo, data: &[u8]) -> Result<(), Error> { + trace!(target:"net", "Received handshake auth to {:?}", self.connection.socket.peer_addr()); + assert!(data.len() == ACK_PACKET_SIZE); + self.ack_cipher = data.to_vec(); + let ack = try!(ecies::decrypt(host.secret(), data)); + self.remote_public.clone_from_slice(&ack[0..64]); + self.remote_nonce.clone_from_slice(&ack[64..(64+32)]); + Ok(()) + } + + fn write_auth(&mut self, host: &HostInfo) -> Result<(), Error> { + trace!(target:"net", "Sending handshake auth to {:?}", self.connection.socket.peer_addr()); + let mut data = [0u8; /*Signature::SIZE*/ 65 + /*H256::SIZE*/ 32 + /*Public::SIZE*/ 64 + /*H256::SIZE*/ 32 + 1]; //TODO: use associated constants + let len = data.len(); + { + data[len - 1] = 0x0; + let (sig, rest) = data.split_at_mut(65); + let (hepubk, rest) = rest.split_at_mut(32); + let (pubk, rest) = rest.split_at_mut(64); + let (nonce, _) = rest.split_at_mut(32); + + // E(remote-pubk, S(ecdhe-random, ecdh-shared-secret^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x0) + let shared = try!(crypto::ecdh::agree(host.secret(), &self.id)); + try!(crypto::ec::sign(self.ecdhe.secret(), &(&shared ^ &self.nonce))).copy_to(sig); + self.ecdhe.public().sha3_into(hepubk); + host.id().copy_to(pubk); + self.nonce.copy_to(nonce); + } + let message = try!(crypto::ecies::encrypt(&self.id, &data)); + self.auth_cipher = message.clone(); + self.connection.send(message); + self.connection.expect(ACK_PACKET_SIZE); + self.state = HandshakeState::ReadingAck; + Ok(()) + } + + fn write_ack(&mut self) -> Result<(), Error> { + trace!(target:"net", "Sending handshake ack to {:?}", self.connection.socket.peer_addr()); + let mut data = [0u8; 1 + /*Public::SIZE*/ 64 + /*H256::SIZE*/ 32]; //TODO: use associated constants + let len = data.len(); + { + data[len - 1] = 0x0; + let (epubk, rest) = data.split_at_mut(64); + let (nonce, _) = rest.split_at_mut(32); + self.ecdhe.public().copy_to(epubk); + self.nonce.copy_to(nonce); + } + let message = try!(crypto::ecies::encrypt(&self.id, &data)); + self.ack_cipher = message.clone(); + self.connection.send(message); + self.state = HandshakeState::StartSession; + Ok(()) + } +} diff --git a/src/network/host.rs b/src/network/host.rs new file mode 100644 index 000000000..9e2b3e101 --- /dev/null +++ b/src/network/host.rs @@ -0,0 +1,662 @@ +use std::net::{SocketAddr, ToSocketAddrs}; +use std::collections::{HashMap}; +use std::hash::{Hash, Hasher}; +use std::str::{FromStr}; +use mio::*; +use mio::util::{Slab}; +use mio::tcp::*; +use mio::udp::*; +use hash::*; +use crypto::*; +use sha3::Hashable; +use rlp::*; +use time::Tm; +use network::handshake::Handshake; +use network::session::{Session, SessionData}; +use network::{Error, ProtocolHandler}; + +const _DEFAULT_PORT: u16 = 30304; + +const MAX_CONNECTIONS: usize = 1024; +const MAX_USER_TIMERS: usize = 32; +const IDEAL_PEERS: u32 = 10; + +pub type NodeId = H512; +pub type TimerToken = usize; + +#[derive(Debug)] +struct NetworkConfiguration { + listen_address: SocketAddr, + public_address: SocketAddr, + nat_enabled: bool, + discovery_enabled: bool, + pin: bool, +} + +impl NetworkConfiguration { + fn new() -> NetworkConfiguration { + NetworkConfiguration { + listen_address: SocketAddr::from_str("0.0.0.0:30304").unwrap(), + public_address: SocketAddr::from_str("0.0.0.0:30304").unwrap(), + nat_enabled: true, + discovery_enabled: true, + pin: false, + } + } +} + +#[derive(Debug)] +pub struct NodeEndpoint { + address: SocketAddr, + address_str: String, + udp_port: u16 +} + +impl NodeEndpoint { + fn from_str(s: &str) -> Result { + let address = s.to_socket_addrs().map(|mut i| i.next()); + match address { + Ok(Some(a)) => Ok(NodeEndpoint { + address: a, + address_str: s.to_string(), + udp_port: a.port() + }), + Ok(_) => Err(Error::AddressResolve(None)), + Err(e) => Err(Error::AddressResolve(Some(e))) + } + } +} + +#[derive(PartialEq, Eq, Copy, Clone)] +enum PeerType { + Required, + Optional +} + +struct Node { + id: NodeId, + endpoint: NodeEndpoint, + peer_type: PeerType, + last_attempted: Option, +} + +impl FromStr for Node { + type Err = Error; + fn from_str(s: &str) -> Result { + let (id, endpoint) = if &s[0..8] == "enode://" && s.len() > 136 && &s[136..137] == "@" { + (try!(NodeId::from_str(&s[8..136])), try!(NodeEndpoint::from_str(&s[137..]))) + } + else { + (NodeId::new(), try!(NodeEndpoint::from_str(s))) + }; + + Ok(Node { + id: id, + endpoint: endpoint, + peer_type: PeerType::Optional, + last_attempted: None, + }) + } +} + +impl PartialEq for Node { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} +impl Eq for Node { } + +impl Hash for Node { + fn hash(&self, state: &mut H) where H: Hasher { + self.id.hash(state) + } +} + +// Tokens +const TCP_ACCEPT: usize = 1; +const IDLE: usize = 3; +const NODETABLE_RECEIVE: usize = 4; +const NODETABLE_MAINTAIN: usize = 5; +const NODETABLE_DISCOVERY: usize = 6; +const FIRST_CONNECTION: usize = 7; +const LAST_CONNECTION: usize = FIRST_CONNECTION + MAX_CONNECTIONS - 1; +const USER_TIMER: usize = LAST_CONNECTION; +const LAST_USER_TIMER: usize = USER_TIMER + MAX_USER_TIMERS - 1; + +pub type PacketId = u8; +pub type ProtocolId = &'static str; + +pub enum HostMessage { + Shutdown, + AddHandler { + handler: Box, + protocol: ProtocolId, + versions: Vec, + }, + Send { + peer: PeerId, + packet_id: PacketId, + protocol: ProtocolId, + data: Vec, + }, + UserMessage(UserMessage), +} + +pub type UserMessageId = u32; + +pub struct UserMessage { + pub protocol: ProtocolId, + pub id: UserMessageId, + pub data: Option>, +} + +pub type PeerId = usize; + +#[derive(Debug, PartialEq, Eq)] +pub struct CapabilityInfo { + pub protocol: ProtocolId, + pub version: u8, + pub packet_count: u8, +} + +impl Encodable for CapabilityInfo { + fn encode(&self, encoder: &mut E) -> () where E: Encoder { + encoder.emit_list(|e| { + self.protocol.encode(e); + (self.version as u32).encode(e); + }); + } +} + +/// IO access point +pub struct HostIo<'s> { + protocol: ProtocolId, + connections: &'s mut Slab, + timers: &'s mut Slab, + session: Option, + event_loop: &'s mut EventLoop, +} + +impl<'s> HostIo<'s> { + fn new(protocol: ProtocolId, session: Option, event_loop: &'s mut EventLoop, connections: &'s mut Slab, timers: &'s mut Slab) -> HostIo<'s> { + HostIo { + protocol: protocol, + session: session, + event_loop: event_loop, + connections: connections, + timers: timers, + } + } + + /// Send a packet over the network to another peer. + pub fn send(&mut self, peer: PeerId, packet_id: PacketId, data: Vec) -> Result<(), Error> { + match self.connections.get_mut(Token(peer)) { + Some(&mut ConnectionEntry::Session(ref mut s)) => { + s.send_packet(self.protocol, packet_id as u8, &data).unwrap_or_else(|e| { + warn!(target: "net", "Send error: {:?}", e); + }); //TODO: don't copy vector data + }, + _ => { + warn!(target: "net", "Send: Peer does not exist"); + } + } + Ok(()) + } + + /// Respond to a current network message. Panics if no there is no packet in the context. + pub fn respond(&mut self, packet_id: PacketId, data: Vec) -> Result<(), Error> { + match self.session { + Some(session) => self.send(session.as_usize(), packet_id, data), + None => { + panic!("Respond: Session does not exist") + } + } + } + + /// Register a new IO timer. Returns a new timer toke. 'ProtocolHandler::timeout' will be called with the token. + pub fn register_timer(&mut self, ms: u64) -> Result{ + match self.timers.insert(UserTimer { + delay: ms, + protocol: self.protocol, + }) { + Ok(token) => { + self.event_loop.timeout_ms(token, ms).expect("Error registering user timer"); + Ok(token.as_usize()) + }, + _ => { panic!("Max timers reached") } + } + } + + /// Broadcast a message to other IO clients + pub fn message(&mut self, id: UserMessageId, data: Option>) { + match self.event_loop.channel().send(HostMessage::UserMessage(UserMessage { + protocol: self.protocol, + id: id, + data: data + })) { + Ok(_) => {} + Err(e) => { panic!("Error sending io message {:?}", e); } + } + } + + /// Disable current protocol capability for given peer. If no capabilities left peer gets disconnected. + pub fn disable_peer(&mut self, _peer: PeerId) { + //TODO: remove capability, disconnect if no capabilities left + } + +} + +struct UserTimer { + protocol: ProtocolId, + delay: u64, +} + +pub struct HostInfo { + keys: KeyPair, + config: NetworkConfiguration, + nonce: H256, + pub protocol_version: u32, + pub client_version: String, + pub listen_port: u16, + pub capabilities: Vec +} + +impl HostInfo { + pub fn id(&self) -> &NodeId { + self.keys.public() + } + + pub fn secret(&self) -> &Secret { + self.keys.secret() + } + pub fn next_nonce(&mut self) -> H256 { + self.nonce = self.nonce.sha3(); + return self.nonce.clone(); + } +} + +enum ConnectionEntry { + Handshake(Handshake), + Session(Session) +} + +pub struct Host { + info: HostInfo, + _udp_socket: UdpSocket, + _listener: TcpListener, + connections: Slab, + timers: Slab, + nodes: HashMap, + handlers: HashMap>, + _idle_timeout: Timeout, +} + +impl Host { + pub fn start(event_loop: &mut EventLoop) -> Result<(), Error> { + let config = NetworkConfiguration::new(); + /* + match ::ifaces::Interface::get_all().unwrap().into_iter().filter(|x| x.kind == ::ifaces::Kind::Packet && x.addr.is_some()).next() { + Some(iface) => config.public_address = iface.addr.unwrap(), + None => warn!("No public network interface"), + } + */ + + let addr = config.listen_address; + // Setup the server socket + let listener = TcpListener::bind(&addr).unwrap(); + // Start listening for incoming connections + event_loop.register(&listener, Token(TCP_ACCEPT), EventSet::readable(), PollOpt::edge()).unwrap(); + let idle_timeout = event_loop.timeout_ms(Token(IDLE), 1000).unwrap(); //TODO: check delay + // open the udp socket + let udp_socket = UdpSocket::bound(&addr).unwrap(); + event_loop.register(&udp_socket, Token(NODETABLE_RECEIVE), EventSet::readable(), PollOpt::edge()).unwrap(); + event_loop.timeout_ms(Token(NODETABLE_MAINTAIN), 7200).unwrap(); + let port = config.listen_address.port(); + + let mut host = Host { + info: HostInfo { + keys: KeyPair::create().unwrap(), + config: config, + nonce: H256::random(), + protocol_version: 4, + client_version: "parity".to_string(), + listen_port: port, + capabilities: Vec::new(), + }, + _udp_socket: udp_socket, + _listener: listener, + connections: Slab::new_starting_at(Token(FIRST_CONNECTION), MAX_CONNECTIONS), + timers: Slab::new_starting_at(Token(USER_TIMER), MAX_USER_TIMERS), + nodes: HashMap::new(), + handlers: HashMap::new(), + _idle_timeout: idle_timeout, + }; + + host.add_node("enode://c022e7a27affdd1632f2e67dffeb87f02bf506344bb142e08d12b28e7e5c6e5dbb8183a46a77bff3631b51c12e8cf15199f797feafdc8834aaf078ad1a2bcfa0@127.0.0.1:30303"); + host.add_node("enode://5374c1bff8df923d3706357eeb4983cd29a63be40a269aaa2296ee5f3b2119a8978c0ed68b8f6fc84aad0df18790417daadf91a4bfbb786a16c9b0a199fa254a@gav.ethdev.com:30300"); + host.add_node("enode://e58d5e26b3b630496ec640f2530f3e7fa8a8c7dfe79d9e9c4aac80e3730132b869c852d3125204ab35bb1b1951f6f2d40996c1034fd8c5a69b383ee337f02ddc@gav.ethdev.com:30303"); + host.add_node("enode://a979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c@52.16.188.185:30303"); + host.add_node("enode://7f25d3eab333a6b98a8b5ed68d962bb22c876ffcd5561fca54e3c2ef27f754df6f7fd7c9b74cc919067abac154fb8e1f8385505954f161ae440abc355855e034@54.207.93.166:30303"); + host.add_node("enode://5374c1bff8df923d3706357eeb4983cd29a63be40a269aaa2296ee5f3b2119a8978c0ed68b8f6fc84aad0df18790417daadf91a4bfbb786a16c9b0a199fa254a@92.51.165.126:30303"); + + try!(event_loop.run(&mut host)); + Ok(()) + } + + fn add_node(&mut self, id: &str) { + match Node::from_str(id) { + Err(e) => { warn!("Could not add node: {:?}", e); }, + Ok(n) => { + self.nodes.insert(n.id.clone(), n); + } + } + } + + fn maintain_network(&mut self, event_loop: &mut EventLoop) { + self.connect_peers(event_loop); + } + + fn have_session(&self, id: &NodeId) -> bool { + self.connections.iter().any(|e| match e { &ConnectionEntry::Session(ref s) => s.info.id.eq(&id), _ => false }) + } + + fn connecting_to(&self, id: &NodeId) -> bool { + self.connections.iter().any(|e| match e { &ConnectionEntry::Handshake(ref h) => h.id.eq(&id), _ => false }) + } + + fn connect_peers(&mut self, event_loop: &mut EventLoop) { + + struct NodeInfo { + id: NodeId, + peer_type: PeerType + } + + let mut to_connect: Vec = Vec::new(); + + let mut req_conn = 0; + //TODO: use nodes from discovery here + //for n in self.node_buckets.iter().flat_map(|n| &n.nodes).map(|id| NodeInfo { id: id.clone(), peer_type: self.nodes.get(id).unwrap().peer_type}) { + for n in self.nodes.values().map(|n| NodeInfo { id: n.id.clone(), peer_type: n.peer_type }) { + let connected = self.have_session(&n.id) || self.connecting_to(&n.id); + let required = n.peer_type == PeerType::Required; + if connected && required { + req_conn += 1; + } + else if !connected && (!self.info.config.pin || required) { + to_connect.push(n); + } + } + + for n in to_connect.iter() { + if n.peer_type == PeerType::Required { + if req_conn < IDEAL_PEERS { + self.connect_peer(&n.id, event_loop); + } + req_conn += 1; + } + } + + if !self.info.config.pin + { + let pending_count = 0; //TODO: + let peer_count = 0; + let mut open_slots = IDEAL_PEERS - peer_count - pending_count + req_conn; + if open_slots > 0 { + for n in to_connect.iter() { + if n.peer_type == PeerType::Optional && open_slots > 0 { + open_slots -= 1; + self.connect_peer(&n.id, event_loop); + } + } + } + } + } + + fn connect_peer(&mut self, id: &NodeId, event_loop: &mut EventLoop) { + if self.have_session(id) + { + warn!("Aborted connect. Node already connected."); + return; + } + if self.connecting_to(id) + { + warn!("Aborted connect. Node already connecting."); + return; + } + + let socket = { + let node = self.nodes.get_mut(id).unwrap(); + node.last_attempted = Some(::time::now()); + + match TcpStream::connect(&node.endpoint.address) { + Ok(socket) => socket, + Err(_) => { + warn!("Cannot connect to node"); + return; + } + } + }; + + let nonce = self.info.next_nonce(); + match self.connections.insert_with(|token| ConnectionEntry::Handshake(Handshake::new(token, id, socket, &nonce).expect("Can't create handshake"))) { + Some(token) => { + match self.connections.get_mut(token) { + Some(&mut ConnectionEntry::Handshake(ref mut h)) => { + h.start(&self.info, true) + .and_then(|_| h.register(event_loop)) + .unwrap_or_else (|e| { + debug!(target: "net", "Handshake create error: {:?}", e); + }); + }, + _ => {} + } + }, + None => { warn!("Max connections reached") } + } + } + + + fn accept(&mut self, _event_loop: &mut EventLoop) { + warn!(target: "net", "accept"); + } + + fn connection_writable(&mut self, token: Token, event_loop: &mut EventLoop) { + let mut kill = false; + let mut create_session = false; + match self.connections.get_mut(token) { + Some(&mut ConnectionEntry::Handshake(ref mut h)) => { + h.writable(event_loop, &self.info).unwrap_or_else(|e| { + debug!(target: "net", "Handshake write error: {:?}", e); + kill = true; + }); + create_session = h.done(); + }, + Some(&mut ConnectionEntry::Session(ref mut s)) => { + s.writable(event_loop, &self.info).unwrap_or_else(|e| { + debug!(target: "net", "Session write error: {:?}", e); + kill = true; + }); + } + _ => { + warn!(target: "net", "Received event for unknown connection"); + } + } + if kill { + self.kill_connection(token, event_loop); + } + if create_session { + self.start_session(token, event_loop); + } + } + + + fn connection_readable(&mut self, token: Token, event_loop: &mut EventLoop) { + let mut kill = false; + let mut create_session = false; + let mut ready_data: Vec = Vec::new(); + let mut packet_data: Option<(ProtocolId, PacketId, Vec)> = None; + match self.connections.get_mut(token) { + Some(&mut ConnectionEntry::Handshake(ref mut h)) => { + h.readable(event_loop, &self.info).unwrap_or_else(|e| { + debug!(target: "net", "Handshake read error: {:?}", e); + kill = true; + }); + create_session = h.done(); + }, + Some(&mut ConnectionEntry::Session(ref mut s)) => { + let sd = { s.readable(event_loop, &self.info).unwrap_or_else(|e| { + debug!(target: "net", "Session read error: {:?}", e); + kill = true; + SessionData::None + }) }; + match sd { + SessionData::Ready => { + for (p, _) in self.handlers.iter_mut() { + if s.have_capability(p) { + ready_data.push(p); + } + } + }, + SessionData::Packet { + data, + protocol, + packet_id, + } => { + match self.handlers.get_mut(protocol) { + None => { warn!(target: "net", "No handler found for protocol: {:?}", protocol) }, + Some(_) => packet_data = Some((protocol, packet_id, data)), + } + }, + SessionData::None => {}, + } + } + _ => { + warn!(target: "net", "Received event for unknown connection"); + } + } + if kill { + self.kill_connection(token, event_loop); + } + if create_session { + self.start_session(token, event_loop); + } + for p in ready_data { + let mut h = self.handlers.get_mut(p).unwrap(); + h.connected(&mut HostIo::new(p, Some(token), event_loop, &mut self.connections, &mut self.timers), &token.as_usize()); + } + if let Some((p, packet_id, data)) = packet_data { + let mut h = self.handlers.get_mut(p).unwrap(); + h.read(&mut HostIo::new(p, Some(token), event_loop, &mut self.connections, &mut self.timers), &token.as_usize(), packet_id, &data[1..]); + } + + } + + fn start_session(&mut self, token: Token, event_loop: &mut EventLoop) { + let info = &self.info; + self.connections.replace_with(token, |c| { + match c { + ConnectionEntry::Handshake(h) => Session::new(h, event_loop, info) + .map(|s| Some(ConnectionEntry::Session(s))) + .unwrap_or_else(|e| { + debug!(target: "net", "Session construction error: {:?}", e); + None + }), + _ => { panic!("No handshake to create a session from"); } + } + }).expect("Error updating slab with session"); + } + + fn connection_timeout(&mut self, token: Token, event_loop: &mut EventLoop) { + self.kill_connection(token, event_loop) + } + + fn kill_connection(&mut self, token: Token, _event_loop: &mut EventLoop) { + self.connections.remove(token); + } +} + +impl Handler for Host { + type Timeout = Token; + type Message = HostMessage; + + fn ready(&mut self, event_loop: &mut EventLoop, token: Token, events: EventSet) { + if events.is_readable() { + match token.as_usize() { + TCP_ACCEPT => self.accept(event_loop), + IDLE => self.maintain_network(event_loop), + FIRST_CONNECTION ... LAST_CONNECTION => self.connection_readable(token, event_loop), + NODETABLE_RECEIVE => {}, + _ => panic!("Received unknown readable token"), + } + } + else if events.is_writable() { + match token.as_usize() { + FIRST_CONNECTION ... LAST_CONNECTION => self.connection_writable(token, event_loop), + _ => panic!("Received unknown writable token"), + } + } + } + + fn timeout(&mut self, event_loop: &mut EventLoop, token: Token) { + match token.as_usize() { + IDLE => self.maintain_network(event_loop), + FIRST_CONNECTION ... LAST_CONNECTION => self.connection_timeout(token, event_loop), + NODETABLE_DISCOVERY => {}, + NODETABLE_MAINTAIN => {}, + USER_TIMER ... LAST_USER_TIMER => { + let (protocol, delay) = { + let timer = self.timers.get_mut(token).expect("Unknown user timer token"); + (timer.protocol, timer.delay) + }; + match self.handlers.get_mut(protocol) { + None => { warn!(target: "net", "No handler found for protocol: {:?}", protocol) }, + Some(h) => { + h.timeout(&mut HostIo::new(protocol, None, event_loop, &mut self.connections, &mut self.timers), token.as_usize()); + event_loop.timeout_ms(token, delay).expect("Error re-registering user timer"); + } + } + } + _ => panic!("Unknown timer token"), + } + } + + fn notify(&mut self, event_loop: &mut EventLoop, msg: Self::Message) { + match msg { + HostMessage::Shutdown => event_loop.shutdown(), + HostMessage::AddHandler { + handler, + protocol, + versions + } => { + self.handlers.insert(protocol, handler); + for v in versions { + self.info.capabilities.push(CapabilityInfo { protocol: protocol, version: v, packet_count:0 }); + } + }, + HostMessage::Send { + peer, + packet_id, + protocol, + data, + } => { + match self.connections.get_mut(Token(peer as usize)) { + Some(&mut ConnectionEntry::Session(ref mut s)) => { + s.send_packet(protocol, packet_id as u8, &data).unwrap_or_else(|e| { + warn!(target: "net", "Send error: {:?}", e); + }); //TODO: don't copy vector data + }, + _ => { + warn!(target: "net", "Send: Peer does not exist"); + } + } + }, + HostMessage::UserMessage(message) => { + for (p, h) in self.handlers.iter_mut() { + if p != &message.protocol { + h.message(&mut HostIo::new(message.protocol, None, event_loop, &mut self.connections, &mut self.timers), &message); + } + } + } + } + } +} diff --git a/src/network/mod.rs b/src/network/mod.rs new file mode 100644 index 000000000..df0da2c13 --- /dev/null +++ b/src/network/mod.rs @@ -0,0 +1,144 @@ +/// Network and general IO module. +/// +/// Example usage for craeting a network service and adding an IO handler: +/// +/// ```rust +/// extern crate ethcore_util as util; +/// use util::network::*; +/// +/// struct MyHandler; +/// +/// impl ProtocolHandler for MyHandler { +/// fn initialize(&mut self, io: &mut HandlerIo) { +/// io.register_timer(1000); +/// } +/// +/// fn read(&mut self, io: &mut HandlerIo, peer: &PeerId, packet_id: u8, data: &[u8]) { +/// println!("Received {} ({} bytes) from {}", packet_id, data.len(), peer); +/// } +/// +/// fn connected(&mut self, io: &mut HandlerIo, peer: &PeerId) { +/// println!("Connected {}", peer); +/// } +/// +/// fn disconnected(&mut self, io: &mut HandlerIo, peer: &PeerId) { +/// println!("Disconnected {}", peer); +/// } +/// +/// fn timeout(&mut self, io: &mut HandlerIo, timer: TimerToken) { +/// println!("Timeout {}", timer); +/// } +/// +/// fn message(&mut self, io: &mut HandlerIo, message: &Message) { +/// println!("Message {}:{}", message.protocol, message.id); +/// } +/// } +/// +/// fn main () { +/// let mut service = NetworkService::start().expect("Error creating network service"); +/// service.register_protocol(Box::new(MyHandler), "myproto", &[1u8]); +/// +/// // Wait for quit condition +/// // ... +/// // Drop the service +/// } +/// ``` +extern crate mio; +mod host; +mod connection; +mod handshake; +mod session; +mod discovery; +mod service; + +#[derive(Debug, Copy, Clone)] +pub enum DisconnectReason +{ + DisconnectRequested, + TCPError, + BadProtocol, + UselessPeer, + TooManyPeers, + DuplicatePeer, + IncompatibleProtocol, + NullIdentity, + ClientQuit, + UnexpectedIdentity, + LocalIdentity, + PingTimeout, +} + +#[derive(Debug)] +pub enum Error { + Crypto(::crypto::CryptoError), + Io(::std::io::Error), + Auth, + BadProtocol, + AddressParse(::std::net::AddrParseError), + AddressResolve(Option<::std::io::Error>), + NodeIdParse(::error::EthcoreError), + PeerNotFound, + Disconnect(DisconnectReason) +} + +impl From<::std::io::Error> for Error { + fn from(err: ::std::io::Error) -> Error { + Error::Io(err) + } +} + +impl From<::crypto::CryptoError> for Error { + fn from(err: ::crypto::CryptoError) -> Error { + Error::Crypto(err) + } +} + +impl From<::std::net::AddrParseError> for Error { + fn from(err: ::std::net::AddrParseError) -> Error { + Error::AddressParse(err) + } +} +impl From<::error::EthcoreError> for Error { + fn from(err: ::error::EthcoreError) -> Error { + Error::NodeIdParse(err) + } +} +impl From<::rlp::DecoderError> for Error { + fn from(_err: ::rlp::DecoderError) -> Error { + Error::Auth + } +} + +impl From<::mio::NotifyError> for Error { + fn from(_err: ::mio::NotifyError) -> Error { + Error::Io(::std::io::Error::new(::std::io::ErrorKind::ConnectionAborted, "Network IO notification error")) + } +} + +pub type PeerId = host::PeerId; +pub type PacketId = host::PacketId; +pub type TimerToken = host::TimerToken; +pub type HandlerIo<'s> = host::HostIo<'s>; +pub type Message = host::UserMessage; +pub type MessageId = host::UserMessageId; + +/// Network IO protocol handler. This needs to be implemented for each new subprotocol. +/// TODO: Separate p2p networking IO from IPC IO. `timeout` and `message` should go to a more genera IO provider. +/// All the handler function are called from within IO event loop. +pub trait ProtocolHandler: Send { + /// Initialize the hadler + fn initialize(&mut self, io: &mut HandlerIo); + /// Called when new network packet received. + fn read(&mut self, io: &mut HandlerIo, peer: &PeerId, packet_id: u8, data: &[u8]); + /// Called when new peer is connected. Only called when peer supports the same protocol. + fn connected(&mut self, io: &mut HandlerIo, peer: &PeerId); + /// Called when a previously connected peer disconnects. + fn disconnected(&mut self, io: &mut HandlerIo, peer: &PeerId); + /// Timer function called after a timeout created with `HandlerIo::timeout`. + fn timeout(&mut self, io: &mut HandlerIo, timer: TimerToken); + /// Called when a broadcasted message is received. The message can only be sent from a different protocol handler. + fn message(&mut self, io: &mut HandlerIo, message: &Message); +} + +pub type NetworkService = service::NetworkService; + diff --git a/src/network/service.rs b/src/network/service.rs new file mode 100644 index 000000000..7598ffdd6 --- /dev/null +++ b/src/network/service.rs @@ -0,0 +1,54 @@ +use std::thread::{self, JoinHandle}; +use mio::*; +use network::{Error, ProtocolHandler}; +use network::host::{Host, HostMessage, PeerId, PacketId, ProtocolId}; + +/// IO Service with networking +pub struct NetworkService { + thread: Option>, + host_channel: Sender +} + +impl NetworkService { + /// Starts IO event loop + pub fn start() -> Result { + let mut event_loop = EventLoop::new().unwrap(); + let channel = event_loop.channel(); + let thread = thread::spawn(move || { + Host::start(&mut event_loop).unwrap(); //TODO: + }); + Ok(NetworkService { + thread: Some(thread), + host_channel: channel + }) + } + + /// Send a message over the network. Normaly `HostIo::send` should be used. This can be used from non-io threads. + pub fn send(&mut self, peer: &PeerId, packet_id: PacketId, protocol: ProtocolId, data: &[u8]) -> Result<(), Error> { + try!(self.host_channel.send(HostMessage::Send { + peer: *peer, + packet_id: packet_id, + protocol: protocol, + data: data.to_vec() + })); + Ok(()) + } + + /// Regiter a new protocol handler with the event loop. + pub fn register_protocol(&mut self, handler: Box, protocol: ProtocolId, versions: &[u8]) -> Result<(), Error> { + try!(self.host_channel.send(HostMessage::AddHandler { + handler: handler, + protocol: protocol, + versions: versions.to_vec(), + })); + Ok(()) + } +} + +impl Drop for NetworkService { + fn drop(&mut self) { + self.host_channel.send(HostMessage::Shutdown).unwrap(); + self.thread.take().unwrap().join().unwrap(); + } +} + diff --git a/src/network/session.rs b/src/network/session.rs new file mode 100644 index 000000000..720902150 --- /dev/null +++ b/src/network/session.rs @@ -0,0 +1,247 @@ +use mio::*; +use hash::*; +use rlp::*; +use network::connection::{EncryptedConnection, Packet}; +use network::handshake::Handshake; +use network::{Error, DisconnectReason}; +use network::host::*; + +pub struct Session { + pub info: SessionInfo, + connection: EncryptedConnection, + had_hello: bool, +} + +pub enum SessionData { + None, + Ready, + Packet { + data: Vec, + protocol: &'static str, + packet_id: u8, + }, +} + +pub struct SessionInfo { + pub id: NodeId, + pub client_version: String, + pub protocol_version: u32, + pub capabilities: Vec, +} + +#[derive(Debug, PartialEq, Eq)] +pub struct PeerCapabilityInfo { + pub protocol: String, + pub version: u8, +} + +impl Decodable for PeerCapabilityInfo { + fn decode(decoder: &D) -> Result where D: Decoder { + let c = try!(decoder.as_list()); + let v: u32 = try!(Decodable::decode(&c[1])); + Ok(PeerCapabilityInfo { + protocol: try!(Decodable::decode(&c[0])), + version: v as u8, + }) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct SessionCapabilityInfo { + pub protocol: &'static str, + pub version: u8, + pub packet_count: u8, + pub id_offset: u8, +} + +const PACKET_HELLO: u8 = 0x80; +const PACKET_DISCONNECT: u8 = 0x01; +const PACKET_PING: u8 = 0x02; +const PACKET_PONG: u8 = 0x03; +const PACKET_GET_PEERS: u8 = 0x04; +const PACKET_PEERS: u8 = 0x05; +const PACKET_USER: u8 = 0x10; +const PACKET_LAST: u8 = 0x7f; + +impl Session { + pub fn new(h: Handshake, event_loop: &mut EventLoop, host: &HostInfo) -> Result { + let id = h.id.clone(); + let connection = try!(EncryptedConnection::new(h)); + let mut session = Session { + connection: connection, + had_hello: false, + info: SessionInfo { + id: id, + client_version: String::new(), + protocol_version: 0, + capabilities: Vec::new(), + }, + }; + try!(session.write_hello(host)); + try!(session.write_ping()); + try!(session.connection.register(event_loop)); + Ok(session) + } + + pub fn readable(&mut self, event_loop: &mut EventLoop, host: &HostInfo) -> Result { + match try!(self.connection.readable(event_loop)) { + Some(data) => self.read_packet(data, host), + None => Ok(SessionData::None) + } + } + + pub fn writable(&mut self, event_loop: &mut EventLoop, _host: &HostInfo) -> Result<(), Error> { + self.connection.writable(event_loop) + } + + pub fn have_capability(&self, protocol: &str) -> bool { + self.info.capabilities.iter().any(|c| c.protocol == protocol) + } + + pub fn send_packet(&mut self, protocol: &str, packet_id: u8, data: &[u8]) -> Result<(), Error> { + let mut i = 0usize; + while protocol != self.info.capabilities[i].protocol { + i += 1; + if i == self.info.capabilities.len() { + debug!(target: "net", "Unkown protocol: {:?}", protocol); + return Ok(()) + } + } + let pid = self.info.capabilities[i].id_offset + packet_id; + let mut rlp = RlpStream::new(); + rlp.append(&(pid as u32)); + rlp.append_raw(data, 1); + self.connection.send_packet(&rlp.out()) + } + + fn read_packet(&mut self, packet: Packet, host: &HostInfo) -> Result { + if packet.data.len() < 2 { + return Err(Error::BadProtocol); + } + let packet_id = packet.data[0]; + if packet_id != PACKET_HELLO && packet_id != PACKET_DISCONNECT && !self.had_hello { + return Err(Error::BadProtocol); + } + match packet_id { + PACKET_HELLO => { + let rlp = UntrustedRlp::new(&packet.data[1..]); //TODO: validate rlp expected size + try!(self.read_hello(&rlp, host)); + Ok(SessionData::Ready) + }, + PACKET_DISCONNECT => Err(Error::Disconnect(DisconnectReason::DisconnectRequested)), + PACKET_PING => { + try!(self.write_pong()); + Ok(SessionData::None) + }, + PACKET_GET_PEERS => Ok(SessionData::None), //TODO; + PACKET_PEERS => Ok(SessionData::None), + PACKET_USER ... PACKET_LAST => { + let mut i = 0usize; + while packet_id < self.info.capabilities[i].id_offset { + i += 1; + if i == self.info.capabilities.len() { + debug!(target: "net", "Unkown packet: {:?}", packet_id); + return Ok(SessionData::None) + } + } + + // map to protocol + let protocol = self.info.capabilities[i].protocol; + let pid = packet_id - self.info.capabilities[i].id_offset; + return Ok(SessionData::Packet { data: packet.data, protocol: protocol, packet_id: pid } ) + }, + _ => { + debug!(target: "net", "Unkown packet: {:?}", packet_id); + Ok(SessionData::None) + } + } + } + + fn write_hello(&mut self, host: &HostInfo) -> Result<(), Error> { + let mut rlp = RlpStream::new(); + rlp.append(&(PACKET_HELLO as u32)); + rlp.append_list(5) + .append(&host.protocol_version) + .append(&host.client_version) + .append(&host.capabilities) + .append(&host.listen_port) + .append(host.id()); + self.connection.send_packet(&rlp.out()) + } + + fn read_hello(&mut self, rlp: &UntrustedRlp, host: &HostInfo) -> Result<(), Error> { + let protocol = try!(rlp.val_at::(0)); + let client_version = try!(rlp.val_at::(1)); + let peer_caps = try!(rlp.val_at::>(2)); + let id = try!(rlp.val_at::(4)); + + // Intersect with host capabilities + // Leave only highset mutually supported capability version + let mut caps: Vec = Vec::new(); + for hc in host.capabilities.iter() { + if peer_caps.iter().any(|c| c.protocol == hc.protocol && c.version == hc.version) { + caps.push(SessionCapabilityInfo { + protocol: hc.protocol, + version: hc.version, + id_offset: 0, + packet_count: hc.packet_count, + }); + } + } + + caps.retain(|c| host.capabilities.iter().any(|hc| hc.protocol == c.protocol && hc.version == c.version)); + let mut i = 0; + while i < caps.len() { + if caps.iter().any(|c| c.protocol == caps[i].protocol && c.version > caps[i].version) { + caps.remove(i); + } + else { + i += 1; + } + } + + i = 0; + let mut offset: u8 = PACKET_USER; + while i < caps.len() { + caps[i].id_offset = offset; + offset += caps[i].packet_count; + i += 1; + } + trace!(target: "net", "Hello: {} v{} {} {:?}", client_version, protocol, id, caps); + self.info.capabilities = caps; + if protocol != host.protocol_version { + return Err(self.disconnect(DisconnectReason::UselessPeer)); + } + self.had_hello = true; + Ok(()) + } + + fn write_ping(&mut self) -> Result<(), Error> { + self.send(try!(Session::prepare(PACKET_PING, 0))) + } + + fn write_pong(&mut self) -> Result<(), Error> { + self.send(try!(Session::prepare(PACKET_PONG, 0))) + } + + fn disconnect(&mut self, reason: DisconnectReason) -> Error { + let mut rlp = RlpStream::new(); + rlp.append(&(PACKET_DISCONNECT as u32)); + rlp.append_list(1); + rlp.append(&(reason.clone() as u32)); + self.connection.send_packet(&rlp.out()).ok(); + Error::Disconnect(reason) + } + + fn prepare(packet_id: u8, items: usize) -> Result { + let mut rlp = RlpStream::new_list(1); + rlp.append(&(packet_id as u32)); + rlp.append_list(items); + Ok(rlp) + } + + fn send(&mut self, rlp: RlpStream) -> Result<(), Error> { + self.connection.send_packet(&rlp.out()) + } +} + diff --git a/src/rlp/mod.rs b/src/rlp/mod.rs index bc2318fc4..39cc3a7e7 100644 --- a/src/rlp/mod.rs +++ b/src/rlp/mod.rs @@ -1,10 +1,10 @@ -//! Rlp serialization module -//! -//! Allows encoding, decoding, and view onto rlp-slice +//! Rlp serialization module +//! +//! Allows encoding, decoding, and view onto rlp-slice //! //!# What should you use when? //! -//!### Use `encode` function when: +//!### Use `encode` function when: //! * You want to encode something inline. //! * You do not work on big set of data. //! * You want to encode whole data structure at once. @@ -23,7 +23,7 @@ //! * You want to get view onto rlp-slice. //! * You don't want to decode whole rlp at once. //! -//!### Use `UntrustedRlp` when: +//!### Use `UntrustedRlp` when: //! * You are working on untrusted data (~corrupted). //! * You need to handle data corruption errors. //! * You are working on input data. @@ -47,14 +47,16 @@ pub use self::rlpstream::{RlpStream}; use super::hash::H256; pub const NULL_RLP: [u8; 1] = [0x80; 1]; +pub const EMPTY_LIST_RLP: [u8; 1] = [0xC0; 1]; pub const SHA3_NULL_RLP: H256 = H256( [0x56, 0xe8, 0x1f, 0x17, 0x1b, 0xcc, 0x55, 0xa6, 0xff, 0x83, 0x45, 0xe6, 0x92, 0xc0, 0xf8, 0x6e, 0x5b, 0x48, 0xe0, 0x1b, 0x99, 0x6c, 0xad, 0xc0, 0x01, 0x62, 0x2f, 0xb5, 0xe3, 0x63, 0xb4, 0x21] ); +pub const SHA3_EMPTY_LIST_RLP: H256 = H256( [0x1d, 0xcc, 0x4d, 0xe8, 0xde, 0xc7, 0x5d, 0x7a, 0xab, 0x85, 0xb5, 0x67, 0xb6, 0xcc, 0xd4, 0x1a, 0xd3, 0x12, 0x45, 0x1b, 0x94, 0x8a, 0x74, 0x13, 0xf0, 0xa1, 0x42, 0xfd, 0x40, 0xd4, 0x93, 0x47] ); /// Shortcut function to decode trusted rlp -/// +/// /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; -/// +/// /// fn main () { /// let data = vec![0xc8, 0x83, b'c', b'a', b't', 0x83, b'd', b'o', b'g']; /// let animals: Vec = decode(&data); @@ -71,7 +73,7 @@ pub fn decode(bytes: &[u8]) -> T where T: Decodable { /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; -/// +/// /// fn main () { /// let animals = vec!["cat", "dog"]; /// let out = encode(&animals); diff --git a/src/rlp/rlpin.rs b/src/rlp/rlpin.rs index ac830cc9c..2179643d1 100644 --- a/src/rlp/rlpin.rs +++ b/src/rlp/rlpin.rs @@ -29,8 +29,8 @@ impl<'a, 'view> View<'a, 'view> for Rlp<'a> where 'a: 'view { } } - fn raw(&'view self) -> &'a [u8] { - self.rlp.raw() + fn as_raw(&'view self) -> &'a [u8] { + self.rlp.as_raw() } fn prototype(&self) -> Self::Prototype { diff --git a/src/rlp/rlpstream.rs b/src/rlp/rlpstream.rs index 3a19cab7d..97ae1b484 100644 --- a/src/rlp/rlpstream.rs +++ b/src/rlp/rlpstream.rs @@ -26,7 +26,7 @@ pub struct RlpStream { } impl Stream for RlpStream { - fn new() -> Self { + fn new() -> Self { RlpStream { unfinished_lists: ElasticArray16::new(), encoder: BasicEncoder::new(), @@ -57,7 +57,7 @@ impl Stream for RlpStream { self.encoder.bytes.push(0xc0u8); self.note_appended(1); }, - _ => { + _ => { let position = self.encoder.bytes.len(); self.unfinished_lists.push(ListInfo::new(position, len)); }, @@ -66,7 +66,7 @@ impl Stream for RlpStream { // return chainable self self } - + fn append_empty_data<'a>(&'a mut self) -> &'a mut RlpStream { // self push raw item self.encoder.bytes.push(0x80); @@ -80,7 +80,7 @@ impl Stream for RlpStream { fn append_raw<'a>(&'a mut self, bytes: &[u8], item_count: usize) -> &'a mut RlpStream { // push raw items - self.encoder.bytes.append_slice(bytes); + self.encoder.bytes.append_slice(bytes); // try to finish and prepend the length self.note_appended(item_count); @@ -101,7 +101,7 @@ impl Stream for RlpStream { self.unfinished_lists.len() == 0 } - fn raw(&self) -> &[u8] { + fn as_raw(&self) -> &[u8] { &self.encoder.bytes } diff --git a/src/rlp/rlptraits.rs b/src/rlp/rlptraits.rs index 5955e132d..407d62daf 100644 --- a/src/rlp/rlptraits.rs +++ b/src/rlp/rlptraits.rs @@ -1,11 +1,11 @@ -use rlp::DecoderError; +use rlp::{DecoderError, UntrustedRlp}; pub trait Decoder: Sized { fn read_value(&self, f: F) -> Result where F: FnOnce(&[u8]) -> Result; fn as_list(&self) -> Result, DecoderError>; - + fn as_rlp<'a>(&'a self) -> &'a UntrustedRlp<'a>; fn as_raw(&self) -> &[u8]; } @@ -24,19 +24,19 @@ pub trait View<'a, 'view>: Sized { fn new(bytes: &'a [u8]) -> Self; /// The raw data of the RLP. - /// + /// /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let data = vec![0xc8, 0x83, b'c', b'a', b't', 0x83, b'd', b'o', b'g']; /// let rlp = Rlp::new(&data); - /// let dog = rlp.at(1).raw(); + /// let dog = rlp.at(1).as_raw(); /// assert_eq!(dog, &[0x83, b'd', b'o', b'g']); /// } /// ``` - fn raw(&'view self) -> &'a [u8]; + fn as_raw(&'view self) -> &'a [u8]; /// Get the prototype of the RLP. fn prototype(&self) -> Self::Prototype; @@ -46,11 +46,11 @@ pub trait View<'a, 'view>: Sized { fn data(&'view self) -> Self::Data; /// Returns number of RLP items. - /// + /// /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let data = vec![0xc8, 0x83, b'c', b'a', b't', 0x83, b'd', b'o', b'g']; /// let rlp = Rlp::new(&data); @@ -62,11 +62,11 @@ pub trait View<'a, 'view>: Sized { fn item_count(&self) -> usize; /// Returns the number of bytes in the data, or zero if it isn't data. - /// + /// /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let data = vec![0xc8, 0x83, b'c', b'a', b't', 0x83, b'd', b'o', b'g']; /// let rlp = Rlp::new(&data); @@ -78,14 +78,14 @@ pub trait View<'a, 'view>: Sized { fn size(&self) -> usize; /// Get view onto RLP-slice at index. - /// + /// /// Caches offset to given index, so access to successive /// slices is faster. - /// + /// /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let data = vec![0xc8, 0x83, b'c', b'a', b't', 0x83, b'd', b'o', b'g']; /// let rlp = Rlp::new(&data); @@ -95,11 +95,11 @@ pub trait View<'a, 'view>: Sized { fn at(&'view self, index: usize) -> Self::Item; /// No value - /// + /// /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let data = vec![]; /// let rlp = Rlp::new(&data); @@ -109,11 +109,11 @@ pub trait View<'a, 'view>: Sized { fn is_null(&self) -> bool; /// Contains a zero-length string or zero-length list. - /// + /// /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let data = vec![0xc0]; /// let rlp = Rlp::new(&data); @@ -123,11 +123,11 @@ pub trait View<'a, 'view>: Sized { fn is_empty(&self) -> bool; /// List value - /// + /// /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let data = vec![0xc8, 0x83, b'c', b'a', b't', 0x83, b'd', b'o', b'g']; /// let rlp = Rlp::new(&data); @@ -137,11 +137,11 @@ pub trait View<'a, 'view>: Sized { fn is_list(&self) -> bool; /// String value - /// + /// /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let data = vec![0xc8, 0x83, b'c', b'a', b't', 0x83, b'd', b'o', b'g']; /// let rlp = Rlp::new(&data); @@ -151,11 +151,11 @@ pub trait View<'a, 'view>: Sized { fn is_data(&self) -> bool; /// Int value - /// + /// /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let data = vec![0xc1, 0x10]; /// let rlp = Rlp::new(&data); @@ -166,11 +166,11 @@ pub trait View<'a, 'view>: Sized { fn is_int(&self) -> bool; /// Get iterator over rlp-slices - /// + /// /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let data = vec![0xc8, 0x83, b'c', b'a', b't', 0x83, b'd', b'o', b'g']; /// let rlp = Rlp::new(&data); @@ -207,7 +207,7 @@ pub trait Stream: Sized { /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let mut stream = RlpStream::new_list(2); /// stream.append(&"cat").append(&"dog"); @@ -222,11 +222,11 @@ pub trait Stream: Sized { /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let mut stream = RlpStream::new_list(2); /// stream.append_list(2).append(&"cat").append(&"dog"); - /// stream.append(&""); + /// stream.append(&""); /// let out = stream.out(); /// assert_eq!(out, vec![0xca, 0xc8, 0x83, b'c', b'a', b't', 0x83, b'd', b'o', b'g', 0x80]); /// } @@ -238,7 +238,7 @@ pub trait Stream: Sized { /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let mut stream = RlpStream::new_list(2); /// stream.append_empty_data().append_empty_data(); @@ -252,11 +252,11 @@ pub trait Stream: Sized { fn append_raw<'a>(&'a mut self, bytes: &[u8], item_count: usize) -> &'a mut Self; /// Clear the output stream so far. - /// + /// /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let mut stream = RlpStream::new_list(3); /// stream.append(&"cat"); @@ -272,7 +272,7 @@ pub trait Stream: Sized { /// ```rust /// extern crate ethcore_util as util; /// use util::rlp::*; - /// + /// /// fn main () { /// let mut stream = RlpStream::new_list(2); /// stream.append(&"cat"); @@ -284,10 +284,10 @@ pub trait Stream: Sized { /// } fn is_finished(&self) -> bool; - fn raw(&self) -> &[u8]; + fn as_raw(&self) -> &[u8]; /// Streams out encoded bytes. - /// + /// /// panic! if stream is not finished. fn out(self) -> Vec; } diff --git a/src/rlp/tests.rs b/src/rlp/tests.rs index 39af862ff..7c2099124 100644 --- a/src/rlp/tests.rs +++ b/src/rlp/tests.rs @@ -19,19 +19,19 @@ fn rlp_at() { let cat = rlp.at(0).unwrap(); assert!(cat.is_data()); - assert_eq!(cat.raw(), &[0x83, b'c', b'a', b't']); + assert_eq!(cat.as_raw(), &[0x83, b'c', b'a', b't']); //assert_eq!(String::decode_untrusted(&cat).unwrap(), "cat".to_string()); assert_eq!(cat.as_val::().unwrap(), "cat".to_string()); let dog = rlp.at(1).unwrap(); assert!(dog.is_data()); - assert_eq!(dog.raw(), &[0x83, b'd', b'o', b'g']); + assert_eq!(dog.as_raw(), &[0x83, b'd', b'o', b'g']); //assert_eq!(String::decode_untrusted(&dog).unwrap(), "dog".to_string()); assert_eq!(dog.as_val::().unwrap(), "dog".to_string()); let cat_again = rlp.at(0).unwrap(); assert!(cat_again.is_data()); - assert_eq!(cat_again.raw(), &[0x83, b'c', b'a', b't']); + assert_eq!(cat_again.as_raw(), &[0x83, b'c', b'a', b't']); //assert_eq!(String::decode_untrusted(&cat_again).unwrap(), "cat".to_string()); assert_eq!(cat_again.as_val::().unwrap(), "cat".to_string()); } @@ -61,18 +61,18 @@ fn rlp_iter() { let cat = iter.next().unwrap(); assert!(cat.is_data()); - assert_eq!(cat.raw(), &[0x83, b'c', b'a', b't']); + assert_eq!(cat.as_raw(), &[0x83, b'c', b'a', b't']); let dog = iter.next().unwrap(); assert!(dog.is_data()); - assert_eq!(dog.raw(), &[0x83, b'd', b'o', b'g']); + assert_eq!(dog.as_raw(), &[0x83, b'd', b'o', b'g']); let none = iter.next(); assert!(none.is_none()); let cat_again = rlp.at(0).unwrap(); assert!(cat_again.is_data()); - assert_eq!(cat_again.raw(), &[0x83, b'c', b'a', b't']); + assert_eq!(cat_again.as_raw(), &[0x83, b'c', b'a', b't']); } } @@ -155,7 +155,7 @@ fn encode_address() { use hash::*; let tests = vec![ - ETestPair(Address::from_str("ef2d6d194084c2de36e0dabfce45d046b37d1106").unwrap(), + ETestPair(Address::from_str("ef2d6d194084c2de36e0dabfce45d046b37d1106").unwrap(), vec![0x94, 0xef, 0x2d, 0x6d, 0x19, 0x40, 0x84, 0xc2, 0xde, 0x36, 0xe0, 0xda, 0xbf, 0xce, 0x45, 0xd0, 0x46, 0xb3, 0x7d, 0x11, 0x06]) @@ -290,7 +290,7 @@ fn decode_untrusted_address() { use hash::*; let tests = vec![ - DTestPair(Address::from_str("ef2d6d194084c2de36e0dabfce45d046b37d1106").unwrap(), + DTestPair(Address::from_str("ef2d6d194084c2de36e0dabfce45d046b37d1106").unwrap(), vec![0x94, 0xef, 0x2d, 0x6d, 0x19, 0x40, 0x84, 0xc2, 0xde, 0x36, 0xe0, 0xda, 0xbf, 0xce, 0x45, 0xd0, 0x46, 0xb3, 0x7d, 0x11, 0x06]) diff --git a/src/rlp/untrusted_rlp.rs b/src/rlp/untrusted_rlp.rs index 68956d8d0..5a12cbc5e 100644 --- a/src/rlp/untrusted_rlp.rs +++ b/src/rlp/untrusted_rlp.rs @@ -41,22 +41,24 @@ impl PayloadInfo { } /// Data-oriented view onto rlp-slice. -/// +/// /// This is immutable structere. No operations change it. -/// +/// /// Should be used in places where, error handling is required, /// eg. on input #[derive(Debug)] pub struct UntrustedRlp<'a> { bytes: &'a [u8], - cache: Cell, + offset_cache: Cell, + count_cache: Cell>, } impl<'a> Clone for UntrustedRlp<'a> { fn clone(&self) -> UntrustedRlp<'a> { UntrustedRlp { bytes: self.bytes, - cache: Cell::new(OffsetCache::new(usize::max_value(), 0)) + offset_cache: self.offset_cache.clone(), + count_cache: self.count_cache.clone(), } } } @@ -72,11 +74,12 @@ impl<'a, 'view> View<'a, 'view> for UntrustedRlp<'a> where 'a: 'view { fn new(bytes: &'a [u8]) -> UntrustedRlp<'a> { UntrustedRlp { bytes: bytes, - cache: Cell::new(OffsetCache::new(usize::max_value(), 0)), + offset_cache: Cell::new(OffsetCache::new(usize::max_value(), 0)), + count_cache: Cell::new(None) } } - - fn raw(&'view self) -> &'a [u8] { + + fn as_raw(&'view self) -> &'a [u8] { self.bytes } @@ -102,7 +105,14 @@ impl<'a, 'view> View<'a, 'view> for UntrustedRlp<'a> where 'a: 'view { fn item_count(&self) -> usize { match self.is_list() { - true => self.iter().count(), + true => match self.count_cache.get() { + Some(c) => c, + None => { + let c = self.iter().count(); + self.count_cache.set(Some(c)); + c + } + }, false => 0 } } @@ -122,7 +132,7 @@ impl<'a, 'view> View<'a, 'view> for UntrustedRlp<'a> where 'a: 'view { // move to cached position if it's index is less or equal to // current search index, otherwise move to beginning of list - let c = self.cache.get(); + let c = self.offset_cache.get(); let (mut bytes, to_skip) = match c.index <= index { true => (try!(UntrustedRlp::consume(self.bytes, c.offset)), index - c.index), false => (try!(self.consume_list_prefix()), index), @@ -132,7 +142,7 @@ impl<'a, 'view> View<'a, 'view> for UntrustedRlp<'a> where 'a: 'view { bytes = try!(UntrustedRlp::consume_items(bytes, to_skip)); // update the cache - self.cache.set(OffsetCache::new(index, self.bytes.len() - bytes.len())); + self.offset_cache.set(OffsetCache::new(index, self.bytes.len() - bytes.len())); // construct new rlp let found = try!(BasicDecoder::payload_info(bytes)); @@ -284,7 +294,7 @@ impl<'a> Decoder for BasicDecoder<'a> { fn read_value(&self, f: F) -> Result where F: FnOnce(&[u8]) -> Result { - let bytes = self.rlp.raw(); + let bytes = self.rlp.as_raw(); match bytes.first().map(|&x| x) { // rlp is too short @@ -306,7 +316,7 @@ impl<'a> Decoder for BasicDecoder<'a> { } fn as_raw(&self) -> &[u8] { - self.rlp.raw() + self.rlp.as_raw() } fn as_list(&self) -> Result, DecoderError> { @@ -315,6 +325,10 @@ impl<'a> Decoder for BasicDecoder<'a> { .collect(); Ok(v) } + + fn as_rlp<'s>(&'s self) -> &'s UntrustedRlp<'s> { + &self.rlp + } } impl Decodable for T where T: FromBytes { @@ -364,7 +378,7 @@ macro_rules! impl_array_decodable { if decoders.len() != $len { return Err(DecoderError::RlpIncorrectListLen); } - + for i in 0..decoders.len() { result[i] = try!(T::decode(&decoders[i])); } diff --git a/src/sha3.rs b/src/sha3.rs index f3faa4e3b..5b0a0c6a4 100644 --- a/src/sha3.rs +++ b/src/sha3.rs @@ -3,22 +3,25 @@ use std::mem::uninitialized; use tiny_keccak::Keccak; use bytes::{BytesConvertable,Populatable}; -use hash::H256; +use hash::{H256, FixedHash}; /// Types implementing this trait are sha3able. -/// +/// /// ``` /// extern crate ethcore_util as util; /// use std::str::FromStr; /// use util::sha3::*; /// use util::hash::*; -/// +/// /// fn main() { /// assert_eq!([0u8; 0].sha3(), H256::from_str("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").unwrap()); /// } /// ``` pub trait Hashable { fn sha3(&self) -> H256; + fn sha3_into(&self, dest: &mut [u8]) { + self.sha3().copy_to(dest); + } } impl Hashable for T where T: BytesConvertable { @@ -31,6 +34,11 @@ impl Hashable for T where T: BytesConvertable { ret } } + fn sha3_into(&self, dest: &mut [u8]) { + let mut keccak = Keccak::new_keccak256(); + keccak.update(self.bytes()); + keccak.finalize(dest); + } } #[test] diff --git a/src/trie/node.rs b/src/trie/node.rs index 03f35a86d..b5745b66f 100644 --- a/src/trie/node.rs +++ b/src/trie/node.rs @@ -25,13 +25,13 @@ impl<'a> Node<'a> { // fed back into this function or inline RLP which can be fed back into this function). Prototype::List(2) => match NibbleSlice::from_encoded(r.at(0).data()) { (slice, true) => Node::Leaf(slice, r.at(1).data()), - (slice, false) => Node::Extension(slice, r.at(1).raw()), + (slice, false) => Node::Extension(slice, r.at(1).as_raw()), }, // branch - first 16 are nodes, 17th is a value (or empty). Prototype::List(17) => { let mut nodes: [&'a [u8]; 16] = unsafe { ::std::mem::uninitialized() }; for i in 0..16 { - nodes[i] = r.at(i).raw(); + nodes[i] = r.at(i).as_raw(); } Node::Branch(nodes, if r.at(16).is_empty() { None } else { Some(r.at(16).data()) }) }, diff --git a/src/trie/triedbmut.rs b/src/trie/triedbmut.rs index b866bc988..03d2b59de 100644 --- a/src/trie/triedbmut.rs +++ b/src/trie/triedbmut.rs @@ -309,22 +309,22 @@ impl<'db> TrieDBMut<'db> { /// removal instructions from the backing database. fn take_node<'a, 'rlp_view>(&'a self, rlp: &'rlp_view Rlp<'a>, journal: &mut Journal) -> &'a [u8] where 'a: 'rlp_view { if rlp.is_list() { - trace!("take_node {:?} (inline)", rlp.raw().pretty()); - rlp.raw() + trace!("take_node {:?} (inline)", rlp.as_raw().pretty()); + rlp.as_raw() } else if rlp.is_data() && rlp.size() == 32 { let h = rlp.as_val(); let r = self.db.lookup(&h).unwrap_or_else(||{ - println!("Node not found! rlp={:?}, node_hash={:?}", rlp.raw().pretty(), h); + println!("Node not found! rlp={:?}, node_hash={:?}", rlp.as_raw().pretty(), h); println!("Journal: {:?}", journal); panic!(); }); - trace!("take_node {:?} (indirect for {:?})", rlp.raw().pretty(), r); + trace!("take_node {:?} (indirect for {:?})", rlp.as_raw().pretty(), r); journal.delete_node_sha3(h); r } else { - trace!("take_node {:?} (???)", rlp.raw().pretty()); + trace!("take_node {:?} (???)", rlp.as_raw().pretty()); panic!("Empty or invalid node given?"); } } @@ -350,7 +350,7 @@ impl<'db> TrieDBMut<'db> { for i in 0..17 { match index == i { // not us - leave alone. - false => { s.append_raw(old_rlp.at(i).raw(), 1); }, + false => { s.append_raw(old_rlp.at(i).as_raw(), 1); }, // branch-leaf entry - just replace. true if i == 16 => { s.append(&value); }, // original had empty slot - place a leaf there. @@ -384,13 +384,13 @@ impl<'db> TrieDBMut<'db> { // not us - empty. _ if index != i => { s.append_empty_data(); }, // branch-value: just replace. - true if i == 16 => { s.append_raw(old_rlp.at(1).raw(), 1); }, + true if i == 16 => { s.append_raw(old_rlp.at(1).as_raw(), 1); }, // direct extension: just replace. - false if existing_key.len() == 1 => { s.append_raw(old_rlp.at(1).raw(), 1); }, + false if existing_key.len() == 1 => { s.append_raw(old_rlp.at(1).as_raw(), 1); }, // original has empty slot. true => journal.new_node(Self::compose_leaf(&existing_key.mid(1), old_rlp.at(1).data()), &mut s), // additional work required after branching. - false => journal.new_node(Self::compose_extension(&existing_key.mid(1), old_rlp.at(1).raw()), &mut s), + false => journal.new_node(Self::compose_extension(&existing_key.mid(1), old_rlp.at(1).as_raw()), &mut s), } }; self.augmented(&s.out(), partial, value, journal) @@ -422,7 +422,7 @@ impl<'db> TrieDBMut<'db> { trace!("partially-shared-prefix (exist={:?}; new={:?}; cp={:?}): AUGMENT-AT-END", existing_key.len(), partial.len(), cp); // low (farther from root) - let low = Self::compose_raw(&existing_key.mid(cp), old_rlp.at(1).raw(), is_leaf); + let low = Self::compose_raw(&existing_key.mid(cp), old_rlp.at(1).as_raw(), is_leaf); let augmented_low = self.augmented(&low, &partial.mid(cp), value, journal); // high (closer to root) diff --git a/src/triehash.rs b/src/triehash.rs index 74900971c..de54bb3a7 100644 --- a/src/triehash.rs +++ b/src/triehash.rs @@ -1,5 +1,5 @@ //! Generetes trie root. -//! +//! //! This module should be used to generate trie root hash. use std::collections::BTreeMap; @@ -11,13 +11,13 @@ use rlp::{RlpStream, Stream}; use vector::SharedPrefix; /// Generates a trie root hash for a vector of values -/// +/// /// ```rust /// extern crate ethcore_util as util; /// use std::str::FromStr; /// use util::triehash::*; /// use util::hash::*; -/// +/// /// fn main() { /// let v = vec![From::from("doe"), From::from("reindeer")]; /// let root = "e766d5d51b89dc39d981b41bda63248d7abce4f0225eefd023792a540bcffee3"; @@ -49,7 +49,7 @@ pub fn ordered_trie_root(input: Vec>) -> H256 { /// use std::str::FromStr; /// use util::triehash::*; /// use util::hash::*; -/// +/// /// fn main() { /// let v = vec![ /// (From::from("doe"), From::from("reindeer")), @@ -121,9 +121,9 @@ fn gen_trie_root(input: Vec<(Vec, Vec)>) -> H256 { /// Hex-prefix Notation. First nibble has flags: oddness = 2^0 & termination = 2^1. /// /// The "termination marker" and "leaf-node" specifier are completely equivalent. -/// +/// /// Input values are in range `[0, 0xf]`. -/// +/// /// ```markdown /// [0,0,1,2,3,4,5] 0x10012345 // 7 > 4 /// [0,1,2,3,4,5] 0x00012345 // 6 > 4 @@ -136,7 +136,7 @@ fn gen_trie_root(input: Vec<(Vec, Vec)>) -> H256 { /// [0,1,2,3,4,5,T] 0x20012345 // 6 > 4 /// [1,2,3,4,5,T] 0x312345 // 5 > 3 /// [1,2,3,4,T] 0x201234 // 4 > 3 -/// ``` +/// ``` fn hex_prefix_encode(nibbles: &[u8], leaf: bool) -> Vec { let inlen = nibbles.len(); let oddness_factor = inlen % 2; @@ -155,7 +155,7 @@ fn hex_prefix_encode(nibbles: &[u8], leaf: bool) -> Vec { res.push(first_byte); - let mut offset = oddness_factor; + let mut offset = oddness_factor; while offset < inlen { let byte = (nibbles[offset] << 4) + nibbles[offset + 1]; res.push(byte); @@ -203,7 +203,7 @@ fn hash256rlp(input: &[(Vec, Vec)], pre_len: usize, stream: &mut RlpStre // skip first element .skip(1) // get minimum number of shared nibbles between first and each successive - .fold(key.len(), | acc, &(ref k, _) | { + .fold(key.len(), | acc, &(ref k, _) | { cmp::min(key.shared_prefix_len(&k), acc) }); @@ -218,7 +218,7 @@ fn hash256rlp(input: &[(Vec, Vec)], pre_len: usize, stream: &mut RlpStre } // an item for every possible nibble/suffix - // + 1 for data + // + 1 for data stream.append_list(17); // if first key len is equal to prefix_len, move to next element @@ -233,10 +233,10 @@ fn hash256rlp(input: &[(Vec, Vec)], pre_len: usize, stream: &mut RlpStre let len = match begin < input.len() { true => input[begin..].iter() .take_while(| pair | pair.0[pre_len] == i ) - .count(), + .count(), false => 0 }; - + // if at least 1 successive element has the same nibble // append their suffixes match len { @@ -272,7 +272,7 @@ fn test_nibbles() { // A => 65 => 0x41 => [4, 1] let v: Vec = From::from("A"); - let e = vec![4, 1]; + let e = vec![4, 1]; assert_eq!(as_nibbles(&v), e); } @@ -338,4 +338,3 @@ mod tests { }); } } - diff --git a/src/uint.rs b/src/uint.rs index b8eccc4cc..7fc11e2df 100644 --- a/src/uint.rs +++ b/src/uint.rs @@ -52,11 +52,39 @@ macro_rules! construct_uint { impl $name { /// Conversion to u32 #[inline] - fn low_u32(&self) -> u32 { + pub fn low_u32(&self) -> u32 { let &$name(ref arr) = self; arr[0] as u32 } + /// Conversion to u64 + #[inline] + pub fn low_u64(&self) -> u64 { + let &$name(ref arr) = self; + arr[0] + } + + /// Conversion to u32 with overflow checking + #[inline] + pub fn as_u32(&self) -> u32 { + let &$name(ref arr) = self; + if (arr[0] & (0xffffffffu64 << 32)) != 0 { + panic!("Intger overflow when casting U256") + } + self.as_u64() as u32 + } + + /// Conversion to u64 with overflow checking + #[inline] + pub fn as_u64(&self) -> u64 { + let &$name(ref arr) = self; + for i in 1..$n_words { + if arr[i] != 0 { + panic!("Intger overflow when casting U256") + } + } + arr[0] + } /// Return the least number of bits needed to represent the number #[inline] pub fn bits(&self) -> usize { @@ -101,7 +129,7 @@ macro_rules! construct_uint { pub fn zero() -> $name { From::from(0u64) } - + #[inline] pub fn one() -> $name { From::from(1u64) @@ -410,7 +438,7 @@ macro_rules! construct_uint { fn from_dec_str(value: &str) -> Result { Ok(value.bytes() .map(|b| b - 48) - .fold($name::from(0u64), | acc, c | + .fold($name::from(0u64), | acc, c | // fast multiplication by 10 // (acc << 3) + (acc << 1) => acc * 10 (acc << 3) + (acc << 1) + $name::from(c) @@ -434,6 +462,18 @@ impl From for U256 { } } +impl From for u64 { + fn from(value: U256) -> u64 { + value.as_u64() + } +} + +impl From for u32 { + fn from(value: U256) -> u32 { + value.as_u32() + } +} + pub const ZERO_U256: U256 = U256([0x00u64; 4]); pub const ONE_U256: U256 = U256([0x01u64, 0x00u64, 0x00u64, 0x00u64]); pub const BAD_U256: U256 = U256([0xffffffffffffffffu64; 4]);