fix(http1): only send 100 Continue if request body is polled
				
					
				
			Before, if a client request included an `Expect: 100-continue` header, the `100 Continue` response was sent immediately. However, this is problematic if the service is going to reply with some 4xx status code and reject the body. This change delays the automatic sending of the `100 Continue` status until the service has call `poll_data` on the request body once.
This commit is contained in:
		| @@ -8,7 +8,7 @@ use http::{HeaderMap, Method, Version}; | ||||
| use tokio::io::{AsyncRead, AsyncWrite}; | ||||
|  | ||||
| use super::io::Buffered; | ||||
| use super::{/*Decode,*/ Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext,}; | ||||
| use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext, Wants}; | ||||
| use crate::common::{task, Pin, Poll, Unpin}; | ||||
| use crate::headers::connection_keep_alive; | ||||
| use crate::proto::{BodyLength, DecodedLength, MessageHead}; | ||||
| @@ -114,7 +114,7 @@ where | ||||
|  | ||||
|     pub fn can_read_body(&self) -> bool { | ||||
|         match self.state.reading { | ||||
|             Reading::Body(..) => true, | ||||
|             Reading::Body(..) | Reading::Continue(..) => true, | ||||
|             _ => false, | ||||
|         } | ||||
|     } | ||||
| @@ -129,10 +129,10 @@ where | ||||
|         read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE | ||||
|     } | ||||
|  | ||||
|     pub fn poll_read_head( | ||||
|     pub(super) fn poll_read_head( | ||||
|         &mut self, | ||||
|         cx: &mut task::Context<'_>, | ||||
|     ) -> Poll<Option<crate::Result<(MessageHead<T::Incoming>, DecodedLength, bool)>>> { | ||||
|     ) -> Poll<Option<crate::Result<(MessageHead<T::Incoming>, DecodedLength, Wants)>>> { | ||||
|         debug_assert!(self.can_read_head()); | ||||
|         trace!("Conn::read_head"); | ||||
|  | ||||
| @@ -156,23 +156,28 @@ where | ||||
|         self.state.keep_alive &= msg.keep_alive; | ||||
|         self.state.version = msg.head.version; | ||||
|  | ||||
|         let mut wants = if msg.wants_upgrade { | ||||
|             Wants::UPGRADE | ||||
|         } else { | ||||
|             Wants::EMPTY | ||||
|         }; | ||||
|  | ||||
|         if msg.decode == DecodedLength::ZERO { | ||||
|             if log_enabled!(log::Level::Debug) && msg.expect_continue { | ||||
|             if msg.expect_continue { | ||||
|                 debug!("ignoring expect-continue since body is empty"); | ||||
|             } | ||||
|             self.state.reading = Reading::KeepAlive; | ||||
|             if !T::should_read_first() { | ||||
|                 self.try_keep_alive(cx); | ||||
|             } | ||||
|         } else if msg.expect_continue { | ||||
|             self.state.reading = Reading::Continue(Decoder::new(msg.decode)); | ||||
|             wants = wants.add(Wants::EXPECT); | ||||
|         } else { | ||||
|             if msg.expect_continue { | ||||
|                 let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; | ||||
|                 self.io.headers_buf().extend_from_slice(cont); | ||||
|             } | ||||
|             self.state.reading = Reading::Body(Decoder::new(msg.decode)); | ||||
|         }; | ||||
|         } | ||||
|  | ||||
|         Poll::Ready(Some(Ok((msg.head, msg.decode, msg.wants_upgrade)))) | ||||
|         Poll::Ready(Some(Ok((msg.head, msg.decode, wants)))) | ||||
|     } | ||||
|  | ||||
|     fn on_read_head_error<Z>(&mut self, e: crate::Error) -> Poll<Option<crate::Result<Z>>> { | ||||
| @@ -239,7 +244,19 @@ where | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             _ => unreachable!("read_body invalid state: {:?}", self.state.reading), | ||||
|             Reading::Continue(ref decoder) => { | ||||
|                 // Write the 100 Continue if not already responded... | ||||
|                 if let Writing::Init = self.state.writing { | ||||
|                     trace!("automatically sending 100 Continue"); | ||||
|                     let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; | ||||
|                     self.io.headers_buf().extend_from_slice(cont); | ||||
|                 } | ||||
|  | ||||
|                 // And now recurse once in the Reading::Body state... | ||||
|                 self.state.reading = Reading::Body(decoder.clone()); | ||||
|                 return self.poll_read_body(cx); | ||||
|             } | ||||
|             _ => unreachable!("poll_read_body invalid state: {:?}", self.state.reading), | ||||
|         }; | ||||
|  | ||||
|         self.state.reading = reading; | ||||
| @@ -346,7 +363,9 @@ where | ||||
|         // would finish. | ||||
|  | ||||
|         match self.state.reading { | ||||
|             Reading::Body(..) | Reading::KeepAlive | Reading::Closed => return, | ||||
|             Reading::Continue(..) | Reading::Body(..) | Reading::KeepAlive | Reading::Closed => { | ||||
|                 return | ||||
|             } | ||||
|             Reading::Init => (), | ||||
|         }; | ||||
|  | ||||
| @@ -711,6 +730,7 @@ struct State { | ||||
| #[derive(Debug)] | ||||
| enum Reading { | ||||
|     Init, | ||||
|     Continue(Decoder), | ||||
|     Body(Decoder), | ||||
|     KeepAlive, | ||||
|     Closed, | ||||
|   | ||||
| @@ -4,7 +4,7 @@ use bytes::{Buf, Bytes}; | ||||
| use http::{Request, Response, StatusCode}; | ||||
| use tokio::io::{AsyncRead, AsyncWrite}; | ||||
|  | ||||
| use super::Http1Transaction; | ||||
| use super::{Http1Transaction, Wants}; | ||||
| use crate::body::{Body, Payload}; | ||||
| use crate::common::{task, Future, Never, Pin, Poll, Unpin}; | ||||
| use crate::proto::{ | ||||
| @@ -235,16 +235,16 @@ where | ||||
|         } | ||||
|         // dispatch is ready for a message, try to read one | ||||
|         match ready!(self.conn.poll_read_head(cx)) { | ||||
|             Some(Ok((head, body_len, wants_upgrade))) => { | ||||
|             Some(Ok((head, body_len, wants))) => { | ||||
|                 let mut body = match body_len { | ||||
|                     DecodedLength::ZERO => Body::empty(), | ||||
|                     other => { | ||||
|                         let (tx, rx) = Body::new_channel(other); | ||||
|                         let (tx, rx) = Body::new_channel(other, wants.contains(Wants::EXPECT)); | ||||
|                         self.body_tx = Some(tx); | ||||
|                         rx | ||||
|                     } | ||||
|                 }; | ||||
|                 if wants_upgrade { | ||||
|                 if wants.contains(Wants::UPGRADE) { | ||||
|                     body.set_on_upgrade(self.conn.on_upgrade()); | ||||
|                 } | ||||
|                 self.dispatch.recv_msg(Ok((head, body)))?; | ||||
|   | ||||
| @@ -74,3 +74,22 @@ pub(crate) struct Encode<'a, T> { | ||||
|     req_method: &'a mut Option<Method>, | ||||
|     title_case_headers: bool, | ||||
| } | ||||
|  | ||||
| /// Extra flags that a request "wants", like expect-continue or upgrades. | ||||
| #[derive(Clone, Copy, Debug)] | ||||
| struct Wants(u8); | ||||
|  | ||||
| impl Wants { | ||||
|     const EMPTY: Wants = Wants(0b00); | ||||
|     const EXPECT: Wants = Wants(0b01); | ||||
|     const UPGRADE: Wants = Wants(0b10); | ||||
|  | ||||
|     #[must_use] | ||||
|     fn add(self, other: Wants) -> Wants { | ||||
|         Wants(self.0 | other.0) | ||||
|     } | ||||
|  | ||||
|     fn contains(&self, other: Wants) -> bool { | ||||
|         (self.0 & other.0) == other.0 | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user