load headers when receiving PushPromise frames

This commit is contained in:
Sean McArthur
2017-09-15 16:45:32 -07:00
parent a8a4cd2be1
commit 21f7e54ce8
2 changed files with 177 additions and 93 deletions

View File

@@ -39,7 +39,7 @@ enum Continuable {
Headers(frame::Headers),
// Decode the Continuation frame but ignore it...
// Ignore(StreamId),
// PushPromise(frame::PushPromise),
PushPromise(frame::PushPromise),
}
impl<T> FramedRead<T> {
@@ -143,8 +143,38 @@ impl<T> FramedRead<T> {
res.map_err(|_| Connection(ProtocolError))?.into()
},
Kind::PushPromise => {
let res = frame::PushPromise::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|_| Connection(ProtocolError))?.into()
// Drop the frame header
// TODO: Change to drain: carllerche/bytes#130
let _ = bytes.split_to(frame::HEADER_LEN);
// Parse the frame w/o parsing the payload
let (mut push, payload) = frame::PushPromise::load(head, bytes)
.map_err(|_| Connection(ProtocolError))?;
if push.is_end_headers() {
// Load the HPACK encoded headers & return the frame
match push.load_hpack(payload, &mut self.hpack) {
Ok(_) => {},
Err(frame::Error::MalformedMessage) => {
return Err(Stream {
id: head.stream_id(),
reason: ProtocolError,
});
},
Err(_) => return Err(Connection(ProtocolError)),
}
push.into()
} else {
// Defer loading the frame
self.partial = Some(Partial {
frame: Continuable::PushPromise(push),
buf: payload,
});
return Ok(None);
}
},
Kind::Priority => {
if head.stream_id() == 0 {
@@ -183,27 +213,23 @@ impl<T> FramedRead<T> {
return Ok(None);
}
match partial.frame {
Continuable::Headers(mut frame) => {
// The stream identifiers must match
if frame.stream_id() != head.stream_id() {
return Err(Connection(ProtocolError));
}
match frame.load_hpack(partial.buf, &mut self.hpack) {
Ok(_) => {},
Err(frame::Error::MalformedMessage) => {
return Err(Stream {
id: head.stream_id(),
reason: ProtocolError,
});
},
Err(_) => return Err(Connection(ProtocolError)),
}
frame.into()
},
// The stream identifiers must match
if partial.frame.stream_id() != head.stream_id() {
return Err(Connection(ProtocolError));
}
match partial.frame.load_hpack(partial.buf, &mut self.hpack) {
Ok(_) => {},
Err(frame::Error::MalformedMessage) => {
return Err(Stream {
id: head.stream_id(),
reason: ProtocolError,
});
},
Err(_) => return Err(Connection(ProtocolError)),
}
partial.frame.into()
},
Kind::Unknown => {
// Unknown frames are ignored
@@ -276,3 +302,30 @@ fn map_err(err: io::Error) -> RecvError {
}
err.into()
}
// ===== impl Continuable =====
impl Continuable {
fn stream_id(&self) -> frame::StreamId {
match *self {
Continuable::Headers(ref h) => h.stream_id(),
Continuable::PushPromise(ref p) => p.stream_id(),
}
}
fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), frame::Error> {
match *self {
Continuable::Headers(ref mut h) => h.load_hpack(src, decoder),
Continuable::PushPromise(ref mut p) => p.load_hpack(src, decoder),
}
}
}
impl<T> From<Continuable> for Frame<T> {
fn from(cont: Continuable) -> Self {
match cont {
Continuable::Headers(headers) => headers.into(),
Continuable::PushPromise(push) => push.into(),
}
}
}

View File

@@ -188,71 +188,7 @@ impl Headers {
}
pub fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> {
let mut reg = false;
let mut malformed = false;
macro_rules! set_pseudo {
($field:ident, $val:expr) => {{
if reg {
trace!("load_hpack; header malformed -- pseudo not at head of block");
malformed = true;
} else if self.header_block.pseudo.$field.is_some() {
trace!("load_hpack; header malformed -- repeated pseudo");
malformed = true;
} else {
self.header_block.pseudo.$field = Some($val);
}
}}
}
let mut src = Cursor::new(src.freeze());
// At this point, we're going to assume that the hpack encoded headers
// contain the entire payload. Later, we need to check for stream
// priority.
//
// TODO: Provide a way to abort decoding if an error is hit.
let res = decoder.decode(&mut src, |header| {
use hpack::Header::*;
match header {
Field {
name,
value,
} => {
// Connection level header fields are not supported and must
// result in a protocol error.
if name == header::CONNECTION {
trace!("load_hpack; connection level header");
malformed = true;
} else if name == header::TE && value != "trailers" {
trace!("load_hpack; TE header not set to trailers; val={:?}", value);
malformed = true;
} else {
reg = true;
self.header_block.fields.append(name, value);
}
},
Authority(v) => set_pseudo!(authority, v),
Method(v) => set_pseudo!(method, v),
Scheme(v) => set_pseudo!(scheme, v),
Path(v) => set_pseudo!(path, v),
Status(v) => set_pseudo!(status, v),
}
});
if let Err(e) = res {
trace!("hpack decoding error; err={:?}", e);
return Err(e.into());
}
if malformed {
trace!("malformed message");
return Err(Error::MalformedMessage.into());
}
Ok(())
self.header_block.load(src, decoder)
}
pub fn stream_id(&self) -> StreamId {
@@ -343,14 +279,36 @@ impl PushPromise {
}
}
pub fn load(head: Head, payload: &[u8]) -> Result<Self, Error> {
/// Loads the push promise frame but doesn't actually do HPACK decoding.
///
/// HPACK decoding is done in the `load_hpack` step.
pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
let flags = PushPromiseFlag(head.flag());
let mut pad = 0;
// TODO: Handle padding
// Read the padding length
if flags.is_padded() {
// TODO: Ensure payload is sized correctly
pad = src[0] as usize;
let (promised_id, _) = StreamId::parse(&payload[..4]);
// Drop the padding
let _ = src.split_to(1);
}
Ok(PushPromise {
let (promised_id, _) = StreamId::parse(&src[..4]);
// Drop promised_id bytes
let _ = src.split_to(5);
if pad > 0 {
if pad > src.len() {
return Err(Error::TooMuchPadding);
}
let len = src.len() - pad;
src.truncate(len);
}
let frame = PushPromise {
flags: flags,
header_block: HeaderBlock {
fields: HeaderMap::new(),
@@ -358,7 +316,12 @@ impl PushPromise {
},
promised_id: promised_id,
stream_id: head.stream_id(),
})
};
Ok((frame, src))
}
pub fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> {
self.header_block.load(src, decoder)
}
pub fn stream_id(&self) -> StreamId {
@@ -626,6 +589,74 @@ impl fmt::Debug for PushPromiseFlag {
// ===== HeaderBlock =====
impl HeaderBlock {
fn load(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> {
let mut reg = false;
let mut malformed = false;
macro_rules! set_pseudo {
($field:ident, $val:expr) => {{
if reg {
trace!("load_hpack; header malformed -- pseudo not at head of block");
malformed = true;
} else if self.pseudo.$field.is_some() {
trace!("load_hpack; header malformed -- repeated pseudo");
malformed = true;
} else {
self.pseudo.$field = Some($val);
}
}}
}
let mut src = Cursor::new(src.freeze());
// At this point, we're going to assume that the hpack encoded headers
// contain the entire payload. Later, we need to check for stream
// priority.
//
// TODO: Provide a way to abort decoding if an error is hit.
let res = decoder.decode(&mut src, |header| {
use hpack::Header::*;
match header {
Field {
name,
value,
} => {
// Connection level header fields are not supported and must
// result in a protocol error.
if name == header::CONNECTION {
trace!("load_hpack; connection level header");
malformed = true;
} else if name == header::TE && value != "trailers" {
trace!("load_hpack; TE header not set to trailers; val={:?}", value);
malformed = true;
} else {
reg = true;
self.fields.append(name, value);
}
},
Authority(v) => set_pseudo!(authority, v),
Method(v) => set_pseudo!(method, v),
Scheme(v) => set_pseudo!(scheme, v),
Path(v) => set_pseudo!(path, v),
Status(v) => set_pseudo!(status, v),
}
});
if let Err(e) = res {
trace!("hpack decoding error; err={:?}", e);
return Err(e.into());
}
if malformed {
trace!("malformed message");
return Err(Error::MalformedMessage.into());
}
Ok(())
}
fn encode(
self,
stream_id: StreamId,