refactor(client): use pin_project for Resolve futures

This commit is contained in:
Sean McArthur
2019-10-23 14:29:15 -07:00
parent 10cda4beff
commit f71304b449
2 changed files with 27 additions and 24 deletions

View File

@@ -17,14 +17,14 @@ use std::str::FromStr;
use tokio_sync::{mpsc, oneshot};
use crate::common::{Future, Never, Pin, Poll, Unpin, task};
use crate::common::{Future, Never, Pin, Poll, task};
/// Resolve a hostname to a set of IP addresses.
pub trait Resolve: Unpin {
pub trait Resolve {
/// The set of IP addresses to try to connect to.
type Addrs: Iterator<Item=IpAddr>;
/// A Future of the resolved set of addresses.
type Future: Future<Output=Result<Self::Addrs, io::Error>> + Unpin;
type Future: Future<Output=Result<Self::Addrs, io::Error>>;
/// Resolve a hostname.
fn resolve(&self, name: Name) -> Self::Future;
}

View File

@@ -9,6 +9,7 @@ use std::time::Duration;
use http::uri::{Scheme, Uri};
use futures_util::{TryFutureExt, FutureExt};
use net2::TcpBuilder;
use pin_project::{pin_project, project};
use tokio_net::driver::Handle;
use tokio_net::tcp::TcpStream;
use tokio_timer::{Delay, Timeout};
@@ -359,7 +360,9 @@ impl StdError for InvalidUrl {
}
/// A Future representing work to connect to a URL.
#[must_use = "futures do nothing unless polled"]
#[pin_project]
pub struct HttpConnecting<R: Resolve = GaiResolver> {
#[pin]
state: State<R>,
handle: Option<Handle>,
connect_timeout: Option<Duration>,
@@ -372,61 +375,61 @@ pub struct HttpConnecting<R: Resolve = GaiResolver> {
recv_buffer_size: Option<usize>,
}
#[pin_project]
enum State<R: Resolve> {
Lazy(R, String, Option<IpAddr>),
Resolving(R::Future, Option<IpAddr>),
Resolving(#[pin] R::Future, Option<IpAddr>),
Connecting(ConnectingTcp),
Error(Option<io::Error>),
}
impl<R: Resolve> Future for HttpConnecting<R>
where
R::Future: Unpin,
{
impl<R: Resolve> Future for HttpConnecting<R> {
type Output = Result<(TcpStream, Connected), io::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let me = &mut *self;
#[project]
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let mut me = self.project();
loop {
let state;
match me.state {
State::Lazy(ref resolver, ref mut host, local_addr) => {
#[project]
match me.state.as_mut().project() {
State::Lazy(ref resolver, ref mut host, ref local_addr) => {
// If the host is already an IP addr (v4 or v6),
// skip resolving the dns and start connecting right away.
if let Some(addrs) = dns::IpAddrs::try_parse(host, me.port) {
if let Some(addrs) = dns::IpAddrs::try_parse(host, *me.port) {
state = State::Connecting(ConnectingTcp::new(
local_addr, addrs, me.connect_timeout, me.happy_eyeballs_timeout, me.reuse_address));
**local_addr, addrs, *me.connect_timeout, *me.happy_eyeballs_timeout, *me.reuse_address));
} else {
let name = dns::Name::new(mem::replace(host, String::new()));
state = State::Resolving(resolver.resolve(name), local_addr);
state = State::Resolving(resolver.resolve(name), **local_addr);
}
},
State::Resolving(ref mut future, local_addr) => {
let addrs = ready!(Pin::new(future).poll(cx))?;
let port = me.port;
State::Resolving(future, local_addr) => {
let addrs = ready!(future.poll(cx))?;
let port = *me.port;
let addrs = addrs
.map(|addr| SocketAddr::new(addr, port))
.collect();
let addrs = dns::IpAddrs::new(addrs);
state = State::Connecting(ConnectingTcp::new(
local_addr, addrs, me.connect_timeout, me.happy_eyeballs_timeout, me.reuse_address));
*local_addr, addrs, *me.connect_timeout, *me.happy_eyeballs_timeout, *me.reuse_address));
},
State::Connecting(ref mut c) => {
let sock = ready!(c.poll(cx, &me.handle))?;
if let Some(dur) = me.keep_alive_timeout {
sock.set_keepalive(Some(dur))?;
sock.set_keepalive(Some(*dur))?;
}
if let Some(size) = me.send_buffer_size {
sock.set_send_buffer_size(size)?;
sock.set_send_buffer_size(*size)?;
}
if let Some(size) = me.recv_buffer_size {
sock.set_recv_buffer_size(size)?;
sock.set_recv_buffer_size(*size)?;
}
sock.set_nodelay(me.nodelay)?;
sock.set_nodelay(*me.nodelay)?;
let extra = HttpInfo {
remote_addr: sock.peer_addr()?,
@@ -438,7 +441,7 @@ where
},
State::Error(ref mut e) => return Poll::Ready(Err(e.take().expect("polled more than once"))),
}
me.state = state;
me.state.set(state);
}
}
}