fix(http1): only send 100 Continue if request body is polled

Before, if a client request included an `Expect: 100-continue` header,
the `100 Continue` response was sent immediately. However, this is
problematic if the service is going to reply with some 4xx status code
and reject the body.

This change delays the automatic sending of the `100 Continue` status
until the service has call `poll_data` on the request body once.
This commit is contained in:
Sean McArthur
2020-01-28 16:23:03 -08:00
parent a354580e3f
commit c4bb4db5c2
7 changed files with 332 additions and 39 deletions

View File

@@ -11,7 +11,7 @@ use futures_util::TryStreamExt;
use http::HeaderMap;
use http_body::{Body as HttpBody, SizeHint};
use crate::common::{task, Future, Never, Pin, Poll};
use crate::common::{task, watch, Future, Never, Pin, Poll};
use crate::proto::DecodedLength;
use crate::upgrade::OnUpgrade;
@@ -33,7 +33,7 @@ enum Kind {
Once(Option<Bytes>),
Chan {
content_length: DecodedLength,
abort_rx: oneshot::Receiver<()>,
want_tx: watch::Sender,
rx: mpsc::Receiver<Result<Bytes, crate::Error>>,
},
H2 {
@@ -79,12 +79,14 @@ enum DelayEof {
/// Useful when wanting to stream chunks from another thread. See
/// [`Body::channel`](Body::channel) for more.
#[must_use = "Sender does nothing unless sent on"]
#[derive(Debug)]
pub struct Sender {
abort_tx: oneshot::Sender<()>,
want_rx: watch::Receiver,
tx: BodySender,
}
const WANT_PENDING: usize = 1;
const WANT_READY: usize = 2;
impl Body {
/// Create an empty `Body` stream.
///
@@ -106,17 +108,22 @@ impl Body {
/// Useful when wanting to stream chunks from another thread.
#[inline]
pub fn channel() -> (Sender, Body) {
Self::new_channel(DecodedLength::CHUNKED)
Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false)
}
pub(crate) fn new_channel(content_length: DecodedLength) -> (Sender, Body) {
pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, Body) {
let (tx, rx) = mpsc::channel(0);
let (abort_tx, abort_rx) = oneshot::channel();
let tx = Sender { abort_tx, tx };
// If wanter is true, `Sender::poll_ready()` won't becoming ready
// until the `Body` has been polled for data once.
let want = if wanter { WANT_PENDING } else { WANT_READY };
let (want_tx, want_rx) = watch::channel(want);
let tx = Sender { want_rx, tx };
let rx = Body::new(Kind::Chan {
content_length,
abort_rx,
want_tx,
rx,
});
@@ -236,11 +243,9 @@ impl Body {
Kind::Chan {
content_length: ref mut len,
ref mut rx,
ref mut abort_rx,
ref mut want_tx,
} => {
if let Poll::Ready(Ok(())) = Pin::new(abort_rx).poll(cx) {
return Poll::Ready(Some(Err(crate::Error::new_body_write_aborted())));
}
want_tx.send(WANT_READY);
match ready!(Pin::new(rx).poll_next(cx)?) {
Some(chunk) => {
@@ -460,19 +465,29 @@ impl From<Cow<'static, str>> for Body {
impl Sender {
/// Check to see if this `Sender` can send more data.
pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
match self.abort_tx.poll_canceled(cx) {
Poll::Ready(()) => return Poll::Ready(Err(crate::Error::new_closed())),
Poll::Pending => (), // fallthrough
}
// Check if the receiver end has tried polling for the body yet
ready!(self.poll_want(cx)?);
self.tx
.poll_ready(cx)
.map_err(|_| crate::Error::new_closed())
}
fn poll_want(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
match self.want_rx.load(cx) {
WANT_READY => Poll::Ready(Ok(())),
WANT_PENDING => Poll::Pending,
watch::CLOSED => Poll::Ready(Err(crate::Error::new_closed())),
unexpected => unreachable!("want_rx value: {}", unexpected),
}
}
async fn ready(&mut self) -> crate::Result<()> {
futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await
}
/// Send data on this channel when it is ready.
pub async fn send_data(&mut self, chunk: Bytes) -> crate::Result<()> {
futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await?;
self.ready().await?;
self.tx
.try_send(Ok(chunk))
.map_err(|_| crate::Error::new_closed())
@@ -498,8 +513,11 @@ impl Sender {
/// Aborts the body in an abnormal fashion.
pub fn abort(self) {
// TODO(sean): this can just be `self.tx.clone().try_send()`
let _ = self.abort_tx.send(());
let _ = self
.tx
// clone so the send works even if buffer is full
.clone()
.try_send(Err(crate::Error::new_body_write_aborted()));
}
pub(crate) fn send_error(&mut self, err: crate::Error) {
@@ -507,11 +525,29 @@ impl Sender {
}
}
impl fmt::Debug for Sender {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
#[derive(Debug)]
struct Open;
#[derive(Debug)]
struct Closed;
let mut builder = f.debug_tuple("Sender");
match self.want_rx.peek() {
watch::CLOSED => builder.field(&Closed),
_ => builder.field(&Open),
};
builder.finish()
}
}
#[cfg(test)]
mod tests {
use std::mem;
use std::task::Poll;
use super::{Body, Sender};
use super::{Body, DecodedLength, HttpBody, Sender};
#[test]
fn test_size_of() {
@@ -541,4 +577,97 @@ mod tests {
"Option<Sender>"
);
}
#[tokio::test]
async fn channel_abort() {
let (tx, mut rx) = Body::channel();
tx.abort();
let err = rx.data().await.unwrap().unwrap_err();
assert!(err.is_body_write_aborted(), "{:?}", err);
}
#[tokio::test]
async fn channel_abort_when_buffer_is_full() {
let (mut tx, mut rx) = Body::channel();
tx.try_send_data("chunk 1".into()).expect("send 1");
// buffer is full, but can still send abort
tx.abort();
let chunk1 = rx.data().await.expect("item 1").expect("chunk 1");
assert_eq!(chunk1, "chunk 1");
let err = rx.data().await.unwrap().unwrap_err();
assert!(err.is_body_write_aborted(), "{:?}", err);
}
#[test]
fn channel_buffers_one() {
let (mut tx, _rx) = Body::channel();
tx.try_send_data("chunk 1".into()).expect("send 1");
// buffer is now full
let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2");
assert_eq!(chunk2, "chunk 2");
}
#[tokio::test]
async fn channel_empty() {
let (_, mut rx) = Body::channel();
assert!(rx.data().await.is_none());
}
#[test]
fn channel_ready() {
let (mut tx, _rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ false);
let mut tx_ready = tokio_test::task::spawn(tx.ready());
assert!(tx_ready.poll().is_ready(), "tx is ready immediately");
}
#[test]
fn channel_wanter() {
let (mut tx, mut rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true);
let mut tx_ready = tokio_test::task::spawn(tx.ready());
let mut rx_data = tokio_test::task::spawn(rx.data());
assert!(
tx_ready.poll().is_pending(),
"tx isn't ready before rx has been polled"
);
assert!(rx_data.poll().is_pending(), "poll rx.data");
assert!(tx_ready.is_woken(), "rx poll wakes tx");
assert!(
tx_ready.poll().is_ready(),
"tx is ready after rx has been polled"
);
}
#[test]
fn channel_notices_closure() {
let (mut tx, rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true);
let mut tx_ready = tokio_test::task::spawn(tx.ready());
assert!(
tx_ready.poll().is_pending(),
"tx isn't ready before rx has been polled"
);
drop(rx);
assert!(tx_ready.is_woken(), "dropping rx wakes tx");
match tx_ready.poll() {
Poll::Ready(Err(ref e)) if e.is_closed() => (),
unexpected => panic!("tx poll ready unexpected: {:?}", unexpected),
}
}
}

View File

@@ -14,6 +14,7 @@ pub(crate) mod io;
mod lazy;
mod never;
pub(crate) mod task;
pub(crate) mod watch;
pub use self::exec::Executor;
pub(crate) use self::exec::{BoxSendFuture, Exec};

73
src/common/watch.rs Normal file
View File

@@ -0,0 +1,73 @@
//! An SPSC broadcast channel.
//!
//! - The value can only be a `usize`.
//! - The consumer is only notified if the value is different.
//! - The value `0` is reserved for closed.
use futures_util::task::AtomicWaker;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::task;
type Value = usize;
pub(crate) const CLOSED: usize = 0;
pub(crate) fn channel(initial: Value) -> (Sender, Receiver) {
debug_assert!(
initial != CLOSED,
"watch::channel initial state of 0 is reserved"
);
let shared = Arc::new(Shared {
value: AtomicUsize::new(initial),
waker: AtomicWaker::new(),
});
(
Sender {
shared: shared.clone(),
},
Receiver { shared },
)
}
pub(crate) struct Sender {
shared: Arc<Shared>,
}
pub(crate) struct Receiver {
shared: Arc<Shared>,
}
struct Shared {
value: AtomicUsize,
waker: AtomicWaker,
}
impl Sender {
pub(crate) fn send(&mut self, value: Value) {
if self.shared.value.swap(value, Ordering::SeqCst) != value {
self.shared.waker.wake();
}
}
}
impl Drop for Sender {
fn drop(&mut self) {
self.send(CLOSED);
}
}
impl Receiver {
pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value {
self.shared.waker.register(cx.waker());
self.shared.value.load(Ordering::SeqCst)
}
pub(crate) fn peek(&self) -> Value {
self.shared.value.load(Ordering::Relaxed)
}
}

View File

@@ -8,7 +8,7 @@ use http::{HeaderMap, Method, Version};
use tokio::io::{AsyncRead, AsyncWrite};
use super::io::Buffered;
use super::{/*Decode,*/ Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext,};
use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext, Wants};
use crate::common::{task, Pin, Poll, Unpin};
use crate::headers::connection_keep_alive;
use crate::proto::{BodyLength, DecodedLength, MessageHead};
@@ -114,7 +114,7 @@ where
pub fn can_read_body(&self) -> bool {
match self.state.reading {
Reading::Body(..) => true,
Reading::Body(..) | Reading::Continue(..) => true,
_ => false,
}
}
@@ -129,10 +129,10 @@ where
read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE
}
pub fn poll_read_head(
pub(super) fn poll_read_head(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Option<crate::Result<(MessageHead<T::Incoming>, DecodedLength, bool)>>> {
) -> Poll<Option<crate::Result<(MessageHead<T::Incoming>, DecodedLength, Wants)>>> {
debug_assert!(self.can_read_head());
trace!("Conn::read_head");
@@ -156,23 +156,28 @@ where
self.state.keep_alive &= msg.keep_alive;
self.state.version = msg.head.version;
let mut wants = if msg.wants_upgrade {
Wants::UPGRADE
} else {
Wants::EMPTY
};
if msg.decode == DecodedLength::ZERO {
if log_enabled!(log::Level::Debug) && msg.expect_continue {
if msg.expect_continue {
debug!("ignoring expect-continue since body is empty");
}
self.state.reading = Reading::KeepAlive;
if !T::should_read_first() {
self.try_keep_alive(cx);
}
} else if msg.expect_continue {
self.state.reading = Reading::Continue(Decoder::new(msg.decode));
wants = wants.add(Wants::EXPECT);
} else {
if msg.expect_continue {
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
self.io.headers_buf().extend_from_slice(cont);
}
self.state.reading = Reading::Body(Decoder::new(msg.decode));
};
}
Poll::Ready(Some(Ok((msg.head, msg.decode, msg.wants_upgrade))))
Poll::Ready(Some(Ok((msg.head, msg.decode, wants))))
}
fn on_read_head_error<Z>(&mut self, e: crate::Error) -> Poll<Option<crate::Result<Z>>> {
@@ -239,7 +244,19 @@ where
}
}
}
_ => unreachable!("read_body invalid state: {:?}", self.state.reading),
Reading::Continue(ref decoder) => {
// Write the 100 Continue if not already responded...
if let Writing::Init = self.state.writing {
trace!("automatically sending 100 Continue");
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
self.io.headers_buf().extend_from_slice(cont);
}
// And now recurse once in the Reading::Body state...
self.state.reading = Reading::Body(decoder.clone());
return self.poll_read_body(cx);
}
_ => unreachable!("poll_read_body invalid state: {:?}", self.state.reading),
};
self.state.reading = reading;
@@ -346,7 +363,9 @@ where
// would finish.
match self.state.reading {
Reading::Body(..) | Reading::KeepAlive | Reading::Closed => return,
Reading::Continue(..) | Reading::Body(..) | Reading::KeepAlive | Reading::Closed => {
return
}
Reading::Init => (),
};
@@ -711,6 +730,7 @@ struct State {
#[derive(Debug)]
enum Reading {
Init,
Continue(Decoder),
Body(Decoder),
KeepAlive,
Closed,

View File

@@ -4,7 +4,7 @@ use bytes::{Buf, Bytes};
use http::{Request, Response, StatusCode};
use tokio::io::{AsyncRead, AsyncWrite};
use super::Http1Transaction;
use super::{Http1Transaction, Wants};
use crate::body::{Body, Payload};
use crate::common::{task, Future, Never, Pin, Poll, Unpin};
use crate::proto::{
@@ -235,16 +235,16 @@ where
}
// dispatch is ready for a message, try to read one
match ready!(self.conn.poll_read_head(cx)) {
Some(Ok((head, body_len, wants_upgrade))) => {
Some(Ok((head, body_len, wants))) => {
let mut body = match body_len {
DecodedLength::ZERO => Body::empty(),
other => {
let (tx, rx) = Body::new_channel(other);
let (tx, rx) = Body::new_channel(other, wants.contains(Wants::EXPECT));
self.body_tx = Some(tx);
rx
}
};
if wants_upgrade {
if wants.contains(Wants::UPGRADE) {
body.set_on_upgrade(self.conn.on_upgrade());
}
self.dispatch.recv_msg(Ok((head, body)))?;

View File

@@ -74,3 +74,22 @@ pub(crate) struct Encode<'a, T> {
req_method: &'a mut Option<Method>,
title_case_headers: bool,
}
/// Extra flags that a request "wants", like expect-continue or upgrades.
#[derive(Clone, Copy, Debug)]
struct Wants(u8);
impl Wants {
const EMPTY: Wants = Wants(0b00);
const EXPECT: Wants = Wants(0b01);
const UPGRADE: Wants = Wants(0b10);
#[must_use]
fn add(self, other: Wants) -> Wants {
Wants(self.0 | other.0)
}
fn contains(&self, other: Wants) -> bool {
(self.0 & other.0) == other.0
}
}

View File

@@ -785,6 +785,57 @@ fn expect_continue_but_no_body_is_ignored() {
assert_eq!(&resp[..expected.len()], expected);
}
#[tokio::test]
async fn expect_continue_waits_for_body_poll() {
let _ = pretty_env_logger::try_init();
let mut listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
let addr = listener.local_addr().unwrap();
let child = thread::spawn(move || {
let mut tcp = connect(&addr);
tcp.write_all(
b"\
POST /foo HTTP/1.1\r\n\
Host: example.domain\r\n\
Expect: 100-continue\r\n\
Content-Length: 100\r\n\
Connection: Close\r\n\
\r\n\
",
)
.expect("write");
let expected = "HTTP/1.1 400 Bad Request\r\n";
let mut resp = String::new();
tcp.read_to_string(&mut resp).expect("read");
assert_eq!(&resp[..expected.len()], expected);
});
let (socket, _) = listener.accept().await.expect("accept");
Http::new()
.serve_connection(
socket,
service_fn(|req| {
assert_eq!(req.headers()["expect"], "100-continue");
// But! We're never going to poll the body!
tokio::time::delay_for(Duration::from_millis(50)).map(move |_| {
// Move and drop the req, so we don't auto-close
drop(req);
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(hyper::Body::empty())
})
}),
)
.await
.expect("serve_connection");
child.join().expect("client thread");
}
#[test]
fn pipeline_disabled() {
let server = serve();