diff --git a/src/body/body.rs b/src/body/body.rs index 166b4203..3308d3b3 100644 --- a/src/body/body.rs +++ b/src/body/body.rs @@ -12,6 +12,7 @@ use http::HeaderMap; use http_body::{Body as HttpBody, SizeHint}; use crate::common::{task, Future, Never, Pin, Poll}; +use crate::proto::DecodedLength; use crate::upgrade::OnUpgrade; type BodySender = mpsc::Sender>; @@ -31,12 +32,12 @@ pub struct Body { enum Kind { Once(Option), Chan { - content_length: Option, + content_length: DecodedLength, abort_rx: oneshot::Receiver<()>, rx: mpsc::Receiver>, }, H2 { - content_length: Option, + content_length: DecodedLength, recv: h2::RecvStream, }, // NOTE: This requires `Sync` because of how easy it is to use `await` @@ -105,10 +106,10 @@ impl Body { /// Useful when wanting to stream chunks from another thread. #[inline] pub fn channel() -> (Sender, Body) { - Self::new_channel(None) + Self::new_channel(DecodedLength::CHUNKED) } - pub(crate) fn new_channel(content_length: Option) -> (Sender, Body) { + pub(crate) fn new_channel(content_length: DecodedLength) -> (Sender, Body) { let (tx, rx) = mpsc::channel(0); let (abort_tx, abort_rx) = oneshot::channel(); @@ -167,7 +168,7 @@ impl Body { Body { kind, extra: None } } - pub(crate) fn h2(recv: h2::RecvStream, content_length: Option) -> Self { + pub(crate) fn h2(recv: h2::RecvStream, content_length: DecodedLength) -> Self { Body::new(Kind::H2 { content_length, recv, @@ -243,20 +244,19 @@ impl Body { match ready!(Pin::new(rx).poll_next(cx)?) { Some(chunk) => { - if let Some(ref mut len) = *len { - debug_assert!(*len >= chunk.len() as u64); - *len -= chunk.len() as u64; - } + len.sub_if(chunk.len() as u64); Poll::Ready(Some(Ok(chunk))) } None => Poll::Ready(None), } } Kind::H2 { - recv: ref mut h2, .. + recv: ref mut h2, + content_length: ref mut len, } => match ready!(h2.poll_data(cx)) { Some(Ok(bytes)) => { let _ = h2.flow_control().release_capacity(bytes.len()); + len.sub_if(bytes.len() as u64); Poll::Ready(Some(Ok(bytes))) } Some(Err(e)) => Poll::Ready(Some(Err(crate::Error::new_body(e)))), @@ -317,7 +317,7 @@ impl HttpBody for Body { fn is_end_stream(&self) -> bool { match self.kind { Kind::Once(ref val) => val.is_none(), - Kind::Chan { content_length, .. } => content_length == Some(0), + Kind::Chan { content_length, .. } => content_length == DecodedLength::ZERO, Kind::H2 { recv: ref h2, .. } => h2.is_end_stream(), #[cfg(feature = "stream")] Kind::Wrapped(..) => false, @@ -337,7 +337,7 @@ impl HttpBody for Body { Kind::Chan { content_length, .. } | Kind::H2 { content_length, .. } => { let mut hint = SizeHint::default(); - if let Some(content_length) = content_length { + if let Some(content_length) = content_length.into_opt() { hint.set_exact(content_length as u64); } @@ -498,6 +498,7 @@ impl Sender { /// Aborts the body in an abnormal fashion. pub fn abort(self) { + // TODO(sean): this can just be `self.tx.clone().try_send()` let _ = self.abort_tx.send(()); } @@ -505,3 +506,39 @@ impl Sender { let _ = self.tx.try_send(Err(err)); } } + +#[cfg(test)] +mod tests { + use std::mem; + + use super::{Body, Sender}; + + #[test] + fn test_size_of() { + // These are mostly to help catch *accidentally* increasing + // the size by too much. + assert_eq!( + mem::size_of::(), + mem::size_of::() * 5 + mem::size_of::(), + "Body" + ); + + assert_eq!( + mem::size_of::(), + mem::size_of::>(), + "Option" + ); + + assert_eq!( + mem::size_of::(), + mem::size_of::() * 4, + "Sender" + ); + + assert_eq!( + mem::size_of::(), + mem::size_of::>(), + "Option" + ); + } +} diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 9e4a6ec3..07d05fca 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -239,7 +239,7 @@ where let mut body = match body_len { DecodedLength::ZERO => Body::empty(), other => { - let (tx, rx) = Body::new_channel(other.into_opt()); + let (tx, rx) = Body::new_channel(other); self.body_tx = Some(tx); rx } diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs index 60e28fa4..9b87121e 100644 --- a/src/proto/h2/client.rs +++ b/src/proto/h2/client.rs @@ -4,11 +4,10 @@ use futures_util::stream::StreamExt as _; use h2::client::{Builder, SendRequest}; use tokio::io::{AsyncRead, AsyncWrite}; -use super::{PipeToSendStream, SendBuf}; +use super::{decode_content_length, PipeToSendStream, SendBuf}; use crate::body::Payload; use crate::common::{task, Exec, Future, Never, Pin, Poll}; use crate::headers; -use crate::headers::content_length_parse_all; use crate::proto::Dispatched; use crate::{Body, Request, Response}; @@ -159,7 +158,7 @@ where let fut = fut.map(move |result| match result { Ok(res) => { - let content_length = content_length_parse_all(res.headers()); + let content_length = decode_content_length(res.headers()); let res = res.map(|stream| crate::Body::h2(stream, content_length)); Ok(res) } diff --git a/src/proto/h2/mod.rs b/src/proto/h2/mod.rs index 2a3bbe71..04d45e08 100644 --- a/src/proto/h2/mod.rs +++ b/src/proto/h2/mod.rs @@ -7,8 +7,10 @@ use http::header::{ use http::HeaderMap; use pin_project::pin_project; +use super::DecodedLength; use crate::body::Payload; use crate::common::{task, Future, Pin, Poll}; +use crate::headers::content_length_parse_all; pub(crate) mod client; pub(crate) mod server; @@ -71,6 +73,15 @@ fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) { } } +fn decode_content_length(headers: &HeaderMap) -> DecodedLength { + if let Some(len) = content_length_parse_all(headers) { + // If the length is u64::MAX, oh well, just reported chunked. + DecodedLength::checked_new(len).unwrap_or_else(|_| DecodedLength::CHUNKED) + } else { + DecodedLength::CHUNKED + } +} + // body adapters used by both Client and Server #[pin_project] diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index d5d9e025..84689b3d 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -6,12 +6,11 @@ use h2::Reason; use pin_project::{pin_project, project}; use tokio::io::{AsyncRead, AsyncWrite}; -use super::{PipeToSendStream, SendBuf}; +use super::{decode_content_length, PipeToSendStream, SendBuf}; use crate::body::Payload; use crate::common::exec::H2Exec; use crate::common::{task, Future, Pin, Poll}; use crate::headers; -use crate::headers::content_length_parse_all; use crate::proto::Dispatched; use crate::service::HttpService; @@ -168,7 +167,7 @@ where match ready!(self.conn.poll_accept(cx)) { Some(Ok((req, respond))) => { trace!("incoming request"); - let content_length = content_length_parse_all(req.headers()); + let content_length = decode_content_length(req.headers()); let req = req.map(|stream| crate::Body::h2(stream, content_length)); let fut = H2Stream::new(service.call(req), respond); exec.execute_h2stream(fut); diff --git a/src/proto/mod.rs b/src/proto/mod.rs index 5e2c631b..3dbac84f 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -1,7 +1,7 @@ //! Pieces pertaining to the HTTP message protocol. use http::{HeaderMap, Method, StatusCode, Uri, Version}; -use self::body_length::DecodedLength; +pub(crate) use self::body_length::DecodedLength; pub(crate) use self::h1::{dispatch, Conn, ServerTransaction}; pub(crate) mod h1; @@ -90,6 +90,15 @@ mod body_length { Err(crate::error::Parse::TooLarge) } } + + pub(crate) fn sub_if(&mut self, amt: u64) { + match *self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => (), + DecodedLength(ref mut known) => { + *known -= amt; + } + } + } } impl fmt::Debug for DecodedLength { @@ -112,4 +121,25 @@ mod body_length { } } } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn sub_if_known() { + let mut len = DecodedLength::new(30); + len.sub_if(20); + + assert_eq!(len.0, 10); + } + + #[test] + fn sub_if_chunked() { + let mut len = DecodedLength::CHUNKED; + len.sub_if(20); + + assert_eq!(len, DecodedLength::CHUNKED); + } + } }