refactor(client): use pin_project for Resolve futures

This commit is contained in:
Sean McArthur
2019-10-23 14:29:15 -07:00
parent 10cda4beff
commit f71304b449
2 changed files with 27 additions and 24 deletions

View File

@@ -17,14 +17,14 @@ use std::str::FromStr;
use tokio_sync::{mpsc, oneshot}; use tokio_sync::{mpsc, oneshot};
use crate::common::{Future, Never, Pin, Poll, Unpin, task}; use crate::common::{Future, Never, Pin, Poll, task};
/// Resolve a hostname to a set of IP addresses. /// Resolve a hostname to a set of IP addresses.
pub trait Resolve: Unpin { pub trait Resolve {
/// The set of IP addresses to try to connect to. /// The set of IP addresses to try to connect to.
type Addrs: Iterator<Item=IpAddr>; type Addrs: Iterator<Item=IpAddr>;
/// A Future of the resolved set of addresses. /// A Future of the resolved set of addresses.
type Future: Future<Output=Result<Self::Addrs, io::Error>> + Unpin; type Future: Future<Output=Result<Self::Addrs, io::Error>>;
/// Resolve a hostname. /// Resolve a hostname.
fn resolve(&self, name: Name) -> Self::Future; fn resolve(&self, name: Name) -> Self::Future;
} }

View File

@@ -9,6 +9,7 @@ use std::time::Duration;
use http::uri::{Scheme, Uri}; use http::uri::{Scheme, Uri};
use futures_util::{TryFutureExt, FutureExt}; use futures_util::{TryFutureExt, FutureExt};
use net2::TcpBuilder; use net2::TcpBuilder;
use pin_project::{pin_project, project};
use tokio_net::driver::Handle; use tokio_net::driver::Handle;
use tokio_net::tcp::TcpStream; use tokio_net::tcp::TcpStream;
use tokio_timer::{Delay, Timeout}; use tokio_timer::{Delay, Timeout};
@@ -359,7 +360,9 @@ impl StdError for InvalidUrl {
} }
/// A Future representing work to connect to a URL. /// A Future representing work to connect to a URL.
#[must_use = "futures do nothing unless polled"] #[must_use = "futures do nothing unless polled"]
#[pin_project]
pub struct HttpConnecting<R: Resolve = GaiResolver> { pub struct HttpConnecting<R: Resolve = GaiResolver> {
#[pin]
state: State<R>, state: State<R>,
handle: Option<Handle>, handle: Option<Handle>,
connect_timeout: Option<Duration>, connect_timeout: Option<Duration>,
@@ -372,61 +375,61 @@ pub struct HttpConnecting<R: Resolve = GaiResolver> {
recv_buffer_size: Option<usize>, recv_buffer_size: Option<usize>,
} }
#[pin_project]
enum State<R: Resolve> { enum State<R: Resolve> {
Lazy(R, String, Option<IpAddr>), Lazy(R, String, Option<IpAddr>),
Resolving(R::Future, Option<IpAddr>), Resolving(#[pin] R::Future, Option<IpAddr>),
Connecting(ConnectingTcp), Connecting(ConnectingTcp),
Error(Option<io::Error>), Error(Option<io::Error>),
} }
impl<R: Resolve> Future for HttpConnecting<R> impl<R: Resolve> Future for HttpConnecting<R> {
where
R::Future: Unpin,
{
type Output = Result<(TcpStream, Connected), io::Error>; type Output = Result<(TcpStream, Connected), io::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { #[project]
let me = &mut *self; fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let mut me = self.project();
loop { loop {
let state; let state;
match me.state { #[project]
State::Lazy(ref resolver, ref mut host, local_addr) => { match me.state.as_mut().project() {
State::Lazy(ref resolver, ref mut host, ref local_addr) => {
// If the host is already an IP addr (v4 or v6), // If the host is already an IP addr (v4 or v6),
// skip resolving the dns and start connecting right away. // skip resolving the dns and start connecting right away.
if let Some(addrs) = dns::IpAddrs::try_parse(host, me.port) { if let Some(addrs) = dns::IpAddrs::try_parse(host, *me.port) {
state = State::Connecting(ConnectingTcp::new( state = State::Connecting(ConnectingTcp::new(
local_addr, addrs, me.connect_timeout, me.happy_eyeballs_timeout, me.reuse_address)); **local_addr, addrs, *me.connect_timeout, *me.happy_eyeballs_timeout, *me.reuse_address));
} else { } else {
let name = dns::Name::new(mem::replace(host, String::new())); let name = dns::Name::new(mem::replace(host, String::new()));
state = State::Resolving(resolver.resolve(name), local_addr); state = State::Resolving(resolver.resolve(name), **local_addr);
} }
}, },
State::Resolving(ref mut future, local_addr) => { State::Resolving(future, local_addr) => {
let addrs = ready!(Pin::new(future).poll(cx))?; let addrs = ready!(future.poll(cx))?;
let port = me.port; let port = *me.port;
let addrs = addrs let addrs = addrs
.map(|addr| SocketAddr::new(addr, port)) .map(|addr| SocketAddr::new(addr, port))
.collect(); .collect();
let addrs = dns::IpAddrs::new(addrs); let addrs = dns::IpAddrs::new(addrs);
state = State::Connecting(ConnectingTcp::new( state = State::Connecting(ConnectingTcp::new(
local_addr, addrs, me.connect_timeout, me.happy_eyeballs_timeout, me.reuse_address)); *local_addr, addrs, *me.connect_timeout, *me.happy_eyeballs_timeout, *me.reuse_address));
}, },
State::Connecting(ref mut c) => { State::Connecting(ref mut c) => {
let sock = ready!(c.poll(cx, &me.handle))?; let sock = ready!(c.poll(cx, &me.handle))?;
if let Some(dur) = me.keep_alive_timeout { if let Some(dur) = me.keep_alive_timeout {
sock.set_keepalive(Some(dur))?; sock.set_keepalive(Some(*dur))?;
} }
if let Some(size) = me.send_buffer_size { if let Some(size) = me.send_buffer_size {
sock.set_send_buffer_size(size)?; sock.set_send_buffer_size(*size)?;
} }
if let Some(size) = me.recv_buffer_size { if let Some(size) = me.recv_buffer_size {
sock.set_recv_buffer_size(size)?; sock.set_recv_buffer_size(*size)?;
} }
sock.set_nodelay(me.nodelay)?; sock.set_nodelay(*me.nodelay)?;
let extra = HttpInfo { let extra = HttpInfo {
remote_addr: sock.peer_addr()?, remote_addr: sock.peer_addr()?,
@@ -438,7 +441,7 @@ where
}, },
State::Error(ref mut e) => return Poll::Ready(Err(e.take().expect("polled more than once"))), State::Error(ref mut e) => return Poll::Ready(Err(e.take().expect("polled more than once"))),
} }
me.state = state; me.state.set(state);
} }
} }
} }