Update to tokio 1.0, bytes 1.0 (#1076)

Co-authored-by: Wim Looman <git@nemo157.com>
Co-authored-by: Paolo Barbolini <paolo@paolo565.org>
This commit is contained in:
messense
2020-12-31 01:57:50 +08:00
committed by GitHub
parent 5ee4fe5ab6
commit a19eb34196
16 changed files with 173 additions and 219 deletions

View File

@@ -2,22 +2,21 @@ use hyper::service::Service;
use http::uri::{Scheme, Authority};
use http::Uri;
use hyper::client::connect::{Connected, Connection};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[cfg(feature = "native-tls-crate")]
use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
#[cfg(feature = "__tls")]
use http::header::HeaderValue;
use futures_util::future::Either;
use bytes::{Buf, BufMut};
use std::future::Future;
use std::io;
use std::io::IoSlice;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use std::mem::MaybeUninit;
use pin_project_lite::pin_project;
#[cfg(feature = "trust-dns")]
@@ -272,7 +271,7 @@ impl Connector {
.ok_or("no host in url")?
.to_string();
let conn = socks::connect(proxy, dst, dns).await?;
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
let io = tls_connector
.connect(&host, conn)
.await?;
@@ -342,13 +341,13 @@ impl Connector {
http.set_nodelay(true);
}
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
let io = http.call(dst).await?;
if let hyper_tls::MaybeHttpsStream::Https(stream) = &io {
if !self.nodelay {
stream.get_ref().set_nodelay(false)?;
stream.get_ref().get_ref().get_ref().set_nodelay(false)?;
}
}
@@ -411,7 +410,7 @@ impl Connector {
let host = dst.host().to_owned();
let port = dst.port().map(|p| p.as_u16()).unwrap_or(443);
let http = http.clone();
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
let conn = http.call(proxy_dst).await?;
log::trace!("tunneling HTTPS over proxy");
@@ -424,7 +423,7 @@ impl Connector {
self.user_agent.clone(),
auth
).await?;
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
let io = tls_connector
.connect(&host.ok_or("no host in url")?, tunneled)
.await?;
@@ -569,30 +568,11 @@ impl AsyncRead for Conn {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8]
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>
) -> Poll<io::Result<()>> {
let this = self.project();
AsyncRead::poll_read(this.inner, cx, buf)
}
unsafe fn prepare_uninitialized_buffer(
&self,
buf: &mut [MaybeUninit<u8>]
) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<io::Result<usize>>
where
Self: Sized
{
let this = self.project();
AsyncRead::poll_read_buf(this.inner, cx, buf)
}
}
impl AsyncWrite for Conn {
@@ -605,6 +585,19 @@ impl AsyncWrite for Conn {
AsyncWrite::poll_write(this.inner, cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>]
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let this = self.project();
AsyncWrite::poll_flush(this.inner, cx)
@@ -617,16 +610,6 @@ impl AsyncWrite for Conn {
let this = self.project();
AsyncWrite::poll_shutdown(this.inner, cx)
}
fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<Result<usize, io::Error>> where
Self: Sized {
let this = self.project();
AsyncWrite::poll_write_buf(this.inner, cx, buf)
}
}
pub(crate) type Connecting =
@@ -715,13 +698,11 @@ fn tunnel_eof() -> BoxError {
#[cfg(feature = "default-tls")]
mod native_tls_conn {
use std::mem::MaybeUninit;
use std::{pin::Pin, task::{Context, Poll}};
use bytes::{Buf, BufMut};
use std::{pin::Pin, task::{Context, Poll}, io::{self, IoSlice}};
use hyper::client::connect::{Connected, Connection};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tls::TlsStream;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_native_tls::TlsStream;
pin_project! {
@@ -732,7 +713,7 @@ mod native_tls_conn {
impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for NativeTlsConn<T> {
fn connected(&self) -> Connected {
self.inner.get_ref().connected()
self.inner.get_ref().get_ref().get_ref().connected()
}
}
@@ -740,30 +721,11 @@ mod native_tls_conn {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8]
) -> Poll<tokio::io::Result<usize>> {
buf: &mut ReadBuf<'_>
) -> Poll<tokio::io::Result<()>> {
let this = self.project();
AsyncRead::poll_read(this.inner, cx, buf)
}
unsafe fn prepare_uninitialized_buffer(
&self,
buf: &mut [MaybeUninit<u8>]
) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<tokio::io::Result<usize>>
where
Self: Sized
{
let this = self.project();
AsyncRead::poll_read_buf(this.inner, cx, buf)
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsConn<T> {
@@ -776,6 +738,19 @@ mod native_tls_conn {
AsyncWrite::poll_write(this.inner, cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>]
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
let this = self.project();
AsyncWrite::poll_flush(this.inner, cx)
@@ -788,28 +763,16 @@ mod native_tls_conn {
let this = self.project();
AsyncWrite::poll_shutdown(this.inner, cx)
}
fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<Result<usize, tokio::io::Error>> where
Self: Sized {
let this = self.project();
AsyncWrite::poll_write_buf(this.inner, cx, buf)
}
}
}
#[cfg(feature = "__rustls")]
mod rustls_tls_conn {
use rustls::Session;
use std::mem::MaybeUninit;
use std::{pin::Pin, task::{Context, Poll}};
use bytes::{Buf, BufMut};
use std::{pin::Pin, task::{Context, Poll}, io::{self, IoSlice}};
use hyper::client::connect::{Connected, Connection};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::client::TlsStream;
@@ -833,30 +796,11 @@ mod rustls_tls_conn {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8]
) -> Poll<tokio::io::Result<usize>> {
buf: &mut ReadBuf<'_>
) -> Poll<tokio::io::Result<()>> {
let this = self.project();
AsyncRead::poll_read(this.inner, cx, buf)
}
unsafe fn prepare_uninitialized_buffer(
&self,
buf: &mut [MaybeUninit<u8>]
) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<tokio::io::Result<usize>>
where
Self: Sized
{
let this = self.project();
AsyncRead::poll_read_buf(this.inner, cx, buf)
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for RustlsTlsConn<T> {
@@ -869,6 +813,19 @@ mod rustls_tls_conn {
AsyncWrite::poll_write(this.inner, cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>]
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
let this = self.project();
AsyncWrite::poll_flush(this.inner, cx)
@@ -881,16 +838,6 @@ mod rustls_tls_conn {
let this = self.project();
AsyncWrite::poll_shutdown(this.inner, cx)
}
fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<Result<usize, tokio::io::Error>> where
Self: Sized {
let this = self.project();
AsyncWrite::poll_write_buf(this.inner, cx, buf)
}
}
}
@@ -961,10 +908,11 @@ mod socks {
mod verbose {
use std::fmt;
use std::io::{self, IoSlice};
use std::pin::Pin;
use std::task::{Context, Poll};
use hyper::client::connect::{Connected, Connection};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pub(super) const OFF: Wrapper = Wrapper(false);
@@ -1000,12 +948,12 @@ mod verbose {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8]
) -> Poll<std::io::Result<usize>> {
buf: &mut ReadBuf<'_>
) -> Poll<std::io::Result<()>> {
match Pin::new(&mut self.inner).poll_read(cx, buf) {
Poll::Ready(Ok(n)) => {
log::trace!("{:08x} read: {:?}", self.id, Escape(&buf[..n]));
Poll::Ready(Ok(n))
Poll::Ready(Ok(())) => {
log::trace!("{:08x} read: {:?}", self.id, Escape(buf.filled()));
Poll::Ready(Ok(()))
},
Poll::Ready(Err(e)) => {
Poll::Ready(Err(e))
@@ -1033,6 +981,18 @@ mod verbose {
}
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>]
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
@@ -1137,7 +1097,7 @@ mod tests {
fn test_tunnel() {
let addr = mock_tunnel!();
let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt");
let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt");
let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string();
@@ -1152,7 +1112,7 @@ mod tests {
fn test_tunnel_eof() {
let addr = mock_tunnel!(b"HTTP/1.1 200 OK");
let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt");
let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt");
let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string();
@@ -1167,7 +1127,7 @@ mod tests {
fn test_tunnel_non_http_response() {
let addr = mock_tunnel!(b"foo bar baz hallo");
let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt");
let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt");
let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string();
@@ -1188,7 +1148,7 @@ mod tests {
"
);
let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt");
let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt");
let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string();
@@ -1207,7 +1167,7 @@ mod tests {
"Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n"
);
let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt");
let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt");
let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string();