add Client config to disable server push
- Adds `Client::builder().enable_push(false)` to disable push - Client sends a GO_AWAY if receiving a push when it's disabled
This commit is contained in:
		| @@ -1,12 +1,12 @@ | ||||
| use super::{StreamDependency, StreamId}; | ||||
| use frame::{self, Error, Frame, Head, Kind}; | ||||
| use frame::{Error, Frame, Head, Kind}; | ||||
| use hpack; | ||||
|  | ||||
| use http::{uri, HeaderMap, Method, StatusCode, Uri}; | ||||
| use http::header::{self, HeaderName, HeaderValue}; | ||||
|  | ||||
| use byteorder::{BigEndian, ByteOrder}; | ||||
| use bytes::{Bytes, BytesMut}; | ||||
| use bytes::{BufMut, Bytes, BytesMut}; | ||||
| use string::String; | ||||
|  | ||||
| use std::fmt; | ||||
| @@ -23,12 +23,8 @@ pub struct Headers { | ||||
|     /// The stream dependency information, if any. | ||||
|     stream_dep: Option<StreamDependency>, | ||||
|  | ||||
|     /// The decoded header fields | ||||
|     fields: HeaderMap, | ||||
|  | ||||
|     /// Pseudo headers, these are broken out as they must be sent as part of the | ||||
|     /// headers frame. | ||||
|     pseudo: Pseudo, | ||||
|     /// The header block fragment | ||||
|     header_block: HeaderBlock, | ||||
|  | ||||
|     /// The associated flags | ||||
|     flags: HeadersFlag, | ||||
| @@ -37,7 +33,7 @@ pub struct Headers { | ||||
| #[derive(Copy, Clone, Eq, PartialEq)] | ||||
| pub struct HeadersFlag(u8); | ||||
|  | ||||
| #[derive(Debug, Eq, PartialEq)] | ||||
| #[derive(Eq, PartialEq)] | ||||
| pub struct PushPromise { | ||||
|     /// The ID of the stream with which this frame is associated. | ||||
|     stream_id: StreamId, | ||||
| @@ -45,11 +41,14 @@ pub struct PushPromise { | ||||
|     /// The ID of the stream being reserved by this PushPromise. | ||||
|     promised_id: StreamId, | ||||
|  | ||||
|     /// The header block fragment | ||||
|     header_block: HeaderBlock, | ||||
|  | ||||
|     /// The associated flags | ||||
|     flags: PushPromiseFlag, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Copy, Clone, Eq, PartialEq)] | ||||
| #[derive(Copy, Clone, Eq, PartialEq)] | ||||
| pub struct PushPromiseFlag(u8); | ||||
|  | ||||
| #[derive(Debug)] | ||||
| @@ -85,6 +84,16 @@ pub struct Iter { | ||||
|     fields: header::IntoIter<HeaderValue>, | ||||
| } | ||||
|  | ||||
| #[derive(PartialEq, Eq)] | ||||
| struct HeaderBlock { | ||||
|     /// The decoded header fields | ||||
|     fields: HeaderMap, | ||||
|  | ||||
|     /// Pseudo headers, these are broken out as they must be sent as part of the | ||||
|     /// headers frame. | ||||
|     pseudo: Pseudo, | ||||
| } | ||||
|  | ||||
| const END_STREAM: u8 = 0x1; | ||||
| const END_HEADERS: u8 = 0x4; | ||||
| const PADDED: u8 = 0x8; | ||||
| @@ -99,8 +108,10 @@ impl Headers { | ||||
|         Headers { | ||||
|             stream_id: stream_id, | ||||
|             stream_dep: None, | ||||
|             fields: fields, | ||||
|             pseudo: pseudo, | ||||
|             header_block: HeaderBlock { | ||||
|                 fields: fields, | ||||
|                 pseudo: pseudo, | ||||
|             }, | ||||
|             flags: HeadersFlag::default(), | ||||
|         } | ||||
|     } | ||||
| @@ -112,8 +123,10 @@ impl Headers { | ||||
|         Headers { | ||||
|             stream_id, | ||||
|             stream_dep: None, | ||||
|             fields: fields, | ||||
|             pseudo: Pseudo::default(), | ||||
|             header_block: HeaderBlock { | ||||
|                 fields: fields, | ||||
|                 pseudo: Pseudo::default(), | ||||
|             }, | ||||
|             flags: flags, | ||||
|         } | ||||
|     } | ||||
| @@ -164,8 +177,10 @@ impl Headers { | ||||
|         let headers = Headers { | ||||
|             stream_id: head.stream_id(), | ||||
|             stream_dep: stream_dep, | ||||
|             fields: HeaderMap::new(), | ||||
|             pseudo: Pseudo::default(), | ||||
|             header_block: HeaderBlock { | ||||
|                 fields: HeaderMap::new(), | ||||
|                 pseudo: Pseudo::default(), | ||||
|             }, | ||||
|             flags: flags, | ||||
|         }; | ||||
|  | ||||
| @@ -181,11 +196,11 @@ impl Headers { | ||||
|                 if reg { | ||||
|                     trace!("load_hpack; header malformed -- pseudo not at head of block"); | ||||
|                     malformed = true; | ||||
|                 } else if self.pseudo.$field.is_some() { | ||||
|                 } else if self.header_block.pseudo.$field.is_some() { | ||||
|                     trace!("load_hpack; header malformed -- repeated pseudo"); | ||||
|                     malformed = true; | ||||
|                 } else { | ||||
|                     self.pseudo.$field = Some($val); | ||||
|                     self.header_block.pseudo.$field = Some($val); | ||||
|                 } | ||||
|             }} | ||||
|         } | ||||
| @@ -216,7 +231,7 @@ impl Headers { | ||||
|                         malformed = true; | ||||
|                     } else { | ||||
|                         reg = true; | ||||
|                         self.fields.append(name, value); | ||||
|                         self.header_block.fields.append(name, value); | ||||
|                     } | ||||
|                 }, | ||||
|                 Authority(v) => set_pseudo!(authority, v), | ||||
| @@ -257,15 +272,15 @@ impl Headers { | ||||
|     } | ||||
|  | ||||
|     pub fn into_parts(self) -> (Pseudo, HeaderMap) { | ||||
|         (self.pseudo, self.fields) | ||||
|         (self.header_block.pseudo, self.header_block.fields) | ||||
|     } | ||||
|  | ||||
|     pub fn fields(&self) -> &HeaderMap { | ||||
|         &self.fields | ||||
|         &self.header_block.fields | ||||
|     } | ||||
|  | ||||
|     pub fn into_fields(self) -> HeaderMap { | ||||
|         self.fields | ||||
|         self.header_block.fields | ||||
|     } | ||||
|  | ||||
|     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { | ||||
| @@ -278,27 +293,12 @@ impl Headers { | ||||
|         head.encode(0, dst); | ||||
|  | ||||
|         // Encode the frame | ||||
|         let mut headers = Iter { | ||||
|             pseudo: Some(self.pseudo), | ||||
|             fields: self.fields.into_iter(), | ||||
|         }; | ||||
|  | ||||
|         let ret = match encoder.encode(None, &mut headers, dst) { | ||||
|             hpack::Encode::Full => None, | ||||
|             hpack::Encode::Partial(state) => Some(Continuation { | ||||
|                 stream_id: self.stream_id, | ||||
|                 hpack: state, | ||||
|                 headers: headers, | ||||
|             }), | ||||
|         }; | ||||
|  | ||||
|         // Compute the frame length | ||||
|         let len = (dst.len() - pos) - frame::HEADER_LEN; | ||||
|         let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); | ||||
|  | ||||
|         // Write the frame length | ||||
|         BigEndian::write_uint(&mut dst[pos..pos + 3], len as u64, 3); | ||||
|         BigEndian::write_uint(&mut dst[pos..pos + 3], len, 3); | ||||
|  | ||||
|         ret | ||||
|         cont | ||||
|     } | ||||
|  | ||||
|     fn head(&self) -> Head { | ||||
| @@ -326,6 +326,23 @@ impl fmt::Debug for Headers { | ||||
| // ===== impl PushPromise ===== | ||||
|  | ||||
| impl PushPromise { | ||||
|     pub fn new( | ||||
|         stream_id: StreamId, | ||||
|         promised_id: StreamId, | ||||
|         pseudo: Pseudo, | ||||
|         fields: HeaderMap, | ||||
|     ) -> Self { | ||||
|         PushPromise { | ||||
|             flags: PushPromiseFlag::default(), | ||||
|             header_block: HeaderBlock { | ||||
|                 fields, | ||||
|                 pseudo, | ||||
|             }, | ||||
|             promised_id, | ||||
|             stream_id, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn load(head: Head, payload: &[u8]) -> Result<Self, Error> { | ||||
|         let flags = PushPromiseFlag(head.flag()); | ||||
|  | ||||
| @@ -334,9 +351,13 @@ impl PushPromise { | ||||
|         let (promised_id, _) = StreamId::parse(&payload[..4]); | ||||
|  | ||||
|         Ok(PushPromise { | ||||
|             stream_id: head.stream_id(), | ||||
|             promised_id: promised_id, | ||||
|             flags: flags, | ||||
|             header_block: HeaderBlock { | ||||
|                 fields: HeaderMap::new(), | ||||
|                 pseudo: Pseudo::default(), | ||||
|             }, | ||||
|             promised_id: promised_id, | ||||
|             stream_id: head.stream_id(), | ||||
|         }) | ||||
|     } | ||||
|  | ||||
| @@ -347,6 +368,45 @@ impl PushPromise { | ||||
|     pub fn promised_id(&self) -> StreamId { | ||||
|         self.promised_id | ||||
|     } | ||||
|  | ||||
|     pub fn is_end_headers(&self) -> bool { | ||||
|         self.flags.is_end_headers() | ||||
|     } | ||||
|  | ||||
|     pub fn into_parts(self) -> (Pseudo, HeaderMap) { | ||||
|         (self.header_block.pseudo, self.header_block.fields) | ||||
|     } | ||||
|  | ||||
|     pub fn fields(&self) -> &HeaderMap { | ||||
|         &self.header_block.fields | ||||
|     } | ||||
|  | ||||
|     pub fn into_fields(self) -> HeaderMap { | ||||
|         self.header_block.fields | ||||
|     } | ||||
|  | ||||
|     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { | ||||
|         let head = self.head(); | ||||
|         let pos = dst.len(); | ||||
|  | ||||
|         // At this point, we don't know how big the h2 frame will be. | ||||
|         // So, we write the head with length 0, then write the body, and | ||||
|         // finally write the length once we know the size. | ||||
|         head.encode(0, dst); | ||||
|  | ||||
|         // Encode the frame | ||||
|         dst.put_u32::<BigEndian>(self.promised_id.into()); | ||||
|         let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); | ||||
|  | ||||
|         // Write the frame length | ||||
|         BigEndian::write_uint(&mut dst[pos..pos + 3], len + 4, 3); | ||||
|  | ||||
|         cont | ||||
|     } | ||||
|  | ||||
|     fn head(&self) -> Head { | ||||
|         Head::new(Kind::PushPromise, self.flags.into(), self.stream_id) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<T> From<PushPromise> for Frame<T> { | ||||
| @@ -355,6 +415,17 @@ impl<T> From<PushPromise> for Frame<T> { | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl fmt::Debug for PushPromise { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||||
|         f.debug_struct("PushPromise") | ||||
|             .field("stream_id", &self.stream_id) | ||||
|             .field("promised_id", &self.promised_id) | ||||
|             .field("flags", &self.flags) | ||||
|             // `fields` and `pseudo` purposefully not included | ||||
|             .finish() | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ===== impl Pseudo ===== | ||||
|  | ||||
| impl Pseudo { | ||||
| @@ -509,3 +580,76 @@ impl fmt::Debug for HeadersFlag { | ||||
|             .finish() | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ===== impl PushPromiseFlag ===== | ||||
|  | ||||
| impl PushPromiseFlag { | ||||
|     pub fn empty() -> PushPromiseFlag { | ||||
|         PushPromiseFlag(0) | ||||
|     } | ||||
|  | ||||
|     pub fn load(bits: u8) -> PushPromiseFlag { | ||||
|         PushPromiseFlag(bits & ALL) | ||||
|     } | ||||
|  | ||||
|     pub fn is_end_headers(&self) -> bool { | ||||
|         self.0 & END_HEADERS == END_HEADERS | ||||
|     } | ||||
|  | ||||
|     pub fn is_padded(&self) -> bool { | ||||
|         self.0 & PADDED == PADDED | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Default for PushPromiseFlag { | ||||
|     /// Returns a `PushPromiseFlag` value with `END_HEADERS` set. | ||||
|     fn default() -> Self { | ||||
|         PushPromiseFlag(END_HEADERS) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<PushPromiseFlag> for u8 { | ||||
|     fn from(src: PushPromiseFlag) -> u8 { | ||||
|         src.0 | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl fmt::Debug for PushPromiseFlag { | ||||
|     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { | ||||
|         fmt.debug_struct("PushPromiseFlag") | ||||
|             .field("end_headers", &self.is_end_headers()) | ||||
|             .field("padded", &self.is_padded()) | ||||
|             .finish() | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ===== HeaderBlock ===== | ||||
|  | ||||
| impl HeaderBlock { | ||||
|     fn encode( | ||||
|         self, | ||||
|         stream_id: StreamId, | ||||
|         encoder: &mut hpack::Encoder, | ||||
|         dst: &mut BytesMut, | ||||
|     ) -> (u64, Option<Continuation>) { | ||||
|         let pos = dst.len(); | ||||
|         let mut headers = Iter { | ||||
|             pseudo: Some(self.pseudo), | ||||
|             fields: self.fields.into_iter(), | ||||
|         }; | ||||
|  | ||||
|         let cont = match encoder.encode(None, &mut headers, dst) { | ||||
|             hpack::Encode::Full => None, | ||||
|             hpack::Encode::Partial(state) => Some(Continuation { | ||||
|                 stream_id: stream_id, | ||||
|                 hpack: state, | ||||
|                 headers: headers, | ||||
|             }), | ||||
|         }; | ||||
|  | ||||
|         // Compute the header block length | ||||
|         let len = (dst.len() - pos) as u64; | ||||
|  | ||||
|         (len, cont) | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user