Handle pinning and enable_discovery options

This commit is contained in:
arkpar 2016-02-16 02:31:17 +01:00
parent f771306867
commit 58fdfe77d3

View File

@ -291,13 +291,14 @@ pub struct Host<Message> where Message: Send + Sync + Clone {
tcp_listener: Mutex<TcpListener>, tcp_listener: Mutex<TcpListener>,
handshakes: Arc<RwLock<Slab<SharedHandshake>>>, handshakes: Arc<RwLock<Slab<SharedHandshake>>>,
sessions: Arc<RwLock<Slab<SharedSession>>>, sessions: Arc<RwLock<Slab<SharedSession>>>,
discovery: Mutex<Discovery>, discovery: Option<Mutex<Discovery>>,
nodes: RwLock<NodeTable>, nodes: RwLock<NodeTable>,
handlers: RwLock<HashMap<ProtocolId, Arc<NetworkProtocolHandler<Message>>>>, handlers: RwLock<HashMap<ProtocolId, Arc<NetworkProtocolHandler<Message>>>>,
timers: RwLock<HashMap<TimerToken, ProtocolTimer>>, timers: RwLock<HashMap<TimerToken, ProtocolTimer>>,
timer_counter: RwLock<usize>, timer_counter: RwLock<usize>,
stats: Arc<NetworkStats>, stats: Arc<NetworkStats>,
public_endpoint: NodeEndpoint, public_endpoint: NodeEndpoint,
pinned_nodes: Vec<NodeId>,
} }
impl<Message> Host<Message> where Message: Send + Sync + Clone { impl<Message> Host<Message> where Message: Send + Sync + Clone {
@ -343,7 +344,9 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
}, },
|s| KeyPair::from_secret(s).expect("Error creating node secret key")) |s| KeyPair::from_secret(s).expect("Error creating node secret key"))
}; };
let discovery = Discovery::new(&keys, listen_address.clone(), public_endpoint.clone(), DISCOVERY); let discovery = if config.discovery_enabled && !config.pin {
Some(Discovery::new(&keys, listen_address.clone(), public_endpoint.clone(), DISCOVERY))
} else { None };
let path = config.config_path.clone(); let path = config.config_path.clone();
let mut host = Host::<Message> { let mut host = Host::<Message> {
info: RwLock::new(HostInfo { info: RwLock::new(HostInfo {
@ -355,7 +358,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
listen_port: 0, listen_port: 0,
capabilities: Vec::new(), capabilities: Vec::new(),
}), }),
discovery: Mutex::new(discovery), discovery: discovery.map(Mutex::new),
tcp_listener: Mutex::new(tcp_listener), tcp_listener: Mutex::new(tcp_listener),
handshakes: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_HANDSHAKE, MAX_HANDSHAKES))), handshakes: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_HANDSHAKE, MAX_HANDSHAKES))),
sessions: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_SESSION, MAX_SESSIONS))), sessions: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_SESSION, MAX_SESSIONS))),
@ -365,6 +368,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
timer_counter: RwLock::new(USER_TIMER), timer_counter: RwLock::new(USER_TIMER),
stats: Arc::new(NetworkStats::default()), stats: Arc::new(NetworkStats::default()),
public_endpoint: public_endpoint, public_endpoint: public_endpoint,
pinned_nodes: Vec::new(),
}; };
let port = listen_address.port(); let port = listen_address.port();
host.info.write().unwrap().deref_mut().listen_port = port; host.info.write().unwrap().deref_mut().listen_port = port;
@ -373,7 +377,9 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
for n in boot_nodes { for n in boot_nodes {
host.add_node(&n); host.add_node(&n);
} }
host.discovery.lock().unwrap().init_node_list(host.nodes.read().unwrap().unordered_entries()); if let Some(ref mut discovery) = host.discovery {
discovery.lock().unwrap().init_node_list(host.nodes.read().unwrap().unordered_entries());
}
host host
} }
@ -386,8 +392,11 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
Err(e) => { warn!("Could not add node: {:?}", e); }, Err(e) => { warn!("Could not add node: {:?}", e); },
Ok(n) => { Ok(n) => {
let entry = NodeEntry { endpoint: n.endpoint.clone(), id: n.id.clone() }; let entry = NodeEntry { endpoint: n.endpoint.clone(), id: n.id.clone() };
self.pinned_nodes.push(n.id.clone());
self.nodes.write().unwrap().add_node(n); self.nodes.write().unwrap().add_node(n);
self.discovery.lock().unwrap().add_node(entry); if let Some(ref mut discovery) = self.discovery {
discovery.lock().unwrap().add_node(entry);
}
} }
} }
} }
@ -437,6 +446,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
fn connect_peers(&self, io: &IoContext<NetworkIoMessage<Message>>) { fn connect_peers(&self, io: &IoContext<NetworkIoMessage<Message>>) {
let ideal_peers = { self.info.read().unwrap().deref().config.ideal_peers }; let ideal_peers = { self.info.read().unwrap().deref().config.ideal_peers };
let pin = { self.info.read().unwrap().deref().config.pin };
let session_count = self.session_count(); let session_count = self.session_count();
if session_count >= ideal_peers as usize { if session_count >= ideal_peers as usize {
return; return;
@ -449,7 +459,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
return; return;
} }
let nodes = { self.nodes.read().unwrap().nodes() }; let nodes = if pin { self.pinned_nodes.clone() } else { self.nodes.read().unwrap().nodes() };
for id in nodes.iter().filter(|ref id| !self.have_session(id) && !self.connecting_to(id)) for id in nodes.iter().filter(|ref id| !self.have_session(id) && !self.connecting_to(id))
.take(min(MAX_HANDSHAKES_PER_ROUND, handshake_limit - handshake_count)) { .take(min(MAX_HANDSHAKES_PER_ROUND, handshake_limit - handshake_count)) {
self.connect_peer(&id, io); self.connect_peer(&id, io);
@ -670,7 +680,9 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
if let Ok(address) = session.remote_addr() { if let Ok(address) = session.remote_addr() {
let entry = NodeEntry { id: session.id().clone(), endpoint: NodeEndpoint { address: address, udp_port: address.port() } }; let entry = NodeEntry { id: session.id().clone(), endpoint: NodeEndpoint { address: address, udp_port: address.port() } };
self.nodes.write().unwrap().add_node(Node::new(entry.id.clone(), entry.endpoint.clone())); self.nodes.write().unwrap().add_node(Node::new(entry.id.clone(), entry.endpoint.clone()));
self.discovery.lock().unwrap().add_node(entry); if let Some(ref discovery) = self.discovery {
discovery.lock().unwrap().add_node(entry);
}
} }
} }
Arc::new(Mutex::new(session)) Arc::new(Mutex::new(session))
@ -759,9 +771,11 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
io.register_stream(TCP_ACCEPT).expect("Error registering TCP listener"); io.register_stream(TCP_ACCEPT).expect("Error registering TCP listener");
io.register_stream(DISCOVERY).expect("Error registering UDP listener"); io.register_stream(DISCOVERY).expect("Error registering UDP listener");
io.register_timer(IDLE, MAINTENANCE_TIMEOUT).expect("Error registering Network idle timer"); io.register_timer(IDLE, MAINTENANCE_TIMEOUT).expect("Error registering Network idle timer");
if self.discovery.is_some() {
io.register_timer(DISCOVERY_REFRESH, 7200).expect("Error registering discovery timer"); io.register_timer(DISCOVERY_REFRESH, 7200).expect("Error registering discovery timer");
io.register_timer(DISCOVERY_ROUND, 300).expect("Error registering discovery timer"); io.register_timer(DISCOVERY_ROUND, 300).expect("Error registering discovery timer");
} }
}
fn stream_hup(&self, io: &IoContext<NetworkIoMessage<Message>>, stream: StreamToken) { fn stream_hup(&self, io: &IoContext<NetworkIoMessage<Message>>, stream: StreamToken) {
trace!(target: "net", "Hup: {}", stream); trace!(target: "net", "Hup: {}", stream);
@ -777,7 +791,7 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
FIRST_SESSION ... LAST_SESSION => self.session_readable(stream, io), FIRST_SESSION ... LAST_SESSION => self.session_readable(stream, io),
FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_readable(stream, io), FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_readable(stream, io),
DISCOVERY => { DISCOVERY => {
if let Some(node_changes) = self.discovery.lock().unwrap().readable() { if let Some(node_changes) = self.discovery.as_ref().unwrap().lock().unwrap().readable() {
self.update_nodes(io, node_changes); self.update_nodes(io, node_changes);
} }
io.update_registration(DISCOVERY).expect("Error updating disicovery registration"); io.update_registration(DISCOVERY).expect("Error updating disicovery registration");
@ -792,7 +806,7 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
FIRST_SESSION ... LAST_SESSION => self.session_writable(stream, io), FIRST_SESSION ... LAST_SESSION => self.session_writable(stream, io),
FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_writable(stream, io), FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_writable(stream, io),
DISCOVERY => { DISCOVERY => {
self.discovery.lock().unwrap().writable(); self.discovery.as_ref().unwrap().lock().unwrap().writable();
io.update_registration(DISCOVERY).expect("Error updating disicovery registration"); io.update_registration(DISCOVERY).expect("Error updating disicovery registration");
} }
_ => panic!("Received unknown writable token"), _ => panic!("Received unknown writable token"),
@ -805,11 +819,11 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
FIRST_SESSION ... LAST_SESSION => self.connection_timeout(token, io), FIRST_SESSION ... LAST_SESSION => self.connection_timeout(token, io),
FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_timeout(token, io), FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_timeout(token, io),
DISCOVERY_REFRESH => { DISCOVERY_REFRESH => {
self.discovery.lock().unwrap().refresh(); self.discovery.as_ref().unwrap().lock().unwrap().refresh();
io.update_registration(DISCOVERY).expect("Error updating disicovery registration"); io.update_registration(DISCOVERY).expect("Error updating disicovery registration");
}, },
DISCOVERY_ROUND => { DISCOVERY_ROUND => {
if let Some(node_changes) = self.discovery.lock().unwrap().round() { if let Some(node_changes) = self.discovery.as_ref().unwrap().lock().unwrap().round() {
self.update_nodes(io, node_changes); self.update_nodes(io, node_changes);
} }
io.update_registration(DISCOVERY).expect("Error updating disicovery registration"); io.update_registration(DISCOVERY).expect("Error updating disicovery registration");
@ -880,7 +894,7 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
connection.lock().unwrap().register_socket(reg, event_loop).expect("Error registering socket"); connection.lock().unwrap().register_socket(reg, event_loop).expect("Error registering socket");
} }
} }
DISCOVERY => self.discovery.lock().unwrap().register_socket(event_loop).expect("Error registering discovery socket"), DISCOVERY => self.discovery.as_ref().unwrap().lock().unwrap().register_socket(event_loop).expect("Error registering discovery socket"),
TCP_ACCEPT => event_loop.register(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error registering stream"), TCP_ACCEPT => event_loop.register(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error registering stream"),
_ => warn!("Unexpected stream registration") _ => warn!("Unexpected stream registration")
} }
@ -922,7 +936,7 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
connection.lock().unwrap().update_socket(reg, event_loop).expect("Error updating socket"); connection.lock().unwrap().update_socket(reg, event_loop).expect("Error updating socket");
} }
} }
DISCOVERY => self.discovery.lock().unwrap().update_registration(event_loop).expect("Error reregistering discovery socket"), DISCOVERY => self.discovery.as_ref().unwrap().lock().unwrap().update_registration(event_loop).expect("Error reregistering discovery socket"),
TCP_ACCEPT => event_loop.reregister(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error reregistering stream"), TCP_ACCEPT => event_loop.reregister(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error reregistering stream"),
_ => warn!("Unexpected stream update") _ => warn!("Unexpected stream update")
} }