This commit is contained in:
arkpar
2016-06-17 18:26:54 +02:00
parent 67ffac1df9
commit c340d8a34f
12 changed files with 99 additions and 70 deletions

View File

@@ -142,7 +142,7 @@ mod tests {
#[test]
fn test_service_register_handler () {
let mut service = IoService::<MyMessage>::start().expect("Error creating network service");
let service = IoService::<MyMessage>::start().expect("Error creating network service");
service.register_handler(Arc::new(MyHandler)).unwrap();
}

View File

@@ -17,7 +17,6 @@
use std::sync::*;
use std::thread::{self, JoinHandle};
use std::collections::HashMap;
use std::ops::Deref;
use mio::*;
use crossbeam::sync::chase_lev;
use slab::Slab;
@@ -37,13 +36,6 @@ pub type HandlerId = usize;
pub const TOKENS_PER_HANDLER: usize = 16384;
const MAX_HANDLERS: usize = 8;
fn compare_arcs<T: ?Sized>(a: Arc<T>, b: Arc<T>) -> bool {
let p1 = &*a as *const T;
let p2 = &*b as *const T;
info!("{:p} == {:p} : {}", p1, p2 , p1 == p2);
p1 == p2
}
/// Messages used to communicate with the event loop from other threads.
#[derive(Clone)]
pub enum IoMessage<Message> where Message: Send + Clone + Sized {
@@ -214,30 +206,31 @@ impl<Message> Handler for IoManager<Message> where Message: Send + Clone + Sync
fn ready(&mut self, _event_loop: &mut EventLoop<Self>, token: Token, events: EventSet) {
let handler_index = token.as_usize() / TOKENS_PER_HANDLER;
let token_id = token.as_usize() % TOKENS_PER_HANDLER;
let handler = self.handlers.get(handler_index).unwrap_or_else(|| panic!("Unexpected stream token: {}", token.as_usize())).clone();
if events.is_hup() {
self.worker_channel.push(Work { work_type: WorkType::Hup, token: token_id, handler: handler.clone(), handler_id: handler_index });
}
else {
if events.is_readable() {
self.worker_channel.push(Work { work_type: WorkType::Readable, token: token_id, handler: handler.clone(), handler_id: handler_index });
if let Some(handler) = self.handlers.get(handler_index) {
if events.is_hup() {
self.worker_channel.push(Work { work_type: WorkType::Hup, token: token_id, handler: handler.clone(), handler_id: handler_index });
}
if events.is_writable() {
self.worker_channel.push(Work { work_type: WorkType::Writable, token: token_id, handler: handler.clone(), handler_id: handler_index });
else {
if events.is_readable() {
self.worker_channel.push(Work { work_type: WorkType::Readable, token: token_id, handler: handler.clone(), handler_id: handler_index });
}
if events.is_writable() {
self.worker_channel.push(Work { work_type: WorkType::Writable, token: token_id, handler: handler.clone(), handler_id: handler_index });
}
}
self.work_ready.notify_all();
}
self.work_ready.notify_all();
}
fn timeout(&mut self, event_loop: &mut EventLoop<Self>, token: Token) {
let handler_index = token.as_usize() / TOKENS_PER_HANDLER;
let token_id = token.as_usize() % TOKENS_PER_HANDLER;
let handler = self.handlers.get(handler_index).unwrap_or_else(|| panic!("Unexpected stream token: {}", token.as_usize())).clone();
if let Some(timer) = self.timers.read().unwrap().get(&token.as_usize()) {
event_loop.timeout_ms(token, timer.delay).expect("Error re-registering user timer");
self.worker_channel.push(Work { work_type: WorkType::Timeout, token: token_id, handler: handler, handler_id: handler_index });
self.work_ready.notify_all();
if let Some(handler) = self.handlers.get(handler_index) {
if let Some(timer) = self.timers.read().unwrap().get(&token.as_usize()) {
event_loop.timeout_ms(token, timer.delay).expect("Error re-registering user timer");
self.worker_channel.push(Work { work_type: WorkType::Timeout, token: token_id, handler: handler.clone(), handler_id: handler_index });
self.work_ready.notify_all();
}
}
}
@@ -254,7 +247,6 @@ impl<Message> Handler for IoManager<Message> where Message: Send + Clone + Sync
IoMessage::RemoveHandler { handler_id } => {
// TODO: flush event loop
self.handlers.remove(handler_id);
info!("{} left", self.handlers.count());
},
IoMessage::AddTimer { handler_id, token, delay } => {
let timer_id = token + handler_id * TOKENS_PER_HANDLER;
@@ -268,21 +260,24 @@ impl<Message> Handler for IoManager<Message> where Message: Send + Clone + Sync
}
},
IoMessage::RegisterStream { handler_id, token } => {
let handler = self.handlers.get(handler_id).expect("Unknown handler id").clone();
handler.register_stream(token, Token(token + handler_id * TOKENS_PER_HANDLER), event_loop);
if let Some(handler) = self.handlers.get(handler_id) {
handler.register_stream(token, Token(token + handler_id * TOKENS_PER_HANDLER), event_loop);
}
},
IoMessage::DeregisterStream { handler_id, token } => {
let handler = self.handlers.get(handler_id).expect("Unknown handler id").clone();
handler.deregister_stream(token, event_loop);
// unregister a timer associated with the token (if any)
let timer_id = token + handler_id * TOKENS_PER_HANDLER;
if let Some(timer) = self.timers.write().unwrap().remove(&timer_id) {
event_loop.clear_timeout(timer.timeout);
if let Some(handler) = self.handlers.get(handler_id) {
handler.deregister_stream(token, event_loop);
// unregister a timer associated with the token (if any)
let timer_id = token + handler_id * TOKENS_PER_HANDLER;
if let Some(timer) = self.timers.write().unwrap().remove(&timer_id) {
event_loop.clear_timeout(timer.timeout);
}
}
},
IoMessage::UpdateStreamRegistration { handler_id, token } => {
let handler = self.handlers.get(handler_id).expect("Unknown handler id").clone();
handler.update_stream(token, Token(token + handler_id * TOKENS_PER_HANDLER), event_loop);
if let Some(handler) = self.handlers.get(handler_id) {
handler.update_stream(token, Token(token + handler_id * TOKENS_PER_HANDLER), event_loop);
}
},
IoMessage::UserMessage(data) => {
//TODO: better way to iterate the slab

View File

@@ -333,6 +333,7 @@ pub struct Host<Message> where Message: Send + Sync + Clone {
stats: Arc<NetworkStats>,
pinned_nodes: Vec<NodeId>,
num_sessions: AtomicUsize,
stopping: AtomicBool,
}
impl<Message> Host<Message> where Message: Send + Sync + Clone {
@@ -384,6 +385,7 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
stats: stats,
pinned_nodes: Vec::new(),
num_sessions: AtomicUsize::new(0),
stopping: AtomicBool::new(false),
};
let boot_nodes = host.info.read().unwrap().config.boot_nodes.clone();
@@ -422,19 +424,18 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
}
pub fn stop(&self, io: &IoContext<NetworkIoMessage<Message>>) -> Result<(), UtilError> {
self.stopping.store(true, AtomicOrdering::Release);
let mut to_kill = Vec::new();
for e in self.sessions.write().unwrap().iter_mut() {
let mut s = e.lock().unwrap();
if !s.keep_alive(io) {
s.disconnect(io, DisconnectReason::PingTimeout);
to_kill.push(s.token());
}
s.disconnect(io, DisconnectReason::ClientQuit);
to_kill.push(s.token());
}
for p in to_kill {
trace!(target: "network", "Ping timeout: {}", p);
trace!(target: "network", "Disconnecting on shutdown: {}", p);
self.kill_connection(p, io, true);
}
io.unregister_handler();
try!(io.unregister_handler());
Ok(())
}
@@ -790,13 +791,6 @@ impl<Message> Host<Message> where Message: Send + Sync + Clone {
}
}
impl<Message> Drop for Host<Message> where Message: Send + Sync + Clone {
fn drop(&mut self) {
info!("Dropping host");
}
}
impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Message: Send + Sync + Clone + 'static {
/// Initialize networking
fn initialize(&self, io: &IoContext<NetworkIoMessage<Message>>) {
@@ -814,6 +808,9 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
}
fn stream_readable(&self, io: &IoContext<NetworkIoMessage<Message>>, stream: StreamToken) {
if self.stopping.load(AtomicOrdering::Acquire) {
return;
}
match stream {
FIRST_SESSION ... LAST_SESSION => self.session_readable(stream, io),
DISCOVERY => {
@@ -829,6 +826,9 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
}
fn stream_writable(&self, io: &IoContext<NetworkIoMessage<Message>>, stream: StreamToken) {
if self.stopping.load(AtomicOrdering::Acquire) {
return;
}
match stream {
FIRST_SESSION ... LAST_SESSION => self.session_writable(stream, io),
DISCOVERY => {
@@ -840,6 +840,9 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
}
fn timeout(&self, io: &IoContext<NetworkIoMessage<Message>>, token: TimerToken) {
if self.stopping.load(AtomicOrdering::Acquire) {
return;
}
match token {
IDLE => self.maintain_network(io),
INIT_PUBLIC => self.init_public_interface(io).unwrap_or_else(|e|
@@ -870,6 +873,9 @@ impl<Message> IoHandler<NetworkIoMessage<Message>> for Host<Message> where Messa
}
fn message(&self, io: &IoContext<NetworkIoMessage<Message>>, message: &NetworkIoMessage<Message>) {
if self.stopping.load(AtomicOrdering::Acquire) {
return;
}
match *message {
NetworkIoMessage::AddHandler {
ref handler,
@@ -1031,6 +1037,6 @@ fn host_client_url() {
let mut config = NetworkConfiguration::new();
let key = h256_from_hex("6f7b0d801bc7b5ce7bbd930b84fd0369b3eb25d09be58d64ba811091046f3aa2");
config.use_secret = Some(key);
let host: Host<u32> = Host::new(config).unwrap();
let host: Host<u32> = Host::new(config, Arc::new(NetworkStats::new())).unwrap();
assert!(host.local_url().starts_with("enode://101b3ef5a4ea7a1c7928e24c4c75fd053c235d7b80c22ae5c03d145d0ac7396e2a4ffff9adee3133a7b05044a5cee08115fd65145e5165d646bde371010d803c@"));
}

View File

@@ -56,8 +56,9 @@
//! }
//!
//! fn main () {
//! let mut service = NetworkService::<MyMessage>::start(NetworkConfiguration::new_local()).expect("Error creating network service");
//! let mut service = NetworkService::<MyMessage>::new(NetworkConfiguration::new_local()).expect("Error creating network service");
//! service.register_protocol(Arc::new(MyHandler), "myproto", &[1u8]);
//! service.start().expect("Error starting service");
//!
//! // Wait for quit condition
//! // ...

View File

@@ -107,7 +107,7 @@ impl<Message> NetworkService<Message> where Message: Send + Sync + Clone + 'stat
if let Some(ref host) = *host {
info!("Unregistering handler");
let io = IoContext::new(self.io_service.channel(), 0); //TODO: take token id from host
host.stop(&io);
try!(host.stop(&io));
}
*host = None;
Ok(())

View File

@@ -97,7 +97,8 @@ impl NetworkProtocolHandler<TestProtocolMessage> for TestProtocol {
#[test]
fn net_service() {
let mut service = NetworkService::<TestProtocolMessage>::start(NetworkConfiguration::new_local()).expect("Error creating network service");
let service = NetworkService::<TestProtocolMessage>::new(NetworkConfiguration::new_local()).expect("Error creating network service");
service.start().unwrap();
service.register_protocol(Arc::new(TestProtocol::new(false)), "myproto", &[1u8]).unwrap();
}
@@ -108,12 +109,14 @@ fn net_connect() {
let mut config1 = NetworkConfiguration::new_local();
config1.use_secret = Some(key1.secret().clone());
config1.boot_nodes = vec![ ];
let mut service1 = NetworkService::<TestProtocolMessage>::start(config1).unwrap();
let mut service1 = NetworkService::<TestProtocolMessage>::new(config1).unwrap();
service1.start().unwrap();
let handler1 = TestProtocol::register(&mut service1, false);
let mut config2 = NetworkConfiguration::new_local();
info!("net_connect: local URL: {}", service1.local_url());
config2.boot_nodes = vec![ service1.local_url() ];
let mut service2 = NetworkService::<TestProtocolMessage>::start(config2).unwrap();
info!("net_connect: local URL: {}", service1.local_url().unwrap());
config2.boot_nodes = vec![ service1.local_url().unwrap() ];
let mut service2 = NetworkService::<TestProtocolMessage>::new(config2).unwrap();
service2.start().unwrap();
let handler2 = TestProtocol::register(&mut service2, false);
while !handler1.got_packet() && !handler2.got_packet() && (service1.stats().sessions() == 0 || service2.stats().sessions() == 0) {
thread::sleep(Duration::from_millis(50));
@@ -122,17 +125,28 @@ fn net_connect() {
assert!(service2.stats().sessions() >= 1);
}
#[test]
fn net_start_stop() {
let config = NetworkConfiguration::new_local();
let service = NetworkService::<TestProtocolMessage>::new(config).unwrap();
service.start().unwrap();
service.stop().unwrap();
service.start().unwrap();
}
#[test]
fn net_disconnect() {
let key1 = KeyPair::create().unwrap();
let mut config1 = NetworkConfiguration::new_local();
config1.use_secret = Some(key1.secret().clone());
config1.boot_nodes = vec![ ];
let mut service1 = NetworkService::<TestProtocolMessage>::start(config1).unwrap();
let mut service1 = NetworkService::<TestProtocolMessage>::new(config1).unwrap();
service1.start().unwrap();
let handler1 = TestProtocol::register(&mut service1, false);
let mut config2 = NetworkConfiguration::new_local();
config2.boot_nodes = vec![ service1.local_url() ];
let mut service2 = NetworkService::<TestProtocolMessage>::start(config2).unwrap();
config2.boot_nodes = vec![ service1.local_url().unwrap() ];
let mut service2 = NetworkService::<TestProtocolMessage>::new(config2).unwrap();
service2.start().unwrap();
let handler2 = TestProtocol::register(&mut service2, true);
while !(handler1.got_disconnect() && handler2.got_disconnect()) {
thread::sleep(Duration::from_millis(50));
@@ -144,7 +158,8 @@ fn net_disconnect() {
#[test]
fn net_timeout() {
let config = NetworkConfiguration::new_local();
let mut service = NetworkService::<TestProtocolMessage>::start(config).unwrap();
let mut service = NetworkService::<TestProtocolMessage>::new(config).unwrap();
service.start().unwrap();
let handler = TestProtocol::register(&mut service, false);
while !handler.got_timeout() {
thread::sleep(Duration::from_millis(50));