Connection filter (#6359)

This commit is contained in:
Arkadiy Paronyan
2017-08-29 14:38:01 +02:00
committed by Gav Wood
parent 96e9a73a1b
commit d520aa2633
21 changed files with 346 additions and 24 deletions

View File

@@ -0,0 +1,31 @@
// Copyright 2015-2017 Parity Technologies (UK) Ltd.
// This file is part of Parity.
// Parity is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// Parity is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with Parity. If not, see <http://www.gnu.org/licenses/>.
//! Connection filter trait.
use super::NodeId;
/// Filtered connection direction.
pub enum ConnectionDirection {
Inbound,
Outbound,
}
/// Connection filter. Each connection is checked against `connection_allowed`.
pub trait ConnectionFilter : Send + Sync {
/// Filter a connection. Returns `true` if connection should be allowed. `false` if rejected.
fn connection_allowed(&self, own_id: &NodeId, connecting_id: &NodeId, direction: ConnectionDirection) -> bool;
}

View File

@@ -42,6 +42,7 @@ use discovery::{Discovery, TableUpdates, NodeEntry};
use ip_utils::{map_external_address, select_public_address};
use path::restrict_permissions_owner;
use parking_lot::{Mutex, RwLock};
use connection_filter::{ConnectionFilter, ConnectionDirection};
type Slab<T> = ::slab::Slab<T, usize>;
@@ -380,11 +381,12 @@ pub struct Host {
reserved_nodes: RwLock<HashSet<NodeId>>,
num_sessions: AtomicUsize,
stopping: AtomicBool,
filter: Option<Arc<ConnectionFilter>>,
}
impl Host {
/// Create a new instance
pub fn new(mut config: NetworkConfiguration, stats: Arc<NetworkStats>) -> Result<Host, NetworkError> {
pub fn new(mut config: NetworkConfiguration, stats: Arc<NetworkStats>, filter: Option<Arc<ConnectionFilter>>) -> Result<Host, NetworkError> {
let mut listen_address = match config.listen_address {
None => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), DEFAULT_PORT)),
Some(addr) => addr,
@@ -437,6 +439,7 @@ impl Host {
reserved_nodes: RwLock::new(HashSet::new()),
num_sessions: AtomicUsize::new(0),
stopping: AtomicBool::new(false),
filter: filter,
};
for n in boot_nodes {
@@ -691,8 +694,12 @@ impl Host {
let max_handshakes_per_round = max_handshakes / 2;
let mut started: usize = 0;
for id in nodes.filter(|id| !self.have_session(id) && !self.connecting_to(id) && *id != self_id)
.take(min(max_handshakes_per_round, max_handshakes - handshake_count)) {
for id in nodes.filter(|id|
!self.have_session(id) &&
!self.connecting_to(id) &&
*id != self_id &&
self.filter.as_ref().map_or(true, |f| f.connection_allowed(&self_id, &id, ConnectionDirection::Outbound))
).take(min(max_handshakes_per_round, max_handshakes - handshake_count)) {
self.connect_peer(&id, io);
started += 1;
}
@@ -827,7 +834,7 @@ impl Host {
Ok(SessionData::Ready) => {
self.num_sessions.fetch_add(1, AtomicOrdering::SeqCst);
let session_count = self.session_count();
let (min_peers, max_peers, reserved_only) = {
let (min_peers, max_peers, reserved_only, self_id) = {
let info = self.info.read();
let mut max_peers = info.config.max_peers;
for cap in s.info.capabilities.iter() {
@@ -836,7 +843,7 @@ impl Host {
break;
}
}
(info.config.min_peers as usize, max_peers as usize, info.config.non_reserved_mode == NonReservedPeerMode::Deny)
(info.config.min_peers as usize, max_peers as usize, info.config.non_reserved_mode == NonReservedPeerMode::Deny, info.id().clone())
};
let id = s.id().expect("Ready session always has id").clone();
@@ -852,6 +859,14 @@ impl Host {
break;
}
}
if !self.filter.as_ref().map_or(true, |f| f.connection_allowed(&self_id, &id, ConnectionDirection::Inbound)) {
trace!(target: "network", "Inbound connection not allowed for {:?}", id);
s.disconnect(io, DisconnectReason::UnexpectedIdentity);
kill = true;
break;
}
ready_id = Some(id);
// Add it to the node table
@@ -1266,7 +1281,7 @@ fn host_client_url() {
let mut config = NetworkConfiguration::new_local();
let key = "6f7b0d801bc7b5ce7bbd930b84fd0369b3eb25d09be58d64ba811091046f3aa2".parse().unwrap();
config.use_secret = Some(key);
let host: Host = Host::new(config, Arc::new(NetworkStats::new())).unwrap();
let host: Host = Host::new(config, Arc::new(NetworkStats::new()), None).unwrap();
assert!(host.local_url().starts_with("enode://101b3ef5a4ea7a1c7928e24c4c75fd053c235d7b80c22ae5c03d145d0ac7396e2a4ffff9adee3133a7b05044a5cee08115fd65145e5165d646bde371010d803c@"));
}

View File

@@ -44,7 +44,7 @@
//! }
//!
//! fn main () {
//! let mut service = NetworkService::new(NetworkConfiguration::new_local()).expect("Error creating network service");
//! let mut service = NetworkService::new(NetworkConfiguration::new_local(), None).expect("Error creating network service");
//! service.start().expect("Error starting service");
//! service.register_protocol(Arc::new(MyHandler), *b"myp", 1, &[1u8]);
//!
@@ -95,6 +95,7 @@ mod error;
mod node_table;
mod stats;
mod ip_utils;
mod connection_filter;
#[cfg(test)]
mod tests;
@@ -104,6 +105,7 @@ pub use service::NetworkService;
pub use error::NetworkError;
pub use stats::NetworkStats;
pub use session::SessionInfo;
pub use connection_filter::{ConnectionFilter, ConnectionDirection};
pub use io::TimerToken;
pub use node_table::{is_valid_node_url, NodeId};

View File

@@ -22,6 +22,7 @@ use io::*;
use parking_lot::RwLock;
use std::sync::Arc;
use ansi_term::Colour;
use connection_filter::ConnectionFilter;
struct HostHandler {
public_url: RwLock<Option<String>>
@@ -48,11 +49,12 @@ pub struct NetworkService {
stats: Arc<NetworkStats>,
host_handler: Arc<HostHandler>,
config: NetworkConfiguration,
filter: Option<Arc<ConnectionFilter>>,
}
impl NetworkService {
/// Starts IO event loop
pub fn new(config: NetworkConfiguration) -> Result<NetworkService, NetworkError> {
pub fn new(config: NetworkConfiguration, filter: Option<Arc<ConnectionFilter>>) -> Result<NetworkService, NetworkError> {
let host_handler = Arc::new(HostHandler { public_url: RwLock::new(None) });
let io_service = IoService::<NetworkIoMessage>::start()?;
@@ -65,6 +67,7 @@ impl NetworkService {
host: RwLock::new(None),
config: config,
host_handler: host_handler,
filter: filter,
})
}
@@ -115,7 +118,7 @@ impl NetworkService {
pub fn start(&self) -> Result<(), NetworkError> {
let mut host = self.host.write();
if host.is_none() {
let h = Arc::new(Host::new(self.config.clone(), self.stats.clone())?);
let h = Arc::new(Host::new(self.config.clone(), self.stats.clone(), self.filter.clone())?);
self.io_service.register_handler(h.clone())?;
*host = Some(h);
}

View File

@@ -92,7 +92,7 @@ impl NetworkProtocolHandler for TestProtocol {
#[test]
fn net_service() {
let service = NetworkService::new(NetworkConfiguration::new_local()).expect("Error creating network service");
let service = NetworkService::new(NetworkConfiguration::new_local(), None).expect("Error creating network service");
service.start().unwrap();
service.register_protocol(Arc::new(TestProtocol::new(false)), *b"myp", 1, &[1u8]).unwrap();
}
@@ -104,13 +104,13 @@ fn net_connect() {
let mut config1 = NetworkConfiguration::new_local();
config1.use_secret = Some(key1.secret().clone());
config1.boot_nodes = vec![ ];
let mut service1 = NetworkService::new(config1).unwrap();
let mut service1 = NetworkService::new(config1, None).unwrap();
service1.start().unwrap();
let handler1 = TestProtocol::register(&mut service1, false);
let mut config2 = NetworkConfiguration::new_local();
info!("net_connect: local URL: {}", service1.local_url().unwrap());
config2.boot_nodes = vec![ service1.local_url().unwrap() ];
let mut service2 = NetworkService::new(config2).unwrap();
let mut service2 = NetworkService::new(config2, None).unwrap();
service2.start().unwrap();
let handler2 = TestProtocol::register(&mut service2, false);
while !handler1.got_packet() && !handler2.got_packet() && (service1.stats().sessions() == 0 || service2.stats().sessions() == 0) {
@@ -123,7 +123,7 @@ fn net_connect() {
#[test]
fn net_start_stop() {
let config = NetworkConfiguration::new_local();
let service = NetworkService::new(config).unwrap();
let service = NetworkService::new(config, None).unwrap();
service.start().unwrap();
service.stop().unwrap();
service.start().unwrap();
@@ -135,12 +135,12 @@ fn net_disconnect() {
let mut config1 = NetworkConfiguration::new_local();
config1.use_secret = Some(key1.secret().clone());
config1.boot_nodes = vec![ ];
let mut service1 = NetworkService::new(config1).unwrap();
let mut service1 = NetworkService::new(config1, None).unwrap();
service1.start().unwrap();
let handler1 = TestProtocol::register(&mut service1, false);
let mut config2 = NetworkConfiguration::new_local();
config2.boot_nodes = vec![ service1.local_url().unwrap() ];
let mut service2 = NetworkService::new(config2).unwrap();
let mut service2 = NetworkService::new(config2, None).unwrap();
service2.start().unwrap();
let handler2 = TestProtocol::register(&mut service2, true);
while !(handler1.got_disconnect() && handler2.got_disconnect()) {
@@ -153,7 +153,7 @@ fn net_disconnect() {
#[test]
fn net_timeout() {
let config = NetworkConfiguration::new_local();
let mut service = NetworkService::new(config).unwrap();
let mut service = NetworkService::new(config, None).unwrap();
service.start().unwrap();
let handler = TestProtocol::register(&mut service, false);
while !handler.got_timeout() {