fix(lib): properly handle HTTP/1.0 remotes
- Downgrades internal semantics to HTTP/1.0 if peer sends a message with 1.0 version. - If downgraded, chunked writers become EOF writers, with the connection closing once the writing is complete. - When downgraded, if keep-alive was wanted, the `Connection: keep-alive` header is added. Closes #1304
This commit is contained in:
@@ -45,6 +45,9 @@ where I: AsyncRead + AsyncWrite,
|
||||
read_task: None,
|
||||
reading: Reading::Init,
|
||||
writing: Writing::Init,
|
||||
// We assume a modern world where the remote speaks HTTP/1.1.
|
||||
// If they tell us otherwise, we'll downgrade in `read_head`.
|
||||
version: Version::Http11,
|
||||
},
|
||||
_marker: PhantomData,
|
||||
}
|
||||
@@ -189,43 +192,44 @@ where I: AsyncRead + AsyncWrite,
|
||||
}
|
||||
};
|
||||
|
||||
match version {
|
||||
HttpVersion::Http10 | HttpVersion::Http11 => {
|
||||
let decoder = match T::decoder(&head, &mut self.state.method) {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
debug!("decoder error = {:?}", e);
|
||||
self.state.close_read();
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
debug!("incoming body is {}", decoder);
|
||||
|
||||
self.state.busy();
|
||||
if head.expecting_continue() {
|
||||
let msg = b"HTTP/1.1 100 Continue\r\n\r\n";
|
||||
self.state.writing = Writing::Continue(Cursor::new(msg));
|
||||
}
|
||||
let wants_keep_alive = head.should_keep_alive();
|
||||
self.state.keep_alive &= wants_keep_alive;
|
||||
let (body, reading) = if decoder.is_eof() {
|
||||
(false, Reading::KeepAlive)
|
||||
} else {
|
||||
(true, Reading::Body(decoder))
|
||||
};
|
||||
self.state.reading = reading;
|
||||
if !body {
|
||||
self.try_keep_alive();
|
||||
}
|
||||
Ok(Async::Ready(Some((head, body))))
|
||||
},
|
||||
self.state.version = match version {
|
||||
HttpVersion::Http10 => Version::Http10,
|
||||
HttpVersion::Http11 => Version::Http11,
|
||||
_ => {
|
||||
error!("unimplemented HTTP Version = {:?}", version);
|
||||
self.state.close_read();
|
||||
Err(::Error::Version)
|
||||
return Err(::Error::Version);
|
||||
}
|
||||
};
|
||||
|
||||
let decoder = match T::decoder(&head, &mut self.state.method) {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
debug!("decoder error = {:?}", e);
|
||||
self.state.close_read();
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
debug!("incoming body is {}", decoder);
|
||||
|
||||
self.state.busy();
|
||||
if head.expecting_continue() {
|
||||
let msg = b"HTTP/1.1 100 Continue\r\n\r\n";
|
||||
self.state.writing = Writing::Continue(Cursor::new(msg));
|
||||
}
|
||||
let wants_keep_alive = head.should_keep_alive();
|
||||
self.state.keep_alive &= wants_keep_alive;
|
||||
let (body, reading) = if decoder.is_eof() {
|
||||
(false, Reading::KeepAlive)
|
||||
} else {
|
||||
(true, Reading::Body(decoder))
|
||||
};
|
||||
self.state.reading = reading;
|
||||
if !body {
|
||||
self.try_keep_alive();
|
||||
}
|
||||
Ok(Async::Ready(Some((head, body))))
|
||||
}
|
||||
|
||||
pub fn read_body(&mut self) -> Poll<Option<super::Chunk>, io::Error> {
|
||||
@@ -414,11 +418,11 @@ where I: AsyncRead + AsyncWrite,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_head(&mut self, head: super::MessageHead<T::Outgoing>, body: bool) {
|
||||
pub fn write_head(&mut self, mut head: super::MessageHead<T::Outgoing>, body: bool) {
|
||||
debug_assert!(self.can_write_head());
|
||||
|
||||
let wants_keep_alive = head.should_keep_alive();
|
||||
self.state.keep_alive &= wants_keep_alive;
|
||||
self.enforce_version(&mut head);
|
||||
|
||||
let buf = self.io.write_buf_mut();
|
||||
// if a 100-continue has started but not finished sending, tack the
|
||||
// remainder on to the start of the buffer.
|
||||
@@ -435,6 +439,36 @@ where I: AsyncRead + AsyncWrite,
|
||||
};
|
||||
}
|
||||
|
||||
// If we know the remote speaks an older version, we try to fix up any messages
|
||||
// to work with our older peer.
|
||||
fn enforce_version(&mut self, head: &mut super::MessageHead<T::Outgoing>) {
|
||||
use header::Connection;
|
||||
|
||||
let wants_keep_alive = if self.state.wants_keep_alive() {
|
||||
let ka = head.should_keep_alive();
|
||||
self.state.keep_alive &= ka;
|
||||
ka
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
match self.state.version {
|
||||
Version::Http10 => {
|
||||
// If the remote only knows HTTP/1.0, we should force ourselves
|
||||
// to do only speak HTTP/1.0 as well.
|
||||
head.version = HttpVersion::Http10;
|
||||
if wants_keep_alive {
|
||||
head.headers.set(Connection::keep_alive());
|
||||
}
|
||||
},
|
||||
Version::Http11 => {
|
||||
// If the remote speaks HTTP/1.1, then it *should* be fine with
|
||||
// both HTTP/1.0 and HTTP/1.1 from us. So again, we just let
|
||||
// the user's headers be.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_body(&mut self, chunk: Option<B>) -> StartSend<Option<B>, io::Error> {
|
||||
debug_assert!(self.can_write_body());
|
||||
|
||||
@@ -486,7 +520,7 @@ where I: AsyncRead + AsyncWrite,
|
||||
}
|
||||
} else {
|
||||
// end of stream, that means we should try to eof
|
||||
match encoder.eof() {
|
||||
match encoder.end() {
|
||||
Ok(Some(end)) => Writing::Ending(Cursor::new(end)),
|
||||
Ok(None) => Writing::KeepAlive,
|
||||
Err(_not_eof) => Writing::Closed,
|
||||
@@ -701,6 +735,7 @@ struct State<B, K> {
|
||||
read_task: Option<Task>,
|
||||
reading: Reading,
|
||||
writing: Writing<B>,
|
||||
version: Version,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -819,6 +854,14 @@ impl<B, K: KeepAlive> State<B, K> {
|
||||
self.keep_alive.disable();
|
||||
}
|
||||
|
||||
fn wants_keep_alive(&self) -> bool {
|
||||
if let KA::Disabled = self.keep_alive.status() {
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn try_keep_alive(&mut self) {
|
||||
match (&self.reading, &self.writing) {
|
||||
(&Reading::KeepAlive, &Writing::KeepAlive) => {
|
||||
@@ -881,6 +924,12 @@ impl<B, K: KeepAlive> State<B, K> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum Version {
|
||||
Http10,
|
||||
Http11,
|
||||
}
|
||||
|
||||
// The DebugFrame and DebugChunk are simple Debug implementations that allow
|
||||
// us to dump the frame into logs, without logging the entirety of the bytes.
|
||||
#[cfg(feature = "tokio-proto")]
|
||||
|
||||
@@ -17,6 +17,11 @@ enum Kind {
|
||||
///
|
||||
/// Enforces that the body is not longer than the Content-Length header.
|
||||
Length(u64),
|
||||
/// An Encoder for when neither Content-Length nore Chunked encoding is set.
|
||||
///
|
||||
/// This is mostly only used with HTTP/1.0 with a length. This kind requires
|
||||
/// the connection to be closed when the body is finished.
|
||||
Eof
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
@@ -32,6 +37,12 @@ impl Encoder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn eof() -> Encoder {
|
||||
Encoder {
|
||||
kind: Kind::Eof,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_eof(&self) -> bool {
|
||||
match self.kind {
|
||||
Kind::Length(0) |
|
||||
@@ -40,7 +51,7 @@ impl Encoder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn eof(&self) -> Result<Option<&'static [u8]>, NotEof> {
|
||||
pub fn end(&self) -> Result<Option<&'static [u8]>, NotEof> {
|
||||
match self.kind {
|
||||
Kind::Length(0) => Ok(None),
|
||||
Kind::Chunked(Chunked::Init) => Ok(Some(b"0\r\n\r\n")),
|
||||
@@ -73,6 +84,12 @@ impl Encoder {
|
||||
trace!("encoded {} bytes, remaining = {}", n, remaining);
|
||||
Ok(n)
|
||||
},
|
||||
Kind::Eof => {
|
||||
if msg.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
w.write_atomic(&[msg])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ use proto::{MessageHead, RawStatus, Http1Transaction, ParseResult,
|
||||
use proto::h1::{Encoder, Decoder, date};
|
||||
use method::Method;
|
||||
use status::StatusCode;
|
||||
use version::HttpVersion::{Http10, Http11};
|
||||
use version::HttpVersion::{self, Http10, Http11};
|
||||
|
||||
const MAX_HEADERS: usize = 100;
|
||||
const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
|
||||
@@ -166,7 +166,7 @@ impl ServerTransaction {
|
||||
};
|
||||
|
||||
if has_body && can_have_body {
|
||||
set_length(&mut head.headers)
|
||||
set_length(head.version, &mut head.headers)
|
||||
} else {
|
||||
head.headers.remove::<TransferEncoding>();
|
||||
if can_have_body {
|
||||
@@ -302,7 +302,7 @@ impl Http1Transaction for ClientTransaction {
|
||||
impl ClientTransaction {
|
||||
fn set_length(head: &mut RequestHead, has_body: bool) -> Encoder {
|
||||
if has_body {
|
||||
set_length(&mut head.headers)
|
||||
set_length(head.version, &mut head.headers)
|
||||
} else {
|
||||
head.headers.remove::<ContentLength>();
|
||||
head.headers.remove::<TransferEncoding>();
|
||||
@@ -311,12 +311,12 @@ impl ClientTransaction {
|
||||
}
|
||||
}
|
||||
|
||||
fn set_length(headers: &mut Headers) -> Encoder {
|
||||
fn set_length(version: HttpVersion, headers: &mut Headers) -> Encoder {
|
||||
let len = headers.get::<header::ContentLength>().map(|n| **n);
|
||||
|
||||
if let Some(len) = len {
|
||||
Encoder::length(len)
|
||||
} else {
|
||||
} else if version == Http11 {
|
||||
let encodings = match headers.get_mut::<header::TransferEncoding>() {
|
||||
Some(&mut header::TransferEncoding(ref mut encodings)) => {
|
||||
if encodings.last() != Some(&header::Encoding::Chunked) {
|
||||
@@ -331,6 +331,9 @@ fn set_length(headers: &mut Headers) -> Encoder {
|
||||
headers.set(header::TransferEncoding(vec![header::Encoding::Chunked]));
|
||||
}
|
||||
Encoder::chunked()
|
||||
} else {
|
||||
headers.remove::<TransferEncoding>();
|
||||
Encoder::eof()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
149
tests/server.rs
149
tests/server.rs
@@ -145,6 +145,75 @@ fn get_chunked_response() {
|
||||
assert_eq!(&body[n..], "B\r\nfoo bar baz\r\n0\r\n\r\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_auto_response() {
|
||||
let foo_bar = b"foo bar baz";
|
||||
let server = serve();
|
||||
server.reply()
|
||||
.status(hyper::Ok)
|
||||
.body(foo_bar);
|
||||
let mut req = connect(server.addr());
|
||||
req.write_all(b"\
|
||||
GET / HTTP/1.1\r\n\
|
||||
Host: example.domain\r\n\
|
||||
Connection: close\r\n\
|
||||
\r\n\
|
||||
").unwrap();
|
||||
let mut body = String::new();
|
||||
req.read_to_string(&mut body).unwrap();
|
||||
|
||||
assert!(has_header(&body, "Transfer-Encoding: chunked"));
|
||||
|
||||
let n = body.find("\r\n\r\n").unwrap() + 4;
|
||||
assert_eq!(&body[n..], "B\r\nfoo bar baz\r\n0\r\n\r\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn http_10_get_auto_response() {
|
||||
let foo_bar = b"foo bar baz";
|
||||
let server = serve();
|
||||
server.reply()
|
||||
.status(hyper::Ok)
|
||||
.body(foo_bar);
|
||||
let mut req = connect(server.addr());
|
||||
req.write_all(b"\
|
||||
GET / HTTP/1.0\r\n\
|
||||
Host: example.domain\r\n\
|
||||
\r\n\
|
||||
").unwrap();
|
||||
let mut body = String::new();
|
||||
req.read_to_string(&mut body).unwrap();
|
||||
|
||||
assert!(!has_header(&body, "Transfer-Encoding:"));
|
||||
|
||||
let n = body.find("\r\n\r\n").unwrap() + 4;
|
||||
assert_eq!(&body[n..], "foo bar baz");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn http_10_get_chunked_response() {
|
||||
let foo_bar = b"foo bar baz";
|
||||
let server = serve();
|
||||
server.reply()
|
||||
.status(hyper::Ok)
|
||||
// this header should actually get removed
|
||||
.header(hyper::header::TransferEncoding::chunked())
|
||||
.body(foo_bar);
|
||||
let mut req = connect(server.addr());
|
||||
req.write_all(b"\
|
||||
GET / HTTP/1.0\r\n\
|
||||
Host: example.domain\r\n\
|
||||
\r\n\
|
||||
").unwrap();
|
||||
let mut body = String::new();
|
||||
req.read_to_string(&mut body).unwrap();
|
||||
|
||||
assert!(!has_header(&body, "Transfer-Encoding:"));
|
||||
|
||||
let n = body.find("\r\n\r\n").unwrap() + 4;
|
||||
assert_eq!(&body[n..], "foo bar baz");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_chunked_response_with_ka() {
|
||||
let foo_bar = b"foo bar baz";
|
||||
@@ -378,7 +447,6 @@ fn keep_alive() {
|
||||
req.write_all(b"\
|
||||
GET / HTTP/1.1\r\n\
|
||||
Host: example.domain\r\n\
|
||||
Connection: keep-alive\r\n\
|
||||
\r\n\
|
||||
").expect("writing 1");
|
||||
|
||||
@@ -388,7 +456,6 @@ fn keep_alive() {
|
||||
if n < buf.len() {
|
||||
if &buf[n - foo_bar.len()..n] == foo_bar {
|
||||
break;
|
||||
} else {
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -419,6 +486,57 @@ fn keep_alive() {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn http_10_keep_alive() {
|
||||
let foo_bar = b"foo bar baz";
|
||||
let server = serve();
|
||||
server.reply()
|
||||
.status(hyper::Ok)
|
||||
.header(hyper::header::ContentLength(foo_bar.len() as u64))
|
||||
.body(foo_bar);
|
||||
let mut req = connect(server.addr());
|
||||
req.write_all(b"\
|
||||
GET / HTTP/1.0\r\n\
|
||||
Host: example.domain\r\n\
|
||||
Connection: keep-alive\r\n\
|
||||
\r\n\
|
||||
").expect("writing 1");
|
||||
|
||||
let mut buf = [0; 1024 * 8];
|
||||
loop {
|
||||
let n = req.read(&mut buf[..]).expect("reading 1");
|
||||
if n < buf.len() {
|
||||
if &buf[n - foo_bar.len()..n] == foo_bar {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// try again!
|
||||
|
||||
let quux = b"zar quux";
|
||||
server.reply()
|
||||
.status(hyper::Ok)
|
||||
.header(hyper::header::ContentLength(quux.len() as u64))
|
||||
.body(quux);
|
||||
req.write_all(b"\
|
||||
GET /quux HTTP/1.0\r\n\
|
||||
Host: example.domain\r\n\
|
||||
\r\n\
|
||||
").expect("writing 2");
|
||||
|
||||
let mut buf = [0; 1024 * 8];
|
||||
loop {
|
||||
let n = req.read(&mut buf[..]).expect("reading 2");
|
||||
assert!(n > 0, "n = {}", n);
|
||||
if n < buf.len() {
|
||||
if &buf[n - quux.len()..n] == quux {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disable_keep_alive() {
|
||||
let foo_bar = b"foo bar baz";
|
||||
@@ -574,6 +692,23 @@ fn pipeline_enabled() {
|
||||
assert_eq!(n, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn http_10_request_receives_http_10_response() {
|
||||
let server = serve();
|
||||
|
||||
let mut req = connect(server.addr());
|
||||
req.write_all(b"\
|
||||
GET / HTTP/1.0\r\n\
|
||||
\r\n\
|
||||
").unwrap();
|
||||
|
||||
let expected = "HTTP/1.0 200 OK\r\nContent-Length: 0\r\n";
|
||||
let mut buf = [0; 256];
|
||||
let n = req.read(&mut buf).unwrap();
|
||||
assert!(n >= expected.len(), "read: {:?} >= {:?}", n, expected.len());
|
||||
assert_eq!(s(&buf[..expected.len()]), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disable_keep_alive_mid_request() {
|
||||
let mut core = Core::new().unwrap();
|
||||
@@ -997,6 +1132,16 @@ fn serve_with_options(options: ServeOptions) -> Serve {
|
||||
}
|
||||
}
|
||||
|
||||
fn s(buf: &[u8]) -> &str {
|
||||
::std::str::from_utf8(buf).unwrap()
|
||||
}
|
||||
|
||||
fn has_header(msg: &str, name: &str) -> bool {
|
||||
let n = msg.find("\r\n\r\n").unwrap_or(msg.len());
|
||||
|
||||
msg[..n].contains(name)
|
||||
}
|
||||
|
||||
struct DebugStream<T, D> {
|
||||
stream: T,
|
||||
_debug: D,
|
||||
|
||||
Reference in New Issue
Block a user