feat(client): change Resolve to be Service<Name>

Closes #1903

BREAKING CHANGE: The `Resolve` trait is gone. All custom resolves should
  implement `tower::Service` instead.

  The error type of `HttpConnector` has been changed away from
  `std::io::Error`.
This commit is contained in:
Sean McArthur
2019-11-12 12:06:16 -08:00
parent 039281b89c
commit 9d9233ce7c
3 changed files with 171 additions and 132 deletions

View File

@@ -228,11 +228,18 @@ impl<R> HttpConnector<R> {
}
}
static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
impl<R: Resolve> HttpConnector<R> {
fn invalid_url(&self, err: InvalidUrl) -> HttpConnecting<R> {
fn invalid_url(&self, msg: impl Into<Box<str>>) -> HttpConnecting<R> {
HttpConnecting {
config: self.config.clone(),
state: State::Error(Some(io::Error::new(io::ErrorKind::InvalidInput, err))),
state: State::Error(Some(ConnectError {
msg: msg.into(),
cause: None,
})),
port: 0,
}
}
@@ -252,14 +259,11 @@ where
R::Future: Send,
{
type Response = (TcpStream, Connected);
type Error = io::Error;
type Error = ConnectError;
type Future = HttpConnecting<R>;
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
// For now, always ready.
// TODO: When `Resolve` becomes an alias for `Service`, check
// the resolver's readiness.
drop(cx);
ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
Poll::Ready(Ok(()))
}
@@ -273,15 +277,15 @@ where
if self.config.enforce_http {
if dst.uri.scheme_part() != Some(&Scheme::HTTP) {
return self.invalid_url(InvalidUrl::NotHttp);
return self.invalid_url(INVALID_NOT_HTTP);
}
} else if dst.uri.scheme_part().is_none() {
return self.invalid_url(InvalidUrl::MissingScheme);
return self.invalid_url(INVALID_MISSING_SCHEME);
}
let host = match dst.uri.host() {
Some(s) => s,
None => return self.invalid_url(InvalidUrl::MissingAuthority),
None => return self.invalid_url(INVALID_MISSING_HOST),
};
let port = match dst.uri.port_part() {
Some(port) => port.as_u16(),
@@ -302,7 +306,7 @@ where
R::Future: Send,
{
type Response = TcpStream;
type Error = io::Error;
type Error = ConnectError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
@@ -324,28 +328,73 @@ impl HttpInfo {
}
}
#[derive(Debug, Clone, Copy)]
enum InvalidUrl {
MissingScheme,
NotHttp,
MissingAuthority,
// Not publicly exported (so missing_docs doesn't trigger).
pub struct ConnectError {
msg: Box<str>,
cause: Option<Box<dyn StdError + Send + Sync>>,
}
impl fmt::Display for InvalidUrl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.description())
impl ConnectError {
fn new<S, E>(msg: S, cause: E) -> ConnectError
where
S: Into<Box<str>>,
E: Into<Box<dyn StdError + Send + Sync>>,
{
ConnectError {
msg: msg.into(),
cause: Some(cause.into()),
}
}
}
impl StdError for InvalidUrl {
fn description(&self) -> &str {
match *self {
InvalidUrl::MissingScheme => "invalid URL, missing scheme",
InvalidUrl::NotHttp => "invalid URL, scheme must be http",
InvalidUrl::MissingAuthority => "invalid URL, missing domain",
fn dns<E>(cause: E) -> ConnectError
where
E: Into<Box<dyn StdError + Send + Sync>>,
{
ConnectError::new("dns error", cause)
}
fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
where
S: Into<Box<str>>,
E: Into<Box<dyn StdError + Send + Sync>>,
{
move |cause| {
ConnectError::new(msg, cause)
}
}
}
impl fmt::Debug for ConnectError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(ref cause) = self.cause {
f.debug_tuple("ConnectError")
.field(&self.msg)
.field(cause)
.finish()
} else {
self.msg.fmt(f)
}
}
}
impl fmt::Display for ConnectError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.msg)?;
if let Some(ref cause) = self.cause {
write!(f, ": {}", cause)?;
}
Ok(())
}
}
impl StdError for ConnectError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
self.cause.as_ref().map(|e| &**e as _)
}
}
/// A Future representing work to connect to a URL.
#[must_use = "futures do nothing unless polled"]
#[pin_project]
@@ -361,11 +410,11 @@ enum State<R: Resolve> {
Lazy(R, String),
Resolving(#[pin] R::Future),
Connecting(ConnectingTcp),
Error(Option<io::Error>),
Error(Option<ConnectError>),
}
impl<R: Resolve> Future for HttpConnecting<R> {
type Output = Result<(TcpStream, Connected), io::Error>;
type Output = Result<(TcpStream, Connected), ConnectError>;
#[project]
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
@@ -375,19 +424,20 @@ impl<R: Resolve> Future for HttpConnecting<R> {
let state;
#[project]
match me.state.as_mut().project() {
State::Lazy(ref resolver, ref mut host) => {
State::Lazy(ref mut resolver, ref mut host) => {
// If the host is already an IP addr (v4 or v6),
// skip resolving the dns and start connecting right away.
if let Some(addrs) = dns::IpAddrs::try_parse(host, *me.port) {
state = State::Connecting(ConnectingTcp::new(
config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address));
} else {
ready!(resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
let name = dns::Name::new(mem::replace(host, String::new()));
state = State::Resolving(resolver.resolve(name));
}
},
State::Resolving(future) => {
let addrs = ready!(future.poll(cx))?;
let addrs = ready!(future.poll(cx)).map_err(ConnectError::dns)?;
let port = *me.port;
let addrs = addrs
.map(|addr| SocketAddr::new(addr, port))
@@ -397,24 +447,25 @@ impl<R: Resolve> Future for HttpConnecting<R> {
config.local_address, addrs, config.connect_timeout, config.happy_eyeballs_timeout, config.reuse_address));
},
State::Connecting(ref mut c) => {
let sock = ready!(c.poll(cx, &config.handle))?;
let sock = ready!(c.poll(cx, &config.handle))
.map_err(ConnectError::m("tcp connect error"))?;
if let Some(dur) = config.keep_alive_timeout {
sock.set_keepalive(Some(dur))?;
sock.set_keepalive(Some(dur)).map_err(ConnectError::m("tcp set_keepalive error"))?;
}
if let Some(size) = config.send_buffer_size {
sock.set_send_buffer_size(size)?;
sock.set_send_buffer_size(size).map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
}
if let Some(size) = config.recv_buffer_size {
sock.set_recv_buffer_size(size)?;
sock.set_recv_buffer_size(size).map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
}
sock.set_nodelay(config.nodelay)?;
sock.set_nodelay(config.nodelay).map_err(ConnectError::m("tcp set_nodelay error"))?;
let extra = HttpInfo {
remote_addr: sock.peer_addr()?,
remote_addr: sock.peer_addr().map_err(ConnectError::m("tcp peer_addr error"))?,
};
let connected = Connected::new()
.extra(extra);
@@ -642,7 +693,6 @@ impl ConnectingTcp {
mod tests {
use std::io;
use tokio::runtime::current_thread::Runtime;
use tokio_net::driver::Handle;
use super::{Connected, Destination, HttpConnector};
@@ -655,55 +705,29 @@ mod tests {
connector.connect(super::super::sealed::Internal, dst).await
}
#[test]
fn test_errors_missing_authority() {
let mut rt = Runtime::new().unwrap();
let uri = "/foo/bar?baz".parse().unwrap();
let dst = Destination {
uri,
};
let connector = HttpConnector::new();
rt.block_on(async {
assert_eq!(
connect(connector, dst).await.unwrap_err().kind(),
io::ErrorKind::InvalidInput,
);
})
}
#[test]
fn test_errors_enforce_http() {
let mut rt = Runtime::new().unwrap();
#[tokio::test]
async fn test_errors_enforce_http() {
let uri = "https://example.domain/foo/bar?baz".parse().unwrap();
let dst = Destination {
uri,
};
let connector = HttpConnector::new();
rt.block_on(async {
assert_eq!(
connect(connector, dst).await.unwrap_err().kind(),
io::ErrorKind::InvalidInput,
);
})
let err = connect(connector, dst).await.unwrap_err();
assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
}
#[test]
fn test_errors_missing_scheme() {
let mut rt = Runtime::new().unwrap();
#[tokio::test]
async fn test_errors_missing_scheme() {
let uri = "example.domain".parse().unwrap();
let dst = Destination {
uri,
};
let connector = HttpConnector::new();
let mut connector = HttpConnector::new();
connector.enforce_http(false);
rt.block_on(async {
assert_eq!(
connect(connector, dst).await.unwrap_err().kind(),
io::ErrorKind::InvalidInput,
);
});
let err = connect(connector, dst).await.unwrap_err();
assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
}
#[test]