From 30ac01c1806236aef443a5ff9e119955941a188c Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Wed, 4 Dec 2019 22:39:56 +0100 Subject: [PATCH] refactor(client): use async/await in HttpConnector (#2019) Closes #1984 --- src/client/connect/dns.rs | 110 ++++---- src/client/connect/http.rs | 528 ++++++++++++++++++------------------- 2 files changed, 326 insertions(+), 312 deletions(-) diff --git a/src/client/connect/dns.rs b/src/client/connect/dns.rs index 17b9396b..68f87714 100644 --- a/src/client/connect/dns.rs +++ b/src/client/connect/dns.rs @@ -21,18 +21,16 @@ //! Ok::<_, Infallible>(iter::once(IpAddr::from([127, 0, 0, 1]))) //! }); //! ``` -use std::{fmt, io, vec}; use std::error::Error; -use std::net::{ - IpAddr, Ipv4Addr, Ipv6Addr, - SocketAddr, ToSocketAddrs, - SocketAddrV4, SocketAddrV6, -}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}; use std::str::FromStr; +use std::task::{self, Poll}; +use std::pin::Pin; +use std::future::Future; +use std::{fmt, io, vec}; use tokio::task::JoinHandle; use tower_service::Service; -use crate::common::{Future, Pin, Poll, task}; pub(super) use self::sealed::Resolve; @@ -60,9 +58,7 @@ pub struct GaiFuture { impl Name { pub(super) fn new(host: String) -> Name { - Name { - host, - } + Name { host } } /// View the hostname as a string slice. @@ -104,13 +100,10 @@ impl fmt::Display for InvalidNameError { impl Error for InvalidNameError {} - impl GaiResolver { /// Construct a new `GaiResolver`. pub fn new() -> Self { - GaiResolver { - _priv: (), - } + GaiResolver { _priv: () } } } @@ -126,13 +119,12 @@ impl Service for GaiResolver { fn call(&mut self, name: Name) -> Self::Future { let blocking = tokio::task::spawn_blocking(move || { debug!("resolving host={:?}", name.host); - (&*name.host, 0).to_socket_addrs() + (&*name.host, 0) + .to_socket_addrs() .map(|i| IpAddrs { iter: i }) }); - GaiFuture { - inner: blocking, - } + GaiFuture { inner: blocking } } } @@ -180,37 +172,46 @@ pub(super) struct IpAddrs { impl IpAddrs { pub(super) fn new(addrs: Vec) -> Self { - IpAddrs { iter: addrs.into_iter() } + IpAddrs { + iter: addrs.into_iter(), + } } pub(super) fn try_parse(host: &str, port: u16) -> Option { if let Ok(addr) = host.parse::() { let addr = SocketAddrV4::new(addr, port); - return Some(IpAddrs { iter: vec![SocketAddr::V4(addr)].into_iter() }) + return Some(IpAddrs { + iter: vec![SocketAddr::V4(addr)].into_iter(), + }); } let host = host.trim_start_matches('[').trim_end_matches(']'); if let Ok(addr) = host.parse::() { let addr = SocketAddrV6::new(addr, port, 0, 0); - return Some(IpAddrs { iter: vec![SocketAddr::V6(addr)].into_iter() }) + return Some(IpAddrs { + iter: vec![SocketAddr::V6(addr)].into_iter(), + }); } None } pub(super) fn split_by_preference(self, local_addr: Option) -> (IpAddrs, IpAddrs) { if let Some(local_addr) = local_addr { - let preferred = self.iter + let preferred = self + .iter .filter(|addr| addr.is_ipv6() == local_addr.is_ipv6()) .collect(); (IpAddrs::new(preferred), IpAddrs::new(vec![])) } else { - let preferring_v6 = self.iter + let preferring_v6 = self + .iter .as_slice() .first() .map(SocketAddr::is_ipv6) .unwrap_or(false); - let (preferred, fallback) = self.iter + let (preferred, fallback) = self + .iter .partition::, _>(|addr| addr.is_ipv6() == preferring_v6); (IpAddrs::new(preferred), IpAddrs::new(fallback)) @@ -281,8 +282,15 @@ impl Future for TokioThreadpoolGaiFuture { type Output = Result; fn poll(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll { - match ready!(tokio_executor::threadpool::blocking(|| (self.name.as_str(), 0).to_socket_addrs())) { - Ok(Ok(iter)) => Poll::Ready(Ok(GaiAddrs { inner: IpAddrs { iter } })), + match ready!(tokio_executor::threadpool::blocking(|| ( + self.name.as_str(), + 0 + ) + .to_socket_addrs())) + { + Ok(Ok(iter)) => Poll::Ready(Ok(GaiAddrs { + inner: IpAddrs { iter }, + })), Ok(Err(e)) => Poll::Ready(Err(e)), // a BlockingError, meaning not on a tokio_executor::threadpool :( Err(e) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))), @@ -292,15 +300,15 @@ impl Future for TokioThreadpoolGaiFuture { */ mod sealed { - use tower_service::Service; - use crate::common::{Future, Poll, task}; use super::{IpAddr, Name}; + use crate::common::{task, Future, Poll}; + use tower_service::Service; // "Trait alias" for `Service` pub trait Resolve { - type Addrs: Iterator; + type Addrs: Iterator; type Error: Into>; - type Future: Future>; + type Future: Future>; fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll>; fn resolve(&mut self, name: Name) -> Self::Future; @@ -309,7 +317,7 @@ mod sealed { impl Resolve for S where S: Service, - S::Response: Iterator, + S::Response: Iterator, S::Error: Into>, { type Addrs = S::Response; @@ -326,33 +334,49 @@ mod sealed { } } +pub(crate) async fn resolve(resolver: &mut R, name: Name) -> Result +where + R: Resolve, +{ + futures_util::future::poll_fn(|cx| resolver.poll_ready(cx)).await?; + resolver.resolve(name).await +} + #[cfg(test)] mod tests { - use std::net::{Ipv4Addr, Ipv6Addr}; use super::*; + use std::net::{Ipv4Addr, Ipv6Addr}; #[test] fn test_ip_addrs_split_by_preference() { let v4_addr = (Ipv4Addr::new(127, 0, 0, 1), 80).into(); let v6_addr = (Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80).into(); - let (mut preferred, mut fallback) = - IpAddrs { iter: vec![v4_addr, v6_addr].into_iter() }.split_by_preference(None); + let (mut preferred, mut fallback) = IpAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(None); assert!(preferred.next().unwrap().is_ipv4()); assert!(fallback.next().unwrap().is_ipv6()); - let (mut preferred, mut fallback) = - IpAddrs { iter: vec![v6_addr, v4_addr].into_iter() }.split_by_preference(None); + let (mut preferred, mut fallback) = IpAddrs { + iter: vec![v6_addr, v4_addr].into_iter(), + } + .split_by_preference(None); assert!(preferred.next().unwrap().is_ipv6()); assert!(fallback.next().unwrap().is_ipv4()); - let (mut preferred, fallback) = - IpAddrs { iter: vec![v4_addr, v6_addr].into_iter() }.split_by_preference(Some(v4_addr.ip())); + let (mut preferred, fallback) = IpAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(Some(v4_addr.ip())); assert!(preferred.next().unwrap().is_ipv4()); assert!(fallback.is_empty()); - let (mut preferred, fallback) = - IpAddrs { iter: vec![v4_addr, v6_addr].into_iter() }.split_by_preference(Some(v6_addr.ip())); + let (mut preferred, fallback) = IpAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(Some(v6_addr.ip())); assert!(preferred.next().unwrap().is_ipv6()); assert!(fallback.is_empty()); } @@ -370,10 +394,8 @@ mod tests { let uri = ::http::Uri::from_static("http://[::1]:8080/"); let dst = super::super::Destination { uri }; - let mut addrs = IpAddrs::try_parse( - dst.host(), - dst.port().expect("port") - ).expect("try_parse"); + let mut addrs = + IpAddrs::try_parse(dst.host(), dst.port().expect("port")).expect("try_parse"); let expected = "[::1]:8080".parse::().expect("expected"); diff --git a/src/client/connect/http.rs b/src/client/connect/http.rs index d505d34b..f9a08d5a 100644 --- a/src/client/connect/http.rs +++ b/src/client/connect/http.rs @@ -1,26 +1,23 @@ -use std::fmt; use std::error::Error as StdError; +use std::future::Future; +use std::pin::Pin; +use std::task::{self, Poll}; +use std::fmt; use std::io; -use std::mem; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Duration; +use futures_util::future::Either; use http::uri::{Scheme, Uri}; -use futures_util::{TryFutureExt}; use net2::TcpBuilder; -use pin_project::{pin_project, project}; use tokio::net::TcpStream; use tokio::time::Delay; -use crate::common::{Future, Pin, Poll, task}; +use super::dns::{self, resolve, GaiResolver, Resolve}; use super::{Connected, Destination}; -use super::dns::{self, GaiResolver, Resolve}; //#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver; -// TODO: unbox me? -type ConnectFuture = Pin> + Send>>; - /// A connector for the `http` scheme. /// /// Performs DNS resolution in a thread pool, and then connects over TCP. @@ -102,7 +99,6 @@ impl HttpConnector { } */ - impl HttpConnector { /// Construct a new HttpConnector. /// @@ -223,35 +219,22 @@ static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http"; static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing"; static INVALID_MISSING_HOST: &str = "invalid URL, host is missing"; -impl HttpConnector { - fn invalid_url(&self, msg: impl Into>) -> HttpConnecting { - HttpConnecting { - config: self.config.clone(), - state: State::Error(Some(ConnectError { - msg: msg.into(), - cause: None, - })), - port: 0, - } - } -} - // R: Debug required for now to allow adding it to debug output later... impl fmt::Debug for HttpConnector { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("HttpConnector") - .finish() + f.debug_struct("HttpConnector").finish() } } impl tower_service::Service for HttpConnector where - R: Resolve + Clone + Send + Sync, + R: Resolve + Clone + Send + Sync + 'static, R::Future: Send, { type Response = (TcpStream, Connected); type Error = ConnectError; - type Future = HttpConnecting; + type Future = + Pin> + Send + 'static>>; fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?; @@ -259,6 +242,19 @@ where } fn call(&mut self, dst: Destination) -> Self::Future { + let mut self_ = self.clone(); + Box::pin(async move { self_.call_async(dst).await }) + } +} + +impl HttpConnector +where + R: Resolve, +{ + async fn call_async( + &mut self, + dst: Destination, + ) -> Result<(TcpStream, Connected), ConnectError> { trace!( "Http::connect; scheme={}, host={}, port={:?}", dst.scheme(), @@ -268,26 +264,85 @@ where if self.config.enforce_http { if dst.uri.scheme() != Some(&Scheme::HTTP) { - return self.invalid_url(INVALID_NOT_HTTP); + return Err(ConnectError { + msg: INVALID_NOT_HTTP.into(), + cause: None, + }); } } else if dst.uri.scheme().is_none() { - return self.invalid_url(INVALID_MISSING_SCHEME); + return Err(ConnectError { + msg: INVALID_MISSING_SCHEME.into(), + cause: None, + }); } let host = match dst.uri.host() { Some(s) => s, - None => return self.invalid_url(INVALID_MISSING_HOST), + None => { + return Err(ConnectError { + msg: INVALID_MISSING_HOST.into(), + cause: None, + }) + } }; let port = match dst.uri.port() { Some(port) => port.as_u16(), None => if dst.uri.scheme() == Some(&Scheme::HTTPS) { 443 } else { 80 }, }; - HttpConnecting { - config: self.config.clone(), - state: State::Lazy(self.resolver.clone(), host.into()), - port, + let config = &self.config; + + // If the host is already an IP addr (v4 or v6), + // skip resolving the dns and start connecting right away. + let addrs = if let Some(addrs) = dns::IpAddrs::try_parse(host, port) { + addrs + } else { + let addrs = resolve(&mut self.resolver, dns::Name::new(host.into())) + .await + .map_err(ConnectError::dns)?; + let addrs = addrs.map(|addr| SocketAddr::new(addr, port)).collect(); + dns::IpAddrs::new(addrs) + }; + + let c = ConnectingTcp::new( + config.local_address, + addrs, + config.connect_timeout, + config.happy_eyeballs_timeout, + config.reuse_address, + ); + + let sock = c + .connect() + .await + .map_err(ConnectError::m("tcp connect error"))?; + + if let Some(dur) = config.keep_alive_timeout { + sock.set_keepalive(Some(dur)) + .map_err(ConnectError::m("tcp set_keepalive error"))?; } + + if let Some(size) = config.send_buffer_size { + sock.set_send_buffer_size(size) + .map_err(ConnectError::m("tcp set_send_buffer_size error"))?; + } + + if let Some(size) = config.recv_buffer_size { + sock.set_recv_buffer_size(size) + .map_err(ConnectError::m("tcp set_recv_buffer_size error"))?; + } + + sock.set_nodelay(config.nodelay) + .map_err(ConnectError::m("tcp set_nodelay error"))?; + + let extra = HttpInfo { + remote_addr: sock + .peer_addr() + .map_err(ConnectError::m("tcp peer_addr error"))?, + }; + let connected = Connected::new().extra(extra); + + Ok((sock, connected)) } } @@ -298,14 +353,16 @@ where { type Response = TcpStream; type Error = ConnectError; - type Future = Pin> + Send + 'static>>; + type Future = + Pin> + Send + 'static>>; fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { tower_service::Service::::poll_ready(self, cx) } fn call(&mut self, uri: Uri) -> Self::Future { - Box::pin(self.call(Destination { uri }).map_ok(|(s, _)| s)) + let mut self_ = self.clone(); + Box::pin(async move { self_.call_async(Destination { uri }).await.map(|(s, _)| s) }) } } @@ -346,9 +403,7 @@ impl ConnectError { S: Into>, E: Into>, { - move |cause| { - ConnectError::new(msg, cause) - } + move |cause| ConnectError::new(msg, cause) } } @@ -383,96 +438,6 @@ impl StdError for ConnectError { } } -/// A Future representing work to connect to a URL. -#[must_use = "futures do nothing unless polled"] -#[pin_project] -pub struct HttpConnecting { - config: Arc, - #[pin] - state: State, - port: u16, -} - -#[pin_project] -enum State { - Lazy(R, String), - Resolving(#[pin] R::Future), - Connecting(ConnectingTcp), - Error(Option), -} - -impl Future for HttpConnecting { - type Output = Result<(TcpStream, Connected), ConnectError>; - - #[project] - fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - let mut me = self.project(); - let config: &Config = &me.config; - loop { - let state; - #[project] - match me.state.as_mut().project() { - State::Lazy(ref mut resolver, ref mut host) => { - // If the host is already an IP addr (v4 or v6), - // skip resolving the dns and start connecting right away. - if let Some(addrs) = dns::IpAddrs::try_parse(host, *me.port) { - state = State::Connecting(ConnectingTcp::new( - config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address)); - } else { - ready!(resolver.poll_ready(cx)).map_err(ConnectError::dns)?; - let name = dns::Name::new(mem::replace(host, String::new())); - state = State::Resolving(resolver.resolve(name)); - } - }, - State::Resolving(future) => { - let addrs = ready!(future.poll(cx)).map_err(ConnectError::dns)?; - let port = *me.port; - let addrs = addrs - .map(|addr| SocketAddr::new(addr, port)) - .collect(); - let addrs = dns::IpAddrs::new(addrs); - state = State::Connecting(ConnectingTcp::new( - config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address)); - }, - State::Connecting(ref mut c) => { - let sock = ready!(c.poll(cx)) - .map_err(ConnectError::m("tcp connect error"))?; - - if let Some(dur) = config.keep_alive_timeout { - sock.set_keepalive(Some(dur)).map_err(ConnectError::m("tcp set_keepalive error"))?; - } - - if let Some(size) = config.send_buffer_size { - sock.set_send_buffer_size(size).map_err(ConnectError::m("tcp set_send_buffer_size error"))?; - } - - if let Some(size) = config.recv_buffer_size { - sock.set_recv_buffer_size(size).map_err(ConnectError::m("tcp set_recv_buffer_size error"))?; - } - - sock.set_nodelay(config.nodelay).map_err(ConnectError::m("tcp set_nodelay error"))?; - - let extra = HttpInfo { - remote_addr: sock.peer_addr().map_err(ConnectError::m("tcp peer_addr error"))?, - }; - let connected = Connected::new() - .extra(extra); - - return Poll::Ready(Ok((sock, connected))); - }, - State::Error(ref mut e) => return Poll::Ready(Err(e.take().expect("polled more than once"))), - } - me.state.set(state); - } - } -} - -impl fmt::Debug for HttpConnecting { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.pad("HttpConnecting") - } -} - struct ConnectingTcp { local_addr: Option, preferred: ConnectingTcpRemote, @@ -527,7 +492,6 @@ struct ConnectingTcpFallback { struct ConnectingTcpRemote { addrs: dns::IpAddrs, connect_timeout: Option, - current: Option, } impl ConnectingTcpRemote { @@ -537,45 +501,39 @@ impl ConnectingTcpRemote { Self { addrs, connect_timeout, - current: None, } } } impl ConnectingTcpRemote { - fn poll( + async fn connect( &mut self, - cx: &mut task::Context<'_>, local_addr: &Option, reuse_address: bool, - ) -> Poll> { + ) -> io::Result { let mut err = None; - loop { - if let Some(ref mut current) = self.current { - match current.as_mut().poll(cx) { - Poll::Ready(Ok(tcp)) => { - debug!("connected to {:?}", tcp.peer_addr().ok()); - return Poll::Ready(Ok(tcp)); - }, - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => { - trace!("connect error {:?}", e); - err = Some(e); - if let Some(addr) = self.addrs.next() { - debug!("connecting to {}", addr); - *current = connect(&addr, local_addr, reuse_address, self.connect_timeout)?; - continue; - } - } + for addr in &mut self.addrs { + debug!("connecting to {}", addr); + match connect( + &addr, + local_addr, + reuse_address, + self.connect_timeout, + )? + .await + { + Ok(tcp) => { + debug!("connected to {:?}", tcp.peer_addr().ok()); + return Ok(tcp); + } + Err(e) => { + trace!("connect error {:?}", e); + err = Some(e); } - } else if let Some(addr) = self.addrs.next() { - debug!("connecting to {}", addr); - self.current = Some(connect(&addr, local_addr, reuse_address, self.connect_timeout)?); - continue; } - - return Poll::Ready(Err(err.take().expect("missing connect error"))); } + + return Err(err.take().expect("missing connect error")); } } @@ -584,7 +542,7 @@ fn connect( local_addr: &Option, reuse_address: bool, connect_timeout: Option, -) -> io::Result { +) -> io::Result>> { let builder = match addr { &SocketAddr::V4(_) => TcpBuilder::new_v4()?, &SocketAddr::V6(_) => TcpBuilder::new_v6()?, @@ -600,12 +558,8 @@ fn connect( } else if cfg!(windows) { // Windows requires a socket be bound before calling connect let any: SocketAddr = match addr { - &SocketAddr::V4(_) => { - ([0, 0, 0, 0], 0).into() - }, - &SocketAddr::V6(_) => { - ([0, 0, 0, 0, 0, 0, 0, 0], 0).into() - } + &SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(), + &SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(), }; builder.bind(any)?; } @@ -614,56 +568,58 @@ fn connect( let std_tcp = builder.to_tcp_stream()?; - Ok(Box::pin(async move { + Ok(async move { let connect = TcpStream::connect_std(std_tcp, &addr); match connect_timeout { Some(dur) => match tokio::time::timeout(dur, connect).await { Ok(Ok(s)) => Ok(s), Ok(Err(e)) => Err(e), Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)), - } + }, None => connect.await, } - })) + }) } impl ConnectingTcp { - fn poll(&mut self, cx: &mut task::Context<'_>) -> Poll> { - match self.fallback.take() { - None => self.preferred.poll(cx, &self.local_addr, self.reuse_address), - Some(mut fallback) => match self.preferred.poll(cx, &self.local_addr, self.reuse_address) { - Poll::Ready(Ok(stream)) => { - // Preferred successful - drop fallback. - Poll::Ready(Ok(stream)) - } - Poll::Pending => match Pin::new(&mut fallback.delay).poll(cx) { - Poll::Ready(()) => match fallback.remote.poll(cx, &self.local_addr, self.reuse_address) { - Poll::Ready(Ok(stream)) => { - // Fallback successful - drop current preferred, - // but keep fallback as new preferred. - self.preferred = fallback.remote; - Poll::Ready(Ok(stream)) + async fn connect(mut self) -> io::Result { + let Self { + ref local_addr, + reuse_address, + .. + } = self; + match self.fallback { + None => { + self.preferred + .connect(local_addr, reuse_address) + .await + } + Some(mut fallback) => { + let preferred_fut = self.preferred.connect(local_addr, reuse_address); + futures_util::pin_mut!(preferred_fut); + + let fallback_fut = fallback.remote.connect(local_addr, reuse_address); + futures_util::pin_mut!(fallback_fut); + + let (result, future) = + match futures_util::future::select(preferred_fut, fallback.delay).await { + Either::Left((result, _fallback_delay)) => { + (result, Either::Right(fallback_fut)) } - Poll::Pending => { - // Neither preferred nor fallback are ready. - self.fallback = Some(fallback); - Poll::Pending + Either::Right(((), preferred_fut)) => { + // Delay is done, start polling both the preferred and the fallback + futures_util::future::select(preferred_fut, fallback_fut) + .await + .factor_first() } - Poll::Ready(Err(_)) => { - // Fallback failed - resume with preferred only. - Poll::Pending - } - }, - Poll::Pending => { - // Too early to attempt fallback. - self.fallback = Some(fallback); - Poll::Pending - } - } - Poll::Ready(Err(_)) => { - // Preferred failed - use fallback as new preferred. - self.preferred = fallback.remote; - self.preferred.poll(cx, &self.local_addr, self.reuse_address) + }; + + if let Err(_) = result { + // Fallback to the remaining future (could be preferred or fallback) + // if we get an error + future.await + } else { + result } } } @@ -674,10 +630,13 @@ impl ConnectingTcp { mod tests { use std::io; - use super::{Connected, Destination, HttpConnector}; use super::super::sealed::Connect; + use super::{Connected, Destination, HttpConnector}; - async fn connect(connector: C, dst: Destination) -> Result<(C::Transport, Connected), C::Error> + async fn connect( + connector: C, + dst: Destination, + ) -> Result<(C::Transport, Connected), C::Error> where C: Connect, { @@ -687,9 +646,7 @@ mod tests { #[tokio::test] async fn test_errors_enforce_http() { let uri = "https://example.domain/foo/bar?baz".parse().unwrap(); - let dst = Destination { - uri, - }; + let dst = Destination { uri }; let connector = HttpConnector::new(); let err = connect(connector, dst).await.unwrap_err(); @@ -699,9 +656,7 @@ mod tests { #[tokio::test] async fn test_errors_missing_scheme() { let uri = "example.domain".parse().unwrap(); - let dst = Destination { - uri, - }; + let dst = Destination { uri }; let mut connector = HttpConnector::new(); connector.enforce_http(false); @@ -712,12 +667,9 @@ mod tests { #[test] #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)] fn client_happy_eyeballs() { - use std::future::Future; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener}; - use std::task::Poll; use std::time::{Duration, Instant}; - use crate::common::{Pin, task}; use super::dns; use super::ConnectingTcp; @@ -740,40 +692,81 @@ mod tests { let scenarios = &[ // Fast primary, without fallback. - (&[local_ipv4_addr()][..], - 4, local_timeout, false), - (&[local_ipv6_addr()][..], - 6, local_timeout, false), - + (&[local_ipv4_addr()][..], 4, local_timeout, false), + (&[local_ipv6_addr()][..], 6, local_timeout, false), // Fast primary, with (unused) fallback. - (&[local_ipv4_addr(), local_ipv6_addr()][..], - 4, local_timeout, false), - (&[local_ipv6_addr(), local_ipv4_addr()][..], - 6, local_timeout, false), - + ( + &[local_ipv4_addr(), local_ipv6_addr()][..], + 4, + local_timeout, + false, + ), + ( + &[local_ipv6_addr(), local_ipv4_addr()][..], + 6, + local_timeout, + false, + ), // Unreachable + fast primary, without fallback. - (&[unreachable_ipv4_addr(), local_ipv4_addr()][..], - 4, unreachable_v4_timeout, false), - (&[unreachable_ipv6_addr(), local_ipv6_addr()][..], - 6, unreachable_v6_timeout, false), - + ( + &[unreachable_ipv4_addr(), local_ipv4_addr()][..], + 4, + unreachable_v4_timeout, + false, + ), + ( + &[unreachable_ipv6_addr(), local_ipv6_addr()][..], + 6, + unreachable_v6_timeout, + false, + ), // Unreachable + fast primary, with (unused) fallback. - (&[unreachable_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..], - 4, unreachable_v4_timeout, false), - (&[unreachable_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..], - 6, unreachable_v6_timeout, true), - + ( + &[ + unreachable_ipv4_addr(), + local_ipv4_addr(), + local_ipv6_addr(), + ][..], + 4, + unreachable_v4_timeout, + false, + ), + ( + &[ + unreachable_ipv6_addr(), + local_ipv6_addr(), + local_ipv4_addr(), + ][..], + 6, + unreachable_v6_timeout, + true, + ), // Slow primary, with (used) fallback. - (&[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..], - 6, fallback_timeout, false), - (&[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..], - 4, fallback_timeout, true), - + ( + &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..], + 6, + fallback_timeout, + false, + ), + ( + &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..], + 4, + fallback_timeout, + true, + ), // Slow primary, with (used) unreachable + fast fallback. - (&[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..], - 6, fallback_timeout + unreachable_v6_timeout, false), - (&[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..], - 4, fallback_timeout + unreachable_v4_timeout, true), + ( + &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..], + 6, + fallback_timeout + unreachable_v6_timeout, + false, + ), + ( + &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..], + 4, + fallback_timeout + unreachable_v4_timeout, + true, + ), ]; // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network. @@ -785,14 +778,30 @@ mod tests { continue; } - let addrs = hosts.iter().map(|host| (host.clone(), addr.port()).into()).collect(); - let (res, duration) = rt.block_on(async move { - let connecting_tcp = ConnectingTcp::new(None, dns::IpAddrs::new(addrs), None, Some(fallback_timeout), false); - let fut = ConnectingTcpFuture(connecting_tcp); - let start = Instant::now(); - let res = fut.await.unwrap(); - (res, start.elapsed()) - }); + + let (start, stream) = rt + .block_on(async move { + let addrs = hosts + .iter() + .map(|host| (host.clone(), addr.port()).into()) + .collect(); + let connecting_tcp = ConnectingTcp::new( + None, + dns::IpAddrs::new(addrs), + None, + Some(fallback_timeout), + false, + ); + let start = Instant::now(); + Ok::<_, io::Error>((start, connecting_tcp.connect().await?)) + }) + .unwrap(); + let res = if stream.peer_addr().unwrap().is_ipv4() { + 4 + } else { + 6 + }; + let duration = start.elapsed(); // Allow actual duration to be +/- 150ms off. let min_duration = if timeout >= Duration::from_millis(150) { @@ -807,22 +816,6 @@ mod tests { assert!(duration <= max_duration); } - struct ConnectingTcpFuture(ConnectingTcp); - - impl Future for ConnectingTcpFuture { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - match self.0.poll(cx) { - Poll::Ready(Ok(stream)) => Poll::Ready(Ok( - if stream.peer_addr().unwrap().is_ipv4() { 4 } else { 6 } - )), - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, - } - } - } - fn local_ipv4_addr() -> IpAddr { Ipv4Addr::new(127, 0, 0, 1).into() } @@ -851,8 +844,8 @@ mod tests { fn measure_connect(addr: IpAddr) -> (bool, Duration) { let start = Instant::now(); - let result = ::std::net::TcpStream::connect_timeout( - &(addr, 80).into(), Duration::from_secs(1)); + let result = + ::std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1)); let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut; let duration = start.elapsed(); @@ -860,4 +853,3 @@ mod tests { } } } -