add Client config to disable server push
- Adds `Client::builder().enable_push(false)` to disable push - Client sends a GO_AWAY if receiving a push when it's disabled
This commit is contained in:
		| @@ -176,6 +176,12 @@ impl Builder { | |||||||
|         self |         self | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     /// Enable or disable the server to send push promises. | ||||||
|  |     pub fn enable_push(&mut self, enabled: bool) -> &mut Self { | ||||||
|  |         self.settings.set_enable_push(enabled); | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|     /// Bind an H2 client connection. |     /// Bind an H2 client connection. | ||||||
|     /// |     /// | ||||||
|     /// Returns a future which resolves to the connection value once the H2 |     /// Returns a future which resolves to the connection value once the H2 | ||||||
|   | |||||||
| @@ -127,8 +127,9 @@ where | |||||||
|                 } |                 } | ||||||
|             }, |             }, | ||||||
|             Frame::PushPromise(v) => { |             Frame::PushPromise(v) => { | ||||||
|                 debug!("unimplemented PUSH_PROMISE write; frame={:?}", v); |                 if let Some(continuation) = v.encode(&mut self.hpack, self.buf.get_mut()) { | ||||||
|                 unimplemented!(); |                     self.next = Some(Next::Continuation(continuation)); | ||||||
|  |                 } | ||||||
|             }, |             }, | ||||||
|             Frame::Settings(v) => { |             Frame::Settings(v) => { | ||||||
|                 v.encode(self.buf.get_mut()); |                 v.encode(self.buf.get_mut()); | ||||||
|   | |||||||
| @@ -1,12 +1,12 @@ | |||||||
| use super::{StreamDependency, StreamId}; | use super::{StreamDependency, StreamId}; | ||||||
| use frame::{self, Error, Frame, Head, Kind}; | use frame::{Error, Frame, Head, Kind}; | ||||||
| use hpack; | use hpack; | ||||||
|  |  | ||||||
| use http::{uri, HeaderMap, Method, StatusCode, Uri}; | use http::{uri, HeaderMap, Method, StatusCode, Uri}; | ||||||
| use http::header::{self, HeaderName, HeaderValue}; | use http::header::{self, HeaderName, HeaderValue}; | ||||||
|  |  | ||||||
| use byteorder::{BigEndian, ByteOrder}; | use byteorder::{BigEndian, ByteOrder}; | ||||||
| use bytes::{Bytes, BytesMut}; | use bytes::{BufMut, Bytes, BytesMut}; | ||||||
| use string::String; | use string::String; | ||||||
|  |  | ||||||
| use std::fmt; | use std::fmt; | ||||||
| @@ -23,12 +23,8 @@ pub struct Headers { | |||||||
|     /// The stream dependency information, if any. |     /// The stream dependency information, if any. | ||||||
|     stream_dep: Option<StreamDependency>, |     stream_dep: Option<StreamDependency>, | ||||||
|  |  | ||||||
|     /// The decoded header fields |     /// The header block fragment | ||||||
|     fields: HeaderMap, |     header_block: HeaderBlock, | ||||||
|  |  | ||||||
|     /// Pseudo headers, these are broken out as they must be sent as part of the |  | ||||||
|     /// headers frame. |  | ||||||
|     pseudo: Pseudo, |  | ||||||
|  |  | ||||||
|     /// The associated flags |     /// The associated flags | ||||||
|     flags: HeadersFlag, |     flags: HeadersFlag, | ||||||
| @@ -37,7 +33,7 @@ pub struct Headers { | |||||||
| #[derive(Copy, Clone, Eq, PartialEq)] | #[derive(Copy, Clone, Eq, PartialEq)] | ||||||
| pub struct HeadersFlag(u8); | pub struct HeadersFlag(u8); | ||||||
|  |  | ||||||
| #[derive(Debug, Eq, PartialEq)] | #[derive(Eq, PartialEq)] | ||||||
| pub struct PushPromise { | pub struct PushPromise { | ||||||
|     /// The ID of the stream with which this frame is associated. |     /// The ID of the stream with which this frame is associated. | ||||||
|     stream_id: StreamId, |     stream_id: StreamId, | ||||||
| @@ -45,11 +41,14 @@ pub struct PushPromise { | |||||||
|     /// The ID of the stream being reserved by this PushPromise. |     /// The ID of the stream being reserved by this PushPromise. | ||||||
|     promised_id: StreamId, |     promised_id: StreamId, | ||||||
|  |  | ||||||
|  |     /// The header block fragment | ||||||
|  |     header_block: HeaderBlock, | ||||||
|  |  | ||||||
|     /// The associated flags |     /// The associated flags | ||||||
|     flags: PushPromiseFlag, |     flags: PushPromiseFlag, | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Debug, Copy, Clone, Eq, PartialEq)] | #[derive(Copy, Clone, Eq, PartialEq)] | ||||||
| pub struct PushPromiseFlag(u8); | pub struct PushPromiseFlag(u8); | ||||||
|  |  | ||||||
| #[derive(Debug)] | #[derive(Debug)] | ||||||
| @@ -85,6 +84,16 @@ pub struct Iter { | |||||||
|     fields: header::IntoIter<HeaderValue>, |     fields: header::IntoIter<HeaderValue>, | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[derive(PartialEq, Eq)] | ||||||
|  | struct HeaderBlock { | ||||||
|  |     /// The decoded header fields | ||||||
|  |     fields: HeaderMap, | ||||||
|  |  | ||||||
|  |     /// Pseudo headers, these are broken out as they must be sent as part of the | ||||||
|  |     /// headers frame. | ||||||
|  |     pseudo: Pseudo, | ||||||
|  | } | ||||||
|  |  | ||||||
| const END_STREAM: u8 = 0x1; | const END_STREAM: u8 = 0x1; | ||||||
| const END_HEADERS: u8 = 0x4; | const END_HEADERS: u8 = 0x4; | ||||||
| const PADDED: u8 = 0x8; | const PADDED: u8 = 0x8; | ||||||
| @@ -99,8 +108,10 @@ impl Headers { | |||||||
|         Headers { |         Headers { | ||||||
|             stream_id: stream_id, |             stream_id: stream_id, | ||||||
|             stream_dep: None, |             stream_dep: None, | ||||||
|             fields: fields, |             header_block: HeaderBlock { | ||||||
|             pseudo: pseudo, |                 fields: fields, | ||||||
|  |                 pseudo: pseudo, | ||||||
|  |             }, | ||||||
|             flags: HeadersFlag::default(), |             flags: HeadersFlag::default(), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -112,8 +123,10 @@ impl Headers { | |||||||
|         Headers { |         Headers { | ||||||
|             stream_id, |             stream_id, | ||||||
|             stream_dep: None, |             stream_dep: None, | ||||||
|             fields: fields, |             header_block: HeaderBlock { | ||||||
|             pseudo: Pseudo::default(), |                 fields: fields, | ||||||
|  |                 pseudo: Pseudo::default(), | ||||||
|  |             }, | ||||||
|             flags: flags, |             flags: flags, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -164,8 +177,10 @@ impl Headers { | |||||||
|         let headers = Headers { |         let headers = Headers { | ||||||
|             stream_id: head.stream_id(), |             stream_id: head.stream_id(), | ||||||
|             stream_dep: stream_dep, |             stream_dep: stream_dep, | ||||||
|             fields: HeaderMap::new(), |             header_block: HeaderBlock { | ||||||
|             pseudo: Pseudo::default(), |                 fields: HeaderMap::new(), | ||||||
|  |                 pseudo: Pseudo::default(), | ||||||
|  |             }, | ||||||
|             flags: flags, |             flags: flags, | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
| @@ -181,11 +196,11 @@ impl Headers { | |||||||
|                 if reg { |                 if reg { | ||||||
|                     trace!("load_hpack; header malformed -- pseudo not at head of block"); |                     trace!("load_hpack; header malformed -- pseudo not at head of block"); | ||||||
|                     malformed = true; |                     malformed = true; | ||||||
|                 } else if self.pseudo.$field.is_some() { |                 } else if self.header_block.pseudo.$field.is_some() { | ||||||
|                     trace!("load_hpack; header malformed -- repeated pseudo"); |                     trace!("load_hpack; header malformed -- repeated pseudo"); | ||||||
|                     malformed = true; |                     malformed = true; | ||||||
|                 } else { |                 } else { | ||||||
|                     self.pseudo.$field = Some($val); |                     self.header_block.pseudo.$field = Some($val); | ||||||
|                 } |                 } | ||||||
|             }} |             }} | ||||||
|         } |         } | ||||||
| @@ -216,7 +231,7 @@ impl Headers { | |||||||
|                         malformed = true; |                         malformed = true; | ||||||
|                     } else { |                     } else { | ||||||
|                         reg = true; |                         reg = true; | ||||||
|                         self.fields.append(name, value); |                         self.header_block.fields.append(name, value); | ||||||
|                     } |                     } | ||||||
|                 }, |                 }, | ||||||
|                 Authority(v) => set_pseudo!(authority, v), |                 Authority(v) => set_pseudo!(authority, v), | ||||||
| @@ -257,15 +272,15 @@ impl Headers { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn into_parts(self) -> (Pseudo, HeaderMap) { |     pub fn into_parts(self) -> (Pseudo, HeaderMap) { | ||||||
|         (self.pseudo, self.fields) |         (self.header_block.pseudo, self.header_block.fields) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn fields(&self) -> &HeaderMap { |     pub fn fields(&self) -> &HeaderMap { | ||||||
|         &self.fields |         &self.header_block.fields | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn into_fields(self) -> HeaderMap { |     pub fn into_fields(self) -> HeaderMap { | ||||||
|         self.fields |         self.header_block.fields | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { |     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { | ||||||
| @@ -278,27 +293,12 @@ impl Headers { | |||||||
|         head.encode(0, dst); |         head.encode(0, dst); | ||||||
|  |  | ||||||
|         // Encode the frame |         // Encode the frame | ||||||
|         let mut headers = Iter { |         let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); | ||||||
|             pseudo: Some(self.pseudo), |  | ||||||
|             fields: self.fields.into_iter(), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         let ret = match encoder.encode(None, &mut headers, dst) { |  | ||||||
|             hpack::Encode::Full => None, |  | ||||||
|             hpack::Encode::Partial(state) => Some(Continuation { |  | ||||||
|                 stream_id: self.stream_id, |  | ||||||
|                 hpack: state, |  | ||||||
|                 headers: headers, |  | ||||||
|             }), |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         // Compute the frame length |  | ||||||
|         let len = (dst.len() - pos) - frame::HEADER_LEN; |  | ||||||
|  |  | ||||||
|         // Write the frame length |         // Write the frame length | ||||||
|         BigEndian::write_uint(&mut dst[pos..pos + 3], len as u64, 3); |         BigEndian::write_uint(&mut dst[pos..pos + 3], len, 3); | ||||||
|  |  | ||||||
|         ret |         cont | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     fn head(&self) -> Head { |     fn head(&self) -> Head { | ||||||
| @@ -326,6 +326,23 @@ impl fmt::Debug for Headers { | |||||||
| // ===== impl PushPromise ===== | // ===== impl PushPromise ===== | ||||||
|  |  | ||||||
| impl PushPromise { | impl PushPromise { | ||||||
|  |     pub fn new( | ||||||
|  |         stream_id: StreamId, | ||||||
|  |         promised_id: StreamId, | ||||||
|  |         pseudo: Pseudo, | ||||||
|  |         fields: HeaderMap, | ||||||
|  |     ) -> Self { | ||||||
|  |         PushPromise { | ||||||
|  |             flags: PushPromiseFlag::default(), | ||||||
|  |             header_block: HeaderBlock { | ||||||
|  |                 fields, | ||||||
|  |                 pseudo, | ||||||
|  |             }, | ||||||
|  |             promised_id, | ||||||
|  |             stream_id, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn load(head: Head, payload: &[u8]) -> Result<Self, Error> { |     pub fn load(head: Head, payload: &[u8]) -> Result<Self, Error> { | ||||||
|         let flags = PushPromiseFlag(head.flag()); |         let flags = PushPromiseFlag(head.flag()); | ||||||
|  |  | ||||||
| @@ -334,9 +351,13 @@ impl PushPromise { | |||||||
|         let (promised_id, _) = StreamId::parse(&payload[..4]); |         let (promised_id, _) = StreamId::parse(&payload[..4]); | ||||||
|  |  | ||||||
|         Ok(PushPromise { |         Ok(PushPromise { | ||||||
|             stream_id: head.stream_id(), |  | ||||||
|             promised_id: promised_id, |  | ||||||
|             flags: flags, |             flags: flags, | ||||||
|  |             header_block: HeaderBlock { | ||||||
|  |                 fields: HeaderMap::new(), | ||||||
|  |                 pseudo: Pseudo::default(), | ||||||
|  |             }, | ||||||
|  |             promised_id: promised_id, | ||||||
|  |             stream_id: head.stream_id(), | ||||||
|         }) |         }) | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -347,6 +368,45 @@ impl PushPromise { | |||||||
|     pub fn promised_id(&self) -> StreamId { |     pub fn promised_id(&self) -> StreamId { | ||||||
|         self.promised_id |         self.promised_id | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn is_end_headers(&self) -> bool { | ||||||
|  |         self.flags.is_end_headers() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn into_parts(self) -> (Pseudo, HeaderMap) { | ||||||
|  |         (self.header_block.pseudo, self.header_block.fields) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn fields(&self) -> &HeaderMap { | ||||||
|  |         &self.header_block.fields | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn into_fields(self) -> HeaderMap { | ||||||
|  |         self.header_block.fields | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { | ||||||
|  |         let head = self.head(); | ||||||
|  |         let pos = dst.len(); | ||||||
|  |  | ||||||
|  |         // At this point, we don't know how big the h2 frame will be. | ||||||
|  |         // So, we write the head with length 0, then write the body, and | ||||||
|  |         // finally write the length once we know the size. | ||||||
|  |         head.encode(0, dst); | ||||||
|  |  | ||||||
|  |         // Encode the frame | ||||||
|  |         dst.put_u32::<BigEndian>(self.promised_id.into()); | ||||||
|  |         let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); | ||||||
|  |  | ||||||
|  |         // Write the frame length | ||||||
|  |         BigEndian::write_uint(&mut dst[pos..pos + 3], len + 4, 3); | ||||||
|  |  | ||||||
|  |         cont | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn head(&self) -> Head { | ||||||
|  |         Head::new(Kind::PushPromise, self.flags.into(), self.stream_id) | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| impl<T> From<PushPromise> for Frame<T> { | impl<T> From<PushPromise> for Frame<T> { | ||||||
| @@ -355,6 +415,17 @@ impl<T> From<PushPromise> for Frame<T> { | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | impl fmt::Debug for PushPromise { | ||||||
|  |     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||||||
|  |         f.debug_struct("PushPromise") | ||||||
|  |             .field("stream_id", &self.stream_id) | ||||||
|  |             .field("promised_id", &self.promised_id) | ||||||
|  |             .field("flags", &self.flags) | ||||||
|  |             // `fields` and `pseudo` purposefully not included | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| // ===== impl Pseudo ===== | // ===== impl Pseudo ===== | ||||||
|  |  | ||||||
| impl Pseudo { | impl Pseudo { | ||||||
| @@ -509,3 +580,76 @@ impl fmt::Debug for HeadersFlag { | |||||||
|             .finish() |             .finish() | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // ===== impl PushPromiseFlag ===== | ||||||
|  |  | ||||||
|  | impl PushPromiseFlag { | ||||||
|  |     pub fn empty() -> PushPromiseFlag { | ||||||
|  |         PushPromiseFlag(0) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn load(bits: u8) -> PushPromiseFlag { | ||||||
|  |         PushPromiseFlag(bits & ALL) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn is_end_headers(&self) -> bool { | ||||||
|  |         self.0 & END_HEADERS == END_HEADERS | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn is_padded(&self) -> bool { | ||||||
|  |         self.0 & PADDED == PADDED | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Default for PushPromiseFlag { | ||||||
|  |     /// Returns a `PushPromiseFlag` value with `END_HEADERS` set. | ||||||
|  |     fn default() -> Self { | ||||||
|  |         PushPromiseFlag(END_HEADERS) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<PushPromiseFlag> for u8 { | ||||||
|  |     fn from(src: PushPromiseFlag) -> u8 { | ||||||
|  |         src.0 | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl fmt::Debug for PushPromiseFlag { | ||||||
|  |     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { | ||||||
|  |         fmt.debug_struct("PushPromiseFlag") | ||||||
|  |             .field("end_headers", &self.is_end_headers()) | ||||||
|  |             .field("padded", &self.is_padded()) | ||||||
|  |             .finish() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ===== HeaderBlock ===== | ||||||
|  |  | ||||||
|  | impl HeaderBlock { | ||||||
|  |     fn encode( | ||||||
|  |         self, | ||||||
|  |         stream_id: StreamId, | ||||||
|  |         encoder: &mut hpack::Encoder, | ||||||
|  |         dst: &mut BytesMut, | ||||||
|  |     ) -> (u64, Option<Continuation>) { | ||||||
|  |         let pos = dst.len(); | ||||||
|  |         let mut headers = Iter { | ||||||
|  |             pseudo: Some(self.pseudo), | ||||||
|  |             fields: self.fields.into_iter(), | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         let cont = match encoder.encode(None, &mut headers, dst) { | ||||||
|  |             hpack::Encode::Full => None, | ||||||
|  |             hpack::Encode::Partial(state) => Some(Continuation { | ||||||
|  |                 stream_id: stream_id, | ||||||
|  |                 hpack: state, | ||||||
|  |                 headers: headers, | ||||||
|  |             }), | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         // Compute the header block length | ||||||
|  |         let len = (dst.len() - pos) as u64; | ||||||
|  |  | ||||||
|  |         (len, cont) | ||||||
|  |     } | ||||||
|  | } | ||||||
|   | |||||||
| @@ -85,6 +85,14 @@ impl Settings { | |||||||
|         self.max_frame_size = size; |         self.max_frame_size = size; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn is_push_enabled(&self) -> bool { | ||||||
|  |         self.enable_push.unwrap_or(1) != 0 | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn set_enable_push(&mut self, enable: bool) { | ||||||
|  |         self.enable_push = Some(enable as u32); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn load(head: Head, payload: &[u8]) -> Result<Settings, Error> { |     pub fn load(head: Head, payload: &[u8]) -> Result<Settings, Error> { | ||||||
|         use self::Setting::*; |         use self::Setting::*; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -64,12 +64,13 @@ where | |||||||
|     ) -> Connection<T, P, B> { |     ) -> Connection<T, P, B> { | ||||||
|         // TODO: Actually configure |         // TODO: Actually configure | ||||||
|         let streams = Streams::new(streams::Config { |         let streams = Streams::new(streams::Config { | ||||||
|             max_remote_initiated: None, |             local_init_window_sz: settings | ||||||
|             init_remote_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, |  | ||||||
|             max_local_initiated: None, |  | ||||||
|             init_local_window_sz: settings |  | ||||||
|                 .initial_window_size() |                 .initial_window_size() | ||||||
|                 .unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE), |                 .unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE), | ||||||
|  |             local_max_initiated: None, | ||||||
|  |             local_push_enabled: settings.is_push_enabled(), | ||||||
|  |             remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, | ||||||
|  |             remote_max_initiated: None, | ||||||
|         }); |         }); | ||||||
|         Connection { |         Connection { | ||||||
|             state: State::Open, |             state: State::Open, | ||||||
|   | |||||||
| @@ -34,9 +34,9 @@ where | |||||||
|     /// Create a new `Counts` using the provided configuration values. |     /// Create a new `Counts` using the provided configuration values. | ||||||
|     pub fn new(config: &Config) -> Self { |     pub fn new(config: &Config) -> Self { | ||||||
|         Counts { |         Counts { | ||||||
|             max_send_streams: config.max_local_initiated, |             max_send_streams: config.local_max_initiated, | ||||||
|             num_send_streams: 0, |             num_send_streams: 0, | ||||||
|             max_recv_streams: config.max_remote_initiated, |             max_recv_streams: config.remote_max_initiated, | ||||||
|             num_recv_streams: 0, |             num_recv_streams: 0, | ||||||
|             blocked_open: None, |             blocked_open: None, | ||||||
|             _p: PhantomData, |             _p: PhantomData, | ||||||
|   | |||||||
| @@ -31,15 +31,18 @@ use http::{Request, Response}; | |||||||
|  |  | ||||||
| #[derive(Debug)] | #[derive(Debug)] | ||||||
| pub struct Config { | pub struct Config { | ||||||
|     /// Maximum number of remote initiated streams |     /// Initial window size of locally initiated streams | ||||||
|     pub max_remote_initiated: Option<usize>, |     pub local_init_window_sz: WindowSize, | ||||||
|  |  | ||||||
|     /// Initial window size of remote initiated streams |  | ||||||
|     pub init_remote_window_sz: WindowSize, |  | ||||||
|  |  | ||||||
|     /// Maximum number of locally initiated streams |     /// Maximum number of locally initiated streams | ||||||
|     pub max_local_initiated: Option<usize>, |     pub local_max_initiated: Option<usize>, | ||||||
|  |  | ||||||
|     /// Initial window size of locally initiated streams |     /// If the local peer is willing to receive push promises | ||||||
|     pub init_local_window_sz: WindowSize, |     pub local_push_enabled: bool, | ||||||
|  |  | ||||||
|  |     /// Initial window size of remote initiated streams | ||||||
|  |     pub remote_init_window_sz: WindowSize, | ||||||
|  |  | ||||||
|  |     /// Maximum number of remote initiated streams | ||||||
|  |     pub remote_max_initiated: Option<usize>, | ||||||
| } | } | ||||||
|   | |||||||
| @@ -49,11 +49,11 @@ where | |||||||
|     pub fn new(config: &Config) -> Prioritize<B, P> { |     pub fn new(config: &Config) -> Prioritize<B, P> { | ||||||
|         let mut flow = FlowControl::new(); |         let mut flow = FlowControl::new(); | ||||||
|  |  | ||||||
|         flow.inc_window(config.init_local_window_sz) |         flow.inc_window(config.local_init_window_sz) | ||||||
|             .ok() |             .ok() | ||||||
|             .expect("invalid initial window size"); |             .expect("invalid initial window size"); | ||||||
|  |  | ||||||
|         flow.assign_capacity(config.init_local_window_sz); |         flow.assign_capacity(config.local_init_window_sz); | ||||||
|  |  | ||||||
|         trace!("Prioritize::new; flow={:?}", flow); |         trace!("Prioritize::new; flow={:?}", flow); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -38,6 +38,9 @@ where | |||||||
|     /// Refused StreamId, this represents a frame that must be sent out. |     /// Refused StreamId, this represents a frame that must be sent out. | ||||||
|     refused: Option<StreamId>, |     refused: Option<StreamId>, | ||||||
|  |  | ||||||
|  |     /// If push promises are allowed to be recevied. | ||||||
|  |     is_push_enabled: bool, | ||||||
|  |  | ||||||
|     _p: PhantomData<B>, |     _p: PhantomData<B>, | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -71,7 +74,7 @@ where | |||||||
|         flow.assign_capacity(DEFAULT_INITIAL_WINDOW_SIZE); |         flow.assign_capacity(DEFAULT_INITIAL_WINDOW_SIZE); | ||||||
|  |  | ||||||
|         Recv { |         Recv { | ||||||
|             init_window_sz: config.init_local_window_sz, |             init_window_sz: config.local_init_window_sz, | ||||||
|             flow: flow, |             flow: flow, | ||||||
|             next_stream_id: next_stream_id.into(), |             next_stream_id: next_stream_id.into(), | ||||||
|             pending_window_updates: store::Queue::new(), |             pending_window_updates: store::Queue::new(), | ||||||
| @@ -79,6 +82,7 @@ where | |||||||
|             pending_accept: store::Queue::new(), |             pending_accept: store::Queue::new(), | ||||||
|             buffer: Buffer::new(), |             buffer: Buffer::new(), | ||||||
|             refused: None, |             refused: None, | ||||||
|  |             is_push_enabled: config.local_push_enabled, | ||||||
|             _p: PhantomData, |             _p: PhantomData, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -429,10 +433,20 @@ where | |||||||
|         // TODO: Are there other rules? |         // TODO: Are there other rules? | ||||||
|         if P::is_server() { |         if P::is_server() { | ||||||
|             // The remote is a client and cannot reserve |             // The remote is a client and cannot reserve | ||||||
|  |             trace!("recv_push_promise; error remote is client"); | ||||||
|             return Err(RecvError::Connection(ProtocolError)); |             return Err(RecvError::Connection(ProtocolError)); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if !promised_id.is_server_initiated() { |         if !promised_id.is_server_initiated() { | ||||||
|  |             trace!( | ||||||
|  |                 "recv_push_promise; error promised id is invalid {:?}", | ||||||
|  |                 promised_id | ||||||
|  |             ); | ||||||
|  |             return Err(RecvError::Connection(ProtocolError)); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if !self.is_push_enabled { | ||||||
|  |             trace!("recv_push_promise; error push is disabled"); | ||||||
|             return Err(RecvError::Connection(ProtocolError)); |             return Err(RecvError::Connection(ProtocolError)); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -35,7 +35,7 @@ where | |||||||
|  |  | ||||||
|         Send { |         Send { | ||||||
|             next_stream_id: next_stream_id.into(), |             next_stream_id: next_stream_id.into(), | ||||||
|             init_window_sz: config.init_local_window_sz, |             init_window_sz: config.local_init_window_sz, | ||||||
|             prioritize: Prioritize::new(config), |             prioritize: Prioritize::new(config), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -285,6 +285,7 @@ impl State { | |||||||
|                 .. |                 .. | ||||||
|             } => true, |             } => true, | ||||||
|             HalfClosedLocal(AwaitingHeaders) => true, |             HalfClosedLocal(AwaitingHeaders) => true, | ||||||
|  |             ReservedRemote => true, | ||||||
|             _ => false, |             _ => false, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|   | |||||||
							
								
								
									
										109
									
								
								tests/push_promise.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								tests/push_promise.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,109 @@ | |||||||
|  | extern crate h2_test_support; | ||||||
|  | use h2_test_support::prelude::*; | ||||||
|  |  | ||||||
|  | #[test] | ||||||
|  | fn recv_push_works() { | ||||||
|  |     // tests that by default, received push promises work | ||||||
|  |     // TODO: once API exists, read the pushed response | ||||||
|  |     let _ = ::env_logger::init(); | ||||||
|  |  | ||||||
|  |     let (io, srv) = mock::new(); | ||||||
|  |     let mock = srv.assert_client_handshake() | ||||||
|  |         .unwrap() | ||||||
|  |         .recv_settings() | ||||||
|  |         .recv_frame( | ||||||
|  |             frames::headers(1) | ||||||
|  |                 .request("GET", "https://http2.akamai.com/") | ||||||
|  |                 .eos(), | ||||||
|  |         ) | ||||||
|  |         .send_frame( | ||||||
|  |             frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css"), | ||||||
|  |         ) | ||||||
|  |         .send_frame(frames::headers(1).response(200).eos()) | ||||||
|  |         .send_frame(frames::headers(2).response(200).eos()); | ||||||
|  |  | ||||||
|  |     let h2 = Client::handshake(io).unwrap().and_then(|mut h2| { | ||||||
|  |         let request = Request::builder() | ||||||
|  |             .method(Method::GET) | ||||||
|  |             .uri("https://http2.akamai.com/") | ||||||
|  |             .body(()) | ||||||
|  |             .unwrap(); | ||||||
|  |         let req = h2.request(request, true) | ||||||
|  |             .unwrap() | ||||||
|  |             .unwrap() | ||||||
|  |             .and_then(|resp| { | ||||||
|  |                 assert_eq!(resp.status(), StatusCode::OK); | ||||||
|  |                 Ok(()) | ||||||
|  |             }); | ||||||
|  |  | ||||||
|  |         h2.drive(req) | ||||||
|  |     }); | ||||||
|  |  | ||||||
|  |     h2.join(mock).wait().unwrap(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[test] | ||||||
|  | fn recv_push_when_push_disabled_is_conn_error() { | ||||||
|  |     let _ = ::env_logger::init(); | ||||||
|  |  | ||||||
|  |     let (io, srv) = mock::new(); | ||||||
|  |     let mock = srv.assert_client_handshake() | ||||||
|  |         .unwrap() | ||||||
|  |         .ignore_settings() | ||||||
|  |         .recv_frame( | ||||||
|  |             frames::headers(1) | ||||||
|  |                 .request("GET", "https://http2.akamai.com/") | ||||||
|  |                 .eos(), | ||||||
|  |         ) | ||||||
|  |         .send_frame( | ||||||
|  |             frames::push_promise(1, 3).request("GET", "https://http2.akamai.com/style.css"), | ||||||
|  |         ) | ||||||
|  |         .send_frame(frames::headers(1).response(200).eos()) | ||||||
|  |         .recv_frame(frames::go_away(0).protocol_error()); | ||||||
|  |  | ||||||
|  |     let h2 = Client::builder() | ||||||
|  |         .enable_push(false) | ||||||
|  |         .handshake::<_, Bytes>(io) | ||||||
|  |         .unwrap() | ||||||
|  |         .and_then(|mut h2| { | ||||||
|  |             let request = Request::builder() | ||||||
|  |                 .method(Method::GET) | ||||||
|  |                 .uri("https://http2.akamai.com/") | ||||||
|  |                 .body(()) | ||||||
|  |                 .unwrap(); | ||||||
|  |             let req = h2.request(request, true).unwrap().then(|res| { | ||||||
|  |                 let err = res.unwrap_err(); | ||||||
|  |                 assert_eq!( | ||||||
|  |                     err.to_string(), | ||||||
|  |                     "protocol error: unspecific protocol error detected" | ||||||
|  |                 ); | ||||||
|  |                 Ok::<(), ()>(()) | ||||||
|  |             }); | ||||||
|  |  | ||||||
|  |             // client should see a protocol error | ||||||
|  |             let conn = h2.then(|res| { | ||||||
|  |                 let err = res.unwrap_err(); | ||||||
|  |                 assert_eq!( | ||||||
|  |                     err.to_string(), | ||||||
|  |                     "protocol error: unspecific protocol error detected" | ||||||
|  |                 ); | ||||||
|  |                 Ok::<(), ()>(()) | ||||||
|  |             }); | ||||||
|  |  | ||||||
|  |             conn.unwrap().join(req) | ||||||
|  |         }); | ||||||
|  |  | ||||||
|  |     h2.join(mock).wait().unwrap(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[test] | ||||||
|  | #[ignore] | ||||||
|  | fn recv_push_promise_with_unsafe_method_is_stream_error() { | ||||||
|  |     // for instance, when :method = POST | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[test] | ||||||
|  | #[ignore] | ||||||
|  | fn recv_push_promise_with_wrong_authority_is_stream_error() { | ||||||
|  |     // if server is foo.com, :authority = bar.com is stream error | ||||||
|  | } | ||||||
| @@ -28,6 +28,18 @@ pub fn data<T, B>(id: T, buf: B) -> Mock<frame::Data> | |||||||
|     Mock(frame::Data::new(id.into(), buf.into())) |     Mock(frame::Data::new(id.into(), buf.into())) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | pub fn push_promise<T1, T2>(id: T1, promised: T2) -> Mock<frame::PushPromise> | ||||||
|  | where T1: Into<StreamId>, | ||||||
|  |       T2: Into<StreamId>, | ||||||
|  | { | ||||||
|  |     Mock(frame::PushPromise::new( | ||||||
|  |         id.into(), | ||||||
|  |         promised.into(), | ||||||
|  |         frame::Pseudo::default(), | ||||||
|  |         HeaderMap::default(), | ||||||
|  |     )) | ||||||
|  | } | ||||||
|  |  | ||||||
| pub fn window_update<T>(id: T, sz: u32) -> frame::WindowUpdate | pub fn window_update<T>(id: T, sz: u32) -> frame::WindowUpdate | ||||||
|     where T: Into<StreamId>, |     where T: Into<StreamId>, | ||||||
| { | { | ||||||
| @@ -140,9 +152,54 @@ impl From<Mock<frame::Data>> for SendFrame { | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | // PushPromise helpers | ||||||
|  |  | ||||||
|  | impl Mock<frame::PushPromise> { | ||||||
|  |     pub fn request<M, U>(self, method: M, uri: U) -> Self | ||||||
|  |     where M: HttpTryInto<http::Method>, | ||||||
|  |           U: HttpTryInto<http::Uri>, | ||||||
|  |     { | ||||||
|  |         let method = method.try_into().unwrap(); | ||||||
|  |         let uri = uri.try_into().unwrap(); | ||||||
|  |         let (id, promised, _, fields) = self.into_parts(); | ||||||
|  |         let frame = frame::PushPromise::new( | ||||||
|  |             id, | ||||||
|  |             promised, | ||||||
|  |             frame::Pseudo::request(method, uri), | ||||||
|  |             fields | ||||||
|  |         ); | ||||||
|  |         Mock(frame) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn fields(self, fields: HeaderMap) -> Self { | ||||||
|  |         let (id, promised, pseudo, _) = self.into_parts(); | ||||||
|  |         let frame = frame::PushPromise::new(id, promised, pseudo, fields); | ||||||
|  |         Mock(frame) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn into_parts(self) -> (StreamId, StreamId, frame::Pseudo, HeaderMap) { | ||||||
|  |         assert!(self.0.is_end_headers(), "unset eoh will be lost"); | ||||||
|  |         let id = self.0.stream_id(); | ||||||
|  |         let promised = self.0.promised_id(); | ||||||
|  |         let parts = self.0.into_parts(); | ||||||
|  |         (id, promised, parts.0, parts.1) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl From<Mock<frame::PushPromise>> for SendFrame { | ||||||
|  |     fn from(src: Mock<frame::PushPromise>) -> Self { | ||||||
|  |         Frame::PushPromise(src.0) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| // GoAway helpers | // GoAway helpers | ||||||
|  |  | ||||||
| impl Mock<frame::GoAway> { | impl Mock<frame::GoAway> { | ||||||
|  |     pub fn protocol_error(self) -> Self { | ||||||
|  |         Mock(frame::GoAway::new(self.0.last_stream_id(), frame::Reason::ProtocolError)) | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn flow_control(self) -> Self { |     pub fn flow_control(self) -> Self { | ||||||
|         Mock(frame::GoAway::new(self.0.last_stream_id(), frame::Reason::FlowControlError)) |         Mock(frame::GoAway::new(self.0.last_stream_id(), frame::Reason::FlowControlError)) | ||||||
|     } |     } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user