diff --git a/Cargo.toml b/Cargo.toml index a69d0e8..89dc44f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ categories = ["web-programming::http-client"] bytes = "0.4" futures = "0.1.14" hyper = "0.11" -hyper-tls = "0.1" +hyper-tls = "0.1.1" libflate = "0.1.5" log = "0.3" native-tls = "0.1" @@ -21,6 +21,7 @@ serde = "1.0" serde_json = "1.0" serde_urlencoded = "0.5" tokio-core = "0.1.6" +tokio-io = "0.1" url = "1.2" [dev-dependencies] diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 05e4315..5409c65 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -14,8 +14,10 @@ use tokio_core::reactor::Handle; use super::body; use super::request::{self, Request, RequestBuilder}; use super::response::{self, Response}; +use connect::Connector; +use into_url::to_uri; use redirect::{self, RedirectPolicy, check_redirect, remove_sensitive_headers}; -use {Certificate, IntoUrl, Method, StatusCode, Url}; +use {Certificate, IntoUrl, Method, proxy, Proxy, StatusCode, Url}; static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); @@ -79,6 +81,7 @@ pub struct ClientBuilder { struct Config { gzip: bool, hostname_verification: bool, + proxies: Vec, redirect_policy: RedirectPolicy, referer: bool, timeout: Option, @@ -97,6 +100,7 @@ impl ClientBuilder { config: Some(Config { gzip: true, hostname_verification: true, + proxies: Vec::new(), redirect_policy: RedirectPolicy::default(), referer: true, timeout: None, @@ -127,7 +131,9 @@ impl ClientBuilder { } */ - let hyper_client = create_hyper_client(tls, handle); + let proxies = Arc::new(config.proxies); + + let hyper_client = create_hyper_client(tls, proxies.clone(), handle); //let mut hyper_client = create_hyper_client(tls_client); //hyper_client.set_read_timeout(config.timeout); @@ -137,6 +143,7 @@ impl ClientBuilder { inner: Arc::new(ClientRef { gzip: config.gzip, hyper: hyper_client, + proxies: proxies, redirect_policy: config.redirect_policy, referer: config.referer, }), @@ -187,6 +194,13 @@ impl ClientBuilder { self } + /// Add a `Proxy` to the list of proxies the `Client` will use. + #[inline] + pub fn proxy(&mut self, proxy: Proxy) -> &mut ClientBuilder { + self.config_mut().proxies.push(proxy); + self + } + /// Set a `RedirectPolicy` for this client. /// /// Default will follow redirects up to a maximum of 10. @@ -226,14 +240,11 @@ impl ClientBuilder { } } -type HyperClient = ::hyper::Client<::hyper_tls::HttpsConnector<::hyper::client::HttpConnector>>; +type HyperClient = ::hyper::Client; -fn create_hyper_client(tls: TlsConnector, handle: &Handle) -> HyperClient { - let mut http = ::hyper::client::HttpConnector::new(4, handle); - http.enforce_http(false); - let https = ::hyper_tls::HttpsConnector::from((http, tls)); +fn create_hyper_client(tls: TlsConnector, proxies: Arc>, handle: &Handle) -> HyperClient { ::hyper::Client::configure() - .connector(https) + .connector(Connector::new(tls, proxies, handle)) .build(handle) } @@ -363,7 +374,8 @@ impl Client { headers.set(AcceptEncoding(vec![qitem(Encoding::Gzip)])); } - let mut req = ::hyper::Request::new(method.clone(), url_to_uri(&url)); + let uri = to_uri(&url); + let mut req = ::hyper::Request::new(method.clone(), uri.clone()); *req.headers_mut() = headers.clone(); let body = body.and_then(|body| { let (resuable, body) = body::into_hyper(body); @@ -371,6 +383,10 @@ impl Client { resuable }); + if proxy::is_proxied(&self.inner.proxies, &uri) { + req.set_proxy(true); + } + let in_flight = self.inner.hyper.request(req); Pending { @@ -408,6 +424,7 @@ impl fmt::Debug for ClientBuilder { struct ClientRef { gzip: bool, hyper: HyperClient, + proxies: Arc>, redirect_policy: RedirectPolicy, referer: bool, } @@ -473,14 +490,18 @@ impl Future for Pending { remove_sensitive_headers(&mut self.headers, &self.url, &self.urls); debug!("redirecting to {:?} '{}'", self.method, self.url); + let uri = to_uri(&self.url); let mut req = ::hyper::Request::new( self.method.clone(), - url_to_uri(&self.url) + uri.clone() ); *req.headers_mut() = self.headers.clone(); if let Some(ref body) = self.body { req.set_body(body.clone()); } + if proxy::is_proxied(&self.client.proxies, &uri) { + req.set_proxy(true); + } self.in_flight = self.client.hyper.request(req); continue; }, @@ -525,10 +546,6 @@ fn make_referer(next: &Url, previous: &Url) -> Option { Some(Referer::new(referer.into_string())) } -fn url_to_uri(url: &Url) -> ::hyper::Uri { - url.as_str().parse().expect("a parsed Url should always be a valid Uri") -} - // pub(crate) pub fn take_builder(builder: &mut ClientBuilder) -> ClientBuilder { diff --git a/src/client.rs b/src/client.rs index 9b9fd12..a3f57a7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,7 +7,7 @@ use futures::sync::{mpsc, oneshot}; use request::{self, Request, RequestBuilder}; use response::{self, Response}; -use {async_impl, Certificate, Method, IntoUrl, RedirectPolicy, wait}; +use {async_impl, Certificate, Method, IntoUrl, Proxy, RedirectPolicy, wait}; /// A `Client` to make Requests with. /// @@ -141,6 +141,13 @@ impl ClientBuilder { self } + /// Add a `Proxy` to the list of proxies the `Client` will use. + #[inline] + pub fn proxy(&mut self, proxy: Proxy) -> &mut ClientBuilder { + self.inner.proxy(proxy); + self + } + /// Set a `RedirectPolicy` for this client. /// /// Default will follow redirects up to a maximum of 10. diff --git a/src/connect.rs b/src/connect.rs new file mode 100644 index 0000000..cf815c6 --- /dev/null +++ b/src/connect.rs @@ -0,0 +1,200 @@ +use bytes::{BufMut, IntoBuf}; +use futures::{Async, Future, Poll}; +use hyper::client::{HttpConnector, Service}; +use hyper::Uri; +use hyper_tls::{/*HttpsConnecting,*/ HttpsConnector, MaybeHttpsStream}; +use native_tls::TlsConnector; +use tokio_core::reactor::Handle; +use tokio_io::{AsyncRead, AsyncWrite}; + +use std::io::{self, Cursor}; +use std::sync::Arc; + +use {proxy, Proxy}; + +// pub(crate) + +pub struct Connector { + https: HttpsConnector, + proxies: Arc>, +} + +impl Connector { + pub fn new(tls: TlsConnector, proxies: Arc>, handle: &Handle) -> Connector { + let mut http = HttpConnector::new(4, handle); + http.enforce_http(false); + let https = HttpsConnector::from((http, tls)); + + Connector { + https: https, + proxies: proxies, + } + } +} + +impl Service for Connector { + type Request = Uri; + type Response = Conn; + type Error = io::Error; + type Future = Connecting; + + fn call(&self, uri: Uri) -> Self::Future { + for prox in self.proxies.iter() { + if let Some(puri) = proxy::proxies(prox, &uri) { + if uri.scheme() == Some("https") { + let host = uri.authority().unwrap().to_owned(); + return Box::new(self.https.call(puri).and_then(|conn| { + tunnel(conn, host) + })); + } + return Box::new(self.https.call(puri)); + } + } + Box::new(self.https.call(uri)) + } +} + +pub type Conn = MaybeHttpsStream<::Response>; +pub type Connecting = Box>; + +fn tunnel(conn: T, host: String) -> Tunnel { + let buf = format!("\ + CONNECT {0} HTTP/1.1\r\n\ + Host: {0}\r\n\ + \r\n\ + ", host).into_bytes(); + + Tunnel { + buf: buf.into_buf(), + conn: Some(conn), + state: TunnelState::Writing, + } +} + +struct Tunnel { + buf: Cursor>, + conn: Option, + state: TunnelState, +} + +enum TunnelState { + Writing, + Reading +} + +impl Future for Tunnel +where T: AsyncRead + AsyncWrite { + type Item = T; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + loop { + if let TunnelState::Writing = self.state { + let n = try_ready!(self.conn.as_mut().unwrap().write_buf(&mut self.buf)); + if !self.buf.has_remaining_mut() { + self.state = TunnelState::Reading; + self.buf.get_mut().truncate(0); + } else if n == 0 { + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected eof while tunneling")); + } + } else { + let n = try_ready!(self.conn.as_mut().unwrap().read_buf(&mut self.buf.get_mut())); + let read = &self.buf.get_ref()[..]; + if n == 0 { + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected eof while tunneling")); + } else if read.len() > 12 { + if read.starts_with(b"HTTP/1.1 200") { + if read.ends_with(b"\r\n\r\n") { + return Ok(Async::Ready(self.conn.take().unwrap())); + } + // else read more + } else { + return Err(io::Error::new(io::ErrorKind::Other, "unsuccessful tunnel")); + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::io::{Read, Write}; + use std::net::TcpListener; + use std::thread; + use futures::Future; + use tokio_core::reactor::Core; + use tokio_core::net::TcpStream; + use super::tunnel; + + + macro_rules! mock_tunnel { + () => ({ + mock_tunnel!(b"\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + ") + }); + ($write:expr) => ({ + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + let connect_expected = format!("\ + CONNECT {0} HTTP/1.1\r\n\ + Host: {0}\r\n\ + \r\n\ + ", addr).into_bytes(); + + thread::spawn(move || { + let (mut sock, _) = listener.accept().unwrap(); + let mut buf = [0u8; 4096]; + let n = sock.read(&mut buf).unwrap(); + assert_eq!(&buf[..n], &connect_expected[..]); + + sock.write_all($write).unwrap(); + }); + addr + }) + } + + #[test] + fn test_tunnel() { + let addr = mock_tunnel!(); + + let mut core = Core::new().unwrap(); + let work = TcpStream::connect(&addr, &core.handle()); + let host = addr.to_string(); + let work = work.and_then(|tcp| { + tunnel(tcp, host) + }); + + core.run(work).unwrap(); + } + + #[test] + fn test_tunnel_eof() { + let addr = mock_tunnel!(b"HTTP/1.1 200 OK"); + + let mut core = Core::new().unwrap(); + let work = TcpStream::connect(&addr, &core.handle()); + let host = addr.to_string(); + let work = work.and_then(|tcp| { + tunnel(tcp, host) + }); + + core.run(work).unwrap_err(); + } + + #[test] + fn test_tunnel_bad_response() { + let addr = mock_tunnel!(b"foo bar baz hallo"); + + let mut core = Core::new().unwrap(); + let work = TcpStream::connect(&addr, &core.handle()); + let host = addr.to_string(); + let work = work.and_then(|tcp| { + tunnel(tcp, host) + }); + + core.run(work).unwrap_err(); + } +} diff --git a/src/into_url.rs b/src/into_url.rs index 5c50395..80de1bd 100644 --- a/src/into_url.rs +++ b/src/into_url.rs @@ -32,3 +32,7 @@ impl<'a> PolyfillTryInto for &'a String { Url::parse(self) } } + +pub fn to_uri(url: &Url) -> ::hyper::Uri { + url.as_str().parse().expect("a parsed Url should always be a valid Uri") +} diff --git a/src/lib.rs b/src/lib.rs index 5d2955d..1e2525c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -131,6 +131,7 @@ extern crate serde; extern crate serde_json; extern crate serde_urlencoded; extern crate tokio_core; +extern crate tokio_io; extern crate url; pub use hyper::header; @@ -144,6 +145,7 @@ pub use self::client::{Client, ClientBuilder}; pub use self::error::{Error, Result}; pub use self::body::Body; pub use self::into_url::IntoUrl; +pub use self::proxy::Proxy; pub use self::redirect::{RedirectAction, RedirectAttempt, RedirectPolicy}; pub use self::request::{Request, RequestBuilder}; pub use self::response::Response; @@ -179,9 +181,11 @@ pub mod unstable { mod async_impl; +mod connect; mod body; mod client; mod into_url; +mod proxy; mod redirect; mod request; mod response; diff --git a/src/proxy.rs b/src/proxy.rs new file mode 100644 index 0000000..229dd10 --- /dev/null +++ b/src/proxy.rs @@ -0,0 +1,190 @@ +use hyper::Uri; +use {IntoUrl}; + +/// Configuration of a proxy that a `Client` should pass requests to. +/// +/// A `Proxy` has a couple pieces to it: +/// +/// - a URL of how to talk to the proxy +/// - rules on what `Client` requests should be directed to the proxy +/// +/// For instance, let's look at `Proxy::http`: +/// +/// ``` +/// # extern crate reqwest; +/// # fn run() -> Result<(), Box<::std::error::Error>> { +/// let proxy = reqwest::Proxy::http("https://secure.example")?; +/// # Ok(()) +/// # } +/// # fn main() {} +/// ``` +/// +/// This proxy will intercept all HTTP requests, and make use of the proxy +/// at `https://secure.example`. A request to `http://hyper.rs` will talk +/// to your proxy. A request to `https://hyper.rs` will not. +/// +/// Multiple `Proxy` rules can be configured for a `Client`. The `Client` will +/// check each `Proxy` in the order it was added. This could mean that a +/// `Proxy` added first with eager intercept rules, such as `Proxy::all`, +/// would prevent a `Proxy` later in the list from ever working, so take care. +#[derive(Clone, Debug)] +pub struct Proxy { + intercept: Intercept, + uri: Uri, +} + +impl Proxy { + /// Proxy all HTTP traffic to the passed URL. + /// + /// # Example + /// + /// ``` + /// # extern crate reqwest; + /// # fn run() -> Result<(), Box<::std::error::Error>> { + /// let client = reqwest::Client::builder()? + /// .proxy(reqwest::Proxy::http("https://my.prox")?) + /// .build()?; + /// # Ok(()) + /// # } + /// # fn main() {} + /// ``` + pub fn http(url: U) -> ::Result { + Proxy::new(Intercept::Http, url) + } + + /// Proxy all HTTPS traffic to the passed URL. + /// + /// # Example + /// + /// ``` + /// # extern crate reqwest; + /// # fn run() -> Result<(), Box<::std::error::Error>> { + /// let client = reqwest::Client::builder()? + /// .proxy(reqwest::Proxy::https("https://example.prox:4545")?) + /// .build()?; + /// # Ok(()) + /// # } + /// # fn main() {} + /// ``` + pub fn https(url: U) -> ::Result { + Proxy::new(Intercept::Https, url) + } + + /// Proxy **all** traffic to the passed URL. + /// + /// # Example + /// + /// ``` + /// # extern crate reqwest; + /// # fn run() -> Result<(), Box<::std::error::Error>> { + /// let client = reqwest::Client::builder()? + /// .proxy(reqwest::Proxy::all("http://pro.xy")?) + /// .build()?; + /// # Ok(()) + /// # } + /// # fn main() {} + /// ``` + pub fn all(url: U) -> ::Result { + Proxy::new(Intercept::All, url) + } + + /* + pub fn unix(path: P) -> Proxy { + + } + */ + + fn new(intercept: Intercept, url: U) -> ::Result { + let uri = ::into_url::to_uri(&try_!(url.into_url())); + Ok(Proxy { + intercept: intercept, + uri: uri, + }) + } + + fn proxies(&self, uri: &Uri) -> bool { + match self.intercept { + Intercept::All => true, + Intercept::Http => uri.scheme() == Some("http"), + Intercept::Https => uri.scheme() == Some("https"), + } + } +} + +#[derive(Clone, Debug)] +enum Intercept { + All, + Http, + Https, +} + +// pub(crate) + +pub fn proxies(proxy: &Proxy, uri: &Uri) -> Option { + if proxy.proxies(uri) { + Some(proxy.uri.clone()) + } else { + None + } +} + +pub fn is_proxied(proxies: &[Proxy], uri: &Uri) -> bool { + proxies.iter().any(|p| p.proxies(uri)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_http() { + let p = Proxy::http("http://example.domain").unwrap(); + + let http = "http://hyper.rs".parse().unwrap(); + let other = "https://hyper.rs".parse().unwrap(); + + assert!(p.proxies(&http)); + assert!(!p.proxies(&other)); + } + + #[test] + fn test_https() { + let p = Proxy::https("http://example.domain").unwrap(); + + let http = "http://hyper.rs".parse().unwrap(); + let other = "https://hyper.rs".parse().unwrap(); + + assert!(!p.proxies(&http)); + assert!(p.proxies(&other)); + } + + #[test] + fn test_all() { + let p = Proxy::all("http://example.domain").unwrap(); + + let http = "http://hyper.rs".parse().unwrap(); + let https = "https://hyper.rs".parse().unwrap(); + let other = "x-youve-never-heard-of-me-mr-proxy://hyper.rs".parse().unwrap(); + + assert!(p.proxies(&http)); + assert!(p.proxies(&https)); + assert!(p.proxies(&other)); + } + + #[test] + fn test_is_proxied() { + let proxies = vec![ + Proxy::http("http://example.domain").unwrap(), + Proxy::https("http://other.domain").unwrap(), + ]; + + let http = "http://hyper.rs".parse().unwrap(); + let https = "https://hyper.rs".parse().unwrap(); + let other = "x-other://hyper.rs".parse().unwrap(); + + assert!(is_proxied(&proxies, &http)); + assert!(is_proxied(&proxies, &https)); + assert!(!is_proxied(&proxies, &other)); + } + +} diff --git a/tests/proxy.rs b/tests/proxy.rs new file mode 100644 index 0000000..c75acfa --- /dev/null +++ b/tests/proxy.rs @@ -0,0 +1,42 @@ +extern crate reqwest; + +#[macro_use] +mod support; + +#[test] +fn test_http_proxy() { + let server = server! { + request: b"\ + GET http://hyper.rs/prox HTTP/1.1\r\n\ + Host: hyper.rs\r\n\ + User-Agent: $USERAGENT\r\n\ + Accept: */*\r\n\ + Accept-Encoding: gzip\r\n\ + \r\n\ + ", + response: b"\ + HTTP/1.1 200 OK\r\n\ + Server: proxied\r\n\ + Content-Length: 0\r\n\ + \r\n\ + " + }; + + let proxy = format!("http://{}", server.addr()); + + let url = "http://hyper.rs/prox"; + let res = reqwest::Client::builder() + .unwrap() + .proxy(reqwest::Proxy::http(&proxy).unwrap()) + .build() + .unwrap() + .get(url) + .unwrap() + .send() + .unwrap(); + + assert_eq!(res.url().as_str(), url); + assert_eq!(res.status(), reqwest::StatusCode::Ok); + assert_eq!(res.headers().get(), + Some(&reqwest::header::Server::new("proxied"))); +}