From 673e5cb1a3dadea178e51677fa660a1258610ae8 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Thu, 13 Jul 2017 11:08:14 -0700 Subject: [PATCH] fix(server): improve detection of when a Response can have a body By knowing if the incoming Request was a HEAD, or checking for 204 or 304 status codes, the server will do a better job of either adding or removing `Content-Length` and `Transfer-Encoding` headers. Closes #1257 --- src/http/conn.rs | 31 +- src/http/h1/parse.rs | 219 +++++++----- src/http/mod.rs | 5 +- src/lib.rs | 2 +- tests/client.rs | 38 ++- tests/server.rs | 787 +++++++++++++++++++++++-------------------- 6 files changed, 609 insertions(+), 473 deletions(-) diff --git a/src/http/conn.rs b/src/http/conn.rs index 7b91efc9..ed79b09e 100644 --- a/src/http/conn.rs +++ b/src/http/conn.rs @@ -7,10 +7,10 @@ use futures::task::Task; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_proto::streaming::pipeline::{Frame, Transport}; -use header::{ContentLength, TransferEncoding}; use http::{self, Http1Transaction, DebugTruncate}; use http::io::{Cursor, Buffered}; use http::h1::{Encoder, Decoder}; +use method::Method; use version::HttpVersion; @@ -37,10 +37,11 @@ where I: AsyncRead + AsyncWrite, Conn { io: Buffered::new(io), state: State { + keep_alive: keep_alive, + method: None, + read_task: None, reading: Reading::Init, writing: Writing::Init, - read_task: None, - keep_alive: keep_alive, }, _marker: PhantomData, } @@ -103,7 +104,7 @@ where I: AsyncRead + AsyncWrite, match version { HttpVersion::Http10 | HttpVersion::Http11 => { - let decoder = match T::decoder(&head) { + let decoder = match T::decoder(&head, &mut self.state.method) { Ok(d) => d, Err(e) => { debug!("decoder error = {:?}", e); @@ -234,17 +235,8 @@ where I: AsyncRead + AsyncWrite, } } - fn write_head(&mut self, mut head: http::MessageHead, body: bool) { + fn write_head(&mut self, head: http::MessageHead, body: bool) { debug_assert!(self.can_write_head()); - if !body { - head.headers.remove::(); - //TODO: check that this isn't a response to a HEAD - //request, which could include the content-length - //even if no body is to be written - if T::should_set_length(&head) { - head.headers.set(ContentLength(0)); - } - } let wants_keep_alive = head.should_keep_alive(); self.state.keep_alive &= wants_keep_alive; @@ -256,8 +248,8 @@ where I: AsyncRead + AsyncWrite, buf.extend_from_slice(pending.buf()); } } - let encoder = T::encode(head, buf); - self.state.writing = if body { + let encoder = T::encode(head, body, &mut self.state.method, buf); + self.state.writing = if !encoder.is_eof() { Writing::Body(encoder, None) } else { Writing::KeepAlive @@ -493,10 +485,11 @@ impl, T, K: fmt::Debug> fmt::Debug for Conn { } struct State { + keep_alive: K, + method: Option, + read_task: Option, reading: Reading, writing: Writing, - read_task: Option, - keep_alive: K, } #[derive(Debug)] @@ -522,6 +515,7 @@ impl, K: fmt::Debug> fmt::Debug for State { .field("reading", &self.reading) .field("writing", &self.writing) .field("keep_alive", &self.keep_alive) + .field("method", &self.method) .field("read_task", &self.read_task) .finish() } @@ -641,6 +635,7 @@ impl State { } fn idle(&mut self) { + self.method = None; self.reading = Reading::Init; self.writing = Writing::Init; self.keep_alive.idle(); diff --git a/src/http/h1/parse.rs b/src/http/h1/parse.rs index f9d964ab..71bcb77c 100644 --- a/src/http/h1/parse.rs +++ b/src/http/h1/parse.rs @@ -5,7 +5,8 @@ use httparse; use bytes::{BytesMut, Bytes}; use header::{self, Headers, ContentLength, TransferEncoding}; -use http::{MessageHead, RawStatus, Http1Transaction, ParseResult, ServerTransaction, ClientTransaction, RequestLine}; +use http::{MessageHead, RawStatus, Http1Transaction, ParseResult, + ServerTransaction, ClientTransaction, RequestLine, RequestHead}; use http::h1::{Encoder, Decoder, date}; use method::Method; use status::StatusCode; @@ -72,8 +73,11 @@ impl Http1Transaction for ServerTransaction { }, len))) } - fn decoder(head: &MessageHead) -> ::Result { + fn decoder(head: &MessageHead, method: &mut Option) -> ::Result { use ::header; + + *method = Some(head.subject.0.clone()); + // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 // 1. (irrelevant to Request) // 2. (irrelevant to Request) @@ -105,30 +109,11 @@ impl Http1Transaction for ServerTransaction { } - fn encode(mut head: MessageHead, dst: &mut Vec) -> Encoder { - use ::header; - trace!("writing head: {:?}", head); + fn encode(mut head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> Encoder { + trace!("ServerTransaction::encode head={:?}, has_body={}, method={:?}", + head, has_body, method); - let len = head.headers.get::().map(|n| **n); - - let body = if let Some(len) = len { - Encoder::length(len) - } else { - let encodings = match head.headers.get_mut::() { - Some(&mut header::TransferEncoding(ref mut encodings)) => { - if encodings.last() != Some(&header::Encoding::Chunked) { - encodings.push(header::Encoding::Chunked); - } - false - }, - None => true - }; - - if encodings { - head.headers.set(header::TransferEncoding(vec![header::Encoding::Chunked])); - } - Encoder::chunked() - }; + let body = ServerTransaction::set_length(&mut head, has_body, method.as_ref()); debug!("encode headers = {:?}", head.headers); let init_cap = 30 + head.headers.len() * AVERAGE_HEADER_SIZE; @@ -150,16 +135,39 @@ impl Http1Transaction for ServerTransaction { extend(dst, b"\r\n"); body } +} - fn should_set_length(head: &MessageHead) -> bool { - //TODO: pass method, check if method == HEAD +impl ServerTransaction { + fn set_length(head: &mut MessageHead, has_body: bool, method: Option<&Method>) -> Encoder { + // these are here thanks to borrowck + // `if method == Some(&Method::Get)` says the RHS doesnt live long enough + const HEAD: Option<&'static Method> = Some(&Method::Head); + const CONNECT: Option<&'static Method> = Some(&Method::Connect); - match head.subject { - // TODO: support for 1xx codes needs improvement everywhere - // would be 100...199 => false - StatusCode::NoContent | - StatusCode::NotModified => false, - _ => true, + let can_have_body = { + if method == HEAD { + false + } else if method == CONNECT && head.subject.is_success() { + false + } else { + match head.subject { + // TODO: support for 1xx codes needs improvement everywhere + // would be 100...199 => false + StatusCode::NoContent | + StatusCode::NotModified => false, + _ => true, + } + } + }; + + if has_body && can_have_body { + set_length(&mut head.headers) + } else { + head.headers.remove::(); + if can_have_body { + head.headers.set(ContentLength(0)); + } + Encoder::length(0) } } } @@ -213,8 +221,7 @@ impl Http1Transaction for ClientTransaction { }, len))) } - fn decoder(inc: &MessageHead) -> ::Result { - use ::header; + fn decoder(inc: &MessageHead, method: &mut Option) -> ::Result { // According to https://tools.ietf.org/html/rfc7230#section-3.3.3 // 1. HEAD responses, and Status 1xx, 204, and 304 cannot have a body. // 2. Status 2xx to a CONNECT cannot have a body. @@ -224,7 +231,21 @@ impl Http1Transaction for ClientTransaction { // 6. (irrelevant to Response) // 7. Read till EOF. - //TODO: need a way to pass the Method that caused this Response + match *method { + Some(Method::Head) => { + return Ok(Decoder::length(0)); + } + Some(Method::Connect) => match inc.subject.0 { + 200...299 => { + return Ok(Decoder::length(0)); + }, + _ => {}, + }, + Some(_) => {}, + None => { + trace!("ClientTransaction::decoder is missing the Method"); + } + } match inc.subject.0 { 100...199 | @@ -251,39 +272,14 @@ impl Http1Transaction for ClientTransaction { } } - fn encode(mut head: MessageHead, dst: &mut Vec) -> Encoder { - trace!("writing head: {:?}", head); + fn encode(mut head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> Encoder { + trace!("ClientTransaction::encode head={:?}, has_body={}, method={:?}", + head, has_body, method); - let mut body = Encoder::length(0); - let expects_no_body = match head.subject.0 { - Method::Head | Method::Get | Method::Connect => true, - _ => false - }; - let mut chunked = false; + *method = Some(head.subject.0.clone()); - if let Some(con_len) = head.headers.get::() { - body = Encoder::length(**con_len); - } else { - chunked = !expects_no_body; - } - - if chunked { - body = Encoder::chunked(); - let encodings = match head.headers.get_mut::() { - Some(encodings) => { - if !encodings.contains(&header::Encoding::Chunked) { - encodings.push(header::Encoding::Chunked); - } - true - }, - None => false - }; - - if !encodings { - head.headers.set(TransferEncoding(vec![header::Encoding::Chunked])); - } - } + let body = ClientTransaction::set_length(&mut head, has_body); debug!("encode headers = {:?}", head.headers); let init_cap = 30 + head.headers.len() * AVERAGE_HEADER_SIZE; @@ -292,16 +288,43 @@ impl Http1Transaction for ClientTransaction { body } +} - - fn should_set_length(head: &MessageHead) -> bool { - match &head.subject.0 { - &Method::Get | &Method::Head => false, - _ => true +impl ClientTransaction { + fn set_length(head: &mut RequestHead, has_body: bool) -> Encoder { + if has_body { + set_length(&mut head.headers) + } else { + head.headers.remove::(); + head.headers.remove::(); + Encoder::length(0) } } } +fn set_length(headers: &mut Headers) -> Encoder { + let len = headers.get::().map(|n| **n); + + if let Some(len) = len { + Encoder::length(len) + } else { + let encodings = match headers.get_mut::() { + Some(&mut header::TransferEncoding(ref mut encodings)) => { + if encodings.last() != Some(&header::Encoding::Chunked) { + encodings.push(header::Encoding::Chunked); + } + false + }, + None => true + }; + + if encodings { + headers.set(header::TransferEncoding(vec![header::Encoding::Chunked])); + } + Encoder::chunked() + } +} + #[derive(Clone, Copy)] struct HeaderIndices { name: (usize, usize), @@ -421,63 +444,83 @@ mod tests { fn test_decoder_request() { use super::Decoder; + let method = &mut None; let mut head = MessageHead::<::http::RequestLine>::default(); head.subject.0 = ::Method::Get; - assert_eq!(Decoder::length(0), ServerTransaction::decoder(&head).unwrap()); + assert_eq!(Decoder::length(0), ServerTransaction::decoder(&head, method).unwrap()); + assert_eq!(*method, Some(::Method::Get)); + head.subject.0 = ::Method::Post; - assert_eq!(Decoder::length(0), ServerTransaction::decoder(&head).unwrap()); + assert_eq!(Decoder::length(0), ServerTransaction::decoder(&head, method).unwrap()); + assert_eq!(*method, Some(::Method::Post)); head.headers.set(TransferEncoding::chunked()); - assert_eq!(Decoder::chunked(), ServerTransaction::decoder(&head).unwrap()); + assert_eq!(Decoder::chunked(), ServerTransaction::decoder(&head, method).unwrap()); // transfer-encoding and content-length = chunked head.headers.set(ContentLength(10)); - assert_eq!(Decoder::chunked(), ServerTransaction::decoder(&head).unwrap()); + assert_eq!(Decoder::chunked(), ServerTransaction::decoder(&head, method).unwrap()); head.headers.remove::(); - assert_eq!(Decoder::length(10), ServerTransaction::decoder(&head).unwrap()); + assert_eq!(Decoder::length(10), ServerTransaction::decoder(&head, method).unwrap()); head.headers.set_raw("Content-Length", vec![b"5".to_vec(), b"5".to_vec()]); - assert_eq!(Decoder::length(5), ServerTransaction::decoder(&head).unwrap()); + assert_eq!(Decoder::length(5), ServerTransaction::decoder(&head, method).unwrap()); head.headers.set_raw("Content-Length", vec![b"10".to_vec(), b"11".to_vec()]); - ServerTransaction::decoder(&head).unwrap_err(); + ServerTransaction::decoder(&head, method).unwrap_err(); head.headers.remove::(); head.headers.set_raw("Transfer-Encoding", "gzip"); - ServerTransaction::decoder(&head).unwrap_err(); + ServerTransaction::decoder(&head, method).unwrap_err(); } #[test] fn test_decoder_response() { use super::Decoder; + let method = &mut Some(::Method::Get); let mut head = MessageHead::<::http::RawStatus>::default(); head.subject.0 = 204; - assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head).unwrap()); + assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap()); head.subject.0 = 304; - assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head).unwrap()); + assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap()); head.subject.0 = 200; - assert_eq!(Decoder::eof(), ClientTransaction::decoder(&head).unwrap()); + assert_eq!(Decoder::eof(), ClientTransaction::decoder(&head, method).unwrap()); + *method = Some(::Method::Head); + assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap()); + + *method = Some(::Method::Connect); + assert_eq!(Decoder::length(0), ClientTransaction::decoder(&head, method).unwrap()); + + + // CONNECT receiving non 200 can have a body + head.subject.0 = 404; + head.headers.set(ContentLength(10)); + assert_eq!(Decoder::length(10), ClientTransaction::decoder(&head, method).unwrap()); + head.headers.remove::(); + + + *method = Some(::Method::Get); head.headers.set(TransferEncoding::chunked()); - assert_eq!(Decoder::chunked(), ClientTransaction::decoder(&head).unwrap()); + assert_eq!(Decoder::chunked(), ClientTransaction::decoder(&head, method).unwrap()); // transfer-encoding and content-length = chunked head.headers.set(ContentLength(10)); - assert_eq!(Decoder::chunked(), ClientTransaction::decoder(&head).unwrap()); + assert_eq!(Decoder::chunked(), ClientTransaction::decoder(&head, method).unwrap()); head.headers.remove::(); - assert_eq!(Decoder::length(10), ClientTransaction::decoder(&head).unwrap()); + assert_eq!(Decoder::length(10), ClientTransaction::decoder(&head, method).unwrap()); head.headers.set_raw("Content-Length", vec![b"5".to_vec(), b"5".to_vec()]); - assert_eq!(Decoder::length(5), ClientTransaction::decoder(&head).unwrap()); + assert_eq!(Decoder::length(5), ClientTransaction::decoder(&head, method).unwrap()); head.headers.set_raw("Content-Length", vec![b"10".to_vec(), b"11".to_vec()]); - ClientTransaction::decoder(&head).unwrap_err(); + ClientTransaction::decoder(&head, method).unwrap_err(); } #[cfg(feature = "nightly")] @@ -541,7 +584,7 @@ mod tests { b.iter(|| { let mut vec = Vec::new(); - ServerTransaction::encode(head.clone(), &mut vec); + ServerTransaction::encode(head.clone(), true, &mut None, &mut vec); assert_eq!(vec.len(), len); ::test::black_box(vec); }) diff --git a/src/http/mod.rs b/src/http/mod.rs index b678e97d..20213b6e 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -144,9 +144,8 @@ pub trait Http1Transaction { type Incoming; type Outgoing: Default; fn parse(bytes: &mut BytesMut) -> ParseResult; - fn decoder(head: &MessageHead) -> ::Result; - fn encode(head: MessageHead, dst: &mut Vec) -> h1::Encoder; - fn should_set_length(head: &MessageHead) -> bool; + fn decoder(head: &MessageHead, method: &mut Option<::Method>) -> ::Result; + fn encode(head: MessageHead, has_body: bool, method: &mut Option, dst: &mut Vec) -> h1::Encoder; } pub type ParseResult = ::Result, usize)>>; diff --git a/src/lib.rs b/src/lib.rs index 0955a475..94dcac52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ #![doc(html_root_url = "https://docs.rs/hyper/0.11.1")] #![deny(missing_docs)] -#![deny(warnings)] +//#![deny(warnings)] #![deny(missing_debug_implementations)] #![cfg_attr(all(test, feature = "nightly"), feature(test))] diff --git a/tests/client.rs b/tests/client.rs index e93426c8..f88a85d3 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -12,7 +12,7 @@ use std::time::Duration; use hyper::client::{Client, Request, HttpConnector}; use hyper::{Method, StatusCode}; -use futures::Future; +use futures::{Future, Stream}; use futures::sync::oneshot; use tokio_core::reactor::{Core, Handle}; @@ -93,6 +93,12 @@ macro_rules! test { $( assert_eq!(res.headers().get(), Some(&$response_headers)); )* + + let body = core.run(res.body().concat2()).unwrap(); + + let expected_res_body = Option::<&[u8]>::from($response_body) + .unwrap_or_default(); + assert_eq!(body.as_ref(), expected_res_body); } ); } @@ -225,6 +231,36 @@ test! { body: None, } + +test! { + name: client_head_ignores_body, + + server: + expected: "\ + HEAD /head HTTP/1.1\r\n\ + Host: {addr}\r\n\ + \r\n\ + ", + reply: "\ + HTTP/1.1 200 OK\r\n\ + Content-Length: 11\r\n\ + \r\n\ + Hello World\ + ", + + client: + request: + method: Head, + url: "http://{addr}/head", + headers: [], + body: None, + proxy: false, + response: + status: Ok, + headers: [], + body: None, +} + #[test] fn client_keep_alive() { let server = TcpListener::bind("127.0.0.1:0").unwrap(); diff --git a/tests/server.rs b/tests/server.rs index e5c38289..ca34c3db 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -16,6 +16,431 @@ use std::time::Duration; use hyper::server::{Http, Request, Response, Service, NewService}; +#[test] +fn get_should_ignore_body() { + let server = serve(); + + let mut req = connect(server.addr()); + // Connection: close = don't try to parse the body as a new request + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + I shouldn't be read.\r\n\ + ").unwrap(); + req.read(&mut [0; 256]).unwrap(); + + assert_eq!(server.body(), b""); +} + +#[test] +fn get_with_body() { + let server = serve(); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Content-Length: 19\r\n\ + \r\n\ + I'm a good request.\r\n\ + ").unwrap(); + req.read(&mut [0; 256]).unwrap(); + + // note: doesn't include trailing \r\n, cause Content-Length wasn't 21 + assert_eq!(server.body(), b"I'm a good request."); +} + +#[test] +fn get_fixed_response() { + let foo_bar = b"foo bar baz"; + let server = serve(); + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(foo_bar.len() as u64)) + .body(foo_bar); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + ").unwrap(); + let mut body = String::new(); + req.read_to_string(&mut body).unwrap(); + let n = body.find("\r\n\r\n").unwrap() + 4; + + assert_eq!(&body[n..], "foo bar baz"); +} + +#[test] +fn get_chunked_response() { + let foo_bar = b"foo bar baz"; + let server = serve(); + server.reply() + .status(hyper::Ok) + .header(hyper::header::TransferEncoding::chunked()) + .body(foo_bar); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + ").unwrap(); + let mut body = String::new(); + req.read_to_string(&mut body).unwrap(); + let n = body.find("\r\n\r\n").unwrap() + 4; + + assert_eq!(&body[n..], "B\r\nfoo bar baz\r\n0\r\n\r\n"); +} + +#[test] +fn get_chunked_response_with_ka() { + let foo_bar = b"foo bar baz"; + let foo_bar_chunk = b"\r\nfoo bar baz\r\n0\r\n\r\n"; + let server = serve(); + server.reply() + .status(hyper::Ok) + .header(hyper::header::TransferEncoding::chunked()) + .body(foo_bar); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: keep-alive\r\n\ + \r\n\ + ").expect("writing 1"); + + let mut buf = [0; 1024 * 4]; + let mut ntotal = 0; + loop { + let n = req.read(&mut buf[ntotal..]).expect("reading 1"); + ntotal = ntotal + n; + assert!(ntotal < buf.len()); + if &buf[ntotal - foo_bar_chunk.len()..ntotal] == foo_bar_chunk { + break; + } + } + + + // try again! + + let quux = b"zar quux"; + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(quux.len() as u64)) + .body(quux); + req.write_all(b"\ + GET /quux HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + ").expect("writing 2"); + + let mut buf = [0; 1024 * 8]; + loop { + let n = req.read(&mut buf[..]).expect("reading 2"); + assert!(n > 0, "n = {}", n); + if n < buf.len() && n > 0 { + if &buf[n - quux.len()..n] == quux { + break; + } + } + } +} + +#[test] +fn post_with_chunked_body() { + let server = serve(); + let mut req = connect(server.addr()); + req.write_all(b"\ + POST / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 1\r\n\ + q\r\n\ + 2\r\n\ + we\r\n\ + 2\r\n\ + rt\r\n\ + 0\r\n\ + \r\n\ + ").unwrap(); + req.read(&mut [0; 256]).unwrap(); + + assert_eq!(server.body(), b"qwert"); +} + +#[test] +fn empty_response_chunked() { + let server = serve(); + + server.reply() + .status(hyper::Ok) + .body(""); + + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Content-Length: 0\r\n\ + Connection: close\r\n\ + \r\n\ + ").unwrap(); + + + let mut response = String::new(); + req.read_to_string(&mut response).unwrap(); + + assert!(response.contains("Transfer-Encoding: chunked\r\n")); + + let mut lines = response.lines(); + assert_eq!(lines.next(), Some("HTTP/1.1 200 OK")); + + let mut lines = lines.skip_while(|line| !line.is_empty()); + assert_eq!(lines.next(), Some("")); + // 0\r\n\r\n + assert_eq!(lines.next(), Some("0")); + assert_eq!(lines.next(), Some("")); + assert_eq!(lines.next(), None); +} + +#[test] +fn empty_response_chunked_without_body_should_set_content_length() { + extern crate pretty_env_logger; + let _ = pretty_env_logger::init(); + let server = serve(); + server.reply() + .status(hyper::Ok) + .header(hyper::header::TransferEncoding::chunked()); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + ").unwrap(); + + let mut response = String::new(); + req.read_to_string(&mut response).unwrap(); + + assert!(!response.contains("Transfer-Encoding: chunked\r\n")); + assert!(response.contains("Content-Length: 0\r\n")); + + let mut lines = response.lines(); + assert_eq!(lines.next(), Some("HTTP/1.1 200 OK")); + + let mut lines = lines.skip_while(|line| !line.is_empty()); + assert_eq!(lines.next(), Some("")); + assert_eq!(lines.next(), None); +} + +#[test] +fn head_response_can_send_content_length() { + extern crate pretty_env_logger; + let _ = pretty_env_logger::init(); + let server = serve(); + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(1024)); + let mut req = connect(server.addr()); + req.write_all(b"\ + HEAD / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + ").unwrap(); + + let mut response = String::new(); + req.read_to_string(&mut response).unwrap(); + + assert!(response.contains("Content-Length: 1024\r\n")); + + let mut lines = response.lines(); + assert_eq!(lines.next(), Some("HTTP/1.1 200 OK")); + + let mut lines = lines.skip_while(|line| !line.is_empty()); + assert_eq!(lines.next(), Some("")); + assert_eq!(lines.next(), None); +} + +#[test] +fn response_does_not_set_chunked_if_body_not_allowed() { + extern crate pretty_env_logger; + let _ = pretty_env_logger::init(); + let server = serve(); + server.reply() + .status(hyper::StatusCode::NotModified) + .header(hyper::header::TransferEncoding::chunked()); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + ").unwrap(); + + let mut response = String::new(); + req.read_to_string(&mut response).unwrap(); + + assert!(!response.contains("Transfer-Encoding")); + + let mut lines = response.lines(); + assert_eq!(lines.next(), Some("HTTP/1.1 304 Not Modified")); + + // no body or 0\r\n\r\n + let mut lines = lines.skip_while(|line| !line.is_empty()); + assert_eq!(lines.next(), Some("")); + assert_eq!(lines.next(), None); +} + +#[test] +fn keep_alive() { + let foo_bar = b"foo bar baz"; + let server = serve(); + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(foo_bar.len() as u64)) + .body(foo_bar); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: keep-alive\r\n\ + \r\n\ + ").expect("writing 1"); + + let mut buf = [0; 1024 * 8]; + loop { + let n = req.read(&mut buf[..]).expect("reading 1"); + if n < buf.len() { + if &buf[n - foo_bar.len()..n] == foo_bar { + break; + } else { + } + } + } + + // try again! + + let quux = b"zar quux"; + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(quux.len() as u64)) + .body(quux); + req.write_all(b"\ + GET /quux HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + ").expect("writing 2"); + + let mut buf = [0; 1024 * 8]; + loop { + let n = req.read(&mut buf[..]).expect("reading 2"); + assert!(n > 0, "n = {}", n); + if n < buf.len() { + if &buf[n - quux.len()..n] == quux { + break; + } + } + } +} + +#[test] +fn disable_keep_alive() { + let foo_bar = b"foo bar baz"; + let server = serve_with_options(ServeOptions { + keep_alive_disabled: true, + .. Default::default() + }); + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(foo_bar.len() as u64)) + .body(foo_bar); + let mut req = connect(server.addr()); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: keep-alive\r\n\ + \r\n\ + ").expect("writing 1"); + + let mut buf = [0; 1024 * 8]; + loop { + let n = req.read(&mut buf[..]).expect("reading 1"); + if n < buf.len() { + if &buf[n - foo_bar.len()..n] == foo_bar { + break; + } else { + } + } + } + + // try again! + + let quux = b"zar quux"; + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(quux.len() as u64)) + .body(quux); + + let _ = req.write_all(b"\ + GET /quux HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + "); + + // the write can possibly succeed, since it fills the kernel buffer on the first write + let mut buf = [0; 1024 * 8]; + match req.read(&mut buf[..]) { + // Ok(0) means EOF, so a proper shutdown + // Err(_) could mean ConnReset or something, also fine + Ok(0) | + Err(_) => {} + Ok(n) => { + panic!("read {} bytes on a disabled keep-alive socket", n); + } + } +} + +#[test] +fn expect_continue() { + let server = serve(); + let mut req = connect(server.addr()); + server.reply().status(hyper::Ok); + + req.write_all(b"\ + POST /foo HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Expect: 100-continue\r\n\ + Content-Length: 5\r\n\ + Connection: Close\r\n\ + \r\n\ + ").expect("write 1"); + + let msg = b"HTTP/1.1 100 Continue\r\n\r\n"; + let mut buf = vec![0; msg.len()]; + req.read_exact(&mut buf).expect("read 1"); + assert_eq!(buf, msg); + + let msg = b"hello"; + req.write_all(msg).expect("write 2"); + + let mut body = String::new(); + req.read_to_string(&mut body).expect("read 2"); + + let body = server.body(); + assert_eq!(body, msg); +} + +// ------------------------------------------------- +// the Server that is used to run all the tests with +// ------------------------------------------------- + struct Serve { addr: SocketAddr, msg_rx: mpsc::Receiver, @@ -190,366 +615,4 @@ fn serve_with_options(options: ServeOptions) -> Serve { } } -#[test] -fn server_get_should_ignore_body() { - let server = serve(); - let mut req = connect(server.addr()); - // Connection: close = don't try to parse the body as a new request - req.write_all(b"\ - GET / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Connection: close\r\n\ - \r\n\ - I shouldn't be read.\r\n\ - ").unwrap(); - req.read(&mut [0; 256]).unwrap(); - - assert_eq!(server.body(), b""); -} - -#[test] -fn server_get_with_body() { - let server = serve(); - let mut req = connect(server.addr()); - req.write_all(b"\ - GET / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Content-Length: 19\r\n\ - \r\n\ - I'm a good request.\r\n\ - ").unwrap(); - req.read(&mut [0; 256]).unwrap(); - - // note: doesn't include trailing \r\n, cause Content-Length wasn't 21 - assert_eq!(server.body(), b"I'm a good request."); -} - -#[test] -fn server_get_fixed_response() { - let foo_bar = b"foo bar baz"; - let server = serve(); - server.reply() - .status(hyper::Ok) - .header(hyper::header::ContentLength(foo_bar.len() as u64)) - .body(foo_bar); - let mut req = connect(server.addr()); - req.write_all(b"\ - GET / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Connection: close\r\n\ - \r\n\ - ").unwrap(); - let mut body = String::new(); - req.read_to_string(&mut body).unwrap(); - let n = body.find("\r\n\r\n").unwrap() + 4; - - assert_eq!(&body[n..], "foo bar baz"); -} - -#[test] -fn server_get_chunked_response() { - let foo_bar = b"foo bar baz"; - let server = serve(); - server.reply() - .status(hyper::Ok) - .header(hyper::header::TransferEncoding::chunked()) - .body(foo_bar); - let mut req = connect(server.addr()); - req.write_all(b"\ - GET / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Connection: close\r\n\ - \r\n\ - ").unwrap(); - let mut body = String::new(); - req.read_to_string(&mut body).unwrap(); - let n = body.find("\r\n\r\n").unwrap() + 4; - - assert_eq!(&body[n..], "B\r\nfoo bar baz\r\n0\r\n\r\n"); -} - -#[test] -fn server_get_chunked_response_with_ka() { - let foo_bar = b"foo bar baz"; - let foo_bar_chunk = b"\r\nfoo bar baz\r\n0\r\n\r\n"; - let server = serve(); - server.reply() - .status(hyper::Ok) - .header(hyper::header::TransferEncoding::chunked()) - .body(foo_bar); - let mut req = connect(server.addr()); - req.write_all(b"\ - GET / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Connection: keep-alive\r\n\ - \r\n\ - ").expect("writing 1"); - - let mut buf = [0; 1024 * 4]; - let mut ntotal = 0; - loop { - let n = req.read(&mut buf[ntotal..]).expect("reading 1"); - ntotal = ntotal + n; - assert!(ntotal < buf.len()); - if &buf[ntotal - foo_bar_chunk.len()..ntotal] == foo_bar_chunk { - break; - } - } - - - // try again! - - let quux = b"zar quux"; - server.reply() - .status(hyper::Ok) - .header(hyper::header::ContentLength(quux.len() as u64)) - .body(quux); - req.write_all(b"\ - GET /quux HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Connection: close\r\n\ - \r\n\ - ").expect("writing 2"); - - let mut buf = [0; 1024 * 8]; - loop { - let n = req.read(&mut buf[..]).expect("reading 2"); - assert!(n > 0, "n = {}", n); - if n < buf.len() && n > 0 { - if &buf[n - quux.len()..n] == quux { - break; - } - } - } -} - - - -#[test] -fn server_post_with_chunked_body() { - let server = serve(); - let mut req = connect(server.addr()); - req.write_all(b"\ - POST / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Transfer-Encoding: chunked\r\n\ - \r\n\ - 1\r\n\ - q\r\n\ - 2\r\n\ - we\r\n\ - 2\r\n\ - rt\r\n\ - 0\r\n\ - \r\n\ - ").unwrap(); - req.read(&mut [0; 256]).unwrap(); - - assert_eq!(server.body(), b"qwert"); -} - -#[test] -fn server_empty_response_chunked() { - let server = serve(); - - server.reply() - .status(hyper::Ok) - .body(""); - - let mut req = connect(server.addr()); - req.write_all(b"\ - GET / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Content-Length: 0\r\n\ - Connection: close\r\n\ - \r\n\ - ").unwrap(); - - - let mut response = String::new(); - req.read_to_string(&mut response).unwrap(); - - assert!(response.contains("Transfer-Encoding: chunked\r\n")); - - let mut lines = response.lines(); - assert_eq!(lines.next(), Some("HTTP/1.1 200 OK")); - - let mut lines = lines.skip_while(|line| !line.is_empty()); - assert_eq!(lines.next(), Some("")); - // 0\r\n\r\n - assert_eq!(lines.next(), Some("0")); - assert_eq!(lines.next(), Some("")); - assert_eq!(lines.next(), None); -} - -#[test] -fn server_empty_response_chunked_without_body_should_set_content_length() { - extern crate pretty_env_logger; - let _ = pretty_env_logger::init(); - let server = serve(); - server.reply() - .status(hyper::Ok) - .header(hyper::header::TransferEncoding::chunked()); - let mut req = connect(server.addr()); - req.write_all(b"\ - GET / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Connection: close\r\n\ - \r\n\ - ").unwrap(); - - let mut response = String::new(); - req.read_to_string(&mut response).unwrap(); - - assert!(!response.contains("Transfer-Encoding: chunked\r\n")); - assert!(response.contains("Content-Length: 0\r\n")); - - let mut lines = response.lines(); - assert_eq!(lines.next(), Some("HTTP/1.1 200 OK")); - - let mut lines = lines.skip_while(|line| !line.is_empty()); - assert_eq!(lines.next(), Some("")); - assert_eq!(lines.next(), None); -} - -#[test] -fn server_keep_alive() { - let foo_bar = b"foo bar baz"; - let server = serve(); - server.reply() - .status(hyper::Ok) - .header(hyper::header::ContentLength(foo_bar.len() as u64)) - .body(foo_bar); - let mut req = connect(server.addr()); - req.write_all(b"\ - GET / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Connection: keep-alive\r\n\ - \r\n\ - ").expect("writing 1"); - - let mut buf = [0; 1024 * 8]; - loop { - let n = req.read(&mut buf[..]).expect("reading 1"); - if n < buf.len() { - if &buf[n - foo_bar.len()..n] == foo_bar { - break; - } else { - } - } - } - - // try again! - - let quux = b"zar quux"; - server.reply() - .status(hyper::Ok) - .header(hyper::header::ContentLength(quux.len() as u64)) - .body(quux); - req.write_all(b"\ - GET /quux HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Connection: close\r\n\ - \r\n\ - ").expect("writing 2"); - - let mut buf = [0; 1024 * 8]; - loop { - let n = req.read(&mut buf[..]).expect("reading 2"); - assert!(n > 0, "n = {}", n); - if n < buf.len() { - if &buf[n - quux.len()..n] == quux { - break; - } - } - } -} - -#[test] -fn test_server_disable_keep_alive() { - let foo_bar = b"foo bar baz"; - let server = serve_with_options(ServeOptions { - keep_alive_disabled: true, - .. Default::default() - }); - server.reply() - .status(hyper::Ok) - .header(hyper::header::ContentLength(foo_bar.len() as u64)) - .body(foo_bar); - let mut req = connect(server.addr()); - req.write_all(b"\ - GET / HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Connection: keep-alive\r\n\ - \r\n\ - ").expect("writing 1"); - - let mut buf = [0; 1024 * 8]; - loop { - let n = req.read(&mut buf[..]).expect("reading 1"); - if n < buf.len() { - if &buf[n - foo_bar.len()..n] == foo_bar { - break; - } else { - } - } - } - - // try again! - - let quux = b"zar quux"; - server.reply() - .status(hyper::Ok) - .header(hyper::header::ContentLength(quux.len() as u64)) - .body(quux); - - let _ = req.write_all(b"\ - GET /quux HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Connection: close\r\n\ - \r\n\ - "); - - // the write can possibly succeed, since it fills the kernel buffer on the first write - let mut buf = [0; 1024 * 8]; - match req.read(&mut buf[..]) { - // Ok(0) means EOF, so a proper shutdown - // Err(_) could mean ConnReset or something, also fine - Ok(0) | - Err(_) => {} - Ok(n) => { - panic!("read {} bytes on a disabled keep-alive socket", n); - } - } -} - -#[test] -fn expect_continue() { - let server = serve(); - let mut req = connect(server.addr()); - server.reply().status(hyper::Ok); - - req.write_all(b"\ - POST /foo HTTP/1.1\r\n\ - Host: example.domain\r\n\ - Expect: 100-continue\r\n\ - Content-Length: 5\r\n\ - Connection: Close\r\n\ - \r\n\ - ").expect("write 1"); - - let msg = b"HTTP/1.1 100 Continue\r\n\r\n"; - let mut buf = vec![0; msg.len()]; - req.read_exact(&mut buf).expect("read 1"); - assert_eq!(buf, msg); - - let msg = b"hello"; - req.write_all(msg).expect("write 2"); - - let mut body = String::new(); - req.read_to_string(&mut body).expect("read 2"); - - let body = server.body(); - assert_eq!(body, msg); -}