feat(client): change Resolve to be Service<Name>
				
					
				
			Closes #1903 BREAKING CHANGE: The `Resolve` trait is gone. All custom resolves should implement `tower::Service` instead. The error type of `HttpConnector` has been changed away from `std::io::Error`.
This commit is contained in:
		| @@ -228,11 +228,18 @@ impl<R> HttpConnector<R> { | ||||
|     } | ||||
| } | ||||
|  | ||||
| static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http"; | ||||
| static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing"; | ||||
| static INVALID_MISSING_HOST: &str = "invalid URL, host is missing"; | ||||
|  | ||||
| impl<R: Resolve> HttpConnector<R> { | ||||
|     fn invalid_url(&self, err: InvalidUrl) -> HttpConnecting<R> { | ||||
|     fn invalid_url(&self, msg: impl Into<Box<str>>) -> HttpConnecting<R> { | ||||
|         HttpConnecting { | ||||
|             config: self.config.clone(), | ||||
|             state: State::Error(Some(io::Error::new(io::ErrorKind::InvalidInput, err))), | ||||
|             state: State::Error(Some(ConnectError { | ||||
|                 msg: msg.into(), | ||||
|                 cause: None, | ||||
|             })), | ||||
|             port: 0, | ||||
|         } | ||||
|     } | ||||
| @@ -252,14 +259,11 @@ where | ||||
|     R::Future: Send, | ||||
| { | ||||
|     type Response = (TcpStream, Connected); | ||||
|     type Error = io::Error; | ||||
|     type Error = ConnectError; | ||||
|     type Future = HttpConnecting<R>; | ||||
|  | ||||
|     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | ||||
|         // For now, always ready. | ||||
|         // TODO: When `Resolve` becomes an alias for `Service`, check | ||||
|         // the resolver's readiness. | ||||
|         drop(cx); | ||||
|         ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?; | ||||
|         Poll::Ready(Ok(())) | ||||
|     } | ||||
|  | ||||
| @@ -273,15 +277,15 @@ where | ||||
|  | ||||
|         if self.config.enforce_http { | ||||
|             if dst.uri.scheme_part() != Some(&Scheme::HTTP) { | ||||
|                 return self.invalid_url(InvalidUrl::NotHttp); | ||||
|                 return self.invalid_url(INVALID_NOT_HTTP); | ||||
|             } | ||||
|         } else if dst.uri.scheme_part().is_none() { | ||||
|             return self.invalid_url(InvalidUrl::MissingScheme); | ||||
|             return self.invalid_url(INVALID_MISSING_SCHEME); | ||||
|         } | ||||
|  | ||||
|         let host = match dst.uri.host() { | ||||
|             Some(s) => s, | ||||
|             None => return self.invalid_url(InvalidUrl::MissingAuthority), | ||||
|             None => return self.invalid_url(INVALID_MISSING_HOST), | ||||
|         }; | ||||
|         let port = match dst.uri.port_part() { | ||||
|             Some(port) => port.as_u16(), | ||||
| @@ -302,7 +306,7 @@ where | ||||
|     R::Future: Send, | ||||
| { | ||||
|     type Response = TcpStream; | ||||
|     type Error = io::Error; | ||||
|     type Error = ConnectError; | ||||
|     type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>; | ||||
|  | ||||
|     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | ||||
| @@ -324,28 +328,73 @@ impl HttpInfo { | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Copy)] | ||||
| enum InvalidUrl { | ||||
|     MissingScheme, | ||||
|     NotHttp, | ||||
|     MissingAuthority, | ||||
| // Not publicly exported (so missing_docs doesn't trigger). | ||||
| pub struct ConnectError { | ||||
|     msg: Box<str>, | ||||
|     cause: Option<Box<dyn StdError + Send + Sync>>, | ||||
| } | ||||
|  | ||||
| impl fmt::Display for InvalidUrl { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         f.write_str(self.description()) | ||||
| impl ConnectError { | ||||
|     fn new<S, E>(msg: S, cause: E) -> ConnectError | ||||
|     where | ||||
|         S: Into<Box<str>>, | ||||
|         E: Into<Box<dyn StdError + Send + Sync>>, | ||||
|     { | ||||
|         ConnectError { | ||||
|             msg: msg.into(), | ||||
|             cause: Some(cause.into()), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl StdError for InvalidUrl { | ||||
|     fn description(&self) -> &str { | ||||
|         match *self { | ||||
|             InvalidUrl::MissingScheme => "invalid URL, missing scheme", | ||||
|             InvalidUrl::NotHttp => "invalid URL, scheme must be http", | ||||
|             InvalidUrl::MissingAuthority => "invalid URL, missing domain", | ||||
|     fn dns<E>(cause: E) -> ConnectError | ||||
|     where | ||||
|         E: Into<Box<dyn StdError + Send + Sync>>, | ||||
|     { | ||||
|         ConnectError::new("dns error", cause) | ||||
|     } | ||||
|  | ||||
|     fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError | ||||
|     where | ||||
|         S: Into<Box<str>>, | ||||
|         E: Into<Box<dyn StdError + Send + Sync>>, | ||||
|     { | ||||
|         move |cause| { | ||||
|             ConnectError::new(msg, cause) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl fmt::Debug for ConnectError { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         if let Some(ref cause) = self.cause { | ||||
|             f.debug_tuple("ConnectError") | ||||
|                 .field(&self.msg) | ||||
|                 .field(cause) | ||||
|                 .finish() | ||||
|         } else { | ||||
|             self.msg.fmt(f) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl fmt::Display for ConnectError { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         f.write_str(&self.msg)?; | ||||
|  | ||||
|         if let Some(ref cause) = self.cause { | ||||
|             write!(f, ": {}", cause)?; | ||||
|         } | ||||
|  | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl StdError for ConnectError { | ||||
|     fn source(&self) -> Option<&(dyn StdError + 'static)> { | ||||
|         self.cause.as_ref().map(|e| &**e as _) | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// A Future representing work to connect to a URL. | ||||
| #[must_use = "futures do nothing unless polled"] | ||||
| #[pin_project] | ||||
| @@ -361,11 +410,11 @@ enum State<R: Resolve> { | ||||
|     Lazy(R, String), | ||||
|     Resolving(#[pin] R::Future), | ||||
|     Connecting(ConnectingTcp), | ||||
|     Error(Option<io::Error>), | ||||
|     Error(Option<ConnectError>), | ||||
| } | ||||
|  | ||||
| impl<R: Resolve> Future for HttpConnecting<R> { | ||||
|     type Output = Result<(TcpStream, Connected), io::Error>; | ||||
|     type Output = Result<(TcpStream, Connected), ConnectError>; | ||||
|  | ||||
|     #[project] | ||||
|     fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { | ||||
| @@ -375,19 +424,20 @@ impl<R: Resolve> Future for HttpConnecting<R> { | ||||
|             let state; | ||||
|             #[project] | ||||
|             match me.state.as_mut().project() { | ||||
|                 State::Lazy(ref resolver, ref mut host) => { | ||||
|                 State::Lazy(ref mut resolver, ref mut host) => { | ||||
|                     // If the host is already an IP addr (v4 or v6), | ||||
|                     // skip resolving the dns and start connecting right away. | ||||
|                     if let Some(addrs) = dns::IpAddrs::try_parse(host, *me.port) { | ||||
|                         state = State::Connecting(ConnectingTcp::new( | ||||
|                             config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address)); | ||||
|                     } else { | ||||
|                         ready!(resolver.poll_ready(cx)).map_err(ConnectError::dns)?; | ||||
|                         let name = dns::Name::new(mem::replace(host, String::new())); | ||||
|                         state = State::Resolving(resolver.resolve(name)); | ||||
|                     } | ||||
|                 }, | ||||
|                 State::Resolving(future) => { | ||||
|                     let addrs =  ready!(future.poll(cx))?; | ||||
|                     let addrs =  ready!(future.poll(cx)).map_err(ConnectError::dns)?; | ||||
|                     let port = *me.port; | ||||
|                     let addrs = addrs | ||||
|                         .map(|addr| SocketAddr::new(addr, port)) | ||||
| @@ -397,24 +447,25 @@ impl<R: Resolve> Future for HttpConnecting<R> { | ||||
|                         config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address)); | ||||
|                 }, | ||||
|                 State::Connecting(ref mut c) => { | ||||
|                     let sock = ready!(c.poll(cx, &config.handle))?; | ||||
|                     let sock = ready!(c.poll(cx, &config.handle)) | ||||
|                         .map_err(ConnectError::m("tcp connect error"))?; | ||||
|  | ||||
|                     if let Some(dur) = config.keep_alive_timeout { | ||||
|                         sock.set_keepalive(Some(dur))?; | ||||
|                         sock.set_keepalive(Some(dur)).map_err(ConnectError::m("tcp set_keepalive error"))?; | ||||
|                     } | ||||
|  | ||||
|                     if let Some(size) = config.send_buffer_size { | ||||
|                         sock.set_send_buffer_size(size)?; | ||||
|                         sock.set_send_buffer_size(size).map_err(ConnectError::m("tcp set_send_buffer_size error"))?; | ||||
|                     } | ||||
|  | ||||
|                     if let Some(size) = config.recv_buffer_size { | ||||
|                         sock.set_recv_buffer_size(size)?; | ||||
|                         sock.set_recv_buffer_size(size).map_err(ConnectError::m("tcp set_recv_buffer_size error"))?; | ||||
|                     } | ||||
|  | ||||
|                     sock.set_nodelay(config.nodelay)?; | ||||
|                     sock.set_nodelay(config.nodelay).map_err(ConnectError::m("tcp set_nodelay error"))?; | ||||
|  | ||||
|                     let extra = HttpInfo { | ||||
|                         remote_addr: sock.peer_addr()?, | ||||
|                         remote_addr: sock.peer_addr().map_err(ConnectError::m("tcp peer_addr error"))?, | ||||
|                     }; | ||||
|                     let connected = Connected::new() | ||||
|                         .extra(extra); | ||||
| @@ -642,7 +693,6 @@ impl ConnectingTcp { | ||||
| mod tests { | ||||
|     use std::io; | ||||
|  | ||||
|     use tokio::runtime::current_thread::Runtime; | ||||
|     use tokio_net::driver::Handle; | ||||
|  | ||||
|     use super::{Connected, Destination, HttpConnector}; | ||||
| @@ -655,55 +705,29 @@ mod tests { | ||||
|         connector.connect(super::super::sealed::Internal, dst).await | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_errors_missing_authority() { | ||||
|         let mut rt = Runtime::new().unwrap(); | ||||
|         let uri = "/foo/bar?baz".parse().unwrap(); | ||||
|         let dst = Destination { | ||||
|             uri, | ||||
|         }; | ||||
|         let connector = HttpConnector::new(); | ||||
|  | ||||
|         rt.block_on(async { | ||||
|             assert_eq!( | ||||
|                 connect(connector, dst).await.unwrap_err().kind(), | ||||
|                 io::ErrorKind::InvalidInput, | ||||
|             ); | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_errors_enforce_http() { | ||||
|         let mut rt = Runtime::new().unwrap(); | ||||
|     #[tokio::test] | ||||
|     async fn test_errors_enforce_http() { | ||||
|         let uri = "https://example.domain/foo/bar?baz".parse().unwrap(); | ||||
|         let dst = Destination { | ||||
|             uri, | ||||
|         }; | ||||
|         let connector = HttpConnector::new(); | ||||
|  | ||||
|         rt.block_on(async { | ||||
|             assert_eq!( | ||||
|                 connect(connector, dst).await.unwrap_err().kind(), | ||||
|                 io::ErrorKind::InvalidInput, | ||||
|             ); | ||||
|         }) | ||||
|         let err = connect(connector, dst).await.unwrap_err(); | ||||
|         assert_eq!(&*err.msg, super::INVALID_NOT_HTTP); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_errors_missing_scheme() { | ||||
|         let mut rt = Runtime::new().unwrap(); | ||||
|     #[tokio::test] | ||||
|     async fn test_errors_missing_scheme() { | ||||
|         let uri = "example.domain".parse().unwrap(); | ||||
|         let dst = Destination { | ||||
|             uri, | ||||
|         }; | ||||
|         let connector = HttpConnector::new(); | ||||
|         let mut connector = HttpConnector::new(); | ||||
|         connector.enforce_http(false); | ||||
|  | ||||
|         rt.block_on(async { | ||||
|             assert_eq!( | ||||
|                 connect(connector, dst).await.unwrap_err().kind(), | ||||
|                 io::ErrorKind::InvalidInput, | ||||
|             ); | ||||
|         }); | ||||
|         let err = connect(connector, dst).await.unwrap_err(); | ||||
|         assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user