From 512b80a3ada48a3c60328fc39a7b87a506a94acf Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Wed, 17 Oct 2018 13:38:54 -0700 Subject: [PATCH] ensure async request is canceled if there is a timeout --- src/client.rs | 37 ++++++++++++++++++++++++-- tests/support/server.rs | 48 +++++++++++++++++++++++++++++---- tests/timeouts.rs | 59 +++++++++++++++++++++++++++++++++++------ 3 files changed, 129 insertions(+), 15 deletions(-) diff --git a/src/client.rs b/src/client.rs index 0eccca5..31497c0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::time::Duration; use std::thread; -use futures::{Future, Stream}; +use futures::{Async, Future, Stream}; use futures::future::{self, Either}; use futures::sync::{mpsc, oneshot}; @@ -437,9 +437,42 @@ impl ClientHandle { }; let work = rx.for_each(move |(req, tx)| { + /* let tx: oneshot::Sender<::Result> = tx; let task = client.execute(req) - .then(move |x| tx.send(x).map_err(|_| ())); + .then(move |r| { + trace!("result received: {:?}", r); + tx.send(r).map_err(|_| ()) + }); + */ + let mut tx_opt: Option>> = Some(tx); + let mut res_fut = client.execute(req); + + let task = future::poll_fn(move || { + let canceled = tx_opt + .as_mut() + .expect("polled after complete") + .poll_cancel() + .expect("poll_cancel cannot error") + .is_ready(); + + if canceled { + trace!("response receiver is canceled"); + Ok(Async::Ready(())) + } else { + let result = match res_fut.poll() { + Ok(Async::NotReady) => return Ok(Async::NotReady), + Ok(Async::Ready(res)) => Ok(res), + Err(err) => Err(err), + }; + + let _ = tx_opt + .take() + .expect("polled after complete") + .send(result); + Ok(Async::Ready(())) + } + }); ::tokio::spawn(task); Ok(()) }); diff --git a/tests/support/server.rs b/tests/support/server.rs index 03fecc5..1f137ff 100644 --- a/tests/support/server.rs +++ b/tests/support/server.rs @@ -3,10 +3,12 @@ use std::io::{Read, Write}; use std::net; use std::time::Duration; +use std::sync::mpsc; use std::thread; pub struct Server { addr: net::SocketAddr, + panic_rx: mpsc::Receiver<()>, } impl Server { @@ -15,12 +17,24 @@ impl Server { } } +impl Drop for Server { + fn drop(&mut self) { + if !::std::thread::panicking() { + self + .panic_rx + .recv_timeout(Duration::from_secs(3)) + .expect("test server should not panic"); + } + } +} + #[derive(Default)] pub struct Txn { pub request: Vec, pub response: Vec, pub read_timeout: Option, + pub read_closes: bool, pub response_timeout: Option, pub write_timeout: Option, pub chunk_size: Option, @@ -32,9 +46,10 @@ static DEFAULT_USER_AGENT: &'static str = pub fn spawn(txns: Vec) -> Server { let listener = net::TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); + let (panic_tx, panic_rx) = mpsc::channel(); let tname = format!("test({})-support-server", thread::current().name().unwrap_or("")); - thread::Builder::new().name(tname).spawn( - move || for txn in txns { + thread::Builder::new().name(tname).spawn(move || { + 'txns: for txn in txns { let mut expected = txn.request; let reply = txn.response; let (mut socket, _addr) = listener.accept().unwrap(); @@ -52,7 +67,13 @@ pub fn spawn(txns: Vec) -> Server { let mut n = 0; while n < expected.len() { match socket.read(&mut buf[n..]) { - Ok(0) => break, + Ok(0) => { + if !txn.read_closes { + panic!("server unexpected socket closed"); + } else { + continue 'txns; + } + }, Ok(nread) => n += nread, Err(err) => { println!("server read error: {}", err); @@ -61,6 +82,21 @@ pub fn spawn(txns: Vec) -> Server { } } + if txn.read_closes { + socket.set_read_timeout(Some(Duration::from_secs(1))).unwrap(); + match socket.read(&mut [0; 256]) { + Ok(0) => { + continue 'txns + }, + Ok(_) => { + panic!("server read expected EOF, found more bytes"); + }, + Err(err) => { + panic!("server read expected EOF, got error: {}", err); + } + } + } + match (::std::str::from_utf8(&expected), ::std::str::from_utf8(&buf[..n])) { (Ok(expected), Ok(received)) => { assert_eq!( @@ -108,10 +144,12 @@ pub fn spawn(txns: Vec) -> Server { socket.write_all(&reply).unwrap(); } } - ).expect("server thread spawn"); + let _ = panic_tx.send(()); + }).expect("server thread spawn"); Server { - addr: addr, + addr, + panic_rx, } } diff --git a/tests/timeouts.rs b/tests/timeouts.rs index 70c8ac2..ea80da4 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -8,11 +8,58 @@ use std::io::Read; use std::time::Duration; #[test] -fn test_write_timeout() { +fn timeout_closes_connection() { + let _ = env_logger::try_init(); + + // Make Client drop *after* the Server, so the background doesn't + // close too early. + let client = reqwest::Client::builder() + .timeout(Duration::from_millis(500)) + .build() + .unwrap(); + + let server = server! { + request: b"\ + GET /closes HTTP/1.1\r\n\ + user-agent: $USERAGENT\r\n\ + accept: */*\r\n\ + accept-encoding: gzip\r\n\ + host: $HOST\r\n\ + \r\n\ + ", + response: b"\ + HTTP/1.1 200 OK\r\n\ + Content-Length: 5\r\n\ + \r\n\ + Hello\ + ", + read_timeout: Duration::from_secs(2), + read_closes: true + }; + + let url = format!("http://{}/closes", server.addr()); + let err = client + .get(&url) + .send() + .unwrap_err(); + + assert_eq!(err.get_ref().unwrap().to_string(), "timed out"); + assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str())); +} + +#[test] +fn write_timeout_large_body() { let _ = env_logger::try_init(); let body = String::from_utf8(vec![b'x'; 20_000]).unwrap(); let len = 8192; + // Make Client drop *after* the Server, so the background doesn't + // close too early. + let client = reqwest::Client::builder() + .timeout(Duration::from_millis(500)) + .build() + .unwrap(); + let server = server! { request: format!("\ POST /write-timeout HTTP/1.1\r\n\ @@ -30,17 +77,13 @@ fn test_write_timeout() { \r\n\ Hello\ ", - read_timeout: Duration::from_secs(2) - - //response_timeout: Duration::from_secs(1) + read_timeout: Duration::from_secs(2), + read_closes: true }; let cursor = ::std::io::Cursor::new(body.into_bytes()); let url = format!("http://{}/write-timeout", server.addr()); - let err = reqwest::Client::builder() - .timeout(Duration::from_millis(500)) - .build() - .unwrap() + let err = client .post(&url) .body(reqwest::Body::sized(cursor, len as u64)) .send()