fix(server): use a timeout for Server keep-alive

Server keep-alive is now **off** by default. In order to turn it on, the
`keep_alive` method must be called on the `Server` object.

Closes #368
This commit is contained in:
Sean McArthur
2015-10-08 16:30:48 -07:00
parent 388ddf6f3b
commit cdaa2547ed
2 changed files with 104 additions and 60 deletions

View File

@@ -111,8 +111,6 @@ use std::fmt;
use std::io::{self, ErrorKind, BufWriter, Write};
use std::net::{SocketAddr, ToSocketAddrs};
use std::thread::{self, JoinHandle};
#[cfg(feature = "timeouts")]
use std::time::Duration;
use num_cpus;
@@ -146,20 +144,16 @@ mod listener;
#[derive(Debug)]
pub struct Server<L = HttpListener> {
listener: L,
_timeouts: Timeouts,
timeouts: Timeouts,
}
#[cfg(feature = "timeouts")]
#[derive(Clone, Copy, Default, Debug)]
struct Timeouts {
read: Option<Duration>,
write: Option<Duration>,
keep_alive: Option<Duration>,
}
#[cfg(not(feature = "timeouts"))]
#[derive(Clone, Copy, Default, Debug)]
struct Timeouts;
macro_rules! try_option(
($e:expr) => {{
match $e {
@@ -175,18 +169,30 @@ impl<L: NetworkListener> Server<L> {
pub fn new(listener: L) -> Server<L> {
Server {
listener: listener,
_timeouts: Timeouts::default(),
timeouts: Timeouts::default(),
}
}
/// Enables keep-alive for this server.
///
/// The timeout duration passed will be used to determine how long
/// to keep the connection alive before dropping it.
///
/// **NOTE**: The timeout will only be used when the `timeouts` feature
/// is enabled for hyper, and rustc is 1.4 or greater.
#[inline]
pub fn keep_alive(&mut self, timeout: Duration) {
self.timeouts.keep_alive = Some(timeout);
}
#[cfg(feature = "timeouts")]
pub fn set_read_timeout(&mut self, dur: Option<Duration>) {
self._timeouts.read = dur;
self.timeouts.read = dur;
}
#[cfg(feature = "timeouts")]
pub fn set_write_timeout(&mut self, dur: Option<Duration>) {
self._timeouts.write = dur;
self.timeouts.write = dur;
}
@@ -228,7 +234,7 @@ L: NetworkListener + Send + 'static {
debug!("threads = {:?}", threads);
let pool = ListenerPool::new(server.listener);
let worker = Worker::new(handler, server._timeouts);
let worker = Worker::new(handler, server.timeouts);
let work = move |mut stream| worker.handle_connection(&mut stream);
let guard = thread::spawn(move || pool.accept(work, threads));
@@ -241,7 +247,7 @@ L: NetworkListener + Send + 'static {
struct Worker<H: Handler + 'static> {
handler: H,
_timeouts: Timeouts,
timeouts: Timeouts,
}
impl<H: Handler + 'static> Worker<H> {
@@ -249,7 +255,7 @@ impl<H: Handler + 'static> Worker<H> {
fn new(handler: H, timeouts: Timeouts) -> Worker<H> {
Worker {
handler: handler,
_timeouts: timeouts,
timeouts: timeouts,
}
}
@@ -258,7 +264,7 @@ impl<H: Handler + 'static> Worker<H> {
self.handler.on_connection_start();
if let Err(e) = self.set_timeouts(stream) {
if let Err(e) = self.set_timeouts(&*stream) {
error!("set_timeouts error: {:?}", e);
return;
}
@@ -273,73 +279,97 @@ impl<H: Handler + 'static> Worker<H> {
// FIXME: Use Type ascription
let stream_clone: &mut NetworkStream = &mut stream.clone();
let rdr = BufReader::new(stream_clone);
let wrt = BufWriter::new(stream);
let mut rdr = BufReader::new(stream_clone);
let mut wrt = BufWriter::new(stream);
self.keep_alive_loop(rdr, wrt, addr);
while self.keep_alive_loop(&mut rdr, &mut wrt, addr) {
if let Err(e) = self.set_read_timeout(*rdr.get_ref(), self.timeouts.keep_alive) {
error!("set_read_timeout keep_alive {:?}", e);
break;
}
}
self.handler.on_connection_end();
debug!("keep_alive loop ending for {}", addr);
}
fn set_timeouts(&self, s: &NetworkStream) -> io::Result<()> {
try!(self.set_read_timeout(s, self.timeouts.read));
self.set_write_timeout(s, self.timeouts.write)
}
#[cfg(not(feature = "timeouts"))]
fn set_timeouts<S>(&self, _: &mut S) -> io::Result<()> where S: NetworkStream {
fn set_write_timeout(&self, _s: &NetworkStream, _timeout: Option<Duration>) -> io::Result<()> {
Ok(())
}
#[cfg(feature = "timeouts")]
fn set_timeouts<S>(&self, s: &mut S) -> io::Result<()> where S: NetworkStream {
try!(s.set_read_timeout(self._timeouts.read));
s.set_write_timeout(self._timeouts.write)
fn set_write_timeout(&self, s: &NetworkStream, timeout: Option<Duration>) -> io::Result<()> {
s.set_write_timeout(timeout)
}
fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>,
mut wrt: W, addr: SocketAddr) {
let mut keep_alive = true;
while keep_alive {
let req = match Request::new(&mut rdr, addr) {
Ok(req) => req,
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
trace!("tcp closed, cancelling keep-alive loop");
break;
}
Err(Error::Io(e)) => {
debug!("ioerror in keepalive loop = {:?}", e);
break;
}
Err(e) => {
//TODO: send a 400 response
error!("request error = {:?}", e);
break;
}
};
#[cfg(not(feature = "timeouts"))]
fn set_read_timeout(&self, _s: &NetworkStream, _timeout: Option<Duration>) -> io::Result<()> {
Ok(())
}
#[cfg(feature = "timeouts")]
fn set_read_timeout(&self, s: &NetworkStream, timeout: Option<Duration>) -> io::Result<()> {
s.set_read_timeout(timeout)
}
if !self.handle_expect(&req, &mut wrt) {
break;
fn keep_alive_loop<W: Write>(&self, mut rdr: &mut BufReader<&mut NetworkStream>,
wrt: &mut W, addr: SocketAddr) -> bool {
let req = match Request::new(rdr, addr) {
Ok(req) => req,
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
trace!("tcp closed, cancelling keep-alive loop");
return false;
}
keep_alive = http::should_keep_alive(req.version, &req.headers);
let version = req.version;
let mut res_headers = Headers::new();
if !keep_alive {
res_headers.set(Connection::close());
Err(Error::Io(e)) => {
debug!("ioerror in keepalive loop = {:?}", e);
return false;
}
{
let mut res = Response::new(&mut wrt, &mut res_headers);
res.version = version;
self.handler.handle(req, res);
Err(e) => {
//TODO: send a 400 response
error!("request error = {:?}", e);
return false;
}
};
// if the request was keep-alive, we need to check that the server agrees
// if it wasn't, then the server cannot force it to be true anyways
if keep_alive {
keep_alive = http::should_keep_alive(version, &res_headers);
}
debug!("keep_alive = {:?} for {}", keep_alive, addr);
if !self.handle_expect(&req, wrt) {
return false;
}
if let Err(e) = req.set_read_timeout(self.timeouts.read) {
error!("set_read_timeout {:?}", e);
return false;
}
let mut keep_alive = self.timeouts.keep_alive.is_some() &&
http::should_keep_alive(req.version, &req.headers);
let version = req.version;
let mut res_headers = Headers::new();
if !keep_alive {
res_headers.set(Connection::close());
}
{
let mut res = Response::new(wrt, &mut res_headers);
res.version = version;
self.handler.handle(req, res);
}
// if the request was keep-alive, we need to check that the server agrees
// if it wasn't, then the server cannot force it to be true anyways
if keep_alive {
keep_alive = http::should_keep_alive(version, &res_headers);
}
debug!("keep_alive = {:?} for {}", keep_alive, addr);
keep_alive
}
fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool {

View File

@@ -4,6 +4,7 @@
//! target URI, headers, and message body.
use std::io::{self, Read};
use std::net::SocketAddr;
use std::time::Duration;
use buffer::BufReader;
use net::NetworkStream;
@@ -64,6 +65,19 @@ impl<'a, 'b: 'a> Request<'a, 'b> {
})
}
/// Set the read timeout of the underlying NetworkStream.
#[cfg(feature = "timeouts")]
#[inline]
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
self.body.get_ref().get_ref().set_read_timeout(timeout)
}
/// Set the read timeout of the underlying NetworkStream.
#[cfg(not(feature = "timeouts"))]
#[inline]
pub fn set_read_timeout(&self, _timeout: Option<Duration>) -> io::Result<()> {
Ok(())
}
/// Get a reference to the underlying `NetworkStream`.
#[inline]
pub fn downcast_ref<T: NetworkStream>(&self) -> Option<&T> {