fix(server): use a timeout for Server keep-alive
Server keep-alive is now **off** by default. In order to turn it on, the `keep_alive` method must be called on the `Server` object. Closes #368
This commit is contained in:
		| @@ -111,8 +111,6 @@ use std::fmt; | ||||
| use std::io::{self, ErrorKind, BufWriter, Write}; | ||||
| use std::net::{SocketAddr, ToSocketAddrs}; | ||||
| use std::thread::{self, JoinHandle}; | ||||
|  | ||||
| #[cfg(feature = "timeouts")] | ||||
| use std::time::Duration; | ||||
|  | ||||
| use num_cpus; | ||||
| @@ -146,20 +144,16 @@ mod listener; | ||||
| #[derive(Debug)] | ||||
| pub struct Server<L = HttpListener> { | ||||
|     listener: L, | ||||
|     _timeouts: Timeouts, | ||||
|     timeouts: Timeouts, | ||||
| } | ||||
|  | ||||
| #[cfg(feature = "timeouts")] | ||||
| #[derive(Clone, Copy, Default, Debug)] | ||||
| struct Timeouts { | ||||
|     read: Option<Duration>, | ||||
|     write: Option<Duration>, | ||||
|     keep_alive: Option<Duration>, | ||||
| } | ||||
|  | ||||
| #[cfg(not(feature = "timeouts"))] | ||||
| #[derive(Clone, Copy, Default, Debug)] | ||||
| struct Timeouts; | ||||
|  | ||||
| macro_rules! try_option( | ||||
|     ($e:expr) => {{ | ||||
|         match $e { | ||||
| @@ -175,18 +169,30 @@ impl<L: NetworkListener> Server<L> { | ||||
|     pub fn new(listener: L) -> Server<L> { | ||||
|         Server { | ||||
|             listener: listener, | ||||
|             _timeouts: Timeouts::default(), | ||||
|             timeouts: Timeouts::default(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /// Enables keep-alive for this server. | ||||
|     /// | ||||
|     /// The timeout duration passed will be used to determine how long | ||||
|     /// to keep the connection alive before dropping it. | ||||
|     /// | ||||
|     /// **NOTE**: The timeout will only be used when the `timeouts` feature | ||||
|     /// is enabled for hyper, and rustc is 1.4 or greater. | ||||
|     #[inline] | ||||
|     pub fn keep_alive(&mut self, timeout: Duration) { | ||||
|         self.timeouts.keep_alive = Some(timeout); | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     pub fn set_read_timeout(&mut self, dur: Option<Duration>) { | ||||
|         self._timeouts.read = dur; | ||||
|         self.timeouts.read = dur; | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     pub fn set_write_timeout(&mut self, dur: Option<Duration>) { | ||||
|         self._timeouts.write = dur; | ||||
|         self.timeouts.write = dur; | ||||
|     } | ||||
|  | ||||
|  | ||||
| @@ -228,7 +234,7 @@ L: NetworkListener + Send + 'static { | ||||
|  | ||||
|     debug!("threads = {:?}", threads); | ||||
|     let pool = ListenerPool::new(server.listener); | ||||
|     let worker = Worker::new(handler, server._timeouts); | ||||
|     let worker = Worker::new(handler, server.timeouts); | ||||
|     let work = move |mut stream| worker.handle_connection(&mut stream); | ||||
|  | ||||
|     let guard = thread::spawn(move || pool.accept(work, threads)); | ||||
| @@ -241,7 +247,7 @@ L: NetworkListener + Send + 'static { | ||||
|  | ||||
| struct Worker<H: Handler + 'static> { | ||||
|     handler: H, | ||||
|     _timeouts: Timeouts, | ||||
|     timeouts: Timeouts, | ||||
| } | ||||
|  | ||||
| impl<H: Handler + 'static> Worker<H> { | ||||
| @@ -249,7 +255,7 @@ impl<H: Handler + 'static> Worker<H> { | ||||
|     fn new(handler: H, timeouts: Timeouts) -> Worker<H> { | ||||
|         Worker { | ||||
|             handler: handler, | ||||
|             _timeouts: timeouts, | ||||
|             timeouts: timeouts, | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -258,7 +264,7 @@ impl<H: Handler + 'static> Worker<H> { | ||||
|  | ||||
|         self.handler.on_connection_start(); | ||||
|  | ||||
|         if let Err(e) = self.set_timeouts(stream) { | ||||
|         if let Err(e) = self.set_timeouts(&*stream) { | ||||
|             error!("set_timeouts error: {:?}", e); | ||||
|             return; | ||||
|         } | ||||
| @@ -273,73 +279,97 @@ impl<H: Handler + 'static> Worker<H> { | ||||
|  | ||||
|         // FIXME: Use Type ascription | ||||
|         let stream_clone: &mut NetworkStream = &mut stream.clone(); | ||||
|         let rdr = BufReader::new(stream_clone); | ||||
|         let wrt = BufWriter::new(stream); | ||||
|         let mut rdr = BufReader::new(stream_clone); | ||||
|         let mut wrt = BufWriter::new(stream); | ||||
|  | ||||
|         self.keep_alive_loop(rdr, wrt, addr); | ||||
|         while self.keep_alive_loop(&mut rdr, &mut wrt, addr) { | ||||
|             if let Err(e) = self.set_read_timeout(*rdr.get_ref(), self.timeouts.keep_alive) { | ||||
|                 error!("set_read_timeout keep_alive {:?}", e); | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         self.handler.on_connection_end(); | ||||
|  | ||||
|         debug!("keep_alive loop ending for {}", addr); | ||||
|     } | ||||
|  | ||||
|     fn set_timeouts(&self, s: &NetworkStream) -> io::Result<()> { | ||||
|         try!(self.set_read_timeout(s, self.timeouts.read)); | ||||
|         self.set_write_timeout(s, self.timeouts.write) | ||||
|     } | ||||
|  | ||||
|  | ||||
|     #[cfg(not(feature = "timeouts"))] | ||||
|     fn set_timeouts<S>(&self, _: &mut S) -> io::Result<()> where S: NetworkStream { | ||||
|     fn set_write_timeout(&self, _s: &NetworkStream, _timeout: Option<Duration>) -> io::Result<()> { | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     fn set_timeouts<S>(&self, s: &mut S) -> io::Result<()> where S: NetworkStream { | ||||
|         try!(s.set_read_timeout(self._timeouts.read)); | ||||
|         s.set_write_timeout(self._timeouts.write) | ||||
|     fn set_write_timeout(&self, s: &NetworkStream, timeout: Option<Duration>) -> io::Result<()> { | ||||
|         s.set_write_timeout(timeout) | ||||
|     } | ||||
|  | ||||
|     fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>, | ||||
|             mut wrt: W, addr: SocketAddr) { | ||||
|         let mut keep_alive = true; | ||||
|         while keep_alive { | ||||
|             let req = match Request::new(&mut rdr, addr) { | ||||
|                 Ok(req) => req, | ||||
|                 Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => { | ||||
|                     trace!("tcp closed, cancelling keep-alive loop"); | ||||
|                     break; | ||||
|                 } | ||||
|                 Err(Error::Io(e)) => { | ||||
|                     debug!("ioerror in keepalive loop = {:?}", e); | ||||
|                     break; | ||||
|                 } | ||||
|                 Err(e) => { | ||||
|                     //TODO: send a 400 response | ||||
|                     error!("request error = {:?}", e); | ||||
|                     break; | ||||
|                 } | ||||
|             }; | ||||
|     #[cfg(not(feature = "timeouts"))] | ||||
|     fn set_read_timeout(&self, _s: &NetworkStream, _timeout: Option<Duration>) -> io::Result<()> { | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     fn set_read_timeout(&self, s: &NetworkStream, timeout: Option<Duration>) -> io::Result<()> { | ||||
|         s.set_read_timeout(timeout) | ||||
|     } | ||||
|  | ||||
|             if !self.handle_expect(&req, &mut wrt) { | ||||
|                 break; | ||||
|     fn keep_alive_loop<W: Write>(&self, mut rdr: &mut BufReader<&mut NetworkStream>, | ||||
|             wrt: &mut W, addr: SocketAddr) -> bool { | ||||
|         let req = match Request::new(rdr, addr) { | ||||
|             Ok(req) => req, | ||||
|             Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => { | ||||
|                 trace!("tcp closed, cancelling keep-alive loop"); | ||||
|                 return false; | ||||
|             } | ||||
|  | ||||
|             keep_alive = http::should_keep_alive(req.version, &req.headers); | ||||
|             let version = req.version; | ||||
|             let mut res_headers = Headers::new(); | ||||
|             if !keep_alive { | ||||
|                 res_headers.set(Connection::close()); | ||||
|             Err(Error::Io(e)) => { | ||||
|                 debug!("ioerror in keepalive loop = {:?}", e); | ||||
|                 return false; | ||||
|             } | ||||
|             { | ||||
|                 let mut res = Response::new(&mut wrt, &mut res_headers); | ||||
|                 res.version = version; | ||||
|                 self.handler.handle(req, res); | ||||
|             Err(e) => { | ||||
|                 //TODO: send a 400 response | ||||
|                 error!("request error = {:?}", e); | ||||
|                 return false; | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|             // if the request was keep-alive, we need to check that the server agrees | ||||
|             // if it wasn't, then the server cannot force it to be true anyways | ||||
|             if keep_alive { | ||||
|                 keep_alive = http::should_keep_alive(version, &res_headers); | ||||
|             } | ||||
|  | ||||
|             debug!("keep_alive = {:?} for {}", keep_alive, addr); | ||||
|         if !self.handle_expect(&req, wrt) { | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         if let Err(e) = req.set_read_timeout(self.timeouts.read) { | ||||
|             error!("set_read_timeout {:?}", e); | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         let mut keep_alive = self.timeouts.keep_alive.is_some() && | ||||
|             http::should_keep_alive(req.version, &req.headers); | ||||
|         let version = req.version; | ||||
|         let mut res_headers = Headers::new(); | ||||
|         if !keep_alive { | ||||
|             res_headers.set(Connection::close()); | ||||
|         } | ||||
|         { | ||||
|             let mut res = Response::new(wrt, &mut res_headers); | ||||
|             res.version = version; | ||||
|             self.handler.handle(req, res); | ||||
|         } | ||||
|  | ||||
|         // if the request was keep-alive, we need to check that the server agrees | ||||
|         // if it wasn't, then the server cannot force it to be true anyways | ||||
|         if keep_alive { | ||||
|             keep_alive = http::should_keep_alive(version, &res_headers); | ||||
|         } | ||||
|  | ||||
|         debug!("keep_alive = {:?} for {}", keep_alive, addr); | ||||
|         keep_alive | ||||
|     } | ||||
|  | ||||
|     fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool { | ||||
|   | ||||
| @@ -4,6 +4,7 @@ | ||||
| //! target URI, headers, and message body. | ||||
| use std::io::{self, Read}; | ||||
| use std::net::SocketAddr; | ||||
| use std::time::Duration; | ||||
|  | ||||
| use buffer::BufReader; | ||||
| use net::NetworkStream; | ||||
| @@ -64,6 +65,19 @@ impl<'a, 'b: 'a> Request<'a, 'b> { | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     /// Set the read timeout of the underlying NetworkStream. | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     #[inline] | ||||
|     pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> { | ||||
|         self.body.get_ref().get_ref().set_read_timeout(timeout) | ||||
|     } | ||||
|  | ||||
|     /// Set the read timeout of the underlying NetworkStream. | ||||
|     #[cfg(not(feature = "timeouts"))] | ||||
|     #[inline] | ||||
|     pub fn set_read_timeout(&self, _timeout: Option<Duration>) -> io::Result<()> { | ||||
|         Ok(()) | ||||
|     } | ||||
|     /// Get a reference to the underlying `NetworkStream`. | ||||
|     #[inline] | ||||
|     pub fn downcast_ref<T: NetworkStream>(&self) -> Option<&T> { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user