diff --git a/src/hpack/encoder.rs b/src/hpack/encoder.rs index 09118e4..27528d3 100644 --- a/src/hpack/encoder.rs +++ b/src/hpack/encoder.rs @@ -17,7 +17,10 @@ pub enum Encode { } #[derive(Debug)] -pub struct EncodeState(Index); +pub struct EncodeState { + index: Index, + value: Option, +} #[derive(Debug, PartialEq, Eq)] pub enum EncoderError { @@ -74,7 +77,7 @@ impl Encoder { /// Encode a set of headers into the provide buffer pub fn encode(&mut self, resume: Option, headers: &mut I, dst: &mut BytesMut) -> Result - where I: Iterator, + where I: Iterator>>, { let len = dst.len(); @@ -89,28 +92,62 @@ impl Encoder { if let Some(resume) = resume { let len = dst.len(); - match self.encode_header(&resume.0, dst) { - Err(EncoderError::BufferOverflow) => { - dst.truncate(len); - return Ok(Encode::Partial(resume)); + if let Some(ref value) = resume.value { + unimplemented!(); + } else { + // Encode the header + match self.encode_header(&resume.index, dst) { + Err(EncoderError::BufferOverflow) => { + dst.truncate(len); + return Ok(Encode::Partial(resume)); + } + Err(e) => return Err(e), + Ok(_) => {} } - Err(e) => return Err(e), - Ok(_) => {} } } + let mut last_index = None; + for header in headers { let len = dst.len(); - let index = self.table.index(header); - match self.encode_header(&index, dst) { - Err(EncoderError::BufferOverflow) => { - dst.truncate(len); - return Ok(Encode::Partial(EncodeState(index))); + match header.reify() { + // The header has an associated name. In which case, try to + // index it in the table. + Ok(header) => { + let index = self.table.index(header); + let res = self.encode_header(&index, dst); + + if try!(is_buffer_overflow(res)) { + dst.truncate(len); + return Ok(Encode::Partial(EncodeState { + index: index, + value: None, + })); + } + + last_index = Some(index); } - Err(e) => return Err(e), - Ok(_) => {} - } + // The header does not have an associated name. This means that + // the name is the same as the previously yielded header. In + // which case, we skip table lookup and just use the same index + // as the previous entry. + Err(value) => { + let res = self.encode_header_without_name( + last_index.as_ref().unwrap(), + &value, + dst); + + if try!(is_buffer_overflow(res)) { + dst.truncate(len); + return Ok(Encode::Partial(EncodeState { + index: last_index.unwrap(), + value: Some(value), + })); + } + } + }; } Ok(Encode::Full) @@ -145,13 +182,11 @@ impl Encoder { Index::Name(idx, _) => { let header = self.table.resolve(&index); - if header.is_sensitive() { - try!(encode_int(idx, 4, 0b10000, dst)); - } else { - try!(encode_int(idx, 4, 0, dst)); - } - - try!(encode_str(header.value_slice(), dst)); + try!(encode_not_indexed( + idx, + header.value_slice(), + header.is_sensitive(), + dst)); } Index::Inserted(idx) => { let header = self.table.resolve(&index); @@ -178,18 +213,41 @@ impl Encoder { Index::NotIndexed(_) => { let header = self.table.resolve(&index); - if !dst.has_remaining_mut() { - return Err(EncoderError::BufferOverflow); - } + try!(encode_not_indexed2( + header.name().as_slice(), + header.value_slice(), + header.is_sensitive(), + dst)); + } + } - if header.is_sensitive() { - dst.put_u8(0b10000); - } else { - dst.put_u8(0); - } + Ok(()) + } - try!(encode_str(header.name().as_slice(), dst)); - try!(encode_str(header.value_slice(), dst)); + fn encode_header_without_name(&mut self, last: &Index, + value: &HeaderValue, dst: &mut BytesMut) + -> Result<(), EncoderError> + { + match *last { + Index::Indexed(idx, ..) | + Index::Name(idx, ..) | + Index::Inserted(idx) | + Index::InsertedValue(idx, ..) => + { + try!(encode_not_indexed( + idx, + value.as_ref(), + value.is_sensitive(), + dst)); + } + Index::NotIndexed(_) => { + let last = self.table.resolve(last); + + try!(encode_not_indexed2( + last.name().as_slice(), + value.as_ref(), + value.is_sensitive(), + dst)); } } @@ -203,6 +261,43 @@ impl Default for Encoder { } } +fn encode_size_update(val: usize, dst: &mut B) -> Result<(), EncoderError> { + encode_int(val, 5, 0b00100000, dst) +} + +fn encode_not_indexed(name: usize, value: &[u8], + sensitive: bool, dst: &mut BytesMut) + -> Result<(), EncoderError> +{ + if sensitive { + try!(encode_int(name, 4, 0b10000, dst)); + } else { + try!(encode_int(name, 4, 0, dst)); + } + + try!(encode_str(value, dst)); + Ok(()) +} + +fn encode_not_indexed2(name: &[u8], value: &[u8], + sensitive: bool, dst: &mut BytesMut) + -> Result<(), EncoderError> +{ + if !dst.has_remaining_mut() { + return Err(EncoderError::BufferOverflow); + } + + if sensitive { + dst.put_u8(0b10000); + } else { + dst.put_u8(0); + } + + try!(encode_str(name, dst)); + try!(encode_str(value, dst)); + Ok(()) +} + fn encode_str(val: &[u8], dst: &mut BytesMut) -> Result<(), EncoderError> { use std::io::Cursor; @@ -261,10 +356,6 @@ fn encode_str(val: &[u8], dst: &mut BytesMut) -> Result<(), EncoderError> { Ok(()) } -fn encode_size_update(val: usize, dst: &mut B) -> Result<(), EncoderError> { - encode_int(val, 5, 0b00100000, dst) -} - /// Encode an integer into the given destination buffer fn encode_int( mut value: usize, // The integer to encode @@ -321,6 +412,14 @@ fn encode_int_one_byte(value: usize, prefix_bits: usize) -> bool { value < (1 << prefix_bits) - 1 } +fn is_buffer_overflow(res: Result<(), EncoderError>) -> Result { + match res { + Err(EncoderError::BufferOverflow) => Ok(true), + Err(e) => Err(e), + Ok(_) => Ok(false), + } +} + #[cfg(test)] mod test { use super::*; @@ -452,7 +551,7 @@ mod test { let mut value = HeaderValue::try_from_bytes(b"12345").unwrap(); value.set_sensitive(true); - let header = Header::Field { name: name, value: value }; + let header = Header::Field { name: Some(name), value: value }; // Now, try to encode the sensitive header @@ -469,7 +568,7 @@ mod test { let mut value = HeaderValue::try_from_bytes(b"12345").unwrap(); value.set_sensitive(true); - let header = Header::Field { name: name, value: value }; + let header = Header::Field { name: Some(name), value: value }; let mut encoder = Encoder::default(); let res = encode(&mut encoder, vec![header]); @@ -487,7 +586,7 @@ mod test { let mut value = HeaderValue::try_from_bytes(b"12345").unwrap(); value.set_sensitive(true); - let header = Header::Field { name: name, value: value }; + let header = Header::Field { name: Some(name), value: value }; let res = encode(&mut encoder, vec![header]); assert_eq!(&[0b11111, 47], &res[..2]); @@ -649,23 +748,23 @@ mod test { // Not sure what the best way to do this is. } - fn encode(e: &mut Encoder, hdrs: Vec
) -> BytesMut { + fn encode(e: &mut Encoder, hdrs: Vec>>) -> BytesMut { let mut dst = BytesMut::with_capacity(1024); e.encode(None, &mut hdrs.into_iter(), &mut dst); dst } - fn method(s: &str) -> Header { + fn method(s: &str) -> Header> { Header::Method(Method::from_bytes(s.as_bytes()).unwrap()) } - fn header(name: &str, val: &str) -> Header { + fn header(name: &str, val: &str) -> Header> { use http::header::{HeaderName, HeaderValue}; let name = HeaderName::from_bytes(name.as_bytes()).unwrap(); let value = HeaderValue::try_from_bytes(val.as_bytes()).unwrap(); - Header::Field { name: name, value: value } + Header::Field { name: Some(name), value: value } } fn huff_decode(src: &[u8]) -> BytesMut { diff --git a/src/hpack/header.rs b/src/hpack/header.rs index 19270f7..c5b11b1 100644 --- a/src/hpack/header.rs +++ b/src/hpack/header.rs @@ -7,9 +7,9 @@ use bytes::Bytes; /// HTTP/2.0 Header #[derive(Debug, Clone, Eq, PartialEq)] -pub enum Header { +pub enum Header { Field { - name: HeaderName, + name: T, value: HeaderValue, }, Authority(ByteStr), @@ -35,6 +35,22 @@ pub fn len(name: &HeaderName, value: &HeaderValue) -> usize { 32 + n.len() + value.len() } +impl Header> { + pub fn reify(self) -> Result { + use self::Header::*; + + Ok(match self { + Field { name: Some(n), value } => Field { name: n, value: value }, + Field { name: None, value } => return Err(value), + Authority(v) => Authority(v), + Method(v) => Method(v), + Scheme(v) => Scheme(v), + Path(v) => Path(v), + Status(v) => Status(v), + }) + } +} + impl Header { pub fn new(name: Bytes, value: Bytes) -> Result { if name[0] == b':' { @@ -191,6 +207,20 @@ impl Header { } } +// Mostly for tests +impl From
for Header> { + fn from(src: Header) -> Self { + match src { + Header::Field { name, value } => Header::Field { name: Some(name), value }, + Header::Authority(v) => Header::Authority(v), + Header::Method(v) => Header::Method(v), + Header::Scheme(v) => Header::Scheme(v), + Header::Path(v) => Header::Path(v), + Header::Status(v) => Header::Status(v), + } + } +} + impl<'a> Name<'a> { pub fn into_entry(self, value: Bytes) -> Result { match self { diff --git a/src/hpack/test.rs b/src/hpack/test.rs index 1d370af..50e4f8f 100644 --- a/src/hpack/test.rs +++ b/src/hpack/test.rs @@ -111,7 +111,7 @@ struct FuzzHpack { #[derive(Debug, Clone)] struct HeaderFrame { resizes: Vec, - headers: Vec
, + headers: Vec>>, } impl FuzzHpack { @@ -124,7 +124,7 @@ impl FuzzHpack { let mut rng = StdRng::from_seed(&seed); // Generates a bunch of source headers - let mut source: Vec
= vec![]; + let mut source: Vec>> = vec![]; for _ in 0..2000 { source.push(gen_header(&mut rng)); @@ -221,7 +221,7 @@ impl FuzzHpack { // Decode the chunk! decoder.decode(&buf.into(), |e| { - assert_eq!(e, expect.remove(0)); + assert_eq!(e, expect.remove(0).reify().unwrap()); }).unwrap(); buf = BytesMut::with_capacity( @@ -232,7 +232,7 @@ impl FuzzHpack { // Decode the chunk! decoder.decode(&buf.into(), |e| { - assert_eq!(e, expect.remove(0)); + assert_eq!(e, expect.remove(0).reify().unwrap()); }).unwrap(); } @@ -246,7 +246,7 @@ impl Arbitrary for FuzzHpack { } } -fn gen_header(g: &mut StdRng) -> Header { +fn gen_header(g: &mut StdRng) -> Header> { use http::StatusCode; use http::method::{self, Method}; @@ -309,7 +309,7 @@ fn gen_header(g: &mut StdRng) -> Header { value.set_sensitive(true); } - Header::Field { name: name, value: value } + Header::Field { name: Some(name), value: value } } } @@ -528,13 +528,13 @@ fn test_story(story: Value) { } let mut input: Vec<_> = case.expect.iter().map(|&(ref name, ref value)| { - Header::new(name.clone().into(), value.clone().into()).unwrap() + Header::new(name.clone().into(), value.clone().into()).unwrap().into() }).collect(); encoder.encode(None, &mut input.clone().into_iter(), &mut buf).unwrap(); decoder.decode(&buf.into(), |e| { - assert_eq!(e, input.remove(0)); + assert_eq!(e, input.remove(0).reify().unwrap()); }).unwrap(); assert_eq!(0, input.len());