diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 2e80479..f7d6695 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -1,9 +1,9 @@ #[cfg(any(feature = "native-tls", feature = "__rustls",))] use std::any::Any; -use std::convert::TryInto; use std::net::IpAddr; use std::sync::Arc; use std::time::Duration; +use std::{collections::HashMap, convert::TryInto, net::SocketAddr}; use std::{fmt, str}; use bytes::Bytes; @@ -107,6 +107,7 @@ struct Config { trust_dns: bool, error: Option, https_only: bool, + dns_overrides: HashMap, } impl Default for ClientBuilder { @@ -164,6 +165,7 @@ impl ClientBuilder { #[cfg(feature = "cookies")] cookie_store: None, https_only: false, + dns_overrides: HashMap::new(), }, } } @@ -194,9 +196,21 @@ impl ClientBuilder { } let http = match config.trust_dns { - false => HttpConnector::new_gai(), + false => { + if config.dns_overrides.is_empty() { + HttpConnector::new_gai() + } else { + HttpConnector::new_gai_with_overrides(config.dns_overrides) + } + } #[cfg(feature = "trust-dns")] - true => HttpConnector::new_trust_dns()?, + true => { + if config.dns_overrides.is_empty() { + HttpConnector::new_trust_dns()? + } else { + HttpConnector::new_trust_dns_with_overrides(config.dns_overrides)? + } + } #[cfg(not(feature = "trust-dns"))] true => unreachable!("trust-dns shouldn't be enabled unless the feature is"), }; @@ -1037,6 +1051,19 @@ impl ClientBuilder { self.config.https_only = enabled; self } + + /// Override DNS resolution for specific domains to particular IP addresses. + /// + /// Warning + /// + /// Since the DNS protocol has no notion of ports, if you wish to send + /// traffic to a particular port you must include this port in the URL + /// itself, any port in the overridden addr will be ignored and traffic sent + /// to the conventional port for the given scheme (e.g. 80 for http). + pub fn resolve(mut self, domain: &str, addr: SocketAddr) -> ClientBuilder { + self.config.dns_overrides.insert(domain.to_string(), addr); + self + } } type HyperClient = hyper::Client; @@ -1350,6 +1377,10 @@ impl Config { { f.field("tls_backend", &self.tls); } + + if !self.dns_overrides.is_empty() { + f.field("dns_overrides", &self.dns_overrides); + } } } diff --git a/src/connect.rs b/src/connect.rs index 9f20347..375d402 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -3,21 +3,24 @@ use futures_util::future::Either; use http::header::HeaderValue; use http::uri::{Authority, Scheme}; use http::Uri; -use hyper::client::connect::{Connected, Connection}; +use hyper::client::connect::{ + dns::{GaiResolver, Name}, + Connected, Connection, +}; use hyper::service::Service; #[cfg(feature = "native-tls-crate")] use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; -use std::future::Future; -use std::io; use std::io::IoSlice; use std::net::IpAddr; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; +use std::{collections::HashMap, io}; +use std::{future::Future, net::SocketAddr}; #[cfg(feature = "default-tls")] use self::native_tls_conn::NativeTlsConn; @@ -31,8 +34,11 @@ use crate::proxy::{Proxy, ProxyScheme}; #[derive(Clone)] pub(crate) enum HttpConnector { Gai(hyper::client::HttpConnector), + GaiWithDnsOverrides(hyper::client::HttpConnector>), #[cfg(feature = "trust-dns")] TrustDns(hyper::client::HttpConnector), + #[cfg(feature = "trust-dns")] + TrustDnsWithOverrides(hyper::client::HttpConnector>), } impl HttpConnector { @@ -40,6 +46,14 @@ impl HttpConnector { Self::Gai(hyper::client::HttpConnector::new()) } + pub(crate) fn new_gai_with_overrides(overrides: HashMap) -> Self { + let gai = hyper::client::connect::dns::GaiResolver::new(); + let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides); + Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver( + overridden_resolver, + )) + } + #[cfg(feature = "trust-dns")] pub(crate) fn new_trust_dns() -> crate::Result { TrustDnsResolver::new() @@ -47,6 +61,17 @@ impl HttpConnector { .map(Self::TrustDns) .map_err(crate::error::builder) } + + #[cfg(feature = "trust-dns")] + pub(crate) fn new_trust_dns_with_overrides( + overrides: HashMap, + ) -> crate::Result { + TrustDnsResolver::new() + .map(|resolver| DnsResolverWithOverrides::new(resolver, overrides)) + .map(hyper::client::HttpConnector::new_with_resolver) + .map(Self::TrustDnsWithOverrides) + .map_err(crate::error::builder) + } } macro_rules! impl_http_connector { @@ -57,8 +82,11 @@ macro_rules! impl_http_connector { fn $name(&mut self, $($par_name: $par_type),*)$( -> $return)? { match self { Self::Gai(resolver) => resolver.$name($($par_name),*), + Self::GaiWithDnsOverrides(resolver) => resolver.$name($($par_name),*), #[cfg(feature = "trust-dns")] Self::TrustDns(resolver) => resolver.$name($($par_name),*), + #[cfg(feature = "trust-dns")] + Self::TrustDnsWithOverrides(resolver) => resolver.$name($($par_name),*), } } )+ @@ -77,29 +105,55 @@ impl Service for HttpConnector { type Response = >::Response; type Error = >::Error; #[cfg(feature = "trust-dns")] - type Future = Either< - >::Future, - as Service>::Future, - >; + type Future = + Either< + Either< + >::Future, + > as Service< + Uri, + >>::Future, + >, + Either< + as Service>::Future, + > as Service>::Future + > + >; #[cfg(not(feature = "trust-dns"))] - type Future = Either< - >::Future, - >::Future, - >; + type Future = + Either< + Either< + >::Future, + > as Service< + Uri, + >>::Future, + >, + Either< + >::Future, + >::Future, + >, + >; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { match self { Self::Gai(resolver) => resolver.poll_ready(cx), + Self::GaiWithDnsOverrides(resolver) => resolver.poll_ready(cx), #[cfg(feature = "trust-dns")] Self::TrustDns(resolver) => resolver.poll_ready(cx), + #[cfg(feature = "trust-dns")] + Self::TrustDnsWithOverrides(resolver) => resolver.poll_ready(cx), } } fn call(&mut self, dst: Uri) -> Self::Future { match self { - Self::Gai(resolver) => Either::Left(resolver.call(dst)), + Self::Gai(resolver) => Either::Left(Either::Left(resolver.call(dst))), + Self::GaiWithDnsOverrides(resolver) => Either::Left(Either::Right(resolver.call(dst))), #[cfg(feature = "trust-dns")] - Self::TrustDns(resolver) => Either::Right(resolver.call(dst)), + Self::TrustDns(resolver) => Either::Right(Either::Left(resolver.call(dst))), + #[cfg(feature = "trust-dns")] + Self::TrustDnsWithOverrides(resolver) => { + Either::Right(Either::Right(resolver.call(dst))) + } } } } @@ -908,6 +962,103 @@ mod socks { } } +pub(crate) mod itertools { + pub(crate) enum Either { + Left(A), + Right(B), + } + + impl Iterator for Either + where + A: Iterator, + B: Iterator::Item>, + { + type Item = ::Item; + + fn next(&mut self) -> Option { + match self { + Either::Left(a) => a.next(), + Either::Right(b) => b.next(), + } + } + } +} + +pin_project! { + pub(crate) struct WrappedResolverFuture { + #[pin] + fut: Fut, + } +} + +impl std::future::Future for WrappedResolverFuture +where + Fut: std::future::Future>, + FutOutput: Iterator, +{ + type Output = Result>, FutError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + this.fut + .poll(cx) + .map(|result| result.map(itertools::Either::Left)) + } +} + +#[derive(Clone)] +pub(crate) struct DnsResolverWithOverrides +where + Resolver: Clone, +{ + dns_resolver: Resolver, + overrides: Arc>, +} + +impl DnsResolverWithOverrides { + fn new(dns_resolver: Resolver, overrides: HashMap) -> Self { + DnsResolverWithOverrides { + dns_resolver, + overrides: Arc::new(overrides), + } + } +} + +impl Service for DnsResolverWithOverrides +where + Resolver: Service + Clone, + Iter: Iterator, +{ + type Response = itertools::Either>; + type Error = >::Error; + type Future = Either< + WrappedResolverFuture<>::Future>, + futures_util::future::Ready< + Result>, Self::Error>, + >, + >; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.dns_resolver.poll_ready(cx) + } + + fn call(&mut self, name: Name) -> Self::Future { + match self.overrides.get(name.as_str()) { + Some(dest) => { + let fut = futures_util::future::ready(Ok(itertools::Either::Right( + std::iter::once(dest.to_owned()), + ))); + Either::Right(fut) + } + None => { + let resolver_fut = self.dns_resolver.call(name); + let y = WrappedResolverFuture { fut: resolver_fut }; + Either::Left(y) + } + } + } +} + mod verbose { use hyper::client::connect::{Connected, Connection}; use std::fmt; diff --git a/tests/client.rs b/tests/client.rs index 2fbddc2..4890b34 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -167,6 +167,54 @@ async fn body_pipe_response() { assert_eq!(res2.status(), reqwest::StatusCode::OK); } +#[tokio::test] +async fn overridden_dns_resolution_with_gai() { + let _ = env_logger::builder().is_test(true).try_init(); + let server = server::http(move |_req| async { http::Response::new("Hello".into()) }); + + let overridden_domain = "rust-lang.org"; + let url = format!( + "http://{}:{}/domain_override", + overridden_domain, + server.addr().port() + ); + let client = reqwest::Client::builder() + .resolve(overridden_domain, server.addr()) + .build() + .expect("client builder"); + let req = client.get(&url); + let res = req.send().await.expect("request"); + + assert_eq!(res.status(), reqwest::StatusCode::OK); + let text = res.text().await.expect("Failed to get text"); + assert_eq!("Hello", text); +} + +#[cfg(feature = "trust-dns")] +#[tokio::test] +async fn overridden_dns_resolution_with_trust_dns() { + let _ = env_logger::builder().is_test(true).try_init(); + let server = server::http(move |_req| async { http::Response::new("Hello".into()) }); + + let overridden_domain = "rust-lang.org"; + let url = format!( + "http://{}:{}/domain_override", + overridden_domain, + server.addr().port() + ); + let client = reqwest::Client::builder() + .resolve(overridden_domain, server.addr()) + .trust_dns(true) + .build() + .expect("client builder"); + let req = client.get(&url); + let res = req.send().await.expect("request"); + + assert_eq!(res.status(), reqwest::StatusCode::OK); + let text = res.text().await.expect("Failed to get text"); + assert_eq!("Hello", text); +} + #[cfg(any(feature = "native-tls", feature = "__rustls",))] #[test] fn use_preconfigured_tls_with_bogus_backend() {