From f6ce08545700d74d9235e2d264e207e47c516c29 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 22 Apr 2019 15:24:35 -0700 Subject: [PATCH] Propagate async timeout to response body (#503) --- src/async_impl/body.rs | 84 +++++++++++++++++++++++--------------- src/async_impl/client.rs | 4 +- src/async_impl/response.rs | 8 ++-- tests/async.rs | 40 ++++++++++++++++++ 4 files changed, 97 insertions(+), 39 deletions(-) diff --git a/src/async_impl/body.rs b/src/async_impl/body.rs index 8e7f1ae..1bc256d 100644 --- a/src/async_impl/body.rs +++ b/src/async_impl/body.rs @@ -1,8 +1,9 @@ -use std::{fmt, mem}; +use std::fmt; -use futures::{Stream, Poll, Async}; +use futures::{Future, Stream, Poll, Async}; use bytes::{Buf, Bytes}; use hyper::body::Payload; +use tokio::timer::Delay; /// An asynchronous `Stream`. pub struct Body { @@ -11,48 +12,43 @@ pub struct Body { enum Inner { Reusable(Bytes), - Hyper(::hyper::Body), + Hyper { + body: ::hyper::Body, + timeout: Option, + } } impl Body { - fn poll_inner(&mut self) -> &mut ::hyper::Body { - match self.inner { - Inner::Hyper(ref mut body) => return body, - Inner::Reusable(_) => (), - } - - let bytes = match mem::replace(&mut self.inner, Inner::Reusable(Bytes::new())) { - Inner::Reusable(bytes) => bytes, - Inner::Hyper(_) => unreachable!(), - }; - - self.inner = Inner::Hyper(bytes.into()); - - match self.inner { - Inner::Hyper(ref mut body) => return body, - Inner::Reusable(_) => unreachable!(), - } - } - pub(crate) fn content_length(&self) -> Option { match self.inner { Inner::Reusable(ref bytes) => Some(bytes.len() as u64), - Inner::Hyper(ref body) => body.content_length(), + Inner::Hyper { ref body, .. } => body.content_length(), + } + } + + #[inline] + pub(crate) fn response(body: ::hyper::Body, timeout: Option) -> Body { + Body { + inner: Inner::Hyper { + body, + timeout, + }, } } #[inline] pub(crate) fn wrap(body: ::hyper::Body) -> Body { Body { - inner: Inner::Hyper(body), + inner: Inner::Hyper { + body, + timeout: None, + }, } } #[inline] pub(crate) fn empty() -> Body { - Body { - inner: Inner::Hyper(::hyper::Body::empty()), - } + Body::wrap(::hyper::Body::empty()) } #[inline] @@ -66,7 +62,10 @@ impl Body { pub(crate) fn into_hyper(self) -> (Option, ::hyper::Body) { match self.inner { Inner::Reusable(chunk) => (Some(chunk.clone()), chunk.into()), - Inner::Hyper(b) => (None, b), + Inner::Hyper { body, timeout } => { + debug_assert!(timeout.is_none()); + (None, body) + }, } } } @@ -77,12 +76,29 @@ impl Stream for Body { #[inline] fn poll(&mut self) -> Poll, Self::Error> { - match try_!(self.poll_inner().poll()) { - Async::Ready(opt) => Ok(Async::Ready(opt.map(|chunk| Chunk { - inner: chunk, - }))), - Async::NotReady => Ok(Async::NotReady), - } + let opt = match self.inner { + Inner::Hyper { ref mut body, ref mut timeout } => { + if let Some(ref mut timeout) = timeout { + if let Async::Ready(()) = try_!(timeout.poll()) { + return Err(::error::timedout(None)); + } + } + try_ready!(body.poll_data().map_err(::error::from)) + }, + Inner::Reusable(ref mut bytes) => { + return if bytes.is_empty() { + Ok(Async::Ready(None)) + } else { + let chunk = Chunk::from_chunk(bytes.clone()); + *bytes = Bytes::new(); + Ok(Async::Ready(Some(chunk))) + }; + }, + }; + + Ok(Async::Ready(opt.map(|chunk| Chunk { + inner: chunk, + }))) } } diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index f9a4ce2..4b0feaf 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -346,7 +346,7 @@ impl ClientBuilder { /// Enables a request timeout. /// /// The timeout is applied from the when the request starts connecting - /// until the response headers are received. Bodies are not affected. + /// until the response body has finished. /// /// Default is no timeout. pub fn timeout(mut self, timeout: Duration) -> ClientBuilder { @@ -839,7 +839,7 @@ impl Future for PendingRequest { } } } - let res = Response::new(res, self.url.clone(), self.client.gzip); + let res = Response::new(res, self.url.clone(), self.client.gzip, self.timeout.take()); return Ok(Async::Ready(res)); } } diff --git a/src/async_impl/response.rs b/src/async_impl/response.rs index 368b6a9..23c37da 100644 --- a/src/async_impl/response.rs +++ b/src/async_impl/response.rs @@ -5,13 +5,15 @@ use std::net::SocketAddr; use futures::{Async, Future, Poll, Stream}; use futures::stream::Concat2; +use http; use hyper::{HeaderMap, StatusCode, Version}; use hyper::client::connect::HttpInfo; use hyper::header::{CONTENT_LENGTH}; +use tokio::timer::Delay; use serde::de::DeserializeOwned; use serde_json; use url::Url; -use http; + use cookie; use super::Decoder; @@ -31,14 +33,14 @@ pub struct Response { } impl Response { - pub(super) fn new(res: ::hyper::Response<::hyper::Body>, url: Url, gzip: bool) -> Response { + pub(super) fn new(res: ::hyper::Response<::hyper::Body>, url: Url, gzip: bool, timeout: Option) -> Response { let (parts, body) = res.into_parts(); let status = parts.status; let version = parts.version; let extensions = parts.extensions; let mut headers = parts.headers; - let decoder = Decoder::detect(&mut headers, Body::wrap(body), gzip); + let decoder = Decoder::detect(&mut headers, Body::response(body, timeout), gzip); debug!("Response: '{}' for {}", status, url); Response { diff --git a/tests/async.rs b/tests/async.rs index e944606..4a50662 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -138,6 +138,46 @@ fn request_timeout() { assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str())); } +#[test] +fn response_timeout() { + let _ = env_logger::try_init(); + + let server = server! { + request: b"\ + GET /slow HTTP/1.1\r\n\ + user-agent: $USERAGENT\r\n\ + accept: */*\r\n\ + accept-encoding: gzip\r\n\ + host: $HOST\r\n\ + \r\n\ + ", + response: b"\ + HTTP/1.1 200 OK\r\n\ + Content-Length: 5\r\n\ + \r\n\ + Hello\ + ", + write_timeout: Duration::from_secs(2) + }; + + let mut rt = Runtime::new().expect("new rt"); + + let client = Client::builder() + .timeout(Duration::from_millis(500)) + .build() + .unwrap(); + + let url = format!("http://{}/slow", server.addr()); + let fut = client + .get(&url) + .send() + .and_then(|res| res.into_body().concat2()); + + let err = rt.block_on(fut).unwrap_err(); + + assert!(err.is_timeout()); +} + fn gzip_case(response_size: usize, chunk_size: usize) { let content: String = (0..response_size).into_iter().map(|i| format!("test {}", i)).collect(); let mut encoder = ::libflate::gzip::Encoder::new(Vec::new()).unwrap();