From ce51fe83d6ee07b9e6c3069a2622fcd4251c1b83 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 22 Apr 2019 12:43:30 -0700 Subject: [PATCH] Add request timeout support to async Client (#501) Closes #496 --- src/async_impl/client.rs | 33 +++++++++++++++----- src/error.rs | 15 ++++++++-- tests/async.rs | 65 +++++++++++++++++++++++++++++++++------- 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 7703715..f9a4ce2 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -26,6 +26,7 @@ use hyper::client::ResponseFuture; use mime; #[cfg(feature = "default-tls")] use native_tls::TlsConnector; +use tokio::{clock, timer::Delay}; use super::request::{Request, RequestBuilder}; @@ -212,14 +213,15 @@ impl ClientBuilder { Ok(Client { inner: Arc::new(ClientRef { + cookie_store, gzip: config.gzip, hyper: hyper_client, headers: config.headers, redirect_policy: config.redirect_policy, referer: config.referer, + request_timeout: config.timeout, proxies, proxies_maybe_http_auth, - cookie_store, }), }) } @@ -341,16 +343,20 @@ impl ClientBuilder { self } - // Currently not used, so hide from docs. - #[doc(hidden)] + /// Enables a request timeout. + /// + /// The timeout is applied from the when the request starts connecting + /// until the response headers are received. Bodies are not affected. + /// + /// Default is no timeout. pub fn timeout(mut self, timeout: Duration) -> ClientBuilder { self.config.timeout = Some(timeout); self } /// Sets the maximum idle connection per host allowed in the pool. - // - // Default is usize::MAX (no limit). + /// + /// Default is usize::MAX (no limit). pub fn max_idle_per_host(mut self, max: usize) -> ClientBuilder { self.config.max_idle_per_host = max; self @@ -585,6 +591,10 @@ impl Client { let in_flight = self.inner.hyper.request(req); + let timeout = self.inner.request_timeout.map(|dur| { + Delay::new(clock::now() + dur) + }); + Pending { inner: PendingInner::Request(PendingRequest { method: method, @@ -597,6 +607,7 @@ impl Client { client: self.inner.clone(), in_flight: in_flight, + timeout, }), } } @@ -654,17 +665,18 @@ impl fmt::Debug for ClientBuilder { } struct ClientRef { + cookie_store: Option>, gzip: bool, headers: HeaderMap, hyper: HyperClient, redirect_policy: RedirectPolicy, referer: bool, + request_timeout: Option, proxies: Arc>, proxies_maybe_http_auth: bool, - cookie_store: Option>, } -pub struct Pending { +pub(super) struct Pending { inner: PendingInner, } @@ -684,6 +696,7 @@ struct PendingRequest { client: Arc, in_flight: ResponseFuture, + timeout: Option, } impl Pending { @@ -711,6 +724,12 @@ impl Future for PendingRequest { type Error = ::Error; fn poll(&mut self) -> Poll { + if let Some(ref mut delay) = self.timeout { + if let Async::Ready(()) = try_!(delay.poll(), &self.url) { + return Err(::error::timedout(Some(self.url.clone()))); + } + } + loop { let res = match try_!(self.in_flight.poll(), &self.url) { Async::Ready(res) => res, diff --git a/src/error.rs b/src/error.rs index b7e88f2..a9241e6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -150,7 +150,8 @@ impl Error { Kind::RedirectLoop | Kind::ClientError(_) | Kind::ServerError(_) | - Kind::UnknownProxyScheme => None, + Kind::UnknownProxyScheme | + Kind::Timer => None, } } @@ -275,6 +276,7 @@ impl fmt::Display for Error { fmt::Display::fmt(code, f) } Kind::UnknownProxyScheme => f.write_str("Unknown proxy scheme"), + Kind::Timer => f.write_str("timer unavailable"), } } } @@ -303,6 +305,7 @@ impl StdError for Error { Kind::ClientError(_) => "Client Error", Kind::ServerError(_) => "Server Error", Kind::UnknownProxyScheme => "Unknown proxy scheme", + Kind::Timer => "timer unavailable", } } @@ -329,7 +332,8 @@ impl StdError for Error { Kind::RedirectLoop | Kind::ClientError(_) | Kind::ServerError(_) | - Kind::UnknownProxyScheme => None, + Kind::UnknownProxyScheme | + Kind::Timer => None, } } } @@ -357,6 +361,7 @@ pub(crate) enum Kind { ClientError(StatusCode), ServerError(StatusCode), UnknownProxyScheme, + Timer, } @@ -433,6 +438,12 @@ where T: Into { } } +impl From<::tokio::timer::Error> for Kind { + fn from(_err: ::tokio::timer::Error) -> Kind { + Kind::Timer + } +} + fn io_timeout() -> io::Error { io::Error::new(io::ErrorKind::TimedOut, "timed out") } diff --git a/tests/async.rs b/tests/async.rs index 7c30e12..e944606 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -7,24 +7,27 @@ extern crate tokio; #[macro_use] mod support; -use reqwest::async::Client; -use reqwest::async::multipart::{Form, Part}; -use futures::{Future, Stream}; use std::io::Write; use std::time::Duration; +use futures::{Future, Stream}; +use tokio::runtime::current_thread::Runtime; + +use reqwest::async::Client; +use reqwest::async::multipart::{Form, Part}; + #[test] -fn async_test_gzip_response() { - test_gzip(10_000, 4096); +fn gzip_response() { + gzip_case(10_000, 4096); } #[test] -fn async_test_gzip_single_byte_chunks() { - test_gzip(10, 1); +fn gzip_single_byte_chunks() { + gzip_case(10, 1); } #[test] -fn async_test_multipart() { +fn multipart() { let _ = env_logger::try_init(); let stream = futures::stream::once::<_, hyper::Error>(Ok(hyper::Chunk::from("part1 part2".to_owned()))); @@ -78,7 +81,7 @@ fn async_test_multipart() { let url = format!("http://{}/multipart/1", server.addr()); - let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt"); + let mut rt = Runtime::new().expect("new rt"); let client = Client::new(); @@ -95,7 +98,47 @@ fn async_test_multipart() { rt.block_on(res_future).unwrap(); } -fn test_gzip(response_size: usize, chunk_size: usize) { +#[test] +fn request_timeout() { + let _ = env_logger::try_init(); + + let server = server! { + request: b"\ + GET /slow HTTP/1.1\r\n\ + user-agent: $USERAGENT\r\n\ + accept: */*\r\n\ + accept-encoding: gzip\r\n\ + host: $HOST\r\n\ + \r\n\ + ", + response: b"\ + HTTP/1.1 200 OK\r\n\ + Content-Length: 5\r\n\ + \r\n\ + Hello\ + ", + read_timeout: Duration::from_secs(2) + }; + + let mut rt = Runtime::new().expect("new rt"); + + let client = Client::builder() + .timeout(Duration::from_millis(500)) + .build() + .unwrap(); + + let url = format!("http://{}/slow", server.addr()); + let fut = client + .get(&url) + .send(); + + let err = rt.block_on(fut).unwrap_err(); + + assert!(err.is_timeout()); + assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str())); +} + +fn gzip_case(response_size: usize, chunk_size: usize) { let content: String = (0..response_size).into_iter().map(|i| format!("test {}", i)).collect(); let mut encoder = ::libflate::gzip::Encoder::new(Vec::new()).unwrap(); match encoder.write(content.as_bytes()) { @@ -128,7 +171,7 @@ fn test_gzip(response_size: usize, chunk_size: usize) { response: response }; - let mut rt = tokio::runtime::current_thread::Runtime::new().expect("new rt"); + let mut rt = Runtime::new().expect("new rt"); let client = Client::new();