add Proxy::basic_auth support

Closes #322
This commit is contained in:
Sean McArthur
2019-01-04 13:48:11 -08:00
parent b9f4661332
commit ad854c3ee8
5 changed files with 289 additions and 25 deletions

View File

@@ -4,10 +4,24 @@ use std::time::Duration;
use bytes::Bytes;
use futures::{Async, Future, Poll};
use header::{
HeaderMap,
HeaderValue,
ACCEPT,
ACCEPT_ENCODING,
CONTENT_LENGTH,
CONTENT_ENCODING,
CONTENT_TYPE,
LOCATION,
PROXY_AUTHORIZATION,
RANGE,
REFERER,
TRANSFER_ENCODING,
USER_AGENT,
};
use http::Uri;
use hyper::client::ResponseFuture;
use header::{HeaderMap, HeaderValue, LOCATION, USER_AGENT, REFERER, ACCEPT,
ACCEPT_ENCODING, RANGE, TRANSFER_ENCODING, CONTENT_TYPE, CONTENT_LENGTH, CONTENT_ENCODING};
use mime::{self};
use mime;
#[cfg(feature = "default-tls")]
use native_tls::TlsConnector;
@@ -197,6 +211,10 @@ impl ClientBuilder {
let hyper_client = ::hyper::Client::builder()
.build(connector);
let proxies_maybe_http_auth = proxies
.iter()
.any(|p| p.maybe_has_http_auth());
Ok(Client {
inner: Arc::new(ClientRef {
gzip: config.gzip,
@@ -204,6 +222,8 @@ impl ClientBuilder {
headers: config.headers,
redirect_policy: config.redirect_policy,
referer: config.referer,
proxies,
proxies_maybe_http_auth,
}),
})
}
@@ -470,6 +490,8 @@ impl Client {
}
};
self.proxy_auth(&uri, &mut headers);
let mut req = ::hyper::Request::builder()
.method(method.clone())
.uri(uri.clone())
@@ -495,6 +517,40 @@ impl Client {
}),
}
}
fn proxy_auth(&self, dst: &Uri, headers: &mut HeaderMap) {
if !self.inner.proxies_maybe_http_auth {
return;
}
// Only set the header here if the destination scheme is 'http',
// since otherwise, the header will be included in the CONNECT tunnel
// request instead.
if dst.scheme_part() != Some(&::http::uri::Scheme::HTTP) {
return;
}
if headers.contains_key(PROXY_AUTHORIZATION) {
return;
}
for proxy in self.inner.proxies.iter() {
if proxy.is_match(dst) {
match proxy.auth() {
Some(::proxy::Auth::Basic(ref header)) => {
headers.insert(
PROXY_AUTHORIZATION,
header.clone()
);
},
None => (),
}
break;
}
}
}
}
impl fmt::Debug for Client {
@@ -520,6 +576,8 @@ struct ClientRef {
hyper: HyperClient,
redirect_policy: RedirectPolicy,
referer: bool,
proxies: Arc<Vec<Proxy>>,
proxies_maybe_http_auth: bool,
}
pub struct Pending {

View File

@@ -114,6 +114,9 @@ impl Connect for Connector {
ndst.set_port(puri.port_part().map(|port| port.as_u16()));
#[cfg(feature = "tls")]
let auth = prox.auth().cloned();
match &self.inner {
#[cfg(feature = "default-tls")]
Inner::DefaultTls(http, tls) => if dst.scheme() == "https" {
@@ -125,7 +128,7 @@ impl Connect for Connector {
let tls = tls.clone();
return Box::new(http.connect(ndst).and_then(move |(conn, connected)| {
trace!("tunneling HTTPS over proxy");
tunnel(conn, host.clone(), port)
tunnel(conn, host.clone(), port, auth)
.and_then(move |tunneled| {
tls.connect_async(&host, tunneled)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
@@ -148,7 +151,7 @@ impl Connect for Connector {
let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
.map(|dnsname| dnsname.to_owned())
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"));
tunnel(conn, host, port)
tunnel(conn, host, port, auth)
.and_then(move |tunneled| Ok((maybe_dnsname?, tunneled)))
.and_then(move |(dnsname, tunneled)| {
RustlsConnector::from(tls).connect(dnsname.as_ref(), tunneled)
@@ -176,13 +179,25 @@ pub(crate) type Conn = Box<dyn AsyncConn + Send + Sync + 'static>;
pub(crate) type Connecting = Box<Future<Item=(Conn, Connected), Error=io::Error> + Send>;
#[cfg(feature = "tls")]
fn tunnel<T>(conn: T, host: String, port: u16) -> Tunnel<T> {
let buf = format!("\
fn tunnel<T>(conn: T, host: String, port: u16, auth: Option<::proxy::Auth>) -> Tunnel<T> {
let mut buf = format!("\
CONNECT {0}:{1} HTTP/1.1\r\n\
Host: {0}:{1}\r\n\
\r\n\
", host, port).into_bytes();
match auth {
Some(::proxy::Auth::Basic(value)) => {
debug!("tunnel to {}:{} using basic auth", host, port);
buf.extend_from_slice(b"Proxy-Authorization: ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
},
None => (),
}
// headers end
buf.extend_from_slice(b"\r\n");
Tunnel {
buf: io::Cursor::new(buf),
conn: Some(conn),
@@ -230,6 +245,8 @@ where T: AsyncRead + AsyncWrite {
return Ok(self.conn.take().unwrap().into());
}
// else read more
} else if read.starts_with(b"HTTP/1.1 407") {
return Err(io::Error::new(io::ErrorKind::Other, "proxy authentication required"));
} else {
return Err(io::Error::new(io::ErrorKind::Other, "unsuccessful tunnel"));
}
@@ -258,23 +275,29 @@ mod tests {
use tokio::runtime::current_thread::Runtime;
use tokio::net::TcpStream;
use super::tunnel;
use proxy;
static TUNNEL_OK: &'static [u8] = b"\
HTTP/1.1 200 OK\r\n\
\r\n\
";
macro_rules! mock_tunnel {
() => ({
mock_tunnel!(b"\
HTTP/1.1 200 OK\r\n\
\r\n\
")
mock_tunnel!(TUNNEL_OK)
});
($write:expr) => ({
mock_tunnel!($write, "")
});
($write:expr, $auth:expr) => ({
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let connect_expected = format!("\
CONNECT {0}:{1} HTTP/1.1\r\n\
Host: {0}:{1}\r\n\
{2}\
\r\n\
", addr.ip(), addr.port()).into_bytes();
", addr.ip(), addr.port(), $auth).into_bytes();
thread::spawn(move || {
let (mut sock, _) = listener.accept().unwrap();
@@ -297,7 +320,7 @@ mod tests {
let host = addr.ip().to_string();
let port = addr.port();
let work = work.and_then(|tcp| {
tunnel(tcp, host, port)
tunnel(tcp, host, port, None)
});
rt.block_on(work).unwrap();
@@ -312,14 +335,14 @@ mod tests {
let host = addr.ip().to_string();
let port = addr.port();
let work = work.and_then(|tcp| {
tunnel(tcp, host, port)
tunnel(tcp, host, port, None)
});
rt.block_on(work).unwrap_err();
}
#[test]
fn test_tunnel_bad_response() {
fn test_tunnel_non_http_response() {
let addr = mock_tunnel!(b"foo bar baz hallo");
let mut rt = Runtime::new().unwrap();
@@ -327,9 +350,47 @@ mod tests {
let host = addr.ip().to_string();
let port = addr.port();
let work = work.and_then(|tcp| {
tunnel(tcp, host, port)
tunnel(tcp, host, port, None)
});
rt.block_on(work).unwrap_err();
}
#[test]
fn test_tunnel_proxy_unauthorized() {
let addr = mock_tunnel!(b"\
HTTP/1.1 407 Proxy Authentication Required\r\n\
Proxy-Authenticate: Basic realm=\"nope\"\r\n\
\r\n\
");
let mut rt = Runtime::new().unwrap();
let work = TcpStream::connect(&addr);
let host = addr.ip().to_string();
let port = addr.port();
let work = work.and_then(|tcp| {
tunnel(tcp, host, port, None)
});
let error = rt.block_on(work).unwrap_err();
assert_eq!(error.to_string(), "proxy authentication required");
}
#[test]
fn test_tunnel_basic_auth() {
let addr = mock_tunnel!(
TUNNEL_OK,
"Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n"
);
let mut rt = Runtime::new().unwrap();
let work = TcpStream::connect(&addr);
let host = addr.ip().to_string();
let port = addr.port();
let work = work.and_then(|tcp| {
tunnel(tcp, host, port, Some(proxy::Auth::basic("Aladdin", "open sesame")))
});
rt.block_on(work).unwrap();
}
}

View File

@@ -1,6 +1,7 @@
use std::fmt;
use std::sync::Arc;
use http::{header::HeaderValue, Uri};
use hyper::client::connect::Destination;
use {into_url, IntoUrl, Url};
@@ -30,9 +31,15 @@ use {into_url, IntoUrl, Url};
/// would prevent a `Proxy` later in the list from ever working, so take care.
#[derive(Clone, Debug)]
pub struct Proxy {
auth: Option<Auth>,
intercept: Intercept,
}
#[derive(Clone, Debug)]
pub(crate) enum Auth {
Basic(HeaderValue),
}
impl Proxy {
/// Proxy all HTTP traffic to the passed URL.
///
@@ -124,7 +131,43 @@ impl Proxy {
fn new(intercept: Intercept) -> Proxy {
Proxy {
intercept: intercept,
auth: None,
intercept,
}
}
/// Set the `Proxy-Authorization` header using Basic auth.
///
/// # Example
///
/// ```
/// # extern crate reqwest;
/// # fn run() -> Result<(), Box<::std::error::Error>> {
/// let proxy = reqwest::Proxy::https("http://localhost:1234")?
/// .basic_auth("Aladdin", "open sesame");
/// # Ok(())
/// # }
/// # fn main() {}
/// ```
pub fn basic_auth(mut self, username: &str, password: &str) -> Proxy {
self.auth = Some(Auth::basic(username, password));
self
}
pub(crate) fn auth(&self) -> Option<&Auth> {
self.auth.as_ref()
}
pub(crate) fn maybe_has_http_auth(&self) -> bool {
match self.auth {
Some(Auth::Basic(_)) => match self.intercept {
Intercept::All(_) |
Intercept::Http(_) |
// Custom *may* match 'http', so assume so.
Intercept::Custom(_) => true,
Intercept::Https(_) => false,
},
None => false,
}
}
@@ -161,6 +204,31 @@ impl Proxy {
},
}
}
pub(crate) fn is_match<D: Dst>(&self, uri: &D) -> bool {
match self.intercept {
Intercept::All(_) => true,
Intercept::Http(_) => {
uri.scheme() == "http"
},
Intercept::Https(_) => {
uri.scheme() == "https"
},
Intercept::Custom(ref fun) => {
(fun.0)(
&format!(
"{}://{}{}{}",
uri.scheme(),
uri.host(),
uri.port().map(|_| ":").unwrap_or(""),
uri.port().map(|p| p.to_string()).unwrap_or(String::new())
)
.parse()
.expect("should be valid Url")
).is_some()
},
}
}
}
#[derive(Clone, Debug)]
@@ -203,6 +271,35 @@ impl Dst for Destination {
}
}
#[doc(hidden)]
impl Dst for Uri {
fn scheme(&self) -> &str {
self.scheme_part()
.expect("Uri should have a scheme")
.as_str()
}
fn host(&self) -> &str {
Uri::host(self)
.expect("<Uri as Dst>::host should have a str")
}
fn port(&self) -> Option<u16> {
self.port_part().map(|p| p.as_u16())
}
}
impl Auth {
pub(crate) fn basic(username: &str, password: &str) -> Auth {
let val = format!("{}:{}", username, password);
let mut header = format!("Basic {}", base64::encode(&val))
.parse::<HeaderValue>()
.expect("base64 is always valid HeaderValue");
header.set_sensitive(true);
Auth::Basic(header)
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,6 +1,13 @@
use std::fmt;
use header::HeaderMap;
use header::{
HeaderMap,
AUTHORIZATION,
COOKIE,
PROXY_AUTHORIZATION,
WWW_AUTHENTICATE,
};
use hyper::StatusCode;
use Url;
@@ -233,10 +240,11 @@ pub(crate) fn remove_sensitive_headers(headers: &mut HeaderMap, next: &Url, prev
let cross_host = next.host_str() != previous.host_str() ||
next.port_or_known_default() != previous.port_or_known_default();
if cross_host {
headers.remove("authorization");
headers.remove("cookie");
headers.remove(AUTHORIZATION);
headers.remove(COOKIE);
headers.remove("cookie2");
headers.remove("www-authenticate");
headers.remove(PROXY_AUTHORIZATION);
headers.remove(WWW_AUTHENTICATE);
}
}
}

View File

@@ -4,7 +4,7 @@ extern crate reqwest;
mod support;
#[test]
fn test_http_proxy() {
fn http_proxy() {
let server = server! {
request: b"\
GET http://hyper.rs/prox HTTP/1.1\r\n\
@@ -37,3 +37,43 @@ fn test_http_proxy() {
assert_eq!(res.status(), reqwest::StatusCode::OK);
assert_eq!(res.headers().get(reqwest::header::SERVER).unwrap(), &"proxied");
}
#[test]
fn http_proxy_basic_auth() {
let server = server! {
request: b"\
GET http://hyper.rs/prox HTTP/1.1\r\n\
user-agent: $USERAGENT\r\n\
accept: */*\r\n\
accept-encoding: gzip\r\n\
proxy-authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n\
host: hyper.rs\r\n\
\r\n\
",
response: b"\
HTTP/1.1 200 OK\r\n\
Server: proxied\r\n\
Content-Length: 0\r\n\
\r\n\
"
};
let proxy = format!("http://{}", server.addr());
let url = "http://hyper.rs/prox";
let res = reqwest::Client::builder()
.proxy(
reqwest::Proxy::http(&proxy)
.unwrap()
.basic_auth("Aladdin", "open sesame")
)
.build()
.unwrap()
.get(url)
.send()
.unwrap();
assert_eq!(res.url().as_str(), url);
assert_eq!(res.status(), reqwest::StatusCode::OK);
assert_eq!(res.headers().get(reqwest::header::SERVER).unwrap(), &"proxied");
}