feat(server): Allow keep alive to be turned off for a connection (#1390)

Closes #1365
This commit is contained in:
Steven Fackler
2017-12-04 10:14:20 -08:00
committed by Sean McArthur
parent cecef9d402
commit eb9590e3da
4 changed files with 136 additions and 3 deletions

View File

@@ -453,6 +453,14 @@ where I: AsyncRead + AsyncWrite,
pub fn close_write(&mut self) { pub fn close_write(&mut self) {
self.state.close_write(); self.state.close_write();
} }
pub fn disable_keep_alive(&mut self) {
if self.state.is_idle() {
self.state.close_read();
} else {
self.state.disable_keep_alive();
}
}
} }
// ==== tokio_proto impl ==== // ==== tokio_proto impl ====
@@ -700,6 +708,10 @@ impl<B, K: KeepAlive> State<B, K> {
} }
} }
fn disable_keep_alive(&mut self) {
self.keep_alive.disable()
}
fn busy(&mut self) { fn busy(&mut self) {
if let KA::Disabled = self.keep_alive.status() { if let KA::Disabled = self.keep_alive.status() {
return; return;
@@ -869,7 +881,7 @@ mod tests {
other => panic!("unexpected frame: {:?}", other) other => panic!("unexpected frame: {:?}", other)
} }
// client // client
let io = AsyncIo::new_buf(vec![], 1); let io = AsyncIo::new_buf(vec![], 1);
let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io, Default::default()); let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io, Default::default());
conn.state.busy(); conn.state.busy();

View File

@@ -54,6 +54,10 @@ where
} }
} }
pub fn disable_keep_alive(&mut self) {
self.conn.disable_keep_alive()
}
fn poll_read(&mut self) -> Poll<(), ::Error> { fn poll_read(&mut self) -> Poll<(), ::Error> {
loop { loop {
if self.conn.can_read_head() { if self.conn.can_read_head() {

View File

@@ -536,6 +536,18 @@ where
} }
} }
impl<I, B, S> Connection<I, S>
where S: Service<Request = Request, Response = Response<B>, Error = ::Error> + 'static,
I: AsyncRead + AsyncWrite + 'static,
B: Stream<Error=::Error> + 'static,
B::Item: AsRef<[u8]>,
{
/// Disables keep-alive for this connection.
pub fn disable_keep_alive(&mut self) {
self.conn.disable_keep_alive()
}
}
mod unnameable { mod unnameable {
// This type is specifically not exported outside the crate, // This type is specifically not exported outside the crate,
// so no one can actually name the type. With no methods, we make no // so no one can actually name the type. With no methods, we make no

View File

@@ -6,7 +6,7 @@ extern crate pretty_env_logger;
extern crate tokio_core; extern crate tokio_core;
use futures::{Future, Stream}; use futures::{Future, Stream};
use futures::future::{self, FutureResult}; use futures::future::{self, FutureResult, Either};
use futures::sync::oneshot; use futures::sync::oneshot;
use tokio_core::net::TcpListener; use tokio_core::net::TcpListener;
@@ -551,6 +551,106 @@ fn pipeline_enabled() {
assert_eq!(n, 0); assert_eq!(n, 0);
} }
#[test]
fn disable_keep_alive_mid_request() {
let mut core = Core::new().unwrap();
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
let addr = listener.local_addr().unwrap();
let (tx1, rx1) = oneshot::channel();
let (tx2, rx2) = oneshot::channel();
let child = thread::spawn(move || {
let mut req = connect(&addr);
req.write_all(b"GET / HTTP/1.1\r\n").unwrap();
tx1.send(()).unwrap();
rx2.wait().unwrap();
req.write_all(b"Host: localhost\r\n\r\n").unwrap();
let mut buf = vec![];
req.read_to_end(&mut buf).unwrap();
});
let fut = listener.incoming()
.into_future()
.map_err(|_| unreachable!())
.and_then(|(item, _incoming)| {
let (socket, _) = item.unwrap();
Http::<hyper::Chunk>::new().serve_connection(socket, HelloWorld)
.select2(rx1)
.then(|r| {
match r {
Ok(Either::A(_)) => panic!("expected rx first"),
Ok(Either::B(((), mut conn))) => {
conn.disable_keep_alive();
tx2.send(()).unwrap();
conn
}
Err(Either::A((e, _))) => panic!("unexpected error {}", e),
Err(Either::B((e, _))) => panic!("unexpected error {}", e),
}
})
});
core.run(fut).unwrap();
child.join().unwrap();
}
#[test]
fn disable_keep_alive_post_request() {
let mut core = Core::new().unwrap();
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
let addr = listener.local_addr().unwrap();
let (tx1, rx1) = oneshot::channel();
let child = thread::spawn(move || {
let mut req = connect(&addr);
req.write_all(b"\
GET / HTTP/1.1\r\n\
Host: localhost\r\n\
\r\n\
").unwrap();
let mut buf = [0; 1024 * 8];
loop {
let n = req.read(&mut buf).expect("reading 1");
if n < buf.len() {
if &buf[n - HELLO.len()..n] == HELLO.as_bytes() {
break;
}
}
}
tx1.send(()).unwrap();
let nread = req.read(&mut buf).unwrap();
assert_eq!(nread, 0);
});
let fut = listener.incoming()
.into_future()
.map_err(|_| unreachable!())
.and_then(|(item, _incoming)| {
let (socket, _) = item.unwrap();
Http::<hyper::Chunk>::new().serve_connection(socket, HelloWorld)
.select2(rx1)
.then(|r| {
match r {
Ok(Either::A(_)) => panic!("expected rx first"),
Ok(Either::B(((), mut conn))) => {
conn.disable_keep_alive();
conn
}
Err(Either::A((e, _))) => panic!("unexpected error {}", e),
Err(Either::B((e, _))) => panic!("unexpected error {}", e),
}
})
});
core.run(fut).unwrap();
child.join().unwrap();
}
#[test] #[test]
fn no_proto_empty_parse_eof_does_not_return_error() { fn no_proto_empty_parse_eof_does_not_return_error() {
let mut core = Core::new().unwrap(); let mut core = Core::new().unwrap();
@@ -719,6 +819,8 @@ impl Service for TestService {
} }
const HELLO: &'static str = "hello";
struct HelloWorld; struct HelloWorld;
impl Service for HelloWorld { impl Service for HelloWorld {
@@ -728,7 +830,10 @@ impl Service for HelloWorld {
type Future = FutureResult<Self::Response, Self::Error>; type Future = FutureResult<Self::Response, Self::Error>;
fn call(&self, _req: Request) -> Self::Future { fn call(&self, _req: Request) -> Self::Future {
future::ok(Response::new()) let mut response = Response::new();
response.headers_mut().set(hyper::header::ContentLength(HELLO.len() as u64));
response.set_body(HELLO);
future::ok(response)
} }
} }