Update to tokio 1.0, bytes 1.0 (#1076)

Co-authored-by: Wim Looman <git@nemo157.com>
Co-authored-by: Paolo Barbolini <paolo@paolo565.org>
This commit is contained in:
messense
2020-12-31 01:57:50 +08:00
committed by GitHub
parent 5ee4fe5ab6
commit a19eb34196
16 changed files with 173 additions and 219 deletions

View File

@@ -191,7 +191,7 @@ jobs:
strategy:
matrix:
rust: [1.39.0]
rust: [1.45.2]
steps:
- name: Checkout

View File

@@ -28,7 +28,7 @@ default = ["default-tls"]
# Note: this doesn't enable the 'native-tls' feature, which adds specific
# functionality for it.
default-tls = ["hyper-tls", "native-tls-crate", "__tls", "tokio-tls"]
default-tls = ["hyper-tls", "native-tls-crate", "__tls", "tokio-native-tls"]
# Enables native-tls specific functionality not available by default.
native-tls = ["default-tls"]
@@ -39,13 +39,13 @@ rustls-tls-manual-roots = ["__rustls"]
rustls-tls-webpki-roots = ["webpki-roots", "__rustls"]
rustls-tls-native-roots = ["rustls-native-certs", "__rustls"]
blocking = ["futures-util/io", "tokio/rt-threaded", "tokio/rt-core", "tokio/sync"]
blocking = ["futures-util/io", "tokio/rt-multi-thread", "tokio/sync"]
cookies = ["cookie_crate", "cookie_store", "time"]
gzip = ["async-compression", "async-compression/gzip"]
gzip = ["async-compression", "async-compression/gzip", "tokio-util"]
brotli = ["async-compression", "async-compression/brotli"]
brotli = ["async-compression", "async-compression/brotli", "tokio-util"]
json = ["serde_json"]
@@ -71,7 +71,7 @@ __internal_proxy_sys_no_cache = []
[dependencies]
http = "0.2"
url = "2.2"
bytes = "0.5"
bytes = "1.0"
serde = "1.0"
serde_urlencoded = "0.7"
mime_guess = "2.0"
@@ -83,29 +83,29 @@ base64 = "0.13"
encoding_rs = "0.8"
futures-core = { version = "0.3.0", default-features = false }
futures-util = { version = "0.3.0", default-features = false }
http-body = "0.3.0"
hyper = { version = "0.13.4", default-features = false, features = ["tcp"] }
http-body = "0.4.0"
hyper = { version = "0.14", default-features = false, features = ["tcp", "http1", "http2", "client"] }
lazy_static = "1.4"
log = "0.4"
mime = "0.3.7"
percent-encoding = "2.1"
tokio = { version = "0.2.5", default-features = false, features = ["tcp", "time"] }
tokio = { version = "1.0", default-features = false, features = ["net", "time"] }
pin-project-lite = "0.2.0"
ipnet = "2.3"
# Optional deps...
## default-tls
hyper-tls = { version = "0.4", optional = true }
hyper-tls = { version = "0.5", optional = true }
native-tls-crate = { version = "0.2", optional = true, package = "native-tls" }
tokio-tls = { version = "0.3.0", optional = true }
tokio-native-tls = { version = "0.3.0", optional = true }
# rustls-tls
hyper-rustls = { version = "0.21", default-features = false, optional = true }
rustls = { version = "0.18", features = ["dangerous_configuration"], optional = true }
tokio-rustls = { version = "0.14", optional = true }
webpki-roots = { version = "0.20", optional = true }
rustls-native-certs = { version = "0.4", optional = true }
hyper-rustls = { version = "0.22.1", default-features = false, optional = true }
rustls = { version = "0.19", features = ["dangerous_configuration"], optional = true }
tokio-rustls = { version = "0.22", optional = true }
webpki-roots = { version = "0.21", optional = true }
rustls-native-certs = { version = "0.5", optional = true }
## cookies
cookie_crate = { version = "0.14", package = "cookie", optional = true }
@@ -113,23 +113,23 @@ cookie_store = { version = "0.12", optional = true }
time = { version = "0.2.11", optional = true }
## compression
async-compression = { version = "0.3.0", default-features = false, features = ["stream"], optional = true }
async-compression = { version = "0.3.7", default-features = false, features = ["tokio"], optional = true }
tokio-util = { version = "0.6.0", default-features = false, features = ["codec", "io"], optional = true }
## socks
tokio-socks = { version = "0.3", optional = true }
tokio-socks = { version = "0.5", optional = true }
## trust-dns
trust-dns-resolver = { version = "0.19", optional = true }
trust-dns-resolver = { version = "0.20", optional = true }
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
env_logger = "0.7"
hyper = { version = "0.13", default-features = false, features = ["tcp", "stream"] }
env_logger = "0.8"
hyper = { version = "0.14", default-features = false, features = ["tcp", "stream", "http1", "http2", "client", "server"] }
serde = { version = "1.0", features = ["derive"] }
libflate = "1.0"
brotli_crate = { package = "brotli", version = "3.3.0" }
doc-comment = "0.3"
tokio = { version = "0.2.0", default-features = false, features = ["macros"] }
tokio = { version = "1.0", default-features = false, features = ["macros", "rt-multi-thread"] }
[target.'cfg(windows)'.dependencies]
winreg = "0.7"

View File

@@ -7,7 +7,7 @@ use bytes::Bytes;
use futures_core::Stream;
use http_body::Body as HttpBody;
use pin_project_lite::pin_project;
use tokio::time::Delay;
use tokio::time::Sleep;
/// An asynchronous request body.
pub struct Body {
@@ -27,7 +27,7 @@ enum Inner {
+ Sync,
>,
>,
timeout: Option<Delay>,
timeout: Option<Pin<Box<Sleep>>>,
},
}
@@ -103,7 +103,7 @@ impl Body {
}
}
pub(crate) fn response(body: hyper::Body, timeout: Option<Delay>) -> Body {
pub(crate) fn response(body: hyper::Body, timeout: Option<Pin<Box<Sleep>>>) -> Body {
Body {
inner: Inner::Streaming {
body: Box::pin(WrapHyper(body)),
@@ -217,7 +217,7 @@ impl HttpBody for ImplStream {
ref mut timeout,
} => {
if let Some(ref mut timeout) = timeout {
if let Poll::Ready(()) = Pin::new(timeout).poll(cx) {
if let Poll::Ready(()) = timeout.as_mut().poll(cx) {
return Poll::Ready(Some(Err(crate::error::body(crate::error::TimedOut))));
}
}

View File

@@ -26,7 +26,7 @@ use rustls::RootCertStore;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::time::Delay;
use tokio::time::Sleep;
use pin_project_lite::pin_project;
use log::debug;
@@ -96,7 +96,6 @@ struct Config {
#[cfg(feature = "__tls")]
tls: TlsBackend,
http2_only: bool,
http1_writev: Option<bool>,
http1_title_case_headers: bool,
http2_initial_stream_window_size: Option<u32>,
http2_initial_connection_window_size: Option<u32>,
@@ -151,7 +150,6 @@ impl ClientBuilder {
#[cfg(feature = "__tls")]
tls: TlsBackend::default(),
http2_only: false,
http1_writev: None,
http1_title_case_headers: false,
http2_initial_stream_window_size: None,
http2_initial_connection_window_size: None,
@@ -316,10 +314,6 @@ impl ClientBuilder {
builder.http2_only(true);
}
if let Some(http1_writev) = config.http1_writev {
builder.http1_writev(http1_writev);
}
if let Some(http2_initial_stream_window_size) = config.http2_initial_stream_window_size {
builder.http2_initial_stream_window_size(http2_initial_stream_window_size);
}
@@ -655,14 +649,6 @@ impl ClientBuilder {
self
}
/// Force hyper to use either queued(if true), or flattened(if false) write strategy
/// This may eliminate unnecessary cloning of buffers for some TLS backends
/// By default hyper will try to guess which strategy to use
pub fn http1_writev(mut self, writev: bool) -> ClientBuilder {
self.config.http1_writev = Some(writev);
self
}
/// Only use HTTP/2.
pub fn http2_prior_knowledge(mut self) -> ClientBuilder {
self.config.http2_only = true;
@@ -1103,7 +1089,8 @@ impl Client {
let timeout = timeout
.or(self.inner.request_timeout)
.map(tokio::time::delay_for);
.map(tokio::time::sleep)
.map(Box::pin);
*req.headers_mut() = headers.clone();
@@ -1317,7 +1304,7 @@ pin_project! {
#[pin]
in_flight: ResponseFuture,
#[pin]
timeout: Option<Delay>,
timeout: Option<Pin<Box<Sleep>>>,
}
}
@@ -1326,7 +1313,7 @@ impl PendingRequest {
self.project().in_flight
}
fn timeout(self: Pin<&mut Self>) -> Pin<&mut Option<Delay>> {
fn timeout(self: Pin<&mut Self>) -> Pin<&mut Option<Pin<Box<Sleep>>>> {
self.project().timeout
}

View File

@@ -4,10 +4,10 @@ use std::pin::Pin;
use std::task::{Context, Poll};
#[cfg(feature = "gzip")]
use async_compression::stream::GzipDecoder;
use async_compression::tokio::bufread::GzipDecoder;
#[cfg(feature = "brotli")]
use async_compression::stream::BrotliDecoder;
use async_compression::tokio::bufread::BrotliDecoder;
use bytes::Bytes;
use futures_core::Stream;
@@ -15,6 +15,11 @@ use futures_util::stream::Peekable;
use http::HeaderMap;
use hyper::body::HttpBody;
#[cfg(any(feature = "gzip", feature = "brotli"))]
use tokio_util::io::StreamReader;
#[cfg(any(feature = "gzip", feature = "brotli"))]
use tokio_util::codec::{BytesCodec, FramedRead};
use super::super::Body;
use crate::error;
@@ -39,11 +44,11 @@ enum Inner {
/// A `Gzip` decoder will uncompress the gzipped response content before returning it.
#[cfg(feature = "gzip")]
Gzip(GzipDecoder<Peekable<IoStream>>),
Gzip(FramedRead<GzipDecoder<StreamReader<Peekable<IoStream>, Bytes>>, BytesCodec>),
/// A `Brotli` decoder will uncompress the brotlied response content before returning it.
#[cfg(feature = "brotli")]
Brotli(BrotliDecoder<Peekable<IoStream>>),
Brotli(FramedRead<BrotliDecoder<StreamReader<Peekable<IoStream>, Bytes>>, BytesCodec>),
/// A decoder that doesn't have a value yet.
#[cfg(any(feature = "brotli", feature = "gzip"))]
@@ -229,7 +234,7 @@ impl Stream for Decoder {
#[cfg(feature = "gzip")]
Inner::Gzip(ref mut decoder) => {
return match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes))),
Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
None => Poll::Ready(None),
};
@@ -237,7 +242,7 @@ impl Stream for Decoder {
#[cfg(feature = "brotli")]
Inner::Brotli(ref mut decoder) => {
return match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes))),
Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.freeze()))),
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
None => Poll::Ready(None),
};
@@ -302,9 +307,9 @@ impl Future for Pending {
match self.1 {
#[cfg(feature = "brotli")]
DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(BrotliDecoder::new(_body)))),
DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(FramedRead::new(BrotliDecoder::new(StreamReader::new(_body)), BytesCodec::new())))),
#[cfg(feature = "gzip")]
DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(GzipDecoder::new(_body)))),
DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(FramedRead::new(GzipDecoder::new(StreamReader::new(_body)), BytesCodec::new())))),
}
}
}

View File

@@ -521,11 +521,7 @@ mod tests {
fn form_empty() {
let form = Form::new();
let mut rt = runtime::Builder::new()
.basic_scheduler()
.enable_all()
.build()
.expect("new rt");
let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt");
let body = form.stream().into_stream();
let s = body.map_ok(|try_c| try_c.to_vec()).try_concat();
@@ -572,11 +568,7 @@ mod tests {
--boundary\r\n\
Content-Disposition: form-data; name=\"key3\"; filename=\"filename\"\r\n\r\n\
value3\r\n--boundary--\r\n";
let mut rt = runtime::Builder::new()
.basic_scheduler()
.enable_all()
.build()
.expect("new rt");
let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt");
let body = form.stream().into_stream();
let s = body.map(|try_c| try_c.map(|r| r.to_vec())).try_concat();
@@ -603,11 +595,7 @@ mod tests {
\r\n\
value2\r\n\
--boundary--\r\n";
let mut rt = runtime::Builder::new()
.basic_scheduler()
.enable_all()
.build()
.expect("new rt");
let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt");
let body = form.stream().into_stream();
let s = body.map(|try_c| try_c.map(|r| r.to_vec())).try_concat();

View File

@@ -1,6 +1,7 @@
use std::borrow::Cow;
use std::fmt;
use std::net::SocketAddr;
use std::pin::Pin;
use bytes::Bytes;
use encoding_rs::{Encoding, UTF_8};
@@ -12,7 +13,7 @@ use mime::Mime;
use serde::de::DeserializeOwned;
#[cfg(feature = "json")]
use serde_json;
use tokio::time::Delay;
use tokio::time::Sleep;
use url::Url;
use super::body::Body;
@@ -37,7 +38,7 @@ impl Response {
res: hyper::Response<hyper::Body>,
url: Url,
accepts: Accepts,
timeout: Option<Delay>,
timeout: Option<Pin<Box<Sleep>>>,
) -> Response {
let (parts, body) = res.into_parts();
let status = parts.status;

View File

@@ -2,10 +2,11 @@ use std::fmt;
use std::fs::File;
use std::future::Future;
use std::io::{self, Cursor, Read};
use std::mem::{self, MaybeUninit};
use std::mem;
use std::ptr;
use bytes::Bytes;
use bytes::buf::UninitSlice;
use crate::async_impl;
@@ -289,14 +290,14 @@ async fn send_future(sender: Sender) -> Result<(), crate::Error> {
if buf.remaining_mut() == 0 {
buf.reserve(8192);
// zero out the reserved memory
let uninit = buf.chunk_mut();
unsafe {
let uninit = mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(buf.bytes_mut());
ptr::write_bytes(uninit.as_mut_ptr(), 0, uninit.len());
}
}
let bytes = unsafe {
mem::transmute::<&mut [MaybeUninit<u8>], &mut [u8]>(buf.bytes_mut())
mem::transmute::<&mut UninitSlice, &mut [u8]>(buf.chunk_mut())
};
match body.read(bytes) {
Ok(0) => {

View File

@@ -764,7 +764,7 @@ impl ClientHandle {
.name("reqwest-internal-sync-runtime".into())
.spawn(move || {
use tokio::runtime;
let mut rt = match runtime::Builder::new().basic_scheduler().enable_all().build().map_err(crate::error::builder) {
let rt = match runtime::Builder::new_current_thread().enable_all().build().map_err(crate::error::builder) {
Err(e) => {
if let Err(e) = spawn_tx.send(Err(e)) {
error!("Failed to communicate runtime creation failure: {:?}", e);

View File

@@ -67,10 +67,9 @@ fn enter() {
// Check we aren't already in a runtime
#[cfg(debug_assertions)]
{
tokio::runtime::Builder::new()
.core_threads(1)
tokio::runtime::Builder::new_current_thread()
.build()
.expect("build shell runtime")
.enter(|| {});
.enter();
}
}

View File

@@ -2,22 +2,21 @@ use hyper::service::Service;
use http::uri::{Scheme, Authority};
use http::Uri;
use hyper::client::connect::{Connected, Connection};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[cfg(feature = "native-tls-crate")]
use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
#[cfg(feature = "__tls")]
use http::header::HeaderValue;
use futures_util::future::Either;
use bytes::{Buf, BufMut};
use std::future::Future;
use std::io;
use std::io::IoSlice;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use std::mem::MaybeUninit;
use pin_project_lite::pin_project;
#[cfg(feature = "trust-dns")]
@@ -272,7 +271,7 @@ impl Connector {
.ok_or("no host in url")?
.to_string();
let conn = socks::connect(proxy, dst, dns).await?;
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
let io = tls_connector
.connect(&host, conn)
.await?;
@@ -342,13 +341,13 @@ impl Connector {
http.set_nodelay(true);
}
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
let io = http.call(dst).await?;
if let hyper_tls::MaybeHttpsStream::Https(stream) = &io {
if !self.nodelay {
stream.get_ref().set_nodelay(false)?;
stream.get_ref().get_ref().get_ref().set_nodelay(false)?;
}
}
@@ -411,7 +410,7 @@ impl Connector {
let host = dst.host().to_owned();
let port = dst.port().map(|p| p.as_u16()).unwrap_or(443);
let http = http.clone();
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
let conn = http.call(proxy_dst).await?;
log::trace!("tunneling HTTPS over proxy");
@@ -424,7 +423,7 @@ impl Connector {
self.user_agent.clone(),
auth
).await?;
let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
let io = tls_connector
.connect(&host.ok_or("no host in url")?, tunneled)
.await?;
@@ -569,30 +568,11 @@ impl AsyncRead for Conn {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8]
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>
) -> Poll<io::Result<()>> {
let this = self.project();
AsyncRead::poll_read(this.inner, cx, buf)
}
unsafe fn prepare_uninitialized_buffer(
&self,
buf: &mut [MaybeUninit<u8>]
) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<io::Result<usize>>
where
Self: Sized
{
let this = self.project();
AsyncRead::poll_read_buf(this.inner, cx, buf)
}
}
impl AsyncWrite for Conn {
@@ -605,6 +585,19 @@ impl AsyncWrite for Conn {
AsyncWrite::poll_write(this.inner, cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>]
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let this = self.project();
AsyncWrite::poll_flush(this.inner, cx)
@@ -617,16 +610,6 @@ impl AsyncWrite for Conn {
let this = self.project();
AsyncWrite::poll_shutdown(this.inner, cx)
}
fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<Result<usize, io::Error>> where
Self: Sized {
let this = self.project();
AsyncWrite::poll_write_buf(this.inner, cx, buf)
}
}
pub(crate) type Connecting =
@@ -715,13 +698,11 @@ fn tunnel_eof() -> BoxError {
#[cfg(feature = "default-tls")]
mod native_tls_conn {
use std::mem::MaybeUninit;
use std::{pin::Pin, task::{Context, Poll}};
use bytes::{Buf, BufMut};
use std::{pin::Pin, task::{Context, Poll}, io::{self, IoSlice}};
use hyper::client::connect::{Connected, Connection};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tls::TlsStream;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_native_tls::TlsStream;
pin_project! {
@@ -732,7 +713,7 @@ mod native_tls_conn {
impl<T: Connection + AsyncRead + AsyncWrite + Unpin> Connection for NativeTlsConn<T> {
fn connected(&self) -> Connected {
self.inner.get_ref().connected()
self.inner.get_ref().get_ref().get_ref().connected()
}
}
@@ -740,30 +721,11 @@ mod native_tls_conn {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8]
) -> Poll<tokio::io::Result<usize>> {
buf: &mut ReadBuf<'_>
) -> Poll<tokio::io::Result<()>> {
let this = self.project();
AsyncRead::poll_read(this.inner, cx, buf)
}
unsafe fn prepare_uninitialized_buffer(
&self,
buf: &mut [MaybeUninit<u8>]
) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<tokio::io::Result<usize>>
where
Self: Sized
{
let this = self.project();
AsyncRead::poll_read_buf(this.inner, cx, buf)
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsConn<T> {
@@ -776,6 +738,19 @@ mod native_tls_conn {
AsyncWrite::poll_write(this.inner, cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>]
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
let this = self.project();
AsyncWrite::poll_flush(this.inner, cx)
@@ -788,28 +763,16 @@ mod native_tls_conn {
let this = self.project();
AsyncWrite::poll_shutdown(this.inner, cx)
}
fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<Result<usize, tokio::io::Error>> where
Self: Sized {
let this = self.project();
AsyncWrite::poll_write_buf(this.inner, cx, buf)
}
}
}
#[cfg(feature = "__rustls")]
mod rustls_tls_conn {
use rustls::Session;
use std::mem::MaybeUninit;
use std::{pin::Pin, task::{Context, Poll}};
use bytes::{Buf, BufMut};
use std::{pin::Pin, task::{Context, Poll}, io::{self, IoSlice}};
use hyper::client::connect::{Connected, Connection};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::client::TlsStream;
@@ -833,30 +796,11 @@ mod rustls_tls_conn {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8]
) -> Poll<tokio::io::Result<usize>> {
buf: &mut ReadBuf<'_>
) -> Poll<tokio::io::Result<()>> {
let this = self.project();
AsyncRead::poll_read(this.inner, cx, buf)
}
unsafe fn prepare_uninitialized_buffer(
&self,
buf: &mut [MaybeUninit<u8>]
) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
fn poll_read_buf<B: BufMut>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<tokio::io::Result<usize>>
where
Self: Sized
{
let this = self.project();
AsyncRead::poll_read_buf(this.inner, cx, buf)
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for RustlsTlsConn<T> {
@@ -869,6 +813,19 @@ mod rustls_tls_conn {
AsyncWrite::poll_write(this.inner, cx, buf)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>]
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
AsyncWrite::poll_write_vectored(this.inner, cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
let this = self.project();
AsyncWrite::poll_flush(this.inner, cx)
@@ -881,16 +838,6 @@ mod rustls_tls_conn {
let this = self.project();
AsyncWrite::poll_shutdown(this.inner, cx)
}
fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut B
) -> Poll<Result<usize, tokio::io::Error>> where
Self: Sized {
let this = self.project();
AsyncWrite::poll_write_buf(this.inner, cx, buf)
}
}
}
@@ -961,10 +908,11 @@ mod socks {
mod verbose {
use std::fmt;
use std::io::{self, IoSlice};
use std::pin::Pin;
use std::task::{Context, Poll};
use hyper::client::connect::{Connected, Connection};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pub(super) const OFF: Wrapper = Wrapper(false);
@@ -1000,12 +948,12 @@ mod verbose {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8]
) -> Poll<std::io::Result<usize>> {
buf: &mut ReadBuf<'_>
) -> Poll<std::io::Result<()>> {
match Pin::new(&mut self.inner).poll_read(cx, buf) {
Poll::Ready(Ok(n)) => {
log::trace!("{:08x} read: {:?}", self.id, Escape(&buf[..n]));
Poll::Ready(Ok(n))
Poll::Ready(Ok(())) => {
log::trace!("{:08x} read: {:?}", self.id, Escape(buf.filled()));
Poll::Ready(Ok(()))
},
Poll::Ready(Err(e)) => {
Poll::Ready(Err(e))
@@ -1033,6 +981,18 @@ mod verbose {
}
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>]
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
@@ -1137,7 +1097,7 @@ mod tests {
fn test_tunnel() {
let addr = mock_tunnel!();
let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt");
let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt");
let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string();
@@ -1152,7 +1112,7 @@ mod tests {
fn test_tunnel_eof() {
let addr = mock_tunnel!(b"HTTP/1.1 200 OK");
let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt");
let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt");
let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string();
@@ -1167,7 +1127,7 @@ mod tests {
fn test_tunnel_non_http_response() {
let addr = mock_tunnel!(b"foo bar baz hallo");
let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt");
let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt");
let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string();
@@ -1188,7 +1148,7 @@ mod tests {
"
);
let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt");
let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt");
let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string();
@@ -1207,7 +1167,7 @@ mod tests {
"Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n"
);
let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt");
let rt = runtime::Builder::new_current_thread().enable_all().build().expect("new rt");
let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string();

View File

@@ -3,6 +3,7 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{self, Poll};
use std::io;
use std::net::SocketAddr;
use hyper::client::connect::dns as hyper_dns;
use hyper::service::Service;
@@ -10,7 +11,7 @@ use tokio::sync::Mutex;
use trust_dns_resolver::{
config::{ResolverConfig, ResolverOpts},
lookup_ip::LookupIpIntoIter,
system_conf, AsyncResolver, TokioConnection, TokioConnectionProvider,
system_conf, AsyncResolver, TokioConnection, TokioConnectionProvider, TokioHandle
};
use crate::error::BoxError;
@@ -26,6 +27,10 @@ pub(crate) struct TrustDnsResolver {
state: Arc<Mutex<State>>,
}
pub(crate) struct SocketAddrs {
iter: LookupIpIntoIter,
}
enum State {
Init,
Ready(SharedResolver),
@@ -47,7 +52,7 @@ impl TrustDnsResolver {
}
impl Service<hyper_dns::Name> for TrustDnsResolver {
type Response = LookupIpIntoIter;
type Response = SocketAddrs;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
@@ -62,7 +67,7 @@ impl Service<hyper_dns::Name> for TrustDnsResolver {
let resolver = match &*lock {
State::Init => {
let resolver = new_resolver(tokio::runtime::Handle::current()).await?;
let resolver = new_resolver().await?;
*lock = State::Ready(resolver.clone());
resolver
},
@@ -74,18 +79,24 @@ impl Service<hyper_dns::Name> for TrustDnsResolver {
drop(lock);
let lookup = resolver.lookup_ip(name.as_str()).await?;
Ok(lookup.into_iter())
Ok(SocketAddrs { iter: lookup.into_iter() })
})
}
}
/// Takes a `Handle` argument as an indicator that it must be called from
/// within the context of a Tokio runtime.
async fn new_resolver(handle: tokio::runtime::Handle) -> Result<SharedResolver, BoxError> {
impl Iterator for SocketAddrs {
type Item = SocketAddr;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|ip_addr| SocketAddr::new(ip_addr, 0))
}
}
async fn new_resolver() -> Result<SharedResolver, BoxError> {
let (config, opts) = SYSTEM_CONF
.as_ref()
.expect("can't construct TrustDnsResolver if SYSTEM_CONF is error")
.clone();
let resolver = AsyncResolver::new(config, opts, handle).await?;
let resolver = AsyncResolver::new(config, opts, TokioHandle)?;
Ok(Arc::new(resolver))
}

View File

@@ -282,7 +282,9 @@ fn test_blocking_inside_a_runtime() {
let url = format!("http://{}/text", server.addr());
let mut rt = tokio::runtime::Builder::new().build().expect("new rt");
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.expect("new rt");
rt.block_on(async move {
let _should_panic = reqwest::blocking::get(&url);

View File

@@ -155,14 +155,15 @@ fn test_redirect_307_does_not_try_if_reader_cannot_reset() {
async fn test_redirect_removes_sensitive_headers() {
use tokio::sync::watch;
let (tx, rx) = watch::channel(None);
let (tx, rx) = watch::channel::<Option<std::net::SocketAddr>>(None);
let end_server = server::http(move |req| {
let mut rx = rx.clone();
async move {
assert_eq!(req.headers().get("cookie"), None);
let mid_addr = rx.recv().await.unwrap().unwrap();
rx.changed().await.unwrap();
let mid_addr = rx.borrow().unwrap();
assert_eq!(
req.headers()["referer"],
format!("http://{}/sensitive", mid_addr)
@@ -182,7 +183,7 @@ async fn test_redirect_removes_sensitive_headers() {
.unwrap()
});
tx.broadcast(Some(mid_server.addr())).unwrap();
tx.send(Some(mid_server.addr())).unwrap();
reqwest::Client::builder()
.build()

View File

@@ -44,8 +44,7 @@ where
{
//Spawn new runtime in thread to prevent reactor execution context conflict
thread::spawn(move || {
let mut rt = runtime::Builder::new()
.basic_scheduler()
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("new rt");

View File

@@ -11,7 +11,7 @@ async fn client_timeout() {
let server = server::http(move |_req| {
async {
// delay returning the response
tokio::time::delay_for(Duration::from_secs(2)).await;
tokio::time::sleep(Duration::from_secs(2)).await;
http::Response::default()
}
});
@@ -38,7 +38,7 @@ async fn request_timeout() {
let server = server::http(move |_req| {
async {
// delay returning the response
tokio::time::delay_for(Duration::from_secs(2)).await;
tokio::time::sleep(Duration::from_secs(2)).await;
http::Response::default()
}
});
@@ -94,7 +94,7 @@ async fn response_timeout() {
async {
// immediate response, but delayed body
let body = hyper::Body::wrap_stream(futures_util::stream::once(async {
tokio::time::delay_for(Duration::from_secs(2)).await;
tokio::time::sleep(Duration::from_secs(2)).await;
Ok::<_, std::convert::Infallible>("Hello")
}));
@@ -134,7 +134,7 @@ fn timeout_closes_connection() {
let server = server::http(move |_req| {
async {
// delay returning the response
tokio::time::delay_for(Duration::from_secs(2)).await;
tokio::time::sleep(Duration::from_secs(2)).await;
http::Response::default()
}
});
@@ -158,7 +158,7 @@ fn timeout_blocking_request() {
let server = server::http(move |_req| {
async {
// delay returning the response
tokio::time::delay_for(Duration::from_secs(2)).await;
tokio::time::sleep(Duration::from_secs(2)).await;
http::Response::default()
}
});
@@ -191,7 +191,7 @@ fn write_timeout_large_body() {
let server = server::http(move |_req| {
async {
// delay returning the response
tokio::time::delay_for(Duration::from_secs(2)).await;
tokio::time::sleep(Duration::from_secs(2)).await;
http::Response::default()
}
});