feat(net): add socket timeouts to Server and Client

While these methods are marked unstable in libstd, this is behind a
feature flag, `timeouts`. The Client and Server both have
`set_read_timeout` and `set_write_timeout` methods, that will affect all
connections with that entity.

BREAKING CHANGE: Any custom implementation of NetworkStream must now
  implement `set_read_timeout` and `set_write_timeout`, so those will
  break. Most users who only use the provided streams should work with
  no changes needed.

Closes #315
This commit is contained in:
Sean McArthur
2015-06-16 11:02:36 -07:00
parent 421422b620
commit 7d1f154cb7
11 changed files with 311 additions and 50 deletions

View File

@@ -108,10 +108,13 @@
//! `Request<Streaming>` object, that no longer has `headers_mut()`, but does
//! implement `Write`.
use std::fmt;
use std::io::{ErrorKind, BufWriter, Write};
use std::io::{self, ErrorKind, BufWriter, Write};
use std::net::{SocketAddr, ToSocketAddrs};
use std::thread::{self, JoinHandle};
#[cfg(feature = "timeouts")]
use std::time::Duration;
use num_cpus;
pub use self::request::Request;
@@ -143,8 +146,20 @@ mod listener;
#[derive(Debug)]
pub struct Server<L = HttpListener> {
listener: L,
_timeouts: Timeouts,
}
#[cfg(feature = "timeouts")]
#[derive(Clone, Copy, Default, Debug)]
struct Timeouts {
read: Option<Duration>,
write: Option<Duration>,
}
#[cfg(not(feature = "timeouts"))]
#[derive(Clone, Copy, Default, Debug)]
struct Timeouts;
macro_rules! try_option(
($e:expr) => {{
match $e {
@@ -159,9 +174,22 @@ impl<L: NetworkListener> Server<L> {
#[inline]
pub fn new(listener: L) -> Server<L> {
Server {
listener: listener
listener: listener,
_timeouts: Timeouts::default(),
}
}
#[cfg(feature = "timeouts")]
pub fn set_read_timeout(&mut self, dur: Option<Duration>) {
self._timeouts.read = dur;
}
#[cfg(feature = "timeouts")]
pub fn set_write_timeout(&mut self, dur: Option<Duration>) {
self._timeouts.write = dur;
}
}
impl Server<HttpListener> {
@@ -183,24 +211,25 @@ impl<S: Ssl + Clone + Send> Server<HttpsListener<S>> {
impl<L: NetworkListener + Send + 'static> Server<L> {
/// Binds to a socket and starts handling connections.
pub fn handle<H: Handler + 'static>(self, handler: H) -> ::Result<Listening> {
with_listener(handler, self.listener, num_cpus::get() * 5 / 4)
self.handle_threads(handler, num_cpus::get() * 5 / 4)
}
/// Binds to a socket and starts handling connections with the provided
/// number of threads.
pub fn handle_threads<H: Handler + 'static>(self, handler: H,
threads: usize) -> ::Result<Listening> {
with_listener(handler, self.listener, threads)
handle(self, handler, threads)
}
}
fn with_listener<H, L>(handler: H, mut listener: L, threads: usize) -> ::Result<Listening>
fn handle<H, L>(mut server: Server<L>, handler: H, threads: usize) -> ::Result<Listening>
where H: Handler + 'static,
L: NetworkListener + Send + 'static {
let socket = try!(listener.local_addr());
let socket = try!(server.listener.local_addr());
debug!("threads = {:?}", threads);
let pool = ListenerPool::new(listener);
let work = move |mut stream| Worker(&handler).handle_connection(&mut stream);
let pool = ListenerPool::new(server.listener);
let worker = Worker::new(handler, server._timeouts);
let work = move |mut stream| worker.handle_connection(&mut stream);
let guard = thread::spawn(move || pool.accept(work, threads));
@@ -210,12 +239,28 @@ L: NetworkListener + Send + 'static {
})
}
struct Worker<'a, H: Handler + 'static>(&'a H);
struct Worker<H: Handler + 'static> {
handler: H,
_timeouts: Timeouts,
}
impl<'a, H: Handler + 'static> Worker<'a, H> {
impl<H: Handler + 'static> Worker<H> {
fn new(handler: H, timeouts: Timeouts) -> Worker<H> {
Worker {
handler: handler,
_timeouts: timeouts,
}
}
fn handle_connection<S>(&self, mut stream: &mut S) where S: NetworkStream + Clone {
debug!("Incoming stream");
if let Err(e) = self.set_timeouts(stream) {
error!("set_timeouts error: {:?}", e);
return;
}
let addr = match stream.peer_addr() {
Ok(addr) => addr,
Err(e) => {
@@ -233,6 +278,17 @@ impl<'a, H: Handler + 'static> Worker<'a, H> {
debug!("keep_alive loop ending for {}", addr);
}
#[cfg(not(feature = "timeouts"))]
fn set_timeouts<S>(&self, _: &mut S) -> io::Result<()> where S: NetworkStream {
Ok(())
}
#[cfg(feature = "timeouts")]
fn set_timeouts<S>(&self, s: &mut S) -> io::Result<()> where S: NetworkStream {
try!(s.set_read_timeout(self._timeouts.read));
s.set_write_timeout(self._timeouts.write)
}
fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>,
mut wrt: W, addr: SocketAddr) {
let mut keep_alive = true;
@@ -268,7 +324,7 @@ impl<'a, H: Handler + 'static> Worker<'a, H> {
{
let mut res = Response::new(&mut wrt, &mut res_headers);
res.version = version;
self.0.handle(req, res);
self.handler.handle(req, res);
}
// if the request was keep-alive, we need to check that the server agrees
@@ -284,7 +340,7 @@ impl<'a, H: Handler + 'static> Worker<'a, H> {
fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool {
if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) {
let status = self.0.check_continue((&req.method, &req.uri, &req.headers));
let status = self.handler.check_continue((&req.method, &req.uri, &req.headers));
match write!(wrt, "{} {}\r\n\r\n", Http11, status) {
Ok(..) => (),
Err(e) => {
@@ -327,7 +383,6 @@ impl Listening {
pub fn close(&mut self) -> ::Result<()> {
let _ = self._guard.take();
debug!("closing server");
//try!(self.acceptor.close());
Ok(())
}
}
@@ -379,7 +434,7 @@ mod tests {
res.start().unwrap().end().unwrap();
}
Worker(&handle).handle_connection(&mut mock);
Worker::new(handle, Default::default()).handle_connection(&mut mock);
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
assert_eq!(&mock.write[..cont.len()], cont);
let res = b"HTTP/1.1 200 OK\r\n";
@@ -408,7 +463,7 @@ mod tests {
1234567890\
");
Worker(&Reject).handle_connection(&mut mock);
Worker::new(Reject, Default::default()).handle_connection(&mut mock);
assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]);
}
}