Merge pull request #527 from hyperium/server-keep-alive
feat(server): check Response headers for Connection: close in keep_alive loop
This commit is contained in:
@@ -33,7 +33,7 @@ pub use net::{Fresh, Streaming};
|
|||||||
|
|
||||||
use Error;
|
use Error;
|
||||||
use buffer::BufReader;
|
use buffer::BufReader;
|
||||||
use header::{Headers, Expect};
|
use header::{Headers, Expect, Connection};
|
||||||
use http;
|
use http;
|
||||||
use method::Method;
|
use method::Method;
|
||||||
use net::{NetworkListener, NetworkStream, HttpListener};
|
use net::{NetworkListener, NetworkStream, HttpListener};
|
||||||
@@ -142,7 +142,7 @@ L: NetworkListener + Send + 'static {
|
|||||||
|
|
||||||
debug!("threads = {:?}", threads);
|
debug!("threads = {:?}", threads);
|
||||||
let pool = ListenerPool::new(listener.clone());
|
let pool = ListenerPool::new(listener.clone());
|
||||||
let work = move |mut stream| handle_connection(&mut stream, &handler);
|
let work = move |mut stream| Worker(&handler).handle_connection(&mut stream);
|
||||||
|
|
||||||
let guard = thread::spawn(move || pool.accept(work, threads));
|
let guard = thread::spawn(move || pool.accept(work, threads));
|
||||||
|
|
||||||
@@ -152,62 +152,95 @@ L: NetworkListener + Send + 'static {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_connection<'h, S, H>(mut stream: &mut S, handler: &'h H)
|
struct Worker<'a, H: Handler + 'static>(&'a H);
|
||||||
where S: NetworkStream + Clone, H: Handler {
|
|
||||||
debug!("Incoming stream");
|
|
||||||
let addr = match stream.peer_addr() {
|
|
||||||
Ok(addr) => addr,
|
|
||||||
Err(e) => {
|
|
||||||
error!("Peer Name error: {:?}", e);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// FIXME: Use Type ascription
|
impl<'a, H: Handler + 'static> Worker<'a, H> {
|
||||||
let stream_clone: &mut NetworkStream = &mut stream.clone();
|
|
||||||
let mut rdr = BufReader::new(stream_clone);
|
|
||||||
let mut wrt = BufWriter::new(stream);
|
|
||||||
|
|
||||||
let mut keep_alive = true;
|
fn handle_connection<S>(&self, mut stream: &mut S) where S: NetworkStream + Clone {
|
||||||
while keep_alive {
|
debug!("Incoming stream");
|
||||||
let req = match Request::new(&mut rdr, addr) {
|
let addr = match stream.peer_addr() {
|
||||||
Ok(req) => req,
|
Ok(addr) => addr,
|
||||||
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
|
|
||||||
trace!("tcp closed, cancelling keep-alive loop");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Err(Error::Io(e)) => {
|
|
||||||
debug!("ioerror in keepalive loop = {:?}", e);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
//TODO: send a 400 response
|
error!("Peer Name error: {:?}", e);
|
||||||
error!("request error = {:?}", e);
|
return;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) {
|
// FIXME: Use Type ascription
|
||||||
let status = handler.check_continue((&req.method, &req.uri, &req.headers));
|
let stream_clone: &mut NetworkStream = &mut stream.clone();
|
||||||
match write!(&mut wrt, "{} {}\r\n\r\n", Http11, status) {
|
let rdr = BufReader::new(stream_clone);
|
||||||
|
let wrt = BufWriter::new(stream);
|
||||||
|
|
||||||
|
self.keep_alive_loop(rdr, wrt, addr);
|
||||||
|
debug!("keep_alive loop ending for {}", addr);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>, mut wrt: W, addr: SocketAddr) {
|
||||||
|
let mut keep_alive = true;
|
||||||
|
while keep_alive {
|
||||||
|
let req = match Request::new(&mut rdr, addr) {
|
||||||
|
Ok(req) => req,
|
||||||
|
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
|
||||||
|
trace!("tcp closed, cancelling keep-alive loop");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(Error::Io(e)) => {
|
||||||
|
debug!("ioerror in keepalive loop = {:?}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
//TODO: send a 400 response
|
||||||
|
error!("request error = {:?}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
if !self.handle_expect(&req, &mut wrt) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
keep_alive = http::should_keep_alive(req.version, &req.headers);
|
||||||
|
let version = req.version;
|
||||||
|
let mut res_headers = Headers::new();
|
||||||
|
if !keep_alive {
|
||||||
|
res_headers.set(Connection::close());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let mut res = Response::new(&mut wrt, &mut res_headers);
|
||||||
|
res.version = version;
|
||||||
|
self.0.handle(req, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the request was keep-alive, we need to check that the server agrees
|
||||||
|
// if it wasn't, then the server cannot force it to be true anyways
|
||||||
|
if keep_alive {
|
||||||
|
keep_alive = http::should_keep_alive(version, &res_headers);
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!("keep_alive = {:?} for {}", keep_alive, addr);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool {
|
||||||
|
if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) {
|
||||||
|
let status = self.0.check_continue((&req.method, &req.uri, &req.headers));
|
||||||
|
match write!(wrt, "{} {}\r\n\r\n", Http11, status) {
|
||||||
Ok(..) => (),
|
Ok(..) => (),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("error writing 100-continue: {:?}", e);
|
error!("error writing 100-continue: {:?}", e);
|
||||||
break;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if status != StatusCode::Continue {
|
if status != StatusCode::Continue {
|
||||||
debug!("non-100 status ({}) for Expect 100 request", status);
|
debug!("non-100 status ({}) for Expect 100 request", status);
|
||||||
break;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
keep_alive = http::should_keep_alive(req.version, &req.headers);
|
true
|
||||||
let mut res = Response::new(&mut wrt);
|
|
||||||
res.version = req.version;
|
|
||||||
handler.handle(req, res);
|
|
||||||
debug!("keep_alive = {:?}", keep_alive);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -270,7 +303,7 @@ mod tests {
|
|||||||
use status::StatusCode;
|
use status::StatusCode;
|
||||||
use uri::RequestUri;
|
use uri::RequestUri;
|
||||||
|
|
||||||
use super::{Request, Response, Fresh, Handler, handle_connection};
|
use super::{Request, Response, Fresh, Handler, Worker};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_check_continue_default() {
|
fn test_check_continue_default() {
|
||||||
@@ -287,7 +320,7 @@ mod tests {
|
|||||||
res.start().unwrap().end().unwrap();
|
res.start().unwrap().end().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
handle_connection(&mut mock, &handle);
|
Worker(&handle).handle_connection(&mut mock);
|
||||||
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
|
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
|
||||||
assert_eq!(&mock.write[..cont.len()], cont);
|
assert_eq!(&mock.write[..cont.len()], cont);
|
||||||
let res = b"HTTP/1.1 200 OK\r\n";
|
let res = b"HTTP/1.1 200 OK\r\n";
|
||||||
@@ -316,7 +349,7 @@ mod tests {
|
|||||||
1234567890\
|
1234567890\
|
||||||
");
|
");
|
||||||
|
|
||||||
handle_connection(&mut mock, &Reject);
|
Worker(&Reject).handle_connection(&mut mock);
|
||||||
assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]);
|
assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ pub struct Response<'a, W: Any = Fresh> {
|
|||||||
// The status code for the request.
|
// The status code for the request.
|
||||||
status: status::StatusCode,
|
status: status::StatusCode,
|
||||||
// The outgoing headers on this response.
|
// The outgoing headers on this response.
|
||||||
headers: header::Headers,
|
headers: &'a mut header::Headers,
|
||||||
|
|
||||||
_writing: PhantomData<W>
|
_writing: PhantomData<W>
|
||||||
}
|
}
|
||||||
@@ -39,13 +39,13 @@ impl<'a, W: Any> Response<'a, W> {
|
|||||||
pub fn status(&self) -> status::StatusCode { self.status }
|
pub fn status(&self) -> status::StatusCode { self.status }
|
||||||
|
|
||||||
/// The headers of this response.
|
/// The headers of this response.
|
||||||
pub fn headers(&self) -> &header::Headers { &self.headers }
|
pub fn headers(&self) -> &header::Headers { &*self.headers }
|
||||||
|
|
||||||
/// Construct a Response from its constituent parts.
|
/// Construct a Response from its constituent parts.
|
||||||
pub fn construct(version: version::HttpVersion,
|
pub fn construct(version: version::HttpVersion,
|
||||||
body: HttpWriter<&'a mut (Write + 'a)>,
|
body: HttpWriter<&'a mut (Write + 'a)>,
|
||||||
status: status::StatusCode,
|
status: status::StatusCode,
|
||||||
headers: header::Headers) -> Response<'a, Fresh> {
|
headers: &'a mut header::Headers) -> Response<'a, Fresh> {
|
||||||
Response {
|
Response {
|
||||||
status: status,
|
status: status,
|
||||||
version: version,
|
version: version,
|
||||||
@@ -57,7 +57,7 @@ impl<'a, W: Any> Response<'a, W> {
|
|||||||
|
|
||||||
/// Deconstruct this Response into its constituent parts.
|
/// Deconstruct this Response into its constituent parts.
|
||||||
pub fn deconstruct(self) -> (version::HttpVersion, HttpWriter<&'a mut (Write + 'a)>,
|
pub fn deconstruct(self) -> (version::HttpVersion, HttpWriter<&'a mut (Write + 'a)>,
|
||||||
status::StatusCode, header::Headers) {
|
status::StatusCode, &'a mut header::Headers) {
|
||||||
unsafe {
|
unsafe {
|
||||||
let parts = (
|
let parts = (
|
||||||
self.version,
|
self.version,
|
||||||
@@ -114,11 +114,11 @@ impl<'a, W: Any> Response<'a, W> {
|
|||||||
impl<'a> Response<'a, Fresh> {
|
impl<'a> Response<'a, Fresh> {
|
||||||
/// Creates a new Response that can be used to write to a network stream.
|
/// Creates a new Response that can be used to write to a network stream.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn new(stream: &'a mut (Write + 'a)) -> Response<'a, Fresh> {
|
pub fn new(stream: &'a mut (Write + 'a), headers: &'a mut header::Headers) -> Response<'a, Fresh> {
|
||||||
Response {
|
Response {
|
||||||
status: status::StatusCode::Ok,
|
status: status::StatusCode::Ok,
|
||||||
version: version::HttpVersion::Http11,
|
version: version::HttpVersion::Http11,
|
||||||
headers: header::Headers::new(),
|
headers: headers,
|
||||||
body: ThroughWriter(stream),
|
body: ThroughWriter(stream),
|
||||||
_writing: PhantomData,
|
_writing: PhantomData,
|
||||||
}
|
}
|
||||||
@@ -165,7 +165,7 @@ impl<'a> Response<'a, Fresh> {
|
|||||||
|
|
||||||
/// Get a mutable reference to the Headers.
|
/// Get a mutable reference to the Headers.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn headers_mut(&mut self) -> &mut header::Headers { &mut self.headers }
|
pub fn headers_mut(&mut self) -> &mut header::Headers { self.headers }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -231,6 +231,7 @@ impl<'a, T: Any> Drop for Response<'a, T> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use header::Headers;
|
||||||
use mock::MockStream;
|
use mock::MockStream;
|
||||||
use super::Response;
|
use super::Response;
|
||||||
|
|
||||||
@@ -252,9 +253,10 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_fresh_start() {
|
fn test_fresh_start() {
|
||||||
|
let mut headers = Headers::new();
|
||||||
let mut stream = MockStream::new();
|
let mut stream = MockStream::new();
|
||||||
{
|
{
|
||||||
let res = Response::new(&mut stream);
|
let res = Response::new(&mut stream, &mut headers);
|
||||||
res.start().unwrap().deconstruct();
|
res.start().unwrap().deconstruct();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -268,9 +270,10 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_streaming_end() {
|
fn test_streaming_end() {
|
||||||
|
let mut headers = Headers::new();
|
||||||
let mut stream = MockStream::new();
|
let mut stream = MockStream::new();
|
||||||
{
|
{
|
||||||
let res = Response::new(&mut stream);
|
let res = Response::new(&mut stream, &mut headers);
|
||||||
res.start().unwrap().end().unwrap();
|
res.start().unwrap().end().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -287,9 +290,10 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_fresh_drop() {
|
fn test_fresh_drop() {
|
||||||
use status::StatusCode;
|
use status::StatusCode;
|
||||||
|
let mut headers = Headers::new();
|
||||||
let mut stream = MockStream::new();
|
let mut stream = MockStream::new();
|
||||||
{
|
{
|
||||||
let mut res = Response::new(&mut stream);
|
let mut res = Response::new(&mut stream, &mut headers);
|
||||||
*res.status_mut() = StatusCode::NotFound;
|
*res.status_mut() = StatusCode::NotFound;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -307,9 +311,10 @@ mod tests {
|
|||||||
fn test_streaming_drop() {
|
fn test_streaming_drop() {
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use status::StatusCode;
|
use status::StatusCode;
|
||||||
|
let mut headers = Headers::new();
|
||||||
let mut stream = MockStream::new();
|
let mut stream = MockStream::new();
|
||||||
{
|
{
|
||||||
let mut res = Response::new(&mut stream);
|
let mut res = Response::new(&mut stream, &mut headers);
|
||||||
*res.status_mut() = StatusCode::NotFound;
|
*res.status_mut() = StatusCode::NotFound;
|
||||||
let mut stream = res.start().unwrap();
|
let mut stream = res.start().unwrap();
|
||||||
stream.write_all(b"foo").unwrap();
|
stream.write_all(b"foo").unwrap();
|
||||||
|
|||||||
Reference in New Issue
Block a user