refactor all to async/await (#617)
Co-authored-by: Danny Browning <danny.browning@protectwise.com> Co-authored-by: Daniel Eades <danieleades@hotmail.com>
This commit is contained in:
805
src/connect.rs
805
src/connect.rs
@@ -1,28 +1,26 @@
|
||||
use futures::Future;
|
||||
use futures::FutureExt;
|
||||
use http::uri::Scheme;
|
||||
use hyper::client::connect::{Connect, Connected, Destination};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use tokio_timer::Timeout;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use bytes::BufMut;
|
||||
#[cfg(feature = "tls")]
|
||||
use futures::Poll;
|
||||
#[cfg(feature = "default-tls")]
|
||||
use native_tls::{TlsConnector, TlsConnectorBuilder};
|
||||
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(feature = "trust-dns")]
|
||||
use crate::dns::TrustDnsResolver;
|
||||
//#[cfg(feature = "trust-dns")]
|
||||
//use crate::dns::TrustDnsResolver;
|
||||
use crate::proxy::{Proxy, ProxyScheme};
|
||||
use tokio::future::FutureExt as _;
|
||||
|
||||
#[cfg(feature = "trust-dns")]
|
||||
type HttpConnector = hyper::client::HttpConnector<TrustDnsResolver>;
|
||||
#[cfg(not(feature = "trust-dns"))]
|
||||
//#[cfg(feature = "trust-dns")]
|
||||
//type HttpConnector = hyper::client::HttpConnector<TrustDnsResolver>;
|
||||
//#[cfg(not(feature = "trust-dns"))]
|
||||
type HttpConnector = hyper::client::HttpConnector;
|
||||
|
||||
pub(crate) struct Connector {
|
||||
@@ -33,6 +31,7 @@ pub(crate) struct Connector {
|
||||
nodelay: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum Inner {
|
||||
#[cfg(not(feature = "tls"))]
|
||||
Http(HttpConnector),
|
||||
@@ -76,7 +75,7 @@ impl Connector {
|
||||
where
|
||||
T: Into<Option<IpAddr>>,
|
||||
{
|
||||
let tls = try_!(tls.build());
|
||||
let tls = tls.build().map_err(crate::error::from)?;
|
||||
|
||||
let mut http = http_connector()?;
|
||||
http.set_local_address(local_addr.into());
|
||||
@@ -130,25 +129,11 @@ impl Connector {
|
||||
}
|
||||
|
||||
#[cfg(feature = "socks")]
|
||||
fn connect_socks(&self, dst: Destination, proxy: ProxyScheme) -> Connecting {
|
||||
macro_rules! timeout {
|
||||
($future:expr) => {
|
||||
if let Some(dur) = self.timeout {
|
||||
Box::new(Timeout::new($future, dur).map_err(|err| {
|
||||
if err.is_inner() {
|
||||
err.into_inner().expect("is_inner")
|
||||
} else if err.is_elapsed() {
|
||||
io::Error::new(io::ErrorKind::TimedOut, "connect timed out")
|
||||
} else {
|
||||
io::Error::new(io::ErrorKind::Other, err)
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
Box::new($future)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
async fn connect_socks(
|
||||
&self,
|
||||
dst: Destination,
|
||||
proxy: ProxyScheme,
|
||||
) -> Result<(Conn, Connected), io::Error> {
|
||||
let dns = match proxy {
|
||||
ProxyScheme::Socks5 {
|
||||
remote_dns: false, ..
|
||||
@@ -167,14 +152,15 @@ impl Connector {
|
||||
if dst.scheme() == "https" {
|
||||
use self::native_tls_async::TlsConnectorExt;
|
||||
|
||||
let tls = tls.clone();
|
||||
let host = dst.host().to_owned();
|
||||
let socks_connecting = socks::connect(proxy, dst, dns);
|
||||
return timeout!(socks_connecting.and_then(move |(conn, connected)| {
|
||||
tls.connect_async(&host, conn)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
.map(move |io| (Box::new(io) as Conn, connected))
|
||||
}));
|
||||
let (conn, connected) = socks::connect(proxy, dst, dns).await?;
|
||||
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
|
||||
let io = tls_connector
|
||||
.connect(&host, conn)
|
||||
.await
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
Ok((Box::new(io) as Conn, connected))
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "rustls-tls")]
|
||||
@@ -185,40 +171,193 @@ impl Connector {
|
||||
|
||||
let tls = tls_proxy.clone();
|
||||
let host = dst.host().to_owned();
|
||||
let socks_connecting = socks::connect(proxy, dst, dns);
|
||||
return timeout!(socks_connecting.and_then(move |(conn, connected)| {
|
||||
let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
|
||||
.map(|dnsname| dnsname.to_owned())
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"));
|
||||
futures::future::result(maybe_dnsname)
|
||||
.and_then(move |dnsname| {
|
||||
RustlsConnector::from(tls)
|
||||
.connect(dnsname.as_ref(), conn)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
})
|
||||
.map(move |io| (Box::new(io) as Conn, connected))
|
||||
}));
|
||||
let (conn, connected) = socks::connect(proxy, dst, dns);
|
||||
let dnsname = DNSNameRef::try_from_ascii_str(&host)
|
||||
.map(|dnsname| dnsname.to_owned())
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"))?;
|
||||
let io = RustlsConnector::from(tls)
|
||||
.connect(dnsname.as_ref(), conn)
|
||||
.await
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
Ok((Box::new(io) as Conn, connected))
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "tls"))]
|
||||
Inner::Http(_) => (),
|
||||
Inner::Http(_) => socks::connect(proxy, dst, dns),
|
||||
}
|
||||
|
||||
// else no TLS
|
||||
socks::connect(proxy, dst, dns)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "trust-dns")]
|
||||
//#[cfg(feature = "trust-dns")]
|
||||
//fn http_connector() -> crate::Result<HttpConnector> {
|
||||
// TrustDnsResolver::new()
|
||||
// .map(HttpConnector::new_with_resolver)
|
||||
// .map_err(crate::error::dns_system_conf)
|
||||
//}
|
||||
|
||||
//#[cfg(not(feature = "trust-dns"))]
|
||||
fn http_connector() -> crate::Result<HttpConnector> {
|
||||
TrustDnsResolver::new()
|
||||
.map(HttpConnector::new_with_resolver)
|
||||
.map_err(crate::error::dns_system_conf)
|
||||
Ok(HttpConnector::new())
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "trust-dns"))]
|
||||
fn http_connector() -> crate::Result<HttpConnector> {
|
||||
Ok(HttpConnector::new(4))
|
||||
async fn connect_with_maybe_proxy(
|
||||
inner: Inner,
|
||||
dst: Destination,
|
||||
is_proxy: bool,
|
||||
no_delay: bool,
|
||||
) -> Result<(Conn, Connected), io::Error> {
|
||||
match inner {
|
||||
#[cfg(not(feature = "tls"))]
|
||||
Inner::Http(http) => {
|
||||
drop(no_delay); // only used for TLS?
|
||||
let (io, connected) = http.connect(dst).await?;
|
||||
Ok((Box::new(io) as Conn, connected.proxy(is_proxy)))
|
||||
}
|
||||
#[cfg(feature = "default-tls")]
|
||||
Inner::DefaultTls(http, tls) => {
|
||||
let mut http = http.clone();
|
||||
|
||||
http.set_nodelay(no_delay || (dst.scheme() == "https"));
|
||||
|
||||
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
|
||||
let http = hyper_tls::HttpsConnector::from((http, tls_connector));
|
||||
let (io, connected) = http.connect(dst).await?;
|
||||
//TODO: where's this at now?
|
||||
//if let hyper_tls::MaybeHttpsStream::Https(_stream) = &io {
|
||||
// if !no_delay {
|
||||
// stream.set_nodelay(false)?;
|
||||
// }
|
||||
//}
|
||||
|
||||
Ok((Box::new(io) as Conn, connected.proxy(is_proxy)))
|
||||
}
|
||||
#[cfg(feature = "rustls-tls")]
|
||||
Inner::RustlsTls { http, tls, .. } => {
|
||||
let mut http = http.clone();
|
||||
|
||||
// Disable Nagle's algorithm for TLS handshake
|
||||
//
|
||||
// https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
|
||||
http.set_nodelay(nodelay || (dst.scheme() == "https"));
|
||||
|
||||
let http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
|
||||
let (io, connected) = http.connect(dst).await;
|
||||
if let hyper_rustls::MaybeHttpsStream::Https(stream) = &io {
|
||||
if !nodelay {
|
||||
let (io, _) = stream.get_ref();
|
||||
io.set_nodelay(false)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok((Box::new(io) as Conn, connected.proxy(is_proxy)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_via_proxy(
|
||||
inner: Inner,
|
||||
dst: Destination,
|
||||
proxy_scheme: ProxyScheme,
|
||||
no_delay: bool,
|
||||
) -> Result<(Conn, Connected), io::Error> {
|
||||
log::trace!("proxy({:?}) intercepts {:?}", proxy_scheme, dst);
|
||||
|
||||
let (puri, _auth) = match proxy_scheme {
|
||||
ProxyScheme::Http { uri, auth, .. } => (uri, auth),
|
||||
#[cfg(feature = "socks")]
|
||||
ProxyScheme::Socks5 { .. } => return this.connect_socks(dst, proxy_scheme),
|
||||
};
|
||||
|
||||
let mut ndst = dst.clone();
|
||||
|
||||
let new_scheme = puri.scheme_part().map(Scheme::as_str).unwrap_or("http");
|
||||
ndst.set_scheme(new_scheme)
|
||||
.expect("proxy target scheme should be valid");
|
||||
|
||||
ndst.set_host(puri.host().expect("proxy target should have host"))
|
||||
.expect("proxy target host should be valid");
|
||||
|
||||
ndst.set_port(puri.port_part().map(|port| port.as_u16()));
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
let auth = _auth;
|
||||
|
||||
match &inner {
|
||||
#[cfg(feature = "default-tls")]
|
||||
Inner::DefaultTls(http, tls) => {
|
||||
if dst.scheme() == "https" {
|
||||
let host = dst.host().to_owned();
|
||||
let port = dst.port().unwrap_or(443);
|
||||
let mut http = http.clone();
|
||||
http.set_nodelay(no_delay);
|
||||
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
|
||||
let http = hyper_tls::HttpsConnector::from((http, tls_connector));
|
||||
let (conn, connected) = http.connect(ndst).await?;
|
||||
log::trace!("tunneling HTTPS over proxy");
|
||||
let tunneled = tunnel(conn, host.clone(), port, auth).await?;
|
||||
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
|
||||
let io = tls_connector
|
||||
.connect(&host, tunneled)
|
||||
.await
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
return Ok((Box::new(io) as Conn, connected.proxy(true)));
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "rustls-tls")]
|
||||
Inner::RustlsTls {
|
||||
http,
|
||||
tls,
|
||||
tls_proxy,
|
||||
} => {
|
||||
if dst.scheme() == "https" {
|
||||
use rustls::Session;
|
||||
use tokio_rustls::webpki::DNSNameRef;
|
||||
use tokio_rustls::TlsConnector as RustlsConnector;
|
||||
|
||||
let host = dst.host().to_owned();
|
||||
let port = dst.port().unwrap_or(443);
|
||||
let mut http = http.clone();
|
||||
http.set_nodelay(nodelay);
|
||||
let http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
|
||||
let tls = tls.clone();
|
||||
let (conn, connected) = http.connect(ndst).await;
|
||||
log::trace!("tunneling HTTPS over proxy");
|
||||
let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
|
||||
.map(|dnsname| dnsname.to_owned())
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"));
|
||||
let tunneled = tunnel(conn, host, port, auth).await;
|
||||
let dnsname = maybe_dnsname?;
|
||||
let io = RustlsConnector::from(tls)
|
||||
.connect(dnsname.as_ref(), tunneled)
|
||||
.await
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
let connected = if io.get_ref().1.get_alpn_protocol() == Some(b"h2") {
|
||||
connected.negotiated_h2()
|
||||
} else {
|
||||
connected
|
||||
};
|
||||
return Ok((Box::new(io) as Conn, connected.proxy(true)));
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "tls"))]
|
||||
Inner::Http(_) => (),
|
||||
}
|
||||
|
||||
connect_with_maybe_proxy(inner, ndst, true, no_delay).await
|
||||
}
|
||||
|
||||
async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, io::Error>
|
||||
where
|
||||
F: Future<Output = Result<T, io::Error>>,
|
||||
{
|
||||
if let Some(to) = timeout {
|
||||
match f.timeout(to).await {
|
||||
Err(_elapsed) => Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out")),
|
||||
Ok(try_res) => try_res,
|
||||
}
|
||||
} else {
|
||||
f.await
|
||||
}
|
||||
}
|
||||
|
||||
impl Connect for Connector {
|
||||
@@ -228,202 +367,47 @@ impl Connect for Connector {
|
||||
|
||||
fn connect(&self, dst: Destination) -> Self::Future {
|
||||
#[cfg(feature = "tls")]
|
||||
let nodelay = self.nodelay;
|
||||
|
||||
macro_rules! timeout {
|
||||
($future:expr) => {
|
||||
if let Some(dur) = self.timeout {
|
||||
Box::new(Timeout::new($future, dur).map_err(|err| {
|
||||
if err.is_inner() {
|
||||
err.into_inner().expect("is_inner")
|
||||
} else if err.is_elapsed() {
|
||||
io::Error::new(io::ErrorKind::TimedOut, "connect timed out")
|
||||
} else {
|
||||
io::Error::new(io::ErrorKind::Other, err)
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
Box::new($future)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! connect {
|
||||
( $http:expr, $dst:expr, $proxy:expr ) => {
|
||||
timeout!($http
|
||||
.connect($dst)
|
||||
.map(|(io, connected)| (Box::new(io) as Conn, connected.proxy($proxy))))
|
||||
};
|
||||
( $dst:expr, $proxy:expr ) => {
|
||||
match &self.inner {
|
||||
#[cfg(not(feature = "tls"))]
|
||||
Inner::Http(http) => connect!(http, $dst, $proxy),
|
||||
#[cfg(feature = "default-tls")]
|
||||
Inner::DefaultTls(http, tls) => {
|
||||
let mut http = http.clone();
|
||||
|
||||
http.set_nodelay(nodelay || ($dst.scheme() == "https"));
|
||||
|
||||
let http = hyper_tls::HttpsConnector::from((http, tls.clone()));
|
||||
timeout!(http.connect($dst).and_then(move |(io, connected)| {
|
||||
if let hyper_tls::MaybeHttpsStream::Https(stream) = &io {
|
||||
if !nodelay {
|
||||
stream.get_ref().get_ref().set_nodelay(false)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok((Box::new(io) as Conn, connected.proxy($proxy)))
|
||||
}))
|
||||
}
|
||||
#[cfg(feature = "rustls-tls")]
|
||||
Inner::RustlsTls { http, tls, .. } => {
|
||||
let mut http = http.clone();
|
||||
|
||||
// Disable Nagle's algorithm for TLS handshake
|
||||
//
|
||||
// https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
|
||||
http.set_nodelay(nodelay || ($dst.scheme() == "https"));
|
||||
|
||||
let http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
|
||||
timeout!(http.connect($dst).and_then(move |(io, connected)| {
|
||||
if let hyper_rustls::MaybeHttpsStream::Https(stream) = &io {
|
||||
if !nodelay {
|
||||
let (io, _) = stream.get_ref();
|
||||
io.set_nodelay(false)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok((Box::new(io) as Conn, connected.proxy($proxy)))
|
||||
}))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let no_delay = self.nodelay;
|
||||
#[cfg(not(feature = "tls"))]
|
||||
let no_delay = false;
|
||||
let timeout = self.timeout;
|
||||
for prox in self.proxies.iter() {
|
||||
if let Some(proxy_scheme) = prox.intercept(&dst) {
|
||||
log::trace!("proxy({:?}) intercepts {:?}", proxy_scheme, dst);
|
||||
|
||||
let (puri, _auth) = match proxy_scheme {
|
||||
ProxyScheme::Http { uri, auth, .. } => (uri, auth),
|
||||
#[cfg(feature = "socks")]
|
||||
ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme),
|
||||
};
|
||||
|
||||
let mut ndst = dst.clone();
|
||||
|
||||
let new_scheme = puri.scheme_part().map(Scheme::as_str).unwrap_or("http");
|
||||
ndst.set_scheme(new_scheme)
|
||||
.expect("proxy target scheme should be valid");
|
||||
|
||||
ndst.set_host(puri.host().expect("proxy target should have host"))
|
||||
.expect("proxy target host should be valid");
|
||||
|
||||
ndst.set_port(puri.port_part().map(|port| port.as_u16()));
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
let auth = _auth;
|
||||
|
||||
match &self.inner {
|
||||
#[cfg(feature = "default-tls")]
|
||||
Inner::DefaultTls(http, tls) => {
|
||||
if dst.scheme() == "https" {
|
||||
use self::native_tls_async::TlsConnectorExt;
|
||||
|
||||
let host = dst.host().to_owned();
|
||||
let port = dst.port().unwrap_or(443);
|
||||
let mut http = http.clone();
|
||||
http.set_nodelay(nodelay);
|
||||
let http = hyper_tls::HttpsConnector::from((http, tls.clone()));
|
||||
let tls = tls.clone();
|
||||
return timeout!(http.connect(ndst).and_then(
|
||||
move |(conn, connected)| {
|
||||
log::trace!("tunneling HTTPS over proxy");
|
||||
tunnel(conn, host.clone(), port, auth)
|
||||
.and_then(move |tunneled| {
|
||||
tls.connect_async(&host, tunneled).map_err(|e| {
|
||||
io::Error::new(io::ErrorKind::Other, e)
|
||||
})
|
||||
})
|
||||
.map(|io| (Box::new(io) as Conn, connected.proxy(true)))
|
||||
}
|
||||
));
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "rustls-tls")]
|
||||
Inner::RustlsTls {
|
||||
http,
|
||||
tls,
|
||||
tls_proxy,
|
||||
} => {
|
||||
if dst.scheme() == "https" {
|
||||
use rustls::Session;
|
||||
use tokio_rustls::webpki::DNSNameRef;
|
||||
use tokio_rustls::TlsConnector as RustlsConnector;
|
||||
|
||||
let host = dst.host().to_owned();
|
||||
let port = dst.port().unwrap_or(443);
|
||||
let mut http = http.clone();
|
||||
http.set_nodelay(nodelay);
|
||||
let http =
|
||||
hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
|
||||
let tls = tls.clone();
|
||||
return timeout!(http.connect(ndst).and_then(
|
||||
move |(conn, connected)| {
|
||||
log::trace!("tunneling HTTPS over proxy");
|
||||
let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
|
||||
.map(|dnsname| dnsname.to_owned())
|
||||
.map_err(|_| {
|
||||
io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")
|
||||
});
|
||||
tunnel(conn, host, port, auth)
|
||||
.and_then(move |tunneled| Ok((maybe_dnsname?, tunneled)))
|
||||
.and_then(move |(dnsname, tunneled)| {
|
||||
RustlsConnector::from(tls)
|
||||
.connect(dnsname.as_ref(), tunneled)
|
||||
.map_err(|e| {
|
||||
io::Error::new(io::ErrorKind::Other, e)
|
||||
})
|
||||
})
|
||||
.map(|io| {
|
||||
let connected = if io.get_ref().1.get_alpn_protocol()
|
||||
== Some(b"h2")
|
||||
{
|
||||
connected.negotiated_h2()
|
||||
} else {
|
||||
connected
|
||||
};
|
||||
(Box::new(io) as Conn, connected.proxy(true))
|
||||
})
|
||||
}
|
||||
));
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "tls"))]
|
||||
Inner::Http(_) => (),
|
||||
}
|
||||
|
||||
return connect!(ndst, true);
|
||||
return with_timeout(
|
||||
connect_via_proxy(self.inner.clone(), dst, proxy_scheme, no_delay),
|
||||
timeout,
|
||||
)
|
||||
.boxed();
|
||||
}
|
||||
}
|
||||
|
||||
connect!(dst, false)
|
||||
with_timeout(
|
||||
connect_with_maybe_proxy(self.inner.clone(), dst, false, no_delay),
|
||||
timeout,
|
||||
)
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait AsyncConn: AsyncRead + AsyncWrite {}
|
||||
impl<T: AsyncRead + AsyncWrite> AsyncConn for T {}
|
||||
pub(crate) type Conn = Box<dyn AsyncConn + Send + Sync + 'static>;
|
||||
pub(crate) type Conn = Box<dyn AsyncConn + Send + Sync + Unpin + 'static>;
|
||||
|
||||
pub(crate) type Connecting = Box<dyn Future<Item = (Conn, Connected), Error = io::Error> + Send>;
|
||||
pub(crate) type Connecting =
|
||||
Pin<Box<dyn Future<Output = Result<(Conn, Connected), io::Error>> + Send>>;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
fn tunnel<T>(
|
||||
conn: T,
|
||||
async fn tunnel<T>(
|
||||
mut conn: T,
|
||||
host: String,
|
||||
port: u16,
|
||||
auth: Option<http::header::HeaderValue>,
|
||||
) -> Tunnel<T> {
|
||||
) -> Result<T, io::Error>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
let mut buf = format!(
|
||||
"\
|
||||
CONNECT {0}:{1} HTTP/1.1\r\n\
|
||||
@@ -443,84 +427,43 @@ fn tunnel<T>(
|
||||
// headers end
|
||||
buf.extend_from_slice(b"\r\n");
|
||||
|
||||
Tunnel {
|
||||
buf: io::Cursor::new(buf),
|
||||
conn: Some(conn),
|
||||
state: TunnelState::Writing,
|
||||
}
|
||||
}
|
||||
conn.write_all(&buf).await?;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
struct Tunnel<T> {
|
||||
buf: io::Cursor<Vec<u8>>,
|
||||
conn: Option<T>,
|
||||
state: TunnelState,
|
||||
}
|
||||
let mut buf = [0; 8192];
|
||||
let mut pos = 0;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
enum TunnelState {
|
||||
Writing,
|
||||
Reading,
|
||||
}
|
||||
loop {
|
||||
let n = conn.read(&mut buf[pos..]).await?;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
impl<T> Future for Tunnel<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Item = T;
|
||||
type Error = io::Error;
|
||||
if n == 0 {
|
||||
return Err(tunnel_eof());
|
||||
}
|
||||
pos += n;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
loop {
|
||||
if let TunnelState::Writing = self.state {
|
||||
let n = futures::try_ready!(self.conn.as_mut().unwrap().write_buf(&mut self.buf));
|
||||
if !self.buf.has_remaining_mut() {
|
||||
self.state = TunnelState::Reading;
|
||||
self.buf.get_mut().truncate(0);
|
||||
} else if n == 0 {
|
||||
return Err(tunnel_eof());
|
||||
}
|
||||
} else {
|
||||
let n = futures::try_ready!(self
|
||||
.conn
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.read_buf(&mut self.buf.get_mut()));
|
||||
let read = &self.buf.get_ref()[..];
|
||||
if n == 0 {
|
||||
return Err(tunnel_eof());
|
||||
} else if read.len() > 12 {
|
||||
if read.starts_with(b"HTTP/1.1 200") || read.starts_with(b"HTTP/1.0 200") {
|
||||
if read.ends_with(b"\r\n\r\n") {
|
||||
return Ok(self.conn.take().unwrap().into());
|
||||
}
|
||||
// else read more
|
||||
} else if read.starts_with(b"HTTP/1.1 407") {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"proxy authentication required",
|
||||
));
|
||||
} else if read.starts_with(b"HTTP/1.1 403") {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"proxy blocked this request",
|
||||
));
|
||||
} else {
|
||||
let (fst, _) = read.split_at(12);
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("unsuccessful tunnel: {:?}", fst).as_str(),
|
||||
));
|
||||
}
|
||||
}
|
||||
let recvd = &buf[..pos];
|
||||
if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
|
||||
if recvd.ends_with(b"\r\n\r\n") {
|
||||
return Ok(conn);
|
||||
}
|
||||
if pos == buf.len() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"proxy headers too long for tunnel",
|
||||
));
|
||||
}
|
||||
// else read more
|
||||
} else if recvd.starts_with(b"HTTP/1.1 407") {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"proxy authentication required",
|
||||
));
|
||||
} else {
|
||||
return Err(io::Error::new(io::ErrorKind::Other, "unsuccessful tunnel"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
#[inline]
|
||||
fn tunnel_eof() -> io::Error {
|
||||
io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
@@ -528,138 +471,6 @@ fn tunnel_eof() -> io::Error {
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(feature = "default-tls")]
|
||||
mod native_tls_async {
|
||||
use std::io::{self, Read, Write};
|
||||
|
||||
use futures::{Async, Future, Poll};
|
||||
use native_tls::{self, Error, HandshakeError, TlsConnector};
|
||||
use tokio_io::{try_nb, AsyncRead, AsyncWrite};
|
||||
|
||||
/// A wrapper around an underlying raw stream which implements the TLS or SSL
|
||||
/// protocol.
|
||||
///
|
||||
/// A `TlsStream<S>` represents a handshake that has been completed successfully
|
||||
/// and both the server and the client are ready for receiving and sending
|
||||
/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written
|
||||
/// to a `TlsStream` are encrypted when passing through to `S`.
|
||||
#[derive(Debug)]
|
||||
pub struct TlsStream<S> {
|
||||
inner: native_tls::TlsStream<S>,
|
||||
}
|
||||
|
||||
/// Future returned from `TlsConnectorExt::connect_async` which will resolve
|
||||
/// once the connection handshake has finished.
|
||||
pub struct ConnectAsync<S> {
|
||||
inner: MidHandshake<S>,
|
||||
}
|
||||
|
||||
struct MidHandshake<S> {
|
||||
inner: Option<Result<native_tls::TlsStream<S>, HandshakeError<S>>>,
|
||||
}
|
||||
|
||||
/// Extension trait for the `TlsConnector` type in the `native_tls` crate.
|
||||
pub trait TlsConnectorExt: sealed::Sealed {
|
||||
/// Connects the provided stream with this connector, assuming the provided
|
||||
/// domain.
|
||||
///
|
||||
/// This function will internally call `TlsConnector::connect` to connect
|
||||
/// the stream and returns a future representing the resolution of the
|
||||
/// connection operation. The returned future will resolve to either
|
||||
/// `TlsStream<S>` or `Error` depending if it's successful or not.
|
||||
///
|
||||
/// This is typically used for clients who have already established, for
|
||||
/// example, a TCP connection to a remote server. That stream is then
|
||||
/// provided here to perform the client half of a connection to a
|
||||
/// TLS-powered server.
|
||||
///
|
||||
/// # Compatibility notes
|
||||
///
|
||||
/// Note that this method currently requires `S: Read + Write` but it's
|
||||
/// highly recommended to ensure that the object implements the `AsyncRead`
|
||||
/// and `AsyncWrite` traits as well, otherwise this function will not work
|
||||
/// properly.
|
||||
fn connect_async<S>(&self, domain: &str, stream: S) -> ConnectAsync<S>
|
||||
where
|
||||
S: Read + Write; // TODO: change to AsyncRead + AsyncWrite
|
||||
}
|
||||
|
||||
mod sealed {
|
||||
pub trait Sealed {}
|
||||
}
|
||||
|
||||
impl<S: Read + Write> Read for TlsStream<S> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.inner.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Read + Write> Write for TlsStream<S> {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.inner.write(buf)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.inner.flush()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite> AsyncRead for TlsStream<S> {}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite> AsyncWrite for TlsStream<S> {
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
try_nb!(self.inner.shutdown());
|
||||
self.inner.get_mut().shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
impl TlsConnectorExt for TlsConnector {
|
||||
fn connect_async<S>(&self, domain: &str, stream: S) -> ConnectAsync<S>
|
||||
where
|
||||
S: Read + Write,
|
||||
{
|
||||
ConnectAsync {
|
||||
inner: MidHandshake {
|
||||
inner: Some(self.connect(domain, stream)),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl sealed::Sealed for TlsConnector {}
|
||||
|
||||
// TODO: change this to AsyncRead/AsyncWrite on next major version
|
||||
impl<S: Read + Write> Future for ConnectAsync<S> {
|
||||
type Item = TlsStream<S>;
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<TlsStream<S>, Error> {
|
||||
self.inner.poll()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: change this to AsyncRead/AsyncWrite on next major version
|
||||
impl<S: Read + Write> Future for MidHandshake<S> {
|
||||
type Item = TlsStream<S>;
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<TlsStream<S>, Error> {
|
||||
match self.inner.take().expect("cannot poll MidHandshake twice") {
|
||||
Ok(stream) => Ok(TlsStream { inner: stream }.into()),
|
||||
Err(HandshakeError::Failure(e)) => Err(e),
|
||||
Err(HandshakeError::WouldBlock(s)) => match s.handshake() {
|
||||
Ok(stream) => Ok(TlsStream { inner: stream }.into()),
|
||||
Err(HandshakeError::Failure(e)) => Err(e),
|
||||
Err(HandshakeError::WouldBlock(s)) => {
|
||||
self.inner = Some(Err(HandshakeError::WouldBlock(s)));
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "socks")]
|
||||
mod socks {
|
||||
use std::io;
|
||||
@@ -678,19 +489,18 @@ mod socks {
|
||||
Proxy,
|
||||
}
|
||||
|
||||
pub(super) fn connect(proxy: ProxyScheme, dst: Destination, dns: DnsResolve) -> Connecting {
|
||||
pub(super) async fn connect(
|
||||
proxy: ProxyScheme,
|
||||
dst: Destination,
|
||||
dns: DnsResolve,
|
||||
) -> Result<(super::Conn, Connected), io::Error> {
|
||||
let https = dst.scheme() == "https";
|
||||
let original_host = dst.host().to_owned();
|
||||
let mut host = original_host.clone();
|
||||
let port = dst.port().unwrap_or_else(|| if https { 443 } else { 80 });
|
||||
|
||||
if let DnsResolve::Local = dns {
|
||||
let maybe_new_target = match (host.as_str(), port).to_socket_addrs() {
|
||||
Ok(mut iter) => iter.next(),
|
||||
Err(err) => {
|
||||
return Box::new(future::err(err));
|
||||
}
|
||||
};
|
||||
let maybe_new_target = (host.as_str(), port).to_socket_addrs()?.next();
|
||||
if let Some(new_target) = maybe_new_target {
|
||||
host = new_target.ip().to_string();
|
||||
}
|
||||
@@ -702,39 +512,33 @@ mod socks {
|
||||
};
|
||||
|
||||
// Get a Tokio TcpStream
|
||||
let stream = future::result(
|
||||
if let Some((username, password)) = auth {
|
||||
Socks5Stream::connect_with_password(
|
||||
socket_addr,
|
||||
(host.as_str(), port),
|
||||
&username,
|
||||
&password,
|
||||
)
|
||||
} else {
|
||||
Socks5Stream::connect(socket_addr, (host.as_str(), port))
|
||||
}
|
||||
.and_then(|s| {
|
||||
TcpStream::from_std(s.into_inner(), &reactor::Handle::default())
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}),
|
||||
);
|
||||
let stream = if let Some((username, password)) = auth {
|
||||
Socks5Stream::connect_with_password(
|
||||
socket_addr,
|
||||
(host.as_str(), port),
|
||||
&username,
|
||||
&password,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
let s = Socks5Stream::connect(socket_addr, (host.as_str(), port)).await;
|
||||
TcpStream::from_std(s.into_inner(), &reactor::Handle::default())
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
|
||||
};
|
||||
|
||||
Box::new(stream.map(|s| (Box::new(s) as super::Conn, Connected::new())))
|
||||
Ok((Box::new(s) as super::Conn, Connected::new()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
extern crate tokio_tcp;
|
||||
|
||||
use self::tokio_tcp::TcpStream;
|
||||
use super::tunnel;
|
||||
use crate::proxy;
|
||||
use futures::Future;
|
||||
use std::io::{Read, Write};
|
||||
use std::net::TcpListener;
|
||||
use std::thread;
|
||||
use tokio::net::tcp::TcpStream;
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
|
||||
static TUNNEL_OK: &[u8] = b"\
|
||||
@@ -782,12 +586,14 @@ mod tests {
|
||||
let addr = mock_tunnel!();
|
||||
|
||||
let mut rt = Runtime::new().unwrap();
|
||||
let work = TcpStream::connect(&addr);
|
||||
let host = addr.ip().to_string();
|
||||
let port = addr.port();
|
||||
let work = work.and_then(|tcp| tunnel(tcp, host, port, None));
|
||||
let f = async move {
|
||||
let tcp = TcpStream::connect(&addr).await?;
|
||||
let host = addr.ip().to_string();
|
||||
let port = addr.port();
|
||||
tunnel(tcp, host, port, None).await
|
||||
};
|
||||
|
||||
rt.block_on(work).unwrap();
|
||||
rt.block_on(f).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -795,12 +601,14 @@ mod tests {
|
||||
let addr = mock_tunnel!(b"HTTP/1.1 200 OK");
|
||||
|
||||
let mut rt = Runtime::new().unwrap();
|
||||
let work = TcpStream::connect(&addr);
|
||||
let host = addr.ip().to_string();
|
||||
let port = addr.port();
|
||||
let work = work.and_then(|tcp| tunnel(tcp, host, port, None));
|
||||
let f = async move {
|
||||
let tcp = TcpStream::connect(&addr).await?;
|
||||
let host = addr.ip().to_string();
|
||||
let port = addr.port();
|
||||
tunnel(tcp, host, port, None).await
|
||||
};
|
||||
|
||||
rt.block_on(work).unwrap_err();
|
||||
rt.block_on(f).unwrap_err();
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -808,12 +616,14 @@ mod tests {
|
||||
let addr = mock_tunnel!(b"foo bar baz hallo");
|
||||
|
||||
let mut rt = Runtime::new().unwrap();
|
||||
let work = TcpStream::connect(&addr);
|
||||
let host = addr.ip().to_string();
|
||||
let port = addr.port();
|
||||
let work = work.and_then(|tcp| tunnel(tcp, host, port, None));
|
||||
let f = async move {
|
||||
let tcp = TcpStream::connect(&addr).await?;
|
||||
let host = addr.ip().to_string();
|
||||
let port = addr.port();
|
||||
tunnel(tcp, host, port, None).await
|
||||
};
|
||||
|
||||
rt.block_on(work).unwrap_err();
|
||||
rt.block_on(f).unwrap_err();
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -827,12 +637,14 @@ mod tests {
|
||||
);
|
||||
|
||||
let mut rt = Runtime::new().unwrap();
|
||||
let work = TcpStream::connect(&addr);
|
||||
let host = addr.ip().to_string();
|
||||
let port = addr.port();
|
||||
let work = work.and_then(|tcp| tunnel(tcp, host, port, None));
|
||||
let f = async move {
|
||||
let tcp = TcpStream::connect(&addr).await?;
|
||||
let host = addr.ip().to_string();
|
||||
let port = addr.port();
|
||||
tunnel(tcp, host, port, None).await
|
||||
};
|
||||
|
||||
let error = rt.block_on(work).unwrap_err();
|
||||
let error = rt.block_on(f).unwrap_err();
|
||||
assert_eq!(error.to_string(), "proxy authentication required");
|
||||
}
|
||||
|
||||
@@ -844,18 +656,19 @@ mod tests {
|
||||
);
|
||||
|
||||
let mut rt = Runtime::new().unwrap();
|
||||
let work = TcpStream::connect(&addr);
|
||||
let host = addr.ip().to_string();
|
||||
let port = addr.port();
|
||||
let work = work.and_then(|tcp| {
|
||||
let f = async move {
|
||||
let tcp = TcpStream::connect(&addr).await?;
|
||||
let host = addr.ip().to_string();
|
||||
let port = addr.port();
|
||||
tunnel(
|
||||
tcp,
|
||||
host,
|
||||
port,
|
||||
Some(proxy::encode_basic_auth("Aladdin", "open sesame")),
|
||||
)
|
||||
});
|
||||
.await
|
||||
};
|
||||
|
||||
rt.block_on(work).unwrap();
|
||||
rt.block_on(f).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user