Merge pull request #480 from ethcore/network

Networking fixes
This commit is contained in:
Gav Wood 2016-02-21 18:57:10 +01:00
commit 75129613c5
5 changed files with 161 additions and 77 deletions

View File

@ -175,13 +175,26 @@ impl Connection {
self.socket.peer_addr() self.socket.peer_addr()
} }
/// Clone this connection. Clears the receiving buffer of the returned connection.
pub fn try_clone(&self) -> io::Result<Self> {
Ok(Connection {
token: self.token,
socket: try!(self.socket.try_clone()),
rec_buf: Vec::new(),
rec_size: 0,
send_queue: self.send_queue.clone(),
interest: EventSet::hup() | EventSet::readable(),
stats: self.stats.clone(),
})
}
/// Register this connection with the IO event loop. /// Register this connection with the IO event loop.
pub fn register_socket<Host: Handler>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> io::Result<()> { pub fn register_socket<Host: Handler>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> io::Result<()> {
trace!(target: "net", "connection register; token={:?}", reg); trace!(target: "net", "connection register; token={:?}", reg);
event_loop.register(&self.socket, reg, self.interest, PollOpt::edge() | PollOpt::oneshot()).or_else(|e| { if let Err(e) = event_loop.register(&self.socket, reg, self.interest, PollOpt::edge() | PollOpt::oneshot()) {
debug!("Failed to register {:?}, {:?}", reg, e); debug!("Failed to register {:?}, {:?}", reg, e);
}
Ok(()) Ok(())
})
} }
/// Update connection registration. Should be called at the end of the IO handler. /// Update connection registration. Should be called at the end of the IO handler.
@ -265,7 +278,7 @@ impl EncryptedConnection {
} }
/// Create an encrypted connection out of the handshake. Consumes a handshake object. /// Create an encrypted connection out of the handshake. Consumes a handshake object.
pub fn new(mut handshake: Handshake) -> Result<EncryptedConnection, UtilError> { pub fn new(handshake: &mut Handshake) -> Result<EncryptedConnection, UtilError> {
let shared = try!(crypto::ecdh::agree(handshake.ecdhe.secret(), &handshake.remote_public)); let shared = try!(crypto::ecdh::agree(handshake.ecdhe.secret(), &handshake.remote_public));
let mut nonce_material = H512::new(); let mut nonce_material = H512::new();
if handshake.originated { if handshake.originated {
@ -300,9 +313,8 @@ impl EncryptedConnection {
ingress_mac.update(&mac_material); ingress_mac.update(&mac_material);
ingress_mac.update(if handshake.originated { &handshake.ack_cipher } else { &handshake.auth_cipher }); ingress_mac.update(if handshake.originated { &handshake.ack_cipher } else { &handshake.auth_cipher });
handshake.connection.expect(ENCRYPTED_HEADER_LEN); let mut enc = EncryptedConnection {
Ok(EncryptedConnection { connection: try!(handshake.connection.try_clone()),
connection: handshake.connection,
encoder: encoder, encoder: encoder,
decoder: decoder, decoder: decoder,
mac_encoder: mac_encoder, mac_encoder: mac_encoder,
@ -311,7 +323,9 @@ impl EncryptedConnection {
read_state: EncryptedConnectionState::Header, read_state: EncryptedConnectionState::Header,
protocol_id: 0, protocol_id: 0,
payload_len: 0 payload_len: 0
}) };
enc.connection.expect(ENCRYPTED_HEADER_LEN);
Ok(enc)
} }
/// Send a packet /// Send a packet
@ -440,6 +454,12 @@ impl EncryptedConnection {
Ok(()) Ok(())
} }
/// Register socket with the event lpop. This should be called at the end of the event loop.
pub fn register_socket<Host:Handler>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
try!(self.connection.register_socket(reg, event_loop));
Ok(())
}
/// Update connection registration. This should be called at the end of the event loop. /// Update connection registration. This should be called at the end of the event loop.
pub fn update_socket<Host:Handler>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> { pub fn update_socket<Host:Handler>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
try!(self.connection.update_socket(reg, event_loop)); try!(self.connection.update_socket(reg, event_loop));

View File

@ -63,7 +63,9 @@ pub struct Handshake {
/// A copy of received encryped auth packet /// A copy of received encryped auth packet
pub auth_cipher: Bytes, pub auth_cipher: Bytes,
/// A copy of received encryped ack packet /// A copy of received encryped ack packet
pub ack_cipher: Bytes pub ack_cipher: Bytes,
/// This Handshake is marked for deleteion flag
pub expired: bool,
} }
const AUTH_PACKET_SIZE: usize = 307; const AUTH_PACKET_SIZE: usize = 307;
@ -84,6 +86,7 @@ impl Handshake {
remote_nonce: H256::new(), remote_nonce: H256::new(),
auth_cipher: Bytes::new(), auth_cipher: Bytes::new(),
ack_cipher: Bytes::new(), ack_cipher: Bytes::new(),
expired: false,
}) })
} }
@ -97,6 +100,16 @@ impl Handshake {
self.connection.token() self.connection.token()
} }
/// Mark this handshake as inactive to be deleted lated.
pub fn set_expired(&mut self) {
self.expired = true;
}
/// Check if this handshake is expired.
pub fn expired(&self) -> bool {
self.expired
}
/// Start a handhsake /// Start a handhsake
pub fn start<Message>(&mut self, io: &IoContext<Message>, host: &HostInfo, originated: bool) -> Result<(), UtilError> where Message: Send + Clone{ pub fn start<Message>(&mut self, io: &IoContext<Message>, host: &HostInfo, originated: bool) -> Result<(), UtilError> where Message: Send + Clone{
self.originated = originated; self.originated = originated;
@ -118,6 +131,7 @@ impl Handshake {
/// Readable IO handler. Drives the state change. /// Readable IO handler. Drives the state change.
pub fn readable<Message>(&mut self, io: &IoContext<Message>, host: &HostInfo) -> Result<(), UtilError> where Message: Send + Clone { pub fn readable<Message>(&mut self, io: &IoContext<Message>, host: &HostInfo) -> Result<(), UtilError> where Message: Send + Clone {
if !self.expired() {
io.clear_timer(self.connection.token).unwrap(); io.clear_timer(self.connection.token).unwrap();
match self.state { match self.state {
HandshakeState::ReadingAuth => { HandshakeState::ReadingAuth => {
@ -138,27 +152,35 @@ impl Handshake {
if self.state != HandshakeState::StartSession { if self.state != HandshakeState::StartSession {
try!(io.update_registration(self.connection.token)); try!(io.update_registration(self.connection.token));
} }
}
Ok(()) Ok(())
} }
/// Writabe IO handler. /// Writabe IO handler.
pub fn writable<Message>(&mut self, io: &IoContext<Message>, _host: &HostInfo) -> Result<(), UtilError> where Message: Send + Clone { pub fn writable<Message>(&mut self, io: &IoContext<Message>, _host: &HostInfo) -> Result<(), UtilError> where Message: Send + Clone {
if !self.expired() {
io.clear_timer(self.connection.token).unwrap(); io.clear_timer(self.connection.token).unwrap();
try!(self.connection.writable()); try!(self.connection.writable());
if self.state != HandshakeState::StartSession { if self.state != HandshakeState::StartSession {
io.update_registration(self.connection.token).unwrap(); io.update_registration(self.connection.token).unwrap();
} }
}
Ok(()) Ok(())
} }
/// Register the socket with the event loop /// Register the socket with the event loop
pub fn register_socket<Host:Handler<Timeout=Token>>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> { pub fn register_socket<Host:Handler<Timeout=Token>>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
if !self.expired() {
try!(self.connection.register_socket(reg, event_loop)); try!(self.connection.register_socket(reg, event_loop));
}
Ok(()) Ok(())
} }
/// Update socket registration with the event loop.
pub fn update_socket<Host:Handler<Timeout=Token>>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> { pub fn update_socket<Host:Handler<Timeout=Token>>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
if !self.expired() {
try!(self.connection.update_socket(reg, event_loop)); try!(self.connection.update_socket(reg, event_loop));
}
Ok(()) Ok(())
} }

View File

@ -190,11 +190,11 @@ impl<'s, Message> NetworkContext<'s, Message> where Message: Send + Sync + Clone
let session = { self.sessions.read().unwrap().get(peer).cloned() }; let session = { self.sessions.read().unwrap().get(peer).cloned() };
if let Some(session) = session { if let Some(session) = session {
session.lock().unwrap().deref_mut().send_packet(self.protocol, packet_id as u8, &data).unwrap_or_else(|e| { session.lock().unwrap().deref_mut().send_packet(self.protocol, packet_id as u8, &data).unwrap_or_else(|e| {
warn!(target: "net", "Send error: {:?}", e); warn!(target: "network", "Send error: {:?}", e);
}); //TODO: don't copy vector data }); //TODO: don't copy vector data
try!(self.io.update_registration(peer)); try!(self.io.update_registration(peer));
} else { } else {
trace!(target: "net", "Send: Peer no longer exist") trace!(target: "network", "Send: Peer no longer exist")
} }
Ok(()) Ok(())
} }
@ -470,18 +470,18 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
.take(min(MAX_HANDSHAKES_PER_ROUND, handshake_limit - handshake_count)) { .take(min(MAX_HANDSHAKES_PER_ROUND, handshake_limit - handshake_count)) {
self.connect_peer(&id, io); self.connect_peer(&id, io);
} }
debug!(target: "net", "Connecting peers: {} sessions, {} pending", self.session_count(), self.handshake_count()); debug!(target: "network", "Connecting peers: {} sessions, {} pending", self.session_count(), self.handshake_count());
} }
#[cfg_attr(feature="dev", allow(single_match))] #[cfg_attr(feature="dev", allow(single_match))]
fn connect_peer(&self, id: &NodeId, io: &IoContext<NetworkIoMessage<Message>>) { fn connect_peer(&self, id: &NodeId, io: &IoContext<NetworkIoMessage<Message>>) {
if self.have_session(id) if self.have_session(id)
{ {
trace!("Aborted connect. Node already connected."); trace!(target: "network", "Aborted connect. Node already connected.");
return; return;
} }
if self.connecting_to(id) { if self.connecting_to(id) {
trace!("Aborted connect. Node already connecting."); trace!(target: "network", "Aborted connect. Node already connecting.");
return; return;
} }
@ -493,7 +493,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
node.endpoint.address node.endpoint.address
} }
else { else {
debug!("Connection to expired node aborted"); debug!(target: "network", "Connection to expired node aborted");
return; return;
} }
}; };
@ -515,16 +515,16 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
if handshakes.insert_with(|token| { if handshakes.insert_with(|token| {
let mut handshake = Handshake::new(token, id, socket, &nonce, self.stats.clone()).expect("Can't create handshake"); let mut handshake = Handshake::new(token, id, socket, &nonce, self.stats.clone()).expect("Can't create handshake");
handshake.start(io, &self.info.read().unwrap(), id.is_some()).and_then(|_| io.register_stream(token)).unwrap_or_else (|e| { handshake.start(io, &self.info.read().unwrap(), id.is_some()).and_then(|_| io.register_stream(token)).unwrap_or_else (|e| {
debug!(target: "net", "Handshake create error: {:?}", e); debug!(target: "network", "Handshake create error: {:?}", e);
}); });
Arc::new(Mutex::new(handshake)) Arc::new(Mutex::new(handshake))
}).is_none() { }).is_none() {
debug!("Max handshakes reached"); debug!(target: "network", "Max handshakes reached");
} }
} }
fn accept(&self, io: &IoContext<NetworkIoMessage<Message>>) { fn accept(&self, io: &IoContext<NetworkIoMessage<Message>>) {
trace!(target: "net", "accept"); trace!(target: "network", "Accepting incoming connection");
loop { loop {
let socket = match self.tcp_listener.lock().unwrap().accept() { let socket = match self.tcp_listener.lock().unwrap().accept() {
Ok(None) => break, Ok(None) => break,
@ -544,7 +544,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
if let Some(handshake) = handshake { if let Some(handshake) = handshake {
let mut h = handshake.lock().unwrap(); let mut h = handshake.lock().unwrap();
if let Err(e) = h.writable(io, &self.info.read().unwrap()) { if let Err(e) = h.writable(io, &self.info.read().unwrap()) {
debug!(target: "net", "Handshake write error: {}:{:?}", token, e); trace!(target: "network", "Handshake write error: {}: {:?}", token, e);
} }
} }
} }
@ -554,9 +554,9 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
if let Some(session) = session { if let Some(session) = session {
let mut s = session.lock().unwrap(); let mut s = session.lock().unwrap();
if let Err(e) = s.writable(io, &self.info.read().unwrap()) { if let Err(e) = s.writable(io, &self.info.read().unwrap()) {
debug!(target: "net", "Session write error: {}:{:?}", token, e); trace!(target: "network", "Session write error: {}: {:?}", token, e);
} }
io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Session registration error: {:?}", e)); io.update_registration(token).unwrap_or_else(|e| debug!(target: "network", "Session registration error: {:?}", e));
} }
} }
@ -571,7 +571,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
if let Some(handshake) = handshake { if let Some(handshake) = handshake {
let mut h = handshake.lock().unwrap(); let mut h = handshake.lock().unwrap();
if let Err(e) = h.readable(io, &self.info.read().unwrap()) { if let Err(e) = h.readable(io, &self.info.read().unwrap()) {
debug!(target: "net", "Handshake read error: {}:{:?}", token, e); debug!(target: "network", "Handshake read error: {}: {:?}", token, e);
kill = true; kill = true;
} }
if h.done() { if h.done() {
@ -583,9 +583,9 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
return; return;
} else if create_session { } else if create_session {
self.start_session(token, io); self.start_session(token, io);
io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Session registration error: {:?}", e)); return;
} }
io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Token registration error: {:?}", e)); io.update_registration(token).unwrap_or_else(|e| debug!(target: "network", "Token registration error: {:?}", e));
} }
fn session_readable(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) { fn session_readable(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) {
@ -597,7 +597,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
let mut s = session.lock().unwrap(); let mut s = session.lock().unwrap();
match s.readable(io, &self.info.read().unwrap()) { match s.readable(io, &self.info.read().unwrap()) {
Err(e) => { Err(e) => {
debug!(target: "net", "Session read error: {}:{:?}", token, e); debug!(target: "network", "Session read error: {}: {:?}", token, e);
kill = true; kill = true;
}, },
Ok(SessionData::Ready) => { Ok(SessionData::Ready) => {
@ -613,7 +613,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
packet_id, packet_id,
}) => { }) => {
match self.handlers.read().unwrap().get(protocol) { match self.handlers.read().unwrap().get(protocol) {
None => { warn!(target: "net", "No handler found for protocol: {:?}", protocol) }, None => { warn!(target: "network", "No handler found for protocol: {:?}", protocol) },
Some(_) => packet_data = Some((protocol, packet_id, data)), Some(_) => packet_data = Some((protocol, packet_id, data)),
} }
}, },
@ -631,7 +631,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
let h = self.handlers.read().unwrap().get(p).unwrap().clone(); let h = self.handlers.read().unwrap().get(p).unwrap().clone();
h.read(&NetworkContext::new(io, p, Some(token), self.sessions.clone()), &token, packet_id, &data[1..]); h.read(&NetworkContext::new(io, p, Some(token), self.sessions.clone()), &token, packet_id, &data[1..]);
} }
io.update_registration(token).unwrap_or_else(|e| debug!(target: "net", "Token registration error: {:?}", e)); io.update_registration(token).unwrap_or_else(|e| debug!(target: "network", "Token registration error: {:?}", e));
} }
fn start_session(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) { fn start_session(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) {
@ -642,26 +642,25 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
// turn a handshake into a session // turn a handshake into a session
let mut sessions = self.sessions.write().unwrap(); let mut sessions = self.sessions.write().unwrap();
let mut h = handshakes.remove(token).unwrap(); let mut h = handshakes.get_mut(token).unwrap().lock().unwrap();
// wait for other threads to stop using it if h.expired {
{ return;
while Arc::get_mut(&mut h).is_none() {
h.lock().ok();
} }
}
let h = Arc::try_unwrap(h).ok().unwrap().into_inner().unwrap();
let originated = h.originated; let originated = h.originated;
let mut session = match Session::new(h, &self.info.read().unwrap()) { let mut session = match Session::new(&mut h, &self.info.read().unwrap()) {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
debug!("Session creation error: {:?}", e); debug!(target: "network", "Session creation error: {:?}", e);
return; return;
} }
}; };
let result = sessions.insert_with(move |session_token| { let result = sessions.insert_with(move |session_token| {
session.set_token(session_token); session.set_token(session_token);
io.update_registration(session_token).expect("Error updating session registration"); io.deregister_stream(token).expect("Error deleting handshake registration");
h.set_expired();
io.register_stream(session_token).expect("Error creating session registration");
self.stats.inc_sessions(); self.stats.inc_sessions();
trace!(target: "network", "Creating session {} -> {}", token, session_token);
if !originated { if !originated {
// Add it no node table // Add it no node table
if let Ok(address) = session.remote_addr() { if let Ok(address) = session.remote_addr() {
@ -690,13 +689,19 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
FIRST_HANDSHAKE ... LAST_HANDSHAKE => { FIRST_HANDSHAKE ... LAST_HANDSHAKE => {
let handshakes = self.handshakes.write().unwrap(); let handshakes = self.handshakes.write().unwrap();
if let Some(handshake) = handshakes.get(token).cloned() { if let Some(handshake) = handshakes.get(token).cloned() {
failure_id = Some(handshake.lock().unwrap().id().clone()); let mut handshake = handshake.lock().unwrap();
if !handshake.expired() {
handshake.set_expired();
failure_id = Some(handshake.id().clone());
io.deregister_stream(token).expect("Error deregistering stream");
}
} }
}, },
FIRST_SESSION ... LAST_SESSION => { FIRST_SESSION ... LAST_SESSION => {
let sessions = self.sessions.write().unwrap(); let sessions = self.sessions.write().unwrap();
if let Some(session) = sessions.get(token).cloned() { if let Some(session) = sessions.get(token).cloned() {
let s = session.lock().unwrap(); let mut s = session.lock().unwrap();
if !s.expired() {
if s.is_ready() { if s.is_ready() {
for (p, _) in self.handlers.read().unwrap().iter() { for (p, _) in self.handlers.read().unwrap().iter() {
if s.have_capability(p) { if s.have_capability(p) {
@ -704,12 +709,14 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
} }
} }
} }
s.set_expired();
failure_id = Some(s.id().clone()); failure_id = Some(s.id().clone());
io.deregister_stream(token).expect("Error deregistering stream");
}
} }
}, },
_ => {}, _ => {},
} }
io.deregister_stream(token).expect("Error deregistering stream");
if let Some(id) = failure_id { if let Some(id) = failure_id {
if remote { if remote {
self.nodes.write().unwrap().note_failure(&id); self.nodes.write().unwrap().note_failure(&id);
@ -764,11 +771,11 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
} }
fn stream_hup(&self, io: &IoContext<NetworkIoMessage<Message>>, stream: StreamToken) { fn stream_hup(&self, io: &IoContext<NetworkIoMessage<Message>>, stream: StreamToken) {
trace!(target: "net", "Hup: {}", stream); trace!(target: "network", "Hup: {}", stream);
match stream { match stream {
FIRST_SESSION ... LAST_SESSION => self.connection_closed(stream, io), FIRST_SESSION ... LAST_SESSION => self.connection_closed(stream, io),
FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_closed(stream, io), FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_closed(stream, io),
_ => warn!(target: "net", "Unexpected hup"), _ => warn!(target: "network", "Unexpected hup"),
}; };
} }
@ -818,7 +825,7 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
}, },
_ => match self.timers.read().unwrap().get(&token).cloned() { _ => match self.timers.read().unwrap().get(&token).cloned() {
Some(timer) => match self.handlers.read().unwrap().get(timer.protocol).cloned() { Some(timer) => match self.handlers.read().unwrap().get(timer.protocol).cloned() {
None => { warn!(target: "net", "No handler found for protocol: {:?}", timer.protocol) }, None => { warn!(target: "network", "No handler found for protocol: {:?}", timer.protocol) },
Some(h) => { h.timeout(&NetworkContext::new(io, timer.protocol, None, self.sessions.clone()), timer.token); } Some(h) => { h.timeout(&NetworkContext::new(io, timer.protocol, None, self.sessions.clone()), timer.token); }
}, },
None => { warn!("Unknown timer token: {}", token); } // timer is not registerd through us None => { warn!("Unknown timer token: {}", token); } // timer is not registerd through us
@ -874,7 +881,10 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
fn register_stream(&self, stream: StreamToken, reg: Token, event_loop: &mut EventLoop<IoManager<NetworkIoMessage<Message>>>) { fn register_stream(&self, stream: StreamToken, reg: Token, event_loop: &mut EventLoop<IoManager<NetworkIoMessage<Message>>>) {
match stream { match stream {
FIRST_SESSION ... LAST_SESSION => { FIRST_SESSION ... LAST_SESSION => {
warn!("Unexpected session stream registration"); let session = { self.sessions.read().unwrap().get(stream).cloned() };
if let Some(session) = session {
session.lock().unwrap().register_socket(reg, event_loop).expect("Error registering socket");
}
} }
FIRST_HANDSHAKE ... LAST_HANDSHAKE => { FIRST_HANDSHAKE ... LAST_HANDSHAKE => {
let connection = { self.handshakes.read().unwrap().get(stream).cloned() }; let connection = { self.handshakes.read().unwrap().get(stream).cloned() };

View File

@ -41,6 +41,8 @@ pub struct Session {
connection: EncryptedConnection, connection: EncryptedConnection,
/// Session ready flag. Set after successfull Hello packet exchange /// Session ready flag. Set after successfull Hello packet exchange
had_hello: bool, had_hello: bool,
/// Session is no longer active flag.
expired: bool,
ping_time_ns: u64, ping_time_ns: u64,
pong_time_ns: Option<u64>, pong_time_ns: Option<u64>,
} }
@ -109,8 +111,9 @@ const PACKET_USER: u8 = 0x10;
const PACKET_LAST: u8 = 0x7f; const PACKET_LAST: u8 = 0x7f;
impl Session { impl Session {
/// Create a new session out of comepleted handshake. Consumes handshake object. /// Create a new session out of comepleted handshake. This clones the handshake connection object
pub fn new(h: Handshake, host: &HostInfo) -> Result<Session, UtilError> { /// and leaves the handhsake in limbo to be deregistered from the event loop.
pub fn new(h: &mut Handshake, host: &HostInfo) -> Result<Session, UtilError> {
let id = h.id.clone(); let id = h.id.clone();
let connection = try!(EncryptedConnection::new(h)); let connection = try!(EncryptedConnection::new(h));
let mut session = Session { let mut session = Session {
@ -125,6 +128,7 @@ impl Session {
}, },
ping_time_ns: 0, ping_time_ns: 0,
pong_time_ns: None, pong_time_ns: None,
expired: false,
}; };
try!(session.write_hello(host)); try!(session.write_hello(host));
try!(session.send_ping()); try!(session.send_ping());
@ -141,6 +145,16 @@ impl Session {
self.had_hello self.had_hello
} }
/// Mark this session as inactive to be deleted lated.
pub fn set_expired(&mut self) {
self.expired = true;
}
/// Check if this session is expired.
pub fn expired(&self) -> bool {
self.expired
}
/// Replace socket token /// Replace socket token
pub fn set_token(&mut self, token: StreamToken) { pub fn set_token(&mut self, token: StreamToken) {
self.connection.set_token(token); self.connection.set_token(token);
@ -153,6 +167,9 @@ impl Session {
/// Readable IO handler. Returns packet data if available. /// Readable IO handler. Returns packet data if available.
pub fn readable<Message>(&mut self, io: &IoContext<Message>, host: &HostInfo) -> Result<SessionData, UtilError> where Message: Send + Sync + Clone { pub fn readable<Message>(&mut self, io: &IoContext<Message>, host: &HostInfo) -> Result<SessionData, UtilError> where Message: Send + Sync + Clone {
if self.expired() {
return Ok(SessionData::None)
}
match try!(self.connection.readable(io)) { match try!(self.connection.readable(io)) {
Some(data) => Ok(try!(self.read_packet(data, host))), Some(data) => Ok(try!(self.read_packet(data, host))),
None => Ok(SessionData::None) None => Ok(SessionData::None)
@ -161,6 +178,9 @@ impl Session {
/// Writable IO handler. Sends pending packets. /// Writable IO handler. Sends pending packets.
pub fn writable<Message>(&mut self, io: &IoContext<Message>, _host: &HostInfo) -> Result<(), UtilError> where Message: Send + Sync + Clone { pub fn writable<Message>(&mut self, io: &IoContext<Message>, _host: &HostInfo) -> Result<(), UtilError> where Message: Send + Sync + Clone {
if self.expired() {
return Ok(())
}
self.connection.writable(io) self.connection.writable(io)
} }
@ -169,8 +189,20 @@ impl Session {
self.info.capabilities.iter().any(|c| c.protocol == protocol) self.info.capabilities.iter().any(|c| c.protocol == protocol)
} }
/// Register the session socket with the event loop
pub fn register_socket<Host:Handler<Timeout = Token>>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
if self.expired() {
return Ok(());
}
try!(self.connection.register_socket(reg, event_loop));
Ok(())
}
/// Update registration with the event loop. Should be called at the end of the IO handler. /// Update registration with the event loop. Should be called at the end of the IO handler.
pub fn update_socket<Host:Handler>(&self, reg:Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> { pub fn update_socket<Host:Handler>(&self, reg:Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
if self.expired() {
return Ok(());
}
self.connection.update_socket(reg, event_loop) self.connection.update_socket(reg, event_loop)
} }

View File

@ -74,7 +74,7 @@ impl NetworkProtocolHandler<TestProtocolMessage> for TestProtocol {
} }
fn connected(&self, io: &NetworkContext<TestProtocolMessage>, peer: &PeerId) { fn connected(&self, io: &NetworkContext<TestProtocolMessage>, peer: &PeerId) {
assert!(io.peer_info(*peer).contains("parity")); assert!(io.peer_info(*peer).contains("Parity"));
if self.drop_session { if self.drop_session {
io.disconnect_peer(*peer) io.disconnect_peer(*peer)
} else { } else {