Add ability to specify multiple IP addresses for resolver overrides (#1622)
This change allows the `ClientBuilder::resolve_to_addrs` method to accept a slice of `SocketAddr`s for overriding resolution for a single domain. Allowing multiple IPs more accurately reflects behavior of `getaddrinfo` and allows users to rely on hyper's happy eyeballs algorithm to connect to a host that can accept traffic on IPv4 and IPv6.
This commit is contained in:
		| @@ -120,7 +120,7 @@ struct Config { | |||||||
|     trust_dns: bool, |     trust_dns: bool, | ||||||
|     error: Option<crate::Error>, |     error: Option<crate::Error>, | ||||||
|     https_only: bool, |     https_only: bool, | ||||||
|     dns_overrides: HashMap<String, SocketAddr>, |     dns_overrides: HashMap<String, Vec<SocketAddr>>, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl Default for ClientBuilder { | impl Default for ClientBuilder { | ||||||
| @@ -1314,7 +1314,7 @@ impl ClientBuilder { | |||||||
|         self |         self | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /// Override DNS resolution for specific domains to particular IP addresses. |     /// Override DNS resolution for specific domains to a particular IP address. | ||||||
|     /// |     /// | ||||||
|     /// Warning |     /// Warning | ||||||
|     /// |     /// | ||||||
| @@ -1322,8 +1322,22 @@ impl ClientBuilder { | |||||||
|     /// traffic to a particular port you must include this port in the URL |     /// traffic to a particular port you must include this port in the URL | ||||||
|     /// itself, any port in the overridden addr will be ignored and traffic sent |     /// itself, any port in the overridden addr will be ignored and traffic sent | ||||||
|     /// to the conventional port for the given scheme (e.g. 80 for http). |     /// to the conventional port for the given scheme (e.g. 80 for http). | ||||||
|     pub fn resolve(mut self, domain: &str, addr: SocketAddr) -> ClientBuilder { |     pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder { | ||||||
|         self.config.dns_overrides.insert(domain.to_string(), addr); |         self.resolve_to_addrs(domain, &[addr]) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /// Override DNS resolution for specific domains to particular IP addresses. | ||||||
|  |     /// | ||||||
|  |     /// Warning | ||||||
|  |     /// | ||||||
|  |     /// Since the DNS protocol has no notion of ports, if you wish to send | ||||||
|  |     /// traffic to a particular port you must include this port in the URL | ||||||
|  |     /// itself, any port in the overridden addresses will be ignored and traffic sent | ||||||
|  |     /// to the conventional port for the given scheme (e.g. 80 for http). | ||||||
|  |     pub fn resolve_to_addrs(mut self, domain: &str, addrs: &[SocketAddr]) -> ClientBuilder { | ||||||
|  |         self.config | ||||||
|  |             .dns_overrides | ||||||
|  |             .insert(domain.to_string(), addrs.to_vec()); | ||||||
|         self |         self | ||||||
|     } |     } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -757,7 +757,7 @@ impl ClientBuilder { | |||||||
|         self.with_inner(|inner| inner.https_only(enabled)) |         self.with_inner(|inner| inner.https_only(enabled)) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /// Override DNS resolution for specific domains to particular IP addresses. |     /// Override DNS resolution for specific domains to a particular IP address. | ||||||
|     /// |     /// | ||||||
|     /// Warning |     /// Warning | ||||||
|     /// |     /// | ||||||
| @@ -766,7 +766,19 @@ impl ClientBuilder { | |||||||
|     /// itself, any port in the overridden addr will be ignored and traffic sent |     /// itself, any port in the overridden addr will be ignored and traffic sent | ||||||
|     /// to the conventional port for the given scheme (e.g. 80 for http). |     /// to the conventional port for the given scheme (e.g. 80 for http). | ||||||
|     pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder { |     pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder { | ||||||
|         self.with_inner(|inner| inner.resolve(domain, addr)) |         self.resolve_to_addrs(domain, &[addr]) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /// Override DNS resolution for specific domains to particular IP addresses. | ||||||
|  |     /// | ||||||
|  |     /// Warning | ||||||
|  |     /// | ||||||
|  |     /// Since the DNS protocol has no notion of ports, if you wish to send | ||||||
|  |     /// traffic to a particular port you must include this port in the URL | ||||||
|  |     /// itself, any port in the overridden addresses will be ignored and traffic sent | ||||||
|  |     /// to the conventional port for the given scheme (e.g. 80 for http). | ||||||
|  |     pub fn resolve_to_addrs(self, domain: &str, addrs: &[SocketAddr]) -> ClientBuilder { | ||||||
|  |         self.with_inner(|inner| inner.resolve_to_addrs(domain, addrs)) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // private |     // private | ||||||
|   | |||||||
| @@ -46,7 +46,7 @@ impl HttpConnector { | |||||||
|         Self::Gai(hyper::client::HttpConnector::new()) |         Self::Gai(hyper::client::HttpConnector::new()) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub(crate) fn new_gai_with_overrides(overrides: HashMap<String, SocketAddr>) -> Self { |     pub(crate) fn new_gai_with_overrides(overrides: HashMap<String, Vec<SocketAddr>>) -> Self { | ||||||
|         let gai = hyper::client::connect::dns::GaiResolver::new(); |         let gai = hyper::client::connect::dns::GaiResolver::new(); | ||||||
|         let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides); |         let overridden_resolver = DnsResolverWithOverrides::new(gai, overrides); | ||||||
|         Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver( |         Self::GaiWithDnsOverrides(hyper::client::HttpConnector::new_with_resolver( | ||||||
| @@ -64,7 +64,7 @@ impl HttpConnector { | |||||||
|  |  | ||||||
|     #[cfg(feature = "trust-dns")] |     #[cfg(feature = "trust-dns")] | ||||||
|     pub(crate) fn new_trust_dns_with_overrides( |     pub(crate) fn new_trust_dns_with_overrides( | ||||||
|         overrides: HashMap<String, SocketAddr>, |         overrides: HashMap<String, Vec<SocketAddr>>, | ||||||
|     ) -> crate::Result<HttpConnector> { |     ) -> crate::Result<HttpConnector> { | ||||||
|         TrustDnsResolver::new() |         TrustDnsResolver::new() | ||||||
|             .map(|resolver| DnsResolverWithOverrides::new(resolver, overrides)) |             .map(|resolver| DnsResolverWithOverrides::new(resolver, overrides)) | ||||||
| @@ -994,7 +994,7 @@ where | |||||||
|     Fut: std::future::Future<Output = Result<FutOutput, FutError>>, |     Fut: std::future::Future<Output = Result<FutOutput, FutError>>, | ||||||
|     FutOutput: Iterator<Item = SocketAddr>, |     FutOutput: Iterator<Item = SocketAddr>, | ||||||
| { | { | ||||||
|     type Output = Result<itertools::Either<FutOutput, std::iter::Once<SocketAddr>>, FutError>; |     type Output = Result<itertools::Either<FutOutput, std::vec::IntoIter<SocketAddr>>, FutError>; | ||||||
|  |  | ||||||
|     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||||||
|         let this = self.project(); |         let this = self.project(); | ||||||
| @@ -1010,11 +1010,11 @@ where | |||||||
|     Resolver: Clone, |     Resolver: Clone, | ||||||
| { | { | ||||||
|     dns_resolver: Resolver, |     dns_resolver: Resolver, | ||||||
|     overrides: Arc<HashMap<String, SocketAddr>>, |     overrides: Arc<HashMap<String, Vec<SocketAddr>>>, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl<Resolver: Clone> DnsResolverWithOverrides<Resolver> { | impl<Resolver: Clone> DnsResolverWithOverrides<Resolver> { | ||||||
|     fn new(dns_resolver: Resolver, overrides: HashMap<String, SocketAddr>) -> Self { |     fn new(dns_resolver: Resolver, overrides: HashMap<String, Vec<SocketAddr>>) -> Self { | ||||||
|         DnsResolverWithOverrides { |         DnsResolverWithOverrides { | ||||||
|             dns_resolver, |             dns_resolver, | ||||||
|             overrides: Arc::new(overrides), |             overrides: Arc::new(overrides), | ||||||
| @@ -1027,12 +1027,12 @@ where | |||||||
|     Resolver: Service<Name, Response = Iter> + Clone, |     Resolver: Service<Name, Response = Iter> + Clone, | ||||||
|     Iter: Iterator<Item = SocketAddr>, |     Iter: Iterator<Item = SocketAddr>, | ||||||
| { | { | ||||||
|     type Response = itertools::Either<Iter, std::iter::Once<SocketAddr>>; |     type Response = itertools::Either<Iter, std::vec::IntoIter<SocketAddr>>; | ||||||
|     type Error = <Resolver as Service<Name>>::Error; |     type Error = <Resolver as Service<Name>>::Error; | ||||||
|     type Future = Either< |     type Future = Either< | ||||||
|         WrappedResolverFuture<<Resolver as Service<Name>>::Future>, |         WrappedResolverFuture<<Resolver as Service<Name>>::Future>, | ||||||
|         futures_util::future::Ready< |         futures_util::future::Ready< | ||||||
|             Result<itertools::Either<Iter, std::iter::Once<SocketAddr>>, Self::Error>, |             Result<itertools::Either<Iter, std::vec::IntoIter<SocketAddr>>, Self::Error>, | ||||||
|         >, |         >, | ||||||
|     >; |     >; | ||||||
|  |  | ||||||
| @@ -1044,7 +1044,7 @@ where | |||||||
|         match self.overrides.get(name.as_str()) { |         match self.overrides.get(name.as_str()) { | ||||||
|             Some(dest) => { |             Some(dest) => { | ||||||
|                 let fut = futures_util::future::ready(Ok(itertools::Either::Right( |                 let fut = futures_util::future::ready(Ok(itertools::Either::Right( | ||||||
|                     std::iter::once(dest.to_owned()), |                     dest.clone().into_iter(), | ||||||
|                 ))); |                 ))); | ||||||
|                 Either::Right(fut) |                 Either::Right(fut) | ||||||
|             } |             } | ||||||
|   | |||||||
| @@ -190,6 +190,40 @@ async fn overridden_dns_resolution_with_gai() { | |||||||
|     assert_eq!("Hello", text); |     assert_eq!("Hello", text); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[tokio::test] | ||||||
|  | async fn overridden_dns_resolution_with_gai_multiple() { | ||||||
|  |     let _ = env_logger::builder().is_test(true).try_init(); | ||||||
|  |     let server = server::http(move |_req| async { http::Response::new("Hello".into()) }); | ||||||
|  |  | ||||||
|  |     let overridden_domain = "rust-lang.org"; | ||||||
|  |     let url = format!( | ||||||
|  |         "http://{}:{}/domain_override", | ||||||
|  |         overridden_domain, | ||||||
|  |         server.addr().port() | ||||||
|  |     ); | ||||||
|  |     // the server runs on IPv4 localhost, so provide both IPv4 and IPv6 and let the happy eyeballs | ||||||
|  |     // algorithm decide which address to use. | ||||||
|  |     let client = reqwest::Client::builder() | ||||||
|  |         .resolve_to_addrs( | ||||||
|  |             overridden_domain, | ||||||
|  |             &[ | ||||||
|  |                 std::net::SocketAddr::new( | ||||||
|  |                     std::net::IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), | ||||||
|  |                     server.addr().port(), | ||||||
|  |                 ), | ||||||
|  |                 server.addr(), | ||||||
|  |             ], | ||||||
|  |         ) | ||||||
|  |         .build() | ||||||
|  |         .expect("client builder"); | ||||||
|  |     let req = client.get(&url); | ||||||
|  |     let res = req.send().await.expect("request"); | ||||||
|  |  | ||||||
|  |     assert_eq!(res.status(), reqwest::StatusCode::OK); | ||||||
|  |     let text = res.text().await.expect("Failed to get text"); | ||||||
|  |     assert_eq!("Hello", text); | ||||||
|  | } | ||||||
|  |  | ||||||
| #[cfg(feature = "trust-dns")] | #[cfg(feature = "trust-dns")] | ||||||
| #[tokio::test] | #[tokio::test] | ||||||
| async fn overridden_dns_resolution_with_trust_dns() { | async fn overridden_dns_resolution_with_trust_dns() { | ||||||
| @@ -215,6 +249,42 @@ async fn overridden_dns_resolution_with_trust_dns() { | |||||||
|     assert_eq!("Hello", text); |     assert_eq!("Hello", text); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[cfg(feature = "trust-dns")] | ||||||
|  | #[tokio::test] | ||||||
|  | async fn overridden_dns_resolution_with_trust_dns_multiple() { | ||||||
|  |     let _ = env_logger::builder().is_test(true).try_init(); | ||||||
|  |     let server = server::http(move |_req| async { http::Response::new("Hello".into()) }); | ||||||
|  |  | ||||||
|  |     let overridden_domain = "rust-lang.org"; | ||||||
|  |     let url = format!( | ||||||
|  |         "http://{}:{}/domain_override", | ||||||
|  |         overridden_domain, | ||||||
|  |         server.addr().port() | ||||||
|  |     ); | ||||||
|  |     // the server runs on IPv4 localhost, so provide both IPv4 and IPv6 and let the happy eyeballs | ||||||
|  |     // algorithm decide which address to use. | ||||||
|  |     let client = reqwest::Client::builder() | ||||||
|  |         .resolve_to_addrs( | ||||||
|  |             overridden_domain, | ||||||
|  |             &[ | ||||||
|  |                 std::net::SocketAddr::new( | ||||||
|  |                     std::net::IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), | ||||||
|  |                     server.addr().port(), | ||||||
|  |                 ), | ||||||
|  |                 server.addr(), | ||||||
|  |             ], | ||||||
|  |         ) | ||||||
|  |         .trust_dns(true) | ||||||
|  |         .build() | ||||||
|  |         .expect("client builder"); | ||||||
|  |     let req = client.get(&url); | ||||||
|  |     let res = req.send().await.expect("request"); | ||||||
|  |  | ||||||
|  |     assert_eq!(res.status(), reqwest::StatusCode::OK); | ||||||
|  |     let text = res.text().await.expect("Failed to get text"); | ||||||
|  |     assert_eq!("Hello", text); | ||||||
|  | } | ||||||
|  |  | ||||||
| #[cfg(any(feature = "native-tls", feature = "__rustls",))] | #[cfg(any(feature = "native-tls", feature = "__rustls",))] | ||||||
| #[test] | #[test] | ||||||
| fn use_preconfigured_tls_with_bogus_backend() { | fn use_preconfigured_tls_with_bogus_backend() { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user