de-generify FramedRead::decode_frame (#509)

* de-generify FramedRead::decode_frame

* Rename arg to decode_frame

Co-authored-by: Dan Burkert <dan@danburkert.com>
This commit is contained in:
Kornel
2021-02-16 20:21:29 +00:00
committed by GitHub
parent fe938cb81c
commit 6357e3256a

View File

@@ -59,249 +59,6 @@ impl<T> FramedRead<T> {
}
}
fn decode_frame(&mut self, mut bytes: BytesMut) -> Result<Option<Frame>, RecvError> {
use self::RecvError::*;
let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len());
let _e = span.enter();
tracing::trace!("decoding frame from {}B", bytes.len());
// Parse the head
let head = frame::Head::parse(&bytes);
if self.partial.is_some() && head.kind() != Kind::Continuation {
proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind());
return Err(Connection(Reason::PROTOCOL_ERROR));
}
let kind = head.kind();
tracing::trace!(frame.kind = ?kind);
macro_rules! header_block {
($frame:ident, $head:ident, $bytes:ident) => ({
// Drop the frame header
// TODO: Change to drain: carllerche/bytes#130
let _ = $bytes.split_to(frame::HEADER_LEN);
// Parse the header frame w/o parsing the payload
let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) {
Ok(res) => res,
Err(frame::Error::InvalidDependencyId) => {
proto_err!(stream: "invalid HEADERS dependency ID");
// A stream cannot depend on itself. An endpoint MUST
// treat this as a stream error (Section 5.4.2) of type
// `PROTOCOL_ERROR`.
return Err(Stream {
id: $head.stream_id(),
reason: Reason::PROTOCOL_ERROR,
});
},
Err(e) => {
proto_err!(conn: "failed to load frame; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
};
let is_end_headers = frame.is_end_headers();
// Load the HPACK encoded headers
match frame.load_hpack(&mut payload, self.max_header_list_size, &mut self.hpack) {
Ok(_) => {},
Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
Err(frame::Error::MalformedMessage) => {
let id = $head.stream_id();
proto_err!(stream: "malformed header block; stream={:?}", id);
return Err(Stream {
id,
reason: Reason::PROTOCOL_ERROR,
});
},
Err(e) => {
proto_err!(conn: "failed HPACK decoding; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
}
if is_end_headers {
frame.into()
} else {
tracing::trace!("loaded partial header block");
// Defer returning the frame
self.partial = Some(Partial {
frame: Continuable::$frame(frame),
buf: payload,
});
return Ok(None);
}
});
}
let frame = match kind {
Kind::Settings => {
let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::Ping => {
let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load PING frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::WindowUpdate => {
let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::Data => {
let _ = bytes.split_to(frame::HEADER_LEN);
let res = frame::Data::load(head, bytes.freeze());
// TODO: Should this always be connection level? Probably not...
res.map_err(|e| {
proto_err!(conn: "failed to load DATA frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::Headers => header_block!(Headers, head, bytes),
Kind::Reset => {
let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load RESET frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::GoAway => {
let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::PushPromise => header_block!(PushPromise, head, bytes),
Kind::Priority => {
if head.stream_id() == 0 {
// Invalid stream identifier
proto_err!(conn: "invalid stream ID 0");
return Err(Connection(Reason::PROTOCOL_ERROR));
}
match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
Ok(frame) => frame.into(),
Err(frame::Error::InvalidDependencyId) => {
// A stream cannot depend on itself. An endpoint MUST
// treat this as a stream error (Section 5.4.2) of type
// `PROTOCOL_ERROR`.
let id = head.stream_id();
proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
return Err(Stream {
id,
reason: Reason::PROTOCOL_ERROR,
});
}
Err(e) => {
proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
}
}
Kind::Continuation => {
let is_end_headers = (head.flag() & 0x4) == 0x4;
let mut partial = match self.partial.take() {
Some(partial) => partial,
None => {
proto_err!(conn: "received unexpected CONTINUATION frame");
return Err(Connection(Reason::PROTOCOL_ERROR));
}
};
// The stream identifiers must match
if partial.frame.stream_id() != head.stream_id() {
proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
return Err(Connection(Reason::PROTOCOL_ERROR));
}
// Extend the buf
if partial.buf.is_empty() {
partial.buf = bytes.split_off(frame::HEADER_LEN);
} else {
if partial.frame.is_over_size() {
// If there was left over bytes previously, they may be
// needed to continue decoding, even though we will
// be ignoring this frame. This is done to keep the HPACK
// decoder state up-to-date.
//
// Still, we need to be careful, because if a malicious
// attacker were to try to send a gigantic string, such
// that it fits over multiple header blocks, we could
// grow memory uncontrollably again, and that'd be a shame.
//
// Instead, we use a simple heuristic to determine if
// we should continue to ignore decoding, or to tell
// the attacker to go away.
if partial.buf.len() + bytes.len() > self.max_header_list_size {
proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
return Err(Connection(Reason::COMPRESSION_ERROR));
}
}
partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
}
match partial.frame.load_hpack(
&mut partial.buf,
self.max_header_list_size,
&mut self.hpack,
) {
Ok(_) => {}
Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_)))
if !is_end_headers => {}
Err(frame::Error::MalformedMessage) => {
let id = head.stream_id();
proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
return Err(Stream {
id,
reason: Reason::PROTOCOL_ERROR,
});
}
Err(e) => {
proto_err!(conn: "failed HPACK decoding; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
}
if is_end_headers {
partial.frame.into()
} else {
self.partial = Some(partial);
return Ok(None);
}
}
Kind::Unknown => {
// Unknown frames are ignored
return Ok(None);
}
};
Ok(Some(frame))
}
pub fn get_ref(&self) -> &T {
self.inner.get_ref()
}
@@ -333,6 +90,255 @@ impl<T> FramedRead<T> {
}
}
/// Decodes a frame.
///
/// This method is intentionally de-generified and outlined because it is very large.
fn decode_frame(
hpack: &mut hpack::Decoder,
max_header_list_size: usize,
partial_inout: &mut Option<Partial>,
mut bytes: BytesMut,
) -> Result<Option<Frame>, RecvError> {
use self::RecvError::*;
let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len());
let _e = span.enter();
tracing::trace!("decoding frame from {}B", bytes.len());
// Parse the head
let head = frame::Head::parse(&bytes);
if partial_inout.is_some() && head.kind() != Kind::Continuation {
proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind());
return Err(Connection(Reason::PROTOCOL_ERROR));
}
let kind = head.kind();
tracing::trace!(frame.kind = ?kind);
macro_rules! header_block {
($frame:ident, $head:ident, $bytes:ident) => ({
// Drop the frame header
// TODO: Change to drain: carllerche/bytes#130
let _ = $bytes.split_to(frame::HEADER_LEN);
// Parse the header frame w/o parsing the payload
let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) {
Ok(res) => res,
Err(frame::Error::InvalidDependencyId) => {
proto_err!(stream: "invalid HEADERS dependency ID");
// A stream cannot depend on itself. An endpoint MUST
// treat this as a stream error (Section 5.4.2) of type
// `PROTOCOL_ERROR`.
return Err(Stream {
id: $head.stream_id(),
reason: Reason::PROTOCOL_ERROR,
});
},
Err(e) => {
proto_err!(conn: "failed to load frame; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
};
let is_end_headers = frame.is_end_headers();
// Load the HPACK encoded headers
match frame.load_hpack(&mut payload, max_header_list_size, hpack) {
Ok(_) => {},
Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
Err(frame::Error::MalformedMessage) => {
let id = $head.stream_id();
proto_err!(stream: "malformed header block; stream={:?}", id);
return Err(Stream {
id,
reason: Reason::PROTOCOL_ERROR,
});
},
Err(e) => {
proto_err!(conn: "failed HPACK decoding; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
}
if is_end_headers {
frame.into()
} else {
tracing::trace!("loaded partial header block");
// Defer returning the frame
*partial_inout = Some(Partial {
frame: Continuable::$frame(frame),
buf: payload,
});
return Ok(None);
}
});
}
let frame = match kind {
Kind::Settings => {
let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::Ping => {
let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load PING frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::WindowUpdate => {
let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::Data => {
let _ = bytes.split_to(frame::HEADER_LEN);
let res = frame::Data::load(head, bytes.freeze());
// TODO: Should this always be connection level? Probably not...
res.map_err(|e| {
proto_err!(conn: "failed to load DATA frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::Headers => header_block!(Headers, head, bytes),
Kind::Reset => {
let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load RESET frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::GoAway => {
let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]);
res.map_err(|e| {
proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?
.into()
}
Kind::PushPromise => header_block!(PushPromise, head, bytes),
Kind::Priority => {
if head.stream_id() == 0 {
// Invalid stream identifier
proto_err!(conn: "invalid stream ID 0");
return Err(Connection(Reason::PROTOCOL_ERROR));
}
match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
Ok(frame) => frame.into(),
Err(frame::Error::InvalidDependencyId) => {
// A stream cannot depend on itself. An endpoint MUST
// treat this as a stream error (Section 5.4.2) of type
// `PROTOCOL_ERROR`.
let id = head.stream_id();
proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
return Err(Stream {
id,
reason: Reason::PROTOCOL_ERROR,
});
}
Err(e) => {
proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
}
}
Kind::Continuation => {
let is_end_headers = (head.flag() & 0x4) == 0x4;
let mut partial = match partial_inout.take() {
Some(partial) => partial,
None => {
proto_err!(conn: "received unexpected CONTINUATION frame");
return Err(Connection(Reason::PROTOCOL_ERROR));
}
};
// The stream identifiers must match
if partial.frame.stream_id() != head.stream_id() {
proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
return Err(Connection(Reason::PROTOCOL_ERROR));
}
// Extend the buf
if partial.buf.is_empty() {
partial.buf = bytes.split_off(frame::HEADER_LEN);
} else {
if partial.frame.is_over_size() {
// If there was left over bytes previously, they may be
// needed to continue decoding, even though we will
// be ignoring this frame. This is done to keep the HPACK
// decoder state up-to-date.
//
// Still, we need to be careful, because if a malicious
// attacker were to try to send a gigantic string, such
// that it fits over multiple header blocks, we could
// grow memory uncontrollably again, and that'd be a shame.
//
// Instead, we use a simple heuristic to determine if
// we should continue to ignore decoding, or to tell
// the attacker to go away.
if partial.buf.len() + bytes.len() > max_header_list_size {
proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
return Err(Connection(Reason::COMPRESSION_ERROR));
}
}
partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
}
match partial
.frame
.load_hpack(&mut partial.buf, max_header_list_size, hpack)
{
Ok(_) => {}
Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}
Err(frame::Error::MalformedMessage) => {
let id = head.stream_id();
proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
return Err(Stream {
id,
reason: Reason::PROTOCOL_ERROR,
});
}
Err(e) => {
proto_err!(conn: "failed HPACK decoding; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
}
if is_end_headers {
partial.frame.into()
} else {
*partial_inout = Some(partial);
return Ok(None);
}
}
Kind::Unknown => {
// Unknown frames are ignored
return Ok(None);
}
};
Ok(Some(frame))
}
impl<T> Stream for FramedRead<T>
where
T: AsyncRead + Unpin,
@@ -351,7 +357,13 @@ where
};
tracing::trace!(read.bytes = bytes.len());
if let Some(frame) = self.decode_frame(bytes)? {
let Self {
ref mut hpack,
max_header_list_size,
ref mut partial,
..
} = *self;
if let Some(frame) = decode_frame(hpack, max_header_list_size, partial, bytes)? {
tracing::debug!(?frame, "received");
return Poll::Ready(Some(Ok(frame)));
}