diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs index 5f07d1c..29a590e 100644 --- a/src/codec/framed_read.rs +++ b/src/codec/framed_read.rs @@ -59,6 +59,7 @@ impl FramedRead { let head = frame::Head::parse(&bytes); if self.partial.is_some() && head.kind() != Kind::Continuation { + trace!("connection error PROTOCOL_ERROR -- expected CONTINUATION, got {:?}", head.kind()); return Err(Connection(Reason::PROTOCOL_ERROR)); } @@ -70,24 +71,36 @@ impl FramedRead { Kind::Settings => { let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); - res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into() + res.map_err(|e| { + debug!("connection error PROTOCOL_ERROR -- failed to load SETTINGS frame; err={:?}", e); + Connection(Reason::PROTOCOL_ERROR) + })?.into() }, Kind::Ping => { let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]); - res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into() + res.map_err(|e| { + debug!("connection error PROTOCOL_ERROR -- failed to load PING frame; err={:?}", e); + Connection(Reason::PROTOCOL_ERROR) + })?.into() }, Kind::WindowUpdate => { let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]); - res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into() + res.map_err(|e| { + debug!("connection error PROTOCOL_ERROR -- failed to load WINDOW_UPDATE frame; err={:?}", e); + Connection(Reason::PROTOCOL_ERROR) + })?.into() }, Kind::Data => { let _ = bytes.split_to(frame::HEADER_LEN); let res = frame::Data::load(head, bytes.freeze()); // TODO: Should this always be connection level? Probably not... - res.map_err(|_| Connection(Reason::PROTOCOL_ERROR))?.into() + res.map_err(|e| { + debug!("connection error PROTOCOL_ERROR -- failed to load DATA frame; err={:?}", e); + Connection(Reason::PROTOCOL_ERROR) + })?.into() }, Kind::Headers => { // Drop the frame header @@ -101,12 +114,16 @@ impl FramedRead { // A stream cannot depend on itself. An endpoint MUST // treat this as a stream error (Section 5.4.2) of type // `PROTOCOL_ERROR`. + debug!("stream error PROTOCOL_ERROR -- invalid HEADERS dependency ID"); return Err(Stream { id: head.stream_id(), reason: Reason::PROTOCOL_ERROR, }); }, - _ => return Err(Connection(Reason::PROTOCOL_ERROR)), + Err(e) => { + debug!("connection error PROTOCOL_ERROR -- failed to load HEADERS frame; err={:?}", e); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } }; if headers.is_end_headers() { @@ -114,12 +131,16 @@ impl FramedRead { match headers.load_hpack(payload, &mut self.hpack) { Ok(_) => {}, Err(frame::Error::MalformedMessage) => { + debug!("stream error PROTOCOL_ERROR -- malformed HEADERS frame"); return Err(Stream { id: head.stream_id(), reason: Reason::PROTOCOL_ERROR, }); }, - Err(_) => return Err(Connection(Reason::PROTOCOL_ERROR)), + Err(e) => { + debug!("connection error PROTOCOL_ERROR -- failed HEADERS frame HPACK decoding; err={:?}", e); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } } headers.into() @@ -148,19 +169,26 @@ impl FramedRead { // Parse the frame w/o parsing the payload let (mut push, payload) = frame::PushPromise::load(head, bytes) - .map_err(|_| Connection(Reason::PROTOCOL_ERROR))?; + .map_err(|e| { + debug!("connection error PROTOCOL_ERROR -- failed to load PUSH_PROMISE frame; err={:?}", e); + Connection(Reason::PROTOCOL_ERROR) + })?; if push.is_end_headers() { // Load the HPACK encoded headers & return the frame match push.load_hpack(payload, &mut self.hpack) { Ok(_) => {}, Err(frame::Error::MalformedMessage) => { + debug!("stream error PROTOCOL_ERROR -- malformed PUSH_PROMISE frame"); return Err(Stream { id: head.stream_id(), reason: Reason::PROTOCOL_ERROR, }); }, - Err(_) => return Err(Connection(Reason::PROTOCOL_ERROR)), + Err(e) => { + debug!("connection error PROTOCOL_ERROR -- failed PUSH_PROMISE frame HPACK decoding; err={:?}", e); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } } push.into() @@ -186,6 +214,7 @@ impl FramedRead { // A stream cannot depend on itself. An endpoint MUST // treat this as a stream error (Section 5.4.2) of type // `PROTOCOL_ERROR`. + debug!("stream error PROTOCOL_ERROR -- PRIORITY invalid dependency ID"); return Err(Stream { id: head.stream_id(), reason: Reason::PROTOCOL_ERROR, @@ -200,7 +229,10 @@ impl FramedRead { let mut partial = match self.partial.take() { Some(partial) => partial, - None => return Err(Connection(Reason::PROTOCOL_ERROR)), + None => { + debug!("connection error PROTOCOL_ERROR -- received unexpected CONTINUATION frame"); + return Err(Connection(Reason::PROTOCOL_ERROR)); + } }; // Extend the buf @@ -213,12 +245,14 @@ impl FramedRead { // The stream identifiers must match if partial.frame.stream_id() != head.stream_id() { + debug!("connection error PROTOCOL_ERROR -- CONTINUATION frame stream ID does not match previous frame stream ID"); return Err(Connection(Reason::PROTOCOL_ERROR)); } match partial.frame.load_hpack(partial.buf, &mut self.hpack) { Ok(_) => {}, Err(frame::Error::MalformedMessage) => { + debug!("stream error PROTOCOL_ERROR -- malformed CONTINUATION frame"); return Err(Stream { id: head.stream_id(), reason: Reason::PROTOCOL_ERROR, @@ -326,8 +360,14 @@ impl Continuable { impl From for Frame { fn from(cont: Continuable) -> Self { match cont { - Continuable::Headers(headers) => headers.into(), - Continuable::PushPromise(push) => push.into(), + Continuable::Headers(mut headers) => { + headers.set_end_headers(); + headers.into() + } + Continuable::PushPromise(mut push) => { + push.set_end_headers(); + push.into() + } } } } diff --git a/src/codec/framed_write.rs b/src/codec/framed_write.rs index dbd867f..504e0c5 100644 --- a/src/codec/framed_write.rs +++ b/src/codec/framed_write.rs @@ -168,37 +168,48 @@ where pub fn flush(&mut self) -> Poll<(), io::Error> { trace!("flush"); - while !self.is_empty() { - match self.next { - Some(Next::Data(ref mut frame)) => { - let mut buf = Buf::by_ref(&mut self.buf).chain(frame.payload_mut()); - try_ready!(self.inner.write_buf(&mut buf)); - }, - _ => { - try_ready!(self.inner.write_buf(&mut self.buf)); - }, + loop { + while !self.is_empty() { + match self.next { + Some(Next::Data(ref mut frame)) => { + trace!(" -> queued data frame"); + let mut buf = Buf::by_ref(&mut self.buf).chain(frame.payload_mut()); + try_ready!(self.inner.write_buf(&mut buf)); + }, + _ => { + trace!(" -> not a queued data frame"); + try_ready!(self.inner.write_buf(&mut self.buf)); + }, + } } - } - // The data frame has been written, so unset it - match self.next.take() { - Some(Next::Data(frame)) => { - self.last_data_frame = Some(frame); - }, - Some(Next::Continuation(_)) => { - unimplemented!(); - }, - None => {}, + // Clear internal buffer + self.buf.set_position(0); + self.buf.get_mut().clear(); + + // The data frame has been written, so unset it + match self.next.take() { + Some(Next::Data(frame)) => { + self.last_data_frame = Some(frame); + debug_assert!(self.is_empty()); + break; + }, + Some(Next::Continuation(frame)) => { + // Buffer the continuation frame, then try to write again + if let Some(continuation) = frame.encode(&mut self.hpack, self.buf.get_mut()) { + self.next = Some(Next::Continuation(continuation)); + } + }, + None => { + break; + } + } } trace!("flushing buffer"); // Flush the upstream try_nb!(self.inner.flush()); - // Clear internal buffer - self.buf.set_position(0); - self.buf.get_mut().clear(); - Ok(Async::Ready(())) } diff --git a/src/frame/headers.rs b/src/frame/headers.rs index 06cb53f..4c98442 100644 --- a/src/frame/headers.rs +++ b/src/frame/headers.rs @@ -56,11 +56,7 @@ pub struct Continuation { /// Stream ID of continuation frame stream_id: StreamId, - /// Argument to pass to the HPACK encoder to resume encoding - hpack: hpack::EncodeState, - - /// remaining headers to encode - headers: Iter, + header_block: EncodingHeaderBlock, } // TODO: These fields shouldn't be `pub` @@ -85,7 +81,7 @@ pub struct Iter { fields: header::IntoIter, } -#[derive(PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq)] struct HeaderBlock { /// The decoded header fields fields: HeaderMap, @@ -95,6 +91,15 @@ struct HeaderBlock { pseudo: Pseudo, } +#[derive(Debug)] +struct EncodingHeaderBlock { + /// Argument to pass to the HPACK encoder to resume encoding + hpack: Option, + + /// remaining headers to encode + headers: Iter, +} + const END_STREAM: u8 = 0x1; const END_HEADERS: u8 = 0x4; const PADDED: u8 = 0x8; @@ -200,6 +205,10 @@ impl Headers { self.flags.is_end_headers() } + pub fn set_end_headers(&mut self) { + self.flags.set_end_headers(); + } + pub fn is_end_stream(&self) -> bool { self.flags.is_end_stream() } @@ -226,21 +235,15 @@ impl Headers { } pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option { + // At this point, the `is_end_headers` flag should always be set + debug_assert!(self.flags.is_end_headers()); + + // Get the HEADERS frame head let head = self.head(); - let pos = dst.len(); - // At this point, we don't know how big the h2 frame will be. - // So, we write the head with length 0, then write the body, and - // finally write the length once we know the size. - head.encode(0, dst); - - // Encode the frame - let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); - - // Write the frame length - BigEndian::write_uint(&mut dst[pos..pos + 3], len, 3); - - cont + self.header_block.into_encoding() + .encode(&head, encoder, dst, |_| { + }) } fn head(&self) -> Head { @@ -325,25 +328,23 @@ impl PushPromise { self.flags.is_end_headers() } + pub fn set_end_headers(&mut self) { + self.flags.set_end_headers(); + } + pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option { use bytes::BufMut; + // At this point, the `is_end_headers` flag should always be set + debug_assert!(self.flags.is_end_headers()); + let head = self.head(); - let pos = dst.len(); + let promised_id = self.promised_id; - // At this point, we don't know how big the h2 frame will be. - // So, we write the head with length 0, then write the body, and - // finally write the length once we know the size. - head.encode(0, dst); - - // Encode the frame - dst.put_u32::(self.promised_id.into()); - let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); - - // Write the frame length - BigEndian::write_uint(&mut dst[pos..pos + 3], len + 4, 3); - - cont + self.header_block.into_encoding() + .encode(&head, encoder, dst, |dst| { + dst.put_u32::(promised_id.into()); + }) } fn head(&self) -> Head { @@ -400,6 +401,23 @@ impl fmt::Debug for PushPromise { } } +// ===== impl Continuation ===== + +impl Continuation { + fn head(&self) -> Head { + Head::new(Kind::Continuation, END_HEADERS, self.stream_id) + } + + pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option { + // Get the CONTINUATION frame head + let head = self.head(); + + self.header_block + .encode(&head, encoder, dst, |_| { + }) + } +} + // ===== impl Pseudo ===== impl Pseudo { @@ -458,6 +476,58 @@ fn to_string(src: Bytes) -> String { unsafe { String::from_utf8_unchecked(src) } } +// ===== impl EncodingHeaderBlock ===== + +impl EncodingHeaderBlock { + fn encode(mut self, + head: &Head, + encoder: &mut hpack::Encoder, + dst: &mut BytesMut, + f: F) + -> Option + where F: FnOnce(&mut BytesMut), + { + let head_pos = dst.len(); + + // At this point, we don't know how big the h2 frame will be. + // So, we write the head with length 0, then write the body, and + // finally write the length once we know the size. + head.encode(0, dst); + + let payload_pos = dst.len(); + + f(dst); + + // Now, encode the header payload + let continuation = match encoder.encode(self.hpack, &mut self.headers, dst) { + hpack::Encode::Full => None, + hpack::Encode::Partial(state) => Some(Continuation { + stream_id: head.stream_id(), + header_block: EncodingHeaderBlock { + hpack: Some(state), + headers: self.headers, + }, + }), + }; + + // Compute the header block length + let payload_len = (dst.len() - payload_pos) as u64; + + // Write the frame length + BigEndian::write_uint(&mut dst[head_pos..head_pos + 3], payload_len, 3); + + if continuation.is_some() { + // There will be continuation frames, so the `is_end_headers` flag + // must be unset + debug_assert!(dst[head_pos + 4] & END_HEADERS == END_HEADERS); + + dst[head_pos + 4] -= END_HEADERS; + } + + continuation + } +} + // ===== impl Iter ===== impl Iterator for Iter { @@ -515,13 +585,17 @@ impl HeadersFlag { } pub fn set_end_stream(&mut self) { - self.0 |= END_STREAM + self.0 |= END_STREAM; } pub fn is_end_headers(&self) -> bool { self.0 & END_HEADERS == END_HEADERS } + pub fn set_end_headers(&mut self) { + self.0 |= END_HEADERS; + } + pub fn is_padded(&self) -> bool { self.0 & PADDED == PADDED } @@ -570,6 +644,10 @@ impl PushPromiseFlag { self.0 & END_HEADERS == END_HEADERS } + pub fn set_end_headers(&mut self) { + self.0 |= END_HEADERS; + } + pub fn is_padded(&self) -> bool { self.0 & PADDED == PADDED } @@ -624,7 +702,10 @@ impl HeaderBlock { // contain the entire payload. Later, we need to check for stream // priority. // - // TODO: Provide a way to abort decoding if an error is hit. + // If the header frame is malformed, we still have to continue decoding + // the headers. A malformed header frame is a stream level error, but + // the hpack state is connection level. In order to maintain correct + // state for other streams, the hpack decoding process must complete. let res = decoder.decode(&mut src, |header| { use hpack::Header::*; @@ -673,30 +754,13 @@ impl HeaderBlock { Ok(()) } - fn encode( - self, - stream_id: StreamId, - encoder: &mut hpack::Encoder, - dst: &mut BytesMut, - ) -> (u64, Option) { - let pos = dst.len(); - let mut headers = Iter { - pseudo: Some(self.pseudo), - fields: self.fields.into_iter(), - }; - - let cont = match encoder.encode(None, &mut headers, dst) { - hpack::Encode::Full => None, - hpack::Encode::Partial(state) => Some(Continuation { - stream_id: stream_id, - hpack: state, - headers: headers, - }), - }; - - // Compute the header block length - let len = (dst.len() - pos) as u64; - - (len, cont) + fn into_encoding(self) -> EncodingHeaderBlock { + EncodingHeaderBlock { + hpack: None, + headers: Iter { + pseudo: Some(self.pseudo), + fields: self.fields.into_iter(), + }, + } } } diff --git a/tests/codec_write.rs b/tests/codec_write.rs index 8b13789..14a321f 100644 --- a/tests/codec_write.rs +++ b/tests/codec_write.rs @@ -1 +1,86 @@ +#[macro_use] +pub mod support; +use support::prelude::*; +#[test] +fn write_continuation_frames() { + // An invalid dependency ID results in a stream level error. The hpack + // payload should still be decoded. + let _ = ::env_logger::init(); + let (io, srv) = mock::new(); + + let large = build_large_headers(); + + // Build the large request frame + let frame = large.iter().fold( + frames::headers(1).request("GET", "https://http2.akamai.com/"), + |frame, &(name, ref value)| frame.field(name, &value[..])); + + let srv = srv.assert_client_handshake() + .unwrap() + .recv_settings() + .recv_frame(frame.eos()) + .send_frame( + frames::headers(1) + .response(204) + .eos(), + ) + .close(); + + let client = Client::handshake(io) + .expect("handshake") + .and_then(|(mut client, conn)| { + let mut request = Request::builder(); + request.uri("https://http2.akamai.com/"); + + for &(name, ref value) in &large { + request.header(name, &value[..]); + } + + let request = request + .body(()) + .unwrap(); + + let req = client + .send_request(request, true) + .expect("send_request1") + .0 + .then(|res| { + let response = res.unwrap(); + assert_eq!(response.status(), StatusCode::NO_CONTENT); + Ok::<_, ()>(()) + }); + + conn.drive(req) + .and_then(move |(h2, _)| { + h2.unwrap() + }) + }); + + client.join(srv).wait().expect("wait"); +} + +fn build_large_headers() -> Vec<(&'static str, String)> { + vec![ + ("one", "hello".to_string()), + ("two", build_large_string('2', 4 * 1024)), + ("three", "three".to_string()), + ("four", build_large_string('4', 4 * 1024)), + ("five", "five".to_string()), + ("six", build_large_string('6', 4 * 1024)), + ("seven", "seven".to_string()), + ("eight", build_large_string('8', 4 * 1024)), + ("nine", "nine".to_string()), + ("ten", build_large_string('0', 4 * 1024)), + ] +} + +fn build_large_string(ch: char, len: usize) -> String { + let mut ret = String::new(); + + for _ in 0..len { + ret.push(ch); + } + + ret +}