feat(server): add Http::max_buf_size() option

The internal connection's read and write bufs will be restricted from
growing bigger than the configured `max_buf_size`.

Closes #1368
This commit is contained in:
Sean McArthur
2018-01-23 16:09:17 -08:00
parent 7cb72d2019
commit d22deb6572
5 changed files with 84 additions and 17 deletions

View File

@@ -58,6 +58,10 @@ where I: AsyncRead + AsyncWrite,
self.io.set_flush_pipeline(enabled);
}
pub fn set_max_buf_size(&mut self, max: usize) {
self.io.set_max_buf_size(max);
}
#[cfg(feature = "tokio-proto")]
fn poll_incoming(&mut self) -> Poll<Option<Frame<super::MessageHead<T::Incoming>, super::Chunk, ::Error>>, io::Error> {
trace!("Conn::poll_incoming()");
@@ -1221,7 +1225,7 @@ mod tests {
let _: Result<(), ()> = future::lazy(|| {
let io = AsyncIo::new_buf(vec![], 0);
let mut conn = Conn::<_, proto::Chunk, ServerTransaction>::new(io, Default::default());
let max = ::proto::io::MAX_BUFFER_SIZE + 4096;
let max = ::proto::io::DEFAULT_MAX_BUFFER_SIZE + 4096;
conn.state.writing = Writing::Body(Encoder::length((max * 2) as u64), None);
assert!(conn.start_send(Frame::Body { chunk: Some(vec![b'a'; 1024 * 8].into()) }).unwrap().is_ready());

View File

@@ -10,11 +10,12 @@ use super::{Http1Transaction, MessageHead};
use bytes::{BytesMut, Bytes};
const INIT_BUFFER_SIZE: usize = 8192;
pub const MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100;
pub const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100;
pub struct Buffered<T> {
flush_pipeline: bool,
io: T,
max_buf_size: usize,
read_blocked: bool,
read_buf: BytesMut,
write_buf: WriteBuf,
@@ -34,6 +35,7 @@ impl<T: AsyncRead + AsyncWrite> Buffered<T> {
Buffered {
flush_pipeline: false,
io: io,
max_buf_size: DEFAULT_MAX_BUFFER_SIZE,
read_buf: BytesMut::with_capacity(0),
write_buf: WriteBuf::new(),
read_blocked: false,
@@ -44,6 +46,11 @@ impl<T: AsyncRead + AsyncWrite> Buffered<T> {
self.flush_pipeline = enabled;
}
pub fn set_max_buf_size(&mut self, max: usize) {
self.max_buf_size = max;
self.write_buf.max_buf_size = max;
}
pub fn read_buf(&self) -> &[u8] {
self.read_buf.as_ref()
}
@@ -51,7 +58,7 @@ impl<T: AsyncRead + AsyncWrite> Buffered<T> {
pub fn write_buf_mut(&mut self) -> &mut Vec<u8> {
self.write_buf.maybe_reset();
self.write_buf.maybe_reserve(0);
&mut self.write_buf.0.bytes
&mut self.write_buf.buf.bytes
}
pub fn consume_leading_lines(&mut self) {
@@ -75,8 +82,8 @@ impl<T: AsyncRead + AsyncWrite> Buffered<T> {
return Ok(Async::Ready(head))
},
None => {
if self.read_buf.capacity() >= MAX_BUFFER_SIZE {
debug!("MAX_BUFFER_SIZE reached, closing");
if self.read_buf.capacity() >= self.max_buf_size {
debug!("max_buf_size ({}) reached, closing", self.max_buf_size);
return Err(::Error::TooLarge);
}
},
@@ -259,22 +266,28 @@ impl<T: Write> AtomicWrite for T {
// an internal buffer to collect writes before flushes
#[derive(Debug)]
struct WriteBuf(Cursor<Vec<u8>>);
struct WriteBuf{
buf: Cursor<Vec<u8>>,
max_buf_size: usize,
}
impl WriteBuf {
fn new() -> WriteBuf {
WriteBuf(Cursor::new(Vec::new()))
WriteBuf {
buf: Cursor::new(Vec::new()),
max_buf_size: DEFAULT_MAX_BUFFER_SIZE,
}
}
fn write_into<W: Write>(&mut self, w: &mut W) -> io::Result<usize> {
self.0.write_to(w)
self.buf.write_to(w)
}
fn buffer(&mut self, data: &[u8]) -> usize {
trace!("WriteBuf::buffer() len = {:?}", data.len());
self.maybe_reset();
self.maybe_reserve(data.len());
let vec = &mut self.0.bytes;
let vec = &mut self.buf.bytes;
let len = cmp::min(vec.capacity() - vec.len(), data.len());
assert!(vec.capacity() - vec.len() >= len);
unsafe {
@@ -291,28 +304,28 @@ impl WriteBuf {
}
fn remaining(&self) -> usize {
self.0.remaining()
self.buf.remaining()
}
#[inline]
fn maybe_reserve(&mut self, needed: usize) {
let vec = &mut self.0.bytes;
let vec = &mut self.buf.bytes;
let cap = vec.capacity();
if cap == 0 {
let init = cmp::min(MAX_BUFFER_SIZE, cmp::max(INIT_BUFFER_SIZE, needed));
let init = cmp::min(self.max_buf_size, cmp::max(INIT_BUFFER_SIZE, needed));
trace!("WriteBuf reserving initial {}", init);
vec.reserve(init);
} else if cap < MAX_BUFFER_SIZE {
vec.reserve(cmp::min(needed, MAX_BUFFER_SIZE - cap));
} else if cap < self.max_buf_size {
vec.reserve(cmp::min(needed, self.max_buf_size - cap));
trace!("WriteBuf reserved {}", vec.capacity() - cap);
}
}
fn maybe_reset(&mut self) {
if self.0.pos != 0 && self.0.remaining() == 0 {
self.0.pos = 0;
if self.buf.pos != 0 && self.buf.remaining() == 0 {
self.buf.pos = 0;
unsafe {
self.0.bytes.set_len(0);
self.buf.bytes.set_len(0);
}
}
}

View File

@@ -53,6 +53,7 @@ pub use self::service::{const_service, service_fn};
/// which handle a connection to an HTTP server. Each instance of `Http` can be
/// configured with various protocol-level options such as keepalive.
pub struct Http<B = ::Chunk> {
max_buf_size: Option<usize>,
keep_alive: bool,
pipeline: bool,
_marker: PhantomData<B>,
@@ -129,6 +130,7 @@ impl<B: AsRef<[u8]> + 'static> Http<B> {
pub fn new() -> Http<B> {
Http {
keep_alive: true,
max_buf_size: None,
pipeline: false,
_marker: PhantomData,
}
@@ -142,6 +144,12 @@ impl<B: AsRef<[u8]> + 'static> Http<B> {
self
}
/// Set the maximum buffer size for the connection.
pub fn max_buf_size(&mut self, max: usize) -> &mut Self {
self.max_buf_size = Some(max);
self
}
/// Aggregates flushes to better support pipelined responses.
///
/// Experimental, may be have bugs.
@@ -226,6 +234,7 @@ impl<B: AsRef<[u8]> + 'static> Http<B> {
new_service: new_service,
protocol: Http {
keep_alive: self.keep_alive,
max_buf_size: self.max_buf_size,
pipeline: self.pipeline,
_marker: PhantomData,
},
@@ -250,6 +259,9 @@ impl<B: AsRef<[u8]> + 'static> Http<B> {
};
let mut conn = proto::Conn::new(io, ka);
conn.set_flush_pipeline(self.pipeline);
if let Some(max) = self.max_buf_size {
conn.set_max_buf_size(max);
}
Connection {
conn: proto::dispatch::Dispatcher::new(proto::dispatch::Server::new(service), conn),
}

View File

@@ -113,6 +113,9 @@ impl<T, B> ServerProto<T> for Http<B>
};
let mut conn = proto::Conn::new(io, ka);
conn.set_flush_pipeline(self.pipeline);
if let Some(max) = self.max_buf_size {
conn.set_max_buf_size(max);
}
__ProtoBindTransport {
inner: future::ok(conn),
}

View File

@@ -958,6 +958,41 @@ fn illegal_request_length_returns_400_response() {
core.run(fut).unwrap_err();
}
#[test]
fn max_buf_size() {
let _ = pretty_env_logger::try_init();
let mut core = Core::new().unwrap();
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
let addr = listener.local_addr().unwrap();
const MAX: usize = 16_000;
thread::spawn(move || {
let mut tcp = connect(&addr);
tcp.write_all(b"POST /").expect("write 1");
tcp.write_all(&vec![b'a'; MAX]).expect("write 2");
tcp.write_all(b" HTTP/1.1\r\n\r\n").expect("write 3");
let mut buf = [0; 256];
tcp.read(&mut buf).expect("read 1");
let expected = "HTTP/1.1 400 ";
assert_eq!(s(&buf[..expected.len()]), expected);
});
let fut = listener.incoming()
.into_future()
.map_err(|_| unreachable!())
.and_then(|(item, _incoming)| {
let (socket, _) = item.unwrap();
Http::<hyper::Chunk>::new()
.max_buf_size(MAX)
.serve_connection(socket, HelloWorld)
.map(|_| ())
});
core.run(fut).unwrap_err();
}
#[test]
fn remote_addr() {
let server = serve();