diff --git a/secret_store/src/key_server_cluster/cluster.rs b/secret_store/src/key_server_cluster/cluster.rs index bd857bda8..209126605 100644 --- a/secret_store/src/key_server_cluster/cluster.rs +++ b/secret_store/src/key_server_cluster/cluster.rs @@ -16,7 +16,7 @@ use std::io; use std::time; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use std::sync::atomic::{AtomicBool, Ordering}; use std::collections::{BTreeMap, BTreeSet, VecDeque}; use std::collections::btree_map::Entry; @@ -27,8 +27,8 @@ use parking_lot::{RwLock, Mutex}; use tokio_io::IoFuture; use tokio_core::reactor::{Handle, Remote, Interval}; use tokio_core::net::{TcpListener, TcpStream}; -use ethkey::{Secret, KeyPair, Signature, Random, Generator}; -use key_server_cluster::{Error, NodeId, SessionId, AclStorage, KeyStorage}; +use ethkey::{Public, Secret, KeyPair, Signature, Random, Generator}; +use key_server_cluster::{Error, NodeId, SessionId, AclStorage, KeyStorage, DocumentEncryptedKeyShadow}; use key_server_cluster::message::{self, Message, ClusterMessage, EncryptionMessage, DecryptionMessage}; use key_server_cluster::decryption_session::{SessionImpl as DecryptionSessionImpl, SessionState as DecryptionSessionState, SessionParams as DecryptionSessionParams, Session as DecryptionSession, DecryptionSessionId}; @@ -236,6 +236,28 @@ pub struct Connection { last_message_time: Mutex, } +/// Encryption session implementation, which removes session from cluster on drop. +struct EncryptionSessionWrapper { + /// Wrapped session. + session: Arc, + /// Session Id. + session_id: SessionId, + /// Cluster data reference. + cluster: Weak, +} + +/// Decryption session implementation, which removes session from cluster on drop. +struct DecryptionSessionWrapper { + /// Wrapped session. + session: Arc, + /// Session Id. + session_id: SessionId, + /// Session sub id. + access_key: Secret, + /// Cluster data reference. + cluster: Weak, +} + impl ClusterCore { pub fn new(handle: Handle, config: ClusterConfiguration) -> Result, Error> { let listen_address = make_socket_address(&config.listen_address.0, config.listen_address.1)?; @@ -1011,9 +1033,9 @@ impl ClusterClient for ClusterClientImpl { connected_nodes.insert(self.data.self_key_pair.public().clone()); let cluster = Arc::new(ClusterView::new(self.data.clone(), connected_nodes.clone())); - let session = self.data.sessions.new_encryption_session(self.data.self_key_pair.public().clone(), session_id, cluster)?; + let session = self.data.sessions.new_encryption_session(self.data.self_key_pair.public().clone(), session_id.clone(), cluster)?; session.initialize(threshold, connected_nodes)?; - Ok(session) + Ok(EncryptionSessionWrapper::new(Arc::downgrade(&self.data), session_id, session)) } fn new_decryption_session(&self, session_id: SessionId, requestor_signature: Signature, is_shadow_decryption: bool) -> Result, Error> { @@ -1022,9 +1044,9 @@ impl ClusterClient for ClusterClientImpl { let access_key = Random.generate()?.secret().clone(); let cluster = Arc::new(ClusterView::new(self.data.clone(), connected_nodes.clone())); - let session = self.data.sessions.new_decryption_session(self.data.self_key_pair.public().clone(), session_id, access_key, cluster)?; + let session = self.data.sessions.new_decryption_session(self.data.self_key_pair.public().clone(), session_id, access_key.clone(), cluster)?; session.initialize(requestor_signature, is_shadow_decryption)?; - Ok(session) + Ok(DecryptionSessionWrapper::new(Arc::downgrade(&self.data), session_id, access_key, session)) } #[cfg(test)] @@ -1043,6 +1065,64 @@ impl ClusterClient for ClusterClientImpl { } } +impl EncryptionSessionWrapper { + pub fn new(cluster: Weak, session_id: SessionId, session: Arc) -> Arc { + Arc::new(EncryptionSessionWrapper { + session: session, + session_id: session_id, + cluster: cluster, + }) + } +} + +impl EncryptionSession for EncryptionSessionWrapper { + fn state(&self) -> EncryptionSessionState { + self.session.state() + } + + fn wait(&self, timeout: Option) -> Result { + self.session.wait(timeout) + } + + #[cfg(test)] + fn joint_public_key(&self) -> Option> { + self.session.joint_public_key() + } +} + +impl Drop for EncryptionSessionWrapper { + fn drop(&mut self) { + if let Some(cluster) = self.cluster.upgrade() { + cluster.sessions.remove_encryption_session(&self.session_id); + } + } +} + +impl DecryptionSessionWrapper { + pub fn new(cluster: Weak, session_id: SessionId, access_key: Secret, session: Arc) -> Arc { + Arc::new(DecryptionSessionWrapper { + session: session, + session_id: session_id, + access_key: access_key, + cluster: cluster, + }) + } +} + +impl DecryptionSession for DecryptionSessionWrapper { + fn wait(&self) -> Result { + self.session.wait() + } +} + +impl Drop for DecryptionSessionWrapper { + fn drop(&mut self) { + if let Some(cluster) = self.cluster.upgrade() { + cluster.sessions.remove_decryption_session(&self.session_id, &self.access_key); + } + } +} + fn make_socket_address(address: &str, port: u16) -> Result { let ip_address: IpAddr = address.parse().map_err(|_| Error::InvalidNodeAddress)?; Ok(SocketAddr::new(ip_address, port))