From ad854c3ee8b01c66d9bd61985ebbaf866e601415 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Fri, 4 Jan 2019 13:48:11 -0800 Subject: [PATCH] add Proxy::basic_auth support Closes #322 --- src/async_impl/client.rs | 64 ++++++++++++++++++++++++-- src/connect.rs | 93 ++++++++++++++++++++++++++++++------- src/proxy.rs | 99 +++++++++++++++++++++++++++++++++++++++- src/redirect.rs | 16 +++++-- tests/proxy.rs | 42 ++++++++++++++++- 5 files changed, 289 insertions(+), 25 deletions(-) diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index ca532bd..0ee60d4 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -4,10 +4,24 @@ use std::time::Duration; use bytes::Bytes; use futures::{Async, Future, Poll}; +use header::{ + HeaderMap, + HeaderValue, + ACCEPT, + ACCEPT_ENCODING, + CONTENT_LENGTH, + CONTENT_ENCODING, + CONTENT_TYPE, + LOCATION, + PROXY_AUTHORIZATION, + RANGE, + REFERER, + TRANSFER_ENCODING, + USER_AGENT, +}; +use http::Uri; use hyper::client::ResponseFuture; -use header::{HeaderMap, HeaderValue, LOCATION, USER_AGENT, REFERER, ACCEPT, - ACCEPT_ENCODING, RANGE, TRANSFER_ENCODING, CONTENT_TYPE, CONTENT_LENGTH, CONTENT_ENCODING}; -use mime::{self}; +use mime; #[cfg(feature = "default-tls")] use native_tls::TlsConnector; @@ -197,6 +211,10 @@ impl ClientBuilder { let hyper_client = ::hyper::Client::builder() .build(connector); + let proxies_maybe_http_auth = proxies + .iter() + .any(|p| p.maybe_has_http_auth()); + Ok(Client { inner: Arc::new(ClientRef { gzip: config.gzip, @@ -204,6 +222,8 @@ impl ClientBuilder { headers: config.headers, redirect_policy: config.redirect_policy, referer: config.referer, + proxies, + proxies_maybe_http_auth, }), }) } @@ -470,6 +490,8 @@ impl Client { } }; + self.proxy_auth(&uri, &mut headers); + let mut req = ::hyper::Request::builder() .method(method.clone()) .uri(uri.clone()) @@ -495,6 +517,40 @@ impl Client { }), } } + + fn proxy_auth(&self, dst: &Uri, headers: &mut HeaderMap) { + if !self.inner.proxies_maybe_http_auth { + return; + } + + // Only set the header here if the destination scheme is 'http', + // since otherwise, the header will be included in the CONNECT tunnel + // request instead. + if dst.scheme_part() != Some(&::http::uri::Scheme::HTTP) { + return; + } + + if headers.contains_key(PROXY_AUTHORIZATION) { + return; + } + + + for proxy in self.inner.proxies.iter() { + if proxy.is_match(dst) { + match proxy.auth() { + Some(::proxy::Auth::Basic(ref header)) => { + headers.insert( + PROXY_AUTHORIZATION, + header.clone() + ); + }, + None => (), + } + + break; + } + } + } } impl fmt::Debug for Client { @@ -520,6 +576,8 @@ struct ClientRef { hyper: HyperClient, redirect_policy: RedirectPolicy, referer: bool, + proxies: Arc>, + proxies_maybe_http_auth: bool, } pub struct Pending { diff --git a/src/connect.rs b/src/connect.rs index 2c95d0b..441e678 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -114,6 +114,9 @@ impl Connect for Connector { ndst.set_port(puri.port_part().map(|port| port.as_u16())); + #[cfg(feature = "tls")] + let auth = prox.auth().cloned(); + match &self.inner { #[cfg(feature = "default-tls")] Inner::DefaultTls(http, tls) => if dst.scheme() == "https" { @@ -125,7 +128,7 @@ impl Connect for Connector { let tls = tls.clone(); return Box::new(http.connect(ndst).and_then(move |(conn, connected)| { trace!("tunneling HTTPS over proxy"); - tunnel(conn, host.clone(), port) + tunnel(conn, host.clone(), port, auth) .and_then(move |tunneled| { tls.connect_async(&host, tunneled) .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) @@ -148,7 +151,7 @@ impl Connect for Connector { let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) .map(|dnsname| dnsname.to_owned()) .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")); - tunnel(conn, host, port) + tunnel(conn, host, port, auth) .and_then(move |tunneled| Ok((maybe_dnsname?, tunneled))) .and_then(move |(dnsname, tunneled)| { RustlsConnector::from(tls).connect(dnsname.as_ref(), tunneled) @@ -176,18 +179,30 @@ pub(crate) type Conn = Box; pub(crate) type Connecting = Box + Send>; #[cfg(feature = "tls")] -fn tunnel(conn: T, host: String, port: u16) -> Tunnel { - let buf = format!("\ +fn tunnel(conn: T, host: String, port: u16, auth: Option<::proxy::Auth>) -> Tunnel { + let mut buf = format!("\ CONNECT {0}:{1} HTTP/1.1\r\n\ Host: {0}:{1}\r\n\ - \r\n\ ", host, port).into_bytes(); - Tunnel { + match auth { + Some(::proxy::Auth::Basic(value)) => { + debug!("tunnel to {}:{} using basic auth", host, port); + buf.extend_from_slice(b"Proxy-Authorization: "); + buf.extend_from_slice(value.as_bytes()); + buf.extend_from_slice(b"\r\n"); + }, + None => (), + } + + // headers end + buf.extend_from_slice(b"\r\n"); + + Tunnel { buf: io::Cursor::new(buf), conn: Some(conn), state: TunnelState::Writing, - } + } } #[cfg(feature = "tls")] @@ -230,6 +245,8 @@ where T: AsyncRead + AsyncWrite { return Ok(self.conn.take().unwrap().into()); } // else read more + } else if read.starts_with(b"HTTP/1.1 407") { + return Err(io::Error::new(io::ErrorKind::Other, "proxy authentication required")); } else { return Err(io::Error::new(io::ErrorKind::Other, "unsuccessful tunnel")); } @@ -258,23 +275,29 @@ mod tests { use tokio::runtime::current_thread::Runtime; use tokio::net::TcpStream; use super::tunnel; + use proxy; + static TUNNEL_OK: &'static [u8] = b"\ + HTTP/1.1 200 OK\r\n\ + \r\n\ + "; macro_rules! mock_tunnel { () => ({ - mock_tunnel!(b"\ - HTTP/1.1 200 OK\r\n\ - \r\n\ - ") + mock_tunnel!(TUNNEL_OK) }); ($write:expr) => ({ + mock_tunnel!($write, "") + }); + ($write:expr, $auth:expr) => ({ let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); let connect_expected = format!("\ CONNECT {0}:{1} HTTP/1.1\r\n\ Host: {0}:{1}\r\n\ + {2}\ \r\n\ - ", addr.ip(), addr.port()).into_bytes(); + ", addr.ip(), addr.port(), $auth).into_bytes(); thread::spawn(move || { let (mut sock, _) = listener.accept().unwrap(); @@ -297,7 +320,7 @@ mod tests { let host = addr.ip().to_string(); let port = addr.port(); let work = work.and_then(|tcp| { - tunnel(tcp, host, port) + tunnel(tcp, host, port, None) }); rt.block_on(work).unwrap(); @@ -312,14 +335,14 @@ mod tests { let host = addr.ip().to_string(); let port = addr.port(); let work = work.and_then(|tcp| { - tunnel(tcp, host, port) + tunnel(tcp, host, port, None) }); rt.block_on(work).unwrap_err(); } #[test] - fn test_tunnel_bad_response() { + fn test_tunnel_non_http_response() { let addr = mock_tunnel!(b"foo bar baz hallo"); let mut rt = Runtime::new().unwrap(); @@ -327,9 +350,47 @@ mod tests { let host = addr.ip().to_string(); let port = addr.port(); let work = work.and_then(|tcp| { - tunnel(tcp, host, port) + tunnel(tcp, host, port, None) }); rt.block_on(work).unwrap_err(); } + + #[test] + fn test_tunnel_proxy_unauthorized() { + let addr = mock_tunnel!(b"\ + HTTP/1.1 407 Proxy Authentication Required\r\n\ + Proxy-Authenticate: Basic realm=\"nope\"\r\n\ + \r\n\ + "); + + let mut rt = Runtime::new().unwrap(); + let work = TcpStream::connect(&addr); + let host = addr.ip().to_string(); + let port = addr.port(); + let work = work.and_then(|tcp| { + tunnel(tcp, host, port, None) + }); + + let error = rt.block_on(work).unwrap_err(); + assert_eq!(error.to_string(), "proxy authentication required"); + } + + #[test] + fn test_tunnel_basic_auth() { + let addr = mock_tunnel!( + TUNNEL_OK, + "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n" + ); + + let mut rt = Runtime::new().unwrap(); + let work = TcpStream::connect(&addr); + let host = addr.ip().to_string(); + let port = addr.port(); + let work = work.and_then(|tcp| { + tunnel(tcp, host, port, Some(proxy::Auth::basic("Aladdin", "open sesame"))) + }); + + rt.block_on(work).unwrap(); + } } diff --git a/src/proxy.rs b/src/proxy.rs index 3d62722..1845d90 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -1,6 +1,7 @@ use std::fmt; use std::sync::Arc; +use http::{header::HeaderValue, Uri}; use hyper::client::connect::Destination; use {into_url, IntoUrl, Url}; @@ -30,9 +31,15 @@ use {into_url, IntoUrl, Url}; /// would prevent a `Proxy` later in the list from ever working, so take care. #[derive(Clone, Debug)] pub struct Proxy { + auth: Option, intercept: Intercept, } +#[derive(Clone, Debug)] +pub(crate) enum Auth { + Basic(HeaderValue), +} + impl Proxy { /// Proxy all HTTP traffic to the passed URL. /// @@ -124,7 +131,43 @@ impl Proxy { fn new(intercept: Intercept) -> Proxy { Proxy { - intercept: intercept, + auth: None, + intercept, + } + } + + /// Set the `Proxy-Authorization` header using Basic auth. + /// + /// # Example + /// + /// ``` + /// # extern crate reqwest; + /// # fn run() -> Result<(), Box<::std::error::Error>> { + /// let proxy = reqwest::Proxy::https("http://localhost:1234")? + /// .basic_auth("Aladdin", "open sesame"); + /// # Ok(()) + /// # } + /// # fn main() {} + /// ``` + pub fn basic_auth(mut self, username: &str, password: &str) -> Proxy { + self.auth = Some(Auth::basic(username, password)); + self + } + + pub(crate) fn auth(&self) -> Option<&Auth> { + self.auth.as_ref() + } + + pub(crate) fn maybe_has_http_auth(&self) -> bool { + match self.auth { + Some(Auth::Basic(_)) => match self.intercept { + Intercept::All(_) | + Intercept::Http(_) | + // Custom *may* match 'http', so assume so. + Intercept::Custom(_) => true, + Intercept::Https(_) => false, + }, + None => false, } } @@ -161,6 +204,31 @@ impl Proxy { }, } } + + pub(crate) fn is_match(&self, uri: &D) -> bool { + match self.intercept { + Intercept::All(_) => true, + Intercept::Http(_) => { + uri.scheme() == "http" + }, + Intercept::Https(_) => { + uri.scheme() == "https" + }, + Intercept::Custom(ref fun) => { + (fun.0)( + &format!( + "{}://{}{}{}", + uri.scheme(), + uri.host(), + uri.port().map(|_| ":").unwrap_or(""), + uri.port().map(|p| p.to_string()).unwrap_or(String::new()) + ) + .parse() + .expect("should be valid Url") + ).is_some() + }, + } + } } #[derive(Clone, Debug)] @@ -203,6 +271,35 @@ impl Dst for Destination { } } +#[doc(hidden)] +impl Dst for Uri { + fn scheme(&self) -> &str { + self.scheme_part() + .expect("Uri should have a scheme") + .as_str() + } + + fn host(&self) -> &str { + Uri::host(self) + .expect("::host should have a str") + } + + fn port(&self) -> Option { + self.port_part().map(|p| p.as_u16()) + } +} + +impl Auth { + pub(crate) fn basic(username: &str, password: &str) -> Auth { + let val = format!("{}:{}", username, password); + let mut header = format!("Basic {}", base64::encode(&val)) + .parse::() + .expect("base64 is always valid HeaderValue"); + header.set_sensitive(true); + Auth::Basic(header) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/redirect.rs b/src/redirect.rs index d3352fa..1c03740 100644 --- a/src/redirect.rs +++ b/src/redirect.rs @@ -1,6 +1,13 @@ use std::fmt; -use header::HeaderMap; +use header::{ + HeaderMap, + AUTHORIZATION, + COOKIE, + PROXY_AUTHORIZATION, + WWW_AUTHENTICATE, + +}; use hyper::StatusCode; use Url; @@ -233,10 +240,11 @@ pub(crate) fn remove_sensitive_headers(headers: &mut HeaderMap, next: &Url, prev let cross_host = next.host_str() != previous.host_str() || next.port_or_known_default() != previous.port_or_known_default(); if cross_host { - headers.remove("authorization"); - headers.remove("cookie"); + headers.remove(AUTHORIZATION); + headers.remove(COOKIE); headers.remove("cookie2"); - headers.remove("www-authenticate"); + headers.remove(PROXY_AUTHORIZATION); + headers.remove(WWW_AUTHENTICATE); } } } diff --git a/tests/proxy.rs b/tests/proxy.rs index c805749..ae5a6d4 100644 --- a/tests/proxy.rs +++ b/tests/proxy.rs @@ -4,7 +4,7 @@ extern crate reqwest; mod support; #[test] -fn test_http_proxy() { +fn http_proxy() { let server = server! { request: b"\ GET http://hyper.rs/prox HTTP/1.1\r\n\ @@ -37,3 +37,43 @@ fn test_http_proxy() { assert_eq!(res.status(), reqwest::StatusCode::OK); assert_eq!(res.headers().get(reqwest::header::SERVER).unwrap(), &"proxied"); } + +#[test] +fn http_proxy_basic_auth() { + let server = server! { + request: b"\ + GET http://hyper.rs/prox HTTP/1.1\r\n\ + user-agent: $USERAGENT\r\n\ + accept: */*\r\n\ + accept-encoding: gzip\r\n\ + proxy-authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n\ + host: hyper.rs\r\n\ + \r\n\ + ", + response: b"\ + HTTP/1.1 200 OK\r\n\ + Server: proxied\r\n\ + Content-Length: 0\r\n\ + \r\n\ + " + }; + + let proxy = format!("http://{}", server.addr()); + + let url = "http://hyper.rs/prox"; + let res = reqwest::Client::builder() + .proxy( + reqwest::Proxy::http(&proxy) + .unwrap() + .basic_auth("Aladdin", "open sesame") + ) + .build() + .unwrap() + .get(url) + .send() + .unwrap(); + + assert_eq!(res.url().as_str(), url); + assert_eq!(res.status(), reqwest::StatusCode::OK); + assert_eq!(res.headers().get(reqwest::header::SERVER).unwrap(), &"proxied"); +}