diff --git a/src/connect.rs b/src/connect.rs index 3db6a16..df7519b 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -177,14 +177,13 @@ impl Connector { if dst.scheme() == Some(&Scheme::HTTPS) { let host = dst .host() - .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? + .ok_or("no host in url")? .to_string(); let conn = socks::connect(proxy, dst, dns).await?; let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); let io = tls_connector .connect(&host, conn) - .await - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + .await?; return Ok(Conn { inner: self.verbose.wrap(NativeTlsConn { inner: io }), is_proxy: false, @@ -200,16 +199,15 @@ impl Connector { let tls = tls_proxy.clone(); let host = dst .host() - .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? + .ok_or("no host in url")? .to_string(); let conn = socks::connect(proxy, dst, dns).await?; let dnsname = DNSNameRef::try_from_ascii_str(&host) .map(|dnsname| dnsname.to_owned()) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"))?; + .map_err(|_| "Invalid DNS Name")?; let io = RustlsConnector::from(tls) .connect(dnsname.as_ref(), conn) - .await - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + .await?; return Ok(Conn { inner: self.verbose.wrap(RustlsTlsConn { inner: io }), is_proxy: false, @@ -321,7 +319,7 @@ impl Connector { let tunneled = tunnel( conn, host - .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? + .ok_or("no host in url")? .to_string(), port, self.user_agent.clone(), @@ -329,9 +327,8 @@ impl Connector { ).await?; let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); let io = tls_connector - .connect(&host.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?, tunneled) - .await - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + .connect(&host.ok_or("no host in url")?, tunneled) + .await?; return Ok(Conn { inner: self.verbose.wrap(NativeTlsConn { inner: io }), is_proxy: false, @@ -350,7 +347,7 @@ impl Connector { let host = dst .host() - .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? + .ok_or("no host in url")? .to_string(); let port = dst.port().map(|r| r.as_u16()).unwrap_or(443); let mut http = http.clone(); @@ -361,13 +358,12 @@ impl Connector { log::trace!("tunneling HTTPS over proxy"); let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) .map(|dnsname| dnsname.to_owned()) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")); + .map_err(|_| "Invalid DNS Name"); let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?; let dnsname = maybe_dnsname?; let io = RustlsConnector::from(tls) .connect(dnsname.as_ref(), tunneled) - .await - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + .await?; return Ok(Conn { inner: self.verbose.wrap(RustlsTlsConn { inner: io }), @@ -412,7 +408,7 @@ where { if let Some(to) = timeout { match tokio::time::timeout(to, f).await { - Err(_elapsed) => Err(Box::new(io::Error::new(io::ErrorKind::TimedOut, "connect timed out")) as BoxError), + Err(_elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError), Ok(Ok(try_res)) => Ok(try_res), Ok(Err(e)) => Err(e), } @@ -478,7 +474,7 @@ impl AsyncRead for Conn { self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8] - ) -> Poll> { + ) -> Poll> { let this = self.project(); AsyncRead::poll_read(this.inner, cx, buf) } @@ -494,7 +490,7 @@ impl AsyncRead for Conn { self: Pin<&mut Self>, cx: &mut Context, buf: &mut B - ) -> Poll> + ) -> Poll> where Self: Sized { @@ -508,12 +504,12 @@ impl AsyncWrite for Conn { self: Pin<&mut Self>, cx: &mut Context, buf: &[u8] - ) -> Poll> { + ) -> Poll> { let this = self.project(); AsyncWrite::poll_write(this.inner, cx, buf) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let this = self.project(); AsyncWrite::poll_flush(this.inner, cx) } @@ -521,7 +517,7 @@ impl AsyncWrite for Conn { fn poll_shutdown( self: Pin<&mut Self>, cx: &mut Context - ) -> Poll> { + ) -> Poll> { let this = self.project(); AsyncWrite::poll_shutdown(this.inner, cx) } @@ -530,7 +526,7 @@ impl AsyncWrite for Conn { self: Pin<&mut Self>, cx: &mut Context, buf: &mut B - ) -> Poll> where + ) -> Poll> where Self: Sized { let this = self.project(); AsyncWrite::poll_write_buf(this.inner, cx, buf) @@ -547,7 +543,7 @@ async fn tunnel( port: u16, user_agent: Option, auth: Option, -) -> Result +) -> Result where T: AsyncRead + AsyncWrite + Unpin, { @@ -601,29 +597,24 @@ where return Ok(conn); } if pos == buf.len() { - return Err(io::Error::new( - io::ErrorKind::Other, - "proxy headers too long for tunnel", - )); + return Err( + "proxy headers too long for tunnel".into() + ); } // else read more } else if recvd.starts_with(b"HTTP/1.1 407") { - return Err(io::Error::new( - io::ErrorKind::Other, - "proxy authentication required", - )); + return Err( + "proxy authentication required".into() + ); } else { - return Err(io::Error::new(io::ErrorKind::Other, "unsuccessful tunnel")); + return Err("unsuccessful tunnel".into()); } } } #[cfg(feature = "__tls")] -fn tunnel_eof() -> io::Error { - io::Error::new( - io::ErrorKind::UnexpectedEof, - "unexpected eof while tunneling", - ) +fn tunnel_eof() -> BoxError { + "unexpected eof while tunneling".into() } #[cfg(feature = "default-tls")] diff --git a/src/error.rs b/src/error.rs index 03e792a..5310631 100644 --- a/src/error.rs +++ b/src/error.rs @@ -82,7 +82,16 @@ impl Error { /// Returns true if the error is related to a timeout. pub fn is_timeout(&self) -> bool { - self.source().map(|e| e.is::()).unwrap_or(false) + let mut source = self.source(); + + while let Some(err) = source { + if err.is::() { + return true; + } + source = err.source(); + } + + false } /// Returns the status code, if the error was generated from a response. @@ -309,4 +318,14 @@ mod tests { _ => panic!("{:?}", err), } } + + #[test] + fn is_timeout() { + let err = super::request(super::TimedOut); + assert!(err.is_timeout()); + + let io = io::Error::new(io::ErrorKind::Other, err); + let nested = super::request(io); + assert!(nested.is_timeout()); + } }