fix(lib): properly handle HTTP/1.0 remotes

- Downgrades internal semantics to HTTP/1.0 if peer sends a message with
  1.0 version.
- If downgraded, chunked writers become EOF writers, with the connection
  closing once the writing is complete.
- When downgraded, if keep-alive was wanted, the `Connection: keep-alive`
  header is added.

Closes #1304
This commit is contained in:
Sean McArthur
2018-01-22 10:08:27 -08:00
parent 7d493aafce
commit 36e66a5054
4 changed files with 258 additions and 44 deletions

View File

@@ -45,6 +45,9 @@ where I: AsyncRead + AsyncWrite,
read_task: None,
reading: Reading::Init,
writing: Writing::Init,
// We assume a modern world where the remote speaks HTTP/1.1.
// If they tell us otherwise, we'll downgrade in `read_head`.
version: Version::Http11,
},
_marker: PhantomData,
}
@@ -189,43 +192,44 @@ where I: AsyncRead + AsyncWrite,
}
};
match version {
HttpVersion::Http10 | HttpVersion::Http11 => {
let decoder = match T::decoder(&head, &mut self.state.method) {
Ok(d) => d,
Err(e) => {
debug!("decoder error = {:?}", e);
self.state.close_read();
return Err(e);
}
};
debug!("incoming body is {}", decoder);
self.state.busy();
if head.expecting_continue() {
let msg = b"HTTP/1.1 100 Continue\r\n\r\n";
self.state.writing = Writing::Continue(Cursor::new(msg));
}
let wants_keep_alive = head.should_keep_alive();
self.state.keep_alive &= wants_keep_alive;
let (body, reading) = if decoder.is_eof() {
(false, Reading::KeepAlive)
} else {
(true, Reading::Body(decoder))
};
self.state.reading = reading;
if !body {
self.try_keep_alive();
}
Ok(Async::Ready(Some((head, body))))
},
self.state.version = match version {
HttpVersion::Http10 => Version::Http10,
HttpVersion::Http11 => Version::Http11,
_ => {
error!("unimplemented HTTP Version = {:?}", version);
self.state.close_read();
Err(::Error::Version)
return Err(::Error::Version);
}
};
let decoder = match T::decoder(&head, &mut self.state.method) {
Ok(d) => d,
Err(e) => {
debug!("decoder error = {:?}", e);
self.state.close_read();
return Err(e);
}
};
debug!("incoming body is {}", decoder);
self.state.busy();
if head.expecting_continue() {
let msg = b"HTTP/1.1 100 Continue\r\n\r\n";
self.state.writing = Writing::Continue(Cursor::new(msg));
}
let wants_keep_alive = head.should_keep_alive();
self.state.keep_alive &= wants_keep_alive;
let (body, reading) = if decoder.is_eof() {
(false, Reading::KeepAlive)
} else {
(true, Reading::Body(decoder))
};
self.state.reading = reading;
if !body {
self.try_keep_alive();
}
Ok(Async::Ready(Some((head, body))))
}
pub fn read_body(&mut self) -> Poll<Option<super::Chunk>, io::Error> {
@@ -414,11 +418,11 @@ where I: AsyncRead + AsyncWrite,
}
}
pub fn write_head(&mut self, head: super::MessageHead<T::Outgoing>, body: bool) {
pub fn write_head(&mut self, mut head: super::MessageHead<T::Outgoing>, body: bool) {
debug_assert!(self.can_write_head());
let wants_keep_alive = head.should_keep_alive();
self.state.keep_alive &= wants_keep_alive;
self.enforce_version(&mut head);
let buf = self.io.write_buf_mut();
// if a 100-continue has started but not finished sending, tack the
// remainder on to the start of the buffer.
@@ -435,6 +439,36 @@ where I: AsyncRead + AsyncWrite,
};
}
// If we know the remote speaks an older version, we try to fix up any messages
// to work with our older peer.
fn enforce_version(&mut self, head: &mut super::MessageHead<T::Outgoing>) {
use header::Connection;
let wants_keep_alive = if self.state.wants_keep_alive() {
let ka = head.should_keep_alive();
self.state.keep_alive &= ka;
ka
} else {
false
};
match self.state.version {
Version::Http10 => {
// If the remote only knows HTTP/1.0, we should force ourselves
// to do only speak HTTP/1.0 as well.
head.version = HttpVersion::Http10;
if wants_keep_alive {
head.headers.set(Connection::keep_alive());
}
},
Version::Http11 => {
// If the remote speaks HTTP/1.1, then it *should* be fine with
// both HTTP/1.0 and HTTP/1.1 from us. So again, we just let
// the user's headers be.
}
}
}
pub fn write_body(&mut self, chunk: Option<B>) -> StartSend<Option<B>, io::Error> {
debug_assert!(self.can_write_body());
@@ -486,7 +520,7 @@ where I: AsyncRead + AsyncWrite,
}
} else {
// end of stream, that means we should try to eof
match encoder.eof() {
match encoder.end() {
Ok(Some(end)) => Writing::Ending(Cursor::new(end)),
Ok(None) => Writing::KeepAlive,
Err(_not_eof) => Writing::Closed,
@@ -701,6 +735,7 @@ struct State<B, K> {
read_task: Option<Task>,
reading: Reading,
writing: Writing<B>,
version: Version,
}
#[derive(Debug)]
@@ -819,6 +854,14 @@ impl<B, K: KeepAlive> State<B, K> {
self.keep_alive.disable();
}
fn wants_keep_alive(&self) -> bool {
if let KA::Disabled = self.keep_alive.status() {
false
} else {
true
}
}
fn try_keep_alive(&mut self) {
match (&self.reading, &self.writing) {
(&Reading::KeepAlive, &Writing::KeepAlive) => {
@@ -881,6 +924,12 @@ impl<B, K: KeepAlive> State<B, K> {
}
}
#[derive(Debug, Clone, Copy)]
enum Version {
Http10,
Http11,
}
// The DebugFrame and DebugChunk are simple Debug implementations that allow
// us to dump the frame into logs, without logging the entirety of the bytes.
#[cfg(feature = "tokio-proto")]

View File

@@ -17,6 +17,11 @@ enum Kind {
///
/// Enforces that the body is not longer than the Content-Length header.
Length(u64),
/// An Encoder for when neither Content-Length nore Chunked encoding is set.
///
/// This is mostly only used with HTTP/1.0 with a length. This kind requires
/// the connection to be closed when the body is finished.
Eof
}
impl Encoder {
@@ -32,6 +37,12 @@ impl Encoder {
}
}
pub fn eof() -> Encoder {
Encoder {
kind: Kind::Eof,
}
}
pub fn is_eof(&self) -> bool {
match self.kind {
Kind::Length(0) |
@@ -40,7 +51,7 @@ impl Encoder {
}
}
pub fn eof(&self) -> Result<Option<&'static [u8]>, NotEof> {
pub fn end(&self) -> Result<Option<&'static [u8]>, NotEof> {
match self.kind {
Kind::Length(0) => Ok(None),
Kind::Chunked(Chunked::Init) => Ok(Some(b"0\r\n\r\n")),
@@ -73,6 +84,12 @@ impl Encoder {
trace!("encoded {} bytes, remaining = {}", n, remaining);
Ok(n)
},
Kind::Eof => {
if msg.is_empty() {
return Ok(0);
}
w.write_atomic(&[msg])
}
}
}
}

View File

@@ -10,7 +10,7 @@ use proto::{MessageHead, RawStatus, Http1Transaction, ParseResult,
use proto::h1::{Encoder, Decoder, date};
use method::Method;
use status::StatusCode;
use version::HttpVersion::{Http10, Http11};
use version::HttpVersion::{self, Http10, Http11};
const MAX_HEADERS: usize = 100;
const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
@@ -166,7 +166,7 @@ impl ServerTransaction {
};
if has_body && can_have_body {
set_length(&mut head.headers)
set_length(head.version, &mut head.headers)
} else {
head.headers.remove::<TransferEncoding>();
if can_have_body {
@@ -302,7 +302,7 @@ impl Http1Transaction for ClientTransaction {
impl ClientTransaction {
fn set_length(head: &mut RequestHead, has_body: bool) -> Encoder {
if has_body {
set_length(&mut head.headers)
set_length(head.version, &mut head.headers)
} else {
head.headers.remove::<ContentLength>();
head.headers.remove::<TransferEncoding>();
@@ -311,12 +311,12 @@ impl ClientTransaction {
}
}
fn set_length(headers: &mut Headers) -> Encoder {
fn set_length(version: HttpVersion, headers: &mut Headers) -> Encoder {
let len = headers.get::<header::ContentLength>().map(|n| **n);
if let Some(len) = len {
Encoder::length(len)
} else {
} else if version == Http11 {
let encodings = match headers.get_mut::<header::TransferEncoding>() {
Some(&mut header::TransferEncoding(ref mut encodings)) => {
if encodings.last() != Some(&header::Encoding::Chunked) {
@@ -331,6 +331,9 @@ fn set_length(headers: &mut Headers) -> Encoder {
headers.set(header::TransferEncoding(vec![header::Encoding::Chunked]));
}
Encoder::chunked()
} else {
headers.remove::<TransferEncoding>();
Encoder::eof()
}
}

View File

@@ -145,6 +145,75 @@ fn get_chunked_response() {
assert_eq!(&body[n..], "B\r\nfoo bar baz\r\n0\r\n\r\n");
}
#[test]
fn get_auto_response() {
let foo_bar = b"foo bar baz";
let server = serve();
server.reply()
.status(hyper::Ok)
.body(foo_bar);
let mut req = connect(server.addr());
req.write_all(b"\
GET / HTTP/1.1\r\n\
Host: example.domain\r\n\
Connection: close\r\n\
\r\n\
").unwrap();
let mut body = String::new();
req.read_to_string(&mut body).unwrap();
assert!(has_header(&body, "Transfer-Encoding: chunked"));
let n = body.find("\r\n\r\n").unwrap() + 4;
assert_eq!(&body[n..], "B\r\nfoo bar baz\r\n0\r\n\r\n");
}
#[test]
fn http_10_get_auto_response() {
let foo_bar = b"foo bar baz";
let server = serve();
server.reply()
.status(hyper::Ok)
.body(foo_bar);
let mut req = connect(server.addr());
req.write_all(b"\
GET / HTTP/1.0\r\n\
Host: example.domain\r\n\
\r\n\
").unwrap();
let mut body = String::new();
req.read_to_string(&mut body).unwrap();
assert!(!has_header(&body, "Transfer-Encoding:"));
let n = body.find("\r\n\r\n").unwrap() + 4;
assert_eq!(&body[n..], "foo bar baz");
}
#[test]
fn http_10_get_chunked_response() {
let foo_bar = b"foo bar baz";
let server = serve();
server.reply()
.status(hyper::Ok)
// this header should actually get removed
.header(hyper::header::TransferEncoding::chunked())
.body(foo_bar);
let mut req = connect(server.addr());
req.write_all(b"\
GET / HTTP/1.0\r\n\
Host: example.domain\r\n\
\r\n\
").unwrap();
let mut body = String::new();
req.read_to_string(&mut body).unwrap();
assert!(!has_header(&body, "Transfer-Encoding:"));
let n = body.find("\r\n\r\n").unwrap() + 4;
assert_eq!(&body[n..], "foo bar baz");
}
#[test]
fn get_chunked_response_with_ka() {
let foo_bar = b"foo bar baz";
@@ -378,7 +447,6 @@ fn keep_alive() {
req.write_all(b"\
GET / HTTP/1.1\r\n\
Host: example.domain\r\n\
Connection: keep-alive\r\n\
\r\n\
").expect("writing 1");
@@ -388,7 +456,6 @@ fn keep_alive() {
if n < buf.len() {
if &buf[n - foo_bar.len()..n] == foo_bar {
break;
} else {
}
}
}
@@ -419,6 +486,57 @@ fn keep_alive() {
}
}
#[test]
fn http_10_keep_alive() {
let foo_bar = b"foo bar baz";
let server = serve();
server.reply()
.status(hyper::Ok)
.header(hyper::header::ContentLength(foo_bar.len() as u64))
.body(foo_bar);
let mut req = connect(server.addr());
req.write_all(b"\
GET / HTTP/1.0\r\n\
Host: example.domain\r\n\
Connection: keep-alive\r\n\
\r\n\
").expect("writing 1");
let mut buf = [0; 1024 * 8];
loop {
let n = req.read(&mut buf[..]).expect("reading 1");
if n < buf.len() {
if &buf[n - foo_bar.len()..n] == foo_bar {
break;
}
}
}
// try again!
let quux = b"zar quux";
server.reply()
.status(hyper::Ok)
.header(hyper::header::ContentLength(quux.len() as u64))
.body(quux);
req.write_all(b"\
GET /quux HTTP/1.0\r\n\
Host: example.domain\r\n\
\r\n\
").expect("writing 2");
let mut buf = [0; 1024 * 8];
loop {
let n = req.read(&mut buf[..]).expect("reading 2");
assert!(n > 0, "n = {}", n);
if n < buf.len() {
if &buf[n - quux.len()..n] == quux {
break;
}
}
}
}
#[test]
fn disable_keep_alive() {
let foo_bar = b"foo bar baz";
@@ -574,6 +692,23 @@ fn pipeline_enabled() {
assert_eq!(n, 0);
}
#[test]
fn http_10_request_receives_http_10_response() {
let server = serve();
let mut req = connect(server.addr());
req.write_all(b"\
GET / HTTP/1.0\r\n\
\r\n\
").unwrap();
let expected = "HTTP/1.0 200 OK\r\nContent-Length: 0\r\n";
let mut buf = [0; 256];
let n = req.read(&mut buf).unwrap();
assert!(n >= expected.len(), "read: {:?} >= {:?}", n, expected.len());
assert_eq!(s(&buf[..expected.len()]), expected);
}
#[test]
fn disable_keep_alive_mid_request() {
let mut core = Core::new().unwrap();
@@ -997,6 +1132,16 @@ fn serve_with_options(options: ServeOptions) -> Serve {
}
}
fn s(buf: &[u8]) -> &str {
::std::str::from_utf8(buf).unwrap()
}
fn has_header(msg: &str, name: &str) -> bool {
let n = msg.find("\r\n\r\n").unwrap_or(msg.len());
msg[..n].contains(name)
}
struct DebugStream<T, D> {
stream: T,
_debug: D,