diff --git a/src/network/connection.rs b/src/network/connection.rs index 2a8025b83..131aae318 100644 --- a/src/network/connection.rs +++ b/src/network/connection.rs @@ -154,7 +154,7 @@ pub struct EncryptedConnection { read_state: EncryptedConnectionState, idle_timeout: Option, protocol_id: u16, - payload_len: u32, + payload_len: usize, } impl EncryptedConnection { @@ -223,7 +223,7 @@ impl EncryptedConnection { self.egress_mac.clone().finalize(&mut packet[16..32]); self.encoder.encrypt(&mut RefReadBuffer::new(&payload), &mut RefWriteBuffer::new(&mut packet[32..(32 + len)]), padding == 0).expect("Invalid length or padding"); if padding != 0 { - let pad = [08; 16]; + let pad = [0u8; 16]; self.encoder.encrypt(&mut RefReadBuffer::new(&pad[0..padding]), &mut RefWriteBuffer::new(&mut packet[(32 + len)..(32 + len + padding)]), true).expect("Invalid length or padding"); } self.egress_mac.update(&packet[32..(32 + len + padding)]); @@ -252,7 +252,7 @@ impl EncryptedConnection { let header_rlp = UntrustedRlp::new(&hdec[3..6]); let protocol_id = try!(header_rlp.val_at::(0)); - self.payload_len = length; + self.payload_len = length as usize; self.protocol_id = protocol_id; self.read_state = EncryptedConnectionState::Payload; @@ -264,7 +264,7 @@ impl EncryptedConnection { fn read_payload(&mut self, payload: &[u8]) -> Result { let padding = (16 - (self.payload_len % 16)) % 16; - let full_length = (self.payload_len + padding + 16) as usize; + let full_length = self.payload_len + padding + 16; if payload.len() != full_length { return Err(Error::Auth); } @@ -277,9 +277,10 @@ impl EncryptedConnection { return Err(Error::Auth); } - let mut packet = vec![0u8; self.payload_len as usize]; - self.decoder.decrypt(&mut RefReadBuffer::new(&payload[0..(full_length - 16)]), &mut RefWriteBuffer::new(&mut packet), false).expect("Invalid length or padding"); - packet.resize(self.payload_len as usize, 0u8); + let mut packet = vec![0u8; self.payload_len]; + self.decoder.decrypt(&mut RefReadBuffer::new(&payload[0..self.payload_len]), &mut RefWriteBuffer::new(&mut packet), false).expect("Invalid length or padding"); + let mut pad_buf = [0u8; 16]; + self.decoder.decrypt(&mut RefReadBuffer::new(&payload[self.payload_len..(payload.len() - 16)]), &mut RefWriteBuffer::new(&mut pad_buf), false).expect("Invalid length or padding"); Ok(Packet { protocol: self.protocol_id, data: packet @@ -299,7 +300,6 @@ impl EncryptedConnection { pub fn readable(&mut self, event_loop: &mut EventLoop) -> Result, Error> { self.idle_timeout.map(|t| event_loop.clear_timeout(t)); - try!(self.connection.reregister(event_loop)); match self.read_state { EncryptedConnectionState::Header => { match try!(self.connection.readable()) { @@ -326,7 +326,6 @@ impl EncryptedConnection { pub fn writable(&mut self, event_loop: &mut EventLoop) -> Result<(), Error> { self.idle_timeout.map(|t| event_loop.clear_timeout(t)); try!(self.connection.writable()); - try!(self.connection.reregister(event_loop)); Ok(()) } @@ -337,6 +336,12 @@ impl EncryptedConnection { try!(self.connection.reregister(event_loop)); Ok(()) } + + pub fn reregister(&mut self, event_loop: &mut EventLoop) -> Result<(), Error> { + try!(self.connection.reregister(event_loop)); + Ok(()) + } + } #[test] diff --git a/src/network/host.rs b/src/network/host.rs index 9e2b3e101..b9ca0152b 100644 --- a/src/network/host.rs +++ b/src/network/host.rs @@ -295,11 +295,10 @@ impl Host { pub fn start(event_loop: &mut EventLoop) -> Result<(), Error> { let config = NetworkConfiguration::new(); /* - match ::ifaces::Interface::get_all().unwrap().into_iter().filter(|x| x.kind == ::ifaces::Kind::Packet && x.addr.is_some()).next() { - Some(iface) => config.public_address = iface.addr.unwrap(), - None => warn!("No public network interface"), - } - */ + match ::ifaces::Interface::get_all().unwrap().into_iter().filter(|x| x.kind == ::ifaces::Kind::Packet && x.addr.is_some()).next() { + Some(iface) => config.public_address = iface.addr.unwrap(), + None => warn!("No public network interface"), + */ let addr = config.listen_address; // Setup the server socket @@ -487,8 +486,17 @@ impl Host { if create_session { self.start_session(token, event_loop); } + match self.connections.get_mut(token) { + Some(&mut ConnectionEntry::Session(ref mut s)) => { + s.reregister(event_loop).unwrap_or_else(|e| debug!(target: "net", "Session registration error: {:?}", e)); + }, + _ => (), + } } + fn connection_closed(&mut self, token: Token, event_loop: &mut EventLoop) { + self.kill_connection(token, event_loop); + } fn connection_readable(&mut self, token: Token, event_loop: &mut EventLoop) { let mut kill = false; @@ -549,6 +557,12 @@ impl Host { h.read(&mut HostIo::new(p, Some(token), event_loop, &mut self.connections, &mut self.timers), &token.as_usize(), packet_id, &data[1..]); } + match self.connections.get_mut(token) { + Some(&mut ConnectionEntry::Session(ref mut s)) => { + s.reregister(event_loop).unwrap_or_else(|e| debug!(target: "net", "Session registration error: {:?}", e)); + }, + _ => (), + } } fn start_session(&mut self, token: Token, event_loop: &mut EventLoop) { @@ -570,7 +584,23 @@ impl Host { self.kill_connection(token, event_loop) } - fn kill_connection(&mut self, token: Token, _event_loop: &mut EventLoop) { + fn kill_connection(&mut self, token: Token, event_loop: &mut EventLoop) { + let mut to_disconnect: Vec = Vec::new(); + match self.connections.get_mut(token) { + Some(&mut ConnectionEntry::Handshake(_)) => (), // just abandon handshake + Some(&mut ConnectionEntry::Session(ref mut s)) if s.is_ready() => { + for (p, _) in self.handlers.iter_mut() { + if s.have_capability(p) { + to_disconnect.push(p); + } + } + }, + _ => (), + } + for p in to_disconnect { + let mut h = self.handlers.get_mut(p).unwrap(); + h.disconnected(&mut HostIo::new(p, Some(token), event_loop, &mut self.connections, &mut self.timers), &token.as_usize()); + } self.connections.remove(token); } } @@ -580,7 +610,14 @@ impl Handler for Host { type Message = HostMessage; fn ready(&mut self, event_loop: &mut EventLoop, token: Token, events: EventSet) { - if events.is_readable() { + if events.is_hup() { + trace!(target: "net", "hup"); + match token.as_usize() { + FIRST_CONNECTION ... LAST_CONNECTION => self.connection_closed(token, event_loop), + _ => warn!(target: "net", "Unexpected hup"), + }; + } + else if events.is_readable() { match token.as_usize() { TCP_ACCEPT => self.accept(event_loop), IDLE => self.maintain_network(event_loop), diff --git a/src/network/session.rs b/src/network/session.rs index 720902150..8e89ff2c4 100644 --- a/src/network/session.rs +++ b/src/network/session.rs @@ -83,6 +83,10 @@ impl Session { Ok(session) } + pub fn is_ready(&self) -> bool { + self.had_hello + } + pub fn readable(&mut self, event_loop: &mut EventLoop, host: &HostInfo) -> Result { match try!(self.connection.readable(event_loop)) { Some(data) => self.read_packet(data, host), @@ -98,6 +102,10 @@ impl Session { self.info.capabilities.iter().any(|c| c.protocol == protocol) } + pub fn reregister(&mut self, event_loop: &mut EventLoop) -> Result<(), Error> { + self.connection.reregister(event_loop) + } + pub fn send_packet(&mut self, protocol: &str, packet_id: u8, data: &[u8]) -> Result<(), Error> { let mut i = 0usize; while protocol != self.info.capabilities[i].protocol { @@ -159,7 +167,7 @@ impl Session { fn write_hello(&mut self, host: &HostInfo) -> Result<(), Error> { let mut rlp = RlpStream::new(); - rlp.append(&(PACKET_HELLO as u32)); + rlp.append_raw(&[PACKET_HELLO as u8], 0); rlp.append_list(5) .append(&host.protocol_version) .append(&host.client_version) @@ -217,11 +225,11 @@ impl Session { } fn write_ping(&mut self) -> Result<(), Error> { - self.send(try!(Session::prepare(PACKET_PING, 0))) + self.send(try!(Session::prepare(PACKET_PING))) } fn write_pong(&mut self) -> Result<(), Error> { - self.send(try!(Session::prepare(PACKET_PONG, 0))) + self.send(try!(Session::prepare(PACKET_PONG))) } fn disconnect(&mut self, reason: DisconnectReason) -> Error { @@ -233,10 +241,10 @@ impl Session { Error::Disconnect(reason) } - fn prepare(packet_id: u8, items: usize) -> Result { - let mut rlp = RlpStream::new_list(1); + fn prepare(packet_id: u8) -> Result { + let mut rlp = RlpStream::new(); rlp.append(&(packet_id as u32)); - rlp.append_list(items); + rlp.append_list(0); Ok(rlp) } diff --git a/src/rlp/untrusted_rlp.rs b/src/rlp/untrusted_rlp.rs index 5a12cbc5e..452a198bb 100644 --- a/src/rlp/untrusted_rlp.rs +++ b/src/rlp/untrusted_rlp.rs @@ -188,7 +188,7 @@ impl<'a, 'view> View<'a, 'view> for UntrustedRlp<'a> where 'a: 'view { } fn val_at(&self, index: usize) -> Result where T: Decodable { - self.at(index).unwrap().as_val() + try!(self.at(index)).as_val() } }