// Copyright 2015, 2016 Parity Technologies (UK) Ltd. // This file is part of Parity. // Parity is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // Parity is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // You should have received a copy of the GNU General Public License // along with Parity. If not, see . use std::cell::RefCell; use std::{fs, str, thread}; use std::sync::Arc; use std::sync::atomic::AtomicBool; use std::path::PathBuf; use std::io::{self, Write}; use std::collections::HashMap; use mio; use tlsclient::{TlsClient, TlsClientError}; use url::Url; #[derive(Debug)] pub enum FetchError { InvalidAddress, ReadingCaCertificates, UnexpectedStatus(String), CaCertificates(io::Error), Io(io::Error), Notify(mio::NotifyError), Client(TlsClientError), } impl From for FetchError { fn from(e: io::Error) -> Self { FetchError::Io(e) } } impl From> for FetchError { fn from(e: mio::NotifyError) -> Self { FetchError::Notify(e) } } impl From for FetchError { fn from(e: TlsClientError) -> Self { FetchError::Client(e) } } pub type FetchResult = Result<(), FetchError>; pub enum ClientMessage { Fetch(Url, Box, Arc, Box), Shutdown, } pub struct Client { channel: mio::Sender, thread: Option>, } impl Drop for Client { fn drop(&mut self) { self.close_internal(); if let Some(thread) = self.thread.take() { thread.join().expect("Clean shutdown."); } } } impl Client { pub fn new() -> Result { Self::with_limit(None) } pub fn with_limit(size_limit: Option) -> Result { let mut event_loop = try!(mio::EventLoop::new()); let channel = event_loop.channel(); let thread = thread::spawn(move || { let mut client = ClientLoop { next_token: 0, sessions: HashMap::new(), size_limit: size_limit, }; event_loop.run(&mut client).unwrap(); }); Ok(Client { channel: channel, thread: Some(thread), }) } pub fn fetch_to_file(&self, url: Url, path: PathBuf, abort: Arc, callback: F) -> Result<(), FetchError> { let file = try!(fs::File::create(&path)); self.fetch(url, Box::new(file), abort, move |result| { if let Err(_) = result { // remove temporary file let _ = fs::remove_file(&path); } callback(result); }) } pub fn fetch(&self, url: Url, writer: Box, abort: Arc, callback: F) -> Result<(), FetchError> { let cell = RefCell::new(Some(callback)); try!(self.channel.send(ClientMessage::Fetch(url, writer, abort, Box::new(move |res| { cell.borrow_mut().take().expect("Called only once.")(res); })))); Ok(()) } pub fn close(mut self) { self.close_internal() } fn close_internal(&mut self) { if let Err(e) = self.channel.send(ClientMessage::Shutdown) { warn!("Error while closing client: {:?}. Already stopped?", e); } } } pub struct ClientLoop { next_token: usize, sessions: HashMap, size_limit: Option, } impl mio::Handler for ClientLoop { type Timeout = (); type Message = ClientMessage; fn ready(&mut self, event_loop: &mut mio::EventLoop, token: mio::Token, events: mio::EventSet) { let utoken = token.as_usize(); let remove = if let Some(mut tlsclient) = self.sessions.get_mut(&utoken) { tlsclient.ready(event_loop, token, events) } else { false }; if remove { self.sessions.remove(&utoken); } } fn notify(&mut self, event_loop: &mut mio::EventLoop, msg: Self::Message) { match msg { ClientMessage::Shutdown => event_loop.shutdown(), ClientMessage::Fetch(url, writer, abort, callback) => { let token = self.next_token; self.next_token += 1; if let Ok(mut tlsclient) = TlsClient::new(mio::Token(token), &url, writer, abort, callback, self.size_limit.clone()) { let httpreq = format!( "GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\nUser-Agent: {}/{}\r\nAccept-Encoding: identity\r\n\r\n", url.path(), url.hostname(), env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION") ); debug!("Requesting content: {}", httpreq); let _ = tlsclient.write(httpreq.as_bytes()); tlsclient.register(event_loop); self.sessions.insert(token, tlsclient); } } } } } #[test] fn should_successfuly_fetch_a_page() { use std::io::{self, Cursor}; use std::sync::{mpsc, Arc}; use std::sync::atomic::{AtomicUsize, AtomicBool, Ordering}; struct Writer { wrote: Arc, data: Cursor>, } impl io::Write for Writer { fn write(&mut self, buf: &[u8]) -> io::Result { let res = self.data.write(buf); if let Ok(count) = res { self.wrote.fetch_add(count, Ordering::Relaxed); } res } fn flush(&mut self) -> io::Result<()> { Ok(()) } } let client = Client::new().unwrap(); let wrote = Arc::new(AtomicUsize::new(0)); let writer = Writer { wrote: wrote.clone(), data: Cursor::new(Vec::new()), }; let (tx, rx) = mpsc::channel(); client.fetch(Url::new("github.com", 443, "/").unwrap(), Box::new(writer), Arc::new(AtomicBool::new(false)), move |result| { assert!(result.is_ok()); assert!(wrote.load(Ordering::Relaxed) > 0); tx.send(result).unwrap(); }).unwrap(); let _ = rx.recv().unwrap(); }