From 6357e3256adf15de6b94181e787ff5b7b1cc79b0 Mon Sep 17 00:00:00 2001 From: Kornel Date: Tue, 16 Feb 2021 20:21:29 +0000 Subject: [PATCH] de-generify FramedRead::decode_frame (#509) * de-generify FramedRead::decode_frame * Rename arg to decode_frame Co-authored-by: Dan Burkert --- src/codec/framed_read.rs | 500 ++++++++++++++++++++------------------- 1 file changed, 256 insertions(+), 244 deletions(-) diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs index 8bba125..9673c49 100644 --- a/src/codec/framed_read.rs +++ b/src/codec/framed_read.rs @@ -59,249 +59,6 @@ impl FramedRead { } } - fn decode_frame(&mut self, mut bytes: BytesMut) -> Result, RecvError> { - use self::RecvError::*; - let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len()); - let _e = span.enter(); - - tracing::trace!("decoding frame from {}B", bytes.len()); - - // Parse the head - let head = frame::Head::parse(&bytes); - - if self.partial.is_some() && head.kind() != Kind::Continuation { - proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind()); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - - let kind = head.kind(); - - tracing::trace!(frame.kind = ?kind); - - macro_rules! header_block { - ($frame:ident, $head:ident, $bytes:ident) => ({ - // Drop the frame header - // TODO: Change to drain: carllerche/bytes#130 - let _ = $bytes.split_to(frame::HEADER_LEN); - - // Parse the header frame w/o parsing the payload - let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) { - Ok(res) => res, - Err(frame::Error::InvalidDependencyId) => { - proto_err!(stream: "invalid HEADERS dependency ID"); - // A stream cannot depend on itself. An endpoint MUST - // treat this as a stream error (Section 5.4.2) of type - // `PROTOCOL_ERROR`. - return Err(Stream { - id: $head.stream_id(), - reason: Reason::PROTOCOL_ERROR, - }); - }, - Err(e) => { - proto_err!(conn: "failed to load frame; err={:?}", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - }; - - let is_end_headers = frame.is_end_headers(); - - // Load the HPACK encoded headers - match frame.load_hpack(&mut payload, self.max_header_list_size, &mut self.hpack) { - Ok(_) => {}, - Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}, - Err(frame::Error::MalformedMessage) => { - let id = $head.stream_id(); - proto_err!(stream: "malformed header block; stream={:?}", id); - return Err(Stream { - id, - reason: Reason::PROTOCOL_ERROR, - }); - }, - Err(e) => { - proto_err!(conn: "failed HPACK decoding; err={:?}", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - } - - if is_end_headers { - frame.into() - } else { - tracing::trace!("loaded partial header block"); - // Defer returning the frame - self.partial = Some(Partial { - frame: Continuable::$frame(frame), - buf: payload, - }); - - return Ok(None); - } - }); - } - - let frame = match kind { - Kind::Settings => { - let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); - - res.map_err(|e| { - proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })? - .into() - } - Kind::Ping => { - let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]); - - res.map_err(|e| { - proto_err!(conn: "failed to load PING frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })? - .into() - } - Kind::WindowUpdate => { - let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]); - - res.map_err(|e| { - proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })? - .into() - } - Kind::Data => { - let _ = bytes.split_to(frame::HEADER_LEN); - let res = frame::Data::load(head, bytes.freeze()); - - // TODO: Should this always be connection level? Probably not... - res.map_err(|e| { - proto_err!(conn: "failed to load DATA frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })? - .into() - } - Kind::Headers => header_block!(Headers, head, bytes), - Kind::Reset => { - let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]); - res.map_err(|e| { - proto_err!(conn: "failed to load RESET frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })? - .into() - } - Kind::GoAway => { - let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]); - res.map_err(|e| { - proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e); - Connection(Reason::PROTOCOL_ERROR) - })? - .into() - } - Kind::PushPromise => header_block!(PushPromise, head, bytes), - Kind::Priority => { - if head.stream_id() == 0 { - // Invalid stream identifier - proto_err!(conn: "invalid stream ID 0"); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - - match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) { - Ok(frame) => frame.into(), - Err(frame::Error::InvalidDependencyId) => { - // A stream cannot depend on itself. An endpoint MUST - // treat this as a stream error (Section 5.4.2) of type - // `PROTOCOL_ERROR`. - let id = head.stream_id(); - proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id); - return Err(Stream { - id, - reason: Reason::PROTOCOL_ERROR, - }); - } - Err(e) => { - proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - } - } - Kind::Continuation => { - let is_end_headers = (head.flag() & 0x4) == 0x4; - - let mut partial = match self.partial.take() { - Some(partial) => partial, - None => { - proto_err!(conn: "received unexpected CONTINUATION frame"); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - }; - - // The stream identifiers must match - if partial.frame.stream_id() != head.stream_id() { - proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID"); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - - // Extend the buf - if partial.buf.is_empty() { - partial.buf = bytes.split_off(frame::HEADER_LEN); - } else { - if partial.frame.is_over_size() { - // If there was left over bytes previously, they may be - // needed to continue decoding, even though we will - // be ignoring this frame. This is done to keep the HPACK - // decoder state up-to-date. - // - // Still, we need to be careful, because if a malicious - // attacker were to try to send a gigantic string, such - // that it fits over multiple header blocks, we could - // grow memory uncontrollably again, and that'd be a shame. - // - // Instead, we use a simple heuristic to determine if - // we should continue to ignore decoding, or to tell - // the attacker to go away. - if partial.buf.len() + bytes.len() > self.max_header_list_size { - proto_err!(conn: "CONTINUATION frame header block size over ignorable limit"); - return Err(Connection(Reason::COMPRESSION_ERROR)); - } - } - partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]); - } - - match partial.frame.load_hpack( - &mut partial.buf, - self.max_header_list_size, - &mut self.hpack, - ) { - Ok(_) => {} - Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) - if !is_end_headers => {} - Err(frame::Error::MalformedMessage) => { - let id = head.stream_id(); - proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id); - return Err(Stream { - id, - reason: Reason::PROTOCOL_ERROR, - }); - } - Err(e) => { - proto_err!(conn: "failed HPACK decoding; err={:?}", e); - return Err(Connection(Reason::PROTOCOL_ERROR)); - } - } - - if is_end_headers { - partial.frame.into() - } else { - self.partial = Some(partial); - return Ok(None); - } - } - Kind::Unknown => { - // Unknown frames are ignored - return Ok(None); - } - }; - - Ok(Some(frame)) - } - pub fn get_ref(&self) -> &T { self.inner.get_ref() } @@ -333,6 +90,255 @@ impl FramedRead { } } +/// Decodes a frame. +/// +/// This method is intentionally de-generified and outlined because it is very large. +fn decode_frame( + hpack: &mut hpack::Decoder, + max_header_list_size: usize, + partial_inout: &mut Option, + mut bytes: BytesMut, +) -> Result, RecvError> { + use self::RecvError::*; + let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len()); + let _e = span.enter(); + + tracing::trace!("decoding frame from {}B", bytes.len()); + + // Parse the head + let head = frame::Head::parse(&bytes); + + if partial_inout.is_some() && head.kind() != Kind::Continuation { + proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind()); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } + + let kind = head.kind(); + + tracing::trace!(frame.kind = ?kind); + + macro_rules! header_block { + ($frame:ident, $head:ident, $bytes:ident) => ({ + // Drop the frame header + // TODO: Change to drain: carllerche/bytes#130 + let _ = $bytes.split_to(frame::HEADER_LEN); + + // Parse the header frame w/o parsing the payload + let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) { + Ok(res) => res, + Err(frame::Error::InvalidDependencyId) => { + proto_err!(stream: "invalid HEADERS dependency ID"); + // A stream cannot depend on itself. An endpoint MUST + // treat this as a stream error (Section 5.4.2) of type + // `PROTOCOL_ERROR`. + return Err(Stream { + id: $head.stream_id(), + reason: Reason::PROTOCOL_ERROR, + }); + }, + Err(e) => { + proto_err!(conn: "failed to load frame; err={:?}", e); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } + }; + + let is_end_headers = frame.is_end_headers(); + + // Load the HPACK encoded headers + match frame.load_hpack(&mut payload, max_header_list_size, hpack) { + Ok(_) => {}, + Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}, + Err(frame::Error::MalformedMessage) => { + let id = $head.stream_id(); + proto_err!(stream: "malformed header block; stream={:?}", id); + return Err(Stream { + id, + reason: Reason::PROTOCOL_ERROR, + }); + }, + Err(e) => { + proto_err!(conn: "failed HPACK decoding; err={:?}", e); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } + } + + if is_end_headers { + frame.into() + } else { + tracing::trace!("loaded partial header block"); + // Defer returning the frame + *partial_inout = Some(Partial { + frame: Continuable::$frame(frame), + buf: payload, + }); + + return Ok(None); + } + }); + } + + let frame = match kind { + Kind::Settings => { + let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); + + res.map_err(|e| { + proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e); + Connection(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::Ping => { + let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]); + + res.map_err(|e| { + proto_err!(conn: "failed to load PING frame; err={:?}", e); + Connection(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::WindowUpdate => { + let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]); + + res.map_err(|e| { + proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e); + Connection(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::Data => { + let _ = bytes.split_to(frame::HEADER_LEN); + let res = frame::Data::load(head, bytes.freeze()); + + // TODO: Should this always be connection level? Probably not... + res.map_err(|e| { + proto_err!(conn: "failed to load DATA frame; err={:?}", e); + Connection(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::Headers => header_block!(Headers, head, bytes), + Kind::Reset => { + let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]); + res.map_err(|e| { + proto_err!(conn: "failed to load RESET frame; err={:?}", e); + Connection(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::GoAway => { + let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]); + res.map_err(|e| { + proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e); + Connection(Reason::PROTOCOL_ERROR) + })? + .into() + } + Kind::PushPromise => header_block!(PushPromise, head, bytes), + Kind::Priority => { + if head.stream_id() == 0 { + // Invalid stream identifier + proto_err!(conn: "invalid stream ID 0"); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } + + match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) { + Ok(frame) => frame.into(), + Err(frame::Error::InvalidDependencyId) => { + // A stream cannot depend on itself. An endpoint MUST + // treat this as a stream error (Section 5.4.2) of type + // `PROTOCOL_ERROR`. + let id = head.stream_id(); + proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id); + return Err(Stream { + id, + reason: Reason::PROTOCOL_ERROR, + }); + } + Err(e) => { + proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } + } + } + Kind::Continuation => { + let is_end_headers = (head.flag() & 0x4) == 0x4; + + let mut partial = match partial_inout.take() { + Some(partial) => partial, + None => { + proto_err!(conn: "received unexpected CONTINUATION frame"); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } + }; + + // The stream identifiers must match + if partial.frame.stream_id() != head.stream_id() { + proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID"); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } + + // Extend the buf + if partial.buf.is_empty() { + partial.buf = bytes.split_off(frame::HEADER_LEN); + } else { + if partial.frame.is_over_size() { + // If there was left over bytes previously, they may be + // needed to continue decoding, even though we will + // be ignoring this frame. This is done to keep the HPACK + // decoder state up-to-date. + // + // Still, we need to be careful, because if a malicious + // attacker were to try to send a gigantic string, such + // that it fits over multiple header blocks, we could + // grow memory uncontrollably again, and that'd be a shame. + // + // Instead, we use a simple heuristic to determine if + // we should continue to ignore decoding, or to tell + // the attacker to go away. + if partial.buf.len() + bytes.len() > max_header_list_size { + proto_err!(conn: "CONTINUATION frame header block size over ignorable limit"); + return Err(Connection(Reason::COMPRESSION_ERROR)); + } + } + partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]); + } + + match partial + .frame + .load_hpack(&mut partial.buf, max_header_list_size, hpack) + { + Ok(_) => {} + Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {} + Err(frame::Error::MalformedMessage) => { + let id = head.stream_id(); + proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id); + return Err(Stream { + id, + reason: Reason::PROTOCOL_ERROR, + }); + } + Err(e) => { + proto_err!(conn: "failed HPACK decoding; err={:?}", e); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } + } + + if is_end_headers { + partial.frame.into() + } else { + *partial_inout = Some(partial); + return Ok(None); + } + } + Kind::Unknown => { + // Unknown frames are ignored + return Ok(None); + } + }; + + Ok(Some(frame)) +} + impl Stream for FramedRead where T: AsyncRead + Unpin, @@ -351,7 +357,13 @@ where }; tracing::trace!(read.bytes = bytes.len()); - if let Some(frame) = self.decode_frame(bytes)? { + let Self { + ref mut hpack, + max_header_list_size, + ref mut partial, + .. + } = *self; + if let Some(frame) = decode_frame(hpack, max_header_list_size, partial, bytes)? { tracing::debug!(?frame, "received"); return Poll::Ready(Some(Ok(frame))); }