refactor all to async/await (#617)

Co-authored-by: Danny Browning <danny.browning@protectwise.com>
Co-authored-by: Daniel Eades <danieleades@hotmail.com>
This commit is contained in:
Sean McArthur
2019-09-06 17:22:56 -07:00
committed by GitHub
parent d7fcd8ac2e
commit ba7b2a754e
30 changed files with 1106 additions and 1430 deletions

View File

@@ -8,7 +8,7 @@ environment:
MINGW_PATH: 'C:\MinGW\bin' MINGW_PATH: 'C:\MinGW\bin'
install: install:
- curl -sSf -o rustup-init.exe https://win.rustup.rs/ - curl -sSf -o rustup-init.exe https://win.rustup.rs/
- rustup-init.exe -y --default-host %TARGET% - rustup-init.exe -y --default-toolchain nightly --default-host %TARGET%
- set PATH=%PATH%;C:\Users\appveyor\.cargo\bin - set PATH=%PATH%;C:\Users\appveyor\.cargo\bin
- if defined MINGW_PATH set PATH=%PATH%;%MINGW_PATH% - if defined MINGW_PATH set PATH=%PATH%;%MINGW_PATH%
- rustc -vV - rustc -vV

View File

@@ -1,40 +1,44 @@
language: rust language: rust
matrix: matrix:
fast_finish: true fast_finish: true
allow_failures: #allow_failures:
- rust: nightly # - rust: nightly
include: include:
- os: osx - os: osx
rust: stable rust: nightly
#rust: stable
- rust: stable #- rust: stable
- rust: beta #- rust: beta
- rust: nightly - rust: nightly
# Disable default-tls # Disable default-tls
- rust: stable #- rust: stable
- rust: nightly
env: FEATURES="--no-default-features" env: FEATURES="--no-default-features"
# rustls-tls # rustls-tls
- rust: stable #- rust: stable
env: FEATURES="--no-default-features --features rustls-tls" #- rust: nightly
# env: FEATURES="--no-default-features --features rustls-tls"
# default-tls and rustls-tls # default-tls and rustls-tls
- rust: stable #- rust: stable
env: FEATURES="--features rustls-tls" #- rust: nightly
# env: FEATURES="--features rustls-tls"
# default-tls, rustls, and socks! # socks
- rust: stable #- rust: stable
env: FEATURES="--features rustls-tls,socks" #- rust: nightly
# env: FEATURES="--features socks"
- rust: stable #- rust: stable
env: FEATURES="--features hyper-011" #- rust: nightly
# env: FEATURES="--features trust-dns"
- rust: stable
env: FEATURES="--features trust-dns"
# android # android
- rust: stable #- rust: stable
- rust: nightly
env: TARGET=aarch64-linux-android env: TARGET=aarch64-linux-android
before_install: before_install:
- wget https://dl.google.com/android/repository/android-ndk-r19c-linux-x86_64.zip; - wget https://dl.google.com/android/repository/android-ndk-r19c-linux-x86_64.zip;
@@ -45,9 +49,16 @@ matrix:
# disable default-tls feature since cross-compiling openssl is dragons # disable default-tls feature since cross-compiling openssl is dragons
script: cargo build --target "$TARGET" --no-default-features script: cargo build --target "$TARGET" --no-default-features
# Check rustfmt
- name: "rustfmt check"
rust: stable
install: rustup component add rustfmt
script: cargo fmt -- --check
# minimum version # minimum version
- rust: 1.34.0 #- rust: 1.39.0
script: cargo build # script: cargo build
sudo: false sudo: false
dist: trusty dist: trusty
@@ -55,9 +66,6 @@ dist: trusty
env: env:
global: global:
- REQWEST_TEST_BODY_FULL=1 - REQWEST_TEST_BODY_FULL=1
before_script:
- rustup component add rustfmt
script: script:
- cargo fmt -- --check
- cargo build $FEATURES - cargo build $FEATURES
- cargo test -v $FEATURES -- --test-threads=1 - cargo test -v $FEATURES -- --test-threads=1

View File

@@ -20,62 +20,66 @@ all-features = true
base64 = "0.10" base64 = "0.10"
bytes = "0.4" bytes = "0.4"
encoding_rs = "0.8" encoding_rs = "0.8"
futures = "0.1.23" futures-preview = { version = "=0.3.0-alpha.18" }
http = "0.1.15" http = "0.1.15"
hyper = "0.12.22" hyper = "=0.13.0-alpha.1"
flate2 = { version = "^1.0.7", default-features = false, features = ["rust_backend"] }
log = "0.4" log = "0.4"
mime = "0.3.7" mime = "0.3.7"
mime_guess = "2.0" mime_guess = "2.0"
percent-encoding = "2.1" percent-encoding = "2.1"
tokio = { version = "=0.2.0-alpha.4", default-features = false, features = ["rt-full", "tcp"] }
tokio-executor = "=0.2.0-alpha.4"
url = "2.1"
uuid = { version = "0.7", features = ["v4"] }
time = "0.1.42"
# TODO: candidates for optional features
async-compression = { version = "0.1.0-alpha.4", default-features = false, features = ["gzip", "stream"] }
cookie_store = "0.9.0"
cookie = "0.12.0"
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
serde_urlencoded = "0.6.1" serde_urlencoded = "0.6.1"
tokio = { version = "0.1.7", default-features = false, features = ["rt-full", "tcp"] }
tokio-executor = "0.1.4" # a minimum version so trust-dns-resolver compiles
tokio-io = "0.1"
tokio-threadpool = "0.1.8" # a minimum version so tokio compiles
tokio-timer = "0.2.6" # a minimum version so trust-dns-resolver compiles
url = "2.1"
uuid = { version = "0.7", features = ["v4"] }
# Optional deps... # Optional deps...
hyper-old-types = { version = "0.11", optional = true, features = ["compat"] } ## default-tls
hyper-rustls = { version = "^0.17.1", optional = true } hyper-tls = { version = "=0.4.0-alpha.1", optional = true }
hyper-tls = { version = "0.3.2", optional = true }
native-tls = { version = "0.2", optional = true } native-tls = { version = "0.2", optional = true }
rustls = { version = "0.16", features = ["dangerous_configuration"], optional = true } tokio-tls = { version = "=0.3.0-alpha.4", optional = true }
socks = { version = "0.3.2", optional = true }
tokio-rustls = { version = "0.10", optional = true } ## rustls-tls
trust-dns-resolver = { version = "0.11", optional = true } #hyper-rustls = { git = "https://github.com/dbcfd/hyper-rustls.git", branch = "master", optional = true }
webpki-roots = { version = "0.17", optional = true } #rustls = { version = "0.16", features = ["dangerous_configuration"], optional = true }
cookie_store = "0.9.0" #tokio-rustls = { version = "=0.12.0-alpha.2", optional = true }
cookie = "0.12.0" #webpki-roots = { version = "0.17", optional = true }
time = "0.1.42"
## socks
#socks = { version = "0.3.2", optional = true }
## trust-dns
#trust-dns-resolver = { version = "0.11", optional = true }
[dev-dependencies] [dev-dependencies]
env_logger = "0.6" env_logger = "0.6"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
tokio = { version = "0.1.7", default-features = false, features = ["rt-full", "tcp", "fs"] }
tokio-tcp = "0.1"
libflate = "0.1" libflate = "0.1"
doc-comment = "0.3" doc-comment = "0.3"
bytes = "0.4" bytes = "0.4"
tokio-fs = { version = "=0.2.0-alpha.4" }
[features] [features]
default = ["default-tls"] default = ["default-tls"]
tls = [] tls = []
default-tls = ["hyper-tls", "native-tls", "tls"] default-tls = ["hyper-tls", "native-tls", "tls", "tokio-tls"]
default-tls-vendored = ["default-tls", "native-tls/vendored"] default-tls-vendored = ["default-tls", "native-tls/vendored"]
rustls-tls = ["hyper-rustls", "tokio-rustls", "webpki-roots", "rustls", "tls"] #rustls-tls = ["hyper-rustls", "tokio-rustls", "webpki-roots", "rustls", "tls"]
trust-dns = ["trust-dns-resolver"] #trust-dns = ["trust-dns-resolver"]
hyper-011 = ["hyper-old-types"]
[target.'cfg(windows)'.dependencies] [target.'cfg(windows)'.dependencies]
winreg = "0.6" winreg = "0.6"

View File

@@ -1,29 +1,16 @@
#![deny(warnings)] #![deny(warnings)]
use futures::{Future, Stream}; use reqwest::r#async::Client;
use reqwest::r#async::{Client, Decoder};
use std::io::{self, Cursor};
use std::mem;
fn fetch() -> impl Future<Item = (), Error = ()> { #[tokio::main]
Client::new() async fn main() -> Result<(), reqwest::Error> {
.get("https://hyper.rs") let mut res = Client::new().get("https://hyper.rs").send().await?;
.send()
.and_then(|mut res| {
println!("{}", res.status());
let body = mem::replace(res.body_mut(), Decoder::empty()); println!("Status: {}", res.status());
body.concat2()
}) let body = res.text().await?;
.map_err(|err| println!("request error: {}", err))
.map(|body| { println!("Body:\n\n{}", body);
let mut body = Cursor::new(body);
let _ = io::copy(&mut body, &mut io::stdout()).map_err(|err| { Ok(())
println!("stdout error: {}", err);
});
})
}
fn main() {
tokio::run(fetch());
} }

View File

@@ -1,8 +1,8 @@
#![deny(warnings)] #![deny(warnings)]
use futures::Future;
use reqwest::r#async::{Client, Response}; use reqwest::r#async::{Client, Response};
use serde::Deserialize; use serde::Deserialize;
use std::future::Future;
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
struct Slideshow { struct Slideshow {
@@ -15,26 +15,27 @@ struct SlideshowContainer {
slideshow: Slideshow, slideshow: Slideshow,
} }
fn fetch() -> impl Future<Item = (), Error = ()> { async fn into_json<F>(f: F) -> Result<SlideshowContainer, reqwest::Error>
where
F: Future<Output = Result<Response, reqwest::Error>>,
{
let mut resp = f.await?;
resp.json::<SlideshowContainer>().await
}
#[tokio::main]
async fn main() -> Result<(), reqwest::Error> {
let client = Client::new(); let client = Client::new();
let json = |mut res: Response| res.json::<SlideshowContainer>(); let request1 = client.get("https://httpbin.org/json").send();
let request1 = client.get("https://httpbin.org/json").send().and_then(json); let request2 = client.get("https://httpbin.org/json").send();
let request2 = client.get("https://httpbin.org/json").send().and_then(json); let (try_json1, try_json2) =
futures::future::join(into_json(request1), into_json(request2)).await;
request1 println!("{:?}", try_json1?);
.join(request2) println!("{:?}", try_json2?);
.map(|(res1, res2)| {
println!("{:?}", res1); Ok(())
println!("{:?}", res2);
})
.map_err(|err| {
println!("stdout error: {}", err);
})
}
fn main() {
tokio::run(fetch());
} }

View File

@@ -1,74 +0,0 @@
#![deny(warnings)]
use std::io::{self, Cursor};
use std::mem;
use std::path::Path;
use bytes::Bytes;
use futures::{try_ready, Async, Future, Poll, Stream};
use reqwest::r#async::{Client, Decoder};
use tokio::fs::File;
use tokio::io::AsyncRead;
const CHUNK_SIZE: usize = 1024;
struct FileSource {
inner: File,
}
impl FileSource {
fn new(file: File) -> FileSource {
FileSource { inner: file }
}
}
impl Stream for FileSource {
type Item = Bytes;
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let mut buf = [0; CHUNK_SIZE];
let size = try_ready!(self.inner.poll_read(&mut buf));
if size > 0 {
Ok(Async::Ready(Some(buf[0..size].into())))
} else {
Ok(Async::Ready(None))
}
}
}
fn post<P>(path: P) -> impl Future<Item = (), Error = ()>
where
P: AsRef<Path>,
{
File::open(path.as_ref().to_owned())
.map_err(|err| println!("request error: {}", err))
.and_then(|file| {
let source: Box<dyn Stream<Item = Bytes, Error = io::Error> + Send> =
Box::new(FileSource::new(file));
Client::new()
.post("https://httpbin.org/post")
.body(source)
.send()
.and_then(|mut res| {
println!("{}", res.status());
let body = mem::replace(res.body_mut(), Decoder::empty());
body.concat2()
})
.map_err(|err| println!("request error: {}", err))
.map(|body| {
let mut body = Cursor::new(body);
let _ = io::copy(&mut body, &mut io::stdout()).map_err(|err| {
println!("stdout error: {}", err);
});
})
})
}
fn main() {
let pool = tokio_threadpool::ThreadPool::new();
let path = concat!(env!("CARGO_MANIFEST_DIR"), "/LICENSE-APACHE");
tokio::run(pool.spawn_handle(post(path)));
}

View File

@@ -3,12 +3,11 @@
//! This is useful for some ad-hoc experiments and situations when you don't //! This is useful for some ad-hoc experiments and situations when you don't
//! really care about the structure of the JSON and just need to display it or //! really care about the structure of the JSON and just need to display it or
//! process it at runtime. //! process it at runtime.
use serde_json::json;
fn main() -> Result<(), reqwest::Error> { fn main() -> Result<(), reqwest::Error> {
let echo_json: serde_json::Value = reqwest::Client::new() let echo_json: serde_json::Value = reqwest::Client::new()
.post("https://jsonplaceholder.typicode.com/posts") .post("https://jsonplaceholder.typicode.com/posts")
.json(&json!({ .json(&serde_json::json!({
"title": "Reqwest.rs", "title": "Reqwest.rs",
"body": "https://docs.rs/reqwest", "body": "https://docs.rs/reqwest",
"userId": 1 "userId": 1

View File

@@ -1,6 +1,5 @@
#![deny(warnings)]
//! `cargo run --example simple` //! `cargo run --example simple`
#![deny(warnings)]
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init(); env_logger::init();
@@ -13,7 +12,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Headers:\n{:?}", res.headers()); println!("Headers:\n{:?}", res.headers());
// copy the response body directly to stdout // copy the response body directly to stdout
std::io::copy(&mut res, &mut std::io::stdout())?; res.copy_to(&mut std::io::stdout())?;
println!("\n\nDone."); println!("\n\nDone.");
Ok(()) Ok(())

View File

@@ -0,0 +1,101 @@
#![deny(warnings)]
use std::io::{self, Cursor};
use std::mem;
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures::{Stream, TryStreamExt};
use reqwest::r#async::{Body, Client, Decoder};
use tokio_fs::File;
use tokio::io::AsyncRead;
use failure::Fail;
#[derive(Debug, Fail)]
pub enum Error {
#[fail(display = "Io Error")]
Io(#[fail(cause)] std::io::Error),
#[fail(display = "Reqwest error")]
Reqwest(#[fail(cause)] reqwest::Error),
}
unsafe impl Send for Error {}
unsafe impl Sync for Error {}
struct AsyncReadWrapper<T> {
inner: T,
}
impl<T> AsyncReadWrapper<T> {
fn inner(self: Pin<&mut Self>) -> Pin<&mut T> {
unsafe {
Pin::map_unchecked_mut(self, |x| &mut x.inner)
}
}
}
impl<T> Stream for AsyncReadWrapper<T>
where T: AsyncRead
{
type Item = Result<hyper::Chunk, failure::Compat<Error>>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let mut buf = vec![];
loop {
let mut read_buf = vec![];
match self.as_mut().inner().as_mut().poll_read(cx, &mut read_buf) {
Poll::Pending => {
if buf.is_empty() {
return Poll::Pending;
} else {
return Poll::Ready(Some(Ok(buf.into())));
}
}
Poll::Ready(Err(e)) => {
return Poll::Ready(Some(Err(Error::Io(e).compat())))
},
Poll::Ready(Ok(n)) => {
buf.extend_from_slice(&read_buf[..n]);
if buf.is_empty() && n == 0 {
return Poll::Ready(None);
} else {
return Poll::Ready(Some(Ok(buf.into())));
}
}
}
}
}
}
async fn post<P>(path: P) -> Result<(), Error>
where
P: AsRef<Path> + Send + Unpin + 'static,
{
let source = File::open(path)
.await.map_err(Error::Io)?;
let wrapper = AsyncReadWrapper { inner: source };
let mut res = Client::new()
.post("https://httpbin.org/post")
.body(Body::wrap_stream(wrapper))
.send()
.await.map_err(Error::Reqwest)?;
println!("{}", res.status());
let body = mem::replace(res.body_mut(), Decoder::empty());
let body: Result<_, _> = body.try_concat().await;
let mut body = Cursor::new(body.map_err(Error::Reqwest)?);
io::copy(&mut body, &mut io::stdout()).map_err(Error::Io)?;
Ok(())
}
#[tokio::main]
async fn main() -> Result<(), Error> {
let path = concat!(env!("CARGO_MANIFEST_DIR"), "/LICENSE-APACHE");
post(path).await
}

View File

@@ -1,8 +1,10 @@
use std::fmt;
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use futures::{try_ready, Async, Future, Poll, Stream}; use futures::Stream;
use hyper::body::Payload; use hyper::body::Payload;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::timer::Delay; use tokio::timer::Delay;
/// An asynchronous `Stream`. /// An asynchronous `Stream`.
@@ -22,10 +24,38 @@ impl Body {
pub(crate) fn content_length(&self) -> Option<u64> { pub(crate) fn content_length(&self) -> Option<u64> {
match self.inner { match self.inner {
Inner::Reusable(ref bytes) => Some(bytes.len() as u64), Inner::Reusable(ref bytes) => Some(bytes.len() as u64),
Inner::Hyper { ref body, .. } => body.content_length(), Inner::Hyper { ref body, .. } => body.size_hint().exact(),
} }
} }
/// Wrap a futures `Stream` in a box inside `Body`.
///
/// # Example
///
/// ```
/// # use reqwest::r#async::Body;
/// # use futures;
/// # fn main() {
/// let chunks: Vec<Result<_, ::std::io::Error>> = vec![
/// Ok("hello"),
/// Ok(" "),
/// Ok("world"),
/// ];
///
/// let stream = futures::stream::iter(chunks);
///
/// let body = Body::wrap_stream(stream);
/// # }
/// ```
pub fn wrap_stream<S>(stream: S) -> Body
where
S: futures::TryStream + Send + Sync + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
hyper::Chunk: From<S::Ok>,
{
Body::wrap(hyper::body::Body::wrap_stream(stream))
}
#[inline] #[inline]
pub(crate) fn response(body: hyper::Body, timeout: Option<Delay>) -> Body { pub(crate) fn response(body: hyper::Body, timeout: Option<Delay>) -> Body {
Body { Body {
@@ -65,38 +95,45 @@ impl Body {
} }
} }
} }
fn inner(self: Pin<&mut Self>) -> Pin<&mut Inner> {
unsafe { Pin::map_unchecked_mut(self, |x| &mut x.inner) }
}
} }
impl Stream for Body { impl Stream for Body {
type Item = Chunk; type Item = Result<Chunk, crate::Error>;
type Error = crate::Error;
#[inline] #[inline]
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let opt = match self.inner { let opt_try_chunk = match self.inner().get_mut() {
Inner::Hyper { Inner::Hyper {
ref mut body, ref mut body,
ref mut timeout, ref mut timeout,
} => { } => {
if let Some(ref mut timeout) = timeout { if let Some(ref mut timeout) = timeout {
if let Async::Ready(()) = try_!(timeout.poll()) { if let Poll::Ready(()) = Pin::new(timeout).poll(cx) {
return Err(crate::error::timedout(None)); return Poll::Ready(Some(Err(crate::error::timedout(None))));
} }
} }
try_ready!(body.poll_data().map_err(crate::error::from)) futures::ready!(Pin::new(body).poll_data(cx)).map(|opt_chunk| {
opt_chunk
.map(|c| Chunk { inner: c })
.map_err(crate::error::from)
})
} }
Inner::Reusable(ref mut bytes) => { Inner::Reusable(ref mut bytes) => {
return if bytes.is_empty() { if bytes.is_empty() {
Ok(Async::Ready(None)) None
} else { } else {
let chunk = Chunk::from_chunk(bytes.clone()); let chunk = Chunk::from_chunk(bytes.clone());
*bytes = Bytes::new(); *bytes = Bytes::new();
Ok(Async::Ready(Some(chunk))) Some(Ok(chunk))
}; }
} }
}; };
Ok(Async::Ready(opt.map(|chunk| Chunk { inner: chunk }))) Poll::Ready(opt_try_chunk)
} }
} }
@@ -135,18 +172,6 @@ impl From<&'static str> for Body {
} }
} }
impl<I, E> From<Box<dyn Stream<Item = I, Error = E> + Send>> for Body
where
hyper::Chunk: From<I>,
I: 'static,
E: std::error::Error + Send + Sync + 'static,
{
#[inline]
fn from(s: Box<dyn Stream<Item = I, Error = E> + Send>) -> Body {
Body::wrap(hyper::Body::wrap_stream(s))
}
}
/// A chunk of bytes for a `Body`. /// A chunk of bytes for a `Body`.
/// ///
/// A `Chunk` can be treated like `&[u8]`. /// A `Chunk` can be treated like `&[u8]`.
@@ -247,6 +272,12 @@ impl From<Bytes> for Chunk {
} }
} }
impl From<Chunk> for Bytes {
fn from(chunk: Chunk) -> Bytes {
chunk.inner.into()
}
}
impl From<Chunk> for hyper::Chunk { impl From<Chunk> for hyper::Chunk {
fn from(val: Chunk) -> hyper::Chunk { fn from(val: Chunk) -> hyper::Chunk {
val.inner val.inner

View File

@@ -8,12 +8,14 @@ use crate::header::{
CONTENT_TYPE, LOCATION, PROXY_AUTHORIZATION, RANGE, REFERER, TRANSFER_ENCODING, USER_AGENT, CONTENT_TYPE, LOCATION, PROXY_AUTHORIZATION, RANGE, REFERER, TRANSFER_ENCODING, USER_AGENT,
}; };
use bytes::Bytes; use bytes::Bytes;
use futures::{Async, Future, Poll};
use http::Uri; use http::Uri;
use hyper::client::ResponseFuture; use hyper::client::ResponseFuture;
use mime; use mime;
#[cfg(feature = "default-tls")] #[cfg(feature = "default-tls")]
use native_tls::TlsConnector; use native_tls::TlsConnector;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::{clock, timer::Delay}; use tokio::{clock, timer::Delay};
use log::debug; use log::debug;
@@ -540,7 +542,10 @@ impl Client {
/// ///
/// This method fails if there was an error while sending request, /// This method fails if there was an error while sending request,
/// redirect loop was detected or redirect limit was exhausted. /// redirect loop was detected or redirect limit was exhausted.
pub fn execute(&self, request: Request) -> impl Future<Item = Response, Error = crate::Error> { pub fn execute(
&self,
request: Request,
) -> impl Future<Output = Result<Response, crate::Error>> {
self.execute_request(request) self.execute_request(request)
} }
@@ -593,7 +598,7 @@ impl Client {
let timeout = self let timeout = self
.inner .inner
.request_timeout .request_timeout
.map(|dur| Delay::new(clock::now() + dur)); .map(|dur| tokio::timer::delay(clock::now() + dur));
Pending { Pending {
inner: PendingInner::Request(PendingRequest { inner: PendingInner::Request(PendingRequest {
@@ -691,43 +696,65 @@ struct PendingRequest {
timeout: Option<Delay>, timeout: Option<Delay>,
} }
impl PendingRequest {
fn in_flight(self: Pin<&mut Self>) -> Pin<&mut ResponseFuture> {
unsafe { Pin::map_unchecked_mut(self, |x| &mut x.in_flight) }
}
fn timeout(self: Pin<&mut Self>) -> Pin<&mut Option<Delay>> {
unsafe { Pin::map_unchecked_mut(self, |x| &mut x.timeout) }
}
fn urls(self: Pin<&mut Self>) -> &mut Vec<Url> {
unsafe { &mut Pin::get_unchecked_mut(self).urls }
}
fn headers(self: Pin<&mut Self>) -> &mut HeaderMap {
unsafe { &mut Pin::get_unchecked_mut(self).headers }
}
}
impl Pending { impl Pending {
pub(super) fn new_err(err: crate::Error) -> Pending { pub(super) fn new_err(err: crate::Error) -> Pending {
Pending { Pending {
inner: PendingInner::Error(Some(err)), inner: PendingInner::Error(Some(err)),
} }
} }
fn inner(self: Pin<&mut Self>) -> Pin<&mut PendingInner> {
unsafe { Pin::map_unchecked_mut(self, |x| &mut x.inner) }
}
} }
impl Future for Pending { impl Future for Pending {
type Item = Response; type Output = Result<Response, crate::Error>;
type Error = crate::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.inner { let inner = self.inner();
PendingInner::Request(ref mut req) => req.poll(), match inner.get_mut() {
PendingInner::Error(ref mut err) => { PendingInner::Request(ref mut req) => Pin::new(req).poll(cx),
Err(err.take().expect("Pending error polled more than once")) PendingInner::Error(ref mut err) => Poll::Ready(Err(err
} .take()
.expect("Pending error polled more than once"))),
} }
} }
} }
impl Future for PendingRequest { impl Future for PendingRequest {
type Item = Response; type Output = Result<Response, crate::Error>;
type Error = crate::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(ref mut delay) = self.timeout { if let Some(delay) = self.as_mut().timeout().as_mut().as_pin_mut() {
if let Async::Ready(()) = try_!(delay.poll(), &self.url) { if let Poll::Ready(()) = delay.poll(cx) {
return Err(crate::error::timedout(Some(self.url.clone()))); return Poll::Ready(Err(crate::error::timedout(Some(self.url.clone()))));
} }
} }
loop { loop {
let res = match try_!(self.in_flight.poll(), &self.url) { let res = match self.as_mut().in_flight().as_mut().poll(cx) {
Async::Ready(res) => res, Poll::Ready(Err(e)) => return Poll::Ready(url_error!(e, &self.url)),
Async::NotReady => return Ok(Async::NotReady), Poll::Ready(Ok(res)) => res,
Poll::Pending => return Poll::Pending,
}; };
if let Some(store_wrapper) = self.client.cookie_store.as_ref() { if let Some(store_wrapper) = self.client.cookie_store.as_ref() {
let mut store = store_wrapper.write().unwrap(); let mut store = store_wrapper.write().unwrap();
@@ -795,7 +822,8 @@ impl Future for PendingRequest {
self.headers.insert(REFERER, referer); self.headers.insert(REFERER, referer);
} }
} }
self.urls.push(self.url.clone()); let url = self.url.clone();
self.as_mut().urls().push(url);
let action = self let action = self
.client .client
.redirect_policy .redirect_policy
@@ -805,7 +833,10 @@ impl Future for PendingRequest {
redirect::Action::Follow => { redirect::Action::Follow => {
self.url = loc; self.url = loc;
remove_sensitive_headers(&mut self.headers, &self.url, &self.urls); let mut headers =
std::mem::replace(self.as_mut().headers(), HeaderMap::new());
remove_sensitive_headers(&mut headers, &self.url, &self.urls);
debug!("redirecting to {:?} '{}'", self.method, self.url); debug!("redirecting to {:?} '{}'", self.method, self.url);
let uri = expect_uri(&self.url); let uri = expect_uri(&self.url);
let body = match self.body { let body = match self.body {
@@ -821,27 +852,30 @@ impl Future for PendingRequest {
// Add cookies from the cookie store. // Add cookies from the cookie store.
if let Some(cookie_store_wrapper) = self.client.cookie_store.as_ref() { if let Some(cookie_store_wrapper) = self.client.cookie_store.as_ref() {
let cookie_store = cookie_store_wrapper.read().unwrap(); let cookie_store = cookie_store_wrapper.read().unwrap();
add_cookie_header(&mut self.headers, &cookie_store, &self.url); add_cookie_header(&mut headers, &cookie_store, &self.url);
} }
*req.headers_mut() = self.headers.clone(); *req.headers_mut() = headers.clone();
self.in_flight = self.client.hyper.request(req); std::mem::swap(self.as_mut().headers(), &mut headers);
*self.as_mut().in_flight().get_mut() = self.client.hyper.request(req);
continue; continue;
} }
redirect::Action::Stop => { redirect::Action::Stop => {
debug!("redirect_policy disallowed redirection to '{}'", loc); debug!("redirect_policy disallowed redirection to '{}'", loc);
} }
redirect::Action::LoopDetected => { redirect::Action::LoopDetected => {
return Err(crate::error::loop_detected(self.url.clone())); return Poll::Ready(Err(crate::error::loop_detected(self.url.clone())));
} }
redirect::Action::TooManyRedirects => { redirect::Action::TooManyRedirects => {
return Err(crate::error::too_many_redirects(self.url.clone())); return Poll::Ready(Err(crate::error::too_many_redirects(
self.url.clone(),
)));
} }
} }
} }
} }
let res = Response::new(res, self.url.clone(), self.client.gzip, self.timeout.take()); let res = Response::new(res, self.url.clone(), self.client.gzip, self.timeout.take());
return Ok(Async::Ready(res)); return Poll::Ready(Ok(res));
} }
} }
} }

View File

@@ -9,25 +9,16 @@ Chunks are just passed along.
If the response is gzip, then the chunks are decompressed into a buffer. If the response is gzip, then the chunks are decompressed into a buffer.
Slices of that buffer are emitted as new chunks. Slices of that buffer are emitted as new chunks.
This module consists of a few main types:
- `ReadableChunks` is a `Read`-like wrapper around a stream
- `Decoder` is a layer over `ReadableChunks` that applies the right decompression
The following types directly support the gzip compression case:
- `Pending` is a non-blocking constructor for a `Decoder` in case the body needs to be checked for EOF
*/ */
use std::cmp;
use std::fmt; use std::fmt;
use std::io::{self, Read}; use std::future::Future;
use std::mem; use std::mem;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Buf, BufMut, BytesMut}; use bytes::Bytes;
use flate2::read::GzDecoder; use futures::Stream;
use futures::{Async, Future, Poll, Stream};
use hyper::header::{CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING}; use hyper::header::{CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING};
use hyper::HeaderMap; use hyper::HeaderMap;
@@ -36,8 +27,6 @@ use log::warn;
use super::{Body, Chunk}; use super::{Body, Chunk};
use crate::error; use crate::error;
const INIT_BUFFER_SIZE: usize = 8192;
/// A response decompressor over a non-blocking stream of chunks. /// A response decompressor over a non-blocking stream of chunks.
/// ///
/// The inner decoder may be constructed asynchronously. /// The inner decoder may be constructed asynchronously.
@@ -49,22 +38,15 @@ enum Inner {
/// A `PlainText` decoder just returns the response content as is. /// A `PlainText` decoder just returns the response content as is.
PlainText(Body), PlainText(Body),
/// A `Gzip` decoder will uncompress the gzipped response content before returning it. /// A `Gzip` decoder will uncompress the gzipped response content before returning it.
Gzip(Gzip), Gzip(async_compression::stream::GzipDecoder<futures::stream::Peekable<BodyBytes>>),
/// A decoder that doesn't have a value yet. /// A decoder that doesn't have a value yet.
Pending(Pending), Pending(Pending),
} }
/// A future attempt to poll the response body for EOF so we know whether to use gzip or not. /// A future attempt to poll the response body for EOF so we know whether to use gzip or not.
struct Pending { struct Pending(futures::stream::Peekable<BodyBytes>);
body: ReadableChunks<Body>,
}
/// A gzip decoder that reads from a `flate2::read::GzDecoder` into a `BytesMut` and emits the results struct BodyBytes(Body);
/// as a `Chunk`.
struct Gzip {
inner: Box<GzDecoder<ReadableChunks<Body>>>,
buf: BytesMut,
}
impl fmt::Debug for Decoder { impl fmt::Debug for Decoder {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@@ -86,7 +68,6 @@ impl Decoder {
/// A plain text decoder. /// A plain text decoder.
/// ///
/// This decoder will emit the underlying chunks as-is. /// This decoder will emit the underlying chunks as-is.
#[inline]
fn plain_text(body: Body) -> Decoder { fn plain_text(body: Body) -> Decoder {
Decoder { Decoder {
inner: Inner::PlainText(body), inner: Inner::PlainText(body),
@@ -96,12 +77,11 @@ impl Decoder {
/// A gzip decoder. /// A gzip decoder.
/// ///
/// This decoder will buffer and decompress chunks that are gzipped. /// This decoder will buffer and decompress chunks that are gzipped.
#[inline]
fn gzip(body: Body) -> Decoder { fn gzip(body: Body) -> Decoder {
use futures::stream::StreamExt;
Decoder { Decoder {
inner: Inner::Pending(Pending { inner: Inner::Pending(Pending(BodyBytes(body).peekable())),
body: ReadableChunks::new(body),
}),
} }
} }
@@ -148,189 +128,65 @@ impl Decoder {
} }
impl Stream for Decoder { impl Stream for Decoder {
type Item = Chunk; type Item = Result<Chunk, error::Error>;
type Error = error::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
// Do a read or poll for a pendidng decoder value. // Do a read or poll for a pending decoder value.
let new_value = match self.inner { let new_value = match self.inner {
Inner::Pending(ref mut future) => match future.poll() { Inner::Pending(ref mut future) => match Pin::new(future).poll(cx) {
Ok(Async::Ready(inner)) => inner, Poll::Ready(Ok(inner)) => inner,
Ok(Async::NotReady) => return Ok(Async::NotReady), Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(crate::error::from_io(e)))),
Err(e) => return Err(e), Poll::Pending => return Poll::Pending,
}, },
Inner::PlainText(ref mut body) => return body.poll(), Inner::PlainText(ref mut body) => return Pin::new(body).poll_next(cx),
Inner::Gzip(ref mut decoder) => return decoder.poll(), Inner::Gzip(ref mut decoder) => {
return match futures::ready!(Pin::new(decoder).poll_next(cx)) {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes.into()))),
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::from_io(err)))),
None => Poll::Ready(None),
}
}
}; };
self.inner = new_value; self.inner = new_value;
self.poll() self.poll_next(cx)
} }
} }
impl Future for Pending { impl Future for Pending {
type Item = Inner; type Output = Result<Inner, std::io::Error>;
type Error = error::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let body_state = match self.body.poll_stream() { use futures::stream::StreamExt;
Ok(Async::Ready(state)) => state,
Ok(Async::NotReady) => return Ok(Async::NotReady), match futures::ready!(Pin::new(&mut self.0).peek(cx)) {
Err(e) => return Err(e), Some(Ok(_)) => {
// fallthrough
}
Some(Err(_e)) => {
// error was just a ref, so we need to really poll to move it
return Poll::Ready(Err(futures::ready!(Pin::new(&mut self.0).poll_next(cx))
.expect("just peeked Some")
.unwrap_err()));
}
None => return Poll::Ready(Ok(Inner::PlainText(Body::empty()))),
}; };
let body = mem::replace(&mut self.body, ReadableChunks::new(Body::empty())); let body = mem::replace(&mut self.0, BodyBytes(Body::empty()).peekable());
match body_state { Poll::Ready(Ok(Inner::Gzip(
StreamState::Eof => Ok(Async::Ready(Inner::PlainText(Body::empty()))), async_compression::stream::GzipDecoder::new(body),
StreamState::HasMore => Ok(Async::Ready(Inner::Gzip(Gzip::new(body)))),
}
}
}
impl Gzip {
fn new(stream: ReadableChunks<Body>) -> Self {
Gzip {
buf: BytesMut::with_capacity(INIT_BUFFER_SIZE),
inner: Box::new(GzDecoder::new(stream)),
}
}
}
impl Stream for Gzip {
type Item = Chunk;
type Error = error::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if self.buf.remaining_mut() == 0 {
self.buf.reserve(INIT_BUFFER_SIZE);
}
// The buffer contains uninitialised memory so getting a readable slice is unsafe.
// We trust the `flate2` and `miniz` writer not to read from the memory given.
//
// To be safe, this memory could be zeroed before passing to `flate2`.
// Otherwise we might need to deal with the case where `flate2` panics.
let read = try_io!(self.inner.read(unsafe { self.buf.bytes_mut() }));
if read == 0 {
// If GzDecoder reports EOF, it doesn't necessarily mean the
// underlying stream reached EOF (such as the `0\r\n\r\n`
// header meaning a chunked transfer has completed). If it
// isn't polled till EOF, the connection may not be able
// to be re-used.
//
// See https://github.com/seanmonstar/reqwest/issues/508.
let inner_read = try_io!(self.inner.get_mut().read(&mut [0]));
if inner_read == 0 {
Ok(Async::Ready(None))
} else {
Err(error::from(io::Error::new(
io::ErrorKind::InvalidData,
"unexpected data after gzip decoder signaled end-of-file",
))) )))
} }
} else {
unsafe { self.buf.advance_mut(read) };
let chunk = Chunk::from_chunk(self.buf.split_to(read).freeze());
Ok(Async::Ready(Some(chunk)))
}
}
} }
/// A `Read`able wrapper over a stream of chunks. impl Stream for BodyBytes {
pub struct ReadableChunks<S> { type Item = Result<Bytes, std::io::Error>;
state: ReadState,
stream: S,
}
enum ReadState { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
/// A chunk is ready to be read from. match futures::ready!(Pin::new(&mut self.0).poll_next(cx)) {
Ready(Chunk), Some(Ok(chunk)) => Poll::Ready(Some(Ok(chunk.into()))),
/// The next chunk isn't ready yet. Some(Err(err)) => Poll::Ready(Some(Err(err.into_io()))),
NotReady, None => Poll::Ready(None),
/// The stream has finished.
Eof,
}
enum StreamState {
/// More bytes can be read from the stream.
HasMore,
/// No more bytes can be read from the stream.
Eof,
}
impl<S> ReadableChunks<S> {
#[inline]
pub(crate) fn new(stream: S) -> Self {
ReadableChunks {
state: ReadState::NotReady,
stream,
}
}
}
impl<S> fmt::Debug for ReadableChunks<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ReadableChunks").finish()
}
}
impl<S> Read for ReadableChunks<S>
where
S: Stream<Item = Chunk, Error = error::Error>,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
loop {
let ret;
match self.state {
ReadState::Ready(ref mut chunk) => {
let len = cmp::min(buf.len(), chunk.remaining());
buf[..len].copy_from_slice(&chunk[..len]);
chunk.advance(len);
if chunk.is_empty() {
ret = len;
} else {
return Ok(len);
}
}
ReadState::NotReady => match self.poll_stream() {
Ok(Async::Ready(StreamState::HasMore)) => continue,
Ok(Async::Ready(StreamState::Eof)) => return Ok(0),
Ok(Async::NotReady) => return Err(io::ErrorKind::WouldBlock.into()),
Err(e) => return Err(error::into_io(e)),
},
ReadState::Eof => return Ok(0),
}
self.state = ReadState::NotReady;
return Ok(ret);
}
}
}
impl<S> ReadableChunks<S>
where
S: Stream<Item = Chunk, Error = error::Error>,
{
/// Poll the readiness of the inner reader.
///
/// This function will update the internal state and return a simplified
/// version of the `ReadState`.
fn poll_stream(&mut self) -> Poll<StreamState, error::Error> {
match self.stream.poll() {
Ok(Async::Ready(Some(chunk))) => {
self.state = ReadState::Ready(chunk);
Ok(Async::Ready(StreamState::HasMore))
}
Ok(Async::Ready(None)) => {
self.state = ReadState::Eof;
Ok(Async::Ready(StreamState::Eof))
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => Err(e),
} }
} }
} }

View File

@@ -1,6 +1,6 @@
pub use self::body::{Body, Chunk}; pub use self::body::{Body, Chunk};
pub use self::client::{Client, ClientBuilder}; pub use self::client::{Client, ClientBuilder};
pub use self::decoder::{Decoder, ReadableChunks}; pub use self::decoder::Decoder;
pub use self::request::{Request, RequestBuilder}; pub use self::request::{Request, RequestBuilder};
pub use self::response::{Response, ResponseBuilderExt}; pub use self::response::{Response, ResponseBuilderExt};

View File

@@ -7,9 +7,9 @@ use mime_guess::Mime;
use percent_encoding::{self, AsciiSet, NON_ALPHANUMERIC}; use percent_encoding::{self, AsciiSet, NON_ALPHANUMERIC};
use uuid::Uuid; use uuid::Uuid;
use futures::Stream; use futures::{Stream, StreamExt};
use super::{Body, Chunk}; use super::Body;
/// An async multipart/form-data request. /// An async multipart/form-data request.
pub struct Form { pub struct Form {
@@ -190,11 +190,11 @@ impl Part {
} }
/// Makes a new parameter from an arbitrary stream. /// Makes a new parameter from an arbitrary stream.
pub fn stream<T>(value: T) -> Part pub fn stream<T, I, E>(value: T) -> Part
where where
T: Stream + Send + 'static, T: Stream<Item = Result<I, E>> + Send + Sync + 'static,
T::Item: Into<Chunk>, E: std::error::Error + Send + Sync + 'static,
T::Error: std::error::Error + Send + Sync, hyper::Chunk: std::convert::From<I>,
{ {
Part::new(Body::wrap(hyper::Body::wrap_stream( Part::new(Body::wrap(hyper::Body::wrap_stream(
value.map(|chunk| chunk.into()), value.map(|chunk| chunk.into()),
@@ -210,7 +210,7 @@ impl Part {
/// Tries to set the mime of this part. /// Tries to set the mime of this part.
pub fn mime_str(self, mime: &str) -> crate::Result<Part> { pub fn mime_str(self, mime: &str) -> crate::Result<Part> {
Ok(self.mime(try_!(mime.parse()))) Ok(self.mime(mime.parse().map_err(crate::error::from)?))
} }
// Re-export when mime 0.4 is available, with split MediaType/MediaRange. // Re-export when mime 0.4 is available, with split MediaType/MediaRange.
@@ -480,6 +480,7 @@ impl PercentEncoding {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use futures::TryStreamExt;
use tokio; use tokio;
#[test] #[test]
@@ -487,9 +488,10 @@ mod tests {
let form = Form::new(); let form = Form::new();
let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt"); let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt");
let body_ft = form.stream(); let body = form.stream();
let s = body.map(|try_c| try_c.map(|c| c.into_bytes())).try_concat();
let out = rt.block_on(body_ft.map(|c| c.into_bytes()).concat2()); let out = rt.block_on(s);
assert_eq!(out.unwrap(), Vec::new()); assert_eq!(out.unwrap(), Vec::new());
} }
@@ -498,16 +500,20 @@ mod tests {
let mut form = Form::new() let mut form = Form::new()
.part( .part(
"reader1", "reader1",
Part::stream(futures::stream::once::<_, hyper::Error>(Ok(Chunk::from( Part::stream(futures::stream::once(futures::future::ready::<
"part1".to_owned(), Result<hyper::Chunk, hyper::Error>,
>(Ok(
hyper::Chunk::from("part1".to_owned()),
)))), )))),
) )
.part("key1", Part::text("value1")) .part("key1", Part::text("value1"))
.part("key2", Part::text("value2").mime(mime::IMAGE_BMP)) .part("key2", Part::text("value2").mime(mime::IMAGE_BMP))
.part( .part(
"reader2", "reader2",
Part::stream(futures::stream::once::<_, hyper::Error>(Ok(Chunk::from( Part::stream(futures::stream::once(futures::future::ready::<
"part2".to_owned(), Result<hyper::Chunk, hyper::Error>,
>(Ok(
hyper::Chunk::from("part2".to_owned()),
)))), )))),
) )
.part("key3", Part::text("value3").file_name("filename")); .part("key3", Part::text("value3").file_name("filename"));
@@ -530,11 +536,10 @@ mod tests {
Content-Disposition: form-data; name=\"key3\"; filename=\"filename\"\r\n\r\n\ Content-Disposition: form-data; name=\"key3\"; filename=\"filename\"\r\n\r\n\
value3\r\n--boundary--\r\n"; value3\r\n--boundary--\r\n";
let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt"); let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt");
let body_ft = form.stream(); let body = form.stream();
let s = body.map(|try_c| try_c.map(|c| c.into_bytes())).try_concat();
let out = rt let out = rt.block_on(s).unwrap();
.block_on(body_ft.map(|c| c.into_bytes()).concat2())
.unwrap();
// These prints are for debug purposes in case the test fails // These prints are for debug purposes in case the test fails
println!( println!(
"START REAL\n{}\nEND REAL", "START REAL\n{}\nEND REAL",
@@ -558,11 +563,10 @@ mod tests {
value2\r\n\ value2\r\n\
--boundary--\r\n"; --boundary--\r\n";
let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt"); let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt");
let body_ft = form.stream(); let body = form.stream();
let s = body.map(|try_c| try_c.map(|c| c.into_bytes())).try_concat();
let out = rt let out = rt.block_on(s).unwrap();
.block_on(body_ft.map(|c| c.into_bytes()).concat2())
.unwrap();
// These prints are for debug purposes in case the test fails // These prints are for debug purposes in case the test fails
println!( println!(
"START REAL\n{}\nEND REAL", "START REAL\n{}\nEND REAL",

View File

@@ -191,28 +191,20 @@ impl RequestBuilder {
/// Sends a multipart/form-data body. /// Sends a multipart/form-data body.
/// ///
/// ``` /// ```
/// # extern crate futures;
/// # extern crate reqwest;
///
/// # use reqwest::Error; /// # use reqwest::Error;
/// # use futures::future::Future;
/// ///
/// # fn run() -> Result<(), Error> { /// # async fn run() -> Result<(), Error> {
/// let client = reqwest::r#async::Client::new(); /// let client = reqwest::r#async::Client::new();
/// let form = reqwest::r#async::multipart::Form::new() /// let form = reqwest::r#async::multipart::Form::new()
/// .text("key3", "value3") /// .text("key3", "value3")
/// .text("key4", "value4"); /// .text("key4", "value4");
/// ///
/// let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt");
/// ///
/// let response = client.post("your url") /// let response = client.post("your url")
/// .multipart(form) /// .multipart(form)
/// .send() /// .send()
/// .and_then(|_| { /// .await?;
/// Ok(()) /// # Ok(())
/// });
///
/// rt.block_on(response)
/// # } /// # }
/// ``` /// ```
pub fn multipart(self, mut multipart: multipart::Form) -> RequestBuilder { pub fn multipart(self, mut multipart: multipart::Form) -> RequestBuilder {
@@ -334,23 +326,17 @@ impl RequestBuilder {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// # extern crate futures;
/// # extern crate reqwest;
/// #
/// # use reqwest::Error; /// # use reqwest::Error;
/// # use futures::future::Future;
/// # /// #
/// # fn run() -> Result<(), Error> { /// # async fn run() -> Result<(), Error> {
/// let response = reqwest::r#async::Client::new() /// let response = reqwest::r#async::Client::new()
/// .get("https://hyper.rs") /// .get("https://hyper.rs")
/// .send() /// .send()
/// .map(|resp| println!("status: {}", resp.status())); /// .await?;
/// /// # Ok(())
/// let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt");
/// rt.block_on(response)
/// # } /// # }
/// ``` /// ```
pub fn send(self) -> impl Future<Item = Response, Error = crate::Error> { pub fn send(self) -> impl Future<Output = Result<Response, crate::Error>> {
match self.request { match self.request {
Ok(req) => self.client.execute_request(req), Ok(req) => self.client.execute_request(req),
Err(err) => Pending::new_err(err), Err(err) => Pending::new_err(err),

View File

@@ -3,10 +3,11 @@ use std::fmt;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem; use std::mem;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use encoding_rs::{Encoding, UTF_8}; use encoding_rs::{Encoding, UTF_8};
use futures::stream::Concat2; use futures::{Future, FutureExt, TryStreamExt};
use futures::{try_ready, Async, Future, Poll, Stream};
use http; use http;
use hyper::client::connect::HttpInfo; use hyper::client::connect::HttpInfo;
use hyper::header::CONTENT_LENGTH; use hyper::header::CONTENT_LENGTH;
@@ -20,8 +21,12 @@ use url::Url;
use super::body::Body; use super::body::Body;
use super::Decoder; use super::Decoder;
use crate::async_impl::Chunk;
use crate::cookie; use crate::cookie;
/// https://github.com/rust-lang-nursery/futures-rs/issues/1812
type ConcatDecoder = Pin<Box<dyn Future<Output = Result<Chunk, crate::Error>> + Send>>;
/// A Response to a submitted `Request`. /// A Response to a submitted `Request`.
pub struct Response { pub struct Response {
status: StatusCode, status: StatusCode,
@@ -139,7 +144,7 @@ impl Response {
} }
/// Get the response text /// Get the response text
pub fn text(&mut self) -> impl Future<Item = String, Error = crate::Error> { pub fn text(&mut self) -> impl Future<Output = Result<String, crate::Error>> {
self.text_with_charset("utf-8") self.text_with_charset("utf-8")
} }
@@ -147,7 +152,7 @@ impl Response {
pub fn text_with_charset( pub fn text_with_charset(
&mut self, &mut self,
default_encoding: &str, default_encoding: &str,
) -> impl Future<Item = String, Error = crate::Error> { ) -> impl Future<Output = Result<String, crate::Error>> {
let body = mem::replace(&mut self.body, Decoder::empty()); let body = mem::replace(&mut self.body, Decoder::empty());
let content_type = self let content_type = self
.headers .headers
@@ -160,18 +165,18 @@ impl Response {
.unwrap_or(default_encoding); .unwrap_or(default_encoding);
let encoding = Encoding::for_label(encoding_name.as_bytes()).unwrap_or(UTF_8); let encoding = Encoding::for_label(encoding_name.as_bytes()).unwrap_or(UTF_8);
Text { Text {
concat: body.concat2(), concat: body.try_concat().boxed(),
encoding, encoding,
} }
} }
/// Try to deserialize the response body as JSON using `serde`. /// Try to deserialize the response body as JSON using `serde`.
#[inline] #[inline]
pub fn json<T: DeserializeOwned>(&mut self) -> impl Future<Item = T, Error = crate::Error> { pub fn json<T: DeserializeOwned>(&mut self) -> impl Future<Output = Result<T, crate::Error>> {
let body = mem::replace(&mut self.body, Decoder::empty()); let body = mem::replace(&mut self.body, Decoder::empty());
Json { Json {
concat: body.concat2(), concat: body.try_concat().boxed(),
_marker: PhantomData, _marker: PhantomData,
} }
} }
@@ -270,17 +275,27 @@ impl<T: Into<Body>> From<http::Response<T>> for Response {
/// A JSON object. /// A JSON object.
struct Json<T> { struct Json<T> {
concat: Concat2<Decoder>, concat: ConcatDecoder,
_marker: PhantomData<T>, _marker: PhantomData<T>,
} }
impl<T> Json<T> {
fn concat(self: Pin<&mut Self>) -> Pin<&mut ConcatDecoder> {
unsafe { Pin::map_unchecked_mut(self, |x| &mut x.concat) }
}
}
impl<T: DeserializeOwned> Future for Json<T> { impl<T: DeserializeOwned> Future for Json<T> {
type Item = T; type Output = Result<T, crate::Error>;
type Error = crate::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let bytes = try_ready!(self.concat.poll()); match futures::ready!(self.concat().as_mut().poll(cx)) {
let t = try_!(serde_json::from_slice(&bytes)); Err(e) => Poll::Ready(Err(e)),
Ok(Async::Ready(t)) Ok(chunk) => {
let t = serde_json::from_slice(&chunk).map_err(crate::error::from);
Poll::Ready(t)
}
}
} }
} }
@@ -290,28 +305,35 @@ impl<T> fmt::Debug for Json<T> {
} }
} }
#[derive(Debug)] //#[derive(Debug)]
struct Text { struct Text {
concat: Concat2<Decoder>, concat: ConcatDecoder,
encoding: &'static Encoding, encoding: &'static Encoding,
} }
impl Future for Text { impl Text {
type Item = String; fn concat(self: Pin<&mut Self>) -> Pin<&mut ConcatDecoder> {
type Error = crate::Error; unsafe { Pin::map_unchecked_mut(self, |x| &mut x.concat) }
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let bytes = try_ready!(self.concat.poll());
// a block because of borrow checker
{
let (text, _, _) = self.encoding.decode(&bytes);
if let Cow::Owned(s) = text {
return Ok(Async::Ready(s));
} }
} }
impl Future for Text {
type Output = Result<String, crate::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match futures::ready!(self.as_mut().concat().as_mut().poll(cx)) {
Err(e) => Poll::Ready(Err(e)),
Ok(chunk) => {
let (text, _, _) = self.as_mut().encoding.decode(&chunk);
if let Cow::Owned(s) = text {
return Poll::Ready(Ok(s));
}
unsafe { unsafe {
// decoding returned Cow::Borrowed, meaning these bytes // decoding returned Cow::Borrowed, meaning these bytes
// are already valid utf8 // are already valid utf8
Ok(Async::Ready(String::from_utf8_unchecked(bytes.to_vec()))) Poll::Ready(Ok(String::from_utf8_unchecked(chunk.to_vec())))
}
}
} }
} }
} }

View File

@@ -1,10 +1,9 @@
use std::fmt; use std::fmt;
use std::fs::File; use std::fs::File;
use std::future::Future;
use std::io::{self, Cursor, Read}; use std::io::{self, Cursor, Read};
use bytes::Bytes; use bytes::Bytes;
use futures::{try_ready, Future};
use hyper;
use crate::async_impl; use crate::async_impl;
@@ -213,26 +212,22 @@ pub(crate) struct Sender {
tx: hyper::body::Sender, tx: hyper::body::Sender,
} }
impl Sender { async fn send_future(sender: Sender) -> Result<(), crate::Error> {
// A `Future` that may do blocking read calls.
// As a `Future`, this integrates easily with `wait::timeout`.
pub(crate) fn send(self) -> impl Future<Item = (), Error = crate::Error> {
use bytes::{BufMut, BytesMut}; use bytes::{BufMut, BytesMut};
use futures::future;
use std::cmp; use std::cmp;
let con_len = self.body.1; let con_len = sender.body.1;
let cap = cmp::min(self.body.1.unwrap_or(8192), 8192); let cap = cmp::min(sender.body.1.unwrap_or(8192), 8192);
let mut written = 0; let mut written = 0;
let mut buf = BytesMut::with_capacity(cap as usize); let mut buf = BytesMut::with_capacity(cap as usize);
let mut body = self.body.0; let mut body = sender.body.0;
// Put in an option so that it can be consumed on error to call abort() // Put in an option so that it can be consumed on error to call abort()
let mut tx = Some(self.tx); let mut tx = Some(sender.tx);
future::poll_fn(move || loop { loop {
if Some(written) == con_len { if Some(written) == con_len {
// Written up to content-length, so stop. // Written up to content-length, so stop.
return Ok(().into()); return Ok(());
} }
// The input stream is read only if the buffer is empty so // The input stream is read only if the buffer is empty so
@@ -257,7 +252,7 @@ impl Sender {
Ok(0) => { Ok(0) => {
// The buffer was empty and nothing's left to // The buffer was empty and nothing's left to
// read. Return. // read. Return.
return Ok(().into()); return Ok(());
} }
Ok(n) => unsafe { Ok(n) => unsafe {
buf.advance_mut(n); buf.advance_mut(n);
@@ -272,18 +267,23 @@ impl Sender {
// The only way to get here is when the buffer is not empty. // The only way to get here is when the buffer is not empty.
// We can check the transmission channel // We can check the transmission channel
try_ready!(tx
.as_mut()
.expect("tx only taken on error")
.poll_ready()
.map_err(crate::error::from));
written += buf.len() as u64; let buf_len = buf.len() as u64;
let tx = tx.as_mut().expect("tx only taken on error"); tx.as_mut()
if tx.send_data(buf.take().freeze().into()).is_err() { .expect("tx only taken on error")
return Err(crate::error::timedout(None)); .send_data(buf.take().freeze().into())
.await
.map_err(crate::error::from)?;
written += buf_len;
} }
}) }
impl Sender {
// A `Future` that may do blocking read calls.
// As a `Future`, this integrates easily with `wait::timeout`.
pub(crate) fn send(self) -> impl Future<Output = Result<(), crate::Error>> {
send_future(self)
} }
} }

View File

@@ -1,14 +1,14 @@
use std::fmt; use std::fmt;
use std::future::Future;
use std::net::IpAddr; use std::net::IpAddr;
use std::sync::Arc; use std::sync::Arc;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use futures::future::{self, Either}; use futures::channel::{mpsc, oneshot};
use futures::sync::{mpsc, oneshot}; use futures::{StreamExt, TryFutureExt};
use futures::{Async, Future, Stream};
use log::trace; use log::{error, trace};
use crate::request::{Request, RequestBuilder}; use crate::request::{Request, RequestBuilder};
use crate::response::Response; use crate::response::Response;
@@ -523,10 +523,8 @@ struct ClientHandle {
inner: Arc<InnerClientHandle>, inner: Arc<InnerClientHandle>,
} }
type ThreadSender = mpsc::UnboundedSender<( type OneshotResponse = oneshot::Sender<crate::Result<async_impl::Response>>;
async_impl::Request, type ThreadSender = mpsc::UnboundedSender<(async_impl::Request, OneshotResponse)>;
oneshot::Sender<crate::Result<async_impl::Response>>,
)>;
struct InnerClientHandle { struct InnerClientHandle {
tx: Option<ThreadSender>, tx: Option<ThreadSender>,
@@ -544,69 +542,54 @@ impl ClientHandle {
fn new(builder: ClientBuilder) -> crate::Result<ClientHandle> { fn new(builder: ClientBuilder) -> crate::Result<ClientHandle> {
let timeout = builder.timeout; let timeout = builder.timeout;
let builder = builder.inner; let builder = builder.inner;
let (tx, rx) = mpsc::unbounded(); let (tx, rx) = mpsc::unbounded::<(async_impl::Request, OneshotResponse)>();
let (spawn_tx, spawn_rx) = oneshot::channel::<crate::Result<()>>(); let (spawn_tx, spawn_rx) = oneshot::channel::<crate::Result<()>>();
let handle = try_!(thread::Builder::new() let handle = thread::Builder::new()
.name("reqwest-internal-sync-runtime".into()) .name("reqwest-internal-sync-runtime".into())
.spawn(move || { .spawn(move || {
use tokio::runtime::current_thread::Runtime; use tokio::runtime::current_thread::Runtime;
let built = (|| { let mut rt = match Runtime::new().map_err(crate::error::from) {
let rt = try_!(Runtime::new());
let client = builder.build()?;
Ok((rt, client))
})();
let (mut rt, client) = match built {
Ok((rt, c)) => {
if spawn_tx.send(Ok(())).is_err() {
return;
}
(rt, c)
}
Err(e) => { Err(e) => {
let _ = spawn_tx.send(Err(e)); if let Err(e) = spawn_tx.send(Err(e)) {
error!("Failed to communicate runtime creation failure: {:?}", e);
}
return; return;
} }
Ok(v) => v,
}; };
let work = rx.for_each(move |(req, tx)| { let f = async move {
let mut tx_opt: Option<oneshot::Sender<crate::Result<async_impl::Response>>> = let client = match builder.build() {
Some(tx); Err(e) => {
let mut res_fut = client.execute(req); if let Err(e) = spawn_tx.send(Err(e)) {
error!("Failed to communicate client creation failure: {:?}", e);
let task = future::poll_fn(move || { }
let canceled = tx_opt return;
.as_mut() }
.expect("polled after complete") Ok(v) => v,
.poll_cancel() };
.expect("poll_cancel cannot error") if let Err(e) = spawn_tx.send(Ok(())) {
.is_ready(); error!("Failed to communicate successful startup: {:?}", e);
return;
if canceled {
trace!("response receiver is canceled");
Ok(Async::Ready(()))
} else {
let result = match res_fut.poll() {
Ok(Async::NotReady) => return Ok(Async::NotReady),
Ok(Async::Ready(res)) => Ok(res),
Err(err) => Err(err),
};
let _ = tx_opt.take().expect("polled after complete").send(result);
Ok(Async::Ready(()))
} }
});
tokio::spawn(task);
Ok(())
});
// work is Future<(), ()>, and our closure will never return Err let mut rx = rx;
rt.block_on(work).expect("runtime unexpected error");
})); while let Some((req, req_tx)) = rx.next().await {
let req_fut = client.execute(req);
tokio::spawn(forward(req_fut, req_tx));
}
trace!("Receiver is shutdown");
};
rt.block_on(f)
})
.map_err(crate::error::from)?;
// Wait for the runtime thread to start up... // Wait for the runtime thread to start up...
match spawn_rx.wait() { match wait::timeout(spawn_rx, None) {
Ok(Ok(())) => (), Ok(Ok(())) => (),
Ok(Err(err)) => return Err(err), Ok(Err(err)) => return Err(err),
Err(_canceled) => event_loop_panicked(), Err(_canceled) => event_loop_panicked(),
@@ -634,34 +617,60 @@ impl ClientHandle {
.unbounded_send((req, tx)) .unbounded_send((req, tx))
.expect("core thread panicked"); .expect("core thread panicked");
let write = if let Some(body) = body { let result: Result<crate::Result<async_impl::Response>, wait::Waited<crate::Error>> =
Either::A(body.send()) if let Some(body) = body {
//try_!(body.send(self.timeout.0), &url); let f = async move {
body.send().await?;
rx.await.map_err(|_canceled| event_loop_panicked())
};
wait::timeout(f, self.timeout.0)
} else { } else {
Either::B(future::ok(())) wait::timeout(
rx.map_err(|_canceled| event_loop_panicked()),
self.timeout.0,
)
}; };
let rx = rx.map_err(|_canceled| event_loop_panicked()); match result {
Ok(Err(err)) => Err(err.with_url(url)),
let fut = write.join(rx).map(|((), res)| res); Ok(Ok(res)) => Ok(Response::new(
let res = match wait::timeout(fut, self.timeout.0) {
Ok(res) => res,
Err(wait::Waited::TimedOut) => return Err(crate::error::timedout(Some(url))),
Err(wait::Waited::Executor(err)) => return Err(crate::error::from(err).with_url(url)),
Err(wait::Waited::Inner(err)) => {
return Err(err.with_url(url));
}
};
res.map(|res| {
Response::new(
res, res,
self.timeout.0, self.timeout.0,
KeepCoreThreadAlive(Some(self.inner.clone())), KeepCoreThreadAlive(Some(self.inner.clone())),
) )),
}) Err(wait::Waited::TimedOut) => Err(crate::error::timedout(Some(url))),
Err(wait::Waited::Executor(err)) => Err(crate::error::from(err).with_url(url)),
Err(wait::Waited::Inner(err)) => Err(err.with_url(url)),
} }
} }
}
async fn forward<F>(fut: F, mut tx: OneshotResponse)
where
F: Future<Output = crate::Result<async_impl::Response>>,
{
use std::task::Poll;
futures::pin_mut!(fut);
// "select" on the sender being canceled, and the future completing
let res = futures::future::poll_fn(|cx| {
match fut.as_mut().poll(cx) {
Poll::Ready(val) => Poll::Ready(Some(val)),
Poll::Pending => {
// check if the callback is canceled
futures::ready!(tx.poll_cancel(cx));
Poll::Ready(None)
}
}
})
.await;
if let Some(res) = res {
let _ = tx.send(res);
}
// else request is canceled
}
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
struct Timeout(Option<Duration>); struct Timeout(Option<Duration>);

View File

@@ -1,28 +1,26 @@
use futures::Future; use futures::FutureExt;
use http::uri::Scheme; use http::uri::Scheme;
use hyper::client::connect::{Connect, Connected, Destination}; use hyper::client::connect::{Connect, Connected, Destination};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio_timer::Timeout;
#[cfg(feature = "tls")]
use bytes::BufMut;
#[cfg(feature = "tls")]
use futures::Poll;
#[cfg(feature = "default-tls")] #[cfg(feature = "default-tls")]
use native_tls::{TlsConnector, TlsConnectorBuilder}; use native_tls::{TlsConnector, TlsConnectorBuilder};
use std::future::Future;
use std::io; use std::io;
use std::net::IpAddr; use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
#[cfg(feature = "trust-dns")] //#[cfg(feature = "trust-dns")]
use crate::dns::TrustDnsResolver; //use crate::dns::TrustDnsResolver;
use crate::proxy::{Proxy, ProxyScheme}; use crate::proxy::{Proxy, ProxyScheme};
use tokio::future::FutureExt as _;
#[cfg(feature = "trust-dns")] //#[cfg(feature = "trust-dns")]
type HttpConnector = hyper::client::HttpConnector<TrustDnsResolver>; //type HttpConnector = hyper::client::HttpConnector<TrustDnsResolver>;
#[cfg(not(feature = "trust-dns"))] //#[cfg(not(feature = "trust-dns"))]
type HttpConnector = hyper::client::HttpConnector; type HttpConnector = hyper::client::HttpConnector;
pub(crate) struct Connector { pub(crate) struct Connector {
@@ -33,6 +31,7 @@ pub(crate) struct Connector {
nodelay: bool, nodelay: bool,
} }
#[derive(Clone)]
enum Inner { enum Inner {
#[cfg(not(feature = "tls"))] #[cfg(not(feature = "tls"))]
Http(HttpConnector), Http(HttpConnector),
@@ -76,7 +75,7 @@ impl Connector {
where where
T: Into<Option<IpAddr>>, T: Into<Option<IpAddr>>,
{ {
let tls = try_!(tls.build()); let tls = tls.build().map_err(crate::error::from)?;
let mut http = http_connector()?; let mut http = http_connector()?;
http.set_local_address(local_addr.into()); http.set_local_address(local_addr.into());
@@ -130,25 +129,11 @@ impl Connector {
} }
#[cfg(feature = "socks")] #[cfg(feature = "socks")]
fn connect_socks(&self, dst: Destination, proxy: ProxyScheme) -> Connecting { async fn connect_socks(
macro_rules! timeout { &self,
($future:expr) => { dst: Destination,
if let Some(dur) = self.timeout { proxy: ProxyScheme,
Box::new(Timeout::new($future, dur).map_err(|err| { ) -> Result<(Conn, Connected), io::Error> {
if err.is_inner() {
err.into_inner().expect("is_inner")
} else if err.is_elapsed() {
io::Error::new(io::ErrorKind::TimedOut, "connect timed out")
} else {
io::Error::new(io::ErrorKind::Other, err)
}
}))
} else {
Box::new($future)
}
};
}
let dns = match proxy { let dns = match proxy {
ProxyScheme::Socks5 { ProxyScheme::Socks5 {
remote_dns: false, .. remote_dns: false, ..
@@ -167,14 +152,15 @@ impl Connector {
if dst.scheme() == "https" { if dst.scheme() == "https" {
use self::native_tls_async::TlsConnectorExt; use self::native_tls_async::TlsConnectorExt;
let tls = tls.clone();
let host = dst.host().to_owned(); let host = dst.host().to_owned();
let socks_connecting = socks::connect(proxy, dst, dns); let socks_connecting = socks::connect(proxy, dst, dns);
return timeout!(socks_connecting.and_then(move |(conn, connected)| { let (conn, connected) = socks::connect(proxy, dst, dns).await?;
tls.connect_async(&host, conn) let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
.map_err(|e| io::Error::new(io::ErrorKind::Other, e)) let io = tls_connector
.map(move |io| (Box::new(io) as Conn, connected)) .connect(&host, conn)
})); .await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
Ok((Box::new(io) as Conn, connected))
} }
} }
#[cfg(feature = "rustls-tls")] #[cfg(feature = "rustls-tls")]
@@ -185,95 +171,65 @@ impl Connector {
let tls = tls_proxy.clone(); let tls = tls_proxy.clone();
let host = dst.host().to_owned(); let host = dst.host().to_owned();
let socks_connecting = socks::connect(proxy, dst, dns); let (conn, connected) = socks::connect(proxy, dst, dns);
return timeout!(socks_connecting.and_then(move |(conn, connected)| { let dnsname = DNSNameRef::try_from_ascii_str(&host)
let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
.map(|dnsname| dnsname.to_owned()) .map(|dnsname| dnsname.to_owned())
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name")); .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"))?;
futures::future::result(maybe_dnsname) let io = RustlsConnector::from(tls)
.and_then(move |dnsname| {
RustlsConnector::from(tls)
.connect(dnsname.as_ref(), conn) .connect(dnsname.as_ref(), conn)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e)) .await
}) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
.map(move |io| (Box::new(io) as Conn, connected)) Ok((Box::new(io) as Conn, connected))
}));
} }
} }
#[cfg(not(feature = "tls"))] #[cfg(not(feature = "tls"))]
Inner::Http(_) => (), Inner::Http(_) => socks::connect(proxy, dst, dns),
} }
// else no TLS
socks::connect(proxy, dst, dns)
} }
} }
#[cfg(feature = "trust-dns")] //#[cfg(feature = "trust-dns")]
//fn http_connector() -> crate::Result<HttpConnector> {
// TrustDnsResolver::new()
// .map(HttpConnector::new_with_resolver)
// .map_err(crate::error::dns_system_conf)
//}
//#[cfg(not(feature = "trust-dns"))]
fn http_connector() -> crate::Result<HttpConnector> { fn http_connector() -> crate::Result<HttpConnector> {
TrustDnsResolver::new() Ok(HttpConnector::new())
.map(HttpConnector::new_with_resolver)
.map_err(crate::error::dns_system_conf)
} }
#[cfg(not(feature = "trust-dns"))] async fn connect_with_maybe_proxy(
fn http_connector() -> crate::Result<HttpConnector> { inner: Inner,
Ok(HttpConnector::new(4)) dst: Destination,
} is_proxy: bool,
no_delay: bool,
impl Connect for Connector { ) -> Result<(Conn, Connected), io::Error> {
type Transport = Conn; match inner {
type Error = io::Error;
type Future = Connecting;
fn connect(&self, dst: Destination) -> Self::Future {
#[cfg(feature = "tls")]
let nodelay = self.nodelay;
macro_rules! timeout {
($future:expr) => {
if let Some(dur) = self.timeout {
Box::new(Timeout::new($future, dur).map_err(|err| {
if err.is_inner() {
err.into_inner().expect("is_inner")
} else if err.is_elapsed() {
io::Error::new(io::ErrorKind::TimedOut, "connect timed out")
} else {
io::Error::new(io::ErrorKind::Other, err)
}
}))
} else {
Box::new($future)
}
};
}
macro_rules! connect {
( $http:expr, $dst:expr, $proxy:expr ) => {
timeout!($http
.connect($dst)
.map(|(io, connected)| (Box::new(io) as Conn, connected.proxy($proxy))))
};
( $dst:expr, $proxy:expr ) => {
match &self.inner {
#[cfg(not(feature = "tls"))] #[cfg(not(feature = "tls"))]
Inner::Http(http) => connect!(http, $dst, $proxy), Inner::Http(http) => {
drop(no_delay); // only used for TLS?
let (io, connected) = http.connect(dst).await?;
Ok((Box::new(io) as Conn, connected.proxy(is_proxy)))
}
#[cfg(feature = "default-tls")] #[cfg(feature = "default-tls")]
Inner::DefaultTls(http, tls) => { Inner::DefaultTls(http, tls) => {
let mut http = http.clone(); let mut http = http.clone();
http.set_nodelay(nodelay || ($dst.scheme() == "https")); http.set_nodelay(no_delay || (dst.scheme() == "https"));
let http = hyper_tls::HttpsConnector::from((http, tls.clone())); let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
timeout!(http.connect($dst).and_then(move |(io, connected)| { let http = hyper_tls::HttpsConnector::from((http, tls_connector));
if let hyper_tls::MaybeHttpsStream::Https(stream) = &io { let (io, connected) = http.connect(dst).await?;
if !nodelay { //TODO: where's this at now?
stream.get_ref().get_ref().set_nodelay(false)?; //if let hyper_tls::MaybeHttpsStream::Https(_stream) = &io {
} // if !no_delay {
} // stream.set_nodelay(false)?;
// }
//}
Ok((Box::new(io) as Conn, connected.proxy($proxy))) Ok((Box::new(io) as Conn, connected.proxy(is_proxy)))
}))
} }
#[cfg(feature = "rustls-tls")] #[cfg(feature = "rustls-tls")]
Inner::RustlsTls { http, tls, .. } => { Inner::RustlsTls { http, tls, .. } => {
@@ -282,10 +238,10 @@ impl Connect for Connector {
// Disable Nagle's algorithm for TLS handshake // Disable Nagle's algorithm for TLS handshake
// //
// https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES
http.set_nodelay(nodelay || ($dst.scheme() == "https")); http.set_nodelay(nodelay || (dst.scheme() == "https"));
let http = hyper_rustls::HttpsConnector::from((http, tls.clone())); let http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
timeout!(http.connect($dst).and_then(move |(io, connected)| { let (io, connected) = http.connect(dst).await;
if let hyper_rustls::MaybeHttpsStream::Https(stream) = &io { if let hyper_rustls::MaybeHttpsStream::Https(stream) = &io {
if !nodelay { if !nodelay {
let (io, _) = stream.get_ref(); let (io, _) = stream.get_ref();
@@ -293,21 +249,23 @@ impl Connect for Connector {
} }
} }
Ok((Box::new(io) as Conn, connected.proxy($proxy))) Ok((Box::new(io) as Conn, connected.proxy(is_proxy)))
}))
} }
} }
};
} }
for prox in self.proxies.iter() { async fn connect_via_proxy(
if let Some(proxy_scheme) = prox.intercept(&dst) { inner: Inner,
dst: Destination,
proxy_scheme: ProxyScheme,
no_delay: bool,
) -> Result<(Conn, Connected), io::Error> {
log::trace!("proxy({:?}) intercepts {:?}", proxy_scheme, dst); log::trace!("proxy({:?}) intercepts {:?}", proxy_scheme, dst);
let (puri, _auth) = match proxy_scheme { let (puri, _auth) = match proxy_scheme {
ProxyScheme::Http { uri, auth, .. } => (uri, auth), ProxyScheme::Http { uri, auth, .. } => (uri, auth),
#[cfg(feature = "socks")] #[cfg(feature = "socks")]
ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme), ProxyScheme::Socks5 { .. } => return this.connect_socks(dst, proxy_scheme),
}; };
let mut ndst = dst.clone(); let mut ndst = dst.clone();
@@ -324,30 +282,25 @@ impl Connect for Connector {
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
let auth = _auth; let auth = _auth;
match &self.inner { match &inner {
#[cfg(feature = "default-tls")] #[cfg(feature = "default-tls")]
Inner::DefaultTls(http, tls) => { Inner::DefaultTls(http, tls) => {
if dst.scheme() == "https" { if dst.scheme() == "https" {
use self::native_tls_async::TlsConnectorExt;
let host = dst.host().to_owned(); let host = dst.host().to_owned();
let port = dst.port().unwrap_or(443); let port = dst.port().unwrap_or(443);
let mut http = http.clone(); let mut http = http.clone();
http.set_nodelay(nodelay); http.set_nodelay(no_delay);
let http = hyper_tls::HttpsConnector::from((http, tls.clone())); let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
let tls = tls.clone(); let http = hyper_tls::HttpsConnector::from((http, tls_connector));
return timeout!(http.connect(ndst).and_then( let (conn, connected) = http.connect(ndst).await?;
move |(conn, connected)| {
log::trace!("tunneling HTTPS over proxy"); log::trace!("tunneling HTTPS over proxy");
tunnel(conn, host.clone(), port, auth) let tunneled = tunnel(conn, host.clone(), port, auth).await?;
.and_then(move |tunneled| { let tls_connector = tokio_tls::TlsConnector::from(tls.clone());
tls.connect_async(&host, tunneled).map_err(|e| { let io = tls_connector
io::Error::new(io::ErrorKind::Other, e) .connect(&host, tunneled)
}) .await
}) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
.map(|io| (Box::new(io) as Conn, connected.proxy(true))) return Ok((Box::new(io) as Conn, connected.proxy(true)));
}
));
} }
} }
#[cfg(feature = "rustls-tls")] #[cfg(feature = "rustls-tls")]
@@ -365,65 +318,96 @@ impl Connect for Connector {
let port = dst.port().unwrap_or(443); let port = dst.port().unwrap_or(443);
let mut http = http.clone(); let mut http = http.clone();
http.set_nodelay(nodelay); http.set_nodelay(nodelay);
let http = let http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
let tls = tls.clone(); let tls = tls.clone();
return timeout!(http.connect(ndst).and_then( let (conn, connected) = http.connect(ndst).await;
move |(conn, connected)| {
log::trace!("tunneling HTTPS over proxy"); log::trace!("tunneling HTTPS over proxy");
let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host)
.map(|dnsname| dnsname.to_owned()) .map(|dnsname| dnsname.to_owned())
.map_err(|_| { .map_err(|_| io::Error::new(io::ErrorKind::Other, "Invalid DNS Name"));
io::Error::new(io::ErrorKind::Other, "Invalid DNS Name") let tunneled = tunnel(conn, host, port, auth).await;
}); let dnsname = maybe_dnsname?;
tunnel(conn, host, port, auth) let io = RustlsConnector::from(tls)
.and_then(move |tunneled| Ok((maybe_dnsname?, tunneled)))
.and_then(move |(dnsname, tunneled)| {
RustlsConnector::from(tls)
.connect(dnsname.as_ref(), tunneled) .connect(dnsname.as_ref(), tunneled)
.map_err(|e| { .await
io::Error::new(io::ErrorKind::Other, e) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
}) let connected = if io.get_ref().1.get_alpn_protocol() == Some(b"h2") {
})
.map(|io| {
let connected = if io.get_ref().1.get_alpn_protocol()
== Some(b"h2")
{
connected.negotiated_h2() connected.negotiated_h2()
} else { } else {
connected connected
}; };
(Box::new(io) as Conn, connected.proxy(true)) return Ok((Box::new(io) as Conn, connected.proxy(true)));
})
}
));
} }
} }
#[cfg(not(feature = "tls"))] #[cfg(not(feature = "tls"))]
Inner::Http(_) => (), Inner::Http(_) => (),
} }
return connect!(ndst, true); connect_with_maybe_proxy(inner, ndst, true, no_delay).await
}
async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, io::Error>
where
F: Future<Output = Result<T, io::Error>>,
{
if let Some(to) = timeout {
match f.timeout(to).await {
Err(_elapsed) => Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out")),
Ok(try_res) => try_res,
}
} else {
f.await
} }
} }
connect!(dst, false) impl Connect for Connector {
type Transport = Conn;
type Error = io::Error;
type Future = Connecting;
fn connect(&self, dst: Destination) -> Self::Future {
#[cfg(feature = "tls")]
let no_delay = self.nodelay;
#[cfg(not(feature = "tls"))]
let no_delay = false;
let timeout = self.timeout;
for prox in self.proxies.iter() {
if let Some(proxy_scheme) = prox.intercept(&dst) {
return with_timeout(
connect_via_proxy(self.inner.clone(), dst, proxy_scheme, no_delay),
timeout,
)
.boxed();
}
}
with_timeout(
connect_with_maybe_proxy(self.inner.clone(), dst, false, no_delay),
timeout,
)
.boxed()
} }
} }
pub(crate) trait AsyncConn: AsyncRead + AsyncWrite {} pub(crate) trait AsyncConn: AsyncRead + AsyncWrite {}
impl<T: AsyncRead + AsyncWrite> AsyncConn for T {} impl<T: AsyncRead + AsyncWrite> AsyncConn for T {}
pub(crate) type Conn = Box<dyn AsyncConn + Send + Sync + 'static>; pub(crate) type Conn = Box<dyn AsyncConn + Send + Sync + Unpin + 'static>;
pub(crate) type Connecting = Box<dyn Future<Item = (Conn, Connected), Error = io::Error> + Send>; pub(crate) type Connecting =
Pin<Box<dyn Future<Output = Result<(Conn, Connected), io::Error>> + Send>>;
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
fn tunnel<T>( async fn tunnel<T>(
conn: T, mut conn: T,
host: String, host: String,
port: u16, port: u16,
auth: Option<http::header::HeaderValue>, auth: Option<http::header::HeaderValue>,
) -> Tunnel<T> { ) -> Result<T, io::Error>
where
T: AsyncRead + AsyncWrite + Unpin,
{
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut buf = format!( let mut buf = format!(
"\ "\
CONNECT {0}:{1} HTTP/1.1\r\n\ CONNECT {0}:{1} HTTP/1.1\r\n\
@@ -443,84 +427,43 @@ fn tunnel<T>(
// headers end // headers end
buf.extend_from_slice(b"\r\n"); buf.extend_from_slice(b"\r\n");
Tunnel { conn.write_all(&buf).await?;
buf: io::Cursor::new(buf),
conn: Some(conn),
state: TunnelState::Writing,
}
}
#[cfg(feature = "tls")] let mut buf = [0; 8192];
struct Tunnel<T> { let mut pos = 0;
buf: io::Cursor<Vec<u8>>,
conn: Option<T>,
state: TunnelState,
}
#[cfg(feature = "tls")]
enum TunnelState {
Writing,
Reading,
}
#[cfg(feature = "tls")]
impl<T> Future for Tunnel<T>
where
T: AsyncRead + AsyncWrite,
{
type Item = T;
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop { loop {
if let TunnelState::Writing = self.state { let n = conn.read(&mut buf[pos..]).await?;
let n = futures::try_ready!(self.conn.as_mut().unwrap().write_buf(&mut self.buf));
if !self.buf.has_remaining_mut() {
self.state = TunnelState::Reading;
self.buf.get_mut().truncate(0);
} else if n == 0 {
return Err(tunnel_eof());
}
} else {
let n = futures::try_ready!(self
.conn
.as_mut()
.unwrap()
.read_buf(&mut self.buf.get_mut()));
let read = &self.buf.get_ref()[..];
if n == 0 { if n == 0 {
return Err(tunnel_eof()); return Err(tunnel_eof());
} else if read.len() > 12 { }
if read.starts_with(b"HTTP/1.1 200") || read.starts_with(b"HTTP/1.0 200") { pos += n;
if read.ends_with(b"\r\n\r\n") {
return Ok(self.conn.take().unwrap().into()); let recvd = &buf[..pos];
if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
if recvd.ends_with(b"\r\n\r\n") {
return Ok(conn);
}
if pos == buf.len() {
return Err(io::Error::new(
io::ErrorKind::Other,
"proxy headers too long for tunnel",
));
} }
// else read more // else read more
} else if read.starts_with(b"HTTP/1.1 407") { } else if recvd.starts_with(b"HTTP/1.1 407") {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::Other, io::ErrorKind::Other,
"proxy authentication required", "proxy authentication required",
)); ));
} else if read.starts_with(b"HTTP/1.1 403") {
return Err(io::Error::new(
io::ErrorKind::Other,
"proxy blocked this request",
));
} else { } else {
let (fst, _) = read.split_at(12); return Err(io::Error::new(io::ErrorKind::Other, "unsuccessful tunnel"));
return Err(io::Error::new(
io::ErrorKind::Other,
format!("unsuccessful tunnel: {:?}", fst).as_str(),
));
}
}
}
} }
} }
} }
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
#[inline]
fn tunnel_eof() -> io::Error { fn tunnel_eof() -> io::Error {
io::Error::new( io::Error::new(
io::ErrorKind::UnexpectedEof, io::ErrorKind::UnexpectedEof,
@@ -528,138 +471,6 @@ fn tunnel_eof() -> io::Error {
) )
} }
#[cfg(feature = "default-tls")]
mod native_tls_async {
use std::io::{self, Read, Write};
use futures::{Async, Future, Poll};
use native_tls::{self, Error, HandshakeError, TlsConnector};
use tokio_io::{try_nb, AsyncRead, AsyncWrite};
/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
///
/// A `TlsStream<S>` represents a handshake that has been completed successfully
/// and both the server and the client are ready for receiving and sending
/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written
/// to a `TlsStream` are encrypted when passing through to `S`.
#[derive(Debug)]
pub struct TlsStream<S> {
inner: native_tls::TlsStream<S>,
}
/// Future returned from `TlsConnectorExt::connect_async` which will resolve
/// once the connection handshake has finished.
pub struct ConnectAsync<S> {
inner: MidHandshake<S>,
}
struct MidHandshake<S> {
inner: Option<Result<native_tls::TlsStream<S>, HandshakeError<S>>>,
}
/// Extension trait for the `TlsConnector` type in the `native_tls` crate.
pub trait TlsConnectorExt: sealed::Sealed {
/// Connects the provided stream with this connector, assuming the provided
/// domain.
///
/// This function will internally call `TlsConnector::connect` to connect
/// the stream and returns a future representing the resolution of the
/// connection operation. The returned future will resolve to either
/// `TlsStream<S>` or `Error` depending if it's successful or not.
///
/// This is typically used for clients who have already established, for
/// example, a TCP connection to a remote server. That stream is then
/// provided here to perform the client half of a connection to a
/// TLS-powered server.
///
/// # Compatibility notes
///
/// Note that this method currently requires `S: Read + Write` but it's
/// highly recommended to ensure that the object implements the `AsyncRead`
/// and `AsyncWrite` traits as well, otherwise this function will not work
/// properly.
fn connect_async<S>(&self, domain: &str, stream: S) -> ConnectAsync<S>
where
S: Read + Write; // TODO: change to AsyncRead + AsyncWrite
}
mod sealed {
pub trait Sealed {}
}
impl<S: Read + Write> Read for TlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl<S: Read + Write> Write for TlsStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<S: AsyncRead + AsyncWrite> AsyncRead for TlsStream<S> {}
impl<S: AsyncRead + AsyncWrite> AsyncWrite for TlsStream<S> {
fn shutdown(&mut self) -> Poll<(), io::Error> {
try_nb!(self.inner.shutdown());
self.inner.get_mut().shutdown()
}
}
impl TlsConnectorExt for TlsConnector {
fn connect_async<S>(&self, domain: &str, stream: S) -> ConnectAsync<S>
where
S: Read + Write,
{
ConnectAsync {
inner: MidHandshake {
inner: Some(self.connect(domain, stream)),
},
}
}
}
impl sealed::Sealed for TlsConnector {}
// TODO: change this to AsyncRead/AsyncWrite on next major version
impl<S: Read + Write> Future for ConnectAsync<S> {
type Item = TlsStream<S>;
type Error = Error;
fn poll(&mut self) -> Poll<TlsStream<S>, Error> {
self.inner.poll()
}
}
// TODO: change this to AsyncRead/AsyncWrite on next major version
impl<S: Read + Write> Future for MidHandshake<S> {
type Item = TlsStream<S>;
type Error = Error;
fn poll(&mut self) -> Poll<TlsStream<S>, Error> {
match self.inner.take().expect("cannot poll MidHandshake twice") {
Ok(stream) => Ok(TlsStream { inner: stream }.into()),
Err(HandshakeError::Failure(e)) => Err(e),
Err(HandshakeError::WouldBlock(s)) => match s.handshake() {
Ok(stream) => Ok(TlsStream { inner: stream }.into()),
Err(HandshakeError::Failure(e)) => Err(e),
Err(HandshakeError::WouldBlock(s)) => {
self.inner = Some(Err(HandshakeError::WouldBlock(s)));
Ok(Async::NotReady)
}
},
}
}
}
}
#[cfg(feature = "socks")] #[cfg(feature = "socks")]
mod socks { mod socks {
use std::io; use std::io;
@@ -678,19 +489,18 @@ mod socks {
Proxy, Proxy,
} }
pub(super) fn connect(proxy: ProxyScheme, dst: Destination, dns: DnsResolve) -> Connecting { pub(super) async fn connect(
proxy: ProxyScheme,
dst: Destination,
dns: DnsResolve,
) -> Result<(super::Conn, Connected), io::Error> {
let https = dst.scheme() == "https"; let https = dst.scheme() == "https";
let original_host = dst.host().to_owned(); let original_host = dst.host().to_owned();
let mut host = original_host.clone(); let mut host = original_host.clone();
let port = dst.port().unwrap_or_else(|| if https { 443 } else { 80 }); let port = dst.port().unwrap_or_else(|| if https { 443 } else { 80 });
if let DnsResolve::Local = dns { if let DnsResolve::Local = dns {
let maybe_new_target = match (host.as_str(), port).to_socket_addrs() { let maybe_new_target = (host.as_str(), port).to_socket_addrs()?.next();
Ok(mut iter) => iter.next(),
Err(err) => {
return Box::new(future::err(err));
}
};
if let Some(new_target) = maybe_new_target { if let Some(new_target) = maybe_new_target {
host = new_target.ip().to_string(); host = new_target.ip().to_string();
} }
@@ -702,39 +512,33 @@ mod socks {
}; };
// Get a Tokio TcpStream // Get a Tokio TcpStream
let stream = future::result( let stream = if let Some((username, password)) = auth {
if let Some((username, password)) = auth {
Socks5Stream::connect_with_password( Socks5Stream::connect_with_password(
socket_addr, socket_addr,
(host.as_str(), port), (host.as_str(), port),
&username, &username,
&password, &password,
) )
.await
} else { } else {
Socks5Stream::connect(socket_addr, (host.as_str(), port)) let s = Socks5Stream::connect(socket_addr, (host.as_str(), port)).await;
}
.and_then(|s| {
TcpStream::from_std(s.into_inner(), &reactor::Handle::default()) TcpStream::from_std(s.into_inner(), &reactor::Handle::default())
.map_err(|e| io::Error::new(io::ErrorKind::Other, e)) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
}), };
);
Box::new(stream.map(|s| (Box::new(s) as super::Conn, Connected::new()))) Ok((Box::new(s) as super::Conn, Connected::new()))
} }
} }
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
extern crate tokio_tcp;
use self::tokio_tcp::TcpStream;
use super::tunnel; use super::tunnel;
use crate::proxy; use crate::proxy;
use futures::Future;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::net::TcpListener; use std::net::TcpListener;
use std::thread; use std::thread;
use tokio::net::tcp::TcpStream;
use tokio::runtime::current_thread::Runtime; use tokio::runtime::current_thread::Runtime;
static TUNNEL_OK: &[u8] = b"\ static TUNNEL_OK: &[u8] = b"\
@@ -782,12 +586,14 @@ mod tests {
let addr = mock_tunnel!(); let addr = mock_tunnel!();
let mut rt = Runtime::new().unwrap(); let mut rt = Runtime::new().unwrap();
let work = TcpStream::connect(&addr); let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string(); let host = addr.ip().to_string();
let port = addr.port(); let port = addr.port();
let work = work.and_then(|tcp| tunnel(tcp, host, port, None)); tunnel(tcp, host, port, None).await
};
rt.block_on(work).unwrap(); rt.block_on(f).unwrap();
} }
#[test] #[test]
@@ -795,12 +601,14 @@ mod tests {
let addr = mock_tunnel!(b"HTTP/1.1 200 OK"); let addr = mock_tunnel!(b"HTTP/1.1 200 OK");
let mut rt = Runtime::new().unwrap(); let mut rt = Runtime::new().unwrap();
let work = TcpStream::connect(&addr); let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string(); let host = addr.ip().to_string();
let port = addr.port(); let port = addr.port();
let work = work.and_then(|tcp| tunnel(tcp, host, port, None)); tunnel(tcp, host, port, None).await
};
rt.block_on(work).unwrap_err(); rt.block_on(f).unwrap_err();
} }
#[test] #[test]
@@ -808,12 +616,14 @@ mod tests {
let addr = mock_tunnel!(b"foo bar baz hallo"); let addr = mock_tunnel!(b"foo bar baz hallo");
let mut rt = Runtime::new().unwrap(); let mut rt = Runtime::new().unwrap();
let work = TcpStream::connect(&addr); let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string(); let host = addr.ip().to_string();
let port = addr.port(); let port = addr.port();
let work = work.and_then(|tcp| tunnel(tcp, host, port, None)); tunnel(tcp, host, port, None).await
};
rt.block_on(work).unwrap_err(); rt.block_on(f).unwrap_err();
} }
#[test] #[test]
@@ -827,12 +637,14 @@ mod tests {
); );
let mut rt = Runtime::new().unwrap(); let mut rt = Runtime::new().unwrap();
let work = TcpStream::connect(&addr); let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string(); let host = addr.ip().to_string();
let port = addr.port(); let port = addr.port();
let work = work.and_then(|tcp| tunnel(tcp, host, port, None)); tunnel(tcp, host, port, None).await
};
let error = rt.block_on(work).unwrap_err(); let error = rt.block_on(f).unwrap_err();
assert_eq!(error.to_string(), "proxy authentication required"); assert_eq!(error.to_string(), "proxy authentication required");
} }
@@ -844,18 +656,19 @@ mod tests {
); );
let mut rt = Runtime::new().unwrap(); let mut rt = Runtime::new().unwrap();
let work = TcpStream::connect(&addr); let f = async move {
let tcp = TcpStream::connect(&addr).await?;
let host = addr.ip().to_string(); let host = addr.ip().to_string();
let port = addr.port(); let port = addr.port();
let work = work.and_then(|tcp| {
tunnel( tunnel(
tcp, tcp,
host, host,
port, port,
Some(proxy::encode_basic_auth("Aladdin", "open sesame")), Some(proxy::encode_basic_auth("Aladdin", "open sesame")),
) )
}); .await
};
rt.block_on(work).unwrap(); rt.block_on(f).unwrap();
} }
} }

View File

@@ -47,7 +47,7 @@ impl TrustDnsResolver {
impl hyper_dns::Resolve for TrustDnsResolver { impl hyper_dns::Resolve for TrustDnsResolver {
type Addrs = vec::IntoIter<IpAddr>; type Addrs = vec::IntoIter<IpAddr>;
type Future = Box<dyn Future<Item = Self::Addrs, Error = io::Error> + Send>; type Future = Box<dyn Future<Output = Result<Self::Addrs, io::Error>> + Send>;
fn resolve(&self, name: hyper_dns::Name) -> Self::Future { fn resolve(&self, name: hyper_dns::Name) -> Self::Future {
let inner = self.inner.clone(); let inner = self.inner.clone();

View File

@@ -2,8 +2,6 @@ use std::error::Error as StdError;
use std::fmt; use std::fmt;
use std::io; use std::io;
use tokio_executor::EnterError;
use crate::{StatusCode, Url}; use crate::{StatusCode, Url};
/// The Errors that may occur when processing a `Request`. /// The Errors that may occur when processing a `Request`.
@@ -95,7 +93,6 @@ impl Error {
} }
pub(crate) fn with_url(mut self, url: Url) -> Error { pub(crate) fn with_url(mut self, url: Url) -> Error {
debug_assert_eq!(self.inner.url, None, "with_url overriding existing url");
self.inner.url = Some(url); self.inner.url = Some(url);
self self
} }
@@ -221,6 +218,13 @@ impl Error {
_ => None, _ => None,
} }
} }
pub(crate) fn into_io(self) -> io::Error {
match self.inner.kind {
Kind::Io(io) => io,
_ => io::Error::new(io::ErrorKind::Other, self),
}
}
} }
impl fmt::Debug for Error { impl fmt::Debug for Error {
@@ -475,8 +479,8 @@ where
} }
} }
impl From<EnterError> for Kind { impl From<tokio_executor::EnterError> for Kind {
fn from(_err: EnterError) -> Kind { fn from(_err: tokio_executor::EnterError) -> Kind {
Kind::BlockingClientInFutureContext Kind::BlockingClientInFutureContext
} }
} }
@@ -521,10 +525,7 @@ where
} }
pub(crate) fn into_io(e: Error) -> io::Error { pub(crate) fn into_io(e: Error) -> io::Error {
match e.inner.kind { e.into_io()
Kind::Io(io) => io,
_ => io::Error::new(io::ErrorKind::Other, e),
}
} }
pub(crate) fn from_io(e: io::Error) -> Error { pub(crate) fn from_io(e: io::Error) -> Error {
@@ -538,39 +539,12 @@ pub(crate) fn from_io(e: io::Error) -> Error {
} }
} }
macro_rules! try_ { macro_rules! url_error {
($e:expr) => {
match $e {
Ok(v) => v,
Err(err) => {
return Err(crate::error::from(err));
}
}
};
($e:expr, $url:expr) => { ($e:expr, $url:expr) => {
match $e { Err(crate::Error::from(crate::error::InternalFrom(
Ok(v) => v, $e,
Err(err) => {
return Err(crate::Error::from(crate::error::InternalFrom(
err,
Some($url.clone()), Some($url.clone()),
))); )))
}
}
};
}
macro_rules! try_io {
($e:expr) => {
match $e {
Ok(v) => v,
Err(ref err) if err.kind() == std::io::ErrorKind::WouldBlock => {
return Ok(futures::Async::NotReady);
}
Err(err) => {
return Err(crate::error::from_io(err));
}
}
}; };
} }
@@ -607,6 +581,9 @@ pub(crate) fn unknown_proxy_scheme() -> Error {
mod tests { mod tests {
use super::*; use super::*;
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
#[allow(deprecated)] #[allow(deprecated)]
#[test] #[test]
fn test_cause_chain() { fn test_cause_chain() {
@@ -652,6 +629,8 @@ mod tests {
let err = Error::new(Kind::Io(io), None); let err = Error::new(Kind::Io(io), None);
assert!(err.cause().is_some()); assert!(err.cause().is_some());
assert_eq!(err.to_string(), "chain: root"); assert_eq!(err.to_string(), "chain: root");
assert_send::<Error>();
assert_sync::<Error>();
} }
#[test] #[test]

View File

@@ -27,7 +27,7 @@ impl PolyfillTryInto for Url {
impl<'a> PolyfillTryInto for &'a str { impl<'a> PolyfillTryInto for &'a str {
fn into_url(self) -> crate::Result<Url> { fn into_url(self) -> crate::Result<Url> {
try_!(Url::parse(self)).into_url() Url::parse(self).map_err(crate::error::from)?.into_url()
} }
} }

View File

@@ -158,8 +158,6 @@
//! - **default-tls-vendored**: Enables the `vendored` feature of `native-tls`. //! - **default-tls-vendored**: Enables the `vendored` feature of `native-tls`.
//! - **rustls-tls**: Provides TLS support via the `rustls` library. //! - **rustls-tls**: Provides TLS support via the `rustls` library.
//! - **socks**: Provides SOCKS5 proxy support. //! - **socks**: Provides SOCKS5 proxy support.
//! - **trust-dns**: Enables a trust-dns async resolver instead of default
//! threadpool using `getaddrinfo`.
//! - **hyper-011**: Provides support for hyper's old typed headers. //! - **hyper-011**: Provides support for hyper's old typed headers.
//! //!
//! //!
@@ -173,6 +171,9 @@
//! [Proxy]: ./struct.Proxy.html //! [Proxy]: ./struct.Proxy.html
//! [cargo-features]: https://doc.rust-lang.org/stable/cargo/reference/manifest.html#the-features-section //! [cargo-features]: https://doc.rust-lang.org/stable/cargo/reference/manifest.html#the-features-section
////! - **trust-dns**: Enables a trust-dns async resolver instead of default
////! threadpool using `getaddrinfo`.
extern crate cookie as cookie_crate; extern crate cookie as cookie_crate;
#[cfg(feature = "hyper-011")] #[cfg(feature = "hyper-011")]
pub use hyper_old_types as hyper_011; pub use hyper_old_types as hyper_011;
@@ -210,8 +211,8 @@ mod body;
mod client; mod client;
mod connect; mod connect;
pub mod cookie; pub mod cookie;
#[cfg(feature = "trust-dns")] //#[cfg(feature = "trust-dns")]
mod dns; //mod dns;
mod into_url; mod into_url;
mod proxy; mod proxy;
mod redirect; mod redirect;

View File

@@ -213,7 +213,6 @@ impl Part {
let file_name = path let file_name = path
.file_name() .file_name()
.map(|filename| filename.to_string_lossy().into_owned()); .map(|filename| filename.to_string_lossy().into_owned());
let ext = path.extension().and_then(|ext| ext.to_str()).unwrap_or(""); let ext = path.extension().and_then(|ext| ext.to_str()).unwrap_or("");
let mime = mime_guess::from_ext(ext).first_or_octet_stream(); let mime = mime_guess::from_ext(ext).first_or_octet_stream();
let file = File::open(path)?; let file = File::open(path)?;
@@ -235,7 +234,7 @@ impl Part {
/// Tries to set the mime of this part. /// Tries to set the mime of this part.
pub fn mime_str(self, mime: &str) -> crate::Result<Part> { pub fn mime_str(self, mime: &str) -> crate::Result<Part> {
Ok(self.mime(try_!(mime.parse()))) Ok(self.mime(mime.parse().map_err(crate::error::from)?))
} }
// Re-export when mime 0.4 is available, with split MediaType/MediaRange. // Re-export when mime 0.4 is available, with split MediaType/MediaRange.

View File

@@ -2,9 +2,9 @@ use std::fmt;
use std::io::{self, Read}; use std::io::{self, Read};
use std::mem; use std::mem;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::pin::Pin;
use std::time::Duration; use std::time::Duration;
use futures::{Async, Poll, Stream};
use http; use http;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
@@ -16,7 +16,7 @@ use hyper::header::HeaderMap;
/// A Response to a submitted `Request`. /// A Response to a submitted `Request`.
pub struct Response { pub struct Response {
inner: async_impl::Response, inner: async_impl::Response,
body: Option<async_impl::ReadableChunks<WaitBody>>, body: Option<Pin<Box<dyn futures::io::AsyncRead + Send + Sync>>>,
timeout: Option<Duration>, timeout: Option<Duration>,
_thread_handle: KeepCoreThreadAlive, _thread_handle: KeepCoreThreadAlive,
} }
@@ -289,7 +289,6 @@ impl Response {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
#[inline]
pub fn copy_to<W: ?Sized>(&mut self, w: &mut W) -> crate::Result<u64> pub fn copy_to<W: ?Sized>(&mut self, w: &mut W) -> crate::Result<u64>
where where
W: io::Write, W: io::Write,
@@ -349,47 +348,32 @@ impl Response {
pub fn error_for_status_ref(&self) -> crate::Result<&Self> { pub fn error_for_status_ref(&self) -> crate::Result<&Self> {
self.inner.error_for_status_ref().and_then(|_| Ok(self)) self.inner.error_for_status_ref().and_then(|_| Ok(self))
} }
// private
fn body_mut(&mut self) -> Pin<&mut dyn futures::io::AsyncRead> {
use futures::stream::TryStreamExt;
if self.body.is_none() {
let body = mem::replace(self.inner.body_mut(), async_impl::Decoder::empty());
let body = body.map_err(crate::error::into_io).into_async_read();
self.body = Some(Box::pin(body));
}
self.body.as_mut().expect("body was init").as_mut()
}
} }
impl Read for Response { impl Read for Response {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.body.is_none() { use futures::io::AsyncReadExt;
let body = mem::replace(self.inner.body_mut(), async_impl::Decoder::empty());
let body = async_impl::ReadableChunks::new(WaitBody {
inner: wait::stream(body, self.timeout),
});
self.body = Some(body);
}
let mut body = self.body.take().unwrap();
let bytes = body.read(buf);
self.body = Some(body);
bytes
}
}
struct WaitBody { let timeout = self.timeout;
inner: wait::WaitStream<async_impl::Decoder>, wait::timeout(self.body_mut().read(buf), timeout).map_err(|e| match e {
} wait::Waited::TimedOut => crate::error::timedout(None).into_io(),
wait::Waited::Executor(e) => crate::error::from(e).into_io(),
impl Stream for WaitBody {
type Item = <async_impl::Decoder as Stream>::Item;
type Error = <async_impl::Decoder as Stream>::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
match self.inner.next() {
Some(Ok(chunk)) => Ok(Async::Ready(Some(chunk))),
Some(Err(e)) => {
let req_err = match e {
wait::Waited::TimedOut => crate::error::timedout(None),
wait::Waited::Executor(e) => crate::error::from(e),
wait::Waited::Inner(e) => e, wait::Waited::Inner(e) => e,
}; })
Err(req_err)
}
None => Ok(Async::Ready(None)),
}
} }
} }

View File

@@ -55,7 +55,7 @@ impl Certificate {
pub fn from_der(der: &[u8]) -> crate::Result<Certificate> { pub fn from_der(der: &[u8]) -> crate::Result<Certificate> {
Ok(Certificate { Ok(Certificate {
#[cfg(feature = "default-tls")] #[cfg(feature = "default-tls")]
native: try_!(native_tls::Certificate::from_der(der)), native: native_tls::Certificate::from_der(der).map_err(crate::error::from)?,
#[cfg(feature = "rustls-tls")] #[cfg(feature = "rustls-tls")]
original: Cert::Der(der.to_owned()), original: Cert::Der(der.to_owned()),
}) })
@@ -80,7 +80,7 @@ impl Certificate {
pub fn from_pem(pem: &[u8]) -> crate::Result<Certificate> { pub fn from_pem(pem: &[u8]) -> crate::Result<Certificate> {
Ok(Certificate { Ok(Certificate {
#[cfg(feature = "default-tls")] #[cfg(feature = "default-tls")]
native: try_!(native_tls::Certificate::from_pem(pem)), native: native_tls::Certificate::from_pem(pem).map_err(crate::error::from)?,
#[cfg(feature = "rustls-tls")] #[cfg(feature = "rustls-tls")]
original: Cert::Pem(pem.to_owned()), original: Cert::Pem(pem.to_owned()),
}) })
@@ -146,7 +146,9 @@ impl Identity {
#[cfg(feature = "default-tls")] #[cfg(feature = "default-tls")]
pub fn from_pkcs12_der(der: &[u8], password: &str) -> crate::Result<Identity> { pub fn from_pkcs12_der(der: &[u8], password: &str) -> crate::Result<Identity> {
Ok(Identity { Ok(Identity {
inner: ClientCert::Pkcs12(try_!(native_tls::Identity::from_pkcs12(der, password))), inner: ClientCert::Pkcs12(
native_tls::Identity::from_pkcs12(der, password).map_err(crate::error::from)?,
),
}) })
} }
@@ -176,10 +178,11 @@ impl Identity {
let (key, certs) = { let (key, certs) = {
let mut pem = Cursor::new(buf); let mut pem = Cursor::new(buf);
let certs = try_!(pemfile::certs(&mut pem) let certs = pemfile::certs(&mut pem)
.map_err(|_| TLSError::General(String::from("No valid certificate was found")))); .map_err(|_| TLSError::General(String::from("No valid certificate was found")))
.map_err(crate::error::from)?;
pem.set_position(0); pem.set_position(0);
let mut sk = try_!(pemfile::pkcs8_private_keys(&mut pem) let mut sk = pemfile::pkcs8_private_keys(&mut pem)
.and_then(|pkcs8_keys| { .and_then(|pkcs8_keys| {
if pkcs8_keys.is_empty() { if pkcs8_keys.is_empty() {
Err(()) Err(())
@@ -191,7 +194,8 @@ impl Identity {
pem.set_position(0); pem.set_position(0);
pemfile::rsa_private_keys(&mut pem) pemfile::rsa_private_keys(&mut pem)
}) })
.map_err(|_| TLSError::General(String::from("No valid private key was found")))); .map_err(|_| TLSError::General(String::from("No valid private key was found")))
.map_err(crate::error::from)?;
if let (Some(sk), false) = (sk.pop(), certs.is_empty()) { if let (Some(sk), false) = (sk.pop(), certs.is_empty()) {
(sk, certs) (sk, certs)
} else { } else {

View File

@@ -1,26 +1,53 @@
use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use std::thread; use std::task::{Context, Poll};
use std::time::{Duration, Instant}; use std::time::Duration;
use futures::executor::{self, Notify}; use tokio::clock;
use futures::{Async, Future, Poll, Stream}; use tokio_executor::{
use tokio_executor::{enter, EnterError}; enter,
park::{Park, ParkThread, Unpark, UnparkThread},
EnterError,
};
pub(crate) fn timeout<F>(fut: F, timeout: Option<Duration>) -> Result<F::Item, Waited<F::Error>> pub(crate) fn timeout<F, I, E>(fut: F, timeout: Option<Duration>) -> Result<I, Waited<E>>
where where
F: Future, F: Future<Output = Result<I, E>>,
{ {
let mut spawn = executor::spawn(fut); let _entered = enter().map_err(Waited::Executor)?;
block_on(timeout, |notify| spawn.poll_future_notify(notify, 0)) let deadline = timeout.map(|d| {
log::trace!("wait at most {:?}", d);
clock::now() + d
});
let mut park = ParkThread::new();
// Arc shouldn't be necessary, since UnparkThread is reference counted internally,
// but let's just stay safe for now.
let waker = futures::task::waker(Arc::new(UnparkWaker(park.unpark())));
let mut cx = Context::from_waker(&waker);
futures::pin_mut!(fut);
loop {
match fut.as_mut().poll(&mut cx) {
Poll::Ready(Ok(val)) => return Ok(val),
Poll::Ready(Err(err)) => return Err(Waited::Inner(err)),
Poll::Pending => (), // fallthrough
} }
pub(crate) fn stream<S>(stream: S, timeout: Option<Duration>) -> WaitStream<S> if let Some(deadline) = deadline {
where let now = clock::now();
S: Stream, if now >= deadline {
{ log::trace!("wait timeout exceeded");
WaitStream { return Err(Waited::TimedOut);
stream: executor::spawn(stream), }
timeout,
log::trace!("park timeout {:?}", deadline - now);
park.park_timeout(deadline - now)
.expect("ParkThread doesn't error");
} else {
park.park().expect("ParkThread doesn't error");
}
} }
} }
@@ -31,71 +58,10 @@ pub(crate) enum Waited<E> {
Inner(E), Inner(E),
} }
impl<E> From<E> for Waited<E> { struct UnparkWaker(UnparkThread);
fn from(err: E) -> Waited<E> {
Waited::Inner(err)
}
}
pub(crate) struct WaitStream<S> { impl futures::task::ArcWake for UnparkWaker {
stream: executor::Spawn<S>, fn wake_by_ref(arc_self: &Arc<Self>) {
timeout: Option<Duration>, arc_self.0.unpark();
}
impl<S> Iterator for WaitStream<S>
where
S: Stream,
{
type Item = Result<S::Item, Waited<S::Error>>;
fn next(&mut self) -> Option<Self::Item> {
let res = block_on(self.timeout, |notify| {
self.stream.poll_stream_notify(notify, 0)
});
match res {
Ok(Some(val)) => Some(Ok(val)),
Ok(None) => None,
Err(err) => Some(Err(err)),
}
}
}
struct ThreadNotify {
thread: thread::Thread,
}
impl Notify for ThreadNotify {
fn notify(&self, _id: usize) {
self.thread.unpark();
}
}
fn block_on<F, U, E>(timeout: Option<Duration>, mut poll: F) -> Result<U, Waited<E>>
where
F: FnMut(&Arc<ThreadNotify>) -> Poll<U, E>,
{
let _entered = enter().map_err(Waited::Executor)?;
let deadline = timeout.map(|d| Instant::now() + d);
let notify = Arc::new(ThreadNotify {
thread: thread::current(),
});
loop {
match poll(&notify)? {
Async::Ready(val) => return Ok(val),
Async::NotReady => {}
}
if let Some(deadline) = deadline {
let now = Instant::now();
if now >= deadline {
return Err(Waited::TimedOut);
}
thread::park_timeout(deadline - now);
} else {
thread::park();
}
} }
} }

View File

@@ -1,29 +1,28 @@
#[macro_use] #[macro_use]
mod support; mod support;
use std::io::{self, Write}; use std::io::Write;
use std::time::Duration; use std::time::Duration;
use futures::{Future, Stream}; use futures::TryStreamExt;
use tokio::runtime::current_thread::Runtime;
use reqwest::r#async::multipart::{Form, Part}; use reqwest::r#async::multipart::{Form, Part};
use reqwest::r#async::{Chunk, Client}; use reqwest::r#async::{Body, Client};
use bytes::Bytes; use bytes::Bytes;
#[test] #[tokio::test]
fn gzip_response() { async fn gzip_response() {
gzip_case(10_000, 4096); gzip_case(10_000, 4096).await;
} }
#[test] #[tokio::test]
fn gzip_single_byte_chunks() { async fn gzip_single_byte_chunks() {
gzip_case(10, 1); gzip_case(10, 1).await;
} }
#[test] #[tokio::test]
fn response_text() { async fn response_text() {
let _ = env_logger::try_init(); let _ = env_logger::try_init();
let server = server! { let server = server! {
@@ -43,24 +42,19 @@ fn response_text() {
" "
}; };
let mut rt = Runtime::new().expect("new rt");
let client = Client::new(); let client = Client::new();
let res_future = client let mut res = client
.get(&format!("http://{}/text", server.addr())) .get(&format!("http://{}/text", server.addr()))
.send() .send()
.and_then(|mut res| res.text()) .await
.and_then(|text| { .expect("Failed to get");
let text = res.text().await.expect("Failed to get text");
assert_eq!("Hello", text); assert_eq!("Hello", text);
Ok(())
});
rt.block_on(res_future).unwrap();
} }
#[test] #[tokio::test]
fn response_json() { async fn response_json() {
let _ = env_logger::try_init(); let _ = env_logger::try_init();
let server = server! { let server = server! {
@@ -80,28 +74,24 @@ fn response_json() {
" "
}; };
let mut rt = Runtime::new().expect("new rt");
let client = Client::new(); let client = Client::new();
let res_future = client let mut res = client
.get(&format!("http://{}/json", server.addr())) .get(&format!("http://{}/json", server.addr()))
.send() .send()
.and_then(|mut res| res.json::<String>()) .await
.and_then(|text| { .expect("Failed to get");
let text = res.json::<String>().await.expect("Failed to get json");
assert_eq!("Hello", text); assert_eq!("Hello", text);
Ok(())
});
rt.block_on(res_future).unwrap();
} }
#[test] #[tokio::test]
fn multipart() { async fn multipart() {
let _ = env_logger::try_init(); let _ = env_logger::try_init();
let stream = let stream = futures::stream::once(futures::future::ready::<Result<_, hyper::Error>>(Ok(
futures::stream::once::<_, hyper::Error>(Ok(Chunk::from("part1 part2".to_owned()))); hyper::Chunk::from("part1 part2".to_owned()),
)));
let part = Part::stream(stream); let part = Part::stream(stream);
let form = Form::new().text("foo", "bar").part("part_stream", part); let form = Form::new().text("foo", "bar").part("part_stream", part);
@@ -153,22 +143,20 @@ fn multipart() {
let url = format!("http://{}/multipart/1", server.addr()); let url = format!("http://{}/multipart/1", server.addr());
let mut rt = Runtime::new().expect("new rt");
let client = Client::new(); let client = Client::new();
let res_future = client.post(&url).multipart(form).send().and_then(|res| { let res = client
.post(&url)
.multipart(form)
.send()
.await
.expect("Failed to post multipart");
assert_eq!(res.url().as_str(), &url); assert_eq!(res.url().as_str(), &url);
assert_eq!(res.status(), reqwest::StatusCode::OK); assert_eq!(res.status(), reqwest::StatusCode::OK);
Ok(())
});
rt.block_on(res_future).unwrap();
} }
#[test] #[tokio::test]
fn request_timeout() { async fn request_timeout() {
let _ = env_logger::try_init(); let _ = env_logger::try_init();
let server = server! { let server = server! {
@@ -189,24 +177,23 @@ fn request_timeout() {
read_timeout: Duration::from_secs(2) read_timeout: Duration::from_secs(2)
}; };
let mut rt = Runtime::new().expect("new rt");
let client = Client::builder() let client = Client::builder()
.timeout(Duration::from_millis(500)) .timeout(Duration::from_millis(500))
.build() .build()
.unwrap(); .unwrap();
let url = format!("http://{}/slow", server.addr()); let url = format!("http://{}/slow", server.addr());
let fut = client.get(&url).send();
let err = rt.block_on(fut).unwrap_err(); let res = client.get(&url).send().await;
let err = res.unwrap_err();
assert!(err.is_timeout()); assert!(err.is_timeout());
assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str())); assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str()));
} }
#[test] #[tokio::test]
fn response_timeout() { async fn response_timeout() {
let _ = env_logger::try_init(); let _ = env_logger::try_init();
let server = server! { let server = server! {
@@ -227,25 +214,21 @@ fn response_timeout() {
write_timeout: Duration::from_secs(2) write_timeout: Duration::from_secs(2)
}; };
let mut rt = Runtime::new().expect("new rt");
let client = Client::builder() let client = Client::builder()
.timeout(Duration::from_millis(500)) .timeout(Duration::from_millis(500))
.build() .build()
.unwrap(); .unwrap();
let url = format!("http://{}/slow", server.addr()); let url = format!("http://{}/slow", server.addr());
let fut = client let res = client.get(&url).send().await.expect("Failed to get");
.get(&url) let body: Result<_, _> = res.into_body().try_concat().await;
.send()
.and_then(|res| res.into_body().concat2());
let err = rt.block_on(fut).unwrap_err(); let err = body.unwrap_err();
assert!(err.is_timeout()); assert!(err.is_timeout());
} }
fn gzip_case(response_size: usize, chunk_size: usize) { async fn gzip_case(response_size: usize, chunk_size: usize) {
let content: String = (0..response_size) let content: String = (0..response_size)
.into_iter() .into_iter()
.map(|i| format!("test {}", i)) .map(|i| format!("test {}", i))
@@ -284,37 +267,26 @@ fn gzip_case(response_size: usize, chunk_size: usize) {
response: response response: response
}; };
let mut rt = Runtime::new().expect("new rt");
let client = Client::new(); let client = Client::new();
let res_future = client let mut res = client
.get(&format!("http://{}/gzip", server.addr())) .get(&format!("http://{}/gzip", server.addr()))
.send() .send()
.and_then(|res| { .await
let body = res.into_body(); .expect("response");
body.concat2()
})
.and_then(|buf| {
let body = std::str::from_utf8(&buf).unwrap();
assert_eq!(body, &content); let body = res.text().await.expect("text");
assert_eq!(body, content);
Ok(())
});
rt.block_on(res_future).unwrap();
} }
#[test] #[tokio::test]
fn body_stream() { async fn body_stream() {
let _ = env_logger::try_init(); let _ = env_logger::try_init();
let source: Box<dyn Stream<Item = Bytes, Error = io::Error> + Send> = let source = futures::stream::iter::<Vec<Result<Bytes, std::io::Error>>>(vec![
Box::new(futures::stream::iter_ok::<_, io::Error>(vec![ Ok(Bytes::from_static(b"123")),
Bytes::from_static(b"123"), Ok(Bytes::from_static(b"4567")),
Bytes::from_static(b"4567"), ]);
]));
let expected_body = "3\r\n123\r\n4\r\n4567\r\n0\r\n\r\n"; let expected_body = "3\r\n123\r\n4\r\n4567\r\n0\r\n\r\n";
@@ -339,16 +311,15 @@ fn body_stream() {
let url = format!("http://{}/post", server.addr()); let url = format!("http://{}/post", server.addr());
let mut rt = Runtime::new().expect("new rt");
let client = Client::new(); let client = Client::new();
let res_future = client.post(&url).body(source).send().and_then(|res| { let res = client
.post(&url)
.body(Body::wrap_stream(source))
.send()
.await
.expect("Failed to post");
assert_eq!(res.url().as_str(), &url); assert_eq!(res.url().as_str(), &url);
assert_eq!(res.status(), reqwest::StatusCode::OK); assert_eq!(res.status(), reqwest::StatusCode::OK);
Ok(())
});
rt.block_on(res_future).unwrap();
} }

View File

@@ -1,8 +1,6 @@
#[macro_use] #[macro_use]
mod support; mod support;
use std::io::Read;
#[test] #[test]
fn test_response_text() { fn test_response_text() {
let server = server! { let server = server! {
@@ -137,9 +135,7 @@ fn test_response_copy_to() {
&"5" &"5"
); );
let mut buf: Vec<u8> = vec![]; assert_eq!("Hello".to_owned(), res.text().unwrap());
res.copy_to(&mut buf).unwrap();
assert_eq!(b"Hello", buf.as_slice());
} }
#[test] #[test]
@@ -173,9 +169,7 @@ fn test_get() {
); );
assert_eq!(res.remote_addr(), Some(server.addr())); assert_eq!(res.remote_addr(), Some(server.addr()));
let mut buf = [0; 1024]; assert_eq!(res.text().unwrap().len(), 0)
let n = res.read(&mut buf).unwrap();
assert_eq!(n, 0)
} }
#[test] #[test]
@@ -214,9 +208,7 @@ fn test_post() {
&"0" &"0"
); );
let mut buf = [0; 1024]; assert_eq!(res.text().unwrap().len(), 0)
let n = res.read(&mut buf).unwrap();
assert_eq!(n, 0)
} }
#[test] #[test]

View File

@@ -1,9 +1,10 @@
#[macro_use] #[macro_use]
mod support; mod support;
use std::io::Read;
use std::time::Duration; use std::time::Duration;
/// Tests that internal client future cancels when the oneshot channel
/// is canceled.
#[test] #[test]
fn timeout_closes_connection() { fn timeout_closes_connection() {
let _ = env_logger::try_init(); let _ = env_logger::try_init();
@@ -156,7 +157,6 @@ fn test_read_timeout() {
&"5" &"5"
); );
let mut buf = [0; 1024]; let err = res.text().unwrap_err();
let err = res.read(&mut buf).unwrap_err();
assert_eq!(err.to_string(), "timed out"); assert_eq!(err.to_string(), "timed out");
} }