refactor(client): use async/await in HttpConnector (#2019)
Closes #1984
This commit is contained in:
committed by
Sean McArthur
parent
19a7aab51f
commit
30ac01c180
@@ -21,18 +21,16 @@
|
||||
//! Ok::<_, Infallible>(iter::once(IpAddr::from([127, 0, 0, 1])))
|
||||
//! });
|
||||
//! ```
|
||||
use std::{fmt, io, vec};
|
||||
use std::error::Error;
|
||||
use std::net::{
|
||||
IpAddr, Ipv4Addr, Ipv6Addr,
|
||||
SocketAddr, ToSocketAddrs,
|
||||
SocketAddrV4, SocketAddrV6,
|
||||
};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
|
||||
use std::str::FromStr;
|
||||
use std::task::{self, Poll};
|
||||
use std::pin::Pin;
|
||||
use std::future::Future;
|
||||
use std::{fmt, io, vec};
|
||||
|
||||
use tokio::task::JoinHandle;
|
||||
use tower_service::Service;
|
||||
use crate::common::{Future, Pin, Poll, task};
|
||||
|
||||
pub(super) use self::sealed::Resolve;
|
||||
|
||||
@@ -60,9 +58,7 @@ pub struct GaiFuture {
|
||||
|
||||
impl Name {
|
||||
pub(super) fn new(host: String) -> Name {
|
||||
Name {
|
||||
host,
|
||||
}
|
||||
Name { host }
|
||||
}
|
||||
|
||||
/// View the hostname as a string slice.
|
||||
@@ -104,13 +100,10 @@ impl fmt::Display for InvalidNameError {
|
||||
|
||||
impl Error for InvalidNameError {}
|
||||
|
||||
|
||||
impl GaiResolver {
|
||||
/// Construct a new `GaiResolver`.
|
||||
pub fn new() -> Self {
|
||||
GaiResolver {
|
||||
_priv: (),
|
||||
}
|
||||
GaiResolver { _priv: () }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,13 +119,12 @@ impl Service<Name> for GaiResolver {
|
||||
fn call(&mut self, name: Name) -> Self::Future {
|
||||
let blocking = tokio::task::spawn_blocking(move || {
|
||||
debug!("resolving host={:?}", name.host);
|
||||
(&*name.host, 0).to_socket_addrs()
|
||||
(&*name.host, 0)
|
||||
.to_socket_addrs()
|
||||
.map(|i| IpAddrs { iter: i })
|
||||
});
|
||||
|
||||
GaiFuture {
|
||||
inner: blocking,
|
||||
}
|
||||
GaiFuture { inner: blocking }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,37 +172,46 @@ pub(super) struct IpAddrs {
|
||||
|
||||
impl IpAddrs {
|
||||
pub(super) fn new(addrs: Vec<SocketAddr>) -> Self {
|
||||
IpAddrs { iter: addrs.into_iter() }
|
||||
IpAddrs {
|
||||
iter: addrs.into_iter(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn try_parse(host: &str, port: u16) -> Option<IpAddrs> {
|
||||
if let Ok(addr) = host.parse::<Ipv4Addr>() {
|
||||
let addr = SocketAddrV4::new(addr, port);
|
||||
return Some(IpAddrs { iter: vec![SocketAddr::V4(addr)].into_iter() })
|
||||
return Some(IpAddrs {
|
||||
iter: vec![SocketAddr::V4(addr)].into_iter(),
|
||||
});
|
||||
}
|
||||
let host = host.trim_start_matches('[').trim_end_matches(']');
|
||||
if let Ok(addr) = host.parse::<Ipv6Addr>() {
|
||||
let addr = SocketAddrV6::new(addr, port, 0, 0);
|
||||
return Some(IpAddrs { iter: vec![SocketAddr::V6(addr)].into_iter() })
|
||||
return Some(IpAddrs {
|
||||
iter: vec![SocketAddr::V6(addr)].into_iter(),
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(super) fn split_by_preference(self, local_addr: Option<IpAddr>) -> (IpAddrs, IpAddrs) {
|
||||
if let Some(local_addr) = local_addr {
|
||||
let preferred = self.iter
|
||||
let preferred = self
|
||||
.iter
|
||||
.filter(|addr| addr.is_ipv6() == local_addr.is_ipv6())
|
||||
.collect();
|
||||
|
||||
(IpAddrs::new(preferred), IpAddrs::new(vec![]))
|
||||
} else {
|
||||
let preferring_v6 = self.iter
|
||||
let preferring_v6 = self
|
||||
.iter
|
||||
.as_slice()
|
||||
.first()
|
||||
.map(SocketAddr::is_ipv6)
|
||||
.unwrap_or(false);
|
||||
|
||||
let (preferred, fallback) = self.iter
|
||||
let (preferred, fallback) = self
|
||||
.iter
|
||||
.partition::<Vec<_>, _>(|addr| addr.is_ipv6() == preferring_v6);
|
||||
|
||||
(IpAddrs::new(preferred), IpAddrs::new(fallback))
|
||||
@@ -281,8 +282,15 @@ impl Future for TokioThreadpoolGaiFuture {
|
||||
type Output = Result<GaiAddrs, io::Error>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Self::Output> {
|
||||
match ready!(tokio_executor::threadpool::blocking(|| (self.name.as_str(), 0).to_socket_addrs())) {
|
||||
Ok(Ok(iter)) => Poll::Ready(Ok(GaiAddrs { inner: IpAddrs { iter } })),
|
||||
match ready!(tokio_executor::threadpool::blocking(|| (
|
||||
self.name.as_str(),
|
||||
0
|
||||
)
|
||||
.to_socket_addrs()))
|
||||
{
|
||||
Ok(Ok(iter)) => Poll::Ready(Ok(GaiAddrs {
|
||||
inner: IpAddrs { iter },
|
||||
})),
|
||||
Ok(Err(e)) => Poll::Ready(Err(e)),
|
||||
// a BlockingError, meaning not on a tokio_executor::threadpool :(
|
||||
Err(e) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
|
||||
@@ -292,15 +300,15 @@ impl Future for TokioThreadpoolGaiFuture {
|
||||
*/
|
||||
|
||||
mod sealed {
|
||||
use tower_service::Service;
|
||||
use crate::common::{Future, Poll, task};
|
||||
use super::{IpAddr, Name};
|
||||
use crate::common::{task, Future, Poll};
|
||||
use tower_service::Service;
|
||||
|
||||
// "Trait alias" for `Service<Name, Response = Addrs>`
|
||||
pub trait Resolve {
|
||||
type Addrs: Iterator<Item=IpAddr>;
|
||||
type Addrs: Iterator<Item = IpAddr>;
|
||||
type Error: Into<Box<dyn std::error::Error + Send + Sync>>;
|
||||
type Future: Future<Output=Result<Self::Addrs, Self::Error>>;
|
||||
type Future: Future<Output = Result<Self::Addrs, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>;
|
||||
fn resolve(&mut self, name: Name) -> Self::Future;
|
||||
@@ -309,7 +317,7 @@ mod sealed {
|
||||
impl<S> Resolve for S
|
||||
where
|
||||
S: Service<Name>,
|
||||
S::Response: Iterator<Item=IpAddr>,
|
||||
S::Response: Iterator<Item = IpAddr>,
|
||||
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
{
|
||||
type Addrs = S::Response;
|
||||
@@ -326,33 +334,49 @@ mod sealed {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn resolve<R>(resolver: &mut R, name: Name) -> Result<R::Addrs, R::Error>
|
||||
where
|
||||
R: Resolve,
|
||||
{
|
||||
futures_util::future::poll_fn(|cx| resolver.poll_ready(cx)).await?;
|
||||
resolver.resolve(name).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use super::*;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
|
||||
#[test]
|
||||
fn test_ip_addrs_split_by_preference() {
|
||||
let v4_addr = (Ipv4Addr::new(127, 0, 0, 1), 80).into();
|
||||
let v6_addr = (Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80).into();
|
||||
|
||||
let (mut preferred, mut fallback) =
|
||||
IpAddrs { iter: vec![v4_addr, v6_addr].into_iter() }.split_by_preference(None);
|
||||
let (mut preferred, mut fallback) = IpAddrs {
|
||||
iter: vec![v4_addr, v6_addr].into_iter(),
|
||||
}
|
||||
.split_by_preference(None);
|
||||
assert!(preferred.next().unwrap().is_ipv4());
|
||||
assert!(fallback.next().unwrap().is_ipv6());
|
||||
|
||||
let (mut preferred, mut fallback) =
|
||||
IpAddrs { iter: vec![v6_addr, v4_addr].into_iter() }.split_by_preference(None);
|
||||
let (mut preferred, mut fallback) = IpAddrs {
|
||||
iter: vec![v6_addr, v4_addr].into_iter(),
|
||||
}
|
||||
.split_by_preference(None);
|
||||
assert!(preferred.next().unwrap().is_ipv6());
|
||||
assert!(fallback.next().unwrap().is_ipv4());
|
||||
|
||||
let (mut preferred, fallback) =
|
||||
IpAddrs { iter: vec![v4_addr, v6_addr].into_iter() }.split_by_preference(Some(v4_addr.ip()));
|
||||
let (mut preferred, fallback) = IpAddrs {
|
||||
iter: vec![v4_addr, v6_addr].into_iter(),
|
||||
}
|
||||
.split_by_preference(Some(v4_addr.ip()));
|
||||
assert!(preferred.next().unwrap().is_ipv4());
|
||||
assert!(fallback.is_empty());
|
||||
|
||||
let (mut preferred, fallback) =
|
||||
IpAddrs { iter: vec![v4_addr, v6_addr].into_iter() }.split_by_preference(Some(v6_addr.ip()));
|
||||
let (mut preferred, fallback) = IpAddrs {
|
||||
iter: vec![v4_addr, v6_addr].into_iter(),
|
||||
}
|
||||
.split_by_preference(Some(v6_addr.ip()));
|
||||
assert!(preferred.next().unwrap().is_ipv6());
|
||||
assert!(fallback.is_empty());
|
||||
}
|
||||
@@ -370,10 +394,8 @@ mod tests {
|
||||
let uri = ::http::Uri::from_static("http://[::1]:8080/");
|
||||
let dst = super::super::Destination { uri };
|
||||
|
||||
let mut addrs = IpAddrs::try_parse(
|
||||
dst.host(),
|
||||
dst.port().expect("port")
|
||||
).expect("try_parse");
|
||||
let mut addrs =
|
||||
IpAddrs::try_parse(dst.host(), dst.port().expect("port")).expect("try_parse");
|
||||
|
||||
let expected = "[::1]:8080".parse::<SocketAddr>().expect("expected");
|
||||
|
||||
|
||||
@@ -1,26 +1,23 @@
|
||||
use std::fmt;
|
||||
use std::error::Error as StdError;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{self, Poll};
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use std::mem;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures_util::future::Either;
|
||||
use http::uri::{Scheme, Uri};
|
||||
use futures_util::{TryFutureExt};
|
||||
use net2::TcpBuilder;
|
||||
use pin_project::{pin_project, project};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::Delay;
|
||||
|
||||
use crate::common::{Future, Pin, Poll, task};
|
||||
use super::dns::{self, resolve, GaiResolver, Resolve};
|
||||
use super::{Connected, Destination};
|
||||
use super::dns::{self, GaiResolver, Resolve};
|
||||
//#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver;
|
||||
|
||||
// TODO: unbox me?
|
||||
type ConnectFuture = Pin<Box<dyn Future<Output = io::Result<TcpStream>> + Send>>;
|
||||
|
||||
/// A connector for the `http` scheme.
|
||||
///
|
||||
/// Performs DNS resolution in a thread pool, and then connects over TCP.
|
||||
@@ -102,7 +99,6 @@ impl HttpConnector<TokioThreadpoolGaiResolver> {
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
impl<R> HttpConnector<R> {
|
||||
/// Construct a new HttpConnector.
|
||||
///
|
||||
@@ -223,35 +219,22 @@ static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
|
||||
static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
|
||||
static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
|
||||
|
||||
impl<R: Resolve> HttpConnector<R> {
|
||||
fn invalid_url(&self, msg: impl Into<Box<str>>) -> HttpConnecting<R> {
|
||||
HttpConnecting {
|
||||
config: self.config.clone(),
|
||||
state: State::Error(Some(ConnectError {
|
||||
msg: msg.into(),
|
||||
cause: None,
|
||||
})),
|
||||
port: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// R: Debug required for now to allow adding it to debug output later...
|
||||
impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("HttpConnector")
|
||||
.finish()
|
||||
f.debug_struct("HttpConnector").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> tower_service::Service<Destination> for HttpConnector<R>
|
||||
where
|
||||
R: Resolve + Clone + Send + Sync,
|
||||
R: Resolve + Clone + Send + Sync + 'static,
|
||||
R::Future: Send,
|
||||
{
|
||||
type Response = (TcpStream, Connected);
|
||||
type Error = ConnectError;
|
||||
type Future = HttpConnecting<R>;
|
||||
type Future =
|
||||
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
|
||||
@@ -259,6 +242,19 @@ where
|
||||
}
|
||||
|
||||
fn call(&mut self, dst: Destination) -> Self::Future {
|
||||
let mut self_ = self.clone();
|
||||
Box::pin(async move { self_.call_async(dst).await })
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> HttpConnector<R>
|
||||
where
|
||||
R: Resolve,
|
||||
{
|
||||
async fn call_async(
|
||||
&mut self,
|
||||
dst: Destination,
|
||||
) -> Result<(TcpStream, Connected), ConnectError> {
|
||||
trace!(
|
||||
"Http::connect; scheme={}, host={}, port={:?}",
|
||||
dst.scheme(),
|
||||
@@ -268,26 +264,85 @@ where
|
||||
|
||||
if self.config.enforce_http {
|
||||
if dst.uri.scheme() != Some(&Scheme::HTTP) {
|
||||
return self.invalid_url(INVALID_NOT_HTTP);
|
||||
return Err(ConnectError {
|
||||
msg: INVALID_NOT_HTTP.into(),
|
||||
cause: None,
|
||||
});
|
||||
}
|
||||
} else if dst.uri.scheme().is_none() {
|
||||
return self.invalid_url(INVALID_MISSING_SCHEME);
|
||||
return Err(ConnectError {
|
||||
msg: INVALID_MISSING_SCHEME.into(),
|
||||
cause: None,
|
||||
});
|
||||
}
|
||||
|
||||
let host = match dst.uri.host() {
|
||||
Some(s) => s,
|
||||
None => return self.invalid_url(INVALID_MISSING_HOST),
|
||||
None => {
|
||||
return Err(ConnectError {
|
||||
msg: INVALID_MISSING_HOST.into(),
|
||||
cause: None,
|
||||
})
|
||||
}
|
||||
};
|
||||
let port = match dst.uri.port() {
|
||||
Some(port) => port.as_u16(),
|
||||
None => if dst.uri.scheme() == Some(&Scheme::HTTPS) { 443 } else { 80 },
|
||||
};
|
||||
|
||||
HttpConnecting {
|
||||
config: self.config.clone(),
|
||||
state: State::Lazy(self.resolver.clone(), host.into()),
|
||||
port,
|
||||
let config = &self.config;
|
||||
|
||||
// If the host is already an IP addr (v4 or v6),
|
||||
// skip resolving the dns and start connecting right away.
|
||||
let addrs = if let Some(addrs) = dns::IpAddrs::try_parse(host, port) {
|
||||
addrs
|
||||
} else {
|
||||
let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
|
||||
.await
|
||||
.map_err(ConnectError::dns)?;
|
||||
let addrs = addrs.map(|addr| SocketAddr::new(addr, port)).collect();
|
||||
dns::IpAddrs::new(addrs)
|
||||
};
|
||||
|
||||
let c = ConnectingTcp::new(
|
||||
config.local_address,
|
||||
addrs,
|
||||
config.connect_timeout,
|
||||
config.happy_eyeballs_timeout,
|
||||
config.reuse_address,
|
||||
);
|
||||
|
||||
let sock = c
|
||||
.connect()
|
||||
.await
|
||||
.map_err(ConnectError::m("tcp connect error"))?;
|
||||
|
||||
if let Some(dur) = config.keep_alive_timeout {
|
||||
sock.set_keepalive(Some(dur))
|
||||
.map_err(ConnectError::m("tcp set_keepalive error"))?;
|
||||
}
|
||||
|
||||
if let Some(size) = config.send_buffer_size {
|
||||
sock.set_send_buffer_size(size)
|
||||
.map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
|
||||
}
|
||||
|
||||
if let Some(size) = config.recv_buffer_size {
|
||||
sock.set_recv_buffer_size(size)
|
||||
.map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
|
||||
}
|
||||
|
||||
sock.set_nodelay(config.nodelay)
|
||||
.map_err(ConnectError::m("tcp set_nodelay error"))?;
|
||||
|
||||
let extra = HttpInfo {
|
||||
remote_addr: sock
|
||||
.peer_addr()
|
||||
.map_err(ConnectError::m("tcp peer_addr error"))?,
|
||||
};
|
||||
let connected = Connected::new().extra(extra);
|
||||
|
||||
Ok((sock, connected))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -298,14 +353,16 @@ where
|
||||
{
|
||||
type Response = TcpStream;
|
||||
type Error = ConnectError;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
|
||||
type Future =
|
||||
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
tower_service::Service::<Destination>::poll_ready(self, cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, uri: Uri) -> Self::Future {
|
||||
Box::pin(self.call(Destination { uri }).map_ok(|(s, _)| s))
|
||||
let mut self_ = self.clone();
|
||||
Box::pin(async move { self_.call_async(Destination { uri }).await.map(|(s, _)| s) })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -346,9 +403,7 @@ impl ConnectError {
|
||||
S: Into<Box<str>>,
|
||||
E: Into<Box<dyn StdError + Send + Sync>>,
|
||||
{
|
||||
move |cause| {
|
||||
ConnectError::new(msg, cause)
|
||||
}
|
||||
move |cause| ConnectError::new(msg, cause)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -383,96 +438,6 @@ impl StdError for ConnectError {
|
||||
}
|
||||
}
|
||||
|
||||
/// A Future representing work to connect to a URL.
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
#[pin_project]
|
||||
pub struct HttpConnecting<R: Resolve = GaiResolver> {
|
||||
config: Arc<Config>,
|
||||
#[pin]
|
||||
state: State<R>,
|
||||
port: u16,
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
enum State<R: Resolve> {
|
||||
Lazy(R, String),
|
||||
Resolving(#[pin] R::Future),
|
||||
Connecting(ConnectingTcp),
|
||||
Error(Option<ConnectError>),
|
||||
}
|
||||
|
||||
impl<R: Resolve> Future for HttpConnecting<R> {
|
||||
type Output = Result<(TcpStream, Connected), ConnectError>;
|
||||
|
||||
#[project]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
|
||||
let mut me = self.project();
|
||||
let config: &Config = &me.config;
|
||||
loop {
|
||||
let state;
|
||||
#[project]
|
||||
match me.state.as_mut().project() {
|
||||
State::Lazy(ref mut resolver, ref mut host) => {
|
||||
// If the host is already an IP addr (v4 or v6),
|
||||
// skip resolving the dns and start connecting right away.
|
||||
if let Some(addrs) = dns::IpAddrs::try_parse(host, *me.port) {
|
||||
state = State::Connecting(ConnectingTcp::new(
|
||||
config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address));
|
||||
} else {
|
||||
ready!(resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
|
||||
let name = dns::Name::new(mem::replace(host, String::new()));
|
||||
state = State::Resolving(resolver.resolve(name));
|
||||
}
|
||||
},
|
||||
State::Resolving(future) => {
|
||||
let addrs = ready!(future.poll(cx)).map_err(ConnectError::dns)?;
|
||||
let port = *me.port;
|
||||
let addrs = addrs
|
||||
.map(|addr| SocketAddr::new(addr, port))
|
||||
.collect();
|
||||
let addrs = dns::IpAddrs::new(addrs);
|
||||
state = State::Connecting(ConnectingTcp::new(
|
||||
config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address));
|
||||
},
|
||||
State::Connecting(ref mut c) => {
|
||||
let sock = ready!(c.poll(cx))
|
||||
.map_err(ConnectError::m("tcp connect error"))?;
|
||||
|
||||
if let Some(dur) = config.keep_alive_timeout {
|
||||
sock.set_keepalive(Some(dur)).map_err(ConnectError::m("tcp set_keepalive error"))?;
|
||||
}
|
||||
|
||||
if let Some(size) = config.send_buffer_size {
|
||||
sock.set_send_buffer_size(size).map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
|
||||
}
|
||||
|
||||
if let Some(size) = config.recv_buffer_size {
|
||||
sock.set_recv_buffer_size(size).map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
|
||||
}
|
||||
|
||||
sock.set_nodelay(config.nodelay).map_err(ConnectError::m("tcp set_nodelay error"))?;
|
||||
|
||||
let extra = HttpInfo {
|
||||
remote_addr: sock.peer_addr().map_err(ConnectError::m("tcp peer_addr error"))?,
|
||||
};
|
||||
let connected = Connected::new()
|
||||
.extra(extra);
|
||||
|
||||
return Poll::Ready(Ok((sock, connected)));
|
||||
},
|
||||
State::Error(ref mut e) => return Poll::Ready(Err(e.take().expect("polled more than once"))),
|
||||
}
|
||||
me.state.set(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Resolve + fmt::Debug> fmt::Debug for HttpConnecting<R> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.pad("HttpConnecting")
|
||||
}
|
||||
}
|
||||
|
||||
struct ConnectingTcp {
|
||||
local_addr: Option<IpAddr>,
|
||||
preferred: ConnectingTcpRemote,
|
||||
@@ -527,7 +492,6 @@ struct ConnectingTcpFallback {
|
||||
struct ConnectingTcpRemote {
|
||||
addrs: dns::IpAddrs,
|
||||
connect_timeout: Option<Duration>,
|
||||
current: Option<ConnectFuture>,
|
||||
}
|
||||
|
||||
impl ConnectingTcpRemote {
|
||||
@@ -537,45 +501,39 @@ impl ConnectingTcpRemote {
|
||||
Self {
|
||||
addrs,
|
||||
connect_timeout,
|
||||
current: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectingTcpRemote {
|
||||
fn poll(
|
||||
async fn connect(
|
||||
&mut self,
|
||||
cx: &mut task::Context<'_>,
|
||||
local_addr: &Option<IpAddr>,
|
||||
reuse_address: bool,
|
||||
) -> Poll<io::Result<TcpStream>> {
|
||||
) -> io::Result<TcpStream> {
|
||||
let mut err = None;
|
||||
loop {
|
||||
if let Some(ref mut current) = self.current {
|
||||
match current.as_mut().poll(cx) {
|
||||
Poll::Ready(Ok(tcp)) => {
|
||||
debug!("connected to {:?}", tcp.peer_addr().ok());
|
||||
return Poll::Ready(Ok(tcp));
|
||||
},
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Err(e)) => {
|
||||
trace!("connect error {:?}", e);
|
||||
err = Some(e);
|
||||
if let Some(addr) = self.addrs.next() {
|
||||
debug!("connecting to {}", addr);
|
||||
*current = connect(&addr, local_addr, reuse_address, self.connect_timeout)?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
for addr in &mut self.addrs {
|
||||
debug!("connecting to {}", addr);
|
||||
match connect(
|
||||
&addr,
|
||||
local_addr,
|
||||
reuse_address,
|
||||
self.connect_timeout,
|
||||
)?
|
||||
.await
|
||||
{
|
||||
Ok(tcp) => {
|
||||
debug!("connected to {:?}", tcp.peer_addr().ok());
|
||||
return Ok(tcp);
|
||||
}
|
||||
Err(e) => {
|
||||
trace!("connect error {:?}", e);
|
||||
err = Some(e);
|
||||
}
|
||||
} else if let Some(addr) = self.addrs.next() {
|
||||
debug!("connecting to {}", addr);
|
||||
self.current = Some(connect(&addr, local_addr, reuse_address, self.connect_timeout)?);
|
||||
continue;
|
||||
}
|
||||
|
||||
return Poll::Ready(Err(err.take().expect("missing connect error")));
|
||||
}
|
||||
|
||||
return Err(err.take().expect("missing connect error"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -584,7 +542,7 @@ fn connect(
|
||||
local_addr: &Option<IpAddr>,
|
||||
reuse_address: bool,
|
||||
connect_timeout: Option<Duration>,
|
||||
) -> io::Result<ConnectFuture> {
|
||||
) -> io::Result<impl Future<Output = io::Result<TcpStream>>> {
|
||||
let builder = match addr {
|
||||
&SocketAddr::V4(_) => TcpBuilder::new_v4()?,
|
||||
&SocketAddr::V6(_) => TcpBuilder::new_v6()?,
|
||||
@@ -600,12 +558,8 @@ fn connect(
|
||||
} else if cfg!(windows) {
|
||||
// Windows requires a socket be bound before calling connect
|
||||
let any: SocketAddr = match addr {
|
||||
&SocketAddr::V4(_) => {
|
||||
([0, 0, 0, 0], 0).into()
|
||||
},
|
||||
&SocketAddr::V6(_) => {
|
||||
([0, 0, 0, 0, 0, 0, 0, 0], 0).into()
|
||||
}
|
||||
&SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
|
||||
&SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
|
||||
};
|
||||
builder.bind(any)?;
|
||||
}
|
||||
@@ -614,56 +568,58 @@ fn connect(
|
||||
|
||||
let std_tcp = builder.to_tcp_stream()?;
|
||||
|
||||
Ok(Box::pin(async move {
|
||||
Ok(async move {
|
||||
let connect = TcpStream::connect_std(std_tcp, &addr);
|
||||
match connect_timeout {
|
||||
Some(dur) => match tokio::time::timeout(dur, connect).await {
|
||||
Ok(Ok(s)) => Ok(s),
|
||||
Ok(Err(e)) => Err(e),
|
||||
Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
|
||||
}
|
||||
},
|
||||
None => connect.await,
|
||||
}
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
impl ConnectingTcp {
|
||||
fn poll(&mut self, cx: &mut task::Context<'_>) -> Poll<io::Result<TcpStream>> {
|
||||
match self.fallback.take() {
|
||||
None => self.preferred.poll(cx, &self.local_addr, self.reuse_address),
|
||||
Some(mut fallback) => match self.preferred.poll(cx, &self.local_addr, self.reuse_address) {
|
||||
Poll::Ready(Ok(stream)) => {
|
||||
// Preferred successful - drop fallback.
|
||||
Poll::Ready(Ok(stream))
|
||||
}
|
||||
Poll::Pending => match Pin::new(&mut fallback.delay).poll(cx) {
|
||||
Poll::Ready(()) => match fallback.remote.poll(cx, &self.local_addr, self.reuse_address) {
|
||||
Poll::Ready(Ok(stream)) => {
|
||||
// Fallback successful - drop current preferred,
|
||||
// but keep fallback as new preferred.
|
||||
self.preferred = fallback.remote;
|
||||
Poll::Ready(Ok(stream))
|
||||
async fn connect(mut self) -> io::Result<TcpStream> {
|
||||
let Self {
|
||||
ref local_addr,
|
||||
reuse_address,
|
||||
..
|
||||
} = self;
|
||||
match self.fallback {
|
||||
None => {
|
||||
self.preferred
|
||||
.connect(local_addr, reuse_address)
|
||||
.await
|
||||
}
|
||||
Some(mut fallback) => {
|
||||
let preferred_fut = self.preferred.connect(local_addr, reuse_address);
|
||||
futures_util::pin_mut!(preferred_fut);
|
||||
|
||||
let fallback_fut = fallback.remote.connect(local_addr, reuse_address);
|
||||
futures_util::pin_mut!(fallback_fut);
|
||||
|
||||
let (result, future) =
|
||||
match futures_util::future::select(preferred_fut, fallback.delay).await {
|
||||
Either::Left((result, _fallback_delay)) => {
|
||||
(result, Either::Right(fallback_fut))
|
||||
}
|
||||
Poll::Pending => {
|
||||
// Neither preferred nor fallback are ready.
|
||||
self.fallback = Some(fallback);
|
||||
Poll::Pending
|
||||
Either::Right(((), preferred_fut)) => {
|
||||
// Delay is done, start polling both the preferred and the fallback
|
||||
futures_util::future::select(preferred_fut, fallback_fut)
|
||||
.await
|
||||
.factor_first()
|
||||
}
|
||||
Poll::Ready(Err(_)) => {
|
||||
// Fallback failed - resume with preferred only.
|
||||
Poll::Pending
|
||||
}
|
||||
},
|
||||
Poll::Pending => {
|
||||
// Too early to attempt fallback.
|
||||
self.fallback = Some(fallback);
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(_)) => {
|
||||
// Preferred failed - use fallback as new preferred.
|
||||
self.preferred = fallback.remote;
|
||||
self.preferred.poll(cx, &self.local_addr, self.reuse_address)
|
||||
};
|
||||
|
||||
if let Err(_) = result {
|
||||
// Fallback to the remaining future (could be preferred or fallback)
|
||||
// if we get an error
|
||||
future.await
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -674,10 +630,13 @@ impl ConnectingTcp {
|
||||
mod tests {
|
||||
use std::io;
|
||||
|
||||
use super::{Connected, Destination, HttpConnector};
|
||||
use super::super::sealed::Connect;
|
||||
use super::{Connected, Destination, HttpConnector};
|
||||
|
||||
async fn connect<C>(connector: C, dst: Destination) -> Result<(C::Transport, Connected), C::Error>
|
||||
async fn connect<C>(
|
||||
connector: C,
|
||||
dst: Destination,
|
||||
) -> Result<(C::Transport, Connected), C::Error>
|
||||
where
|
||||
C: Connect,
|
||||
{
|
||||
@@ -687,9 +646,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_errors_enforce_http() {
|
||||
let uri = "https://example.domain/foo/bar?baz".parse().unwrap();
|
||||
let dst = Destination {
|
||||
uri,
|
||||
};
|
||||
let dst = Destination { uri };
|
||||
let connector = HttpConnector::new();
|
||||
|
||||
let err = connect(connector, dst).await.unwrap_err();
|
||||
@@ -699,9 +656,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_errors_missing_scheme() {
|
||||
let uri = "example.domain".parse().unwrap();
|
||||
let dst = Destination {
|
||||
uri,
|
||||
};
|
||||
let dst = Destination { uri };
|
||||
let mut connector = HttpConnector::new();
|
||||
connector.enforce_http(false);
|
||||
|
||||
@@ -712,12 +667,9 @@ mod tests {
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
|
||||
fn client_happy_eyeballs() {
|
||||
use std::future::Future;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
|
||||
use std::task::Poll;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::common::{Pin, task};
|
||||
use super::dns;
|
||||
use super::ConnectingTcp;
|
||||
|
||||
@@ -740,40 +692,81 @@ mod tests {
|
||||
|
||||
let scenarios = &[
|
||||
// Fast primary, without fallback.
|
||||
(&[local_ipv4_addr()][..],
|
||||
4, local_timeout, false),
|
||||
(&[local_ipv6_addr()][..],
|
||||
6, local_timeout, false),
|
||||
|
||||
(&[local_ipv4_addr()][..], 4, local_timeout, false),
|
||||
(&[local_ipv6_addr()][..], 6, local_timeout, false),
|
||||
// Fast primary, with (unused) fallback.
|
||||
(&[local_ipv4_addr(), local_ipv6_addr()][..],
|
||||
4, local_timeout, false),
|
||||
(&[local_ipv6_addr(), local_ipv4_addr()][..],
|
||||
6, local_timeout, false),
|
||||
|
||||
(
|
||||
&[local_ipv4_addr(), local_ipv6_addr()][..],
|
||||
4,
|
||||
local_timeout,
|
||||
false,
|
||||
),
|
||||
(
|
||||
&[local_ipv6_addr(), local_ipv4_addr()][..],
|
||||
6,
|
||||
local_timeout,
|
||||
false,
|
||||
),
|
||||
// Unreachable + fast primary, without fallback.
|
||||
(&[unreachable_ipv4_addr(), local_ipv4_addr()][..],
|
||||
4, unreachable_v4_timeout, false),
|
||||
(&[unreachable_ipv6_addr(), local_ipv6_addr()][..],
|
||||
6, unreachable_v6_timeout, false),
|
||||
|
||||
(
|
||||
&[unreachable_ipv4_addr(), local_ipv4_addr()][..],
|
||||
4,
|
||||
unreachable_v4_timeout,
|
||||
false,
|
||||
),
|
||||
(
|
||||
&[unreachable_ipv6_addr(), local_ipv6_addr()][..],
|
||||
6,
|
||||
unreachable_v6_timeout,
|
||||
false,
|
||||
),
|
||||
// Unreachable + fast primary, with (unused) fallback.
|
||||
(&[unreachable_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
|
||||
4, unreachable_v4_timeout, false),
|
||||
(&[unreachable_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
|
||||
6, unreachable_v6_timeout, true),
|
||||
|
||||
(
|
||||
&[
|
||||
unreachable_ipv4_addr(),
|
||||
local_ipv4_addr(),
|
||||
local_ipv6_addr(),
|
||||
][..],
|
||||
4,
|
||||
unreachable_v4_timeout,
|
||||
false,
|
||||
),
|
||||
(
|
||||
&[
|
||||
unreachable_ipv6_addr(),
|
||||
local_ipv6_addr(),
|
||||
local_ipv4_addr(),
|
||||
][..],
|
||||
6,
|
||||
unreachable_v6_timeout,
|
||||
true,
|
||||
),
|
||||
// Slow primary, with (used) fallback.
|
||||
(&[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
|
||||
6, fallback_timeout, false),
|
||||
(&[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
|
||||
4, fallback_timeout, true),
|
||||
|
||||
(
|
||||
&[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
|
||||
6,
|
||||
fallback_timeout,
|
||||
false,
|
||||
),
|
||||
(
|
||||
&[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
|
||||
4,
|
||||
fallback_timeout,
|
||||
true,
|
||||
),
|
||||
// Slow primary, with (used) unreachable + fast fallback.
|
||||
(&[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
|
||||
6, fallback_timeout + unreachable_v6_timeout, false),
|
||||
(&[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
|
||||
4, fallback_timeout + unreachable_v4_timeout, true),
|
||||
(
|
||||
&[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
|
||||
6,
|
||||
fallback_timeout + unreachable_v6_timeout,
|
||||
false,
|
||||
),
|
||||
(
|
||||
&[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
|
||||
4,
|
||||
fallback_timeout + unreachable_v4_timeout,
|
||||
true,
|
||||
),
|
||||
];
|
||||
|
||||
// Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
|
||||
@@ -785,14 +778,30 @@ mod tests {
|
||||
continue;
|
||||
}
|
||||
|
||||
let addrs = hosts.iter().map(|host| (host.clone(), addr.port()).into()).collect();
|
||||
let (res, duration) = rt.block_on(async move {
|
||||
let connecting_tcp = ConnectingTcp::new(None, dns::IpAddrs::new(addrs), None, Some(fallback_timeout), false);
|
||||
let fut = ConnectingTcpFuture(connecting_tcp);
|
||||
let start = Instant::now();
|
||||
let res = fut.await.unwrap();
|
||||
(res, start.elapsed())
|
||||
});
|
||||
|
||||
let (start, stream) = rt
|
||||
.block_on(async move {
|
||||
let addrs = hosts
|
||||
.iter()
|
||||
.map(|host| (host.clone(), addr.port()).into())
|
||||
.collect();
|
||||
let connecting_tcp = ConnectingTcp::new(
|
||||
None,
|
||||
dns::IpAddrs::new(addrs),
|
||||
None,
|
||||
Some(fallback_timeout),
|
||||
false,
|
||||
);
|
||||
let start = Instant::now();
|
||||
Ok::<_, io::Error>((start, connecting_tcp.connect().await?))
|
||||
})
|
||||
.unwrap();
|
||||
let res = if stream.peer_addr().unwrap().is_ipv4() {
|
||||
4
|
||||
} else {
|
||||
6
|
||||
};
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Allow actual duration to be +/- 150ms off.
|
||||
let min_duration = if timeout >= Duration::from_millis(150) {
|
||||
@@ -807,22 +816,6 @@ mod tests {
|
||||
assert!(duration <= max_duration);
|
||||
}
|
||||
|
||||
struct ConnectingTcpFuture(ConnectingTcp);
|
||||
|
||||
impl Future for ConnectingTcpFuture {
|
||||
type Output = Result<u8, std::io::Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
|
||||
match self.0.poll(cx) {
|
||||
Poll::Ready(Ok(stream)) => Poll::Ready(Ok(
|
||||
if stream.peer_addr().unwrap().is_ipv4() { 4 } else { 6 }
|
||||
)),
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn local_ipv4_addr() -> IpAddr {
|
||||
Ipv4Addr::new(127, 0, 0, 1).into()
|
||||
}
|
||||
@@ -851,8 +844,8 @@ mod tests {
|
||||
|
||||
fn measure_connect(addr: IpAddr) -> (bool, Duration) {
|
||||
let start = Instant::now();
|
||||
let result = ::std::net::TcpStream::connect_timeout(
|
||||
&(addr, 80).into(), Duration::from_secs(1));
|
||||
let result =
|
||||
::std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
|
||||
|
||||
let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
|
||||
let duration = start.elapsed();
|
||||
@@ -860,4 +853,3 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user