fix(server): use a timeout for Server keep-alive

Server keep-alive is now **off** by default. In order to turn it on, the
`keep_alive` method must be called on the `Server` object.

Closes #368
This commit is contained in:
Sean McArthur
2015-10-08 16:30:48 -07:00
parent 388ddf6f3b
commit cdaa2547ed
2 changed files with 104 additions and 60 deletions

View File

@@ -111,8 +111,6 @@ use std::fmt;
use std::io::{self, ErrorKind, BufWriter, Write}; use std::io::{self, ErrorKind, BufWriter, Write};
use std::net::{SocketAddr, ToSocketAddrs}; use std::net::{SocketAddr, ToSocketAddrs};
use std::thread::{self, JoinHandle}; use std::thread::{self, JoinHandle};
#[cfg(feature = "timeouts")]
use std::time::Duration; use std::time::Duration;
use num_cpus; use num_cpus;
@@ -146,20 +144,16 @@ mod listener;
#[derive(Debug)] #[derive(Debug)]
pub struct Server<L = HttpListener> { pub struct Server<L = HttpListener> {
listener: L, listener: L,
_timeouts: Timeouts, timeouts: Timeouts,
} }
#[cfg(feature = "timeouts")]
#[derive(Clone, Copy, Default, Debug)] #[derive(Clone, Copy, Default, Debug)]
struct Timeouts { struct Timeouts {
read: Option<Duration>, read: Option<Duration>,
write: Option<Duration>, write: Option<Duration>,
keep_alive: Option<Duration>,
} }
#[cfg(not(feature = "timeouts"))]
#[derive(Clone, Copy, Default, Debug)]
struct Timeouts;
macro_rules! try_option( macro_rules! try_option(
($e:expr) => {{ ($e:expr) => {{
match $e { match $e {
@@ -175,18 +169,30 @@ impl<L: NetworkListener> Server<L> {
pub fn new(listener: L) -> Server<L> { pub fn new(listener: L) -> Server<L> {
Server { Server {
listener: listener, listener: listener,
_timeouts: Timeouts::default(), timeouts: Timeouts::default(),
} }
} }
/// Enables keep-alive for this server.
///
/// The timeout duration passed will be used to determine how long
/// to keep the connection alive before dropping it.
///
/// **NOTE**: The timeout will only be used when the `timeouts` feature
/// is enabled for hyper, and rustc is 1.4 or greater.
#[inline]
pub fn keep_alive(&mut self, timeout: Duration) {
self.timeouts.keep_alive = Some(timeout);
}
#[cfg(feature = "timeouts")] #[cfg(feature = "timeouts")]
pub fn set_read_timeout(&mut self, dur: Option<Duration>) { pub fn set_read_timeout(&mut self, dur: Option<Duration>) {
self._timeouts.read = dur; self.timeouts.read = dur;
} }
#[cfg(feature = "timeouts")] #[cfg(feature = "timeouts")]
pub fn set_write_timeout(&mut self, dur: Option<Duration>) { pub fn set_write_timeout(&mut self, dur: Option<Duration>) {
self._timeouts.write = dur; self.timeouts.write = dur;
} }
@@ -228,7 +234,7 @@ L: NetworkListener + Send + 'static {
debug!("threads = {:?}", threads); debug!("threads = {:?}", threads);
let pool = ListenerPool::new(server.listener); let pool = ListenerPool::new(server.listener);
let worker = Worker::new(handler, server._timeouts); let worker = Worker::new(handler, server.timeouts);
let work = move |mut stream| worker.handle_connection(&mut stream); let work = move |mut stream| worker.handle_connection(&mut stream);
let guard = thread::spawn(move || pool.accept(work, threads)); let guard = thread::spawn(move || pool.accept(work, threads));
@@ -241,7 +247,7 @@ L: NetworkListener + Send + 'static {
struct Worker<H: Handler + 'static> { struct Worker<H: Handler + 'static> {
handler: H, handler: H,
_timeouts: Timeouts, timeouts: Timeouts,
} }
impl<H: Handler + 'static> Worker<H> { impl<H: Handler + 'static> Worker<H> {
@@ -249,7 +255,7 @@ impl<H: Handler + 'static> Worker<H> {
fn new(handler: H, timeouts: Timeouts) -> Worker<H> { fn new(handler: H, timeouts: Timeouts) -> Worker<H> {
Worker { Worker {
handler: handler, handler: handler,
_timeouts: timeouts, timeouts: timeouts,
} }
} }
@@ -258,7 +264,7 @@ impl<H: Handler + 'static> Worker<H> {
self.handler.on_connection_start(); self.handler.on_connection_start();
if let Err(e) = self.set_timeouts(stream) { if let Err(e) = self.set_timeouts(&*stream) {
error!("set_timeouts error: {:?}", e); error!("set_timeouts error: {:?}", e);
return; return;
} }
@@ -273,73 +279,97 @@ impl<H: Handler + 'static> Worker<H> {
// FIXME: Use Type ascription // FIXME: Use Type ascription
let stream_clone: &mut NetworkStream = &mut stream.clone(); let stream_clone: &mut NetworkStream = &mut stream.clone();
let rdr = BufReader::new(stream_clone); let mut rdr = BufReader::new(stream_clone);
let wrt = BufWriter::new(stream); let mut wrt = BufWriter::new(stream);
self.keep_alive_loop(rdr, wrt, addr); while self.keep_alive_loop(&mut rdr, &mut wrt, addr) {
if let Err(e) = self.set_read_timeout(*rdr.get_ref(), self.timeouts.keep_alive) {
error!("set_read_timeout keep_alive {:?}", e);
break;
}
}
self.handler.on_connection_end(); self.handler.on_connection_end();
debug!("keep_alive loop ending for {}", addr); debug!("keep_alive loop ending for {}", addr);
} }
fn set_timeouts(&self, s: &NetworkStream) -> io::Result<()> {
try!(self.set_read_timeout(s, self.timeouts.read));
self.set_write_timeout(s, self.timeouts.write)
}
#[cfg(not(feature = "timeouts"))] #[cfg(not(feature = "timeouts"))]
fn set_timeouts<S>(&self, _: &mut S) -> io::Result<()> where S: NetworkStream { fn set_write_timeout(&self, _s: &NetworkStream, _timeout: Option<Duration>) -> io::Result<()> {
Ok(()) Ok(())
} }
#[cfg(feature = "timeouts")] #[cfg(feature = "timeouts")]
fn set_timeouts<S>(&self, s: &mut S) -> io::Result<()> where S: NetworkStream { fn set_write_timeout(&self, s: &NetworkStream, timeout: Option<Duration>) -> io::Result<()> {
try!(s.set_read_timeout(self._timeouts.read)); s.set_write_timeout(timeout)
s.set_write_timeout(self._timeouts.write)
} }
fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>, #[cfg(not(feature = "timeouts"))]
mut wrt: W, addr: SocketAddr) { fn set_read_timeout(&self, _s: &NetworkStream, _timeout: Option<Duration>) -> io::Result<()> {
let mut keep_alive = true; Ok(())
while keep_alive { }
let req = match Request::new(&mut rdr, addr) {
Ok(req) => req,
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
trace!("tcp closed, cancelling keep-alive loop");
break;
}
Err(Error::Io(e)) => {
debug!("ioerror in keepalive loop = {:?}", e);
break;
}
Err(e) => {
//TODO: send a 400 response
error!("request error = {:?}", e);
break;
}
};
#[cfg(feature = "timeouts")]
fn set_read_timeout(&self, s: &NetworkStream, timeout: Option<Duration>) -> io::Result<()> {
s.set_read_timeout(timeout)
}
if !self.handle_expect(&req, &mut wrt) { fn keep_alive_loop<W: Write>(&self, mut rdr: &mut BufReader<&mut NetworkStream>,
break; wrt: &mut W, addr: SocketAddr) -> bool {
let req = match Request::new(rdr, addr) {
Ok(req) => req,
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
trace!("tcp closed, cancelling keep-alive loop");
return false;
} }
Err(Error::Io(e)) => {
keep_alive = http::should_keep_alive(req.version, &req.headers); debug!("ioerror in keepalive loop = {:?}", e);
let version = req.version; return false;
let mut res_headers = Headers::new();
if !keep_alive {
res_headers.set(Connection::close());
} }
{ Err(e) => {
let mut res = Response::new(&mut wrt, &mut res_headers); //TODO: send a 400 response
res.version = version; error!("request error = {:?}", e);
self.handler.handle(req, res); return false;
} }
};
// if the request was keep-alive, we need to check that the server agrees
// if it wasn't, then the server cannot force it to be true anyways
if keep_alive {
keep_alive = http::should_keep_alive(version, &res_headers);
}
debug!("keep_alive = {:?} for {}", keep_alive, addr); if !self.handle_expect(&req, wrt) {
return false;
} }
if let Err(e) = req.set_read_timeout(self.timeouts.read) {
error!("set_read_timeout {:?}", e);
return false;
}
let mut keep_alive = self.timeouts.keep_alive.is_some() &&
http::should_keep_alive(req.version, &req.headers);
let version = req.version;
let mut res_headers = Headers::new();
if !keep_alive {
res_headers.set(Connection::close());
}
{
let mut res = Response::new(wrt, &mut res_headers);
res.version = version;
self.handler.handle(req, res);
}
// if the request was keep-alive, we need to check that the server agrees
// if it wasn't, then the server cannot force it to be true anyways
if keep_alive {
keep_alive = http::should_keep_alive(version, &res_headers);
}
debug!("keep_alive = {:?} for {}", keep_alive, addr);
keep_alive
} }
fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool { fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool {

View File

@@ -4,6 +4,7 @@
//! target URI, headers, and message body. //! target URI, headers, and message body.
use std::io::{self, Read}; use std::io::{self, Read};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration;
use buffer::BufReader; use buffer::BufReader;
use net::NetworkStream; use net::NetworkStream;
@@ -64,6 +65,19 @@ impl<'a, 'b: 'a> Request<'a, 'b> {
}) })
} }
/// Set the read timeout of the underlying NetworkStream.
#[cfg(feature = "timeouts")]
#[inline]
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
self.body.get_ref().get_ref().set_read_timeout(timeout)
}
/// Set the read timeout of the underlying NetworkStream.
#[cfg(not(feature = "timeouts"))]
#[inline]
pub fn set_read_timeout(&self, _timeout: Option<Duration>) -> io::Result<()> {
Ok(())
}
/// Get a reference to the underlying `NetworkStream`. /// Get a reference to the underlying `NetworkStream`.
#[inline] #[inline]
pub fn downcast_ref<T: NetworkStream>(&self) -> Option<&T> { pub fn downcast_ref<T: NetworkStream>(&self) -> Option<&T> {