diff --git a/util/src/network/host.rs b/util/src/network/host.rs index 3d55430bd..004410466 100644 --- a/util/src/network/host.rs +++ b/util/src/network/host.rs @@ -291,13 +291,14 @@ pub struct Host where Message: Send + Sync + Clone { tcp_listener: Mutex, handshakes: Arc>>, sessions: Arc>>, - discovery: Mutex, + discovery: Option>, nodes: RwLock, handlers: RwLock>>>, timers: RwLock>, timer_counter: RwLock, stats: Arc, public_endpoint: NodeEndpoint, + pinned_nodes: Vec, } impl Host where Message: Send + Sync + Clone { @@ -343,7 +344,9 @@ impl Host where Message: Send + Sync + Clone { }, |s| KeyPair::from_secret(s).expect("Error creating node secret key")) }; - let discovery = Discovery::new(&keys, listen_address.clone(), public_endpoint.clone(), DISCOVERY); + let discovery = if config.discovery_enabled && !config.pin { + Some(Discovery::new(&keys, listen_address.clone(), public_endpoint.clone(), DISCOVERY)) + } else { None }; let path = config.config_path.clone(); let mut host = Host:: { info: RwLock::new(HostInfo { @@ -355,7 +358,7 @@ impl Host where Message: Send + Sync + Clone { listen_port: 0, capabilities: Vec::new(), }), - discovery: Mutex::new(discovery), + discovery: discovery.map(Mutex::new), tcp_listener: Mutex::new(tcp_listener), handshakes: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_HANDSHAKE, MAX_HANDSHAKES))), sessions: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_SESSION, MAX_SESSIONS))), @@ -365,6 +368,7 @@ impl Host where Message: Send + Sync + Clone { timer_counter: RwLock::new(USER_TIMER), stats: Arc::new(NetworkStats::default()), public_endpoint: public_endpoint, + pinned_nodes: Vec::new(), }; let port = listen_address.port(); host.info.write().unwrap().deref_mut().listen_port = port; @@ -373,7 +377,9 @@ impl Host where Message: Send + Sync + Clone { for n in boot_nodes { host.add_node(&n); } - host.discovery.lock().unwrap().init_node_list(host.nodes.read().unwrap().unordered_entries()); + if let Some(ref mut discovery) = host.discovery { + discovery.lock().unwrap().init_node_list(host.nodes.read().unwrap().unordered_entries()); + } host } @@ -386,8 +392,11 @@ impl Host where Message: Send + Sync + Clone { Err(e) => { warn!("Could not add node: {:?}", e); }, Ok(n) => { let entry = NodeEntry { endpoint: n.endpoint.clone(), id: n.id.clone() }; + self.pinned_nodes.push(n.id.clone()); self.nodes.write().unwrap().add_node(n); - self.discovery.lock().unwrap().add_node(entry); + if let Some(ref mut discovery) = self.discovery { + discovery.lock().unwrap().add_node(entry); + } } } } @@ -437,6 +446,7 @@ impl Host where Message: Send + Sync + Clone { fn connect_peers(&self, io: &IoContext>) { let ideal_peers = { self.info.read().unwrap().deref().config.ideal_peers }; + let pin = { self.info.read().unwrap().deref().config.pin }; let session_count = self.session_count(); if session_count >= ideal_peers as usize { return; @@ -449,7 +459,7 @@ impl Host where Message: Send + Sync + Clone { return; } - let nodes = { self.nodes.read().unwrap().nodes() }; + let nodes = if pin { self.pinned_nodes.clone() } else { self.nodes.read().unwrap().nodes() }; for id in nodes.iter().filter(|ref id| !self.have_session(id) && !self.connecting_to(id)) .take(min(MAX_HANDSHAKES_PER_ROUND, handshake_limit - handshake_count)) { self.connect_peer(&id, io); @@ -670,7 +680,9 @@ impl Host where Message: Send + Sync + Clone { if let Ok(address) = session.remote_addr() { let entry = NodeEntry { id: session.id().clone(), endpoint: NodeEndpoint { address: address, udp_port: address.port() } }; self.nodes.write().unwrap().add_node(Node::new(entry.id.clone(), entry.endpoint.clone())); - self.discovery.lock().unwrap().add_node(entry); + if let Some(ref discovery) = self.discovery { + discovery.lock().unwrap().add_node(entry); + } } } Arc::new(Mutex::new(session)) @@ -759,8 +771,10 @@ impl IoHandler> for Host where Messa io.register_stream(TCP_ACCEPT).expect("Error registering TCP listener"); io.register_stream(DISCOVERY).expect("Error registering UDP listener"); io.register_timer(IDLE, MAINTENANCE_TIMEOUT).expect("Error registering Network idle timer"); - io.register_timer(DISCOVERY_REFRESH, 7200).expect("Error registering discovery timer"); - io.register_timer(DISCOVERY_ROUND, 300).expect("Error registering discovery timer"); + if self.discovery.is_some() { + io.register_timer(DISCOVERY_REFRESH, 7200).expect("Error registering discovery timer"); + io.register_timer(DISCOVERY_ROUND, 300).expect("Error registering discovery timer"); + } } fn stream_hup(&self, io: &IoContext>, stream: StreamToken) { @@ -777,7 +791,7 @@ impl IoHandler> for Host where Messa FIRST_SESSION ... LAST_SESSION => self.session_readable(stream, io), FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_readable(stream, io), DISCOVERY => { - if let Some(node_changes) = self.discovery.lock().unwrap().readable() { + if let Some(node_changes) = self.discovery.as_ref().unwrap().lock().unwrap().readable() { self.update_nodes(io, node_changes); } io.update_registration(DISCOVERY).expect("Error updating disicovery registration"); @@ -792,7 +806,7 @@ impl IoHandler> for Host where Messa FIRST_SESSION ... LAST_SESSION => self.session_writable(stream, io), FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_writable(stream, io), DISCOVERY => { - self.discovery.lock().unwrap().writable(); + self.discovery.as_ref().unwrap().lock().unwrap().writable(); io.update_registration(DISCOVERY).expect("Error updating disicovery registration"); } _ => panic!("Received unknown writable token"), @@ -805,11 +819,11 @@ impl IoHandler> for Host where Messa FIRST_SESSION ... LAST_SESSION => self.connection_timeout(token, io), FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_timeout(token, io), DISCOVERY_REFRESH => { - self.discovery.lock().unwrap().refresh(); + self.discovery.as_ref().unwrap().lock().unwrap().refresh(); io.update_registration(DISCOVERY).expect("Error updating disicovery registration"); }, DISCOVERY_ROUND => { - if let Some(node_changes) = self.discovery.lock().unwrap().round() { + if let Some(node_changes) = self.discovery.as_ref().unwrap().lock().unwrap().round() { self.update_nodes(io, node_changes); } io.update_registration(DISCOVERY).expect("Error updating disicovery registration"); @@ -880,7 +894,7 @@ impl IoHandler> for Host where Messa connection.lock().unwrap().register_socket(reg, event_loop).expect("Error registering socket"); } } - DISCOVERY => self.discovery.lock().unwrap().register_socket(event_loop).expect("Error registering discovery socket"), + DISCOVERY => self.discovery.as_ref().unwrap().lock().unwrap().register_socket(event_loop).expect("Error registering discovery socket"), TCP_ACCEPT => event_loop.register(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error registering stream"), _ => warn!("Unexpected stream registration") } @@ -922,7 +936,7 @@ impl IoHandler> for Host where Messa connection.lock().unwrap().update_socket(reg, event_loop).expect("Error updating socket"); } } - DISCOVERY => self.discovery.lock().unwrap().update_registration(event_loop).expect("Error reregistering discovery socket"), + DISCOVERY => self.discovery.as_ref().unwrap().lock().unwrap().update_registration(event_loop).expect("Error reregistering discovery socket"), TCP_ACCEPT => event_loop.reregister(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error reregistering stream"), _ => warn!("Unexpected stream update") }