use std::cell::Cell; use std::cmp; use std::fmt; use std::io::{self, IoSlice}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use tokio::io::{AsyncRead, AsyncWrite}; use super::{Http1Transaction, ParseContext, ParsedMessage}; use crate::common::buf::BufList; use crate::common::{task, Pin, Poll, Unpin}; /// The initial buffer size allocated before trying to read from IO. pub(crate) const INIT_BUFFER_SIZE: usize = 8192; /// The minimum value that can be set to max buffer size. pub const MINIMUM_MAX_BUFFER_SIZE: usize = INIT_BUFFER_SIZE; /// The default maximum read buffer size. If the buffer gets this big and /// a message is still not complete, a `TooLarge` error is triggered. // Note: if this changes, update server::conn::Http::max_buf_size docs. pub(crate) const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100; /// The maximum number of distinct `Buf`s to hold in a list before requiring /// a flush. Only affects when the buffer strategy is to queue buffers. /// /// Note that a flush can happen before reaching the maximum. This simply /// forces a flush if the queue gets this big. const MAX_BUF_LIST_BUFFERS: usize = 16; pub struct Buffered { flush_pipeline: bool, io: T, read_blocked: bool, read_buf: BytesMut, read_buf_strategy: ReadStrategy, write_buf: WriteBuf, } impl fmt::Debug for Buffered where B: Buf, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Buffered") .field("read_buf", &self.read_buf) .field("write_buf", &self.write_buf) .finish() } } impl Buffered where T: AsyncRead + AsyncWrite + Unpin, B: Buf, { pub fn new(io: T) -> Buffered { Buffered { flush_pipeline: false, io: io, read_blocked: false, read_buf: BytesMut::with_capacity(0), read_buf_strategy: ReadStrategy::default(), write_buf: WriteBuf::new(), } } pub fn set_flush_pipeline(&mut self, enabled: bool) { debug_assert!(!self.write_buf.has_remaining()); self.flush_pipeline = enabled; if enabled { self.set_write_strategy_flatten(); } } pub fn set_max_buf_size(&mut self, max: usize) { assert!( max >= MINIMUM_MAX_BUFFER_SIZE, "The max_buf_size cannot be smaller than {}.", MINIMUM_MAX_BUFFER_SIZE, ); self.read_buf_strategy = ReadStrategy::with_max(max); self.write_buf.max_buf_size = max; } pub fn set_read_buf_exact_size(&mut self, sz: usize) { self.read_buf_strategy = ReadStrategy::Exact(sz); } pub fn set_write_strategy_flatten(&mut self) { // this should always be called only at construction time, // so this assert is here to catch myself debug_assert!(self.write_buf.queue.bufs_cnt() == 0); self.write_buf.set_strategy(WriteStrategy::Flatten); } pub fn read_buf(&self) -> &[u8] { self.read_buf.as_ref() } #[cfg(test)] #[cfg(feature = "nightly")] pub(super) fn read_buf_mut(&mut self) -> &mut BytesMut { &mut self.read_buf } /// Return the "allocated" available space, not the potential space /// that could be allocated in the future. fn read_buf_remaining_mut(&self) -> usize { self.read_buf.capacity() - self.read_buf.len() } pub fn headers_buf(&mut self) -> &mut Vec { let buf = self.write_buf.headers_mut(); &mut buf.bytes } pub(super) fn write_buf(&mut self) -> &mut WriteBuf { &mut self.write_buf } pub fn buffer>(&mut self, buf: BB) { self.write_buf.buffer(buf) } pub fn can_buffer(&self) -> bool { self.flush_pipeline || self.write_buf.can_buffer() } pub fn consume_leading_lines(&mut self) { if !self.read_buf.is_empty() { let mut i = 0; while i < self.read_buf.len() { match self.read_buf[i] { b'\r' | b'\n' => i += 1, _ => break, } } self.read_buf.split_to(i); } } pub(super) fn parse( &mut self, cx: &mut task::Context<'_>, parse_ctx: ParseContext<'_>, ) -> Poll>> where S: Http1Transaction, { loop { match S::parse( &mut self.read_buf, ParseContext { cached_headers: parse_ctx.cached_headers, req_method: parse_ctx.req_method, }, )? { Some(msg) => { debug!("parsed {} headers", msg.head.headers.len()); return Poll::Ready(Ok(msg)); } None => { let max = self.read_buf_strategy.max(); if self.read_buf.len() >= max { debug!("max_buf_size ({}) reached, closing", max); return Poll::Ready(Err(crate::Error::new_too_large())); } } } match ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? { 0 => { trace!("parse eof"); return Poll::Ready(Err(crate::Error::new_incomplete())); } _ => {} } } } pub fn poll_read_from_io(&mut self, cx: &mut task::Context<'_>) -> Poll> { self.read_blocked = false; let next = self.read_buf_strategy.next(); if self.read_buf_remaining_mut() < next { self.read_buf.reserve(next); } match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf) { Poll::Ready(Ok(n)) => { debug!("read {} bytes", n); self.read_buf_strategy.record(n); Poll::Ready(Ok(n)) } Poll::Pending => { self.read_blocked = true; Poll::Pending } Poll::Ready(Err(e)) => Poll::Ready(Err(e)), } } pub fn into_inner(self) -> (T, Bytes) { (self.io, self.read_buf.freeze()) } pub fn io_mut(&mut self) -> &mut T { &mut self.io } pub fn is_read_blocked(&self) -> bool { self.read_blocked } pub fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll> { if self.flush_pipeline && !self.read_buf.is_empty() { Poll::Ready(Ok(())) } else if self.write_buf.remaining() == 0 { Pin::new(&mut self.io).poll_flush(cx) } else { match self.write_buf.strategy { WriteStrategy::Flatten => return self.poll_flush_flattened(cx), _ => (), } loop { let n = ready!(Pin::new(&mut self.io).poll_write_buf(cx, &mut self.write_buf.auto()))?; debug!("flushed {} bytes", n); if self.write_buf.remaining() == 0 { break; } else if n == 0 { trace!( "write returned zero, but {} bytes remaining", self.write_buf.remaining() ); return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } } Pin::new(&mut self.io).poll_flush(cx) } } /// Specialized version of `flush` when strategy is Flatten. /// /// Since all buffered bytes are flattened into the single headers buffer, /// that skips some bookkeeping around using multiple buffers. fn poll_flush_flattened(&mut self, cx: &mut task::Context<'_>) -> Poll> { loop { let n = ready!(Pin::new(&mut self.io).poll_write(cx, self.write_buf.headers.bytes()))?; debug!("flushed {} bytes", n); self.write_buf.headers.advance(n); if self.write_buf.headers.remaining() == 0 { self.write_buf.headers.reset(); break; } else if n == 0 { trace!( "write returned zero, but {} bytes remaining", self.write_buf.remaining() ); return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); } } Pin::new(&mut self.io).poll_flush(cx) } #[cfg(test)] fn flush<'a>(&'a mut self) -> impl std::future::Future> + 'a { futures_util::future::poll_fn(move |cx| self.poll_flush(cx)) } } // The `B` is a `Buf`, we never project a pin to it impl Unpin for Buffered {} // TODO: This trait is old... at least rename to PollBytes or something... pub trait MemRead { fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll>; } impl MemRead for Buffered where T: AsyncRead + AsyncWrite + Unpin, B: Buf, { fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll> { if !self.read_buf.is_empty() { let n = ::std::cmp::min(len, self.read_buf.len()); Poll::Ready(Ok(self.read_buf.split_to(n).freeze())) } else { let n = ready!(self.poll_read_from_io(cx))?; Poll::Ready(Ok(self.read_buf.split_to(::std::cmp::min(len, n)).freeze())) } } } #[derive(Clone, Copy, Debug)] enum ReadStrategy { Adaptive { decrease_now: bool, next: usize, max: usize, }, Exact(usize), } impl ReadStrategy { fn with_max(max: usize) -> ReadStrategy { ReadStrategy::Adaptive { decrease_now: false, next: INIT_BUFFER_SIZE, max, } } fn next(&self) -> usize { match *self { ReadStrategy::Adaptive { next, .. } => next, ReadStrategy::Exact(exact) => exact, } } fn max(&self) -> usize { match *self { ReadStrategy::Adaptive { max, .. } => max, ReadStrategy::Exact(exact) => exact, } } fn record(&mut self, bytes_read: usize) { match *self { ReadStrategy::Adaptive { ref mut decrease_now, ref mut next, max, .. } => { if bytes_read >= *next { *next = cmp::min(incr_power_of_two(*next), max); *decrease_now = false; } else { let decr_to = prev_power_of_two(*next); if bytes_read < decr_to { if *decrease_now { *next = cmp::max(decr_to, INIT_BUFFER_SIZE); *decrease_now = false; } else { // Decreasing is a two "record" process. *decrease_now = true; } } else { // A read within the current range should cancel // a potential decrease, since we just saw proof // that we still need this size. *decrease_now = false; } } } _ => (), } } } fn incr_power_of_two(n: usize) -> usize { n.saturating_mul(2) } fn prev_power_of_two(n: usize) -> usize { // Only way this shift can underflow is if n is less than 4. // (Which would means `usize::MAX >> 64` and underflowed!) debug_assert!(n >= 4); (::std::usize::MAX >> (n.leading_zeros() + 2)) + 1 } impl Default for ReadStrategy { fn default() -> ReadStrategy { ReadStrategy::with_max(DEFAULT_MAX_BUFFER_SIZE) } } #[derive(Clone)] pub struct Cursor { bytes: T, pos: usize, } impl> Cursor { #[inline] pub(crate) fn new(bytes: T) -> Cursor { Cursor { bytes: bytes, pos: 0, } } } impl Cursor> { fn reset(&mut self) { self.pos = 0; self.bytes.clear(); } } impl> fmt::Debug for Cursor { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Cursor") .field("pos", &self.pos) .field("len", &self.bytes.as_ref().len()) .finish() } } impl> Buf for Cursor { #[inline] fn remaining(&self) -> usize { self.bytes.as_ref().len() - self.pos } #[inline] fn bytes(&self) -> &[u8] { &self.bytes.as_ref()[self.pos..] } #[inline] fn advance(&mut self, cnt: usize) { debug_assert!(self.pos + cnt <= self.bytes.as_ref().len()); self.pos += cnt; } } // an internal buffer to collect writes before flushes pub(super) struct WriteBuf { /// Re-usable buffer that holds message headers headers: Cursor>, max_buf_size: usize, /// Deque of user buffers if strategy is Queue queue: BufList, strategy: WriteStrategy, } impl WriteBuf { fn new() -> WriteBuf { WriteBuf { headers: Cursor::new(Vec::with_capacity(INIT_BUFFER_SIZE)), max_buf_size: DEFAULT_MAX_BUFFER_SIZE, queue: BufList::new(), strategy: WriteStrategy::Auto, } } } impl WriteBuf where B: Buf, { fn set_strategy(&mut self, strategy: WriteStrategy) { self.strategy = strategy; } #[inline] fn auto(&mut self) -> WriteBufAuto<'_, B> { WriteBufAuto::new(self) } pub(super) fn buffer>(&mut self, mut buf: BB) { debug_assert!(buf.has_remaining()); match self.strategy { WriteStrategy::Flatten => { let head = self.headers_mut(); //perf: This is a little faster than >::put, //but accomplishes the same result. loop { let adv = { let slice = buf.bytes(); if slice.is_empty() { return; } head.bytes.extend_from_slice(slice); slice.len() }; buf.advance(adv); } } WriteStrategy::Auto | WriteStrategy::Queue => { self.queue.push(buf.into()); } } } fn can_buffer(&self) -> bool { match self.strategy { WriteStrategy::Flatten => self.remaining() < self.max_buf_size, WriteStrategy::Auto | WriteStrategy::Queue => { self.queue.bufs_cnt() < MAX_BUF_LIST_BUFFERS && self.remaining() < self.max_buf_size } } } fn headers_mut(&mut self) -> &mut Cursor> { debug_assert!(!self.queue.has_remaining()); &mut self.headers } } impl fmt::Debug for WriteBuf { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("WriteBuf") .field("remaining", &self.remaining()) .field("strategy", &self.strategy) .finish() } } impl Buf for WriteBuf { #[inline] fn remaining(&self) -> usize { self.headers.remaining() + self.queue.remaining() } #[inline] fn bytes(&self) -> &[u8] { let headers = self.headers.bytes(); if !headers.is_empty() { headers } else { self.queue.bytes() } } #[inline] fn advance(&mut self, cnt: usize) { let hrem = self.headers.remaining(); if hrem == cnt { self.headers.reset(); } else if hrem > cnt { self.headers.advance(cnt); } else { let qcnt = cnt - hrem; self.headers.reset(); self.queue.advance(qcnt); } } #[inline] fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { let n = self.headers.bytes_vectored(dst); self.queue.bytes_vectored(&mut dst[n..]) + n } } /// Detects when wrapped `WriteBuf` is used for vectored IO, and /// adjusts the `WriteBuf` strategy if not. struct WriteBufAuto<'a, B: Buf> { bytes_called: Cell, bytes_vec_called: Cell, inner: &'a mut WriteBuf, } impl<'a, B: Buf> WriteBufAuto<'a, B> { fn new(inner: &'a mut WriteBuf) -> WriteBufAuto<'a, B> { WriteBufAuto { bytes_called: Cell::new(false), bytes_vec_called: Cell::new(false), inner: inner, } } } impl<'a, B: Buf> Buf for WriteBufAuto<'a, B> { #[inline] fn remaining(&self) -> usize { self.inner.remaining() } #[inline] fn bytes(&self) -> &[u8] { self.bytes_called.set(true); self.inner.bytes() } #[inline] fn advance(&mut self, cnt: usize) { self.inner.advance(cnt) } #[inline] fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { self.bytes_vec_called.set(true); self.inner.bytes_vectored(dst) } } impl<'a, B: Buf + 'a> Drop for WriteBufAuto<'a, B> { fn drop(&mut self) { if let WriteStrategy::Auto = self.inner.strategy { if self.bytes_vec_called.get() { self.inner.strategy = WriteStrategy::Queue; } else if self.bytes_called.get() { trace!("detected no usage of vectored write, flattening"); self.inner.strategy = WriteStrategy::Flatten; self.inner.headers.bytes.put(&mut self.inner.queue); } } } } #[derive(Debug)] enum WriteStrategy { Auto, Flatten, Queue, } #[cfg(test)] mod tests { use super::*; use std::time::Duration; use tokio_test::io::Builder as Mock; #[cfg(feature = "nightly")] use test::Bencher; /* impl MemRead for AsyncIo { fn read_mem(&mut self, len: usize) -> Poll { let mut v = vec![0; len]; let n = try_nb!(self.read(v.as_mut_slice())); Ok(Async::Ready(BytesMut::from(&v[..n]).freeze())) } } */ #[tokio::test] async fn iobuf_write_empty_slice() { // First, let's just check that the Mock would normally return an // error on an unexpected write, even if the buffer is empty... let mut mock = Mock::new().build(); futures_util::future::poll_fn(|cx| { Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[])) }) .await .expect_err("should be a broken pipe"); // underlying io will return the logic error upon write, // so we are testing that the io_buf does not trigger a write // when there is nothing to flush let mock = Mock::new().build(); let mut io_buf = Buffered::<_, Cursor>>::new(mock); io_buf.flush().await.expect("should short-circuit flush"); } #[tokio::test] async fn parse_reads_until_blocked() { use crate::proto::h1::ClientTransaction; let mock = Mock::new() // Split over multiple reads will read all of it .read(b"HTTP/1.1 200 OK\r\n") .read(b"Server: hyper\r\n") // missing last line ending .wait(Duration::from_secs(1)) .build(); let mut buffered = Buffered::<_, Cursor>>::new(mock); // We expect a `parse` to be not ready, and so can't await it directly. // Rather, this `poll_fn` will wrap the `Poll` result. futures_util::future::poll_fn(|cx| { let parse_ctx = ParseContext { cached_headers: &mut None, req_method: &mut None, }; assert!(buffered .parse::(cx, parse_ctx) .is_pending()); Poll::Ready(()) }) .await; assert_eq!( buffered.read_buf, b"HTTP/1.1 200 OK\r\nServer: hyper\r\n"[..] ); } #[test] fn read_strategy_adaptive_increments() { let mut strategy = ReadStrategy::default(); assert_eq!(strategy.next(), 8192); // Grows if record == next strategy.record(8192); assert_eq!(strategy.next(), 16384); strategy.record(16384); assert_eq!(strategy.next(), 32768); // Enormous records still increment at same rate strategy.record(::std::usize::MAX); assert_eq!(strategy.next(), 65536); let max = strategy.max(); while strategy.next() < max { strategy.record(max); } assert_eq!(strategy.next(), max, "never goes over max"); strategy.record(max + 1); assert_eq!(strategy.next(), max, "never goes over max"); } #[test] fn read_strategy_adaptive_decrements() { let mut strategy = ReadStrategy::default(); strategy.record(8192); assert_eq!(strategy.next(), 16384); strategy.record(1); assert_eq!( strategy.next(), 16384, "first smaller record doesn't decrement yet" ); strategy.record(8192); assert_eq!(strategy.next(), 16384, "record was with range"); strategy.record(1); assert_eq!( strategy.next(), 16384, "in-range record should make this the 'first' again" ); strategy.record(1); assert_eq!(strategy.next(), 8192, "second smaller record decrements"); strategy.record(1); assert_eq!(strategy.next(), 8192, "first doesn't decrement"); strategy.record(1); assert_eq!(strategy.next(), 8192, "doesn't decrement under minimum"); } #[test] fn read_strategy_adaptive_stays_the_same() { let mut strategy = ReadStrategy::default(); strategy.record(8192); assert_eq!(strategy.next(), 16384); strategy.record(8193); assert_eq!( strategy.next(), 16384, "first smaller record doesn't decrement yet" ); strategy.record(8193); assert_eq!( strategy.next(), 16384, "with current step does not decrement" ); } #[test] fn read_strategy_adaptive_max_fuzz() { fn fuzz(max: usize) { let mut strategy = ReadStrategy::with_max(max); while strategy.next() < max { strategy.record(::std::usize::MAX); } let mut next = strategy.next(); while next > 8192 { strategy.record(1); strategy.record(1); next = strategy.next(); assert!( next.is_power_of_two(), "decrement should be powers of two: {} (max = {})", next, max, ); } } let mut max = 8192; while max < ::std::usize::MAX { fuzz(max); max = (max / 2).saturating_mul(3); } fuzz(::std::usize::MAX); } #[test] #[should_panic] #[cfg(debug_assertions)] // needs to trigger a debug_assert fn write_buf_requires_non_empty_bufs() { let mock = Mock::new().build(); let mut buffered = Buffered::<_, Cursor>>::new(mock); buffered.buffer(Cursor::new(Vec::new())); } /* TODO: needs tokio_test::io to allow configure write_buf calls #[test] fn write_buf_queue() { let _ = pretty_env_logger::try_init(); let mock = AsyncIo::new_buf(vec![], 1024); let mut buffered = Buffered::<_, Cursor>>::new(mock); buffered.headers_buf().extend(b"hello "); buffered.buffer(Cursor::new(b"world, ".to_vec())); buffered.buffer(Cursor::new(b"it's ".to_vec())); buffered.buffer(Cursor::new(b"hyper!".to_vec())); assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); buffered.flush().unwrap(); assert_eq!(buffered.io, b"hello world, it's hyper!"); assert_eq!(buffered.io.num_writes(), 1); assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); } */ #[tokio::test] async fn write_buf_flatten() { let _ = pretty_env_logger::try_init(); let mock = Mock::new() // Just a single write .write(b"hello world, it's hyper!") .build(); let mut buffered = Buffered::<_, Cursor>>::new(mock); buffered.write_buf.set_strategy(WriteStrategy::Flatten); buffered.headers_buf().extend(b"hello "); buffered.buffer(Cursor::new(b"world, ".to_vec())); buffered.buffer(Cursor::new(b"it's ".to_vec())); buffered.buffer(Cursor::new(b"hyper!".to_vec())); assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); buffered.flush().await.expect("flush"); } #[tokio::test] async fn write_buf_auto_flatten() { let _ = pretty_env_logger::try_init(); let mock = Mock::new() // Expects write_buf to only consume first buffer .write(b"hello ") // And then the Auto strategy will have flattened .write(b"world, it's hyper!") .build(); let mut buffered = Buffered::<_, Cursor>>::new(mock); // we have 4 buffers, but hope to detect that vectored IO isn't // being used, and switch to flattening automatically, // resulting in only 2 writes buffered.headers_buf().extend(b"hello "); buffered.buffer(Cursor::new(b"world, ".to_vec())); buffered.buffer(Cursor::new(b"it's ".to_vec())); buffered.buffer(Cursor::new(b"hyper!".to_vec())); assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); buffered.flush().await.expect("flush"); assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); } #[tokio::test] async fn write_buf_queue_disable_auto() { let _ = pretty_env_logger::try_init(); let mock = Mock::new() .write(b"hello ") .write(b"world, ") .write(b"it's ") .write(b"hyper!") .build(); let mut buffered = Buffered::<_, Cursor>>::new(mock); buffered.write_buf.set_strategy(WriteStrategy::Queue); // we have 4 buffers, and vec IO disabled, but explicitly said // don't try to auto detect (via setting strategy above) buffered.headers_buf().extend(b"hello "); buffered.buffer(Cursor::new(b"world, ".to_vec())); buffered.buffer(Cursor::new(b"it's ".to_vec())); buffered.buffer(Cursor::new(b"hyper!".to_vec())); assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3); buffered.flush().await.expect("flush"); assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0); } #[cfg(feature = "nightly")] #[bench] fn bench_write_buf_flatten_buffer_chunk(b: &mut Bencher) { let s = "Hello, World!"; b.bytes = s.len() as u64; let mut write_buf = WriteBuf::::new(); write_buf.set_strategy(WriteStrategy::Flatten); b.iter(|| { let chunk = bytes::Bytes::from(s); write_buf.buffer(chunk); ::test::black_box(&write_buf); write_buf.headers.bytes.clear(); }) } }