fix(http1): only send 100 Continue if request body is polled

Before, if a client request included an `Expect: 100-continue` header,
the `100 Continue` response was sent immediately. However, this is
problematic if the service is going to reply with some 4xx status code
and reject the body.

This change delays the automatic sending of the `100 Continue` status
until the service has call `poll_data` on the request body once.
This commit is contained in:
Sean McArthur
2020-01-28 16:23:03 -08:00
parent a354580e3f
commit c4bb4db5c2
7 changed files with 332 additions and 39 deletions

View File

@@ -8,7 +8,7 @@ use http::{HeaderMap, Method, Version};
use tokio::io::{AsyncRead, AsyncWrite};
use super::io::Buffered;
use super::{/*Decode,*/ Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext,};
use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext, Wants};
use crate::common::{task, Pin, Poll, Unpin};
use crate::headers::connection_keep_alive;
use crate::proto::{BodyLength, DecodedLength, MessageHead};
@@ -114,7 +114,7 @@ where
pub fn can_read_body(&self) -> bool {
match self.state.reading {
Reading::Body(..) => true,
Reading::Body(..) | Reading::Continue(..) => true,
_ => false,
}
}
@@ -129,10 +129,10 @@ where
read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE
}
pub fn poll_read_head(
pub(super) fn poll_read_head(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Option<crate::Result<(MessageHead<T::Incoming>, DecodedLength, bool)>>> {
) -> Poll<Option<crate::Result<(MessageHead<T::Incoming>, DecodedLength, Wants)>>> {
debug_assert!(self.can_read_head());
trace!("Conn::read_head");
@@ -156,23 +156,28 @@ where
self.state.keep_alive &= msg.keep_alive;
self.state.version = msg.head.version;
let mut wants = if msg.wants_upgrade {
Wants::UPGRADE
} else {
Wants::EMPTY
};
if msg.decode == DecodedLength::ZERO {
if log_enabled!(log::Level::Debug) && msg.expect_continue {
if msg.expect_continue {
debug!("ignoring expect-continue since body is empty");
}
self.state.reading = Reading::KeepAlive;
if !T::should_read_first() {
self.try_keep_alive(cx);
}
} else if msg.expect_continue {
self.state.reading = Reading::Continue(Decoder::new(msg.decode));
wants = wants.add(Wants::EXPECT);
} else {
if msg.expect_continue {
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
self.io.headers_buf().extend_from_slice(cont);
}
self.state.reading = Reading::Body(Decoder::new(msg.decode));
};
}
Poll::Ready(Some(Ok((msg.head, msg.decode, msg.wants_upgrade))))
Poll::Ready(Some(Ok((msg.head, msg.decode, wants))))
}
fn on_read_head_error<Z>(&mut self, e: crate::Error) -> Poll<Option<crate::Result<Z>>> {
@@ -239,7 +244,19 @@ where
}
}
}
_ => unreachable!("read_body invalid state: {:?}", self.state.reading),
Reading::Continue(ref decoder) => {
// Write the 100 Continue if not already responded...
if let Writing::Init = self.state.writing {
trace!("automatically sending 100 Continue");
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
self.io.headers_buf().extend_from_slice(cont);
}
// And now recurse once in the Reading::Body state...
self.state.reading = Reading::Body(decoder.clone());
return self.poll_read_body(cx);
}
_ => unreachable!("poll_read_body invalid state: {:?}", self.state.reading),
};
self.state.reading = reading;
@@ -346,7 +363,9 @@ where
// would finish.
match self.state.reading {
Reading::Body(..) | Reading::KeepAlive | Reading::Closed => return,
Reading::Continue(..) | Reading::Body(..) | Reading::KeepAlive | Reading::Closed => {
return
}
Reading::Init => (),
};
@@ -711,6 +730,7 @@ struct State {
#[derive(Debug)]
enum Reading {
Init,
Continue(Decoder),
Body(Decoder),
KeepAlive,
Closed,

View File

@@ -4,7 +4,7 @@ use bytes::{Buf, Bytes};
use http::{Request, Response, StatusCode};
use tokio::io::{AsyncRead, AsyncWrite};
use super::Http1Transaction;
use super::{Http1Transaction, Wants};
use crate::body::{Body, Payload};
use crate::common::{task, Future, Never, Pin, Poll, Unpin};
use crate::proto::{
@@ -235,16 +235,16 @@ where
}
// dispatch is ready for a message, try to read one
match ready!(self.conn.poll_read_head(cx)) {
Some(Ok((head, body_len, wants_upgrade))) => {
Some(Ok((head, body_len, wants))) => {
let mut body = match body_len {
DecodedLength::ZERO => Body::empty(),
other => {
let (tx, rx) = Body::new_channel(other);
let (tx, rx) = Body::new_channel(other, wants.contains(Wants::EXPECT));
self.body_tx = Some(tx);
rx
}
};
if wants_upgrade {
if wants.contains(Wants::UPGRADE) {
body.set_on_upgrade(self.conn.on_upgrade());
}
self.dispatch.recv_msg(Ok((head, body)))?;

View File

@@ -74,3 +74,22 @@ pub(crate) struct Encode<'a, T> {
req_method: &'a mut Option<Method>,
title_case_headers: bool,
}
/// Extra flags that a request "wants", like expect-continue or upgrades.
#[derive(Clone, Copy, Debug)]
struct Wants(u8);
impl Wants {
const EMPTY: Wants = Wants(0b00);
const EXPECT: Wants = Wants(0b01);
const UPGRADE: Wants = Wants(0b10);
#[must_use]
fn add(self, other: Wants) -> Wants {
Wants(self.0 | other.0)
}
fn contains(&self, other: Wants) -> bool {
(self.0 & other.0) == other.0
}
}