diff --git a/sync/src/light_sync/response.rs b/sync/src/light_sync/response.rs index 2d64bb36b..131f7e4e2 100644 --- a/sync/src/light_sync/response.rs +++ b/sync/src/light_sync/response.rs @@ -21,10 +21,10 @@ use std::fmt; use ethcore::header::Header; use light::request::{HashOrNumber, Headers as HeadersRequest}; use rlp::{DecoderError, UntrustedRlp, View}; -use util::H256; +use util::{Bytes, H256}; /// Errors found when decoding headers and verifying with basic constraints. -#[derive(Debug, Clone)] +#[derive(Debug, PartialEq)] pub enum BasicError { /// Wrong skip value: expected, found (if any). WrongSkip(u64, Option), @@ -46,7 +46,7 @@ impl From for BasicError { impl fmt::Display for BasicError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Header response verification error: "); + try!(write!(f, "Header response verification error: ")); match *self { BasicError::WrongSkip(ref exp, ref got) @@ -58,7 +58,7 @@ impl fmt::Display for BasicError { BasicError::TooManyHeaders(ref max, ref got) => write!(f, "too many headers (max {}, got {})", max, got), BasicError::Decoder(ref err) - => write!(f, "invalid encoding ({})", err), + => write!(f, "{}", err), } } } @@ -73,7 +73,7 @@ pub trait Constraint { /// Decode a response and do basic verification against a request. pub fn decode_and_verify(headers: &[Bytes], request: &HeadersRequest) -> Result, BasicError> { - let headers: Vec<_> = try!!(headers.iter().map(|x| UntrustedRlp::new(&x).as_val()).collect()); + let headers: Vec<_> = try!(headers.iter().map(|x| UntrustedRlp::new(&x).as_val()).collect()); let reverse = request.reverse; @@ -84,6 +84,8 @@ pub fn decode_and_verify(headers: &[Bytes], request: &HeadersRequest) -> Result< } try!(SkipsBetween(request.skip).verify(&headers, reverse)); + + Ok(headers) } struct StartsAtNumber(u64); @@ -150,3 +152,107 @@ impl Constraint for Max { } } +#[cfg(test)] +mod tests { + use ethcore::header::Header; + use light::request::Headers as HeadersRequest; + + use super::*; + + #[test] + fn sequential_forward() { + let request = HeadersRequest { + start: 10.into(), + max: 30, + skip: 0, + reverse: false, + }; + + let mut parent_hash = None; + let headers: Vec<_> = (0..25).map(|x| x + 10).map(|x| { + let mut header = Header::default(); + header.set_number(x); + + if let Some(parent_hash) = parent_hash { + header.set_parent_hash(parent_hash); + } + + parent_hash = Some(header.hash()); + + ::rlp::encode(&header).to_vec() + }).collect(); + + assert!(decode_and_verify(&headers, &request).is_ok()); + } + + #[test] + fn sequential_backward() { + let request = HeadersRequest { + start: 10.into(), + max: 30, + skip: 0, + reverse: true, + }; + + let mut parent_hash = None; + let headers: Vec<_> = (0..25).map(|x| x + 10).rev().map(|x| { + let mut header = Header::default(); + header.set_number(x); + + if let Some(parent_hash) = parent_hash { + header.set_parent_hash(parent_hash); + } + + parent_hash = Some(header.hash()); + + ::rlp::encode(&header).to_vec() + }).collect(); + + assert!(decode_and_verify(&headers, &request).is_ok()); + } + + #[test] + fn too_many() { + let request = HeadersRequest { + start: 10.into(), + max: 20, + skip: 0, + reverse: false, + }; + + let mut parent_hash = None; + let headers: Vec<_> = (0..25).map(|x| x + 10).map(|x| { + let mut header = Header::default(); + header.set_number(x); + + if let Some(parent_hash) = parent_hash { + header.set_parent_hash(parent_hash); + } + + parent_hash = Some(header.hash()); + + ::rlp::encode(&header).to_vec() + }).collect(); + + assert_eq!(decode_and_verify(&headers, &request), Err(BasicError::TooManyHeaders(20, 25))); + } + + #[test] + fn wrong_skip() { + let request = HeadersRequest { + start: 10.into(), + max: 30, + skip: 5, + reverse: false, + }; + + let headers: Vec<_> = (0..25).map(|x| x * 3).map(|x| x + 10).map(|x| { + let mut header = Header::default(); + header.set_number(x); + + ::rlp::encode(&header).to_vec() + }).collect(); + + assert_eq!(decode_and_verify(&headers, &request), Err(BasicError::WrongSkip(5, Some(2)))); + } +}