diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 629d9ae..a830af7 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -137,6 +137,11 @@ impl ClientBuilder { let proxies = Arc::new(config.proxies); let mut connector = { + #[cfg(feature = "tls")] + fn user_agent(headers: &HeaderMap) -> HeaderValue { + headers[USER_AGENT].clone() + } + #[cfg(feature = "tls")] match config.tls { #[cfg(feature = "default-tls")] @@ -156,6 +161,7 @@ impl ClientBuilder { Connector::new_default_tls( tls, proxies.clone(), + user_agent(&config.headers), config.local_address, config.nodelay, )? @@ -189,6 +195,7 @@ impl ClientBuilder { Connector::new_rustls_tls( tls, proxies.clone(), + user_agent(&config.headers), config.local_address, config.nodelay, )? diff --git a/src/connect.rs b/src/connect.rs index 950a0b2..de439f2 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -5,6 +5,8 @@ use tokio::io::{AsyncRead, AsyncWrite}; #[cfg(feature = "default-tls")] use native_tls::{TlsConnector, TlsConnectorBuilder}; +#[cfg(feature = "tls")] +use http::header::HeaderValue; use std::future::Future; use std::io; @@ -23,12 +25,15 @@ use tokio::future::FutureExt as _; //#[cfg(not(feature = "trust-dns"))] type HttpConnector = hyper::client::HttpConnector; +#[derive(Clone)] pub(crate) struct Connector { inner: Inner, proxies: Arc>, timeout: Option, #[cfg(feature = "tls")] nodelay: bool, + #[cfg(feature = "tls")] + user_agent: HeaderValue, } #[derive(Clone)] @@ -69,6 +74,7 @@ impl Connector { pub(crate) fn new_default_tls( tls: TlsConnectorBuilder, proxies: Arc>, + user_agent: HeaderValue, local_addr: T, nodelay: bool, ) -> crate::Result @@ -86,6 +92,7 @@ impl Connector { proxies, timeout: None, nodelay, + user_agent, }) } @@ -93,6 +100,7 @@ impl Connector { pub(crate) fn new_rustls_tls( tls: rustls::ClientConfig, proxies: Arc>, + user_agent: HeaderValue, local_addr: T, nodelay: bool, ) -> crate::Result @@ -121,6 +129,7 @@ impl Connector { proxies, timeout: None, nodelay, + user_agent, }) } @@ -186,6 +195,149 @@ impl Connector { Inner::Http(_) => socks::connect(proxy, dst, dns), } } + + async fn connect_with_maybe_proxy( + self, + dst: Destination, + is_proxy: bool, + ) -> Result<(Conn, Connected), io::Error> { + match self.inner { + #[cfg(not(feature = "tls"))] + Inner::Http(http) => { + let (io, connected) = http.connect(dst).await?; + Ok((Box::new(io) as Conn, connected.proxy(is_proxy))) + } + #[cfg(feature = "default-tls")] + Inner::DefaultTls(http, tls) => { + let mut http = http.clone(); + + http.set_nodelay(self.nodelay || (dst.scheme() == "https")); + + let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); + let http = hyper_tls::HttpsConnector::from((http, tls_connector)); + let (io, connected) = http.connect(dst).await?; + //TODO: where's this at now? + //if let hyper_tls::MaybeHttpsStream::Https(_stream) = &io { + // if !no_delay { + // stream.set_nodelay(false)?; + // } + //} + + Ok((Box::new(io) as Conn, connected.proxy(is_proxy))) + } + #[cfg(feature = "rustls-tls")] + Inner::RustlsTls { http, tls, .. } => { + let mut http = http.clone(); + + // Disable Nagle's algorithm for TLS handshake + // + // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES + http.set_nodelay(no_delay || (dst.scheme() == "https")); + + let http = hyper_rustls::HttpsConnector::from((http, tls.clone())); + let (io, connected) = http.connect(dst).await?; + if let hyper_rustls::MaybeHttpsStream::Https(stream) = &io { + if !no_delay { + let (io, _) = stream.get_ref(); + io.set_nodelay(false)?; + } + } + + Ok((Box::new(io) as Conn, connected.proxy(is_proxy))) + } + } + } + + async fn connect_via_proxy( + self, + dst: Destination, + proxy_scheme: ProxyScheme, + ) -> Result<(Conn, Connected), io::Error> { + log::trace!("proxy({:?}) intercepts {:?}", proxy_scheme, dst); + + let (puri, _auth) = match proxy_scheme { + ProxyScheme::Http { uri, auth, .. } => (uri, auth), + #[cfg(feature = "socks")] + ProxyScheme::Socks5 { .. } => return this.connect_socks(dst, proxy_scheme), + }; + + let mut ndst = dst.clone(); + + let new_scheme = puri.scheme_part().map(Scheme::as_str).unwrap_or("http"); + ndst.set_scheme(new_scheme) + .expect("proxy target scheme should be valid"); + + ndst.set_host(puri.host().expect("proxy target should have host")) + .expect("proxy target host should be valid"); + + ndst.set_port(puri.port_part().map(|port| port.as_u16())); + + #[cfg(feature = "tls")] + let auth = _auth; + + match &self.inner { + #[cfg(feature = "default-tls")] + Inner::DefaultTls(http, tls) => { + if dst.scheme() == "https" { + let host = dst.host().to_owned(); + let port = dst.port().unwrap_or(443); + let mut http = http.clone(); + http.set_nodelay(self.nodelay); + let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); + let http = hyper_tls::HttpsConnector::from((http, tls_connector)); + let (conn, connected) = http.connect(ndst).await?; + log::trace!("tunneling HTTPS over proxy"); + let tunneled = tunnel(conn, host.clone(), port, self.user_agent.clone(), auth).await?; + let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); + let io = tls_connector + .connect(&host, tunneled) + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + return Ok((Box::new(io) as Conn, connected.proxy(true))); + } + } + #[cfg(feature = "rustls-tls")] + Inner::RustlsTls { + http, + tls, + tls_proxy, + } => { + if dst.scheme() == "https" { + use rustls::Session; + use tokio_rustls::webpki::DNSNameRef; + use tokio_rustls::TlsConnector as RustlsConnector; + + let host = dst.host().to_owned(); + let port = dst.port().unwrap_or(443); + let mut http = http.clone(); + http.set_nodelay(no_delay); + let http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone())); + let tls = tls.clone(); + let (conn, connected) = http.connect(ndst).await?; + log::trace!("tunneling HTTPS over proxy"); + let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) + .map(|dnsname| dnsname.to_owned()) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")); + let tunneled = tunnel(conn, host, port, auth).await?; + let dnsname = maybe_dnsname?; + let io = RustlsConnector::from(tls) + .connect(dnsname.as_ref(), tunneled) + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + let connected = if io.get_ref().1.get_alpn_protocol() == Some(b"h2") { + connected.negotiated_h2() + } else { + connected + }; + return Ok((Box::new(io) as Conn, connected.proxy(true))); + } + } + #[cfg(not(feature = "tls"))] + Inner::Http(_) => (), + } + + self.connect_with_maybe_proxy(ndst, true).await + } } //#[cfg(feature = "trust-dns")] @@ -200,151 +352,6 @@ fn http_connector() -> crate::Result { Ok(HttpConnector::new()) } -async fn connect_with_maybe_proxy( - inner: Inner, - dst: Destination, - is_proxy: bool, - no_delay: bool, -) -> Result<(Conn, Connected), io::Error> { - match inner { - #[cfg(not(feature = "tls"))] - Inner::Http(http) => { - drop(no_delay); // only used for TLS? - let (io, connected) = http.connect(dst).await?; - Ok((Box::new(io) as Conn, connected.proxy(is_proxy))) - } - #[cfg(feature = "default-tls")] - Inner::DefaultTls(http, tls) => { - let mut http = http.clone(); - - http.set_nodelay(no_delay || (dst.scheme() == "https")); - - let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); - let http = hyper_tls::HttpsConnector::from((http, tls_connector)); - let (io, connected) = http.connect(dst).await?; - //TODO: where's this at now? - //if let hyper_tls::MaybeHttpsStream::Https(_stream) = &io { - // if !no_delay { - // stream.set_nodelay(false)?; - // } - //} - - Ok((Box::new(io) as Conn, connected.proxy(is_proxy))) - } - #[cfg(feature = "rustls-tls")] - Inner::RustlsTls { http, tls, .. } => { - let mut http = http.clone(); - - // Disable Nagle's algorithm for TLS handshake - // - // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES - http.set_nodelay(no_delay || (dst.scheme() == "https")); - - let http = hyper_rustls::HttpsConnector::from((http, tls.clone())); - let (io, connected) = http.connect(dst).await?; - if let hyper_rustls::MaybeHttpsStream::Https(stream) = &io { - if !no_delay { - let (io, _) = stream.get_ref(); - io.set_nodelay(false)?; - } - } - - Ok((Box::new(io) as Conn, connected.proxy(is_proxy))) - } - } -} - -async fn connect_via_proxy( - inner: Inner, - dst: Destination, - proxy_scheme: ProxyScheme, - no_delay: bool, -) -> Result<(Conn, Connected), io::Error> { - log::trace!("proxy({:?}) intercepts {:?}", proxy_scheme, dst); - - let (puri, _auth) = match proxy_scheme { - ProxyScheme::Http { uri, auth, .. } => (uri, auth), - #[cfg(feature = "socks")] - ProxyScheme::Socks5 { .. } => return this.connect_socks(dst, proxy_scheme), - }; - - let mut ndst = dst.clone(); - - let new_scheme = puri.scheme_part().map(Scheme::as_str).unwrap_or("http"); - ndst.set_scheme(new_scheme) - .expect("proxy target scheme should be valid"); - - ndst.set_host(puri.host().expect("proxy target should have host")) - .expect("proxy target host should be valid"); - - ndst.set_port(puri.port_part().map(|port| port.as_u16())); - - #[cfg(feature = "tls")] - let auth = _auth; - - match &inner { - #[cfg(feature = "default-tls")] - Inner::DefaultTls(http, tls) => { - if dst.scheme() == "https" { - let host = dst.host().to_owned(); - let port = dst.port().unwrap_or(443); - let mut http = http.clone(); - http.set_nodelay(no_delay); - let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); - let http = hyper_tls::HttpsConnector::from((http, tls_connector)); - let (conn, connected) = http.connect(ndst).await?; - log::trace!("tunneling HTTPS over proxy"); - let tunneled = tunnel(conn, host.clone(), port, auth).await?; - let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); - let io = tls_connector - .connect(&host, tunneled) - .await - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - return Ok((Box::new(io) as Conn, connected.proxy(true))); - } - } - #[cfg(feature = "rustls-tls")] - Inner::RustlsTls { - http, - tls, - tls_proxy, - } => { - if dst.scheme() == "https" { - use rustls::Session; - use tokio_rustls::webpki::DNSNameRef; - use tokio_rustls::TlsConnector as RustlsConnector; - - let host = dst.host().to_owned(); - let port = dst.port().unwrap_or(443); - let mut http = http.clone(); - http.set_nodelay(no_delay); - let http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone())); - let tls = tls.clone(); - let (conn, connected) = http.connect(ndst).await?; - log::trace!("tunneling HTTPS over proxy"); - let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) - .map(|dnsname| dnsname.to_owned()) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")); - let tunneled = tunnel(conn, host, port, auth).await?; - let dnsname = maybe_dnsname?; - let io = RustlsConnector::from(tls) - .connect(dnsname.as_ref(), tunneled) - .await - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - let connected = if io.get_ref().1.get_alpn_protocol() == Some(b"h2") { - connected.negotiated_h2() - } else { - connected - }; - return Ok((Box::new(io) as Conn, connected.proxy(true))); - } - } - #[cfg(not(feature = "tls"))] - Inner::Http(_) => (), - } - - connect_with_maybe_proxy(inner, ndst, true, no_delay).await -} async fn with_timeout(f: F, timeout: Option) -> Result where @@ -366,15 +373,11 @@ impl Connect for Connector { type Future = Connecting; fn connect(&self, dst: Destination) -> Self::Future { - #[cfg(feature = "tls")] - let no_delay = self.nodelay; - #[cfg(not(feature = "tls"))] - let no_delay = false; let timeout = self.timeout; for prox in self.proxies.iter() { if let Some(proxy_scheme) = prox.intercept(&dst) { return with_timeout( - connect_via_proxy(self.inner.clone(), dst, proxy_scheme, no_delay), + self.clone().connect_via_proxy(dst, proxy_scheme), timeout, ) .boxed(); @@ -382,7 +385,7 @@ impl Connect for Connector { } with_timeout( - connect_with_maybe_proxy(self.inner.clone(), dst, false, no_delay), + self.clone().connect_with_maybe_proxy(dst, false), timeout, ) .boxed() @@ -401,7 +404,8 @@ async fn tunnel( mut conn: T, host: String, port: u16, - auth: Option, + user_agent: HeaderValue, + auth: Option, ) -> Result where T: AsyncRead + AsyncWrite + Unpin, @@ -417,6 +421,14 @@ where ) .into_bytes(); + + // user-agent + buf.extend_from_slice(b"User-Agent: "); + buf.extend_from_slice(user_agent.as_bytes()); + buf.extend_from_slice(b"\r\n"); + + + // proxy-authorization if let Some(value) = auth { log::debug!("tunnel to {}:{} using basic auth", host, port); buf.extend_from_slice(b"Proxy-Authorization: "); @@ -541,6 +553,7 @@ mod tests { use tokio::net::tcp::TcpStream; use tokio::runtime::current_thread::Runtime; + static TUNNEL_UA: &'static str = "tunnel-test/x.y"; static TUNNEL_OK: &[u8] = b"\ HTTP/1.1 200 OK\r\n\ \r\n\ @@ -560,11 +573,13 @@ mod tests { "\ CONNECT {0}:{1} HTTP/1.1\r\n\ Host: {0}:{1}\r\n\ - {2}\ + User-Agent: {2}\r\n\ + {3}\ \r\n\ ", addr.ip(), addr.port(), + TUNNEL_UA, $auth ) .into_bytes(); @@ -581,6 +596,10 @@ mod tests { }}; } + fn ua() -> http::header::HeaderValue { + http::header::HeaderValue::from_static(TUNNEL_UA) + } + #[test] fn test_tunnel() { let addr = mock_tunnel!(); @@ -590,7 +609,7 @@ mod tests { let tcp = TcpStream::connect(&addr).await?; let host = addr.ip().to_string(); let port = addr.port(); - tunnel(tcp, host, port, None).await + tunnel(tcp, host, port, ua(), None).await }; rt.block_on(f).unwrap(); @@ -605,7 +624,7 @@ mod tests { let tcp = TcpStream::connect(&addr).await?; let host = addr.ip().to_string(); let port = addr.port(); - tunnel(tcp, host, port, None).await + tunnel(tcp, host, port, ua(), None).await }; rt.block_on(f).unwrap_err(); @@ -620,7 +639,7 @@ mod tests { let tcp = TcpStream::connect(&addr).await?; let host = addr.ip().to_string(); let port = addr.port(); - tunnel(tcp, host, port, None).await + tunnel(tcp, host, port, ua(), None).await }; rt.block_on(f).unwrap_err(); @@ -641,7 +660,7 @@ mod tests { let tcp = TcpStream::connect(&addr).await?; let host = addr.ip().to_string(); let port = addr.port(); - tunnel(tcp, host, port, None).await + tunnel(tcp, host, port, ua(), None).await }; let error = rt.block_on(f).unwrap_err(); @@ -664,6 +683,7 @@ mod tests { tcp, host, port, + ua(), Some(proxy::encode_basic_auth("Aladdin", "open sesame")), ) .await