Support writing continuation frames. (#198)

Large header sets might require being split up across multiple frames.
This patch adds support for doing so.
This commit is contained in:
Carl Lerche
2017-12-20 17:24:29 -08:00
committed by GitHub
parent a89401dd91
commit fc75311fae
4 changed files with 294 additions and 94 deletions

View File

@@ -59,6 +59,7 @@ impl<T> FramedRead<T> {
let head = frame::Head::parse(&bytes);
if self.partial.is_some() && head.kind() != Kind::Continuation {
trace!("connection error PROTOCOL_ERROR -- expected CONTINUATION, got {:?}", head.kind());
return Err(Connection(Reason::PROTOCOL_ERROR));
}
@@ -70,24 +71,36 @@ impl<T> FramedRead<T> {
Kind::Settings => {
let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into()
res.map_err(|e| {
debug!("connection error PROTOCOL_ERROR -- failed to load SETTINGS frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?.into()
},
Kind::Ping => {
let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into()
res.map_err(|e| {
debug!("connection error PROTOCOL_ERROR -- failed to load PING frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?.into()
},
Kind::WindowUpdate => {
let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into()
res.map_err(|e| {
debug!("connection error PROTOCOL_ERROR -- failed to load WINDOW_UPDATE frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?.into()
},
Kind::Data => {
let _ = bytes.split_to(frame::HEADER_LEN);
let res = frame::Data::load(head, bytes.freeze());
// TODO: Should this always be connection level? Probably not...
res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into()
res.map_err(|e| {
debug!("connection error PROTOCOL_ERROR -- failed to load DATA frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?.into()
},
Kind::Headers => {
// Drop the frame header
@@ -101,12 +114,16 @@ impl<T> FramedRead<T> {
// A stream cannot depend on itself. An endpoint MUST
// treat this as a stream error (Section 5.4.2) of type
// `PROTOCOL_ERROR`.
debug!("stream error PROTOCOL_ERROR -- invalid HEADERS dependency ID");
return Err(Stream {
id: head.stream_id(),
reason: Reason::PROTOCOL_ERROR,
});
},
_ => return Err(Connection(Reason::PROTOCOL_ERROR)),
Err(e) => {
debug!("connection error PROTOCOL_ERROR -- failed to load HEADERS frame; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
};
if headers.is_end_headers() {
@@ -114,12 +131,16 @@ impl<T> FramedRead<T> {
match headers.load_hpack(payload, &mut self.hpack) {
Ok(_) => {},
Err(frame::Error::MalformedMessage) => {
debug!("stream error PROTOCOL_ERROR -- malformed HEADERS frame");
return Err(Stream {
id: head.stream_id(),
reason: Reason::PROTOCOL_ERROR,
});
},
Err(_) => return Err(Connection(Reason::PROTOCOL_ERROR)),
Err(e) => {
debug!("connection error PROTOCOL_ERROR -- failed HEADERS frame HPACK decoding; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
}
headers.into()
@@ -148,19 +169,26 @@ impl<T> FramedRead<T> {
// Parse the frame w/o parsing the payload
let (mut push, payload) = frame::PushPromise::load(head, bytes)
.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?;
.map_err(|e| {
debug!("connection error PROTOCOL_ERROR -- failed to load PUSH_PROMISE frame; err={:?}", e);
Connection(Reason::PROTOCOL_ERROR)
})?;
if push.is_end_headers() {
// Load the HPACK encoded headers & return the frame
match push.load_hpack(payload, &mut self.hpack) {
Ok(_) => {},
Err(frame::Error::MalformedMessage) => {
debug!("stream error PROTOCOL_ERROR -- malformed PUSH_PROMISE frame");
return Err(Stream {
id: head.stream_id(),
reason: Reason::PROTOCOL_ERROR,
});
},
Err(_) => return Err(Connection(Reason::PROTOCOL_ERROR)),
Err(e) => {
debug!("connection error PROTOCOL_ERROR -- failed PUSH_PROMISE frame HPACK decoding; err={:?}", e);
return Err(Connection(Reason::PROTOCOL_ERROR));
}
}
push.into()
@@ -186,6 +214,7 @@ impl<T> FramedRead<T> {
// A stream cannot depend on itself. An endpoint MUST
// treat this as a stream error (Section 5.4.2) of type
// `PROTOCOL_ERROR`.
debug!("stream error PROTOCOL_ERROR -- PRIORITY invalid dependency ID");
return Err(Stream {
id: head.stream_id(),
reason: Reason::PROTOCOL_ERROR,
@@ -200,7 +229,10 @@ impl<T> FramedRead<T> {
let mut partial = match self.partial.take() {
Some(partial) => partial,
None => return Err(Connection(Reason::PROTOCOL_ERROR)),
None => {
debug!("connection error PROTOCOL_ERROR -- received unexpected CONTINUATION frame");
return Err(Connection(Reason::PROTOCOL_ERROR));
}
};
// Extend the buf
@@ -213,12 +245,14 @@ impl<T> FramedRead<T> {
// The stream identifiers must match
if partial.frame.stream_id() != head.stream_id() {
debug!("connection error PROTOCOL_ERROR -- CONTINUATION frame stream ID does not match previous frame stream ID");
return Err(Connection(Reason::PROTOCOL_ERROR));
}
match partial.frame.load_hpack(partial.buf, &mut self.hpack) {
Ok(_) => {},
Err(frame::Error::MalformedMessage) => {
debug!("stream error PROTOCOL_ERROR -- malformed CONTINUATION frame");
return Err(Stream {
id: head.stream_id(),
reason: Reason::PROTOCOL_ERROR,
@@ -326,8 +360,14 @@ impl Continuable {
impl<T> From<Continuable> for Frame<T> {
fn from(cont: Continuable) -> Self {
match cont {
Continuable::Headers(headers) => headers.into(),
Continuable::PushPromise(push) => push.into(),
Continuable::Headers(mut headers) => {
headers.set_end_headers();
headers.into()
}
Continuable::PushPromise(mut push) => {
push.set_end_headers();
push.into()
}
}
}
}

View File

@@ -168,37 +168,48 @@ where
pub fn flush(&mut self) -> Poll<(), io::Error> {
trace!("flush");
while !self.is_empty() {
match self.next {
Some(Next::Data(ref mut frame)) => {
let mut buf = Buf::by_ref(&mut self.buf).chain(frame.payload_mut());
try_ready!(self.inner.write_buf(&mut buf));
},
_ => {
try_ready!(self.inner.write_buf(&mut self.buf));
},
loop {
while !self.is_empty() {
match self.next {
Some(Next::Data(ref mut frame)) => {
trace!(" -> queued data frame");
let mut buf = Buf::by_ref(&mut self.buf).chain(frame.payload_mut());
try_ready!(self.inner.write_buf(&mut buf));
},
_ => {
trace!(" -> not a queued data frame");
try_ready!(self.inner.write_buf(&mut self.buf));
},
}
}
}
// The data frame has been written, so unset it
match self.next.take() {
Some(Next::Data(frame)) => {
self.last_data_frame = Some(frame);
},
Some(Next::Continuation(_)) => {
unimplemented!();
},
None => {},
// Clear internal buffer
self.buf.set_position(0);
self.buf.get_mut().clear();
// The data frame has been written, so unset it
match self.next.take() {
Some(Next::Data(frame)) => {
self.last_data_frame = Some(frame);
debug_assert!(self.is_empty());
break;
},
Some(Next::Continuation(frame)) => {
// Buffer the continuation frame, then try to write again
if let Some(continuation) = frame.encode(&mut self.hpack, self.buf.get_mut()) {
self.next = Some(Next::Continuation(continuation));
}
},
None => {
break;
}
}
}
trace!("flushing buffer");
// Flush the upstream
try_nb!(self.inner.flush());
// Clear internal buffer
self.buf.set_position(0);
self.buf.get_mut().clear();
Ok(Async::Ready(()))
}

View File

@@ -56,11 +56,7 @@ pub struct Continuation {
/// Stream ID of continuation frame
stream_id: StreamId,
/// Argument to pass to the HPACK encoder to resume encoding
hpack: hpack::EncodeState,
/// remaining headers to encode
headers: Iter,
header_block: EncodingHeaderBlock,
}
// TODO: These fields shouldn't be `pub`
@@ -85,7 +81,7 @@ pub struct Iter {
fields: header::IntoIter<HeaderValue>,
}
#[derive(PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq)]
struct HeaderBlock {
/// The decoded header fields
fields: HeaderMap,
@@ -95,6 +91,15 @@ struct HeaderBlock {
pseudo: Pseudo,
}
#[derive(Debug)]
struct EncodingHeaderBlock {
/// Argument to pass to the HPACK encoder to resume encoding
hpack: Option<hpack::EncodeState>,
/// remaining headers to encode
headers: Iter,
}
const END_STREAM: u8 = 0x1;
const END_HEADERS: u8 = 0x4;
const PADDED: u8 = 0x8;
@@ -200,6 +205,10 @@ impl Headers {
self.flags.is_end_headers()
}
pub fn set_end_headers(&mut self) {
self.flags.set_end_headers();
}
pub fn is_end_stream(&self) -> bool {
self.flags.is_end_stream()
}
@@ -226,21 +235,15 @@ impl Headers {
}
pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> {
// At this point, the `is_end_headers` flag should always be set
debug_assert!(self.flags.is_end_headers());
// Get the HEADERS frame head
let head = self.head();
let pos = dst.len();
// At this point, we don't know how big the h2 frame will be.
// So, we write the head with length 0, then write the body, and
// finally write the length once we know the size.
head.encode(0, dst);
// Encode the frame
let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst);
// Write the frame length
BigEndian::write_uint(&mut dst[pos..pos + 3], len, 3);
cont
self.header_block.into_encoding()
.encode(&head, encoder, dst, |_| {
})
}
fn head(&self) -> Head {
@@ -325,25 +328,23 @@ impl PushPromise {
self.flags.is_end_headers()
}
pub fn set_end_headers(&mut self) {
self.flags.set_end_headers();
}
pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> {
use bytes::BufMut;
// At this point, the `is_end_headers` flag should always be set
debug_assert!(self.flags.is_end_headers());
let head = self.head();
let pos = dst.len();
let promised_id = self.promised_id;
// At this point, we don't know how big the h2 frame will be.
// So, we write the head with length 0, then write the body, and
// finally write the length once we know the size.
head.encode(0, dst);
// Encode the frame
dst.put_u32::<BigEndian>(self.promised_id.into());
let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst);
// Write the frame length
BigEndian::write_uint(&mut dst[pos..pos + 3], len + 4, 3);
cont
self.header_block.into_encoding()
.encode(&head, encoder, dst, |dst| {
dst.put_u32::<BigEndian>(promised_id.into());
})
}
fn head(&self) -> Head {
@@ -400,6 +401,23 @@ impl fmt::Debug for PushPromise {
}
}
// ===== impl Continuation =====
impl Continuation {
fn head(&self) -> Head {
Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
}
pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> {
// Get the CONTINUATION frame head
let head = self.head();
self.header_block
.encode(&head, encoder, dst, |_| {
})
}
}
// ===== impl Pseudo =====
impl Pseudo {
@@ -458,6 +476,58 @@ fn to_string(src: Bytes) -> String<Bytes> {
unsafe { String::from_utf8_unchecked(src) }
}
// ===== impl EncodingHeaderBlock =====
impl EncodingHeaderBlock {
fn encode<F>(mut self,
head: &Head,
encoder: &mut hpack::Encoder,
dst: &mut BytesMut,
f: F)
-> Option<Continuation>
where F: FnOnce(&mut BytesMut),
{
let head_pos = dst.len();
// At this point, we don't know how big the h2 frame will be.
// So, we write the head with length 0, then write the body, and
// finally write the length once we know the size.
head.encode(0, dst);
let payload_pos = dst.len();
f(dst);
// Now, encode the header payload
let continuation = match encoder.encode(self.hpack, &mut self.headers, dst) {
hpack::Encode::Full => None,
hpack::Encode::Partial(state) => Some(Continuation {
stream_id: head.stream_id(),
header_block: EncodingHeaderBlock {
hpack: Some(state),
headers: self.headers,
},
}),
};
// Compute the header block length
let payload_len = (dst.len() - payload_pos) as u64;
// Write the frame length
BigEndian::write_uint(&mut dst[head_pos..head_pos + 3], payload_len, 3);
if continuation.is_some() {
// There will be continuation frames, so the `is_end_headers` flag
// must be unset
debug_assert!(dst[head_pos + 4] & END_HEADERS == END_HEADERS);
dst[head_pos + 4] -= END_HEADERS;
}
continuation
}
}
// ===== impl Iter =====
impl Iterator for Iter {
@@ -515,13 +585,17 @@ impl HeadersFlag {
}
pub fn set_end_stream(&mut self) {
self.0 |= END_STREAM
self.0 |= END_STREAM;
}
pub fn is_end_headers(&self) -> bool {
self.0 & END_HEADERS == END_HEADERS
}
pub fn set_end_headers(&mut self) {
self.0 |= END_HEADERS;
}
pub fn is_padded(&self) -> bool {
self.0 & PADDED == PADDED
}
@@ -570,6 +644,10 @@ impl PushPromiseFlag {
self.0 & END_HEADERS == END_HEADERS
}
pub fn set_end_headers(&mut self) {
self.0 |= END_HEADERS;
}
pub fn is_padded(&self) -> bool {
self.0 & PADDED == PADDED
}
@@ -624,7 +702,10 @@ impl HeaderBlock {
// contain the entire payload. Later, we need to check for stream
// priority.
//
// TODO: Provide a way to abort decoding if an error is hit.
// If the header frame is malformed, we still have to continue decoding
// the headers. A malformed header frame is a stream level error, but
// the hpack state is connection level. In order to maintain correct
// state for other streams, the hpack decoding process must complete.
let res = decoder.decode(&mut src, |header| {
use hpack::Header::*;
@@ -673,30 +754,13 @@ impl HeaderBlock {
Ok(())
}
fn encode(
self,
stream_id: StreamId,
encoder: &mut hpack::Encoder,
dst: &mut BytesMut,
) -> (u64, Option<Continuation>) {
let pos = dst.len();
let mut headers = Iter {
pseudo: Some(self.pseudo),
fields: self.fields.into_iter(),
};
let cont = match encoder.encode(None, &mut headers, dst) {
hpack::Encode::Full => None,
hpack::Encode::Partial(state) => Some(Continuation {
stream_id: stream_id,
hpack: state,
headers: headers,
}),
};
// Compute the header block length
let len = (dst.len() - pos) as u64;
(len, cont)
fn into_encoding(self) -> EncodingHeaderBlock {
EncodingHeaderBlock {
hpack: None,
headers: Iter {
pseudo: Some(self.pseudo),
fields: self.fields.into_iter(),
},
}
}
}

View File

@@ -1 +1,86 @@
#[macro_use]
pub mod support;
use support::prelude::*;
#[test]
fn write_continuation_frames() {
// An invalid dependency ID results in a stream level error. The hpack
// payload should still be decoded.
let _ = ::env_logger::init();
let (io, srv) = mock::new();
let large = build_large_headers();
// Build the large request frame
let frame = large.iter().fold(
frames::headers(1).request("GET", "https://http2.akamai.com/"),
|frame, &(name, ref value)| frame.field(name, &value[..]));
let srv = srv.assert_client_handshake()
.unwrap()
.recv_settings()
.recv_frame(frame.eos())
.send_frame(
frames::headers(1)
.response(204)
.eos(),
)
.close();
let client = Client::handshake(io)
.expect("handshake")
.and_then(|(mut client, conn)| {
let mut request = Request::builder();
request.uri("https://http2.akamai.com/");
for &(name, ref value) in &large {
request.header(name, &value[..]);
}
let request = request
.body(())
.unwrap();
let req = client
.send_request(request, true)
.expect("send_request1")
.0
.then(|res| {
let response = res.unwrap();
assert_eq!(response.status(), StatusCode::NO_CONTENT);
Ok::<_, ()>(())
});
conn.drive(req)
.and_then(move |(h2, _)| {
h2.unwrap()
})
});
client.join(srv).wait().expect("wait");
}
fn build_large_headers() -> Vec<(&'static str, String)> {
vec![
("one", "hello".to_string()),
("two", build_large_string('2', 4 * 1024)),
("three", "three".to_string()),
("four", build_large_string('4', 4 * 1024)),
("five", "five".to_string()),
("six", build_large_string('6', 4 * 1024)),
("seven", "seven".to_string()),
("eight", build_large_string('8', 4 * 1024)),
("nine", "nine".to_string()),
("ten", build_large_string('0', 4 * 1024)),
]
}
fn build_large_string(ch: char, len: usize) -> String {
let mut ret = String::new();
for _ in 0..len {
ret.push(ch);
}
ret
}