feat(server): check Response headers for Connection: close in keep_alive loop
BREAKING CHANGE: Usage of Response.deconstruct() and construct() now use a &mut Headers, instead of the struct proper.
This commit is contained in:
		| @@ -33,7 +33,7 @@ pub use net::{Fresh, Streaming}; | ||||
|  | ||||
| use Error; | ||||
| use buffer::BufReader; | ||||
| use header::{Headers, Expect}; | ||||
| use header::{Headers, Expect, Connection}; | ||||
| use http; | ||||
| use method::Method; | ||||
| use net::{NetworkListener, NetworkStream, HttpListener}; | ||||
| @@ -142,7 +142,7 @@ L: NetworkListener + Send + 'static { | ||||
|  | ||||
|     debug!("threads = {:?}", threads); | ||||
|     let pool = ListenerPool::new(listener.clone()); | ||||
|     let work = move |mut stream| handle_connection(&mut stream, &handler); | ||||
|     let work = move |mut stream| Worker(&handler).handle_connection(&mut stream); | ||||
|  | ||||
|     let guard = thread::spawn(move || pool.accept(work, threads)); | ||||
|  | ||||
| @@ -152,8 +152,11 @@ L: NetworkListener + Send + 'static { | ||||
|     }) | ||||
| } | ||||
|  | ||||
| fn handle_connection<'h, S, H>(mut stream: &mut S, handler: &'h H) | ||||
| where S: NetworkStream + Clone, H: Handler { | ||||
| struct Worker<'a, H: Handler + 'static>(&'a H); | ||||
|  | ||||
| impl<'a, H: Handler + 'static> Worker<'a, H> { | ||||
|  | ||||
|     fn handle_connection<S>(&self, mut stream: &mut S) where S: NetworkStream + Clone { | ||||
|         debug!("Incoming stream"); | ||||
|         let addr = match stream.peer_addr() { | ||||
|             Ok(addr) => addr, | ||||
| @@ -165,9 +168,14 @@ where S: NetworkStream + Clone, H: Handler { | ||||
|  | ||||
|         // FIXME: Use Type ascription | ||||
|         let stream_clone: &mut NetworkStream = &mut stream.clone(); | ||||
|     let mut rdr = BufReader::new(stream_clone); | ||||
|     let mut wrt = BufWriter::new(stream); | ||||
|         let rdr = BufReader::new(stream_clone); | ||||
|         let wrt = BufWriter::new(stream); | ||||
|  | ||||
|         self.keep_alive_loop(rdr, wrt, addr); | ||||
|         debug!("keep_alive loop ending for {}", addr); | ||||
|     } | ||||
|  | ||||
|     fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>, mut wrt: W, addr: SocketAddr) { | ||||
|         let mut keep_alive = true; | ||||
|         while keep_alive { | ||||
|             let req = match Request::new(&mut rdr, addr) { | ||||
| @@ -187,27 +195,52 @@ where S: NetworkStream + Clone, H: Handler { | ||||
|                 } | ||||
|             }; | ||||
|  | ||||
|  | ||||
|             if !self.handle_expect(&req, &mut wrt) { | ||||
|                 break; | ||||
|             } | ||||
|  | ||||
|             keep_alive = http::should_keep_alive(req.version, &req.headers); | ||||
|             let version = req.version; | ||||
|             let mut res_headers = Headers::new(); | ||||
|             if !keep_alive { | ||||
|                 res_headers.set(Connection::close()); | ||||
|             } | ||||
|             { | ||||
|                 let mut res = Response::new(&mut wrt, &mut res_headers); | ||||
|                 res.version = version; | ||||
|                 self.0.handle(req, res); | ||||
|             } | ||||
|  | ||||
|             // if the request was keep-alive, we need to check that the server agrees | ||||
|             // if it wasn't, then the server cannot force it to be true anyways | ||||
|             if keep_alive { | ||||
|                 keep_alive = http::should_keep_alive(version, &res_headers); | ||||
|             } | ||||
|  | ||||
|             debug!("keep_alive = {:?} for {}", keep_alive, addr); | ||||
|         } | ||||
|   | ||||
|     } | ||||
|  | ||||
|     fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool { | ||||
|          if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) { | ||||
|             let status = handler.check_continue((&req.method, &req.uri, &req.headers)); | ||||
|             match write!(&mut wrt, "{} {}\r\n\r\n", Http11, status) { | ||||
|             let status = self.0.check_continue((&req.method, &req.uri, &req.headers)); | ||||
|             match write!(wrt, "{} {}\r\n\r\n", Http11, status) { | ||||
|                 Ok(..) => (), | ||||
|                 Err(e) => { | ||||
|                     error!("error writing 100-continue: {:?}", e); | ||||
|                     break; | ||||
|                     return false; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             if status != StatusCode::Continue { | ||||
|                 debug!("non-100 status ({}) for Expect 100 request", status); | ||||
|                 break; | ||||
|                 return false; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         keep_alive = http::should_keep_alive(req.version, &req.headers); | ||||
|         let mut res = Response::new(&mut wrt); | ||||
|         res.version = req.version; | ||||
|         handler.handle(req, res); | ||||
|         debug!("keep_alive = {:?}", keep_alive); | ||||
|         true | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -270,7 +303,7 @@ mod tests { | ||||
|     use status::StatusCode; | ||||
|     use uri::RequestUri; | ||||
|  | ||||
|     use super::{Request, Response, Fresh, Handler, handle_connection}; | ||||
|     use super::{Request, Response, Fresh, Handler, Worker}; | ||||
|  | ||||
|     #[test] | ||||
|     fn test_check_continue_default() { | ||||
| @@ -287,7 +320,7 @@ mod tests { | ||||
|             res.start().unwrap().end().unwrap(); | ||||
|         } | ||||
|  | ||||
|         handle_connection(&mut mock, &handle); | ||||
|         Worker(&handle).handle_connection(&mut mock); | ||||
|         let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; | ||||
|         assert_eq!(&mock.write[..cont.len()], cont); | ||||
|         let res = b"HTTP/1.1 200 OK\r\n"; | ||||
| @@ -316,7 +349,7 @@ mod tests { | ||||
|             1234567890\ | ||||
|         "); | ||||
|  | ||||
|         handle_connection(&mut mock, &Reject); | ||||
|         Worker(&Reject).handle_connection(&mut mock); | ||||
|         assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]); | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -28,7 +28,7 @@ pub struct Response<'a, W: Any = Fresh> { | ||||
|     // The status code for the request. | ||||
|     status: status::StatusCode, | ||||
|     // The outgoing headers on this response. | ||||
|     headers: header::Headers, | ||||
|     headers: &'a mut header::Headers, | ||||
|  | ||||
|     _writing: PhantomData<W> | ||||
| } | ||||
| @@ -39,13 +39,13 @@ impl<'a, W: Any> Response<'a, W> { | ||||
|     pub fn status(&self) -> status::StatusCode { self.status } | ||||
|  | ||||
|     /// The headers of this response. | ||||
|     pub fn headers(&self) -> &header::Headers { &self.headers } | ||||
|     pub fn headers(&self) -> &header::Headers { &*self.headers } | ||||
|  | ||||
|     /// Construct a Response from its constituent parts. | ||||
|     pub fn construct(version: version::HttpVersion, | ||||
|                      body: HttpWriter<&'a mut (Write + 'a)>, | ||||
|                      status: status::StatusCode, | ||||
|                      headers: header::Headers) -> Response<'a, Fresh> { | ||||
|                      headers: &'a mut header::Headers) -> Response<'a, Fresh> { | ||||
|         Response { | ||||
|             status: status, | ||||
|             version: version, | ||||
| @@ -57,7 +57,7 @@ impl<'a, W: Any> Response<'a, W> { | ||||
|  | ||||
|     /// Deconstruct this Response into its constituent parts. | ||||
|     pub fn deconstruct(self) -> (version::HttpVersion, HttpWriter<&'a mut (Write + 'a)>, | ||||
|                                  status::StatusCode, header::Headers) { | ||||
|                                  status::StatusCode, &'a mut header::Headers) { | ||||
|         unsafe { | ||||
|             let parts = ( | ||||
|                 self.version, | ||||
| @@ -114,11 +114,11 @@ impl<'a, W: Any> Response<'a, W> { | ||||
| impl<'a> Response<'a, Fresh> { | ||||
|     /// Creates a new Response that can be used to write to a network stream. | ||||
|     #[inline] | ||||
|     pub fn new(stream: &'a mut (Write + 'a)) -> Response<'a, Fresh> { | ||||
|     pub fn new(stream: &'a mut (Write + 'a), headers: &'a mut header::Headers) -> Response<'a, Fresh> { | ||||
|         Response { | ||||
|             status: status::StatusCode::Ok, | ||||
|             version: version::HttpVersion::Http11, | ||||
|             headers: header::Headers::new(), | ||||
|             headers: headers, | ||||
|             body: ThroughWriter(stream), | ||||
|             _writing: PhantomData, | ||||
|         } | ||||
| @@ -165,7 +165,7 @@ impl<'a> Response<'a, Fresh> { | ||||
|  | ||||
|     /// Get a mutable reference to the Headers. | ||||
|     #[inline] | ||||
|     pub fn headers_mut(&mut self) -> &mut header::Headers { &mut self.headers } | ||||
|     pub fn headers_mut(&mut self) -> &mut header::Headers { self.headers } | ||||
| } | ||||
|  | ||||
|  | ||||
| @@ -231,6 +231,7 @@ impl<'a, T: Any> Drop for Response<'a, T> { | ||||
|  | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|     use header::Headers; | ||||
|     use mock::MockStream; | ||||
|     use super::Response; | ||||
|  | ||||
| @@ -252,9 +253,10 @@ mod tests { | ||||
|  | ||||
|     #[test] | ||||
|     fn test_fresh_start() { | ||||
|         let mut headers = Headers::new(); | ||||
|         let mut stream = MockStream::new(); | ||||
|         { | ||||
|             let res = Response::new(&mut stream); | ||||
|             let res = Response::new(&mut stream, &mut headers); | ||||
|             res.start().unwrap().deconstruct(); | ||||
|         } | ||||
|  | ||||
| @@ -268,9 +270,10 @@ mod tests { | ||||
|  | ||||
|     #[test] | ||||
|     fn test_streaming_end() { | ||||
|         let mut headers = Headers::new(); | ||||
|         let mut stream = MockStream::new(); | ||||
|         { | ||||
|             let res = Response::new(&mut stream); | ||||
|             let res = Response::new(&mut stream, &mut headers); | ||||
|             res.start().unwrap().end().unwrap(); | ||||
|         } | ||||
|  | ||||
| @@ -287,9 +290,10 @@ mod tests { | ||||
|     #[test] | ||||
|     fn test_fresh_drop() { | ||||
|         use status::StatusCode; | ||||
|         let mut headers = Headers::new(); | ||||
|         let mut stream = MockStream::new(); | ||||
|         { | ||||
|             let mut res = Response::new(&mut stream); | ||||
|             let mut res = Response::new(&mut stream, &mut headers); | ||||
|             *res.status_mut() = StatusCode::NotFound; | ||||
|         } | ||||
|  | ||||
| @@ -307,9 +311,10 @@ mod tests { | ||||
|     fn test_streaming_drop() { | ||||
|         use std::io::Write; | ||||
|         use status::StatusCode; | ||||
|         let mut headers = Headers::new(); | ||||
|         let mut stream = MockStream::new(); | ||||
|         { | ||||
|             let mut res = Response::new(&mut stream); | ||||
|             let mut res = Response::new(&mut stream, &mut headers); | ||||
|             *res.status_mut() = StatusCode::NotFound; | ||||
|             let mut stream = res.start().unwrap(); | ||||
|             stream.write_all(b"foo").unwrap(); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user