diff --git a/Cargo.lock b/Cargo.lock index 6ae4a06ca..8035d307c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -174,6 +174,7 @@ dependencies = [ "crossbeam 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)", "env_logger 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)", "ethash 0.9.99", + "ethcore-devtools 0.9.99", "ethcore-util 0.9.99", "heapsize 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)", "lazy_static 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)", @@ -226,6 +227,7 @@ dependencies = [ "itertools 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", "json-tests 0.1.0", "lazy_static 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "mio 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.3.14 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/ethcore/src/block_queue.rs b/ethcore/src/block_queue.rs index 851c56c05..c39f158f0 100644 --- a/ethcore/src/block_queue.rs +++ b/ethcore/src/block_queue.rs @@ -153,7 +153,7 @@ impl BlockQueue { } fn verify(verification: Arc>, engine: Arc>, wait: Arc, ready: Arc, deleting: Arc, empty: Arc) { - while !deleting.load(AtomicOrdering::Relaxed) { + while !deleting.load(AtomicOrdering::Acquire) { { let mut lock = verification.lock().unwrap(); @@ -161,11 +161,11 @@ impl BlockQueue { empty.notify_all(); } - while lock.unverified.is_empty() && !deleting.load(AtomicOrdering::Relaxed) { + while lock.unverified.is_empty() && !deleting.load(AtomicOrdering::Acquire) { lock = wait.wait(lock).unwrap(); } - if deleting.load(AtomicOrdering::Relaxed) { + if deleting.load(AtomicOrdering::Acquire) { return; } } @@ -347,7 +347,7 @@ impl MayPanic for BlockQueue { impl Drop for BlockQueue { fn drop(&mut self) { self.clear(); - self.deleting.store(true, AtomicOrdering::Relaxed); + self.deleting.store(true, AtomicOrdering::Release); self.more_to_verify.notify_all(); for t in self.verifiers.drain(..) { t.join().unwrap(); diff --git a/ethcore/src/client.rs b/ethcore/src/client.rs index 72e520ad7..c3ec4b4d0 100644 --- a/ethcore/src/client.rs +++ b/ethcore/src/client.rs @@ -419,11 +419,11 @@ impl BlockChainClient for Client { } fn state_data(&self, _hash: &H256) -> Option { - unimplemented!(); + None } fn block_receipts(&self, _hash: &H256) -> Option { - unimplemented!(); + None } fn import_block(&self, bytes: Bytes) -> ImportResult { diff --git a/ethcore/src/service.rs b/ethcore/src/service.rs index 9389f1db1..534aab49d 100644 --- a/ethcore/src/service.rs +++ b/ethcore/src/service.rs @@ -124,20 +124,18 @@ impl IoHandler for ClientIoHandler { } } -// TODO: rewrite into something that doesn't dependent on the testing environment having a particular port ready for use. -/* #[cfg(test)] mod tests { use super::*; use tests::helpers::*; use util::network::*; + use devtools::*; #[test] fn it_can_be_started() { let spec = get_test_spec(); let temp_path = RandomTempPath::new(); - let service = ClientService::start(spec, NetworkConfiguration::new(), &temp_path.as_path()); + let service = ClientService::start(spec, NetworkConfiguration::new_with_port(40456), &temp_path.as_path()); assert!(service.is_ok()); } } -*/ \ No newline at end of file diff --git a/parity/main.rs b/parity/main.rs index 3279e1fed..b89bd752e 100644 --- a/parity/main.rs +++ b/parity/main.rs @@ -37,6 +37,7 @@ extern crate ethcore_rpc as rpc; use std::net::{SocketAddr}; use std::env; +use std::path::PathBuf; use rlog::{LogLevelFilter}; use env_logger::LogBuilder; use ctrlc::CtrlC; @@ -69,8 +70,10 @@ Options: --no-bootstrap Don't bother trying to connect to any nodes initially. --listen-address URL Specify the IP/port on which to listen for peers [default: 0.0.0.0:30304]. - --public-address URL Specify the IP/port on which peers may connect [default: 0.0.0.0:30304]. + --public-address URL Specify the IP/port on which peers may connect. --address URL Equivalent to --listen-address URL --public-address URL. + --peers NUM Try to manintain that many peers [default: 25]. + --no-discovery Disable new peer discovery. --upnp Use UPnP to try to figure out the correct network settings. --node-key KEY Specify node secret key as hex string. @@ -95,8 +98,10 @@ struct Args { flag_keys_path: String, flag_no_bootstrap: bool, flag_listen_address: String, - flag_public_address: String, + flag_public_address: Option, flag_address: Option, + flag_peers: u32, + flag_no_discovery: bool, flag_upnp: bool, flag_node_key: Option, flag_cache_pref_size: usize, @@ -165,7 +170,7 @@ impl Configuration { self.args.flag_db_path.replace("$HOME", env::home_dir().unwrap().to_str().unwrap()) } - fn keys_path(&self) -> String { + fn _keys_path(&self) -> String { self.args.flag_keys_path.replace("$HOME", env::home_dir().unwrap().to_str().unwrap()) } @@ -187,21 +192,23 @@ impl Configuration { } } - fn net_addresses(&self) -> (SocketAddr, SocketAddr) { - let listen_address; - let public_address; + fn net_addresses(&self) -> (Option, Option) { + let mut listen_address = None; + let mut public_address = None; - match self.args.flag_address { - None => { - listen_address = SocketAddr::from_str(self.args.flag_listen_address.as_ref()).expect("Invalid listen address given with --listen-address"); - public_address = SocketAddr::from_str(self.args.flag_public_address.as_ref()).expect("Invalid public address given with --public-address"); + if let Some(ref a) = self.args.flag_address { + public_address = Some(SocketAddr::from_str(a.as_ref()).expect("Invalid listen/public address given with --address")); + listen_address = public_address; + } + if listen_address.is_none() { + listen_address = Some(SocketAddr::from_str(self.args.flag_listen_address.as_ref()).expect("Invalid listen address given with --listen-address")); + } + if let Some(ref a) = self.args.flag_public_address { + if public_address.is_some() { + panic!("Conflicting flags: --address and --public-address"); } - Some(ref a) => { - public_address = SocketAddr::from_str(a.as_ref()).expect("Invalid listen/public address given with --address"); - listen_address = public_address; - } - }; - + public_address = Some(SocketAddr::from_str(a.as_ref()).expect("Invalid listen address given with --public-address")); + } (listen_address, public_address) } @@ -236,6 +243,11 @@ impl Configuration { net_settings.listen_address = listen; net_settings.public_address = public; net_settings.use_secret = self.args.flag_node_key.as_ref().map(|s| Secret::from_str(&s).expect("Invalid key string")); + net_settings.discovery_enabled = !self.args.flag_no_discovery; + net_settings.ideal_peers = self.args.flag_peers; + let mut net_path = PathBuf::from(&self.path()); + net_path.push("network"); + net_settings.config_path = Some(net_path.to_str().unwrap().to_owned()); // Build client let mut service = ClientService::start(spec, net_settings, &Path::new(&self.path())).unwrap(); diff --git a/sync/src/chain.rs b/sync/src/chain.rs index 00e677538..5c79e08b6 100644 --- a/sync/src/chain.rs +++ b/sync/src/chain.rs @@ -82,7 +82,7 @@ const RECEIPTS_PACKET: u8 = 0x10; const NETWORK_ID: U256 = ONE_U256; //TODO: get this from parent -const CONNECTION_TIMEOUT_SEC: f64 = 30f64; +const CONNECTION_TIMEOUT_SEC: f64 = 10f64; struct Header { /// Header data @@ -314,7 +314,7 @@ impl ChainSync { } self.peers.insert(peer_id.clone(), peer); - info!(target: "sync", "Connected {}:{}", peer_id, io.peer_info(peer_id)); + debug!(target: "sync", "Connected {}:{}", peer_id, io.peer_info(peer_id)); self.sync_peer(io, peer_id, false); Ok(()) } @@ -545,7 +545,7 @@ impl ChainSync { pub fn on_peer_aborting(&mut self, io: &mut SyncIo, peer: PeerId) { trace!(target: "sync", "== Disconnecting {}", peer); if self.peers.contains_key(&peer) { - info!(target: "sync", "Disconnected {}", peer); + debug!(target: "sync", "Disconnected {}", peer); self.clear_peer_download(peer); self.peers.remove(&peer); self.continue_sync(io); @@ -1179,7 +1179,7 @@ impl ChainSync { for (peer_id, peer_number) in updated_peers { let mut peer_best = self.peers.get(&peer_id).unwrap().latest_hash.clone(); if best_number - peer_number > MAX_PEERS_PROPAGATION as BlockNumber { - // If we think peer is too far behind just end one latest hash + // If we think peer is too far behind just send one latest hash peer_best = last_parent.clone(); } sent = sent + match ChainSync::create_new_hashes_rlp(io.chain(), &peer_best, &local_best) { diff --git a/util/Cargo.toml b/util/Cargo.toml index 0ff7fe318..a6259176a 100644 --- a/util/Cargo.toml +++ b/util/Cargo.toml @@ -31,6 +31,8 @@ json-tests = { path = "json-tests" } target_info = "0.1.0" igd = "0.4.2" ethcore-devtools = { path = "../devtools" } +libc = "0.2.7" + [features] default = [] dev = ["clippy"] diff --git a/util/fdlimit/src/raise_fd_limit.rs b/util/fdlimit/src/raise_fd_limit.rs index f57ac2785..92127da35 100644 --- a/util/fdlimit/src/raise_fd_limit.rs +++ b/util/fdlimit/src/raise_fd_limit.rs @@ -57,5 +57,28 @@ pub unsafe fn raise_fd_limit() { } } -#[cfg(not(any(target_os = "macos", target_os = "ios")))] +#[cfg(any(target_os = "linux"))] +#[allow(non_camel_case_types)] +pub unsafe fn raise_fd_limit() { + use libc; + use std::io; + + // Fetch the current resource limits + let mut rlim = libc::rlimit{rlim_cur: 0, rlim_max: 0}; + if libc::getrlimit(libc::RLIMIT_NOFILE, &mut rlim) != 0 { + let err = io::Error::last_os_error(); + panic!("raise_fd_limit: error calling getrlimit: {}", err); + } + + // Set soft limit to hard imit + rlim.rlim_cur = rlim.rlim_max; + + // Set our newly-increased resource limit + if libc::setrlimit(libc::RLIMIT_NOFILE, &rlim) != 0 { + let err = io::Error::last_os_error(); + panic!("raise_fd_limit: error calling setrlimit: {}", err); + } +} + +#[cfg(not(any(target_os = "macos", target_os = "ios", target_os = "linux")))] pub unsafe fn raise_fd_limit() {} diff --git a/util/src/bytes.rs b/util/src/bytes.rs index 08c299ddf..5a4500ae6 100644 --- a/util/src/bytes.rs +++ b/util/src/bytes.rs @@ -170,32 +170,8 @@ pub trait BytesConvertable { fn to_bytes(&self) -> Bytes { self.as_slice().to_vec() } } -impl<'a> BytesConvertable for &'a [u8] { - fn bytes(&self) -> &[u8] { self } -} - -impl BytesConvertable for Vec { - fn bytes(&self) -> &[u8] { self } -} - -impl BytesConvertable for String { - fn bytes(&self) -> &[u8] { &self.as_bytes() } -} - -macro_rules! impl_bytes_convertable_for_array { - ($zero: expr) => (); - ($len: expr, $($idx: expr),*) => { - impl BytesConvertable for [u8; $len] { - fn bytes(&self) -> &[u8] { self } - } - impl_bytes_convertable_for_array! { $($idx),* } - } -} - -// -1 at the end is not expanded -impl_bytes_convertable_for_array! { - 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, -1 +impl BytesConvertable for T where T: AsRef<[u8]> { + fn bytes(&self) -> &[u8] { self.as_ref() } } #[test] diff --git a/util/src/hash.rs b/util/src/hash.rs index 924465e70..d436c2d81 100644 --- a/util/src/hash.rs +++ b/util/src/hash.rs @@ -77,12 +77,6 @@ macro_rules! impl_hash { /// Unformatted binary data of fixed length. pub struct $from (pub [u8; $size]); - impl BytesConvertable for $from { - fn bytes(&self) -> &[u8] { - &self.0 - } - } - impl Deref for $from { type Target = [u8]; @@ -92,6 +86,13 @@ macro_rules! impl_hash { } } + impl AsRef<[u8]> for $from { + #[inline] + fn as_ref(&self) -> &[u8] { + &self.0 + } + } + impl DerefMut for $from { #[inline] fn deref_mut(&mut self) -> &mut [u8] { diff --git a/util/src/io/service.rs b/util/src/io/service.rs index c5f4a6072..83fa71b8a 100644 --- a/util/src/io/service.rs +++ b/util/src/io/service.rs @@ -256,6 +256,11 @@ impl Handler for IoManager where Message: Send + Clone + Sync IoMessage::DeregisterStream { handler_id, token } => { let handler = self.handlers.get(handler_id).expect("Unknown handler id").clone(); handler.deregister_stream(token, event_loop); + // unregister a timer associated with the token (if any) + let timer_id = token + handler_id * TOKENS_PER_HANDLER; + if let Some(timer) = self.timers.write().unwrap().remove(&timer_id) { + event_loop.clear_timeout(timer.timeout); + } }, IoMessage::UpdateStreamRegistration { handler_id, token } => { let handler = self.handlers.get(handler_id).expect("Unknown handler id").clone(); diff --git a/util/src/io/worker.rs b/util/src/io/worker.rs index 1ba0318bc..b874ea0a4 100644 --- a/util/src/io/worker.rs +++ b/util/src/io/worker.rs @@ -44,6 +44,7 @@ pub struct Worker { thread: Option>, wait: Arc, deleting: Arc, + wait_mutex: Arc>, } impl Worker { @@ -61,6 +62,7 @@ impl Worker { thread: None, wait: wait.clone(), deleting: deleting.clone(), + wait_mutex: wait_mutex.clone(), }; worker.thread = Some(thread::Builder::new().name(format!("IO Worker #{}", index)).spawn( move || { @@ -77,13 +79,17 @@ impl Worker { wait_mutex: Arc>, deleting: Arc) where Message: Send + Sync + Clone + 'static { - while !deleting.load(AtomicOrdering::Relaxed) { + loop { { let lock = wait_mutex.lock().unwrap(); - let _ = wait.wait(lock).unwrap(); - if deleting.load(AtomicOrdering::Relaxed) { + if deleting.load(AtomicOrdering::Acquire) { return; } + let _ = wait.wait(lock).unwrap(); + } + + if deleting.load(AtomicOrdering::Acquire) { + return; } while let chase_lev::Steal::Data(work) = stealer.steal() { Worker::do_work(work, channel.clone()); @@ -114,7 +120,8 @@ impl Worker { impl Drop for Worker { fn drop(&mut self) { - self.deleting.store(true, AtomicOrdering::Relaxed); + let _ = self.wait_mutex.lock(); + self.deleting.store(true, AtomicOrdering::Release); self.wait.notify_all(); let thread = mem::replace(&mut self.thread, None).unwrap(); thread.join().ok(); diff --git a/util/src/lib.rs b/util/src/lib.rs index 8f354b7c7..9527341ad 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -106,8 +106,8 @@ extern crate serde; #[macro_use] extern crate log as rlog; extern crate igd; -#[cfg(test)] extern crate ethcore_devtools as devtools; +extern crate libc; pub mod standard; #[macro_use] diff --git a/util/src/network/connection.rs b/util/src/network/connection.rs index 746c745c4..9e9304ca6 100644 --- a/util/src/network/connection.rs +++ b/util/src/network/connection.rs @@ -16,6 +16,7 @@ use std::sync::Arc; use std::collections::VecDeque; +use std::net::SocketAddr; use mio::{Handler, Token, EventSet, EventLoop, PollOpt, TryRead, TryWrite}; use mio::tcp::*; use hash::*; @@ -159,6 +160,21 @@ impl Connection { } } + /// Get socket token + pub fn token(&self) -> StreamToken { + self.token + } + + /// Replace socket token + pub fn set_token(&mut self, token: StreamToken) { + self.token = token; + } + + /// Get remote peer address + pub fn remote_addr(&self) -> io::Result { + self.socket.peer_addr() + } + /// Register this connection with the IO event loop. pub fn register_socket(&self, reg: Token, event_loop: &mut EventLoop) -> io::Result<()> { trace!(target: "net", "connection register; token={:?}", reg); @@ -238,6 +254,16 @@ impl EncryptedConnection { self.connection.token } + /// Replace socket token + pub fn set_token(&mut self, token: StreamToken) { + self.connection.set_token(token); + } + + /// Get remote peer address + pub fn remote_addr(&self) -> io::Result { + self.connection.remote_addr() + } + /// Create an encrypted connection out of the handshake. Consumes a handshake object. pub fn new(mut handshake: Handshake) -> Result { let shared = try!(crypto::ecdh::agree(handshake.ecdhe.secret(), &handshake.remote_public)); diff --git a/util/src/network/discovery.rs b/util/src/network/discovery.rs index e28c79e80..01d2da52c 100644 --- a/util/src/network/discovery.rs +++ b/util/src/network/discovery.rs @@ -14,115 +14,180 @@ // You should have received a copy of the GNU General Public License // along with Parity. If not, see . -// This module is a work in progress - -#![allow(dead_code)] //TODO: remove this after everything is done - -use std::collections::{HashSet, BTreeMap}; -use std::cell::{RefCell}; -use std::ops::{DerefMut}; +use bytes::Bytes; +use std::net::SocketAddr; +use std::collections::{HashSet, HashMap, BTreeMap, VecDeque}; +use std::mem; +use std::cmp; use mio::*; use mio::udp::*; +use sha3::*; +use time; use hash::*; -use sha3::Hashable; use crypto::*; -use network::node::*; +use rlp::*; +use network::node_table::*; +use network::error::NetworkError; +use io::StreamToken; -const ADDRESS_BYTES_SIZE: u32 = 32; ///< Size of address type in bytes. -const ADDRESS_BITS: u32 = 8 * ADDRESS_BYTES_SIZE; ///< Denoted by n in [Kademlia]. -const NODE_BINS: u32 = ADDRESS_BITS - 1; ///< Size of m_state (excludes root, which is us). -const DISCOVERY_MAX_STEPS: u16 = 8; ///< Max iterations of discovery. (discover) -const BUCKET_SIZE: u32 = 16; ///< Denoted by k in [Kademlia]. Number of nodes stored in each bucket. -const ALPHA: usize = 3; ///< Denoted by \alpha in [Kademlia]. Number of concurrent FindNode requests. +use network::PROTOCOL_VERSION; + +const ADDRESS_BYTES_SIZE: u32 = 32; // Size of address type in bytes. +const ADDRESS_BITS: u32 = 8 * ADDRESS_BYTES_SIZE; // Denoted by n in [Kademlia]. +const NODE_BINS: u32 = ADDRESS_BITS - 1; // Size of m_state (excludes root, which is us). +const DISCOVERY_MAX_STEPS: u16 = 8; // Max iterations of discovery. (discover) +const BUCKET_SIZE: usize = 16; // Denoted by k in [Kademlia]. Number of nodes stored in each bucket. +const ALPHA: usize = 3; // Denoted by \alpha in [Kademlia]. Number of concurrent FindNode requests. +const MAX_DATAGRAM_SIZE: usize = 1280; + +const PACKET_PING: u8 = 1; +const PACKET_PONG: u8 = 2; +const PACKET_FIND_NODE: u8 = 3; +const PACKET_NEIGHBOURS: u8 = 4; + +const PING_TIMEOUT_MS: u64 = 300; + +#[derive(Clone, Debug)] +pub struct NodeEntry { + pub id: NodeId, + pub endpoint: NodeEndpoint, +} + +pub struct BucketEntry { + pub address: NodeEntry, + pub timeout: Option, +} struct NodeBucket { - distance: u32, - nodes: Vec + nodes: VecDeque, //sorted by last active } impl NodeBucket { - fn new(distance: u32) -> NodeBucket { + fn new() -> NodeBucket { NodeBucket { - distance: distance, - nodes: Vec::new() + nodes: VecDeque::new() } } } -struct Discovery { +struct Datagramm { + payload: Bytes, + address: SocketAddr, +} + +pub struct Discovery { id: NodeId, + secret: Secret, + public_endpoint: NodeEndpoint, + udp_socket: UdpSocket, + token: StreamToken, discovery_round: u16, discovery_id: NodeId, discovery_nodes: HashSet, node_buckets: Vec, + send_queue: VecDeque } -struct FindNodePacket; - -impl FindNodePacket { - fn new(_endpoint: &NodeEndpoint, _id: &NodeId) -> FindNodePacket { - FindNodePacket - } - - fn sign(&mut self, _secret: &Secret) { - } - - fn send(& self, _socket: &mut UdpSocket) { - } +pub struct TableUpdates { + pub added: HashMap, + pub removed: HashSet, } impl Discovery { - pub fn new(id: &NodeId) -> Discovery { + pub fn new(key: &KeyPair, listen: SocketAddr, public: NodeEndpoint, token: StreamToken) -> Discovery { + let socket = UdpSocket::bound(&listen).expect("Error binding UDP socket"); Discovery { - id: id.clone(), + id: key.public().clone(), + secret: key.secret().clone(), + public_endpoint: public, + token: token, discovery_round: 0, discovery_id: NodeId::new(), discovery_nodes: HashSet::new(), - node_buckets: (0..NODE_BINS).map(NodeBucket::new).collect(), + node_buckets: (0..NODE_BINS).map(|_| NodeBucket::new()).collect(), + udp_socket: socket, + send_queue: VecDeque::new(), } } - pub fn add_node(&mut self, id: &NodeId) { - self.node_buckets[Discovery::distance(&self.id, &id) as usize].nodes.push(id.clone()); + /// Add a new node to discovery table. Pings the node. + pub fn add_node(&mut self, e: NodeEntry) { + let endpoint = e.endpoint.clone(); + self.update_node(e); + self.ping(&endpoint); } - fn start_node_discovery(&mut self, event_loop: &mut EventLoop) { + /// Add a list of known nodes to the table. + pub fn init_node_list(&mut self, mut nodes: Vec) { + for n in nodes.drain(..) { + self.update_node(n); + } + } + + fn update_node(&mut self, e: NodeEntry) { + trace!(target: "discovery", "Inserting {:?}", &e); + let ping = { + let mut bucket = self.node_buckets.get_mut(Discovery::distance(&self.id, &e.id) as usize).unwrap(); + let updated = if let Some(node) = bucket.nodes.iter_mut().find(|n| n.address.id == e.id) { + node.address = e.clone(); + node.timeout = None; + true + } else { false }; + + if !updated { + bucket.nodes.push_front(BucketEntry { address: e, timeout: None }); + } + + if bucket.nodes.len() > BUCKET_SIZE { + //ping least active node + bucket.nodes.back_mut().unwrap().timeout = Some(time::precise_time_ns()); + Some(bucket.nodes.back().unwrap().address.endpoint.clone()) + } else { None } + }; + if let Some(endpoint) = ping { + self.ping(&endpoint); + } + } + + fn clear_ping(&mut self, id: &NodeId) { + let mut bucket = self.node_buckets.get_mut(Discovery::distance(&self.id, &id) as usize).unwrap(); + if let Some(node) = bucket.nodes.iter_mut().find(|n| &n.address.id == id) { + node.timeout = None; + } + } + + fn start(&mut self) { + trace!(target: "discovery", "Starting discovery"); self.discovery_round = 0; - self.discovery_id.randomize(); + self.discovery_id.randomize(); //TODO: use cryptographic nonce self.discovery_nodes.clear(); - self.discover(event_loop); } - fn discover(&mut self, event_loop: &mut EventLoop) { - if self.discovery_round == DISCOVERY_MAX_STEPS - { - debug!("Restarting discovery"); - self.start_node_discovery(event_loop); + fn discover(&mut self) { + if self.discovery_round == DISCOVERY_MAX_STEPS { return; } + trace!(target: "discovery", "Starting round {:?}", self.discovery_round); let mut tried_count = 0; { - let nearest = Discovery::nearest_node_entries(&self.id, &self.discovery_id, &self.node_buckets).into_iter(); - let nodes = RefCell::new(&mut self.discovery_nodes); - let nearest = nearest.filter(|x| nodes.borrow().contains(&x)).take(ALPHA); + let nearest = Discovery::nearest_node_entries(&self.discovery_id, &self.node_buckets).into_iter(); + let nearest = nearest.filter(|x| !self.discovery_nodes.contains(&x.id)).take(ALPHA).collect::>(); for r in nearest { - //let mut p = FindNodePacket::new(&r.endpoint, &self.discovery_id); - //p.sign(&self.secret); - //p.send(&mut self.udp_socket); - let mut borrowed = nodes.borrow_mut(); - borrowed.deref_mut().insert(r.clone()); + let rlp = encode(&(&[self.discovery_id.clone()][..])); + self.send_packet(PACKET_FIND_NODE, &r.endpoint.udp_address(), &rlp); + self.discovery_nodes.insert(r.id.clone()); tried_count += 1; + trace!(target: "discovery", "Sent FindNode to {:?}", &r.endpoint); } } - if tried_count == 0 - { - debug!("Restarting discovery"); - self.start_node_discovery(event_loop); + if tried_count == 0 { + trace!(target: "discovery", "Completing discovery"); + self.discovery_round = DISCOVERY_MAX_STEPS; + self.discovery_nodes.clear(); return; } self.discovery_round += 1; - //event_loop.timeout_ms(Token(NODETABLE_DISCOVERY), 1200).unwrap(); } fn distance(a: &NodeId, b: &NodeId) -> u32 { @@ -138,86 +203,353 @@ impl Discovery { ret } - #[cfg_attr(feature="dev", allow(cyclomatic_complexity))] - fn nearest_node_entries<'b>(source: &NodeId, target: &NodeId, buckets: &'b [NodeBucket]) -> Vec<&'b NodeId> - { - // send ALPHA FindNode packets to nodes we know, closest to target - const LAST_BIN: u32 = NODE_BINS - 1; - let mut head = Discovery::distance(source, target); - let mut tail = if head == 0 { LAST_BIN } else { (head - 1) % NODE_BINS }; + fn ping(&mut self, node: &NodeEndpoint) { + let mut rlp = RlpStream::new_list(3); + rlp.append(&PROTOCOL_VERSION); + self.public_endpoint.to_rlp_list(&mut rlp); + node.to_rlp_list(&mut rlp); + trace!(target: "discovery", "Sent Ping to {:?}", &node); + self.send_packet(PACKET_PING, &node.udp_address(), &rlp.drain()); + } - let mut found: BTreeMap> = BTreeMap::new(); + fn send_packet(&mut self, packet_id: u8, address: &SocketAddr, payload: &[u8]) { + let mut rlp = RlpStream::new(); + rlp.append_raw(&[packet_id], 1); + let source = Rlp::new(payload); + rlp.begin_list(source.item_count() + 1); + for i in 0 .. source.item_count() { + rlp.append_raw(source.at(i).as_raw(), 1); + } + let timestamp = time::get_time().sec as u32 + 60; + rlp.append(×tamp); + + let bytes = rlp.drain(); + let hash = bytes.as_ref().sha3(); + let signature = match ec::sign(&self.secret, &hash) { + Ok(s) => s, + Err(_) => { + warn!("Error signing UDP packet"); + return; + } + }; + let mut packet = Bytes::with_capacity(bytes.len() + 32 + 65); + packet.extend(hash.iter()); + packet.extend(signature.iter()); + packet.extend(bytes.iter()); + let signed_hash = (&packet[32..]).sha3(); + packet[0..32].clone_from_slice(&signed_hash); + self.send_to(packet, address.clone()); + } + + #[cfg_attr(feature="dev", allow(map_clone))] + fn nearest_node_entries(target: &NodeId, buckets: &[NodeBucket]) -> Vec { + let mut found: BTreeMap> = BTreeMap::new(); let mut count = 0; - // if d is 0, then we roll look forward, if last, we reverse, else, spread from d - if head > 1 && tail != LAST_BIN { - while head != tail && head < NODE_BINS && count < BUCKET_SIZE - { - for n in &buckets[head as usize].nodes - { - if count < BUCKET_SIZE { - count += 1; - found.entry(Discovery::distance(target, &n)).or_insert_with(Vec::new).push(n); - } - else { - break; - } - } - if count < BUCKET_SIZE && tail != 0 { - for n in &buckets[tail as usize].nodes { - if count < BUCKET_SIZE { - count += 1; - found.entry(Discovery::distance(target, &n)).or_insert_with(Vec::new).push(n); - } - else { - break; - } + // Sort nodes by distance to target + for bucket in buckets { + for node in &bucket.nodes { + let distance = Discovery::distance(target, &node.address.id); + found.entry(distance).or_insert_with(Vec::new).push(&node.address); + if count == BUCKET_SIZE { + // delete the most distant element + let remove = { + let (_, last) = found.iter_mut().next_back().unwrap(); + last.pop(); + last.is_empty() + }; + if remove { + found.remove(&distance); } } - - head += 1; - if tail > 0 { - tail -= 1; + else { + count += 1; } } } - else if head < 2 { - while head < NODE_BINS && count < BUCKET_SIZE { - for n in &buckets[head as usize].nodes { - if count < BUCKET_SIZE { - count += 1; - found.entry(Discovery::distance(target, &n)).or_insert_with(Vec::new).push(n); - } - else { - break; - } - } - head += 1; - } - } - else { - while tail > 0 && count < BUCKET_SIZE { - for n in &buckets[tail as usize].nodes { - if count < BUCKET_SIZE { - count += 1; - found.entry(Discovery::distance(target, &n)).or_insert_with(Vec::new).push(n); - } - else { - break; - } - } - tail -= 1; - } - } - let mut ret:Vec<&NodeId> = Vec::new(); + let mut ret:Vec = Vec::new(); for nodes in found.values() { - for n in nodes { - if ret.len() < BUCKET_SIZE as usize /* && n->endpoint && n->endpoint.isAllowed() */ { - ret.push(n); - } - } + ret.extend(nodes.iter().map(|&n| n.clone())); } ret } + + pub fn writable(&mut self) { + if self.send_queue.is_empty() { + return; + } + while !self.send_queue.is_empty() { + let data = self.send_queue.pop_front().unwrap(); + match self.udp_socket.send_to(&data.payload, &data.address) { + Ok(Some(size)) if size == data.payload.len() => { + }, + Ok(Some(_)) => { + warn!("UDP sent incomplete datagramm"); + }, + Ok(None) => { + self.send_queue.push_front(data); + return; + } + Err(e) => { + warn!("UDP send error: {:?}, address: {:?}", e, &data.address); + return; + } + } + } + } + + fn send_to(&mut self, payload: Bytes, address: SocketAddr) { + self.send_queue.push_back(Datagramm { payload: payload, address: address }); + } + + pub fn readable(&mut self) -> Option { + let mut buf: [u8; MAX_DATAGRAM_SIZE] = unsafe { mem::uninitialized() }; + match self.udp_socket.recv_from(&mut buf) { + Ok(Some((len, address))) => self.on_packet(&buf[0..len], address).unwrap_or_else(|e| { + debug!("Error processing UDP packet: {:?}", e); + None + }), + Ok(_) => None, + Err(e) => { + warn!("Error reading UPD socket: {:?}", e); + None + } + } + } + + fn on_packet(&mut self, packet: &[u8], from: SocketAddr) -> Result, NetworkError> { + // validate packet + if packet.len() < 32 + 65 + 4 + 1 { + return Err(NetworkError::BadProtocol); + } + + let hash_signed = (&packet[32..]).sha3(); + if hash_signed[..] != packet[0..32] { + return Err(NetworkError::BadProtocol); + } + + let signed = &packet[(32 + 65)..]; + let signature = Signature::from_slice(&packet[32..(32 + 65)]); + let node_id = try!(ec::recover(&signature, &signed.sha3())); + + let packet_id = signed[0]; + let rlp = UntrustedRlp::new(&signed[1..]); + match packet_id { + PACKET_PING => self.on_ping(&rlp, &node_id, &from), + PACKET_PONG => self.on_pong(&rlp, &node_id, &from), + PACKET_FIND_NODE => self.on_find_node(&rlp, &node_id, &from), + PACKET_NEIGHBOURS => self.on_neighbours(&rlp, &node_id, &from), + _ => { + debug!("Unknown UDP packet: {}", packet_id); + Ok(None) + } + } + } + + fn on_ping(&mut self, rlp: &UntrustedRlp, node: &NodeId, from: &SocketAddr) -> Result, NetworkError> { + trace!(target: "discovery", "Got Ping from {:?}", &from); + let version: u32 = try!(rlp.val_at(0)); + if version != PROTOCOL_VERSION { + debug!(target: "discovery", "Unexpected protocol version: {}", version); + return Err(NetworkError::BadProtocol); + } + let source = try!(NodeEndpoint::from_rlp(&try!(rlp.at(1)))); + let dest = try!(NodeEndpoint::from_rlp(&try!(rlp.at(2)))); + let timestamp: u64 = try!(rlp.val_at(3)); + if timestamp < time::get_time().sec as u64{ + debug!(target: "discovery", "Expired ping"); + return Err(NetworkError::Expired); + } + let mut added_map = HashMap::new(); + let entry = NodeEntry { id: node.clone(), endpoint: source.clone() }; + if !entry.endpoint.is_valid() || !entry.endpoint.is_global() { + debug!(target: "discovery", "Got bad address: {:?}", entry); + } + else { + self.update_node(entry.clone()); + added_map.insert(node.clone(), entry); + } + let hash = rlp.as_raw().sha3(); + let mut response = RlpStream::new_list(2); + dest.to_rlp_list(&mut response); + response.append(&hash); + self.send_packet(PACKET_PONG, from, &response.drain()); + + Ok(Some(TableUpdates { added: added_map, removed: HashSet::new() })) + } + + fn on_pong(&mut self, rlp: &UntrustedRlp, node: &NodeId, from: &SocketAddr) -> Result, NetworkError> { + trace!(target: "discovery", "Got Pong from {:?}", &from); + // TODO: validate pong packet + let dest = try!(NodeEndpoint::from_rlp(&try!(rlp.at(0)))); + let timestamp: u64 = try!(rlp.val_at(2)); + if timestamp < time::get_time().sec as u64 { + return Err(NetworkError::Expired); + } + let mut entry = NodeEntry { id: node.clone(), endpoint: dest }; + if !entry.endpoint.is_valid() { + debug!(target: "discovery", "Bad address: {:?}", entry); + entry.endpoint.address = from.clone(); + } + self.clear_ping(node); + let mut added_map = HashMap::new(); + added_map.insert(node.clone(), entry); + Ok(None) + } + + fn on_find_node(&mut self, rlp: &UntrustedRlp, _node: &NodeId, from: &SocketAddr) -> Result, NetworkError> { + trace!(target: "discovery", "Got FindNode from {:?}", &from); + let target: NodeId = try!(rlp.val_at(0)); + let timestamp: u64 = try!(rlp.val_at(1)); + if timestamp < time::get_time().sec as u64 { + return Err(NetworkError::Expired); + } + + let limit = (MAX_DATAGRAM_SIZE - 109) / 90; + let nearest = Discovery::nearest_node_entries(&target, &self.node_buckets); + if nearest.is_empty() { + return Ok(None); + } + let mut rlp = RlpStream::new_list(1); + rlp.begin_list(cmp::min(limit, nearest.len())); + for n in 0 .. nearest.len() { + rlp.begin_list(4); + nearest[n].endpoint.to_rlp(&mut rlp); + rlp.append(&nearest[n].id); + if (n + 1) % limit == 0 || n == nearest.len() - 1 { + self.send_packet(PACKET_NEIGHBOURS, &from, &rlp.drain()); + trace!(target: "discovery", "Sent {} Neighbours to {:?}", n, &from); + rlp = RlpStream::new_list(1); + rlp.begin_list(cmp::min(limit, nearest.len() - n)); + } + } + Ok(None) + } + + fn on_neighbours(&mut self, rlp: &UntrustedRlp, _node: &NodeId, from: &SocketAddr) -> Result, NetworkError> { + // TODO: validate packet + let mut added = HashMap::new(); + trace!(target: "discovery", "Got {} Neighbours from {:?}", try!(rlp.at(0)).item_count(), &from); + for r in try!(rlp.at(0)).iter() { + let endpoint = try!(NodeEndpoint::from_rlp(&r)); + if !endpoint.is_valid() { + debug!(target: "discovery", "Bad address: {:?}", endpoint); + continue; + } + let node_id: NodeId = try!(r.val_at(3)); + if node_id == self.id { + continue; + } + let entry = NodeEntry { id: node_id.clone(), endpoint: endpoint }; + added.insert(node_id, entry.clone()); + self.ping(&entry.endpoint); + self.update_node(entry); + } + Ok(Some(TableUpdates { added: added, removed: HashSet::new() })) + } + + fn check_expired(&mut self, force: bool) -> HashSet { + let now = time::precise_time_ns(); + let mut removed: HashSet = HashSet::new(); + for bucket in &mut self.node_buckets { + bucket.nodes.retain(|node| { + if let Some(timeout) = node.timeout { + if !force && now - timeout < PING_TIMEOUT_MS * 1000_0000 { + true + } + else { + trace!(target: "discovery", "Removed expired node {:?}", &node.address); + removed.insert(node.address.id.clone()); + false + } + } else { true } + }); + } + removed + } + + pub fn round(&mut self) -> Option { + let removed = self.check_expired(false); + self.discover(); + if !removed.is_empty() { + Some(TableUpdates { added: HashMap::new(), removed: removed }) + } else { None } + } + + pub fn refresh(&mut self) { + self.start(); + } + + pub fn register_socket(&self, event_loop: &mut EventLoop) -> Result<(), NetworkError> { + event_loop.register(&self.udp_socket, Token(self.token), EventSet::all(), PollOpt::edge()).expect("Error registering UDP socket"); + Ok(()) + } + + pub fn update_registration(&self, event_loop: &mut EventLoop) -> Result<(), NetworkError> { + let mut registration = EventSet::readable(); + if !self.send_queue.is_empty() { + registration = registration | EventSet::writable(); + } + event_loop.reregister(&self.udp_socket, Token(self.token), registration, PollOpt::edge()).expect("Error reregistering UDP socket"); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use hash::*; + use std::net::*; + use network::node_table::*; + use crypto::KeyPair; + use std::str::FromStr; + + #[test] + fn discovery() { + let key1 = KeyPair::create().unwrap(); + let key2 = KeyPair::create().unwrap(); + let ep1 = NodeEndpoint { address: SocketAddr::from_str("127.0.0.1:40444").unwrap(), udp_port: 40444 }; + let ep2 = NodeEndpoint { address: SocketAddr::from_str("127.0.0.1:40445").unwrap(), udp_port: 40445 }; + let mut discovery1 = Discovery::new(&key1, ep1.address.clone(), ep1.clone(), 0); + let mut discovery2 = Discovery::new(&key2, ep2.address.clone(), ep2.clone(), 0); + + let node1 = Node::from_str("enode://a979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c@127.0.0.1:7770").unwrap(); + let node2 = Node::from_str("enode://b979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c@127.0.0.1:7771").unwrap(); + discovery1.add_node(NodeEntry { id: node1.id.clone(), endpoint: node1.endpoint.clone() }); + discovery1.add_node(NodeEntry { id: node2.id.clone(), endpoint: node2.endpoint.clone() }); + + discovery2.add_node(NodeEntry { id: key1.public().clone(), endpoint: ep1.clone() }); + discovery2.refresh(); + + for _ in 0 .. 10 { + while !discovery1.send_queue.is_empty() { + let datagramm = discovery1.send_queue.pop_front().unwrap(); + if datagramm.address == ep2.address { + discovery2.on_packet(&datagramm.payload, ep1.address.clone()).ok(); + } + } + while !discovery2.send_queue.is_empty() { + let datagramm = discovery2.send_queue.pop_front().unwrap(); + if datagramm.address == ep1.address { + discovery1.on_packet(&datagramm.payload, ep2.address.clone()).ok(); + } + } + discovery2.round(); + } + assert_eq!(Discovery::nearest_node_entries(&NodeId::new(), &discovery2.node_buckets).len(), 3) + } + + #[test] + fn removes_expired() { + let key = KeyPair::create().unwrap(); + let ep = NodeEndpoint { address: SocketAddr::from_str("127.0.0.1:40446").unwrap(), udp_port: 40444 }; + let mut discovery = Discovery::new(&key, ep.address.clone(), ep.clone(), 0); + for _ in 0..1200 { + discovery.add_node(NodeEntry { id: NodeId::random(), endpoint: ep.clone() }); + } + assert!(Discovery::nearest_node_entries(&NodeId::new(), &discovery.node_buckets).len() <= 16); + let removed = discovery.check_expired(true).len(); + assert!(removed > 0); + } } diff --git a/util/src/network/error.rs b/util/src/network/error.rs index 4d7fb483e..31e1d785b 100644 --- a/util/src/network/error.rs +++ b/util/src/network/error.rs @@ -15,23 +15,45 @@ // along with Parity. If not, see . use io::IoError; +use crypto::CryptoError; use rlp::*; -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum DisconnectReason { DisconnectRequested, - _TCPError, - _BadProtocol, + TCPError, + BadProtocol, UselessPeer, - _TooManyPeers, - _DuplicatePeer, - _IncompatibleProtocol, - _NullIdentity, - _ClientQuit, - _UnexpectedIdentity, - _LocalIdentity, + TooManyPeers, + DuplicatePeer, + IncompatibleProtocol, + NullIdentity, + ClientQuit, + UnexpectedIdentity, + LocalIdentity, PingTimeout, + Unknown, +} + +impl DisconnectReason { + pub fn from_u8(n: u8) -> DisconnectReason { + match n { + 0 => DisconnectReason::DisconnectRequested, + 1 => DisconnectReason::TCPError, + 2 => DisconnectReason::BadProtocol, + 3 => DisconnectReason::UselessPeer, + 4 => DisconnectReason::TooManyPeers, + 5 => DisconnectReason::DuplicatePeer, + 6 => DisconnectReason::IncompatibleProtocol, + 7 => DisconnectReason::NullIdentity, + 8 => DisconnectReason::ClientQuit, + 9 => DisconnectReason::UnexpectedIdentity, + 10 => DisconnectReason::LocalIdentity, + 11 => DisconnectReason::PingTimeout, + _ => DisconnectReason::Unknown, + } + } } #[derive(Debug)] @@ -41,6 +63,8 @@ pub enum NetworkError { Auth, /// Unrecognised protocol. BadProtocol, + /// Message expired. + Expired, /// Peer not found. PeerNotFound, /// Peer is diconnected. @@ -61,3 +85,28 @@ impl From for NetworkError { } } +impl From for NetworkError { + fn from(_err: CryptoError) -> NetworkError { + NetworkError::Auth + } +} + +#[test] +fn test_errors() { + assert_eq!(DisconnectReason::ClientQuit, DisconnectReason::from_u8(8)); + let mut r = DisconnectReason::DisconnectRequested; + for i in 0 .. 20 { + r = DisconnectReason::from_u8(i); + } + assert_eq!(DisconnectReason::Unknown, r); + + match >::from(DecoderError::RlpIsTooBig) { + NetworkError::Auth => {}, + _ => panic!("Unexpeceted error"), + } + + match >::from(CryptoError::InvalidSecret) { + NetworkError::Auth => {}, + _ => panic!("Unexpeceted error"), + } +} diff --git a/util/src/network/handshake.rs b/util/src/network/handshake.rs index 4b23c4e16..5d43decd7 100644 --- a/util/src/network/handshake.rs +++ b/util/src/network/handshake.rs @@ -24,7 +24,7 @@ use crypto::*; use crypto; use network::connection::{Connection}; use network::host::{HostInfo}; -use network::node::NodeId; +use network::node_table::NodeId; use error::*; use network::error::NetworkError; use network::stats::NetworkStats; @@ -68,7 +68,7 @@ pub struct Handshake { const AUTH_PACKET_SIZE: usize = 307; const ACK_PACKET_SIZE: usize = 210; -const HANDSHAKE_TIMEOUT: u64 = 30000; +const HANDSHAKE_TIMEOUT: u64 = 5000; impl Handshake { /// Create a new handshake object @@ -87,6 +87,16 @@ impl Handshake { }) } + /// Get id of the remote node if known + pub fn id(&self) -> &NodeId { + &self.id + } + + /// Get stream token id + pub fn token(&self) -> StreamToken { + self.connection.token() + } + /// Start a handhsake pub fn start(&mut self, io: &IoContext, host: &HostInfo, originated: bool) -> Result<(), UtilError> where Message: Send + Clone{ self.originated = originated; diff --git a/util/src/network/host.rs b/util/src/network/host.rs index 5db724d32..feddf1952 100644 --- a/util/src/network/host.rs +++ b/util/src/network/host.rs @@ -14,15 +14,18 @@ // You should have received a copy of the GNU General Public License // along with Parity. If not, see . -use std::net::{SocketAddr, SocketAddrV4}; +use std::net::{SocketAddr}; use std::collections::{HashMap}; use std::hash::{Hasher}; use std::str::{FromStr}; use std::sync::*; use std::ops::*; +use std::cmp::min; +use std::path::{Path, PathBuf}; +use std::io::{Read, Write}; +use std::fs; use mio::*; use mio::tcp::*; -use mio::udp::*; use target_info::Target; use hash::*; use crypto::*; @@ -32,28 +35,32 @@ use network::handshake::Handshake; use network::session::{Session, SessionData}; use error::*; use io::*; -use network::NetworkProtocolHandler; -use network::node::*; +use network::{NetworkProtocolHandler, PROTOCOL_VERSION}; +use network::node_table::*; use network::stats::NetworkStats; use network::error::DisconnectReason; -use igd::{PortMappingProtocol,search_gateway}; +use network::discovery::{Discovery, TableUpdates, NodeEntry}; +use network::ip_utils::{map_external_address, select_public_address}; type Slab = ::slab::Slab; const _DEFAULT_PORT: u16 = 30304; - -const MAX_CONNECTIONS: usize = 1024; -const IDEAL_PEERS: u32 = 10; - +const MAX_SESSIONS: usize = 1024; +const MAX_HANDSHAKES: usize = 80; +const MAX_HANDSHAKES_PER_ROUND: usize = 32; const MAINTENANCE_TIMEOUT: u64 = 1000; #[derive(Debug)] /// Network service configuration pub struct NetworkConfiguration { - /// IP address to listen for incoming connections - pub listen_address: SocketAddr, - /// IP address to advertise - pub public_address: SocketAddr, + /// Directory path to store network configuration. None means nothing will be saved + pub config_path: Option, + /// IP address to listen for incoming connections. Listen to all connections by default + pub listen_address: Option, + /// IP address to advertise. Detected automatically if none. + pub public_address: Option, + /// Port for UDP connections, same as TCP by default + pub udp_port: Option, /// Enable NAT configuration pub nat_enabled: bool, /// Enable discovery @@ -64,77 +71,46 @@ pub struct NetworkConfiguration { pub boot_nodes: Vec, /// Use provided node key instead of default pub use_secret: Option, + /// Number of connected peers to maintain + pub ideal_peers: u32, } impl NetworkConfiguration { /// Create a new instance of default settings. pub fn new() -> NetworkConfiguration { NetworkConfiguration { - listen_address: SocketAddr::from_str("0.0.0.0:30304").unwrap(), - public_address: SocketAddr::from_str("0.0.0.0:30304").unwrap(), + config_path: None, + listen_address: None, + public_address: None, + udp_port: None, nat_enabled: true, discovery_enabled: true, pin: false, boot_nodes: Vec::new(), use_secret: None, + ideal_peers: 25, } } /// Create new default configuration with sepcified listen port. pub fn new_with_port(port: u16) -> NetworkConfiguration { let mut config = NetworkConfiguration::new(); - config.listen_address = SocketAddr::from_str(&format!("0.0.0.0:{}", port)).unwrap(); - config.public_address = SocketAddr::from_str(&format!("0.0.0.0:{}", port)).unwrap(); + config.listen_address = Some(SocketAddr::from_str(&format!("0.0.0.0:{}", port)).unwrap()); config } - - /// Conduct NAT if needed. - pub fn prepared(self) -> Self { - let mut listen = self.listen_address; - let mut public = self.public_address; - - if self.nat_enabled { - info!("Enabling NAT..."); - match search_gateway() { - Err(ref err) => info!("Error: {}", err), - Ok(gateway) => { - let int_addr = SocketAddrV4::from_str("127.0.0.1:30304").unwrap(); - match gateway.get_any_address(PortMappingProtocol::TCP, int_addr, 0, "Parity Node/TCP") { - Err(ref err) => { - info!("There was an error! {}", err); - }, - Ok(ext_addr) => { - info!("Local gateway: {}, External ip address: {}", gateway, ext_addr); - public = SocketAddr::V4(ext_addr); - listen = SocketAddr::V4(int_addr); - }, - } - }, - } - } - - NetworkConfiguration { - listen_address: listen, - public_address: public, - nat_enabled: false, - discovery_enabled: self.discovery_enabled, - pin: self.pin, - boot_nodes: self.boot_nodes, - use_secret: self.use_secret, - } - } } // Tokens -//const TOKEN_BEGIN: usize = USER_TOKEN_START; // TODO: ICE in rustc 1.7.0-nightly (49c382779 2016-01-12) -const TOKEN_BEGIN: usize = 32; -const TCP_ACCEPT: usize = TOKEN_BEGIN + 1; -const IDLE: usize = TOKEN_BEGIN + 2; -const NODETABLE_RECEIVE: usize = TOKEN_BEGIN + 3; -const NODETABLE_MAINTAIN: usize = TOKEN_BEGIN + 4; -const NODETABLE_DISCOVERY: usize = TOKEN_BEGIN + 5; -const FIRST_CONNECTION: usize = TOKEN_BEGIN + 16; -const LAST_CONNECTION: usize = FIRST_CONNECTION + MAX_CONNECTIONS - 1; +const TCP_ACCEPT: usize = LAST_HANDSHAKE + 1; +const IDLE: usize = LAST_HANDSHAKE + 2; +const DISCOVERY: usize = LAST_HANDSHAKE + 3; +const DISCOVERY_REFRESH: usize = LAST_HANDSHAKE + 4; +const DISCOVERY_ROUND: usize = LAST_HANDSHAKE + 5; +const FIRST_SESSION: usize = 0; +const LAST_SESSION: usize = FIRST_SESSION + MAX_SESSIONS - 1; +const FIRST_HANDSHAKE: usize = LAST_SESSION + 1; +const LAST_HANDSHAKE: usize = FIRST_HANDSHAKE + MAX_HANDSHAKES - 1; +const USER_TIMER: usize = LAST_HANDSHAKE + 256; /// Protocol handler level packet id pub type PacketId = u8; @@ -192,37 +168,33 @@ impl Encodable for CapabilityInfo { pub struct NetworkContext<'s, Message> where Message: Send + Sync + Clone + 'static, 's { io: &'s IoContext>, protocol: ProtocolId, - connections: Arc>>, + sessions: Arc>>, session: Option, } impl<'s, Message> NetworkContext<'s, Message> where Message: Send + Sync + Clone + 'static, { /// Create a new network IO access point. Takes references to all the data that can be updated within the IO handler. - fn new(io: &'s IoContext>, - protocol: ProtocolId, - session: Option, connections: Arc>>) -> NetworkContext<'s, Message> { + fn new(io: &'s IoContext>, + protocol: ProtocolId, + session: Option, sessions: Arc>>) -> NetworkContext<'s, Message> { NetworkContext { io: io, protocol: protocol, session: session, - connections: connections, + sessions: sessions, } } /// Send a packet over the network to another peer. pub fn send(&self, peer: PeerId, packet_id: PacketId, data: Vec) -> Result<(), UtilError> { - if let Some(connection) = self.connections.read().unwrap().get(peer).cloned() { - match *connection.lock().unwrap().deref_mut() { - ConnectionEntry::Session(ref mut s) => { - s.send_packet(self.protocol, packet_id as u8, &data).unwrap_or_else(|e| { + let session = { self.sessions.read().unwrap().get(peer).cloned() }; + if let Some(session) = session { + session.lock().unwrap().deref_mut().send_packet(self.protocol, packet_id as u8, &data).unwrap_or_else(|e| { warn!(target: "net", "Send error: {:?}", e); }); //TODO: don't copy vector data - try!(self.io.update_registration(peer)); - }, - _ => warn!(target: "net", "Send: Peer is not connected yet") - } + try!(self.io.update_registration(peer)); } else { - warn!(target: "net", "Send: Peer does not exist") + trace!(target: "net", "Send: Peer no longer exist") } Ok(()) } @@ -237,6 +209,12 @@ impl<'s, Message> NetworkContext<'s, Message> where Message: Send + Sync + Clone } } + /// Send an IO message + pub fn message(&self, msg: Message) { + self.io.message(NetworkIoMessage::User(msg)); + } + + /// Disable current protocol capability for given peer. If no capabilities left peer gets disconnected. pub fn disable_peer(&self, peer: PeerId) { //TODO: remove capability, disconnect if no capabilities left @@ -260,10 +238,9 @@ impl<'s, Message> NetworkContext<'s, Message> where Message: Send + Sync + Clone /// Returns peer identification string pub fn peer_info(&self, peer: PeerId) -> String { - if let Some(connection) = self.connections.read().unwrap().get(peer).cloned() { - if let ConnectionEntry::Session(ref s) = *connection.lock().unwrap().deref() { - return s.info.client_version.clone() - } + let session = { self.sessions.read().unwrap().get(peer).cloned() }; + if let Some(session) = session { + return session.lock().unwrap().info.client_version.clone() } "unknown".to_owned() } @@ -305,12 +282,8 @@ impl HostInfo { } } -enum ConnectionEntry { - Handshake(Handshake), - Session(Session) -} - -type SharedConnectionEntry = Arc>; +type SharedSession = Arc>; +type SharedHandshake = Arc>; #[derive(Copy, Clone)] struct ProtocolTimer { @@ -321,54 +294,95 @@ struct ProtocolTimer { /// Root IO handler. Manages protocol handlers, IO timers and network connections. pub struct Host where Message: Send + Sync + Clone { pub info: RwLock, - udp_socket: Mutex, tcp_listener: Mutex, - connections: Arc>>, - nodes: RwLock>, + handshakes: Arc>>, + sessions: Arc>>, + discovery: Option>, + nodes: RwLock, handlers: RwLock>>>, timers: RwLock>, timer_counter: RwLock, stats: Arc, + public_endpoint: NodeEndpoint, + pinned_nodes: Vec, } impl Host where Message: Send + Sync + Clone { /// Create a new instance pub fn new(config: NetworkConfiguration) -> Host { - let config = config.prepared(); + let listen_address = match config.listen_address { + None => SocketAddr::from_str("0.0.0.0:30304").unwrap(), + Some(addr) => addr, + }; + + let udp_port = config.udp_port.unwrap_or(listen_address.port()); + let public_endpoint = match config.public_address { + None => { + let public_address = select_public_address(listen_address.port()); + let local_endpoint = NodeEndpoint { address: public_address, udp_port: udp_port }; + if config.nat_enabled { + match map_external_address(&local_endpoint) { + Some(endpoint) => { + info!("NAT Mappped to external address {}", endpoint.address); + endpoint + }, + None => local_endpoint + } + } else { + local_endpoint + } + } + Some(addr) => NodeEndpoint { address: addr, udp_port: udp_port } + }; - let addr = config.listen_address; // Setup the server socket - let tcp_listener = TcpListener::bind(&addr).unwrap(); - let udp_socket = UdpSocket::bound(&addr).unwrap(); + let tcp_listener = TcpListener::bind(&listen_address).unwrap(); + let keys = if let Some(ref secret) = config.use_secret { + KeyPair::from_secret(secret.clone()).unwrap() + } else { + config.config_path.clone().and_then(|ref p| load_key(&Path::new(&p))) + .map_or_else(|| { + let key = KeyPair::create().unwrap(); + if let Some(path) = config.config_path.clone() { + save_key(&Path::new(&path), &key.secret()); + } + key + }, + |s| KeyPair::from_secret(s).expect("Error creating node secret key")) + }; + let discovery = if config.discovery_enabled && !config.pin { + Some(Discovery::new(&keys, listen_address.clone(), public_endpoint.clone(), DISCOVERY)) + } else { None }; + let path = config.config_path.clone(); let mut host = Host:: { info: RwLock::new(HostInfo { - keys: if let Some(ref secret) = config.use_secret { KeyPair::from_secret(secret.clone()).unwrap() } else { KeyPair::create().unwrap() }, + keys: keys, config: config, nonce: H256::random(), - protocol_version: 4, + protocol_version: PROTOCOL_VERSION, client_version: format!("Parity/{}/{}-{}-{}", env!("CARGO_PKG_VERSION"), Target::arch(), Target::env(), Target::os()), listen_port: 0, capabilities: Vec::new(), }), - udp_socket: Mutex::new(udp_socket), + discovery: discovery.map(Mutex::new), tcp_listener: Mutex::new(tcp_listener), - connections: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_CONNECTION, MAX_CONNECTIONS))), - nodes: RwLock::new(HashMap::new()), + handshakes: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_HANDSHAKE, MAX_HANDSHAKES))), + sessions: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_SESSION, MAX_SESSIONS))), + nodes: RwLock::new(NodeTable::new(path)), handlers: RwLock::new(HashMap::new()), timers: RwLock::new(HashMap::new()), - timer_counter: RwLock::new(LAST_CONNECTION + 1), + timer_counter: RwLock::new(USER_TIMER), stats: Arc::new(NetworkStats::default()), + public_endpoint: public_endpoint, + pinned_nodes: Vec::new(), }; - let port = host.info.read().unwrap().config.listen_address.port(); + let port = listen_address.port(); host.info.write().unwrap().deref_mut().listen_port = port; - /* - match ::ifaces::Interface::get_all().unwrap().into_iter().filter(|x| x.kind == ::ifaces::Kind::Packet && x.addr.is_some()).next() { - Some(iface) => config.public_address = iface.addr.unwrap(), - None => warn!("No public network interface"), - */ - let boot_nodes = host.info.read().unwrap().config.boot_nodes.clone(); + if let Some(ref mut discovery) = host.discovery { + discovery.lock().unwrap().init_node_list(host.nodes.read().unwrap().unordered_entries()); + } for n in boot_nodes { host.add_node(&n); } @@ -383,7 +397,12 @@ impl Host where Message: Send + Sync + Clone { match Node::from_str(id) { Err(e) => { warn!("Could not add node: {:?}", e); }, Ok(n) => { - self.nodes.write().unwrap().insert(n.id.clone(), n); + let entry = NodeEntry { endpoint: n.endpoint.clone(), id: n.id.clone() }; + self.pinned_nodes.push(n.id.clone()); + self.nodes.write().unwrap().add_node(n); + if let Some(ref mut discovery) = self.discovery { + discovery.lock().unwrap().add_node(entry); + } } } } @@ -392,8 +411,8 @@ impl Host where Message: Send + Sync + Clone { self.info.read().unwrap().client_version.clone() } - pub fn client_id(&self) -> NodeId { - self.info.read().unwrap().id().clone() + pub fn client_url(&self) -> String { + format!("{}", Node::new(self.info.read().unwrap().id().clone(), self.public_endpoint.clone())) } fn maintain_network(&self, io: &IoContext>) { @@ -402,98 +421,86 @@ impl Host where Message: Send + Sync + Clone { } fn have_session(&self, id: &NodeId) -> bool { - self.connections.read().unwrap().iter().any(|e| match *e.lock().unwrap().deref() { ConnectionEntry::Session(ref s) => s.info.id.eq(&id), _ => false }) + self.sessions.read().unwrap().iter().any(|e| e.lock().unwrap().info.id.eq(&id)) + } + + fn session_count(&self) -> usize { + self.sessions.read().unwrap().count() } fn connecting_to(&self, id: &NodeId) -> bool { - self.connections.read().unwrap().iter().any(|e| match *e.lock().unwrap().deref() { ConnectionEntry::Handshake(ref h) => h.id.eq(&id), _ => false }) + self.handshakes.read().unwrap().iter().any(|e| e.lock().unwrap().id.eq(&id)) + } + + fn handshake_count(&self) -> usize { + self.handshakes.read().unwrap().count() } fn keep_alive(&self, io: &IoContext>) { let mut to_kill = Vec::new(); - for e in self.connections.write().unwrap().iter_mut() { - if let ConnectionEntry::Session(ref mut s) = *e.lock().unwrap().deref_mut() { - if !s.keep_alive(io) { - s.disconnect(DisconnectReason::PingTimeout); - to_kill.push(s.token()); - } + for e in self.sessions.write().unwrap().iter_mut() { + let mut s = e.lock().unwrap(); + if !s.keep_alive(io) { + s.disconnect(DisconnectReason::PingTimeout); + to_kill.push(s.token()); } } for p in to_kill { - self.kill_connection(p, io); + self.kill_connection(p, io, true); } } fn connect_peers(&self, io: &IoContext>) { - struct NodeInfo { - id: NodeId, - peer_type: PeerType + let ideal_peers = { self.info.read().unwrap().deref().config.ideal_peers }; + let pin = { self.info.read().unwrap().deref().config.pin }; + let session_count = self.session_count(); + if session_count >= ideal_peers as usize { + return; } - let mut to_connect: Vec = Vec::new(); - - let mut req_conn = 0; - //TODO: use nodes from discovery here - //for n in self.node_buckets.iter().flat_map(|n| &n.nodes).map(|id| NodeInfo { id: id.clone(), peer_type: self.nodes.get(id).unwrap().peer_type}) { - let pin = self.info.read().unwrap().deref().config.pin; - for n in self.nodes.read().unwrap().values().map(|n| NodeInfo { id: n.id.clone(), peer_type: n.peer_type }) { - let connected = self.have_session(&n.id) || self.connecting_to(&n.id); - let required = n.peer_type == PeerType::Required; - if connected && required { - req_conn += 1; - } - else if !connected && (!pin || required) { - to_connect.push(n); - } + let handshake_count = self.handshake_count(); + // allow 16 slots for incoming connections + let handshake_limit = MAX_HANDSHAKES - 16; + if handshake_count >= handshake_limit { + return; } - for n in &to_connect { - if n.peer_type == PeerType::Required { - if req_conn < IDEAL_PEERS { - self.connect_peer(&n.id, io); - } - req_conn += 1; - } - } - - if !pin { - let pending_count = 0; //TODO: - let peer_count = 0; - let mut open_slots = IDEAL_PEERS - peer_count - pending_count + req_conn; - if open_slots > 0 { - for n in &to_connect { - if n.peer_type == PeerType::Optional && open_slots > 0 { - open_slots -= 1; - self.connect_peer(&n.id, io); - } - } - } + let nodes = if pin { self.pinned_nodes.clone() } else { self.nodes.read().unwrap().nodes() }; + for id in nodes.iter().filter(|ref id| !self.have_session(id) && !self.connecting_to(id)) + .take(min(MAX_HANDSHAKES_PER_ROUND, handshake_limit - handshake_count)) { + self.connect_peer(&id, io); } + debug!(target: "net", "Connecting peers: {} sessions, {} pending", self.session_count(), self.handshake_count()); } #[cfg_attr(feature="dev", allow(single_match))] fn connect_peer(&self, id: &NodeId, io: &IoContext>) { if self.have_session(id) { - warn!("Aborted connect. Node already connected."); + trace!("Aborted connect. Node already connected."); return; } if self.connecting_to(id) { - warn!("Aborted connect. Node already connecting."); + trace!("Aborted connect. Node already connecting."); return; } let socket = { let address = { let mut nodes = self.nodes.write().unwrap(); - let node = nodes.get_mut(id).unwrap(); - node.last_attempted = Some(::time::now()); - node.endpoint.address + if let Some(node) = nodes.get_mut(id) { + node.last_attempted = Some(::time::now()); + node.endpoint.address + } + else { + debug!("Connection to expired node aborted"); + return; + } }; match TcpStream::connect(&address) { Ok(socket) => socket, - Err(_) => { - warn!("Cannot connect to node"); + Err(e) => { + warn!("Can't connect to node: {:?}", e); return; } } @@ -504,15 +511,15 @@ impl Host where Message: Send + Sync + Clone { #[cfg_attr(feature="dev", allow(block_in_if_condition_stmt))] fn create_connection(&self, socket: TcpStream, id: Option<&NodeId>, io: &IoContext>) { let nonce = self.info.write().unwrap().next_nonce(); - let mut connections = self.connections.write().unwrap(); - if connections.insert_with(|token| { + let mut handshakes = self.handshakes.write().unwrap(); + if handshakes.insert_with(|token| { let mut handshake = Handshake::new(token, id, socket, &nonce, self.stats.clone()).expect("Can't create handshake"); handshake.start(io, &self.info.read().unwrap(), id.is_some()).and_then(|_| io.register_stream(token)).unwrap_or_else (|e| { debug!(target: "net", "Handshake create error: {:?}", e); }); - Arc::new(Mutex::new(ConnectionEntry::Handshake(handshake))) + Arc::new(Mutex::new(handshake)) }).is_none() { - warn!("Max connections reached"); + debug!("Max handshakes reached"); } } @@ -532,192 +539,262 @@ impl Host where Message: Send + Sync + Clone { io.update_registration(TCP_ACCEPT).expect("Error registering TCP listener"); } - #[cfg_attr(feature="dev", allow(single_match))] - fn connection_writable(&self, token: StreamToken, io: &IoContext>) { - let mut create_session = false; - let mut kill = false; - if let Some(connection) = self.connections.read().unwrap().get(token).cloned() { - match *connection.lock().unwrap().deref_mut() { - ConnectionEntry::Handshake(ref mut h) => { - match h.writable(io, &self.info.read().unwrap()) { - Err(e) => { - debug!(target: "net", "Handshake write error: {:?}", e); - kill = true; - }, - Ok(_) => () - } - if h.done() { - create_session = true; - } - }, - ConnectionEntry::Session(ref mut s) => { - match s.writable(io, &self.info.read().unwrap()) { - Err(e) => { - debug!(target: "net", "Session write error: {:?}", e); - kill = true; - }, - Ok(_) => () - } - io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Session registration error: {:?}", e)); - } + fn handshake_writable(&self, token: StreamToken, io: &IoContext>) { + let handshake = { self.handshakes.read().unwrap().get(token).cloned() }; + if let Some(handshake) = handshake { + let mut h = handshake.lock().unwrap(); + if let Err(e) = h.writable(io, &self.info.read().unwrap()) { + debug!(target: "net", "Handshake write error: {}:{:?}", token, e); + } + } + } + + fn session_writable(&self, token: StreamToken, io: &IoContext>) { + let session = { self.sessions.read().unwrap().get(token).cloned() }; + if let Some(session) = session { + let mut s = session.lock().unwrap(); + if let Err(e) = s.writable(io, &self.info.read().unwrap()) { + debug!(target: "net", "Session write error: {}:{:?}", token, e); } - } - if kill { - self.kill_connection(token, io); //TODO: mark connection as dead an check in kill_connection - return; - } else if create_session { - self.start_session(token, io); io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Session registration error: {:?}", e)); } } fn connection_closed(&self, token: TimerToken, io: &IoContext>) { - self.kill_connection(token, io); + self.kill_connection(token, io, true); } - fn connection_readable(&self, token: StreamToken, io: &IoContext>) { - let mut ready_data: Vec = Vec::new(); - let mut packet_data: Option<(ProtocolId, PacketId, Vec)> = None; + fn handshake_readable(&self, token: StreamToken, io: &IoContext>) { let mut create_session = false; let mut kill = false; - if let Some(connection) = self.connections.read().unwrap().get(token).cloned() { - match *connection.lock().unwrap().deref_mut() { - ConnectionEntry::Handshake(ref mut h) => { - if let Err(e) = h.readable(io, &self.info.read().unwrap()) { - debug!(target: "net", "Handshake read error: {:?}", e); - kill = true; - } - if h.done() { - create_session = true; - } - }, - ConnectionEntry::Session(ref mut s) => { - match s.readable(io, &self.info.read().unwrap()) { - Err(e) => { - debug!(target: "net", "Handshake read error: {:?}", e); - kill = true; - }, - Ok(SessionData::Ready) => { - for (p, _) in self.handlers.read().unwrap().iter() { - if s.have_capability(p) { - ready_data.push(p); - } - } - }, - Ok(SessionData::Packet { - data, - protocol, - packet_id, - }) => { - match self.handlers.read().unwrap().get(protocol) { - None => { warn!(target: "net", "No handler found for protocol: {:?}", protocol) }, - Some(_) => packet_data = Some((protocol, packet_id, data)), - } - }, - Ok(SessionData::None) => {}, - } - } + let handshake = { self.handshakes.read().unwrap().get(token).cloned() }; + if let Some(handshake) = handshake { + let mut h = handshake.lock().unwrap(); + if let Err(e) = h.readable(io, &self.info.read().unwrap()) { + debug!(target: "net", "Handshake read error: {}:{:?}", token, e); + kill = true; } - } + if h.done() { + create_session = true; + } + } if kill { - self.kill_connection(token, io); //TODO: mark connection as dead an check in kill_connection + self.kill_connection(token, io, true); //TODO: mark connection as dead an check in kill_connection return; } else if create_session { self.start_session(token, io); io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Session registration error: {:?}", e)); } + io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Token registration error: {:?}", e)); + } + + fn session_readable(&self, token: StreamToken, io: &IoContext>) { + let mut ready_data: Vec = Vec::new(); + let mut packet_data: Option<(ProtocolId, PacketId, Vec)> = None; + let mut kill = false; + let session = { self.sessions.read().unwrap().get(token).cloned() }; + if let Some(session) = session { + let mut s = session.lock().unwrap(); + match s.readable(io, &self.info.read().unwrap()) { + Err(e) => { + debug!(target: "net", "Session read error: {}:{:?}", token, e); + kill = true; + }, + Ok(SessionData::Ready) => { + for (p, _) in self.handlers.read().unwrap().iter() { + if s.have_capability(p) { + ready_data.push(p); + } + } + }, + Ok(SessionData::Packet { + data, + protocol, + packet_id, + }) => { + match self.handlers.read().unwrap().get(protocol) { + None => { warn!(target: "net", "No handler found for protocol: {:?}", protocol) }, + Some(_) => packet_data = Some((protocol, packet_id, data)), + } + }, + Ok(SessionData::None) => {}, + } + } + if kill { + self.kill_connection(token, io, true); //TODO: mark connection as dead an check in kill_connection + } for p in ready_data { let h = self.handlers.read().unwrap().get(p).unwrap().clone(); - h.connected(&NetworkContext::new(io, p, Some(token), self.connections.clone()), &token); + h.connected(&NetworkContext::new(io, p, Some(token), self.sessions.clone()), &token); } if let Some((p, packet_id, data)) = packet_data { let h = self.handlers.read().unwrap().get(p).unwrap().clone(); - h.read(&NetworkContext::new(io, p, Some(token), self.connections.clone()), &token, packet_id, &data[1..]); + h.read(&NetworkContext::new(io, p, Some(token), self.sessions.clone()), &token, packet_id, &data[1..]); } io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Token registration error: {:?}", e)); } fn start_session(&self, token: StreamToken, io: &IoContext>) { - let mut connections = self.connections.write().unwrap(); - if connections.get(token).is_none() { - return; // handshake expired + let mut handshakes = self.handshakes.write().unwrap(); + if handshakes.get(token).is_none() { + return; } - connections.replace_with(token, |c| { - match Arc::try_unwrap(c).ok().unwrap().into_inner().unwrap() { - ConnectionEntry::Handshake(h) => { - let session = Session::new(h, io, &self.info.read().unwrap()).expect("Session creation error"); - io.update_registration(token).expect("Error updating session registration"); - self.stats.inc_sessions(); - Some(Arc::new(Mutex::new(ConnectionEntry::Session(session)))) - }, - _ => { None } // handshake expired + + // turn a handshake into a session + let mut sessions = self.sessions.write().unwrap(); + let mut h = handshakes.remove(token).unwrap(); + // wait for other threads to stop using it + { + while Arc::get_mut(&mut h).is_none() { + h.lock().ok(); } - }).ok(); + } + let h = Arc::try_unwrap(h).ok().unwrap().into_inner().unwrap(); + let originated = h.originated; + let mut session = match Session::new(h, &self.info.read().unwrap()) { + Ok(s) => s, + Err(e) => { + debug!("Session creation error: {:?}", e); + return; + } + }; + let result = sessions.insert_with(move |session_token| { + session.set_token(session_token); + io.update_registration(session_token).expect("Error updating session registration"); + self.stats.inc_sessions(); + if !originated { + // Add it no node table + if let Ok(address) = session.remote_addr() { + let entry = NodeEntry { id: session.id().clone(), endpoint: NodeEndpoint { address: address, udp_port: address.port() } }; + self.nodes.write().unwrap().add_node(Node::new(entry.id.clone(), entry.endpoint.clone())); + if let Some(ref discovery) = self.discovery { + discovery.lock().unwrap().add_node(entry); + } + } + } + Arc::new(Mutex::new(session)) + }); + if result.is_none() { + warn!("Max sessions reached"); + } } fn connection_timeout(&self, token: StreamToken, io: &IoContext>) { - self.kill_connection(token, io) + self.kill_connection(token, io, true) } - fn kill_connection(&self, token: StreamToken, io: &IoContext>) { + fn kill_connection(&self, token: StreamToken, io: &IoContext>, remote: bool) { let mut to_disconnect: Vec = Vec::new(); - { - let mut connections = self.connections.write().unwrap(); - if let Some(connection) = connections.get(token).cloned() { - match *connection.lock().unwrap().deref_mut() { - ConnectionEntry::Handshake(_) => { - connections.remove(token); - }, - ConnectionEntry::Session(ref mut s) if s.is_ready() => { + let mut failure_id = None; + match token { + FIRST_HANDSHAKE ... LAST_HANDSHAKE => { + let handshakes = self.handshakes.write().unwrap(); + if let Some(handshake) = handshakes.get(token).cloned() { + failure_id = Some(handshake.lock().unwrap().id().clone()); + } + }, + FIRST_SESSION ... LAST_SESSION => { + let sessions = self.sessions.write().unwrap(); + if let Some(session) = sessions.get(token).cloned() { + let s = session.lock().unwrap(); + if s.is_ready() { for (p, _) in self.handlers.read().unwrap().iter() { if s.have_capability(p) { to_disconnect.push(p); } } - connections.remove(token); - }, - _ => {}, + } + failure_id = Some(s.id().clone()); } + }, + _ => {}, + } + io.deregister_stream(token).expect("Error deregistering stream"); + if let Some(id) = failure_id { + if remote { + self.nodes.write().unwrap().note_failure(&id); } - io.deregister_stream(token).expect("Error deregistering stream"); } for p in to_disconnect { let h = self.handlers.read().unwrap().get(p).unwrap().clone(); - h.disconnected(&NetworkContext::new(io, p, Some(token), self.connections.clone()), &token); + h.disconnected(&NetworkContext::new(io, p, Some(token), self.sessions.clone()), &token); } } + + fn update_nodes(&self, io: &IoContext>, node_changes: TableUpdates) { + let mut to_remove: Vec = Vec::new(); + { + { + let handshakes = self.handshakes.write().unwrap(); + for c in handshakes.iter() { + let h = c.lock().unwrap(); + if node_changes.removed.contains(&h.id()) { + to_remove.push(h.token()); + } + } + } + { + let sessions = self.sessions.write().unwrap(); + for c in sessions.iter() { + let s = c.lock().unwrap(); + if node_changes.removed.contains(&s.id()) { + to_remove.push(s.token()); + } + } + } + } + for i in to_remove { + self.kill_connection(i, io, false); + } + self.nodes.write().unwrap().update(node_changes); + } } impl IoHandler> for Host where Message: Send + Sync + Clone + 'static { /// Initialize networking fn initialize(&self, io: &IoContext>) { io.register_stream(TCP_ACCEPT).expect("Error registering TCP listener"); - io.register_stream(NODETABLE_RECEIVE).expect("Error registering UDP listener"); io.register_timer(IDLE, MAINTENANCE_TIMEOUT).expect("Error registering Network idle timer"); - //io.register_timer(NODETABLE_MAINTAIN, 7200); + if self.discovery.is_some() { + io.register_stream(DISCOVERY).expect("Error registering UDP listener"); + io.register_timer(DISCOVERY_REFRESH, 7200).expect("Error registering discovery timer"); + io.register_timer(DISCOVERY_ROUND, 300).expect("Error registering discovery timer"); + } + self.maintain_network(io) } fn stream_hup(&self, io: &IoContext>, stream: StreamToken) { trace!(target: "net", "Hup: {}", stream); match stream { - FIRST_CONNECTION ... LAST_CONNECTION => self.connection_closed(stream, io), + FIRST_SESSION ... LAST_SESSION => self.connection_closed(stream, io), + FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_closed(stream, io), _ => warn!(target: "net", "Unexpected hup"), }; } fn stream_readable(&self, io: &IoContext>, stream: StreamToken) { match stream { - FIRST_CONNECTION ... LAST_CONNECTION => self.connection_readable(stream, io), - NODETABLE_RECEIVE => {}, - TCP_ACCEPT => self.accept(io), + FIRST_SESSION ... LAST_SESSION => self.session_readable(stream, io), + FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_readable(stream, io), + DISCOVERY => { + if let Some(node_changes) = self.discovery.as_ref().unwrap().lock().unwrap().readable() { + self.update_nodes(io, node_changes); + } + io.update_registration(DISCOVERY).expect("Error updating discovery registration"); + }, + TCP_ACCEPT => self.accept(io), _ => panic!("Received unknown readable token"), } } fn stream_writable(&self, io: &IoContext>, stream: StreamToken) { match stream { - FIRST_CONNECTION ... LAST_CONNECTION => self.connection_writable(stream, io), - NODETABLE_RECEIVE => {}, + FIRST_SESSION ... LAST_SESSION => self.session_writable(stream, io), + FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_writable(stream, io), + DISCOVERY => { + self.discovery.as_ref().unwrap().lock().unwrap().writable(); + io.update_registration(DISCOVERY).expect("Error updating discovery registration"); + } _ => panic!("Received unknown writable token"), } } @@ -725,13 +802,22 @@ impl IoHandler> for Host where Messa fn timeout(&self, io: &IoContext>, token: TimerToken) { match token { IDLE => self.maintain_network(io), - FIRST_CONNECTION ... LAST_CONNECTION => self.connection_timeout(token, io), - NODETABLE_DISCOVERY => {}, - NODETABLE_MAINTAIN => {}, + FIRST_SESSION ... LAST_SESSION => self.connection_timeout(token, io), + FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_timeout(token, io), + DISCOVERY_REFRESH => { + self.discovery.as_ref().unwrap().lock().unwrap().refresh(); + io.update_registration(DISCOVERY).expect("Error updating discovery registration"); + }, + DISCOVERY_ROUND => { + if let Some(node_changes) = self.discovery.as_ref().unwrap().lock().unwrap().round() { + self.update_nodes(io, node_changes); + } + io.update_registration(DISCOVERY).expect("Error updating discovery registration"); + }, _ => match self.timers.read().unwrap().get(&token).cloned() { Some(timer) => match self.handlers.read().unwrap().get(timer.protocol).cloned() { None => { warn!(target: "net", "No handler found for protocol: {:?}", timer.protocol) }, - Some(h) => { h.timeout(&NetworkContext::new(io, timer.protocol, None, self.connections.clone()), timer.token); } + Some(h) => { h.timeout(&NetworkContext::new(io, timer.protocol, None, self.sessions.clone()), timer.token); } }, None => { warn!("Unknown timer token: {}", token); } // timer is not registerd through us } @@ -746,7 +832,7 @@ impl IoHandler> for Host where Messa ref versions } => { let h = handler.clone(); - h.initialize(&NetworkContext::new(io, protocol, None, self.connections.clone())); + h.initialize(&NetworkContext::new(io, protocol, None, self.sessions.clone())); self.handlers.write().unwrap().insert(protocol, h); let mut info = self.info.write().unwrap(); for v in versions { @@ -769,17 +855,15 @@ impl IoHandler> for Host where Messa io.register_timer(handler_token, *delay).expect("Error registering timer"); }, NetworkIoMessage::Disconnect(ref peer) => { - if let Some(connection) = self.connections.read().unwrap().get(*peer).cloned() { - match *connection.lock().unwrap().deref_mut() { - ConnectionEntry::Handshake(_) => {}, - ConnectionEntry::Session(ref mut s) => { s.disconnect(DisconnectReason::DisconnectRequested); } - } - } - self.kill_connection(*peer, io); + let session = { self.sessions.read().unwrap().get(*peer).cloned() }; + if let Some(session) = session { + session.lock().unwrap().disconnect(DisconnectReason::DisconnectRequested); + } + self.kill_connection(*peer, io, false); }, NetworkIoMessage::User(ref message) => { for (p, h) in self.handlers.read().unwrap().iter() { - h.message(&NetworkContext::new(io, p, None, self.connections.clone()), &message); + h.message(&NetworkContext::new(io, p, None, self.sessions.clone()), &message); } } } @@ -787,15 +871,16 @@ impl IoHandler> for Host where Messa fn register_stream(&self, stream: StreamToken, reg: Token, event_loop: &mut EventLoop>>) { match stream { - FIRST_CONNECTION ... LAST_CONNECTION => { - if let Some(connection) = self.connections.read().unwrap().get(stream).cloned() { - match *connection.lock().unwrap().deref() { - ConnectionEntry::Handshake(ref h) => h.register_socket(reg, event_loop).expect("Error registering socket"), - ConnectionEntry::Session(_) => warn!("Unexpected session stream registration") - } - } else {} // expired + FIRST_SESSION ... LAST_SESSION => { + warn!("Unexpected session stream registration"); } - NODETABLE_RECEIVE => event_loop.register(self.udp_socket.lock().unwrap().deref(), Token(NODETABLE_RECEIVE), EventSet::all(), PollOpt::edge()).expect("Error registering stream"), + FIRST_HANDSHAKE ... LAST_HANDSHAKE => { + let connection = { self.handshakes.read().unwrap().get(stream).cloned() }; + if let Some(connection) = connection { + connection.lock().unwrap().register_socket(reg, event_loop).expect("Error registering socket"); + } + } + DISCOVERY => self.discovery.as_ref().unwrap().lock().unwrap().register_socket(event_loop).expect("Error registering discovery socket"), TCP_ACCEPT => event_loop.register(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error registering stream"), _ => warn!("Unexpected stream registration") } @@ -803,17 +888,21 @@ impl IoHandler> for Host where Messa fn deregister_stream(&self, stream: StreamToken, event_loop: &mut EventLoop>>) { match stream { - FIRST_CONNECTION ... LAST_CONNECTION => { - let mut connections = self.connections.write().unwrap(); + FIRST_SESSION ... LAST_SESSION => { + let mut connections = self.sessions.write().unwrap(); if let Some(connection) = connections.get(stream).cloned() { - match *connection.lock().unwrap().deref() { - ConnectionEntry::Handshake(ref h) => h.deregister_socket(event_loop).expect("Error deregistering socket"), - ConnectionEntry::Session(ref s) => s.deregister_socket(event_loop).expect("Error deregistering session socket"), - } + connection.lock().unwrap().deregister_socket(event_loop).expect("Error deregistering socket"); connections.remove(stream); - } - }, - NODETABLE_RECEIVE => event_loop.deregister(self.udp_socket.lock().unwrap().deref()).unwrap(), + } + } + FIRST_HANDSHAKE ... LAST_HANDSHAKE => { + let mut connections = self.handshakes.write().unwrap(); + if let Some(connection) = connections.get(stream).cloned() { + connection.lock().unwrap().deregister_socket(event_loop).expect("Error deregistering socket"); + connections.remove(stream); + } + } + DISCOVERY => (), TCP_ACCEPT => event_loop.deregister(self.tcp_listener.lock().unwrap().deref()).unwrap(), _ => warn!("Unexpected stream deregistration") } @@ -821,17 +910,87 @@ impl IoHandler> for Host where Messa fn update_stream(&self, stream: StreamToken, reg: Token, event_loop: &mut EventLoop>>) { match stream { - FIRST_CONNECTION ... LAST_CONNECTION => { - if let Some(connection) = self.connections.read().unwrap().get(stream).cloned() { - match *connection.lock().unwrap().deref() { - ConnectionEntry::Handshake(ref h) => h.update_socket(reg, event_loop).expect("Error updating socket"), - ConnectionEntry::Session(ref s) => s.update_socket(reg, event_loop).expect("Error updating socket"), - } - } else {} // expired + FIRST_SESSION ... LAST_SESSION => { + let connection = { self.sessions.read().unwrap().get(stream).cloned() }; + if let Some(connection) = connection { + connection.lock().unwrap().update_socket(reg, event_loop).expect("Error updating socket"); + } } - NODETABLE_RECEIVE => event_loop.reregister(self.udp_socket.lock().unwrap().deref(), Token(NODETABLE_RECEIVE), EventSet::all(), PollOpt::edge()).expect("Error reregistering stream"), + FIRST_HANDSHAKE ... LAST_HANDSHAKE => { + let connection = { self.handshakes.read().unwrap().get(stream).cloned() }; + if let Some(connection) = connection { + connection.lock().unwrap().update_socket(reg, event_loop).expect("Error updating socket"); + } + } + DISCOVERY => self.discovery.as_ref().unwrap().lock().unwrap().update_registration(event_loop).expect("Error reregistering discovery socket"), TCP_ACCEPT => event_loop.reregister(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error reregistering stream"), _ => warn!("Unexpected stream update") } } } + +fn save_key(path: &Path, key: &Secret) { + let mut path_buf = PathBuf::from(path); + if let Err(e) = fs::create_dir_all(path_buf.as_path()) { + warn!("Error creating key directory: {:?}", e); + return; + }; + path_buf.push("key"); + let mut file = match fs::File::create(path_buf.as_path()) { + Ok(file) => file, + Err(e) => { + warn!("Error creating key file: {:?}", e); + return; + } + }; + if let Err(e) = file.write(&key.hex().into_bytes()) { + warn!("Error writing key file: {:?}", e); + } +} + +fn load_key(path: &Path) -> Option { + let mut path_buf = PathBuf::from(path); + path_buf.push("key"); + let mut file = match fs::File::open(path_buf.as_path()) { + Ok(file) => file, + Err(e) => { + debug!("Error opening key file: {:?}", e); + return None; + } + }; + let mut buf = String::new(); + match file.read_to_string(&mut buf) { + Ok(_) => {}, + Err(e) => { + warn!("Error reading key file: {:?}", e); + return None; + } + } + match Secret::from_str(&buf) { + Ok(key) => Some(key), + Err(e) => { + warn!("Error parsing key file: {:?}", e); + None + } + } +} + +#[test] +fn key_save_load() { + use ::devtools::RandomTempPath; + let temp_path = RandomTempPath::create_dir(); + let key = H256::random(); + save_key(temp_path.as_path(), &key); + let r = load_key(temp_path.as_path()); + assert_eq!(key, r.unwrap()); +} + + +#[test] +fn host_client_url() { + let mut config = NetworkConfiguration::new(); + let key = h256_from_hex("6f7b0d801bc7b5ce7bbd930b84fd0369b3eb25d09be58d64ba811091046f3aa2"); + config.use_secret = Some(key); + let host: Host = Host::new(config); + assert!(host.client_url().starts_with("enode://101b3ef5a4ea7a1c7928e24c4c75fd053c235d7b80c22ae5c03d145d0ac7396e2a4ffff9adee3133a7b05044a5cee08115fd65145e5165d646bde371010d803c@")); +} diff --git a/util/src/network/ip_utils.rs b/util/src/network/ip_utils.rs new file mode 100644 index 000000000..9696c601d --- /dev/null +++ b/util/src/network/ip_utils.rs @@ -0,0 +1,268 @@ +// Copyright 2015, 2016 Ethcore (UK) Ltd. +// This file is part of Parity. + +// Parity is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// Parity is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with Parity. If not, see . + +// Based on original work by David Levy https://raw.githubusercontent.com/dlevy47/rust-interfaces + +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::io; +use igd::{PortMappingProtocol, search_gateway_from_timeout}; +use std::time::Duration; +use network::node_table::{NodeEndpoint}; + +pub enum IpAddr{ + V4(Ipv4Addr), + V6(Ipv6Addr), +} + +/// Socket address extension for rustc beta. To be replaces with now unstable API +pub trait SocketAddrExt { + /// Returns true for the special 'unspecified' address 0.0.0.0. + fn is_unspecified_s(&self) -> bool; + /// Returns true if the address appears to be globally routable. + fn is_global_s(&self) -> bool; +} + +impl SocketAddrExt for Ipv4Addr { + fn is_unspecified_s(&self) -> bool { + self.octets() == [0, 0, 0, 0] + } + + fn is_global_s(&self) -> bool { + !self.is_private() && !self.is_loopback() && !self.is_link_local() && + !self.is_broadcast() && !self.is_documentation() + } +} + +impl SocketAddrExt for Ipv6Addr { + fn is_unspecified_s(&self) -> bool { + self.segments() == [0, 0, 0, 0, 0, 0, 0, 0] + } + + fn is_global_s(&self) -> bool { + if self.is_multicast() { + self.segments()[0] & 0x000f == 14 + } else { + !self.is_loopback() && !((self.segments()[0] & 0xffc0) == 0xfe80) && + !((self.segments()[0] & 0xffc0) == 0xfec0) && !((self.segments()[0] & 0xfe00) == 0xfc00) + } + } +} + +#[cfg(not(windows))] +mod getinterfaces { + use std::{mem, io, ptr}; + use libc::{AF_INET, AF_INET6}; + use libc::{getifaddrs, freeifaddrs, ifaddrs, sockaddr, sockaddr_in, sockaddr_in6}; + use std::net::{Ipv4Addr, Ipv6Addr}; + use super::IpAddr; + + fn convert_sockaddr (sa: *mut sockaddr) -> Option { + if sa == ptr::null_mut() { return None; } + + let (addr, _) = match unsafe { *sa }.sa_family as i32 { + AF_INET => { + let sa: *const sockaddr_in = unsafe { mem::transmute(sa) }; + let sa = & unsafe { *sa }; + let (addr, port) = (sa.sin_addr.s_addr, sa.sin_port); + (IpAddr::V4(Ipv4Addr::new( + (addr & 0x000000FF) as u8, + ((addr & 0x0000FF00) >> 8) as u8, + ((addr & 0x00FF0000) >> 16) as u8, + ((addr & 0xFF000000) >> 24) as u8)), + port) + }, + AF_INET6 => { + let sa: *const sockaddr_in6 = unsafe { mem::transmute(sa) }; + let sa = & unsafe { *sa }; + let (addr, port) = (sa.sin6_addr.s6_addr, sa.sin6_port); + let addr: [u16; 8] = unsafe { mem::transmute(addr) }; + (IpAddr::V6(Ipv6Addr::new( + addr[0], + addr[1], + addr[2], + addr[3], + addr[4], + addr[5], + addr[6], + addr[7])), + port) + }, + _ => return None, + }; + Some(addr) + } + + fn convert_ifaddrs (ifa: *mut ifaddrs) -> Option { + let ifa = unsafe { &mut *ifa }; + convert_sockaddr(ifa.ifa_addr) + } + + pub fn get_all() -> io::Result> { + let mut ifap: *mut ifaddrs = unsafe { mem::zeroed() }; + if unsafe { getifaddrs(&mut ifap as *mut _) } != 0 { + return Err(io::Error::last_os_error()); + } + + let mut ret = Vec::new(); + let mut cur: *mut ifaddrs = ifap; + while cur != ptr::null_mut() { + if let Some(ip_addr) = convert_ifaddrs(cur) { + ret.push(ip_addr); + } + + //TODO: do something else maybe? + cur = unsafe { (*cur).ifa_next }; + } + + unsafe { freeifaddrs(ifap) }; + Ok(ret) + } +} + +#[cfg(not(windows))] +fn get_if_addrs() -> io::Result> { + getinterfaces::get_all() +} + +#[cfg(windows)] +fn get_if_addrs() -> io::Result> { + Ok(Vec::new()) +} + +/// Select the best available public address +pub fn select_public_address(port: u16) -> SocketAddr { + match get_if_addrs() { + Ok(list) => { + //prefer IPV4 bindings + for addr in &list { //TODO: use better criteria than just the first in the list + match *addr { + IpAddr::V4(a) if !a.is_unspecified_s() && !a.is_loopback() && !a.is_link_local() => { + return SocketAddr::V4(SocketAddrV4::new(a, port)); + }, + _ => {}, + } + } + for addr in list { + match addr { + IpAddr::V6(a) if !a.is_unspecified_s() && !a.is_loopback() => { + return SocketAddr::V6(SocketAddrV6::new(a, port, 0, 0)); + }, + _ => {}, + } + } + }, + Err(e) => debug!("Error listing public interfaces: {:?}", e) + } + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port)) +} + +pub fn map_external_address(local: &NodeEndpoint) -> Option { + if let SocketAddr::V4(ref local_addr) = local.address { + match search_gateway_from_timeout(local_addr.ip().clone(), Duration::new(5, 0)) { + Err(ref err) => debug!("Gateway search error: {}", err), + Ok(gateway) => { + match gateway.get_external_ip() { + Err(ref err) => { + debug!("IP request error: {}", err); + }, + Ok(external_addr) => { + match gateway.add_any_port(PortMappingProtocol::TCP, SocketAddrV4::new(local_addr.ip().clone(), local_addr.port()), 0, "Parity Node/TCP") { + Err(ref err) => { + debug!("Port mapping error: {}", err); + }, + Ok(tcp_port) => { + match gateway.add_any_port(PortMappingProtocol::UDP, SocketAddrV4::new(local_addr.ip().clone(), local.udp_port), 0, "Parity Node/UDP") { + Err(ref err) => { + debug!("Port mapping error: {}", err); + }, + Ok(udp_port) => { + return Some(NodeEndpoint { address: SocketAddr::V4(SocketAddrV4::new(external_addr, tcp_port)), udp_port: udp_port }); + }, + } + }, + } + }, + } + }, + } + } + None +} + +#[test] +fn can_select_public_address() { + let pub_address = select_public_address(40477); + assert!(pub_address.port() == 40477); +} + +#[test] +fn can_map_external_address_or_fail() { + let pub_address = select_public_address(40478); + let _ = map_external_address(&NodeEndpoint { address: pub_address, udp_port: 40478 }); +} + +#[test] +fn ipv4_properties() { + fn check(octets: &[u8; 4], unspec: bool, loopback: bool, + private: bool, link_local: bool, global: bool, + multicast: bool, broadcast: bool, documentation: bool) { + let ip = Ipv4Addr::new(octets[0], octets[1], octets[2], octets[3]); + assert_eq!(octets, &ip.octets()); + + assert_eq!(ip.is_unspecified_s(), unspec); + assert_eq!(ip.is_loopback(), loopback); + assert_eq!(ip.is_private(), private); + assert_eq!(ip.is_link_local(), link_local); + assert_eq!(ip.is_global_s(), global); + assert_eq!(ip.is_multicast(), multicast); + assert_eq!(ip.is_broadcast(), broadcast); + assert_eq!(ip.is_documentation(), documentation); + } + + // address unspec loopbk privt linloc global multicast brdcast doc + check(&[0, 0, 0, 0], true, false, false, false, true, false, false, false); + check(&[0, 0, 0, 1], false, false, false, false, true, false, false, false); + check(&[1, 0, 0, 0], false, false, false, false, true, false, false, false); + check(&[10, 9, 8, 7], false, false, true, false, false, false, false, false); + check(&[127, 1, 2, 3], false, true, false, false, false, false, false, false); + check(&[172, 31, 254, 253], false, false, true, false, false, false, false, false); + check(&[169, 254, 253, 242], false, false, false, true, false, false, false, false); + check(&[192, 0, 2, 183], false, false, false, false, false, false, false, true); + check(&[192, 1, 2, 183], false, false, false, false, true, false, false, false); + check(&[192, 168, 254, 253], false, false, true, false, false, false, false, false); + check(&[198, 51, 100, 0], false, false, false, false, false, false, false, true); + check(&[203, 0, 113, 0], false, false, false, false, false, false, false, true); + check(&[203, 2, 113, 0], false, false, false, false, true, false, false, false); + check(&[224, 0, 0, 0], false, false, false, false, true, true, false, false); + check(&[239, 255, 255, 255], false, false, false, false, true, true, false, false); + check(&[255, 255, 255, 255], false, false, false, false, false, false, true, false); +} + +#[test] +fn ipv6_properties() { + fn check(str_addr: &str, unspec: bool, loopback: bool, global: bool) { + let ip: Ipv6Addr = str_addr.parse().unwrap(); + assert_eq!(str_addr, ip.to_string()); + + assert_eq!(ip.is_unspecified_s(), unspec); + assert_eq!(ip.is_loopback(), loopback); + assert_eq!(ip.is_global_s(), global); + } + + // unspec loopbk global + check("::", true, false, true); + check("::1", false, true, false); +} diff --git a/util/src/network/mod.rs b/util/src/network/mod.rs index 6b58c87eb..50645f2be 100644 --- a/util/src/network/mod.rs +++ b/util/src/network/mod.rs @@ -56,7 +56,7 @@ //! } //! //! fn main () { -//! let mut service = NetworkService::::start(NetworkConfiguration::new()).expect("Error creating network service"); +//! let mut service = NetworkService::::start(NetworkConfiguration::new_with_port(40412)).expect("Error creating network service"); //! service.register_protocol(Arc::new(MyHandler), "myproto", &[1u8]); //! //! // Wait for quit condition @@ -71,8 +71,9 @@ mod session; mod discovery; mod service; mod error; -mod node; +mod node_table; mod stats; +mod ip_utils; #[cfg(test)] mod tests; @@ -88,6 +89,9 @@ pub use network::host::NetworkConfiguration; pub use network::stats::NetworkStats; use io::TimerToken; +pub use network::node_table::is_valid_node_url; + +const PROTOCOL_VERSION: u32 = 4; /// Network IO protocol handler. This needs to be implemented for each new subprotocol. /// All the handler function are called from within IO event loop. diff --git a/util/src/network/node.rs b/util/src/network/node.rs deleted file mode 100644 index e23dee9f5..000000000 --- a/util/src/network/node.rs +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright 2015, 2016 Ethcore (UK) Ltd. -// This file is part of Parity. - -// Parity is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. - -// Parity is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. - -// You should have received a copy of the GNU General Public License -// along with Parity. If not, see . - -use std::net::{SocketAddr, ToSocketAddrs}; -use std::hash::{Hash, Hasher}; -use std::str::{FromStr}; -use hash::*; -use rlp::*; -use time::Tm; -use error::*; - -/// Node public key -pub type NodeId = H512; - -#[derive(Debug)] -/// Noe address info -pub struct NodeEndpoint { - /// IP(V4 or V6) address - pub address: SocketAddr, - /// Address as string (can be host name). - pub address_str: String, - /// Conneciton port. - pub udp_port: u16 -} - -impl FromStr for NodeEndpoint { - type Err = UtilError; - - /// Create endpoint from string. Performs name resolution if given a host name. - fn from_str(s: &str) -> Result { - let address = s.to_socket_addrs().map(|mut i| i.next()); - match address { - Ok(Some(a)) => Ok(NodeEndpoint { - address: a, - address_str: s.to_owned(), - udp_port: a.port() - }), - Ok(_) => Err(UtilError::AddressResolve(None)), - Err(e) => Err(UtilError::AddressResolve(Some(e))) - } - } -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum PeerType { - Required, - Optional -} - -pub struct Node { - pub id: NodeId, - pub endpoint: NodeEndpoint, - pub peer_type: PeerType, - pub last_attempted: Option, -} - -impl FromStr for Node { - type Err = UtilError; - fn from_str(s: &str) -> Result { - let (id, endpoint) = if &s[0..8] == "enode://" && s.len() > 136 && &s[136..137] == "@" { - (try!(NodeId::from_str(&s[8..136])), try!(NodeEndpoint::from_str(&s[137..]))) - } - else { - (NodeId::new(), try!(NodeEndpoint::from_str(s))) - }; - - Ok(Node { - id: id, - endpoint: endpoint, - peer_type: PeerType::Optional, - last_attempted: None, - }) - } -} - -impl PartialEq for Node { - fn eq(&self, other: &Self) -> bool { - self.id == other.id - } -} -impl Eq for Node { } - -impl Hash for Node { - fn hash(&self, state: &mut H) where H: Hasher { - self.id.hash(state) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::str::FromStr; - use std::net::*; - use hash::*; - - #[test] - fn endpoint_parse() { - let endpoint = NodeEndpoint::from_str("123.99.55.44:7770"); - assert!(endpoint.is_ok()); - let v4 = match endpoint.unwrap().address { - SocketAddr::V4(v4address) => v4address, - _ => panic!("should ve v4 address") - }; - assert_eq!(SocketAddrV4::new(Ipv4Addr::new(123, 99, 55, 44), 7770), v4); - } - - #[test] - fn node_parse() { - let node = Node::from_str("enode://a979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c@22.99.55.44:7770"); - assert!(node.is_ok()); - let node = node.unwrap(); - let v4 = match node.endpoint.address { - SocketAddr::V4(v4address) => v4address, - _ => panic!("should ve v4 address") - }; - assert_eq!(SocketAddrV4::new(Ipv4Addr::new(22, 99, 55, 44), 7770), v4); - assert_eq!( - H512::from_str("a979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c").unwrap(), - node.id); - } -} diff --git a/util/src/network/node_table.rs b/util/src/network/node_table.rs new file mode 100644 index 000000000..ec6bad2aa --- /dev/null +++ b/util/src/network/node_table.rs @@ -0,0 +1,420 @@ +// Copyright 2015, 2016 Ethcore (UK) Ltd. +// This file is part of Parity. + +// Parity is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// Parity is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with Parity. If not, see . + +use std::mem; +use std::slice::from_raw_parts; +use std::net::{SocketAddr, ToSocketAddrs, SocketAddrV4, SocketAddrV6, Ipv4Addr, Ipv6Addr}; +use std::hash::{Hash, Hasher}; +use std::str::{FromStr}; +use std::collections::HashMap; +use std::fmt::{Display, Formatter}; +use std::path::{PathBuf}; +use std::fmt; +use std::fs; +use std::io::{Read, Write}; +use hash::*; +use rlp::*; +use time::Tm; +use error::*; +use network::discovery::{TableUpdates, NodeEntry}; +use network::ip_utils::*; +pub use rustc_serialize::json::Json; + +/// Node public key +pub type NodeId = H512; + +#[derive(Debug, Clone)] +/// Node address info +pub struct NodeEndpoint { + /// IP(V4 or V6) address + pub address: SocketAddr, + /// Conneciton port. + pub udp_port: u16 +} + +impl NodeEndpoint { + pub fn udp_address(&self) -> SocketAddr { + match self.address { + SocketAddr::V4(a) => SocketAddr::V4(SocketAddrV4::new(a.ip().clone(), self.udp_port)), + SocketAddr::V6(a) => SocketAddr::V6(SocketAddrV6::new(a.ip().clone(), self.udp_port, a.flowinfo(), a.scope_id())), + } + } +} + +impl NodeEndpoint { + pub fn from_rlp(rlp: &UntrustedRlp) -> Result { + let tcp_port = try!(rlp.val_at::(2)); + let udp_port = try!(rlp.val_at::(1)); + let addr_bytes = try!(try!(rlp.at(0)).data()); + let address = try!(match addr_bytes.len() { + 4 => Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(addr_bytes[0], addr_bytes[1], addr_bytes[2], addr_bytes[3]), tcp_port))), + 16 => unsafe { + let o: *const u16 = mem::transmute(addr_bytes.as_ptr()); + let o = from_raw_parts(o, 8); + Ok(SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::new(o[0], o[1], o[2], o[3], o[4], o[5], o[6], o[7]), tcp_port, 0, 0))) + }, + _ => Err(DecoderError::RlpInconsistentLengthAndData) + }); + Ok(NodeEndpoint { address: address, udp_port: udp_port }) + } + + pub fn to_rlp(&self, rlp: &mut RlpStream) { + match self.address { + SocketAddr::V4(a) => { + rlp.append(&(&a.ip().octets()[..])); + } + SocketAddr::V6(a) => unsafe { + let o: *const u8 = mem::transmute(a.ip().segments().as_ptr()); + rlp.append(&from_raw_parts(o, 16)); + } + }; + rlp.append(&self.udp_port); + rlp.append(&self.address.port()); + } + + pub fn to_rlp_list(&self, rlp: &mut RlpStream) { + rlp.begin_list(3); + self.to_rlp(rlp); + } + + pub fn is_valid(&self) -> bool { + self.udp_port != 0 && self.address.port() != 0 && + match self.address { + SocketAddr::V4(a) => !a.ip().is_unspecified_s(), + SocketAddr::V6(a) => !a.ip().is_unspecified_s() + } + } + + pub fn is_global(&self) -> bool { + match self.address { + SocketAddr::V4(a) => a.ip().is_global_s(), + SocketAddr::V6(a) => a.ip().is_global_s() + } + } +} + +impl FromStr for NodeEndpoint { + type Err = UtilError; + + /// Create endpoint from string. Performs name resolution if given a host name. + fn from_str(s: &str) -> Result { + let address = s.to_socket_addrs().map(|mut i| i.next()); + match address { + Ok(Some(a)) => Ok(NodeEndpoint { + address: a, + udp_port: a.port() + }), + Ok(_) => Err(UtilError::AddressResolve(None)), + Err(e) => Err(UtilError::AddressResolve(Some(e))) + } + } +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum PeerType { + _Required, + Optional +} + +pub struct Node { + pub id: NodeId, + pub endpoint: NodeEndpoint, + pub peer_type: PeerType, + pub failures: u32, + pub last_attempted: Option, +} + +impl Node { + pub fn new(id: NodeId, endpoint: NodeEndpoint) -> Node { + Node { + id: id, + endpoint: endpoint, + peer_type: PeerType::Optional, + failures: 0, + last_attempted: None, + } + } +} + +impl Display for Node { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + if self.endpoint.udp_port != self.endpoint.address.port() { + try!(write!(f, "enode://{}@{}+{}", self.id.hex(), self.endpoint.address, self.endpoint.udp_port)); + } else { + try!(write!(f, "enode://{}@{}", self.id.hex(), self.endpoint.address)); + } + Ok(()) + } +} + +impl FromStr for Node { + type Err = UtilError; + fn from_str(s: &str) -> Result { + let (id, endpoint) = if &s[0..8] == "enode://" && s.len() > 136 && &s[136..137] == "@" { + (try!(NodeId::from_str(&s[8..136])), try!(NodeEndpoint::from_str(&s[137..]))) + } + else { + (NodeId::new(), try!(NodeEndpoint::from_str(s))) + }; + + Ok(Node { + id: id, + endpoint: endpoint, + peer_type: PeerType::Optional, + last_attempted: None, + failures: 0, + }) + } +} + +impl PartialEq for Node { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} +impl Eq for Node {} + +impl Hash for Node { + fn hash(&self, state: &mut H) where H: Hasher { + self.id.hash(state) + } +} + +/// Node table backed by disk file. +pub struct NodeTable { + nodes: HashMap, + path: Option, +} + +impl NodeTable { + pub fn new(path: Option) -> NodeTable { + NodeTable { + path: path.clone(), + nodes: NodeTable::load(path), + } + } + + /// Add a node to table + pub fn add_node(&mut self, mut node: Node) { + // preserve failure counter + let failures = self.nodes.get(&node.id).map_or(0, |n| n.failures); + node.failures = failures; + self.nodes.insert(node.id.clone(), node); + } + + /// Returns node ids sorted by number of failures + pub fn nodes(&self) -> Vec { + let mut refs: Vec<&Node> = self.nodes.values().collect(); + refs.sort_by(|a, b| a.failures.cmp(&b.failures)); + refs.iter().map(|n| n.id.clone()).collect() + } + + /// Unordered list of all entries + pub fn unordered_entries(&self) -> Vec { + // preserve failure counter + self.nodes.values().map(|n| NodeEntry { endpoint: n.endpoint.clone(), id: n.id.clone() }).collect() + } + + /// Get particular node + pub fn get_mut(&mut self, id: &NodeId) -> Option<&mut Node> { + self.nodes.get_mut(id) + } + + /// Apply table changes coming from discovery + pub fn update(&mut self, mut update: TableUpdates) { + for (_, node) in update.added.drain() { + let mut entry = self.nodes.entry(node.id.clone()).or_insert_with(|| Node::new(node.id.clone(), node.endpoint.clone())); + entry.endpoint = node.endpoint; + } + for r in update.removed { + self.nodes.remove(&r); + } + } + + /// Increase failure counte for a node + pub fn note_failure(&mut self, id: &NodeId) { + if let Some(node) = self.nodes.get_mut(id) { + node.failures += 1; + } + } + + fn save(&self) { + if let Some(ref path) = self.path { + let mut path_buf = PathBuf::from(path); + if let Err(e) = fs::create_dir_all(path_buf.as_path()) { + warn!("Error creating node table directory: {:?}", e); + return; + }; + path_buf.push("nodes.json"); + let mut json = String::new(); + json.push_str("{\n"); + json.push_str("\"nodes\": [\n"); + let node_ids = self.nodes(); + for i in 0 .. node_ids.len() { + let node = self.nodes.get(&node_ids[i]).unwrap(); + json.push_str(&format!("\t{{ \"url\": \"{}\", \"failures\": {} }}{}\n", node, node.failures, if i == node_ids.len() - 1 {""} else {","})) + } + json.push_str("]\n"); + json.push_str("}"); + let mut file = match fs::File::create(path_buf.as_path()) { + Ok(file) => file, + Err(e) => { + warn!("Error creating node table file: {:?}", e); + return; + } + }; + if let Err(e) = file.write(&json.into_bytes()) { + warn!("Error writing node table file: {:?}", e); + } + } + } + + fn load(path: Option) -> HashMap { + let mut nodes: HashMap = HashMap::new(); + if let Some(path) = path { + let mut path_buf = PathBuf::from(path); + path_buf.push("nodes.json"); + let mut file = match fs::File::open(path_buf.as_path()) { + Ok(file) => file, + Err(e) => { + debug!("Error opening node table file: {:?}", e); + return nodes; + } + }; + let mut buf = String::new(); + match file.read_to_string(&mut buf) { + Ok(_) => {}, + Err(e) => { + warn!("Error reading node table file: {:?}", e); + return nodes; + } + } + let json = match Json::from_str(&buf) { + Ok(json) => json, + Err(e) => { + warn!("Error parsing node table file: {:?}", e); + return nodes; + } + }; + if let Some(list) = json.as_object().and_then(|o| o.get("nodes")).and_then(|n| n.as_array()) { + for n in list.iter().filter_map(|n| n.as_object()) { + if let Some(url) = n.get("url").and_then(|u| u.as_string()) { + if let Ok(mut node) = Node::from_str(url) { + if let Some(failures) = n.get("failures").and_then(|f| f.as_u64()) { + node.failures = failures as u32; + } + nodes.insert(node.id.clone(), node); + } + } + } + } + } + nodes + } +} + +impl Drop for NodeTable { + fn drop(&mut self) { + self.save(); + } +} + +/// Check if node url is valid +pub fn is_valid_node_url(url: &str) -> bool { + use std::str::FromStr; + Node::from_str(url).is_ok() +} + +#[cfg(test)] +mod tests { + use super::*; + use std::str::FromStr; + use std::net::*; + use hash::*; + use devtools::*; + + #[test] + fn endpoint_parse() { + let endpoint = NodeEndpoint::from_str("123.99.55.44:7770"); + assert!(endpoint.is_ok()); + let v4 = match endpoint.unwrap().address { + SocketAddr::V4(v4address) => v4address, + _ => panic!("should ve v4 address") + }; + assert_eq!(SocketAddrV4::new(Ipv4Addr::new(123, 99, 55, 44), 7770), v4); + } + + #[test] + fn node_parse() { + assert!(is_valid_node_url("enode://a979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c@22.99.55.44:7770")); + let node = Node::from_str("enode://a979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c@22.99.55.44:7770"); + assert!(node.is_ok()); + let node = node.unwrap(); + let v4 = match node.endpoint.address { + SocketAddr::V4(v4address) => v4address, + _ => panic!("should ve v4 address") + }; + assert_eq!(SocketAddrV4::new(Ipv4Addr::new(22, 99, 55, 44), 7770), v4); + assert_eq!( + H512::from_str("a979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c").unwrap(), + node.id); + } + + #[test] + fn table_failure_order() { + let node1 = Node::from_str("enode://a979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c@22.99.55.44:7770").unwrap(); + let node2 = Node::from_str("enode://b979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c@22.99.55.44:7770").unwrap(); + let node3 = Node::from_str("enode://c979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c@22.99.55.44:7770").unwrap(); + let id1 = H512::from_str("a979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c").unwrap(); + let id2 = H512::from_str("b979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c").unwrap(); + let id3 = H512::from_str("c979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c").unwrap(); + let mut table = NodeTable::new(None); + table.add_node(node3); + table.add_node(node1); + table.add_node(node2); + + table.note_failure(&id1); + table.note_failure(&id1); + table.note_failure(&id2); + + let r = table.nodes(); + assert_eq!(r[0][..], id3[..]); + assert_eq!(r[1][..], id2[..]); + assert_eq!(r[2][..], id1[..]); + } + + #[test] + fn table_save_load() { + let temp_path = RandomTempPath::create_dir(); + let node1 = Node::from_str("enode://a979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c@22.99.55.44:7770").unwrap(); + let node2 = Node::from_str("enode://b979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c@22.99.55.44:7770").unwrap(); + let id1 = H512::from_str("a979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c").unwrap(); + let id2 = H512::from_str("b979fb575495b8d6db44f750317d0f4622bf4c2aa3365d6af7c284339968eef29b69ad0dce72a4d8db5ebb4968de0e3bec910127f134779fbcb0cb6d3331163c").unwrap(); + { + let mut table = NodeTable::new(Some(temp_path.as_path().to_str().unwrap().to_owned())); + table.add_node(node1); + table.add_node(node2); + table.note_failure(&id2); + } + + { + let table = NodeTable::new(Some(temp_path.as_path().to_str().unwrap().to_owned())); + let r = table.nodes(); + assert_eq!(r[0][..], id1[..]); + assert_eq!(r[1][..], id2[..]); + } + } +} diff --git a/util/src/network/service.rs b/util/src/network/service.rs index 60f0ec415..1cd48abe1 100644 --- a/util/src/network/service.rs +++ b/util/src/network/service.rs @@ -42,7 +42,7 @@ impl NetworkService where Message: Send + Sync + Clone + 'stat let host = Arc::new(Host::new(config)); let stats = host.stats().clone(); let host_info = host.client_version(); - info!("Host ID={:?}", host.client_id()); + info!("Node URL: {}", host.client_url()); try!(io_service.register_handler(host)); Ok(NetworkService { io_service: io_service, diff --git a/util/src/network/session.rs b/util/src/network/session.rs index c4ebe7a2a..b0db5f7ef 100644 --- a/util/src/network/session.rs +++ b/util/src/network/session.rs @@ -14,6 +14,8 @@ // You should have received a copy of the GNU General Public License // along with Parity. If not, see . +use std::net::SocketAddr; +use std::io; use mio::*; use hash::*; use rlp::*; @@ -23,7 +25,7 @@ use error::*; use io::{IoContext, StreamToken}; use network::error::{NetworkError, DisconnectReason}; use network::host::*; -use network::node::NodeId; +use network::node_table::NodeId; use time; const PING_TIMEOUT_SEC: u64 = 30; @@ -89,7 +91,7 @@ impl Decodable for PeerCapabilityInfo { } } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug)] struct SessionCapabilityInfo { pub protocol: &'static str, pub version: u8, @@ -108,7 +110,7 @@ const PACKET_LAST: u8 = 0x7f; impl Session { /// Create a new session out of comepleted handshake. Consumes handshake object. - pub fn new(h: Handshake, _io: &IoContext, host: &HostInfo) -> Result where Message: Send + Sync + Clone { + pub fn new(h: Handshake, host: &HostInfo) -> Result { let id = h.id.clone(); let connection = try!(EncryptedConnection::new(h)); let mut session = Session { @@ -129,11 +131,26 @@ impl Session { Ok(session) } + /// Get id of the remote peer + pub fn id(&self) -> &NodeId { + &self.info.id + } + /// Check if session is ready to send/receive data pub fn is_ready(&self) -> bool { self.had_hello } + /// Replace socket token + pub fn set_token(&mut self, token: StreamToken) { + self.connection.set_token(token); + } + + /// Get remote peer address + pub fn remote_addr(&self) -> io::Result { + self.connection.remote_addr() + } + /// Readable IO handler. Returns packet data if available. pub fn readable(&mut self, io: &IoContext, host: &HostInfo) -> Result where Message: Send + Sync + Clone { match try!(self.connection.readable(io)) { @@ -214,7 +231,11 @@ impl Session { try!(self.read_hello(&rlp, host)); Ok(SessionData::Ready) }, - PACKET_DISCONNECT => Err(From::from(NetworkError::Disconnect(DisconnectReason::DisconnectRequested))), + PACKET_DISCONNECT => { + let rlp = UntrustedRlp::new(&packet.data[1..]); + let reason: u8 = try!(rlp.val_at(0)); + Err(From::from(NetworkError::Disconnect(DisconnectReason::from_u8(reason)))) + } PACKET_PING => { try!(self.send_pong()); Ok(SessionData::None) @@ -301,7 +322,12 @@ impl Session { trace!(target: "net", "Hello: {} v{} {} {:?}", client_version, protocol, id, caps); self.info.client_version = client_version; self.info.capabilities = caps; + if self.info.capabilities.is_empty() { + trace!("No common capabilities with peer."); + return Err(From::from(self.disconnect(DisconnectReason::UselessPeer))); + } if protocol != host.protocol_version { + trace!("Peer protocol version mismatch: {}", protocol); return Err(From::from(self.disconnect(DisconnectReason::UselessPeer))); } self.had_hello = true; diff --git a/util/src/network/tests.rs b/util/src/network/tests.rs index c1b59df9b..44d53bdbe 100644 --- a/util/src/network/tests.rs +++ b/util/src/network/tests.rs @@ -23,17 +23,10 @@ use io::TimerToken; use crypto::KeyPair; pub struct TestProtocol { + drop_session: bool, pub packet: Mutex, pub got_timeout: AtomicBool, -} - -impl Default for TestProtocol { - fn default() -> Self { - TestProtocol { - packet: Mutex::new(Vec::new()), - got_timeout: AtomicBool::new(false), - } - } + pub got_disconnect: AtomicBool, } #[derive(Clone)] @@ -42,9 +35,17 @@ pub struct TestProtocolMessage { } impl TestProtocol { + pub fn new(drop_session: bool) -> Self { + TestProtocol { + packet: Mutex::new(Vec::new()), + got_timeout: AtomicBool::new(false), + got_disconnect: AtomicBool::new(false), + drop_session: drop_session, + } + } /// Creates and register protocol with the network service - pub fn register(service: &mut NetworkService) -> Arc { - let handler = Arc::new(TestProtocol::default()); + pub fn register(service: &mut NetworkService, drop_session: bool) -> Arc { + let handler = Arc::new(TestProtocol::new(drop_session)); service.register_protocol(handler.clone(), "test", &[42u8, 43u8]).expect("Error registering test protocol handler"); handler } @@ -56,6 +57,10 @@ impl TestProtocol { pub fn got_timeout(&self) -> bool { self.got_timeout.load(AtomicOrdering::Relaxed) } + + pub fn got_disconnect(&self) -> bool { + self.got_disconnect.load(AtomicOrdering::Relaxed) + } } impl NetworkProtocolHandler for TestProtocol { @@ -68,15 +73,22 @@ impl NetworkProtocolHandler for TestProtocol { self.packet.lock().unwrap().extend(data); } - fn connected(&self, io: &NetworkContext, _peer: &PeerId) { - io.respond(33, "hello".to_owned().into_bytes()).unwrap(); + fn connected(&self, io: &NetworkContext, peer: &PeerId) { + assert!(io.peer_info(*peer).contains("parity")); + if self.drop_session { + io.disconnect_peer(*peer) + } else { + io.respond(33, "hello".to_owned().into_bytes()).unwrap(); + } } fn disconnected(&self, _io: &NetworkContext, _peer: &PeerId) { + self.got_disconnect.store(true, AtomicOrdering::Relaxed); } /// Timer function called after a timeout created with `NetworkContext::timeout`. - fn timeout(&self, _io: &NetworkContext, timer: TimerToken) { + fn timeout(&self, io: &NetworkContext, timer: TimerToken) { + io.message(TestProtocolMessage { payload: 22 }); assert_eq!(timer, 0); self.got_timeout.store(true, AtomicOrdering::Relaxed); } @@ -85,34 +97,57 @@ impl NetworkProtocolHandler for TestProtocol { #[test] fn net_service() { - let mut service = NetworkService::::start(NetworkConfiguration::new()).expect("Error creating network service"); - service.register_protocol(Arc::new(TestProtocol::default()), "myproto", &[1u8]).unwrap(); + let mut service = NetworkService::::start(NetworkConfiguration::new_with_port(40414)).expect("Error creating network service"); + service.register_protocol(Arc::new(TestProtocol::new(false)), "myproto", &[1u8]).unwrap(); } #[test] fn net_connect() { let key1 = KeyPair::create().unwrap(); - let mut config1 = NetworkConfiguration::new_with_port(30344); + let mut config1 = NetworkConfiguration::new_with_port(30354); config1.use_secret = Some(key1.secret().clone()); + config1.nat_enabled = false; config1.boot_nodes = vec![ ]; - let mut config2 = NetworkConfiguration::new_with_port(30345); - config2.boot_nodes = vec![ format!("enode://{}@127.0.0.1:30344", key1.public().hex()) ]; + let mut config2 = NetworkConfiguration::new_with_port(30355); + config2.boot_nodes = vec![ format!("enode://{}@127.0.0.1:30354", key1.public().hex()) ]; + config2.nat_enabled = false; let mut service1 = NetworkService::::start(config1).unwrap(); let mut service2 = NetworkService::::start(config2).unwrap(); - let handler1 = TestProtocol::register(&mut service1); - let handler2 = TestProtocol::register(&mut service2); - while !handler1.got_packet() && !handler2.got_packet() { + let handler1 = TestProtocol::register(&mut service1, false); + let handler2 = TestProtocol::register(&mut service2, false); + while !handler1.got_packet() && !handler2.got_packet() && (service1.stats().sessions() == 0 || service2.stats().sessions() == 0) { thread::sleep(Duration::from_millis(50)); } assert!(service1.stats().sessions() >= 1); assert!(service2.stats().sessions() >= 1); } +#[test] +fn net_disconnect() { + let key1 = KeyPair::create().unwrap(); + let mut config1 = NetworkConfiguration::new_with_port(30364); + config1.use_secret = Some(key1.secret().clone()); + config1.nat_enabled = false; + config1.boot_nodes = vec![ ]; + let mut config2 = NetworkConfiguration::new_with_port(30365); + config2.boot_nodes = vec![ format!("enode://{}@127.0.0.1:30364", key1.public().hex()) ]; + config2.nat_enabled = false; + let mut service1 = NetworkService::::start(config1).unwrap(); + let mut service2 = NetworkService::::start(config2).unwrap(); + let handler1 = TestProtocol::register(&mut service1, false); + let handler2 = TestProtocol::register(&mut service2, true); + while !(handler1.got_disconnect() && handler2.got_disconnect()) { + thread::sleep(Duration::from_millis(50)); + } + assert!(handler1.got_disconnect()); + assert!(handler2.got_disconnect()); +} + #[test] fn net_timeout() { let config = NetworkConfiguration::new_with_port(30346); let mut service = NetworkService::::start(config).unwrap(); - let handler = TestProtocol::register(&mut service); + let handler = TestProtocol::register(&mut service, false); while !handler.got_timeout() { thread::sleep(Duration::from_millis(50)); }