SETTINGS_MAX_HEADER_LIST_SIZE (#206)

This, uh, grew into something far bigger than expected, but it turns out, all of it was needed to eventually support this correctly.

- Adds configuration to client and server to set [SETTINGS_MAX_HEADER_LIST_SIZE](http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE)
- If not set, a "sane default" of 16 MB is used (taken from golang's http2)
- Decoding header blocks now happens as they are received, instead of buffering up possibly forever until the last continuation frame is parsed.
- As each field is decoded, it's undecoded size is added to the total. Whenever a header block goes over the maximum size, the `frame` will be marked as such.
- Whenever a header block is deemed over max limit, decoding will still continue, but new fields will not be appended to `HeaderMap`. This is also can save wasted hashing.
- To protect against enormous string literals, such that they span multiple continuation frames, a check is made that the combined encoded bytes is less than the max allowed size. While technically not exactly what the spec suggests (counting decoded size instead), this should hopefully only happen when someone is indeed malicious. If found, a `GOAWAY` of `COMPRESSION_ERROR` is sent, and the connection shut down.
- After an oversize header block frame is finished decoding, the streams state machine will notice it is oversize, and handle that.
  - If the local peer is a server, a 431 response is sent, as suggested by the spec.
  - A `REFUSED_STREAM` reset is sent, since we cannot actually give the stream to the user.
- In order to be able to send both the 431 headers frame, and a reset frame afterwards, the scheduled `Canceled` machinery was made more general to a `Scheduled(Reason)` state instead.

Closes #18 
Closes #191
This commit is contained in:
Sean McArthur
2018-01-05 09:23:48 -08:00
committed by GitHub
parent 6f7b826b0a
commit aa23a9735d
26 changed files with 752 additions and 226 deletions

View File

@@ -24,7 +24,7 @@ unstable = []
[dependencies] [dependencies]
futures = "0.1" futures = "0.1"
tokio-io = "0.1.3" tokio-io = "0.1.4"
bytes = "0.4" bytes = "0.4"
http = "0.1" http = "0.1"
byteorder = "1.0" byteorder = "1.0"

View File

@@ -172,6 +172,12 @@ impl Builder {
self self
} }
/// Set the max size of received header frames.
pub fn max_header_list_size(&mut self, max: u32) -> &mut Self {
self.settings.set_max_header_list_size(Some(max));
self
}
/// Set the maximum number of concurrent streams. /// Set the maximum number of concurrent streams.
/// ///
/// Clients can only limit the maximum number of streams that that the /// Clients can only limit the maximum number of streams that that the
@@ -339,6 +345,10 @@ where
codec.set_max_recv_frame_size(max as usize); codec.set_max_recv_frame_size(max as usize);
} }
if let Some(max) = self.builder.settings.max_header_list_size() {
codec.set_max_recv_header_list_size(max as usize);
}
// Send initial settings frame // Send initial settings frame
codec codec
.buffer(self.builder.settings.clone().into()) .buffer(self.builder.settings.clone().into())

View File

@@ -54,6 +54,15 @@ pub enum UserError {
// ===== impl RecvError ===== // ===== impl RecvError =====
impl RecvError {
pub(crate) fn is_stream_error(&self) -> bool {
match *self {
RecvError::Stream { .. } => true,
_ => false,
}
}
}
impl From<io::Error> for RecvError { impl From<io::Error> for RecvError {
fn from(src: io::Error) -> Self { fn from(src: io::Error) -> Self {
RecvError::Io(src) RecvError::Io(src)

View File

@@ -13,6 +13,9 @@ use std::io;
use tokio_io::AsyncRead; use tokio_io::AsyncRead;
use tokio_io::codec::length_delimited; use tokio_io::codec::length_delimited;
// 16 MB "sane default" taken from golang http2
const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20;
#[derive(Debug)] #[derive(Debug)]
pub struct FramedRead<T> { pub struct FramedRead<T> {
inner: length_delimited::FramedRead<T>, inner: length_delimited::FramedRead<T>,
@@ -20,6 +23,8 @@ pub struct FramedRead<T> {
// hpack decoder state // hpack decoder state
hpack: hpack::Decoder, hpack: hpack::Decoder,
max_header_list_size: usize,
partial: Option<Partial>, partial: Option<Partial>,
} }
@@ -36,8 +41,6 @@ struct Partial {
#[derive(Debug)] #[derive(Debug)]
enum Continuable { enum Continuable {
Headers(frame::Headers), Headers(frame::Headers),
// Decode the Continuation frame but ignore it...
// Ignore(StreamId),
PushPromise(frame::PushPromise), PushPromise(frame::PushPromise),
} }
@@ -46,6 +49,7 @@ impl<T> FramedRead<T> {
FramedRead { FramedRead {
inner: inner, inner: inner,
hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE), hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE,
partial: None, partial: None,
} }
} }
@@ -67,6 +71,66 @@ impl<T> FramedRead<T> {
trace!(" -> kind={:?}", kind); trace!(" -> kind={:?}", kind);
macro_rules! header_block {
($frame:ident, $head:ident, $bytes:ident) => ({
// Drop the frame header
// TODO: Change to drain: carllerche/bytes#130
let _ = $bytes.split_to(frame::HEADER_LEN);
// Parse the header frame w/o parsing the payload
let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) {
Ok(res) => res,
Err(frame::Error::InvalidDependencyId) => {
debug!("stream error PROTOCOL_ERROR -- invalid HEADERS dependency ID");
// A stream cannot depend on itself. An endpoint MUST
// treat this as a stream error (Section 5.4.2) of type
// `PROTOCOL_ERROR`.
return Err(Stream {
id: $head.stream_id(),
reason: Reason::PROTOCOL_ERROR,
});
},
Err(e) => {
debug!("connection error PROTOCOL_ERROR -- failed to load frame; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
};
let is_end_headers = frame.is_end_headers();
// Load the HPACK encoded headers
match frame.load_hpack(&mut payload, self.max_header_list_size, &mut self.hpack) {
Ok(_) => {},
Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
Err(frame::Error::MalformedMessage) => {
debug!("stream error PROTOCOL_ERROR -- malformed header block");
return Err(Stream {
id: $head.stream_id(),
reason: Reason::PROTOCOL_ERROR,
});
},
Err(e) => {
debug!("connection error PROTOCOL_ERROR -- failed HPACK decoding; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
}
if is_end_headers {
frame.into()
} else {
trace!("loaded partial header block");
// Defer returning the frame
self.partial = Some(Partial {
frame: Continuable::$frame(frame),
buf: payload,
});
return Ok(None);
}
});
}
let frame = match kind { let frame = match kind {
Kind::Settings => { Kind::Settings => {
let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
@@ -103,56 +167,7 @@ impl<T> FramedRead<T> {
})?.into() })?.into()
}, },
Kind::Headers => { Kind::Headers => {
// Drop the frame header header_block!(Headers, head, bytes)
// TODO: Change to drain: carllerche/bytes#130
let _ = bytes.split_to(frame::HEADER_LEN);
// Parse the header frame w/o parsing the payload
let (mut headers, payload) = match frame::Headers::load(head, bytes) {
Ok(res) => res,
Err(frame::Error::InvalidDependencyId) => {
// A stream cannot depend on itself. An endpoint MUST
// treat this as a stream error (Section 5.4.2) of type
// `PROTOCOL_ERROR`.
debug!("stream error PROTOCOL_ERROR -- invalid HEADERS dependency ID");
return Err(Stream {
id: head.stream_id(),
reason: Reason::PROTOCOL_ERROR,
});
},
Err(e) => {
debug!("connection error PROTOCOL_ERROR -- failed to load HEADERS frame; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
};
if headers.is_end_headers() {
// Load the HPACK encoded headers & return the frame
match headers.load_hpack(payload, &mut self.hpack) {
Ok(_) => {},
Err(frame::Error::MalformedMessage) => {
debug!("stream error PROTOCOL_ERROR -- malformed HEADERS frame");
return Err(Stream {
id: head.stream_id(),
reason: Reason::PROTOCOL_ERROR,
});
},
Err(e) => {
debug!("connection error PROTOCOL_ERROR -- failed HEADERS frame HPACK decoding; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
}
headers.into()
} else {
// Defer loading the frame
self.partial = Some(Partial {
frame: Continuable::Headers(headers),
buf: payload,
});
return Ok(None);
}
}, },
Kind::Reset => { Kind::Reset => {
let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]); let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
@@ -163,44 +178,7 @@ impl<T> FramedRead<T> {
res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into() res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into()
}, },
Kind::PushPromise => { Kind::PushPromise => {
// Drop the frame header header_block!(PushPromise, head, bytes)
// TODO: Change to drain: carllerche/bytes#130
let _ = bytes.split_to(frame::HEADER_LEN);
// Parse the frame w/o parsing the payload
let (mut push, payload) = frame::PushPromise::load(head, bytes)
.map_err(|e| {
debug!("connection error PROTOCOL_ERROR -- failed to load PUSH_PROMISE frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?;
if push.is_end_headers() {
// Load the HPACK encoded headers & return the frame
match push.load_hpack(payload, &mut self.hpack) {
Ok(_) => {},
Err(frame::Error::MalformedMessage) => {
debug!("stream error PROTOCOL_ERROR -- malformed PUSH_PROMISE frame");
return Err(Stream {
id: head.stream_id(),
reason: Reason::PROTOCOL_ERROR,
});
},
Err(e) => {
debug!("connection error PROTOCOL_ERROR -- failed PUSH_PROMISE frame HPACK decoding; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
}
push.into()
} else {
// Defer loading the frame
self.partial = Some(Partial {
frame: Continuable::PushPromise(push),
buf: payload,
});
return Ok(None);
}
}, },
Kind::Priority => { Kind::Priority => {
if head.stream_id() == 0 { if head.stream_id() == 0 {
@@ -224,8 +202,7 @@ impl<T> FramedRead<T> {
} }
}, },
Kind::Continuation => { Kind::Continuation => {
// TODO: Un-hack this let is_end_headers = (head.flag() & 0x4) == 0x4;
let end_of_headers = (head.flag() & 0x4) == 0x4;
let mut partial = match self.partial.take() { let mut partial = match self.partial.take() {
Some(partial) => partial, Some(partial) => partial,
@@ -235,22 +212,43 @@ impl<T> FramedRead<T> {
} }
}; };
// Extend the buf
partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
if !end_of_headers {
self.partial = Some(partial);
return Ok(None);
}
// The stream identifiers must match // The stream identifiers must match
if partial.frame.stream_id() != head.stream_id() { if partial.frame.stream_id() != head.stream_id() {
debug!("connection error PROTOCOL_ERROR -- CONTINUATION frame stream ID does not match previous frame stream ID"); debug!("connection error PROTOCOL_ERROR -- CONTINUATION frame stream ID does not match previous frame stream ID");
return Err(Connection(Reason::PROTOCOL_ERROR)); return Err(Connection(Reason::PROTOCOL_ERROR));
} }
match partial.frame.load_hpack(partial.buf, &mut self.hpack) {
// Extend the buf
if partial.buf.is_empty() {
partial.buf = bytes.split_off(frame::HEADER_LEN);
} else {
if partial.frame.is_over_size() {
// If there was left over bytes previously, they may be
// needed to continue decoding, even though we will
// be ignoring this frame. This is done to keep the HPACK
// decoder state up-to-date.
//
// Still, we need to be careful, because if a malicious
// attacker were to try to send a gigantic string, such
// that it fits over multiple header blocks, we could
// grow memory uncontrollably again, and that'd be a shame.
//
// Instead, we use a simple heuristic to determine if
// we should continue to ignore decoding, or to tell
// the attacker to go away.
if partial.buf.len() + bytes.len() > self.max_header_list_size {
debug!("connection error COMPRESSION_ERROR -- CONTINUATION frame header block size over ignorable limit");
return Err(Connection(Reason::COMPRESSION_ERROR));
}
}
partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
}
match partial.frame.load_hpack(&mut partial.buf, self.max_header_list_size, &mut self.hpack) {
Ok(_) => {}, Ok(_) => {},
Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
Err(frame::Error::MalformedMessage) => { Err(frame::Error::MalformedMessage) => {
debug!("stream error PROTOCOL_ERROR -- malformed CONTINUATION frame"); debug!("stream error PROTOCOL_ERROR -- malformed CONTINUATION frame");
return Err(Stream { return Err(Stream {
@@ -258,10 +256,18 @@ impl<T> FramedRead<T> {
reason: Reason::PROTOCOL_ERROR, reason: Reason::PROTOCOL_ERROR,
}); });
}, },
Err(_) => return Err(Connection(Reason::PROTOCOL_ERROR)), Err(e) => {
debug!("connection error PROTOCOL_ERROR -- failed HPACK decoding; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
},
} }
partial.frame.into() if is_end_headers {
partial.frame.into()
} else {
self.partial = Some(partial);
return Ok(None);
}
}, },
Kind::Unknown => { Kind::Unknown => {
// Unknown frames are ignored // Unknown frames are ignored
@@ -295,6 +301,12 @@ impl<T> FramedRead<T> {
assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize); assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize);
self.inner.set_max_frame_length(val) self.inner.set_max_frame_length(val)
} }
/// Update the max header list size setting.
#[inline]
pub fn set_max_header_list_size(&mut self, val: usize) {
self.max_header_list_size = val;
}
} }
impl<T> Stream for FramedRead<T> impl<T> Stream for FramedRead<T>
@@ -322,14 +334,13 @@ where
} }
fn map_err(err: io::Error) -> RecvError { fn map_err(err: io::Error) -> RecvError {
use std::error::Error; use tokio_io::codec::length_delimited::FrameTooBig;
if let io::ErrorKind::InvalidData = err.kind() { if let io::ErrorKind::InvalidData = err.kind() {
// woah, brittle... if let Some(custom) = err.get_ref() {
// TODO: with tokio-io v0.1.4, we can check if custom.is::<FrameTooBig>() {
// err.get_ref().is::<tokio_io::length_delimited::FrameTooBig>() return RecvError::Connection(Reason::FRAME_SIZE_ERROR);
if err.description() == "frame size too big" { }
return RecvError::Connection(Reason::FRAME_SIZE_ERROR);
} }
} }
err.into() err.into()
@@ -345,14 +356,22 @@ impl Continuable {
} }
} }
fn is_over_size(&self) -> bool {
match *self {
Continuable::Headers(ref h) => h.is_over_size(),
Continuable::PushPromise(ref p) => p.is_over_size(),
}
}
fn load_hpack( fn load_hpack(
&mut self, &mut self,
src: BytesMut, src: &mut BytesMut,
max_header_list_size: usize,
decoder: &mut hpack::Decoder, decoder: &mut hpack::Decoder,
) -> Result<(), frame::Error> { ) -> Result<(), frame::Error> {
match *self { match *self {
Continuable::Headers(ref mut h) => h.load_hpack(src, decoder), Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder),
Continuable::PushPromise(ref mut p) => p.load_hpack(src, decoder), Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder),
} }
} }
} }

View File

@@ -90,6 +90,11 @@ impl<T, B> Codec<T, B> {
self.framed_write().set_max_frame_size(val) self.framed_write().set_max_frame_size(val)
} }
/// Set the max header list size that can be received.
pub fn set_max_recv_header_list_size(&mut self, val: usize) {
self.inner.set_max_header_list_size(val);
}
/// Get a reference to the inner stream. /// Get a reference to the inner stream.
#[cfg(feature = "unstable")] #[cfg(feature = "unstable")]
pub fn get_ref(&self) -> &T { pub fn get_ref(&self) -> &T {

View File

@@ -86,6 +86,9 @@ struct HeaderBlock {
/// The decoded header fields /// The decoded header fields
fields: HeaderMap, fields: HeaderMap,
/// Set to true if decoding went over the max header list size.
is_over_size: bool,
/// Pseudo headers, these are broken out as they must be sent as part of the /// Pseudo headers, these are broken out as they must be sent as part of the
/// headers frame. /// headers frame.
pseudo: Pseudo, pseudo: Pseudo,
@@ -116,6 +119,7 @@ impl Headers {
stream_dep: None, stream_dep: None,
header_block: HeaderBlock { header_block: HeaderBlock {
fields: fields, fields: fields,
is_over_size: false,
pseudo: pseudo, pseudo: pseudo,
}, },
flags: HeadersFlag::default(), flags: HeadersFlag::default(),
@@ -131,6 +135,7 @@ impl Headers {
stream_dep: None, stream_dep: None,
header_block: HeaderBlock { header_block: HeaderBlock {
fields: fields, fields: fields,
is_over_size: false,
pseudo: Pseudo::default(), pseudo: Pseudo::default(),
}, },
flags: flags, flags: flags,
@@ -185,6 +190,7 @@ impl Headers {
stream_dep: stream_dep, stream_dep: stream_dep,
header_block: HeaderBlock { header_block: HeaderBlock {
fields: HeaderMap::new(), fields: HeaderMap::new(),
is_over_size: false,
pseudo: Pseudo::default(), pseudo: Pseudo::default(),
}, },
flags: flags, flags: flags,
@@ -193,8 +199,8 @@ impl Headers {
Ok((headers, src)) Ok((headers, src))
} }
pub fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> { pub fn load_hpack(&mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder) -> Result<(), Error> {
self.header_block.load(src, decoder) self.header_block.load(src, max_header_list_size, decoder)
} }
pub fn stream_id(&self) -> StreamId { pub fn stream_id(&self) -> StreamId {
@@ -217,6 +223,10 @@ impl Headers {
self.flags.set_end_stream() self.flags.set_end_stream()
} }
pub fn is_over_size(&self) -> bool {
self.header_block.is_over_size
}
pub fn into_parts(self) -> (Pseudo, HeaderMap) { pub fn into_parts(self) -> (Pseudo, HeaderMap) {
(self.header_block.pseudo, self.header_block.fields) (self.header_block.pseudo, self.header_block.fields)
} }
@@ -304,6 +314,7 @@ impl PushPromise {
flags: flags, flags: flags,
header_block: HeaderBlock { header_block: HeaderBlock {
fields: HeaderMap::new(), fields: HeaderMap::new(),
is_over_size: false,
pseudo: Pseudo::default(), pseudo: Pseudo::default(),
}, },
promised_id: promised_id, promised_id: promised_id,
@@ -312,8 +323,8 @@ impl PushPromise {
Ok((frame, src)) Ok((frame, src))
} }
pub fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> { pub fn load_hpack(&mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder) -> Result<(), Error> {
self.header_block.load(src, decoder) self.header_block.load(src, max_header_list_size, decoder)
} }
pub fn stream_id(&self) -> StreamId { pub fn stream_id(&self) -> StreamId {
@@ -332,6 +343,10 @@ impl PushPromise {
self.flags.set_end_headers(); self.flags.set_end_headers();
} }
pub fn is_over_size(&self) -> bool {
self.header_block.is_over_size
}
pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> {
use bytes::BufMut; use bytes::BufMut;
@@ -364,6 +379,7 @@ impl PushPromise {
flags: PushPromiseFlag::default(), flags: PushPromiseFlag::default(),
header_block: HeaderBlock { header_block: HeaderBlock {
fields, fields,
is_over_size: false,
pseudo, pseudo,
}, },
promised_id, promised_id,
@@ -677,10 +693,12 @@ impl fmt::Debug for PushPromiseFlag {
// ===== HeaderBlock ===== // ===== HeaderBlock =====
impl HeaderBlock { impl HeaderBlock {
fn load(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> { fn load(&mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder) -> Result<(), Error> {
let mut reg = false; let mut reg = !self.fields.is_empty();
let mut malformed = false; let mut malformed = false;
let mut headers_size = self.calculate_header_list_size();
macro_rules! set_pseudo { macro_rules! set_pseudo {
($field:ident, $val:expr) => {{ ($field:ident, $val:expr) => {{
@@ -691,22 +709,25 @@ impl HeaderBlock {
trace!("load_hpack; header malformed -- repeated pseudo"); trace!("load_hpack; header malformed -- repeated pseudo");
malformed = true; malformed = true;
} else { } else {
self.pseudo.$field = Some($val); let __val = $val;
headers_size += decoded_header_size(stringify!($ident).len() + 1, __val.as_str().len());
if headers_size < max_header_list_size {
self.pseudo.$field = Some(__val);
} else if !self.is_over_size {
trace!("load_hpack; header list size over max");
self.is_over_size = true;
}
} }
}} }}
} }
let mut src = Cursor::new(src.freeze()); let mut cursor = Cursor::new(src);
// At this point, we're going to assume that the hpack encoded headers
// contain the entire payload. Later, we need to check for stream
// priority.
//
// If the header frame is malformed, we still have to continue decoding // If the header frame is malformed, we still have to continue decoding
// the headers. A malformed header frame is a stream level error, but // the headers. A malformed header frame is a stream level error, but
// the hpack state is connection level. In order to maintain correct // the hpack state is connection level. In order to maintain correct
// state for other streams, the hpack decoding process must complete. // state for other streams, the hpack decoding process must complete.
let res = decoder.decode(&mut src, |header| { let res = decoder.decode(&mut cursor, |header| {
use hpack::Header::*; use hpack::Header::*;
match header { match header {
@@ -730,7 +751,14 @@ impl HeaderBlock {
malformed = true; malformed = true;
} else { } else {
reg = true; reg = true;
self.fields.append(name, value);
headers_size += decoded_header_size(name.as_str().len(), value.len());
if headers_size < max_header_list_size {
self.fields.append(name, value);
} else if !self.is_over_size {
trace!("load_hpack; header list size over max");
self.is_over_size = true;
}
} }
}, },
Authority(v) => set_pseudo!(authority, v), Authority(v) => set_pseudo!(authority, v),
@@ -763,4 +791,48 @@ impl HeaderBlock {
}, },
} }
} }
/// Calculates the size of the currently decoded header list.
///
/// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
///
/// > The value is based on the uncompressed size of header fields,
/// > including the length of the name and value in octets plus an
/// > overhead of 32 octets for each header field.
fn calculate_header_list_size(&self) -> usize {
macro_rules! pseudo_size {
($name:ident) => ({
self.pseudo
.$name
.as_ref()
.map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
.unwrap_or(0)
});
}
pseudo_size!(method) +
pseudo_size!(scheme) +
pseudo_size!(status) +
pseudo_size!(authority) +
pseudo_size!(path) +
self.fields.iter()
.map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
.sum::<usize>()
}
}
fn decoded_header_size(name: usize, value: usize) -> usize {
name + value + 32
}
// Stupid hack to make the set_pseudo! macro happy, since all other values
// have a method `as_str` except for `String<Bytes>`.
trait AsStr {
fn as_str(&self) -> &str;
}
impl AsStr for String<Bytes> {
fn as_str(&self) -> &str {
self
}
} }

View File

@@ -89,6 +89,14 @@ impl Settings {
self.max_frame_size = size; self.max_frame_size = size;
} }
pub fn max_header_list_size(&self) -> Option<u32> {
self.max_header_list_size
}
pub fn set_max_header_list_size(&mut self, size: Option<u32>) {
self.max_header_list_size = size;
}
pub fn is_push_enabled(&self) -> bool { pub fn is_push_enabled(&self) -> bool {
self.enable_push.unwrap_or(1) != 0 self.enable_push.unwrap_or(1) != 0
} }

View File

@@ -34,10 +34,15 @@ pub enum DecoderError {
InvalidStatusCode, InvalidStatusCode,
InvalidPseudoheader, InvalidPseudoheader,
InvalidMaxDynamicSize, InvalidMaxDynamicSize,
IntegerUnderflow,
IntegerOverflow, IntegerOverflow,
StringUnderflow, NeedMore(NeedMore),
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum NeedMore {
UnexpectedEndOfStream, UnexpectedEndOfStream,
IntegerUnderflow,
StringUnderflow,
} }
enum Representation { enum Representation {
@@ -163,7 +168,7 @@ impl Decoder {
} }
/// Decodes the headers found in the given buffer. /// Decodes the headers found in the given buffer.
pub fn decode<F>(&mut self, src: &mut Cursor<Bytes>, mut f: F) -> Result<(), DecoderError> pub fn decode<F>(&mut self, src: &mut Cursor<&mut BytesMut>, mut f: F) -> Result<(), DecoderError>
where where
F: FnMut(Header), F: FnMut(Header),
{ {
@@ -185,7 +190,9 @@ impl Decoder {
Indexed => { Indexed => {
trace!(" Indexed; rem={:?}", src.remaining()); trace!(" Indexed; rem={:?}", src.remaining());
can_resize = false; can_resize = false;
f(self.decode_indexed(src)?); let entry = self.decode_indexed(src)?;
consume(src);
f(entry);
}, },
LiteralWithIndexing => { LiteralWithIndexing => {
trace!(" LiteralWithIndexing; rem={:?}", src.remaining()); trace!(" LiteralWithIndexing; rem={:?}", src.remaining());
@@ -194,6 +201,7 @@ impl Decoder {
// Insert the header into the table // Insert the header into the table
self.table.insert(entry.clone()); self.table.insert(entry.clone());
consume(src);
f(entry); f(entry);
}, },
@@ -201,12 +209,14 @@ impl Decoder {
trace!(" LiteralWithoutIndexing; rem={:?}", src.remaining()); trace!(" LiteralWithoutIndexing; rem={:?}", src.remaining());
can_resize = false; can_resize = false;
let entry = self.decode_literal(src, false)?; let entry = self.decode_literal(src, false)?;
consume(src);
f(entry); f(entry);
}, },
LiteralNeverIndexed => { LiteralNeverIndexed => {
trace!(" LiteralNeverIndexed; rem={:?}", src.remaining()); trace!(" LiteralNeverIndexed; rem={:?}", src.remaining());
can_resize = false; can_resize = false;
let entry = self.decode_literal(src, false)?; let entry = self.decode_literal(src, false)?;
consume(src);
// TODO: Track that this should never be indexed // TODO: Track that this should never be indexed
@@ -220,6 +230,7 @@ impl Decoder {
// Handle the dynamic table size update // Handle the dynamic table size update
self.process_size_update(src)?; self.process_size_update(src)?;
consume(src);
}, },
} }
} }
@@ -227,7 +238,7 @@ impl Decoder {
Ok(()) Ok(())
} }
fn process_size_update(&mut self, buf: &mut Cursor<Bytes>) -> Result<(), DecoderError> { fn process_size_update(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result<(), DecoderError> {
let new_size = decode_int(buf, 5)?; let new_size = decode_int(buf, 5)?;
if new_size > self.last_max_update { if new_size > self.last_max_update {
@@ -245,14 +256,14 @@ impl Decoder {
Ok(()) Ok(())
} }
fn decode_indexed(&self, buf: &mut Cursor<Bytes>) -> Result<Header, DecoderError> { fn decode_indexed(&self, buf: &mut Cursor<&mut BytesMut>) -> Result<Header, DecoderError> {
let index = decode_int(buf, 7)?; let index = decode_int(buf, 7)?;
self.table.get(index) self.table.get(index)
} }
fn decode_literal( fn decode_literal(
&mut self, &mut self,
buf: &mut Cursor<Bytes>, buf: &mut Cursor<&mut BytesMut>,
index: bool, index: bool,
) -> Result<Header, DecoderError> { ) -> Result<Header, DecoderError> {
let prefix = if index { 6 } else { 4 }; let prefix = if index { 6 } else { 4 };
@@ -275,13 +286,13 @@ impl Decoder {
} }
} }
fn decode_string(&mut self, buf: &mut Cursor<Bytes>) -> Result<Bytes, DecoderError> { fn decode_string(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result<Bytes, DecoderError> {
const HUFF_FLAG: u8 = 0b10000000; const HUFF_FLAG: u8 = 0b10000000;
// The first bit in the first byte contains the huffman encoded flag. // The first bit in the first byte contains the huffman encoded flag.
let huff = match peek_u8(buf) { let huff = match peek_u8(buf) {
Some(hdr) => (hdr & HUFF_FLAG) == HUFF_FLAG, Some(hdr) => (hdr & HUFF_FLAG) == HUFF_FLAG,
None => return Err(DecoderError::UnexpectedEndOfStream), None => return Err(DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream)),
}; };
// Decode the string length using 7 bit prefix // Decode the string length using 7 bit prefix
@@ -293,7 +304,7 @@ impl Decoder {
len, len,
buf.remaining() buf.remaining()
); );
return Err(DecoderError::StringUnderflow); return Err(DecoderError::NeedMore(NeedMore::StringUnderflow));
} }
if huff { if huff {
@@ -358,7 +369,7 @@ fn decode_int<B: Buf>(buf: &mut B, prefix_size: u8) -> Result<usize, DecoderErro
} }
if !buf.has_remaining() { if !buf.has_remaining() {
return Err(DecoderError::IntegerUnderflow); return Err(DecoderError::NeedMore(NeedMore::IntegerUnderflow));
} }
let mask = if prefix_size == 8 { let mask = if prefix_size == 8 {
@@ -401,7 +412,7 @@ fn decode_int<B: Buf>(buf: &mut B, prefix_size: u8) -> Result<usize, DecoderErro
} }
} }
Err(DecoderError::IntegerUnderflow) Err(DecoderError::NeedMore(NeedMore::IntegerUnderflow))
} }
fn peek_u8<B: Buf>(buf: &mut B) -> Option<u8> { fn peek_u8<B: Buf>(buf: &mut B) -> Option<u8> {
@@ -412,11 +423,19 @@ fn peek_u8<B: Buf>(buf: &mut B) -> Option<u8> {
} }
} }
fn take(buf: &mut Cursor<Bytes>, n: usize) -> Bytes { fn take(buf: &mut Cursor<&mut BytesMut>, n: usize) -> Bytes {
let pos = buf.position() as usize; let pos = buf.position() as usize;
let ret = buf.get_ref().slice(pos, pos + n); let mut head = buf.get_mut().split_to(pos + n);
buf.set_position((pos + n) as u64); buf.set_position(0);
ret head.split_to(pos);
head.freeze()
}
fn consume(buf: &mut Cursor<&mut BytesMut>) {
// remove bytes from the internal BytesMut when they have been successfully
// decoded. This is a more permanent cursor position, which will be
// used to resume if decoding was only partial.
take(buf, 0);
} }
// ===== impl Table ===== // ===== impl Table =====
@@ -778,15 +797,15 @@ fn test_peek_u8() {
#[test] #[test]
fn test_decode_string_empty() { fn test_decode_string_empty() {
let mut de = Decoder::new(0); let mut de = Decoder::new(0);
let buf = Bytes::new(); let mut buf = BytesMut::new();
let err = de.decode_string(&mut Cursor::new(buf)).unwrap_err(); let err = de.decode_string(&mut Cursor::new(&mut buf)).unwrap_err();
assert_eq!(err, DecoderError::UnexpectedEndOfStream); assert_eq!(err, DecoderError::NeedMore(NeedMore::UnexpectedEndOfStream));
} }
#[test] #[test]
fn test_decode_empty() { fn test_decode_empty() {
let mut de = Decoder::new(0); let mut de = Decoder::new(0);
let buf = Bytes::new(); let mut buf = BytesMut::new();
let empty = de.decode(&mut Cursor::new(buf), |_| {}).unwrap(); let empty = de.decode(&mut Cursor::new(&mut buf), |_| {}).unwrap();
assert_eq!(empty, ()); assert_eq!(empty, ());
} }

View File

@@ -7,6 +7,6 @@ mod table;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
pub use self::decoder::{Decoder, DecoderError}; pub use self::decoder::{Decoder, DecoderError, NeedMore};
pub use self::encoder::{Encode, EncodeState, Encoder, EncoderError}; pub use self::encoder::{Encode, EncodeState, Encoder, EncoderError};
pub use self::header::Header; pub use self::header::Header;

View File

@@ -74,7 +74,7 @@ fn test_story(story: Value) {
} }
decoder decoder
.decode(&mut Cursor::new(case.wire.clone().into()), |e| { .decode(&mut Cursor::new(&mut case.wire.clone().into()), |e| {
let (name, value) = expect.remove(0); let (name, value) = expect.remove(0);
assert_eq!(name, key_str(&e)); assert_eq!(name, key_str(&e));
assert_eq!(value, value_str(&e)); assert_eq!(value, value_str(&e));
@@ -108,7 +108,7 @@ fn test_story(story: Value) {
encoder.encode(None, &mut input.clone().into_iter(), &mut buf); encoder.encode(None, &mut input.clone().into_iter(), &mut buf);
decoder decoder
.decode(&mut Cursor::new(buf.into()), |e| { .decode(&mut Cursor::new(&mut buf), |e| {
assert_eq!(e, input.remove(0).reify().unwrap()); assert_eq!(e, input.remove(0).reify().unwrap());
}) })
.unwrap(); .unwrap();

View File

@@ -149,7 +149,7 @@ impl FuzzHpack {
// Decode the chunk! // Decode the chunk!
decoder decoder
.decode(&mut Cursor::new(buf.into()), |e| { .decode(&mut Cursor::new(&mut buf), |e| {
assert_eq!(e, expect.remove(0).reify().unwrap()); assert_eq!(e, expect.remove(0).reify().unwrap());
}) })
.unwrap(); .unwrap();
@@ -161,7 +161,7 @@ impl FuzzHpack {
// Decode the chunk! // Decode the chunk!
decoder decoder
.decode(&mut Cursor::new(buf.into()), |e| { .decode(&mut Cursor::new(&mut buf), |e| {
assert_eq!(e, expect.remove(0).reify().unwrap()); assert_eq!(e, expect.remove(0).reify().unwrap());
}) })
.unwrap(); .unwrap();

View File

@@ -659,10 +659,12 @@ impl Prioritize {
) )
), ),
None => { None => {
assert!(stream.state.is_canceled()); let reason = stream.state.get_scheduled_reset()
stream.state.set_reset(Reason::CANCEL); .expect("must be scheduled to reset");
let frame = frame::Reset::new(stream.id, Reason::CANCEL); stream.state.set_reset(reason);
let frame = frame::Reset::new(stream.id, reason);
Frame::Reset(frame) Frame::Reset(frame)
} }
}; };
@@ -674,7 +676,7 @@ impl Prioritize {
self.last_opened_id = stream.id; self.last_opened_id = stream.id;
} }
if !stream.pending_send.is_empty() || stream.state.is_canceled() { if !stream.pending_send.is_empty() || stream.state.is_scheduled_reset() {
// TODO: Only requeue the sender IF it is ready to send // TODO: Only requeue the sender IF it is ready to send
// the next frame. i.e. don't requeue it if the next // the next frame. i.e. don't requeue it if the next
// frame is a data frame and the stream does not have // frame is a data frame and the stream does not have

View File

@@ -1,4 +1,5 @@
use super::*; use super::*;
use super::store::Resolve;
use {frame, proto}; use {frame, proto};
use codec::{RecvError, UserError}; use codec::{RecvError, UserError};
use frame::{Reason, DEFAULT_INITIAL_WINDOW_SIZE}; use frame::{Reason, DEFAULT_INITIAL_WINDOW_SIZE};
@@ -54,6 +55,12 @@ pub(super) enum Event {
Trailers(HeaderMap), Trailers(HeaderMap),
} }
#[derive(Debug)]
pub(super) enum RecvHeaderBlockError<T> {
Oversize(T),
State(RecvError),
}
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
struct Indices { struct Indices {
head: store::Key, head: store::Key,
@@ -133,7 +140,7 @@ impl Recv {
frame: frame::Headers, frame: frame::Headers,
stream: &mut store::Ptr, stream: &mut store::Ptr,
counts: &mut Counts, counts: &mut Counts,
) -> Result<(), RecvError> { ) -> Result<(), RecvHeaderBlockError<Option<frame::Headers>>> {
trace!("opening stream; init_window={}", self.init_window_sz); trace!("opening stream; init_window={}", self.init_window_sz);
let is_initial = stream.state.recv_open(frame.is_end_stream())?; let is_initial = stream.state.recv_open(frame.is_end_stream())?;
@@ -158,7 +165,7 @@ impl Recv {
return Err(RecvError::Stream { return Err(RecvError::Stream {
id: stream.id, id: stream.id,
reason: Reason::PROTOCOL_ERROR, reason: Reason::PROTOCOL_ERROR,
}) }.into())
}, },
}; };
@@ -166,6 +173,32 @@ impl Recv {
} }
} }
if frame.is_over_size() {
// A frame is over size if the decoded header block was bigger than
// SETTINGS_MAX_HEADER_LIST_SIZE.
//
// > A server that receives a larger header block than it is willing
// > to handle can send an HTTP 431 (Request Header Fields Too
// > Large) status code [RFC6585]. A client can discard responses
// > that it cannot process.
//
// So, if peer is a server, we'll send a 431. In either case,
// an error is recorded, which will send a REFUSED_STREAM,
// since we don't want any of the data frames either.
trace!("recv_headers; frame for {:?} is over size", stream.id);
return if counts.peer().is_server() && is_initial {
let mut res = frame::Headers::new(
stream.id,
frame::Pseudo::response(::http::StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE),
HeaderMap::new()
);
res.set_end_stream();
Err(RecvHeaderBlockError::Oversize(Some(res)))
} else {
Err(RecvHeaderBlockError::Oversize(None))
};
}
let message = counts.peer().convert_poll_message(frame)?; let message = counts.peer().convert_poll_message(frame)?;
// Push the frame onto the stream's recv buffer // Push the frame onto the stream's recv buffer
@@ -517,15 +550,20 @@ impl Recv {
); );
new_stream.state.reserve_remote()?; new_stream.state.reserve_remote()?;
// Store the stream
let new_stream = store.insert(frame.promised_id(), new_stream).key();
if frame.is_over_size() {
trace!("recv_push_promise; frame for {:?} is over size", frame.promised_id());
return Err(RecvError::Stream {
id: frame.promised_id(),
reason: Reason::REFUSED_STREAM,
});
}
let mut ppp = store[stream].pending_push_promises.take(); let mut ppp = store[stream].pending_push_promises.take();
ppp.push(&mut store.resolve(new_stream));
{
// Store the stream
let mut new_stream = store.insert(frame.promised_id(), new_stream);
ppp.push(&mut new_stream);
}
let stream = &mut store[stream]; let stream = &mut store[stream];
@@ -609,9 +647,7 @@ impl Recv {
stream: &mut store::Ptr, stream: &mut store::Ptr,
counts: &mut Counts, counts: &mut Counts,
) { ) {
assert!(stream.state.is_local_reset()); if !stream.state.is_local_reset() || stream.is_pending_reset_expiration() {
if stream.is_pending_reset_expiration() {
return; return;
} }
@@ -842,6 +878,14 @@ impl Event {
} }
} }
// ===== impl RecvHeaderBlockError =====
impl<T> From<RecvError> for RecvHeaderBlockError<T> {
fn from(err: RecvError) -> Self {
RecvHeaderBlockError::State(err)
}
}
// ===== util ===== // ===== util =====
fn parse_u64(src: &[u8]) -> Result<u64, ()> { fn parse_u64(src: &[u8]) -> Result<u64, ()> {

View File

@@ -132,6 +132,9 @@ impl Send {
return; return;
} }
// Transition the state to reset no matter what.
stream.state.set_reset(reason);
// If closed AND the send queue is flushed, then the stream cannot be // If closed AND the send queue is flushed, then the stream cannot be
// reset explicitly, either. Implicit resets can still be queued. // reset explicitly, either. Implicit resets can still be queued.
if is_closed && is_empty { if is_closed && is_empty {
@@ -143,9 +146,6 @@ impl Send {
return; return;
} }
// Transition the state
stream.state.set_reset(reason);
self.recv_err(buffer, stream); self.recv_err(buffer, stream);
let frame = frame::Reset::new(stream.id, reason); let frame = frame::Reset::new(stream.id, reason);
@@ -154,14 +154,18 @@ impl Send {
self.prioritize.queue_frame(frame.into(), buffer, stream, task); self.prioritize.queue_frame(frame.into(), buffer, stream, task);
} }
pub fn schedule_cancel(&mut self, stream: &mut store::Ptr, task: &mut Option<Task>) { pub fn schedule_implicit_reset(
trace!("schedule_cancel; {:?}", stream.id); &mut self,
stream: &mut store::Ptr,
reason: Reason,
task: &mut Option<Task>,
) {
if stream.state.is_closed() { if stream.state.is_closed() {
// Stream is already closed, nothing more to do // Stream is already closed, nothing more to do
return; return;
} }
stream.state.set_canceled(); stream.state.set_scheduled_reset(reason);
self.prioritize.reclaim_reserved_capacity(stream); self.prioritize.reclaim_reserved_capacity(stream);
self.prioritize.schedule_send(stream, task); self.prioritize.schedule_send(stream, task);

View File

@@ -76,10 +76,14 @@ enum Cause {
LocallyReset(Reason), LocallyReset(Reason),
Io, Io,
/// The user droped all handles to the stream without explicitly canceling.
/// This indicates to the connection that a reset frame must be sent out /// This indicates to the connection that a reset frame must be sent out
/// once the send queue has been flushed. /// once the send queue has been flushed.
Canceled, ///
/// Examples of when this could happen:
/// - User drops all references to a stream, so we want to CANCEL the it.
/// - Header block size was too large, so we want to REFUSE, possibly
/// after sending a 431 response frame.
Scheduled(Reason),
} }
impl State { impl State {
@@ -269,15 +273,22 @@ impl State {
self.inner = Closed(Cause::LocallyReset(reason)); self.inner = Closed(Cause::LocallyReset(reason));
} }
/// Set the stream state to canceled /// Set the stream state to a scheduled reset.
pub fn set_canceled(&mut self) { pub fn set_scheduled_reset(&mut self, reason: Reason) {
debug_assert!(!self.is_closed()); debug_assert!(!self.is_closed());
self.inner = Closed(Cause::Canceled); self.inner = Closed(Cause::Scheduled(reason));
} }
pub fn is_canceled(&self) -> bool { pub fn get_scheduled_reset(&self) -> Option<Reason> {
match self.inner { match self.inner {
Closed(Cause::Canceled) => true, Closed(Cause::Scheduled(reason)) => Some(reason),
_ => None,
}
}
pub fn is_scheduled_reset(&self) -> bool {
match self.inner {
Closed(Cause::Scheduled(..)) => true,
_ => false, _ => false,
} }
} }
@@ -285,7 +296,7 @@ impl State {
pub fn is_local_reset(&self) -> bool { pub fn is_local_reset(&self) -> bool {
match self.inner { match self.inner {
Closed(Cause::LocallyReset(_)) => true, Closed(Cause::LocallyReset(_)) => true,
Closed(Cause::Canceled) => true, Closed(Cause::Scheduled(..)) => true,
_ => false, _ => false,
} }
} }
@@ -381,8 +392,8 @@ impl State {
// TODO: Is this correct? // TODO: Is this correct?
match self.inner { match self.inner {
Closed(Cause::Proto(reason)) | Closed(Cause::Proto(reason)) |
Closed(Cause::LocallyReset(reason)) => Err(proto::Error::Proto(reason)), Closed(Cause::LocallyReset(reason)) |
Closed(Cause::Canceled) => Err(proto::Error::Proto(Reason::CANCEL)), Closed(Cause::Scheduled(reason)) => Err(proto::Error::Proto(reason)),
Closed(Cause::Io) => Err(proto::Error::Io(io::ErrorKind::BrokenPipe.into())), Closed(Cause::Io) => Err(proto::Error::Io(io::ErrorKind::BrokenPipe.into())),
Closed(Cause::EndStream) | Closed(Cause::EndStream) |
HalfClosedRemote(..) => Ok(false), HalfClosedRemote(..) => Ok(false),

View File

@@ -1,4 +1,5 @@
use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId};
use super::recv::RecvHeaderBlockError;
use super::store::{self, Entry, Resolve, Store}; use super::store::{self, Entry, Resolve, Store};
use {client, proto, server}; use {client, proto, server};
use codec::{Codec, RecvError, SendError, UserError}; use codec::{Codec, RecvError, SendError, UserError};
@@ -164,7 +165,28 @@ where
); );
let res = if stream.state.is_recv_headers() { let res = if stream.state.is_recv_headers() {
actions.recv.recv_headers(frame, stream, counts) match actions.recv.recv_headers(frame, stream, counts) {
Ok(()) => Ok(()),
Err(RecvHeaderBlockError::Oversize(resp)) => {
if let Some(resp) = resp {
let _ = actions.send.send_headers(
resp, send_buffer, stream, counts, &mut actions.task);
actions.send.schedule_implicit_reset(
stream,
Reason::REFUSED_STREAM,
&mut actions.task);
actions.recv.enqueue_reset_expiration(stream, counts);
Ok(())
} else {
Err(RecvError::Stream {
id: stream.id,
reason: Reason::REFUSED_STREAM,
})
}
},
Err(RecvHeaderBlockError::State(err)) => Err(err),
}
} else { } else {
if !frame.is_end_stream() { if !frame.is_end_stream() {
// TODO: Is this the right error // TODO: Is this the right error
@@ -363,22 +385,42 @@ where
let me = &mut *me; let me = &mut *me;
let id = frame.stream_id(); let id = frame.stream_id();
let promised_id = frame.promised_id();
let stream = match me.store.find_mut(&id) { let res = {
Some(stream) => stream.key(), let stream = match me.store.find_mut(&id) {
None => return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)), Some(stream) => stream.key(),
None => return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)),
};
if me.counts.peer().is_server() {
// The remote is a client and cannot reserve
trace!("recv_push_promise; error remote is client");
return Err(RecvError::Connection(Reason::PROTOCOL_ERROR));
}
me.actions.recv.recv_push_promise(frame,
&me.actions.send,
stream,
&mut me.store)
}; };
if me.counts.peer().is_server() { if let Err(err) = res {
// The remote is a client and cannot reserve if let Some(ref mut new_stream) = me.store.find_mut(&promised_id) {
trace!("recv_push_promise; error remote is client");
return Err(RecvError::Connection(Reason::PROTOCOL_ERROR));
}
me.actions.recv.recv_push_promise(frame, let mut send_buffer = self.send_buffer.inner.lock().unwrap();
&me.actions.send, me.actions.reset_on_recv_stream_err(&mut *send_buffer, new_stream, Err(err))
stream, } else {
&mut me.store) // If there was a stream error, the stream should have been stored
// so we can track sending a reset.
//
// Otherwise, this MUST be an connection error.
assert!(!err.is_stream_error());
Err(err)
}
} else {
res
}
} }
pub fn next_incoming(&mut self) -> Option<StreamRef<B>> { pub fn next_incoming(&mut self) -> Option<StreamRef<B>> {
@@ -925,8 +967,9 @@ fn drop_stream_ref(inner: &Mutex<Inner>, key: store::Key) {
fn maybe_cancel(stream: &mut store::Ptr, actions: &mut Actions, counts: &mut Counts) { fn maybe_cancel(stream: &mut store::Ptr, actions: &mut Actions, counts: &mut Counts) {
if stream.is_canceled_interest() { if stream.is_canceled_interest() {
actions.send.schedule_cancel( actions.send.schedule_implicit_reset(
stream, stream,
Reason::CANCEL,
&mut actions.task); &mut actions.task);
actions.recv.enqueue_reset_expiration(stream, counts); actions.recv.enqueue_reset_expiration(stream, counts);
} }

View File

@@ -363,6 +363,10 @@ where
codec.set_max_recv_frame_size(max as usize); codec.set_max_recv_frame_size(max as usize);
} }
if let Some(max) = builder.settings.max_header_list_size() {
codec.set_max_recv_header_list_size(max as usize);
}
// Send initial settings frame. // Send initial settings frame.
codec codec
.buffer(builder.settings.clone().into()) .buffer(builder.settings.clone().into())
@@ -577,6 +581,12 @@ impl Builder {
self self
} }
/// Set the max size of received header frames.
pub fn max_header_list_size(&mut self, max: u32) -> &mut Self {
self.settings.set_max_header_list_size(Some(max));
self
}
/// Set the maximum number of concurrent streams. /// Set the maximum number of concurrent streams.
/// ///
/// The maximum concurrent streams setting only controls the maximum number /// The maximum concurrent streams setting only controls the maximum number

View File

@@ -544,6 +544,82 @@ fn sending_request_on_closed_connection() {
h2.join(srv).wait().expect("wait"); h2.join(srv).wait().expect("wait");
} }
#[test]
fn recv_too_big_headers() {
let _ = ::env_logger::init();
let (io, srv) = mock::new();
let srv = srv.assert_client_handshake()
.unwrap()
.recv_custom_settings(
frames::settings()
.max_header_list_size(10)
)
.recv_frame(
frames::headers(1)
.request("GET", "https://http2.akamai.com/")
.eos(),
)
.recv_frame(
frames::headers(3)
.request("GET", "https://http2.akamai.com/")
.eos(),
)
.send_frame(frames::headers(1).response(200).eos())
.send_frame(frames::headers(3).response(200))
// no reset for 1, since it's closed anyways
// but reset for 3, since server hasn't closed stream
.recv_frame(frames::reset(3).refused())
.idle_ms(10)
.close();
let client = client::Builder::new()
.max_header_list_size(10)
.handshake::<_, Bytes>(io)
.expect("handshake")
.and_then(|(mut client, conn)| {
let request = Request::builder()
.uri("https://http2.akamai.com/")
.body(())
.unwrap();
let req1 = client
.send_request(request, true)
.expect("send_request")
.0
.expect_err("response1")
.map(|err| {
assert_eq!(
err.reason(),
Some(Reason::REFUSED_STREAM)
);
});
let request = Request::builder()
.uri("https://http2.akamai.com/")
.body(())
.unwrap();
let req2 = client
.send_request(request, true)
.expect("send_request")
.0
.expect_err("response2")
.map(|err| {
assert_eq!(
err.reason(),
Some(Reason::REFUSED_STREAM)
);
});
conn.drive(req1.join(req2))
.and_then(|(conn, _)| conn.expect("client"))
});
client.join(srv).wait().expect("wait");
}
const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0];
const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0];

View File

@@ -12,6 +12,7 @@ fn read_none() {
} }
#[test] #[test]
#[ignore]
fn read_frame_too_big() {} fn read_frame_too_big() {}
// ===== DATA ===== // ===== DATA =====
@@ -100,14 +101,73 @@ fn read_data_stream_id_zero() {
// ===== HEADERS ===== // ===== HEADERS =====
#[test] #[test]
#[ignore]
fn read_headers_without_pseudo() {} fn read_headers_without_pseudo() {}
#[test] #[test]
#[ignore]
fn read_headers_with_pseudo() {} fn read_headers_with_pseudo() {}
#[test] #[test]
#[ignore]
fn read_headers_empty_payload() {} fn read_headers_empty_payload() {}
#[test]
fn read_continuation_frames() {
let _ = ::env_logger::init();
let (io, srv) = mock::new();
let large = build_large_headers();
let frame = large.iter().fold(
frames::headers(1).response(200),
|frame, &(name, ref value)| frame.field(name, &value[..]),
).eos();
let srv = srv.assert_client_handshake()
.unwrap()
.recv_settings()
.recv_frame(
frames::headers(1)
.request("GET", "https://http2.akamai.com/")
.eos(),
)
.send_frame(frame)
.close();
let client = client::handshake(io)
.expect("handshake")
.and_then(|(mut client, conn)| {
let request = Request::builder()
.uri("https://http2.akamai.com/")
.body(())
.unwrap();
let req = client
.send_request(request, true)
.expect("send_request")
.0
.expect("response")
.map(move |res| {
assert_eq!(res.status(), StatusCode::OK);
let (head, _body) = res.into_parts();
let expected = large.iter().fold(HeaderMap::new(), |mut map, &(name, ref value)| {
use support::frames::HttpTryInto;
map.append(name, value.as_str().try_into().unwrap());
map
});
assert_eq!(head.headers, expected);
});
conn.drive(req)
.and_then(move |(h2, _)| {
h2.expect("client")
})
});
client.join(srv).wait().expect("wait");
}
#[test] #[test]
fn update_max_frame_len_at_rest() { fn update_max_frame_len_at_rest() {
let _ = ::env_logger::init(); let _ = ::env_logger::init();

View File

@@ -59,28 +59,3 @@ fn write_continuation_frames() {
client.join(srv).wait().expect("wait"); client.join(srv).wait().expect("wait");
} }
fn build_large_headers() -> Vec<(&'static str, String)> {
vec![
("one", "hello".to_string()),
("two", build_large_string('2', 4 * 1024)),
("three", "three".to_string()),
("four", build_large_string('4', 4 * 1024)),
("five", "five".to_string()),
("six", build_large_string('6', 4 * 1024)),
("seven", "seven".to_string()),
("eight", build_large_string('8', 4 * 1024)),
("nine", "nine".to_string()),
("ten", build_large_string('0', 4 * 1024)),
]
}
fn build_large_string(ch: char, len: usize) -> String {
let mut ret = String::new();
for _ in 0..len {
ret.push(ch);
}
ret
}

View File

@@ -137,6 +137,56 @@ fn pending_push_promises_reset_when_dropped() {
client.join(srv).wait().expect("wait"); client.join(srv).wait().expect("wait");
} }
#[test]
fn recv_push_promise_over_max_header_list_size() {
let _ = ::env_logger::init();
let (io, srv) = mock::new();
let srv = srv.assert_client_handshake()
.unwrap()
.recv_custom_settings(
frames::settings()
.max_header_list_size(10)
)
.recv_frame(
frames::headers(1)
.request("GET", "https://http2.akamai.com/")
.eos(),
)
.send_frame(frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css"))
.recv_frame(frames::reset(2).refused())
.send_frame(frames::headers(1).response(200).eos())
.idle_ms(10)
.close();
let client = client::Builder::new()
.max_header_list_size(10)
.handshake::<_, Bytes>(io)
.expect("handshake")
.and_then(|(mut client, conn)| {
let request = Request::builder()
.uri("https://http2.akamai.com/")
.body(())
.unwrap();
let req = client
.send_request(request, true)
.expect("send_request")
.0
.expect_err("response")
.map(|err| {
assert_eq!(
err.reason(),
Some(Reason::REFUSED_STREAM)
);
});
conn.drive(req)
.and_then(|(conn, _)| conn.expect("client"))
});
client.join(srv).wait().expect("wait");
}
#[test] #[test]
#[ignore] #[ignore]
fn recv_push_promise_with_unsafe_method_is_stream_error() { fn recv_push_promise_with_unsafe_method_is_stream_error() {

View File

@@ -261,3 +261,76 @@ fn sends_reset_cancel_when_res_body_is_dropped() {
srv.join(client).wait().expect("wait"); srv.join(client).wait().expect("wait");
} }
#[test]
fn too_big_headers_sends_431() {
let _ = ::env_logger::init();
let (io, client) = mock::new();
let client = client
.assert_server_handshake()
.unwrap()
.recv_custom_settings(
frames::settings()
.max_header_list_size(10)
)
.send_frame(
frames::headers(1)
.request("GET", "https://example.com/")
.field("some-header", "some-value")
.eos()
)
.recv_frame(frames::headers(1).response(431).eos())
.idle_ms(10)
.close();
let srv = server::Builder::new()
.max_header_list_size(10)
.handshake::<_, Bytes>(io)
.expect("handshake")
.and_then(|srv| {
srv.into_future()
.expect("server")
.map(|(req, _)| {
assert!(req.is_none(), "req is {:?}", req);
})
});
srv.join(client).wait().expect("wait");
}
#[test]
fn too_big_headers_sends_reset_after_431_if_not_eos() {
let _ = ::env_logger::init();
let (io, client) = mock::new();
let client = client
.assert_server_handshake()
.unwrap()
.recv_custom_settings(
frames::settings()
.max_header_list_size(10)
)
.send_frame(
frames::headers(1)
.request("GET", "https://example.com/")
.field("some-header", "some-value")
)
.recv_frame(frames::headers(1).response(431).eos())
.recv_frame(frames::reset(1).refused())
.close();
let srv = server::Builder::new()
.max_header_list_size(10)
.handshake::<_, Bytes>(io)
.expect("handshake")
.and_then(|srv| {
srv.into_future()
.expect("server")
.map(|(req, _)| {
assert!(req.is_none(), "req is {:?}", req);
})
});
srv.join(client).wait().expect("wait");
}

View File

@@ -152,6 +152,10 @@ impl Mock<frame::Headers> {
self self
} }
pub fn into_fields(self) -> HeaderMap {
self.0.into_parts().1
}
fn into_parts(self) -> (StreamId, frame::Pseudo, HeaderMap) { fn into_parts(self) -> (StreamId, frame::Pseudo, HeaderMap) {
assert!(!self.0.is_end_stream(), "eos flag will be lost"); assert!(!self.0.is_end_stream(), "eos flag will be lost");
assert!(self.0.is_end_headers(), "unset eoh will be lost"); assert!(self.0.is_end_headers(), "unset eoh will be lost");
@@ -304,6 +308,11 @@ impl Mock<frame::Settings> {
self.0.set_initial_window_size(Some(val)); self.0.set_initial_window_size(Some(val));
self self
} }
pub fn max_header_list_size(mut self, val: u32) -> Self {
self.0.set_max_header_list_size(Some(val));
self
}
} }
impl From<Mock<frame::Settings>> for frame::Settings { impl From<Mock<frame::Settings>> for frame::Settings {

View File

@@ -394,12 +394,13 @@ pub trait HandleFutureExt {
self.recv_custom_settings(frame::Settings::default()) self.recv_custom_settings(frame::Settings::default())
} }
fn recv_custom_settings(self, settings: frame::Settings) fn recv_custom_settings<T>(self, settings: T)
-> RecvFrame<Box<Future<Item = (Option<Frame>, Handle), Error = ()>>> -> RecvFrame<Box<Future<Item = (Option<Frame>, Handle), Error = ()>>>
where where
Self: Sized + 'static, Self: Sized + 'static,
Self: Future<Item = (frame::Settings, Handle)>, Self: Future<Item = (frame::Settings, Handle)>,
Self::Error: fmt::Debug, Self::Error: fmt::Debug,
T: Into<frame::Settings>,
{ {
let map = self let map = self
.map(|(settings, handle)| (Some(settings.into()), handle)) .map(|(settings, handle)| (Some(settings.into()), handle))
@@ -409,7 +410,7 @@ pub trait HandleFutureExt {
Box::new(map); Box::new(map);
RecvFrame { RecvFrame {
inner: boxed, inner: boxed,
frame: settings.into(), frame: settings.into().into(),
} }
} }

View File

@@ -64,3 +64,4 @@ pub type Codec<T> = h2::Codec<T, ::std::io::Cursor<::bytes::Bytes>>;
// This is the frame type that is sent // This is the frame type that is sent
pub type SendFrame = h2::frame::Frame<::std::io::Cursor<::bytes::Bytes>>; pub type SendFrame = h2::frame::Frame<::std::io::Cursor<::bytes::Bytes>>;

View File

@@ -85,3 +85,28 @@ where
} }
} }
} }
pub fn build_large_headers() -> Vec<(&'static str, String)> {
vec![
("one", "hello".to_string()),
("two", build_large_string('2', 4 * 1024)),
("three", "three".to_string()),
("four", build_large_string('4', 4 * 1024)),
("five", "five".to_string()),
("six", build_large_string('6', 4 * 1024)),
("seven", "seven".to_string()),
("eight", build_large_string('8', 4 * 1024)),
("nine", "nine".to_string()),
("ten", build_large_string('0', 4 * 1024)),
]
}
fn build_large_string(ch: char, len: usize) -> String {
let mut ret = String::new();
for _ in 0..len {
ret.push(ch);
}
ret
}