diff --git a/src/http/conn.rs b/src/http/conn.rs index 09694917..420441ae 100644 --- a/src/http/conn.rs +++ b/src/http/conn.rs @@ -610,6 +610,9 @@ impl State { } fn busy(&mut self) { + if let KA::Disabled = self.keep_alive.status() { + return; + } self.keep_alive.busy(); } diff --git a/tests/server.rs b/tests/server.rs index c2d9c803..40b75aaf 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -146,10 +146,16 @@ fn connect(addr: &SocketAddr) -> TcpStream { } fn serve() -> Serve { - serve_with_timeout(None) + serve_with_options(Default::default()) } -fn serve_with_timeout(dur: Option) -> Serve { +#[derive(Default)] +struct ServeOptions { + keep_alive_disabled: bool, + timeout: Option, +} + +fn serve_with_options(options: ServeOptions) -> Serve { let _ = pretty_env_logger::init(); let (addr_tx, addr_rx) = mpsc::channel(); @@ -159,9 +165,12 @@ fn serve_with_timeout(dur: Option) -> Serve { let addr = "127.0.0.1:0".parse().unwrap(); + let keep_alive = !options.keep_alive_disabled; + let dur = options.timeout; + let thread_name = format!("test-server-{:?}", dur); let thread = thread::Builder::new().name(thread_name).spawn(move || { - let srv = Http::new().bind(&addr, TestService { + let srv = Http::new().keep_alive(keep_alive).bind(&addr, TestService { tx: Arc::new(Mutex::new(msg_tx.clone())), _timeout: dur, reply: reply_rx, @@ -456,3 +465,61 @@ fn server_keep_alive() { } } } + +#[test] +fn test_server_disable_keep_alive() { + let foo_bar = b"foo bar baz"; + let server = serve_with_options(ServeOptions { + keep_alive_disabled: true, + .. Default::default() + }); + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(foo_bar.len() as u64)) + .body(foo_bar); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: keep-alive\r\n\ + \r\n\ + ").expect("writing 1"); + + let mut buf = [0; 1024 * 8]; + loop { + let n = req.read(&mut buf[..]).expect("reading 1"); + if n < buf.len() { + if &buf[n - foo_bar.len()..n] == foo_bar { + break; + } else { + } + } + } + + // try again! + + let quux = b"zar quux"; + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(quux.len() as u64)) + .body(quux); + + let _ = req.write_all(b"\ + GET /quux HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + "); + + // the write can possibly succeed, since it fills the kernel buffer on the first write + let mut buf = [0; 1024 * 8]; + match req.read(&mut buf[..]) { + // Ok(0) means EOF, so a proper shutdown + // Err(_) could mean ConnReset or something, also fine + Ok(0) | + Err(_) => {} + Ok(n) => { + panic!("read {} bytes on a disabled keep-alive socket", n); + } + } +} \ No newline at end of file