Refactor connect errors to not use io::Error (#782)

This commit is contained in:
Sean McArthur
2020-01-13 13:29:14 -08:00
committed by GitHub
parent 14908ad3f0
commit e31d5221fe
2 changed files with 48 additions and 38 deletions

View File

@@ -177,14 +177,13 @@ impl Connector {
if dst.scheme() == Some(&Scheme::HTTPS) { if dst.scheme() == Some(&Scheme::HTTPS) {
let host = dst let host = dst
.host() .host()
.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? .ok_or("no host in url")?
.to_string(); .to_string();
let conn = socks::connect(proxy, dst, dns).await?; let conn = socks::connect(proxy, dst, dns).await?;
let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
let io = tls_connector let io = tls_connector
.connect(&host, conn) .connect(&host, conn)
.await .await?;
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
return Ok(Conn { return Ok(Conn {
inner: self.verbose.wrap(NativeTlsConn { inner: io }), inner: self.verbose.wrap(NativeTlsConn { inner: io }),
is_proxy: false, is_proxy: false,
@@ -200,16 +199,15 @@ impl Connector {
let tls = tls_proxy.clone(); let tls = tls_proxy.clone();
let host = dst let host = dst
.host() .host()
.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? .ok_or("no host in url")?
.to_string(); .to_string();
let conn = socks::connect(proxy, dst, dns).await?; let conn = socks::connect(proxy, dst, dns).await?;
let dnsname = DNSNameRef::try_from_ascii_str(&host) let dnsname = DNSNameRef::try_from_ascii_str(&host)
.map(|dnsname| dnsname.to_owned()) .map(|dnsname| dnsname.to_owned())
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"))?; .map_err(|_| "Invalid DNS Name")?;
let io = RustlsConnector::from(tls) let io = RustlsConnector::from(tls)
.connect(dnsname.as_ref(), conn) .connect(dnsname.as_ref(), conn)
.await .await?;
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
return Ok(Conn { return Ok(Conn {
inner: self.verbose.wrap(RustlsTlsConn { inner: io }), inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
is_proxy: false, is_proxy: false,
@@ -321,7 +319,7 @@ impl Connector {
let tunneled = tunnel( let tunneled = tunnel(
conn, conn,
host host
.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? .ok_or("no host in url")?
.to_string(), .to_string(),
port, port,
self.user_agent.clone(), self.user_agent.clone(),
@@ -329,9 +327,8 @@ impl Connector {
).await?; ).await?;
let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
let io = tls_connector let io = tls_connector
.connect(&host.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?, tunneled) .connect(&host.ok_or("no host in url")?, tunneled)
.await .await?;
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
return Ok(Conn { return Ok(Conn {
inner: self.verbose.wrap(NativeTlsConn { inner: io }), inner: self.verbose.wrap(NativeTlsConn { inner: io }),
is_proxy: false, is_proxy: false,
@@ -350,7 +347,7 @@ impl Connector {
let host = dst let host = dst
.host() .host()
.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? .ok_or("no host in url")?
.to_string(); .to_string();
let port = dst.port().map(|r| r.as_u16()).unwrap_or(443); let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
let mut http = http.clone(); let mut http = http.clone();
@@ -361,13 +358,12 @@ impl Connector {
log::trace!("tunneling HTTPS over proxy"); log::trace!("tunneling HTTPS over proxy");
let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
.map(|dnsname| dnsname.to_owned()) .map(|dnsname| dnsname.to_owned())
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")); .map_err(|_| "Invalid DNS Name");
let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?; let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
let dnsname = maybe_dnsname?; let dnsname = maybe_dnsname?;
let io = RustlsConnector::from(tls) let io = RustlsConnector::from(tls)
.connect(dnsname.as_ref(), tunneled) .connect(dnsname.as_ref(), tunneled)
.await .await?;
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
return Ok(Conn { return Ok(Conn {
inner: self.verbose.wrap(RustlsTlsConn { inner: io }), inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
@@ -412,7 +408,7 @@ where
{ {
if let Some(to) = timeout { if let Some(to) = timeout {
match tokio::time::timeout(to, f).await { match tokio::time::timeout(to, f).await {
Err(_elapsed) => Err(Box::new(io::Error::new(io::ErrorKind::TimedOut, "connect timed out")) as BoxError), Err(_elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError),
Ok(Ok(try_res)) => Ok(try_res), Ok(Ok(try_res)) => Ok(try_res),
Ok(Err(e)) => Err(e), Ok(Err(e)) => Err(e),
} }
@@ -478,7 +474,7 @@ impl AsyncRead for Conn {
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context, cx: &mut Context,
buf: &mut [u8] buf: &mut [u8]
) -> Poll<tokio::io::Result<usize>> { ) -> Poll<io::Result<usize>> {
let this = self.project(); let this = self.project();
AsyncRead::poll_read(this.inner, cx, buf) AsyncRead::poll_read(this.inner, cx, buf)
} }
@@ -494,7 +490,7 @@ impl AsyncRead for Conn {
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context, cx: &mut Context,
buf: &mut B buf: &mut B
) -> Poll<tokio::io::Result<usize>> ) -> Poll<io::Result<usize>>
where where
Self: Sized Self: Sized
{ {
@@ -508,12 +504,12 @@ impl AsyncWrite for Conn {
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context, cx: &mut Context,
buf: &[u8] buf: &[u8]
) -> Poll<Result<usize, tokio::io::Error>> { ) -> Poll<Result<usize, io::Error>> {
let this = self.project(); let this = self.project();
AsyncWrite::poll_write(this.inner, cx, buf) AsyncWrite::poll_write(this.inner, cx, buf)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let this = self.project(); let this = self.project();
AsyncWrite::poll_flush(this.inner, cx) AsyncWrite::poll_flush(this.inner, cx)
} }
@@ -521,7 +517,7 @@ impl AsyncWrite for Conn {
fn poll_shutdown( fn poll_shutdown(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context cx: &mut Context
) -> Poll<Result<(), tokio::io::Error>> { ) -> Poll<Result<(), io::Error>> {
let this = self.project(); let this = self.project();
AsyncWrite::poll_shutdown(this.inner, cx) AsyncWrite::poll_shutdown(this.inner, cx)
} }
@@ -530,7 +526,7 @@ impl AsyncWrite for Conn {
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context, cx: &mut Context,
buf: &mut B buf: &mut B
) -> Poll<Result<usize, tokio::io::Error>> where ) -> Poll<Result<usize, io::Error>> where
Self: Sized { Self: Sized {
let this = self.project(); let this = self.project();
AsyncWrite::poll_write_buf(this.inner, cx, buf) AsyncWrite::poll_write_buf(this.inner, cx, buf)
@@ -547,7 +543,7 @@ async fn tunnel<T>(
port: u16, port: u16,
user_agent: Option<HeaderValue>, user_agent: Option<HeaderValue>,
auth: Option<HeaderValue>, auth: Option<HeaderValue>,
) -> Result<T, io::Error> ) -> Result<T, BoxError>
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,
{ {
@@ -601,29 +597,24 @@ where
return Ok(conn); return Ok(conn);
} }
if pos == buf.len() { if pos == buf.len() {
return Err(io::Error::new( return Err(
io::ErrorKind::Other, "proxy headers too long for tunnel".into()
"proxy headers too long for tunnel", );
));
} }
// else read more // else read more
} else if recvd.starts_with(b"HTTP/1.1 407") { } else if recvd.starts_with(b"HTTP/1.1 407") {
return Err(io::Error::new( return Err(
io::ErrorKind::Other, "proxy authentication required".into()
"proxy authentication required", );
));
} else { } else {
return Err(io::Error::new(io::ErrorKind::Other, "unsuccessful tunnel")); return Err("unsuccessful tunnel".into());
} }
} }
} }
#[cfg(feature = "__tls")] #[cfg(feature = "__tls")]
fn tunnel_eof() -> io::Error { fn tunnel_eof() -> BoxError {
io::Error::new( "unexpected eof while tunneling".into()
io::ErrorKind::UnexpectedEof,
"unexpected eof while tunneling",
)
} }
#[cfg(feature = "default-tls")] #[cfg(feature = "default-tls")]

View File

@@ -82,7 +82,16 @@ impl Error {
/// Returns true if the error is related to a timeout. /// Returns true if the error is related to a timeout.
pub fn is_timeout(&self) -> bool { pub fn is_timeout(&self) -> bool {
self.source().map(|e| e.is::<TimedOut>()).unwrap_or(false) let mut source = self.source();
while let Some(err) = source {
if err.is::<TimedOut>() {
return true;
}
source = err.source();
}
false
} }
/// Returns the status code, if the error was generated from a response. /// Returns the status code, if the error was generated from a response.
@@ -309,4 +318,14 @@ mod tests {
_ => panic!("{:?}", err), _ => panic!("{:?}", err),
} }
} }
#[test]
fn is_timeout() {
let err = super::request(super::TimedOut);
assert!(err.is_timeout());
let io = io::Error::new(io::ErrorKind::Other, err);
let nested = super::request(io);
assert!(nested.is_timeout());
}
} }