Support writing continuation frames. (#198)
Large header sets might require being split up across multiple frames. This patch adds support for doing so.
This commit is contained in:
		| @@ -56,11 +56,7 @@ pub struct Continuation { | ||||
|     /// Stream ID of continuation frame | ||||
|     stream_id: StreamId, | ||||
|  | ||||
|     /// Argument to pass to the HPACK encoder to resume encoding | ||||
|     hpack: hpack::EncodeState, | ||||
|  | ||||
|     /// remaining headers to encode | ||||
|     headers: Iter, | ||||
|     header_block: EncodingHeaderBlock, | ||||
| } | ||||
|  | ||||
| // TODO: These fields shouldn't be `pub` | ||||
| @@ -85,7 +81,7 @@ pub struct Iter { | ||||
|     fields: header::IntoIter<HeaderValue>, | ||||
| } | ||||
|  | ||||
| #[derive(PartialEq, Eq)] | ||||
| #[derive(Debug, PartialEq, Eq)] | ||||
| struct HeaderBlock { | ||||
|     /// The decoded header fields | ||||
|     fields: HeaderMap, | ||||
| @@ -95,6 +91,15 @@ struct HeaderBlock { | ||||
|     pseudo: Pseudo, | ||||
| } | ||||
|  | ||||
| #[derive(Debug)] | ||||
| struct EncodingHeaderBlock { | ||||
|     /// Argument to pass to the HPACK encoder to resume encoding | ||||
|     hpack: Option<hpack::EncodeState>, | ||||
|  | ||||
|     /// remaining headers to encode | ||||
|     headers: Iter, | ||||
| } | ||||
|  | ||||
| const END_STREAM: u8 = 0x1; | ||||
| const END_HEADERS: u8 = 0x4; | ||||
| const PADDED: u8 = 0x8; | ||||
| @@ -200,6 +205,10 @@ impl Headers { | ||||
|         self.flags.is_end_headers() | ||||
|     } | ||||
|  | ||||
|     pub fn set_end_headers(&mut self) { | ||||
|         self.flags.set_end_headers(); | ||||
|     } | ||||
|  | ||||
|     pub fn is_end_stream(&self) -> bool { | ||||
|         self.flags.is_end_stream() | ||||
|     } | ||||
| @@ -226,21 +235,15 @@ impl Headers { | ||||
|     } | ||||
|  | ||||
|     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { | ||||
|         // At this point, the `is_end_headers` flag should always be set | ||||
|         debug_assert!(self.flags.is_end_headers()); | ||||
|  | ||||
|         // Get the HEADERS frame head | ||||
|         let head = self.head(); | ||||
|         let pos = dst.len(); | ||||
|  | ||||
|         // At this point, we don't know how big the h2 frame will be. | ||||
|         // So, we write the head with length 0, then write the body, and | ||||
|         // finally write the length once we know the size. | ||||
|         head.encode(0, dst); | ||||
|  | ||||
|         // Encode the frame | ||||
|         let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); | ||||
|  | ||||
|         // Write the frame length | ||||
|         BigEndian::write_uint(&mut dst[pos..pos + 3], len, 3); | ||||
|  | ||||
|         cont | ||||
|         self.header_block.into_encoding() | ||||
|             .encode(&head, encoder, dst, |_| { | ||||
|             }) | ||||
|     } | ||||
|  | ||||
|     fn head(&self) -> Head { | ||||
| @@ -325,25 +328,23 @@ impl PushPromise { | ||||
|         self.flags.is_end_headers() | ||||
|     } | ||||
|  | ||||
|     pub fn set_end_headers(&mut self) { | ||||
|         self.flags.set_end_headers(); | ||||
|     } | ||||
|  | ||||
|     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { | ||||
|         use bytes::BufMut; | ||||
|  | ||||
|         // At this point, the `is_end_headers` flag should always be set | ||||
|         debug_assert!(self.flags.is_end_headers()); | ||||
|  | ||||
|         let head = self.head(); | ||||
|         let pos = dst.len(); | ||||
|         let promised_id = self.promised_id; | ||||
|  | ||||
|         // At this point, we don't know how big the h2 frame will be. | ||||
|         // So, we write the head with length 0, then write the body, and | ||||
|         // finally write the length once we know the size. | ||||
|         head.encode(0, dst); | ||||
|  | ||||
|         // Encode the frame | ||||
|         dst.put_u32::<BigEndian>(self.promised_id.into()); | ||||
|         let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); | ||||
|  | ||||
|         // Write the frame length | ||||
|         BigEndian::write_uint(&mut dst[pos..pos + 3], len + 4, 3); | ||||
|  | ||||
|         cont | ||||
|         self.header_block.into_encoding() | ||||
|             .encode(&head, encoder, dst, |dst| { | ||||
|                 dst.put_u32::<BigEndian>(promised_id.into()); | ||||
|             }) | ||||
|     } | ||||
|  | ||||
|     fn head(&self) -> Head { | ||||
| @@ -400,6 +401,23 @@ impl fmt::Debug for PushPromise { | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ===== impl Continuation ===== | ||||
|  | ||||
| impl Continuation { | ||||
|     fn head(&self) -> Head { | ||||
|         Head::new(Kind::Continuation, END_HEADERS, self.stream_id) | ||||
|     } | ||||
|  | ||||
|     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { | ||||
|         // Get the CONTINUATION frame head | ||||
|         let head = self.head(); | ||||
|  | ||||
|         self.header_block | ||||
|             .encode(&head, encoder, dst, |_| { | ||||
|             }) | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ===== impl Pseudo ===== | ||||
|  | ||||
| impl Pseudo { | ||||
| @@ -458,6 +476,58 @@ fn to_string(src: Bytes) -> String<Bytes> { | ||||
|     unsafe { String::from_utf8_unchecked(src) } | ||||
| } | ||||
|  | ||||
| // ===== impl EncodingHeaderBlock ===== | ||||
|  | ||||
| impl EncodingHeaderBlock { | ||||
|     fn encode<F>(mut self, | ||||
|                  head: &Head, | ||||
|                  encoder: &mut hpack::Encoder, | ||||
|                  dst: &mut BytesMut, | ||||
|                  f: F) | ||||
|         -> Option<Continuation> | ||||
|     where F: FnOnce(&mut BytesMut), | ||||
|     { | ||||
|         let head_pos = dst.len(); | ||||
|  | ||||
|         // At this point, we don't know how big the h2 frame will be. | ||||
|         // So, we write the head with length 0, then write the body, and | ||||
|         // finally write the length once we know the size. | ||||
|         head.encode(0, dst); | ||||
|  | ||||
|         let payload_pos = dst.len(); | ||||
|  | ||||
|         f(dst); | ||||
|  | ||||
|         // Now, encode the header payload | ||||
|         let continuation = match encoder.encode(self.hpack, &mut self.headers, dst) { | ||||
|             hpack::Encode::Full => None, | ||||
|             hpack::Encode::Partial(state) => Some(Continuation { | ||||
|                 stream_id: head.stream_id(), | ||||
|                 header_block: EncodingHeaderBlock { | ||||
|                     hpack: Some(state), | ||||
|                     headers: self.headers, | ||||
|                 }, | ||||
|             }), | ||||
|         }; | ||||
|  | ||||
|         // Compute the header block length | ||||
|         let payload_len = (dst.len() - payload_pos) as u64; | ||||
|  | ||||
|         // Write the frame length | ||||
|         BigEndian::write_uint(&mut dst[head_pos..head_pos + 3], payload_len, 3); | ||||
|  | ||||
|         if continuation.is_some() { | ||||
|             // There will be continuation frames, so the `is_end_headers` flag | ||||
|             // must be unset | ||||
|             debug_assert!(dst[head_pos + 4] & END_HEADERS == END_HEADERS); | ||||
|  | ||||
|             dst[head_pos + 4] -= END_HEADERS; | ||||
|         } | ||||
|  | ||||
|         continuation | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ===== impl Iter ===== | ||||
|  | ||||
| impl Iterator for Iter { | ||||
| @@ -515,13 +585,17 @@ impl HeadersFlag { | ||||
|     } | ||||
|  | ||||
|     pub fn set_end_stream(&mut self) { | ||||
|         self.0 |= END_STREAM | ||||
|         self.0 |= END_STREAM; | ||||
|     } | ||||
|  | ||||
|     pub fn is_end_headers(&self) -> bool { | ||||
|         self.0 & END_HEADERS == END_HEADERS | ||||
|     } | ||||
|  | ||||
|     pub fn set_end_headers(&mut self) { | ||||
|         self.0 |= END_HEADERS; | ||||
|     } | ||||
|  | ||||
|     pub fn is_padded(&self) -> bool { | ||||
|         self.0 & PADDED == PADDED | ||||
|     } | ||||
| @@ -570,6 +644,10 @@ impl PushPromiseFlag { | ||||
|         self.0 & END_HEADERS == END_HEADERS | ||||
|     } | ||||
|  | ||||
|     pub fn set_end_headers(&mut self) { | ||||
|         self.0 |= END_HEADERS; | ||||
|     } | ||||
|  | ||||
|     pub fn is_padded(&self) -> bool { | ||||
|         self.0 & PADDED == PADDED | ||||
|     } | ||||
| @@ -624,7 +702,10 @@ impl HeaderBlock { | ||||
|         // contain the entire payload. Later, we need to check for stream | ||||
|         // priority. | ||||
|         // | ||||
|         // TODO: Provide a way to abort decoding if an error is hit. | ||||
|         // If the header frame is malformed, we still have to continue decoding | ||||
|         // the headers. A malformed header frame is a stream level error, but | ||||
|         // the hpack state is connection level. In order to maintain correct | ||||
|         // state for other streams, the hpack decoding process must complete. | ||||
|         let res = decoder.decode(&mut src, |header| { | ||||
|             use hpack::Header::*; | ||||
|  | ||||
| @@ -673,30 +754,13 @@ impl HeaderBlock { | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     fn encode( | ||||
|         self, | ||||
|         stream_id: StreamId, | ||||
|         encoder: &mut hpack::Encoder, | ||||
|         dst: &mut BytesMut, | ||||
|     ) -> (u64, Option<Continuation>) { | ||||
|         let pos = dst.len(); | ||||
|         let mut headers = Iter { | ||||
|             pseudo: Some(self.pseudo), | ||||
|             fields: self.fields.into_iter(), | ||||
|         }; | ||||
|  | ||||
|         let cont = match encoder.encode(None, &mut headers, dst) { | ||||
|             hpack::Encode::Full => None, | ||||
|             hpack::Encode::Partial(state) => Some(Continuation { | ||||
|                 stream_id: stream_id, | ||||
|                 hpack: state, | ||||
|                 headers: headers, | ||||
|             }), | ||||
|         }; | ||||
|  | ||||
|         // Compute the header block length | ||||
|         let len = (dst.len() - pos) as u64; | ||||
|  | ||||
|         (len, cont) | ||||
|     fn into_encoding(self) -> EncodingHeaderBlock { | ||||
|         EncodingHeaderBlock { | ||||
|             hpack: None, | ||||
|             headers: Iter { | ||||
|                 pseudo: Some(self.pseudo), | ||||
|                 fields: self.fields.into_iter(), | ||||
|             }, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user