From aa23a9735dec88754fb9695d4432a6822e4d4cc9 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Fri, 5 Jan 2018 09:23:48 -0800 Subject: [PATCH] SETTINGS_MAX_HEADER_LIST_SIZE (#206) This, uh, grew into something far bigger than expected, but it turns out, all of it was needed to eventually support this correctly. - Adds configuration to client and server to set [SETTINGS_MAX_HEADER_LIST_SIZE](http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE) - If not set, a "sane default" of 16 MB is used (taken from golang's http2) - Decoding header blocks now happens as they are received, instead of buffering up possibly forever until the last continuation frame is parsed. - As each field is decoded, it's undecoded size is added to the total. Whenever a header block goes over the maximum size, the `frame` will be marked as such. - Whenever a header block is deemed over max limit, decoding will still continue, but new fields will not be appended to `HeaderMap`. This is also can save wasted hashing. - To protect against enormous string literals, such that they span multiple continuation frames, a check is made that the combined encoded bytes is less than the max allowed size. While technically not exactly what the spec suggests (counting decoded size instead), this should hopefully only happen when someone is indeed malicious. If found, a `GOAWAY` of `COMPRESSION_ERROR` is sent, and the connection shut down. - After an oversize header block frame is finished decoding, the streams state machine will notice it is oversize, and handle that. - If the local peer is a server, a 431 response is sent, as suggested by the spec. - A `REFUSED_STREAM` reset is sent, since we cannot actually give the stream to the user. - In order to be able to send both the 431 headers frame, and a reset frame afterwards, the scheduled `Canceled` machinery was made more general to a `Scheduled(Reason)` state instead. Closes #18 Closes #191 --- Cargo.toml | 2 +- src/client.rs | 10 ++ src/codec/error.rs | 9 ++ src/codec/framed_read.rs | 243 +++++++++++++++++--------------- src/codec/mod.rs | 5 + src/frame/headers.rs | 100 +++++++++++-- src/frame/settings.rs | 8 ++ src/hpack/decoder.rs | 61 +++++--- src/hpack/mod.rs | 2 +- src/hpack/test/fixture.rs | 4 +- src/hpack/test/fuzz.rs | 4 +- src/proto/streams/prioritize.rs | 10 +- src/proto/streams/recv.rs | 68 +++++++-- src/proto/streams/send.rs | 16 ++- src/proto/streams/state.rs | 31 ++-- src/proto/streams/streams.rs | 71 ++++++++-- src/server.rs | 10 ++ tests/client_request.rs | 76 ++++++++++ tests/codec_read.rs | 60 ++++++++ tests/codec_write.rs | 25 ---- tests/push_promise.rs | 50 +++++++ tests/server.rs | 73 ++++++++++ tests/support/frames.rs | 9 ++ tests/support/mock.rs | 5 +- tests/support/mod.rs | 1 + tests/support/prelude.rs | 25 ++++ 26 files changed, 752 insertions(+), 226 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0097755..109b906 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ unstable = [] [dependencies] futures = "0.1" -tokio-io = "0.1.3" +tokio-io = "0.1.4" bytes = "0.4" http = "0.1" byteorder = "1.0" diff --git a/src/client.rs b/src/client.rs index 660d08e..e3485da 100644 --- a/src/client.rs +++ b/src/client.rs @@ -172,6 +172,12 @@ impl Builder { self } + /// Set the max size of received header frames. + pub fn max_header_list_size(&mut self, max: u32) -> &mut Self { + self.settings.set_max_header_list_size(Some(max)); + self + } + /// Set the maximum number of concurrent streams. /// /// Clients can only limit the maximum number of streams that that the @@ -339,6 +345,10 @@ where codec.set_max_recv_frame_size(max as usize); } + if let Some(max) = self.builder.settings.max_header_list_size() { + codec.set_max_recv_header_list_size(max as usize); + } + // Send initial settings frame codec .buffer(self.builder.settings.clone().into()) diff --git a/src/codec/error.rs b/src/codec/error.rs index 76ac619..74063c8 100644 --- a/src/codec/error.rs +++ b/src/codec/error.rs @@ -54,6 +54,15 @@ pub enum UserError { // ===== impl RecvError ===== +impl RecvError { + pub(crate) fn is_stream_error(&self) -> bool { + match *self { + RecvError::Stream { .. } => true, + _ => false, + } + } +} + impl From for RecvError { fn from(src: io::Error) -> Self { RecvError::Io(src) diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs index 29a590e..2f7bdc0 100644 --- a/src/codec/framed_read.rs +++ b/src/codec/framed_read.rs @@ -13,6 +13,9 @@ use std::io; use tokio_io::AsyncRead; use tokio_io::codec::length_delimited; +// 16 MB "sane default" taken from golang http2 +const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20; + #[derive(Debug)] pub struct FramedRead { inner: length_delimited::FramedRead, @@ -20,6 +23,8 @@ pub struct FramedRead { // hpack decoder state hpack: hpack::Decoder, + max_header_list_size: usize, + partial: Option, } @@ -36,8 +41,6 @@ struct Partial { #[derive(Debug)] enum Continuable { Headers(frame::Headers), - // Decode the Continuation frame but ignore it... - // Ignore(StreamId), PushPromise(frame::PushPromise), } @@ -46,6 +49,7 @@ impl FramedRead { FramedRead { inner: inner, hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE), + max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE, partial: None, } } @@ -67,6 +71,66 @@ impl FramedRead { trace!(" -> kind={:?}", kind); + macro_rules! header_block { + ($frame:ident, $head:ident, $bytes:ident) => ({ + // Drop the frame header + // TODO: Change to drain: carllerche/bytes#130 + let _ = $bytes.split_to(frame::HEADER_LEN); + + // Parse the header frame w/o parsing the payload + let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) { + Ok(res) => res, + Err(frame::Error::InvalidDependencyId) => { + debug!("stream error PROTOCOL_ERROR -- invalid HEADERS dependency ID"); + // A stream cannot depend on itself. An endpoint MUST + // treat this as a stream error (Section 5.4.2) of type + // `PROTOCOL_ERROR`. + return Err(Stream { + id: $head.stream_id(), + reason: Reason::PROTOCOL_ERROR, + }); + }, + Err(e) => { + debug!("connection error PROTOCOL_ERROR -- failed to load frame; err={:?}", e); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } + }; + + let is_end_headers = frame.is_end_headers(); + + // Load the HPACK encoded headers + match frame.load_hpack(&mut payload, self.max_header_list_size, &mut self.hpack) { + Ok(_) => {}, + Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}, + Err(frame::Error::MalformedMessage) => { + + debug!("stream error PROTOCOL_ERROR -- malformed header block"); + return Err(Stream { + id: $head.stream_id(), + reason: Reason::PROTOCOL_ERROR, + }); + }, + Err(e) => { + debug!("connection error PROTOCOL_ERROR -- failed HPACK decoding; err={:?}", e); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } + } + + if is_end_headers { + frame.into() + } else { + trace!("loaded partial header block"); + // Defer returning the frame + self.partial = Some(Partial { + frame: Continuable::$frame(frame), + buf: payload, + }); + + return Ok(None); + } + }); + } + let frame = match kind { Kind::Settings => { let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); @@ -103,56 +167,7 @@ impl FramedRead { })?.into() }, Kind::Headers => { - // Drop the frame header - // TODO: Change to drain: carllerche/bytes#130 - let _ = bytes.split_to(frame::HEADER_LEN); - - // Parse the header frame w/o parsing the payload - let (mut headers, payload) = match frame::Headers::load(head, bytes) { - Ok(res) => res, - Err(frame::Error::InvalidDependencyId) => { - // A stream cannot depend on itself. An endpoint MUST - // treat this as a stream error (Section 5.4.2) of type - // `PROTOCOL_ERROR`. - debug!("stream error PROTOCOL_ERROR -- invalid HEADERS dependency ID"); - return Err(Stream { - id: head.stream_id(), - reason: Reason::PROTOCOL_ERROR, - }); - }, - Err(e) => { - debug!("connection error PROTOCOL_ERROR -- failed to load HEADERS frame; err={:?}", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - }; - - if headers.is_end_headers() { - // Load the HPACK encoded headers & return the frame - match headers.load_hpack(payload, &mut self.hpack) { - Ok(_) => {}, - Err(frame::Error::MalformedMessage) => { - debug!("stream error PROTOCOL_ERROR -- malformed HEADERS frame"); - return Err(Stream { - id: head.stream_id(), - reason: Reason::PROTOCOL_ERROR, - }); - }, - Err(e) => { - debug!("connection error PROTOCOL_ERROR -- failed HEADERS frame HPACK decoding; err={:?}", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - } - - headers.into() - } else { - // Defer loading the frame - self.partial = Some(Partial { - frame: Continuable::Headers(headers), - buf: payload, - }); - - return Ok(None); - } + header_block!(Headers, head, bytes) }, Kind::Reset => { let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]); @@ -163,44 +178,7 @@ impl FramedRead { res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into() }, Kind::PushPromise => { - // Drop the frame header - // TODO: Change to drain: carllerche/bytes#130 - let _ = bytes.split_to(frame::HEADER_LEN); - - // Parse the frame w/o parsing the payload - let (mut push, payload) = frame::PushPromise::load(head, bytes) - .map_err(|e| { - debug!("connection error PROTOCOL_ERROR -- failed to load PUSH_PROMISE frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })?; - - if push.is_end_headers() { - // Load the HPACK encoded headers & return the frame - match push.load_hpack(payload, &mut self.hpack) { - Ok(_) => {}, - Err(frame::Error::MalformedMessage) => { - debug!("stream error PROTOCOL_ERROR -- malformed PUSH_PROMISE frame"); - return Err(Stream { - id: head.stream_id(), - reason: Reason::PROTOCOL_ERROR, - }); - }, - Err(e) => { - debug!("connection error PROTOCOL_ERROR -- failed PUSH_PROMISE frame HPACK decoding; err={:?}", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - } - - push.into() - } else { - // Defer loading the frame - self.partial = Some(Partial { - frame: Continuable::PushPromise(push), - buf: payload, - }); - - return Ok(None); - } + header_block!(PushPromise, head, bytes) }, Kind::Priority => { if head.stream_id() == 0 { @@ -224,8 +202,7 @@ impl FramedRead { } }, Kind::Continuation => { - // TODO: Un-hack this - let end_of_headers = (head.flag() & 0x4) == 0x4; + let is_end_headers = (head.flag() & 0x4) == 0x4; let mut partial = match self.partial.take() { Some(partial) => partial, @@ -235,22 +212,43 @@ impl FramedRead { } }; - // Extend the buf - partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]); - - if !end_of_headers { - self.partial = Some(partial); - return Ok(None); - } - // The stream identifiers must match if partial.frame.stream_id() != head.stream_id() { debug!("connection error PROTOCOL_ERROR -- CONTINUATION frame stream ID does not match previous frame stream ID"); return Err(Connection(Reason::PROTOCOL_ERROR)); } - match partial.frame.load_hpack(partial.buf, &mut self.hpack) { + + + // Extend the buf + if partial.buf.is_empty() { + partial.buf = bytes.split_off(frame::HEADER_LEN); + } else { + if partial.frame.is_over_size() { + // If there was left over bytes previously, they may be + // needed to continue decoding, even though we will + // be ignoring this frame. This is done to keep the HPACK + // decoder state up-to-date. + // + // Still, we need to be careful, because if a malicious + // attacker were to try to send a gigantic string, such + // that it fits over multiple header blocks, we could + // grow memory uncontrollably again, and that'd be a shame. + // + // Instead, we use a simple heuristic to determine if + // we should continue to ignore decoding, or to tell + // the attacker to go away. + if partial.buf.len() + bytes.len() > self.max_header_list_size { + debug!("connection error COMPRESSION_ERROR -- CONTINUATION frame header block size over ignorable limit"); + return Err(Connection(Reason::COMPRESSION_ERROR)); + } + } + partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]); + } + + match partial.frame.load_hpack(&mut partial.buf, self.max_header_list_size, &mut self.hpack) { Ok(_) => {}, + Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}, Err(frame::Error::MalformedMessage) => { debug!("stream error PROTOCOL_ERROR -- malformed CONTINUATION frame"); return Err(Stream { @@ -258,10 +256,18 @@ impl FramedRead { reason: Reason::PROTOCOL_ERROR, }); }, - Err(_) => return Err(Connection(Reason::PROTOCOL_ERROR)), + Err(e) => { + debug!("connection error PROTOCOL_ERROR -- failed HPACK decoding; err={:?}", e); + return Err(Connection(Reason::PROTOCOL_ERROR)); + }, } - partial.frame.into() + if is_end_headers { + partial.frame.into() + } else { + self.partial = Some(partial); + return Ok(None); + } }, Kind::Unknown => { // Unknown frames are ignored @@ -295,6 +301,12 @@ impl FramedRead { assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize); self.inner.set_max_frame_length(val) } + + /// Update the max header list size setting. + #[inline] + pub fn set_max_header_list_size(&mut self, val: usize) { + self.max_header_list_size = val; + } } impl Stream for FramedRead @@ -322,14 +334,13 @@ where } fn map_err(err: io::Error) -> RecvError { - use std::error::Error; + use tokio_io::codec::length_delimited::FrameTooBig; if let io::ErrorKind::InvalidData = err.kind() { - // woah, brittle... - // TODO: with tokio-io v0.1.4, we can check - // err.get_ref().is::() - if err.description() == "frame size too big" { - return RecvError::Connection(Reason::FRAME_SIZE_ERROR); + if let Some(custom) = err.get_ref() { + if custom.is::() { + return RecvError::Connection(Reason::FRAME_SIZE_ERROR); + } } } err.into() @@ -345,14 +356,22 @@ impl Continuable { } } + fn is_over_size(&self) -> bool { + match *self { + Continuable::Headers(ref h) => h.is_over_size(), + Continuable::PushPromise(ref p) => p.is_over_size(), + } + } + fn load_hpack( &mut self, - src: BytesMut, + src: &mut BytesMut, + max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), frame::Error> { match *self { - Continuable::Headers(ref mut h) => h.load_hpack(src, decoder), - Continuable::PushPromise(ref mut p) => p.load_hpack(src, decoder), + Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder), + Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder), } } } diff --git a/src/codec/mod.rs b/src/codec/mod.rs index 6dc3267..0f8acbf 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -90,6 +90,11 @@ impl Codec { self.framed_write().set_max_frame_size(val) } + /// Set the max header list size that can be received. + pub fn set_max_recv_header_list_size(&mut self, val: usize) { + self.inner.set_max_header_list_size(val); + } + /// Get a reference to the inner stream. #[cfg(feature = "unstable")] pub fn get_ref(&self) -> &T { diff --git a/src/frame/headers.rs b/src/frame/headers.rs index 4c98442..1ffa743 100644 --- a/src/frame/headers.rs +++ b/src/frame/headers.rs @@ -86,6 +86,9 @@ struct HeaderBlock { /// The decoded header fields fields: HeaderMap, + /// Set to true if decoding went over the max header list size. + is_over_size: bool, + /// Pseudo headers, these are broken out as they must be sent as part of the /// headers frame. pseudo: Pseudo, @@ -116,6 +119,7 @@ impl Headers { stream_dep: None, header_block: HeaderBlock { fields: fields, + is_over_size: false, pseudo: pseudo, }, flags: HeadersFlag::default(), @@ -131,6 +135,7 @@ impl Headers { stream_dep: None, header_block: HeaderBlock { fields: fields, + is_over_size: false, pseudo: Pseudo::default(), }, flags: flags, @@ -185,6 +190,7 @@ impl Headers { stream_dep: stream_dep, header_block: HeaderBlock { fields: HeaderMap::new(), + is_over_size: false, pseudo: Pseudo::default(), }, flags: flags, @@ -193,8 +199,8 @@ impl Headers { Ok((headers, src)) } - pub fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> { - self.header_block.load(src, decoder) + pub fn load_hpack(&mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder) -> Result<(), Error> { + self.header_block.load(src, max_header_list_size, decoder) } pub fn stream_id(&self) -> StreamId { @@ -217,6 +223,10 @@ impl Headers { self.flags.set_end_stream() } + pub fn is_over_size(&self) -> bool { + self.header_block.is_over_size + } + pub fn into_parts(self) -> (Pseudo, HeaderMap) { (self.header_block.pseudo, self.header_block.fields) } @@ -304,6 +314,7 @@ impl PushPromise { flags: flags, header_block: HeaderBlock { fields: HeaderMap::new(), + is_over_size: false, pseudo: Pseudo::default(), }, promised_id: promised_id, @@ -312,8 +323,8 @@ impl PushPromise { Ok((frame, src)) } - pub fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> { - self.header_block.load(src, decoder) + pub fn load_hpack(&mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder) -> Result<(), Error> { + self.header_block.load(src, max_header_list_size, decoder) } pub fn stream_id(&self) -> StreamId { @@ -332,6 +343,10 @@ impl PushPromise { self.flags.set_end_headers(); } + pub fn is_over_size(&self) -> bool { + self.header_block.is_over_size + } + pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option { use bytes::BufMut; @@ -364,6 +379,7 @@ impl PushPromise { flags: PushPromiseFlag::default(), header_block: HeaderBlock { fields, + is_over_size: false, pseudo, }, promised_id, @@ -677,10 +693,12 @@ impl fmt::Debug for PushPromiseFlag { // ===== HeaderBlock ===== + impl HeaderBlock { - fn load(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> { - let mut reg = false; + fn load(&mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder) -> Result<(), Error> { + let mut reg = !self.fields.is_empty(); let mut malformed = false; + let mut headers_size = self.calculate_header_list_size(); macro_rules! set_pseudo { ($field:ident, $val:expr) => {{ @@ -691,22 +709,25 @@ impl HeaderBlock { trace!("load_hpack; header malformed -- repeated pseudo"); malformed = true; } else { - self.pseudo.$field = Some($val); + let __val = $val; + headers_size += decoded_header_size(stringify!($ident).len() + 1, __val.as_str().len()); + if headers_size < max_header_list_size { + self.pseudo.$field = Some(__val); + } else if !self.is_over_size { + trace!("load_hpack; header list size over max"); + self.is_over_size = true; + } } }} } - let mut src = Cursor::new(src.freeze()); + let mut cursor = Cursor::new(src); - // At this point, we're going to assume that the hpack encoded headers - // contain the entire payload. Later, we need to check for stream - // priority. - // // If the header frame is malformed, we still have to continue decoding // the headers. A malformed header frame is a stream level error, but // the hpack state is connection level. In order to maintain correct // state for other streams, the hpack decoding process must complete. - let res = decoder.decode(&mut src, |header| { + let res = decoder.decode(&mut cursor, |header| { use hpack::Header::*; match header { @@ -730,7 +751,14 @@ impl HeaderBlock { malformed = true; } else { reg = true; - self.fields.append(name, value); + + headers_size += decoded_header_size(name.as_str().len(), value.len()); + if headers_size < max_header_list_size { + self.fields.append(name, value); + } else if !self.is_over_size { + trace!("load_hpack; header list size over max"); + self.is_over_size = true; + } } }, Authority(v) => set_pseudo!(authority, v), @@ -763,4 +791,48 @@ impl HeaderBlock { }, } } + + /// Calculates the size of the currently decoded header list. + /// + /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE + /// + /// > The value is based on the uncompressed size of header fields, + /// > including the length of the name and value in octets plus an + /// > overhead of 32 octets for each header field. + fn calculate_header_list_size(&self) -> usize { + macro_rules! pseudo_size { + ($name:ident) => ({ + self.pseudo + .$name + .as_ref() + .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len())) + .unwrap_or(0) + }); + } + + pseudo_size!(method) + + pseudo_size!(scheme) + + pseudo_size!(status) + + pseudo_size!(authority) + + pseudo_size!(path) + + self.fields.iter() + .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len())) + .sum::() + } +} + +fn decoded_header_size(name: usize, value: usize) -> usize { + name + value + 32 +} + +// Stupid hack to make the set_pseudo! macro happy, since all other values +// have a method `as_str` except for `String`. +trait AsStr { + fn as_str(&self) -> &str; +} + +impl AsStr for String { + fn as_str(&self) -> &str { + self + } } diff --git a/src/frame/settings.rs b/src/frame/settings.rs index b130f43..90fd493 100644 --- a/src/frame/settings.rs +++ b/src/frame/settings.rs @@ -89,6 +89,14 @@ impl Settings { self.max_frame_size = size; } + pub fn max_header_list_size(&self) -> Option { + self.max_header_list_size + } + + pub fn set_max_header_list_size(&mut self, size: Option) { + self.max_header_list_size = size; + } + pub fn is_push_enabled(&self) -> bool { self.enable_push.unwrap_or(1) != 0 } diff --git a/src/hpack/decoder.rs b/src/hpack/decoder.rs index ddc2812..514c09f 100644 --- a/src/hpack/decoder.rs +++ b/src/hpack/decoder.rs @@ -34,10 +34,15 @@ pub enum DecoderError { InvalidStatusCode, InvalidPseudoheader, InvalidMaxDynamicSize, - IntegerUnderflow, IntegerOverflow, - StringUnderflow, + NeedMore(NeedMore), +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum NeedMore { UnexpectedEndOfStream, + IntegerUnderflow, + StringUnderflow, } enum Representation { @@ -163,7 +168,7 @@ impl Decoder { } /// Decodes the headers found in the given buffer. - pub fn decode(&mut self, src: &mut Cursor, mut f: F) -> Result<(), DecoderError> + pub fn decode(&mut self, src: &mut Cursor<&mut BytesMut>, mut f: F) -> Result<(), DecoderError> where F: FnMut(Header), { @@ -185,7 +190,9 @@ impl Decoder { Indexed => { trace!(" Indexed; rem={:?}", src.remaining()); can_resize = false; - f(self.decode_indexed(src)?); + let entry = self.decode_indexed(src)?; + consume(src); + f(entry); }, LiteralWithIndexing => { trace!(" LiteralWithIndexing; rem={:?}", src.remaining()); @@ -194,6 +201,7 @@ impl Decoder { // Insert the header into the table self.table.insert(entry.clone()); + consume(src); f(entry); }, @@ -201,12 +209,14 @@ impl Decoder { trace!(" LiteralWithoutIndexing; rem={:?}", src.remaining()); can_resize = false; let entry = self.decode_literal(src, false)?; + consume(src); f(entry); }, LiteralNeverIndexed => { trace!(" LiteralNeverIndexed; rem={:?}", src.remaining()); can_resize = false; let entry = self.decode_literal(src, false)?; + consume(src); // TODO: Track that this should never be indexed @@ -220,6 +230,7 @@ impl Decoder { // Handle the dynamic table size update self.process_size_update(src)?; + consume(src); }, } } @@ -227,7 +238,7 @@ impl Decoder { Ok(()) } - fn process_size_update(&mut self, buf: &mut Cursor) -> Result<(), DecoderError> { + fn process_size_update(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result<(), DecoderError> { let new_size = decode_int(buf, 5)?; if new_size > self.last_max_update { @@ -245,14 +256,14 @@ impl Decoder { Ok(()) } - fn decode_indexed(&self, buf: &mut Cursor) -> Result { + fn decode_indexed(&self, buf: &mut Cursor<&mut BytesMut>) -> Result { let index = decode_int(buf, 7)?; self.table.get(index) } fn decode_literal( &mut self, - buf: &mut Cursor, + buf: &mut Cursor<&mut BytesMut>, index: bool, ) -> Result { let prefix = if index { 6 } else { 4 }; @@ -275,13 +286,13 @@ impl Decoder { } } - fn decode_string(&mut self, buf: &mut Cursor) -> Result { + fn decode_string(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result { const HUFF_FLAG: u8 = 0b10000000; // The first bit in the first byte contains the huffman encoded flag. let huff = match peek_u8(buf) { Some(hdr) => (hdr & HUFF_FLAG) == HUFF_FLAG, - None => return Err(DecoderError::UnexpectedEndOfStream), + None => return Err(DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream)), }; // Decode the string length using 7 bit prefix @@ -293,7 +304,7 @@ impl Decoder { len, buf.remaining() ); - return Err(DecoderError::StringUnderflow); + return Err(DecoderError::NeedMore(NeedMore::StringUnderflow)); } if huff { @@ -358,7 +369,7 @@ fn decode_int(buf: &mut B, prefix_size: u8) -> Result(buf: &mut B, prefix_size: u8) -> Result(buf: &mut B) -> Option { @@ -412,11 +423,19 @@ fn peek_u8(buf: &mut B) -> Option { } } -fn take(buf: &mut Cursor, n: usize) -> Bytes { +fn take(buf: &mut Cursor<&mut BytesMut>, n: usize) -> Bytes { let pos = buf.position() as usize; - let ret = buf.get_ref().slice(pos, pos + n); - buf.set_position((pos + n) as u64); - ret + let mut head = buf.get_mut().split_to(pos + n); + buf.set_position(0); + head.split_to(pos); + head.freeze() +} + +fn consume(buf: &mut Cursor<&mut BytesMut>) { + // remove bytes from the internal BytesMut when they have been successfully + // decoded. This is a more permanent cursor position, which will be + // used to resume if decoding was only partial. + take(buf, 0); } // ===== impl Table ===== @@ -778,15 +797,15 @@ fn test_peek_u8() { #[test] fn test_decode_string_empty() { let mut de = Decoder::new(0); - let buf = Bytes::new(); - let err = de.decode_string(&mut Cursor::new(buf)).unwrap_err(); - assert_eq!(err, DecoderError::UnexpectedEndOfStream); + let mut buf = BytesMut::new(); + let err = de.decode_string(&mut Cursor::new(&mut buf)).unwrap_err(); + assert_eq!(err, DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream)); } #[test] fn test_decode_empty() { let mut de = Decoder::new(0); - let buf = Bytes::new(); - let empty = de.decode(&mut Cursor::new(buf), |_| {}).unwrap(); + let mut buf = BytesMut::new(); + let empty = de.decode(&mut Cursor::new(&mut buf), |_| {}).unwrap(); assert_eq!(empty, ()); } diff --git a/src/hpack/mod.rs b/src/hpack/mod.rs index 4a0ab9e..956de88 100644 --- a/src/hpack/mod.rs +++ b/src/hpack/mod.rs @@ -7,6 +7,6 @@ mod table; #[cfg(test)] mod test; -pub use self::decoder::{Decoder, DecoderError}; +pub use self::decoder::{Decoder, DecoderError, NeedMore}; pub use self::encoder::{Encode, EncodeState, Encoder, EncoderError}; pub use self::header::Header; diff --git a/src/hpack/test/fixture.rs b/src/hpack/test/fixture.rs index 80f7144..d7a4883 100644 --- a/src/hpack/test/fixture.rs +++ b/src/hpack/test/fixture.rs @@ -74,7 +74,7 @@ fn test_story(story: Value) { } decoder - .decode(&mut Cursor::new(case.wire.clone().into()), |e| { + .decode(&mut Cursor::new(&mut case.wire.clone().into()), |e| { let (name, value) = expect.remove(0); assert_eq!(name, key_str(&e)); assert_eq!(value, value_str(&e)); @@ -108,7 +108,7 @@ fn test_story(story: Value) { encoder.encode(None, &mut input.clone().into_iter(), &mut buf); decoder - .decode(&mut Cursor::new(buf.into()), |e| { + .decode(&mut Cursor::new(&mut buf), |e| { assert_eq!(e, input.remove(0).reify().unwrap()); }) .unwrap(); diff --git a/src/hpack/test/fuzz.rs b/src/hpack/test/fuzz.rs index 7b19aec..b5be1a8 100644 --- a/src/hpack/test/fuzz.rs +++ b/src/hpack/test/fuzz.rs @@ -149,7 +149,7 @@ impl FuzzHpack { // Decode the chunk! decoder - .decode(&mut Cursor::new(buf.into()), |e| { + .decode(&mut Cursor::new(&mut buf), |e| { assert_eq!(e, expect.remove(0).reify().unwrap()); }) .unwrap(); @@ -161,7 +161,7 @@ impl FuzzHpack { // Decode the chunk! decoder - .decode(&mut Cursor::new(buf.into()), |e| { + .decode(&mut Cursor::new(&mut buf), |e| { assert_eq!(e, expect.remove(0).reify().unwrap()); }) .unwrap(); diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index 7768bec..0d9919d 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -659,10 +659,12 @@ impl Prioritize { ) ), None => { - assert!(stream.state.is_canceled()); - stream.state.set_reset(Reason::CANCEL); + let reason = stream.state.get_scheduled_reset() + .expect("must be scheduled to reset"); - let frame = frame::Reset::new(stream.id, Reason::CANCEL); + stream.state.set_reset(reason); + + let frame = frame::Reset::new(stream.id, reason); Frame::Reset(frame) } }; @@ -674,7 +676,7 @@ impl Prioritize { self.last_opened_id = stream.id; } - if !stream.pending_send.is_empty() || stream.state.is_canceled() { + if !stream.pending_send.is_empty() || stream.state.is_scheduled_reset() { // TODO: Only requeue the sender IF it is ready to send // the next frame. i.e. don't requeue it if the next // frame is a data frame and the stream does not have diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index dd8b138..96391d5 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -1,4 +1,5 @@ use super::*; +use super::store::Resolve; use {frame, proto}; use codec::{RecvError, UserError}; use frame::{Reason, DEFAULT_INITIAL_WINDOW_SIZE}; @@ -54,6 +55,12 @@ pub(super) enum Event { Trailers(HeaderMap), } +#[derive(Debug)] +pub(super) enum RecvHeaderBlockError { + Oversize(T), + State(RecvError), +} + #[derive(Debug, Clone, Copy)] struct Indices { head: store::Key, @@ -133,7 +140,7 @@ impl Recv { frame: frame::Headers, stream: &mut store::Ptr, counts: &mut Counts, - ) -> Result<(), RecvError> { + ) -> Result<(), RecvHeaderBlockError>> { trace!("opening stream; init_window={}", self.init_window_sz); let is_initial = stream.state.recv_open(frame.is_end_stream())?; @@ -158,7 +165,7 @@ impl Recv { return Err(RecvError::Stream { id: stream.id, reason: Reason::PROTOCOL_ERROR, - }) + }.into()) }, }; @@ -166,6 +173,32 @@ impl Recv { } } + if frame.is_over_size() { + // A frame is over size if the decoded header block was bigger than + // SETTINGS_MAX_HEADER_LIST_SIZE. + // + // > A server that receives a larger header block than it is willing + // > to handle can send an HTTP 431 (Request Header Fields Too + // > Large) status code [RFC6585]. A client can discard responses + // > that it cannot process. + // + // So, if peer is a server, we'll send a 431. In either case, + // an error is recorded, which will send a REFUSED_STREAM, + // since we don't want any of the data frames either. + trace!("recv_headers; frame for {:?} is over size", stream.id); + return if counts.peer().is_server() && is_initial { + let mut res = frame::Headers::new( + stream.id, + frame::Pseudo::response(::http::StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE), + HeaderMap::new() + ); + res.set_end_stream(); + Err(RecvHeaderBlockError::Oversize(Some(res))) + } else { + Err(RecvHeaderBlockError::Oversize(None)) + }; + } + let message = counts.peer().convert_poll_message(frame)?; // Push the frame onto the stream's recv buffer @@ -517,15 +550,20 @@ impl Recv { ); new_stream.state.reserve_remote()?; + // Store the stream + let new_stream = store.insert(frame.promised_id(), new_stream).key(); + + + if frame.is_over_size() { + trace!("recv_push_promise; frame for {:?} is over size", frame.promised_id()); + return Err(RecvError::Stream { + id: frame.promised_id(), + reason: Reason::REFUSED_STREAM, + }); + } let mut ppp = store[stream].pending_push_promises.take(); - - { - // Store the stream - let mut new_stream = store.insert(frame.promised_id(), new_stream); - - ppp.push(&mut new_stream); - } + ppp.push(&mut store.resolve(new_stream)); let stream = &mut store[stream]; @@ -609,9 +647,7 @@ impl Recv { stream: &mut store::Ptr, counts: &mut Counts, ) { - assert!(stream.state.is_local_reset()); - - if stream.is_pending_reset_expiration() { + if !stream.state.is_local_reset() || stream.is_pending_reset_expiration() { return; } @@ -842,6 +878,14 @@ impl Event { } } +// ===== impl RecvHeaderBlockError ===== + +impl From for RecvHeaderBlockError { + fn from(err: RecvError) -> Self { + RecvHeaderBlockError::State(err) + } +} + // ===== util ===== fn parse_u64(src: &[u8]) -> Result { diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index cff192b..bd69080 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -132,6 +132,9 @@ impl Send { return; } + // Transition the state to reset no matter what. + stream.state.set_reset(reason); + // If closed AND the send queue is flushed, then the stream cannot be // reset explicitly, either. Implicit resets can still be queued. if is_closed && is_empty { @@ -143,9 +146,6 @@ impl Send { return; } - // Transition the state - stream.state.set_reset(reason); - self.recv_err(buffer, stream); let frame = frame::Reset::new(stream.id, reason); @@ -154,14 +154,18 @@ impl Send { self.prioritize.queue_frame(frame.into(), buffer, stream, task); } - pub fn schedule_cancel(&mut self, stream: &mut store::Ptr, task: &mut Option) { - trace!("schedule_cancel; {:?}", stream.id); + pub fn schedule_implicit_reset( + &mut self, + stream: &mut store::Ptr, + reason: Reason, + task: &mut Option, + ) { if stream.state.is_closed() { // Stream is already closed, nothing more to do return; } - stream.state.set_canceled(); + stream.state.set_scheduled_reset(reason); self.prioritize.reclaim_reserved_capacity(stream); self.prioritize.schedule_send(stream, task); diff --git a/src/proto/streams/state.rs b/src/proto/streams/state.rs index d4b15aa..0952070 100644 --- a/src/proto/streams/state.rs +++ b/src/proto/streams/state.rs @@ -76,10 +76,14 @@ enum Cause { LocallyReset(Reason), Io, - /// The user droped all handles to the stream without explicitly canceling. /// This indicates to the connection that a reset frame must be sent out /// once the send queue has been flushed. - Canceled, + /// + /// Examples of when this could happen: + /// - User drops all references to a stream, so we want to CANCEL the it. + /// - Header block size was too large, so we want to REFUSE, possibly + /// after sending a 431 response frame. + Scheduled(Reason), } impl State { @@ -269,15 +273,22 @@ impl State { self.inner = Closed(Cause::LocallyReset(reason)); } - /// Set the stream state to canceled - pub fn set_canceled(&mut self) { + /// Set the stream state to a scheduled reset. + pub fn set_scheduled_reset(&mut self, reason: Reason) { debug_assert!(!self.is_closed()); - self.inner = Closed(Cause::Canceled); + self.inner = Closed(Cause::Scheduled(reason)); } - pub fn is_canceled(&self) -> bool { + pub fn get_scheduled_reset(&self) -> Option { match self.inner { - Closed(Cause::Canceled) => true, + Closed(Cause::Scheduled(reason)) => Some(reason), + _ => None, + } + } + + pub fn is_scheduled_reset(&self) -> bool { + match self.inner { + Closed(Cause::Scheduled(..)) => true, _ => false, } } @@ -285,7 +296,7 @@ impl State { pub fn is_local_reset(&self) -> bool { match self.inner { Closed(Cause::LocallyReset(_)) => true, - Closed(Cause::Canceled) => true, + Closed(Cause::Scheduled(..)) => true, _ => false, } } @@ -381,8 +392,8 @@ impl State { // TODO: Is this correct? match self.inner { Closed(Cause::Proto(reason)) | - Closed(Cause::LocallyReset(reason)) => Err(proto::Error::Proto(reason)), - Closed(Cause::Canceled) => Err(proto::Error::Proto(Reason::CANCEL)), + Closed(Cause::LocallyReset(reason)) | + Closed(Cause::Scheduled(reason)) => Err(proto::Error::Proto(reason)), Closed(Cause::Io) => Err(proto::Error::Io(io::ErrorKind::BrokenPipe.into())), Closed(Cause::EndStream) | HalfClosedRemote(..) => Ok(false), diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 2a8d678..eb97f63 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -1,4 +1,5 @@ use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; +use super::recv::RecvHeaderBlockError; use super::store::{self, Entry, Resolve, Store}; use {client, proto, server}; use codec::{Codec, RecvError, SendError, UserError}; @@ -164,7 +165,28 @@ where ); let res = if stream.state.is_recv_headers() { - actions.recv.recv_headers(frame, stream, counts) + match actions.recv.recv_headers(frame, stream, counts) { + Ok(()) => Ok(()), + Err(RecvHeaderBlockError::Oversize(resp)) => { + if let Some(resp) = resp { + let _ = actions.send.send_headers( + resp, send_buffer, stream, counts, &mut actions.task); + + actions.send.schedule_implicit_reset( + stream, + Reason::REFUSED_STREAM, + &mut actions.task); + actions.recv.enqueue_reset_expiration(stream, counts); + Ok(()) + } else { + Err(RecvError::Stream { + id: stream.id, + reason: Reason::REFUSED_STREAM, + }) + } + }, + Err(RecvHeaderBlockError::State(err)) => Err(err), + } } else { if !frame.is_end_stream() { // TODO: Is this the right error @@ -363,22 +385,42 @@ where let me = &mut *me; let id = frame.stream_id(); + let promised_id = frame.promised_id(); - let stream = match me.store.find_mut(&id) { - Some(stream) => stream.key(), - None => return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)), + let res = { + let stream = match me.store.find_mut(&id) { + Some(stream) => stream.key(), + None => return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)), + }; + + if me.counts.peer().is_server() { + // The remote is a client and cannot reserve + trace!("recv_push_promise; error remote is client"); + return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + } + + me.actions.recv.recv_push_promise(frame, + &me.actions.send, + stream, + &mut me.store) }; - if me.counts.peer().is_server() { - // The remote is a client and cannot reserve - trace!("recv_push_promise; error remote is client"); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); - } + if let Err(err) = res { + if let Some(ref mut new_stream) = me.store.find_mut(&promised_id) { - me.actions.recv.recv_push_promise(frame, - &me.actions.send, - stream, - &mut me.store) + let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + me.actions.reset_on_recv_stream_err(&mut *send_buffer, new_stream, Err(err)) + } else { + // If there was a stream error, the stream should have been stored + // so we can track sending a reset. + // + // Otherwise, this MUST be an connection error. + assert!(!err.is_stream_error()); + Err(err) + } + } else { + res + } } pub fn next_incoming(&mut self) -> Option> { @@ -925,8 +967,9 @@ fn drop_stream_ref(inner: &Mutex, key: store::Key) { fn maybe_cancel(stream: &mut store::Ptr, actions: &mut Actions, counts: &mut Counts) { if stream.is_canceled_interest() { - actions.send.schedule_cancel( + actions.send.schedule_implicit_reset( stream, + Reason::CANCEL, &mut actions.task); actions.recv.enqueue_reset_expiration(stream, counts); } diff --git a/src/server.rs b/src/server.rs index 74bb9a9..b072f79 100644 --- a/src/server.rs +++ b/src/server.rs @@ -363,6 +363,10 @@ where codec.set_max_recv_frame_size(max as usize); } + if let Some(max) = builder.settings.max_header_list_size() { + codec.set_max_recv_header_list_size(max as usize); + } + // Send initial settings frame. codec .buffer(builder.settings.clone().into()) @@ -577,6 +581,12 @@ impl Builder { self } + /// Set the max size of received header frames. + pub fn max_header_list_size(&mut self, max: u32) -> &mut Self { + self.settings.set_max_header_list_size(Some(max)); + self + } + /// Set the maximum number of concurrent streams. /// /// The maximum concurrent streams setting only controls the maximum number diff --git a/tests/client_request.rs b/tests/client_request.rs index cbaaf37..1226b47 100644 --- a/tests/client_request.rs +++ b/tests/client_request.rs @@ -544,6 +544,82 @@ fn sending_request_on_closed_connection() { h2.join(srv).wait().expect("wait"); } +#[test] +fn recv_too_big_headers() { + let _ = ::env_logger::init(); + let (io, srv) = mock::new(); + + let srv = srv.assert_client_handshake() + .unwrap() + .recv_custom_settings( + frames::settings() + .max_header_list_size(10) + ) + .recv_frame( + frames::headers(1) + .request("GET", "https://http2.akamai.com/") + .eos(), + ) + .recv_frame( + frames::headers(3) + .request("GET", "https://http2.akamai.com/") + .eos(), + ) + .send_frame(frames::headers(1).response(200).eos()) + .send_frame(frames::headers(3).response(200)) + // no reset for 1, since it's closed anyways + // but reset for 3, since server hasn't closed stream + .recv_frame(frames::reset(3).refused()) + .idle_ms(10) + .close(); + + let client = client::Builder::new() + .max_header_list_size(10) + .handshake::<_, Bytes>(io) + .expect("handshake") + .and_then(|(mut client, conn)| { + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + + let req1 = client + .send_request(request, true) + .expect("send_request") + .0 + .expect_err("response1") + .map(|err| { + assert_eq!( + err.reason(), + Some(Reason::REFUSED_STREAM) + ); + }); + + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + + let req2 = client + .send_request(request, true) + .expect("send_request") + .0 + .expect_err("response2") + .map(|err| { + assert_eq!( + err.reason(), + Some(Reason::REFUSED_STREAM) + ); + }); + + conn.drive(req1.join(req2)) + .and_then(|(conn, _)| conn.expect("client")) + }); + + client.join(srv).wait().expect("wait"); + +} + const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; diff --git a/tests/codec_read.rs b/tests/codec_read.rs index 5aebced..4fee476 100644 --- a/tests/codec_read.rs +++ b/tests/codec_read.rs @@ -12,6 +12,7 @@ fn read_none() { } #[test] +#[ignore] fn read_frame_too_big() {} // ===== DATA ===== @@ -100,14 +101,73 @@ fn read_data_stream_id_zero() { // ===== HEADERS ===== #[test] +#[ignore] fn read_headers_without_pseudo() {} #[test] +#[ignore] fn read_headers_with_pseudo() {} #[test] +#[ignore] fn read_headers_empty_payload() {} +#[test] +fn read_continuation_frames() { + let _ = ::env_logger::init(); + let (io, srv) = mock::new(); + + let large = build_large_headers(); + let frame = large.iter().fold( + frames::headers(1).response(200), + |frame, &(name, ref value)| frame.field(name, &value[..]), + ).eos(); + + let srv = srv.assert_client_handshake() + .unwrap() + .recv_settings() + .recv_frame( + frames::headers(1) + .request("GET", "https://http2.akamai.com/") + .eos(), + ) + .send_frame(frame) + .close(); + + let client = client::handshake(io) + .expect("handshake") + .and_then(|(mut client, conn)| { + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + + let req = client + .send_request(request, true) + .expect("send_request") + .0 + .expect("response") + .map(move |res| { + assert_eq!(res.status(), StatusCode::OK); + let (head, _body) = res.into_parts(); + let expected = large.iter().fold(HeaderMap::new(), |mut map, &(name, ref value)| { + use support::frames::HttpTryInto; + map.append(name, value.as_str().try_into().unwrap()); + map + }); + assert_eq!(head.headers, expected); + }); + + conn.drive(req) + .and_then(move |(h2, _)| { + h2.expect("client") + }) + }); + + client.join(srv).wait().expect("wait"); + +} + #[test] fn update_max_frame_len_at_rest() { let _ = ::env_logger::init(); diff --git a/tests/codec_write.rs b/tests/codec_write.rs index 06a374f..a0a3da5 100644 --- a/tests/codec_write.rs +++ b/tests/codec_write.rs @@ -59,28 +59,3 @@ fn write_continuation_frames() { client.join(srv).wait().expect("wait"); } - -fn build_large_headers() -> Vec<(&'static str, String)> { - vec![ - ("one", "hello".to_string()), - ("two", build_large_string('2', 4 * 1024)), - ("three", "three".to_string()), - ("four", build_large_string('4', 4 * 1024)), - ("five", "five".to_string()), - ("six", build_large_string('6', 4 * 1024)), - ("seven", "seven".to_string()), - ("eight", build_large_string('8', 4 * 1024)), - ("nine", "nine".to_string()), - ("ten", build_large_string('0', 4 * 1024)), - ] -} - -fn build_large_string(ch: char, len: usize) -> String { - let mut ret = String::new(); - - for _ in 0..len { - ret.push(ch); - } - - ret -} diff --git a/tests/push_promise.rs b/tests/push_promise.rs index 1c38023..c133dab 100644 --- a/tests/push_promise.rs +++ b/tests/push_promise.rs @@ -137,6 +137,56 @@ fn pending_push_promises_reset_when_dropped() { client.join(srv).wait().expect("wait"); } +#[test] +fn recv_push_promise_over_max_header_list_size() { + let _ = ::env_logger::init(); + let (io, srv) = mock::new(); + + let srv = srv.assert_client_handshake() + .unwrap() + .recv_custom_settings( + frames::settings() + .max_header_list_size(10) + ) + .recv_frame( + frames::headers(1) + .request("GET", "https://http2.akamai.com/") + .eos(), + ) + .send_frame(frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css")) + .recv_frame(frames::reset(2).refused()) + .send_frame(frames::headers(1).response(200).eos()) + .idle_ms(10) + .close(); + + let client = client::Builder::new() + .max_header_list_size(10) + .handshake::<_, Bytes>(io) + .expect("handshake") + .and_then(|(mut client, conn)| { + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + + let req = client + .send_request(request, true) + .expect("send_request") + .0 + .expect_err("response") + .map(|err| { + assert_eq!( + err.reason(), + Some(Reason::REFUSED_STREAM) + ); + }); + + conn.drive(req) + .and_then(|(conn, _)| conn.expect("client")) + }); + client.join(srv).wait().expect("wait"); +} + #[test] #[ignore] fn recv_push_promise_with_unsafe_method_is_stream_error() { diff --git a/tests/server.rs b/tests/server.rs index 9ad6196..290fedf 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -261,3 +261,76 @@ fn sends_reset_cancel_when_res_body_is_dropped() { srv.join(client).wait().expect("wait"); } + +#[test] +fn too_big_headers_sends_431() { + let _ = ::env_logger::init(); + let (io, client) = mock::new(); + + let client = client + .assert_server_handshake() + .unwrap() + .recv_custom_settings( + frames::settings() + .max_header_list_size(10) + ) + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .field("some-header", "some-value") + .eos() + ) + .recv_frame(frames::headers(1).response(431).eos()) + .idle_ms(10) + .close(); + + let srv = server::Builder::new() + .max_header_list_size(10) + .handshake::<_, Bytes>(io) + .expect("handshake") + .and_then(|srv| { + srv.into_future() + .expect("server") + .map(|(req, _)| { + assert!(req.is_none(), "req is {:?}", req); + }) + }); + + srv.join(client).wait().expect("wait"); +} + +#[test] +fn too_big_headers_sends_reset_after_431_if_not_eos() { + let _ = ::env_logger::init(); + let (io, client) = mock::new(); + + let client = client + .assert_server_handshake() + .unwrap() + .recv_custom_settings( + frames::settings() + .max_header_list_size(10) + ) + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .field("some-header", "some-value") + ) + .recv_frame(frames::headers(1).response(431).eos()) + .recv_frame(frames::reset(1).refused()) + .close(); + + let srv = server::Builder::new() + .max_header_list_size(10) + .handshake::<_, Bytes>(io) + .expect("handshake") + .and_then(|srv| { + srv.into_future() + .expect("server") + .map(|(req, _)| { + assert!(req.is_none(), "req is {:?}", req); + }) + }); + + srv.join(client).wait().expect("wait"); +} diff --git a/tests/support/frames.rs b/tests/support/frames.rs index 3a25bc3..4a50ff0 100644 --- a/tests/support/frames.rs +++ b/tests/support/frames.rs @@ -152,6 +152,10 @@ impl Mock { self } + pub fn into_fields(self) -> HeaderMap { + self.0.into_parts().1 + } + fn into_parts(self) -> (StreamId, frame::Pseudo, HeaderMap) { assert!(!self.0.is_end_stream(), "eos flag will be lost"); assert!(self.0.is_end_headers(), "unset eoh will be lost"); @@ -304,6 +308,11 @@ impl Mock { self.0.set_initial_window_size(Some(val)); self } + + pub fn max_header_list_size(mut self, val: u32) -> Self { + self.0.set_max_header_list_size(Some(val)); + self + } } impl From> for frame::Settings { diff --git a/tests/support/mock.rs b/tests/support/mock.rs index 3614359..d12dd4f 100644 --- a/tests/support/mock.rs +++ b/tests/support/mock.rs @@ -394,12 +394,13 @@ pub trait HandleFutureExt { self.recv_custom_settings(frame::Settings::default()) } - fn recv_custom_settings(self, settings: frame::Settings) + fn recv_custom_settings(self, settings: T) -> RecvFrame, Handle), Error = ()>>> where Self: Sized + 'static, Self: Future, Self::Error: fmt::Debug, + T: Into, { let map = self .map(|(settings, handle)| (Some(settings.into()), handle)) @@ -409,7 +410,7 @@ pub trait HandleFutureExt { Box::new(map); RecvFrame { inner: boxed, - frame: settings.into(), + frame: settings.into().into(), } } diff --git a/tests/support/mod.rs b/tests/support/mod.rs index 732fdef..ec61244 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -64,3 +64,4 @@ pub type Codec = h2::Codec>; // This is the frame type that is sent pub type SendFrame = h2::frame::Frame<::std::io::Cursor<::bytes::Bytes>>; + diff --git a/tests/support/prelude.rs b/tests/support/prelude.rs index 3bf75b8..272fd98 100644 --- a/tests/support/prelude.rs +++ b/tests/support/prelude.rs @@ -85,3 +85,28 @@ where } } } + +pub fn build_large_headers() -> Vec<(&'static str, String)> { + vec![ + ("one", "hello".to_string()), + ("two", build_large_string('2', 4 * 1024)), + ("three", "three".to_string()), + ("four", build_large_string('4', 4 * 1024)), + ("five", "five".to_string()), + ("six", build_large_string('6', 4 * 1024)), + ("seven", "seven".to_string()), + ("eight", build_large_string('8', 4 * 1024)), + ("nine", "nine".to_string()), + ("ten", build_large_string('0', 4 * 1024)), + ] +} + +fn build_large_string(ch: char, len: usize) -> String { + let mut ret = String::new(); + + for _ in 0..len { + ret.push(ch); + } + + ret +}