diff --git a/Cargo.lock b/Cargo.lock index addfaa23a..fc6c35554 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -700,6 +700,7 @@ dependencies = [ "rustc-hex 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "serde_json 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)", "slab 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "snappy 0.1.0", "time 0.1.38 (registry+https://github.com/rust-lang/crates.io-index)", "tiny-keccak 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -2265,7 +2266,7 @@ dependencies = [ [[package]] name = "parity-ui-precompiled" version = "1.4.0" -source = "git+https://github.com/paritytech/js-precompiled.git#fe0b4dbdfe6e7ebebb247d565d937fd0a0feca5f" +source = "git+https://github.com/paritytech/js-precompiled.git#29360e67331334a9ec3aafdb3725d8f7d8b5d2a1" dependencies = [ "parity-dapps-glue 1.9.0 (registry+https://github.com/rust-lang/crates.io-index)", ] diff --git a/hw/src/trezor.rs b/hw/src/trezor.rs index 239968120..a77d7233c 100644 --- a/hw/src/trezor.rs +++ b/hw/src/trezor.rs @@ -37,8 +37,8 @@ use trezor_sys::messages::{EthereumAddress, PinMatrixAck, MessageType, EthereumT const TREZOR_VID: u16 = 0x534c; const TREZOR_PIDS: [u16; 1] = [0x0001]; // Trezor v1, keeping this as an array to leave room for Trezor v2 which is in progress -const ETH_DERIVATION_PATH: [u32; 4] = [0x8000002C, 0x8000003C, 0x80000000, 0]; // m/44'/60'/0'/0 -const ETC_DERIVATION_PATH: [u32; 4] = [0x8000002C, 0x8000003D, 0x80000000, 0]; // m/44'/61'/0'/0 +const ETH_DERIVATION_PATH: [u32; 5] = [0x8000002C, 0x8000003C, 0x80000000, 0, 0]; // m/44'/60'/0'/0/0 +const ETC_DERIVATION_PATH: [u32; 5] = [0x8000002C, 0x8000003D, 0x80000000, 0, 0]; // m/44'/61'/0'/0/0 /// Hardware wallet error. diff --git a/js-old/src/ui/TokenImage/tokenImage.js b/js-old/src/ui/TokenImage/tokenImage.js index e0e66d22b..af7a80a02 100644 --- a/js-old/src/ui/TokenImage/tokenImage.js +++ b/js-old/src/ui/TokenImage/tokenImage.js @@ -32,14 +32,19 @@ class TokenImage extends Component { }).isRequired }; + state = { + error: false + }; + render () { + const { error } = this.state; const { api } = this.context; const { image, token } = this.props; const imageurl = token.image || image; let imagesrc = unknownImage; - if (imageurl) { + if (imageurl && !error) { const host = /^(\/)?api/.test(imageurl) ? api.dappsUrl : ''; @@ -49,11 +54,16 @@ class TokenImage extends Component { return ( { ); } + + handleError = () => { + this.setState({ error: true }); + }; } function mapStateToProps (iniState) { diff --git a/js/src/redux/providers/tokensActions.js b/js/src/redux/providers/tokensActions.js new file mode 100644 index 000000000..2e1e8c052 --- /dev/null +++ b/js/src/redux/providers/tokensActions.js @@ -0,0 +1,250 @@ +// Copyright 2015-2017 Parity Technologies (UK) Ltd. +// This file is part of Parity. + +// Parity is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// Parity is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with Parity. If not, see . + +import { chunk, uniq } from 'lodash'; +import store from 'store'; + +import Contracts from '~/contracts'; +import { LOG_KEYS, getLogger } from '~/config'; +import { fetchTokenIds, fetchTokensBasics, fetchTokensInfo, fetchTokensImages } from '~/util/tokens'; + +import { setAddressImage } from './imagesActions'; + +const TOKENS_CACHE_LS_KEY_PREFIX = '_parity::tokens::'; +const log = getLogger(LOG_KEYS.Balances); + +function _setTokens (tokens) { + return { + type: 'setTokens', + tokens + }; +} + +export function setTokens (nextTokens) { + return (dispatch, getState) => { + const { nodeStatus, tokens: prevTokens } = getState(); + const { tokenReg } = Contracts.get(); + const tokens = { + ...prevTokens, + ...nextTokens + }; + + return tokenReg.getContract() + .then((tokenRegContract) => { + const lsKey = TOKENS_CACHE_LS_KEY_PREFIX + nodeStatus.netChain; + + store.set(lsKey, { + tokenreg: tokenRegContract.address, + tokens + }); + }) + .catch((error) => { + console.error(error); + }) + .then(() => { + dispatch(_setTokens(nextTokens)); + }); + }; +} + +function loadCachedTokens (tokenRegContract) { + return (dispatch, getState) => { + const { nodeStatus } = getState(); + + const lsKey = TOKENS_CACHE_LS_KEY_PREFIX + nodeStatus.netChain; + const cached = store.get(lsKey); + + if (cached) { + // Check if we have data from the right contract + if (cached.tokenreg === tokenRegContract.address && cached.tokens) { + log.debug('found cached tokens', cached.tokens); + + // Fetch all the tokens images on load + // (it's the only thing that might have changed) + const tokenIndexes = Object.values(cached.tokens) + .filter((t) => t && t.fetched) + .map((t) => t.index); + + fetchTokensData(tokenRegContract, tokenIndexes)(dispatch, getState); + } else { + store.remove(lsKey); + } + } + }; +} + +export function loadTokens (options = {}) { + log.debug('loading tokens', Object.keys(options).length ? options : ''); + + return (dispatch, getState) => { + const { tokenReg } = Contracts.get(); + + return tokenReg.getContract() + .then((tokenRegContract) => { + loadCachedTokens(tokenRegContract)(dispatch, getState); + return fetchTokenIds(tokenRegContract.instance); + }) + .then((tokenIndexes) => loadTokensBasics(tokenIndexes, options)(dispatch, getState)) + .catch((error) => { + console.warn('tokens::loadTokens', error); + }); + }; +} + +export function loadTokensBasics (tokenIndexes, options) { + const limit = 64; + + return (dispatch, getState) => { + const { api } = getState(); + const { tokenReg } = Contracts.get(); + const nextTokens = {}; + const count = tokenIndexes.length; + + log.debug('loading basic tokens', tokenIndexes); + + if (count === 0) { + return Promise.resolve(); + } + + return tokenReg.getContract() + .then((tokenRegContract) => { + let promise = Promise.resolve(); + const first = tokenIndexes[0]; + const last = tokenIndexes[tokenIndexes.length - 1]; + + for (let from = first; from <= last; from += limit) { + // No need to fetch `limit` elements + const lowerLimit = Math.min(limit, last - from + 1); + + promise = promise + .then(() => fetchTokensBasics(api, tokenRegContract, from, lowerLimit)) + .then((results) => { + results + .forEach((token) => { + nextTokens[token.id] = token; + }); + }); + } + + return promise; + }) + .then(() => { + log.debug('fetched tokens basic info', nextTokens); + + dispatch(setTokens(nextTokens)); + }) + .catch((error) => { + console.warn('tokens::fetchTokens', error); + }); + }; +} + +export function fetchTokens (_tokenIndexes) { + const tokenIndexes = uniq(_tokenIndexes || []); + const tokenChunks = chunk(tokenIndexes, 64); + + return (dispatch, getState) => { + const { tokenReg } = Contracts.get(); + + return tokenReg.getContract() + .then((tokenRegContract) => { + let promise = Promise.resolve(); + + tokenChunks.forEach((tokenChunk) => { + promise = promise + .then(() => fetchTokensData(tokenRegContract, tokenChunk)(dispatch, getState)); + }); + + return promise; + }) + .then(() => { + log.debug('fetched token', getState().tokens); + }) + .catch((error) => { + console.warn('tokens::fetchTokens', error); + }); + }; +} + +/** + * Split the given token indexes between those for whom + * we already have some info, and thus just need to fetch + * the image, and those for whom we don't have anything and + * need to fetch all the info. + */ +function fetchTokensData (tokenRegContract, tokenIndexes) { + return (dispatch, getState) => { + const { api, tokens, images } = getState(); + const allTokens = Object.values(tokens); + + const tokensIndexesMap = allTokens + .reduce((map, token) => { + map[token.index] = token; + return map; + }, {}); + + const fetchedTokenIndexes = allTokens + .filter((token) => token.fetched) + .map((token) => token.index); + + const fullIndexes = []; + const partialIndexes = []; + + tokenIndexes.forEach((tokenIndex) => { + if (fetchedTokenIndexes.includes(tokenIndex)) { + partialIndexes.push(tokenIndex); + } else { + fullIndexes.push(tokenIndex); + } + }); + + log.debug('need to fully fetch', fullIndexes); + log.debug('need to partially fetch', partialIndexes); + + const fullPromise = fetchTokensInfo(api, tokenRegContract, fullIndexes); + const partialPromise = fetchTokensImages(api, tokenRegContract, partialIndexes) + .then((imagesResult) => { + return imagesResult.map((image, index) => { + const tokenIndex = partialIndexes[index]; + const token = tokensIndexesMap[tokenIndex]; + + return { ...token, image }; + }); + }); + + return Promise.all([ fullPromise, partialPromise ]) + .then(([ fullResults, partialResults ]) => { + log.debug('fetched', { fullResults, partialResults }); + + return [].concat(fullResults, partialResults) + .filter(({ address }) => !/0x0*$/.test(address)) + .reduce((tokens, token) => { + const { id, image, address } = token; + + // dispatch only the changed images + if (images[address] !== image) { + dispatch(setAddressImage(address, image, true)); + } + + tokens[id] = token; + return tokens; + }, {}); + }) + .then((tokens) => { + dispatch(setTokens(tokens)); + }); + }; +} diff --git a/js/src/util/tokens/index.js b/js/src/util/tokens/index.js new file mode 100644 index 000000000..11ad0f903 --- /dev/null +++ b/js/src/util/tokens/index.js @@ -0,0 +1,299 @@ +// Copyright 2015-2017 Parity Technologies (UK) Ltd. +// This file is part of Parity. + +// Parity is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// Parity is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with Parity. If not, see . + +import { range } from 'lodash'; +import BigNumber from 'bignumber.js'; + +import { hashToImageUrl } from '~/redux/util'; +import { sha3 } from '~/api/util/sha3'; +import imagesEthereum from '~/../assets/images/contracts/ethereum-black-64x64.png'; +import { + tokenAddresses as tokenAddressesBytcode, + tokensBalances as tokensBalancesBytecode +} from './bytecodes'; + +export const ETH_TOKEN = { + address: '', + format: new BigNumber(10).pow(18), + id: getTokenId('eth_native_token'), + image: imagesEthereum, + name: 'Ethereum', + native: true, + tag: 'ETH' +}; + +export function fetchTokenIds (tokenregInstance) { + return tokenregInstance.tokenCount + .call() + .then((numTokens) => { + const tokenIndexes = range(numTokens.toNumber()); + + return tokenIndexes; + }); +} + +export function fetchTokensBasics (api, tokenReg, start = 0, limit = 100) { + const tokenAddressesCallData = encode( + api, + [ 'address', 'uint', 'uint' ], + [ tokenReg.address, start, limit ] + ); + + return api.eth + .call({ data: tokenAddressesBytcode + tokenAddressesCallData }) + .then((result) => { + return decodeArray(api, 'address[]', result); + }) + .then((tokenAddresses) => { + return tokenAddresses.map((tokenAddress, index) => { + if (/^0x0*$/.test(tokenAddress)) { + return null; + } + + const tokenIndex = start + index; + + return { + address: tokenAddress, + id: getTokenId(tokenAddress, tokenIndex), + index: tokenIndex, + + fetched: false + }; + }); + }) + .then((tokens) => tokens.filter((token) => token)) + .then((tokens) => { + const randomAddress = sha3(`${Date.now()}`).substr(0, 42); + + return fetchTokensBalances(api, tokens, [randomAddress]) + .then((_balances) => { + const balances = _balances[randomAddress]; + + return tokens.filter(({ id }) => balances[id].eq(0)); + }); + }); +} + +export function fetchTokensInfo (api, tokenReg, tokenIndexes) { + const requests = tokenIndexes.map((tokenIndex) => { + const tokenCalldata = tokenReg.getCallData(tokenReg.instance.token, {}, [tokenIndex]); + + return { to: tokenReg.address, data: tokenCalldata }; + }); + + const calls = requests.map((req) => api.eth.call(req)); + const imagesPromise = fetchTokensImages(api, tokenReg, tokenIndexes); + + return Promise.all(calls) + .then((results) => { + return imagesPromise.then((images) => [ results, images ]); + }) + .then(([ results, images ]) => { + return results.map((rawTokenData, index) => { + const tokenIndex = tokenIndexes[index]; + const tokenData = tokenReg.instance.token + .decodeOutput(rawTokenData) + .map((t) => t.value); + + const [ address, tag, format, name ] = tokenData; + const image = images[index]; + + const token = { + address, + id: getTokenId(address, tokenIndex), + index: tokenIndex, + + format: format.toString(), + image, + name, + tag, + + fetched: true + }; + + return token; + }); + }); +} + +export function fetchTokensImages (api, tokenReg, tokenIndexes) { + const requests = tokenIndexes.map((tokenIndex) => { + const metaCalldata = tokenReg.getCallData(tokenReg.instance.meta, {}, [tokenIndex, 'IMG']); + + return { to: tokenReg.address, data: metaCalldata }; + }); + + const calls = requests.map((req) => api.eth.call(req)); + + return Promise.all(calls) + .then((results) => { + return results.map((rawImage) => { + const image = tokenReg.instance.meta.decodeOutput(rawImage)[0].value; + + return hashToImageUrl(image); + }); + }); +} + +/** + * `updates` should be in the shape: + * { + * [ who ]: [ tokenId ] // Array of tokens to updates + * } + * + * Returns a Promise resolved with the balances in the shape: + * { + * [ who ]: { [ tokenId ]: BigNumber } // The balances of `who` + * } + */ +export function fetchAccountsBalances (api, tokens, updates) { + const accountAddresses = Object.keys(updates); + + // Updates for the ETH balances + const ethUpdates = accountAddresses + .filter((accountAddress) => { + return updates[accountAddress].find((tokenId) => tokenId === ETH_TOKEN.id); + }) + .reduce((nextUpdates, accountAddress) => { + nextUpdates[accountAddress] = [ETH_TOKEN.id]; + return nextUpdates; + }, {}); + + // Updates for Tokens balances + const tokenUpdates = Object.keys(updates) + .reduce((nextUpdates, accountAddress) => { + const tokenIds = updates[accountAddress].filter((tokenId) => tokenId !== ETH_TOKEN.id); + + if (tokenIds.length > 0) { + nextUpdates[accountAddress] = tokenIds; + } + + return nextUpdates; + }, {}); + + let ethBalances = {}; + let tokensBalances = {}; + + const ethPromise = fetchEthBalances(api, Object.keys(ethUpdates)) + .then((_ethBalances) => { + ethBalances = _ethBalances; + }); + + const tokenPromise = Object.keys(tokenUpdates) + .reduce((tokenPromise, accountAddress) => { + const tokenIds = tokenUpdates[accountAddress]; + const updateTokens = tokens + .filter((t) => tokenIds.includes(t.id)); + + return tokenPromise + .then(() => fetchTokensBalances(api, updateTokens, [ accountAddress ])) + .then((balances) => { + tokensBalances[accountAddress] = balances[accountAddress]; + }); + }, Promise.resolve()); + + return Promise.all([ ethPromise, tokenPromise ]) + .then(() => { + const balances = Object.assign({}, tokensBalances); + + Object.keys(ethBalances).forEach((accountAddress) => { + if (!balances[accountAddress]) { + balances[accountAddress] = {}; + } + + balances[accountAddress] = Object.assign( + {}, + balances[accountAddress], + ethBalances[accountAddress] + ); + }); + + return balances; + }); +} + +function fetchEthBalances (api, accountAddresses) { + const promises = accountAddresses + .map((accountAddress) => api.eth.getBalance(accountAddress)); + + return Promise.all(promises) + .then((balancesArray) => { + return balancesArray.reduce((balances, balance, index) => { + balances[accountAddresses[index]] = { + [ETH_TOKEN.id]: balance + }; + + return balances; + }, {}); + }); +} + +function fetchTokensBalances (api, tokens, accountAddresses) { + const tokenAddresses = tokens.map((t) => t.address); + const tokensBalancesCallData = encode( + api, + [ 'address[]', 'address[]' ], + [ accountAddresses, tokenAddresses ] + ); + + return api.eth + .call({ data: tokensBalancesBytecode + tokensBalancesCallData }) + .then((result) => { + const rawBalances = decodeArray(api, 'uint[]', result); + const balances = {}; + + accountAddresses.forEach((accountAddress, accountIndex) => { + const balance = {}; + const preIndex = accountIndex * tokenAddresses.length; + + tokenAddresses.forEach((tokenAddress, tokenIndex) => { + const index = preIndex + tokenIndex; + const token = tokens[tokenIndex]; + + balance[token.id] = rawBalances[index]; + }); + + balances[accountAddress] = balance; + }); + + return balances; + }); +} + +function getTokenId (...args) { + return sha3(args.join('')).slice(0, 10); +} + +function encode (api, types, values) { + return api.util.abiEncode( + null, + types, + values + ).replace('0x', ''); +} + +function decodeArray (api, type, data) { + return api.util + .abiDecode( + [type], + [ + '0x', + (32).toString(16).padStart(64, 0), + data.replace('0x', '') + ].join('') + )[0] + .map((t) => t.value); +} diff --git a/util/network/Cargo.toml b/util/network/Cargo.toml index 0960715bf..e989fb599 100644 --- a/util/network/Cargo.toml +++ b/util/network/Cargo.toml @@ -33,6 +33,7 @@ path = { path = "../path" } ethcore-logger = { path ="../../logger" } ipnetwork = "0.12.6" hash = { path = "../hash" } +snappy = { path = "../snappy" } serde_json = "1.0" [features] diff --git a/util/network/src/connection.rs b/util/network/src/connection.rs index fd61b4b38..726952648 100644 --- a/util/network/src/connection.rs +++ b/util/network/src/connection.rs @@ -40,6 +40,7 @@ use crypto; const ENCRYPTED_HEADER_LEN: usize = 32; const RECIEVE_PAYLOAD_TIMEOUT: u64 = 30000; +pub const MAX_PAYLOAD_SIZE: usize = (1 << 24) - 1; pub trait GenericSocket : Read + Write { } @@ -345,7 +346,7 @@ impl EncryptedConnection { ingress_mac: ingress_mac, read_state: EncryptedConnectionState::Header, protocol_id: 0, - payload_len: 0 + payload_len: 0, }; enc.connection.expect(ENCRYPTED_HEADER_LEN); Ok(enc) @@ -355,7 +356,7 @@ impl EncryptedConnection { pub fn send_packet(&mut self, io: &IoContext, payload: &[u8]) -> Result<(), NetworkError> where Message: Send + Clone + Sync + 'static { let mut header = RlpStream::new(); let len = payload.len(); - if len >= (1 << 24) { + if len > MAX_PAYLOAD_SIZE { return Err(NetworkError::OversizedPacket); } header.append_raw(&[(len >> 16) as u8, (len >> 8) as u8, len as u8], 1); diff --git a/util/network/src/error.rs b/util/network/src/error.rs index 54773d573..96fc1ff23 100644 --- a/util/network/src/error.rs +++ b/util/network/src/error.rs @@ -19,6 +19,7 @@ use rlp::*; use std::fmt; use ethkey::Error as KeyError; use crypto::Error as CryptoError; +use snappy; #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum DisconnectReason @@ -107,6 +108,8 @@ pub enum NetworkError { StdIo(::std::io::Error), /// Packet size is over the protocol limit. OversizedPacket, + /// Decompression error. + Decompression(snappy::InvalidInput), } impl fmt::Display for NetworkError { @@ -126,6 +129,7 @@ impl fmt::Display for NetworkError { StdIo(ref err) => format!("{}", err), InvalidNodeId => "Invalid node id".into(), OversizedPacket => "Packet is too large".into(), + Decompression(ref err) => format!("Error decompressing packet: {}", err), }; f.write_fmt(format_args!("Network error ({})", msg)) @@ -162,6 +166,12 @@ impl From for NetworkError { } } +impl From for NetworkError { + fn from(err: snappy::InvalidInput) -> NetworkError { + NetworkError::Decompression(err) + } +} + impl From<::std::net::AddrParseError> for NetworkError { fn from(err: ::std::net::AddrParseError) -> NetworkError { NetworkError::AddressParse(err) diff --git a/util/network/src/host.rs b/util/network/src/host.rs index 13e5f74a3..3d21bf5fe 100644 --- a/util/network/src/host.rs +++ b/util/network/src/host.rs @@ -256,7 +256,7 @@ impl<'s> NetworkContext<'s> { pub fn send_protocol(&self, protocol: ProtocolId, peer: PeerId, packet_id: PacketId, data: Vec) -> Result<(), NetworkError> { let session = self.resolve_session(peer); if let Some(session) = session { - session.lock().send_packet(self.io, protocol, packet_id as u8, &data)?; + session.lock().send_packet(self.io, Some(protocol), packet_id as u8, &data)?; } else { trace!(target: "network", "Send: Peer no longer exist") } @@ -938,7 +938,7 @@ impl Host { for (p, packet_id, data) in packet_data { let reserved = self.reserved_nodes.read(); if let Some(h) = handlers.get(&p).clone() { - h.read(&NetworkContext::new(io, p, Some(session.clone()), self.sessions.clone(), &reserved), &token, packet_id, &data[1..]); + h.read(&NetworkContext::new(io, p, Some(session.clone()), self.sessions.clone(), &reserved), &token, packet_id, &data); } } } diff --git a/util/network/src/lib.rs b/util/network/src/lib.rs index 6257118b6..8f2ccb8d4 100644 --- a/util/network/src/lib.rs +++ b/util/network/src/lib.rs @@ -81,6 +81,7 @@ extern crate ethcore_logger; extern crate ipnetwork; extern crate hash; extern crate serde_json; +extern crate snappy; #[macro_use] extern crate log; @@ -115,7 +116,7 @@ pub use node_table::{is_valid_node_url, NodeId}; use ipnetwork::{IpNetwork, IpNetworkError}; use std::str::FromStr; -const PROTOCOL_VERSION: u32 = 4; +const PROTOCOL_VERSION: u32 = 5; /// Network IO protocol handler. This needs to be implemented for each new subprotocol. /// All the handler function are called from within IO event loop. diff --git a/util/network/src/session.rs b/util/network/src/session.rs index 992081237..cf6c196e3 100644 --- a/util/network/src/session.rs +++ b/util/network/src/session.rs @@ -25,7 +25,7 @@ use mio::deprecated::{Handler, EventLoop}; use mio::tcp::*; use bigint::hash::*; use rlp::*; -use connection::{EncryptedConnection, Packet, Connection}; +use connection::{EncryptedConnection, Packet, Connection, MAX_PAYLOAD_SIZE}; use handshake::Handshake; use io::{IoContext, StreamToken}; use error::{NetworkError, DisconnectReason}; @@ -33,10 +33,13 @@ use host::*; use node_table::NodeId; use stats::NetworkStats; use time; +use snappy; // Timeout must be less than (interval - 1). const PING_TIMEOUT_SEC: u64 = 60; const PING_INTERVAL_SEC: u64 = 120; +const MIN_PROTOCOL_VERSION: u32 = 4; +const MIN_COMPRESSION_PROTOCOL_VERSION: u32 = 5; #[derive(Debug, Clone)] enum ProtocolState { @@ -61,6 +64,7 @@ pub struct Session { state: State, // Protocol states -- accumulates pending packets until signaled as ready. protocol_states: HashMap, + compression: bool, } enum State { @@ -198,6 +202,7 @@ impl Session { pong_time_ns: None, expired: false, protocol_states: HashMap::new(), + compression: false, }) } @@ -211,7 +216,6 @@ impl Session { }; self.state = State::Session(connection); self.write_hello(io, host)?; - self.send_ping(io)?; Ok(()) } @@ -326,28 +330,43 @@ impl Session { } /// Send a protocol packet to peer. - pub fn send_packet(&mut self, io: &IoContext, protocol: [u8; 3], packet_id: u8, data: &[u8]) -> Result<(), NetworkError> + pub fn send_packet(&mut self, io: &IoContext, protocol: Option<[u8; 3]>, packet_id: u8, data: &[u8]) -> Result<(), NetworkError> where Message: Send + Sync + Clone { - if self.info.capabilities.is_empty() || !self.had_hello { - debug!(target: "network", "Sending to unconfirmed session {}, protocol: {}, packet: {}", self.token(), str::from_utf8(&protocol[..]).unwrap_or("??"), packet_id); + if protocol.is_some() && (self.info.capabilities.is_empty() || !self.had_hello) { + debug!(target: "network", "Sending to unconfirmed session {}, protocol: {:?}, packet: {}", self.token(), protocol.as_ref().map(|p| str::from_utf8(&p[..]).unwrap_or("??")), packet_id); return Err(From::from(NetworkError::BadProtocol)); } if self.expired() { return Err(From::from(NetworkError::Expired)); } let mut i = 0usize; - while protocol != self.info.capabilities[i].protocol { - i += 1; - if i == self.info.capabilities.len() { - debug!(target: "network", "Unknown protocol: {:?}", protocol); - return Ok(()) - } - } - let pid = self.info.capabilities[i].id_offset + packet_id; + let pid = match protocol { + Some(protocol) => { + while protocol != self.info.capabilities[i].protocol { + i += 1; + if i == self.info.capabilities.len() { + debug!(target: "network", "Unknown protocol: {:?}", protocol); + return Ok(()) + } + } + self.info.capabilities[i].id_offset + packet_id + }, + None => packet_id + }; let mut rlp = RlpStream::new(); rlp.append(&(pid as u32)); - rlp.append_raw(data, 1); - self.send(io, rlp) + let mut compressed = Vec::new(); + let mut payload = data; // create a reference with local lifetime + if self.compression { + if payload.len() > MAX_PAYLOAD_SIZE { + return Err(NetworkError::OversizedPacket); + } + let len = snappy::compress_into(&payload, &mut compressed); + trace!(target: "network", "compressed {} to {}", payload.len(), len); + payload = &compressed[0..len]; + } + rlp.append_raw(payload, 1); + self.send(io, &rlp.drain()) } /// Keep this session alive. Returns false if ping timeout happened @@ -396,14 +415,23 @@ impl Session { if packet_id != PACKET_HELLO && packet_id != PACKET_DISCONNECT && !self.had_hello { return Err(From::from(NetworkError::BadProtocol)); } + let data = if self.compression { + let compressed = &packet.data[1..]; + if snappy::decompressed_len(&compressed)? > MAX_PAYLOAD_SIZE { + return Err(NetworkError::OversizedPacket); + } + snappy::decompress(&compressed)? + } else { + packet.data[1..].to_owned() + }; match packet_id { PACKET_HELLO => { - let rlp = UntrustedRlp::new(&packet.data[1..]); //TODO: validate rlp expected size + let rlp = UntrustedRlp::new(&data); //TODO: validate rlp expected size self.read_hello(io, &rlp, host)?; Ok(SessionData::Ready) }, PACKET_DISCONNECT => { - let rlp = UntrustedRlp::new(&packet.data[1..]); + let rlp = UntrustedRlp::new(&data); let reason: u8 = rlp.val_at(0)?; if self.had_hello { debug!(target:"network", "Disconnected: {}: {:?}", self.token(), DisconnectReason::from_u8(reason)); @@ -439,11 +467,11 @@ impl Session { match *self.protocol_states.entry(protocol).or_insert_with(|| ProtocolState::Pending(Vec::new())) { ProtocolState::Connected => { trace!(target: "network", "Packet {} mapped to {:?}:{}, i={}, capabilities={:?}", packet_id, protocol, protocol_packet_id, i, self.info.capabilities); - Ok(SessionData::Packet { data: packet.data, protocol: protocol, packet_id: protocol_packet_id } ) + Ok(SessionData::Packet { data: data, protocol: protocol, packet_id: protocol_packet_id } ) } ProtocolState::Pending(ref mut pending) => { trace!(target: "network", "Packet {} deferred until protocol connection event completion", packet_id); - pending.push((packet.data, protocol_packet_id)); + pending.push((data, protocol_packet_id)); Ok(SessionData::Continue) } @@ -465,7 +493,7 @@ impl Session { .append_list(&host.capabilities) .append(&host.local_endpoint.address.port()) .append(host.id()); - self.send(io, rlp) + self.send(io, &rlp.drain()) } fn read_hello(&mut self, io: &IoContext, rlp: &UntrustedRlp, host: &HostInfo) -> Result<(), NetworkError> @@ -494,8 +522,7 @@ impl Session { while i < caps.len() { if caps.iter().any(|c| c.protocol == caps[i].protocol && c.version > caps[i].version) { caps.remove(i); - } - else { + } else { i += 1; } } @@ -520,52 +547,46 @@ impl Session { trace!(target: "network", "No common capabilities with peer."); return Err(From::from(self.disconnect(io, DisconnectReason::UselessPeer))); } - if protocol != host.protocol_version { + if protocol < MIN_PROTOCOL_VERSION { trace!(target: "network", "Peer protocol version mismatch: {}", protocol); return Err(From::from(self.disconnect(io, DisconnectReason::UselessPeer))); } + self.compression = protocol >= MIN_COMPRESSION_PROTOCOL_VERSION; + self.send_ping(io)?; self.had_hello = true; Ok(()) } /// Senf ping packet pub fn send_ping(&mut self, io: &IoContext) -> Result<(), NetworkError> where Message: Send + Sync + Clone { - self.send(io, Session::prepare(PACKET_PING)?)?; + self.send_packet(io, None, PACKET_PING, &EMPTY_LIST_RLP)?; self.ping_time_ns = time::precise_time_ns(); self.pong_time_ns = None; Ok(()) } fn send_pong(&mut self, io: &IoContext) -> Result<(), NetworkError> where Message: Send + Sync + Clone { - self.send(io, Session::prepare(PACKET_PONG)?) + self.send_packet(io, None, PACKET_PONG, &EMPTY_LIST_RLP) } /// Disconnect this session pub fn disconnect(&mut self, io: &IoContext, reason: DisconnectReason) -> NetworkError where Message: Send + Sync + Clone { if let State::Session(_) = self.state { let mut rlp = RlpStream::new(); - rlp.append(&(PACKET_DISCONNECT as u32)); rlp.begin_list(1); rlp.append(&(reason as u32)); - self.send(io, rlp).ok(); + self.send_packet(io, None, PACKET_DISCONNECT, &rlp.drain()).ok(); } NetworkError::Disconnect(reason) } - fn prepare(packet_id: u8) -> Result { - let mut rlp = RlpStream::new(); - rlp.append(&(packet_id as u32)); - rlp.begin_list(0); - Ok(rlp) - } - - fn send(&mut self, io: &IoContext, rlp: RlpStream) -> Result<(), NetworkError> where Message: Send + Sync + Clone { + fn send(&mut self, io: &IoContext, data: &[u8]) -> Result<(), NetworkError> where Message: Send + Sync + Clone { match self.state { State::Handshake(_) => { warn!(target:"network", "Unexpected send request"); }, State::Session(ref mut s) => { - s.send_packet(io, &rlp.out())? + s.send_packet(io, data)? }, } Ok(())