feat(server): check Response headers for Connection: close in keep_alive loop
BREAKING CHANGE: Usage of Response.deconstruct() and construct() now use a &mut Headers, instead of the struct proper.
This commit is contained in:
		| @@ -33,7 +33,7 @@ pub use net::{Fresh, Streaming}; | |||||||
|  |  | ||||||
| use Error; | use Error; | ||||||
| use buffer::BufReader; | use buffer::BufReader; | ||||||
| use header::{Headers, Expect}; | use header::{Headers, Expect, Connection}; | ||||||
| use http; | use http; | ||||||
| use method::Method; | use method::Method; | ||||||
| use net::{NetworkListener, NetworkStream, HttpListener}; | use net::{NetworkListener, NetworkStream, HttpListener}; | ||||||
| @@ -142,7 +142,7 @@ L: NetworkListener + Send + 'static { | |||||||
|  |  | ||||||
|     debug!("threads = {:?}", threads); |     debug!("threads = {:?}", threads); | ||||||
|     let pool = ListenerPool::new(listener.clone()); |     let pool = ListenerPool::new(listener.clone()); | ||||||
|     let work = move |mut stream| handle_connection(&mut stream, &handler); |     let work = move |mut stream| Worker(&handler).handle_connection(&mut stream); | ||||||
|  |  | ||||||
|     let guard = thread::spawn(move || pool.accept(work, threads)); |     let guard = thread::spawn(move || pool.accept(work, threads)); | ||||||
|  |  | ||||||
| @@ -152,62 +152,95 @@ L: NetworkListener + Send + 'static { | |||||||
|     }) |     }) | ||||||
| } | } | ||||||
|  |  | ||||||
| fn handle_connection<'h, S, H>(mut stream: &mut S, handler: &'h H) | struct Worker<'a, H: Handler + 'static>(&'a H); | ||||||
| where S: NetworkStream + Clone, H: Handler { |  | ||||||
|     debug!("Incoming stream"); |  | ||||||
|     let addr = match stream.peer_addr() { |  | ||||||
|         Ok(addr) => addr, |  | ||||||
|         Err(e) => { |  | ||||||
|             error!("Peer Name error: {:?}", e); |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
|  |  | ||||||
|     // FIXME: Use Type ascription | impl<'a, H: Handler + 'static> Worker<'a, H> { | ||||||
|     let stream_clone: &mut NetworkStream = &mut stream.clone(); |  | ||||||
|     let mut rdr = BufReader::new(stream_clone); |  | ||||||
|     let mut wrt = BufWriter::new(stream); |  | ||||||
|  |  | ||||||
|     let mut keep_alive = true; |     fn handle_connection<S>(&self, mut stream: &mut S) where S: NetworkStream + Clone { | ||||||
|     while keep_alive { |         debug!("Incoming stream"); | ||||||
|         let req = match Request::new(&mut rdr, addr) { |         let addr = match stream.peer_addr() { | ||||||
|             Ok(req) => req, |             Ok(addr) => addr, | ||||||
|             Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => { |  | ||||||
|                 trace!("tcp closed, cancelling keep-alive loop"); |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|             Err(Error::Io(e)) => { |  | ||||||
|                 debug!("ioerror in keepalive loop = {:?}", e); |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|             Err(e) => { |             Err(e) => { | ||||||
|                 //TODO: send a 400 response |                 error!("Peer Name error: {:?}", e); | ||||||
|                 error!("request error = {:?}", e); |                 return; | ||||||
|                 break; |  | ||||||
|             } |             } | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) { |         // FIXME: Use Type ascription | ||||||
|             let status = handler.check_continue((&req.method, &req.uri, &req.headers)); |         let stream_clone: &mut NetworkStream = &mut stream.clone(); | ||||||
|             match write!(&mut wrt, "{} {}\r\n\r\n", Http11, status) { |         let rdr = BufReader::new(stream_clone); | ||||||
|  |         let wrt = BufWriter::new(stream); | ||||||
|  |  | ||||||
|  |         self.keep_alive_loop(rdr, wrt, addr); | ||||||
|  |         debug!("keep_alive loop ending for {}", addr); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>, mut wrt: W, addr: SocketAddr) { | ||||||
|  |         let mut keep_alive = true; | ||||||
|  |         while keep_alive { | ||||||
|  |             let req = match Request::new(&mut rdr, addr) { | ||||||
|  |                 Ok(req) => req, | ||||||
|  |                 Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => { | ||||||
|  |                     trace!("tcp closed, cancelling keep-alive loop"); | ||||||
|  |                     break; | ||||||
|  |                 } | ||||||
|  |                 Err(Error::Io(e)) => { | ||||||
|  |                     debug!("ioerror in keepalive loop = {:?}", e); | ||||||
|  |                     break; | ||||||
|  |                 } | ||||||
|  |                 Err(e) => { | ||||||
|  |                     //TODO: send a 400 response | ||||||
|  |                     error!("request error = {:?}", e); | ||||||
|  |                     break; | ||||||
|  |                 } | ||||||
|  |             }; | ||||||
|  |  | ||||||
|  |  | ||||||
|  |             if !self.handle_expect(&req, &mut wrt) { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             keep_alive = http::should_keep_alive(req.version, &req.headers); | ||||||
|  |             let version = req.version; | ||||||
|  |             let mut res_headers = Headers::new(); | ||||||
|  |             if !keep_alive { | ||||||
|  |                 res_headers.set(Connection::close()); | ||||||
|  |             } | ||||||
|  |             { | ||||||
|  |                 let mut res = Response::new(&mut wrt, &mut res_headers); | ||||||
|  |                 res.version = version; | ||||||
|  |                 self.0.handle(req, res); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             // if the request was keep-alive, we need to check that the server agrees | ||||||
|  |             // if it wasn't, then the server cannot force it to be true anyways | ||||||
|  |             if keep_alive { | ||||||
|  |                 keep_alive = http::should_keep_alive(version, &res_headers); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             debug!("keep_alive = {:?} for {}", keep_alive, addr); | ||||||
|  |         } | ||||||
|  |   | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool { | ||||||
|  |          if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) { | ||||||
|  |             let status = self.0.check_continue((&req.method, &req.uri, &req.headers)); | ||||||
|  |             match write!(wrt, "{} {}\r\n\r\n", Http11, status) { | ||||||
|                 Ok(..) => (), |                 Ok(..) => (), | ||||||
|                 Err(e) => { |                 Err(e) => { | ||||||
|                     error!("error writing 100-continue: {:?}", e); |                     error!("error writing 100-continue: {:?}", e); | ||||||
|                     break; |                     return false; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if status != StatusCode::Continue { |             if status != StatusCode::Continue { | ||||||
|                 debug!("non-100 status ({}) for Expect 100 request", status); |                 debug!("non-100 status ({}) for Expect 100 request", status); | ||||||
|                 break; |                 return false; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         keep_alive = http::should_keep_alive(req.version, &req.headers); |         true | ||||||
|         let mut res = Response::new(&mut wrt); |  | ||||||
|         res.version = req.version; |  | ||||||
|         handler.handle(req, res); |  | ||||||
|         debug!("keep_alive = {:?}", keep_alive); |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -270,7 +303,7 @@ mod tests { | |||||||
|     use status::StatusCode; |     use status::StatusCode; | ||||||
|     use uri::RequestUri; |     use uri::RequestUri; | ||||||
|  |  | ||||||
|     use super::{Request, Response, Fresh, Handler, handle_connection}; |     use super::{Request, Response, Fresh, Handler, Worker}; | ||||||
|  |  | ||||||
|     #[test] |     #[test] | ||||||
|     fn test_check_continue_default() { |     fn test_check_continue_default() { | ||||||
| @@ -287,7 +320,7 @@ mod tests { | |||||||
|             res.start().unwrap().end().unwrap(); |             res.start().unwrap().end().unwrap(); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         handle_connection(&mut mock, &handle); |         Worker(&handle).handle_connection(&mut mock); | ||||||
|         let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; |         let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; | ||||||
|         assert_eq!(&mock.write[..cont.len()], cont); |         assert_eq!(&mock.write[..cont.len()], cont); | ||||||
|         let res = b"HTTP/1.1 200 OK\r\n"; |         let res = b"HTTP/1.1 200 OK\r\n"; | ||||||
| @@ -316,7 +349,7 @@ mod tests { | |||||||
|             1234567890\ |             1234567890\ | ||||||
|         "); |         "); | ||||||
|  |  | ||||||
|         handle_connection(&mut mock, &Reject); |         Worker(&Reject).handle_connection(&mut mock); | ||||||
|         assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]); |         assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -28,7 +28,7 @@ pub struct Response<'a, W: Any = Fresh> { | |||||||
|     // The status code for the request. |     // The status code for the request. | ||||||
|     status: status::StatusCode, |     status: status::StatusCode, | ||||||
|     // The outgoing headers on this response. |     // The outgoing headers on this response. | ||||||
|     headers: header::Headers, |     headers: &'a mut header::Headers, | ||||||
|  |  | ||||||
|     _writing: PhantomData<W> |     _writing: PhantomData<W> | ||||||
| } | } | ||||||
| @@ -39,13 +39,13 @@ impl<'a, W: Any> Response<'a, W> { | |||||||
|     pub fn status(&self) -> status::StatusCode { self.status } |     pub fn status(&self) -> status::StatusCode { self.status } | ||||||
|  |  | ||||||
|     /// The headers of this response. |     /// The headers of this response. | ||||||
|     pub fn headers(&self) -> &header::Headers { &self.headers } |     pub fn headers(&self) -> &header::Headers { &*self.headers } | ||||||
|  |  | ||||||
|     /// Construct a Response from its constituent parts. |     /// Construct a Response from its constituent parts. | ||||||
|     pub fn construct(version: version::HttpVersion, |     pub fn construct(version: version::HttpVersion, | ||||||
|                      body: HttpWriter<&'a mut (Write + 'a)>, |                      body: HttpWriter<&'a mut (Write + 'a)>, | ||||||
|                      status: status::StatusCode, |                      status: status::StatusCode, | ||||||
|                      headers: header::Headers) -> Response<'a, Fresh> { |                      headers: &'a mut header::Headers) -> Response<'a, Fresh> { | ||||||
|         Response { |         Response { | ||||||
|             status: status, |             status: status, | ||||||
|             version: version, |             version: version, | ||||||
| @@ -57,7 +57,7 @@ impl<'a, W: Any> Response<'a, W> { | |||||||
|  |  | ||||||
|     /// Deconstruct this Response into its constituent parts. |     /// Deconstruct this Response into its constituent parts. | ||||||
|     pub fn deconstruct(self) -> (version::HttpVersion, HttpWriter<&'a mut (Write + 'a)>, |     pub fn deconstruct(self) -> (version::HttpVersion, HttpWriter<&'a mut (Write + 'a)>, | ||||||
|                                  status::StatusCode, header::Headers) { |                                  status::StatusCode, &'a mut header::Headers) { | ||||||
|         unsafe { |         unsafe { | ||||||
|             let parts = ( |             let parts = ( | ||||||
|                 self.version, |                 self.version, | ||||||
| @@ -114,11 +114,11 @@ impl<'a, W: Any> Response<'a, W> { | |||||||
| impl<'a> Response<'a, Fresh> { | impl<'a> Response<'a, Fresh> { | ||||||
|     /// Creates a new Response that can be used to write to a network stream. |     /// Creates a new Response that can be used to write to a network stream. | ||||||
|     #[inline] |     #[inline] | ||||||
|     pub fn new(stream: &'a mut (Write + 'a)) -> Response<'a, Fresh> { |     pub fn new(stream: &'a mut (Write + 'a), headers: &'a mut header::Headers) -> Response<'a, Fresh> { | ||||||
|         Response { |         Response { | ||||||
|             status: status::StatusCode::Ok, |             status: status::StatusCode::Ok, | ||||||
|             version: version::HttpVersion::Http11, |             version: version::HttpVersion::Http11, | ||||||
|             headers: header::Headers::new(), |             headers: headers, | ||||||
|             body: ThroughWriter(stream), |             body: ThroughWriter(stream), | ||||||
|             _writing: PhantomData, |             _writing: PhantomData, | ||||||
|         } |         } | ||||||
| @@ -165,7 +165,7 @@ impl<'a> Response<'a, Fresh> { | |||||||
|  |  | ||||||
|     /// Get a mutable reference to the Headers. |     /// Get a mutable reference to the Headers. | ||||||
|     #[inline] |     #[inline] | ||||||
|     pub fn headers_mut(&mut self) -> &mut header::Headers { &mut self.headers } |     pub fn headers_mut(&mut self) -> &mut header::Headers { self.headers } | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -231,6 +231,7 @@ impl<'a, T: Any> Drop for Response<'a, T> { | |||||||
|  |  | ||||||
| #[cfg(test)] | #[cfg(test)] | ||||||
| mod tests { | mod tests { | ||||||
|  |     use header::Headers; | ||||||
|     use mock::MockStream; |     use mock::MockStream; | ||||||
|     use super::Response; |     use super::Response; | ||||||
|  |  | ||||||
| @@ -252,9 +253,10 @@ mod tests { | |||||||
|  |  | ||||||
|     #[test] |     #[test] | ||||||
|     fn test_fresh_start() { |     fn test_fresh_start() { | ||||||
|  |         let mut headers = Headers::new(); | ||||||
|         let mut stream = MockStream::new(); |         let mut stream = MockStream::new(); | ||||||
|         { |         { | ||||||
|             let res = Response::new(&mut stream); |             let res = Response::new(&mut stream, &mut headers); | ||||||
|             res.start().unwrap().deconstruct(); |             res.start().unwrap().deconstruct(); | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -268,9 +270,10 @@ mod tests { | |||||||
|  |  | ||||||
|     #[test] |     #[test] | ||||||
|     fn test_streaming_end() { |     fn test_streaming_end() { | ||||||
|  |         let mut headers = Headers::new(); | ||||||
|         let mut stream = MockStream::new(); |         let mut stream = MockStream::new(); | ||||||
|         { |         { | ||||||
|             let res = Response::new(&mut stream); |             let res = Response::new(&mut stream, &mut headers); | ||||||
|             res.start().unwrap().end().unwrap(); |             res.start().unwrap().end().unwrap(); | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -287,9 +290,10 @@ mod tests { | |||||||
|     #[test] |     #[test] | ||||||
|     fn test_fresh_drop() { |     fn test_fresh_drop() { | ||||||
|         use status::StatusCode; |         use status::StatusCode; | ||||||
|  |         let mut headers = Headers::new(); | ||||||
|         let mut stream = MockStream::new(); |         let mut stream = MockStream::new(); | ||||||
|         { |         { | ||||||
|             let mut res = Response::new(&mut stream); |             let mut res = Response::new(&mut stream, &mut headers); | ||||||
|             *res.status_mut() = StatusCode::NotFound; |             *res.status_mut() = StatusCode::NotFound; | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -307,9 +311,10 @@ mod tests { | |||||||
|     fn test_streaming_drop() { |     fn test_streaming_drop() { | ||||||
|         use std::io::Write; |         use std::io::Write; | ||||||
|         use status::StatusCode; |         use status::StatusCode; | ||||||
|  |         let mut headers = Headers::new(); | ||||||
|         let mut stream = MockStream::new(); |         let mut stream = MockStream::new(); | ||||||
|         { |         { | ||||||
|             let mut res = Response::new(&mut stream); |             let mut res = Response::new(&mut stream, &mut headers); | ||||||
|             *res.status_mut() = StatusCode::NotFound; |             *res.status_mut() = StatusCode::NotFound; | ||||||
|             let mut stream = res.start().unwrap(); |             let mut stream = res.start().unwrap(); | ||||||
|             stream.write_all(b"foo").unwrap(); |             stream.write_all(b"foo").unwrap(); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user