feat(client): change Resolve to be Service<Name>
				
					
				
			Closes #1903 BREAKING CHANGE: The `Resolve` trait is gone. All custom resolves should implement `tower::Service` instead. The error type of `HttpConnector` has been changed away from `std::io::Error`.
This commit is contained in:
		| @@ -1,11 +1,26 @@ | ||||
| //! The `Resolve` trait, support types, and some basic implementations. | ||||
| //! DNS Resolution used by the `HttpConnector`. | ||||
| //! | ||||
| //! This module contains: | ||||
| //! | ||||
| //! - A [`GaiResolver`](dns::GaiResolver) that is the default resolver for the | ||||
| //!   `HttpConnector`. | ||||
| //! - The [`Resolve`](dns::Resolve) trait and related types to build a custom | ||||
| //!   resolver for use with the `HttpConnector`. | ||||
| //! - The `Name` type used as an argument to custom resolvers. | ||||
| //! | ||||
| //! # Resolvers are `Service`s | ||||
| //! | ||||
| //! A resolver is just a | ||||
| //! `Service<Name, Response = impl Iterator<Item = IpAddr>>`. | ||||
| //! | ||||
| //! A simple resolver that ignores the name and always returns a specific | ||||
| //! address: | ||||
| //! | ||||
| //! ```rust,ignore | ||||
| //! use std::{convert::Infallible, iter, net::IpAddr}; | ||||
| //! | ||||
| //! let resolver = tower::service_fn(|_name| async { | ||||
| //!     Ok::<_, Infallible>(iter::once(IpAddr::from([127, 0, 0, 1]))) | ||||
| //! }); | ||||
| //! ``` | ||||
| use std::{fmt, io, vec}; | ||||
| use std::error::Error; | ||||
| use std::net::{ | ||||
| @@ -15,19 +30,10 @@ use std::net::{ | ||||
| }; | ||||
| use std::str::FromStr; | ||||
|  | ||||
| use tokio_sync::{mpsc, oneshot}; | ||||
| use tower_service::Service; | ||||
| use crate::common::{Future, Pin, Poll, task}; | ||||
|  | ||||
| use crate::common::{Future, Never, Pin, Poll, task}; | ||||
|  | ||||
| /// Resolve a hostname to a set of IP addresses. | ||||
| pub trait Resolve { | ||||
|     /// The set of IP addresses to try to connect to. | ||||
|     type Addrs: Iterator<Item=IpAddr>; | ||||
|     /// A Future of the resolved set of addresses. | ||||
|     type Future: Future<Output=Result<Self::Addrs, io::Error>>; | ||||
|     /// Resolve a hostname. | ||||
|     fn resolve(&self, name: Name) -> Self::Future; | ||||
| } | ||||
| pub(super) use self::sealed::Resolve; | ||||
|  | ||||
| /// A domain name to resolve into IP addresses. | ||||
| #[derive(Clone, Hash, Eq, PartialEq)] | ||||
| @@ -41,15 +47,12 @@ pub struct GaiResolver { | ||||
|     _priv: (), | ||||
| } | ||||
|  | ||||
| #[derive(Clone)] | ||||
| struct ThreadPoolKeepAlive(mpsc::Sender<Never>); | ||||
|  | ||||
| /// An iterator of IP addresses returned from `getaddrinfo`. | ||||
| pub struct GaiAddrs { | ||||
|     inner: IpAddrs, | ||||
| } | ||||
|  | ||||
| /// A future to resole a name returned by `GaiResolver`. | ||||
| /// A future to resolve a name returned by `GaiResolver`. | ||||
| pub struct GaiFuture { | ||||
|     inner: tokio_executor::blocking::Blocking<Result<IpAddrs, io::Error>>, | ||||
| } | ||||
| @@ -110,11 +113,16 @@ impl GaiResolver { | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Resolve for GaiResolver { | ||||
|     type Addrs = GaiAddrs; | ||||
| impl Service<Name> for GaiResolver { | ||||
|     type Response = GaiAddrs; | ||||
|     type Error = io::Error; | ||||
|     type Future = GaiFuture; | ||||
|  | ||||
|     fn resolve(&self, name: Name) -> Self::Future { | ||||
|     fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> { | ||||
|         Poll::Ready(Ok(())) | ||||
|     } | ||||
|  | ||||
|     fn call(&mut self, name: Name) -> Self::Future { | ||||
|         let blocking = tokio_executor::blocking::run(move || { | ||||
|             debug!("resolving host={:?}", name.host); | ||||
|             (&*name.host, 0).to_socket_addrs() | ||||
| @@ -164,39 +172,6 @@ impl fmt::Debug for GaiAddrs { | ||||
|     } | ||||
| } | ||||
|  | ||||
|  | ||||
| pub(super) struct GaiBlocking { | ||||
|     host: String, | ||||
|     tx: Option<oneshot::Sender<io::Result<IpAddrs>>>, | ||||
| } | ||||
|  | ||||
| impl GaiBlocking { | ||||
|     fn block(&self) -> io::Result<IpAddrs> { | ||||
|         debug!("resolving host={:?}", self.host); | ||||
|         (&*self.host, 0).to_socket_addrs() | ||||
|             .map(|i| IpAddrs { iter: i }) | ||||
|  | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Future for GaiBlocking { | ||||
|     type Output = (); | ||||
|  | ||||
|     fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { | ||||
|         if self.tx.as_mut().expect("polled after complete").poll_closed(cx).is_ready() { | ||||
|             trace!("resolve future canceled for {:?}", self.host); | ||||
|             return Poll::Ready(()); | ||||
|         } | ||||
|  | ||||
|         let res = self.block(); | ||||
|  | ||||
|         let tx = self.tx.take().expect("polled after complete"); | ||||
|         let _ = tx.send(res); | ||||
|  | ||||
|         Poll::Ready(()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub(super) struct IpAddrs { | ||||
|     iter: vec::IntoIter<SocketAddr>, | ||||
| } | ||||
| @@ -276,11 +251,16 @@ impl TokioThreadpoolGaiResolver { | ||||
| } | ||||
|  | ||||
| #[cfg(feature = "runtime")] | ||||
| impl Resolve for TokioThreadpoolGaiResolver { | ||||
|     type Addrs = GaiAddrs; | ||||
| impl Service<Name> for TokioThreadpoolGaiResolver { | ||||
|     type Response = GaiAddrs; | ||||
|     type Error = io::Error; | ||||
|     type Future = TokioThreadpoolGaiFuture; | ||||
|  | ||||
|     fn resolve(&self, name: Name) -> TokioThreadpoolGaiFuture { | ||||
|     fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> { | ||||
|         Poll::Ready(Ok(())) | ||||
|     } | ||||
|  | ||||
|     fn call(&mut self, name: Name) -> Self::Future { | ||||
|         TokioThreadpoolGaiFuture { name } | ||||
|     } | ||||
| } | ||||
| @@ -299,6 +279,41 @@ impl Future for TokioThreadpoolGaiFuture { | ||||
|     } | ||||
| } | ||||
|  | ||||
| mod sealed { | ||||
|     use tower_service::Service; | ||||
|     use crate::common::{Future, Poll, task}; | ||||
|     use super::{IpAddr, Name}; | ||||
|  | ||||
|     // "Trait alias" for `Service<Name, Response = Addrs>` | ||||
|     pub trait Resolve { | ||||
|         type Addrs: Iterator<Item=IpAddr>; | ||||
|         type Error: Into<Box<dyn std::error::Error + Send + Sync>>; | ||||
|         type Future: Future<Output=Result<Self::Addrs, Self::Error>>; | ||||
|  | ||||
|         fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>; | ||||
|         fn resolve(&mut self, name: Name) -> Self::Future; | ||||
|     } | ||||
|  | ||||
|     impl<S> Resolve for S | ||||
|     where | ||||
|         S: Service<Name>, | ||||
|         S::Response: Iterator<Item=IpAddr>, | ||||
|         S::Error: Into<Box<dyn std::error::Error + Send + Sync>>, | ||||
|     { | ||||
|         type Addrs = S::Response; | ||||
|         type Error = S::Error; | ||||
|         type Future = S::Future; | ||||
|  | ||||
|         fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | ||||
|             Service::poll_ready(self, cx) | ||||
|         } | ||||
|  | ||||
|         fn resolve(&mut self, name: Name) -> Self::Future { | ||||
|             Service::call(self, name) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|     use std::net::{Ipv4Addr, Ipv6Addr}; | ||||
|   | ||||
| @@ -228,11 +228,18 @@ impl<R> HttpConnector<R> { | ||||
|     } | ||||
| } | ||||
|  | ||||
| static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http"; | ||||
| static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing"; | ||||
| static INVALID_MISSING_HOST: &str = "invalid URL, host is missing"; | ||||
|  | ||||
| impl<R: Resolve> HttpConnector<R> { | ||||
|     fn invalid_url(&self, err: InvalidUrl) -> HttpConnecting<R> { | ||||
|     fn invalid_url(&self, msg: impl Into<Box<str>>) -> HttpConnecting<R> { | ||||
|         HttpConnecting { | ||||
|             config: self.config.clone(), | ||||
|             state: State::Error(Some(io::Error::new(io::ErrorKind::InvalidInput, err))), | ||||
|             state: State::Error(Some(ConnectError { | ||||
|                 msg: msg.into(), | ||||
|                 cause: None, | ||||
|             })), | ||||
|             port: 0, | ||||
|         } | ||||
|     } | ||||
| @@ -252,14 +259,11 @@ where | ||||
|     R::Future: Send, | ||||
| { | ||||
|     type Response = (TcpStream, Connected); | ||||
|     type Error = io::Error; | ||||
|     type Error = ConnectError; | ||||
|     type Future = HttpConnecting<R>; | ||||
|  | ||||
|     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | ||||
|         // For now, always ready. | ||||
|         // TODO: When `Resolve` becomes an alias for `Service`, check | ||||
|         // the resolver's readiness. | ||||
|         drop(cx); | ||||
|         ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?; | ||||
|         Poll::Ready(Ok(())) | ||||
|     } | ||||
|  | ||||
| @@ -273,15 +277,15 @@ where | ||||
|  | ||||
|         if self.config.enforce_http { | ||||
|             if dst.uri.scheme_part() != Some(&Scheme::HTTP) { | ||||
|                 return self.invalid_url(InvalidUrl::NotHttp); | ||||
|                 return self.invalid_url(INVALID_NOT_HTTP); | ||||
|             } | ||||
|         } else if dst.uri.scheme_part().is_none() { | ||||
|             return self.invalid_url(InvalidUrl::MissingScheme); | ||||
|             return self.invalid_url(INVALID_MISSING_SCHEME); | ||||
|         } | ||||
|  | ||||
|         let host = match dst.uri.host() { | ||||
|             Some(s) => s, | ||||
|             None => return self.invalid_url(InvalidUrl::MissingAuthority), | ||||
|             None => return self.invalid_url(INVALID_MISSING_HOST), | ||||
|         }; | ||||
|         let port = match dst.uri.port_part() { | ||||
|             Some(port) => port.as_u16(), | ||||
| @@ -302,7 +306,7 @@ where | ||||
|     R::Future: Send, | ||||
| { | ||||
|     type Response = TcpStream; | ||||
|     type Error = io::Error; | ||||
|     type Error = ConnectError; | ||||
|     type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; | ||||
|  | ||||
|     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | ||||
| @@ -324,28 +328,73 @@ impl HttpInfo { | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Copy)] | ||||
| enum InvalidUrl { | ||||
|     MissingScheme, | ||||
|     NotHttp, | ||||
|     MissingAuthority, | ||||
| // Not publicly exported (so missing_docs doesn't trigger). | ||||
| pub struct ConnectError { | ||||
|     msg: Box<str>, | ||||
|     cause: Option<Box<dyn StdError + Send + Sync>>, | ||||
| } | ||||
|  | ||||
| impl fmt::Display for InvalidUrl { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         f.write_str(self.description()) | ||||
| impl ConnectError { | ||||
|     fn new<S, E>(msg: S, cause: E) -> ConnectError | ||||
|     where | ||||
|         S: Into<Box<str>>, | ||||
|         E: Into<Box<dyn StdError + Send + Sync>>, | ||||
|     { | ||||
|         ConnectError { | ||||
|             msg: msg.into(), | ||||
|             cause: Some(cause.into()), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl StdError for InvalidUrl { | ||||
|     fn description(&self) -> &str { | ||||
|         match *self { | ||||
|             InvalidUrl::MissingScheme => "invalid URL, missing scheme", | ||||
|             InvalidUrl::NotHttp => "invalid URL, scheme must be http", | ||||
|             InvalidUrl::MissingAuthority => "invalid URL, missing domain", | ||||
|     fn dns<E>(cause: E) -> ConnectError | ||||
|     where | ||||
|         E: Into<Box<dyn StdError + Send + Sync>>, | ||||
|     { | ||||
|         ConnectError::new("dns error", cause) | ||||
|     } | ||||
|  | ||||
|     fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError | ||||
|     where | ||||
|         S: Into<Box<str>>, | ||||
|         E: Into<Box<dyn StdError + Send + Sync>>, | ||||
|     { | ||||
|         move |cause| { | ||||
|             ConnectError::new(msg, cause) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl fmt::Debug for ConnectError { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         if let Some(ref cause) = self.cause { | ||||
|             f.debug_tuple("ConnectError") | ||||
|                 .field(&self.msg) | ||||
|                 .field(cause) | ||||
|                 .finish() | ||||
|         } else { | ||||
|             self.msg.fmt(f) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl fmt::Display for ConnectError { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         f.write_str(&self.msg)?; | ||||
|  | ||||
|         if let Some(ref cause) = self.cause { | ||||
|             write!(f, ": {}", cause)?; | ||||
|         } | ||||
|  | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl StdError for ConnectError { | ||||
|     fn source(&self) -> Option<&(dyn StdError + 'static)> { | ||||
|         self.cause.as_ref().map(|e| &**e as _) | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// A Future representing work to connect to a URL. | ||||
| #[must_use = "futures do nothing unless polled"] | ||||
| #[pin_project] | ||||
| @@ -361,11 +410,11 @@ enum State<R: Resolve> { | ||||
|     Lazy(R, String), | ||||
|     Resolving(#[pin] R::Future), | ||||
|     Connecting(ConnectingTcp), | ||||
|     Error(Option<io::Error>), | ||||
|     Error(Option<ConnectError>), | ||||
| } | ||||
|  | ||||
| impl<R: Resolve> Future for HttpConnecting<R> { | ||||
|     type Output = Result<(TcpStream, Connected), io::Error>; | ||||
|     type Output = Result<(TcpStream, Connected), ConnectError>; | ||||
|  | ||||
|     #[project] | ||||
|     fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { | ||||
| @@ -375,19 +424,20 @@ impl<R: Resolve> Future for HttpConnecting<R> { | ||||
|             let state; | ||||
|             #[project] | ||||
|             match me.state.as_mut().project() { | ||||
|                 State::Lazy(ref resolver, ref mut host) => { | ||||
|                 State::Lazy(ref mut resolver, ref mut host) => { | ||||
|                     // If the host is already an IP addr (v4 or v6), | ||||
|                     // skip resolving the dns and start connecting right away. | ||||
|                     if let Some(addrs) = dns::IpAddrs::try_parse(host, *me.port) { | ||||
|                         state = State::Connecting(ConnectingTcp::new( | ||||
|                             config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address)); | ||||
|                     } else { | ||||
|                         ready!(resolver.poll_ready(cx)).map_err(ConnectError::dns)?; | ||||
|                         let name = dns::Name::new(mem::replace(host, String::new())); | ||||
|                         state = State::Resolving(resolver.resolve(name)); | ||||
|                     } | ||||
|                 }, | ||||
|                 State::Resolving(future) => { | ||||
|                     let addrs =  ready!(future.poll(cx))?; | ||||
|                     let addrs =  ready!(future.poll(cx)).map_err(ConnectError::dns)?; | ||||
|                     let port = *me.port; | ||||
|                     let addrs = addrs | ||||
|                         .map(|addr| SocketAddr::new(addr, port)) | ||||
| @@ -397,24 +447,25 @@ impl<R: Resolve> Future for HttpConnecting<R> { | ||||
|                         config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address)); | ||||
|                 }, | ||||
|                 State::Connecting(ref mut c) => { | ||||
|                     let sock = ready!(c.poll(cx, &config.handle))?; | ||||
|                     let sock = ready!(c.poll(cx, &config.handle)) | ||||
|                         .map_err(ConnectError::m("tcp connect error"))?; | ||||
|  | ||||
|                     if let Some(dur) = config.keep_alive_timeout { | ||||
|                         sock.set_keepalive(Some(dur))?; | ||||
|                         sock.set_keepalive(Some(dur)).map_err(ConnectError::m("tcp set_keepalive error"))?; | ||||
|                     } | ||||
|  | ||||
|                     if let Some(size) = config.send_buffer_size { | ||||
|                         sock.set_send_buffer_size(size)?; | ||||
|                         sock.set_send_buffer_size(size).map_err(ConnectError::m("tcp set_send_buffer_size error"))?; | ||||
|                     } | ||||
|  | ||||
|                     if let Some(size) = config.recv_buffer_size { | ||||
|                         sock.set_recv_buffer_size(size)?; | ||||
|                         sock.set_recv_buffer_size(size).map_err(ConnectError::m("tcp set_recv_buffer_size error"))?; | ||||
|                     } | ||||
|  | ||||
|                     sock.set_nodelay(config.nodelay)?; | ||||
|                     sock.set_nodelay(config.nodelay).map_err(ConnectError::m("tcp set_nodelay error"))?; | ||||
|  | ||||
|                     let extra = HttpInfo { | ||||
|                         remote_addr: sock.peer_addr()?, | ||||
|                         remote_addr: sock.peer_addr().map_err(ConnectError::m("tcp peer_addr error"))?, | ||||
|                     }; | ||||
|                     let connected = Connected::new() | ||||
|                         .extra(extra); | ||||
| @@ -642,7 +693,6 @@ impl ConnectingTcp { | ||||
| mod tests { | ||||
|     use std::io; | ||||
|  | ||||
|     use tokio::runtime::current_thread::Runtime; | ||||
|     use tokio_net::driver::Handle; | ||||
|  | ||||
|     use super::{Connected, Destination, HttpConnector}; | ||||
| @@ -655,55 +705,29 @@ mod tests { | ||||
|         connector.connect(super::super::sealed::Internal, dst).await | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_errors_missing_authority() { | ||||
|         let mut rt = Runtime::new().unwrap(); | ||||
|         let uri = "/foo/bar?baz".parse().unwrap(); | ||||
|         let dst = Destination { | ||||
|             uri, | ||||
|         }; | ||||
|         let connector = HttpConnector::new(); | ||||
|  | ||||
|         rt.block_on(async { | ||||
|             assert_eq!( | ||||
|                 connect(connector, dst).await.unwrap_err().kind(), | ||||
|                 io::ErrorKind::InvalidInput, | ||||
|             ); | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_errors_enforce_http() { | ||||
|         let mut rt = Runtime::new().unwrap(); | ||||
|     #[tokio::test] | ||||
|     async fn test_errors_enforce_http() { | ||||
|         let uri = "https://example.domain/foo/bar?baz".parse().unwrap(); | ||||
|         let dst = Destination { | ||||
|             uri, | ||||
|         }; | ||||
|         let connector = HttpConnector::new(); | ||||
|  | ||||
|         rt.block_on(async { | ||||
|             assert_eq!( | ||||
|                 connect(connector, dst).await.unwrap_err().kind(), | ||||
|                 io::ErrorKind::InvalidInput, | ||||
|             ); | ||||
|         }) | ||||
|         let err = connect(connector, dst).await.unwrap_err(); | ||||
|         assert_eq!(&*err.msg, super::INVALID_NOT_HTTP); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_errors_missing_scheme() { | ||||
|         let mut rt = Runtime::new().unwrap(); | ||||
|     #[tokio::test] | ||||
|     async fn test_errors_missing_scheme() { | ||||
|         let uri = "example.domain".parse().unwrap(); | ||||
|         let dst = Destination { | ||||
|             uri, | ||||
|         }; | ||||
|         let connector = HttpConnector::new(); | ||||
|         let mut connector = HttpConnector::new(); | ||||
|         connector.enforce_http(false); | ||||
|  | ||||
|         rt.block_on(async { | ||||
|             assert_eq!( | ||||
|                 connect(connector, dst).await.unwrap_err().kind(), | ||||
|                 io::ErrorKind::InvalidInput, | ||||
|             ); | ||||
|         }); | ||||
|         let err = connect(connector, dst).await.unwrap_err(); | ||||
|         assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user