From ba7b2a754eab0d79817ea8551d0803806ae8af7d Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Fri, 6 Sep 2019 17:22:56 -0700 Subject: [PATCH] refactor all to async/await (#617) Co-authored-by: Danny Browning Co-authored-by: Daniel Eades --- .appveyor.yml | 2 +- .travis.yml | 56 +- Cargo.toml | 60 +- examples/async.rs | 35 +- examples/async_multiple_requests.rs | 37 +- examples/async_stream.rs | 74 --- examples/json_dynamic.rs | 3 +- examples/simple.rs | 5 +- examples_disabled/async_stream.rs | 101 +++ src/async_impl/body.rs | 87 ++- src/async_impl/client.rs | 90 ++- src/async_impl/decoder.rs | 252 ++----- src/{async_impl.rs => async_impl/mod.rs} | 2 +- src/async_impl/multipart.rs | 46 +- src/async_impl/request.rs | 28 +- src/async_impl/response.rs | 82 ++- src/body.rs | 140 ++-- src/client.rs | 167 ++--- src/connect.rs | 805 +++++++++-------------- src/dns.rs | 2 +- src/error.rs | 61 +- src/into_url.rs | 2 +- src/lib.rs | 9 +- src/multipart.rs | 3 +- src/response.rs | 56 +- src/tls.rs | 18 +- src/wait.rs | 132 ++-- tests/async.rs | 161 ++--- tests/client.rs | 14 +- tests/timeouts.rs | 6 +- 30 files changed, 1106 insertions(+), 1430 deletions(-) delete mode 100644 examples/async_stream.rs create mode 100644 examples_disabled/async_stream.rs rename src/{async_impl.rs => async_impl/mod.rs} (85%) diff --git a/.appveyor.yml b/.appveyor.yml index 9d0f3dd..a20323b 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -8,7 +8,7 @@ environment: MINGW_PATH: 'C:\MinGW\bin' install: - curl -sSf -o rustup-init.exe https://win.rustup.rs/ - - rustup-init.exe -y --default-host %TARGET% + - rustup-init.exe -y --default-toolchain nightly --default-host %TARGET% - set PATH=%PATH%;C:\Users\appveyor\.cargo\bin - if defined MINGW_PATH set PATH=%PATH%;%MINGW_PATH% - rustc -vV diff --git a/.travis.yml b/.travis.yml index 86dfb51..b6b87e1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,40 +1,44 @@ language: rust matrix: fast_finish: true - allow_failures: - - rust: nightly + #allow_failures: + # - rust: nightly include: - os: osx - rust: stable + rust: nightly + #rust: stable - - rust: stable - - rust: beta + #- rust: stable + #- rust: beta - rust: nightly # Disable default-tls - - rust: stable + #- rust: stable + - rust: nightly env: FEATURES="--no-default-features" # rustls-tls - - rust: stable - env: FEATURES="--no-default-features --features rustls-tls" + #- rust: stable + #- rust: nightly + # env: FEATURES="--no-default-features --features rustls-tls" # default-tls and rustls-tls - - rust: stable - env: FEATURES="--features rustls-tls" + #- rust: stable + #- rust: nightly + # env: FEATURES="--features rustls-tls" - # default-tls, rustls, and socks! - - rust: stable - env: FEATURES="--features rustls-tls,socks" + # socks + #- rust: stable + #- rust: nightly + # env: FEATURES="--features socks" - - rust: stable - env: FEATURES="--features hyper-011" - - - rust: stable - env: FEATURES="--features trust-dns" + #- rust: stable + #- rust: nightly + # env: FEATURES="--features trust-dns" # android - - rust: stable + #- rust: stable + - rust: nightly env: TARGET=aarch64-linux-android before_install: - wget https://dl.google.com/android/repository/android-ndk-r19c-linux-x86_64.zip; @@ -45,9 +49,16 @@ matrix: # disable default-tls feature since cross-compiling openssl is dragons script: cargo build --target "$TARGET" --no-default-features + # Check rustfmt + - name: "rustfmt check" + rust: stable + install: rustup component add rustfmt + script: cargo fmt -- --check + + # minimum version - - rust: 1.34.0 - script: cargo build + #- rust: 1.39.0 + # script: cargo build sudo: false dist: trusty @@ -55,9 +66,6 @@ dist: trusty env: global: - REQWEST_TEST_BODY_FULL=1 -before_script: - - rustup component add rustfmt script: - - cargo fmt -- --check - cargo build $FEATURES - cargo test -v $FEATURES -- --test-threads=1 diff --git a/Cargo.toml b/Cargo.toml index a3aa598..f714435 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,62 +20,66 @@ all-features = true base64 = "0.10" bytes = "0.4" encoding_rs = "0.8" -futures = "0.1.23" +futures-preview = { version = "=0.3.0-alpha.18" } http = "0.1.15" -hyper = "0.12.22" -flate2 = { version = "^1.0.7", default-features = false, features = ["rust_backend"] } +hyper = "=0.13.0-alpha.1" log = "0.4" mime = "0.3.7" mime_guess = "2.0" percent-encoding = "2.1" +tokio = { version = "=0.2.0-alpha.4", default-features = false, features = ["rt-full", "tcp"] } +tokio-executor = "=0.2.0-alpha.4" +url = "2.1" +uuid = { version = "0.7", features = ["v4"] } +time = "0.1.42" + +# TODO: candidates for optional features + +async-compression = { version = "0.1.0-alpha.4", default-features = false, features = ["gzip", "stream"] } +cookie_store = "0.9.0" +cookie = "0.12.0" serde = "1.0" serde_json = "1.0" serde_urlencoded = "0.6.1" -tokio = { version = "0.1.7", default-features = false, features = ["rt-full", "tcp"] } -tokio-executor = "0.1.4" # a minimum version so trust-dns-resolver compiles -tokio-io = "0.1" -tokio-threadpool = "0.1.8" # a minimum version so tokio compiles -tokio-timer = "0.2.6" # a minimum version so trust-dns-resolver compiles -url = "2.1" -uuid = { version = "0.7", features = ["v4"] } # Optional deps... -hyper-old-types = { version = "0.11", optional = true, features = ["compat"] } -hyper-rustls = { version = "^0.17.1", optional = true } -hyper-tls = { version = "0.3.2", optional = true } +## default-tls +hyper-tls = { version = "=0.4.0-alpha.1", optional = true } native-tls = { version = "0.2", optional = true } -rustls = { version = "0.16", features = ["dangerous_configuration"], optional = true } -socks = { version = "0.3.2", optional = true } -tokio-rustls = { version = "0.10", optional = true } -trust-dns-resolver = { version = "0.11", optional = true } -webpki-roots = { version = "0.17", optional = true } -cookie_store = "0.9.0" -cookie = "0.12.0" -time = "0.1.42" +tokio-tls = { version = "=0.3.0-alpha.4", optional = true } + +## rustls-tls +#hyper-rustls = { git = "https://github.com/dbcfd/hyper-rustls.git", branch = "master", optional = true } +#rustls = { version = "0.16", features = ["dangerous_configuration"], optional = true } +#tokio-rustls = { version = "=0.12.0-alpha.2", optional = true } +#webpki-roots = { version = "0.17", optional = true } + +## socks +#socks = { version = "0.3.2", optional = true } + +## trust-dns +#trust-dns-resolver = { version = "0.11", optional = true } [dev-dependencies] env_logger = "0.6" serde = { version = "1.0", features = ["derive"] } -tokio = { version = "0.1.7", default-features = false, features = ["rt-full", "tcp", "fs"] } -tokio-tcp = "0.1" libflate = "0.1" doc-comment = "0.3" bytes = "0.4" +tokio-fs = { version = "=0.2.0-alpha.4" } [features] default = ["default-tls"] tls = [] -default-tls = ["hyper-tls", "native-tls", "tls"] +default-tls = ["hyper-tls", "native-tls", "tls", "tokio-tls"] default-tls-vendored = ["default-tls", "native-tls/vendored"] -rustls-tls = ["hyper-rustls", "tokio-rustls", "webpki-roots", "rustls", "tls"] +#rustls-tls = ["hyper-rustls", "tokio-rustls", "webpki-roots", "rustls", "tls"] -trust-dns = ["trust-dns-resolver"] - -hyper-011 = ["hyper-old-types"] +#trust-dns = ["trust-dns-resolver"] [target.'cfg(windows)'.dependencies] winreg = "0.6" diff --git a/examples/async.rs b/examples/async.rs index fbdb771..c4dad41 100644 --- a/examples/async.rs +++ b/examples/async.rs @@ -1,29 +1,16 @@ #![deny(warnings)] -use futures::{Future, Stream}; -use reqwest::r#async::{Client, Decoder}; -use std::io::{self, Cursor}; -use std::mem; +use reqwest::r#async::Client; -fn fetch() -> impl Future { - Client::new() - .get("https://hyper.rs") - .send() - .and_then(|mut res| { - println!("{}", res.status()); +#[tokio::main] +async fn main() -> Result<(), reqwest::Error> { + let mut res = Client::new().get("https://hyper.rs").send().await?; - let body = mem::replace(res.body_mut(), Decoder::empty()); - body.concat2() - }) - .map_err(|err| println!("request error: {}", err)) - .map(|body| { - let mut body = Cursor::new(body); - let _ = io::copy(&mut body, &mut io::stdout()).map_err(|err| { - println!("stdout error: {}", err); - }); - }) -} - -fn main() { - tokio::run(fetch()); + println!("Status: {}", res.status()); + + let body = res.text().await?; + + println!("Body:\n\n{}", body); + + Ok(()) } diff --git a/examples/async_multiple_requests.rs b/examples/async_multiple_requests.rs index 2884023..d52a60f 100644 --- a/examples/async_multiple_requests.rs +++ b/examples/async_multiple_requests.rs @@ -1,8 +1,8 @@ #![deny(warnings)] -use futures::Future; use reqwest::r#async::{Client, Response}; use serde::Deserialize; +use std::future::Future; #[derive(Deserialize, Debug)] struct Slideshow { @@ -15,26 +15,27 @@ struct SlideshowContainer { slideshow: Slideshow, } -fn fetch() -> impl Future { +async fn into_json(f: F) -> Result +where + F: Future>, +{ + let mut resp = f.await?; + resp.json::().await +} + +#[tokio::main] +async fn main() -> Result<(), reqwest::Error> { let client = Client::new(); - let json = |mut res: Response| res.json::(); + let request1 = client.get("https://httpbin.org/json").send(); - let request1 = client.get("https://httpbin.org/json").send().and_then(json); + let request2 = client.get("https://httpbin.org/json").send(); - let request2 = client.get("https://httpbin.org/json").send().and_then(json); + let (try_json1, try_json2) = + futures::future::join(into_json(request1), into_json(request2)).await; - request1 - .join(request2) - .map(|(res1, res2)| { - println!("{:?}", res1); - println!("{:?}", res2); - }) - .map_err(|err| { - println!("stdout error: {}", err); - }) -} - -fn main() { - tokio::run(fetch()); + println!("{:?}", try_json1?); + println!("{:?}", try_json2?); + + Ok(()) } diff --git a/examples/async_stream.rs b/examples/async_stream.rs deleted file mode 100644 index 6ebb7b3..0000000 --- a/examples/async_stream.rs +++ /dev/null @@ -1,74 +0,0 @@ -#![deny(warnings)] - -use std::io::{self, Cursor}; -use std::mem; -use std::path::Path; - -use bytes::Bytes; -use futures::{try_ready, Async, Future, Poll, Stream}; -use reqwest::r#async::{Client, Decoder}; -use tokio::fs::File; -use tokio::io::AsyncRead; - -const CHUNK_SIZE: usize = 1024; - -struct FileSource { - inner: File, -} - -impl FileSource { - fn new(file: File) -> FileSource { - FileSource { inner: file } - } -} - -impl Stream for FileSource { - type Item = Bytes; - type Error = io::Error; - - fn poll(&mut self) -> Poll, Self::Error> { - let mut buf = [0; CHUNK_SIZE]; - let size = try_ready!(self.inner.poll_read(&mut buf)); - if size > 0 { - Ok(Async::Ready(Some(buf[0..size].into()))) - } else { - Ok(Async::Ready(None)) - } - } -} - -fn post

(path: P) -> impl Future -where - P: AsRef, -{ - File::open(path.as_ref().to_owned()) - .map_err(|err| println!("request error: {}", err)) - .and_then(|file| { - let source: Box + Send> = - Box::new(FileSource::new(file)); - - Client::new() - .post("https://httpbin.org/post") - .body(source) - .send() - .and_then(|mut res| { - println!("{}", res.status()); - - let body = mem::replace(res.body_mut(), Decoder::empty()); - body.concat2() - }) - .map_err(|err| println!("request error: {}", err)) - .map(|body| { - let mut body = Cursor::new(body); - let _ = io::copy(&mut body, &mut io::stdout()).map_err(|err| { - println!("stdout error: {}", err); - }); - }) - }) -} - -fn main() { - let pool = tokio_threadpool::ThreadPool::new(); - let path = concat!(env!("CARGO_MANIFEST_DIR"), "/LICENSE-APACHE"); - tokio::run(pool.spawn_handle(post(path))); -} diff --git a/examples/json_dynamic.rs b/examples/json_dynamic.rs index fd5d520..1a31f61 100644 --- a/examples/json_dynamic.rs +++ b/examples/json_dynamic.rs @@ -3,12 +3,11 @@ //! This is useful for some ad-hoc experiments and situations when you don't //! really care about the structure of the JSON and just need to display it or //! process it at runtime. -use serde_json::json; fn main() -> Result<(), reqwest::Error> { let echo_json: serde_json::Value = reqwest::Client::new() .post("https://jsonplaceholder.typicode.com/posts") - .json(&json!({ + .json(&serde_json::json!({ "title": "Reqwest.rs", "body": "https://docs.rs/reqwest", "userId": 1 diff --git a/examples/simple.rs b/examples/simple.rs index 23cd99d..73753da 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -1,6 +1,5 @@ -#![deny(warnings)] - //! `cargo run --example simple` +#![deny(warnings)] fn main() -> Result<(), Box> { env_logger::init(); @@ -13,7 +12,7 @@ fn main() -> Result<(), Box> { println!("Headers:\n{:?}", res.headers()); // copy the response body directly to stdout - std::io::copy(&mut res, &mut std::io::stdout())?; + res.copy_to(&mut std::io::stdout())?; println!("\n\nDone."); Ok(()) diff --git a/examples_disabled/async_stream.rs b/examples_disabled/async_stream.rs new file mode 100644 index 0000000..78d1db6 --- /dev/null +++ b/examples_disabled/async_stream.rs @@ -0,0 +1,101 @@ +#![deny(warnings)] +use std::io::{self, Cursor}; +use std::mem; +use std::path::Path; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::{Stream, TryStreamExt}; +use reqwest::r#async::{Body, Client, Decoder}; +use tokio_fs::File; +use tokio::io::AsyncRead; + +use failure::Fail; + +#[derive(Debug, Fail)] +pub enum Error { + #[fail(display = "Io Error")] + Io(#[fail(cause)] std::io::Error), + #[fail(display = "Reqwest error")] + Reqwest(#[fail(cause)] reqwest::Error), +} + +unsafe impl Send for Error {} +unsafe impl Sync for Error {} + +struct AsyncReadWrapper { + inner: T, +} + +impl AsyncReadWrapper { + fn inner(self: Pin<&mut Self>) -> Pin<&mut T> { + unsafe { + Pin::map_unchecked_mut(self, |x| &mut x.inner) + } + } +} + +impl Stream for AsyncReadWrapper + where T: AsyncRead +{ + type Item = Result>; + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let mut buf = vec![]; + loop { + let mut read_buf = vec![]; + match self.as_mut().inner().as_mut().poll_read(cx, &mut read_buf) { + Poll::Pending => { + if buf.is_empty() { + return Poll::Pending; + } else { + return Poll::Ready(Some(Ok(buf.into()))); + } + } + Poll::Ready(Err(e)) => { + return Poll::Ready(Some(Err(Error::Io(e).compat()))) + }, + Poll::Ready(Ok(n)) => { + buf.extend_from_slice(&read_buf[..n]); + if buf.is_empty() && n == 0 { + return Poll::Ready(None); + } else { + return Poll::Ready(Some(Ok(buf.into()))); + } + } + } + } + } +} + +async fn post

(path: P) -> Result<(), Error> +where + P: AsRef + Send + Unpin + 'static, +{ + let source = File::open(path) + .await.map_err(Error::Io)?; + let wrapper = AsyncReadWrapper { inner: source }; + let mut res = Client::new() + .post("https://httpbin.org/post") + .body(Body::wrap_stream(wrapper)) + .send() + .await.map_err(Error::Reqwest)?; + + println!("{}", res.status()); + + let body = mem::replace(res.body_mut(), Decoder::empty()); + let body: Result<_, _> = body.try_concat().await; + + let mut body = Cursor::new(body.map_err(Error::Reqwest)?); + io::copy(&mut body, &mut io::stdout()).map_err(Error::Io)?; + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + let path = concat!(env!("CARGO_MANIFEST_DIR"), "/LICENSE-APACHE"); + post(path).await +} \ No newline at end of file diff --git a/src/async_impl/body.rs b/src/async_impl/body.rs index bf5e7e0..dd4ac00 100644 --- a/src/async_impl/body.rs +++ b/src/async_impl/body.rs @@ -1,8 +1,10 @@ -use std::fmt; - use bytes::{Buf, Bytes}; -use futures::{try_ready, Async, Future, Poll, Stream}; +use futures::Stream; use hyper::body::Payload; +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use tokio::timer::Delay; /// An asynchronous `Stream`. @@ -22,10 +24,38 @@ impl Body { pub(crate) fn content_length(&self) -> Option { match self.inner { Inner::Reusable(ref bytes) => Some(bytes.len() as u64), - Inner::Hyper { ref body, .. } => body.content_length(), + Inner::Hyper { ref body, .. } => body.size_hint().exact(), } } + /// Wrap a futures `Stream` in a box inside `Body`. + /// + /// # Example + /// + /// ``` + /// # use reqwest::r#async::Body; + /// # use futures; + /// # fn main() { + /// let chunks: Vec> = vec![ + /// Ok("hello"), + /// Ok(" "), + /// Ok("world"), + /// ]; + /// + /// let stream = futures::stream::iter(chunks); + /// + /// let body = Body::wrap_stream(stream); + /// # } + /// ``` + pub fn wrap_stream(stream: S) -> Body + where + S: futures::TryStream + Send + Sync + 'static, + S::Error: Into>, + hyper::Chunk: From, + { + Body::wrap(hyper::body::Body::wrap_stream(stream)) + } + #[inline] pub(crate) fn response(body: hyper::Body, timeout: Option) -> Body { Body { @@ -65,38 +95,45 @@ impl Body { } } } + + fn inner(self: Pin<&mut Self>) -> Pin<&mut Inner> { + unsafe { Pin::map_unchecked_mut(self, |x| &mut x.inner) } + } } impl Stream for Body { - type Item = Chunk; - type Error = crate::Error; + type Item = Result; #[inline] - fn poll(&mut self) -> Poll, Self::Error> { - let opt = match self.inner { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let opt_try_chunk = match self.inner().get_mut() { Inner::Hyper { ref mut body, ref mut timeout, } => { if let Some(ref mut timeout) = timeout { - if let Async::Ready(()) = try_!(timeout.poll()) { - return Err(crate::error::timedout(None)); + if let Poll::Ready(()) = Pin::new(timeout).poll(cx) { + return Poll::Ready(Some(Err(crate::error::timedout(None)))); } } - try_ready!(body.poll_data().map_err(crate::error::from)) + futures::ready!(Pin::new(body).poll_data(cx)).map(|opt_chunk| { + opt_chunk + .map(|c| Chunk { inner: c }) + .map_err(crate::error::from) + }) } Inner::Reusable(ref mut bytes) => { - return if bytes.is_empty() { - Ok(Async::Ready(None)) + if bytes.is_empty() { + None } else { let chunk = Chunk::from_chunk(bytes.clone()); *bytes = Bytes::new(); - Ok(Async::Ready(Some(chunk))) - }; + Some(Ok(chunk)) + } } }; - Ok(Async::Ready(opt.map(|chunk| Chunk { inner: chunk }))) + Poll::Ready(opt_try_chunk) } } @@ -135,18 +172,6 @@ impl From<&'static str> for Body { } } -impl From + Send>> for Body -where - hyper::Chunk: From, - I: 'static, - E: std::error::Error + Send + Sync + 'static, -{ - #[inline] - fn from(s: Box + Send>) -> Body { - Body::wrap(hyper::Body::wrap_stream(s)) - } -} - /// A chunk of bytes for a `Body`. /// /// A `Chunk` can be treated like `&[u8]`. @@ -247,6 +272,12 @@ impl From for Chunk { } } +impl From for Bytes { + fn from(chunk: Chunk) -> Bytes { + chunk.inner.into() + } +} + impl From for hyper::Chunk { fn from(val: Chunk) -> hyper::Chunk { val.inner diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 1472da3..6581f4e 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -8,12 +8,14 @@ use crate::header::{ CONTENT_TYPE, LOCATION, PROXY_AUTHORIZATION, RANGE, REFERER, TRANSFER_ENCODING, USER_AGENT, }; use bytes::Bytes; -use futures::{Async, Future, Poll}; use http::Uri; use hyper::client::ResponseFuture; use mime; #[cfg(feature = "default-tls")] use native_tls::TlsConnector; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use tokio::{clock, timer::Delay}; use log::debug; @@ -540,7 +542,10 @@ impl Client { /// /// This method fails if there was an error while sending request, /// redirect loop was detected or redirect limit was exhausted. - pub fn execute(&self, request: Request) -> impl Future { + pub fn execute( + &self, + request: Request, + ) -> impl Future> { self.execute_request(request) } @@ -593,7 +598,7 @@ impl Client { let timeout = self .inner .request_timeout - .map(|dur| Delay::new(clock::now() + dur)); + .map(|dur| tokio::timer::delay(clock::now() + dur)); Pending { inner: PendingInner::Request(PendingRequest { @@ -691,43 +696,65 @@ struct PendingRequest { timeout: Option, } +impl PendingRequest { + fn in_flight(self: Pin<&mut Self>) -> Pin<&mut ResponseFuture> { + unsafe { Pin::map_unchecked_mut(self, |x| &mut x.in_flight) } + } + + fn timeout(self: Pin<&mut Self>) -> Pin<&mut Option> { + unsafe { Pin::map_unchecked_mut(self, |x| &mut x.timeout) } + } + + fn urls(self: Pin<&mut Self>) -> &mut Vec { + unsafe { &mut Pin::get_unchecked_mut(self).urls } + } + + fn headers(self: Pin<&mut Self>) -> &mut HeaderMap { + unsafe { &mut Pin::get_unchecked_mut(self).headers } + } +} + impl Pending { pub(super) fn new_err(err: crate::Error) -> Pending { Pending { inner: PendingInner::Error(Some(err)), } } + + fn inner(self: Pin<&mut Self>) -> Pin<&mut PendingInner> { + unsafe { Pin::map_unchecked_mut(self, |x| &mut x.inner) } + } } impl Future for Pending { - type Item = Response; - type Error = crate::Error; + type Output = Result; - fn poll(&mut self) -> Poll { - match self.inner { - PendingInner::Request(ref mut req) => req.poll(), - PendingInner::Error(ref mut err) => { - Err(err.take().expect("Pending error polled more than once")) - } + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let inner = self.inner(); + match inner.get_mut() { + PendingInner::Request(ref mut req) => Pin::new(req).poll(cx), + PendingInner::Error(ref mut err) => Poll::Ready(Err(err + .take() + .expect("Pending error polled more than once"))), } } } impl Future for PendingRequest { - type Item = Response; - type Error = crate::Error; + type Output = Result; - fn poll(&mut self) -> Poll { - if let Some(ref mut delay) = self.timeout { - if let Async::Ready(()) = try_!(delay.poll(), &self.url) { - return Err(crate::error::timedout(Some(self.url.clone()))); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(delay) = self.as_mut().timeout().as_mut().as_pin_mut() { + if let Poll::Ready(()) = delay.poll(cx) { + return Poll::Ready(Err(crate::error::timedout(Some(self.url.clone())))); } } loop { - let res = match try_!(self.in_flight.poll(), &self.url) { - Async::Ready(res) => res, - Async::NotReady => return Ok(Async::NotReady), + let res = match self.as_mut().in_flight().as_mut().poll(cx) { + Poll::Ready(Err(e)) => return Poll::Ready(url_error!(e, &self.url)), + Poll::Ready(Ok(res)) => res, + Poll::Pending => return Poll::Pending, }; if let Some(store_wrapper) = self.client.cookie_store.as_ref() { let mut store = store_wrapper.write().unwrap(); @@ -795,7 +822,8 @@ impl Future for PendingRequest { self.headers.insert(REFERER, referer); } } - self.urls.push(self.url.clone()); + let url = self.url.clone(); + self.as_mut().urls().push(url); let action = self .client .redirect_policy @@ -805,7 +833,10 @@ impl Future for PendingRequest { redirect::Action::Follow => { self.url = loc; - remove_sensitive_headers(&mut self.headers, &self.url, &self.urls); + let mut headers = + std::mem::replace(self.as_mut().headers(), HeaderMap::new()); + + remove_sensitive_headers(&mut headers, &self.url, &self.urls); debug!("redirecting to {:?} '{}'", self.method, self.url); let uri = expect_uri(&self.url); let body = match self.body { @@ -821,27 +852,30 @@ impl Future for PendingRequest { // Add cookies from the cookie store. if let Some(cookie_store_wrapper) = self.client.cookie_store.as_ref() { let cookie_store = cookie_store_wrapper.read().unwrap(); - add_cookie_header(&mut self.headers, &cookie_store, &self.url); + add_cookie_header(&mut headers, &cookie_store, &self.url); } - *req.headers_mut() = self.headers.clone(); - self.in_flight = self.client.hyper.request(req); + *req.headers_mut() = headers.clone(); + std::mem::swap(self.as_mut().headers(), &mut headers); + *self.as_mut().in_flight().get_mut() = self.client.hyper.request(req); continue; } redirect::Action::Stop => { debug!("redirect_policy disallowed redirection to '{}'", loc); } redirect::Action::LoopDetected => { - return Err(crate::error::loop_detected(self.url.clone())); + return Poll::Ready(Err(crate::error::loop_detected(self.url.clone()))); } redirect::Action::TooManyRedirects => { - return Err(crate::error::too_many_redirects(self.url.clone())); + return Poll::Ready(Err(crate::error::too_many_redirects( + self.url.clone(), + ))); } } } } let res = Response::new(res, self.url.clone(), self.client.gzip, self.timeout.take()); - return Ok(Async::Ready(res)); + return Poll::Ready(Ok(res)); } } } diff --git a/src/async_impl/decoder.rs b/src/async_impl/decoder.rs index 3f95bd4..ff56c4e 100644 --- a/src/async_impl/decoder.rs +++ b/src/async_impl/decoder.rs @@ -9,25 +9,16 @@ Chunks are just passed along. If the response is gzip, then the chunks are decompressed into a buffer. Slices of that buffer are emitted as new chunks. - -This module consists of a few main types: - -- `ReadableChunks` is a `Read`-like wrapper around a stream -- `Decoder` is a layer over `ReadableChunks` that applies the right decompression - -The following types directly support the gzip compression case: - -- `Pending` is a non-blocking constructor for a `Decoder` in case the body needs to be checked for EOF */ -use std::cmp; use std::fmt; -use std::io::{self, Read}; +use std::future::Future; use std::mem; +use std::pin::Pin; +use std::task::{Context, Poll}; -use bytes::{Buf, BufMut, BytesMut}; -use flate2::read::GzDecoder; -use futures::{Async, Future, Poll, Stream}; +use bytes::Bytes; +use futures::Stream; use hyper::header::{CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING}; use hyper::HeaderMap; @@ -36,8 +27,6 @@ use log::warn; use super::{Body, Chunk}; use crate::error; -const INIT_BUFFER_SIZE: usize = 8192; - /// A response decompressor over a non-blocking stream of chunks. /// /// The inner decoder may be constructed asynchronously. @@ -49,22 +38,15 @@ enum Inner { /// A `PlainText` decoder just returns the response content as is. PlainText(Body), /// A `Gzip` decoder will uncompress the gzipped response content before returning it. - Gzip(Gzip), + Gzip(async_compression::stream::GzipDecoder>), /// A decoder that doesn't have a value yet. Pending(Pending), } /// A future attempt to poll the response body for EOF so we know whether to use gzip or not. -struct Pending { - body: ReadableChunks, -} +struct Pending(futures::stream::Peekable); -/// A gzip decoder that reads from a `flate2::read::GzDecoder` into a `BytesMut` and emits the results -/// as a `Chunk`. -struct Gzip { - inner: Box>>, - buf: BytesMut, -} +struct BodyBytes(Body); impl fmt::Debug for Decoder { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -86,7 +68,6 @@ impl Decoder { /// A plain text decoder. /// /// This decoder will emit the underlying chunks as-is. - #[inline] fn plain_text(body: Body) -> Decoder { Decoder { inner: Inner::PlainText(body), @@ -96,12 +77,11 @@ impl Decoder { /// A gzip decoder. /// /// This decoder will buffer and decompress chunks that are gzipped. - #[inline] fn gzip(body: Body) -> Decoder { + use futures::stream::StreamExt; + Decoder { - inner: Inner::Pending(Pending { - body: ReadableChunks::new(body), - }), + inner: Inner::Pending(Pending(BodyBytes(body).peekable())), } } @@ -148,189 +128,65 @@ impl Decoder { } impl Stream for Decoder { - type Item = Chunk; - type Error = error::Error; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - // Do a read or poll for a pendidng decoder value. + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + // Do a read or poll for a pending decoder value. let new_value = match self.inner { - Inner::Pending(ref mut future) => match future.poll() { - Ok(Async::Ready(inner)) => inner, - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(e) => return Err(e), + Inner::Pending(ref mut future) => match Pin::new(future).poll(cx) { + Poll::Ready(Ok(inner)) => inner, + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(crate::error::from_io(e)))), + Poll::Pending => return Poll::Pending, }, - Inner::PlainText(ref mut body) => return body.poll(), - Inner::Gzip(ref mut decoder) => return decoder.poll(), + Inner::PlainText(ref mut body) => return Pin::new(body).poll_next(cx), + Inner::Gzip(ref mut decoder) => { + return match futures::ready!(Pin::new(decoder).poll_next(cx)) { + Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.into()))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::from_io(err)))), + None => Poll::Ready(None), + } + } }; self.inner = new_value; - self.poll() + self.poll_next(cx) } } impl Future for Pending { - type Item = Inner; - type Error = error::Error; + type Output = Result; - fn poll(&mut self) -> Poll { - let body_state = match self.body.poll_stream() { - Ok(Async::Ready(state)) => state, - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(e) => return Err(e), + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + use futures::stream::StreamExt; + + match futures::ready!(Pin::new(&mut self.0).peek(cx)) { + Some(Ok(_)) => { + // fallthrough + } + Some(Err(_e)) => { + // error was just a ref, so we need to really poll to move it + return Poll::Ready(Err(futures::ready!(Pin::new(&mut self.0).poll_next(cx)) + .expect("just peeked Some") + .unwrap_err())); + } + None => return Poll::Ready(Ok(Inner::PlainText(Body::empty()))), }; - let body = mem::replace(&mut self.body, ReadableChunks::new(Body::empty())); - match body_state { - StreamState::Eof => Ok(Async::Ready(Inner::PlainText(Body::empty()))), - StreamState::HasMore => Ok(Async::Ready(Inner::Gzip(Gzip::new(body)))), - } + let body = mem::replace(&mut self.0, BodyBytes(Body::empty()).peekable()); + Poll::Ready(Ok(Inner::Gzip( + async_compression::stream::GzipDecoder::new(body), + ))) } } -impl Gzip { - fn new(stream: ReadableChunks) -> Self { - Gzip { - buf: BytesMut::with_capacity(INIT_BUFFER_SIZE), - inner: Box::new(GzDecoder::new(stream)), - } - } -} - -impl Stream for Gzip { - type Item = Chunk; - type Error = error::Error; - - fn poll(&mut self) -> Poll, Self::Error> { - if self.buf.remaining_mut() == 0 { - self.buf.reserve(INIT_BUFFER_SIZE); - } - - // The buffer contains uninitialised memory so getting a readable slice is unsafe. - // We trust the `flate2` and `miniz` writer not to read from the memory given. - // - // To be safe, this memory could be zeroed before passing to `flate2`. - // Otherwise we might need to deal with the case where `flate2` panics. - let read = try_io!(self.inner.read(unsafe { self.buf.bytes_mut() })); - - if read == 0 { - // If GzDecoder reports EOF, it doesn't necessarily mean the - // underlying stream reached EOF (such as the `0\r\n\r\n` - // header meaning a chunked transfer has completed). If it - // isn't polled till EOF, the connection may not be able - // to be re-used. - // - // See https://github.com/seanmonstar/reqwest/issues/508. - let inner_read = try_io!(self.inner.get_mut().read(&mut [0])); - if inner_read == 0 { - Ok(Async::Ready(None)) - } else { - Err(error::from(io::Error::new( - io::ErrorKind::InvalidData, - "unexpected data after gzip decoder signaled end-of-file", - ))) - } - } else { - unsafe { self.buf.advance_mut(read) }; - let chunk = Chunk::from_chunk(self.buf.split_to(read).freeze()); - - Ok(Async::Ready(Some(chunk))) - } - } -} - -/// A `Read`able wrapper over a stream of chunks. -pub struct ReadableChunks { - state: ReadState, - stream: S, -} - -enum ReadState { - /// A chunk is ready to be read from. - Ready(Chunk), - /// The next chunk isn't ready yet. - NotReady, - /// The stream has finished. - Eof, -} - -enum StreamState { - /// More bytes can be read from the stream. - HasMore, - /// No more bytes can be read from the stream. - Eof, -} - -impl ReadableChunks { - #[inline] - pub(crate) fn new(stream: S) -> Self { - ReadableChunks { - state: ReadState::NotReady, - stream, - } - } -} - -impl fmt::Debug for ReadableChunks { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("ReadableChunks").finish() - } -} - -impl Read for ReadableChunks -where - S: Stream, -{ - fn read(&mut self, buf: &mut [u8]) -> io::Result { - loop { - let ret; - match self.state { - ReadState::Ready(ref mut chunk) => { - let len = cmp::min(buf.len(), chunk.remaining()); - - buf[..len].copy_from_slice(&chunk[..len]); - chunk.advance(len); - if chunk.is_empty() { - ret = len; - } else { - return Ok(len); - } - } - ReadState::NotReady => match self.poll_stream() { - Ok(Async::Ready(StreamState::HasMore)) => continue, - Ok(Async::Ready(StreamState::Eof)) => return Ok(0), - Ok(Async::NotReady) => return Err(io::ErrorKind::WouldBlock.into()), - Err(e) => return Err(error::into_io(e)), - }, - ReadState::Eof => return Ok(0), - } - self.state = ReadState::NotReady; - return Ok(ret); - } - } -} - -impl ReadableChunks -where - S: Stream, -{ - /// Poll the readiness of the inner reader. - /// - /// This function will update the internal state and return a simplified - /// version of the `ReadState`. - fn poll_stream(&mut self) -> Poll { - match self.stream.poll() { - Ok(Async::Ready(Some(chunk))) => { - self.state = ReadState::Ready(chunk); - - Ok(Async::Ready(StreamState::HasMore)) - } - Ok(Async::Ready(None)) => { - self.state = ReadState::Eof; - - Ok(Async::Ready(StreamState::Eof)) - } - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => Err(e), +impl Stream for BodyBytes { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match futures::ready!(Pin::new(&mut self.0).poll_next(cx)) { + Some(Ok(chunk)) => Poll::Ready(Some(Ok(chunk.into()))), + Some(Err(err)) => Poll::Ready(Some(Err(err.into_io()))), + None => Poll::Ready(None), } } } diff --git a/src/async_impl.rs b/src/async_impl/mod.rs similarity index 85% rename from src/async_impl.rs rename to src/async_impl/mod.rs index 5997f06..8218d83 100644 --- a/src/async_impl.rs +++ b/src/async_impl/mod.rs @@ -1,6 +1,6 @@ pub use self::body::{Body, Chunk}; pub use self::client::{Client, ClientBuilder}; -pub use self::decoder::{Decoder, ReadableChunks}; +pub use self::decoder::Decoder; pub use self::request::{Request, RequestBuilder}; pub use self::response::{Response, ResponseBuilderExt}; diff --git a/src/async_impl/multipart.rs b/src/async_impl/multipart.rs index c2eb8c2..48ec7d1 100644 --- a/src/async_impl/multipart.rs +++ b/src/async_impl/multipart.rs @@ -7,9 +7,9 @@ use mime_guess::Mime; use percent_encoding::{self, AsciiSet, NON_ALPHANUMERIC}; use uuid::Uuid; -use futures::Stream; +use futures::{Stream, StreamExt}; -use super::{Body, Chunk}; +use super::Body; /// An async multipart/form-data request. pub struct Form { @@ -190,11 +190,11 @@ impl Part { } /// Makes a new parameter from an arbitrary stream. - pub fn stream(value: T) -> Part + pub fn stream(value: T) -> Part where - T: Stream + Send + 'static, - T::Item: Into, - T::Error: std::error::Error + Send + Sync, + T: Stream> + Send + Sync + 'static, + E: std::error::Error + Send + Sync + 'static, + hyper::Chunk: std::convert::From, { Part::new(Body::wrap(hyper::Body::wrap_stream( value.map(|chunk| chunk.into()), @@ -210,7 +210,7 @@ impl Part { /// Tries to set the mime of this part. pub fn mime_str(self, mime: &str) -> crate::Result { - Ok(self.mime(try_!(mime.parse()))) + Ok(self.mime(mime.parse().map_err(crate::error::from)?)) } // Re-export when mime 0.4 is available, with split MediaType/MediaRange. @@ -480,6 +480,7 @@ impl PercentEncoding { #[cfg(test)] mod tests { use super::*; + use futures::TryStreamExt; use tokio; #[test] @@ -487,9 +488,10 @@ mod tests { let form = Form::new(); let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt"); - let body_ft = form.stream(); + let body = form.stream(); + let s = body.map(|try_c| try_c.map(|c| c.into_bytes())).try_concat(); - let out = rt.block_on(body_ft.map(|c| c.into_bytes()).concat2()); + let out = rt.block_on(s); assert_eq!(out.unwrap(), Vec::new()); } @@ -498,16 +500,20 @@ mod tests { let mut form = Form::new() .part( "reader1", - Part::stream(futures::stream::once::<_, hyper::Error>(Ok(Chunk::from( - "part1".to_owned(), + Part::stream(futures::stream::once(futures::future::ready::< + Result, + >(Ok( + hyper::Chunk::from("part1".to_owned()), )))), ) .part("key1", Part::text("value1")) .part("key2", Part::text("value2").mime(mime::IMAGE_BMP)) .part( "reader2", - Part::stream(futures::stream::once::<_, hyper::Error>(Ok(Chunk::from( - "part2".to_owned(), + Part::stream(futures::stream::once(futures::future::ready::< + Result, + >(Ok( + hyper::Chunk::from("part2".to_owned()), )))), ) .part("key3", Part::text("value3").file_name("filename")); @@ -530,11 +536,10 @@ mod tests { Content-Disposition: form-data; name=\"key3\"; filename=\"filename\"\r\n\r\n\ value3\r\n--boundary--\r\n"; let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt"); - let body_ft = form.stream(); + let body = form.stream(); + let s = body.map(|try_c| try_c.map(|c| c.into_bytes())).try_concat(); - let out = rt - .block_on(body_ft.map(|c| c.into_bytes()).concat2()) - .unwrap(); + let out = rt.block_on(s).unwrap(); // These prints are for debug purposes in case the test fails println!( "START REAL\n{}\nEND REAL", @@ -558,11 +563,10 @@ mod tests { value2\r\n\ --boundary--\r\n"; let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt"); - let body_ft = form.stream(); + let body = form.stream(); + let s = body.map(|try_c| try_c.map(|c| c.into_bytes())).try_concat(); - let out = rt - .block_on(body_ft.map(|c| c.into_bytes()).concat2()) - .unwrap(); + let out = rt.block_on(s).unwrap(); // These prints are for debug purposes in case the test fails println!( "START REAL\n{}\nEND REAL", diff --git a/src/async_impl/request.rs b/src/async_impl/request.rs index b77e6ba..a3e1d30 100644 --- a/src/async_impl/request.rs +++ b/src/async_impl/request.rs @@ -191,28 +191,20 @@ impl RequestBuilder { /// Sends a multipart/form-data body. /// /// ``` - /// # extern crate futures; - /// # extern crate reqwest; - /// /// # use reqwest::Error; - /// # use futures::future::Future; /// - /// # fn run() -> Result<(), Error> { + /// # async fn run() -> Result<(), Error> { /// let client = reqwest::r#async::Client::new(); /// let form = reqwest::r#async::multipart::Form::new() /// .text("key3", "value3") /// .text("key4", "value4"); /// - /// let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt"); /// /// let response = client.post("your url") /// .multipart(form) /// .send() - /// .and_then(|_| { - /// Ok(()) - /// }); - /// - /// rt.block_on(response) + /// .await?; + /// # Ok(()) /// # } /// ``` pub fn multipart(self, mut multipart: multipart::Form) -> RequestBuilder { @@ -334,23 +326,17 @@ impl RequestBuilder { /// # Example /// /// ```no_run - /// # extern crate futures; - /// # extern crate reqwest; - /// # /// # use reqwest::Error; - /// # use futures::future::Future; /// # - /// # fn run() -> Result<(), Error> { + /// # async fn run() -> Result<(), Error> { /// let response = reqwest::r#async::Client::new() /// .get("https://hyper.rs") /// .send() - /// .map(|resp| println!("status: {}", resp.status())); - /// - /// let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt"); - /// rt.block_on(response) + /// .await?; + /// # Ok(()) /// # } /// ``` - pub fn send(self) -> impl Future { + pub fn send(self) -> impl Future> { match self.request { Ok(req) => self.client.execute_request(req), Err(err) => Pending::new_err(err), diff --git a/src/async_impl/response.rs b/src/async_impl/response.rs index 007595b..4413457 100644 --- a/src/async_impl/response.rs +++ b/src/async_impl/response.rs @@ -3,10 +3,11 @@ use std::fmt; use std::marker::PhantomData; use std::mem; use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{Context, Poll}; use encoding_rs::{Encoding, UTF_8}; -use futures::stream::Concat2; -use futures::{try_ready, Async, Future, Poll, Stream}; +use futures::{Future, FutureExt, TryStreamExt}; use http; use hyper::client::connect::HttpInfo; use hyper::header::CONTENT_LENGTH; @@ -20,8 +21,12 @@ use url::Url; use super::body::Body; use super::Decoder; +use crate::async_impl::Chunk; use crate::cookie; +/// https://github.com/rust-lang-nursery/futures-rs/issues/1812 +type ConcatDecoder = Pin> + Send>>; + /// A Response to a submitted `Request`. pub struct Response { status: StatusCode, @@ -139,7 +144,7 @@ impl Response { } /// Get the response text - pub fn text(&mut self) -> impl Future { + pub fn text(&mut self) -> impl Future> { self.text_with_charset("utf-8") } @@ -147,7 +152,7 @@ impl Response { pub fn text_with_charset( &mut self, default_encoding: &str, - ) -> impl Future { + ) -> impl Future> { let body = mem::replace(&mut self.body, Decoder::empty()); let content_type = self .headers @@ -160,18 +165,18 @@ impl Response { .unwrap_or(default_encoding); let encoding = Encoding::for_label(encoding_name.as_bytes()).unwrap_or(UTF_8); Text { - concat: body.concat2(), + concat: body.try_concat().boxed(), encoding, } } /// Try to deserialize the response body as JSON using `serde`. #[inline] - pub fn json(&mut self) -> impl Future { + pub fn json(&mut self) -> impl Future> { let body = mem::replace(&mut self.body, Decoder::empty()); Json { - concat: body.concat2(), + concat: body.try_concat().boxed(), _marker: PhantomData, } } @@ -270,17 +275,27 @@ impl> From> for Response { /// A JSON object. struct Json { - concat: Concat2, + concat: ConcatDecoder, _marker: PhantomData, } +impl Json { + fn concat(self: Pin<&mut Self>) -> Pin<&mut ConcatDecoder> { + unsafe { Pin::map_unchecked_mut(self, |x| &mut x.concat) } + } +} + impl Future for Json { - type Item = T; - type Error = crate::Error; - fn poll(&mut self) -> Poll { - let bytes = try_ready!(self.concat.poll()); - let t = try_!(serde_json::from_slice(&bytes)); - Ok(Async::Ready(t)) + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match futures::ready!(self.concat().as_mut().poll(cx)) { + Err(e) => Poll::Ready(Err(e)), + Ok(chunk) => { + let t = serde_json::from_slice(&chunk).map_err(crate::error::from); + Poll::Ready(t) + } + } } } @@ -290,29 +305,36 @@ impl fmt::Debug for Json { } } -#[derive(Debug)] +//#[derive(Debug)] struct Text { - concat: Concat2, + concat: ConcatDecoder, encoding: &'static Encoding, } +impl Text { + fn concat(self: Pin<&mut Self>) -> Pin<&mut ConcatDecoder> { + unsafe { Pin::map_unchecked_mut(self, |x| &mut x.concat) } + } +} + impl Future for Text { - type Item = String; - type Error = crate::Error; - fn poll(&mut self) -> Poll { - let bytes = try_ready!(self.concat.poll()); - // a block because of borrow checker - { - let (text, _, _) = self.encoding.decode(&bytes); - if let Cow::Owned(s) = text { - return Ok(Async::Ready(s)); + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match futures::ready!(self.as_mut().concat().as_mut().poll(cx)) { + Err(e) => Poll::Ready(Err(e)), + Ok(chunk) => { + let (text, _, _) = self.as_mut().encoding.decode(&chunk); + if let Cow::Owned(s) = text { + return Poll::Ready(Ok(s)); + } + unsafe { + // decoding returned Cow::Borrowed, meaning these bytes + // are already valid utf8 + Poll::Ready(Ok(String::from_utf8_unchecked(chunk.to_vec()))) + } } } - unsafe { - // decoding returned Cow::Borrowed, meaning these bytes - // are already valid utf8 - Ok(Async::Ready(String::from_utf8_unchecked(bytes.to_vec()))) - } } } diff --git a/src/body.rs b/src/body.rs index dc17289..5712af4 100644 --- a/src/body.rs +++ b/src/body.rs @@ -1,10 +1,9 @@ use std::fmt; use std::fs::File; +use std::future::Future; use std::io::{self, Cursor, Read}; use bytes::Bytes; -use futures::{try_ready, Future}; -use hyper; use crate::async_impl; @@ -213,77 +212,78 @@ pub(crate) struct Sender { tx: hyper::body::Sender, } +async fn send_future(sender: Sender) -> Result<(), crate::Error> { + use bytes::{BufMut, BytesMut}; + use std::cmp; + + let con_len = sender.body.1; + let cap = cmp::min(sender.body.1.unwrap_or(8192), 8192); + let mut written = 0; + let mut buf = BytesMut::with_capacity(cap as usize); + let mut body = sender.body.0; + // Put in an option so that it can be consumed on error to call abort() + let mut tx = Some(sender.tx); + + loop { + if Some(written) == con_len { + // Written up to content-length, so stop. + return Ok(()); + } + + // The input stream is read only if the buffer is empty so + // that there is only one read in the buffer at any time. + // + // We need to know whether there is any data to send before + // we check the transmission channel (with poll_ready below) + // because somestimes the receiver disappears as soon as is + // considers the data is completely transmitted, which may + // be true. + // + // The use case is a web server that closes its + // input stream as soon as the data received is valid JSON. + // This behaviour is questionable, but it exists and the + // fact is that there is actually no remaining data to read. + if buf.is_empty() { + if buf.remaining_mut() == 0 { + buf.reserve(8192); + } + + match body.read(unsafe { buf.bytes_mut() }) { + Ok(0) => { + // The buffer was empty and nothing's left to + // read. Return. + return Ok(()); + } + Ok(n) => unsafe { + buf.advance_mut(n); + }, + Err(e) => { + let ret = io::Error::new(e.kind(), e.to_string()); + tx.take().expect("tx only taken on error").abort(); + return Err(crate::error::from(ret)); + } + } + } + + // The only way to get here is when the buffer is not empty. + // We can check the transmission channel + + let buf_len = buf.len() as u64; + tx.as_mut() + .expect("tx only taken on error") + .send_data(buf.take().freeze().into()) + .await + .map_err(crate::error::from)?; + + written += buf_len; + } +} + impl Sender { // A `Future` that may do blocking read calls. // As a `Future`, this integrates easily with `wait::timeout`. - pub(crate) fn send(self) -> impl Future { - use bytes::{BufMut, BytesMut}; - use futures::future; - use std::cmp; - - let con_len = self.body.1; - let cap = cmp::min(self.body.1.unwrap_or(8192), 8192); - let mut written = 0; - let mut buf = BytesMut::with_capacity(cap as usize); - let mut body = self.body.0; - // Put in an option so that it can be consumed on error to call abort() - let mut tx = Some(self.tx); - - future::poll_fn(move || loop { - if Some(written) == con_len { - // Written up to content-length, so stop. - return Ok(().into()); - } - - // The input stream is read only if the buffer is empty so - // that there is only one read in the buffer at any time. - // - // We need to know whether there is any data to send before - // we check the transmission channel (with poll_ready below) - // because somestimes the receiver disappears as soon as is - // considers the data is completely transmitted, which may - // be true. - // - // The use case is a web server that closes its - // input stream as soon as the data received is valid JSON. - // This behaviour is questionable, but it exists and the - // fact is that there is actually no remaining data to read. - if buf.is_empty() { - if buf.remaining_mut() == 0 { - buf.reserve(8192); - } - - match body.read(unsafe { buf.bytes_mut() }) { - Ok(0) => { - // The buffer was empty and nothing's left to - // read. Return. - return Ok(().into()); - } - Ok(n) => unsafe { - buf.advance_mut(n); - }, - Err(e) => { - let ret = io::Error::new(e.kind(), e.to_string()); - tx.take().expect("tx only taken on error").abort(); - return Err(crate::error::from(ret)); - } - } - } - - // The only way to get here is when the buffer is not empty. - // We can check the transmission channel - try_ready!(tx - .as_mut() - .expect("tx only taken on error") - .poll_ready() - .map_err(crate::error::from)); - - written += buf.len() as u64; - let tx = tx.as_mut().expect("tx only taken on error"); - if tx.send_data(buf.take().freeze().into()).is_err() { - return Err(crate::error::timedout(None)); - } - }) + pub(crate) fn send(self) -> impl Future> { + send_future(self) } } diff --git a/src/client.rs b/src/client.rs index 9e9d5a2..e34e54e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,14 +1,14 @@ use std::fmt; +use std::future::Future; use std::net::IpAddr; use std::sync::Arc; use std::thread; use std::time::Duration; -use futures::future::{self, Either}; -use futures::sync::{mpsc, oneshot}; -use futures::{Async, Future, Stream}; +use futures::channel::{mpsc, oneshot}; +use futures::{StreamExt, TryFutureExt}; -use log::trace; +use log::{error, trace}; use crate::request::{Request, RequestBuilder}; use crate::response::Response; @@ -523,10 +523,8 @@ struct ClientHandle { inner: Arc, } -type ThreadSender = mpsc::UnboundedSender<( - async_impl::Request, - oneshot::Sender>, -)>; +type OneshotResponse = oneshot::Sender>; +type ThreadSender = mpsc::UnboundedSender<(async_impl::Request, OneshotResponse)>; struct InnerClientHandle { tx: Option, @@ -544,69 +542,54 @@ impl ClientHandle { fn new(builder: ClientBuilder) -> crate::Result { let timeout = builder.timeout; let builder = builder.inner; - let (tx, rx) = mpsc::unbounded(); + let (tx, rx) = mpsc::unbounded::<(async_impl::Request, OneshotResponse)>(); let (spawn_tx, spawn_rx) = oneshot::channel::>(); - let handle = try_!(thread::Builder::new() + let handle = thread::Builder::new() .name("reqwest-internal-sync-runtime".into()) .spawn(move || { use tokio::runtime::current_thread::Runtime; - let built = (|| { - let rt = try_!(Runtime::new()); - let client = builder.build()?; - Ok((rt, client)) - })(); - - let (mut rt, client) = match built { - Ok((rt, c)) => { - if spawn_tx.send(Ok(())).is_err() { - return; - } - (rt, c) - } + let mut rt = match Runtime::new().map_err(crate::error::from) { Err(e) => { - let _ = spawn_tx.send(Err(e)); + if let Err(e) = spawn_tx.send(Err(e)) { + error!("Failed to communicate runtime creation failure: {:?}", e); + } return; } + Ok(v) => v, }; - let work = rx.for_each(move |(req, tx)| { - let mut tx_opt: Option>> = - Some(tx); - let mut res_fut = client.execute(req); - - let task = future::poll_fn(move || { - let canceled = tx_opt - .as_mut() - .expect("polled after complete") - .poll_cancel() - .expect("poll_cancel cannot error") - .is_ready(); - - if canceled { - trace!("response receiver is canceled"); - Ok(Async::Ready(())) - } else { - let result = match res_fut.poll() { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(res)) => Ok(res), - Err(err) => Err(err), - }; - - let _ = tx_opt.take().expect("polled after complete").send(result); - Ok(Async::Ready(())) + let f = async move { + let client = match builder.build() { + Err(e) => { + if let Err(e) = spawn_tx.send(Err(e)) { + error!("Failed to communicate client creation failure: {:?}", e); + } + return; } - }); - tokio::spawn(task); - Ok(()) - }); + Ok(v) => v, + }; + if let Err(e) = spawn_tx.send(Ok(())) { + error!("Failed to communicate successful startup: {:?}", e); + return; + } - // work is Future<(), ()>, and our closure will never return Err - rt.block_on(work).expect("runtime unexpected error"); - })); + let mut rx = rx; + + while let Some((req, req_tx)) = rx.next().await { + let req_fut = client.execute(req); + tokio::spawn(forward(req_fut, req_tx)); + } + + trace!("Receiver is shutdown"); + }; + + rt.block_on(f) + }) + .map_err(crate::error::from)?; // Wait for the runtime thread to start up... - match spawn_rx.wait() { + match wait::timeout(spawn_rx, None) { Ok(Ok(())) => (), Ok(Err(err)) => return Err(err), Err(_canceled) => event_loop_panicked(), @@ -634,35 +617,61 @@ impl ClientHandle { .unbounded_send((req, tx)) .expect("core thread panicked"); - let write = if let Some(body) = body { - Either::A(body.send()) - //try_!(body.send(self.timeout.0), &url); - } else { - Either::B(future::ok(())) - }; + let result: Result, wait::Waited> = + if let Some(body) = body { + let f = async move { + body.send().await?; + rx.await.map_err(|_canceled| event_loop_panicked()) + }; + wait::timeout(f, self.timeout.0) + } else { + wait::timeout( + rx.map_err(|_canceled| event_loop_panicked()), + self.timeout.0, + ) + }; - let rx = rx.map_err(|_canceled| event_loop_panicked()); - - let fut = write.join(rx).map(|((), res)| res); - - let res = match wait::timeout(fut, self.timeout.0) { - Ok(res) => res, - Err(wait::Waited::TimedOut) => return Err(crate::error::timedout(Some(url))), - Err(wait::Waited::Executor(err)) => return Err(crate::error::from(err).with_url(url)), - Err(wait::Waited::Inner(err)) => { - return Err(err.with_url(url)); - } - }; - res.map(|res| { - Response::new( + match result { + Ok(Err(err)) => Err(err.with_url(url)), + Ok(Ok(res)) => Ok(Response::new( res, self.timeout.0, KeepCoreThreadAlive(Some(self.inner.clone())), - ) - }) + )), + Err(wait::Waited::TimedOut) => Err(crate::error::timedout(Some(url))), + Err(wait::Waited::Executor(err)) => Err(crate::error::from(err).with_url(url)), + Err(wait::Waited::Inner(err)) => Err(err.with_url(url)), + } } } +async fn forward(fut: F, mut tx: OneshotResponse) +where + F: Future>, +{ + use std::task::Poll; + + futures::pin_mut!(fut); + + // "select" on the sender being canceled, and the future completing + let res = futures::future::poll_fn(|cx| { + match fut.as_mut().poll(cx) { + Poll::Ready(val) => Poll::Ready(Some(val)), + Poll::Pending => { + // check if the callback is canceled + futures::ready!(tx.poll_cancel(cx)); + Poll::Ready(None) + } + } + }) + .await; + + if let Some(res) = res { + let _ = tx.send(res); + } + // else request is canceled +} + #[derive(Clone, Copy)] struct Timeout(Option); diff --git a/src/connect.rs b/src/connect.rs index 8b9b9af..2c70bda 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -1,28 +1,26 @@ -use futures::Future; +use futures::FutureExt; use http::uri::Scheme; use hyper::client::connect::{Connect, Connected, Destination}; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_timer::Timeout; +use tokio::io::{AsyncRead, AsyncWrite}; -#[cfg(feature = "tls")] -use bytes::BufMut; -#[cfg(feature = "tls")] -use futures::Poll; #[cfg(feature = "default-tls")] use native_tls::{TlsConnector, TlsConnectorBuilder}; +use std::future::Future; use std::io; use std::net::IpAddr; +use std::pin::Pin; use std::sync::Arc; use std::time::Duration; -#[cfg(feature = "trust-dns")] -use crate::dns::TrustDnsResolver; +//#[cfg(feature = "trust-dns")] +//use crate::dns::TrustDnsResolver; use crate::proxy::{Proxy, ProxyScheme}; +use tokio::future::FutureExt as _; -#[cfg(feature = "trust-dns")] -type HttpConnector = hyper::client::HttpConnector; -#[cfg(not(feature = "trust-dns"))] +//#[cfg(feature = "trust-dns")] +//type HttpConnector = hyper::client::HttpConnector; +//#[cfg(not(feature = "trust-dns"))] type HttpConnector = hyper::client::HttpConnector; pub(crate) struct Connector { @@ -33,6 +31,7 @@ pub(crate) struct Connector { nodelay: bool, } +#[derive(Clone)] enum Inner { #[cfg(not(feature = "tls"))] Http(HttpConnector), @@ -76,7 +75,7 @@ impl Connector { where T: Into>, { - let tls = try_!(tls.build()); + let tls = tls.build().map_err(crate::error::from)?; let mut http = http_connector()?; http.set_local_address(local_addr.into()); @@ -130,25 +129,11 @@ impl Connector { } #[cfg(feature = "socks")] - fn connect_socks(&self, dst: Destination, proxy: ProxyScheme) -> Connecting { - macro_rules! timeout { - ($future:expr) => { - if let Some(dur) = self.timeout { - Box::new(Timeout::new($future, dur).map_err(|err| { - if err.is_inner() { - err.into_inner().expect("is_inner") - } else if err.is_elapsed() { - io::Error::new(io::ErrorKind::TimedOut, "connect timed out") - } else { - io::Error::new(io::ErrorKind::Other, err) - } - })) - } else { - Box::new($future) - } - }; - } - + async fn connect_socks( + &self, + dst: Destination, + proxy: ProxyScheme, + ) -> Result<(Conn, Connected), io::Error> { let dns = match proxy { ProxyScheme::Socks5 { remote_dns: false, .. @@ -167,14 +152,15 @@ impl Connector { if dst.scheme() == "https" { use self::native_tls_async::TlsConnectorExt; - let tls = tls.clone(); let host = dst.host().to_owned(); let socks_connecting = socks::connect(proxy, dst, dns); - return timeout!(socks_connecting.and_then(move |(conn, connected)| { - tls.connect_async(&host, conn) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - .map(move |io| (Box::new(io) as Conn, connected)) - })); + let (conn, connected) = socks::connect(proxy, dst, dns).await?; + let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); + let io = tls_connector + .connect(&host, conn) + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + Ok((Box::new(io) as Conn, connected)) } } #[cfg(feature = "rustls-tls")] @@ -185,40 +171,193 @@ impl Connector { let tls = tls_proxy.clone(); let host = dst.host().to_owned(); - let socks_connecting = socks::connect(proxy, dst, dns); - return timeout!(socks_connecting.and_then(move |(conn, connected)| { - let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) - .map(|dnsname| dnsname.to_owned()) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")); - futures::future::result(maybe_dnsname) - .and_then(move |dnsname| { - RustlsConnector::from(tls) - .connect(dnsname.as_ref(), conn) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - }) - .map(move |io| (Box::new(io) as Conn, connected)) - })); + let (conn, connected) = socks::connect(proxy, dst, dns); + let dnsname = DNSNameRef::try_from_ascii_str(&host) + .map(|dnsname| dnsname.to_owned()) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"))?; + let io = RustlsConnector::from(tls) + .connect(dnsname.as_ref(), conn) + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + Ok((Box::new(io) as Conn, connected)) } } #[cfg(not(feature = "tls"))] - Inner::Http(_) => (), + Inner::Http(_) => socks::connect(proxy, dst, dns), } - - // else no TLS - socks::connect(proxy, dst, dns) } } -#[cfg(feature = "trust-dns")] +//#[cfg(feature = "trust-dns")] +//fn http_connector() -> crate::Result { +// TrustDnsResolver::new() +// .map(HttpConnector::new_with_resolver) +// .map_err(crate::error::dns_system_conf) +//} + +//#[cfg(not(feature = "trust-dns"))] fn http_connector() -> crate::Result { - TrustDnsResolver::new() - .map(HttpConnector::new_with_resolver) - .map_err(crate::error::dns_system_conf) + Ok(HttpConnector::new()) } -#[cfg(not(feature = "trust-dns"))] -fn http_connector() -> crate::Result { - Ok(HttpConnector::new(4)) +async fn connect_with_maybe_proxy( + inner: Inner, + dst: Destination, + is_proxy: bool, + no_delay: bool, +) -> Result<(Conn, Connected), io::Error> { + match inner { + #[cfg(not(feature = "tls"))] + Inner::Http(http) => { + drop(no_delay); // only used for TLS? + let (io, connected) = http.connect(dst).await?; + Ok((Box::new(io) as Conn, connected.proxy(is_proxy))) + } + #[cfg(feature = "default-tls")] + Inner::DefaultTls(http, tls) => { + let mut http = http.clone(); + + http.set_nodelay(no_delay || (dst.scheme() == "https")); + + let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); + let http = hyper_tls::HttpsConnector::from((http, tls_connector)); + let (io, connected) = http.connect(dst).await?; + //TODO: where's this at now? + //if let hyper_tls::MaybeHttpsStream::Https(_stream) = &io { + // if !no_delay { + // stream.set_nodelay(false)?; + // } + //} + + Ok((Box::new(io) as Conn, connected.proxy(is_proxy))) + } + #[cfg(feature = "rustls-tls")] + Inner::RustlsTls { http, tls, .. } => { + let mut http = http.clone(); + + // Disable Nagle's algorithm for TLS handshake + // + // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES + http.set_nodelay(nodelay || (dst.scheme() == "https")); + + let http = hyper_rustls::HttpsConnector::from((http, tls.clone())); + let (io, connected) = http.connect(dst).await; + if let hyper_rustls::MaybeHttpsStream::Https(stream) = &io { + if !nodelay { + let (io, _) = stream.get_ref(); + io.set_nodelay(false)?; + } + } + + Ok((Box::new(io) as Conn, connected.proxy(is_proxy))) + } + } +} + +async fn connect_via_proxy( + inner: Inner, + dst: Destination, + proxy_scheme: ProxyScheme, + no_delay: bool, +) -> Result<(Conn, Connected), io::Error> { + log::trace!("proxy({:?}) intercepts {:?}", proxy_scheme, dst); + + let (puri, _auth) = match proxy_scheme { + ProxyScheme::Http { uri, auth, .. } => (uri, auth), + #[cfg(feature = "socks")] + ProxyScheme::Socks5 { .. } => return this.connect_socks(dst, proxy_scheme), + }; + + let mut ndst = dst.clone(); + + let new_scheme = puri.scheme_part().map(Scheme::as_str).unwrap_or("http"); + ndst.set_scheme(new_scheme) + .expect("proxy target scheme should be valid"); + + ndst.set_host(puri.host().expect("proxy target should have host")) + .expect("proxy target host should be valid"); + + ndst.set_port(puri.port_part().map(|port| port.as_u16())); + + #[cfg(feature = "tls")] + let auth = _auth; + + match &inner { + #[cfg(feature = "default-tls")] + Inner::DefaultTls(http, tls) => { + if dst.scheme() == "https" { + let host = dst.host().to_owned(); + let port = dst.port().unwrap_or(443); + let mut http = http.clone(); + http.set_nodelay(no_delay); + let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); + let http = hyper_tls::HttpsConnector::from((http, tls_connector)); + let (conn, connected) = http.connect(ndst).await?; + log::trace!("tunneling HTTPS over proxy"); + let tunneled = tunnel(conn, host.clone(), port, auth).await?; + let tls_connector = tokio_tls::TlsConnector::from(tls.clone()); + let io = tls_connector + .connect(&host, tunneled) + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + return Ok((Box::new(io) as Conn, connected.proxy(true))); + } + } + #[cfg(feature = "rustls-tls")] + Inner::RustlsTls { + http, + tls, + tls_proxy, + } => { + if dst.scheme() == "https" { + use rustls::Session; + use tokio_rustls::webpki::DNSNameRef; + use tokio_rustls::TlsConnector as RustlsConnector; + + let host = dst.host().to_owned(); + let port = dst.port().unwrap_or(443); + let mut http = http.clone(); + http.set_nodelay(nodelay); + let http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone())); + let tls = tls.clone(); + let (conn, connected) = http.connect(ndst).await; + log::trace!("tunneling HTTPS over proxy"); + let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) + .map(|dnsname| dnsname.to_owned()) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")); + let tunneled = tunnel(conn, host, port, auth).await; + let dnsname = maybe_dnsname?; + let io = RustlsConnector::from(tls) + .connect(dnsname.as_ref(), tunneled) + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + let connected = if io.get_ref().1.get_alpn_protocol() == Some(b"h2") { + connected.negotiated_h2() + } else { + connected + }; + return Ok((Box::new(io) as Conn, connected.proxy(true))); + } + } + #[cfg(not(feature = "tls"))] + Inner::Http(_) => (), + } + + connect_with_maybe_proxy(inner, ndst, true, no_delay).await +} + +async fn with_timeout(f: F, timeout: Option) -> Result +where + F: Future>, +{ + if let Some(to) = timeout { + match f.timeout(to).await { + Err(_elapsed) => Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out")), + Ok(try_res) => try_res, + } + } else { + f.await + } } impl Connect for Connector { @@ -228,202 +367,47 @@ impl Connect for Connector { fn connect(&self, dst: Destination) -> Self::Future { #[cfg(feature = "tls")] - let nodelay = self.nodelay; - - macro_rules! timeout { - ($future:expr) => { - if let Some(dur) = self.timeout { - Box::new(Timeout::new($future, dur).map_err(|err| { - if err.is_inner() { - err.into_inner().expect("is_inner") - } else if err.is_elapsed() { - io::Error::new(io::ErrorKind::TimedOut, "connect timed out") - } else { - io::Error::new(io::ErrorKind::Other, err) - } - })) - } else { - Box::new($future) - } - }; - } - - macro_rules! connect { - ( $http:expr, $dst:expr, $proxy:expr ) => { - timeout!($http - .connect($dst) - .map(|(io, connected)| (Box::new(io) as Conn, connected.proxy($proxy)))) - }; - ( $dst:expr, $proxy:expr ) => { - match &self.inner { - #[cfg(not(feature = "tls"))] - Inner::Http(http) => connect!(http, $dst, $proxy), - #[cfg(feature = "default-tls")] - Inner::DefaultTls(http, tls) => { - let mut http = http.clone(); - - http.set_nodelay(nodelay || ($dst.scheme() == "https")); - - let http = hyper_tls::HttpsConnector::from((http, tls.clone())); - timeout!(http.connect($dst).and_then(move |(io, connected)| { - if let hyper_tls::MaybeHttpsStream::Https(stream) = &io { - if !nodelay { - stream.get_ref().get_ref().set_nodelay(false)?; - } - } - - Ok((Box::new(io) as Conn, connected.proxy($proxy))) - })) - } - #[cfg(feature = "rustls-tls")] - Inner::RustlsTls { http, tls, .. } => { - let mut http = http.clone(); - - // Disable Nagle's algorithm for TLS handshake - // - // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES - http.set_nodelay(nodelay || ($dst.scheme() == "https")); - - let http = hyper_rustls::HttpsConnector::from((http, tls.clone())); - timeout!(http.connect($dst).and_then(move |(io, connected)| { - if let hyper_rustls::MaybeHttpsStream::Https(stream) = &io { - if !nodelay { - let (io, _) = stream.get_ref(); - io.set_nodelay(false)?; - } - } - - Ok((Box::new(io) as Conn, connected.proxy($proxy))) - })) - } - } - }; - } - + let no_delay = self.nodelay; + #[cfg(not(feature = "tls"))] + let no_delay = false; + let timeout = self.timeout; for prox in self.proxies.iter() { if let Some(proxy_scheme) = prox.intercept(&dst) { - log::trace!("proxy({:?}) intercepts {:?}", proxy_scheme, dst); - - let (puri, _auth) = match proxy_scheme { - ProxyScheme::Http { uri, auth, .. } => (uri, auth), - #[cfg(feature = "socks")] - ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme), - }; - - let mut ndst = dst.clone(); - - let new_scheme = puri.scheme_part().map(Scheme::as_str).unwrap_or("http"); - ndst.set_scheme(new_scheme) - .expect("proxy target scheme should be valid"); - - ndst.set_host(puri.host().expect("proxy target should have host")) - .expect("proxy target host should be valid"); - - ndst.set_port(puri.port_part().map(|port| port.as_u16())); - - #[cfg(feature = "tls")] - let auth = _auth; - - match &self.inner { - #[cfg(feature = "default-tls")] - Inner::DefaultTls(http, tls) => { - if dst.scheme() == "https" { - use self::native_tls_async::TlsConnectorExt; - - let host = dst.host().to_owned(); - let port = dst.port().unwrap_or(443); - let mut http = http.clone(); - http.set_nodelay(nodelay); - let http = hyper_tls::HttpsConnector::from((http, tls.clone())); - let tls = tls.clone(); - return timeout!(http.connect(ndst).and_then( - move |(conn, connected)| { - log::trace!("tunneling HTTPS over proxy"); - tunnel(conn, host.clone(), port, auth) - .and_then(move |tunneled| { - tls.connect_async(&host, tunneled).map_err(|e| { - io::Error::new(io::ErrorKind::Other, e) - }) - }) - .map(|io| (Box::new(io) as Conn, connected.proxy(true))) - } - )); - } - } - #[cfg(feature = "rustls-tls")] - Inner::RustlsTls { - http, - tls, - tls_proxy, - } => { - if dst.scheme() == "https" { - use rustls::Session; - use tokio_rustls::webpki::DNSNameRef; - use tokio_rustls::TlsConnector as RustlsConnector; - - let host = dst.host().to_owned(); - let port = dst.port().unwrap_or(443); - let mut http = http.clone(); - http.set_nodelay(nodelay); - let http = - hyper_rustls::HttpsConnector::from((http, tls_proxy.clone())); - let tls = tls.clone(); - return timeout!(http.connect(ndst).and_then( - move |(conn, connected)| { - log::trace!("tunneling HTTPS over proxy"); - let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) - .map(|dnsname| dnsname.to_owned()) - .map_err(|_| { - io::Error::new(io::ErrorKind::Other, "Invalid DNS Name") - }); - tunnel(conn, host, port, auth) - .and_then(move |tunneled| Ok((maybe_dnsname?, tunneled))) - .and_then(move |(dnsname, tunneled)| { - RustlsConnector::from(tls) - .connect(dnsname.as_ref(), tunneled) - .map_err(|e| { - io::Error::new(io::ErrorKind::Other, e) - }) - }) - .map(|io| { - let connected = if io.get_ref().1.get_alpn_protocol() - == Some(b"h2") - { - connected.negotiated_h2() - } else { - connected - }; - (Box::new(io) as Conn, connected.proxy(true)) - }) - } - )); - } - } - #[cfg(not(feature = "tls"))] - Inner::Http(_) => (), - } - - return connect!(ndst, true); + return with_timeout( + connect_via_proxy(self.inner.clone(), dst, proxy_scheme, no_delay), + timeout, + ) + .boxed(); } } - connect!(dst, false) + with_timeout( + connect_with_maybe_proxy(self.inner.clone(), dst, false, no_delay), + timeout, + ) + .boxed() } } pub(crate) trait AsyncConn: AsyncRead + AsyncWrite {} impl AsyncConn for T {} -pub(crate) type Conn = Box; +pub(crate) type Conn = Box; -pub(crate) type Connecting = Box + Send>; +pub(crate) type Connecting = + Pin> + Send>>; #[cfg(feature = "tls")] -fn tunnel( - conn: T, +async fn tunnel( + mut conn: T, host: String, port: u16, auth: Option, -) -> Tunnel { +) -> Result +where + T: AsyncRead + AsyncWrite + Unpin, +{ + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + let mut buf = format!( "\ CONNECT {0}:{1} HTTP/1.1\r\n\ @@ -443,84 +427,43 @@ fn tunnel( // headers end buf.extend_from_slice(b"\r\n"); - Tunnel { - buf: io::Cursor::new(buf), - conn: Some(conn), - state: TunnelState::Writing, - } -} + conn.write_all(&buf).await?; -#[cfg(feature = "tls")] -struct Tunnel { - buf: io::Cursor>, - conn: Option, - state: TunnelState, -} + let mut buf = [0; 8192]; + let mut pos = 0; -#[cfg(feature = "tls")] -enum TunnelState { - Writing, - Reading, -} + loop { + let n = conn.read(&mut buf[pos..]).await?; -#[cfg(feature = "tls")] -impl Future for Tunnel -where - T: AsyncRead + AsyncWrite, -{ - type Item = T; - type Error = io::Error; + if n == 0 { + return Err(tunnel_eof()); + } + pos += n; - fn poll(&mut self) -> Poll { - loop { - if let TunnelState::Writing = self.state { - let n = futures::try_ready!(self.conn.as_mut().unwrap().write_buf(&mut self.buf)); - if !self.buf.has_remaining_mut() { - self.state = TunnelState::Reading; - self.buf.get_mut().truncate(0); - } else if n == 0 { - return Err(tunnel_eof()); - } - } else { - let n = futures::try_ready!(self - .conn - .as_mut() - .unwrap() - .read_buf(&mut self.buf.get_mut())); - let read = &self.buf.get_ref()[..]; - if n == 0 { - return Err(tunnel_eof()); - } else if read.len() > 12 { - if read.starts_with(b"HTTP/1.1 200") || read.starts_with(b"HTTP/1.0 200") { - if read.ends_with(b"\r\n\r\n") { - return Ok(self.conn.take().unwrap().into()); - } - // else read more - } else if read.starts_with(b"HTTP/1.1 407") { - return Err(io::Error::new( - io::ErrorKind::Other, - "proxy authentication required", - )); - } else if read.starts_with(b"HTTP/1.1 403") { - return Err(io::Error::new( - io::ErrorKind::Other, - "proxy blocked this request", - )); - } else { - let (fst, _) = read.split_at(12); - return Err(io::Error::new( - io::ErrorKind::Other, - format!("unsuccessful tunnel: {:?}", fst).as_str(), - )); - } - } + let recvd = &buf[..pos]; + if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") { + if recvd.ends_with(b"\r\n\r\n") { + return Ok(conn); } + if pos == buf.len() { + return Err(io::Error::new( + io::ErrorKind::Other, + "proxy headers too long for tunnel", + )); + } + // else read more + } else if recvd.starts_with(b"HTTP/1.1 407") { + return Err(io::Error::new( + io::ErrorKind::Other, + "proxy authentication required", + )); + } else { + return Err(io::Error::new(io::ErrorKind::Other, "unsuccessful tunnel")); } } } #[cfg(feature = "tls")] -#[inline] fn tunnel_eof() -> io::Error { io::Error::new( io::ErrorKind::UnexpectedEof, @@ -528,138 +471,6 @@ fn tunnel_eof() -> io::Error { ) } -#[cfg(feature = "default-tls")] -mod native_tls_async { - use std::io::{self, Read, Write}; - - use futures::{Async, Future, Poll}; - use native_tls::{self, Error, HandshakeError, TlsConnector}; - use tokio_io::{try_nb, AsyncRead, AsyncWrite}; - - /// A wrapper around an underlying raw stream which implements the TLS or SSL - /// protocol. - /// - /// A `TlsStream` represents a handshake that has been completed successfully - /// and both the server and the client are ready for receiving and sending - /// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written - /// to a `TlsStream` are encrypted when passing through to `S`. - #[derive(Debug)] - pub struct TlsStream { - inner: native_tls::TlsStream, - } - - /// Future returned from `TlsConnectorExt::connect_async` which will resolve - /// once the connection handshake has finished. - pub struct ConnectAsync { - inner: MidHandshake, - } - - struct MidHandshake { - inner: Option, HandshakeError>>, - } - - /// Extension trait for the `TlsConnector` type in the `native_tls` crate. - pub trait TlsConnectorExt: sealed::Sealed { - /// Connects the provided stream with this connector, assuming the provided - /// domain. - /// - /// This function will internally call `TlsConnector::connect` to connect - /// the stream and returns a future representing the resolution of the - /// connection operation. The returned future will resolve to either - /// `TlsStream` or `Error` depending if it's successful or not. - /// - /// This is typically used for clients who have already established, for - /// example, a TCP connection to a remote server. That stream is then - /// provided here to perform the client half of a connection to a - /// TLS-powered server. - /// - /// # Compatibility notes - /// - /// Note that this method currently requires `S: Read + Write` but it's - /// highly recommended to ensure that the object implements the `AsyncRead` - /// and `AsyncWrite` traits as well, otherwise this function will not work - /// properly. - fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync - where - S: Read + Write; // TODO: change to AsyncRead + AsyncWrite - } - - mod sealed { - pub trait Sealed {} - } - - impl Read for TlsStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.read(buf) - } - } - - impl Write for TlsStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.inner.flush() - } - } - - impl AsyncRead for TlsStream {} - - impl AsyncWrite for TlsStream { - fn shutdown(&mut self) -> Poll<(), io::Error> { - try_nb!(self.inner.shutdown()); - self.inner.get_mut().shutdown() - } - } - - impl TlsConnectorExt for TlsConnector { - fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync - where - S: Read + Write, - { - ConnectAsync { - inner: MidHandshake { - inner: Some(self.connect(domain, stream)), - }, - } - } - } - - impl sealed::Sealed for TlsConnector {} - - // TODO: change this to AsyncRead/AsyncWrite on next major version - impl Future for ConnectAsync { - type Item = TlsStream; - type Error = Error; - - fn poll(&mut self) -> Poll, Error> { - self.inner.poll() - } - } - - // TODO: change this to AsyncRead/AsyncWrite on next major version - impl Future for MidHandshake { - type Item = TlsStream; - type Error = Error; - - fn poll(&mut self) -> Poll, Error> { - match self.inner.take().expect("cannot poll MidHandshake twice") { - Ok(stream) => Ok(TlsStream { inner: stream }.into()), - Err(HandshakeError::Failure(e)) => Err(e), - Err(HandshakeError::WouldBlock(s)) => match s.handshake() { - Ok(stream) => Ok(TlsStream { inner: stream }.into()), - Err(HandshakeError::Failure(e)) => Err(e), - Err(HandshakeError::WouldBlock(s)) => { - self.inner = Some(Err(HandshakeError::WouldBlock(s))); - Ok(Async::NotReady) - } - }, - } - } - } -} - #[cfg(feature = "socks")] mod socks { use std::io; @@ -678,19 +489,18 @@ mod socks { Proxy, } - pub(super) fn connect(proxy: ProxyScheme, dst: Destination, dns: DnsResolve) -> Connecting { + pub(super) async fn connect( + proxy: ProxyScheme, + dst: Destination, + dns: DnsResolve, + ) -> Result<(super::Conn, Connected), io::Error> { let https = dst.scheme() == "https"; let original_host = dst.host().to_owned(); let mut host = original_host.clone(); let port = dst.port().unwrap_or_else(|| if https { 443 } else { 80 }); if let DnsResolve::Local = dns { - let maybe_new_target = match (host.as_str(), port).to_socket_addrs() { - Ok(mut iter) => iter.next(), - Err(err) => { - return Box::new(future::err(err)); - } - }; + let maybe_new_target = (host.as_str(), port).to_socket_addrs()?.next(); if let Some(new_target) = maybe_new_target { host = new_target.ip().to_string(); } @@ -702,39 +512,33 @@ mod socks { }; // Get a Tokio TcpStream - let stream = future::result( - if let Some((username, password)) = auth { - Socks5Stream::connect_with_password( - socket_addr, - (host.as_str(), port), - &username, - &password, - ) - } else { - Socks5Stream::connect(socket_addr, (host.as_str(), port)) - } - .and_then(|s| { - TcpStream::from_std(s.into_inner(), &reactor::Handle::default()) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) - }), - ); + let stream = if let Some((username, password)) = auth { + Socks5Stream::connect_with_password( + socket_addr, + (host.as_str(), port), + &username, + &password, + ) + .await + } else { + let s = Socks5Stream::connect(socket_addr, (host.as_str(), port)).await; + TcpStream::from_std(s.into_inner(), &reactor::Handle::default()) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))? + }; - Box::new(stream.map(|s| (Box::new(s) as super::Conn, Connected::new()))) + Ok((Box::new(s) as super::Conn, Connected::new())) } } #[cfg(feature = "tls")] #[cfg(test)] mod tests { - extern crate tokio_tcp; - - use self::tokio_tcp::TcpStream; use super::tunnel; use crate::proxy; - use futures::Future; use std::io::{Read, Write}; use std::net::TcpListener; use std::thread; + use tokio::net::tcp::TcpStream; use tokio::runtime::current_thread::Runtime; static TUNNEL_OK: &[u8] = b"\ @@ -782,12 +586,14 @@ mod tests { let addr = mock_tunnel!(); let mut rt = Runtime::new().unwrap(); - let work = TcpStream::connect(&addr); - let host = addr.ip().to_string(); - let port = addr.port(); - let work = work.and_then(|tcp| tunnel(tcp, host, port, None)); + let f = async move { + let tcp = TcpStream::connect(&addr).await?; + let host = addr.ip().to_string(); + let port = addr.port(); + tunnel(tcp, host, port, None).await + }; - rt.block_on(work).unwrap(); + rt.block_on(f).unwrap(); } #[test] @@ -795,12 +601,14 @@ mod tests { let addr = mock_tunnel!(b"HTTP/1.1 200 OK"); let mut rt = Runtime::new().unwrap(); - let work = TcpStream::connect(&addr); - let host = addr.ip().to_string(); - let port = addr.port(); - let work = work.and_then(|tcp| tunnel(tcp, host, port, None)); + let f = async move { + let tcp = TcpStream::connect(&addr).await?; + let host = addr.ip().to_string(); + let port = addr.port(); + tunnel(tcp, host, port, None).await + }; - rt.block_on(work).unwrap_err(); + rt.block_on(f).unwrap_err(); } #[test] @@ -808,12 +616,14 @@ mod tests { let addr = mock_tunnel!(b"foo bar baz hallo"); let mut rt = Runtime::new().unwrap(); - let work = TcpStream::connect(&addr); - let host = addr.ip().to_string(); - let port = addr.port(); - let work = work.and_then(|tcp| tunnel(tcp, host, port, None)); + let f = async move { + let tcp = TcpStream::connect(&addr).await?; + let host = addr.ip().to_string(); + let port = addr.port(); + tunnel(tcp, host, port, None).await + }; - rt.block_on(work).unwrap_err(); + rt.block_on(f).unwrap_err(); } #[test] @@ -827,12 +637,14 @@ mod tests { ); let mut rt = Runtime::new().unwrap(); - let work = TcpStream::connect(&addr); - let host = addr.ip().to_string(); - let port = addr.port(); - let work = work.and_then(|tcp| tunnel(tcp, host, port, None)); + let f = async move { + let tcp = TcpStream::connect(&addr).await?; + let host = addr.ip().to_string(); + let port = addr.port(); + tunnel(tcp, host, port, None).await + }; - let error = rt.block_on(work).unwrap_err(); + let error = rt.block_on(f).unwrap_err(); assert_eq!(error.to_string(), "proxy authentication required"); } @@ -844,18 +656,19 @@ mod tests { ); let mut rt = Runtime::new().unwrap(); - let work = TcpStream::connect(&addr); - let host = addr.ip().to_string(); - let port = addr.port(); - let work = work.and_then(|tcp| { + let f = async move { + let tcp = TcpStream::connect(&addr).await?; + let host = addr.ip().to_string(); + let port = addr.port(); tunnel( tcp, host, port, Some(proxy::encode_basic_auth("Aladdin", "open sesame")), ) - }); + .await + }; - rt.block_on(work).unwrap(); + rt.block_on(f).unwrap(); } } diff --git a/src/dns.rs b/src/dns.rs index 726d9e9..9d03eee 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -47,7 +47,7 @@ impl TrustDnsResolver { impl hyper_dns::Resolve for TrustDnsResolver { type Addrs = vec::IntoIter; - type Future = Box + Send>; + type Future = Box> + Send>; fn resolve(&self, name: hyper_dns::Name) -> Self::Future { let inner = self.inner.clone(); diff --git a/src/error.rs b/src/error.rs index 694ebe9..69d1d49 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,8 +2,6 @@ use std::error::Error as StdError; use std::fmt; use std::io; -use tokio_executor::EnterError; - use crate::{StatusCode, Url}; /// The Errors that may occur when processing a `Request`. @@ -95,7 +93,6 @@ impl Error { } pub(crate) fn with_url(mut self, url: Url) -> Error { - debug_assert_eq!(self.inner.url, None, "with_url overriding existing url"); self.inner.url = Some(url); self } @@ -221,6 +218,13 @@ impl Error { _ => None, } } + + pub(crate) fn into_io(self) -> io::Error { + match self.inner.kind { + Kind::Io(io) => io, + _ => io::Error::new(io::ErrorKind::Other, self), + } + } } impl fmt::Debug for Error { @@ -475,8 +479,8 @@ where } } -impl From for Kind { - fn from(_err: EnterError) -> Kind { +impl From for Kind { + fn from(_err: tokio_executor::EnterError) -> Kind { Kind::BlockingClientInFutureContext } } @@ -521,10 +525,7 @@ where } pub(crate) fn into_io(e: Error) -> io::Error { - match e.inner.kind { - Kind::Io(io) => io, - _ => io::Error::new(io::ErrorKind::Other, e), - } + e.into_io() } pub(crate) fn from_io(e: io::Error) -> Error { @@ -538,39 +539,12 @@ pub(crate) fn from_io(e: io::Error) -> Error { } } -macro_rules! try_ { - ($e:expr) => { - match $e { - Ok(v) => v, - Err(err) => { - return Err(crate::error::from(err)); - } - } - }; +macro_rules! url_error { ($e:expr, $url:expr) => { - match $e { - Ok(v) => v, - Err(err) => { - return Err(crate::Error::from(crate::error::InternalFrom( - err, - Some($url.clone()), - ))); - } - } - }; -} - -macro_rules! try_io { - ($e:expr) => { - match $e { - Ok(v) => v, - Err(ref err) if err.kind() == std::io::ErrorKind::WouldBlock => { - return Ok(futures::Async::NotReady); - } - Err(err) => { - return Err(crate::error::from_io(err)); - } - } + Err(crate::Error::from(crate::error::InternalFrom( + $e, + Some($url.clone()), + ))) }; } @@ -607,6 +581,9 @@ pub(crate) fn unknown_proxy_scheme() -> Error { mod tests { use super::*; + fn assert_send() {} + fn assert_sync() {} + #[allow(deprecated)] #[test] fn test_cause_chain() { @@ -652,6 +629,8 @@ mod tests { let err = Error::new(Kind::Io(io), None); assert!(err.cause().is_some()); assert_eq!(err.to_string(), "chain: root"); + assert_send::(); + assert_sync::(); } #[test] diff --git a/src/into_url.rs b/src/into_url.rs index 3442d5c..0a53ac7 100644 --- a/src/into_url.rs +++ b/src/into_url.rs @@ -27,7 +27,7 @@ impl PolyfillTryInto for Url { impl<'a> PolyfillTryInto for &'a str { fn into_url(self) -> crate::Result { - try_!(Url::parse(self)).into_url() + Url::parse(self).map_err(crate::error::from)?.into_url() } } diff --git a/src/lib.rs b/src/lib.rs index 9c007be..cd144ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -158,8 +158,6 @@ //! - **default-tls-vendored**: Enables the `vendored` feature of `native-tls`. //! - **rustls-tls**: Provides TLS support via the `rustls` library. //! - **socks**: Provides SOCKS5 proxy support. -//! - **trust-dns**: Enables a trust-dns async resolver instead of default -//! threadpool using `getaddrinfo`. //! - **hyper-011**: Provides support for hyper's old typed headers. //! //! @@ -173,6 +171,9 @@ //! [Proxy]: ./struct.Proxy.html //! [cargo-features]: https://doc.rust-lang.org/stable/cargo/reference/manifest.html#the-features-section +////! - **trust-dns**: Enables a trust-dns async resolver instead of default +////! threadpool using `getaddrinfo`. + extern crate cookie as cookie_crate; #[cfg(feature = "hyper-011")] pub use hyper_old_types as hyper_011; @@ -210,8 +211,8 @@ mod body; mod client; mod connect; pub mod cookie; -#[cfg(feature = "trust-dns")] -mod dns; +//#[cfg(feature = "trust-dns")] +//mod dns; mod into_url; mod proxy; mod redirect; diff --git a/src/multipart.rs b/src/multipart.rs index 644dbd5..1b7bc46 100644 --- a/src/multipart.rs +++ b/src/multipart.rs @@ -213,7 +213,6 @@ impl Part { let file_name = path .file_name() .map(|filename| filename.to_string_lossy().into_owned()); - let ext = path.extension().and_then(|ext| ext.to_str()).unwrap_or(""); let mime = mime_guess::from_ext(ext).first_or_octet_stream(); let file = File::open(path)?; @@ -235,7 +234,7 @@ impl Part { /// Tries to set the mime of this part. pub fn mime_str(self, mime: &str) -> crate::Result { - Ok(self.mime(try_!(mime.parse()))) + Ok(self.mime(mime.parse().map_err(crate::error::from)?)) } // Re-export when mime 0.4 is available, with split MediaType/MediaRange. diff --git a/src/response.rs b/src/response.rs index 46cbcda..cd68dcf 100644 --- a/src/response.rs +++ b/src/response.rs @@ -2,9 +2,9 @@ use std::fmt; use std::io::{self, Read}; use std::mem; use std::net::SocketAddr; +use std::pin::Pin; use std::time::Duration; -use futures::{Async, Poll, Stream}; use http; use serde::de::DeserializeOwned; @@ -16,7 +16,7 @@ use hyper::header::HeaderMap; /// A Response to a submitted `Request`. pub struct Response { inner: async_impl::Response, - body: Option>, + body: Option>>, timeout: Option, _thread_handle: KeepCoreThreadAlive, } @@ -289,7 +289,6 @@ impl Response { /// # Ok(()) /// # } /// ``` - #[inline] pub fn copy_to(&mut self, w: &mut W) -> crate::Result where W: io::Write, @@ -349,47 +348,32 @@ impl Response { pub fn error_for_status_ref(&self) -> crate::Result<&Self> { self.inner.error_for_status_ref().and_then(|_| Ok(self)) } -} -impl Read for Response { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { + // private + + fn body_mut(&mut self) -> Pin<&mut dyn futures::io::AsyncRead> { + use futures::stream::TryStreamExt; if self.body.is_none() { let body = mem::replace(self.inner.body_mut(), async_impl::Decoder::empty()); - let body = async_impl::ReadableChunks::new(WaitBody { - inner: wait::stream(body, self.timeout), - }); - self.body = Some(body); + + let body = body.map_err(crate::error::into_io).into_async_read(); + + self.body = Some(Box::pin(body)); } - let mut body = self.body.take().unwrap(); - let bytes = body.read(buf); - self.body = Some(body); - bytes + self.body.as_mut().expect("body was init").as_mut() } } -struct WaitBody { - inner: wait::WaitStream, -} +impl Read for Response { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + use futures::io::AsyncReadExt; -impl Stream for WaitBody { - type Item = ::Item; - type Error = ::Error; - - fn poll(&mut self) -> Poll, Self::Error> { - match self.inner.next() { - Some(Ok(chunk)) => Ok(Async::Ready(Some(chunk))), - Some(Err(e)) => { - let req_err = match e { - wait::Waited::TimedOut => crate::error::timedout(None), - wait::Waited::Executor(e) => crate::error::from(e), - wait::Waited::Inner(e) => e, - }; - - Err(req_err) - } - None => Ok(Async::Ready(None)), - } + let timeout = self.timeout; + wait::timeout(self.body_mut().read(buf), timeout).map_err(|e| match e { + wait::Waited::TimedOut => crate::error::timedout(None).into_io(), + wait::Waited::Executor(e) => crate::error::from(e).into_io(), + wait::Waited::Inner(e) => e, + }) } } diff --git a/src/tls.rs b/src/tls.rs index 9fa3946..19d9a73 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -55,7 +55,7 @@ impl Certificate { pub fn from_der(der: &[u8]) -> crate::Result { Ok(Certificate { #[cfg(feature = "default-tls")] - native: try_!(native_tls::Certificate::from_der(der)), + native: native_tls::Certificate::from_der(der).map_err(crate::error::from)?, #[cfg(feature = "rustls-tls")] original: Cert::Der(der.to_owned()), }) @@ -80,7 +80,7 @@ impl Certificate { pub fn from_pem(pem: &[u8]) -> crate::Result { Ok(Certificate { #[cfg(feature = "default-tls")] - native: try_!(native_tls::Certificate::from_pem(pem)), + native: native_tls::Certificate::from_pem(pem).map_err(crate::error::from)?, #[cfg(feature = "rustls-tls")] original: Cert::Pem(pem.to_owned()), }) @@ -146,7 +146,9 @@ impl Identity { #[cfg(feature = "default-tls")] pub fn from_pkcs12_der(der: &[u8], password: &str) -> crate::Result { Ok(Identity { - inner: ClientCert::Pkcs12(try_!(native_tls::Identity::from_pkcs12(der, password))), + inner: ClientCert::Pkcs12( + native_tls::Identity::from_pkcs12(der, password).map_err(crate::error::from)?, + ), }) } @@ -176,10 +178,11 @@ impl Identity { let (key, certs) = { let mut pem = Cursor::new(buf); - let certs = try_!(pemfile::certs(&mut pem) - .map_err(|_| TLSError::General(String::from("No valid certificate was found")))); + let certs = pemfile::certs(&mut pem) + .map_err(|_| TLSError::General(String::from("No valid certificate was found"))) + .map_err(crate::error::from)?; pem.set_position(0); - let mut sk = try_!(pemfile::pkcs8_private_keys(&mut pem) + let mut sk = pemfile::pkcs8_private_keys(&mut pem) .and_then(|pkcs8_keys| { if pkcs8_keys.is_empty() { Err(()) @@ -191,7 +194,8 @@ impl Identity { pem.set_position(0); pemfile::rsa_private_keys(&mut pem) }) - .map_err(|_| TLSError::General(String::from("No valid private key was found")))); + .map_err(|_| TLSError::General(String::from("No valid private key was found"))) + .map_err(crate::error::from)?; if let (Some(sk), false) = (sk.pop(), certs.is_empty()) { (sk, certs) } else { diff --git a/src/wait.rs b/src/wait.rs index 9beb41b..e96c904 100644 --- a/src/wait.rs +++ b/src/wait.rs @@ -1,26 +1,53 @@ +use std::future::Future; use std::sync::Arc; -use std::thread; -use std::time::{Duration, Instant}; +use std::task::{Context, Poll}; +use std::time::Duration; -use futures::executor::{self, Notify}; -use futures::{Async, Future, Poll, Stream}; -use tokio_executor::{enter, EnterError}; +use tokio::clock; +use tokio_executor::{ + enter, + park::{Park, ParkThread, Unpark, UnparkThread}, + EnterError, +}; -pub(crate) fn timeout(fut: F, timeout: Option) -> Result> +pub(crate) fn timeout(fut: F, timeout: Option) -> Result> where - F: Future, + F: Future>, { - let mut spawn = executor::spawn(fut); - block_on(timeout, |notify| spawn.poll_future_notify(notify, 0)) -} + let _entered = enter().map_err(Waited::Executor)?; + let deadline = timeout.map(|d| { + log::trace!("wait at most {:?}", d); + clock::now() + d + }); -pub(crate) fn stream(stream: S, timeout: Option) -> WaitStream -where - S: Stream, -{ - WaitStream { - stream: executor::spawn(stream), - timeout, + let mut park = ParkThread::new(); + // Arc shouldn't be necessary, since UnparkThread is reference counted internally, + // but let's just stay safe for now. + let waker = futures::task::waker(Arc::new(UnparkWaker(park.unpark()))); + let mut cx = Context::from_waker(&waker); + + futures::pin_mut!(fut); + + loop { + match fut.as_mut().poll(&mut cx) { + Poll::Ready(Ok(val)) => return Ok(val), + Poll::Ready(Err(err)) => return Err(Waited::Inner(err)), + Poll::Pending => (), // fallthrough + } + + if let Some(deadline) = deadline { + let now = clock::now(); + if now >= deadline { + log::trace!("wait timeout exceeded"); + return Err(Waited::TimedOut); + } + + log::trace!("park timeout {:?}", deadline - now); + park.park_timeout(deadline - now) + .expect("ParkThread doesn't error"); + } else { + park.park().expect("ParkThread doesn't error"); + } } } @@ -31,71 +58,10 @@ pub(crate) enum Waited { Inner(E), } -impl From for Waited { - fn from(err: E) -> Waited { - Waited::Inner(err) - } -} - -pub(crate) struct WaitStream { - stream: executor::Spawn, - timeout: Option, -} - -impl Iterator for WaitStream -where - S: Stream, -{ - type Item = Result>; - - fn next(&mut self) -> Option { - let res = block_on(self.timeout, |notify| { - self.stream.poll_stream_notify(notify, 0) - }); - - match res { - Ok(Some(val)) => Some(Ok(val)), - Ok(None) => None, - Err(err) => Some(Err(err)), - } - } -} - -struct ThreadNotify { - thread: thread::Thread, -} - -impl Notify for ThreadNotify { - fn notify(&self, _id: usize) { - self.thread.unpark(); - } -} - -fn block_on(timeout: Option, mut poll: F) -> Result> -where - F: FnMut(&Arc) -> Poll, -{ - let _entered = enter().map_err(Waited::Executor)?; - let deadline = timeout.map(|d| Instant::now() + d); - let notify = Arc::new(ThreadNotify { - thread: thread::current(), - }); - - loop { - match poll(¬ify)? { - Async::Ready(val) => return Ok(val), - Async::NotReady => {} - } - - if let Some(deadline) = deadline { - let now = Instant::now(); - if now >= deadline { - return Err(Waited::TimedOut); - } - - thread::park_timeout(deadline - now); - } else { - thread::park(); - } +struct UnparkWaker(UnparkThread); + +impl futures::task::ArcWake for UnparkWaker { + fn wake_by_ref(arc_self: &Arc) { + arc_self.0.unpark(); } } diff --git a/tests/async.rs b/tests/async.rs index 64073a7..d8a6ea7 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -1,29 +1,28 @@ #[macro_use] mod support; -use std::io::{self, Write}; +use std::io::Write; use std::time::Duration; -use futures::{Future, Stream}; -use tokio::runtime::current_thread::Runtime; +use futures::TryStreamExt; use reqwest::r#async::multipart::{Form, Part}; -use reqwest::r#async::{Chunk, Client}; +use reqwest::r#async::{Body, Client}; use bytes::Bytes; -#[test] -fn gzip_response() { - gzip_case(10_000, 4096); +#[tokio::test] +async fn gzip_response() { + gzip_case(10_000, 4096).await; } -#[test] -fn gzip_single_byte_chunks() { - gzip_case(10, 1); +#[tokio::test] +async fn gzip_single_byte_chunks() { + gzip_case(10, 1).await; } -#[test] -fn response_text() { +#[tokio::test] +async fn response_text() { let _ = env_logger::try_init(); let server = server! { @@ -43,24 +42,19 @@ fn response_text() { " }; - let mut rt = Runtime::new().expect("new rt"); - let client = Client::new(); - let res_future = client + let mut res = client .get(&format!("http://{}/text", server.addr())) .send() - .and_then(|mut res| res.text()) - .and_then(|text| { - assert_eq!("Hello", text); - Ok(()) - }); - - rt.block_on(res_future).unwrap(); + .await + .expect("Failed to get"); + let text = res.text().await.expect("Failed to get text"); + assert_eq!("Hello", text); } -#[test] -fn response_json() { +#[tokio::test] +async fn response_json() { let _ = env_logger::try_init(); let server = server! { @@ -80,28 +74,24 @@ fn response_json() { " }; - let mut rt = Runtime::new().expect("new rt"); - let client = Client::new(); - let res_future = client + let mut res = client .get(&format!("http://{}/json", server.addr())) .send() - .and_then(|mut res| res.json::()) - .and_then(|text| { - assert_eq!("Hello", text); - Ok(()) - }); - - rt.block_on(res_future).unwrap(); + .await + .expect("Failed to get"); + let text = res.json::().await.expect("Failed to get json"); + assert_eq!("Hello", text); } -#[test] -fn multipart() { +#[tokio::test] +async fn multipart() { let _ = env_logger::try_init(); - let stream = - futures::stream::once::<_, hyper::Error>(Ok(Chunk::from("part1 part2".to_owned()))); + let stream = futures::stream::once(futures::future::ready::>(Ok( + hyper::Chunk::from("part1 part2".to_owned()), + ))); let part = Part::stream(stream); let form = Form::new().text("foo", "bar").part("part_stream", part); @@ -153,22 +143,20 @@ fn multipart() { let url = format!("http://{}/multipart/1", server.addr()); - let mut rt = Runtime::new().expect("new rt"); - let client = Client::new(); - let res_future = client.post(&url).multipart(form).send().and_then(|res| { - assert_eq!(res.url().as_str(), &url); - assert_eq!(res.status(), reqwest::StatusCode::OK); - - Ok(()) - }); - - rt.block_on(res_future).unwrap(); + let res = client + .post(&url) + .multipart(form) + .send() + .await + .expect("Failed to post multipart"); + assert_eq!(res.url().as_str(), &url); + assert_eq!(res.status(), reqwest::StatusCode::OK); } -#[test] -fn request_timeout() { +#[tokio::test] +async fn request_timeout() { let _ = env_logger::try_init(); let server = server! { @@ -189,24 +177,23 @@ fn request_timeout() { read_timeout: Duration::from_secs(2) }; - let mut rt = Runtime::new().expect("new rt"); - let client = Client::builder() .timeout(Duration::from_millis(500)) .build() .unwrap(); let url = format!("http://{}/slow", server.addr()); - let fut = client.get(&url).send(); - let err = rt.block_on(fut).unwrap_err(); + let res = client.get(&url).send().await; + + let err = res.unwrap_err(); assert!(err.is_timeout()); assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str())); } -#[test] -fn response_timeout() { +#[tokio::test] +async fn response_timeout() { let _ = env_logger::try_init(); let server = server! { @@ -227,25 +214,21 @@ fn response_timeout() { write_timeout: Duration::from_secs(2) }; - let mut rt = Runtime::new().expect("new rt"); - let client = Client::builder() .timeout(Duration::from_millis(500)) .build() .unwrap(); let url = format!("http://{}/slow", server.addr()); - let fut = client - .get(&url) - .send() - .and_then(|res| res.into_body().concat2()); + let res = client.get(&url).send().await.expect("Failed to get"); + let body: Result<_, _> = res.into_body().try_concat().await; - let err = rt.block_on(fut).unwrap_err(); + let err = body.unwrap_err(); assert!(err.is_timeout()); } -fn gzip_case(response_size: usize, chunk_size: usize) { +async fn gzip_case(response_size: usize, chunk_size: usize) { let content: String = (0..response_size) .into_iter() .map(|i| format!("test {}", i)) @@ -284,37 +267,26 @@ fn gzip_case(response_size: usize, chunk_size: usize) { response: response }; - let mut rt = Runtime::new().expect("new rt"); - let client = Client::new(); - let res_future = client + let mut res = client .get(&format!("http://{}/gzip", server.addr())) .send() - .and_then(|res| { - let body = res.into_body(); - body.concat2() - }) - .and_then(|buf| { - let body = std::str::from_utf8(&buf).unwrap(); + .await + .expect("response"); - assert_eq!(body, &content); - - Ok(()) - }); - - rt.block_on(res_future).unwrap(); + let body = res.text().await.expect("text"); + assert_eq!(body, content); } -#[test] -fn body_stream() { +#[tokio::test] +async fn body_stream() { let _ = env_logger::try_init(); - let source: Box + Send> = - Box::new(futures::stream::iter_ok::<_, io::Error>(vec![ - Bytes::from_static(b"123"), - Bytes::from_static(b"4567"), - ])); + let source = futures::stream::iter::>>(vec![ + Ok(Bytes::from_static(b"123")), + Ok(Bytes::from_static(b"4567")), + ]); let expected_body = "3\r\n123\r\n4\r\n4567\r\n0\r\n\r\n"; @@ -339,16 +311,15 @@ fn body_stream() { let url = format!("http://{}/post", server.addr()); - let mut rt = Runtime::new().expect("new rt"); - let client = Client::new(); - let res_future = client.post(&url).body(source).send().and_then(|res| { - assert_eq!(res.url().as_str(), &url); - assert_eq!(res.status(), reqwest::StatusCode::OK); + let res = client + .post(&url) + .body(Body::wrap_stream(source)) + .send() + .await + .expect("Failed to post"); - Ok(()) - }); - - rt.block_on(res_future).unwrap(); + assert_eq!(res.url().as_str(), &url); + assert_eq!(res.status(), reqwest::StatusCode::OK); } diff --git a/tests/client.rs b/tests/client.rs index 6f1bdf4..85ba691 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1,8 +1,6 @@ #[macro_use] mod support; -use std::io::Read; - #[test] fn test_response_text() { let server = server! { @@ -137,9 +135,7 @@ fn test_response_copy_to() { &"5" ); - let mut buf: Vec = vec![]; - res.copy_to(&mut buf).unwrap(); - assert_eq!(b"Hello", buf.as_slice()); + assert_eq!("Hello".to_owned(), res.text().unwrap()); } #[test] @@ -173,9 +169,7 @@ fn test_get() { ); assert_eq!(res.remote_addr(), Some(server.addr())); - let mut buf = [0; 1024]; - let n = res.read(&mut buf).unwrap(); - assert_eq!(n, 0) + assert_eq!(res.text().unwrap().len(), 0) } #[test] @@ -214,9 +208,7 @@ fn test_post() { &"0" ); - let mut buf = [0; 1024]; - let n = res.read(&mut buf).unwrap(); - assert_eq!(n, 0) + assert_eq!(res.text().unwrap().len(), 0) } #[test] diff --git a/tests/timeouts.rs b/tests/timeouts.rs index 0c57eae..707d08e 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -1,9 +1,10 @@ #[macro_use] mod support; -use std::io::Read; use std::time::Duration; +/// Tests that internal client future cancels when the oneshot channel +/// is canceled. #[test] fn timeout_closes_connection() { let _ = env_logger::try_init(); @@ -156,7 +157,6 @@ fn test_read_timeout() { &"5" ); - let mut buf = [0; 1024]; - let err = res.read(&mut buf).unwrap_err(); + let err = res.text().unwrap_err(); assert_eq!(err.to_string(), "timed out"); }