Refactor connect errors to not use io::Error (#782)
This commit is contained in:
		| @@ -177,14 +177,13 @@ impl Connector { | |||||||
|                 if dst.scheme() == Some(&Scheme::HTTPS) { |                 if dst.scheme() == Some(&Scheme::HTTPS) { | ||||||
|                     let host = dst |                     let host = dst | ||||||
|                         .host() |                         .host() | ||||||
|                         .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? |                         .ok_or("no host in url")? | ||||||
|                         .to_string(); |                         .to_string(); | ||||||
|                     let conn = socks::connect(proxy, dst, dns).await?; |                     let conn = socks::connect(proxy, dst, dns).await?; | ||||||
|                     let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); |                     let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); | ||||||
|                     let io = tls_connector |                     let io = tls_connector | ||||||
|                         .connect(&host, conn) |                         .connect(&host, conn) | ||||||
|                         .await |                         .await?; | ||||||
|                         .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; |  | ||||||
|                     return Ok(Conn { |                     return Ok(Conn { | ||||||
|                         inner: self.verbose.wrap(NativeTlsConn { inner: io }), |                         inner: self.verbose.wrap(NativeTlsConn { inner: io }), | ||||||
|                         is_proxy: false, |                         is_proxy: false, | ||||||
| @@ -200,16 +199,15 @@ impl Connector { | |||||||
|                     let tls = tls_proxy.clone(); |                     let tls = tls_proxy.clone(); | ||||||
|                     let host = dst |                     let host = dst | ||||||
|                         .host() |                         .host() | ||||||
|                         .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? |                         .ok_or("no host in url")? | ||||||
|                         .to_string(); |                         .to_string(); | ||||||
|                     let conn = socks::connect(proxy, dst, dns).await?; |                     let conn = socks::connect(proxy, dst, dns).await?; | ||||||
|                     let dnsname = DNSNameRef::try_from_ascii_str(&host) |                     let dnsname = DNSNameRef::try_from_ascii_str(&host) | ||||||
|                         .map(|dnsname| dnsname.to_owned()) |                         .map(|dnsname| dnsname.to_owned()) | ||||||
|                         .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"))?; |                         .map_err(|_| "Invalid DNS Name")?; | ||||||
|                     let io = RustlsConnector::from(tls) |                     let io = RustlsConnector::from(tls) | ||||||
|                         .connect(dnsname.as_ref(), conn) |                         .connect(dnsname.as_ref(), conn) | ||||||
|                         .await |                         .await?; | ||||||
|                         .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; |  | ||||||
|                     return Ok(Conn { |                     return Ok(Conn { | ||||||
|                         inner: self.verbose.wrap(RustlsTlsConn { inner: io }), |                         inner: self.verbose.wrap(RustlsTlsConn { inner: io }), | ||||||
|                         is_proxy: false, |                         is_proxy: false, | ||||||
| @@ -321,7 +319,7 @@ impl Connector { | |||||||
|                     let tunneled = tunnel( |                     let tunneled = tunnel( | ||||||
|                         conn, |                         conn, | ||||||
|                         host |                         host | ||||||
|                             .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? |                             .ok_or("no host in url")? | ||||||
|                             .to_string(), |                             .to_string(), | ||||||
|                         port, |                         port, | ||||||
|                         self.user_agent.clone(), |                         self.user_agent.clone(), | ||||||
| @@ -329,9 +327,8 @@ impl Connector { | |||||||
|                     ).await?; |                     ).await?; | ||||||
|                     let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); |                     let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); | ||||||
|                     let io = tls_connector |                     let io = tls_connector | ||||||
|                         .connect(&host.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?, tunneled) |                         .connect(&host.ok_or("no host in url")?, tunneled) | ||||||
|                         .await |                         .await?; | ||||||
|                         .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; |  | ||||||
|                     return Ok(Conn { |                     return Ok(Conn { | ||||||
|                         inner: self.verbose.wrap(NativeTlsConn { inner: io }), |                         inner: self.verbose.wrap(NativeTlsConn { inner: io }), | ||||||
|                         is_proxy: false, |                         is_proxy: false, | ||||||
| @@ -350,7 +347,7 @@ impl Connector { | |||||||
|  |  | ||||||
|                     let host = dst |                     let host = dst | ||||||
|                         .host() |                         .host() | ||||||
|                         .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? |                         .ok_or("no host in url")? | ||||||
|                         .to_string(); |                         .to_string(); | ||||||
|                     let port = dst.port().map(|r| r.as_u16()).unwrap_or(443); |                     let port = dst.port().map(|r| r.as_u16()).unwrap_or(443); | ||||||
|                     let mut http = http.clone(); |                     let mut http = http.clone(); | ||||||
| @@ -361,13 +358,12 @@ impl Connector { | |||||||
|                     log::trace!("tunneling HTTPS over proxy"); |                     log::trace!("tunneling HTTPS over proxy"); | ||||||
|                     let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) |                     let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) | ||||||
|                         .map(|dnsname| dnsname.to_owned()) |                         .map(|dnsname| dnsname.to_owned()) | ||||||
|                         .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")); |                         .map_err(|_| "Invalid DNS Name"); | ||||||
|                     let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?; |                     let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?; | ||||||
|                     let dnsname = maybe_dnsname?; |                     let dnsname = maybe_dnsname?; | ||||||
|                     let io = RustlsConnector::from(tls) |                     let io = RustlsConnector::from(tls) | ||||||
|                         .connect(dnsname.as_ref(), tunneled) |                         .connect(dnsname.as_ref(), tunneled) | ||||||
|                         .await |                         .await?; | ||||||
|                         .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; |  | ||||||
|  |  | ||||||
|                     return Ok(Conn { |                     return Ok(Conn { | ||||||
|                         inner: self.verbose.wrap(RustlsTlsConn { inner: io }), |                         inner: self.verbose.wrap(RustlsTlsConn { inner: io }), | ||||||
| @@ -412,7 +408,7 @@ where | |||||||
| { | { | ||||||
|     if let Some(to) = timeout { |     if let Some(to) = timeout { | ||||||
|         match tokio::time::timeout(to, f).await { |         match tokio::time::timeout(to, f).await { | ||||||
|             Err(_elapsed) => Err(Box::new(io::Error::new(io::ErrorKind::TimedOut, "connect timed out")) as BoxError), |             Err(_elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError), | ||||||
|             Ok(Ok(try_res)) => Ok(try_res), |             Ok(Ok(try_res)) => Ok(try_res), | ||||||
|             Ok(Err(e)) => Err(e), |             Ok(Err(e)) => Err(e), | ||||||
|         } |         } | ||||||
| @@ -478,7 +474,7 @@ impl AsyncRead for Conn { | |||||||
|         self: Pin<&mut Self>, |         self: Pin<&mut Self>, | ||||||
|         cx: &mut Context, |         cx: &mut Context, | ||||||
|         buf: &mut [u8] |         buf: &mut [u8] | ||||||
|     ) -> Poll<tokio::io::Result<usize>> { |     ) -> Poll<io::Result<usize>> { | ||||||
|         let this = self.project(); |         let this = self.project(); | ||||||
|         AsyncRead::poll_read(this.inner, cx, buf) |         AsyncRead::poll_read(this.inner, cx, buf) | ||||||
|     } |     } | ||||||
| @@ -494,7 +490,7 @@ impl AsyncRead for Conn { | |||||||
|         self: Pin<&mut Self>, |         self: Pin<&mut Self>, | ||||||
|         cx: &mut Context, |         cx: &mut Context, | ||||||
|         buf: &mut B |         buf: &mut B | ||||||
|     ) -> Poll<tokio::io::Result<usize>> |     ) -> Poll<io::Result<usize>> | ||||||
|         where |         where | ||||||
|             Self: Sized |             Self: Sized | ||||||
|     { |     { | ||||||
| @@ -508,12 +504,12 @@ impl AsyncWrite for Conn { | |||||||
|         self: Pin<&mut Self>, |         self: Pin<&mut Self>, | ||||||
|         cx: &mut Context, |         cx: &mut Context, | ||||||
|         buf: &[u8] |         buf: &[u8] | ||||||
|     ) -> Poll<Result<usize, tokio::io::Error>> { |     ) -> Poll<Result<usize, io::Error>> { | ||||||
|         let this = self.project(); |         let this = self.project(); | ||||||
|         AsyncWrite::poll_write(this.inner, cx, buf) |         AsyncWrite::poll_write(this.inner, cx, buf) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> { |     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> { | ||||||
|         let this = self.project(); |         let this = self.project(); | ||||||
|         AsyncWrite::poll_flush(this.inner, cx) |         AsyncWrite::poll_flush(this.inner, cx) | ||||||
|     } |     } | ||||||
| @@ -521,7 +517,7 @@ impl AsyncWrite for Conn { | |||||||
|     fn poll_shutdown( |     fn poll_shutdown( | ||||||
|         self: Pin<&mut Self>, |         self: Pin<&mut Self>, | ||||||
|         cx: &mut Context |         cx: &mut Context | ||||||
|     ) -> Poll<Result<(), tokio::io::Error>> { |     ) -> Poll<Result<(), io::Error>> { | ||||||
|         let this = self.project(); |         let this = self.project(); | ||||||
|         AsyncWrite::poll_shutdown(this.inner, cx) |         AsyncWrite::poll_shutdown(this.inner, cx) | ||||||
|     } |     } | ||||||
| @@ -530,7 +526,7 @@ impl AsyncWrite for Conn { | |||||||
|         self: Pin<&mut Self>, |         self: Pin<&mut Self>, | ||||||
|         cx: &mut Context, |         cx: &mut Context, | ||||||
|         buf: &mut B |         buf: &mut B | ||||||
|     ) -> Poll<Result<usize, tokio::io::Error>> where |     ) -> Poll<Result<usize, io::Error>> where | ||||||
|         Self: Sized { |         Self: Sized { | ||||||
|         let this = self.project(); |         let this = self.project(); | ||||||
|         AsyncWrite::poll_write_buf(this.inner, cx, buf) |         AsyncWrite::poll_write_buf(this.inner, cx, buf) | ||||||
| @@ -547,7 +543,7 @@ async fn tunnel<T>( | |||||||
|     port: u16, |     port: u16, | ||||||
|     user_agent: Option<HeaderValue>, |     user_agent: Option<HeaderValue>, | ||||||
|     auth: Option<HeaderValue>, |     auth: Option<HeaderValue>, | ||||||
| ) -> Result<T, io::Error> | ) -> Result<T, BoxError> | ||||||
| where | where | ||||||
|     T: AsyncRead + AsyncWrite + Unpin, |     T: AsyncRead + AsyncWrite + Unpin, | ||||||
| { | { | ||||||
| @@ -601,29 +597,24 @@ where | |||||||
|                 return Ok(conn); |                 return Ok(conn); | ||||||
|             } |             } | ||||||
|             if pos == buf.len() { |             if pos == buf.len() { | ||||||
|                 return Err(io::Error::new( |                 return Err( | ||||||
|                     io::ErrorKind::Other, |                     "proxy headers too long for tunnel".into() | ||||||
|                     "proxy headers too long for tunnel", |                 ); | ||||||
|                 )); |  | ||||||
|             } |             } | ||||||
|         // else read more |         // else read more | ||||||
|         } else if recvd.starts_with(b"HTTP/1.1 407") { |         } else if recvd.starts_with(b"HTTP/1.1 407") { | ||||||
|             return Err(io::Error::new( |             return Err( | ||||||
|                 io::ErrorKind::Other, |                 "proxy authentication required".into() | ||||||
|                 "proxy authentication required", |             ); | ||||||
|             )); |  | ||||||
|         } else { |         } else { | ||||||
|             return Err(io::Error::new(io::ErrorKind::Other, "unsuccessful tunnel")); |             return Err("unsuccessful tunnel".into()); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| #[cfg(feature = "__tls")] | #[cfg(feature = "__tls")] | ||||||
| fn tunnel_eof() -> io::Error { | fn tunnel_eof() -> BoxError { | ||||||
|     io::Error::new( |     "unexpected eof while tunneling".into() | ||||||
|         io::ErrorKind::UnexpectedEof, |  | ||||||
|         "unexpected eof while tunneling", |  | ||||||
|     ) |  | ||||||
| } | } | ||||||
|  |  | ||||||
| #[cfg(feature = "default-tls")] | #[cfg(feature = "default-tls")] | ||||||
|   | |||||||
							
								
								
									
										21
									
								
								src/error.rs
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								src/error.rs
									
									
									
									
									
								
							| @@ -82,7 +82,16 @@ impl Error { | |||||||
|  |  | ||||||
|     /// Returns true if the error is related to a timeout. |     /// Returns true if the error is related to a timeout. | ||||||
|     pub fn is_timeout(&self) -> bool { |     pub fn is_timeout(&self) -> bool { | ||||||
|         self.source().map(|e| e.is::<TimedOut>()).unwrap_or(false) |         let mut source = self.source(); | ||||||
|  |  | ||||||
|  |         while let Some(err) = source { | ||||||
|  |             if err.is::<TimedOut>() { | ||||||
|  |                 return true; | ||||||
|  |             } | ||||||
|  |             source = err.source(); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         false | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /// Returns the status code, if the error was generated from a response. |     /// Returns the status code, if the error was generated from a response. | ||||||
| @@ -309,4 +318,14 @@ mod tests { | |||||||
|             _ => panic!("{:?}", err), |             _ => panic!("{:?}", err), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     #[test] | ||||||
|  |     fn is_timeout() { | ||||||
|  |         let err = super::request(super::TimedOut); | ||||||
|  |         assert!(err.is_timeout()); | ||||||
|  |  | ||||||
|  |         let io = io::Error::new(io::ErrorKind::Other, err); | ||||||
|  |         let nested = super::request(io); | ||||||
|  |         assert!(nested.is_timeout()); | ||||||
|  |     } | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user