Re-enable rustls (#747)
This commit is contained in:
		
							
								
								
									
										129
									
								
								src/connect.rs
									
									
									
									
									
								
							
							
						
						
									
										129
									
								
								src/connect.rs
									
									
									
									
									
								
							| @@ -25,6 +25,8 @@ use crate::proxy::{Proxy, ProxyScheme}; | ||||
| use crate::error::BoxError; | ||||
| #[cfg(feature = "default-tls")] | ||||
| use self::native_tls_conn::NativeTlsConn; | ||||
| #[cfg(feature = "rustls-tls")] | ||||
| use self::rustls_tls_conn::RustlsTlsConn; | ||||
|  | ||||
| //#[cfg(feature = "trust-dns")] | ||||
| //type HttpConnector = hyper::client::HttpConnector<TrustDnsResolver>; | ||||
| @@ -244,12 +246,13 @@ impl Connector { | ||||
|                 // Disable Nagle's algorithm for TLS handshake | ||||
|                 // | ||||
|                 // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES | ||||
|                 http.set_nodelay(no_delay || (dst.scheme() == Some(&Scheme::HTTPS))); | ||||
|                 http.set_nodelay(self.nodelay || (dst.scheme() == Some(&Scheme::HTTPS))); | ||||
|  | ||||
|                 let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone())); | ||||
|                 let io = http.call(dst).await?; | ||||
|  | ||||
|                 let http = hyper_rustls::HttpsConnector::from((http, tls.clone())); | ||||
|                 let io = http.connect(dst).await?; | ||||
|                 if let hyper_rustls::MaybeHttpsStream::Https(stream) = &io { | ||||
|                     if !no_delay { | ||||
|                     if !self.nodelay { | ||||
|                         let (io, _) = stream.get_ref(); | ||||
|                         io.set_nodelay(false)?; | ||||
|                     } | ||||
| @@ -320,35 +323,32 @@ impl Connector { | ||||
|                 tls_proxy, | ||||
|             } => { | ||||
|                 if dst.scheme() == Some(&Scheme::HTTPS) { | ||||
|                     use rustls::Session; | ||||
|                     use tokio_rustls::webpki::DNSNameRef; | ||||
|                     use tokio_rustls::TlsConnector as RustlsConnector; | ||||
|  | ||||
|                     let host = dst.host().to_owned(); | ||||
|                     let port = dst.port().unwrap_or(443); | ||||
|                     let host = dst.host() | ||||
|                         .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))? | ||||
|                         .to_string(); | ||||
|                     let port = dst.port().map(|r| r.as_u16()).unwrap_or(443); | ||||
|                     let mut http = http.clone(); | ||||
|                     http.set_nodelay(no_delay); | ||||
|                     let http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone())); | ||||
|                     http.set_nodelay(self.nodelay); | ||||
|                     let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone())); | ||||
|                     let tls = tls.clone(); | ||||
|                     let (conn, connected) = http.connect(proxy_dst).await?; | ||||
|                     let conn = http.call(proxy_dst).await?; | ||||
|                     log::trace!("tunneling HTTPS over proxy"); | ||||
|                     let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) | ||||
|                         .map(|dnsname| dnsname.to_owned()) | ||||
|                         .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")); | ||||
|                     let tunneled = tunnel(conn, host, port, auth).await?; | ||||
|                     let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?; | ||||
|                     let dnsname = maybe_dnsname?; | ||||
|                     let io = RustlsConnector::from(tls) | ||||
|                         .connect(dnsname.as_ref(), tunneled) | ||||
|                         .await | ||||
|                         .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; | ||||
|                     let connected = if io.get_ref().1.get_alpn_protocol() == Some(b"h2") { | ||||
|                         connected.negotiated_h2() | ||||
|                     } else { | ||||
|                         connected | ||||
|                     }; | ||||
|  | ||||
|                     return Ok(Conn { | ||||
|                         inner: Box::new(io), | ||||
|                         connected: Connected::new(), | ||||
|                         inner: Box::new(RustlsTlsConn { inner: io }), | ||||
|                         is_proxy: false, | ||||
|                     }); | ||||
|                 } | ||||
|             } | ||||
| @@ -682,6 +682,99 @@ mod native_tls_conn { | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[cfg(feature = "rustls-tls")] | ||||
| mod rustls_tls_conn { | ||||
|     use rustls::Session; | ||||
|     use std::mem::MaybeUninit; | ||||
|     use std::{pin::Pin, task::{Context, Poll}}; | ||||
|     use bytes::{Buf, BufMut}; | ||||
|     use hyper::client::connect::{Connected, Connection}; | ||||
|     use pin_project_lite::pin_project; | ||||
|     use tokio::io::{AsyncRead, AsyncWrite}; | ||||
|     use tokio_rustls::client::TlsStream; | ||||
|  | ||||
|  | ||||
|     pin_project! { | ||||
|         pub(super) struct RustlsTlsConn<T> { | ||||
|             #[pin] pub(super) inner: TlsStream<T>, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RustlsTlsConn<T> { | ||||
|         fn connected(&self) -> Connected { | ||||
|             if self.inner.get_ref().1.get_alpn_protocol() == Some(b"h2") { | ||||
|                 self.inner.get_ref().0.connected().negotiated_h2() | ||||
|             } else { | ||||
|                 self.inner.get_ref().0.connected() | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsTlsConn<T> { | ||||
|         fn poll_read( | ||||
|             self: Pin<&mut Self>, | ||||
|             cx: &mut Context, | ||||
|             buf: &mut [u8] | ||||
|         ) -> Poll<tokio::io::Result<usize>> { | ||||
|             let this = self.project(); | ||||
|             AsyncRead::poll_read(this.inner, cx, buf) | ||||
|         } | ||||
|  | ||||
|         unsafe fn prepare_uninitialized_buffer( | ||||
|             &self, | ||||
|             buf: &mut [MaybeUninit<u8>] | ||||
|         ) -> bool { | ||||
|             self.inner.prepare_uninitialized_buffer(buf) | ||||
|         } | ||||
|  | ||||
|         fn poll_read_buf<B: BufMut>( | ||||
|             self: Pin<&mut Self>, | ||||
|             cx: &mut Context, | ||||
|             buf: &mut B | ||||
|         ) -> Poll<tokio::io::Result<usize>> | ||||
|             where | ||||
|                 Self: Sized | ||||
|         { | ||||
|             let this = self.project(); | ||||
|             AsyncRead::poll_read_buf(this.inner, cx, buf) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for RustlsTlsConn<T> { | ||||
|         fn poll_write( | ||||
|             self: Pin<&mut Self>, | ||||
|             cx: &mut Context, | ||||
|             buf: &[u8] | ||||
|         ) -> Poll<Result<usize, tokio::io::Error>> { | ||||
|             let this = self.project(); | ||||
|             AsyncWrite::poll_write(this.inner, cx, buf) | ||||
|         } | ||||
|  | ||||
|         fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> { | ||||
|             let this = self.project(); | ||||
|             AsyncWrite::poll_flush(this.inner, cx) | ||||
|         } | ||||
|  | ||||
|         fn poll_shutdown( | ||||
|             self: Pin<&mut Self>, | ||||
|             cx: &mut Context | ||||
|         ) -> Poll<Result<(), tokio::io::Error>> { | ||||
|             let this = self.project(); | ||||
|             AsyncWrite::poll_shutdown(this.inner, cx) | ||||
|         } | ||||
|  | ||||
|         fn poll_write_buf<B: Buf>( | ||||
|             self: Pin<&mut Self>, | ||||
|             cx: &mut Context, | ||||
|             buf: &mut B | ||||
|         ) -> Poll<Result<usize, tokio::io::Error>> where | ||||
|             Self: Sized { | ||||
|             let this = self.project(); | ||||
|             AsyncWrite::poll_write_buf(this.inner, cx, buf) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[cfg(feature = "socks")] | ||||
| mod socks { | ||||
|     use std::io; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user