From c32015d48e666573f0ef0862ba941b2f19497fc8 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Wed, 13 Sep 2017 14:10:27 -0700 Subject: [PATCH] add support for configuring max frame size - Adds `max_frame_size` to client and server builders - Pushes max_frame_size into Codec - Detects when the Codec triggers an error from a frame too big - Sends a GOAWAY when FRAME_SIZE_ERROR is encountered reading a frame --- src/client.rs | 10 +++ src/codec/framed_read.rs | 25 +++++++- src/codec/framed_write.rs | 2 +- src/codec/mod.rs | 7 ++- src/frame/reason.rs | 2 +- src/frame/settings.rs | 11 +++- src/server.rs | 15 ++++- tests/codec_read.rs | 10 +-- tests/stream_states.rs | 104 ++++++++++++++++++++++++++++++++ tests/support/src/frames.rs | 4 ++ tests/support/src/future_ext.rs | 34 +++++++++++ tests/support/src/mock.rs | 9 ++- tests/support/src/raw.rs | 6 ++ 13 files changed, 221 insertions(+), 18 deletions(-) diff --git a/src/client.rs b/src/client.rs index e35269e..5b5c44e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -170,6 +170,12 @@ impl Builder { self } + /// Set the max frame size of received frames. + pub fn max_frame_size(&mut self, max: u32) -> &mut Self { + self.settings.set_max_frame_size(Some(max)); + self + } + /// Bind an H2 client connection. /// /// Returns a future which resolves to the connection value once the H2 @@ -203,6 +209,10 @@ where // Create the codec let mut codec = Codec::new(io); + if let Some(max) = self.settings.max_frame_size() { + codec.set_max_recv_frame_size(max as usize); + } + // Send initial settings frame codec .buffer(self.settings.clone().into()) diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs index c4792f9..46ac723 100644 --- a/src/codec/framed_read.rs +++ b/src/codec/framed_read.rs @@ -1,6 +1,6 @@ use codec::RecvError; -use frame::{self, Frame, Kind}; -use frame::DEFAULT_SETTINGS_HEADER_TABLE_SIZE; +use frame::{self, Frame, Kind, Reason}; +use frame::{DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE}; use frame::Reason::*; use hpack; @@ -9,6 +9,8 @@ use futures::*; use bytes::BytesMut; +use std::io; + use tokio_io::AsyncRead; use tokio_io::codec::length_delimited; @@ -228,9 +230,12 @@ impl FramedRead { } /// Updates the max frame size setting. + /// + /// Must be within 16,384 and 16,777,215. #[cfg(feature = "unstable")] #[inline] pub fn set_max_frame_size(&mut self, val: usize) { + assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize); self.inner.set_max_frame_length(val) } } @@ -245,7 +250,7 @@ where fn poll(&mut self) -> Poll, Self::Error> { loop { trace!("poll"); - let bytes = match try_ready!(self.inner.poll()) { + let bytes = match try_ready!(self.inner.poll().map_err(map_err)) { Some(bytes) => bytes, None => return Ok(Async::Ready(None)), }; @@ -258,3 +263,17 @@ where } } } + +fn map_err(err: io::Error) -> RecvError { + use std::error::Error; + + if let io::ErrorKind::InvalidData = err.kind() { + // woah, brittle... + // TODO: with tokio-io v0.1.4, we can check + // err.get_ref().is::() + if err.description() == "frame size too big" { + return RecvError::Connection(Reason::FrameSizeError); + } + } + err.into() +} diff --git a/src/codec/framed_write.rs b/src/codec/framed_write.rs index c4ed185..473b1ce 100644 --- a/src/codec/framed_write.rs +++ b/src/codec/framed_write.rs @@ -227,7 +227,7 @@ impl FramedWrite { /// Set the peer's max frame size. pub fn set_max_frame_size(&mut self, val: usize) { - assert!(val <= frame::MAX_MAX_FRAME_SIZE); + assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize); self.max_frame_size = val as FrameSize; } diff --git a/src/codec/mod.rs b/src/codec/mod.rs index a4cf71a..057ca94 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -45,10 +45,12 @@ where .length_field_length(3) .length_adjustment(9) .num_skip(0) // Don't skip the header - .max_frame_length(max_frame_size) .new_read(framed_write); - let inner = FramedRead::new(delimited); + let mut inner = FramedRead::new(delimited); + + // Use FramedRead's method since it checks the value is within range. + inner.set_max_frame_size(max_frame_size); Codec { inner, @@ -66,7 +68,6 @@ impl Codec { #[cfg(feature = "unstable")] #[inline] pub fn set_max_recv_frame_size(&mut self, val: usize) { - // TODO: should probably make some assertions about max frame size... self.inner.set_max_frame_size(val) } diff --git a/src/frame/reason.rs b/src/frame/reason.rs index 0c804c2..1488181 100644 --- a/src/frame/reason.rs +++ b/src/frame/reason.rs @@ -33,7 +33,7 @@ impl Reason { FlowControlError => "flow-control protocol violated", SettingsTimeout => "settings ACK not received in timely manner", StreamClosed => "received frame when stream half-closed", - FrameSizeError => "frame sent with invalid size", + FrameSizeError => "frame with invalid size", RefusedStream => "refused stream before processing any application logic", Cancel => "stream no longer needed", CompressionError => "unable to maintain the header compression context", diff --git a/src/frame/settings.rs b/src/frame/settings.rs index 5da69bf..a2e9747 100644 --- a/src/frame/settings.rs +++ b/src/frame/settings.rs @@ -46,7 +46,7 @@ pub const DEFAULT_MAX_FRAME_SIZE: FrameSize = 16_384; pub const MAX_INITIAL_WINDOW_SIZE: usize = (1 << 31) - 1; /// MAX_FRAME_SIZE upper bound -pub const MAX_MAX_FRAME_SIZE: usize = (1 << 24) - 1; +pub const MAX_MAX_FRAME_SIZE: FrameSize = (1 << 24) - 1; // ===== impl Settings ===== @@ -78,6 +78,13 @@ impl Settings { self.max_frame_size } + pub fn set_max_frame_size(&mut self, size: Option) { + if let Some(val) = size { + assert!(DEFAULT_MAX_FRAME_SIZE <= val && val <= MAX_MAX_FRAME_SIZE); + } + self.max_frame_size = size; + } + pub fn load(head: Head, payload: &[u8]) -> Result { use self::Setting::*; @@ -131,7 +138,7 @@ impl Settings { settings.initial_window_size = Some(val); }, Some(MaxFrameSize(val)) => { - if val < DEFAULT_MAX_FRAME_SIZE || val as usize > MAX_MAX_FRAME_SIZE { + if val < DEFAULT_MAX_FRAME_SIZE || val > MAX_MAX_FRAME_SIZE { return Err(Error::InvalidSettingValue); } else { settings.max_frame_size = Some(val); diff --git a/src/server.rs b/src/server.rs index daa36c1..2711e3a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -93,6 +93,10 @@ where // Create the codec let mut codec = Codec::new(io); + if let Some(max) = settings.max_frame_size() { + codec.set_max_recv_frame_size(max as usize); + } + // Send initial settings frame codec .buffer(settings.clone().into()) @@ -180,13 +184,20 @@ impl Builder { self } + /// Set the max frame size of received frames. + pub fn max_frame_size(&mut self, max: u32) -> &mut Self { + self.settings.set_max_frame_size(Some(max)); + self + } + /// Bind an H2 server connection. /// /// Returns a future which resolves to the connection value once the H2 /// handshake has been completed. pub fn handshake(&self, io: T) -> Handshake - where T: AsyncRead + AsyncWrite + 'static, - B: IntoBuf + 'static + where + T: AsyncRead + AsyncWrite + 'static, + B: IntoBuf + 'static, { Server::handshake2(io, self.settings.clone()) } diff --git a/tests/codec_read.rs b/tests/codec_read.rs index d9f6b05..c69ee5c 100644 --- a/tests/codec_read.rs +++ b/tests/codec_read.rs @@ -110,22 +110,24 @@ fn read_headers_empty_payload() {} #[test] fn update_max_frame_len_at_rest() { + let _ = ::env_logger::init(); // TODO: add test for updating max frame length in flight as well? let mut codec = raw_codec! { read => [ 0, 0, 5, 0, 0, 0, 0, 0, 1, "hello", - "world", + 0, 64, 1, 0, 0, 0, 0, 0, 1, + vec![0; 16_385], ]; }; assert_eq!(poll_data!(codec).payload(), &b"hello"[..]); - codec.set_max_recv_frame_size(2); + codec.set_max_recv_frame_size(16_384); - assert_eq!(codec.max_recv_frame_size(), 2); + assert_eq!(codec.max_recv_frame_size(), 16_384); assert_eq!( codec.poll().unwrap_err().description(), - "frame size too big" + "frame with invalid size" ); } diff --git a/tests/stream_states.rs b/tests/stream_states.rs index ff1be13..b3bc8e5 100644 --- a/tests/stream_states.rs +++ b/tests/stream_states.rs @@ -195,6 +195,110 @@ fn closed_streams_are_released() { let _ = h2.join(srv).wait().unwrap(); } +#[test] +fn errors_if_recv_frame_exceeds_max_frame_size() { + let _ = ::env_logger::init(); + let (io, mut srv) = mock::new(); + + let h2 = Client::handshake(io).unwrap().and_then(|mut h2| { + let request = Request::builder() + .method(Method::GET) + .uri("https://example.com/") + .body(()) + .unwrap(); + + let req = h2.request(request, true) + .unwrap() + .unwrap() + .and_then(|resp| { + assert_eq!(resp.status(), StatusCode::OK); + let body = resp.into_parts().1; + body.concat2().then(|res| { + let err = res.unwrap_err(); + assert_eq!(err.to_string(), "protocol error: frame with invalid size"); + Ok::<(), ()>(()) + }) + }); + + // client should see a conn error + let conn = h2.then(|res| { + let err = res.unwrap_err(); + assert_eq!(err.to_string(), "protocol error: frame with invalid size"); + Ok::<(), ()>(()) + }); + conn.unwrap().join(req) + }); + + // a bad peer + srv.codec_mut().set_max_send_frame_size(16_384 * 4); + + let srv = srv.assert_client_handshake() + .unwrap() + .ignore_settings() + .recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .send_frame(frames::headers(1).response(200)) + .send_frame(frames::data(1, vec![0; 16_385]).eos()) + .recv_frame(frames::go_away(0).frame_size()) + .close(); + + let _ = h2.join(srv).wait().unwrap(); +} + + +#[test] +fn configure_max_frame_size() { + let _ = ::env_logger::init(); + let (io, mut srv) = mock::new(); + + let h2 = Client::builder() + .max_frame_size(16_384 * 2) + .handshake::<_, Bytes>(io) + .expect("handshake") + .and_then(|mut h2| { + let request = Request::builder() + .method(Method::GET) + .uri("https://example.com/") + .body(()) + .unwrap(); + + let req = h2.request(request, true) + .unwrap() + .expect("response") + .and_then(|resp| { + assert_eq!(resp.status(), StatusCode::OK); + let body = resp.into_parts().1; + body.concat2().expect("body") + }) + .and_then(|buf| { + assert_eq!(buf.len(), 16_385); + Ok(()) + }); + + h2.expect("client").join(req) + }); + + // a good peer + srv.codec_mut().set_max_send_frame_size(16_384 * 2); + + let srv = srv.assert_client_handshake() + .unwrap() + .ignore_settings() + .recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .send_frame(frames::headers(1).response(200)) + .send_frame(frames::data(1, vec![0; 16_385]).eos()) + .close(); + + let _ = h2.join(srv).wait().expect("wait"); +} + /* #[test] fn send_data_after_headers_eos() { diff --git a/tests/support/src/frames.rs b/tests/support/src/frames.rs index af5719a..f9c5c15 100644 --- a/tests/support/src/frames.rs +++ b/tests/support/src/frames.rs @@ -146,6 +146,10 @@ impl Mock { pub fn flow_control(self) -> Self { Mock(frame::GoAway::new(self.0.last_stream_id(), frame::Reason::FlowControlError)) } + + pub fn frame_size(self) -> Self { + Mock(frame::GoAway::new(self.0.last_stream_id(), frame::Reason::FrameSizeError)) + } } impl From> for SendFrame { diff --git a/tests/support/src/future_ext.rs b/tests/support/src/future_ext.rs index f4397c7..414eb98 100644 --- a/tests/support/src/future_ext.rs +++ b/tests/support/src/future_ext.rs @@ -12,6 +12,18 @@ pub trait FutureExt: Future { Unwrap { inner: self } } + /// Panic on error, with a message. + fn expect(self, msg: T) -> Expect + where Self: Sized, + Self::Error: fmt::Debug, + T: fmt::Display, + { + Expect { + inner: self, + msg: msg.to_string(), + } + } + /// Drive `other` by polling `self`. /// /// `self` must not resolve before `other` does. @@ -51,6 +63,28 @@ impl Future for Unwrap } } + +// ===== Expect ====== + +/// Panic on error +pub struct Expect { + inner: T, + msg: String, +} + +impl Future for Expect + where T: Future, + T::Item: fmt::Debug, + T::Error: fmt::Debug, +{ + type Item = T::Item; + type Error = (); + + fn poll(&mut self) -> Poll { + Ok(self.inner.poll().expect(&self.msg)) + } +} + // ===== Drive ====== /// Drive a future to completion while also polling the driver diff --git a/tests/support/src/mock.rs b/tests/support/src/mock.rs index 67e4731..6bbc974 100644 --- a/tests/support/src/mock.rs +++ b/tests/support/src/mock.rs @@ -26,7 +26,7 @@ pub struct Handle { } #[derive(Debug)] -struct Pipe { +pub struct Pipe { inner: Arc>, } @@ -67,6 +67,11 @@ pub fn new() -> (Mock, Handle) { // ===== impl Handle ===== impl Handle { + /// Get a mutable reference to inner Codec. + pub fn codec_mut(&mut self) -> &mut ::Codec { + &mut self.codec + } + /// Send a frame pub fn send(&mut self, item: SendFrame) -> Result<(), SendError> { // Queue the frame @@ -237,7 +242,7 @@ impl io::Write for Mock { let mut me = self.pipe.inner.lock().unwrap(); if me.closed { - return Err(io::ErrorKind::BrokenPipe.into()); + return Err(io::Error::new(io::ErrorKind::BrokenPipe, "mock closed")); } me.tx.extend(buf); diff --git a/tests/support/src/raw.rs b/tests/support/src/raw.rs index 19afff0..fbaa109 100644 --- a/tests/support/src/raw.rs +++ b/tests/support/src/raw.rs @@ -44,3 +44,9 @@ impl<'a> Chunk for &'a str { dst.extend(self.as_bytes()) } } + +impl Chunk for Vec { + fn push(&self, dst: &mut Vec) { + dst.extend(self.iter()) + } +}