From ae2d5216649c9be04074e7ae2c37b129e8768470 Mon Sep 17 00:00:00 2001 From: lpraneis Date: Mon, 19 Sep 2022 15:53:36 -0500 Subject: [PATCH] Add ability to specify multiple IP addresses for resolver overrides (#1622) This change allows the `ClientBuilder::resolve_to_addrs` method to accept a slice of `SocketAddr`s for overriding resolution for a single domain. Allowing multiple IPs more accurately reflects behavior of `getaddrinfo` and allows users to rely on hyper's happy eyeballs algorithm to connect to a host that can accept traffic on IPv4 and IPv6. --- src/async_impl/client.rs | 22 ++++++++++--- src/blocking/client.rs | 16 +++++++-- src/connect.rs | 16 ++++----- tests/client.rs | 70 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 14 deletions(-) diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 4517328..c6283a5 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -120,7 +120,7 @@ struct Config { trust_dns: bool, error: Option, https_only: bool, - dns_overrides: HashMap, + dns_overrides: HashMap>, } impl Default for ClientBuilder { @@ -1314,7 +1314,7 @@ impl ClientBuilder { self } - /// Override DNS resolution for specific domains to particular IP addresses. + /// Override DNS resolution for specific domains to a particular IP address. /// /// Warning /// @@ -1322,8 +1322,22 @@ impl ClientBuilder { /// traffic to a particular port you must include this port in the URL /// itself, any port in the overridden addr will be ignored and traffic sent /// to the conventional port for the given scheme (e.g. 80 for http). - pub fn resolve(mut self, domain: &str, addr: SocketAddr) -> ClientBuilder { - self.config.dns_overrides.insert(domain.to_string(), addr); + pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder { + self.resolve_to_addrs(domain, &[addr]) + } + + /// Override DNS resolution for specific domains to particular IP addresses. + /// + /// Warning + /// + /// Since the DNS protocol has no notion of ports, if you wish to send + /// traffic to a particular port you must include this port in the URL + /// itself, any port in the overridden addresses will be ignored and traffic sent + /// to the conventional port for the given scheme (e.g. 80 for http). + pub fn resolve_to_addrs(mut self, domain: &str, addrs: &[SocketAddr]) -> ClientBuilder { + self.config + .dns_overrides + .insert(domain.to_string(), addrs.to_vec()); self } } diff --git a/src/blocking/client.rs b/src/blocking/client.rs index 6e40e4d..bfba12a 100644 --- a/src/blocking/client.rs +++ b/src/blocking/client.rs @@ -757,7 +757,7 @@ impl ClientBuilder { self.with_inner(|inner| inner.https_only(enabled)) } - /// Override DNS resolution for specific domains to particular IP addresses. + /// Override DNS resolution for specific domains to a particular IP address. /// /// Warning /// @@ -766,7 +766,19 @@ impl ClientBuilder { /// itself, any port in the overridden addr will be ignored and traffic sent /// to the conventional port for the given scheme (e.g. 80 for http). pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder { - self.with_inner(|inner| inner.resolve(domain, addr)) + self.resolve_to_addrs(domain, &[addr]) + } + + /// Override DNS resolution for specific domains to particular IP addresses. + /// + /// Warning + /// + /// Since the DNS protocol has no notion of ports, if you wish to send + /// traffic to a particular port you must include this port in the URL + /// itself, any port in the overridden addresses will be ignored and traffic sent + /// to the conventional port for the given scheme (e.g. 80 for http). + pub fn resolve_to_addrs(self, domain: &str, addrs: &[SocketAddr]) -> ClientBuilder { + self.with_inner(|inner| inner.resolve_to_addrs(domain, addrs)) } // private diff --git a/src/connect.rs b/src/connect.rs index 4f2c3db..35843ec 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -46,7 +46,7 @@ impl HttpConnector { Self::Gai(hyper::client::HttpConnector::new()) } - pub(crate) fn new_gai_with_overrides(overrides: HashMap) -> Self { + pub(crate) fn new_gai_with_overrides(overrides: HashMap>) -> Self { let gai = hyper::client::connect::dns::GaiResolver::new(); let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides); Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver( @@ -64,7 +64,7 @@ impl HttpConnector { #[cfg(feature = "trust-dns")] pub(crate) fn new_trust_dns_with_overrides( - overrides: HashMap, + overrides: HashMap>, ) -> crate::Result { TrustDnsResolver::new() .map(|resolver| DnsResolverWithOverrides::new(resolver, overrides)) @@ -994,7 +994,7 @@ where Fut: std::future::Future>, FutOutput: Iterator, { - type Output = Result>, FutError>; + type Output = Result>, FutError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -1010,11 +1010,11 @@ where Resolver: Clone, { dns_resolver: Resolver, - overrides: Arc>, + overrides: Arc>>, } impl DnsResolverWithOverrides { - fn new(dns_resolver: Resolver, overrides: HashMap) -> Self { + fn new(dns_resolver: Resolver, overrides: HashMap>) -> Self { DnsResolverWithOverrides { dns_resolver, overrides: Arc::new(overrides), @@ -1027,12 +1027,12 @@ where Resolver: Service + Clone, Iter: Iterator, { - type Response = itertools::Either>; + type Response = itertools::Either>; type Error = >::Error; type Future = Either< WrappedResolverFuture<>::Future>, futures_util::future::Ready< - Result>, Self::Error>, + Result>, Self::Error>, >, >; @@ -1044,7 +1044,7 @@ where match self.overrides.get(name.as_str()) { Some(dest) => { let fut = futures_util::future::ready(Ok(itertools::Either::Right( - std::iter::once(dest.to_owned()), + dest.clone().into_iter(), ))); Either::Right(fut) } diff --git a/tests/client.rs b/tests/client.rs index 8a663fd..56267e6 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -190,6 +190,40 @@ async fn overridden_dns_resolution_with_gai() { assert_eq!("Hello", text); } +#[tokio::test] +async fn overridden_dns_resolution_with_gai_multiple() { + let _ = env_logger::builder().is_test(true).try_init(); + let server = server::http(move |_req| async { http::Response::new("Hello".into()) }); + + let overridden_domain = "rust-lang.org"; + let url = format!( + "http://{}:{}/domain_override", + overridden_domain, + server.addr().port() + ); + // the server runs on IPv4 localhost, so provide both IPv4 and IPv6 and let the happy eyeballs + // algorithm decide which address to use. + let client = reqwest::Client::builder() + .resolve_to_addrs( + overridden_domain, + &[ + std::net::SocketAddr::new( + std::net::IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), + server.addr().port(), + ), + server.addr(), + ], + ) + .build() + .expect("client builder"); + let req = client.get(&url); + let res = req.send().await.expect("request"); + + assert_eq!(res.status(), reqwest::StatusCode::OK); + let text = res.text().await.expect("Failed to get text"); + assert_eq!("Hello", text); +} + #[cfg(feature = "trust-dns")] #[tokio::test] async fn overridden_dns_resolution_with_trust_dns() { @@ -215,6 +249,42 @@ async fn overridden_dns_resolution_with_trust_dns() { assert_eq!("Hello", text); } +#[cfg(feature = "trust-dns")] +#[tokio::test] +async fn overridden_dns_resolution_with_trust_dns_multiple() { + let _ = env_logger::builder().is_test(true).try_init(); + let server = server::http(move |_req| async { http::Response::new("Hello".into()) }); + + let overridden_domain = "rust-lang.org"; + let url = format!( + "http://{}:{}/domain_override", + overridden_domain, + server.addr().port() + ); + // the server runs on IPv4 localhost, so provide both IPv4 and IPv6 and let the happy eyeballs + // algorithm decide which address to use. + let client = reqwest::Client::builder() + .resolve_to_addrs( + overridden_domain, + &[ + std::net::SocketAddr::new( + std::net::IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), + server.addr().port(), + ), + server.addr(), + ], + ) + .trust_dns(true) + .build() + .expect("client builder"); + let req = client.get(&url); + let res = req.send().await.expect("request"); + + assert_eq!(res.status(), reqwest::StatusCode::OK); + let text = res.text().await.expect("Failed to get text"); + assert_eq!("Hello", text); +} + #[cfg(any(feature = "native-tls", feature = "__rustls",))] #[test] fn use_preconfigured_tls_with_bogus_backend() {