Merge pull request #527 from hyperium/server-keep-alive

feat(server): check Response headers for Connection: close in keep_alive loop
This commit is contained in:
Sean McArthur
2015-05-12 18:18:41 -07:00
2 changed files with 93 additions and 55 deletions

View File

@@ -33,7 +33,7 @@ pub use net::{Fresh, Streaming};
use Error; use Error;
use buffer::BufReader; use buffer::BufReader;
use header::{Headers, Expect}; use header::{Headers, Expect, Connection};
use http; use http;
use method::Method; use method::Method;
use net::{NetworkListener, NetworkStream, HttpListener}; use net::{NetworkListener, NetworkStream, HttpListener};
@@ -142,7 +142,7 @@ L: NetworkListener + Send + 'static {
debug!("threads = {:?}", threads); debug!("threads = {:?}", threads);
let pool = ListenerPool::new(listener.clone()); let pool = ListenerPool::new(listener.clone());
let work = move |mut stream| handle_connection(&mut stream, &handler); let work = move |mut stream| Worker(&handler).handle_connection(&mut stream);
let guard = thread::spawn(move || pool.accept(work, threads)); let guard = thread::spawn(move || pool.accept(work, threads));
@@ -152,8 +152,11 @@ L: NetworkListener + Send + 'static {
}) })
} }
fn handle_connection<'h, S, H>(mut stream: &mut S, handler: &'h H) struct Worker<'a, H: Handler + 'static>(&'a H);
where S: NetworkStream + Clone, H: Handler {
impl<'a, H: Handler + 'static> Worker<'a, H> {
fn handle_connection<S>(&self, mut stream: &mut S) where S: NetworkStream + Clone {
debug!("Incoming stream"); debug!("Incoming stream");
let addr = match stream.peer_addr() { let addr = match stream.peer_addr() {
Ok(addr) => addr, Ok(addr) => addr,
@@ -165,9 +168,14 @@ where S: NetworkStream + Clone, H: Handler {
// FIXME: Use Type ascription // FIXME: Use Type ascription
let stream_clone: &mut NetworkStream = &mut stream.clone(); let stream_clone: &mut NetworkStream = &mut stream.clone();
let mut rdr = BufReader::new(stream_clone); let rdr = BufReader::new(stream_clone);
let mut wrt = BufWriter::new(stream); let wrt = BufWriter::new(stream);
self.keep_alive_loop(rdr, wrt, addr);
debug!("keep_alive loop ending for {}", addr);
}
fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>, mut wrt: W, addr: SocketAddr) {
let mut keep_alive = true; let mut keep_alive = true;
while keep_alive { while keep_alive {
let req = match Request::new(&mut rdr, addr) { let req = match Request::new(&mut rdr, addr) {
@@ -187,27 +195,52 @@ where S: NetworkStream + Clone, H: Handler {
} }
}; };
if !self.handle_expect(&req, &mut wrt) {
break;
}
keep_alive = http::should_keep_alive(req.version, &req.headers);
let version = req.version;
let mut res_headers = Headers::new();
if !keep_alive {
res_headers.set(Connection::close());
}
{
let mut res = Response::new(&mut wrt, &mut res_headers);
res.version = version;
self.0.handle(req, res);
}
// if the request was keep-alive, we need to check that the server agrees
// if it wasn't, then the server cannot force it to be true anyways
if keep_alive {
keep_alive = http::should_keep_alive(version, &res_headers);
}
debug!("keep_alive = {:?} for {}", keep_alive, addr);
}
}
fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool {
if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) { if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) {
let status = handler.check_continue((&req.method, &req.uri, &req.headers)); let status = self.0.check_continue((&req.method, &req.uri, &req.headers));
match write!(&mut wrt, "{} {}\r\n\r\n", Http11, status) { match write!(wrt, "{} {}\r\n\r\n", Http11, status) {
Ok(..) => (), Ok(..) => (),
Err(e) => { Err(e) => {
error!("error writing 100-continue: {:?}", e); error!("error writing 100-continue: {:?}", e);
break; return false;
} }
} }
if status != StatusCode::Continue { if status != StatusCode::Continue {
debug!("non-100 status ({}) for Expect 100 request", status); debug!("non-100 status ({}) for Expect 100 request", status);
break; return false;
} }
} }
keep_alive = http::should_keep_alive(req.version, &req.headers); true
let mut res = Response::new(&mut wrt);
res.version = req.version;
handler.handle(req, res);
debug!("keep_alive = {:?}", keep_alive);
} }
} }
@@ -270,7 +303,7 @@ mod tests {
use status::StatusCode; use status::StatusCode;
use uri::RequestUri; use uri::RequestUri;
use super::{Request, Response, Fresh, Handler, handle_connection}; use super::{Request, Response, Fresh, Handler, Worker};
#[test] #[test]
fn test_check_continue_default() { fn test_check_continue_default() {
@@ -287,7 +320,7 @@ mod tests {
res.start().unwrap().end().unwrap(); res.start().unwrap().end().unwrap();
} }
handle_connection(&mut mock, &handle); Worker(&handle).handle_connection(&mut mock);
let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
assert_eq!(&mock.write[..cont.len()], cont); assert_eq!(&mock.write[..cont.len()], cont);
let res = b"HTTP/1.1 200 OK\r\n"; let res = b"HTTP/1.1 200 OK\r\n";
@@ -316,7 +349,7 @@ mod tests {
1234567890\ 1234567890\
"); ");
handle_connection(&mut mock, &Reject); Worker(&Reject).handle_connection(&mut mock);
assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]); assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]);
} }
} }

View File

@@ -28,7 +28,7 @@ pub struct Response<'a, W: Any = Fresh> {
// The status code for the request. // The status code for the request.
status: status::StatusCode, status: status::StatusCode,
// The outgoing headers on this response. // The outgoing headers on this response.
headers: header::Headers, headers: &'a mut header::Headers,
_writing: PhantomData<W> _writing: PhantomData<W>
} }
@@ -39,13 +39,13 @@ impl<'a, W: Any> Response<'a, W> {
pub fn status(&self) -> status::StatusCode { self.status } pub fn status(&self) -> status::StatusCode { self.status }
/// The headers of this response. /// The headers of this response.
pub fn headers(&self) -> &header::Headers { &self.headers } pub fn headers(&self) -> &header::Headers { &*self.headers }
/// Construct a Response from its constituent parts. /// Construct a Response from its constituent parts.
pub fn construct(version: version::HttpVersion, pub fn construct(version: version::HttpVersion,
body: HttpWriter<&'a mut (Write + 'a)>, body: HttpWriter<&'a mut (Write + 'a)>,
status: status::StatusCode, status: status::StatusCode,
headers: header::Headers) -> Response<'a, Fresh> { headers: &'a mut header::Headers) -> Response<'a, Fresh> {
Response { Response {
status: status, status: status,
version: version, version: version,
@@ -57,7 +57,7 @@ impl<'a, W: Any> Response<'a, W> {
/// Deconstruct this Response into its constituent parts. /// Deconstruct this Response into its constituent parts.
pub fn deconstruct(self) -> (version::HttpVersion, HttpWriter<&'a mut (Write + 'a)>, pub fn deconstruct(self) -> (version::HttpVersion, HttpWriter<&'a mut (Write + 'a)>,
status::StatusCode, header::Headers) { status::StatusCode, &'a mut header::Headers) {
unsafe { unsafe {
let parts = ( let parts = (
self.version, self.version,
@@ -114,11 +114,11 @@ impl<'a, W: Any> Response<'a, W> {
impl<'a> Response<'a, Fresh> { impl<'a> Response<'a, Fresh> {
/// Creates a new Response that can be used to write to a network stream. /// Creates a new Response that can be used to write to a network stream.
#[inline] #[inline]
pub fn new(stream: &'a mut (Write + 'a)) -> Response<'a, Fresh> { pub fn new(stream: &'a mut (Write + 'a), headers: &'a mut header::Headers) -> Response<'a, Fresh> {
Response { Response {
status: status::StatusCode::Ok, status: status::StatusCode::Ok,
version: version::HttpVersion::Http11, version: version::HttpVersion::Http11,
headers: header::Headers::new(), headers: headers,
body: ThroughWriter(stream), body: ThroughWriter(stream),
_writing: PhantomData, _writing: PhantomData,
} }
@@ -165,7 +165,7 @@ impl<'a> Response<'a, Fresh> {
/// Get a mutable reference to the Headers. /// Get a mutable reference to the Headers.
#[inline] #[inline]
pub fn headers_mut(&mut self) -> &mut header::Headers { &mut self.headers } pub fn headers_mut(&mut self) -> &mut header::Headers { self.headers }
} }
@@ -231,6 +231,7 @@ impl<'a, T: Any> Drop for Response<'a, T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use header::Headers;
use mock::MockStream; use mock::MockStream;
use super::Response; use super::Response;
@@ -252,9 +253,10 @@ mod tests {
#[test] #[test]
fn test_fresh_start() { fn test_fresh_start() {
let mut headers = Headers::new();
let mut stream = MockStream::new(); let mut stream = MockStream::new();
{ {
let res = Response::new(&mut stream); let res = Response::new(&mut stream, &mut headers);
res.start().unwrap().deconstruct(); res.start().unwrap().deconstruct();
} }
@@ -268,9 +270,10 @@ mod tests {
#[test] #[test]
fn test_streaming_end() { fn test_streaming_end() {
let mut headers = Headers::new();
let mut stream = MockStream::new(); let mut stream = MockStream::new();
{ {
let res = Response::new(&mut stream); let res = Response::new(&mut stream, &mut headers);
res.start().unwrap().end().unwrap(); res.start().unwrap().end().unwrap();
} }
@@ -287,9 +290,10 @@ mod tests {
#[test] #[test]
fn test_fresh_drop() { fn test_fresh_drop() {
use status::StatusCode; use status::StatusCode;
let mut headers = Headers::new();
let mut stream = MockStream::new(); let mut stream = MockStream::new();
{ {
let mut res = Response::new(&mut stream); let mut res = Response::new(&mut stream, &mut headers);
*res.status_mut() = StatusCode::NotFound; *res.status_mut() = StatusCode::NotFound;
} }
@@ -307,9 +311,10 @@ mod tests {
fn test_streaming_drop() { fn test_streaming_drop() {
use std::io::Write; use std::io::Write;
use status::StatusCode; use status::StatusCode;
let mut headers = Headers::new();
let mut stream = MockStream::new(); let mut stream = MockStream::new();
{ {
let mut res = Response::new(&mut stream); let mut res = Response::new(&mut stream, &mut headers);
*res.status_mut() = StatusCode::NotFound; *res.status_mut() = StatusCode::NotFound;
let mut stream = res.start().unwrap(); let mut stream = res.start().unwrap();
stream.write_all(b"foo").unwrap(); stream.write_all(b"foo").unwrap();