Re-enable rustls (#747)
This commit is contained in:
129
src/connect.rs
129
src/connect.rs
@@ -25,6 +25,8 @@ use crate::proxy::{Proxy, ProxyScheme};
|
||||
use crate::error::BoxError;
|
||||
#[cfg(feature = "default-tls")]
|
||||
use self::native_tls_conn::NativeTlsConn;
|
||||
#[cfg(feature = "rustls-tls")]
|
||||
use self::rustls_tls_conn::RustlsTlsConn;
|
||||
|
||||
//#[cfg(feature = "trust-dns")]
|
||||
//type HttpConnector = hyper::client::HttpConnector<TrustDnsResolver>;
|
||||
@@ -244,12 +246,13 @@ impl Connector {
|
||||
// Disable Nagle's algorithm for TLS handshake
|
||||
//
|
||||
// https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
|
||||
http.set_nodelay(no_delay || (dst.scheme() == Some(&Scheme::HTTPS)));
|
||||
http.set_nodelay(self.nodelay || (dst.scheme() == Some(&Scheme::HTTPS)));
|
||||
|
||||
let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
|
||||
let io = http.call(dst).await?;
|
||||
|
||||
let http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
|
||||
let io = http.connect(dst).await?;
|
||||
if let hyper_rustls::MaybeHttpsStream::Https(stream) = &io {
|
||||
if !no_delay {
|
||||
if !self.nodelay {
|
||||
let (io, _) = stream.get_ref();
|
||||
io.set_nodelay(false)?;
|
||||
}
|
||||
@@ -320,35 +323,32 @@ impl Connector {
|
||||
tls_proxy,
|
||||
} => {
|
||||
if dst.scheme() == Some(&Scheme::HTTPS) {
|
||||
use rustls::Session;
|
||||
use tokio_rustls::webpki::DNSNameRef;
|
||||
use tokio_rustls::TlsConnector as RustlsConnector;
|
||||
|
||||
let host = dst.host().to_owned();
|
||||
let port = dst.port().unwrap_or(443);
|
||||
let host = dst.host()
|
||||
.ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?
|
||||
.to_string();
|
||||
let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
|
||||
let mut http = http.clone();
|
||||
http.set_nodelay(no_delay);
|
||||
let http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
|
||||
http.set_nodelay(self.nodelay);
|
||||
let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
|
||||
let tls = tls.clone();
|
||||
let (conn, connected) = http.connect(proxy_dst).await?;
|
||||
let conn = http.call(proxy_dst).await?;
|
||||
log::trace!("tunneling HTTPS over proxy");
|
||||
let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
|
||||
.map(|dnsname| dnsname.to_owned())
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"));
|
||||
let tunneled = tunnel(conn, host, port, auth).await?;
|
||||
let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
|
||||
let dnsname = maybe_dnsname?;
|
||||
let io = RustlsConnector::from(tls)
|
||||
.connect(dnsname.as_ref(), tunneled)
|
||||
.await
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
let connected = if io.get_ref().1.get_alpn_protocol() == Some(b"h2") {
|
||||
connected.negotiated_h2()
|
||||
} else {
|
||||
connected
|
||||
};
|
||||
|
||||
return Ok(Conn {
|
||||
inner: Box::new(io),
|
||||
connected: Connected::new(),
|
||||
inner: Box::new(RustlsTlsConn { inner: io }),
|
||||
is_proxy: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -682,6 +682,99 @@ mod native_tls_conn {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rustls-tls")]
|
||||
mod rustls_tls_conn {
|
||||
use rustls::Session;
|
||||
use std::mem::MaybeUninit;
|
||||
use std::{pin::Pin, task::{Context, Poll}};
|
||||
use bytes::{Buf, BufMut};
|
||||
use hyper::client::connect::{Connected, Connection};
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::client::TlsStream;
|
||||
|
||||
|
||||
pin_project! {
|
||||
pub(super) struct RustlsTlsConn<T> {
|
||||
#[pin] pub(super) inner: TlsStream<T>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RustlsTlsConn<T> {
|
||||
fn connected(&self) -> Connected {
|
||||
if self.inner.get_ref().1.get_alpn_protocol() == Some(b"h2") {
|
||||
self.inner.get_ref().0.connected().negotiated_h2()
|
||||
} else {
|
||||
self.inner.get_ref().0.connected()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsTlsConn<T> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
buf: &mut [u8]
|
||||
) -> Poll<tokio::io::Result<usize>> {
|
||||
let this = self.project();
|
||||
AsyncRead::poll_read(this.inner, cx, buf)
|
||||
}
|
||||
|
||||
unsafe fn prepare_uninitialized_buffer(
|
||||
&self,
|
||||
buf: &mut [MaybeUninit<u8>]
|
||||
) -> bool {
|
||||
self.inner.prepare_uninitialized_buffer(buf)
|
||||
}
|
||||
|
||||
fn poll_read_buf<B: BufMut>(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
buf: &mut B
|
||||
) -> Poll<tokio::io::Result<usize>>
|
||||
where
|
||||
Self: Sized
|
||||
{
|
||||
let this = self.project();
|
||||
AsyncRead::poll_read_buf(this.inner, cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for RustlsTlsConn<T> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
buf: &[u8]
|
||||
) -> Poll<Result<usize, tokio::io::Error>> {
|
||||
let this = self.project();
|
||||
AsyncWrite::poll_write(this.inner, cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
|
||||
let this = self.project();
|
||||
AsyncWrite::poll_flush(this.inner, cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context
|
||||
) -> Poll<Result<(), tokio::io::Error>> {
|
||||
let this = self.project();
|
||||
AsyncWrite::poll_shutdown(this.inner, cx)
|
||||
}
|
||||
|
||||
fn poll_write_buf<B: Buf>(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
buf: &mut B
|
||||
) -> Poll<Result<usize, tokio::io::Error>> where
|
||||
Self: Sized {
|
||||
let this = self.project();
|
||||
AsyncWrite::poll_write_buf(this.inner, cx, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "socks")]
|
||||
mod socks {
|
||||
use std::io;
|
||||
|
||||
Reference in New Issue
Block a user