From 898e91950473cd6d3cbf4c34fcb8a24b555a0ce7 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Thu, 31 May 2018 17:42:55 -0700 Subject: [PATCH] perf(h1): optimize for when Body is only 1 chunk - When the `Body` is created from a buffer of bytes (such as `Body::from("hello")`), we can skip some bookkeeping that is normally required for streaming bodies. - Orthogonally, optimize encoding body chunks when the strategy is to flatten into the headers buf, by skipping the EncodedBuf enum. --- src/body.rs | 37 +++++++++++++++++++ src/proto/h1/conn.rs | 50 ++++++++++++++++++------- src/proto/h1/dispatch.rs | 42 +++++++++++---------- src/proto/h1/encode.rs | 79 ++++++++++++++++++++++++++++++++++------ src/proto/h1/io.rs | 12 ++++-- src/proto/h1/role.rs | 41 ++++----------------- tests/server.rs | 29 +++++++++++++++ 7 files changed, 207 insertions(+), 83 deletions(-) diff --git a/src/body.rs b/src/body.rs index 1c19d942..71da99d8 100644 --- a/src/body.rs +++ b/src/body.rs @@ -26,8 +26,11 @@ use http::HeaderMap; use common::Never; pub use chunk::Chunk; +use self::internal::{FullDataArg, FullDataRet}; + type BodySender = mpsc::Sender>; + /// This trait represents a streaming body of a `Request` or `Response`. /// /// The built-in implementation of this trait is [`Body`](Body), in case you @@ -80,6 +83,16 @@ pub trait Payload: Send + 'static { fn content_length(&self) -> Option { None } + + // This API is unstable, and is impossible to use outside of hyper. Some + // form of it may become stable in a later version. + // + // The only thing a user *could* do is reference the method, but DON'T + // DO THAT! :) + #[doc(hidden)] + fn __hyper_full_data(&mut self, FullDataArg) -> FullDataRet { + FullDataRet(None) + } } impl Payload for Box { @@ -343,6 +356,14 @@ impl Payload for Body { Kind::Wrapped(..) => None, } } + + // We can improve the performance of `Body` when we know it is a Once kind. + fn __hyper_full_data(&mut self, _: FullDataArg) -> FullDataRet { + match self.kind { + Kind::Once(ref mut val) => FullDataRet(val.take()), + _ => FullDataRet(None), + } + } } impl Stream for Body { @@ -469,6 +490,22 @@ impl From> for Body { } } +// The full_data API is not stable, so these types are to try to prevent +// users from being able to: +// +// - Implment `__hyper_full_data` on their own Payloads. +// - Call `__hyper_full_data` on any Payload. +// +// That's because to implement it, they need to name these types, and +// they can't because they aren't exported. And to call it, they would +// need to create one of these values, which they also can't. +pub(crate) mod internal { + #[allow(missing_debug_implementations)] + pub struct FullDataArg(pub(crate) ()); + #[allow(missing_debug_implementations)] + pub struct FullDataRet(pub(crate) Option); +} + fn _assert_send_sync() { fn _assert_send() {} fn _assert_sync() {} diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 128bc582..364f890c 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -388,7 +388,35 @@ where I: AsyncRead + AsyncWrite, self.io.can_buffer() } - pub fn write_head(&mut self, mut head: MessageHead, body: Option) { + pub fn write_head(&mut self, head: MessageHead, body: Option) { + if let Some(encoder) = self.encode_head(head, body) { + self.state.writing = if !encoder.is_eof() { + Writing::Body(encoder) + } else if encoder.is_last() { + Writing::Closed + } else { + Writing::KeepAlive + }; + } + } + + pub fn write_full_msg(&mut self, head: MessageHead, body: B) { + if let Some(encoder) = self.encode_head(head, Some(BodyLength::Known(body.remaining() as u64))) { + let is_last = encoder.is_last(); + // Make sure we don't write a body if we weren't actually allowed + // to do so, like because its a HEAD request. + if !encoder.is_eof() { + encoder.danger_full_buf(body, self.io.write_buf()); + } + self.state.writing = if is_last { + Writing::Closed + } else { + Writing::KeepAlive + } + } + } + + fn encode_head(&mut self, mut head: MessageHead, body: Option) -> Option { debug_assert!(self.can_write_head()); if !T::should_read_first() { @@ -398,7 +426,7 @@ where I: AsyncRead + AsyncWrite, self.enforce_version(&mut head); let buf = self.io.headers_buf(); - self.state.writing = match T::encode(Encode { + match T::encode(Encode { head: &mut head, body, keep_alive: self.state.wants_keep_alive(), @@ -409,19 +437,14 @@ where I: AsyncRead + AsyncWrite, debug_assert!(self.state.cached_headers.is_none()); debug_assert!(head.headers.is_empty()); self.state.cached_headers = Some(head.headers); - if !encoder.is_eof() { - Writing::Body(encoder) - } else if encoder.is_last() { - Writing::Closed - } else { - Writing::KeepAlive - } + Some(encoder) }, Err(err) => { self.state.error = Some(err); - Writing::Closed - } - }; + self.state.writing = Writing::Closed; + None + }, + } } // If we know the remote speaks an older version, we try to fix up any messages @@ -474,8 +497,7 @@ where I: AsyncRead + AsyncWrite, let state = match self.state.writing { Writing::Body(ref encoder) => { - let (encoded, can_keep_alive) = encoder.encode_and_end(chunk); - self.io.buffer(encoded); + let can_keep_alive = encoder.encode_and_end(chunk, self.io.write_buf()); if can_keep_alive { Writing::KeepAlive } else { diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 364048c7..5b34c7e4 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -4,6 +4,7 @@ use http::{Request, Response, StatusCode}; use tokio_io::{AsyncRead, AsyncWrite}; use body::{Body, Payload}; +use body::internal::FullDataArg; use proto::{BodyLength, Conn, MessageHead, RequestHead, RequestLine, ResponseHead}; use super::Http1Transaction; use service::Service; @@ -20,7 +21,7 @@ pub(crate) trait Dispatch { type PollItem; type PollBody; type RecvItem; - fn poll_msg(&mut self) -> Poll)>, ::Error>; + fn poll_msg(&mut self) -> Poll, ::Error>; fn recv_msg(&mut self, msg: ::Result<(Self::RecvItem, Body)>) -> ::Result<()>; fn poll_ready(&mut self) -> Poll<(), ()>; fn should_poll(&self) -> bool; @@ -222,14 +223,26 @@ where if self.is_closing { return Ok(Async::Ready(())); } else if self.body_rx.is_none() && self.conn.can_write_head() && self.dispatch.should_poll() { - if let Some((head, body)) = try_ready!(self.dispatch.poll_msg()) { - let body_type = body.as_ref().map(|body| { - body.content_length() + if let Some((head, mut body)) = try_ready!(self.dispatch.poll_msg()) { + // Check if the body knows its full data immediately. + // + // If so, we can skip a bit of bookkeeping that streaming + // bodies need to do. + if let Some(full) = body.__hyper_full_data(FullDataArg(())).0 { + self.conn.write_full_msg(head, full); + return Ok(Async::Ready(())); + } + let body_type = if body.is_end_stream() { + self.body_rx = None; + None + } else { + let btype = body.content_length() .map(BodyLength::Known) - .unwrap_or(BodyLength::Unknown) - }); + .or_else(|| Some(BodyLength::Unknown)); + self.body_rx = Some(body); + btype + }; self.conn.write_head(head, body_type); - self.body_rx = body; } else { self.close(); return Ok(Async::Ready(())); @@ -349,7 +362,7 @@ where type PollBody = Bs; type RecvItem = RequestHead; - fn poll_msg(&mut self) -> Poll)>, ::Error> { + fn poll_msg(&mut self) -> Poll, ::Error> { if let Some(mut fut) = self.in_flight.take() { let resp = match fut.poll().map_err(::Error::new_user_service)? { Async::Ready(res) => res, @@ -364,11 +377,6 @@ where subject: parts.status, headers: parts.headers, }; - let body = if body.is_end_stream() { - None - } else { - Some(body) - }; Ok(Async::Ready(Some((head, body)))) } else { unreachable!("poll_msg shouldn't be called if no inflight"); @@ -419,7 +427,7 @@ where type PollBody = B; type RecvItem = ResponseHead; - fn poll_msg(&mut self) -> Poll)>, ::Error> { + fn poll_msg(&mut self) -> Poll, ::Error> { match self.rx.poll() { Ok(Async::Ready(Some((req, mut cb)))) => { // check that future hasn't been canceled already @@ -435,12 +443,6 @@ where subject: RequestLine(parts.method, parts.uri), headers: parts.headers, }; - - let body = if body.is_end_stream() { - None - } else { - Some(body) - }; self.callback = Some(cb); Ok(Async::Ready(Some((head, body)))) } diff --git a/src/proto/h1/encode.rs b/src/proto/h1/encode.rs index cfda1c30..fe7c0025 100644 --- a/src/proto/h1/encode.rs +++ b/src/proto/h1/encode.rs @@ -5,6 +5,7 @@ use bytes::buf::{Chain, Take}; use iovec::IoVec; use common::StaticBuf; +use super::io::WriteBuf; /// Encoders to handle different Transfer-Encodings. #[derive(Debug, Clone, PartialEq)] @@ -126,7 +127,7 @@ impl Encoder { } } - pub fn encode_and_end(&self, msg: B) -> (EncodedBuf, bool) + pub(super) fn encode_and_end(&self, msg: B, dst: &mut WriteBuf>) -> bool where B: IntoBuf, { @@ -134,13 +135,14 @@ impl Encoder { let len = msg.remaining(); debug_assert!(len > 0, "encode() called with empty buf"); - let (kind, eof) = match self.kind { + match self.kind { Kind::Chunked => { trace!("encoding chunked {}B", len); let buf = ChunkSize::new(len) .chain(msg) .chain(StaticBuf(b"\r\n0\r\n\r\n")); - (BufKind::Chunked(buf), !self.is_last) + dst.buffer(buf); + !self.is_last }, Kind::Length(remaining) => { use std::cmp::Ordering; @@ -148,25 +150,56 @@ impl Encoder { trace!("sized write, len = {}", len); match (len as u64).cmp(&remaining) { Ordering::Equal => { - (BufKind::Exact(msg), !self.is_last) + dst.buffer(msg); + !self.is_last }, Ordering::Greater => { - (BufKind::Limited(msg.take(remaining as usize)), !self.is_last) + dst.buffer(msg.take(remaining as usize)); + !self.is_last }, Ordering::Less => { - (BufKind::Exact(msg), false) + dst.buffer(msg); + false } } }, Kind::CloseDelimited => { trace!("close delimited write {}B", len); - (BufKind::Exact(msg), false) + dst.buffer(msg); + false } - }; + } + } - (EncodedBuf { - kind, - }, eof) + /// Encodes the full body, without verifying the remaining length matches. + /// + /// This is used in conjunction with Payload::__hyper_full_data(), which + /// means we can trust that the buf has the correct size (the buf itself + /// was checked to make the headers). + pub(super) fn danger_full_buf(self, msg: B, dst: &mut WriteBuf>) + where + B: IntoBuf, + { + let msg = msg.into_buf(); + debug_assert!(msg.remaining() > 0, "encode() called with empty buf"); + debug_assert!(match self.kind { + Kind::Length(len) => len == msg.remaining() as u64, + _ => true, + }, "danger_full_buf length mismatches"); + + match self.kind { + Kind::Chunked => { + let len = msg.remaining(); + trace!("encoding chunked {}B", len); + let buf = ChunkSize::new(len) + .chain(msg) + .chain(StaticBuf(b"\r\n0\r\n\r\n")); + dst.buffer(buf); + }, + _ => { + dst.buffer(msg); + }, + } } } @@ -283,6 +316,30 @@ impl fmt::Write for ChunkSize { } } +impl From for EncodedBuf { + fn from(buf: B) -> Self { + EncodedBuf { + kind: BufKind::Exact(buf), + } + } +} + +impl From> for EncodedBuf { + fn from(buf: Take) -> Self { + EncodedBuf { + kind: BufKind::Limited(buf), + } + } +} + +impl From, StaticBuf>> for EncodedBuf { + fn from(buf: Chain, StaticBuf>) -> Self { + EncodedBuf { + kind: BufKind::Chunked(buf), + } + } +} + #[cfg(test)] mod tests { use bytes::{BufMut}; diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index 0e280b00..993c254b 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -99,7 +99,11 @@ where &mut buf.bytes } - pub fn buffer(&mut self, buf: B) { + pub(super) fn write_buf(&mut self) -> &mut WriteBuf { + &mut self.write_buf + } + + pub fn buffer>(&mut self, buf: BB) { self.write_buf.buffer(buf) } @@ -300,7 +304,7 @@ impl> Buf for Cursor { } // an internal buffer to collect writes before flushes -struct WriteBuf { +pub(super) struct WriteBuf { /// Re-usable buffer that holds message headers headers: Cursor>, max_buf_size: usize, @@ -334,7 +338,7 @@ where WriteBufAuto::new(self) } - fn buffer(&mut self, buf: B) { + pub(super) fn buffer>(&mut self, buf: BB) { debug_assert!(buf.has_remaining()); match self.strategy { Strategy::Flatten => { @@ -342,7 +346,7 @@ where head.bytes.put(buf); }, Strategy::Auto | Strategy::Queue => { - self.queue.bufs.push_back(buf); + self.queue.bufs.push_back(buf.into()); }, } } diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index a3093d50..86e9bc95 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -431,6 +431,11 @@ where }; } + if !Server::can_have_body(msg.req_method, msg.head.subject) { + trace!("body not allowed for {:?} {:?}", msg.req_method, msg.head.subject); + encoder = Encoder::length(0); + } + // cached date is much faster than formatting every request if !wrote_date { dst.reserve(date::DATE_VALUE_LENGTH + 8); @@ -479,41 +484,9 @@ where } impl Server<()> { - /* - fn set_length(head: &mut MessageHead, body: Option, method: Option<&Method>) -> Encoder { - // these are here thanks to borrowck - // `if method == Some(&Method::Get)` says the RHS doesn't live long enough - const HEAD: Option<&'static Method> = Some(&Method::HEAD); - const CONNECT: Option<&'static Method> = Some(&Method::CONNECT); - - let can_have_body = { - if method == HEAD { - false - } else if method == CONNECT && head.subject.is_success() { - false - } else { - match head.subject { - // TODO: support for 1xx codes needs improvement everywhere - // would be 100...199 => false - StatusCode::SWITCHING_PROTOCOLS | - StatusCode::NO_CONTENT | - StatusCode::NOT_MODIFIED => false, - _ => true, - } - } - }; - - if let (Some(body), true) = (body, can_have_body) { - set_length(&mut head.headers, body, head.version == Version::HTTP_11) - } else { - head.headers.remove(header::TRANSFER_ENCODING); - if can_have_body { - headers::content_length_zero(&mut head.headers); - } - Encoder::length(0) - } + fn can_have_body(method: &Option, status: StatusCode) -> bool { + Server::can_chunked(method, status) } - */ fn can_chunked(method: &Option, status: StatusCode) -> bool { if method == &Some(Method::HEAD) { diff --git a/tests/server.rs b/tests/server.rs index 878f23db..09c2e9f5 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -480,6 +480,35 @@ fn head_response_can_send_content_length() { assert_eq!(lines.next(), None); } +#[test] +fn head_response_doesnt_send_body() { + extern crate pretty_env_logger; + let _ = pretty_env_logger::try_init(); + let foo_bar = b"foo bar baz"; + let server = serve(); + server.reply() + .body(foo_bar); + let mut req = connect(server.addr()); + req.write_all(b"\ + HEAD / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + ").unwrap(); + + let mut response = String::new(); + req.read_to_string(&mut response).unwrap(); + + assert!(response.contains("content-length: 11\r\n")); + + let mut lines = response.lines(); + assert_eq!(lines.next(), Some("HTTP/1.1 200 OK")); + + let mut lines = lines.skip_while(|line| !line.is_empty()); + assert_eq!(lines.next(), Some("")); + assert_eq!(lines.next(), None); +} + #[test] fn response_does_not_set_chunked_if_body_not_allowed() { extern crate pretty_env_logger;