diff --git a/util/network/src/host.rs b/util/network/src/host.rs index 2f236a5f7..05f0e83a4 100644 --- a/util/network/src/host.rs +++ b/util/network/src/host.rs @@ -683,8 +683,7 @@ impl Host { #[cfg_attr(feature="dev", allow(single_match))] fn connect_peer(&self, id: &NodeId, io: &IoContext) { - if self.have_session(id) - { + if self.have_session(id) { trace!(target: "network", "Aborted connect. Node already connected."); return; } @@ -788,102 +787,119 @@ impl Host { let mut packet_data: Vec<(ProtocolId, PacketId, Vec)> = Vec::new(); let mut kill = false; let session = { self.sessions.read().get(token).cloned() }; + let mut ready_id = None; if let Some(session) = session.clone() { - let mut s = session.lock(); - loop { - let session_result = s.readable(io, &self.info.read()); - match session_result { - Err(e) => { - trace!(target: "network", "Session read error: {}:{:?} ({:?}) {:?}", token, s.id(), s.remote_addr(), e); - if let NetworkError::Disconnect(DisconnectReason::IncompatibleProtocol) = e { - if let Some(id) = s.id() { - if !self.reserved_nodes.read().contains(id) { - self.nodes.write().mark_as_useless(id); + { + let mut s = session.lock(); + loop { + let session_result = s.readable(io, &self.info.read()); + match session_result { + Err(e) => { + trace!(target: "network", "Session read error: {}:{:?} ({:?}) {:?}", token, s.id(), s.remote_addr(), e); + if let NetworkError::Disconnect(DisconnectReason::IncompatibleProtocol) = e { + if let Some(id) = s.id() { + if !self.reserved_nodes.read().contains(id) { + self.nodes.write().mark_as_useless(id); + } } } - } - kill = true; - break; - }, - Ok(SessionData::Ready) => { - self.num_sessions.fetch_add(1, AtomicOrdering::SeqCst); - let session_count = self.session_count(); - let (min_peers, max_peers, reserved_only) = { - let info = self.info.read(); - let mut max_peers = info.config.max_peers; - for cap in s.info.capabilities.iter() { - if let Some(num) = info.config.reserved_protocols.get(&cap.protocol) { - max_peers += *num; - break; + kill = true; + break; + }, + Ok(SessionData::Ready) => { + self.num_sessions.fetch_add(1, AtomicOrdering::SeqCst); + let session_count = self.session_count(); + let (min_peers, max_peers, reserved_only) = { + let info = self.info.read(); + let mut max_peers = info.config.max_peers; + for cap in s.info.capabilities.iter() { + if let Some(num) = info.config.reserved_protocols.get(&cap.protocol) { + max_peers += *num; + break; + } + } + (info.config.min_peers as usize, max_peers as usize, info.config.non_reserved_mode == NonReservedPeerMode::Deny) + }; + + let id = s.id().expect("Ready session always has id").clone(); + + // Check for the session limit. session_counts accounts for the new session. + if reserved_only || + (s.info.originated && session_count > min_peers) || + (!s.info.originated && session_count > max_peers) { + // only proceed if the connecting peer is reserved. + if !self.reserved_nodes.read().contains(&id) { + s.disconnect(io, DisconnectReason::TooManyPeers); + return; } } - (info.config.min_peers as usize, max_peers as usize, info.config.non_reserved_mode == NonReservedPeerMode::Deny) - }; + ready_id = Some(id); - // Check for the session limit. session_counts accounts for the new session. - if reserved_only || - (s.info.originated && session_count > min_peers) || - (!s.info.originated && session_count > max_peers) { - // only proceed if the connecting peer is reserved. - if !self.reserved_nodes.read().contains(s.id().expect("Ready session always has id")) { - s.disconnect(io, DisconnectReason::TooManyPeers); - return; - } - } - - // Add it to the node table - if !s.info.originated { - if let Ok(address) = s.remote_addr() { - let entry = NodeEntry { id: s.id().expect("Ready session always has id").clone(), endpoint: NodeEndpoint { address: address, udp_port: address.port() } }; - self.nodes.write().add_node(Node::new(entry.id.clone(), entry.endpoint.clone())); - let mut discovery = self.discovery.lock(); - if let Some(ref mut discovery) = *discovery { - discovery.add_node(entry); + // Add it to the node table + if !s.info.originated { + if let Ok(address) = s.remote_addr() { + let entry = NodeEntry { id: id, endpoint: NodeEndpoint { address: address, udp_port: address.port() } }; + self.nodes.write().add_node(Node::new(entry.id.clone(), entry.endpoint.clone())); + let mut discovery = self.discovery.lock(); + if let Some(ref mut discovery) = *discovery { + discovery.add_node(entry); + } } } - } - for (p, _) in self.handlers.read().iter() { - if s.have_capability(*p) { - ready_data.push(*p); + for (p, _) in self.handlers.read().iter() { + if s.have_capability(*p) { + ready_data.push(*p); + } } - } - }, - Ok(SessionData::Packet { - data, - protocol, - packet_id, - }) => { - match self.handlers.read().get(&protocol) { - None => { warn!(target: "network", "No handler found for protocol: {:?}", protocol) }, - Some(_) => packet_data.push((protocol, packet_id, data)), - } - }, - Ok(SessionData::Continue) => (), - Ok(SessionData::None) => break, - } - } - } - if kill { - self.kill_connection(token, io, true); - } - let handlers = self.handlers.read(); - for p in ready_data { - self.stats.inc_sessions(); - let reserved = self.reserved_nodes.read(); - if let Some(h) = handlers.get(&p).clone() { - h.connected(&NetworkContext::new(io, p, session.clone(), self.sessions.clone(), &reserved), &token); - - // accumulate pending packets. - if let Some(session) = session.as_ref() { - let mut session = session.lock(); - packet_data.extend(session.mark_connected(p)); + }, + Ok(SessionData::Packet { + data, + protocol, + packet_id, + }) => { + match self.handlers.read().get(&protocol) { + None => { warn!(target: "network", "No handler found for protocol: {:?}", protocol) }, + Some(_) => packet_data.push((protocol, packet_id, data)), + } + }, + Ok(SessionData::Continue) => (), + Ok(SessionData::None) => break, + } } } - } - for (p, packet_id, data) in packet_data { - let reserved = self.reserved_nodes.read(); - if let Some(h) = handlers.get(&p).clone() { - h.read(&NetworkContext::new(io, p, session.clone(), self.sessions.clone(), &reserved), &token, packet_id, &data[1..]); + + if kill { + self.kill_connection(token, io, true); + } + + let handlers = self.handlers.read(); + if !ready_data.is_empty() { + let duplicate = self.sessions.read().iter().any(|e| { + let session = e.lock(); + session.token() != token && session.info.id == ready_id + }); + if duplicate { + trace!(target: "network", "Rejected duplicate connection: {}", token); + session.lock().disconnect(io, DisconnectReason::DuplicatePeer); + return; + } + for p in ready_data { + self.stats.inc_sessions(); + let reserved = self.reserved_nodes.read(); + if let Some(h) = handlers.get(&p).clone() { + h.connected(&NetworkContext::new(io, p, Some(session.clone()), self.sessions.clone(), &reserved), &token); + // accumulate pending packets. + let mut session = session.lock(); + packet_data.extend(session.mark_connected(p)); + } + } + } + + for (p, packet_id, data) in packet_data { + let reserved = self.reserved_nodes.read(); + if let Some(h) = handlers.get(&p).clone() { + h.read(&NetworkContext::new(io, p, Some(session.clone()), self.sessions.clone(), &reserved), &token, packet_id, &data[1..]); + } } } }