feat(client): change Resolve to be Service<Name>
Closes #1903 BREAKING CHANGE: The `Resolve` trait is gone. All custom resolves should implement `tower::Service` instead. The error type of `HttpConnector` has been changed away from `std::io::Error`.
This commit is contained in:
@@ -1,11 +1,26 @@
|
||||
//! The `Resolve` trait, support types, and some basic implementations.
|
||||
//! DNS Resolution used by the `HttpConnector`.
|
||||
//!
|
||||
//! This module contains:
|
||||
//!
|
||||
//! - A [`GaiResolver`](dns::GaiResolver) that is the default resolver for the
|
||||
//! `HttpConnector`.
|
||||
//! - The [`Resolve`](dns::Resolve) trait and related types to build a custom
|
||||
//! resolver for use with the `HttpConnector`.
|
||||
//! - The `Name` type used as an argument to custom resolvers.
|
||||
//!
|
||||
//! # Resolvers are `Service`s
|
||||
//!
|
||||
//! A resolver is just a
|
||||
//! `Service<Name, Response = impl Iterator<Item = IpAddr>>`.
|
||||
//!
|
||||
//! A simple resolver that ignores the name and always returns a specific
|
||||
//! address:
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use std::{convert::Infallible, iter, net::IpAddr};
|
||||
//!
|
||||
//! let resolver = tower::service_fn(|_name| async {
|
||||
//! Ok::<_, Infallible>(iter::once(IpAddr::from([127, 0, 0, 1])))
|
||||
//! });
|
||||
//! ```
|
||||
use std::{fmt, io, vec};
|
||||
use std::error::Error;
|
||||
use std::net::{
|
||||
@@ -15,19 +30,10 @@ use std::net::{
|
||||
};
|
||||
use std::str::FromStr;
|
||||
|
||||
use tokio_sync::{mpsc, oneshot};
|
||||
use tower_service::Service;
|
||||
use crate::common::{Future, Pin, Poll, task};
|
||||
|
||||
use crate::common::{Future, Never, Pin, Poll, task};
|
||||
|
||||
/// Resolve a hostname to a set of IP addresses.
|
||||
pub trait Resolve {
|
||||
/// The set of IP addresses to try to connect to.
|
||||
type Addrs: Iterator<Item=IpAddr>;
|
||||
/// A Future of the resolved set of addresses.
|
||||
type Future: Future<Output=Result<Self::Addrs, io::Error>>;
|
||||
/// Resolve a hostname.
|
||||
fn resolve(&self, name: Name) -> Self::Future;
|
||||
}
|
||||
pub(super) use self::sealed::Resolve;
|
||||
|
||||
/// A domain name to resolve into IP addresses.
|
||||
#[derive(Clone, Hash, Eq, PartialEq)]
|
||||
@@ -41,15 +47,12 @@ pub struct GaiResolver {
|
||||
_priv: (),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ThreadPoolKeepAlive(mpsc::Sender<Never>);
|
||||
|
||||
/// An iterator of IP addresses returned from `getaddrinfo`.
|
||||
pub struct GaiAddrs {
|
||||
inner: IpAddrs,
|
||||
}
|
||||
|
||||
/// A future to resole a name returned by `GaiResolver`.
|
||||
/// A future to resolve a name returned by `GaiResolver`.
|
||||
pub struct GaiFuture {
|
||||
inner: tokio_executor::blocking::Blocking<Result<IpAddrs, io::Error>>,
|
||||
}
|
||||
@@ -110,11 +113,16 @@ impl GaiResolver {
|
||||
}
|
||||
}
|
||||
|
||||
impl Resolve for GaiResolver {
|
||||
type Addrs = GaiAddrs;
|
||||
impl Service<Name> for GaiResolver {
|
||||
type Response = GaiAddrs;
|
||||
type Error = io::Error;
|
||||
type Future = GaiFuture;
|
||||
|
||||
fn resolve(&self, name: Name) -> Self::Future {
|
||||
fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, name: Name) -> Self::Future {
|
||||
let blocking = tokio_executor::blocking::run(move || {
|
||||
debug!("resolving host={:?}", name.host);
|
||||
(&*name.host, 0).to_socket_addrs()
|
||||
@@ -164,39 +172,6 @@ impl fmt::Debug for GaiAddrs {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub(super) struct GaiBlocking {
|
||||
host: String,
|
||||
tx: Option<oneshot::Sender<io::Result<IpAddrs>>>,
|
||||
}
|
||||
|
||||
impl GaiBlocking {
|
||||
fn block(&self) -> io::Result<IpAddrs> {
|
||||
debug!("resolving host={:?}", self.host);
|
||||
(&*self.host, 0).to_socket_addrs()
|
||||
.map(|i| IpAddrs { iter: i })
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for GaiBlocking {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
|
||||
if self.tx.as_mut().expect("polled after complete").poll_closed(cx).is_ready() {
|
||||
trace!("resolve future canceled for {:?}", self.host);
|
||||
return Poll::Ready(());
|
||||
}
|
||||
|
||||
let res = self.block();
|
||||
|
||||
let tx = self.tx.take().expect("polled after complete");
|
||||
let _ = tx.send(res);
|
||||
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct IpAddrs {
|
||||
iter: vec::IntoIter<SocketAddr>,
|
||||
}
|
||||
@@ -276,11 +251,16 @@ impl TokioThreadpoolGaiResolver {
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl Resolve for TokioThreadpoolGaiResolver {
|
||||
type Addrs = GaiAddrs;
|
||||
impl Service<Name> for TokioThreadpoolGaiResolver {
|
||||
type Response = GaiAddrs;
|
||||
type Error = io::Error;
|
||||
type Future = TokioThreadpoolGaiFuture;
|
||||
|
||||
fn resolve(&self, name: Name) -> TokioThreadpoolGaiFuture {
|
||||
fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, name: Name) -> Self::Future {
|
||||
TokioThreadpoolGaiFuture { name }
|
||||
}
|
||||
}
|
||||
@@ -299,6 +279,41 @@ impl Future for TokioThreadpoolGaiFuture {
|
||||
}
|
||||
}
|
||||
|
||||
mod sealed {
|
||||
use tower_service::Service;
|
||||
use crate::common::{Future, Poll, task};
|
||||
use super::{IpAddr, Name};
|
||||
|
||||
// "Trait alias" for `Service<Name, Response = Addrs>`
|
||||
pub trait Resolve {
|
||||
type Addrs: Iterator<Item=IpAddr>;
|
||||
type Error: Into<Box<dyn std::error::Error + Send + Sync>>;
|
||||
type Future: Future<Output=Result<Self::Addrs, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>;
|
||||
fn resolve(&mut self, name: Name) -> Self::Future;
|
||||
}
|
||||
|
||||
impl<S> Resolve for S
|
||||
where
|
||||
S: Service<Name>,
|
||||
S::Response: Iterator<Item=IpAddr>,
|
||||
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
{
|
||||
type Addrs = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = S::Future;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Service::poll_ready(self, cx)
|
||||
}
|
||||
|
||||
fn resolve(&mut self, name: Name) -> Self::Future {
|
||||
Service::call(self, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
|
||||
@@ -228,11 +228,18 @@ impl<R> HttpConnector<R> {
|
||||
}
|
||||
}
|
||||
|
||||
static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
|
||||
static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
|
||||
static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
|
||||
|
||||
impl<R: Resolve> HttpConnector<R> {
|
||||
fn invalid_url(&self, err: InvalidUrl) -> HttpConnecting<R> {
|
||||
fn invalid_url(&self, msg: impl Into<Box<str>>) -> HttpConnecting<R> {
|
||||
HttpConnecting {
|
||||
config: self.config.clone(),
|
||||
state: State::Error(Some(io::Error::new(io::ErrorKind::InvalidInput, err))),
|
||||
state: State::Error(Some(ConnectError {
|
||||
msg: msg.into(),
|
||||
cause: None,
|
||||
})),
|
||||
port: 0,
|
||||
}
|
||||
}
|
||||
@@ -252,14 +259,11 @@ where
|
||||
R::Future: Send,
|
||||
{
|
||||
type Response = (TcpStream, Connected);
|
||||
type Error = io::Error;
|
||||
type Error = ConnectError;
|
||||
type Future = HttpConnecting<R>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// For now, always ready.
|
||||
// TODO: When `Resolve` becomes an alias for `Service`, check
|
||||
// the resolver's readiness.
|
||||
drop(cx);
|
||||
ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
@@ -273,15 +277,15 @@ where
|
||||
|
||||
if self.config.enforce_http {
|
||||
if dst.uri.scheme_part() != Some(&Scheme::HTTP) {
|
||||
return self.invalid_url(InvalidUrl::NotHttp);
|
||||
return self.invalid_url(INVALID_NOT_HTTP);
|
||||
}
|
||||
} else if dst.uri.scheme_part().is_none() {
|
||||
return self.invalid_url(InvalidUrl::MissingScheme);
|
||||
return self.invalid_url(INVALID_MISSING_SCHEME);
|
||||
}
|
||||
|
||||
let host = match dst.uri.host() {
|
||||
Some(s) => s,
|
||||
None => return self.invalid_url(InvalidUrl::MissingAuthority),
|
||||
None => return self.invalid_url(INVALID_MISSING_HOST),
|
||||
};
|
||||
let port = match dst.uri.port_part() {
|
||||
Some(port) => port.as_u16(),
|
||||
@@ -302,7 +306,7 @@ where
|
||||
R::Future: Send,
|
||||
{
|
||||
type Response = TcpStream;
|
||||
type Error = io::Error;
|
||||
type Error = ConnectError;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
@@ -324,28 +328,73 @@ impl HttpInfo {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum InvalidUrl {
|
||||
MissingScheme,
|
||||
NotHttp,
|
||||
MissingAuthority,
|
||||
// Not publicly exported (so missing_docs doesn't trigger).
|
||||
pub struct ConnectError {
|
||||
msg: Box<str>,
|
||||
cause: Option<Box<dyn StdError + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl fmt::Display for InvalidUrl {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(self.description())
|
||||
impl ConnectError {
|
||||
fn new<S, E>(msg: S, cause: E) -> ConnectError
|
||||
where
|
||||
S: Into<Box<str>>,
|
||||
E: Into<Box<dyn StdError + Send + Sync>>,
|
||||
{
|
||||
ConnectError {
|
||||
msg: msg.into(),
|
||||
cause: Some(cause.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StdError for InvalidUrl {
|
||||
fn description(&self) -> &str {
|
||||
match *self {
|
||||
InvalidUrl::MissingScheme => "invalid URL, missing scheme",
|
||||
InvalidUrl::NotHttp => "invalid URL, scheme must be http",
|
||||
InvalidUrl::MissingAuthority => "invalid URL, missing domain",
|
||||
fn dns<E>(cause: E) -> ConnectError
|
||||
where
|
||||
E: Into<Box<dyn StdError + Send + Sync>>,
|
||||
{
|
||||
ConnectError::new("dns error", cause)
|
||||
}
|
||||
|
||||
fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
|
||||
where
|
||||
S: Into<Box<str>>,
|
||||
E: Into<Box<dyn StdError + Send + Sync>>,
|
||||
{
|
||||
move |cause| {
|
||||
ConnectError::new(msg, cause)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for ConnectError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
if let Some(ref cause) = self.cause {
|
||||
f.debug_tuple("ConnectError")
|
||||
.field(&self.msg)
|
||||
.field(cause)
|
||||
.finish()
|
||||
} else {
|
||||
self.msg.fmt(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ConnectError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(&self.msg)?;
|
||||
|
||||
if let Some(ref cause) = self.cause {
|
||||
write!(f, ": {}", cause)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl StdError for ConnectError {
|
||||
fn source(&self) -> Option<&(dyn StdError + 'static)> {
|
||||
self.cause.as_ref().map(|e| &**e as _)
|
||||
}
|
||||
}
|
||||
|
||||
/// A Future representing work to connect to a URL.
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
#[pin_project]
|
||||
@@ -361,11 +410,11 @@ enum State<R: Resolve> {
|
||||
Lazy(R, String),
|
||||
Resolving(#[pin] R::Future),
|
||||
Connecting(ConnectingTcp),
|
||||
Error(Option<io::Error>),
|
||||
Error(Option<ConnectError>),
|
||||
}
|
||||
|
||||
impl<R: Resolve> Future for HttpConnecting<R> {
|
||||
type Output = Result<(TcpStream, Connected), io::Error>;
|
||||
type Output = Result<(TcpStream, Connected), ConnectError>;
|
||||
|
||||
#[project]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
|
||||
@@ -375,19 +424,20 @@ impl<R: Resolve> Future for HttpConnecting<R> {
|
||||
let state;
|
||||
#[project]
|
||||
match me.state.as_mut().project() {
|
||||
State::Lazy(ref resolver, ref mut host) => {
|
||||
State::Lazy(ref mut resolver, ref mut host) => {
|
||||
// If the host is already an IP addr (v4 or v6),
|
||||
// skip resolving the dns and start connecting right away.
|
||||
if let Some(addrs) = dns::IpAddrs::try_parse(host, *me.port) {
|
||||
state = State::Connecting(ConnectingTcp::new(
|
||||
config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address));
|
||||
} else {
|
||||
ready!(resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
|
||||
let name = dns::Name::new(mem::replace(host, String::new()));
|
||||
state = State::Resolving(resolver.resolve(name));
|
||||
}
|
||||
},
|
||||
State::Resolving(future) => {
|
||||
let addrs = ready!(future.poll(cx))?;
|
||||
let addrs = ready!(future.poll(cx)).map_err(ConnectError::dns)?;
|
||||
let port = *me.port;
|
||||
let addrs = addrs
|
||||
.map(|addr| SocketAddr::new(addr, port))
|
||||
@@ -397,24 +447,25 @@ impl<R: Resolve> Future for HttpConnecting<R> {
|
||||
config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address));
|
||||
},
|
||||
State::Connecting(ref mut c) => {
|
||||
let sock = ready!(c.poll(cx, &config.handle))?;
|
||||
let sock = ready!(c.poll(cx, &config.handle))
|
||||
.map_err(ConnectError::m("tcp connect error"))?;
|
||||
|
||||
if let Some(dur) = config.keep_alive_timeout {
|
||||
sock.set_keepalive(Some(dur))?;
|
||||
sock.set_keepalive(Some(dur)).map_err(ConnectError::m("tcp set_keepalive error"))?;
|
||||
}
|
||||
|
||||
if let Some(size) = config.send_buffer_size {
|
||||
sock.set_send_buffer_size(size)?;
|
||||
sock.set_send_buffer_size(size).map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
|
||||
}
|
||||
|
||||
if let Some(size) = config.recv_buffer_size {
|
||||
sock.set_recv_buffer_size(size)?;
|
||||
sock.set_recv_buffer_size(size).map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
|
||||
}
|
||||
|
||||
sock.set_nodelay(config.nodelay)?;
|
||||
sock.set_nodelay(config.nodelay).map_err(ConnectError::m("tcp set_nodelay error"))?;
|
||||
|
||||
let extra = HttpInfo {
|
||||
remote_addr: sock.peer_addr()?,
|
||||
remote_addr: sock.peer_addr().map_err(ConnectError::m("tcp peer_addr error"))?,
|
||||
};
|
||||
let connected = Connected::new()
|
||||
.extra(extra);
|
||||
@@ -642,7 +693,6 @@ impl ConnectingTcp {
|
||||
mod tests {
|
||||
use std::io;
|
||||
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
use tokio_net::driver::Handle;
|
||||
|
||||
use super::{Connected, Destination, HttpConnector};
|
||||
@@ -655,55 +705,29 @@ mod tests {
|
||||
connector.connect(super::super::sealed::Internal, dst).await
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_errors_missing_authority() {
|
||||
let mut rt = Runtime::new().unwrap();
|
||||
let uri = "/foo/bar?baz".parse().unwrap();
|
||||
let dst = Destination {
|
||||
uri,
|
||||
};
|
||||
let connector = HttpConnector::new();
|
||||
|
||||
rt.block_on(async {
|
||||
assert_eq!(
|
||||
connect(connector, dst).await.unwrap_err().kind(),
|
||||
io::ErrorKind::InvalidInput,
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_errors_enforce_http() {
|
||||
let mut rt = Runtime::new().unwrap();
|
||||
#[tokio::test]
|
||||
async fn test_errors_enforce_http() {
|
||||
let uri = "https://example.domain/foo/bar?baz".parse().unwrap();
|
||||
let dst = Destination {
|
||||
uri,
|
||||
};
|
||||
let connector = HttpConnector::new();
|
||||
|
||||
rt.block_on(async {
|
||||
assert_eq!(
|
||||
connect(connector, dst).await.unwrap_err().kind(),
|
||||
io::ErrorKind::InvalidInput,
|
||||
);
|
||||
})
|
||||
let err = connect(connector, dst).await.unwrap_err();
|
||||
assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_errors_missing_scheme() {
|
||||
let mut rt = Runtime::new().unwrap();
|
||||
#[tokio::test]
|
||||
async fn test_errors_missing_scheme() {
|
||||
let uri = "example.domain".parse().unwrap();
|
||||
let dst = Destination {
|
||||
uri,
|
||||
};
|
||||
let connector = HttpConnector::new();
|
||||
let mut connector = HttpConnector::new();
|
||||
connector.enforce_http(false);
|
||||
|
||||
rt.block_on(async {
|
||||
assert_eq!(
|
||||
connect(connector, dst).await.unwrap_err().kind(),
|
||||
io::ErrorKind::InvalidInput,
|
||||
);
|
||||
});
|
||||
let err = connect(connector, dst).await.unwrap_err();
|
||||
assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1722,7 +1722,7 @@ mod dispatch_impl {
|
||||
|
||||
impl hyper::service::Service<Destination> for DebugConnector {
|
||||
type Response = (DebugStream, Connected);
|
||||
type Error = io::Error;
|
||||
type Error = <HttpConnector as hyper::service::Service<Destination>>::Error;
|
||||
type Future = Pin<Box<dyn Future<
|
||||
Output = Result<Self::Response, Self::Error>
|
||||
> + Send>>;
|
||||
|
||||
Reference in New Issue
Block a user