From 8596a347ead1c053e50ebc6f412bacc9e50182ec Mon Sep 17 00:00:00 2001 From: Arkadiy Paronyan Date: Thu, 2 Jun 2016 11:49:56 +0200 Subject: [PATCH] Networking refactoring (#1172) * Networking refactoring * Make sure the same socket is reused * Safer atomic ordering * Replaced eq with == --- Cargo.lock | 7 +- util/Cargo.toml | 2 +- util/src/network/connection.rs | 59 ++------ util/src/network/handshake.rs | 54 ++------ util/src/network/host.rs | 245 ++++++++++----------------------- util/src/network/session.rs | 151 ++++++++++++++------ 6 files changed, 209 insertions(+), 309 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ce5c1f481..723bca872 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -390,7 +390,7 @@ dependencies = [ "rustc_version 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)", "serde 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", "sha3 0.1.0", - "slab 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", + "slab 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", "target_info 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "time 0.1.35 (registry+https://github.com/rust-lang/crates.io-index)", "tiny-keccak 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1176,6 +1176,11 @@ name = "slab" version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "slab" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "solicit" version = "0.4.4" diff --git a/util/Cargo.toml b/util/Cargo.toml index 7d03fd320..a88ffe037 100644 --- a/util/Cargo.toml +++ b/util/Cargo.toml @@ -25,7 +25,7 @@ elastic-array = "0.4" heapsize = "0.3" itertools = "0.4" crossbeam = "0.2" -slab = "0.1" +slab = "0.2" sha3 = { path = "sha3" } serde = "0.7.0" clippy = { version = "0.0.69", optional = true} diff --git a/util/src/network/connection.rs b/util/src/network/connection.rs index 589fc0106..3f20b8f7b 100644 --- a/util/src/network/connection.rs +++ b/util/src/network/connection.rs @@ -170,16 +170,16 @@ impl Connection { self.token } - /// Replace socket token - pub fn set_token(&mut self, token: StreamToken) { - self.token = token; - } - /// Get remote peer address pub fn remote_addr(&self) -> io::Result { self.socket.peer_addr() } + /// Get remote peer address string + pub fn remote_addr_str(&self) -> String { + self.socket.peer_addr().map(|a| a.to_string()).unwrap_or_else(|_| "Unknown".to_owned()) + } + /// Clone this connection. Clears the receiving buffer of the returned connection. pub fn try_clone(&self) -> io::Result { Ok(Connection { @@ -196,7 +196,7 @@ impl Connection { /// Register this connection with the IO event loop. pub fn register_socket(&self, reg: Token, event_loop: &mut EventLoop) -> io::Result<()> { trace!(target: "network", "connection register; token={:?}", reg); - if let Err(e) = event_loop.register(&self.socket, reg, self.interest, PollOpt::edge() | PollOpt::oneshot()) { + if let Err(e) = event_loop.register(&self.socket, reg, self.interest, PollOpt::edge() /* | PollOpt::oneshot() */) { // TODO: oneshot is broken on windows trace!(target: "network", "Failed to register {:?}, {:?}", reg, e); } Ok(()) @@ -205,7 +205,7 @@ impl Connection { /// Update connection registration. Should be called at the end of the IO handler. pub fn update_socket(&self, reg: Token, event_loop: &mut EventLoop) -> io::Result<()> { trace!(target: "network", "connection reregister; token={:?}", reg); - event_loop.reregister( &self.socket, reg, self.interest, PollOpt::edge() | PollOpt::oneshot()).or_else(|e| { + event_loop.reregister( &self.socket, reg, self.interest, PollOpt::edge() /* | PollOpt::oneshot() */ ).or_else(|e| { // TODO: oneshot is broken on windows trace!(target: "network", "Failed to reregister {:?}, {:?}", reg, e); Ok(()) }) @@ -246,7 +246,7 @@ enum EncryptedConnectionState { /// https://github.com/ethereum/devp2p/blob/master/rlpx.md#framing pub struct EncryptedConnection { /// Underlying tcp connection - connection: Connection, + pub connection: Connection, /// Egress data encryptor encoder: CtrMode, /// Ingress data decryptor @@ -266,27 +266,6 @@ pub struct EncryptedConnection { } impl EncryptedConnection { - - /// Get socket token - pub fn token(&self) -> StreamToken { - self.connection.token - } - - /// Replace socket token - pub fn set_token(&mut self, token: StreamToken) { - self.connection.set_token(token); - } - - /// Get remote peer address - pub fn remote_addr(&self) -> io::Result { - self.connection.remote_addr() - } - - /// Check if this connection has data to be sent. - pub fn is_sending(&self) -> bool { - self.connection.is_sending() - } - /// Create an encrypted connection out of the handshake. Consumes a handshake object. pub fn new(handshake: &mut Handshake) -> Result { let shared = try!(crypto::ecdh::agree(handshake.ecdhe.secret(), &handshake.remote_ephemeral)); @@ -323,8 +302,10 @@ impl EncryptedConnection { ingress_mac.update(&mac_material); ingress_mac.update(if handshake.originated { &handshake.ack_cipher } else { &handshake.auth_cipher }); + let old_connection = try!(handshake.connection.try_clone()); + let connection = ::std::mem::replace(&mut handshake.connection, old_connection); let mut enc = EncryptedConnection { - connection: try!(handshake.connection.try_clone()), + connection: connection, encoder: encoder, decoder: decoder, mac_encoder: mac_encoder, @@ -463,24 +444,6 @@ impl EncryptedConnection { try!(self.connection.writable()); Ok(()) } - - /// Register socket with the event lpop. This should be called at the end of the event loop. - pub fn register_socket(&self, reg: Token, event_loop: &mut EventLoop) -> Result<(), UtilError> { - try!(self.connection.register_socket(reg, event_loop)); - Ok(()) - } - - /// Update connection registration. This should be called at the end of the event loop. - pub fn update_socket(&self, reg: Token, event_loop: &mut EventLoop) -> Result<(), UtilError> { - try!(self.connection.update_socket(reg, event_loop)); - Ok(()) - } - - /// Delete connection registration. This should be called at the end of the event loop. - pub fn deregister_socket(&self, event_loop: &mut EventLoop) -> Result<(), UtilError> { - try!(self.connection.deregister_socket(event_loop)); - Ok(()) - } } #[test] diff --git a/util/src/network/handshake.rs b/util/src/network/handshake.rs index 123531d8d..e02da3d4c 100644 --- a/util/src/network/handshake.rs +++ b/util/src/network/handshake.rs @@ -16,7 +16,6 @@ use std::sync::Arc; use rand::random; -use mio::*; use mio::tcp::*; use hash::*; use rlp::*; @@ -102,21 +101,6 @@ impl Handshake { }) } - /// Get id of the remote node if known - pub fn id(&self) -> &NodeId { - &self.id - } - - /// Get stream token id - pub fn token(&self) -> StreamToken { - self.connection.token() - } - - /// Mark this handshake as inactive to be deleted lated. - pub fn set_expired(&mut self) { - self.expired = true; - } - /// Check if this handshake is expired. pub fn expired(&self) -> bool { self.expired @@ -177,7 +161,7 @@ impl Handshake { } /// Writabe IO handler. - pub fn writable(&mut self, io: &IoContext, _host: &HostInfo) -> Result<(), UtilError> where Message: Send + Clone { + pub fn writable(&mut self, io: &IoContext) -> Result<(), UtilError> where Message: Send + Clone { if !self.expired() { io.clear_timer(self.connection.token).unwrap(); try!(self.connection.writable()); @@ -188,28 +172,6 @@ impl Handshake { Ok(()) } - /// Register the socket with the event loop - pub fn register_socket>(&self, reg: Token, event_loop: &mut EventLoop) -> Result<(), UtilError> { - if !self.expired() { - try!(self.connection.register_socket(reg, event_loop)); - } - Ok(()) - } - - /// Update socket registration with the event loop. - pub fn update_socket>(&self, reg: Token, event_loop: &mut EventLoop) -> Result<(), UtilError> { - if !self.expired() { - try!(self.connection.update_socket(reg, event_loop)); - } - Ok(()) - } - - /// Delete registration - pub fn deregister_socket(&self, event_loop: &mut EventLoop) -> Result<(), UtilError> { - try!(self.connection.deregister_socket(event_loop)); - Ok(()) - } - fn set_auth(&mut self, host_secret: &Secret, sig: &[u8], remote_public: &[u8], remote_nonce: &[u8], remote_version: u64) -> Result<(), UtilError> { self.id.clone_from_slice(remote_public); self.remote_nonce.clone_from_slice(remote_nonce); @@ -222,7 +184,7 @@ impl Handshake { /// Parse, validate and confirm auth message fn read_auth(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> { - trace!(target:"network", "Received handshake auth from {:?}", self.connection.socket.peer_addr()); + trace!(target:"network", "Received handshake auth from {:?}", self.connection.remote_addr_str()); if data.len() != V4_AUTH_PACKET_SIZE { debug!(target:"net", "Wrong auth packet size"); return Err(From::from(NetworkError::BadProtocol)); @@ -253,7 +215,7 @@ impl Handshake { } fn read_auth_eip8(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> { - trace!(target:"network", "Received EIP8 handshake auth from {:?}", self.connection.socket.peer_addr()); + trace!(target:"network", "Received EIP8 handshake auth from {:?}", self.connection.remote_addr_str()); self.auth_cipher.extend_from_slice(data); let auth = try!(ecies::decrypt(secret, &self.auth_cipher[0..2], &self.auth_cipher[2..])); let rlp = UntrustedRlp::new(&auth); @@ -268,7 +230,7 @@ impl Handshake { /// Parse and validate ack message fn read_ack(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> { - trace!(target:"network", "Received handshake auth to {:?}", self.connection.socket.peer_addr()); + trace!(target:"network", "Received handshake auth to {:?}", self.connection.remote_addr_str()); if data.len() != V4_ACK_PACKET_SIZE { debug!(target:"net", "Wrong ack packet size"); return Err(From::from(NetworkError::BadProtocol)); @@ -296,7 +258,7 @@ impl Handshake { } fn read_ack_eip8(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> { - trace!(target:"network", "Received EIP8 handshake auth from {:?}", self.connection.socket.peer_addr()); + trace!(target:"network", "Received EIP8 handshake auth from {:?}", self.connection.remote_addr_str()); self.ack_cipher.extend_from_slice(data); let ack = try!(ecies::decrypt(secret, &self.ack_cipher[0..2], &self.ack_cipher[2..])); let rlp = UntrustedRlp::new(&ack); @@ -309,7 +271,7 @@ impl Handshake { /// Sends auth message fn write_auth(&mut self, secret: &Secret, public: &Public) -> Result<(), UtilError> { - trace!(target:"network", "Sending handshake auth to {:?}", self.connection.socket.peer_addr()); + trace!(target:"network", "Sending handshake auth to {:?}", self.connection.remote_addr_str()); let mut data = [0u8; /*Signature::SIZE*/ 65 + /*H256::SIZE*/ 32 + /*Public::SIZE*/ 64 + /*H256::SIZE*/ 32 + 1]; //TODO: use associated constants let len = data.len(); { @@ -336,7 +298,7 @@ impl Handshake { /// Sends ack message fn write_ack(&mut self) -> Result<(), UtilError> { - trace!(target:"network", "Sending handshake ack to {:?}", self.connection.socket.peer_addr()); + trace!(target:"network", "Sending handshake ack to {:?}", self.connection.remote_addr_str()); let mut data = [0u8; 1 + /*Public::SIZE*/ 64 + /*H256::SIZE*/ 32]; //TODO: use associated constants let len = data.len(); { @@ -355,7 +317,7 @@ impl Handshake { /// Sends EIP8 ack message fn write_ack_eip8(&mut self) -> Result<(), UtilError> { - trace!(target:"network", "Sending EIP8 handshake ack to {:?}", self.connection.socket.peer_addr()); + trace!(target:"network", "Sending EIP8 handshake ack to {:?}", self.connection.remote_addr_str()); let mut rlp = RlpStream::new_list(3); rlp.append(self.ecdhe.public()); rlp.append(&self.nonce); diff --git a/util/src/network/host.rs b/util/src/network/host.rs index 13b64eb3c..92a912a40 100644 --- a/util/src/network/host.rs +++ b/util/src/network/host.rs @@ -18,6 +18,7 @@ use std::net::{SocketAddr}; use std::collections::{HashMap}; use std::str::{FromStr}; use std::sync::*; +use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering}; use std::ops::*; use std::cmp::min; use std::path::{Path, PathBuf}; @@ -31,7 +32,6 @@ use misc::version; use crypto::*; use sha3::Hashable; use rlp::*; -use network::handshake::Handshake; use network::session::{Session, SessionData}; use error::*; use io::*; @@ -44,8 +44,7 @@ use network::ip_utils::{map_external_address, select_public_address}; type Slab = ::slab::Slab; -const _DEFAULT_PORT: u16 = 30304; -const MAX_SESSIONS: usize = 1024; +const MAX_SESSIONS: usize = 1024 + MAX_HANDSHAKES; const MAX_HANDSHAKES: usize = 80; const MAX_HANDSHAKES_PER_ROUND: usize = 32; const MAINTENANCE_TIMEOUT: u64 = 1000; @@ -115,18 +114,17 @@ impl NetworkConfiguration { } // Tokens -const TCP_ACCEPT: usize = LAST_HANDSHAKE + 1; -const IDLE: usize = LAST_HANDSHAKE + 2; -const DISCOVERY: usize = LAST_HANDSHAKE + 3; -const DISCOVERY_REFRESH: usize = LAST_HANDSHAKE + 4; -const DISCOVERY_ROUND: usize = LAST_HANDSHAKE + 5; -const INIT_PUBLIC: usize = LAST_HANDSHAKE + 6; -const NODE_TABLE: usize = LAST_HANDSHAKE + 7; +const TCP_ACCEPT: usize = SYS_TIMER + 1; +const IDLE: usize = SYS_TIMER + 2; +const DISCOVERY: usize = SYS_TIMER + 3; +const DISCOVERY_REFRESH: usize = SYS_TIMER + 4; +const DISCOVERY_ROUND: usize = SYS_TIMER + 5; +const INIT_PUBLIC: usize = SYS_TIMER + 6; +const NODE_TABLE: usize = SYS_TIMER + 7; const FIRST_SESSION: usize = 0; const LAST_SESSION: usize = FIRST_SESSION + MAX_SESSIONS - 1; -const FIRST_HANDSHAKE: usize = LAST_SESSION + 1; -const LAST_HANDSHAKE: usize = FIRST_HANDSHAKE + MAX_HANDSHAKES - 1; -const USER_TIMER: usize = LAST_HANDSHAKE + 256; +const USER_TIMER: usize = LAST_SESSION + 256; +const SYS_TIMER: usize = LAST_SESSION + 1; /// Protocol handler level packet id pub type PacketId = u8; @@ -306,7 +304,6 @@ impl HostInfo { } type SharedSession = Arc>; -type SharedHandshake = Arc>; #[derive(Copy, Clone)] struct ProtocolTimer { @@ -318,7 +315,6 @@ struct ProtocolTimer { pub struct Host where Message: Send + Sync + Clone { pub info: RwLock, tcp_listener: Mutex, - handshakes: Arc>>, sessions: Arc>>, discovery: Mutex>, nodes: RwLock, @@ -327,6 +323,7 @@ pub struct Host where Message: Send + Sync + Clone { timer_counter: RwLock, stats: Arc, pinned_nodes: Vec, + num_sessions: AtomicUsize, } impl Host where Message: Send + Sync + Clone { @@ -370,7 +367,6 @@ impl Host where Message: Send + Sync + Clone { }), discovery: Mutex::new(None), tcp_listener: Mutex::new(tcp_listener), - handshakes: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_HANDSHAKE, MAX_HANDSHAKES))), sessions: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_SESSION, MAX_SESSIONS))), nodes: RwLock::new(NodeTable::new(path)), handlers: RwLock::new(HashMap::new()), @@ -378,6 +374,7 @@ impl Host where Message: Send + Sync + Clone { timer_counter: RwLock::new(USER_TIMER), stats: Arc::new(NetworkStats::default()), pinned_nodes: Vec::new(), + num_sessions: AtomicUsize::new(0), }; let boot_nodes = host.info.read().unwrap().config.boot_nodes.clone(); @@ -477,19 +474,19 @@ impl Host where Message: Send + Sync + Clone { } fn have_session(&self, id: &NodeId) -> bool { - self.sessions.read().unwrap().iter().any(|e| e.lock().unwrap().info.id.eq(&id)) + self.sessions.read().unwrap().iter().any(|e| e.lock().unwrap().info.id == Some(id.clone())) } fn session_count(&self) -> usize { - self.sessions.read().unwrap().count() + self.num_sessions.load(AtomicOrdering::Relaxed) } fn connecting_to(&self, id: &NodeId) -> bool { - self.handshakes.read().unwrap().iter().any(|e| e.lock().unwrap().id.eq(&id)) + self.sessions.read().unwrap().iter().any(|e| e.lock().unwrap().id() == Some(id)) } fn handshake_count(&self) -> usize { - self.handshakes.read().unwrap().count() + self.sessions.read().unwrap().count() - self.session_count() } fn keep_alive(&self, io: &IoContext>) { @@ -565,21 +562,31 @@ impl Host where Message: Send + Sync + Clone { } } }; - self.create_connection(socket, Some(id), io); + if let Err(e) = self.create_connection(socket, Some(id), io) { + debug!(target: "network", "Can't create connection: {:?}", e); + } } #[cfg_attr(feature="dev", allow(block_in_if_condition_stmt))] - fn create_connection(&self, socket: TcpStream, id: Option<&NodeId>, io: &IoContext>) { + fn create_connection(&self, socket: TcpStream, id: Option<&NodeId>, io: &IoContext>) -> Result<(), UtilError> { let nonce = self.info.write().unwrap().next_nonce(); - let mut handshakes = self.handshakes.write().unwrap(); - if handshakes.insert_with(|token| { - let mut handshake = Handshake::new(token, id, socket, &nonce, self.stats.clone()).expect("Can't create handshake"); - handshake.start(io, &self.info.read().unwrap(), id.is_some()).and_then(|_| io.register_stream(token)).unwrap_or_else (|e| { - debug!(target: "network", "Handshake create error: {:?}", e); - }); - Arc::new(Mutex::new(handshake)) - }).is_none() { - debug!(target: "network", "Max handshakes reached"); + let mut sessions = self.sessions.write().unwrap(); + let token = sessions.insert_with_opt(|token| { + match Session::new(io, socket, token, id, &nonce, self.stats.clone(), &self.info.read().unwrap()) { + Ok(s) => Some(Arc::new(Mutex::new(s))), + Err(e) => { + debug!(target: "network", "Session create error: {:?}", e); + None + } + } + }); + + match token { + Some(t) => io.register_stream(t), + None => { + debug!(target: "network", "Max sessions reached"); + Ok(()) + } } } @@ -594,19 +601,11 @@ impl Host where Message: Send + Sync + Clone { break }, }; - self.create_connection(socket, None, io); - } - io.update_registration(TCP_ACCEPT).expect("Error registering TCP listener"); - } - - fn handshake_writable(&self, token: StreamToken, io: &IoContext>) { - let handshake = { self.handshakes.read().unwrap().get(token).cloned() }; - if let Some(handshake) = handshake { - let mut h = handshake.lock().unwrap(); - if let Err(e) = h.writable(io, &self.info.read().unwrap()) { - trace!(target: "network", "Handshake write error: {}: {:?}", token, e); + if let Err(e) = self.create_connection(socket, None, io) { + debug!(target: "network", "Can't accept connection: {:?}", e); } } + io.update_registration(TCP_ACCEPT).expect("Error registering TCP listener"); } fn session_writable(&self, token: StreamToken, io: &IoContext>) { @@ -629,30 +628,6 @@ impl Host where Message: Send + Sync + Clone { self.kill_connection(token, io, true); } - fn handshake_readable(&self, token: StreamToken, io: &IoContext>) { - let mut create_session = false; - let mut kill = false; - let handshake = { self.handshakes.read().unwrap().get(token).cloned() }; - if let Some(handshake) = handshake { - let mut h = handshake.lock().unwrap(); - if let Err(e) = h.readable(io, &self.info.read().unwrap()) { - debug!(target: "network", "Handshake read error: {}: {:?}", token, e); - kill = true; - } - if h.done() { - create_session = true; - } - } - if kill { - self.kill_connection(token, io, true); - return; - } else if create_session { - self.start_session(token, io); - return; - } - io.update_registration(token).unwrap_or_else(|e| debug!(target: "network", "Token registration error: {:?}", e)); - } - fn session_readable(&self, token: StreamToken, io: &IoContext>) { let mut ready_data: Vec = Vec::new(); let mut packet_data: Option<(ProtocolId, PacketId, Vec)> = None; @@ -662,17 +637,37 @@ impl Host where Message: Send + Sync + Clone { let mut s = session.lock().unwrap(); match s.readable(io, &self.info.read().unwrap()) { Err(e) => { - trace!(target: "network", "Session read error: {}:{} ({:?}) {:?}", token, s.id(), s.remote_addr(), e); + trace!(target: "network", "Session read error: {}:{:?} ({:?}) {:?}", token, s.id(), s.remote_addr(), e); match e { UtilError::Network(NetworkError::Disconnect(DisconnectReason::UselessPeer)) | UtilError::Network(NetworkError::Disconnect(DisconnectReason::IncompatibleProtocol)) => { - self.nodes.write().unwrap().mark_as_useless(s.id()); + if let Some(id) = s.id() { + self.nodes.write().unwrap().mark_as_useless(id); + } } _ => (), } kill = true; }, Ok(SessionData::Ready) => { + if !s.info.originated { + let session_count = self.session_count(); + let ideal_peers = { self.info.read().unwrap().deref().config.ideal_peers }; + if session_count >= ideal_peers as usize { + s.disconnect(DisconnectReason::TooManyPeers); + return; + } + // Add it no node table + if let Ok(address) = s.remote_addr() { + let entry = NodeEntry { id: s.id().unwrap().clone(), endpoint: NodeEndpoint { address: address, udp_port: address.port() } }; + self.nodes.write().unwrap().add_node(Node::new(entry.id.clone(), entry.endpoint.clone())); + let mut discovery = self.discovery.lock().unwrap(); + if let Some(ref mut discovery) = *discovery.deref_mut() { + discovery.add_node(entry); + } + } + } + self.num_sessions.fetch_add(1, AtomicOrdering::SeqCst); for (p, _) in self.handlers.read().unwrap().iter() { if s.have_capability(p) { ready_data.push(p); @@ -697,6 +692,7 @@ impl Host where Message: Send + Sync + Clone { } for p in ready_data { let h = self.handlers.read().unwrap().get(p).unwrap().clone(); + self.stats.inc_sessions(); h.connected(&NetworkContext::new(io, p, session.clone(), self.sessions.clone()), &token); } if let Some((p, packet_id, data)) = packet_data { @@ -706,59 +702,6 @@ impl Host where Message: Send + Sync + Clone { io.update_registration(token).unwrap_or_else(|e| debug!(target: "network", "Token registration error: {:?}", e)); } - fn start_session(&self, token: StreamToken, io: &IoContext>) { - let mut handshakes = self.handshakes.write().unwrap(); - if handshakes.get(token).is_none() { - return; - } - - // turn a handshake into a session - let mut sessions = self.sessions.write().unwrap(); - let mut h = handshakes.get_mut(token).unwrap().lock().unwrap(); - if h.expired { - return; - } - io.deregister_stream(token).expect("Error deleting handshake registration"); - h.set_expired(); - let originated = h.originated; - let mut session = match Session::new(&mut h, &self.info.read().unwrap()) { - Ok(s) => s, - Err(e) => { - debug!(target: "network", "Session creation error: {:?}", e); - return; - } - }; - if !originated { - let session_count = sessions.count(); - let ideal_peers = { self.info.read().unwrap().deref().config.ideal_peers }; - if session_count >= ideal_peers as usize { - session.disconnect(DisconnectReason::TooManyPeers); - return; - } - } - let result = sessions.insert_with(move |session_token| { - session.set_token(session_token); - io.register_stream(session_token).expect("Error creating session registration"); - self.stats.inc_sessions(); - trace!(target: "network", "Creating session {} -> {}:{} ({:?})", token, session_token, session.id(), session.remote_addr()); - if !originated { - // Add it no node table - if let Ok(address) = session.remote_addr() { - let entry = NodeEntry { id: session.id().clone(), endpoint: NodeEndpoint { address: address, udp_port: address.port() } }; - self.nodes.write().unwrap().add_node(Node::new(entry.id.clone(), entry.endpoint.clone())); - let mut discovery = self.discovery.lock().unwrap(); - if let Some(ref mut discovery) = *discovery.deref_mut() { - discovery.add_node(entry); - } - } - } - Arc::new(Mutex::new(session)) - }); - if result.is_none() { - warn!("Max sessions reached"); - } - } - fn connection_timeout(&self, token: StreamToken, io: &IoContext>) { trace!(target: "network", "Connection timeout: {}", token); self.kill_connection(token, io, true) @@ -770,17 +713,6 @@ impl Host where Message: Send + Sync + Clone { let mut deregister = false; let mut expired_session = None; match token { - FIRST_HANDSHAKE ... LAST_HANDSHAKE => { - let handshakes = self.handshakes.write().unwrap(); - if let Some(handshake) = handshakes.get(token).cloned() { - let mut handshake = handshake.lock().unwrap(); - if !handshake.expired() { - handshake.set_expired(); - failure_id = Some(handshake.id().clone()); - deregister = true; - } - } - }, FIRST_SESSION ... LAST_SESSION => { let sessions = self.sessions.write().unwrap(); if let Some(session) = sessions.get(token).cloned() { @@ -790,12 +722,13 @@ impl Host where Message: Send + Sync + Clone { if s.is_ready() { for (p, _) in self.handlers.read().unwrap().iter() { if s.have_capability(p) { + self.num_sessions.fetch_sub(1, AtomicOrdering::SeqCst); to_disconnect.push(p); } } } s.set_expired(); - failure_id = Some(s.id().clone()); + failure_id = s.id().cloned(); } deregister = remote || s.done(); } @@ -821,20 +754,11 @@ impl Host where Message: Send + Sync + Clone { fn update_nodes(&self, io: &IoContext>, node_changes: TableUpdates) { let mut to_remove: Vec = Vec::new(); { - { - let handshakes = self.handshakes.write().unwrap(); - for c in handshakes.iter() { - let h = c.lock().unwrap(); - if node_changes.removed.contains(&h.id()) { - to_remove.push(h.token()); - } - } - } - { - let sessions = self.sessions.write().unwrap(); - for c in sessions.iter() { - let s = c.lock().unwrap(); - if node_changes.removed.contains(&s.id()) { + let sessions = self.sessions.write().unwrap(); + for c in sessions.iter() { + let s = c.lock().unwrap(); + if let Some(id) = s.id() { + if node_changes.removed.contains(id) { to_remove.push(s.token()); } } @@ -860,7 +784,6 @@ impl IoHandler> for Host where Messa trace!(target: "network", "Hup: {}", stream); match stream { FIRST_SESSION ... LAST_SESSION => self.connection_closed(stream, io), - FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_closed(stream, io), _ => warn!(target: "network", "Unexpected hup"), }; } @@ -868,7 +791,6 @@ impl IoHandler> for Host where Messa fn stream_readable(&self, io: &IoContext>, stream: StreamToken) { match stream { FIRST_SESSION ... LAST_SESSION => self.session_readable(stream, io), - FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_readable(stream, io), DISCOVERY => { let node_changes = { self.discovery.lock().unwrap().as_mut().unwrap().readable() }; if let Some(node_changes) = node_changes { @@ -884,7 +806,6 @@ impl IoHandler> for Host where Messa fn stream_writable(&self, io: &IoContext>, stream: StreamToken) { match stream { FIRST_SESSION ... LAST_SESSION => self.session_writable(stream, io), - FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_writable(stream, io), DISCOVERY => { self.discovery.lock().unwrap().as_mut().unwrap().writable(); io.update_registration(DISCOVERY).expect("Error updating discovery registration"); @@ -899,7 +820,6 @@ impl IoHandler> for Host where Messa INIT_PUBLIC => self.init_public_interface(io).unwrap_or_else(|e| warn!("Error initializing public interface: {:?}", e)), FIRST_SESSION ... LAST_SESSION => self.connection_timeout(token, io), - FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_timeout(token, io), DISCOVERY_REFRESH => { self.discovery.lock().unwrap().as_mut().unwrap().refresh(); io.update_registration(DISCOVERY).expect("Error updating discovery registration"); @@ -966,7 +886,9 @@ impl IoHandler> for Host where Messa let session = { self.sessions.read().unwrap().get(*peer).cloned() }; if let Some(session) = session { session.lock().unwrap().disconnect(DisconnectReason::DisconnectRequested); - self.nodes.write().unwrap().mark_as_useless(session.lock().unwrap().id()); + if let Some(id) = session.lock().unwrap().id() { + self.nodes.write().unwrap().mark_as_useless(id) + } } trace!(target: "network", "Disabling peer {}", peer); self.kill_connection(*peer, io, false); @@ -987,12 +909,6 @@ impl IoHandler> for Host where Messa session.lock().unwrap().register_socket(reg, event_loop).expect("Error registering socket"); } } - FIRST_HANDSHAKE ... LAST_HANDSHAKE => { - let connection = { self.handshakes.read().unwrap().get(stream).cloned() }; - if let Some(connection) = connection { - connection.lock().unwrap().register_socket(reg, event_loop).expect("Error registering socket"); - } - } DISCOVERY => self.discovery.lock().unwrap().as_ref().unwrap().register_socket(event_loop).expect("Error registering discovery socket"), TCP_ACCEPT => event_loop.register(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error registering stream"), _ => warn!("Unexpected stream registration") @@ -1008,13 +924,6 @@ impl IoHandler> for Host where Messa connections.remove(stream); } } - FIRST_HANDSHAKE ... LAST_HANDSHAKE => { - let mut connections = self.handshakes.write().unwrap(); - if let Some(connection) = connections.get(stream).cloned() { - connection.lock().unwrap().deregister_socket(event_loop).expect("Error deregistering socket"); - connections.remove(stream); - } - } DISCOVERY => (), _ => warn!("Unexpected stream deregistration") } @@ -1028,12 +937,6 @@ impl IoHandler> for Host where Messa connection.lock().unwrap().update_socket(reg, event_loop).expect("Error updating socket"); } } - FIRST_HANDSHAKE ... LAST_HANDSHAKE => { - let connection = { self.handshakes.read().unwrap().get(stream).cloned() }; - if let Some(connection) = connection { - connection.lock().unwrap().update_socket(reg, event_loop).expect("Error updating socket"); - } - } DISCOVERY => self.discovery.lock().unwrap().as_ref().unwrap().update_registration(event_loop).expect("Error reregistering discovery socket"), TCP_ACCEPT => event_loop.reregister(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error reregistering stream"), _ => warn!("Unexpected stream update") diff --git a/util/src/network/session.rs b/util/src/network/session.rs index 6c0a20a14..7b7f16c18 100644 --- a/util/src/network/session.rs +++ b/util/src/network/session.rs @@ -16,15 +16,19 @@ use std::net::SocketAddr; use std::io; +use std::sync::*; use mio::*; +use mio::tcp::*; use rlp::*; -use network::connection::{EncryptedConnection, Packet}; +use hash::*; +use network::connection::{EncryptedConnection, Packet, Connection}; use network::handshake::Handshake; use error::*; use io::{IoContext, StreamToken}; use network::error::{NetworkError, DisconnectReason}; use network::host::*; use network::node_table::NodeId; +use network::stats::NetworkStats; use time; const PING_TIMEOUT_SEC: u64 = 30; @@ -36,14 +40,18 @@ const PING_INTERVAL_SEC: u64 = 30; pub struct Session { /// Shared session information pub info: SessionInfo, - /// Underlying connection - connection: EncryptedConnection, /// Session ready flag. Set after successfull Hello packet exchange had_hello: bool, /// Session is no longer active flag. expired: bool, ping_time_ns: u64, pong_time_ns: Option, + state: State, +} + +enum State { + Handshake(Handshake), + Session(EncryptedConnection), } /// Structure used to report various session events. @@ -65,7 +73,7 @@ pub enum SessionData { /// Shared session information pub struct SessionInfo { /// Peer public key - pub id: NodeId, + pub id: Option, /// Peer client ID pub client_version: String, /// Peer RLPx protocol version @@ -74,6 +82,8 @@ pub struct SessionInfo { capabilities: Vec, /// Peer ping delay in milliseconds pub ping_ms: Option, + /// True if this session was originated by us. + pub originated: bool, } #[derive(Debug, PartialEq, Eq)] @@ -112,31 +122,52 @@ const PACKET_LAST: u8 = 0x7f; impl Session { /// Create a new session out of comepleted handshake. This clones the handshake connection object /// and leaves the handhsake in limbo to be deregistered from the event loop. - pub fn new(h: &mut Handshake, host: &HostInfo) -> Result { - let id = h.id.clone(); - let connection = try!(EncryptedConnection::new(h)); - let mut session = Session { - connection: connection, + pub fn new(io: &IoContext, socket: TcpStream, token: StreamToken, id: Option<&NodeId>, + nonce: &H256, stats: Arc, host: &HostInfo) -> Result + where Message: Send + Clone { + let originated = id.is_some(); + let mut handshake = Handshake::new(token, id, socket, &nonce, stats).expect("Can't create handshake"); + try!(handshake.start(io, host, originated)); + Ok(Session { + state: State::Handshake(handshake), had_hello: false, info: SessionInfo { - id: id, + id: id.cloned(), client_version: String::new(), protocol_version: 0, capabilities: Vec::new(), ping_ms: None, + originated: originated, }, ping_time_ns: 0, pong_time_ns: None, expired: false, + }) + } + + fn complete_handshake(&mut self, host: &HostInfo) -> Result<(), UtilError> { + let connection = if let State::Handshake(ref mut h) = self.state { + self.info.id = Some(h.id.clone()); + try!(EncryptedConnection::new(h)) + } else { + panic!("Unexpected state"); }; - try!(session.write_hello(host)); - try!(session.send_ping()); - Ok(session) + self.state = State::Session(connection); + try!(self.write_hello(host)); + try!(self.send_ping()); + Ok(()) + } + + fn connection(&self) -> &Connection { + match self.state { + State::Handshake(ref h) => &h.connection, + State::Session(ref s) => &s.connection, + } } /// Get id of the remote peer - pub fn id(&self) -> &NodeId { - &self.info.id + pub fn id(&self) -> Option<&NodeId> { + self.info.id.as_ref() } /// Check if session is ready to send/receive data @@ -151,21 +182,20 @@ impl Session { /// Check if this session is expired. pub fn expired(&self) -> bool { - self.expired + match self.state { + State::Handshake(ref h) => h.expired(), + _ => self.expired, + } } /// Check if this session is over and there is nothing to be sent. pub fn done(&self) -> bool { - self.expired() && !self.connection.is_sending() - } - /// Replace socket token - pub fn set_token(&mut self, token: StreamToken) { - self.connection.set_token(token); + self.expired() && !self.connection().is_sending() } /// Get remote peer address pub fn remote_addr(&self) -> io::Result { - self.connection.remote_addr() + self.connection().remote_addr() } /// Readable IO handler. Returns packet data if available. @@ -173,15 +203,37 @@ impl Session { if self.expired() { return Ok(SessionData::None) } - match try!(self.connection.readable(io)) { - Some(data) => Ok(try!(self.read_packet(data, host))), - None => Ok(SessionData::None) + let mut create_session = false; + let mut packet_data = None; + match self.state { + State::Handshake(ref mut h) => { + try!(h.readable(io, host)); + if h.done() { + create_session = true; + } + } + State::Session(ref mut c) => { + match try!(c.readable(io)) { + data @ Some(_) => packet_data = data, + None => return Ok(SessionData::None) + } + } } + if let Some(data) = packet_data { + return Ok(try!(self.read_packet(data, host))); + } + if create_session { + try!(self.complete_handshake(host)); + } + Ok(SessionData::None) } /// Writable IO handler. Sends pending packets. pub fn writable(&mut self, io: &IoContext, _host: &HostInfo) -> Result<(), UtilError> where Message: Send + Sync + Clone { - self.connection.writable(io) + match self.state { + State::Handshake(ref mut h) => h.writable(io), + State::Session(ref mut s) => s.writable(io), + } } /// Checks if peer supports given capability @@ -194,18 +246,20 @@ impl Session { if self.expired() { return Ok(()); } - try!(self.connection.register_socket(reg, event_loop)); + try!(self.connection().register_socket(reg, event_loop)); Ok(()) } /// Update registration with the event loop. Should be called at the end of the IO handler. pub fn update_socket(&self, reg:Token, event_loop: &mut EventLoop) -> Result<(), UtilError> { - self.connection.update_socket(reg, event_loop) + try!(self.connection().update_socket(reg, event_loop)); + Ok(()) } /// Delete registration pub fn deregister_socket(&self, event_loop: &mut EventLoop) -> Result<(), UtilError> { - self.connection.deregister_socket(event_loop) + try!(self.connection().deregister_socket(event_loop)); + Ok(()) } /// Send a protocol packet to peer. @@ -221,7 +275,7 @@ impl Session { while protocol != self.info.capabilities[i].protocol { i += 1; if i == self.info.capabilities.len() { - debug!(target: "net", "Unknown protocol: {:?}", protocol); + debug!(target: "network", "Unknown protocol: {:?}", protocol); return Ok(()) } } @@ -229,11 +283,14 @@ impl Session { let mut rlp = RlpStream::new(); rlp.append(&(pid as u32)); rlp.append_raw(data, 1); - self.connection.send_packet(&rlp.out()) + self.send(rlp) } /// Keep this session alive. Returns false if ping timeout happened pub fn keep_alive(&mut self, io: &IoContext) -> bool where Message: Send + Sync + Clone { + if let State::Handshake(_) = self.state { + return true; + } let timed_out = if let Some(pong) = self.pong_time_ns { pong - self.ping_time_ns > PING_TIMEOUT_SEC * 1000_000_000 } else { @@ -244,13 +301,13 @@ impl Session { if let Err(e) = self.send_ping() { debug!("Error sending ping message: {:?}", e); } - io.update_registration(self.token()).unwrap_or_else(|e| debug!(target: "net", "Session registration error: {:?}", e)); + io.update_registration(self.token()).unwrap_or_else(|e| debug!(target: "network", "Session registration error: {:?}", e)); } !timed_out } pub fn token(&self) -> StreamToken { - self.connection.token() + self.connection().token() } fn read_packet(&mut self, packet: Packet, host: &HostInfo) -> Result { @@ -288,7 +345,7 @@ impl Session { while packet_id < self.info.capabilities[i].id_offset { i += 1; if i == self.info.capabilities.len() { - debug!(target: "net", "Unknown packet: {:?}", packet_id); + debug!(target: "network", "Unknown packet: {:?}", packet_id); return Ok(SessionData::None) } } @@ -299,7 +356,7 @@ impl Session { Ok(SessionData::Packet { data: packet.data, protocol: protocol, packet_id: pid } ) }, _ => { - debug!(target: "net", "Unknown packet: {:?}", packet_id); + debug!(target: "network", "Unknown packet: {:?}", packet_id); Ok(SessionData::None) } } @@ -314,7 +371,7 @@ impl Session { .append(&host.capabilities) .append(&host.local_endpoint.address.port()) .append(host.id()); - self.connection.send_packet(&rlp.out()) + self.send(rlp) } fn read_hello(&mut self, rlp: &UntrustedRlp, host: &HostInfo) -> Result<(), UtilError> { @@ -384,11 +441,13 @@ impl Session { /// Disconnect this session pub fn disconnect(&mut self, reason: DisconnectReason) -> NetworkError { - let mut rlp = RlpStream::new(); - rlp.append(&(PACKET_DISCONNECT as u32)); - rlp.begin_list(1); - rlp.append(&(reason as u32)); - self.connection.send_packet(&rlp.out()).ok(); + if let State::Session(_) = self.state { + let mut rlp = RlpStream::new(); + rlp.append(&(PACKET_DISCONNECT as u32)); + rlp.begin_list(1); + rlp.append(&(reason as u32)); + self.send(rlp).ok(); + } NetworkError::Disconnect(reason) } @@ -400,7 +459,15 @@ impl Session { } fn send(&mut self, rlp: RlpStream) -> Result<(), UtilError> { - self.connection.send_packet(&rlp.out()) + match self.state { + State::Handshake(_) => { + warn!(target:"network", "Unexpected send request"); + }, + State::Session(ref mut s) => { + try!(s.send_packet(&rlp.out())) + }, + } + Ok(()) } }