diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 6a057d1..edaaba9 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -35,7 +35,7 @@ use {IntoUrl, Method, Proxy, StatusCode, Url}; #[cfg(feature = "tls")] use {Certificate, Identity}; #[cfg(feature = "tls")] -use ::tls::{TlsBackend, inner as tls_inner}; +use ::tls::TlsBackend; static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); @@ -128,32 +128,17 @@ impl ClientBuilder { tls.danger_accept_invalid_certs(!config.certs_verification); for cert in config.root_certs { - let cert = match cert.inner { - tls_inner::Certificate::Der(buf) => - try_!(::native_tls::Certificate::from_der(&buf)), - tls_inner::Certificate::Pem(buf) => - try_!(::native_tls::Certificate::from_pem(&buf)) - }; - tls.add_root_certificate(cert); + cert.add_to_native_tls(&mut tls); } if let Some(id) = config.identity { - let id = match id.inner { - tls_inner::Identity::Pkcs12(buf, passwd) => - try_!(::native_tls::Identity::from_pkcs12(&buf, &passwd)), - #[cfg(feature = "rustls-tls")] - _ => return Err(::error::from(::error::Kind::TlsIncompatible)) - }; - tls.identity(id); + id.add_to_native_tls(&mut tls)?; } Connector::new_default_tls(tls, proxies.clone())? }, #[cfg(feature = "rustls-tls")] TlsBackend::Rustls => { - use std::io::Cursor; - use rustls::TLSError; - use rustls::internal::pemfile; use ::tls::NoVerifier; let mut tls = ::rustls::ClientConfig::new(); @@ -164,44 +149,11 @@ impl ClientBuilder { } for cert in config.root_certs { - match cert.inner { - tls_inner::Certificate::Der(buf) => try_!(tls.root_store.add(&::rustls::Certificate(buf)) - .map_err(TLSError::WebPKIError)), - tls_inner::Certificate::Pem(buf) => { - let mut pem = Cursor::new(buf); - let mut certs = try_!(pemfile::certs(&mut pem) - .map_err(|_| TLSError::General(String::from("No valid certificate was found")))); - for c in certs { - try_!(tls.root_store.add(&c) - .map_err(TLSError::WebPKIError)); - } - } - } + cert.add_to_rustls(&mut tls)?; } if let Some(id) = config.identity { - let (key, certs) = match id.inner { - tls_inner::Identity::Pem(buf) => { - let mut pem = Cursor::new(buf); - let mut certs = try_!(pemfile::certs(&mut pem) - .map_err(|_| TLSError::General(String::from("No valid certificate was found")))); - pem.set_position(0); - let mut sk = try_!(pemfile::pkcs8_private_keys(&mut pem) - .or_else(|_| { - pem.set_position(0); - pemfile::rsa_private_keys(&mut pem) - }) - .map_err(|_| TLSError::General(String::from("No valid private key was found")))); - if let (Some(sk), false) = (sk.pop(), certs.is_empty()) { - (sk, certs) - } else { - return Err(::error::from(TLSError::General(String::from("private key or certificate not found")))); - } - }, - #[cfg(feature = "default-tls")] - _ => return Err(::error::from(::error::Kind::TlsIncompatible)) - }; - tls.set_single_client_cert(certs, key); + id.add_to_rustls(&mut tls)?; } Connector::new_rustls_tls(tls, proxies.clone())? diff --git a/src/tls.rs b/src/tls.rs index 2bf335e..6046ecc 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -6,25 +6,30 @@ use tokio_rustls::webpki::DNSNameRef; /// Represent a server X509 certificate. pub struct Certificate { - pub(crate) inner: inner::Certificate + #[cfg(feature = "default-tls")] + native: ::native_tls::Certificate, + #[cfg(feature = "rustls-tls")] + original: Cert, +} + +#[cfg(feature = "rustls-tls")] +enum Cert { + Der(Vec), + Pem(Vec) } /// Represent a private key and X509 cert as a client certificate. pub struct Identity { - pub(crate) inner: inner::Identity + inner: ClientCert, } -pub(crate) mod inner { - pub(crate) enum Certificate { - Der(Vec), - Pem(Vec) - } - - pub(crate) enum Identity { - #[cfg(feature = "default-tls")] - Pkcs12(Vec, String), - #[cfg(feature = "rustls-tls")] - Pem(Vec), +enum ClientCert { + #[cfg(feature = "default-tls")] + Pkcs12(::native_tls::Identity), + #[cfg(feature = "rustls-tls")] + Pem { + key: ::rustls::PrivateKey, + certs: Vec<::rustls::Certificate>, } } @@ -47,7 +52,10 @@ impl Certificate { /// ``` pub fn from_der(der: &[u8]) -> ::Result { Ok(Certificate { - inner: inner::Certificate::Der(der.to_owned()) + #[cfg(feature = "default-tls")] + native: try_!(::native_tls::Certificate::from_der(der)), + #[cfg(feature = "rustls-tls")] + original: Cert::Der(der.to_owned()), }) } @@ -68,11 +76,47 @@ impl Certificate { /// # Ok(()) /// # } /// ``` - pub fn from_pem(der: &[u8]) -> ::Result { + pub fn from_pem(pem: &[u8]) -> ::Result { Ok(Certificate { - inner: inner::Certificate::Pem(der.to_owned()) + #[cfg(feature = "default-tls")] + native: try_!(::native_tls::Certificate::from_pem(pem)), + #[cfg(feature = "rustls-tls")] + original: Cert::Pem(pem.to_owned()) }) } + + #[cfg(feature = "default-tls")] + pub(crate) fn add_to_native_tls( + self, + tls: &mut ::native_tls::TlsConnectorBuilder, + ) { + tls.add_root_certificate(self.native); + } + + #[cfg(feature = "rustls-tls")] + pub(crate) fn add_to_rustls( + self, + tls: &mut ::rustls::ClientConfig, + ) -> ::Result<()> { + use std::io::Cursor; + use rustls::TLSError; + use rustls::internal::pemfile; + + match self.original { + Cert::Der(buf) => try_!(tls.root_store.add(&::rustls::Certificate(buf)) + .map_err(TLSError::WebPKIError)), + Cert::Pem(buf) => { + let mut pem = Cursor::new(buf); + let mut certs = try_!(pemfile::certs(&mut pem) + .map_err(|_| TLSError::General(String::from("No valid certificate was found")))); + for c in certs { + try_!(tls.root_store.add(&c) + .map_err(TLSError::WebPKIError)); + } + } + } + Ok(()) + } } impl Identity { @@ -106,7 +150,9 @@ impl Identity { #[cfg(feature = "default-tls")] pub fn from_pkcs12_der(der: &[u8], password: &str) -> ::Result { Ok(Identity { - inner: inner::Identity::Pkcs12(der.to_owned(), password.to_owned()) + inner: ClientCert::Pkcs12( + try_!(::native_tls::Identity::from_pkcs12(der, password)) + ), }) } @@ -130,11 +176,66 @@ impl Identity { /// # } /// ``` #[cfg(feature = "rustls-tls")] - pub fn from_pem(pem: &[u8]) -> ::Result { + pub fn from_pem(buf: &[u8]) -> ::Result { + use std::io::Cursor; + use rustls::TLSError; + use rustls::internal::pemfile; + + let (key, certs) = { + let mut pem = Cursor::new(buf); + let mut certs = try_!(pemfile::certs(&mut pem) + .map_err(|_| TLSError::General(String::from("No valid certificate was found")))); + pem.set_position(0); + let mut sk = try_!(pemfile::pkcs8_private_keys(&mut pem) + .or_else(|_| { + pem.set_position(0); + pemfile::rsa_private_keys(&mut pem) + }) + .map_err(|_| TLSError::General(String::from("No valid private key was found")))); + if let (Some(sk), false) = (sk.pop(), certs.is_empty()) { + (sk, certs) + } else { + return Err(::error::from(TLSError::General(String::from("private key or certificate not found")))); + } + }; + Ok(Identity { - inner: inner::Identity::Pem(pem.to_owned()) + inner: ClientCert::Pem { + key, + certs, + }, }) } + + #[cfg(feature = "default-tls")] + pub(crate) fn add_to_native_tls( + self, + tls: &mut ::native_tls::TlsConnectorBuilder, + ) -> ::Result<()> { + match self.inner { + ClientCert::Pkcs12(id) => { + tls.identity(id); + Ok(()) + }, + #[cfg(feature = "rustls-tls")] + ClientCert::Pem { .. } => Err(::error::from(::error::Kind::TlsIncompatible)) + } + } + + #[cfg(feature = "rustls-tls")] + pub(crate) fn add_to_rustls( + self, + tls: &mut ::rustls::ClientConfig, + ) -> ::Result<()> { + match self.inner { + ClientCert::Pem { key, certs } => { + tls.set_single_client_cert(certs, key); + Ok(()) + }, + #[cfg(feature = "default-tls")] + ClientCert::Pkcs12(..) => return Err(::error::from(::error::Kind::TlsIncompatible)) + } + } } impl fmt::Debug for Certificate { @@ -183,3 +284,32 @@ impl ServerCertVerifier for NoVerifier { Ok(ServerCertVerified::assertion()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "default-tls")] + #[test] + fn certificate_from_der_invalid() { + Certificate::from_der(b"not der").unwrap_err(); + } + + #[cfg(feature = "default-tls")] + #[test] + fn certificate_from_pem_invalid() { + Certificate::from_pem(b"not pem").unwrap_err(); + } + + #[cfg(feature = "default-tls")] + #[test] + fn identity_from_pkcs12_der_invalid() { + Identity::from_pkcs12_der(b"not der", "nope").unwrap_err(); + } + + #[cfg(feature = "rustls-tls")] + #[test] + fn identity_from_pem_invalid() { + Identity::from_pem(b"not pem").unwrap_err(); + } +}