Support writing continuation frames. (#198)
Large header sets might require being split up across multiple frames. This patch adds support for doing so.
This commit is contained in:
		| @@ -59,6 +59,7 @@ impl<T> FramedRead<T> { | |||||||
|         let head = frame::Head::parse(&bytes); |         let head = frame::Head::parse(&bytes); | ||||||
|  |  | ||||||
|         if self.partial.is_some() && head.kind() != Kind::Continuation { |         if self.partial.is_some() && head.kind() != Kind::Continuation { | ||||||
|  |             trace!("connection error PROTOCOL_ERROR -- expected CONTINUATION, got {:?}", head.kind()); | ||||||
|             return Err(Connection(Reason::PROTOCOL_ERROR)); |             return Err(Connection(Reason::PROTOCOL_ERROR)); | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -70,24 +71,36 @@ impl<T> FramedRead<T> { | |||||||
|             Kind::Settings => { |             Kind::Settings => { | ||||||
|                 let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); |                 let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); | ||||||
|  |  | ||||||
|                 res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into() |                 res.map_err(|e| { | ||||||
|  |                     debug!("connection error PROTOCOL_ERROR -- failed to load SETTINGS frame; err={:?}", e); | ||||||
|  |                     Connection(Reason::PROTOCOL_ERROR) | ||||||
|  |                 })?.into() | ||||||
|             }, |             }, | ||||||
|             Kind::Ping => { |             Kind::Ping => { | ||||||
|                 let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]); |                 let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]); | ||||||
|  |  | ||||||
|                 res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into() |                 res.map_err(|e| { | ||||||
|  |                     debug!("connection error PROTOCOL_ERROR -- failed to load PING frame; err={:?}", e); | ||||||
|  |                     Connection(Reason::PROTOCOL_ERROR) | ||||||
|  |                 })?.into() | ||||||
|             }, |             }, | ||||||
|             Kind::WindowUpdate => { |             Kind::WindowUpdate => { | ||||||
|                 let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]); |                 let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]); | ||||||
|  |  | ||||||
|                 res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into() |                 res.map_err(|e| { | ||||||
|  |                     debug!("connection error PROTOCOL_ERROR -- failed to load WINDOW_UPDATE frame; err={:?}", e); | ||||||
|  |                     Connection(Reason::PROTOCOL_ERROR) | ||||||
|  |                 })?.into() | ||||||
|             }, |             }, | ||||||
|             Kind::Data => { |             Kind::Data => { | ||||||
|                 let _ = bytes.split_to(frame::HEADER_LEN); |                 let _ = bytes.split_to(frame::HEADER_LEN); | ||||||
|                 let res = frame::Data::load(head, bytes.freeze()); |                 let res = frame::Data::load(head, bytes.freeze()); | ||||||
|  |  | ||||||
|                 // TODO: Should this always be connection level? Probably not... |                 // TODO: Should this always be connection level? Probably not... | ||||||
|                 res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into() |                 res.map_err(|e| { | ||||||
|  |                     debug!("connection error PROTOCOL_ERROR -- failed to load DATA frame; err={:?}", e); | ||||||
|  |                     Connection(Reason::PROTOCOL_ERROR) | ||||||
|  |                 })?.into() | ||||||
|             }, |             }, | ||||||
|             Kind::Headers => { |             Kind::Headers => { | ||||||
|                 // Drop the frame header |                 // Drop the frame header | ||||||
| @@ -101,12 +114,16 @@ impl<T> FramedRead<T> { | |||||||
|                         // A stream cannot depend on itself. An endpoint MUST |                         // A stream cannot depend on itself. An endpoint MUST | ||||||
|                         // treat this as a stream error (Section 5.4.2) of type |                         // treat this as a stream error (Section 5.4.2) of type | ||||||
|                         // `PROTOCOL_ERROR`. |                         // `PROTOCOL_ERROR`. | ||||||
|  |                         debug!("stream error PROTOCOL_ERROR -- invalid HEADERS dependency ID"); | ||||||
|                         return Err(Stream { |                         return Err(Stream { | ||||||
|                             id: head.stream_id(), |                             id: head.stream_id(), | ||||||
|                             reason: Reason::PROTOCOL_ERROR, |                             reason: Reason::PROTOCOL_ERROR, | ||||||
|                         }); |                         }); | ||||||
|                     }, |                     }, | ||||||
|                     _ => return Err(Connection(Reason::PROTOCOL_ERROR)), |                     Err(e) => { | ||||||
|  |                         debug!("connection error PROTOCOL_ERROR -- failed to load HEADERS frame; err={:?}", e); | ||||||
|  |                         return Err(Connection(Reason::PROTOCOL_ERROR)); | ||||||
|  |                     } | ||||||
|                 }; |                 }; | ||||||
|  |  | ||||||
|                 if headers.is_end_headers() { |                 if headers.is_end_headers() { | ||||||
| @@ -114,12 +131,16 @@ impl<T> FramedRead<T> { | |||||||
|                     match headers.load_hpack(payload, &mut self.hpack) { |                     match headers.load_hpack(payload, &mut self.hpack) { | ||||||
|                         Ok(_) => {}, |                         Ok(_) => {}, | ||||||
|                         Err(frame::Error::MalformedMessage) => { |                         Err(frame::Error::MalformedMessage) => { | ||||||
|  |                             debug!("stream error PROTOCOL_ERROR -- malformed HEADERS frame"); | ||||||
|                             return Err(Stream { |                             return Err(Stream { | ||||||
|                                 id: head.stream_id(), |                                 id: head.stream_id(), | ||||||
|                                 reason: Reason::PROTOCOL_ERROR, |                                 reason: Reason::PROTOCOL_ERROR, | ||||||
|                             }); |                             }); | ||||||
|                         }, |                         }, | ||||||
|                         Err(_) => return Err(Connection(Reason::PROTOCOL_ERROR)), |                         Err(e) => { | ||||||
|  |                             debug!("connection error PROTOCOL_ERROR -- failed HEADERS frame HPACK decoding; err={:?}", e); | ||||||
|  |                             return Err(Connection(Reason::PROTOCOL_ERROR)); | ||||||
|  |                         } | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                     headers.into() |                     headers.into() | ||||||
| @@ -148,19 +169,26 @@ impl<T> FramedRead<T> { | |||||||
|  |  | ||||||
|                 // Parse the frame w/o parsing the payload |                 // Parse the frame w/o parsing the payload | ||||||
|                 let (mut push, payload) = frame::PushPromise::load(head, bytes) |                 let (mut push, payload) = frame::PushPromise::load(head, bytes) | ||||||
|                     .map_err(|_| Connection(Reason::PROTOCOL_ERROR))?; |                     .map_err(|e| { | ||||||
|  |                         debug!("connection error PROTOCOL_ERROR -- failed to load PUSH_PROMISE frame; err={:?}", e); | ||||||
|  |                         Connection(Reason::PROTOCOL_ERROR) | ||||||
|  |                     })?; | ||||||
|  |  | ||||||
|                 if push.is_end_headers() { |                 if push.is_end_headers() { | ||||||
|                     // Load the HPACK encoded headers & return the frame |                     // Load the HPACK encoded headers & return the frame | ||||||
|                     match push.load_hpack(payload, &mut self.hpack) { |                     match push.load_hpack(payload, &mut self.hpack) { | ||||||
|                         Ok(_) => {}, |                         Ok(_) => {}, | ||||||
|                         Err(frame::Error::MalformedMessage) => { |                         Err(frame::Error::MalformedMessage) => { | ||||||
|  |                             debug!("stream error PROTOCOL_ERROR -- malformed PUSH_PROMISE frame"); | ||||||
|                             return Err(Stream { |                             return Err(Stream { | ||||||
|                                 id: head.stream_id(), |                                 id: head.stream_id(), | ||||||
|                                 reason: Reason::PROTOCOL_ERROR, |                                 reason: Reason::PROTOCOL_ERROR, | ||||||
|                             }); |                             }); | ||||||
|                         }, |                         }, | ||||||
|                         Err(_) => return Err(Connection(Reason::PROTOCOL_ERROR)), |                         Err(e) => { | ||||||
|  |                             debug!("connection error PROTOCOL_ERROR -- failed PUSH_PROMISE frame HPACK decoding; err={:?}", e); | ||||||
|  |                             return Err(Connection(Reason::PROTOCOL_ERROR)); | ||||||
|  |                         } | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|                     push.into() |                     push.into() | ||||||
| @@ -186,6 +214,7 @@ impl<T> FramedRead<T> { | |||||||
|                         // A stream cannot depend on itself. An endpoint MUST |                         // A stream cannot depend on itself. An endpoint MUST | ||||||
|                         // treat this as a stream error (Section 5.4.2) of type |                         // treat this as a stream error (Section 5.4.2) of type | ||||||
|                         // `PROTOCOL_ERROR`. |                         // `PROTOCOL_ERROR`. | ||||||
|  |                         debug!("stream error PROTOCOL_ERROR -- PRIORITY invalid dependency ID"); | ||||||
|                         return Err(Stream { |                         return Err(Stream { | ||||||
|                             id: head.stream_id(), |                             id: head.stream_id(), | ||||||
|                             reason: Reason::PROTOCOL_ERROR, |                             reason: Reason::PROTOCOL_ERROR, | ||||||
| @@ -200,7 +229,10 @@ impl<T> FramedRead<T> { | |||||||
|  |  | ||||||
|                 let mut partial = match self.partial.take() { |                 let mut partial = match self.partial.take() { | ||||||
|                     Some(partial) => partial, |                     Some(partial) => partial, | ||||||
|                     None => return Err(Connection(Reason::PROTOCOL_ERROR)), |                     None => { | ||||||
|  |                         debug!("connection error PROTOCOL_ERROR -- received unexpected CONTINUATION frame"); | ||||||
|  |                         return Err(Connection(Reason::PROTOCOL_ERROR)); | ||||||
|  |                     } | ||||||
|                 }; |                 }; | ||||||
|  |  | ||||||
|                 // Extend the buf |                 // Extend the buf | ||||||
| @@ -213,12 +245,14 @@ impl<T> FramedRead<T> { | |||||||
|  |  | ||||||
|                 // The stream identifiers must match |                 // The stream identifiers must match | ||||||
|                 if partial.frame.stream_id() != head.stream_id() { |                 if partial.frame.stream_id() != head.stream_id() { | ||||||
|  |                     debug!("connection error PROTOCOL_ERROR -- CONTINUATION frame stream ID does not match previous frame stream ID"); | ||||||
|                     return Err(Connection(Reason::PROTOCOL_ERROR)); |                     return Err(Connection(Reason::PROTOCOL_ERROR)); | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 match partial.frame.load_hpack(partial.buf, &mut self.hpack) { |                 match partial.frame.load_hpack(partial.buf, &mut self.hpack) { | ||||||
|                     Ok(_) => {}, |                     Ok(_) => {}, | ||||||
|                     Err(frame::Error::MalformedMessage) => { |                     Err(frame::Error::MalformedMessage) => { | ||||||
|  |                         debug!("stream error PROTOCOL_ERROR -- malformed CONTINUATION frame"); | ||||||
|                         return Err(Stream { |                         return Err(Stream { | ||||||
|                             id: head.stream_id(), |                             id: head.stream_id(), | ||||||
|                             reason: Reason::PROTOCOL_ERROR, |                             reason: Reason::PROTOCOL_ERROR, | ||||||
| @@ -326,8 +360,14 @@ impl Continuable { | |||||||
| impl<T> From<Continuable> for Frame<T> { | impl<T> From<Continuable> for Frame<T> { | ||||||
|     fn from(cont: Continuable) -> Self { |     fn from(cont: Continuable) -> Self { | ||||||
|         match cont { |         match cont { | ||||||
|             Continuable::Headers(headers) => headers.into(), |             Continuable::Headers(mut headers) => { | ||||||
|             Continuable::PushPromise(push) => push.into(), |                 headers.set_end_headers(); | ||||||
|  |                 headers.into() | ||||||
|  |             } | ||||||
|  |             Continuable::PushPromise(mut push) => { | ||||||
|  |                 push.set_end_headers(); | ||||||
|  |                 push.into() | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -168,37 +168,48 @@ where | |||||||
|     pub fn flush(&mut self) -> Poll<(), io::Error> { |     pub fn flush(&mut self) -> Poll<(), io::Error> { | ||||||
|         trace!("flush"); |         trace!("flush"); | ||||||
|  |  | ||||||
|  |         loop { | ||||||
|             while !self.is_empty() { |             while !self.is_empty() { | ||||||
|                 match self.next { |                 match self.next { | ||||||
|                     Some(Next::Data(ref mut frame)) => { |                     Some(Next::Data(ref mut frame)) => { | ||||||
|  |                         trace!("  -> queued data frame"); | ||||||
|                         let mut buf = Buf::by_ref(&mut self.buf).chain(frame.payload_mut()); |                         let mut buf = Buf::by_ref(&mut self.buf).chain(frame.payload_mut()); | ||||||
|                         try_ready!(self.inner.write_buf(&mut buf)); |                         try_ready!(self.inner.write_buf(&mut buf)); | ||||||
|                     }, |                     }, | ||||||
|                     _ => { |                     _ => { | ||||||
|  |                         trace!("  -> not a queued data frame"); | ||||||
|                         try_ready!(self.inner.write_buf(&mut self.buf)); |                         try_ready!(self.inner.write_buf(&mut self.buf)); | ||||||
|                     }, |                     }, | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |  | ||||||
|  |             // Clear internal buffer | ||||||
|  |             self.buf.set_position(0); | ||||||
|  |             self.buf.get_mut().clear(); | ||||||
|  |  | ||||||
|             // The data frame has been written, so unset it |             // The data frame has been written, so unset it | ||||||
|             match self.next.take() { |             match self.next.take() { | ||||||
|                 Some(Next::Data(frame)) => { |                 Some(Next::Data(frame)) => { | ||||||
|                     self.last_data_frame = Some(frame); |                     self.last_data_frame = Some(frame); | ||||||
|  |                     debug_assert!(self.is_empty()); | ||||||
|  |                     break; | ||||||
|                 }, |                 }, | ||||||
|             Some(Next::Continuation(_)) => { |                 Some(Next::Continuation(frame)) => { | ||||||
|                 unimplemented!(); |                     // Buffer the continuation frame, then try to write again | ||||||
|  |                     if let Some(continuation) = frame.encode(&mut self.hpack, self.buf.get_mut()) { | ||||||
|  |                         self.next = Some(Next::Continuation(continuation)); | ||||||
|  |                     } | ||||||
|                 }, |                 }, | ||||||
|             None => {}, |                 None => { | ||||||
|  |                     break; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         trace!("flushing buffer"); |         trace!("flushing buffer"); | ||||||
|         // Flush the upstream |         // Flush the upstream | ||||||
|         try_nb!(self.inner.flush()); |         try_nb!(self.inner.flush()); | ||||||
|  |  | ||||||
|         // Clear internal buffer |  | ||||||
|         self.buf.set_position(0); |  | ||||||
|         self.buf.get_mut().clear(); |  | ||||||
|  |  | ||||||
|         Ok(Async::Ready(())) |         Ok(Async::Ready(())) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -56,11 +56,7 @@ pub struct Continuation { | |||||||
|     /// Stream ID of continuation frame |     /// Stream ID of continuation frame | ||||||
|     stream_id: StreamId, |     stream_id: StreamId, | ||||||
|  |  | ||||||
|     /// Argument to pass to the HPACK encoder to resume encoding |     header_block: EncodingHeaderBlock, | ||||||
|     hpack: hpack::EncodeState, |  | ||||||
|  |  | ||||||
|     /// remaining headers to encode |  | ||||||
|     headers: Iter, |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // TODO: These fields shouldn't be `pub` | // TODO: These fields shouldn't be `pub` | ||||||
| @@ -85,7 +81,7 @@ pub struct Iter { | |||||||
|     fields: header::IntoIter<HeaderValue>, |     fields: header::IntoIter<HeaderValue>, | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(PartialEq, Eq)] | #[derive(Debug, PartialEq, Eq)] | ||||||
| struct HeaderBlock { | struct HeaderBlock { | ||||||
|     /// The decoded header fields |     /// The decoded header fields | ||||||
|     fields: HeaderMap, |     fields: HeaderMap, | ||||||
| @@ -95,6 +91,15 @@ struct HeaderBlock { | |||||||
|     pseudo: Pseudo, |     pseudo: Pseudo, | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[derive(Debug)] | ||||||
|  | struct EncodingHeaderBlock { | ||||||
|  |     /// Argument to pass to the HPACK encoder to resume encoding | ||||||
|  |     hpack: Option<hpack::EncodeState>, | ||||||
|  |  | ||||||
|  |     /// remaining headers to encode | ||||||
|  |     headers: Iter, | ||||||
|  | } | ||||||
|  |  | ||||||
| const END_STREAM: u8 = 0x1; | const END_STREAM: u8 = 0x1; | ||||||
| const END_HEADERS: u8 = 0x4; | const END_HEADERS: u8 = 0x4; | ||||||
| const PADDED: u8 = 0x8; | const PADDED: u8 = 0x8; | ||||||
| @@ -200,6 +205,10 @@ impl Headers { | |||||||
|         self.flags.is_end_headers() |         self.flags.is_end_headers() | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn set_end_headers(&mut self) { | ||||||
|  |         self.flags.set_end_headers(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn is_end_stream(&self) -> bool { |     pub fn is_end_stream(&self) -> bool { | ||||||
|         self.flags.is_end_stream() |         self.flags.is_end_stream() | ||||||
|     } |     } | ||||||
| @@ -226,21 +235,15 @@ impl Headers { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { |     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { | ||||||
|  |         // At this point, the `is_end_headers` flag should always be set | ||||||
|  |         debug_assert!(self.flags.is_end_headers()); | ||||||
|  |  | ||||||
|  |         // Get the HEADERS frame head | ||||||
|         let head = self.head(); |         let head = self.head(); | ||||||
|         let pos = dst.len(); |  | ||||||
|  |  | ||||||
|         // At this point, we don't know how big the h2 frame will be. |         self.header_block.into_encoding() | ||||||
|         // So, we write the head with length 0, then write the body, and |             .encode(&head, encoder, dst, |_| { | ||||||
|         // finally write the length once we know the size. |             }) | ||||||
|         head.encode(0, dst); |  | ||||||
|  |  | ||||||
|         // Encode the frame |  | ||||||
|         let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); |  | ||||||
|  |  | ||||||
|         // Write the frame length |  | ||||||
|         BigEndian::write_uint(&mut dst[pos..pos + 3], len, 3); |  | ||||||
|  |  | ||||||
|         cont |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     fn head(&self) -> Head { |     fn head(&self) -> Head { | ||||||
| @@ -325,25 +328,23 @@ impl PushPromise { | |||||||
|         self.flags.is_end_headers() |         self.flags.is_end_headers() | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn set_end_headers(&mut self) { | ||||||
|  |         self.flags.set_end_headers(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { |     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { | ||||||
|         use bytes::BufMut; |         use bytes::BufMut; | ||||||
|  |  | ||||||
|  |         // At this point, the `is_end_headers` flag should always be set | ||||||
|  |         debug_assert!(self.flags.is_end_headers()); | ||||||
|  |  | ||||||
|         let head = self.head(); |         let head = self.head(); | ||||||
|         let pos = dst.len(); |         let promised_id = self.promised_id; | ||||||
|  |  | ||||||
|         // At this point, we don't know how big the h2 frame will be. |         self.header_block.into_encoding() | ||||||
|         // So, we write the head with length 0, then write the body, and |             .encode(&head, encoder, dst, |dst| { | ||||||
|         // finally write the length once we know the size. |                 dst.put_u32::<BigEndian>(promised_id.into()); | ||||||
|         head.encode(0, dst); |             }) | ||||||
|  |  | ||||||
|         // Encode the frame |  | ||||||
|         dst.put_u32::<BigEndian>(self.promised_id.into()); |  | ||||||
|         let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); |  | ||||||
|  |  | ||||||
|         // Write the frame length |  | ||||||
|         BigEndian::write_uint(&mut dst[pos..pos + 3], len + 4, 3); |  | ||||||
|  |  | ||||||
|         cont |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     fn head(&self) -> Head { |     fn head(&self) -> Head { | ||||||
| @@ -400,6 +401,23 @@ impl fmt::Debug for PushPromise { | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // ===== impl Continuation ===== | ||||||
|  |  | ||||||
|  | impl Continuation { | ||||||
|  |     fn head(&self) -> Head { | ||||||
|  |         Head::new(Kind::Continuation, END_HEADERS, self.stream_id) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { | ||||||
|  |         // Get the CONTINUATION frame head | ||||||
|  |         let head = self.head(); | ||||||
|  |  | ||||||
|  |         self.header_block | ||||||
|  |             .encode(&head, encoder, dst, |_| { | ||||||
|  |             }) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| // ===== impl Pseudo ===== | // ===== impl Pseudo ===== | ||||||
|  |  | ||||||
| impl Pseudo { | impl Pseudo { | ||||||
| @@ -458,6 +476,58 @@ fn to_string(src: Bytes) -> String<Bytes> { | |||||||
|     unsafe { String::from_utf8_unchecked(src) } |     unsafe { String::from_utf8_unchecked(src) } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // ===== impl EncodingHeaderBlock ===== | ||||||
|  |  | ||||||
|  | impl EncodingHeaderBlock { | ||||||
|  |     fn encode<F>(mut self, | ||||||
|  |                  head: &Head, | ||||||
|  |                  encoder: &mut hpack::Encoder, | ||||||
|  |                  dst: &mut BytesMut, | ||||||
|  |                  f: F) | ||||||
|  |         -> Option<Continuation> | ||||||
|  |     where F: FnOnce(&mut BytesMut), | ||||||
|  |     { | ||||||
|  |         let head_pos = dst.len(); | ||||||
|  |  | ||||||
|  |         // At this point, we don't know how big the h2 frame will be. | ||||||
|  |         // So, we write the head with length 0, then write the body, and | ||||||
|  |         // finally write the length once we know the size. | ||||||
|  |         head.encode(0, dst); | ||||||
|  |  | ||||||
|  |         let payload_pos = dst.len(); | ||||||
|  |  | ||||||
|  |         f(dst); | ||||||
|  |  | ||||||
|  |         // Now, encode the header payload | ||||||
|  |         let continuation = match encoder.encode(self.hpack, &mut self.headers, dst) { | ||||||
|  |             hpack::Encode::Full => None, | ||||||
|  |             hpack::Encode::Partial(state) => Some(Continuation { | ||||||
|  |                 stream_id: head.stream_id(), | ||||||
|  |                 header_block: EncodingHeaderBlock { | ||||||
|  |                     hpack: Some(state), | ||||||
|  |                     headers: self.headers, | ||||||
|  |                 }, | ||||||
|  |             }), | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         // Compute the header block length | ||||||
|  |         let payload_len = (dst.len() - payload_pos) as u64; | ||||||
|  |  | ||||||
|  |         // Write the frame length | ||||||
|  |         BigEndian::write_uint(&mut dst[head_pos..head_pos + 3], payload_len, 3); | ||||||
|  |  | ||||||
|  |         if continuation.is_some() { | ||||||
|  |             // There will be continuation frames, so the `is_end_headers` flag | ||||||
|  |             // must be unset | ||||||
|  |             debug_assert!(dst[head_pos + 4] & END_HEADERS == END_HEADERS); | ||||||
|  |  | ||||||
|  |             dst[head_pos + 4] -= END_HEADERS; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         continuation | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| // ===== impl Iter ===== | // ===== impl Iter ===== | ||||||
|  |  | ||||||
| impl Iterator for Iter { | impl Iterator for Iter { | ||||||
| @@ -515,13 +585,17 @@ impl HeadersFlag { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn set_end_stream(&mut self) { |     pub fn set_end_stream(&mut self) { | ||||||
|         self.0 |= END_STREAM |         self.0 |= END_STREAM; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn is_end_headers(&self) -> bool { |     pub fn is_end_headers(&self) -> bool { | ||||||
|         self.0 & END_HEADERS == END_HEADERS |         self.0 & END_HEADERS == END_HEADERS | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn set_end_headers(&mut self) { | ||||||
|  |         self.0 |= END_HEADERS; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn is_padded(&self) -> bool { |     pub fn is_padded(&self) -> bool { | ||||||
|         self.0 & PADDED == PADDED |         self.0 & PADDED == PADDED | ||||||
|     } |     } | ||||||
| @@ -570,6 +644,10 @@ impl PushPromiseFlag { | |||||||
|         self.0 & END_HEADERS == END_HEADERS |         self.0 & END_HEADERS == END_HEADERS | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn set_end_headers(&mut self) { | ||||||
|  |         self.0 |= END_HEADERS; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn is_padded(&self) -> bool { |     pub fn is_padded(&self) -> bool { | ||||||
|         self.0 & PADDED == PADDED |         self.0 & PADDED == PADDED | ||||||
|     } |     } | ||||||
| @@ -624,7 +702,10 @@ impl HeaderBlock { | |||||||
|         // contain the entire payload. Later, we need to check for stream |         // contain the entire payload. Later, we need to check for stream | ||||||
|         // priority. |         // priority. | ||||||
|         // |         // | ||||||
|         // TODO: Provide a way to abort decoding if an error is hit. |         // If the header frame is malformed, we still have to continue decoding | ||||||
|  |         // the headers. A malformed header frame is a stream level error, but | ||||||
|  |         // the hpack state is connection level. In order to maintain correct | ||||||
|  |         // state for other streams, the hpack decoding process must complete. | ||||||
|         let res = decoder.decode(&mut src, |header| { |         let res = decoder.decode(&mut src, |header| { | ||||||
|             use hpack::Header::*; |             use hpack::Header::*; | ||||||
|  |  | ||||||
| @@ -673,30 +754,13 @@ impl HeaderBlock { | |||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     fn encode( |     fn into_encoding(self) -> EncodingHeaderBlock { | ||||||
|         self, |         EncodingHeaderBlock { | ||||||
|         stream_id: StreamId, |             hpack: None, | ||||||
|         encoder: &mut hpack::Encoder, |             headers: Iter { | ||||||
|         dst: &mut BytesMut, |  | ||||||
|     ) -> (u64, Option<Continuation>) { |  | ||||||
|         let pos = dst.len(); |  | ||||||
|         let mut headers = Iter { |  | ||||||
|                 pseudo: Some(self.pseudo), |                 pseudo: Some(self.pseudo), | ||||||
|                 fields: self.fields.into_iter(), |                 fields: self.fields.into_iter(), | ||||||
|         }; |             }, | ||||||
|  |         } | ||||||
|         let cont = match encoder.encode(None, &mut headers, dst) { |  | ||||||
|             hpack::Encode::Full => None, |  | ||||||
|             hpack::Encode::Partial(state) => Some(Continuation { |  | ||||||
|                 stream_id: stream_id, |  | ||||||
|                 hpack: state, |  | ||||||
|                 headers: headers, |  | ||||||
|             }), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         // Compute the header block length |  | ||||||
|         let len = (dst.len() - pos) as u64; |  | ||||||
|  |  | ||||||
|         (len, cont) |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1 +1,86 @@ | |||||||
|  | #[macro_use] | ||||||
|  | pub mod support; | ||||||
|  | use support::prelude::*; | ||||||
|  |  | ||||||
|  | #[test] | ||||||
|  | fn write_continuation_frames() { | ||||||
|  |     // An invalid dependency ID results in a stream level error. The hpack | ||||||
|  |     // payload should still be decoded. | ||||||
|  |     let _ = ::env_logger::init(); | ||||||
|  |     let (io, srv) = mock::new(); | ||||||
|  |  | ||||||
|  |     let large = build_large_headers(); | ||||||
|  |  | ||||||
|  |     // Build the large request frame | ||||||
|  |     let frame = large.iter().fold( | ||||||
|  |         frames::headers(1).request("GET", "https://http2.akamai.com/"), | ||||||
|  |         |frame, &(name, ref value)| frame.field(name, &value[..])); | ||||||
|  |  | ||||||
|  |     let srv = srv.assert_client_handshake() | ||||||
|  |         .unwrap() | ||||||
|  |         .recv_settings() | ||||||
|  |         .recv_frame(frame.eos()) | ||||||
|  |         .send_frame( | ||||||
|  |             frames::headers(1) | ||||||
|  |                 .response(204) | ||||||
|  |                 .eos(), | ||||||
|  |         ) | ||||||
|  |         .close(); | ||||||
|  |  | ||||||
|  |     let client = Client::handshake(io) | ||||||
|  |         .expect("handshake") | ||||||
|  |         .and_then(|(mut client, conn)| { | ||||||
|  |             let mut request = Request::builder(); | ||||||
|  |             request.uri("https://http2.akamai.com/"); | ||||||
|  |  | ||||||
|  |             for &(name, ref value) in &large { | ||||||
|  |                 request.header(name, &value[..]); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             let request = request | ||||||
|  |                 .body(()) | ||||||
|  |                 .unwrap(); | ||||||
|  |  | ||||||
|  |             let req = client | ||||||
|  |                 .send_request(request, true) | ||||||
|  |                 .expect("send_request1") | ||||||
|  |                 .0 | ||||||
|  |                 .then(|res| { | ||||||
|  |                     let response = res.unwrap(); | ||||||
|  |                     assert_eq!(response.status(), StatusCode::NO_CONTENT); | ||||||
|  |                     Ok::<_, ()>(()) | ||||||
|  |                 }); | ||||||
|  |  | ||||||
|  |             conn.drive(req) | ||||||
|  |                 .and_then(move |(h2, _)| { | ||||||
|  |                     h2.unwrap() | ||||||
|  |                 }) | ||||||
|  |         }); | ||||||
|  |  | ||||||
|  |     client.join(srv).wait().expect("wait"); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn build_large_headers() -> Vec<(&'static str, String)> { | ||||||
|  |     vec![ | ||||||
|  |         ("one", "hello".to_string()), | ||||||
|  |         ("two", build_large_string('2', 4 * 1024)), | ||||||
|  |         ("three", "three".to_string()), | ||||||
|  |         ("four", build_large_string('4', 4 * 1024)), | ||||||
|  |         ("five", "five".to_string()), | ||||||
|  |         ("six", build_large_string('6', 4 * 1024)), | ||||||
|  |         ("seven", "seven".to_string()), | ||||||
|  |         ("eight", build_large_string('8', 4 * 1024)), | ||||||
|  |         ("nine", "nine".to_string()), | ||||||
|  |         ("ten", build_large_string('0', 4 * 1024)), | ||||||
|  |     ] | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn build_large_string(ch: char, len: usize) -> String { | ||||||
|  |     let mut ret = String::new(); | ||||||
|  |  | ||||||
|  |     for _ in 0..len { | ||||||
|  |         ret.push(ch); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     ret | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user