diff --git a/Cargo.toml b/Cargo.toml index 0097755..109b906 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ unstable = [] [dependencies] futures = "0.1" -tokio-io = "0.1.3" +tokio-io = "0.1.4" bytes = "0.4" http = "0.1" byteorder = "1.0" diff --git a/src/client.rs b/src/client.rs index 660d08e..e3485da 100644 --- a/src/client.rs +++ b/src/client.rs @@ -172,6 +172,12 @@ impl Builder { self } + /// Set the max size of received header frames. + pub fn max_header_list_size(&mut self, max: u32) -> &mut Self { + self.settings.set_max_header_list_size(Some(max)); + self + } + /// Set the maximum number of concurrent streams. /// /// Clients can only limit the maximum number of streams that that the @@ -339,6 +345,10 @@ where codec.set_max_recv_frame_size(max as usize); } + if let Some(max) = self.builder.settings.max_header_list_size() { + codec.set_max_recv_header_list_size(max as usize); + } + // Send initial settings frame codec .buffer(self.builder.settings.clone().into()) diff --git a/src/codec/error.rs b/src/codec/error.rs index 76ac619..74063c8 100644 --- a/src/codec/error.rs +++ b/src/codec/error.rs @@ -54,6 +54,15 @@ pub enum UserError { // ===== impl RecvError ===== +impl RecvError { + pub(crate) fn is_stream_error(&self) -> bool { + match *self { + RecvError::Stream { .. } => true, + _ => false, + } + } +} + impl From for RecvError { fn from(src: io::Error) -> Self { RecvError::Io(src) diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs index 29a590e..2f7bdc0 100644 --- a/src/codec/framed_read.rs +++ b/src/codec/framed_read.rs @@ -13,6 +13,9 @@ use std::io; use tokio_io::AsyncRead; use tokio_io::codec::length_delimited; +// 16 MB "sane default" taken from golang http2 +const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20; + #[derive(Debug)] pub struct FramedRead { inner: length_delimited::FramedRead, @@ -20,6 +23,8 @@ pub struct FramedRead { // hpack decoder state hpack: hpack::Decoder, + max_header_list_size: usize, + partial: Option, } @@ -36,8 +41,6 @@ struct Partial { #[derive(Debug)] enum Continuable { Headers(frame::Headers), - // Decode the Continuation frame but ignore it... - // Ignore(StreamId), PushPromise(frame::PushPromise), } @@ -46,6 +49,7 @@ impl FramedRead { FramedRead { inner: inner, hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE), + max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE, partial: None, } } @@ -67,6 +71,66 @@ impl FramedRead { trace!(" -> kind={:?}", kind); + macro_rules! header_block { + ($frame:ident, $head:ident, $bytes:ident) => ({ + // Drop the frame header + // TODO: Change to drain: carllerche/bytes#130 + let _ = $bytes.split_to(frame::HEADER_LEN); + + // Parse the header frame w/o parsing the payload + let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) { + Ok(res) => res, + Err(frame::Error::InvalidDependencyId) => { + debug!("stream error PROTOCOL_ERROR -- invalid HEADERS dependency ID"); + // A stream cannot depend on itself. An endpoint MUST + // treat this as a stream error (Section 5.4.2) of type + // `PROTOCOL_ERROR`. + return Err(Stream { + id: $head.stream_id(), + reason: Reason::PROTOCOL_ERROR, + }); + }, + Err(e) => { + debug!("connection error PROTOCOL_ERROR -- failed to load frame; err={:?}", e); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } + }; + + let is_end_headers = frame.is_end_headers(); + + // Load the HPACK encoded headers + match frame.load_hpack(&mut payload, self.max_header_list_size, &mut self.hpack) { + Ok(_) => {}, + Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}, + Err(frame::Error::MalformedMessage) => { + + debug!("stream error PROTOCOL_ERROR -- malformed header block"); + return Err(Stream { + id: $head.stream_id(), + reason: Reason::PROTOCOL_ERROR, + }); + }, + Err(e) => { + debug!("connection error PROTOCOL_ERROR -- failed HPACK decoding; err={:?}", e); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } + } + + if is_end_headers { + frame.into() + } else { + trace!("loaded partial header block"); + // Defer returning the frame + self.partial = Some(Partial { + frame: Continuable::$frame(frame), + buf: payload, + }); + + return Ok(None); + } + }); + } + let frame = match kind { Kind::Settings => { let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); @@ -103,56 +167,7 @@ impl FramedRead { })?.into() }, Kind::Headers => { - // Drop the frame header - // TODO: Change to drain: carllerche/bytes#130 - let _ = bytes.split_to(frame::HEADER_LEN); - - // Parse the header frame w/o parsing the payload - let (mut headers, payload) = match frame::Headers::load(head, bytes) { - Ok(res) => res, - Err(frame::Error::InvalidDependencyId) => { - // A stream cannot depend on itself. An endpoint MUST - // treat this as a stream error (Section 5.4.2) of type - // `PROTOCOL_ERROR`. - debug!("stream error PROTOCOL_ERROR -- invalid HEADERS dependency ID"); - return Err(Stream { - id: head.stream_id(), - reason: Reason::PROTOCOL_ERROR, - }); - }, - Err(e) => { - debug!("connection error PROTOCOL_ERROR -- failed to load HEADERS frame; err={:?}", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - }; - - if headers.is_end_headers() { - // Load the HPACK encoded headers & return the frame - match headers.load_hpack(payload, &mut self.hpack) { - Ok(_) => {}, - Err(frame::Error::MalformedMessage) => { - debug!("stream error PROTOCOL_ERROR -- malformed HEADERS frame"); - return Err(Stream { - id: head.stream_id(), - reason: Reason::PROTOCOL_ERROR, - }); - }, - Err(e) => { - debug!("connection error PROTOCOL_ERROR -- failed HEADERS frame HPACK decoding; err={:?}", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - } - - headers.into() - } else { - // Defer loading the frame - self.partial = Some(Partial { - frame: Continuable::Headers(headers), - buf: payload, - }); - - return Ok(None); - } + header_block!(Headers, head, bytes) }, Kind::Reset => { let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]); @@ -163,44 +178,7 @@ impl FramedRead { res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into() }, Kind::PushPromise => { - // Drop the frame header - // TODO: Change to drain: carllerche/bytes#130 - let _ = bytes.split_to(frame::HEADER_LEN); - - // Parse the frame w/o parsing the payload - let (mut push, payload) = frame::PushPromise::load(head, bytes) - .map_err(|e| { - debug!("connection error PROTOCOL_ERROR -- failed to load PUSH_PROMISE frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })?; - - if push.is_end_headers() { - // Load the HPACK encoded headers & return the frame - match push.load_hpack(payload, &mut self.hpack) { - Ok(_) => {}, - Err(frame::Error::MalformedMessage) => { - debug!("stream error PROTOCOL_ERROR -- malformed PUSH_PROMISE frame"); - return Err(Stream { - id: head.stream_id(), - reason: Reason::PROTOCOL_ERROR, - }); - }, - Err(e) => { - debug!("connection error PROTOCOL_ERROR -- failed PUSH_PROMISE frame HPACK decoding; err={:?}", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - } - - push.into() - } else { - // Defer loading the frame - self.partial = Some(Partial { - frame: Continuable::PushPromise(push), - buf: payload, - }); - - return Ok(None); - } + header_block!(PushPromise, head, bytes) }, Kind::Priority => { if head.stream_id() == 0 { @@ -224,8 +202,7 @@ impl FramedRead { } }, Kind::Continuation => { - // TODO: Un-hack this - let end_of_headers = (head.flag() & 0x4) == 0x4; + let is_end_headers = (head.flag() & 0x4) == 0x4; let mut partial = match self.partial.take() { Some(partial) => partial, @@ -235,22 +212,43 @@ impl FramedRead { } }; - // Extend the buf - partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]); - - if !end_of_headers { - self.partial = Some(partial); - return Ok(None); - } - // The stream identifiers must match if partial.frame.stream_id() != head.stream_id() { debug!("connection error PROTOCOL_ERROR -- CONTINUATION frame stream ID does not match previous frame stream ID"); return Err(Connection(Reason::PROTOCOL_ERROR)); } - match partial.frame.load_hpack(partial.buf, &mut self.hpack) { + + + // Extend the buf + if partial.buf.is_empty() { + partial.buf = bytes.split_off(frame::HEADER_LEN); + } else { + if partial.frame.is_over_size() { + // If there was left over bytes previously, they may be + // needed to continue decoding, even though we will + // be ignoring this frame. This is done to keep the HPACK + // decoder state up-to-date. + // + // Still, we need to be careful, because if a malicious + // attacker were to try to send a gigantic string, such + // that it fits over multiple header blocks, we could + // grow memory uncontrollably again, and that'd be a shame. + // + // Instead, we use a simple heuristic to determine if + // we should continue to ignore decoding, or to tell + // the attacker to go away. + if partial.buf.len() + bytes.len() > self.max_header_list_size { + debug!("connection error COMPRESSION_ERROR -- CONTINUATION frame header block size over ignorable limit"); + return Err(Connection(Reason::COMPRESSION_ERROR)); + } + } + partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]); + } + + match partial.frame.load_hpack(&mut partial.buf, self.max_header_list_size, &mut self.hpack) { Ok(_) => {}, + Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}, Err(frame::Error::MalformedMessage) => { debug!("stream error PROTOCOL_ERROR -- malformed CONTINUATION frame"); return Err(Stream { @@ -258,10 +256,18 @@ impl FramedRead { reason: Reason::PROTOCOL_ERROR, }); }, - Err(_) => return Err(Connection(Reason::PROTOCOL_ERROR)), + Err(e) => { + debug!("connection error PROTOCOL_ERROR -- failed HPACK decoding; err={:?}", e); + return Err(Connection(Reason::PROTOCOL_ERROR)); + }, } - partial.frame.into() + if is_end_headers { + partial.frame.into() + } else { + self.partial = Some(partial); + return Ok(None); + } }, Kind::Unknown => { // Unknown frames are ignored @@ -295,6 +301,12 @@ impl FramedRead { assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize); self.inner.set_max_frame_length(val) } + + /// Update the max header list size setting. + #[inline] + pub fn set_max_header_list_size(&mut self, val: usize) { + self.max_header_list_size = val; + } } impl Stream for FramedRead @@ -322,14 +334,13 @@ where } fn map_err(err: io::Error) -> RecvError { - use std::error::Error; + use tokio_io::codec::length_delimited::FrameTooBig; if let io::ErrorKind::InvalidData = err.kind() { - // woah, brittle... - // TODO: with tokio-io v0.1.4, we can check - // err.get_ref().is::() - if err.description() == "frame size too big" { - return RecvError::Connection(Reason::FRAME_SIZE_ERROR); + if let Some(custom) = err.get_ref() { + if custom.is::() { + return RecvError::Connection(Reason::FRAME_SIZE_ERROR); + } } } err.into() @@ -345,14 +356,22 @@ impl Continuable { } } + fn is_over_size(&self) -> bool { + match *self { + Continuable::Headers(ref h) => h.is_over_size(), + Continuable::PushPromise(ref p) => p.is_over_size(), + } + } + fn load_hpack( &mut self, - src: BytesMut, + src: &mut BytesMut, + max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), frame::Error> { match *self { - Continuable::Headers(ref mut h) => h.load_hpack(src, decoder), - Continuable::PushPromise(ref mut p) => p.load_hpack(src, decoder), + Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder), + Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder), } } } diff --git a/src/codec/mod.rs b/src/codec/mod.rs index 6dc3267..0f8acbf 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -90,6 +90,11 @@ impl Codec { self.framed_write().set_max_frame_size(val) } + /// Set the max header list size that can be received. + pub fn set_max_recv_header_list_size(&mut self, val: usize) { + self.inner.set_max_header_list_size(val); + } + /// Get a reference to the inner stream. #[cfg(feature = "unstable")] pub fn get_ref(&self) -> &T { diff --git a/src/frame/headers.rs b/src/frame/headers.rs index 4c98442..1ffa743 100644 --- a/src/frame/headers.rs +++ b/src/frame/headers.rs @@ -86,6 +86,9 @@ struct HeaderBlock { /// The decoded header fields fields: HeaderMap, + /// Set to true if decoding went over the max header list size. + is_over_size: bool, + /// Pseudo headers, these are broken out as they must be sent as part of the /// headers frame. pseudo: Pseudo, @@ -116,6 +119,7 @@ impl Headers { stream_dep: None, header_block: HeaderBlock { fields: fields, + is_over_size: false, pseudo: pseudo, }, flags: HeadersFlag::default(), @@ -131,6 +135,7 @@ impl Headers { stream_dep: None, header_block: HeaderBlock { fields: fields, + is_over_size: false, pseudo: Pseudo::default(), }, flags: flags, @@ -185,6 +190,7 @@ impl Headers { stream_dep: stream_dep, header_block: HeaderBlock { fields: HeaderMap::new(), + is_over_size: false, pseudo: Pseudo::default(), }, flags: flags, @@ -193,8 +199,8 @@ impl Headers { Ok((headers, src)) } - pub fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> { - self.header_block.load(src, decoder) + pub fn load_hpack(&mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder) -> Result<(), Error> { + self.header_block.load(src, max_header_list_size, decoder) } pub fn stream_id(&self) -> StreamId { @@ -217,6 +223,10 @@ impl Headers { self.flags.set_end_stream() } + pub fn is_over_size(&self) -> bool { + self.header_block.is_over_size + } + pub fn into_parts(self) -> (Pseudo, HeaderMap) { (self.header_block.pseudo, self.header_block.fields) } @@ -304,6 +314,7 @@ impl PushPromise { flags: flags, header_block: HeaderBlock { fields: HeaderMap::new(), + is_over_size: false, pseudo: Pseudo::default(), }, promised_id: promised_id, @@ -312,8 +323,8 @@ impl PushPromise { Ok((frame, src)) } - pub fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> { - self.header_block.load(src, decoder) + pub fn load_hpack(&mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder) -> Result<(), Error> { + self.header_block.load(src, max_header_list_size, decoder) } pub fn stream_id(&self) -> StreamId { @@ -332,6 +343,10 @@ impl PushPromise { self.flags.set_end_headers(); } + pub fn is_over_size(&self) -> bool { + self.header_block.is_over_size + } + pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option { use bytes::BufMut; @@ -364,6 +379,7 @@ impl PushPromise { flags: PushPromiseFlag::default(), header_block: HeaderBlock { fields, + is_over_size: false, pseudo, }, promised_id, @@ -677,10 +693,12 @@ impl fmt::Debug for PushPromiseFlag { // ===== HeaderBlock ===== + impl HeaderBlock { - fn load(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> { - let mut reg = false; + fn load(&mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder) -> Result<(), Error> { + let mut reg = !self.fields.is_empty(); let mut malformed = false; + let mut headers_size = self.calculate_header_list_size(); macro_rules! set_pseudo { ($field:ident, $val:expr) => {{ @@ -691,22 +709,25 @@ impl HeaderBlock { trace!("load_hpack; header malformed -- repeated pseudo"); malformed = true; } else { - self.pseudo.$field = Some($val); + let __val = $val; + headers_size += decoded_header_size(stringify!($ident).len() + 1, __val.as_str().len()); + if headers_size < max_header_list_size { + self.pseudo.$field = Some(__val); + } else if !self.is_over_size { + trace!("load_hpack; header list size over max"); + self.is_over_size = true; + } } }} } - let mut src = Cursor::new(src.freeze()); + let mut cursor = Cursor::new(src); - // At this point, we're going to assume that the hpack encoded headers - // contain the entire payload. Later, we need to check for stream - // priority. - // // If the header frame is malformed, we still have to continue decoding // the headers. A malformed header frame is a stream level error, but // the hpack state is connection level. In order to maintain correct // state for other streams, the hpack decoding process must complete. - let res = decoder.decode(&mut src, |header| { + let res = decoder.decode(&mut cursor, |header| { use hpack::Header::*; match header { @@ -730,7 +751,14 @@ impl HeaderBlock { malformed = true; } else { reg = true; - self.fields.append(name, value); + + headers_size += decoded_header_size(name.as_str().len(), value.len()); + if headers_size < max_header_list_size { + self.fields.append(name, value); + } else if !self.is_over_size { + trace!("load_hpack; header list size over max"); + self.is_over_size = true; + } } }, Authority(v) => set_pseudo!(authority, v), @@ -763,4 +791,48 @@ impl HeaderBlock { }, } } + + /// Calculates the size of the currently decoded header list. + /// + /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE + /// + /// > The value is based on the uncompressed size of header fields, + /// > including the length of the name and value in octets plus an + /// > overhead of 32 octets for each header field. + fn calculate_header_list_size(&self) -> usize { + macro_rules! pseudo_size { + ($name:ident) => ({ + self.pseudo + .$name + .as_ref() + .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len())) + .unwrap_or(0) + }); + } + + pseudo_size!(method) + + pseudo_size!(scheme) + + pseudo_size!(status) + + pseudo_size!(authority) + + pseudo_size!(path) + + self.fields.iter() + .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len())) + .sum::() + } +} + +fn decoded_header_size(name: usize, value: usize) -> usize { + name + value + 32 +} + +// Stupid hack to make the set_pseudo! macro happy, since all other values +// have a method `as_str` except for `String`. +trait AsStr { + fn as_str(&self) -> &str; +} + +impl AsStr for String { + fn as_str(&self) -> &str { + self + } } diff --git a/src/frame/settings.rs b/src/frame/settings.rs index b130f43..90fd493 100644 --- a/src/frame/settings.rs +++ b/src/frame/settings.rs @@ -89,6 +89,14 @@ impl Settings { self.max_frame_size = size; } + pub fn max_header_list_size(&self) -> Option { + self.max_header_list_size + } + + pub fn set_max_header_list_size(&mut self, size: Option) { + self.max_header_list_size = size; + } + pub fn is_push_enabled(&self) -> bool { self.enable_push.unwrap_or(1) != 0 } diff --git a/src/hpack/decoder.rs b/src/hpack/decoder.rs index ddc2812..514c09f 100644 --- a/src/hpack/decoder.rs +++ b/src/hpack/decoder.rs @@ -34,10 +34,15 @@ pub enum DecoderError { InvalidStatusCode, InvalidPseudoheader, InvalidMaxDynamicSize, - IntegerUnderflow, IntegerOverflow, - StringUnderflow, + NeedMore(NeedMore), +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum NeedMore { UnexpectedEndOfStream, + IntegerUnderflow, + StringUnderflow, } enum Representation { @@ -163,7 +168,7 @@ impl Decoder { } /// Decodes the headers found in the given buffer. - pub fn decode(&mut self, src: &mut Cursor, mut f: F) -> Result<(), DecoderError> + pub fn decode(&mut self, src: &mut Cursor<&mut BytesMut>, mut f: F) -> Result<(), DecoderError> where F: FnMut(Header), { @@ -185,7 +190,9 @@ impl Decoder { Indexed => { trace!(" Indexed; rem={:?}", src.remaining()); can_resize = false; - f(self.decode_indexed(src)?); + let entry = self.decode_indexed(src)?; + consume(src); + f(entry); }, LiteralWithIndexing => { trace!(" LiteralWithIndexing; rem={:?}", src.remaining()); @@ -194,6 +201,7 @@ impl Decoder { // Insert the header into the table self.table.insert(entry.clone()); + consume(src); f(entry); }, @@ -201,12 +209,14 @@ impl Decoder { trace!(" LiteralWithoutIndexing; rem={:?}", src.remaining()); can_resize = false; let entry = self.decode_literal(src, false)?; + consume(src); f(entry); }, LiteralNeverIndexed => { trace!(" LiteralNeverIndexed; rem={:?}", src.remaining()); can_resize = false; let entry = self.decode_literal(src, false)?; + consume(src); // TODO: Track that this should never be indexed @@ -220,6 +230,7 @@ impl Decoder { // Handle the dynamic table size update self.process_size_update(src)?; + consume(src); }, } } @@ -227,7 +238,7 @@ impl Decoder { Ok(()) } - fn process_size_update(&mut self, buf: &mut Cursor) -> Result<(), DecoderError> { + fn process_size_update(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result<(), DecoderError> { let new_size = decode_int(buf, 5)?; if new_size > self.last_max_update { @@ -245,14 +256,14 @@ impl Decoder { Ok(()) } - fn decode_indexed(&self, buf: &mut Cursor) -> Result { + fn decode_indexed(&self, buf: &mut Cursor<&mut BytesMut>) -> Result { let index = decode_int(buf, 7)?; self.table.get(index) } fn decode_literal( &mut self, - buf: &mut Cursor, + buf: &mut Cursor<&mut BytesMut>, index: bool, ) -> Result { let prefix = if index { 6 } else { 4 }; @@ -275,13 +286,13 @@ impl Decoder { } } - fn decode_string(&mut self, buf: &mut Cursor) -> Result { + fn decode_string(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result { const HUFF_FLAG: u8 = 0b10000000; // The first bit in the first byte contains the huffman encoded flag. let huff = match peek_u8(buf) { Some(hdr) => (hdr & HUFF_FLAG) == HUFF_FLAG, - None => return Err(DecoderError::UnexpectedEndOfStream), + None => return Err(DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream)), }; // Decode the string length using 7 bit prefix @@ -293,7 +304,7 @@ impl Decoder { len, buf.remaining() ); - return Err(DecoderError::StringUnderflow); + return Err(DecoderError::NeedMore(NeedMore::StringUnderflow)); } if huff { @@ -358,7 +369,7 @@ fn decode_int(buf: &mut B, prefix_size: u8) -> Result(buf: &mut B, prefix_size: u8) -> Result(buf: &mut B) -> Option { @@ -412,11 +423,19 @@ fn peek_u8(buf: &mut B) -> Option { } } -fn take(buf: &mut Cursor, n: usize) -> Bytes { +fn take(buf: &mut Cursor<&mut BytesMut>, n: usize) -> Bytes { let pos = buf.position() as usize; - let ret = buf.get_ref().slice(pos, pos + n); - buf.set_position((pos + n) as u64); - ret + let mut head = buf.get_mut().split_to(pos + n); + buf.set_position(0); + head.split_to(pos); + head.freeze() +} + +fn consume(buf: &mut Cursor<&mut BytesMut>) { + // remove bytes from the internal BytesMut when they have been successfully + // decoded. This is a more permanent cursor position, which will be + // used to resume if decoding was only partial. + take(buf, 0); } // ===== impl Table ===== @@ -778,15 +797,15 @@ fn test_peek_u8() { #[test] fn test_decode_string_empty() { let mut de = Decoder::new(0); - let buf = Bytes::new(); - let err = de.decode_string(&mut Cursor::new(buf)).unwrap_err(); - assert_eq!(err, DecoderError::UnexpectedEndOfStream); + let mut buf = BytesMut::new(); + let err = de.decode_string(&mut Cursor::new(&mut buf)).unwrap_err(); + assert_eq!(err, DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream)); } #[test] fn test_decode_empty() { let mut de = Decoder::new(0); - let buf = Bytes::new(); - let empty = de.decode(&mut Cursor::new(buf), |_| {}).unwrap(); + let mut buf = BytesMut::new(); + let empty = de.decode(&mut Cursor::new(&mut buf), |_| {}).unwrap(); assert_eq!(empty, ()); } diff --git a/src/hpack/mod.rs b/src/hpack/mod.rs index 4a0ab9e..956de88 100644 --- a/src/hpack/mod.rs +++ b/src/hpack/mod.rs @@ -7,6 +7,6 @@ mod table; #[cfg(test)] mod test; -pub use self::decoder::{Decoder, DecoderError}; +pub use self::decoder::{Decoder, DecoderError, NeedMore}; pub use self::encoder::{Encode, EncodeState, Encoder, EncoderError}; pub use self::header::Header; diff --git a/src/hpack/test/fixture.rs b/src/hpack/test/fixture.rs index 80f7144..d7a4883 100644 --- a/src/hpack/test/fixture.rs +++ b/src/hpack/test/fixture.rs @@ -74,7 +74,7 @@ fn test_story(story: Value) { } decoder - .decode(&mut Cursor::new(case.wire.clone().into()), |e| { + .decode(&mut Cursor::new(&mut case.wire.clone().into()), |e| { let (name, value) = expect.remove(0); assert_eq!(name, key_str(&e)); assert_eq!(value, value_str(&e)); @@ -108,7 +108,7 @@ fn test_story(story: Value) { encoder.encode(None, &mut input.clone().into_iter(), &mut buf); decoder - .decode(&mut Cursor::new(buf.into()), |e| { + .decode(&mut Cursor::new(&mut buf), |e| { assert_eq!(e, input.remove(0).reify().unwrap()); }) .unwrap(); diff --git a/src/hpack/test/fuzz.rs b/src/hpack/test/fuzz.rs index 7b19aec..b5be1a8 100644 --- a/src/hpack/test/fuzz.rs +++ b/src/hpack/test/fuzz.rs @@ -149,7 +149,7 @@ impl FuzzHpack { // Decode the chunk! decoder - .decode(&mut Cursor::new(buf.into()), |e| { + .decode(&mut Cursor::new(&mut buf), |e| { assert_eq!(e, expect.remove(0).reify().unwrap()); }) .unwrap(); @@ -161,7 +161,7 @@ impl FuzzHpack { // Decode the chunk! decoder - .decode(&mut Cursor::new(buf.into()), |e| { + .decode(&mut Cursor::new(&mut buf), |e| { assert_eq!(e, expect.remove(0).reify().unwrap()); }) .unwrap(); diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index 7768bec..0d9919d 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -659,10 +659,12 @@ impl Prioritize { ) ), None => { - assert!(stream.state.is_canceled()); - stream.state.set_reset(Reason::CANCEL); + let reason = stream.state.get_scheduled_reset() + .expect("must be scheduled to reset"); - let frame = frame::Reset::new(stream.id, Reason::CANCEL); + stream.state.set_reset(reason); + + let frame = frame::Reset::new(stream.id, reason); Frame::Reset(frame) } }; @@ -674,7 +676,7 @@ impl Prioritize { self.last_opened_id = stream.id; } - if !stream.pending_send.is_empty() || stream.state.is_canceled() { + if !stream.pending_send.is_empty() || stream.state.is_scheduled_reset() { // TODO: Only requeue the sender IF it is ready to send // the next frame. i.e. don't requeue it if the next // frame is a data frame and the stream does not have diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index dd8b138..96391d5 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -1,4 +1,5 @@ use super::*; +use super::store::Resolve; use {frame, proto}; use codec::{RecvError, UserError}; use frame::{Reason, DEFAULT_INITIAL_WINDOW_SIZE}; @@ -54,6 +55,12 @@ pub(super) enum Event { Trailers(HeaderMap), } +#[derive(Debug)] +pub(super) enum RecvHeaderBlockError { + Oversize(T), + State(RecvError), +} + #[derive(Debug, Clone, Copy)] struct Indices { head: store::Key, @@ -133,7 +140,7 @@ impl Recv { frame: frame::Headers, stream: &mut store::Ptr, counts: &mut Counts, - ) -> Result<(), RecvError> { + ) -> Result<(), RecvHeaderBlockError>> { trace!("opening stream; init_window={}", self.init_window_sz); let is_initial = stream.state.recv_open(frame.is_end_stream())?; @@ -158,7 +165,7 @@ impl Recv { return Err(RecvError::Stream { id: stream.id, reason: Reason::PROTOCOL_ERROR, - }) + }.into()) }, }; @@ -166,6 +173,32 @@ impl Recv { } } + if frame.is_over_size() { + // A frame is over size if the decoded header block was bigger than + // SETTINGS_MAX_HEADER_LIST_SIZE. + // + // > A server that receives a larger header block than it is willing + // > to handle can send an HTTP 431 (Request Header Fields Too + // > Large) status code [RFC6585]. A client can discard responses + // > that it cannot process. + // + // So, if peer is a server, we'll send a 431. In either case, + // an error is recorded, which will send a REFUSED_STREAM, + // since we don't want any of the data frames either. + trace!("recv_headers; frame for {:?} is over size", stream.id); + return if counts.peer().is_server() && is_initial { + let mut res = frame::Headers::new( + stream.id, + frame::Pseudo::response(::http::StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE), + HeaderMap::new() + ); + res.set_end_stream(); + Err(RecvHeaderBlockError::Oversize(Some(res))) + } else { + Err(RecvHeaderBlockError::Oversize(None)) + }; + } + let message = counts.peer().convert_poll_message(frame)?; // Push the frame onto the stream's recv buffer @@ -517,15 +550,20 @@ impl Recv { ); new_stream.state.reserve_remote()?; + // Store the stream + let new_stream = store.insert(frame.promised_id(), new_stream).key(); + + + if frame.is_over_size() { + trace!("recv_push_promise; frame for {:?} is over size", frame.promised_id()); + return Err(RecvError::Stream { + id: frame.promised_id(), + reason: Reason::REFUSED_STREAM, + }); + } let mut ppp = store[stream].pending_push_promises.take(); - - { - // Store the stream - let mut new_stream = store.insert(frame.promised_id(), new_stream); - - ppp.push(&mut new_stream); - } + ppp.push(&mut store.resolve(new_stream)); let stream = &mut store[stream]; @@ -609,9 +647,7 @@ impl Recv { stream: &mut store::Ptr, counts: &mut Counts, ) { - assert!(stream.state.is_local_reset()); - - if stream.is_pending_reset_expiration() { + if !stream.state.is_local_reset() || stream.is_pending_reset_expiration() { return; } @@ -842,6 +878,14 @@ impl Event { } } +// ===== impl RecvHeaderBlockError ===== + +impl From for RecvHeaderBlockError { + fn from(err: RecvError) -> Self { + RecvHeaderBlockError::State(err) + } +} + // ===== util ===== fn parse_u64(src: &[u8]) -> Result { diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index cff192b..bd69080 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -132,6 +132,9 @@ impl Send { return; } + // Transition the state to reset no matter what. + stream.state.set_reset(reason); + // If closed AND the send queue is flushed, then the stream cannot be // reset explicitly, either. Implicit resets can still be queued. if is_closed && is_empty { @@ -143,9 +146,6 @@ impl Send { return; } - // Transition the state - stream.state.set_reset(reason); - self.recv_err(buffer, stream); let frame = frame::Reset::new(stream.id, reason); @@ -154,14 +154,18 @@ impl Send { self.prioritize.queue_frame(frame.into(), buffer, stream, task); } - pub fn schedule_cancel(&mut self, stream: &mut store::Ptr, task: &mut Option) { - trace!("schedule_cancel; {:?}", stream.id); + pub fn schedule_implicit_reset( + &mut self, + stream: &mut store::Ptr, + reason: Reason, + task: &mut Option, + ) { if stream.state.is_closed() { // Stream is already closed, nothing more to do return; } - stream.state.set_canceled(); + stream.state.set_scheduled_reset(reason); self.prioritize.reclaim_reserved_capacity(stream); self.prioritize.schedule_send(stream, task); diff --git a/src/proto/streams/state.rs b/src/proto/streams/state.rs index d4b15aa..0952070 100644 --- a/src/proto/streams/state.rs +++ b/src/proto/streams/state.rs @@ -76,10 +76,14 @@ enum Cause { LocallyReset(Reason), Io, - /// The user droped all handles to the stream without explicitly canceling. /// This indicates to the connection that a reset frame must be sent out /// once the send queue has been flushed. - Canceled, + /// + /// Examples of when this could happen: + /// - User drops all references to a stream, so we want to CANCEL the it. + /// - Header block size was too large, so we want to REFUSE, possibly + /// after sending a 431 response frame. + Scheduled(Reason), } impl State { @@ -269,15 +273,22 @@ impl State { self.inner = Closed(Cause::LocallyReset(reason)); } - /// Set the stream state to canceled - pub fn set_canceled(&mut self) { + /// Set the stream state to a scheduled reset. + pub fn set_scheduled_reset(&mut self, reason: Reason) { debug_assert!(!self.is_closed()); - self.inner = Closed(Cause::Canceled); + self.inner = Closed(Cause::Scheduled(reason)); } - pub fn is_canceled(&self) -> bool { + pub fn get_scheduled_reset(&self) -> Option { match self.inner { - Closed(Cause::Canceled) => true, + Closed(Cause::Scheduled(reason)) => Some(reason), + _ => None, + } + } + + pub fn is_scheduled_reset(&self) -> bool { + match self.inner { + Closed(Cause::Scheduled(..)) => true, _ => false, } } @@ -285,7 +296,7 @@ impl State { pub fn is_local_reset(&self) -> bool { match self.inner { Closed(Cause::LocallyReset(_)) => true, - Closed(Cause::Canceled) => true, + Closed(Cause::Scheduled(..)) => true, _ => false, } } @@ -381,8 +392,8 @@ impl State { // TODO: Is this correct? match self.inner { Closed(Cause::Proto(reason)) | - Closed(Cause::LocallyReset(reason)) => Err(proto::Error::Proto(reason)), - Closed(Cause::Canceled) => Err(proto::Error::Proto(Reason::CANCEL)), + Closed(Cause::LocallyReset(reason)) | + Closed(Cause::Scheduled(reason)) => Err(proto::Error::Proto(reason)), Closed(Cause::Io) => Err(proto::Error::Io(io::ErrorKind::BrokenPipe.into())), Closed(Cause::EndStream) | HalfClosedRemote(..) => Ok(false), diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 2a8d678..eb97f63 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -1,4 +1,5 @@ use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; +use super::recv::RecvHeaderBlockError; use super::store::{self, Entry, Resolve, Store}; use {client, proto, server}; use codec::{Codec, RecvError, SendError, UserError}; @@ -164,7 +165,28 @@ where ); let res = if stream.state.is_recv_headers() { - actions.recv.recv_headers(frame, stream, counts) + match actions.recv.recv_headers(frame, stream, counts) { + Ok(()) => Ok(()), + Err(RecvHeaderBlockError::Oversize(resp)) => { + if let Some(resp) = resp { + let _ = actions.send.send_headers( + resp, send_buffer, stream, counts, &mut actions.task); + + actions.send.schedule_implicit_reset( + stream, + Reason::REFUSED_STREAM, + &mut actions.task); + actions.recv.enqueue_reset_expiration(stream, counts); + Ok(()) + } else { + Err(RecvError::Stream { + id: stream.id, + reason: Reason::REFUSED_STREAM, + }) + } + }, + Err(RecvHeaderBlockError::State(err)) => Err(err), + } } else { if !frame.is_end_stream() { // TODO: Is this the right error @@ -363,22 +385,42 @@ where let me = &mut *me; let id = frame.stream_id(); + let promised_id = frame.promised_id(); - let stream = match me.store.find_mut(&id) { - Some(stream) => stream.key(), - None => return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)), + let res = { + let stream = match me.store.find_mut(&id) { + Some(stream) => stream.key(), + None => return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)), + }; + + if me.counts.peer().is_server() { + // The remote is a client and cannot reserve + trace!("recv_push_promise; error remote is client"); + return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + } + + me.actions.recv.recv_push_promise(frame, + &me.actions.send, + stream, + &mut me.store) }; - if me.counts.peer().is_server() { - // The remote is a client and cannot reserve - trace!("recv_push_promise; error remote is client"); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); - } + if let Err(err) = res { + if let Some(ref mut new_stream) = me.store.find_mut(&promised_id) { - me.actions.recv.recv_push_promise(frame, - &me.actions.send, - stream, - &mut me.store) + let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + me.actions.reset_on_recv_stream_err(&mut *send_buffer, new_stream, Err(err)) + } else { + // If there was a stream error, the stream should have been stored + // so we can track sending a reset. + // + // Otherwise, this MUST be an connection error. + assert!(!err.is_stream_error()); + Err(err) + } + } else { + res + } } pub fn next_incoming(&mut self) -> Option> { @@ -925,8 +967,9 @@ fn drop_stream_ref(inner: &Mutex, key: store::Key) { fn maybe_cancel(stream: &mut store::Ptr, actions: &mut Actions, counts: &mut Counts) { if stream.is_canceled_interest() { - actions.send.schedule_cancel( + actions.send.schedule_implicit_reset( stream, + Reason::CANCEL, &mut actions.task); actions.recv.enqueue_reset_expiration(stream, counts); } diff --git a/src/server.rs b/src/server.rs index 74bb9a9..b072f79 100644 --- a/src/server.rs +++ b/src/server.rs @@ -363,6 +363,10 @@ where codec.set_max_recv_frame_size(max as usize); } + if let Some(max) = builder.settings.max_header_list_size() { + codec.set_max_recv_header_list_size(max as usize); + } + // Send initial settings frame. codec .buffer(builder.settings.clone().into()) @@ -577,6 +581,12 @@ impl Builder { self } + /// Set the max size of received header frames. + pub fn max_header_list_size(&mut self, max: u32) -> &mut Self { + self.settings.set_max_header_list_size(Some(max)); + self + } + /// Set the maximum number of concurrent streams. /// /// The maximum concurrent streams setting only controls the maximum number diff --git a/tests/client_request.rs b/tests/client_request.rs index cbaaf37..1226b47 100644 --- a/tests/client_request.rs +++ b/tests/client_request.rs @@ -544,6 +544,82 @@ fn sending_request_on_closed_connection() { h2.join(srv).wait().expect("wait"); } +#[test] +fn recv_too_big_headers() { + let _ = ::env_logger::init(); + let (io, srv) = mock::new(); + + let srv = srv.assert_client_handshake() + .unwrap() + .recv_custom_settings( + frames::settings() + .max_header_list_size(10) + ) + .recv_frame( + frames::headers(1) + .request("GET", "https://http2.akamai.com/") + .eos(), + ) + .recv_frame( + frames::headers(3) + .request("GET", "https://http2.akamai.com/") + .eos(), + ) + .send_frame(frames::headers(1).response(200).eos()) + .send_frame(frames::headers(3).response(200)) + // no reset for 1, since it's closed anyways + // but reset for 3, since server hasn't closed stream + .recv_frame(frames::reset(3).refused()) + .idle_ms(10) + .close(); + + let client = client::Builder::new() + .max_header_list_size(10) + .handshake::<_, Bytes>(io) + .expect("handshake") + .and_then(|(mut client, conn)| { + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + + let req1 = client + .send_request(request, true) + .expect("send_request") + .0 + .expect_err("response1") + .map(|err| { + assert_eq!( + err.reason(), + Some(Reason::REFUSED_STREAM) + ); + }); + + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + + let req2 = client + .send_request(request, true) + .expect("send_request") + .0 + .expect_err("response2") + .map(|err| { + assert_eq!( + err.reason(), + Some(Reason::REFUSED_STREAM) + ); + }); + + conn.drive(req1.join(req2)) + .and_then(|(conn, _)| conn.expect("client")) + }); + + client.join(srv).wait().expect("wait"); + +} + const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; diff --git a/tests/codec_read.rs b/tests/codec_read.rs index 5aebced..4fee476 100644 --- a/tests/codec_read.rs +++ b/tests/codec_read.rs @@ -12,6 +12,7 @@ fn read_none() { } #[test] +#[ignore] fn read_frame_too_big() {} // ===== DATA ===== @@ -100,14 +101,73 @@ fn read_data_stream_id_zero() { // ===== HEADERS ===== #[test] +#[ignore] fn read_headers_without_pseudo() {} #[test] +#[ignore] fn read_headers_with_pseudo() {} #[test] +#[ignore] fn read_headers_empty_payload() {} +#[test] +fn read_continuation_frames() { + let _ = ::env_logger::init(); + let (io, srv) = mock::new(); + + let large = build_large_headers(); + let frame = large.iter().fold( + frames::headers(1).response(200), + |frame, &(name, ref value)| frame.field(name, &value[..]), + ).eos(); + + let srv = srv.assert_client_handshake() + .unwrap() + .recv_settings() + .recv_frame( + frames::headers(1) + .request("GET", "https://http2.akamai.com/") + .eos(), + ) + .send_frame(frame) + .close(); + + let client = client::handshake(io) + .expect("handshake") + .and_then(|(mut client, conn)| { + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + + let req = client + .send_request(request, true) + .expect("send_request") + .0 + .expect("response") + .map(move |res| { + assert_eq!(res.status(), StatusCode::OK); + let (head, _body) = res.into_parts(); + let expected = large.iter().fold(HeaderMap::new(), |mut map, &(name, ref value)| { + use support::frames::HttpTryInto; + map.append(name, value.as_str().try_into().unwrap()); + map + }); + assert_eq!(head.headers, expected); + }); + + conn.drive(req) + .and_then(move |(h2, _)| { + h2.expect("client") + }) + }); + + client.join(srv).wait().expect("wait"); + +} + #[test] fn update_max_frame_len_at_rest() { let _ = ::env_logger::init(); diff --git a/tests/codec_write.rs b/tests/codec_write.rs index 06a374f..a0a3da5 100644 --- a/tests/codec_write.rs +++ b/tests/codec_write.rs @@ -59,28 +59,3 @@ fn write_continuation_frames() { client.join(srv).wait().expect("wait"); } - -fn build_large_headers() -> Vec<(&'static str, String)> { - vec![ - ("one", "hello".to_string()), - ("two", build_large_string('2', 4 * 1024)), - ("three", "three".to_string()), - ("four", build_large_string('4', 4 * 1024)), - ("five", "five".to_string()), - ("six", build_large_string('6', 4 * 1024)), - ("seven", "seven".to_string()), - ("eight", build_large_string('8', 4 * 1024)), - ("nine", "nine".to_string()), - ("ten", build_large_string('0', 4 * 1024)), - ] -} - -fn build_large_string(ch: char, len: usize) -> String { - let mut ret = String::new(); - - for _ in 0..len { - ret.push(ch); - } - - ret -} diff --git a/tests/push_promise.rs b/tests/push_promise.rs index 1c38023..c133dab 100644 --- a/tests/push_promise.rs +++ b/tests/push_promise.rs @@ -137,6 +137,56 @@ fn pending_push_promises_reset_when_dropped() { client.join(srv).wait().expect("wait"); } +#[test] +fn recv_push_promise_over_max_header_list_size() { + let _ = ::env_logger::init(); + let (io, srv) = mock::new(); + + let srv = srv.assert_client_handshake() + .unwrap() + .recv_custom_settings( + frames::settings() + .max_header_list_size(10) + ) + .recv_frame( + frames::headers(1) + .request("GET", "https://http2.akamai.com/") + .eos(), + ) + .send_frame(frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css")) + .recv_frame(frames::reset(2).refused()) + .send_frame(frames::headers(1).response(200).eos()) + .idle_ms(10) + .close(); + + let client = client::Builder::new() + .max_header_list_size(10) + .handshake::<_, Bytes>(io) + .expect("handshake") + .and_then(|(mut client, conn)| { + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + + let req = client + .send_request(request, true) + .expect("send_request") + .0 + .expect_err("response") + .map(|err| { + assert_eq!( + err.reason(), + Some(Reason::REFUSED_STREAM) + ); + }); + + conn.drive(req) + .and_then(|(conn, _)| conn.expect("client")) + }); + client.join(srv).wait().expect("wait"); +} + #[test] #[ignore] fn recv_push_promise_with_unsafe_method_is_stream_error() { diff --git a/tests/server.rs b/tests/server.rs index 9ad6196..290fedf 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -261,3 +261,76 @@ fn sends_reset_cancel_when_res_body_is_dropped() { srv.join(client).wait().expect("wait"); } + +#[test] +fn too_big_headers_sends_431() { + let _ = ::env_logger::init(); + let (io, client) = mock::new(); + + let client = client + .assert_server_handshake() + .unwrap() + .recv_custom_settings( + frames::settings() + .max_header_list_size(10) + ) + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .field("some-header", "some-value") + .eos() + ) + .recv_frame(frames::headers(1).response(431).eos()) + .idle_ms(10) + .close(); + + let srv = server::Builder::new() + .max_header_list_size(10) + .handshake::<_, Bytes>(io) + .expect("handshake") + .and_then(|srv| { + srv.into_future() + .expect("server") + .map(|(req, _)| { + assert!(req.is_none(), "req is {:?}", req); + }) + }); + + srv.join(client).wait().expect("wait"); +} + +#[test] +fn too_big_headers_sends_reset_after_431_if_not_eos() { + let _ = ::env_logger::init(); + let (io, client) = mock::new(); + + let client = client + .assert_server_handshake() + .unwrap() + .recv_custom_settings( + frames::settings() + .max_header_list_size(10) + ) + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .field("some-header", "some-value") + ) + .recv_frame(frames::headers(1).response(431).eos()) + .recv_frame(frames::reset(1).refused()) + .close(); + + let srv = server::Builder::new() + .max_header_list_size(10) + .handshake::<_, Bytes>(io) + .expect("handshake") + .and_then(|srv| { + srv.into_future() + .expect("server") + .map(|(req, _)| { + assert!(req.is_none(), "req is {:?}", req); + }) + }); + + srv.join(client).wait().expect("wait"); +} diff --git a/tests/support/frames.rs b/tests/support/frames.rs index 3a25bc3..4a50ff0 100644 --- a/tests/support/frames.rs +++ b/tests/support/frames.rs @@ -152,6 +152,10 @@ impl Mock { self } + pub fn into_fields(self) -> HeaderMap { + self.0.into_parts().1 + } + fn into_parts(self) -> (StreamId, frame::Pseudo, HeaderMap) { assert!(!self.0.is_end_stream(), "eos flag will be lost"); assert!(self.0.is_end_headers(), "unset eoh will be lost"); @@ -304,6 +308,11 @@ impl Mock { self.0.set_initial_window_size(Some(val)); self } + + pub fn max_header_list_size(mut self, val: u32) -> Self { + self.0.set_max_header_list_size(Some(val)); + self + } } impl From> for frame::Settings { diff --git a/tests/support/mock.rs b/tests/support/mock.rs index 3614359..d12dd4f 100644 --- a/tests/support/mock.rs +++ b/tests/support/mock.rs @@ -394,12 +394,13 @@ pub trait HandleFutureExt { self.recv_custom_settings(frame::Settings::default()) } - fn recv_custom_settings(self, settings: frame::Settings) + fn recv_custom_settings(self, settings: T) -> RecvFrame, Handle), Error = ()>>> where Self: Sized + 'static, Self: Future, Self::Error: fmt::Debug, + T: Into, { let map = self .map(|(settings, handle)| (Some(settings.into()), handle)) @@ -409,7 +410,7 @@ pub trait HandleFutureExt { Box::new(map); RecvFrame { inner: boxed, - frame: settings.into(), + frame: settings.into().into(), } } diff --git a/tests/support/mod.rs b/tests/support/mod.rs index 732fdef..ec61244 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -64,3 +64,4 @@ pub type Codec = h2::Codec>; // This is the frame type that is sent pub type SendFrame = h2::frame::Frame<::std::io::Cursor<::bytes::Bytes>>; + diff --git a/tests/support/prelude.rs b/tests/support/prelude.rs index 3bf75b8..272fd98 100644 --- a/tests/support/prelude.rs +++ b/tests/support/prelude.rs @@ -85,3 +85,28 @@ where } } } + +pub fn build_large_headers() -> Vec<(&'static str, String)> { + vec![ + ("one", "hello".to_string()), + ("two", build_large_string('2', 4 * 1024)), + ("three", "three".to_string()), + ("four", build_large_string('4', 4 * 1024)), + ("five", "five".to_string()), + ("six", build_large_string('6', 4 * 1024)), + ("seven", "seven".to_string()), + ("eight", build_large_string('8', 4 * 1024)), + ("nine", "nine".to_string()), + ("ten", build_large_string('0', 4 * 1024)), + ] +} + +fn build_large_string(ch: char, len: usize) -> String { + let mut ret = String::new(); + + for _ in 0..len { + ret.push(ch); + } + + ret +}