Add support for custom DNS resolution (#1653)

Closes #1125
This commit is contained in:
Thomas Smith
2022-11-02 06:27:07 -07:00
committed by GitHub
parent 231b18f835
commit f11e958433
7 changed files with 191 additions and 278 deletions

View File

@@ -13,7 +13,7 @@ use http::header::{
}; };
use http::uri::Scheme; use http::uri::Scheme;
use http::Uri; use http::Uri;
use hyper::client::ResponseFuture; use hyper::client::{HttpConnector, ResponseFuture};
#[cfg(feature = "native-tls-crate")] #[cfg(feature = "native-tls-crate")]
use native_tls_crate::TlsConnector; use native_tls_crate::TlsConnector;
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
@@ -28,9 +28,12 @@ use super::decoder::Accepts;
use super::request::{Request, RequestBuilder}; use super::request::{Request, RequestBuilder};
use super::response::Response; use super::response::Response;
use super::Body; use super::Body;
use crate::connect::{Connector, HttpConnector}; use crate::connect::Connector;
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
use crate::cookie; use crate::cookie;
#[cfg(feature = "trust-dns")]
use crate::dns::trust_dns::TrustDnsResolver;
use crate::dns::{gai::GaiResolver, DnsResolverWithOverrides, DynResolver, Resolve};
use crate::error; use crate::error;
use crate::into_url::{expect_uri, try_uri}; use crate::into_url::{expect_uri, try_uri};
use crate::redirect::{self, remove_sensitive_headers}; use crate::redirect::{self, remove_sensitive_headers};
@@ -121,6 +124,7 @@ struct Config {
error: Option<crate::Error>, error: Option<crate::Error>,
https_only: bool, https_only: bool,
dns_overrides: HashMap<String, Vec<SocketAddr>>, dns_overrides: HashMap<String, Vec<SocketAddr>>,
dns_resolver: Option<Arc<dyn Resolve>>,
} }
impl Default for ClientBuilder { impl Default for ClientBuilder {
@@ -188,6 +192,7 @@ impl ClientBuilder {
cookie_store: None, cookie_store: None,
https_only: false, https_only: false,
dns_overrides: HashMap::new(), dns_overrides: HashMap::new(),
dns_resolver: None,
}, },
} }
} }
@@ -217,25 +222,23 @@ impl ClientBuilder {
headers.get(USER_AGENT).cloned() headers.get(USER_AGENT).cloned()
} }
let http = match config.trust_dns { let mut resolver: Arc<dyn Resolve> = match config.trust_dns {
false => { false => Arc::new(GaiResolver::new()),
if config.dns_overrides.is_empty() {
HttpConnector::new_gai()
} else {
HttpConnector::new_gai_with_overrides(config.dns_overrides)
}
}
#[cfg(feature = "trust-dns")] #[cfg(feature = "trust-dns")]
true => { true => Arc::new(TrustDnsResolver::new().map_err(crate::error::builder)?),
if config.dns_overrides.is_empty() {
HttpConnector::new_trust_dns()?
} else {
HttpConnector::new_trust_dns_with_overrides(config.dns_overrides)?
}
}
#[cfg(not(feature = "trust-dns"))] #[cfg(not(feature = "trust-dns"))]
true => unreachable!("trust-dns shouldn't be enabled unless the feature is"), true => unreachable!("trust-dns shouldn't be enabled unless the feature is"),
}; };
if let Some(dns_resolver) = config.dns_resolver {
resolver = dns_resolver;
}
if !config.dns_overrides.is_empty() {
resolver = Arc::new(DnsResolverWithOverrides::new(
resolver,
config.dns_overrides,
));
}
let http = HttpConnector::new_with_resolver(DynResolver::new(resolver));
#[cfg(feature = "__tls")] #[cfg(feature = "__tls")]
match config.tls { match config.tls {
@@ -1340,6 +1343,16 @@ impl ClientBuilder {
.insert(domain.to_string(), addrs.to_vec()); .insert(domain.to_string(), addrs.to_vec());
self self
} }
/// Override the DNS resolver implementation.
///
/// Pass an `Arc` wrapping a trait object implementing `Resolve`.
/// Overrides for specific names passed to `resolve` and `resolve_to_addrs` will
/// still be applied on top of this resolver.
pub fn dns_resolver<R: Resolve + 'static>(mut self, resolver: Arc<R>) -> ClientBuilder {
self.config.dns_resolver = Some(resolver as _);
self
}
} }
type HyperClient = hyper::Client<Connector, super::body::ImplStream>; type HyperClient = hyper::Client<Connector, super::body::ImplStream>;

View File

@@ -1,162 +1,31 @@
use futures_util::future::Either;
#[cfg(feature = "__tls")] #[cfg(feature = "__tls")]
use http::header::HeaderValue; use http::header::HeaderValue;
use http::uri::{Authority, Scheme}; use http::uri::{Authority, Scheme};
use http::Uri; use http::Uri;
use hyper::client::connect::{ use hyper::client::connect::{Connected, Connection};
dns::{GaiResolver, Name},
Connected, Connection,
};
use hyper::service::Service; use hyper::service::Service;
#[cfg(feature = "native-tls-crate")] #[cfg(feature = "native-tls-crate")]
use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use std::io::IoSlice; use std::future::Future;
use std::io::{self, IoSlice};
use std::net::IpAddr; use std::net::IpAddr;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use std::{collections::HashMap, io};
use std::{future::Future, net::SocketAddr};
#[cfg(feature = "default-tls")] #[cfg(feature = "default-tls")]
use self::native_tls_conn::NativeTlsConn; use self::native_tls_conn::NativeTlsConn;
#[cfg(feature = "__rustls")] #[cfg(feature = "__rustls")]
use self::rustls_tls_conn::RustlsTlsConn; use self::rustls_tls_conn::RustlsTlsConn;
#[cfg(feature = "trust-dns")] use crate::dns::DynResolver;
use crate::dns::TrustDnsResolver;
use crate::error::BoxError; use crate::error::BoxError;
use crate::proxy::{Proxy, ProxyScheme}; use crate::proxy::{Proxy, ProxyScheme};
#[derive(Clone)] pub(crate) type HttpConnector = hyper::client::HttpConnector<DynResolver>;
pub(crate) enum HttpConnector {
Gai(hyper::client::HttpConnector),
GaiWithDnsOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>>),
#[cfg(feature = "trust-dns")]
TrustDns(hyper::client::HttpConnector<TrustDnsResolver>),
#[cfg(feature = "trust-dns")]
TrustDnsWithOverrides(hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>>),
}
impl HttpConnector {
pub(crate) fn new_gai() -> Self {
Self::Gai(hyper::client::HttpConnector::new())
}
pub(crate) fn new_gai_with_overrides(overrides: HashMap<String, Vec<SocketAddr>>) -> Self {
let gai = hyper::client::connect::dns::GaiResolver::new();
let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides);
Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver(
overridden_resolver,
))
}
#[cfg(feature = "trust-dns")]
pub(crate) fn new_trust_dns() -> crate::Result<HttpConnector> {
TrustDnsResolver::new()
.map(hyper::client::HttpConnector::new_with_resolver)
.map(Self::TrustDns)
.map_err(crate::error::builder)
}
#[cfg(feature = "trust-dns")]
pub(crate) fn new_trust_dns_with_overrides(
overrides: HashMap<String, Vec<SocketAddr>>,
) -> crate::Result<HttpConnector> {
TrustDnsResolver::new()
.map(|resolver| DnsResolverWithOverrides::new(resolver, overrides))
.map(hyper::client::HttpConnector::new_with_resolver)
.map(Self::TrustDnsWithOverrides)
.map_err(crate::error::builder)
}
}
macro_rules! impl_http_connector {
($(fn $name:ident(&mut self, $($par_name:ident: $par_type:ty),*)$( -> $return:ty)?;)+) => {
#[allow(dead_code)]
impl HttpConnector {
$(
fn $name(&mut self, $($par_name: $par_type),*)$( -> $return)? {
match self {
Self::Gai(resolver) => resolver.$name($($par_name),*),
Self::GaiWithDnsOverrides(resolver) => resolver.$name($($par_name),*),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => resolver.$name($($par_name),*),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => resolver.$name($($par_name),*),
}
}
)+
}
};
}
impl_http_connector! {
fn set_local_address(&mut self, addr: Option<IpAddr>);
fn enforce_http(&mut self, is_enforced: bool);
fn set_nodelay(&mut self, nodelay: bool);
fn set_keepalive(&mut self, dur: Option<Duration>);
}
impl Service<Uri> for HttpConnector {
type Response = <hyper::client::HttpConnector as Service<Uri>>::Response;
type Error = <hyper::client::HttpConnector as Service<Uri>>::Error;
#[cfg(feature = "trust-dns")]
type Future =
Either<
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
Uri,
>>::Future,
>,
Either<
<hyper::client::HttpConnector<TrustDnsResolver> as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<TrustDnsResolver>> as Service<Uri>>::Future
>
>;
#[cfg(not(feature = "trust-dns"))]
type Future =
Either<
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector<DnsResolverWithOverrides<GaiResolver>> as Service<
Uri,
>>::Future,
>,
Either<
<hyper::client::HttpConnector as Service<Uri>>::Future,
<hyper::client::HttpConnector as Service<Uri>>::Future,
>,
>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self {
Self::Gai(resolver) => resolver.poll_ready(cx),
Self::GaiWithDnsOverrides(resolver) => resolver.poll_ready(cx),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => resolver.poll_ready(cx),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => resolver.poll_ready(cx),
}
}
fn call(&mut self, dst: Uri) -> Self::Future {
match self {
Self::Gai(resolver) => Either::Left(Either::Left(resolver.call(dst))),
Self::GaiWithDnsOverrides(resolver) => Either::Left(Either::Right(resolver.call(dst))),
#[cfg(feature = "trust-dns")]
Self::TrustDns(resolver) => Either::Right(Either::Left(resolver.call(dst))),
#[cfg(feature = "trust-dns")]
Self::TrustDnsWithOverrides(resolver) => {
Either::Right(Either::Right(resolver.call(dst)))
}
}
}
}
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct Connector { pub(crate) struct Connector {
@@ -960,103 +829,6 @@ mod socks {
} }
} }
pub(crate) mod itertools {
pub(crate) enum Either<A, B> {
Left(A),
Right(B),
}
impl<A, B> Iterator for Either<A, B>
where
A: Iterator,
B: Iterator<Item = <A as Iterator>::Item>,
{
type Item = <A as Iterator>::Item;
fn next(&mut self) -> Option<Self::Item> {
match self {
Either::Left(a) => a.next(),
Either::Right(b) => b.next(),
}
}
}
}
pin_project! {
pub(crate) struct WrappedResolverFuture<Fut> {
#[pin]
fut: Fut,
}
}
impl<Fut, FutOutput, FutError> std::future::Future for WrappedResolverFuture<Fut>
where
Fut: std::future::Future<Output = Result<FutOutput, FutError>>,
FutOutput: Iterator<Item = SocketAddr>,
{
type Output = Result<itertools::Either<FutOutput, std::vec::IntoIter<SocketAddr>>, FutError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.fut
.poll(cx)
.map(|result| result.map(itertools::Either::Left))
}
}
#[derive(Clone)]
pub(crate) struct DnsResolverWithOverrides<Resolver>
where
Resolver: Clone,
{
dns_resolver: Resolver,
overrides: Arc<HashMap<String, Vec<SocketAddr>>>,
}
impl<Resolver: Clone> DnsResolverWithOverrides<Resolver> {
fn new(dns_resolver: Resolver, overrides: HashMap<String, Vec<SocketAddr>>) -> Self {
DnsResolverWithOverrides {
dns_resolver,
overrides: Arc::new(overrides),
}
}
}
impl<Resolver, Iter> Service<Name> for DnsResolverWithOverrides<Resolver>
where
Resolver: Service<Name, Response = Iter> + Clone,
Iter: Iterator<Item = SocketAddr>,
{
type Response = itertools::Either<Iter, std::vec::IntoIter<SocketAddr>>;
type Error = <Resolver as Service<Name>>::Error;
type Future = Either<
WrappedResolverFuture<<Resolver as Service<Name>>::Future>,
futures_util::future::Ready<
Result<itertools::Either<Iter, std::vec::IntoIter<SocketAddr>>, Self::Error>,
>,
>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.dns_resolver.poll_ready(cx)
}
fn call(&mut self, name: Name) -> Self::Future {
match self.overrides.get(name.as_str()) {
Some(dest) => {
let fut = futures_util::future::ready(Ok(itertools::Either::Right(
dest.clone().into_iter(),
)));
Either::Right(fut)
}
None => {
let resolver_fut = self.dns_resolver.call(name);
let y = WrappedResolverFuture { fut: resolver_fut };
Either::Left(y)
}
}
}
}
mod verbose { mod verbose {
use hyper::client::connect::{Connected, Connection}; use hyper::client::connect::{Connected, Connection};
use std::fmt; use std::fmt;

32
src/dns/gai.rs Normal file
View File

@@ -0,0 +1,32 @@
use futures_util::future::FutureExt;
use hyper::client::connect::dns::{GaiResolver as HyperGaiResolver, Name};
use hyper::service::Service;
use crate::dns::{Addrs, Resolve, Resolving};
use crate::error::BoxError;
#[derive(Debug)]
pub struct GaiResolver(HyperGaiResolver);
impl GaiResolver {
pub fn new() -> Self {
Self(HyperGaiResolver::new())
}
}
impl Default for GaiResolver {
fn default() -> Self {
GaiResolver::new()
}
}
impl Resolve for GaiResolver {
fn resolve(&self, name: Name) -> Resolving {
let this = &mut self.0.clone();
Box::pin(Service::<Name>::call(this, name).map(|result| {
result
.map(|addrs| -> Addrs { Box::new(addrs) })
.map_err(|err| -> BoxError { Box::new(err) })
}))
}
}

9
src/dns/mod.rs Normal file
View File

@@ -0,0 +1,9 @@
//! DNS resolution
pub use resolve::{Addrs, Resolve, Resolving};
pub(crate) use resolve::{DnsResolverWithOverrides, DynResolver};
pub(crate) mod gai;
pub(crate) mod resolve;
#[cfg(feature = "trust-dns")]
pub(crate) mod trust_dns;

84
src/dns/resolve.rs Normal file
View File

@@ -0,0 +1,84 @@
use hyper::client::connect::dns::Name;
use hyper::service::Service;
use std::collections::HashMap;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::error::BoxError;
/// Alias for an `Iterator` trait object over `SocketAddr`.
pub type Addrs = Box<dyn Iterator<Item = SocketAddr> + Send>;
/// Alias for the `Future` type returned by a DNS resolver.
pub type Resolving = Pin<Box<dyn Future<Output = Result<Addrs, BoxError>> + Send>>;
/// Trait for customizing DNS resolution in reqwest.
pub trait Resolve: Send + Sync {
/// Performs DNS resolution on a `Name`.
/// The return type is a future containing an iterator of `SocketAddr`.
///
/// It differs from `tower_service::Service<Name>` in several ways:
/// * It is assumed that `resolve` will always be ready to poll.
/// * It does not need a mutable reference to `self`.
/// * Since trait objects cannot make use of associated types, it requires
/// wrapping the returned `Future` and its contained `Iterator` with `Box`.
fn resolve(&self, name: Name) -> Resolving;
}
#[derive(Clone)]
pub(crate) struct DynResolver {
resolver: Arc<dyn Resolve>,
}
impl DynResolver {
pub(crate) fn new(resolver: Arc<dyn Resolve>) -> Self {
Self { resolver }
}
}
impl Service<Name> for DynResolver {
type Response = Addrs;
type Error = BoxError;
type Future = Resolving;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, name: Name) -> Self::Future {
self.resolver.resolve(name)
}
}
pub(crate) struct DnsResolverWithOverrides {
dns_resolver: Arc<dyn Resolve>,
overrides: Arc<HashMap<String, Vec<SocketAddr>>>,
}
impl DnsResolverWithOverrides {
pub(crate) fn new(
dns_resolver: Arc<dyn Resolve>,
overrides: HashMap<String, Vec<SocketAddr>>,
) -> Self {
DnsResolverWithOverrides {
dns_resolver,
overrides: Arc::new(overrides),
}
}
}
impl Resolve for DnsResolverWithOverrides {
fn resolve(&self, name: Name) -> Resolving {
match self.overrides.get(name.as_str()) {
Some(dest) => {
let addrs: Addrs = Box::new(dest.clone().into_iter());
Box::pin(futures_util::future::ready(Ok(addrs)))
}
None => self.dns_resolver.resolve(name),
}
}
}

View File

@@ -1,20 +1,20 @@
use std::future::Future; //! DNS resolution via the [trust_dns_resolver](https://github.com/bluejekyll/trust-dns) crate
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{self, Poll};
use hyper::client::connect::dns as hyper_dns; use hyper::client::connect::dns::Name;
use hyper::service::Service;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use tokio::sync::Mutex; use tokio::sync::Mutex;
pub use trust_dns_resolver::config::{ResolverConfig, ResolverOpts};
use trust_dns_resolver::{ use trust_dns_resolver::{
config::{ResolverConfig, ResolverOpts}, lookup_ip::LookupIpIntoIter, system_conf, AsyncResolver, TokioConnection,
lookup_ip::LookupIpIntoIter, TokioConnectionProvider, TokioHandle,
system_conf, AsyncResolver, TokioConnection, TokioConnectionProvider, TokioHandle,
}; };
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use super::{Addrs, Resolve, Resolving};
use crate::error::BoxError; use crate::error::BoxError;
type SharedResolver = Arc<AsyncResolver<TokioConnection, TokioConnectionProvider>>; type SharedResolver = Arc<AsyncResolver<TokioConnection, TokioConnectionProvider>>;
@@ -22,22 +22,26 @@ type SharedResolver = Arc<AsyncResolver<TokioConnection, TokioConnectionProvider
static SYSTEM_CONF: Lazy<io::Result<(ResolverConfig, ResolverOpts)>> = static SYSTEM_CONF: Lazy<io::Result<(ResolverConfig, ResolverOpts)>> =
Lazy::new(|| system_conf::read_system_conf().map_err(io::Error::from)); Lazy::new(|| system_conf::read_system_conf().map_err(io::Error::from));
#[derive(Clone)] /// Wrapper around an `AsyncResolver`, which implements the `Resolve` trait.
#[derive(Debug, Clone)]
pub(crate) struct TrustDnsResolver { pub(crate) struct TrustDnsResolver {
state: Arc<Mutex<State>>, state: Arc<Mutex<State>>,
} }
pub(crate) struct SocketAddrs { struct SocketAddrs {
iter: LookupIpIntoIter, iter: LookupIpIntoIter,
} }
#[derive(Debug)]
enum State { enum State {
Init, Init,
Ready(SharedResolver), Ready(SharedResolver),
} }
impl TrustDnsResolver { impl TrustDnsResolver {
pub(crate) fn new() -> io::Result<Self> { /// Create a new resolver with the default configuration,
/// which reads from `/etc/resolve.conf`.
pub fn new() -> io::Result<Self> {
SYSTEM_CONF.as_ref().map_err(|e| { SYSTEM_CONF.as_ref().map_err(|e| {
io::Error::new(e.kind(), format!("error reading DNS system conf: {}", e)) io::Error::new(e.kind(), format!("error reading DNS system conf: {}", e))
})?; })?;
@@ -51,16 +55,8 @@ impl TrustDnsResolver {
} }
} }
impl Service<hyper_dns::Name> for TrustDnsResolver { impl Resolve for TrustDnsResolver {
type Response = SocketAddrs; fn resolve(&self, name: Name) -> Resolving {
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, name: hyper_dns::Name) -> Self::Future {
let resolver = self.clone(); let resolver = self.clone();
Box::pin(async move { Box::pin(async move {
let mut lock = resolver.state.lock().await; let mut lock = resolver.state.lock().await;
@@ -79,9 +75,10 @@ impl Service<hyper_dns::Name> for TrustDnsResolver {
drop(lock); drop(lock);
let lookup = resolver.lookup_ip(name.as_str()).await?; let lookup = resolver.lookup_ip(name.as_str()).await?;
Ok(SocketAddrs { let addrs: Addrs = Box::new(SocketAddrs {
iter: lookup.into_iter(), iter: lookup.into_iter(),
}) });
Ok(addrs)
}) })
} }
} }
@@ -99,6 +96,13 @@ async fn new_resolver() -> Result<SharedResolver, BoxError> {
.as_ref() .as_ref()
.expect("can't construct TrustDnsResolver if SYSTEM_CONF is error") .expect("can't construct TrustDnsResolver if SYSTEM_CONF is error")
.clone(); .clone();
new_resolver_with_config(config, opts)
}
fn new_resolver_with_config(
config: ResolverConfig,
opts: ResolverOpts,
) -> Result<SharedResolver, BoxError> {
let resolver = AsyncResolver::new(config, opts, TokioHandle)?; let resolver = AsyncResolver::new(config, opts, TokioHandle)?;
Ok(Arc::new(resolver)) Ok(Arc::new(resolver))
} }

View File

@@ -309,8 +309,7 @@ if_hyper! {
mod connect; mod connect;
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
pub mod cookie; pub mod cookie;
#[cfg(feature = "trust-dns")] pub mod dns;
mod dns;
mod proxy; mod proxy;
pub mod redirect; pub mod redirect;
#[cfg(feature = "__tls")] #[cfg(feature = "__tls")]