Rewrite tests with a hyper server instead of raw TCP

This makes the tests much less brittle, by not depending on the exact
order of the HTTP headers, nor always requiring to check for every
single header.
This commit is contained in:
Sean McArthur
2019-09-23 11:33:04 -07:00
parent 3cf8ede960
commit f4100e4148
10 changed files with 881 additions and 1806 deletions

View File

@@ -1,14 +1,18 @@
//! A server builder helper for the integration tests.
use std::io::{Read, Write};
use std::convert::Infallible;
use std::future::Future;
use std::net;
use std::sync::mpsc;
use std::sync::mpsc as std_mpsc;
use std::thread;
use std::time::Duration;
use tokio::sync::oneshot;
pub use http::Response;
pub struct Server {
addr: net::SocketAddr,
panic_rx: mpsc::Receiver<()>,
panic_rx: std_mpsc::Receiver<()>,
shutdown_tx: Option<oneshot::Sender<()>>,
}
impl Server {
@@ -19,7 +23,11 @@ impl Server {
impl Drop for Server {
fn drop(&mut self) {
if !thread::panicking() {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if !::std::thread::panicking() {
self.panic_rx
.recv_timeout(Duration::from_secs(3))
.expect("test server should not panic");
@@ -27,197 +35,46 @@ impl Drop for Server {
}
}
#[derive(Debug, Default)]
pub struct Txn {
pub request: Vec<u8>,
pub response: Vec<u8>,
pub fn http<F, Fut>(func: F) -> Server
where
F: Fn(http::Request<hyper::Body>) -> Fut + Clone + Send + 'static,
Fut: Future<Output = http::Response<hyper::Body>> + Send + 'static,
{
let srv = hyper::Server::bind(&([127, 0, 0, 1], 0).into()).serve(
hyper::service::make_service_fn(move |_| {
let func = func.clone();
async move {
Ok::<_, Infallible>(hyper::service::service_fn(move |req| {
let fut = func(req);
async move { Ok::<_, Infallible>(fut.await) }
}))
}
}),
);
pub read_timeout: Option<Duration>,
pub read_closes: bool,
pub response_timeout: Option<Duration>,
pub write_timeout: Option<Duration>,
pub chunk_size: Option<usize>,
}
let addr = srv.local_addr();
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let srv = srv.with_graceful_shutdown(async move {
let _ = shutdown_rx.await;
});
static DEFAULT_USER_AGENT: &'static str =
concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
pub fn spawn(txns: Vec<Txn>) -> Server {
let listener = net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let (panic_tx, panic_rx) = mpsc::channel();
let (panic_tx, panic_rx) = std_mpsc::channel();
let tname = format!(
"test({})-support-server",
thread::current().name().unwrap_or("<unknown>")
);
thread::Builder::new().name(tname).spawn(move || {
'txns: for txn in txns {
let mut expected = txn.request;
let reply = txn.response;
let (mut socket, _addr) = listener.accept().unwrap();
thread::Builder::new()
.name(tname)
.spawn(move || {
let mut rt = tokio::runtime::current_thread::Runtime::new().expect("rt new");
rt.block_on(srv).unwrap();
let _ = panic_tx.send(());
})
.expect("thread spawn");
socket.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
replace_expected_vars(&mut expected, addr.to_string().as_ref(), DEFAULT_USER_AGENT.as_ref());
if let Some(dur) = txn.read_timeout {
thread::park_timeout(dur);
}
let mut buf = vec![0; expected.len() + 256];
let mut n = 0;
while n < expected.len() {
match socket.read(&mut buf[n..]) {
Ok(0) => {
if !txn.read_closes {
panic!("server unexpected socket closed");
} else {
continue 'txns;
}
},
Ok(nread) => n += nread,
Err(err) => {
println!("server read error: {}", err);
break;
}
}
}
if txn.read_closes {
socket.set_read_timeout(Some(Duration::from_secs(1))).unwrap();
match socket.read(&mut [0; 256]) {
Ok(0) => {
continue 'txns
},
Ok(_) => {
panic!("server read expected EOF, found more bytes");
},
Err(err) => {
panic!("server read expected EOF, got error: {}", err);
}
}
}
match (std::str::from_utf8(&expected), std::str::from_utf8(&buf[..n])) {
(Ok(expected), Ok(received)) => {
if expected.len() > 300 && std::env::var("REQWEST_TEST_BODY_FULL").is_err() {
assert_eq!(
expected.len(),
received.len(),
"expected len = {}, received len = {}; to skip length check and see exact contents, re-run with REQWEST_TEST_BODY_FULL=1",
expected.len(),
received.len(),
);
}
assert_eq!(expected, received)
},
_ => {
assert_eq!(
expected.len(),
n,
"expected len = {}, received len = {}",
expected.len(),
n,
);
assert_eq!(expected, &buf[..n])
},
}
if let Some(dur) = txn.response_timeout {
thread::park_timeout(dur);
}
if let Some(dur) = txn.write_timeout {
let headers_end = b"\r\n\r\n";
let headers_end = reply.windows(headers_end.len()).position(|w| w == headers_end).unwrap() + 4;
socket.write_all(&reply[..headers_end]).unwrap();
let body = &reply[headers_end..];
if let Some(chunk_size) = txn.chunk_size {
for content in body.chunks(chunk_size) {
thread::park_timeout(dur);
socket.write_all(&content).unwrap();
}
} else {
thread::park_timeout(dur);
socket.write_all(&body).unwrap();
}
} else {
socket.write_all(&reply).unwrap();
}
}
let _ = panic_tx.send(());
}).expect("server thread spawn");
Server { addr, panic_rx }
}
fn replace_expected_vars(bytes: &mut Vec<u8>, host: &[u8], ua: &[u8]) {
// plenty horrible, but these are just tests, and gets the job done
let mut index = 0;
loop {
if index == bytes.len() {
return;
}
for b in (&bytes[index..]).iter() {
index += 1;
if *b == b'$' {
break;
}
}
let has_host = (&bytes[index..]).starts_with(b"HOST");
if has_host {
bytes.drain(index - 1..index + 4);
for (i, b) in host.iter().enumerate() {
bytes.insert(index - 1 + i, *b);
}
} else {
let has_ua = (&bytes[index..]).starts_with(b"USERAGENT");
if has_ua {
bytes.drain(index - 1..index + 9);
for (i, b) in ua.iter().enumerate() {
bytes.insert(index - 1 + i, *b);
}
}
}
Server {
addr,
panic_rx,
shutdown_tx: Some(shutdown_tx),
}
}
#[macro_export]
macro_rules! server {
($($($f:ident: $v:expr),+);*) => ({
let txns = vec![
$(__internal__txn! {
$($f: $v,)+
}),*
];
crate::support::server::spawn(txns)
})
}
#[macro_export]
macro_rules! __internal__txn {
($($field:ident: $val:expr,)+) => (
crate::support::server::Txn {
$( $field: __internal__prop!($field: $val), )+
.. Default::default()
}
)
}
#[macro_export]
macro_rules! __internal__prop {
(request: $val:expr) => {
From::from(&$val[..])
};
(response: $val:expr) => {
From::from(&$val[..])
};
($field:ident: $val:expr) => {
From::from($val)
};
}