ensure async request is canceled if there is a timeout

This commit is contained in:
Sean McArthur
2018-10-17 13:38:54 -07:00
parent a82232f0ee
commit 512b80a3ad
3 changed files with 129 additions and 15 deletions

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use std::time::Duration;
use std::thread;
use futures::{Future, Stream};
use futures::{Async, Future, Stream};
use futures::future::{self, Either};
use futures::sync::{mpsc, oneshot};
@@ -437,9 +437,42 @@ impl ClientHandle {
};
let work = rx.for_each(move |(req, tx)| {
/*
let tx: oneshot::Sender<::Result<async_impl::Response>> = tx;
let task = client.execute(req)
.then(move |x| tx.send(x).map_err(|_| ()));
.then(move |r| {
trace!("result received: {:?}", r);
tx.send(r).map_err(|_| ())
});
*/
let mut tx_opt: Option<oneshot::Sender<::Result<async_impl::Response>>> = Some(tx);
let mut res_fut = client.execute(req);
let task = future::poll_fn(move || {
let canceled = tx_opt
.as_mut()
.expect("polled after complete")
.poll_cancel()
.expect("poll_cancel cannot error")
.is_ready();
if canceled {
trace!("response receiver is canceled");
Ok(Async::Ready(()))
} else {
let result = match res_fut.poll() {
Ok(Async::NotReady) => return Ok(Async::NotReady),
Ok(Async::Ready(res)) => Ok(res),
Err(err) => Err(err),
};
let _ = tx_opt
.take()
.expect("polled after complete")
.send(result);
Ok(Async::Ready(()))
}
});
::tokio::spawn(task);
Ok(())
});

View File

@@ -3,10 +3,12 @@
use std::io::{Read, Write};
use std::net;
use std::time::Duration;
use std::sync::mpsc;
use std::thread;
pub struct Server {
addr: net::SocketAddr,
panic_rx: mpsc::Receiver<()>,
}
impl Server {
@@ -15,12 +17,24 @@ impl Server {
}
}
impl Drop for Server {
fn drop(&mut self) {
if !::std::thread::panicking() {
self
.panic_rx
.recv_timeout(Duration::from_secs(3))
.expect("test server should not panic");
}
}
}
#[derive(Default)]
pub struct Txn {
pub request: Vec<u8>,
pub response: Vec<u8>,
pub read_timeout: Option<Duration>,
pub read_closes: bool,
pub response_timeout: Option<Duration>,
pub write_timeout: Option<Duration>,
pub chunk_size: Option<usize>,
@@ -32,9 +46,10 @@ static DEFAULT_USER_AGENT: &'static str =
pub fn spawn(txns: Vec<Txn>) -> Server {
let listener = net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let (panic_tx, panic_rx) = mpsc::channel();
let tname = format!("test({})-support-server", thread::current().name().unwrap_or("<unknown>"));
thread::Builder::new().name(tname).spawn(
move || for txn in txns {
thread::Builder::new().name(tname).spawn(move || {
'txns: for txn in txns {
let mut expected = txn.request;
let reply = txn.response;
let (mut socket, _addr) = listener.accept().unwrap();
@@ -52,7 +67,13 @@ pub fn spawn(txns: Vec<Txn>) -> Server {
let mut n = 0;
while n < expected.len() {
match socket.read(&mut buf[n..]) {
Ok(0) => break,
Ok(0) => {
if !txn.read_closes {
panic!("server unexpected socket closed");
} else {
continue 'txns;
}
},
Ok(nread) => n += nread,
Err(err) => {
println!("server read error: {}", err);
@@ -61,6 +82,21 @@ pub fn spawn(txns: Vec<Txn>) -> Server {
}
}
if txn.read_closes {
socket.set_read_timeout(Some(Duration::from_secs(1))).unwrap();
match socket.read(&mut [0; 256]) {
Ok(0) => {
continue 'txns
},
Ok(_) => {
panic!("server read expected EOF, found more bytes");
},
Err(err) => {
panic!("server read expected EOF, got error: {}", err);
}
}
}
match (::std::str::from_utf8(&expected), ::std::str::from_utf8(&buf[..n])) {
(Ok(expected), Ok(received)) => {
assert_eq!(
@@ -108,10 +144,12 @@ pub fn spawn(txns: Vec<Txn>) -> Server {
socket.write_all(&reply).unwrap();
}
}
).expect("server thread spawn");
let _ = panic_tx.send(());
}).expect("server thread spawn");
Server {
addr: addr,
addr,
panic_rx,
}
}

View File

@@ -8,11 +8,58 @@ use std::io::Read;
use std::time::Duration;
#[test]
fn test_write_timeout() {
fn timeout_closes_connection() {
let _ = env_logger::try_init();
// Make Client drop *after* the Server, so the background doesn't
// close too early.
let client = reqwest::Client::builder()
.timeout(Duration::from_millis(500))
.build()
.unwrap();
let server = server! {
request: b"\
GET /closes HTTP/1.1\r\n\
user-agent: $USERAGENT\r\n\
accept: */*\r\n\
accept-encoding: gzip\r\n\
host: $HOST\r\n\
\r\n\
",
response: b"\
HTTP/1.1 200 OK\r\n\
Content-Length: 5\r\n\
\r\n\
Hello\
",
read_timeout: Duration::from_secs(2),
read_closes: true
};
let url = format!("http://{}/closes", server.addr());
let err = client
.get(&url)
.send()
.unwrap_err();
assert_eq!(err.get_ref().unwrap().to_string(), "timed out");
assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str()));
}
#[test]
fn write_timeout_large_body() {
let _ = env_logger::try_init();
let body = String::from_utf8(vec![b'x'; 20_000]).unwrap();
let len = 8192;
// Make Client drop *after* the Server, so the background doesn't
// close too early.
let client = reqwest::Client::builder()
.timeout(Duration::from_millis(500))
.build()
.unwrap();
let server = server! {
request: format!("\
POST /write-timeout HTTP/1.1\r\n\
@@ -30,17 +77,13 @@ fn test_write_timeout() {
\r\n\
Hello\
",
read_timeout: Duration::from_secs(2)
//response_timeout: Duration::from_secs(1)
read_timeout: Duration::from_secs(2),
read_closes: true
};
let cursor = ::std::io::Cursor::new(body.into_bytes());
let url = format!("http://{}/write-timeout", server.addr());
let err = reqwest::Client::builder()
.timeout(Duration::from_millis(500))
.build()
.unwrap()
let err = client
.post(&url)
.body(reqwest::Body::sized(cursor, len as u64))
.send()