From 21f7e54ce8f3ded3d2e9bac3367511795d1d9a09 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Fri, 15 Sep 2017 16:45:32 -0700 Subject: [PATCH] load headers when receiving PushPromise frames --- src/codec/framed_read.rs | 99 +++++++++++++++++------ src/frame/headers.rs | 171 +++++++++++++++++++++++---------------- 2 files changed, 177 insertions(+), 93 deletions(-) diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs index 90f2094..c91c959 100644 --- a/src/codec/framed_read.rs +++ b/src/codec/framed_read.rs @@ -39,7 +39,7 @@ enum Continuable { Headers(frame::Headers), // Decode the Continuation frame but ignore it... // Ignore(StreamId), - // PushPromise(frame::PushPromise), + PushPromise(frame::PushPromise), } impl FramedRead { @@ -143,8 +143,38 @@ impl FramedRead { res.map_err(|_| Connection(ProtocolError))?.into() }, Kind::PushPromise => { - let res = frame::PushPromise::load(head, &bytes[frame::HEADER_LEN..]); - res.map_err(|_| Connection(ProtocolError))?.into() + // Drop the frame header + // TODO: Change to drain: carllerche/bytes#130 + let _ = bytes.split_to(frame::HEADER_LEN); + + // Parse the frame w/o parsing the payload + let (mut push, payload) = frame::PushPromise::load(head, bytes) + .map_err(|_| Connection(ProtocolError))?; + + if push.is_end_headers() { + // Load the HPACK encoded headers & return the frame + match push.load_hpack(payload, &mut self.hpack) { + Ok(_) => {}, + Err(frame::Error::MalformedMessage) => { + return Err(Stream { + id: head.stream_id(), + reason: ProtocolError, + }); + }, + Err(_) => return Err(Connection(ProtocolError)), + } + + push.into() + } else { + // Defer loading the frame + self.partial = Some(Partial { + frame: Continuable::PushPromise(push), + buf: payload, + }); + + return Ok(None); + } + }, Kind::Priority => { if head.stream_id() == 0 { @@ -183,27 +213,23 @@ impl FramedRead { return Ok(None); } - match partial.frame { - Continuable::Headers(mut frame) => { - // The stream identifiers must match - if frame.stream_id() != head.stream_id() { - return Err(Connection(ProtocolError)); - } - - match frame.load_hpack(partial.buf, &mut self.hpack) { - Ok(_) => {}, - Err(frame::Error::MalformedMessage) => { - return Err(Stream { - id: head.stream_id(), - reason: ProtocolError, - }); - }, - Err(_) => return Err(Connection(ProtocolError)), - } - - frame.into() - }, + // The stream identifiers must match + if partial.frame.stream_id() != head.stream_id() { + return Err(Connection(ProtocolError)); } + + match partial.frame.load_hpack(partial.buf, &mut self.hpack) { + Ok(_) => {}, + Err(frame::Error::MalformedMessage) => { + return Err(Stream { + id: head.stream_id(), + reason: ProtocolError, + }); + }, + Err(_) => return Err(Connection(ProtocolError)), + } + + partial.frame.into() }, Kind::Unknown => { // Unknown frames are ignored @@ -276,3 +302,30 @@ fn map_err(err: io::Error) -> RecvError { } err.into() } + +// ===== impl Continuable ===== + +impl Continuable { + fn stream_id(&self) -> frame::StreamId { + match *self { + Continuable::Headers(ref h) => h.stream_id(), + Continuable::PushPromise(ref p) => p.stream_id(), + } + } + + fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), frame::Error> { + match *self { + Continuable::Headers(ref mut h) => h.load_hpack(src, decoder), + Continuable::PushPromise(ref mut p) => p.load_hpack(src, decoder), + } + } +} + +impl From for Frame { + fn from(cont: Continuable) -> Self { + match cont { + Continuable::Headers(headers) => headers.into(), + Continuable::PushPromise(push) => push.into(), + } + } +} diff --git a/src/frame/headers.rs b/src/frame/headers.rs index c261f16..4a4688e 100644 --- a/src/frame/headers.rs +++ b/src/frame/headers.rs @@ -188,71 +188,7 @@ impl Headers { } pub fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> { - let mut reg = false; - let mut malformed = false; - - macro_rules! set_pseudo { - ($field:ident, $val:expr) => {{ - if reg { - trace!("load_hpack; header malformed -- pseudo not at head of block"); - malformed = true; - } else if self.header_block.pseudo.$field.is_some() { - trace!("load_hpack; header malformed -- repeated pseudo"); - malformed = true; - } else { - self.header_block.pseudo.$field = Some($val); - } - }} - } - - let mut src = Cursor::new(src.freeze()); - - // At this point, we're going to assume that the hpack encoded headers - // contain the entire payload. Later, we need to check for stream - // priority. - // - // TODO: Provide a way to abort decoding if an error is hit. - let res = decoder.decode(&mut src, |header| { - use hpack::Header::*; - - match header { - Field { - name, - value, - } => { - // Connection level header fields are not supported and must - // result in a protocol error. - - if name == header::CONNECTION { - trace!("load_hpack; connection level header"); - malformed = true; - } else if name == header::TE && value != "trailers" { - trace!("load_hpack; TE header not set to trailers; val={:?}", value); - malformed = true; - } else { - reg = true; - self.header_block.fields.append(name, value); - } - }, - Authority(v) => set_pseudo!(authority, v), - Method(v) => set_pseudo!(method, v), - Scheme(v) => set_pseudo!(scheme, v), - Path(v) => set_pseudo!(path, v), - Status(v) => set_pseudo!(status, v), - } - }); - - if let Err(e) = res { - trace!("hpack decoding error; err={:?}", e); - return Err(e.into()); - } - - if malformed { - trace!("malformed message"); - return Err(Error::MalformedMessage.into()); - } - - Ok(()) + self.header_block.load(src, decoder) } pub fn stream_id(&self) -> StreamId { @@ -343,14 +279,36 @@ impl PushPromise { } } - pub fn load(head: Head, payload: &[u8]) -> Result { + /// Loads the push promise frame but doesn't actually do HPACK decoding. + /// + /// HPACK decoding is done in the `load_hpack` step. + pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> { let flags = PushPromiseFlag(head.flag()); + let mut pad = 0; - // TODO: Handle padding + // Read the padding length + if flags.is_padded() { + // TODO: Ensure payload is sized correctly + pad = src[0] as usize; - let (promised_id, _) = StreamId::parse(&payload[..4]); + // Drop the padding + let _ = src.split_to(1); + } - Ok(PushPromise { + let (promised_id, _) = StreamId::parse(&src[..4]); + // Drop promised_id bytes + let _ = src.split_to(5); + + if pad > 0 { + if pad > src.len() { + return Err(Error::TooMuchPadding); + } + + let len = src.len() - pad; + src.truncate(len); + } + + let frame = PushPromise { flags: flags, header_block: HeaderBlock { fields: HeaderMap::new(), @@ -358,7 +316,12 @@ impl PushPromise { }, promised_id: promised_id, stream_id: head.stream_id(), - }) + }; + Ok((frame, src)) + } + + pub fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> { + self.header_block.load(src, decoder) } pub fn stream_id(&self) -> StreamId { @@ -626,6 +589,74 @@ impl fmt::Debug for PushPromiseFlag { // ===== HeaderBlock ===== impl HeaderBlock { + fn load(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> { + let mut reg = false; + let mut malformed = false; + + macro_rules! set_pseudo { + ($field:ident, $val:expr) => {{ + if reg { + trace!("load_hpack; header malformed -- pseudo not at head of block"); + malformed = true; + } else if self.pseudo.$field.is_some() { + trace!("load_hpack; header malformed -- repeated pseudo"); + malformed = true; + } else { + self.pseudo.$field = Some($val); + } + }} + } + + let mut src = Cursor::new(src.freeze()); + + // At this point, we're going to assume that the hpack encoded headers + // contain the entire payload. Later, we need to check for stream + // priority. + // + // TODO: Provide a way to abort decoding if an error is hit. + let res = decoder.decode(&mut src, |header| { + use hpack::Header::*; + + match header { + Field { + name, + value, + } => { + // Connection level header fields are not supported and must + // result in a protocol error. + + if name == header::CONNECTION { + trace!("load_hpack; connection level header"); + malformed = true; + } else if name == header::TE && value != "trailers" { + trace!("load_hpack; TE header not set to trailers; val={:?}", value); + malformed = true; + } else { + reg = true; + self.fields.append(name, value); + } + }, + Authority(v) => set_pseudo!(authority, v), + Method(v) => set_pseudo!(method, v), + Scheme(v) => set_pseudo!(scheme, v), + Path(v) => set_pseudo!(path, v), + Status(v) => set_pseudo!(status, v), + } + }); + + if let Err(e) = res { + trace!("hpack decoding error; err={:?}", e); + return Err(e.into()); + } + + if malformed { + trace!("malformed message"); + return Err(Error::MalformedMessage.into()); + } + + Ok(()) + } + fn encode( self, stream_id: StreamId,