Refactor connect errors to not use io::Error (#782)
This commit is contained in:
@@ -177,14 +177,13 @@ impl Connector {
|
|||||||
if dst.scheme() == Some(&Scheme::HTTPS) {
|
if dst.scheme() == Some(&Scheme::HTTPS) {
|
||||||
let host = dst
|
let host = dst
|
||||||
.host()
|
.host()
|
||||||
.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?
|
.ok_or("no host in url")?
|
||||||
.to_string();
|
.to_string();
|
||||||
let conn = socks::connect(proxy, dst, dns).await?;
|
let conn = socks::connect(proxy, dst, dns).await?;
|
||||||
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
|
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
|
||||||
let io = tls_connector
|
let io = tls_connector
|
||||||
.connect(&host, conn)
|
.connect(&host, conn)
|
||||||
.await
|
.await?;
|
||||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
|
||||||
return Ok(Conn {
|
return Ok(Conn {
|
||||||
inner: self.verbose.wrap(NativeTlsConn { inner: io }),
|
inner: self.verbose.wrap(NativeTlsConn { inner: io }),
|
||||||
is_proxy: false,
|
is_proxy: false,
|
||||||
@@ -200,16 +199,15 @@ impl Connector {
|
|||||||
let tls = tls_proxy.clone();
|
let tls = tls_proxy.clone();
|
||||||
let host = dst
|
let host = dst
|
||||||
.host()
|
.host()
|
||||||
.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?
|
.ok_or("no host in url")?
|
||||||
.to_string();
|
.to_string();
|
||||||
let conn = socks::connect(proxy, dst, dns).await?;
|
let conn = socks::connect(proxy, dst, dns).await?;
|
||||||
let dnsname = DNSNameRef::try_from_ascii_str(&host)
|
let dnsname = DNSNameRef::try_from_ascii_str(&host)
|
||||||
.map(|dnsname| dnsname.to_owned())
|
.map(|dnsname| dnsname.to_owned())
|
||||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"))?;
|
.map_err(|_| "Invalid DNS Name")?;
|
||||||
let io = RustlsConnector::from(tls)
|
let io = RustlsConnector::from(tls)
|
||||||
.connect(dnsname.as_ref(), conn)
|
.connect(dnsname.as_ref(), conn)
|
||||||
.await
|
.await?;
|
||||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
|
||||||
return Ok(Conn {
|
return Ok(Conn {
|
||||||
inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
|
inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
|
||||||
is_proxy: false,
|
is_proxy: false,
|
||||||
@@ -321,7 +319,7 @@ impl Connector {
|
|||||||
let tunneled = tunnel(
|
let tunneled = tunnel(
|
||||||
conn,
|
conn,
|
||||||
host
|
host
|
||||||
.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?
|
.ok_or("no host in url")?
|
||||||
.to_string(),
|
.to_string(),
|
||||||
port,
|
port,
|
||||||
self.user_agent.clone(),
|
self.user_agent.clone(),
|
||||||
@@ -329,9 +327,8 @@ impl Connector {
|
|||||||
).await?;
|
).await?;
|
||||||
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
|
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
|
||||||
let io = tls_connector
|
let io = tls_connector
|
||||||
.connect(&host.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?, tunneled)
|
.connect(&host.ok_or("no host in url")?, tunneled)
|
||||||
.await
|
.await?;
|
||||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
|
||||||
return Ok(Conn {
|
return Ok(Conn {
|
||||||
inner: self.verbose.wrap(NativeTlsConn { inner: io }),
|
inner: self.verbose.wrap(NativeTlsConn { inner: io }),
|
||||||
is_proxy: false,
|
is_proxy: false,
|
||||||
@@ -350,7 +347,7 @@ impl Connector {
|
|||||||
|
|
||||||
let host = dst
|
let host = dst
|
||||||
.host()
|
.host()
|
||||||
.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?
|
.ok_or("no host in url")?
|
||||||
.to_string();
|
.to_string();
|
||||||
let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
|
let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
|
||||||
let mut http = http.clone();
|
let mut http = http.clone();
|
||||||
@@ -361,13 +358,12 @@ impl Connector {
|
|||||||
log::trace!("tunneling HTTPS over proxy");
|
log::trace!("tunneling HTTPS over proxy");
|
||||||
let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
|
let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
|
||||||
.map(|dnsname| dnsname.to_owned())
|
.map(|dnsname| dnsname.to_owned())
|
||||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"));
|
.map_err(|_| "Invalid DNS Name");
|
||||||
let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
|
let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
|
||||||
let dnsname = maybe_dnsname?;
|
let dnsname = maybe_dnsname?;
|
||||||
let io = RustlsConnector::from(tls)
|
let io = RustlsConnector::from(tls)
|
||||||
.connect(dnsname.as_ref(), tunneled)
|
.connect(dnsname.as_ref(), tunneled)
|
||||||
.await
|
.await?;
|
||||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
|
||||||
|
|
||||||
return Ok(Conn {
|
return Ok(Conn {
|
||||||
inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
|
inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
|
||||||
@@ -412,7 +408,7 @@ where
|
|||||||
{
|
{
|
||||||
if let Some(to) = timeout {
|
if let Some(to) = timeout {
|
||||||
match tokio::time::timeout(to, f).await {
|
match tokio::time::timeout(to, f).await {
|
||||||
Err(_elapsed) => Err(Box::new(io::Error::new(io::ErrorKind::TimedOut, "connect timed out")) as BoxError),
|
Err(_elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError),
|
||||||
Ok(Ok(try_res)) => Ok(try_res),
|
Ok(Ok(try_res)) => Ok(try_res),
|
||||||
Ok(Err(e)) => Err(e),
|
Ok(Err(e)) => Err(e),
|
||||||
}
|
}
|
||||||
@@ -478,7 +474,7 @@ impl AsyncRead for Conn {
|
|||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
cx: &mut Context,
|
cx: &mut Context,
|
||||||
buf: &mut [u8]
|
buf: &mut [u8]
|
||||||
) -> Poll<tokio::io::Result<usize>> {
|
) -> Poll<io::Result<usize>> {
|
||||||
let this = self.project();
|
let this = self.project();
|
||||||
AsyncRead::poll_read(this.inner, cx, buf)
|
AsyncRead::poll_read(this.inner, cx, buf)
|
||||||
}
|
}
|
||||||
@@ -494,7 +490,7 @@ impl AsyncRead for Conn {
|
|||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
cx: &mut Context,
|
cx: &mut Context,
|
||||||
buf: &mut B
|
buf: &mut B
|
||||||
) -> Poll<tokio::io::Result<usize>>
|
) -> Poll<io::Result<usize>>
|
||||||
where
|
where
|
||||||
Self: Sized
|
Self: Sized
|
||||||
{
|
{
|
||||||
@@ -508,12 +504,12 @@ impl AsyncWrite for Conn {
|
|||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
cx: &mut Context,
|
cx: &mut Context,
|
||||||
buf: &[u8]
|
buf: &[u8]
|
||||||
) -> Poll<Result<usize, tokio::io::Error>> {
|
) -> Poll<Result<usize, io::Error>> {
|
||||||
let this = self.project();
|
let this = self.project();
|
||||||
AsyncWrite::poll_write(this.inner, cx, buf)
|
AsyncWrite::poll_write(this.inner, cx, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
|
||||||
let this = self.project();
|
let this = self.project();
|
||||||
AsyncWrite::poll_flush(this.inner, cx)
|
AsyncWrite::poll_flush(this.inner, cx)
|
||||||
}
|
}
|
||||||
@@ -521,7 +517,7 @@ impl AsyncWrite for Conn {
|
|||||||
fn poll_shutdown(
|
fn poll_shutdown(
|
||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
cx: &mut Context
|
cx: &mut Context
|
||||||
) -> Poll<Result<(), tokio::io::Error>> {
|
) -> Poll<Result<(), io::Error>> {
|
||||||
let this = self.project();
|
let this = self.project();
|
||||||
AsyncWrite::poll_shutdown(this.inner, cx)
|
AsyncWrite::poll_shutdown(this.inner, cx)
|
||||||
}
|
}
|
||||||
@@ -530,7 +526,7 @@ impl AsyncWrite for Conn {
|
|||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
cx: &mut Context,
|
cx: &mut Context,
|
||||||
buf: &mut B
|
buf: &mut B
|
||||||
) -> Poll<Result<usize, tokio::io::Error>> where
|
) -> Poll<Result<usize, io::Error>> where
|
||||||
Self: Sized {
|
Self: Sized {
|
||||||
let this = self.project();
|
let this = self.project();
|
||||||
AsyncWrite::poll_write_buf(this.inner, cx, buf)
|
AsyncWrite::poll_write_buf(this.inner, cx, buf)
|
||||||
@@ -547,7 +543,7 @@ async fn tunnel<T>(
|
|||||||
port: u16,
|
port: u16,
|
||||||
user_agent: Option<HeaderValue>,
|
user_agent: Option<HeaderValue>,
|
||||||
auth: Option<HeaderValue>,
|
auth: Option<HeaderValue>,
|
||||||
) -> Result<T, io::Error>
|
) -> Result<T, BoxError>
|
||||||
where
|
where
|
||||||
T: AsyncRead + AsyncWrite + Unpin,
|
T: AsyncRead + AsyncWrite + Unpin,
|
||||||
{
|
{
|
||||||
@@ -601,29 +597,24 @@ where
|
|||||||
return Ok(conn);
|
return Ok(conn);
|
||||||
}
|
}
|
||||||
if pos == buf.len() {
|
if pos == buf.len() {
|
||||||
return Err(io::Error::new(
|
return Err(
|
||||||
io::ErrorKind::Other,
|
"proxy headers too long for tunnel".into()
|
||||||
"proxy headers too long for tunnel",
|
);
|
||||||
));
|
|
||||||
}
|
}
|
||||||
// else read more
|
// else read more
|
||||||
} else if recvd.starts_with(b"HTTP/1.1 407") {
|
} else if recvd.starts_with(b"HTTP/1.1 407") {
|
||||||
return Err(io::Error::new(
|
return Err(
|
||||||
io::ErrorKind::Other,
|
"proxy authentication required".into()
|
||||||
"proxy authentication required",
|
);
|
||||||
));
|
|
||||||
} else {
|
} else {
|
||||||
return Err(io::Error::new(io::ErrorKind::Other, "unsuccessful tunnel"));
|
return Err("unsuccessful tunnel".into());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "__tls")]
|
#[cfg(feature = "__tls")]
|
||||||
fn tunnel_eof() -> io::Error {
|
fn tunnel_eof() -> BoxError {
|
||||||
io::Error::new(
|
"unexpected eof while tunneling".into()
|
||||||
io::ErrorKind::UnexpectedEof,
|
|
||||||
"unexpected eof while tunneling",
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "default-tls")]
|
#[cfg(feature = "default-tls")]
|
||||||
|
|||||||
21
src/error.rs
21
src/error.rs
@@ -82,7 +82,16 @@ impl Error {
|
|||||||
|
|
||||||
/// Returns true if the error is related to a timeout.
|
/// Returns true if the error is related to a timeout.
|
||||||
pub fn is_timeout(&self) -> bool {
|
pub fn is_timeout(&self) -> bool {
|
||||||
self.source().map(|e| e.is::<TimedOut>()).unwrap_or(false)
|
let mut source = self.source();
|
||||||
|
|
||||||
|
while let Some(err) = source {
|
||||||
|
if err.is::<TimedOut>() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
source = err.source();
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the status code, if the error was generated from a response.
|
/// Returns the status code, if the error was generated from a response.
|
||||||
@@ -309,4 +318,14 @@ mod tests {
|
|||||||
_ => panic!("{:?}", err),
|
_ => panic!("{:?}", err),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_timeout() {
|
||||||
|
let err = super::request(super::TimedOut);
|
||||||
|
assert!(err.is_timeout());
|
||||||
|
|
||||||
|
let io = io::Error::new(io::ErrorKind::Other, err);
|
||||||
|
let nested = super::request(io);
|
||||||
|
assert!(nested.is_timeout());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user