diff --git a/util/src/network/connection.rs b/util/src/network/connection.rs index 44d429164..242e8935e 100644 --- a/util/src/network/connection.rs +++ b/util/src/network/connection.rs @@ -164,6 +164,11 @@ impl Connection { self.token } + /// Replace socket token + pub fn set_token(&mut self, token: StreamToken) { + self.token = token; + } + /// Register this connection with the IO event loop. pub fn register_socket(&self, reg: Token, event_loop: &mut EventLoop) -> io::Result<()> { trace!(target: "net", "connection register; token={:?}", reg); @@ -243,6 +248,11 @@ impl EncryptedConnection { self.connection.token } + /// Replace socket token + pub fn set_token(&mut self, token: StreamToken) { + self.connection.set_token(token); + } + /// Create an encrypted connection out of the handshake. Consumes a handshake object. pub fn new(mut handshake: Handshake) -> Result { let shared = try!(crypto::ecdh::agree(handshake.ecdhe.secret(), &handshake.remote_public)); diff --git a/util/src/network/host.rs b/util/src/network/host.rs index 5c08ad5c8..04dda02bd 100644 --- a/util/src/network/host.rs +++ b/util/src/network/host.rs @@ -20,6 +20,7 @@ use std::hash::{Hasher}; use std::str::{FromStr}; use std::sync::*; use std::ops::*; +use std::cmp::min; use mio::*; use mio::tcp::*; use target_info::Target; @@ -41,9 +42,10 @@ use network::discovery::{Discovery, TableUpdates, NodeEntry}; type Slab = ::slab::Slab; const _DEFAULT_PORT: u16 = 30304; -const MAX_CONNECTIONS: usize = 1024; +const MAX_SESSIONS: usize = 1024; +const MAX_HANDSHAKES: usize = 256; +const MAX_HANDSHAKES_PER_ROUND: usize = 64; const MAINTENANCE_TIMEOUT: u64 = 1000; -const MAX_HANDSHAKES: usize = 100; #[derive(Debug)] /// Network service configuration @@ -132,13 +134,16 @@ impl NetworkConfiguration { } // Tokens -const TCP_ACCEPT: usize = MAX_CONNECTIONS + 1; -const IDLE: usize = MAX_CONNECTIONS + 2; -const DISCOVERY: usize = MAX_CONNECTIONS + 3; -const DISCOVERY_REFRESH: usize = MAX_CONNECTIONS + 4; -const DISCOVERY_ROUND: usize = MAX_CONNECTIONS + 5; -const FIRST_CONNECTION: usize = 0; -const LAST_CONNECTION: usize = FIRST_CONNECTION + MAX_CONNECTIONS - 1; +const TCP_ACCEPT: usize = LAST_HANDSHAKE + 1; +const IDLE: usize = LAST_HANDSHAKE + 2; +const DISCOVERY: usize = LAST_HANDSHAKE + 3; +const DISCOVERY_REFRESH: usize = LAST_HANDSHAKE + 4; +const DISCOVERY_ROUND: usize = LAST_HANDSHAKE + 5; +const FIRST_SESSION: usize = 0; +const LAST_SESSION: usize = FIRST_SESSION + MAX_SESSIONS - 1; +const FIRST_HANDSHAKE: usize = LAST_SESSION + 1; +const LAST_HANDSHAKE: usize = FIRST_HANDSHAKE + MAX_HANDSHAKES - 1; +const USER_TIMER: usize = LAST_HANDSHAKE + 256; /// Protocol handler level packet id pub type PacketId = u8; @@ -196,7 +201,7 @@ impl Encodable for CapabilityInfo { pub struct NetworkContext<'s, Message> where Message: Send + Sync + Clone + 'static, 's { io: &'s IoContext>, protocol: ProtocolId, - connections: Arc>>, + sessions: Arc>>, session: Option, } @@ -204,28 +209,23 @@ impl<'s, Message> NetworkContext<'s, Message> where Message: Send + Sync + Clone /// Create a new network IO access point. Takes references to all the data that can be updated within the IO handler. fn new(io: &'s IoContext>, protocol: ProtocolId, - session: Option, connections: Arc>>) -> NetworkContext<'s, Message> { + session: Option, sessions: Arc>>) -> NetworkContext<'s, Message> { NetworkContext { io: io, protocol: protocol, session: session, - connections: connections, + sessions: sessions, } } /// Send a packet over the network to another peer. pub fn send(&self, peer: PeerId, packet_id: PacketId, data: Vec) -> Result<(), UtilError> { - let connection = { self.connections.read().unwrap().get(peer).cloned() }; - if let Some(connection) = connection { - match *connection.lock().unwrap().deref_mut() { - ConnectionEntry::Session(ref mut s) => { - s.send_packet(self.protocol, packet_id as u8, &data).unwrap_or_else(|e| { + let session = { self.sessions.read().unwrap().get(peer).cloned() }; + if let Some(session) = session { + session.lock().unwrap().deref_mut().send_packet(self.protocol, packet_id as u8, &data).unwrap_or_else(|e| { warn!(target: "net", "Send error: {:?}", e); }); //TODO: don't copy vector data - try!(self.io.update_registration(peer)); - }, - _ => warn!(target: "net", "Send: Peer is not connected yet") - } + try!(self.io.update_registration(peer)); } else { trace!(target: "net", "Send: Peer no longer exist") } @@ -265,11 +265,9 @@ impl<'s, Message> NetworkContext<'s, Message> where Message: Send + Sync + Clone /// Returns peer identification string pub fn peer_info(&self, peer: PeerId) -> String { - let connection = { self.connections.read().unwrap().get(peer).cloned() }; - if let Some(connection) = connection { - if let ConnectionEntry::Session(ref s) = *connection.lock().unwrap().deref() { - return s.info.client_version.clone() - } + let session = { self.sessions.read().unwrap().get(peer).cloned() }; + if let Some(session) = session { + return session.lock().unwrap().info.client_version.clone() } "unknown".to_owned() } @@ -311,12 +309,8 @@ impl HostInfo { } } -enum ConnectionEntry { - Handshake(Handshake), - Session(Session) -} - -type SharedConnectionEntry = Arc>; +type SharedSession = Arc>; +type SharedHandshake = Arc>; #[derive(Copy, Clone)] struct ProtocolTimer { @@ -328,7 +322,8 @@ struct ProtocolTimer { pub struct Host where Message: Send + Sync + Clone { pub info: RwLock, tcp_listener: Mutex, - connections: Arc>>, + handshakes: Arc>>, + sessions: Arc>>, discovery: Mutex, nodes: RwLock, handlers: RwLock>>>, @@ -361,11 +356,12 @@ impl Host where Message: Send + Sync + Clone { }), discovery: Mutex::new(discovery), tcp_listener: Mutex::new(tcp_listener), - connections: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_CONNECTION, MAX_CONNECTIONS))), + handshakes: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_HANDSHAKE, MAX_HANDSHAKES))), + sessions: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_SESSION, MAX_SESSIONS))), nodes: RwLock::new(NodeTable::new(path)), handlers: RwLock::new(HashMap::new()), timers: RwLock::new(HashMap::new()), - timer_counter: RwLock::new(LAST_CONNECTION + 1), + timer_counter: RwLock::new(USER_TIMER), stats: Arc::new(NetworkStats::default()), }; let port = host.info.read().unwrap().config.listen_address.port(); @@ -407,33 +403,28 @@ impl Host where Message: Send + Sync + Clone { } fn have_session(&self, id: &NodeId) -> bool { - self.connections.read().unwrap().iter().any(|e| - match *e.lock().unwrap().deref() { ConnectionEntry::Session(ref s) => s.info.id.eq(&id), _ => false }) + self.sessions.read().unwrap().iter().any(|e| e.lock().unwrap().info.id.eq(&id)) } fn session_count(&self) -> usize { - self.connections.read().unwrap().iter().filter(|e| - match *e.lock().unwrap().deref() { ConnectionEntry::Session(_) => true, _ => false }).count() + self.sessions.read().unwrap().count() } fn connecting_to(&self, id: &NodeId) -> bool { - self.connections.read().unwrap().iter().any(|e| - match *e.lock().unwrap().deref() { ConnectionEntry::Handshake(ref h) => h.id.eq(&id), _ => false }) + self.handshakes.read().unwrap().iter().any(|e| e.lock().unwrap().id.eq(&id)) } fn handshake_count(&self) -> usize { - self.connections.read().unwrap().iter().filter(|e| - match *e.lock().unwrap().deref() { ConnectionEntry::Handshake(_) => true, _ => false }).count() + self.handshakes.read().unwrap().count() } fn keep_alive(&self, io: &IoContext>) { let mut to_kill = Vec::new(); - for e in self.connections.write().unwrap().iter_mut() { - if let ConnectionEntry::Session(ref mut s) = *e.lock().unwrap().deref_mut() { - if !s.keep_alive(io) { - s.disconnect(DisconnectReason::PingTimeout); - to_kill.push(s.token()); - } + for e in self.sessions.write().unwrap().iter_mut() { + let mut s = e.lock().unwrap(); + if !s.keep_alive(io) { + s.disconnect(DisconnectReason::PingTimeout); + to_kill.push(s.token()); } } for p in to_kill { @@ -443,8 +434,8 @@ impl Host where Message: Send + Sync + Clone { fn connect_peers(&self, io: &IoContext>) { let ideal_peers = { self.info.read().unwrap().deref().config.ideal_peers }; - let connections = self.session_count(); - if connections >= ideal_peers as usize { + let session_count = self.session_count(); + if session_count >= ideal_peers as usize { return; } @@ -453,10 +444,9 @@ impl Host where Message: Send + Sync + Clone { return; } - let nodes = { self.nodes.read().unwrap().nodes() }; - - for id in nodes.iter().filter(|ref id| !self.have_session(id) && !self.connecting_to(id)).take(MAX_HANDSHAKES - handshake_count) { + for id in nodes.iter().filter(|ref id| !self.have_session(id) && !self.connecting_to(id)) + .take(min(MAX_HANDSHAKES_PER_ROUND, MAX_HANDSHAKES - handshake_count)) { self.connect_peer(&id, io); } debug!(target: "net", "Connecting peers: {} sessions, {} pending", self.session_count(), self.handshake_count()); @@ -495,15 +485,15 @@ impl Host where Message: Send + Sync + Clone { #[allow(block_in_if_condition_stmt)] fn create_connection(&self, socket: TcpStream, id: Option<&NodeId>, io: &IoContext>) { let nonce = self.info.write().unwrap().next_nonce(); - let mut connections = self.connections.write().unwrap(); - if connections.insert_with(|token| { + let mut handshakes = self.handshakes.write().unwrap(); + if handshakes.insert_with(|token| { let mut handshake = Handshake::new(token, id, socket, &nonce, self.stats.clone()).expect("Can't create handshake"); handshake.start(io, &self.info.read().unwrap(), id.is_some()).and_then(|_| io.register_stream(token)).unwrap_or_else (|e| { debug!(target: "net", "Handshake create error: {:?}", e); }); - Arc::new(Mutex::new(ConnectionEntry::Handshake(handshake))) + Arc::new(Mutex::new(handshake)) }).is_none() { - warn!("Max connections reached"); + warn!("Max handshakes reached"); } } @@ -523,35 +513,18 @@ impl Host where Message: Send + Sync + Clone { io.update_registration(TCP_ACCEPT).expect("Error registering TCP listener"); } - #[allow(single_match)] - fn connection_writable(&self, token: StreamToken, io: &IoContext>) { + fn handshake_writable(&self, token: StreamToken, io: &IoContext>) { let mut create_session = false; let mut kill = false; - let connection = { self.connections.read().unwrap().get(token).cloned() }; - if let Some(connection) = connection { - match *connection.lock().unwrap().deref_mut() { - ConnectionEntry::Handshake(ref mut h) => { - match h.writable(io, &self.info.read().unwrap()) { - Err(e) => { - debug!(target: "net", "Handshake write error: {}:{:?}", token, e); - kill = true; - }, - Ok(_) => () - } - if h.done() { - create_session = true; - } - }, - ConnectionEntry::Session(ref mut s) => { - match s.writable(io, &self.info.read().unwrap()) { - Err(e) => { - debug!(target: "net", "Session write error: {}:{:?}", token, e); - kill = true; - }, - Ok(_) => () - } - io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Session registration error: {:?}", e)); - } + let handshake = { self.handshakes.read().unwrap().get(token).cloned() }; + if let Some(handshake) = handshake { + let mut h = handshake.lock().unwrap(); + if let Err(e) = h.writable(io, &self.info.read().unwrap()) { + debug!(target: "net", "Handshake write error: {}:{:?}", token, e); + kill = true; + } + if h.done() { + create_session = true; } } if kill { @@ -563,55 +536,40 @@ impl Host where Message: Send + Sync + Clone { } } + fn session_writable(&self, token: StreamToken, io: &IoContext>) { + let mut kill = false; + let session = { self.sessions.read().unwrap().get(token).cloned() }; + if let Some(session) = session { + let mut s = session.lock().unwrap(); + if let Err(e) = s.writable(io, &self.info.read().unwrap()) { + debug!(target: "net", "Session write error: {}:{:?}", token, e); + kill = true; + } + io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Session registration error: {:?}", e)); + } + if kill { + self.kill_connection(token, io, true); //TODO: mark connection as dead an check in kill_connection + } + } + fn connection_closed(&self, token: TimerToken, io: &IoContext>) { self.kill_connection(token, io, true); } - fn connection_readable(&self, token: StreamToken, io: &IoContext>) { - let mut ready_data: Vec = Vec::new(); - let mut packet_data: Option<(ProtocolId, PacketId, Vec)> = None; + fn handshake_readable(&self, token: StreamToken, io: &IoContext>) { let mut create_session = false; let mut kill = false; - let connection = { self.connections.read().unwrap().get(token).cloned() }; - if let Some(connection) = connection { - match *connection.lock().unwrap().deref_mut() { - ConnectionEntry::Handshake(ref mut h) => { - if let Err(e) = h.readable(io, &self.info.read().unwrap()) { - debug!(target: "net", "Handshake read error: {}:{:?}", token, e); - kill = true; - } - if h.done() { - create_session = true; - } - }, - ConnectionEntry::Session(ref mut s) => { - match s.readable(io, &self.info.read().unwrap()) { - Err(e) => { - debug!(target: "net", "Handshake read error: {}:{:?}", token, e); - kill = true; - }, - Ok(SessionData::Ready) => { - for (p, _) in self.handlers.read().unwrap().iter() { - if s.have_capability(p) { - ready_data.push(p); - } - } - }, - Ok(SessionData::Packet { - data, - protocol, - packet_id, - }) => { - match self.handlers.read().unwrap().get(protocol) { - None => { warn!(target: "net", "No handler found for protocol: {:?}", protocol) }, - Some(_) => packet_data = Some((protocol, packet_id, data)), - } - }, - Ok(SessionData::None) => {}, - } - } + let handshake = { self.handshakes.read().unwrap().get(token).cloned() }; + if let Some(handshake) = handshake { + let mut h = handshake.lock().unwrap(); + if let Err(e) = h.readable(io, &self.info.read().unwrap()) { + debug!(target: "net", "Handshake read error: {}:{:?}", token, e); + kill = true; } - } + if h.done() { + create_session = true; + } + } if kill { self.kill_connection(token, io, true); //TODO: mark connection as dead an check in kill_connection return; @@ -619,40 +577,74 @@ impl Host where Message: Send + Sync + Clone { self.start_session(token, io); io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Session registration error: {:?}", e)); } + io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Token registration error: {:?}", e)); + } + + fn session_readable(&self, token: StreamToken, io: &IoContext>) { + let mut ready_data: Vec = Vec::new(); + let mut packet_data: Option<(ProtocolId, PacketId, Vec)> = None; + let mut kill = false; + let session = { self.sessions.read().unwrap().get(token).cloned() }; + if let Some(session) = session { + let mut s = session.lock().unwrap(); + match s.readable(io, &self.info.read().unwrap()) { + Err(e) => { + debug!(target: "net", "Session read error: {}:{:?}", token, e); + kill = true; + }, + Ok(SessionData::Ready) => { + for (p, _) in self.handlers.read().unwrap().iter() { + if s.have_capability(p) { + ready_data.push(p); + } + } + }, + Ok(SessionData::Packet { + data, + protocol, + packet_id, + }) => { + match self.handlers.read().unwrap().get(protocol) { + None => { warn!(target: "net", "No handler found for protocol: {:?}", protocol) }, + Some(_) => packet_data = Some((protocol, packet_id, data)), + } + }, + Ok(SessionData::None) => {}, + } + } + if kill { + self.kill_connection(token, io, true); //TODO: mark connection as dead an check in kill_connection + return; + } for p in ready_data { let h = self.handlers.read().unwrap().get(p).unwrap().clone(); - h.connected(&NetworkContext::new(io, p, Some(token), self.connections.clone()), &token); + h.connected(&NetworkContext::new(io, p, Some(token), self.sessions.clone()), &token); } if let Some((p, packet_id, data)) = packet_data { let h = self.handlers.read().unwrap().get(p).unwrap().clone(); - h.read(&NetworkContext::new(io, p, Some(token), self.connections.clone()), &token, packet_id, &data[1..]); + h.read(&NetworkContext::new(io, p, Some(token), self.sessions.clone()), &token, packet_id, &data[1..]); } io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Token registration error: {:?}", e)); } fn start_session(&self, token: StreamToken, io: &IoContext>) { - let mut connections = self.connections.write().unwrap(); - let replace = { - let connection = { connections.get(token).cloned() }; - if let Some(connection) = connection { - match *connection.lock().unwrap().deref_mut() { - ConnectionEntry::Handshake(_) => true, - _ => false, - } - } else { false } - }; - if replace { - connections.replace_with(token, |c| { - match Arc::try_unwrap(c).ok().unwrap().into_inner().unwrap() { - ConnectionEntry::Handshake(h) => { - let session = Session::new(h, io, &self.info.read().unwrap()).expect("Session creation error"); - io.update_registration(token).expect("Error updating session registration"); - self.stats.inc_sessions(); - Some(Arc::new(Mutex::new(ConnectionEntry::Session(session)))) - }, - _ => { None } // handshake expired - } - }).ok(); + let mut handshakes = self.handshakes.write().unwrap(); + if handshakes.get(token).is_none() { + return; + } + + // turn a handshake into a session + let mut sessions = self.sessions.write().unwrap(); + let h = handshakes.remove(token).unwrap(); + let h = Arc::try_unwrap(h).ok().unwrap().into_inner().unwrap(); + let result = sessions.insert_with(move |session_token| { + let session = Session::new(h, session_token, &self.info.read().unwrap()).expect("Session creation error"); + io.update_registration(session_token).expect("Error updating session registration"); + self.stats.inc_sessions(); + Arc::new(Mutex::new(session)) + }); + if result.is_none() { + warn!("Max sessions reached"); } } @@ -663,28 +655,32 @@ impl Host where Message: Send + Sync + Clone { fn kill_connection(&self, token: StreamToken, io: &IoContext>, remote: bool) { let mut to_disconnect: Vec = Vec::new(); let mut failure_id = None; - { - let mut connections = self.connections.write().unwrap(); - if let Some(connection) = connections.get(token).cloned() { - match *connection.lock().unwrap().deref_mut() { - ConnectionEntry::Handshake(ref h) => { - connections.remove(token); - failure_id = Some(h.id().clone()); - }, - ConnectionEntry::Session(ref mut s) if s.is_ready() => { + match token { + FIRST_HANDSHAKE ... LAST_HANDSHAKE => { + let mut handshakes = self.handshakes.write().unwrap(); + if let Some(handshake) = handshakes.get(token).cloned() { + failure_id = Some(handshake.lock().unwrap().id().clone()); + handshakes.remove(token); + } + }, + FIRST_SESSION ... LAST_SESSION => { + let mut sessions = self.sessions.write().unwrap(); + if let Some(session) = sessions.get(token).cloned() { + let s = session.lock().unwrap(); + if s.is_ready() { for (p, _) in self.handlers.read().unwrap().iter() { if s.have_capability(p) { to_disconnect.push(p); } } - connections.remove(token); - failure_id = Some(s.id().clone()); - }, - _ => {}, + } + failure_id = Some(s.id().clone()); + sessions.remove(token); } - } - io.deregister_stream(token).expect("Error deregistering stream"); + }, + _ => {}, } + io.deregister_stream(token).expect("Error deregistering stream"); if let Some(id) = failure_id { if remote { self.nodes.write().unwrap().note_failure(&id); @@ -692,25 +688,28 @@ impl Host where Message: Send + Sync + Clone { } for p in to_disconnect { let h = self.handlers.read().unwrap().get(p).unwrap().clone(); - h.disconnected(&NetworkContext::new(io, p, Some(token), self.connections.clone()), &token); + h.disconnected(&NetworkContext::new(io, p, Some(token), self.sessions.clone()), &token); } } fn update_nodes(&self, io: &IoContext>, node_changes: TableUpdates) { let mut to_remove: Vec = Vec::new(); { - let connections = self.connections.write().unwrap(); - for c in connections.iter() { - match *c.lock().unwrap().deref_mut() { - ConnectionEntry::Handshake(ref h) => { - if node_changes.removed.contains(&h.id()) { - to_remove.push(h.token()); - } + { + let handshakes = self.handshakes.write().unwrap(); + for c in handshakes.iter() { + let h = c.lock().unwrap(); + if node_changes.removed.contains(&h.id()) { + to_remove.push(h.token()); } - ConnectionEntry::Session(ref s) => { - if node_changes.removed.contains(&s.id()) { - to_remove.push(s.token()); - } + } + } + { + let sessions = self.sessions.write().unwrap(); + for c in sessions.iter() { + let s = c.lock().unwrap(); + if node_changes.removed.contains(&s.id()) { + to_remove.push(s.token()); } } } @@ -735,14 +734,16 @@ impl IoHandler> for Host where Messa fn stream_hup(&self, io: &IoContext>, stream: StreamToken) { trace!(target: "net", "Hup: {}", stream); match stream { - FIRST_CONNECTION ... LAST_CONNECTION => self.connection_closed(stream, io), + FIRST_SESSION ... LAST_SESSION => self.connection_closed(stream, io), + FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_closed(stream, io), _ => warn!(target: "net", "Unexpected hup"), }; } fn stream_readable(&self, io: &IoContext>, stream: StreamToken) { match stream { - FIRST_CONNECTION ... LAST_CONNECTION => self.connection_readable(stream, io), + FIRST_SESSION ... LAST_SESSION => self.session_readable(stream, io), + FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_readable(stream, io), DISCOVERY => { if let Some(node_changes) = self.discovery.lock().unwrap().readable() { self.update_nodes(io, node_changes); @@ -756,7 +757,8 @@ impl IoHandler> for Host where Messa fn stream_writable(&self, io: &IoContext>, stream: StreamToken) { match stream { - FIRST_CONNECTION ... LAST_CONNECTION => self.connection_writable(stream, io), + FIRST_SESSION ... LAST_SESSION => self.session_writable(stream, io), + FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_writable(stream, io), DISCOVERY => { self.discovery.lock().unwrap().writable(); io.update_registration(DISCOVERY).expect("Error updating disicovery registration"); @@ -768,7 +770,8 @@ impl IoHandler> for Host where Messa fn timeout(&self, io: &IoContext>, token: TimerToken) { match token { IDLE => self.maintain_network(io), - FIRST_CONNECTION ... LAST_CONNECTION => self.connection_timeout(token, io), + FIRST_SESSION ... LAST_SESSION => self.connection_timeout(token, io), + FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_timeout(token, io), DISCOVERY_REFRESH => { self.discovery.lock().unwrap().refresh(); io.update_registration(DISCOVERY).expect("Error updating disicovery registration"); @@ -782,7 +785,7 @@ impl IoHandler> for Host where Messa _ => match self.timers.read().unwrap().get(&token).cloned() { Some(timer) => match self.handlers.read().unwrap().get(timer.protocol).cloned() { None => { warn!(target: "net", "No handler found for protocol: {:?}", timer.protocol) }, - Some(h) => { h.timeout(&NetworkContext::new(io, timer.protocol, None, self.connections.clone()), timer.token); } + Some(h) => { h.timeout(&NetworkContext::new(io, timer.protocol, None, self.sessions.clone()), timer.token); } }, None => { warn!("Unknown timer token: {}", token); } // timer is not registerd through us } @@ -797,7 +800,7 @@ impl IoHandler> for Host where Messa ref versions } => { let h = handler.clone(); - h.initialize(&NetworkContext::new(io, protocol, None, self.connections.clone())); + h.initialize(&NetworkContext::new(io, protocol, None, self.sessions.clone())); self.handlers.write().unwrap().insert(protocol, h); let mut info = self.info.write().unwrap(); for v in versions { @@ -820,18 +823,15 @@ impl IoHandler> for Host where Messa io.register_timer(handler_token, *delay).expect("Error registering timer"); }, NetworkIoMessage::Disconnect(ref peer) => { - let connection = { self.connections.read().unwrap().get(*peer).cloned() }; - if let Some(connection) = connection { - match *connection.lock().unwrap().deref_mut() { - ConnectionEntry::Handshake(_) => {}, - ConnectionEntry::Session(ref mut s) => { s.disconnect(DisconnectReason::DisconnectRequested); } - } + let session = { self.sessions.read().unwrap().get(*peer).cloned() }; + if let Some(session) = session { + session.lock().unwrap().disconnect(DisconnectReason::DisconnectRequested); } self.kill_connection(*peer, io, false); }, NetworkIoMessage::User(ref message) => { for (p, h) in self.handlers.read().unwrap().iter() { - h.message(&NetworkContext::new(io, p, None, self.connections.clone()), &message); + h.message(&NetworkContext::new(io, p, None, self.sessions.clone()), &message); } } } @@ -839,14 +839,14 @@ impl IoHandler> for Host where Messa fn register_stream(&self, stream: StreamToken, reg: Token, event_loop: &mut EventLoop>>) { match stream { - FIRST_CONNECTION ... LAST_CONNECTION => { - let connection = { self.connections.read().unwrap().get(stream).cloned() }; + FIRST_SESSION ... LAST_SESSION => { + warn!("Unexpected session stream registration"); + } + FIRST_HANDSHAKE ... LAST_HANDSHAKE => { + let connection = { self.handshakes.read().unwrap().get(stream).cloned() }; if let Some(connection) = connection { - match *connection.lock().unwrap().deref() { - ConnectionEntry::Handshake(ref h) => h.register_socket(reg, event_loop).expect("Error registering socket"), - ConnectionEntry::Session(_) => warn!("Unexpected session stream registration") - } - } else {} // expired + connection.lock().unwrap().register_socket(reg, event_loop).expect("Error registering socket"); + } } DISCOVERY => self.discovery.lock().unwrap().register_socket(event_loop).expect("Error registering discovery socket"), TCP_ACCEPT => event_loop.register(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error registering stream"), @@ -856,16 +856,20 @@ impl IoHandler> for Host where Messa fn deregister_stream(&self, stream: StreamToken, event_loop: &mut EventLoop>>) { match stream { - FIRST_CONNECTION ... LAST_CONNECTION => { - let mut connections = self.connections.write().unwrap(); + FIRST_SESSION ... LAST_SESSION => { + let mut connections = self.sessions.write().unwrap(); if let Some(connection) = connections.get(stream).cloned() { - match *connection.lock().unwrap().deref() { - ConnectionEntry::Handshake(ref h) => h.deregister_socket(event_loop).expect("Error deregistering socket"), - ConnectionEntry::Session(ref s) => s.deregister_socket(event_loop).expect("Error deregistering session socket"), - } + connection.lock().unwrap().deregister_socket(event_loop).expect("Error deregistering socket"); connections.remove(stream); - } - }, + } + } + FIRST_HANDSHAKE ... LAST_HANDSHAKE => { + let mut connections = self.handshakes.write().unwrap(); + if let Some(connection) = connections.get(stream).cloned() { + connection.lock().unwrap().deregister_socket(event_loop).expect("Error deregistering socket"); + connections.remove(stream); + } + } DISCOVERY => (), TCP_ACCEPT => event_loop.deregister(self.tcp_listener.lock().unwrap().deref()).unwrap(), _ => warn!("Unexpected stream deregistration") @@ -874,14 +878,17 @@ impl IoHandler> for Host where Messa fn update_stream(&self, stream: StreamToken, reg: Token, event_loop: &mut EventLoop>>) { match stream { - FIRST_CONNECTION ... LAST_CONNECTION => { - let connection = { self.connections.read().unwrap().get(stream).cloned() }; - if let Some(connection) = connection { - match *connection.lock().unwrap().deref() { - ConnectionEntry::Handshake(ref h) => h.update_socket(reg, event_loop).expect("Error updating socket"), - ConnectionEntry::Session(ref s) => s.update_socket(reg, event_loop).expect("Error updating socket"), - } - } else {} // expired + FIRST_SESSION ... LAST_SESSION => { + let connection = { self.sessions.read().unwrap().get(stream).cloned() }; + if let Some(connection) = connection { + connection.lock().unwrap().update_socket(reg, event_loop).expect("Error updating socket"); + } + } + FIRST_HANDSHAKE ... LAST_HANDSHAKE => { + let connection = { self.handshakes.read().unwrap().get(stream).cloned() }; + if let Some(connection) = connection { + connection.lock().unwrap().update_socket(reg, event_loop).expect("Error updating socket"); + } } DISCOVERY => self.discovery.lock().unwrap().update_registration(event_loop).expect("Error reregistering discovery socket"), TCP_ACCEPT => event_loop.reregister(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error reregistering stream"), diff --git a/util/src/network/session.rs b/util/src/network/session.rs index 3b49a8f5e..2763dfd82 100644 --- a/util/src/network/session.rs +++ b/util/src/network/session.rs @@ -108,7 +108,7 @@ const PACKET_LAST: u8 = 0x7f; impl Session { /// Create a new session out of comepleted handshake. Consumes handshake object. - pub fn new(h: Handshake, _io: &IoContext, host: &HostInfo) -> Result where Message: Send + Sync + Clone { + pub fn new(h: Handshake, token: StreamToken, host: &HostInfo) -> Result { let id = h.id.clone(); let connection = try!(EncryptedConnection::new(h)); let mut session = Session { @@ -124,6 +124,7 @@ impl Session { ping_time_ns: 0, pong_time_ns: None, }; + session.connection.set_token(token); try!(session.write_hello(host)); try!(session.send_ping()); Ok(session)