From 37039760f8c86213cf29b1d2cce2a7b1b32384e2 Mon Sep 17 00:00:00 2001 From: quininer Date: Tue, 18 Dec 2018 03:57:43 +0800 Subject: [PATCH] Add rustls support (#390) --- .travis.yml | 8 ++ Cargo.toml | 8 +- src/async_impl/client.rs | 160 +++++++++++++++++++------ src/client.rs | 35 ++++-- src/connect.rs | 244 +++++++++++++++++++-------------------- src/error.rs | 39 ++++++- src/lib.rs | 13 ++- src/tls.rs | 135 +++++++++++++++++----- tests/badssl.rs | 46 ++++++++ 9 files changed, 482 insertions(+), 206 deletions(-) create mode 100644 tests/badssl.rs diff --git a/.travis.yml b/.travis.yml index c9de06c..7a25554 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,6 +15,14 @@ matrix: - rust: stable env: FEATURES="--no-default-features" + # rustls-tls + - rust: stable + env: FEATURES="--no-default-features --features rustls-tls" + + # default-tls and rustls-tls + - rust: stable + env: FEATURES="--features rustls-tls" + - rust: stable env: FEATURES="--features hyper-011" diff --git a/Cargo.toml b/Cargo.toml index 4033296..7df90b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,10 @@ tokio = "0.1.7" tokio-io = "0.1" url = "1.2" uuid = { version = "0.7", features = ["v4"] } +hyper-rustls = { version = "0.15", optional = true } +tokio-rustls = { version = "0.8", optional = true } +webpki-roots = { version = "0.15", optional = true } +rustls = { version = "0.14", features = ["dangerous_configuration"], optional = true } [dev-dependencies] env_logger = "0.6" @@ -37,8 +41,10 @@ serde_derive = "1.0" [features] default = ["default-tls"] +tls = [] hyper-011 = ["hyper-old-types"] -default-tls = ["hyper-tls", "native-tls"] +default-tls = ["hyper-tls", "native-tls", "tls"] +rustls-tls = ["hyper-rustls", "tokio-rustls", "webpki-roots", "rustls", "tls"] native-tls-vendored = ["native-tls/vendored"] [package.metadata.docs.rs] diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 673abd5..4bb8a1b 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -9,7 +9,7 @@ use header::{HeaderMap, HeaderValue, LOCATION, USER_AGENT, REFERER, ACCEPT, ACCEPT_ENCODING, RANGE, TRANSFER_ENCODING, CONTENT_TYPE, CONTENT_LENGTH, CONTENT_ENCODING}; use mime::{self}; #[cfg(feature = "default-tls")] -use native_tls::{TlsConnector, TlsConnectorBuilder}; +use native_tls::TlsConnector; use super::request::{Request, RequestBuilder}; @@ -18,8 +18,10 @@ use connect::Connector; use into_url::to_uri; use redirect::{self, RedirectPolicy, remove_sensitive_headers}; use {IntoUrl, Method, Proxy, StatusCode, Url}; -#[cfg(feature = "default-tls")] +#[cfg(feature = "tls")] use {Certificate, Identity}; +#[cfg(feature = "tls")] +use ::tls::{ TLSBackend, inner }; static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); @@ -46,14 +48,18 @@ struct Config { headers: HeaderMap, #[cfg(feature = "default-tls")] hostname_verification: bool, - #[cfg(feature = "default-tls")] + #[cfg(feature = "tls")] certs_verification: bool, proxies: Vec, redirect_policy: RedirectPolicy, referer: bool, timeout: Option, - #[cfg(feature = "default-tls")] - tls: TlsConnectorBuilder, + #[cfg(feature = "tls")] + root_certs: Vec, + #[cfg(feature = "tls")] + identity: Option, + #[cfg(feature = "tls")] + tls: TLSBackend, dns_threads: usize, } @@ -70,14 +76,18 @@ impl ClientBuilder { headers: headers, #[cfg(feature = "default-tls")] hostname_verification: true, - #[cfg(feature = "default-tls")] + #[cfg(feature = "tls")] certs_verification: true, proxies: Vec::new(), redirect_policy: RedirectPolicy::default(), referer: true, timeout: None, - #[cfg(feature = "default-tls")] - tls: TlsConnector::builder(), + #[cfg(feature = "tls")] + root_certs: Vec::new(), + #[cfg(feature = "tls")] + identity: None, + #[cfg(feature = "tls")] + tls: TLSBackend::default(), dns_threads: 4, }, } @@ -87,32 +97,103 @@ impl ClientBuilder { /// /// # Errors /// - /// This method fails if native TLS backend cannot be initialized. + /// This method fails if TLS backend cannot be initialized. pub fn build(self) -> ::Result { let config = self.config; - + let proxies = Arc::new(config.proxies); let connector = { - #[cfg(feature = "default-tls")] - { - let mut tls = config.tls; - tls.danger_accept_invalid_hostnames(!config.hostname_verification); - tls.danger_accept_invalid_certs(!config.certs_verification); + #[cfg(feature = "tls")] + match config.tls { + #[cfg(feature = "default-tls")] + TLSBackend::Default => { + let mut tls = TlsConnector::builder(); + tls.danger_accept_invalid_hostnames(!config.hostname_verification); + tls.danger_accept_invalid_certs(!config.certs_verification); - let tls = try_!(tls.build()); + for cert in config.root_certs { + let cert = match cert.inner { + inner::Certificate::Der(buf) => + try_!(::native_tls::Certificate::from_der(&buf)), + inner::Certificate::Pem(buf) => + try_!(::native_tls::Certificate::from_pem(&buf)) + }; + tls.add_root_certificate(cert); + } - let proxies = Arc::new(config.proxies); + if let Some(id) = config.identity { + let id = match id.inner { + inner::Identity::Pkcs12(buf, passwd) => + try_!(::native_tls::Identity::from_pkcs12(&buf, &passwd)), + #[cfg(feature = "rustls-tls")] + _ => return Err(::error::from(::error::Kind::Incompatible)) + }; + tls.identity(id); + } - Connector::new(config.dns_threads, tls, proxies.clone()) + Connector::new_default_tls(config.dns_threads, tls, proxies.clone())? + }, + #[cfg(feature = "rustls-tls")] + TLSBackend::Rustls => { + use std::io::Cursor; + use rustls::TLSError; + use rustls::internal::pemfile; + use ::tls::NoVerifier; + + let mut tls = ::rustls::ClientConfig::new(); + tls.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + + if !config.certs_verification { + tls.dangerous().set_certificate_verifier(Arc::new(NoVerifier)); + } + + for cert in config.root_certs { + match cert.inner { + inner::Certificate::Der(buf) => try_!(tls.root_store.add(&::rustls::Certificate(buf)) + .map_err(TLSError::WebPKIError)), + inner::Certificate::Pem(buf) => { + let mut pem = Cursor::new(buf); + let mut certs = try_!(pemfile::certs(&mut pem) + .map_err(|_| TLSError::General(String::from("No valid certificate was found")))); + for c in certs { + try_!(tls.root_store.add(&c) + .map_err(TLSError::WebPKIError)); + } + } + } + } + + if let Some(id) = config.identity { + let (key, certs) = match id.inner { + inner::Identity::Pem(buf) => { + let mut pem = Cursor::new(buf); + let mut certs = try_!(pemfile::certs(&mut pem) + .map_err(|_| TLSError::General(String::from("No valid certificate was found")))); + pem.set_position(0); + let mut sk = try_!(pemfile::pkcs8_private_keys(&mut pem) + .or_else(|_| { + pem.set_position(0); + pemfile::rsa_private_keys(&mut pem) + }) + .map_err(|_| TLSError::General(String::from("No valid private key was found")))); + if let (Some(sk), false) = (sk.pop(), certs.is_empty()) { + (sk, certs) + } else { + return Err(::error::from(TLSError::General(String::from("private key or certificate not found")))); + } + }, + #[cfg(feature = "default-tls")] + _ => return Err(::error::from(::error::Kind::Incompatible)) + }; + tls.set_single_client_cert(certs, key); + } + + Connector::new_rustls_tls(config.dns_threads, tls, proxies.clone())? + } } - - #[cfg(not(feature = "default-tls"))] - { - let proxies = Arc::new(config.proxies); - + #[cfg(not(feature = "tls"))] Connector::new(config.dns_threads, proxies.clone()) - } }; let hyper_client = ::hyper::Client::builder() @@ -129,20 +210,34 @@ impl ClientBuilder { }) } + /// Use native TLS backend. + #[cfg(feature = "default-tls")] + pub fn use_default_tls(mut self) -> ClientBuilder { + self.config.tls = TLSBackend::Default; + self + } + + /// Use rustls TLS backend. + #[cfg(feature = "rustls-tls")] + pub fn use_rustls_tls(mut self) -> ClientBuilder { + self.config.tls = TLSBackend::Rustls; + self + } + /// Add a custom root certificate. /// /// This can be used to connect to a server that has a self-signed /// certificate for example. - #[cfg(feature = "default-tls")] + #[cfg(feature = "tls")] pub fn add_root_certificate(mut self, cert: Certificate) -> ClientBuilder { - self.config.tls.add_root_certificate(cert.cert()); + self.config.root_certs.push(cert); self } /// Sets the identity to be used for client certificate authentication. - #[cfg(feature = "default-tls")] + #[cfg(feature = "tls")] pub fn identity(mut self, identity: Identity) -> ClientBuilder { - self.config.tls.identity(identity.pkcs12()); + self.config.identity = Some(identity); self } @@ -162,7 +257,6 @@ impl ClientBuilder { self } - /// Controls the use of certificate validation. /// /// Defaults to `false`. @@ -174,7 +268,7 @@ impl ClientBuilder { /// will be trusted for use. This includes expired certificates. This /// introduces significant vulnerabilities, and should only be used /// as a last resort. - #[cfg(feature = "default-tls")] + #[cfg(feature = "tls")] pub fn danger_accept_invalid_certs(mut self, accept_invalid_certs: bool) -> ClientBuilder { self.config.certs_verification = !accept_invalid_certs; self @@ -196,9 +290,9 @@ impl ClientBuilder { /// an `Accept-Encoding` **and** `Range` values, the `Accept-Encoding` header is set to `gzip`. /// The body is **not** automatically inflated. /// - When receiving a response, if it's headers contain a `Content-Encoding` value that - /// equals to `gzip`, both values `Content-Encoding` and `Content-Length` are removed from the + /// equals to `gzip`, both values `Content-Encoding` and `Content-Length` are removed from the /// headers' set. The body is automatically deinflated. - /// + /// /// Default is enabled. pub fn gzip(mut self, enable: bool) -> ClientBuilder { self.config.gzip = enable; @@ -247,7 +341,7 @@ impl Client { /// /// # Panics /// - /// This method panics if native TLS backend cannot be created or + /// This method panics if TLS backend cannot be created or /// initialized. Use `Client::builder()` if you wish to handle the failure /// as an `Error` instead of panicking. pub fn new() -> Client { diff --git a/src/client.rs b/src/client.rs index e234e16..fd3c64f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -10,7 +10,7 @@ use futures::sync::{mpsc, oneshot}; use request::{Request, RequestBuilder}; use response::Response; use {async_impl, header, Method, IntoUrl, Proxy, RedirectPolicy, wait}; -#[cfg(feature = "default-tls")] +#[cfg(feature = "tls")] use {Certificate, Identity}; /// A `Client` to make Requests with. @@ -79,6 +79,18 @@ impl ClientBuilder { }) } + /// Use native TLS backend. + #[cfg(feature = "default-tls")] + pub fn use_default_tls(self) -> ClientBuilder { + self.with_inner(move |inner| inner.use_default_tls()) + } + + /// Use rustls TLS backend. + #[cfg(feature = "rustls-tls")] + pub fn use_rustls_tls(self) -> ClientBuilder { + self.with_inner(move |inner| inner.use_rustls_tls()) + } + /// Add a custom root certificate. /// /// This can be used to connect to a server that has a self-signed @@ -108,7 +120,7 @@ impl ClientBuilder { /// # Errors /// /// This method fails if adding root certificate was unsuccessful. - #[cfg(feature = "default-tls")] + #[cfg(feature = "tls")] pub fn add_root_certificate(self, cert: Certificate) -> ClientBuilder { self.with_inner(move |inner| inner.add_root_certificate(cert)) } @@ -123,10 +135,18 @@ impl ClientBuilder { /// # fn build_client() -> Result<(), Box> { /// // read a local PKCS12 bundle /// let mut buf = Vec::new(); - /// File::open("my-ident.pfx")?.read_to_end(&mut buf)?; /// + /// #[cfg(feature = "default-tls")] + /// File::open("my-ident.pfx")?.read_to_end(&mut buf)?; + /// #[cfg(feature = "rustls-tls")] + /// File::open("my-ident.pem")?.read_to_end(&mut buf)?; + /// + /// #[cfg(feature = "default-tls")] /// // create an Identity from the PKCS#12 archive /// let pkcs12 = reqwest::Identity::from_pkcs12_der(&buf, "my-privkey-password")?; + /// #[cfg(feature = "rustls-tls")] + /// // create an Identity from the PEM file + /// let pkcs12 = reqwest::Identity::from_pem(&buf)?; /// /// // get a client builder /// let client = reqwest::Client::builder() @@ -136,7 +156,7 @@ impl ClientBuilder { /// # Ok(()) /// # } /// ``` - #[cfg(feature = "default-tls")] + #[cfg(feature = "tls")] pub fn identity(self, identity: Identity) -> ClientBuilder { self.with_inner(move |inner| inner.identity(identity)) } @@ -157,7 +177,6 @@ impl ClientBuilder { self.with_inner(|inner| inner.danger_accept_invalid_hostnames(accept_invalid_hostname)) } - /// Controls the use of certificate validation. /// /// Defaults to `false`. @@ -169,7 +188,7 @@ impl ClientBuilder { /// will be trusted for use. This includes expired certificates. This /// introduces significant vulnerabilities, and should only be used /// as a last resort. - #[cfg(feature = "default-tls")] + #[cfg(feature = "tls")] pub fn danger_accept_invalid_certs(self, accept_invalid_certs: bool) -> ClientBuilder { self.with_inner(|inner| inner.danger_accept_invalid_certs(accept_invalid_certs)) } @@ -223,9 +242,9 @@ impl ClientBuilder { /// an `Accept-Encoding` **and** `Range` values, the `Accept-Encoding` header is set to `gzip`. /// The body is **not** automatically inflated. /// - When receiving a response, if it's headers contain a `Content-Encoding` value that - /// equals to `gzip`, both values `Content-Encoding` and `Content-Length` are removed from the + /// equals to `gzip`, both values `Content-Encoding` and `Content-Length` are removed from the /// headers' set. The body is automatically deinflated. - /// + /// /// Default is enabled. pub fn gzip(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.gzip(enable)) diff --git a/src/connect.rs b/src/connect.rs index ff890bb..4a5a4da 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -1,51 +1,70 @@ -use bytes::{Buf, BufMut}; -use futures::{Future, Poll}; +use futures::Future; use http::uri::Scheme; use hyper::client::{HttpConnector}; use hyper::client::connect::{Connect, Connected, Destination}; -#[cfg(feature = "default-tls")] -use hyper_tls::{HttpsConnector, MaybeHttpsStream}; -#[cfg(feature = "default-tls")] -use native_tls::TlsConnector; use tokio_io::{AsyncRead, AsyncWrite}; -#[cfg(feature = "default-tls")] -use connect_async::{TlsConnectorExt, TlsStream}; -use std::io::{self, Read, Write}; +#[cfg(feature = "default-tls")] +use native_tls::{TlsConnector, TlsConnectorBuilder}; +#[cfg(feature = "tls")] +use futures::Poll; +#[cfg(feature = "tls")] +use bytes::BufMut; + +use std::io; use std::sync::Arc; use Proxy; + pub(crate) struct Connector { - #[cfg(feature = "default-tls")] - http: HttpsConnector, - #[cfg(not(feature = "default-tls"))] - http: HttpConnector, proxies: Arc>, + inner: Inner +} + +enum Inner { + #[cfg(not(feature = "tls"))] + Http(HttpConnector), #[cfg(feature = "default-tls")] - tls: TlsConnector, + DefaultTls(::hyper_tls::HttpsConnector, TlsConnector), + #[cfg(feature = "rustls-tls")] + RustlsTls(::hyper_rustls::HttpsConnector, Arc) } impl Connector { - #[cfg(not(feature = "default-tls"))] + #[cfg(not(feature = "tls"))] pub(crate) fn new(threads: usize, proxies: Arc>) -> Connector { let http = HttpConnector::new(threads); Connector { - http, proxies, + inner: Inner::Http(http) } } + #[cfg(feature = "default-tls")] - pub(crate) fn new(threads: usize, tls: TlsConnector, proxies: Arc>) -> Connector { + pub(crate) fn new_default_tls(threads: usize, tls: TlsConnectorBuilder, proxies: Arc>) -> ::Result { + let tls = try_!(tls.build()); + let mut http = HttpConnector::new(threads); http.enforce_http(false); - let http = HttpsConnector::from((http, tls.clone())); + let http = ::hyper_tls::HttpsConnector::from((http, tls.clone())); - Connector { - http, + Ok(Connector { proxies, - tls, - } + inner: Inner::DefaultTls(http, tls) + }) + } + + #[cfg(feature = "rustls-tls")] + pub(crate) fn new_rustls_tls(threads: usize, tls: rustls::ClientConfig, proxies: Arc>) -> ::Result { + let mut http = HttpConnector::new(threads); + http.enforce_http(false); + let http = ::hyper_rustls::HttpsConnector::from((http, tls.clone())); + + Ok(Connector { + proxies, + inner: Inner::RustlsTls(http, Arc::new(tls)) + }) } } @@ -55,6 +74,23 @@ impl Connect for Connector { type Future = Connecting; fn connect(&self, dst: Destination) -> Self::Future { + macro_rules! connect { + ( $http:expr, $dst:expr, $proxy:expr ) => { + Box::new($http.connect($dst) + .map(|(io, connected)| (Box::new(io) as Conn, connected.proxy($proxy)))) + }; + ( $dst:expr, $proxy:expr ) => { + match &self.inner { + #[cfg(not(feature = "tls"))] + Inner::Http(http) => connect!(http, $dst, $proxy), + #[cfg(feature = "default-tls")] + Inner::DefaultTls(http, _) => connect!(http, $dst, $proxy), + #[cfg(feature = "rustls-tls")] + Inner::RustlsTls(http, _) => connect!(http, $dst, $proxy) + } + }; + } + for prox in self.proxies.iter() { if let Some(puri) = prox.intercept(&dst) { trace!("proxy({:?}) intercepts {:?}", puri, dst); @@ -69,116 +105,70 @@ impl Connect for Connector { ndst.set_host(puri.host().expect("proxy target should have host")) .expect("proxy target host should be valid"); - ndst.set_port(puri.port_part().map(|p| p.as_u16())); + ndst.set_port(puri.port_part().map(|port| port.as_u16())); - #[cfg(feature = "default-tls")] - { - if dst.scheme() == "https" { - let host = dst.host().to_owned(); - let port = dst.port().unwrap_or(443); - let tls = self.tls.clone(); - return Box::new(self.http.connect(ndst).and_then(move |(conn, connected)| { - trace!("tunneling HTTPS over proxy"); - tunnel(conn, host.clone(), port) - .and_then(move |tunneled| { - tls.connect_async(&host, tunneled) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - }) - .map(|io| (Conn::Proxied(io), connected.proxy(true))) - })); + match &self.inner { + #[cfg(feature = "default-tls")] + Inner::DefaultTls(http, tls) => if dst.scheme() == "https" { + #[cfg(feature = "default-tls")] + use connect_async::TlsConnectorExt; + + let host = dst.host().to_owned(); + let port = dst.port().unwrap_or(443); + let tls = tls.clone(); + return Box::new(http.connect(ndst).and_then(move |(conn, connected)| { + trace!("tunneling HTTPS over proxy"); + tunnel(conn, host.clone(), port) + .and_then(move |tunneled| { + tls.connect_async(&host, tunneled) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + }) + .map(|io| (Box::new(io) as Conn, connected.proxy(true))) + })); + }, + #[cfg(feature = "rustls-tls")] + Inner::RustlsTls(http, tls) => if dst.scheme() == "https" { + #[cfg(feature = "rustls-tls")] + use tokio_rustls::TlsConnector as RustlsConnector; + #[cfg(feature = "rustls-tls")] + use tokio_rustls::webpki::DNSNameRef; + + let host = dst.host().to_owned(); + let port = dst.port().unwrap_or(443); + let tls = tls.clone(); + return Box::new(http.connect(ndst).and_then(move |(conn, connected)| { + trace!("tunneling HTTPS over proxy"); + let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) + .map(|dnsname| dnsname.to_owned()) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")); + tunnel(conn, host, port) + .and_then(move |tunneled| Ok((maybe_dnsname?, tunneled))) + .and_then(move |(dnsname, tunneled)| { + RustlsConnector::from(tls).connect(dnsname.as_ref(), tunneled) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + }) + .map(|io| (Box::new(io) as Conn, connected.proxy(true))) + })); + }, + #[cfg(not(feature = "tls"))] + Inner::Http(_) => () } - } - return Box::new(self.http.connect(ndst).map(|(io, connected)| (Conn::Normal(io), connected.proxy(true)))); + + return connect!(ndst, true); } } - Box::new(self.http.connect(dst).map(|(io, connected)| (Conn::Normal(io), connected))) + + connect!(dst, false) } } -type HttpStream = ::Transport; -#[cfg(feature = "default-tls")] -type HttpsStream = MaybeHttpsStream; - +pub(crate) trait AsyncConn: AsyncRead + AsyncWrite {} +impl AsyncConn for T {} +pub(crate) type Conn = Box; pub(crate) type Connecting = Box + Send>; -pub(crate) enum Conn { - #[cfg(feature = "default-tls")] - Normal(HttpsStream), - #[cfg(not(feature = "default-tls"))] - Normal(HttpStream), - #[cfg(feature = "default-tls")] - Proxied(TlsStream>), -} - -impl Read for Conn { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match *self { - Conn::Normal(ref mut s) => s.read(buf), - #[cfg(feature = "default-tls")] - Conn::Proxied(ref mut s) => s.read(buf), - } - } -} - -impl Write for Conn { - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - match *self { - Conn::Normal(ref mut s) => s.write(buf), - #[cfg(feature = "default-tls")] - Conn::Proxied(ref mut s) => s.write(buf), - } - } - - #[inline] - fn flush(&mut self) -> io::Result<()> { - match *self { - Conn::Normal(ref mut s) => s.flush(), - #[cfg(feature = "default-tls")] - Conn::Proxied(ref mut s) => s.flush(), - } - } -} - -impl AsyncRead for Conn { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - match *self { - Conn::Normal(ref s) => s.prepare_uninitialized_buffer(buf), - #[cfg(feature = "default-tls")] - Conn::Proxied(ref s) => s.prepare_uninitialized_buffer(buf), - } - } - - fn read_buf(&mut self, buf: &mut B) -> Poll { - match *self { - Conn::Normal(ref mut s) => s.read_buf(buf), - #[cfg(feature = "default-tls")] - Conn::Proxied(ref mut s) => s.read_buf(buf), - } - } -} - -impl AsyncWrite for Conn { - fn shutdown(&mut self) -> Poll<(), io::Error> { - match *self { - Conn::Normal(ref mut s) => s.shutdown(), - #[cfg(feature = "default-tls")] - Conn::Proxied(ref mut s) => s.shutdown(), - } - } - - fn write_buf(&mut self, buf: &mut B) -> Poll { - match *self { - Conn::Normal(ref mut s) => s.write_buf(buf), - #[cfg(feature = "default-tls")] - Conn::Proxied(ref mut s) => s.write_buf(buf), - } - } -} - -#[cfg(feature = "default-tls")] +#[cfg(feature = "tls")] fn tunnel(conn: T, host: String, port: u16) -> Tunnel { let buf = format!("\ CONNECT {0}:{1} HTTP/1.1\r\n\ @@ -193,20 +183,20 @@ fn tunnel(conn: T, host: String, port: u16) -> Tunnel { } } -#[cfg(feature = "default-tls")] +#[cfg(feature = "tls")] struct Tunnel { buf: io::Cursor>, conn: Option, state: TunnelState, } -#[cfg(feature = "default-tls")] +#[cfg(feature = "tls")] enum TunnelState { Writing, Reading } -#[cfg(feature = "default-tls")] +#[cfg(feature = "tls")] impl Future for Tunnel where T: AsyncRead + AsyncWrite { type Item = T; @@ -242,7 +232,7 @@ where T: AsyncRead + AsyncWrite { } } -#[cfg(feature = "default-tls")] +#[cfg(feature = "tls")] #[inline] fn tunnel_eof() -> io::Error { io::Error::new( @@ -251,7 +241,7 @@ fn tunnel_eof() -> io::Error { ) } -#[cfg(feature = "default-tls")] +#[cfg(feature = "tls")] #[cfg(test)] mod tests { use std::io::{Read, Write}; diff --git a/src/error.rs b/src/error.rs index cf2edb2..0d2c17b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -135,8 +135,12 @@ impl Error { Kind::Hyper(ref e) => Some(e), Kind::Mime(ref e) => Some(e), Kind::Url(ref e) => Some(e), + #[cfg(all(feature = "default-tls", feature = "rustls-tls"))] + Kind::Incompatible => None, #[cfg(feature = "default-tls")] - Kind::Tls(ref e) => Some(e), + Kind::NativeTls(ref e) => Some(e), + #[cfg(feature = "rustls-tls")] + Kind::Rustls(ref e) => Some(e), Kind::Io(ref e) => Some(e), Kind::UrlEncoded(ref e) => Some(e), Kind::Json(ref e) => Some(e), @@ -225,8 +229,12 @@ impl fmt::Display for Error { Kind::Mime(ref e) => fmt::Display::fmt(e, f), Kind::Url(ref e) => fmt::Display::fmt(e, f), Kind::UrlBadScheme => f.write_str("URL scheme is not allowed"), + #[cfg(all(feature = "default-tls", feature = "rustls-tls"))] + Kind::Incompatible => f.write_str("Incompatible identity type"), #[cfg(feature = "default-tls")] - Kind::Tls(ref e) => fmt::Display::fmt(e, f), + Kind::NativeTls(ref e) => fmt::Display::fmt(e, f), + #[cfg(feature = "rustls-tls")] + Kind::Rustls(ref e) => fmt::Display::fmt(e, f), Kind::Io(ref e) => fmt::Display::fmt(e, f), Kind::UrlEncoded(ref e) => fmt::Display::fmt(e, f), Kind::Json(ref e) => fmt::Display::fmt(e, f), @@ -252,8 +260,12 @@ impl StdError for Error { Kind::Mime(ref e) => e.description(), Kind::Url(ref e) => e.description(), Kind::UrlBadScheme => "URL scheme is not allowed", + #[cfg(all(feature = "default-tls", feature = "rustls-tls"))] + Kind::Incompatible => "Incompatible identity type", #[cfg(feature = "default-tls")] - Kind::Tls(ref e) => e.description(), + Kind::NativeTls(ref e) => e.description(), + #[cfg(feature = "rustls-tls")] + Kind::Rustls(ref e) => e.description(), Kind::Io(ref e) => e.description(), Kind::UrlEncoded(ref e) => e.description(), Kind::Json(ref e) => e.description(), @@ -270,8 +282,12 @@ impl StdError for Error { Kind::Hyper(ref e) => e.cause(), Kind::Mime(ref e) => e.cause(), Kind::Url(ref e) => e.cause(), + #[cfg(all(feature = "default-tls", feature = "rustls-tls"))] + Kind::Incompatible => None, #[cfg(feature = "default-tls")] - Kind::Tls(ref e) => e.cause(), + Kind::NativeTls(ref e) => e.cause(), + #[cfg(feature = "rustls-tls")] + Kind::Rustls(ref e) => e.cause(), Kind::Io(ref e) => e.cause(), Kind::UrlEncoded(ref e) => e.cause(), Kind::Json(ref e) => e.cause(), @@ -291,8 +307,12 @@ pub(crate) enum Kind { Mime(::mime::FromStrError), Url(::url::ParseError), UrlBadScheme, + #[cfg(all(feature = "default-tls", feature = "rustls-tls"))] + Incompatible, #[cfg(feature = "default-tls")] - Tls(::native_tls::Error), + NativeTls(::native_tls::Error), + #[cfg(feature = "rustls-tls")] + Rustls(::rustls::TLSError), Io(io::Error), UrlEncoded(::serde_urlencoded::ser::Error), Json(::serde_json::Error), @@ -355,7 +375,14 @@ impl From<::serde_json::Error> for Kind { #[cfg(feature = "default-tls")] impl From<::native_tls::Error> for Kind { fn from(err: ::native_tls::Error) -> Kind { - Kind::Tls(err) + Kind::NativeTls(err) + } +} + +#[cfg(feature = "rustls-tls")] +impl From<::rustls::TLSError> for Kind { + fn from(err: ::rustls::TLSError) -> Kind { + Kind::Rustls(err) } } diff --git a/src/lib.rs b/src/lib.rs index 5ad59d3..7ddca5d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -166,6 +166,15 @@ extern crate tokio_io; extern crate url; extern crate uuid; +#[cfg(feature = "rustls-tls")] +extern crate hyper_rustls; +#[cfg(feature = "rustls-tls")] +extern crate tokio_rustls; +#[cfg(feature = "rustls-tls")] +extern crate webpki_roots; +#[cfg(feature = "rustls-tls")] +extern crate rustls; + pub use hyper::header; pub use hyper::Method; pub use hyper::{StatusCode, Version}; @@ -180,7 +189,7 @@ pub use self::proxy::Proxy; pub use self::redirect::{RedirectAction, RedirectAttempt, RedirectPolicy}; pub use self::request::{Request, RequestBuilder}; pub use self::response::Response; -#[cfg(feature = "default-tls")] +#[cfg(feature = "tls")] pub use self::tls::{Certificate, Identity}; @@ -199,7 +208,7 @@ mod proxy; mod redirect; mod request; mod response; -#[cfg(feature = "default-tls")] +#[cfg(feature = "tls")] mod tls; mod wait; diff --git a/src/tls.rs b/src/tls.rs index 37fe69c..c36dc60 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -1,8 +1,32 @@ use std::fmt; -use native_tls; +#[cfg(feature = "rustls-tls")] +use rustls::{TLSError, ServerCertVerifier, RootCertStore, ServerCertVerified}; +#[cfg(feature = "rustls-tls")] +use tokio_rustls::webpki::DNSNameRef; /// Represent an X509 certificate. -pub struct Certificate(native_tls::Certificate); +pub struct Certificate { + pub(crate) inner: inner::Certificate +} + +/// Represent a private key and X509 cert as a client certificate. +pub struct Identity { + pub(crate) inner: inner::Identity +} + +pub(crate) mod inner { + pub(crate) enum Certificate { + Der(Vec), + Pem(Vec) + } + + pub(crate) enum Identity { + #[cfg(feature = "default-tls")] + Pkcs12(Vec, String), + #[cfg(feature = "rustls-tls")] + Pem(Vec), + } +} impl Certificate { /// Create a `Certificate` from a binary DER encoded certificate @@ -24,10 +48,11 @@ impl Certificate { /// /// # Errors /// - /// If the provided buffer is not valid DER, an error will be returned. + /// It never returns error. pub fn from_der(der: &[u8]) -> ::Result { - let inner = try_!(native_tls::Certificate::from_der(der)); - Ok(Certificate(inner)) + Ok(Certificate { + inner: inner::Certificate::Der(der.to_owned()) + }) } @@ -50,29 +75,14 @@ impl Certificate { /// /// # Errors /// - /// If the provided buffer is not valid PEM, an error will be returned. + /// It never returns error. pub fn from_pem(der: &[u8]) -> ::Result { - let inner = try_!(native_tls::Certificate::from_pem(der)); - Ok(Certificate(inner)) - } - - pub(crate) fn cert(self) -> native_tls::Certificate { - self.0 - } - -} - -impl fmt::Debug for Certificate { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Certificate") - .finish() + Ok(Certificate { + inner: inner::Certificate::Pem(der.to_owned()) + }) } } - -/// Represent a private key and X509 cert as a client certificate. -pub struct Identity(native_tls::Identity); - impl Identity { /// Parses a DER-formatted PKCS #12 archive, using the specified password to decrypt the key. /// @@ -104,14 +114,49 @@ impl Identity { /// /// # Errors /// - /// If the provided buffer is not valid DER, an error will be returned. + /// It never returns error. + #[cfg(feature = "default-tls")] pub fn from_pkcs12_der(der: &[u8], password: &str) -> ::Result { - let inner = try_!(native_tls::Identity::from_pkcs12(der, password)); - Ok(Identity(inner)) + Ok(Identity { + inner: inner::Identity::Pkcs12(der.to_owned(), password.to_owned()) + }) } - pub(crate) fn pkcs12(self) -> native_tls::Identity { - self.0 + /// Parses PEM encoded private key and certificate. + /// + /// The input should contain a PEM encoded private key + /// and at least one PEM encoded certificate. + /// + /// # Examples + /// + /// ``` + /// # use std::fs::File; + /// # use std::io::Read; + /// # fn pem() -> Result<(), Box> { + /// let mut buf = Vec::new(); + /// File::open("my-ident.pem")? + /// .read_to_end(&mut buf)?; + /// let id = reqwest::Identity::from_pem(&buf)?; + /// # drop(id); + /// # Ok(()) + /// # } + /// ``` + /// + /// # Errors + /// + /// It never returns error. + #[cfg(feature = "rustls-tls")] + pub fn from_pem(pem: &[u8]) -> ::Result { + Ok(Identity { + inner: inner::Identity::Pem(pem.to_owned()) + }) + } +} + +impl fmt::Debug for Certificate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Certificate") + .finish() } } @@ -122,3 +167,35 @@ impl fmt::Debug for Identity { } } +pub(crate) enum TLSBackend { + #[cfg(feature = "default-tls")] + Default, + #[cfg(feature = "rustls-tls")] + Rustls +} + +impl Default for TLSBackend { + fn default() -> TLSBackend { + #[cfg(feature = "default-tls")] + { TLSBackend::Default } + + #[cfg(all(feature = "rustls-tls", not(feature = "default-tls")))] + { TLSBackend::Rustls } + } +} + +#[cfg(feature = "rustls-tls")] +pub(crate) struct NoVerifier; + +#[cfg(feature = "rustls-tls")] +impl ServerCertVerifier for NoVerifier { + fn verify_server_cert( + &self, + _roots: &RootCertStore, + _presented_certs: &[rustls::Certificate], + _dns_name: DNSNameRef, + _ocsp_response: &[u8] + ) -> Result { + Ok(ServerCertVerified::assertion()) + } +} diff --git a/tests/badssl.rs b/tests/badssl.rs new file mode 100644 index 0000000..c4d84c9 --- /dev/null +++ b/tests/badssl.rs @@ -0,0 +1,46 @@ +extern crate reqwest; + + +#[cfg(feature = "tls")] +#[test] +fn test_badssl_modern() { + let text = reqwest::get("https://mozilla-modern.badssl.com/").unwrap() + .text().unwrap(); + + assert!(text.contains("mozilla-modern.badssl.com")); +} + +#[cfg(feature = "tls")] +#[test] +fn test_badssl_self_signed() { + let text = reqwest::Client::builder() + .danger_accept_invalid_certs(true) + .build().unwrap() + .get("https://self-signed.badssl.com/") + .send().unwrap() + .text().unwrap(); + + assert!(text.contains("self-signed.badssl.com")); +} + +#[cfg(feature = "default-tls")] +#[test] +fn test_badssl_wrong_host() { + let text = reqwest::Client::builder() + .danger_accept_invalid_hostnames(true) + .build().unwrap() + .get("https://wrong.host.badssl.com/") + .send().unwrap() + .text().unwrap(); + + assert!(text.contains("wrong.host.badssl.com")); + + + let result = reqwest::Client::builder() + .danger_accept_invalid_hostnames(true) + .build().unwrap() + .get("https://self-signed.badssl.com/") + .send(); + + assert!(result.is_err()); +}