Networking refactoring (#1172)

* Networking refactoring

* Make sure the same socket is reused

* Safer atomic ordering

* Replaced eq with ==
This commit is contained in:
Arkadiy Paronyan 2016-06-02 11:49:56 +02:00 committed by Gav Wood
parent e170916563
commit 8596a347ea
6 changed files with 209 additions and 309 deletions

7
Cargo.lock generated
View File

@ -390,7 +390,7 @@ dependencies = [
"rustc_version 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)", "rustc_version 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", "serde 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
"sha3 0.1.0", "sha3 0.1.0",
"slab 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)", "slab 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
"target_info 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "target_info 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"time 0.1.35 (registry+https://github.com/rust-lang/crates.io-index)", "time 0.1.35 (registry+https://github.com/rust-lang/crates.io-index)",
"tiny-keccak 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)", "tiny-keccak 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)",
@ -1176,6 +1176,11 @@ name = "slab"
version = "0.1.3" version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "slab"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]] [[package]]
name = "solicit" name = "solicit"
version = "0.4.4" version = "0.4.4"

View File

@ -25,7 +25,7 @@ elastic-array = "0.4"
heapsize = "0.3" heapsize = "0.3"
itertools = "0.4" itertools = "0.4"
crossbeam = "0.2" crossbeam = "0.2"
slab = "0.1" slab = "0.2"
sha3 = { path = "sha3" } sha3 = { path = "sha3" }
serde = "0.7.0" serde = "0.7.0"
clippy = { version = "0.0.69", optional = true} clippy = { version = "0.0.69", optional = true}

View File

@ -170,16 +170,16 @@ impl Connection {
self.token self.token
} }
/// Replace socket token
pub fn set_token(&mut self, token: StreamToken) {
self.token = token;
}
/// Get remote peer address /// Get remote peer address
pub fn remote_addr(&self) -> io::Result<SocketAddr> { pub fn remote_addr(&self) -> io::Result<SocketAddr> {
self.socket.peer_addr() self.socket.peer_addr()
} }
/// Get remote peer address string
pub fn remote_addr_str(&self) -> String {
self.socket.peer_addr().map(|a| a.to_string()).unwrap_or_else(|_| "Unknown".to_owned())
}
/// Clone this connection. Clears the receiving buffer of the returned connection. /// Clone this connection. Clears the receiving buffer of the returned connection.
pub fn try_clone(&self) -> io::Result<Self> { pub fn try_clone(&self) -> io::Result<Self> {
Ok(Connection { Ok(Connection {
@ -196,7 +196,7 @@ impl Connection {
/// Register this connection with the IO event loop. /// Register this connection with the IO event loop.
pub fn register_socket<Host: Handler>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> io::Result<()> { pub fn register_socket<Host: Handler>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> io::Result<()> {
trace!(target: "network", "connection register; token={:?}", reg); trace!(target: "network", "connection register; token={:?}", reg);
if let Err(e) = event_loop.register(&self.socket, reg, self.interest, PollOpt::edge() | PollOpt::oneshot()) { if let Err(e) = event_loop.register(&self.socket, reg, self.interest, PollOpt::edge() /* | PollOpt::oneshot() */) { // TODO: oneshot is broken on windows
trace!(target: "network", "Failed to register {:?}, {:?}", reg, e); trace!(target: "network", "Failed to register {:?}, {:?}", reg, e);
} }
Ok(()) Ok(())
@ -205,7 +205,7 @@ impl Connection {
/// Update connection registration. Should be called at the end of the IO handler. /// Update connection registration. Should be called at the end of the IO handler.
pub fn update_socket<Host: Handler>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> io::Result<()> { pub fn update_socket<Host: Handler>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> io::Result<()> {
trace!(target: "network", "connection reregister; token={:?}", reg); trace!(target: "network", "connection reregister; token={:?}", reg);
event_loop.reregister( &self.socket, reg, self.interest, PollOpt::edge() | PollOpt::oneshot()).or_else(|e| { event_loop.reregister( &self.socket, reg, self.interest, PollOpt::edge() /* | PollOpt::oneshot() */ ).or_else(|e| { // TODO: oneshot is broken on windows
trace!(target: "network", "Failed to reregister {:?}, {:?}", reg, e); trace!(target: "network", "Failed to reregister {:?}, {:?}", reg, e);
Ok(()) Ok(())
}) })
@ -246,7 +246,7 @@ enum EncryptedConnectionState {
/// https://github.com/ethereum/devp2p/blob/master/rlpx.md#framing /// https://github.com/ethereum/devp2p/blob/master/rlpx.md#framing
pub struct EncryptedConnection { pub struct EncryptedConnection {
/// Underlying tcp connection /// Underlying tcp connection
connection: Connection, pub connection: Connection,
/// Egress data encryptor /// Egress data encryptor
encoder: CtrMode<AesSafe256Encryptor>, encoder: CtrMode<AesSafe256Encryptor>,
/// Ingress data decryptor /// Ingress data decryptor
@ -266,27 +266,6 @@ pub struct EncryptedConnection {
} }
impl EncryptedConnection { impl EncryptedConnection {
/// Get socket token
pub fn token(&self) -> StreamToken {
self.connection.token
}
/// Replace socket token
pub fn set_token(&mut self, token: StreamToken) {
self.connection.set_token(token);
}
/// Get remote peer address
pub fn remote_addr(&self) -> io::Result<SocketAddr> {
self.connection.remote_addr()
}
/// Check if this connection has data to be sent.
pub fn is_sending(&self) -> bool {
self.connection.is_sending()
}
/// Create an encrypted connection out of the handshake. Consumes a handshake object. /// Create an encrypted connection out of the handshake. Consumes a handshake object.
pub fn new(handshake: &mut Handshake) -> Result<EncryptedConnection, UtilError> { pub fn new(handshake: &mut Handshake) -> Result<EncryptedConnection, UtilError> {
let shared = try!(crypto::ecdh::agree(handshake.ecdhe.secret(), &handshake.remote_ephemeral)); let shared = try!(crypto::ecdh::agree(handshake.ecdhe.secret(), &handshake.remote_ephemeral));
@ -323,8 +302,10 @@ impl EncryptedConnection {
ingress_mac.update(&mac_material); ingress_mac.update(&mac_material);
ingress_mac.update(if handshake.originated { &handshake.ack_cipher } else { &handshake.auth_cipher }); ingress_mac.update(if handshake.originated { &handshake.ack_cipher } else { &handshake.auth_cipher });
let old_connection = try!(handshake.connection.try_clone());
let connection = ::std::mem::replace(&mut handshake.connection, old_connection);
let mut enc = EncryptedConnection { let mut enc = EncryptedConnection {
connection: try!(handshake.connection.try_clone()), connection: connection,
encoder: encoder, encoder: encoder,
decoder: decoder, decoder: decoder,
mac_encoder: mac_encoder, mac_encoder: mac_encoder,
@ -463,24 +444,6 @@ impl EncryptedConnection {
try!(self.connection.writable()); try!(self.connection.writable());
Ok(()) Ok(())
} }
/// Register socket with the event lpop. This should be called at the end of the event loop.
pub fn register_socket<Host:Handler>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
try!(self.connection.register_socket(reg, event_loop));
Ok(())
}
/// Update connection registration. This should be called at the end of the event loop.
pub fn update_socket<Host:Handler>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
try!(self.connection.update_socket(reg, event_loop));
Ok(())
}
/// Delete connection registration. This should be called at the end of the event loop.
pub fn deregister_socket<Host:Handler>(&self, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
try!(self.connection.deregister_socket(event_loop));
Ok(())
}
} }
#[test] #[test]

View File

@ -16,7 +16,6 @@
use std::sync::Arc; use std::sync::Arc;
use rand::random; use rand::random;
use mio::*;
use mio::tcp::*; use mio::tcp::*;
use hash::*; use hash::*;
use rlp::*; use rlp::*;
@ -102,21 +101,6 @@ impl Handshake {
}) })
} }
/// Get id of the remote node if known
pub fn id(&self) -> &NodeId {
&self.id
}
/// Get stream token id
pub fn token(&self) -> StreamToken {
self.connection.token()
}
/// Mark this handshake as inactive to be deleted lated.
pub fn set_expired(&mut self) {
self.expired = true;
}
/// Check if this handshake is expired. /// Check if this handshake is expired.
pub fn expired(&self) -> bool { pub fn expired(&self) -> bool {
self.expired self.expired
@ -177,7 +161,7 @@ impl Handshake {
} }
/// Writabe IO handler. /// Writabe IO handler.
pub fn writable<Message>(&mut self, io: &IoContext<Message>, _host: &HostInfo) -> Result<(), UtilError> where Message: Send + Clone { pub fn writable<Message>(&mut self, io: &IoContext<Message>) -> Result<(), UtilError> where Message: Send + Clone {
if !self.expired() { if !self.expired() {
io.clear_timer(self.connection.token).unwrap(); io.clear_timer(self.connection.token).unwrap();
try!(self.connection.writable()); try!(self.connection.writable());
@ -188,28 +172,6 @@ impl Handshake {
Ok(()) Ok(())
} }
/// Register the socket with the event loop
pub fn register_socket<Host:Handler<Timeout=Token>>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
if !self.expired() {
try!(self.connection.register_socket(reg, event_loop));
}
Ok(())
}
/// Update socket registration with the event loop.
pub fn update_socket<Host:Handler<Timeout=Token>>(&self, reg: Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
if !self.expired() {
try!(self.connection.update_socket(reg, event_loop));
}
Ok(())
}
/// Delete registration
pub fn deregister_socket<Host:Handler>(&self, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
try!(self.connection.deregister_socket(event_loop));
Ok(())
}
fn set_auth(&mut self, host_secret: &Secret, sig: &[u8], remote_public: &[u8], remote_nonce: &[u8], remote_version: u64) -> Result<(), UtilError> { fn set_auth(&mut self, host_secret: &Secret, sig: &[u8], remote_public: &[u8], remote_nonce: &[u8], remote_version: u64) -> Result<(), UtilError> {
self.id.clone_from_slice(remote_public); self.id.clone_from_slice(remote_public);
self.remote_nonce.clone_from_slice(remote_nonce); self.remote_nonce.clone_from_slice(remote_nonce);
@ -222,7 +184,7 @@ impl Handshake {
/// Parse, validate and confirm auth message /// Parse, validate and confirm auth message
fn read_auth(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> { fn read_auth(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> {
trace!(target:"network", "Received handshake auth from {:?}", self.connection.socket.peer_addr()); trace!(target:"network", "Received handshake auth from {:?}", self.connection.remote_addr_str());
if data.len() != V4_AUTH_PACKET_SIZE { if data.len() != V4_AUTH_PACKET_SIZE {
debug!(target:"net", "Wrong auth packet size"); debug!(target:"net", "Wrong auth packet size");
return Err(From::from(NetworkError::BadProtocol)); return Err(From::from(NetworkError::BadProtocol));
@ -253,7 +215,7 @@ impl Handshake {
} }
fn read_auth_eip8(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> { fn read_auth_eip8(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> {
trace!(target:"network", "Received EIP8 handshake auth from {:?}", self.connection.socket.peer_addr()); trace!(target:"network", "Received EIP8 handshake auth from {:?}", self.connection.remote_addr_str());
self.auth_cipher.extend_from_slice(data); self.auth_cipher.extend_from_slice(data);
let auth = try!(ecies::decrypt(secret, &self.auth_cipher[0..2], &self.auth_cipher[2..])); let auth = try!(ecies::decrypt(secret, &self.auth_cipher[0..2], &self.auth_cipher[2..]));
let rlp = UntrustedRlp::new(&auth); let rlp = UntrustedRlp::new(&auth);
@ -268,7 +230,7 @@ impl Handshake {
/// Parse and validate ack message /// Parse and validate ack message
fn read_ack(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> { fn read_ack(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> {
trace!(target:"network", "Received handshake auth to {:?}", self.connection.socket.peer_addr()); trace!(target:"network", "Received handshake auth to {:?}", self.connection.remote_addr_str());
if data.len() != V4_ACK_PACKET_SIZE { if data.len() != V4_ACK_PACKET_SIZE {
debug!(target:"net", "Wrong ack packet size"); debug!(target:"net", "Wrong ack packet size");
return Err(From::from(NetworkError::BadProtocol)); return Err(From::from(NetworkError::BadProtocol));
@ -296,7 +258,7 @@ impl Handshake {
} }
fn read_ack_eip8(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> { fn read_ack_eip8(&mut self, secret: &Secret, data: &[u8]) -> Result<(), UtilError> {
trace!(target:"network", "Received EIP8 handshake auth from {:?}", self.connection.socket.peer_addr()); trace!(target:"network", "Received EIP8 handshake auth from {:?}", self.connection.remote_addr_str());
self.ack_cipher.extend_from_slice(data); self.ack_cipher.extend_from_slice(data);
let ack = try!(ecies::decrypt(secret, &self.ack_cipher[0..2], &self.ack_cipher[2..])); let ack = try!(ecies::decrypt(secret, &self.ack_cipher[0..2], &self.ack_cipher[2..]));
let rlp = UntrustedRlp::new(&ack); let rlp = UntrustedRlp::new(&ack);
@ -309,7 +271,7 @@ impl Handshake {
/// Sends auth message /// Sends auth message
fn write_auth(&mut self, secret: &Secret, public: &Public) -> Result<(), UtilError> { fn write_auth(&mut self, secret: &Secret, public: &Public) -> Result<(), UtilError> {
trace!(target:"network", "Sending handshake auth to {:?}", self.connection.socket.peer_addr()); trace!(target:"network", "Sending handshake auth to {:?}", self.connection.remote_addr_str());
let mut data = [0u8; /*Signature::SIZE*/ 65 + /*H256::SIZE*/ 32 + /*Public::SIZE*/ 64 + /*H256::SIZE*/ 32 + 1]; //TODO: use associated constants let mut data = [0u8; /*Signature::SIZE*/ 65 + /*H256::SIZE*/ 32 + /*Public::SIZE*/ 64 + /*H256::SIZE*/ 32 + 1]; //TODO: use associated constants
let len = data.len(); let len = data.len();
{ {
@ -336,7 +298,7 @@ impl Handshake {
/// Sends ack message /// Sends ack message
fn write_ack(&mut self) -> Result<(), UtilError> { fn write_ack(&mut self) -> Result<(), UtilError> {
trace!(target:"network", "Sending handshake ack to {:?}", self.connection.socket.peer_addr()); trace!(target:"network", "Sending handshake ack to {:?}", self.connection.remote_addr_str());
let mut data = [0u8; 1 + /*Public::SIZE*/ 64 + /*H256::SIZE*/ 32]; //TODO: use associated constants let mut data = [0u8; 1 + /*Public::SIZE*/ 64 + /*H256::SIZE*/ 32]; //TODO: use associated constants
let len = data.len(); let len = data.len();
{ {
@ -355,7 +317,7 @@ impl Handshake {
/// Sends EIP8 ack message /// Sends EIP8 ack message
fn write_ack_eip8(&mut self) -> Result<(), UtilError> { fn write_ack_eip8(&mut self) -> Result<(), UtilError> {
trace!(target:"network", "Sending EIP8 handshake ack to {:?}", self.connection.socket.peer_addr()); trace!(target:"network", "Sending EIP8 handshake ack to {:?}", self.connection.remote_addr_str());
let mut rlp = RlpStream::new_list(3); let mut rlp = RlpStream::new_list(3);
rlp.append(self.ecdhe.public()); rlp.append(self.ecdhe.public());
rlp.append(&self.nonce); rlp.append(&self.nonce);

View File

@ -18,6 +18,7 @@ use std::net::{SocketAddr};
use std::collections::{HashMap}; use std::collections::{HashMap};
use std::str::{FromStr}; use std::str::{FromStr};
use std::sync::*; use std::sync::*;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::ops::*; use std::ops::*;
use std::cmp::min; use std::cmp::min;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
@ -31,7 +32,6 @@ use misc::version;
use crypto::*; use crypto::*;
use sha3::Hashable; use sha3::Hashable;
use rlp::*; use rlp::*;
use network::handshake::Handshake;
use network::session::{Session, SessionData}; use network::session::{Session, SessionData};
use error::*; use error::*;
use io::*; use io::*;
@ -44,8 +44,7 @@ use network::ip_utils::{map_external_address, select_public_address};
type Slab<T> = ::slab::Slab<T, usize>; type Slab<T> = ::slab::Slab<T, usize>;
const _DEFAULT_PORT: u16 = 30304; const MAX_SESSIONS: usize = 1024 + MAX_HANDSHAKES;
const MAX_SESSIONS: usize = 1024;
const MAX_HANDSHAKES: usize = 80; const MAX_HANDSHAKES: usize = 80;
const MAX_HANDSHAKES_PER_ROUND: usize = 32; const MAX_HANDSHAKES_PER_ROUND: usize = 32;
const MAINTENANCE_TIMEOUT: u64 = 1000; const MAINTENANCE_TIMEOUT: u64 = 1000;
@ -115,18 +114,17 @@ impl NetworkConfiguration {
} }
// Tokens // Tokens
const TCP_ACCEPT: usize = LAST_HANDSHAKE + 1; const TCP_ACCEPT: usize = SYS_TIMER + 1;
const IDLE: usize = LAST_HANDSHAKE + 2; const IDLE: usize = SYS_TIMER + 2;
const DISCOVERY: usize = LAST_HANDSHAKE + 3; const DISCOVERY: usize = SYS_TIMER + 3;
const DISCOVERY_REFRESH: usize = LAST_HANDSHAKE + 4; const DISCOVERY_REFRESH: usize = SYS_TIMER + 4;
const DISCOVERY_ROUND: usize = LAST_HANDSHAKE + 5; const DISCOVERY_ROUND: usize = SYS_TIMER + 5;
const INIT_PUBLIC: usize = LAST_HANDSHAKE + 6; const INIT_PUBLIC: usize = SYS_TIMER + 6;
const NODE_TABLE: usize = LAST_HANDSHAKE + 7; const NODE_TABLE: usize = SYS_TIMER + 7;
const FIRST_SESSION: usize = 0; const FIRST_SESSION: usize = 0;
const LAST_SESSION: usize = FIRST_SESSION + MAX_SESSIONS - 1; const LAST_SESSION: usize = FIRST_SESSION + MAX_SESSIONS - 1;
const FIRST_HANDSHAKE: usize = LAST_SESSION + 1; const USER_TIMER: usize = LAST_SESSION + 256;
const LAST_HANDSHAKE: usize = FIRST_HANDSHAKE + MAX_HANDSHAKES - 1; const SYS_TIMER: usize = LAST_SESSION + 1;
const USER_TIMER: usize = LAST_HANDSHAKE + 256;
/// Protocol handler level packet id /// Protocol handler level packet id
pub type PacketId = u8; pub type PacketId = u8;
@ -306,7 +304,6 @@ impl HostInfo {
} }
type SharedSession = Arc<Mutex<Session>>; type SharedSession = Arc<Mutex<Session>>;
type SharedHandshake = Arc<Mutex<Handshake>>;
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
struct ProtocolTimer { struct ProtocolTimer {
@ -318,7 +315,6 @@ struct ProtocolTimer {
pub struct Host<Message> where Message: Send + Sync + Clone { pub struct Host<Message> where Message: Send + Sync + Clone {
pub info: RwLock<HostInfo>, pub info: RwLock<HostInfo>,
tcp_listener: Mutex<TcpListener>, tcp_listener: Mutex<TcpListener>,
handshakes: Arc<RwLock<Slab<SharedHandshake>>>,
sessions: Arc<RwLock<Slab<SharedSession>>>, sessions: Arc<RwLock<Slab<SharedSession>>>,
discovery: Mutex<Option<Discovery>>, discovery: Mutex<Option<Discovery>>,
nodes: RwLock<NodeTable>, nodes: RwLock<NodeTable>,
@ -327,6 +323,7 @@ pub struct Host<Message> where Message: Send + Sync + Clone {
timer_counter: RwLock<usize>, timer_counter: RwLock<usize>,
stats: Arc<NetworkStats>, stats: Arc<NetworkStats>,
pinned_nodes: Vec<NodeId>, pinned_nodes: Vec<NodeId>,
num_sessions: AtomicUsize,
} }
impl<Message> Host<Message> where Message: Send + Sync + Clone { impl<Message> Host<Message> where Message: Send + Sync + Clone {
@ -370,7 +367,6 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
}), }),
discovery: Mutex::new(None), discovery: Mutex::new(None),
tcp_listener: Mutex::new(tcp_listener), tcp_listener: Mutex::new(tcp_listener),
handshakes: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_HANDSHAKE, MAX_HANDSHAKES))),
sessions: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_SESSION, MAX_SESSIONS))), sessions: Arc::new(RwLock::new(Slab::new_starting_at(FIRST_SESSION, MAX_SESSIONS))),
nodes: RwLock::new(NodeTable::new(path)), nodes: RwLock::new(NodeTable::new(path)),
handlers: RwLock::new(HashMap::new()), handlers: RwLock::new(HashMap::new()),
@ -378,6 +374,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
timer_counter: RwLock::new(USER_TIMER), timer_counter: RwLock::new(USER_TIMER),
stats: Arc::new(NetworkStats::default()), stats: Arc::new(NetworkStats::default()),
pinned_nodes: Vec::new(), pinned_nodes: Vec::new(),
num_sessions: AtomicUsize::new(0),
}; };
let boot_nodes = host.info.read().unwrap().config.boot_nodes.clone(); let boot_nodes = host.info.read().unwrap().config.boot_nodes.clone();
@ -477,19 +474,19 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
} }
fn have_session(&self, id: &NodeId) -> bool { fn have_session(&self, id: &NodeId) -> bool {
self.sessions.read().unwrap().iter().any(|e| e.lock().unwrap().info.id.eq(&id)) self.sessions.read().unwrap().iter().any(|e| e.lock().unwrap().info.id == Some(id.clone()))
} }
fn session_count(&self) -> usize { fn session_count(&self) -> usize {
self.sessions.read().unwrap().count() self.num_sessions.load(AtomicOrdering::Relaxed)
} }
fn connecting_to(&self, id: &NodeId) -> bool { fn connecting_to(&self, id: &NodeId) -> bool {
self.handshakes.read().unwrap().iter().any(|e| e.lock().unwrap().id.eq(&id)) self.sessions.read().unwrap().iter().any(|e| e.lock().unwrap().id() == Some(id))
} }
fn handshake_count(&self) -> usize { fn handshake_count(&self) -> usize {
self.handshakes.read().unwrap().count() self.sessions.read().unwrap().count() - self.session_count()
} }
fn keep_alive(&self, io: &IoContext<NetworkIoMessage<Message>>) { fn keep_alive(&self, io: &IoContext<NetworkIoMessage<Message>>) {
@ -565,21 +562,31 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
} }
} }
}; };
self.create_connection(socket, Some(id), io); if let Err(e) = self.create_connection(socket, Some(id), io) {
debug!(target: "network", "Can't create connection: {:?}", e);
}
} }
#[cfg_attr(feature="dev", allow(block_in_if_condition_stmt))] #[cfg_attr(feature="dev", allow(block_in_if_condition_stmt))]
fn create_connection(&self, socket: TcpStream, id: Option<&NodeId>, io: &IoContext<NetworkIoMessage<Message>>) { fn create_connection(&self, socket: TcpStream, id: Option<&NodeId>, io: &IoContext<NetworkIoMessage<Message>>) -> Result<(), UtilError> {
let nonce = self.info.write().unwrap().next_nonce(); let nonce = self.info.write().unwrap().next_nonce();
let mut handshakes = self.handshakes.write().unwrap(); let mut sessions = self.sessions.write().unwrap();
if handshakes.insert_with(|token| { let token = sessions.insert_with_opt(|token| {
let mut handshake = Handshake::new(token, id, socket, &nonce, self.stats.clone()).expect("Can't create handshake"); match Session::new(io, socket, token, id, &nonce, self.stats.clone(), &self.info.read().unwrap()) {
handshake.start(io, &self.info.read().unwrap(), id.is_some()).and_then(|_| io.register_stream(token)).unwrap_or_else (|e| { Ok(s) => Some(Arc::new(Mutex::new(s))),
debug!(target: "network", "Handshake create error: {:?}", e); Err(e) => {
debug!(target: "network", "Session create error: {:?}", e);
None
}
}
}); });
Arc::new(Mutex::new(handshake))
}).is_none() { match token {
debug!(target: "network", "Max handshakes reached"); Some(t) => io.register_stream(t),
None => {
debug!(target: "network", "Max sessions reached");
Ok(())
}
} }
} }
@ -594,21 +601,13 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
break break
}, },
}; };
self.create_connection(socket, None, io); if let Err(e) = self.create_connection(socket, None, io) {
debug!(target: "network", "Can't accept connection: {:?}", e);
}
} }
io.update_registration(TCP_ACCEPT).expect("Error registering TCP listener"); io.update_registration(TCP_ACCEPT).expect("Error registering TCP listener");
} }
fn handshake_writable(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) {
let handshake = { self.handshakes.read().unwrap().get(token).cloned() };
if let Some(handshake) = handshake {
let mut h = handshake.lock().unwrap();
if let Err(e) = h.writable(io, &self.info.read().unwrap()) {
trace!(target: "network", "Handshake write error: {}: {:?}", token, e);
}
}
}
fn session_writable(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) { fn session_writable(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) {
let session = { self.sessions.read().unwrap().get(token).cloned() }; let session = { self.sessions.read().unwrap().get(token).cloned() };
if let Some(session) = session { if let Some(session) = session {
@ -629,30 +628,6 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
self.kill_connection(token, io, true); self.kill_connection(token, io, true);
} }
fn handshake_readable(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) {
let mut create_session = false;
let mut kill = false;
let handshake = { self.handshakes.read().unwrap().get(token).cloned() };
if let Some(handshake) = handshake {
let mut h = handshake.lock().unwrap();
if let Err(e) = h.readable(io, &self.info.read().unwrap()) {
debug!(target: "network", "Handshake read error: {}: {:?}", token, e);
kill = true;
}
if h.done() {
create_session = true;
}
}
if kill {
self.kill_connection(token, io, true);
return;
} else if create_session {
self.start_session(token, io);
return;
}
io.update_registration(token).unwrap_or_else(|e| debug!(target: "network", "Token registration error: {:?}", e));
}
fn session_readable(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) { fn session_readable(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) {
let mut ready_data: Vec<ProtocolId> = Vec::new(); let mut ready_data: Vec<ProtocolId> = Vec::new();
let mut packet_data: Option<(ProtocolId, PacketId, Vec<u8>)> = None; let mut packet_data: Option<(ProtocolId, PacketId, Vec<u8>)> = None;
@ -662,17 +637,37 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
let mut s = session.lock().unwrap(); let mut s = session.lock().unwrap();
match s.readable(io, &self.info.read().unwrap()) { match s.readable(io, &self.info.read().unwrap()) {
Err(e) => { Err(e) => {
trace!(target: "network", "Session read error: {}:{} ({:?}) {:?}", token, s.id(), s.remote_addr(), e); trace!(target: "network", "Session read error: {}:{:?} ({:?}) {:?}", token, s.id(), s.remote_addr(), e);
match e { match e {
UtilError::Network(NetworkError::Disconnect(DisconnectReason::UselessPeer)) | UtilError::Network(NetworkError::Disconnect(DisconnectReason::UselessPeer)) |
UtilError::Network(NetworkError::Disconnect(DisconnectReason::IncompatibleProtocol)) => { UtilError::Network(NetworkError::Disconnect(DisconnectReason::IncompatibleProtocol)) => {
self.nodes.write().unwrap().mark_as_useless(s.id()); if let Some(id) = s.id() {
self.nodes.write().unwrap().mark_as_useless(id);
}
} }
_ => (), _ => (),
} }
kill = true; kill = true;
}, },
Ok(SessionData::Ready) => { Ok(SessionData::Ready) => {
if !s.info.originated {
let session_count = self.session_count();
let ideal_peers = { self.info.read().unwrap().deref().config.ideal_peers };
if session_count >= ideal_peers as usize {
s.disconnect(DisconnectReason::TooManyPeers);
return;
}
// Add it no node table
if let Ok(address) = s.remote_addr() {
let entry = NodeEntry { id: s.id().unwrap().clone(), endpoint: NodeEndpoint { address: address, udp_port: address.port() } };
self.nodes.write().unwrap().add_node(Node::new(entry.id.clone(), entry.endpoint.clone()));
let mut discovery = self.discovery.lock().unwrap();
if let Some(ref mut discovery) = *discovery.deref_mut() {
discovery.add_node(entry);
}
}
}
self.num_sessions.fetch_add(1, AtomicOrdering::SeqCst);
for (p, _) in self.handlers.read().unwrap().iter() { for (p, _) in self.handlers.read().unwrap().iter() {
if s.have_capability(p) { if s.have_capability(p) {
ready_data.push(p); ready_data.push(p);
@ -697,6 +692,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
} }
for p in ready_data { for p in ready_data {
let h = self.handlers.read().unwrap().get(p).unwrap().clone(); let h = self.handlers.read().unwrap().get(p).unwrap().clone();
self.stats.inc_sessions();
h.connected(&NetworkContext::new(io, p, session.clone(), self.sessions.clone()), &token); h.connected(&NetworkContext::new(io, p, session.clone(), self.sessions.clone()), &token);
} }
if let Some((p, packet_id, data)) = packet_data { if let Some((p, packet_id, data)) = packet_data {
@ -706,59 +702,6 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
io.update_registration(token).unwrap_or_else(|e| debug!(target: "network", "Token registration error: {:?}", e)); io.update_registration(token).unwrap_or_else(|e| debug!(target: "network", "Token registration error: {:?}", e));
} }
fn start_session(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) {
let mut handshakes = self.handshakes.write().unwrap();
if handshakes.get(token).is_none() {
return;
}
// turn a handshake into a session
let mut sessions = self.sessions.write().unwrap();
let mut h = handshakes.get_mut(token).unwrap().lock().unwrap();
if h.expired {
return;
}
io.deregister_stream(token).expect("Error deleting handshake registration");
h.set_expired();
let originated = h.originated;
let mut session = match Session::new(&mut h, &self.info.read().unwrap()) {
Ok(s) => s,
Err(e) => {
debug!(target: "network", "Session creation error: {:?}", e);
return;
}
};
if !originated {
let session_count = sessions.count();
let ideal_peers = { self.info.read().unwrap().deref().config.ideal_peers };
if session_count >= ideal_peers as usize {
session.disconnect(DisconnectReason::TooManyPeers);
return;
}
}
let result = sessions.insert_with(move |session_token| {
session.set_token(session_token);
io.register_stream(session_token).expect("Error creating session registration");
self.stats.inc_sessions();
trace!(target: "network", "Creating session {} -> {}:{} ({:?})", token, session_token, session.id(), session.remote_addr());
if !originated {
// Add it no node table
if let Ok(address) = session.remote_addr() {
let entry = NodeEntry { id: session.id().clone(), endpoint: NodeEndpoint { address: address, udp_port: address.port() } };
self.nodes.write().unwrap().add_node(Node::new(entry.id.clone(), entry.endpoint.clone()));
let mut discovery = self.discovery.lock().unwrap();
if let Some(ref mut discovery) = *discovery.deref_mut() {
discovery.add_node(entry);
}
}
}
Arc::new(Mutex::new(session))
});
if result.is_none() {
warn!("Max sessions reached");
}
}
fn connection_timeout(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) { fn connection_timeout(&self, token: StreamToken, io: &IoContext<NetworkIoMessage<Message>>) {
trace!(target: "network", "Connection timeout: {}", token); trace!(target: "network", "Connection timeout: {}", token);
self.kill_connection(token, io, true) self.kill_connection(token, io, true)
@ -770,17 +713,6 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
let mut deregister = false; let mut deregister = false;
let mut expired_session = None; let mut expired_session = None;
match token { match token {
FIRST_HANDSHAKE ... LAST_HANDSHAKE => {
let handshakes = self.handshakes.write().unwrap();
if let Some(handshake) = handshakes.get(token).cloned() {
let mut handshake = handshake.lock().unwrap();
if !handshake.expired() {
handshake.set_expired();
failure_id = Some(handshake.id().clone());
deregister = true;
}
}
},
FIRST_SESSION ... LAST_SESSION => { FIRST_SESSION ... LAST_SESSION => {
let sessions = self.sessions.write().unwrap(); let sessions = self.sessions.write().unwrap();
if let Some(session) = sessions.get(token).cloned() { if let Some(session) = sessions.get(token).cloned() {
@ -790,12 +722,13 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
if s.is_ready() { if s.is_ready() {
for (p, _) in self.handlers.read().unwrap().iter() { for (p, _) in self.handlers.read().unwrap().iter() {
if s.have_capability(p) { if s.have_capability(p) {
self.num_sessions.fetch_sub(1, AtomicOrdering::SeqCst);
to_disconnect.push(p); to_disconnect.push(p);
} }
} }
} }
s.set_expired(); s.set_expired();
failure_id = Some(s.id().clone()); failure_id = s.id().cloned();
} }
deregister = remote || s.done(); deregister = remote || s.done();
} }
@ -820,21 +753,12 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
fn update_nodes(&self, io: &IoContext<NetworkIoMessage<Message>>, node_changes: TableUpdates) { fn update_nodes(&self, io: &IoContext<NetworkIoMessage<Message>>, node_changes: TableUpdates) {
let mut to_remove: Vec<PeerId> = Vec::new(); let mut to_remove: Vec<PeerId> = Vec::new();
{
{
let handshakes = self.handshakes.write().unwrap();
for c in handshakes.iter() {
let h = c.lock().unwrap();
if node_changes.removed.contains(&h.id()) {
to_remove.push(h.token());
}
}
}
{ {
let sessions = self.sessions.write().unwrap(); let sessions = self.sessions.write().unwrap();
for c in sessions.iter() { for c in sessions.iter() {
let s = c.lock().unwrap(); let s = c.lock().unwrap();
if node_changes.removed.contains(&s.id()) { if let Some(id) = s.id() {
if node_changes.removed.contains(id) {
to_remove.push(s.token()); to_remove.push(s.token());
} }
} }
@ -860,7 +784,6 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
trace!(target: "network", "Hup: {}", stream); trace!(target: "network", "Hup: {}", stream);
match stream { match stream {
FIRST_SESSION ... LAST_SESSION => self.connection_closed(stream, io), FIRST_SESSION ... LAST_SESSION => self.connection_closed(stream, io),
FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_closed(stream, io),
_ => warn!(target: "network", "Unexpected hup"), _ => warn!(target: "network", "Unexpected hup"),
}; };
} }
@ -868,7 +791,6 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
fn stream_readable(&self, io: &IoContext<NetworkIoMessage<Message>>, stream: StreamToken) { fn stream_readable(&self, io: &IoContext<NetworkIoMessage<Message>>, stream: StreamToken) {
match stream { match stream {
FIRST_SESSION ... LAST_SESSION => self.session_readable(stream, io), FIRST_SESSION ... LAST_SESSION => self.session_readable(stream, io),
FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_readable(stream, io),
DISCOVERY => { DISCOVERY => {
let node_changes = { self.discovery.lock().unwrap().as_mut().unwrap().readable() }; let node_changes = { self.discovery.lock().unwrap().as_mut().unwrap().readable() };
if let Some(node_changes) = node_changes { if let Some(node_changes) = node_changes {
@ -884,7 +806,6 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
fn stream_writable(&self, io: &IoContext<NetworkIoMessage<Message>>, stream: StreamToken) { fn stream_writable(&self, io: &IoContext<NetworkIoMessage<Message>>, stream: StreamToken) {
match stream { match stream {
FIRST_SESSION ... LAST_SESSION => self.session_writable(stream, io), FIRST_SESSION ... LAST_SESSION => self.session_writable(stream, io),
FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.handshake_writable(stream, io),
DISCOVERY => { DISCOVERY => {
self.discovery.lock().unwrap().as_mut().unwrap().writable(); self.discovery.lock().unwrap().as_mut().unwrap().writable();
io.update_registration(DISCOVERY).expect("Error updating discovery registration"); io.update_registration(DISCOVERY).expect("Error updating discovery registration");
@ -899,7 +820,6 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
INIT_PUBLIC => self.init_public_interface(io).unwrap_or_else(|e| INIT_PUBLIC => self.init_public_interface(io).unwrap_or_else(|e|
warn!("Error initializing public interface: {:?}", e)), warn!("Error initializing public interface: {:?}", e)),
FIRST_SESSION ... LAST_SESSION => self.connection_timeout(token, io), FIRST_SESSION ... LAST_SESSION => self.connection_timeout(token, io),
FIRST_HANDSHAKE ... LAST_HANDSHAKE => self.connection_timeout(token, io),
DISCOVERY_REFRESH => { DISCOVERY_REFRESH => {
self.discovery.lock().unwrap().as_mut().unwrap().refresh(); self.discovery.lock().unwrap().as_mut().unwrap().refresh();
io.update_registration(DISCOVERY).expect("Error updating discovery registration"); io.update_registration(DISCOVERY).expect("Error updating discovery registration");
@ -966,7 +886,9 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
let session = { self.sessions.read().unwrap().get(*peer).cloned() }; let session = { self.sessions.read().unwrap().get(*peer).cloned() };
if let Some(session) = session { if let Some(session) = session {
session.lock().unwrap().disconnect(DisconnectReason::DisconnectRequested); session.lock().unwrap().disconnect(DisconnectReason::DisconnectRequested);
self.nodes.write().unwrap().mark_as_useless(session.lock().unwrap().id()); if let Some(id) = session.lock().unwrap().id() {
self.nodes.write().unwrap().mark_as_useless(id)
}
} }
trace!(target: "network", "Disabling peer {}", peer); trace!(target: "network", "Disabling peer {}", peer);
self.kill_connection(*peer, io, false); self.kill_connection(*peer, io, false);
@ -987,12 +909,6 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
session.lock().unwrap().register_socket(reg, event_loop).expect("Error registering socket"); session.lock().unwrap().register_socket(reg, event_loop).expect("Error registering socket");
} }
} }
FIRST_HANDSHAKE ... LAST_HANDSHAKE => {
let connection = { self.handshakes.read().unwrap().get(stream).cloned() };
if let Some(connection) = connection {
connection.lock().unwrap().register_socket(reg, event_loop).expect("Error registering socket");
}
}
DISCOVERY => self.discovery.lock().unwrap().as_ref().unwrap().register_socket(event_loop).expect("Error registering discovery socket"), DISCOVERY => self.discovery.lock().unwrap().as_ref().unwrap().register_socket(event_loop).expect("Error registering discovery socket"),
TCP_ACCEPT => event_loop.register(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error registering stream"), TCP_ACCEPT => event_loop.register(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error registering stream"),
_ => warn!("Unexpected stream registration") _ => warn!("Unexpected stream registration")
@ -1008,13 +924,6 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
connections.remove(stream); connections.remove(stream);
} }
} }
FIRST_HANDSHAKE ... LAST_HANDSHAKE => {
let mut connections = self.handshakes.write().unwrap();
if let Some(connection) = connections.get(stream).cloned() {
connection.lock().unwrap().deregister_socket(event_loop).expect("Error deregistering socket");
connections.remove(stream);
}
}
DISCOVERY => (), DISCOVERY => (),
_ => warn!("Unexpected stream deregistration") _ => warn!("Unexpected stream deregistration")
} }
@ -1028,12 +937,6 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
connection.lock().unwrap().update_socket(reg, event_loop).expect("Error updating socket"); connection.lock().unwrap().update_socket(reg, event_loop).expect("Error updating socket");
} }
} }
FIRST_HANDSHAKE ... LAST_HANDSHAKE => {
let connection = { self.handshakes.read().unwrap().get(stream).cloned() };
if let Some(connection) = connection {
connection.lock().unwrap().update_socket(reg, event_loop).expect("Error updating socket");
}
}
DISCOVERY => self.discovery.lock().unwrap().as_ref().unwrap().update_registration(event_loop).expect("Error reregistering discovery socket"), DISCOVERY => self.discovery.lock().unwrap().as_ref().unwrap().update_registration(event_loop).expect("Error reregistering discovery socket"),
TCP_ACCEPT => event_loop.reregister(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error reregistering stream"), TCP_ACCEPT => event_loop.reregister(self.tcp_listener.lock().unwrap().deref(), Token(TCP_ACCEPT), EventSet::all(), PollOpt::edge()).expect("Error reregistering stream"),
_ => warn!("Unexpected stream update") _ => warn!("Unexpected stream update")

View File

@ -16,15 +16,19 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::io; use std::io;
use std::sync::*;
use mio::*; use mio::*;
use mio::tcp::*;
use rlp::*; use rlp::*;
use network::connection::{EncryptedConnection, Packet}; use hash::*;
use network::connection::{EncryptedConnection, Packet, Connection};
use network::handshake::Handshake; use network::handshake::Handshake;
use error::*; use error::*;
use io::{IoContext, StreamToken}; use io::{IoContext, StreamToken};
use network::error::{NetworkError, DisconnectReason}; use network::error::{NetworkError, DisconnectReason};
use network::host::*; use network::host::*;
use network::node_table::NodeId; use network::node_table::NodeId;
use network::stats::NetworkStats;
use time; use time;
const PING_TIMEOUT_SEC: u64 = 30; const PING_TIMEOUT_SEC: u64 = 30;
@ -36,14 +40,18 @@ const PING_INTERVAL_SEC: u64 = 30;
pub struct Session { pub struct Session {
/// Shared session information /// Shared session information
pub info: SessionInfo, pub info: SessionInfo,
/// Underlying connection
connection: EncryptedConnection,
/// Session ready flag. Set after successfull Hello packet exchange /// Session ready flag. Set after successfull Hello packet exchange
had_hello: bool, had_hello: bool,
/// Session is no longer active flag. /// Session is no longer active flag.
expired: bool, expired: bool,
ping_time_ns: u64, ping_time_ns: u64,
pong_time_ns: Option<u64>, pong_time_ns: Option<u64>,
state: State,
}
enum State {
Handshake(Handshake),
Session(EncryptedConnection),
} }
/// Structure used to report various session events. /// Structure used to report various session events.
@ -65,7 +73,7 @@ pub enum SessionData {
/// Shared session information /// Shared session information
pub struct SessionInfo { pub struct SessionInfo {
/// Peer public key /// Peer public key
pub id: NodeId, pub id: Option<NodeId>,
/// Peer client ID /// Peer client ID
pub client_version: String, pub client_version: String,
/// Peer RLPx protocol version /// Peer RLPx protocol version
@ -74,6 +82,8 @@ pub struct SessionInfo {
capabilities: Vec<SessionCapabilityInfo>, capabilities: Vec<SessionCapabilityInfo>,
/// Peer ping delay in milliseconds /// Peer ping delay in milliseconds
pub ping_ms: Option<u64>, pub ping_ms: Option<u64>,
/// True if this session was originated by us.
pub originated: bool,
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
@ -112,31 +122,52 @@ const PACKET_LAST: u8 = 0x7f;
impl Session { impl Session {
/// Create a new session out of comepleted handshake. This clones the handshake connection object /// Create a new session out of comepleted handshake. This clones the handshake connection object
/// and leaves the handhsake in limbo to be deregistered from the event loop. /// and leaves the handhsake in limbo to be deregistered from the event loop.
pub fn new(h: &mut Handshake, host: &HostInfo) -> Result<Session, UtilError> { pub fn new<Message>(io: &IoContext<Message>, socket: TcpStream, token: StreamToken, id: Option<&NodeId>,
let id = h.id.clone(); nonce: &H256, stats: Arc<NetworkStats>, host: &HostInfo) -> Result<Session, UtilError>
let connection = try!(EncryptedConnection::new(h)); where Message: Send + Clone {
let mut session = Session { let originated = id.is_some();
connection: connection, let mut handshake = Handshake::new(token, id, socket, &nonce, stats).expect("Can't create handshake");
try!(handshake.start(io, host, originated));
Ok(Session {
state: State::Handshake(handshake),
had_hello: false, had_hello: false,
info: SessionInfo { info: SessionInfo {
id: id, id: id.cloned(),
client_version: String::new(), client_version: String::new(),
protocol_version: 0, protocol_version: 0,
capabilities: Vec::new(), capabilities: Vec::new(),
ping_ms: None, ping_ms: None,
originated: originated,
}, },
ping_time_ns: 0, ping_time_ns: 0,
pong_time_ns: None, pong_time_ns: None,
expired: false, expired: false,
})
}
fn complete_handshake(&mut self, host: &HostInfo) -> Result<(), UtilError> {
let connection = if let State::Handshake(ref mut h) = self.state {
self.info.id = Some(h.id.clone());
try!(EncryptedConnection::new(h))
} else {
panic!("Unexpected state");
}; };
try!(session.write_hello(host)); self.state = State::Session(connection);
try!(session.send_ping()); try!(self.write_hello(host));
Ok(session) try!(self.send_ping());
Ok(())
}
fn connection(&self) -> &Connection {
match self.state {
State::Handshake(ref h) => &h.connection,
State::Session(ref s) => &s.connection,
}
} }
/// Get id of the remote peer /// Get id of the remote peer
pub fn id(&self) -> &NodeId { pub fn id(&self) -> Option<&NodeId> {
&self.info.id self.info.id.as_ref()
} }
/// Check if session is ready to send/receive data /// Check if session is ready to send/receive data
@ -151,21 +182,20 @@ impl Session {
/// Check if this session is expired. /// Check if this session is expired.
pub fn expired(&self) -> bool { pub fn expired(&self) -> bool {
self.expired match self.state {
State::Handshake(ref h) => h.expired(),
_ => self.expired,
}
} }
/// Check if this session is over and there is nothing to be sent. /// Check if this session is over and there is nothing to be sent.
pub fn done(&self) -> bool { pub fn done(&self) -> bool {
self.expired() && !self.connection.is_sending() self.expired() && !self.connection().is_sending()
}
/// Replace socket token
pub fn set_token(&mut self, token: StreamToken) {
self.connection.set_token(token);
} }
/// Get remote peer address /// Get remote peer address
pub fn remote_addr(&self) -> io::Result<SocketAddr> { pub fn remote_addr(&self) -> io::Result<SocketAddr> {
self.connection.remote_addr() self.connection().remote_addr()
} }
/// Readable IO handler. Returns packet data if available. /// Readable IO handler. Returns packet data if available.
@ -173,15 +203,37 @@ impl Session {
if self.expired() { if self.expired() {
return Ok(SessionData::None) return Ok(SessionData::None)
} }
match try!(self.connection.readable(io)) { let mut create_session = false;
Some(data) => Ok(try!(self.read_packet(data, host))), let mut packet_data = None;
None => Ok(SessionData::None) match self.state {
State::Handshake(ref mut h) => {
try!(h.readable(io, host));
if h.done() {
create_session = true;
} }
} }
State::Session(ref mut c) => {
match try!(c.readable(io)) {
data @ Some(_) => packet_data = data,
None => return Ok(SessionData::None)
}
}
}
if let Some(data) = packet_data {
return Ok(try!(self.read_packet(data, host)));
}
if create_session {
try!(self.complete_handshake(host));
}
Ok(SessionData::None)
}
/// Writable IO handler. Sends pending packets. /// Writable IO handler. Sends pending packets.
pub fn writable<Message>(&mut self, io: &IoContext<Message>, _host: &HostInfo) -> Result<(), UtilError> where Message: Send + Sync + Clone { pub fn writable<Message>(&mut self, io: &IoContext<Message>, _host: &HostInfo) -> Result<(), UtilError> where Message: Send + Sync + Clone {
self.connection.writable(io) match self.state {
State::Handshake(ref mut h) => h.writable(io),
State::Session(ref mut s) => s.writable(io),
}
} }
/// Checks if peer supports given capability /// Checks if peer supports given capability
@ -194,18 +246,20 @@ impl Session {
if self.expired() { if self.expired() {
return Ok(()); return Ok(());
} }
try!(self.connection.register_socket(reg, event_loop)); try!(self.connection().register_socket(reg, event_loop));
Ok(()) Ok(())
} }
/// Update registration with the event loop. Should be called at the end of the IO handler. /// Update registration with the event loop. Should be called at the end of the IO handler.
pub fn update_socket<Host:Handler>(&self, reg:Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> { pub fn update_socket<Host:Handler>(&self, reg:Token, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
self.connection.update_socket(reg, event_loop) try!(self.connection().update_socket(reg, event_loop));
Ok(())
} }
/// Delete registration /// Delete registration
pub fn deregister_socket<Host:Handler>(&self, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> { pub fn deregister_socket<Host:Handler>(&self, event_loop: &mut EventLoop<Host>) -> Result<(), UtilError> {
self.connection.deregister_socket(event_loop) try!(self.connection().deregister_socket(event_loop));
Ok(())
} }
/// Send a protocol packet to peer. /// Send a protocol packet to peer.
@ -221,7 +275,7 @@ impl Session {
while protocol != self.info.capabilities[i].protocol { while protocol != self.info.capabilities[i].protocol {
i += 1; i += 1;
if i == self.info.capabilities.len() { if i == self.info.capabilities.len() {
debug!(target: "net", "Unknown protocol: {:?}", protocol); debug!(target: "network", "Unknown protocol: {:?}", protocol);
return Ok(()) return Ok(())
} }
} }
@ -229,11 +283,14 @@ impl Session {
let mut rlp = RlpStream::new(); let mut rlp = RlpStream::new();
rlp.append(&(pid as u32)); rlp.append(&(pid as u32));
rlp.append_raw(data, 1); rlp.append_raw(data, 1);
self.connection.send_packet(&rlp.out()) self.send(rlp)
} }
/// Keep this session alive. Returns false if ping timeout happened /// Keep this session alive. Returns false if ping timeout happened
pub fn keep_alive<Message>(&mut self, io: &IoContext<Message>) -> bool where Message: Send + Sync + Clone { pub fn keep_alive<Message>(&mut self, io: &IoContext<Message>) -> bool where Message: Send + Sync + Clone {
if let State::Handshake(_) = self.state {
return true;
}
let timed_out = if let Some(pong) = self.pong_time_ns { let timed_out = if let Some(pong) = self.pong_time_ns {
pong - self.ping_time_ns > PING_TIMEOUT_SEC * 1000_000_000 pong - self.ping_time_ns > PING_TIMEOUT_SEC * 1000_000_000
} else { } else {
@ -244,13 +301,13 @@ impl Session {
if let Err(e) = self.send_ping() { if let Err(e) = self.send_ping() {
debug!("Error sending ping message: {:?}", e); debug!("Error sending ping message: {:?}", e);
} }
io.update_registration(self.token()).unwrap_or_else(|e| debug!(target: "net", "Session registration error: {:?}", e)); io.update_registration(self.token()).unwrap_or_else(|e| debug!(target: "network", "Session registration error: {:?}", e));
} }
!timed_out !timed_out
} }
pub fn token(&self) -> StreamToken { pub fn token(&self) -> StreamToken {
self.connection.token() self.connection().token()
} }
fn read_packet(&mut self, packet: Packet, host: &HostInfo) -> Result<SessionData, UtilError> { fn read_packet(&mut self, packet: Packet, host: &HostInfo) -> Result<SessionData, UtilError> {
@ -288,7 +345,7 @@ impl Session {
while packet_id < self.info.capabilities[i].id_offset { while packet_id < self.info.capabilities[i].id_offset {
i += 1; i += 1;
if i == self.info.capabilities.len() { if i == self.info.capabilities.len() {
debug!(target: "net", "Unknown packet: {:?}", packet_id); debug!(target: "network", "Unknown packet: {:?}", packet_id);
return Ok(SessionData::None) return Ok(SessionData::None)
} }
} }
@ -299,7 +356,7 @@ impl Session {
Ok(SessionData::Packet { data: packet.data, protocol: protocol, packet_id: pid } ) Ok(SessionData::Packet { data: packet.data, protocol: protocol, packet_id: pid } )
}, },
_ => { _ => {
debug!(target: "net", "Unknown packet: {:?}", packet_id); debug!(target: "network", "Unknown packet: {:?}", packet_id);
Ok(SessionData::None) Ok(SessionData::None)
} }
} }
@ -314,7 +371,7 @@ impl Session {
.append(&host.capabilities) .append(&host.capabilities)
.append(&host.local_endpoint.address.port()) .append(&host.local_endpoint.address.port())
.append(host.id()); .append(host.id());
self.connection.send_packet(&rlp.out()) self.send(rlp)
} }
fn read_hello(&mut self, rlp: &UntrustedRlp, host: &HostInfo) -> Result<(), UtilError> { fn read_hello(&mut self, rlp: &UntrustedRlp, host: &HostInfo) -> Result<(), UtilError> {
@ -384,11 +441,13 @@ impl Session {
/// Disconnect this session /// Disconnect this session
pub fn disconnect(&mut self, reason: DisconnectReason) -> NetworkError { pub fn disconnect(&mut self, reason: DisconnectReason) -> NetworkError {
if let State::Session(_) = self.state {
let mut rlp = RlpStream::new(); let mut rlp = RlpStream::new();
rlp.append(&(PACKET_DISCONNECT as u32)); rlp.append(&(PACKET_DISCONNECT as u32));
rlp.begin_list(1); rlp.begin_list(1);
rlp.append(&(reason as u32)); rlp.append(&(reason as u32));
self.connection.send_packet(&rlp.out()).ok(); self.send(rlp).ok();
}
NetworkError::Disconnect(reason) NetworkError::Disconnect(reason)
} }
@ -400,7 +459,15 @@ impl Session {
} }
fn send(&mut self, rlp: RlpStream) -> Result<(), UtilError> { fn send(&mut self, rlp: RlpStream) -> Result<(), UtilError> {
self.connection.send_packet(&rlp.out()) match self.state {
State::Handshake(_) => {
warn!(target:"network", "Unexpected send request");
},
State::Session(ref mut s) => {
try!(s.send_packet(&rlp.out()))
},
}
Ok(())
} }
} }