ensure async request is canceled if there is a timeout
This commit is contained in:
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::thread;
|
||||
|
||||
use futures::{Future, Stream};
|
||||
use futures::{Async, Future, Stream};
|
||||
use futures::future::{self, Either};
|
||||
use futures::sync::{mpsc, oneshot};
|
||||
|
||||
@@ -437,9 +437,42 @@ impl ClientHandle {
|
||||
};
|
||||
|
||||
let work = rx.for_each(move |(req, tx)| {
|
||||
/*
|
||||
let tx: oneshot::Sender<::Result<async_impl::Response>> = tx;
|
||||
let task = client.execute(req)
|
||||
.then(move |x| tx.send(x).map_err(|_| ()));
|
||||
.then(move |r| {
|
||||
trace!("result received: {:?}", r);
|
||||
tx.send(r).map_err(|_| ())
|
||||
});
|
||||
*/
|
||||
let mut tx_opt: Option<oneshot::Sender<::Result<async_impl::Response>>> = Some(tx);
|
||||
let mut res_fut = client.execute(req);
|
||||
|
||||
let task = future::poll_fn(move || {
|
||||
let canceled = tx_opt
|
||||
.as_mut()
|
||||
.expect("polled after complete")
|
||||
.poll_cancel()
|
||||
.expect("poll_cancel cannot error")
|
||||
.is_ready();
|
||||
|
||||
if canceled {
|
||||
trace!("response receiver is canceled");
|
||||
Ok(Async::Ready(()))
|
||||
} else {
|
||||
let result = match res_fut.poll() {
|
||||
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
||||
Ok(Async::Ready(res)) => Ok(res),
|
||||
Err(err) => Err(err),
|
||||
};
|
||||
|
||||
let _ = tx_opt
|
||||
.take()
|
||||
.expect("polled after complete")
|
||||
.send(result);
|
||||
Ok(Async::Ready(()))
|
||||
}
|
||||
});
|
||||
::tokio::spawn(task);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
use std::io::{Read, Write};
|
||||
use std::net;
|
||||
use std::time::Duration;
|
||||
use std::sync::mpsc;
|
||||
use std::thread;
|
||||
|
||||
pub struct Server {
|
||||
addr: net::SocketAddr,
|
||||
panic_rx: mpsc::Receiver<()>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
@@ -15,12 +17,24 @@ impl Server {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Server {
|
||||
fn drop(&mut self) {
|
||||
if !::std::thread::panicking() {
|
||||
self
|
||||
.panic_rx
|
||||
.recv_timeout(Duration::from_secs(3))
|
||||
.expect("test server should not panic");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Txn {
|
||||
pub request: Vec<u8>,
|
||||
pub response: Vec<u8>,
|
||||
|
||||
pub read_timeout: Option<Duration>,
|
||||
pub read_closes: bool,
|
||||
pub response_timeout: Option<Duration>,
|
||||
pub write_timeout: Option<Duration>,
|
||||
pub chunk_size: Option<usize>,
|
||||
@@ -32,9 +46,10 @@ static DEFAULT_USER_AGENT: &'static str =
|
||||
pub fn spawn(txns: Vec<Txn>) -> Server {
|
||||
let listener = net::TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let (panic_tx, panic_rx) = mpsc::channel();
|
||||
let tname = format!("test({})-support-server", thread::current().name().unwrap_or("<unknown>"));
|
||||
thread::Builder::new().name(tname).spawn(
|
||||
move || for txn in txns {
|
||||
thread::Builder::new().name(tname).spawn(move || {
|
||||
'txns: for txn in txns {
|
||||
let mut expected = txn.request;
|
||||
let reply = txn.response;
|
||||
let (mut socket, _addr) = listener.accept().unwrap();
|
||||
@@ -52,7 +67,13 @@ pub fn spawn(txns: Vec<Txn>) -> Server {
|
||||
let mut n = 0;
|
||||
while n < expected.len() {
|
||||
match socket.read(&mut buf[n..]) {
|
||||
Ok(0) => break,
|
||||
Ok(0) => {
|
||||
if !txn.read_closes {
|
||||
panic!("server unexpected socket closed");
|
||||
} else {
|
||||
continue 'txns;
|
||||
}
|
||||
},
|
||||
Ok(nread) => n += nread,
|
||||
Err(err) => {
|
||||
println!("server read error: {}", err);
|
||||
@@ -61,6 +82,21 @@ pub fn spawn(txns: Vec<Txn>) -> Server {
|
||||
}
|
||||
}
|
||||
|
||||
if txn.read_closes {
|
||||
socket.set_read_timeout(Some(Duration::from_secs(1))).unwrap();
|
||||
match socket.read(&mut [0; 256]) {
|
||||
Ok(0) => {
|
||||
continue 'txns
|
||||
},
|
||||
Ok(_) => {
|
||||
panic!("server read expected EOF, found more bytes");
|
||||
},
|
||||
Err(err) => {
|
||||
panic!("server read expected EOF, got error: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match (::std::str::from_utf8(&expected), ::std::str::from_utf8(&buf[..n])) {
|
||||
(Ok(expected), Ok(received)) => {
|
||||
assert_eq!(
|
||||
@@ -108,10 +144,12 @@ pub fn spawn(txns: Vec<Txn>) -> Server {
|
||||
socket.write_all(&reply).unwrap();
|
||||
}
|
||||
}
|
||||
).expect("server thread spawn");
|
||||
let _ = panic_tx.send(());
|
||||
}).expect("server thread spawn");
|
||||
|
||||
Server {
|
||||
addr: addr,
|
||||
addr,
|
||||
panic_rx,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,11 +8,58 @@ use std::io::Read;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn test_write_timeout() {
|
||||
fn timeout_closes_connection() {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
// Make Client drop *after* the Server, so the background doesn't
|
||||
// close too early.
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_millis(500))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let server = server! {
|
||||
request: b"\
|
||||
GET /closes HTTP/1.1\r\n\
|
||||
user-agent: $USERAGENT\r\n\
|
||||
accept: */*\r\n\
|
||||
accept-encoding: gzip\r\n\
|
||||
host: $HOST\r\n\
|
||||
\r\n\
|
||||
",
|
||||
response: b"\
|
||||
HTTP/1.1 200 OK\r\n\
|
||||
Content-Length: 5\r\n\
|
||||
\r\n\
|
||||
Hello\
|
||||
",
|
||||
read_timeout: Duration::from_secs(2),
|
||||
read_closes: true
|
||||
};
|
||||
|
||||
let url = format!("http://{}/closes", server.addr());
|
||||
let err = client
|
||||
.get(&url)
|
||||
.send()
|
||||
.unwrap_err();
|
||||
|
||||
assert_eq!(err.get_ref().unwrap().to_string(), "timed out");
|
||||
assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_timeout_large_body() {
|
||||
let _ = env_logger::try_init();
|
||||
let body = String::from_utf8(vec![b'x'; 20_000]).unwrap();
|
||||
let len = 8192;
|
||||
|
||||
// Make Client drop *after* the Server, so the background doesn't
|
||||
// close too early.
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_millis(500))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let server = server! {
|
||||
request: format!("\
|
||||
POST /write-timeout HTTP/1.1\r\n\
|
||||
@@ -30,17 +77,13 @@ fn test_write_timeout() {
|
||||
\r\n\
|
||||
Hello\
|
||||
",
|
||||
read_timeout: Duration::from_secs(2)
|
||||
|
||||
//response_timeout: Duration::from_secs(1)
|
||||
read_timeout: Duration::from_secs(2),
|
||||
read_closes: true
|
||||
};
|
||||
|
||||
let cursor = ::std::io::Cursor::new(body.into_bytes());
|
||||
let url = format!("http://{}/write-timeout", server.addr());
|
||||
let err = reqwest::Client::builder()
|
||||
.timeout(Duration::from_millis(500))
|
||||
.build()
|
||||
.unwrap()
|
||||
let err = client
|
||||
.post(&url)
|
||||
.body(reqwest::Body::sized(cursor, len as u64))
|
||||
.send()
|
||||
|
||||
Reference in New Issue
Block a user