Refactor connect errors to not use io::Error (#782)
This commit is contained in:
		| @@ -177,14 +177,13 @@ impl Connector { | ||||
|                 if dst.scheme() == Some(&Scheme::HTTPS) { | ||||
|                     let host = dst | ||||
|                         .host() | ||||
|                         .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? | ||||
|                         .ok_or("no host in url")? | ||||
|                         .to_string(); | ||||
|                     let conn = socks::connect(proxy, dst, dns).await?; | ||||
|                     let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); | ||||
|                     let io = tls_connector | ||||
|                         .connect(&host, conn) | ||||
|                         .await | ||||
|                         .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; | ||||
|                         .await?; | ||||
|                     return Ok(Conn { | ||||
|                         inner: self.verbose.wrap(NativeTlsConn { inner: io }), | ||||
|                         is_proxy: false, | ||||
| @@ -200,16 +199,15 @@ impl Connector { | ||||
|                     let tls = tls_proxy.clone(); | ||||
|                     let host = dst | ||||
|                         .host() | ||||
|                         .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? | ||||
|                         .ok_or("no host in url")? | ||||
|                         .to_string(); | ||||
|                     let conn = socks::connect(proxy, dst, dns).await?; | ||||
|                     let dnsname = DNSNameRef::try_from_ascii_str(&host) | ||||
|                         .map(|dnsname| dnsname.to_owned()) | ||||
|                         .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"))?; | ||||
|                         .map_err(|_| "Invalid DNS Name")?; | ||||
|                     let io = RustlsConnector::from(tls) | ||||
|                         .connect(dnsname.as_ref(), conn) | ||||
|                         .await | ||||
|                         .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; | ||||
|                         .await?; | ||||
|                     return Ok(Conn { | ||||
|                         inner: self.verbose.wrap(RustlsTlsConn { inner: io }), | ||||
|                         is_proxy: false, | ||||
| @@ -321,7 +319,7 @@ impl Connector { | ||||
|                     let tunneled = tunnel( | ||||
|                         conn, | ||||
|                         host | ||||
|                             .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? | ||||
|                             .ok_or("no host in url")? | ||||
|                             .to_string(), | ||||
|                         port, | ||||
|                         self.user_agent.clone(), | ||||
| @@ -329,9 +327,8 @@ impl Connector { | ||||
|                     ).await?; | ||||
|                     let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); | ||||
|                     let io = tls_connector | ||||
|                         .connect(&host.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?, tunneled) | ||||
|                         .await | ||||
|                         .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; | ||||
|                         .connect(&host.ok_or("no host in url")?, tunneled) | ||||
|                         .await?; | ||||
|                     return Ok(Conn { | ||||
|                         inner: self.verbose.wrap(NativeTlsConn { inner: io }), | ||||
|                         is_proxy: false, | ||||
| @@ -350,7 +347,7 @@ impl Connector { | ||||
|  | ||||
|                     let host = dst | ||||
|                         .host() | ||||
|                         .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? | ||||
|                         .ok_or("no host in url")? | ||||
|                         .to_string(); | ||||
|                     let port = dst.port().map(|r| r.as_u16()).unwrap_or(443); | ||||
|                     let mut http = http.clone(); | ||||
| @@ -361,13 +358,12 @@ impl Connector { | ||||
|                     log::trace!("tunneling HTTPS over proxy"); | ||||
|                     let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) | ||||
|                         .map(|dnsname| dnsname.to_owned()) | ||||
|                         .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")); | ||||
|                         .map_err(|_| "Invalid DNS Name"); | ||||
|                     let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?; | ||||
|                     let dnsname = maybe_dnsname?; | ||||
|                     let io = RustlsConnector::from(tls) | ||||
|                         .connect(dnsname.as_ref(), tunneled) | ||||
|                         .await | ||||
|                         .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; | ||||
|                         .await?; | ||||
|  | ||||
|                     return Ok(Conn { | ||||
|                         inner: self.verbose.wrap(RustlsTlsConn { inner: io }), | ||||
| @@ -412,7 +408,7 @@ where | ||||
| { | ||||
|     if let Some(to) = timeout { | ||||
|         match tokio::time::timeout(to, f).await { | ||||
|             Err(_elapsed) => Err(Box::new(io::Error::new(io::ErrorKind::TimedOut, "connect timed out")) as BoxError), | ||||
|             Err(_elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError), | ||||
|             Ok(Ok(try_res)) => Ok(try_res), | ||||
|             Ok(Err(e)) => Err(e), | ||||
|         } | ||||
| @@ -478,7 +474,7 @@ impl AsyncRead for Conn { | ||||
|         self: Pin<&mut Self>, | ||||
|         cx: &mut Context, | ||||
|         buf: &mut [u8] | ||||
|     ) -> Poll<tokio::io::Result<usize>> { | ||||
|     ) -> Poll<io::Result<usize>> { | ||||
|         let this = self.project(); | ||||
|         AsyncRead::poll_read(this.inner, cx, buf) | ||||
|     } | ||||
| @@ -494,7 +490,7 @@ impl AsyncRead for Conn { | ||||
|         self: Pin<&mut Self>, | ||||
|         cx: &mut Context, | ||||
|         buf: &mut B | ||||
|     ) -> Poll<tokio::io::Result<usize>> | ||||
|     ) -> Poll<io::Result<usize>> | ||||
|         where | ||||
|             Self: Sized | ||||
|     { | ||||
| @@ -508,12 +504,12 @@ impl AsyncWrite for Conn { | ||||
|         self: Pin<&mut Self>, | ||||
|         cx: &mut Context, | ||||
|         buf: &[u8] | ||||
|     ) -> Poll<Result<usize, tokio::io::Error>> { | ||||
|     ) -> Poll<Result<usize, io::Error>> { | ||||
|         let this = self.project(); | ||||
|         AsyncWrite::poll_write(this.inner, cx, buf) | ||||
|     } | ||||
|  | ||||
|     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> { | ||||
|     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> { | ||||
|         let this = self.project(); | ||||
|         AsyncWrite::poll_flush(this.inner, cx) | ||||
|     } | ||||
| @@ -521,7 +517,7 @@ impl AsyncWrite for Conn { | ||||
|     fn poll_shutdown( | ||||
|         self: Pin<&mut Self>, | ||||
|         cx: &mut Context | ||||
|     ) -> Poll<Result<(), tokio::io::Error>> { | ||||
|     ) -> Poll<Result<(), io::Error>> { | ||||
|         let this = self.project(); | ||||
|         AsyncWrite::poll_shutdown(this.inner, cx) | ||||
|     } | ||||
| @@ -530,7 +526,7 @@ impl AsyncWrite for Conn { | ||||
|         self: Pin<&mut Self>, | ||||
|         cx: &mut Context, | ||||
|         buf: &mut B | ||||
|     ) -> Poll<Result<usize, tokio::io::Error>> where | ||||
|     ) -> Poll<Result<usize, io::Error>> where | ||||
|         Self: Sized { | ||||
|         let this = self.project(); | ||||
|         AsyncWrite::poll_write_buf(this.inner, cx, buf) | ||||
| @@ -547,7 +543,7 @@ async fn tunnel<T>( | ||||
|     port: u16, | ||||
|     user_agent: Option<HeaderValue>, | ||||
|     auth: Option<HeaderValue>, | ||||
| ) -> Result<T, io::Error> | ||||
| ) -> Result<T, BoxError> | ||||
| where | ||||
|     T: AsyncRead + AsyncWrite + Unpin, | ||||
| { | ||||
| @@ -601,29 +597,24 @@ where | ||||
|                 return Ok(conn); | ||||
|             } | ||||
|             if pos == buf.len() { | ||||
|                 return Err(io::Error::new( | ||||
|                     io::ErrorKind::Other, | ||||
|                     "proxy headers too long for tunnel", | ||||
|                 )); | ||||
|                 return Err( | ||||
|                     "proxy headers too long for tunnel".into() | ||||
|                 ); | ||||
|             } | ||||
|         // else read more | ||||
|         } else if recvd.starts_with(b"HTTP/1.1 407") { | ||||
|             return Err(io::Error::new( | ||||
|                 io::ErrorKind::Other, | ||||
|                 "proxy authentication required", | ||||
|             )); | ||||
|             return Err( | ||||
|                 "proxy authentication required".into() | ||||
|             ); | ||||
|         } else { | ||||
|             return Err(io::Error::new(io::ErrorKind::Other, "unsuccessful tunnel")); | ||||
|             return Err("unsuccessful tunnel".into()); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[cfg(feature = "__tls")] | ||||
| fn tunnel_eof() -> io::Error { | ||||
|     io::Error::new( | ||||
|         io::ErrorKind::UnexpectedEof, | ||||
|         "unexpected eof while tunneling", | ||||
|     ) | ||||
| fn tunnel_eof() -> BoxError { | ||||
|     "unexpected eof while tunneling".into() | ||||
| } | ||||
|  | ||||
| #[cfg(feature = "default-tls")] | ||||
|   | ||||
							
								
								
									
										21
									
								
								src/error.rs
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								src/error.rs
									
									
									
									
									
								
							| @@ -82,7 +82,16 @@ impl Error { | ||||
|  | ||||
|     /// Returns true if the error is related to a timeout. | ||||
|     pub fn is_timeout(&self) -> bool { | ||||
|         self.source().map(|e| e.is::<TimedOut>()).unwrap_or(false) | ||||
|         let mut source = self.source(); | ||||
|  | ||||
|         while let Some(err) = source { | ||||
|             if err.is::<TimedOut>() { | ||||
|                 return true; | ||||
|             } | ||||
|             source = err.source(); | ||||
|         } | ||||
|  | ||||
|         false | ||||
|     } | ||||
|  | ||||
|     /// Returns the status code, if the error was generated from a response. | ||||
| @@ -309,4 +318,14 @@ mod tests { | ||||
|             _ => panic!("{:?}", err), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn is_timeout() { | ||||
|         let err = super::request(super::TimedOut); | ||||
|         assert!(err.is_timeout()); | ||||
|  | ||||
|         let io = io::Error::new(io::ErrorKind::Other, err); | ||||
|         let nested = super::request(io); | ||||
|         assert!(nested.is_timeout()); | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user