diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index c6283a5..8c05a7d 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -13,7 +13,7 @@ use http::header::{ }; use http::uri::Scheme; use http::Uri; -use hyper::client::ResponseFuture; +use hyper::client::{HttpConnector, ResponseFuture}; #[cfg(feature = "native-tls-crate")] use native_tls_crate::TlsConnector; use pin_project_lite::pin_project; @@ -28,9 +28,12 @@ use super::decoder::Accepts; use super::request::{Request, RequestBuilder}; use super::response::Response; use super::Body; -use crate::connect::{Connector, HttpConnector}; +use crate::connect::Connector; #[cfg(feature = "cookies")] use crate::cookie; +#[cfg(feature = "trust-dns")] +use crate::dns::trust_dns::TrustDnsResolver; +use crate::dns::{gai::GaiResolver, DnsResolverWithOverrides, DynResolver, Resolve}; use crate::error; use crate::into_url::{expect_uri, try_uri}; use crate::redirect::{self, remove_sensitive_headers}; @@ -121,6 +124,7 @@ struct Config { error: Option, https_only: bool, dns_overrides: HashMap>, + dns_resolver: Option>, } impl Default for ClientBuilder { @@ -188,6 +192,7 @@ impl ClientBuilder { cookie_store: None, https_only: false, dns_overrides: HashMap::new(), + dns_resolver: None, }, } } @@ -217,25 +222,23 @@ impl ClientBuilder { headers.get(USER_AGENT).cloned() } - let http = match config.trust_dns { - false => { - if config.dns_overrides.is_empty() { - HttpConnector::new_gai() - } else { - HttpConnector::new_gai_with_overrides(config.dns_overrides) - } - } + let mut resolver: Arc = match config.trust_dns { + false => Arc::new(GaiResolver::new()), #[cfg(feature = "trust-dns")] - true => { - if config.dns_overrides.is_empty() { - HttpConnector::new_trust_dns()? - } else { - HttpConnector::new_trust_dns_with_overrides(config.dns_overrides)? - } - } + true => Arc::new(TrustDnsResolver::new().map_err(crate::error::builder)?), #[cfg(not(feature = "trust-dns"))] true => unreachable!("trust-dns shouldn't be enabled unless the feature is"), }; + if let Some(dns_resolver) = config.dns_resolver { + resolver = dns_resolver; + } + if !config.dns_overrides.is_empty() { + resolver = Arc::new(DnsResolverWithOverrides::new( + resolver, + config.dns_overrides, + )); + } + let http = HttpConnector::new_with_resolver(DynResolver::new(resolver)); #[cfg(feature = "__tls")] match config.tls { @@ -1340,6 +1343,16 @@ impl ClientBuilder { .insert(domain.to_string(), addrs.to_vec()); self } + + /// Override the DNS resolver implementation. + /// + /// Pass an `Arc` wrapping a trait object implementing `Resolve`. + /// Overrides for specific names passed to `resolve` and `resolve_to_addrs` will + /// still be applied on top of this resolver. + pub fn dns_resolver(mut self, resolver: Arc) -> ClientBuilder { + self.config.dns_resolver = Some(resolver as _); + self + } } type HyperClient = hyper::Client; diff --git a/src/connect.rs b/src/connect.rs index 35843ec..388a39b 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -1,162 +1,31 @@ -use futures_util::future::Either; #[cfg(feature = "__tls")] use http::header::HeaderValue; use http::uri::{Authority, Scheme}; use http::Uri; -use hyper::client::connect::{ - dns::{GaiResolver, Name}, - Connected, Connection, -}; +use hyper::client::connect::{Connected, Connection}; use hyper::service::Service; #[cfg(feature = "native-tls-crate")] use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; -use std::io::IoSlice; +use std::future::Future; +use std::io::{self, IoSlice}; use std::net::IpAddr; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; -use std::{collections::HashMap, io}; -use std::{future::Future, net::SocketAddr}; #[cfg(feature = "default-tls")] use self::native_tls_conn::NativeTlsConn; #[cfg(feature = "__rustls")] use self::rustls_tls_conn::RustlsTlsConn; -#[cfg(feature = "trust-dns")] -use crate::dns::TrustDnsResolver; +use crate::dns::DynResolver; use crate::error::BoxError; use crate::proxy::{Proxy, ProxyScheme}; -#[derive(Clone)] -pub(crate) enum HttpConnector { - Gai(hyper::client::HttpConnector), - GaiWithDnsOverrides(hyper::client::HttpConnector>), - #[cfg(feature = "trust-dns")] - TrustDns(hyper::client::HttpConnector), - #[cfg(feature = "trust-dns")] - TrustDnsWithOverrides(hyper::client::HttpConnector>), -} - -impl HttpConnector { - pub(crate) fn new_gai() -> Self { - Self::Gai(hyper::client::HttpConnector::new()) - } - - pub(crate) fn new_gai_with_overrides(overrides: HashMap>) -> Self { - let gai = hyper::client::connect::dns::GaiResolver::new(); - let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides); - Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver( - overridden_resolver, - )) - } - - #[cfg(feature = "trust-dns")] - pub(crate) fn new_trust_dns() -> crate::Result { - TrustDnsResolver::new() - .map(hyper::client::HttpConnector::new_with_resolver) - .map(Self::TrustDns) - .map_err(crate::error::builder) - } - - #[cfg(feature = "trust-dns")] - pub(crate) fn new_trust_dns_with_overrides( - overrides: HashMap>, - ) -> crate::Result { - TrustDnsResolver::new() - .map(|resolver| DnsResolverWithOverrides::new(resolver, overrides)) - .map(hyper::client::HttpConnector::new_with_resolver) - .map(Self::TrustDnsWithOverrides) - .map_err(crate::error::builder) - } -} - -macro_rules! impl_http_connector { - ($(fn $name:ident(&mut self, $($par_name:ident: $par_type:ty),*)$( -> $return:ty)?;)+) => { - #[allow(dead_code)] - impl HttpConnector { - $( - fn $name(&mut self, $($par_name: $par_type),*)$( -> $return)? { - match self { - Self::Gai(resolver) => resolver.$name($($par_name),*), - Self::GaiWithDnsOverrides(resolver) => resolver.$name($($par_name),*), - #[cfg(feature = "trust-dns")] - Self::TrustDns(resolver) => resolver.$name($($par_name),*), - #[cfg(feature = "trust-dns")] - Self::TrustDnsWithOverrides(resolver) => resolver.$name($($par_name),*), - } - } - )+ - } - }; -} - -impl_http_connector! { - fn set_local_address(&mut self, addr: Option); - fn enforce_http(&mut self, is_enforced: bool); - fn set_nodelay(&mut self, nodelay: bool); - fn set_keepalive(&mut self, dur: Option); -} - -impl Service for HttpConnector { - type Response = >::Response; - type Error = >::Error; - #[cfg(feature = "trust-dns")] - type Future = - Either< - Either< - >::Future, - > as Service< - Uri, - >>::Future, - >, - Either< - as Service>::Future, - > as Service>::Future - > - >; - #[cfg(not(feature = "trust-dns"))] - type Future = - Either< - Either< - >::Future, - > as Service< - Uri, - >>::Future, - >, - Either< - >::Future, - >::Future, - >, - >; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - match self { - Self::Gai(resolver) => resolver.poll_ready(cx), - Self::GaiWithDnsOverrides(resolver) => resolver.poll_ready(cx), - #[cfg(feature = "trust-dns")] - Self::TrustDns(resolver) => resolver.poll_ready(cx), - #[cfg(feature = "trust-dns")] - Self::TrustDnsWithOverrides(resolver) => resolver.poll_ready(cx), - } - } - - fn call(&mut self, dst: Uri) -> Self::Future { - match self { - Self::Gai(resolver) => Either::Left(Either::Left(resolver.call(dst))), - Self::GaiWithDnsOverrides(resolver) => Either::Left(Either::Right(resolver.call(dst))), - #[cfg(feature = "trust-dns")] - Self::TrustDns(resolver) => Either::Right(Either::Left(resolver.call(dst))), - #[cfg(feature = "trust-dns")] - Self::TrustDnsWithOverrides(resolver) => { - Either::Right(Either::Right(resolver.call(dst))) - } - } - } -} +pub(crate) type HttpConnector = hyper::client::HttpConnector; #[derive(Clone)] pub(crate) struct Connector { @@ -960,103 +829,6 @@ mod socks { } } -pub(crate) mod itertools { - pub(crate) enum Either { - Left(A), - Right(B), - } - - impl Iterator for Either - where - A: Iterator, - B: Iterator::Item>, - { - type Item = ::Item; - - fn next(&mut self) -> Option { - match self { - Either::Left(a) => a.next(), - Either::Right(b) => b.next(), - } - } - } -} - -pin_project! { - pub(crate) struct WrappedResolverFuture { - #[pin] - fut: Fut, - } -} - -impl std::future::Future for WrappedResolverFuture -where - Fut: std::future::Future>, - FutOutput: Iterator, -{ - type Output = Result>, FutError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - this.fut - .poll(cx) - .map(|result| result.map(itertools::Either::Left)) - } -} - -#[derive(Clone)] -pub(crate) struct DnsResolverWithOverrides -where - Resolver: Clone, -{ - dns_resolver: Resolver, - overrides: Arc>>, -} - -impl DnsResolverWithOverrides { - fn new(dns_resolver: Resolver, overrides: HashMap>) -> Self { - DnsResolverWithOverrides { - dns_resolver, - overrides: Arc::new(overrides), - } - } -} - -impl Service for DnsResolverWithOverrides -where - Resolver: Service + Clone, - Iter: Iterator, -{ - type Response = itertools::Either>; - type Error = >::Error; - type Future = Either< - WrappedResolverFuture<>::Future>, - futures_util::future::Ready< - Result>, Self::Error>, - >, - >; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.dns_resolver.poll_ready(cx) - } - - fn call(&mut self, name: Name) -> Self::Future { - match self.overrides.get(name.as_str()) { - Some(dest) => { - let fut = futures_util::future::ready(Ok(itertools::Either::Right( - dest.clone().into_iter(), - ))); - Either::Right(fut) - } - None => { - let resolver_fut = self.dns_resolver.call(name); - let y = WrappedResolverFuture { fut: resolver_fut }; - Either::Left(y) - } - } - } -} - mod verbose { use hyper::client::connect::{Connected, Connection}; use std::fmt; diff --git a/src/dns/gai.rs b/src/dns/gai.rs new file mode 100644 index 0000000..f32f3b0 --- /dev/null +++ b/src/dns/gai.rs @@ -0,0 +1,32 @@ +use futures_util::future::FutureExt; +use hyper::client::connect::dns::{GaiResolver as HyperGaiResolver, Name}; +use hyper::service::Service; + +use crate::dns::{Addrs, Resolve, Resolving}; +use crate::error::BoxError; + +#[derive(Debug)] +pub struct GaiResolver(HyperGaiResolver); + +impl GaiResolver { + pub fn new() -> Self { + Self(HyperGaiResolver::new()) + } +} + +impl Default for GaiResolver { + fn default() -> Self { + GaiResolver::new() + } +} + +impl Resolve for GaiResolver { + fn resolve(&self, name: Name) -> Resolving { + let this = &mut self.0.clone(); + Box::pin(Service::::call(this, name).map(|result| { + result + .map(|addrs| -> Addrs { Box::new(addrs) }) + .map_err(|err| -> BoxError { Box::new(err) }) + })) + } +} diff --git a/src/dns/mod.rs b/src/dns/mod.rs new file mode 100644 index 0000000..40cdabf --- /dev/null +++ b/src/dns/mod.rs @@ -0,0 +1,9 @@ +//! DNS resolution + +pub use resolve::{Addrs, Resolve, Resolving}; +pub(crate) use resolve::{DnsResolverWithOverrides, DynResolver}; + +pub(crate) mod gai; +pub(crate) mod resolve; +#[cfg(feature = "trust-dns")] +pub(crate) mod trust_dns; diff --git a/src/dns/resolve.rs b/src/dns/resolve.rs new file mode 100644 index 0000000..3686765 --- /dev/null +++ b/src/dns/resolve.rs @@ -0,0 +1,84 @@ +use hyper::client::connect::dns::Name; +use hyper::service::Service; + +use std::collections::HashMap; +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::error::BoxError; + +/// Alias for an `Iterator` trait object over `SocketAddr`. +pub type Addrs = Box + Send>; + +/// Alias for the `Future` type returned by a DNS resolver. +pub type Resolving = Pin> + Send>>; + +/// Trait for customizing DNS resolution in reqwest. +pub trait Resolve: Send + Sync { + /// Performs DNS resolution on a `Name`. + /// The return type is a future containing an iterator of `SocketAddr`. + /// + /// It differs from `tower_service::Service` in several ways: + /// * It is assumed that `resolve` will always be ready to poll. + /// * It does not need a mutable reference to `self`. + /// * Since trait objects cannot make use of associated types, it requires + /// wrapping the returned `Future` and its contained `Iterator` with `Box`. + fn resolve(&self, name: Name) -> Resolving; +} + +#[derive(Clone)] +pub(crate) struct DynResolver { + resolver: Arc, +} + +impl DynResolver { + pub(crate) fn new(resolver: Arc) -> Self { + Self { resolver } + } +} + +impl Service for DynResolver { + type Response = Addrs; + type Error = BoxError; + type Future = Resolving; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, name: Name) -> Self::Future { + self.resolver.resolve(name) + } +} + +pub(crate) struct DnsResolverWithOverrides { + dns_resolver: Arc, + overrides: Arc>>, +} + +impl DnsResolverWithOverrides { + pub(crate) fn new( + dns_resolver: Arc, + overrides: HashMap>, + ) -> Self { + DnsResolverWithOverrides { + dns_resolver, + overrides: Arc::new(overrides), + } + } +} + +impl Resolve for DnsResolverWithOverrides { + fn resolve(&self, name: Name) -> Resolving { + match self.overrides.get(name.as_str()) { + Some(dest) => { + let addrs: Addrs = Box::new(dest.clone().into_iter()); + Box::pin(futures_util::future::ready(Ok(addrs))) + } + None => self.dns_resolver.resolve(name), + } + } +} diff --git a/src/dns.rs b/src/dns/trust_dns.rs similarity index 69% rename from src/dns.rs rename to src/dns/trust_dns.rs index 4f0f90d..129000c 100644 --- a/src/dns.rs +++ b/src/dns/trust_dns.rs @@ -1,20 +1,20 @@ -use std::future::Future; -use std::io; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{self, Poll}; +//! DNS resolution via the [trust_dns_resolver](https://github.com/bluejekyll/trust-dns) crate -use hyper::client::connect::dns as hyper_dns; -use hyper::service::Service; +use hyper::client::connect::dns::Name; use once_cell::sync::Lazy; use tokio::sync::Mutex; +pub use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; use trust_dns_resolver::{ - config::{ResolverConfig, ResolverOpts}, - lookup_ip::LookupIpIntoIter, - system_conf, AsyncResolver, TokioConnection, TokioConnectionProvider, TokioHandle, + lookup_ip::LookupIpIntoIter, system_conf, AsyncResolver, TokioConnection, + TokioConnectionProvider, TokioHandle, }; +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; + +use super::{Addrs, Resolve, Resolving}; + use crate::error::BoxError; type SharedResolver = Arc>; @@ -22,22 +22,26 @@ type SharedResolver = Arc> = Lazy::new(|| system_conf::read_system_conf().map_err(io::Error::from)); -#[derive(Clone)] +/// Wrapper around an `AsyncResolver`, which implements the `Resolve` trait. +#[derive(Debug, Clone)] pub(crate) struct TrustDnsResolver { state: Arc>, } -pub(crate) struct SocketAddrs { +struct SocketAddrs { iter: LookupIpIntoIter, } +#[derive(Debug)] enum State { Init, Ready(SharedResolver), } impl TrustDnsResolver { - pub(crate) fn new() -> io::Result { + /// Create a new resolver with the default configuration, + /// which reads from `/etc/resolve.conf`. + pub fn new() -> io::Result { SYSTEM_CONF.as_ref().map_err(|e| { io::Error::new(e.kind(), format!("error reading DNS system conf: {}", e)) })?; @@ -51,16 +55,8 @@ impl TrustDnsResolver { } } -impl Service for TrustDnsResolver { - type Response = SocketAddrs; - type Error = BoxError; - type Future = Pin> + Send>>; - - fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, name: hyper_dns::Name) -> Self::Future { +impl Resolve for TrustDnsResolver { + fn resolve(&self, name: Name) -> Resolving { let resolver = self.clone(); Box::pin(async move { let mut lock = resolver.state.lock().await; @@ -79,9 +75,10 @@ impl Service for TrustDnsResolver { drop(lock); let lookup = resolver.lookup_ip(name.as_str()).await?; - Ok(SocketAddrs { + let addrs: Addrs = Box::new(SocketAddrs { iter: lookup.into_iter(), - }) + }); + Ok(addrs) }) } } @@ -99,6 +96,13 @@ async fn new_resolver() -> Result { .as_ref() .expect("can't construct TrustDnsResolver if SYSTEM_CONF is error") .clone(); + new_resolver_with_config(config, opts) +} + +fn new_resolver_with_config( + config: ResolverConfig, + opts: ResolverOpts, +) -> Result { let resolver = AsyncResolver::new(config, opts, TokioHandle)?; Ok(Arc::new(resolver)) } diff --git a/src/lib.rs b/src/lib.rs index 1baea86..4866117 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -309,8 +309,7 @@ if_hyper! { mod connect; #[cfg(feature = "cookies")] pub mod cookie; - #[cfg(feature = "trust-dns")] - mod dns; + pub mod dns; mod proxy; pub mod redirect; #[cfg(feature = "__tls")]