feat(net): add socket timeouts to Server and Client
While these methods are marked unstable in libstd, this is behind a feature flag, `timeouts`. The Client and Server both have `set_read_timeout` and `set_write_timeout` methods, that will affect all connections with that entity. BREAKING CHANGE: Any custom implementation of NetworkStream must now implement `set_read_timeout` and `set_write_timeout`, so those will break. Most users who only use the provided streams should work with no changes needed. Closes #315
This commit is contained in:
@@ -108,10 +108,13 @@
|
||||
//! `Request<Streaming>` object, that no longer has `headers_mut()`, but does
|
||||
//! implement `Write`.
|
||||
use std::fmt;
|
||||
use std::io::{ErrorKind, BufWriter, Write};
|
||||
use std::io::{self, ErrorKind, BufWriter, Write};
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
use std::thread::{self, JoinHandle};
|
||||
|
||||
#[cfg(feature = "timeouts")]
|
||||
use std::time::Duration;
|
||||
|
||||
use num_cpus;
|
||||
|
||||
pub use self::request::Request;
|
||||
@@ -143,8 +146,20 @@ mod listener;
|
||||
#[derive(Debug)]
|
||||
pub struct Server<L = HttpListener> {
|
||||
listener: L,
|
||||
_timeouts: Timeouts,
|
||||
}
|
||||
|
||||
#[cfg(feature = "timeouts")]
|
||||
#[derive(Clone, Copy, Default, Debug)]
|
||||
struct Timeouts {
|
||||
read: Option<Duration>,
|
||||
write: Option<Duration>,
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "timeouts"))]
|
||||
#[derive(Clone, Copy, Default, Debug)]
|
||||
struct Timeouts;
|
||||
|
||||
macro_rules! try_option(
|
||||
($e:expr) => {{
|
||||
match $e {
|
||||
@@ -159,9 +174,22 @@ impl<L: NetworkListener> Server<L> {
|
||||
#[inline]
|
||||
pub fn new(listener: L) -> Server<L> {
|
||||
Server {
|
||||
listener: listener
|
||||
listener: listener,
|
||||
_timeouts: Timeouts::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "timeouts")]
|
||||
pub fn set_read_timeout(&mut self, dur: Option<Duration>) {
|
||||
self._timeouts.read = dur;
|
||||
}
|
||||
|
||||
#[cfg(feature = "timeouts")]
|
||||
pub fn set_write_timeout(&mut self, dur: Option<Duration>) {
|
||||
self._timeouts.write = dur;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
impl Server<HttpListener> {
|
||||
@@ -183,24 +211,25 @@ impl<S: Ssl + Clone + Send> Server<HttpsListener<S>> {
|
||||
impl<L: NetworkListener + Send + 'static> Server<L> {
|
||||
/// Binds to a socket and starts handling connections.
|
||||
pub fn handle<H: Handler + 'static>(self, handler: H) -> ::Result<Listening> {
|
||||
with_listener(handler, self.listener, num_cpus::get() * 5 / 4)
|
||||
self.handle_threads(handler, num_cpus::get() * 5 / 4)
|
||||
}
|
||||
/// Binds to a socket and starts handling connections with the provided
|
||||
/// number of threads.
|
||||
pub fn handle_threads<H: Handler + 'static>(self, handler: H,
|
||||
threads: usize) -> ::Result<Listening> {
|
||||
with_listener(handler, self.listener, threads)
|
||||
handle(self, handler, threads)
|
||||
}
|
||||
}
|
||||
|
||||
fn with_listener<H, L>(handler: H, mut listener: L, threads: usize) -> ::Result<Listening>
|
||||
fn handle<H, L>(mut server: Server<L>, handler: H, threads: usize) -> ::Result<Listening>
|
||||
where H: Handler + 'static,
|
||||
L: NetworkListener + Send + 'static {
|
||||
let socket = try!(listener.local_addr());
|
||||
let socket = try!(server.listener.local_addr());
|
||||
|
||||
debug!("threads = {:?}", threads);
|
||||
let pool = ListenerPool::new(listener);
|
||||
let work = move |mut stream| Worker(&handler).handle_connection(&mut stream);
|
||||
let pool = ListenerPool::new(server.listener);
|
||||
let worker = Worker::new(handler, server._timeouts);
|
||||
let work = move |mut stream| worker.handle_connection(&mut stream);
|
||||
|
||||
let guard = thread::spawn(move || pool.accept(work, threads));
|
||||
|
||||
@@ -210,12 +239,28 @@ L: NetworkListener + Send + 'static {
|
||||
})
|
||||
}
|
||||
|
||||
struct Worker<'a, H: Handler + 'static>(&'a H);
|
||||
struct Worker<H: Handler + 'static> {
|
||||
handler: H,
|
||||
_timeouts: Timeouts,
|
||||
}
|
||||
|
||||
impl<'a, H: Handler + 'static> Worker<'a, H> {
|
||||
impl<H: Handler + 'static> Worker<H> {
|
||||
|
||||
fn new(handler: H, timeouts: Timeouts) -> Worker<H> {
|
||||
Worker {
|
||||
handler: handler,
|
||||
_timeouts: timeouts,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_connection<S>(&self, mut stream: &mut S) where S: NetworkStream + Clone {
|
||||
debug!("Incoming stream");
|
||||
|
||||
if let Err(e) = self.set_timeouts(stream) {
|
||||
error!("set_timeouts error: {:?}", e);
|
||||
return;
|
||||
}
|
||||
|
||||
let addr = match stream.peer_addr() {
|
||||
Ok(addr) => addr,
|
||||
Err(e) => {
|
||||
@@ -233,6 +278,17 @@ impl<'a, H: Handler + 'static> Worker<'a, H> {
|
||||
debug!("keep_alive loop ending for {}", addr);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "timeouts"))]
|
||||
fn set_timeouts<S>(&self, _: &mut S) -> io::Result<()> where S: NetworkStream {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "timeouts")]
|
||||
fn set_timeouts<S>(&self, s: &mut S) -> io::Result<()> where S: NetworkStream {
|
||||
try!(s.set_read_timeout(self._timeouts.read));
|
||||
s.set_write_timeout(self._timeouts.write)
|
||||
}
|
||||
|
||||
fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>,
|
||||
mut wrt: W, addr: SocketAddr) {
|
||||
let mut keep_alive = true;
|
||||
@@ -268,7 +324,7 @@ impl<'a, H: Handler + 'static> Worker<'a, H> {
|
||||
{
|
||||
let mut res = Response::new(&mut wrt, &mut res_headers);
|
||||
res.version = version;
|
||||
self.0.handle(req, res);
|
||||
self.handler.handle(req, res);
|
||||
}
|
||||
|
||||
// if the request was keep-alive, we need to check that the server agrees
|
||||
@@ -284,7 +340,7 @@ impl<'a, H: Handler + 'static> Worker<'a, H> {
|
||||
|
||||
fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool {
|
||||
if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) {
|
||||
let status = self.0.check_continue((&req.method, &req.uri, &req.headers));
|
||||
let status = self.handler.check_continue((&req.method, &req.uri, &req.headers));
|
||||
match write!(wrt, "{} {}\r\n\r\n", Http11, status) {
|
||||
Ok(..) => (),
|
||||
Err(e) => {
|
||||
@@ -327,7 +383,6 @@ impl Listening {
|
||||
pub fn close(&mut self) -> ::Result<()> {
|
||||
let _ = self._guard.take();
|
||||
debug!("closing server");
|
||||
//try!(self.acceptor.close());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -379,7 +434,7 @@ mod tests {
|
||||
res.start().unwrap().end().unwrap();
|
||||
}
|
||||
|
||||
Worker(&handle).handle_connection(&mut mock);
|
||||
Worker::new(handle, Default::default()).handle_connection(&mut mock);
|
||||
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
|
||||
assert_eq!(&mock.write[..cont.len()], cont);
|
||||
let res = b"HTTP/1.1 200 OK\r\n";
|
||||
@@ -408,7 +463,7 @@ mod tests {
|
||||
1234567890\
|
||||
");
|
||||
|
||||
Worker(&Reject).handle_connection(&mut mock);
|
||||
Worker::new(Reject, Default::default()).handle_connection(&mut mock);
|
||||
assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user