add Client config to disable server push
- Adds `Client::builder().enable_push(false)` to disable push - Client sends a GO_AWAY if receiving a push when it's disabled
This commit is contained in:
		| @@ -176,6 +176,12 @@ impl Builder { | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     /// Enable or disable the server to send push promises. | ||||
|     pub fn enable_push(&mut self, enabled: bool) -> &mut Self { | ||||
|         self.settings.set_enable_push(enabled); | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     /// Bind an H2 client connection. | ||||
|     /// | ||||
|     /// Returns a future which resolves to the connection value once the H2 | ||||
|   | ||||
| @@ -127,8 +127,9 @@ where | ||||
|                 } | ||||
|             }, | ||||
|             Frame::PushPromise(v) => { | ||||
|                 debug!("unimplemented PUSH_PROMISE write; frame={:?}", v); | ||||
|                 unimplemented!(); | ||||
|                 if let Some(continuation) = v.encode(&mut self.hpack, self.buf.get_mut()) { | ||||
|                     self.next = Some(Next::Continuation(continuation)); | ||||
|                 } | ||||
|             }, | ||||
|             Frame::Settings(v) => { | ||||
|                 v.encode(self.buf.get_mut()); | ||||
|   | ||||
| @@ -1,12 +1,12 @@ | ||||
| use super::{StreamDependency, StreamId}; | ||||
| use frame::{self, Error, Frame, Head, Kind}; | ||||
| use frame::{Error, Frame, Head, Kind}; | ||||
| use hpack; | ||||
|  | ||||
| use http::{uri, HeaderMap, Method, StatusCode, Uri}; | ||||
| use http::header::{self, HeaderName, HeaderValue}; | ||||
|  | ||||
| use byteorder::{BigEndian, ByteOrder}; | ||||
| use bytes::{Bytes, BytesMut}; | ||||
| use bytes::{BufMut, Bytes, BytesMut}; | ||||
| use string::String; | ||||
|  | ||||
| use std::fmt; | ||||
| @@ -23,12 +23,8 @@ pub struct Headers { | ||||
|     /// The stream dependency information, if any. | ||||
|     stream_dep: Option<StreamDependency>, | ||||
|  | ||||
|     /// The decoded header fields | ||||
|     fields: HeaderMap, | ||||
|  | ||||
|     /// Pseudo headers, these are broken out as they must be sent as part of the | ||||
|     /// headers frame. | ||||
|     pseudo: Pseudo, | ||||
|     /// The header block fragment | ||||
|     header_block: HeaderBlock, | ||||
|  | ||||
|     /// The associated flags | ||||
|     flags: HeadersFlag, | ||||
| @@ -37,7 +33,7 @@ pub struct Headers { | ||||
| #[derive(Copy, Clone, Eq, PartialEq)] | ||||
| pub struct HeadersFlag(u8); | ||||
|  | ||||
| #[derive(Debug, Eq, PartialEq)] | ||||
| #[derive(Eq, PartialEq)] | ||||
| pub struct PushPromise { | ||||
|     /// The ID of the stream with which this frame is associated. | ||||
|     stream_id: StreamId, | ||||
| @@ -45,11 +41,14 @@ pub struct PushPromise { | ||||
|     /// The ID of the stream being reserved by this PushPromise. | ||||
|     promised_id: StreamId, | ||||
|  | ||||
|     /// The header block fragment | ||||
|     header_block: HeaderBlock, | ||||
|  | ||||
|     /// The associated flags | ||||
|     flags: PushPromiseFlag, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Copy, Clone, Eq, PartialEq)] | ||||
| #[derive(Copy, Clone, Eq, PartialEq)] | ||||
| pub struct PushPromiseFlag(u8); | ||||
|  | ||||
| #[derive(Debug)] | ||||
| @@ -85,6 +84,16 @@ pub struct Iter { | ||||
|     fields: header::IntoIter<HeaderValue>, | ||||
| } | ||||
|  | ||||
| #[derive(PartialEq, Eq)] | ||||
| struct HeaderBlock { | ||||
|     /// The decoded header fields | ||||
|     fields: HeaderMap, | ||||
|  | ||||
|     /// Pseudo headers, these are broken out as they must be sent as part of the | ||||
|     /// headers frame. | ||||
|     pseudo: Pseudo, | ||||
| } | ||||
|  | ||||
| const END_STREAM: u8 = 0x1; | ||||
| const END_HEADERS: u8 = 0x4; | ||||
| const PADDED: u8 = 0x8; | ||||
| @@ -99,8 +108,10 @@ impl Headers { | ||||
|         Headers { | ||||
|             stream_id: stream_id, | ||||
|             stream_dep: None, | ||||
|             header_block: HeaderBlock { | ||||
|                 fields: fields, | ||||
|                 pseudo: pseudo, | ||||
|             }, | ||||
|             flags: HeadersFlag::default(), | ||||
|         } | ||||
|     } | ||||
| @@ -112,8 +123,10 @@ impl Headers { | ||||
|         Headers { | ||||
|             stream_id, | ||||
|             stream_dep: None, | ||||
|             header_block: HeaderBlock { | ||||
|                 fields: fields, | ||||
|                 pseudo: Pseudo::default(), | ||||
|             }, | ||||
|             flags: flags, | ||||
|         } | ||||
|     } | ||||
| @@ -164,8 +177,10 @@ impl Headers { | ||||
|         let headers = Headers { | ||||
|             stream_id: head.stream_id(), | ||||
|             stream_dep: stream_dep, | ||||
|             header_block: HeaderBlock { | ||||
|                 fields: HeaderMap::new(), | ||||
|                 pseudo: Pseudo::default(), | ||||
|             }, | ||||
|             flags: flags, | ||||
|         }; | ||||
|  | ||||
| @@ -181,11 +196,11 @@ impl Headers { | ||||
|                 if reg { | ||||
|                     trace!("load_hpack; header malformed -- pseudo not at head of block"); | ||||
|                     malformed = true; | ||||
|                 } else if self.pseudo.$field.is_some() { | ||||
|                 } else if self.header_block.pseudo.$field.is_some() { | ||||
|                     trace!("load_hpack; header malformed -- repeated pseudo"); | ||||
|                     malformed = true; | ||||
|                 } else { | ||||
|                     self.pseudo.$field = Some($val); | ||||
|                     self.header_block.pseudo.$field = Some($val); | ||||
|                 } | ||||
|             }} | ||||
|         } | ||||
| @@ -216,7 +231,7 @@ impl Headers { | ||||
|                         malformed = true; | ||||
|                     } else { | ||||
|                         reg = true; | ||||
|                         self.fields.append(name, value); | ||||
|                         self.header_block.fields.append(name, value); | ||||
|                     } | ||||
|                 }, | ||||
|                 Authority(v) => set_pseudo!(authority, v), | ||||
| @@ -257,15 +272,15 @@ impl Headers { | ||||
|     } | ||||
|  | ||||
|     pub fn into_parts(self) -> (Pseudo, HeaderMap) { | ||||
|         (self.pseudo, self.fields) | ||||
|         (self.header_block.pseudo, self.header_block.fields) | ||||
|     } | ||||
|  | ||||
|     pub fn fields(&self) -> &HeaderMap { | ||||
|         &self.fields | ||||
|         &self.header_block.fields | ||||
|     } | ||||
|  | ||||
|     pub fn into_fields(self) -> HeaderMap { | ||||
|         self.fields | ||||
|         self.header_block.fields | ||||
|     } | ||||
|  | ||||
|     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { | ||||
| @@ -278,27 +293,12 @@ impl Headers { | ||||
|         head.encode(0, dst); | ||||
|  | ||||
|         // Encode the frame | ||||
|         let mut headers = Iter { | ||||
|             pseudo: Some(self.pseudo), | ||||
|             fields: self.fields.into_iter(), | ||||
|         }; | ||||
|  | ||||
|         let ret = match encoder.encode(None, &mut headers, dst) { | ||||
|             hpack::Encode::Full => None, | ||||
|             hpack::Encode::Partial(state) => Some(Continuation { | ||||
|                 stream_id: self.stream_id, | ||||
|                 hpack: state, | ||||
|                 headers: headers, | ||||
|             }), | ||||
|         }; | ||||
|  | ||||
|         // Compute the frame length | ||||
|         let len = (dst.len() - pos) - frame::HEADER_LEN; | ||||
|         let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); | ||||
|  | ||||
|         // Write the frame length | ||||
|         BigEndian::write_uint(&mut dst[pos..pos + 3], len as u64, 3); | ||||
|         BigEndian::write_uint(&mut dst[pos..pos + 3], len, 3); | ||||
|  | ||||
|         ret | ||||
|         cont | ||||
|     } | ||||
|  | ||||
|     fn head(&self) -> Head { | ||||
| @@ -326,6 +326,23 @@ impl fmt::Debug for Headers { | ||||
| // ===== impl PushPromise ===== | ||||
|  | ||||
| impl PushPromise { | ||||
|     pub fn new( | ||||
|         stream_id: StreamId, | ||||
|         promised_id: StreamId, | ||||
|         pseudo: Pseudo, | ||||
|         fields: HeaderMap, | ||||
|     ) -> Self { | ||||
|         PushPromise { | ||||
|             flags: PushPromiseFlag::default(), | ||||
|             header_block: HeaderBlock { | ||||
|                 fields, | ||||
|                 pseudo, | ||||
|             }, | ||||
|             promised_id, | ||||
|             stream_id, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn load(head: Head, payload: &[u8]) -> Result<Self, Error> { | ||||
|         let flags = PushPromiseFlag(head.flag()); | ||||
|  | ||||
| @@ -334,9 +351,13 @@ impl PushPromise { | ||||
|         let (promised_id, _) = StreamId::parse(&payload[..4]); | ||||
|  | ||||
|         Ok(PushPromise { | ||||
|             stream_id: head.stream_id(), | ||||
|             promised_id: promised_id, | ||||
|             flags: flags, | ||||
|             header_block: HeaderBlock { | ||||
|                 fields: HeaderMap::new(), | ||||
|                 pseudo: Pseudo::default(), | ||||
|             }, | ||||
|             promised_id: promised_id, | ||||
|             stream_id: head.stream_id(), | ||||
|         }) | ||||
|     } | ||||
|  | ||||
| @@ -347,6 +368,45 @@ impl PushPromise { | ||||
|     pub fn promised_id(&self) -> StreamId { | ||||
|         self.promised_id | ||||
|     } | ||||
|  | ||||
|     pub fn is_end_headers(&self) -> bool { | ||||
|         self.flags.is_end_headers() | ||||
|     } | ||||
|  | ||||
|     pub fn into_parts(self) -> (Pseudo, HeaderMap) { | ||||
|         (self.header_block.pseudo, self.header_block.fields) | ||||
|     } | ||||
|  | ||||
|     pub fn fields(&self) -> &HeaderMap { | ||||
|         &self.header_block.fields | ||||
|     } | ||||
|  | ||||
|     pub fn into_fields(self) -> HeaderMap { | ||||
|         self.header_block.fields | ||||
|     } | ||||
|  | ||||
|     pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> { | ||||
|         let head = self.head(); | ||||
|         let pos = dst.len(); | ||||
|  | ||||
|         // At this point, we don't know how big the h2 frame will be. | ||||
|         // So, we write the head with length 0, then write the body, and | ||||
|         // finally write the length once we know the size. | ||||
|         head.encode(0, dst); | ||||
|  | ||||
|         // Encode the frame | ||||
|         dst.put_u32::<BigEndian>(self.promised_id.into()); | ||||
|         let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); | ||||
|  | ||||
|         // Write the frame length | ||||
|         BigEndian::write_uint(&mut dst[pos..pos + 3], len + 4, 3); | ||||
|  | ||||
|         cont | ||||
|     } | ||||
|  | ||||
|     fn head(&self) -> Head { | ||||
|         Head::new(Kind::PushPromise, self.flags.into(), self.stream_id) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<T> From<PushPromise> for Frame<T> { | ||||
| @@ -355,6 +415,17 @@ impl<T> From<PushPromise> for Frame<T> { | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl fmt::Debug for PushPromise { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||||
|         f.debug_struct("PushPromise") | ||||
|             .field("stream_id", &self.stream_id) | ||||
|             .field("promised_id", &self.promised_id) | ||||
|             .field("flags", &self.flags) | ||||
|             // `fields` and `pseudo` purposefully not included | ||||
|             .finish() | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ===== impl Pseudo ===== | ||||
|  | ||||
| impl Pseudo { | ||||
| @@ -509,3 +580,76 @@ impl fmt::Debug for HeadersFlag { | ||||
|             .finish() | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ===== impl PushPromiseFlag ===== | ||||
|  | ||||
| impl PushPromiseFlag { | ||||
|     pub fn empty() -> PushPromiseFlag { | ||||
|         PushPromiseFlag(0) | ||||
|     } | ||||
|  | ||||
|     pub fn load(bits: u8) -> PushPromiseFlag { | ||||
|         PushPromiseFlag(bits & ALL) | ||||
|     } | ||||
|  | ||||
|     pub fn is_end_headers(&self) -> bool { | ||||
|         self.0 & END_HEADERS == END_HEADERS | ||||
|     } | ||||
|  | ||||
|     pub fn is_padded(&self) -> bool { | ||||
|         self.0 & PADDED == PADDED | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Default for PushPromiseFlag { | ||||
|     /// Returns a `PushPromiseFlag` value with `END_HEADERS` set. | ||||
|     fn default() -> Self { | ||||
|         PushPromiseFlag(END_HEADERS) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<PushPromiseFlag> for u8 { | ||||
|     fn from(src: PushPromiseFlag) -> u8 { | ||||
|         src.0 | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl fmt::Debug for PushPromiseFlag { | ||||
|     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { | ||||
|         fmt.debug_struct("PushPromiseFlag") | ||||
|             .field("end_headers", &self.is_end_headers()) | ||||
|             .field("padded", &self.is_padded()) | ||||
|             .finish() | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ===== HeaderBlock ===== | ||||
|  | ||||
| impl HeaderBlock { | ||||
|     fn encode( | ||||
|         self, | ||||
|         stream_id: StreamId, | ||||
|         encoder: &mut hpack::Encoder, | ||||
|         dst: &mut BytesMut, | ||||
|     ) -> (u64, Option<Continuation>) { | ||||
|         let pos = dst.len(); | ||||
|         let mut headers = Iter { | ||||
|             pseudo: Some(self.pseudo), | ||||
|             fields: self.fields.into_iter(), | ||||
|         }; | ||||
|  | ||||
|         let cont = match encoder.encode(None, &mut headers, dst) { | ||||
|             hpack::Encode::Full => None, | ||||
|             hpack::Encode::Partial(state) => Some(Continuation { | ||||
|                 stream_id: stream_id, | ||||
|                 hpack: state, | ||||
|                 headers: headers, | ||||
|             }), | ||||
|         }; | ||||
|  | ||||
|         // Compute the header block length | ||||
|         let len = (dst.len() - pos) as u64; | ||||
|  | ||||
|         (len, cont) | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -85,6 +85,14 @@ impl Settings { | ||||
|         self.max_frame_size = size; | ||||
|     } | ||||
|  | ||||
|     pub fn is_push_enabled(&self) -> bool { | ||||
|         self.enable_push.unwrap_or(1) != 0 | ||||
|     } | ||||
|  | ||||
|     pub fn set_enable_push(&mut self, enable: bool) { | ||||
|         self.enable_push = Some(enable as u32); | ||||
|     } | ||||
|  | ||||
|     pub fn load(head: Head, payload: &[u8]) -> Result<Settings, Error> { | ||||
|         use self::Setting::*; | ||||
|  | ||||
|   | ||||
| @@ -64,12 +64,13 @@ where | ||||
|     ) -> Connection<T, P, B> { | ||||
|         // TODO: Actually configure | ||||
|         let streams = Streams::new(streams::Config { | ||||
|             max_remote_initiated: None, | ||||
|             init_remote_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, | ||||
|             max_local_initiated: None, | ||||
|             init_local_window_sz: settings | ||||
|             local_init_window_sz: settings | ||||
|                 .initial_window_size() | ||||
|                 .unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE), | ||||
|             local_max_initiated: None, | ||||
|             local_push_enabled: settings.is_push_enabled(), | ||||
|             remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, | ||||
|             remote_max_initiated: None, | ||||
|         }); | ||||
|         Connection { | ||||
|             state: State::Open, | ||||
|   | ||||
| @@ -34,9 +34,9 @@ where | ||||
|     /// Create a new `Counts` using the provided configuration values. | ||||
|     pub fn new(config: &Config) -> Self { | ||||
|         Counts { | ||||
|             max_send_streams: config.max_local_initiated, | ||||
|             max_send_streams: config.local_max_initiated, | ||||
|             num_send_streams: 0, | ||||
|             max_recv_streams: config.max_remote_initiated, | ||||
|             max_recv_streams: config.remote_max_initiated, | ||||
|             num_recv_streams: 0, | ||||
|             blocked_open: None, | ||||
|             _p: PhantomData, | ||||
|   | ||||
| @@ -31,15 +31,18 @@ use http::{Request, Response}; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct Config { | ||||
|     /// Maximum number of remote initiated streams | ||||
|     pub max_remote_initiated: Option<usize>, | ||||
|  | ||||
|     /// Initial window size of remote initiated streams | ||||
|     pub init_remote_window_sz: WindowSize, | ||||
|     /// Initial window size of locally initiated streams | ||||
|     pub local_init_window_sz: WindowSize, | ||||
|  | ||||
|     /// Maximum number of locally initiated streams | ||||
|     pub max_local_initiated: Option<usize>, | ||||
|     pub local_max_initiated: Option<usize>, | ||||
|  | ||||
|     /// Initial window size of locally initiated streams | ||||
|     pub init_local_window_sz: WindowSize, | ||||
|     /// If the local peer is willing to receive push promises | ||||
|     pub local_push_enabled: bool, | ||||
|  | ||||
|     /// Initial window size of remote initiated streams | ||||
|     pub remote_init_window_sz: WindowSize, | ||||
|  | ||||
|     /// Maximum number of remote initiated streams | ||||
|     pub remote_max_initiated: Option<usize>, | ||||
| } | ||||
|   | ||||
| @@ -49,11 +49,11 @@ where | ||||
|     pub fn new(config: &Config) -> Prioritize<B, P> { | ||||
|         let mut flow = FlowControl::new(); | ||||
|  | ||||
|         flow.inc_window(config.init_local_window_sz) | ||||
|         flow.inc_window(config.local_init_window_sz) | ||||
|             .ok() | ||||
|             .expect("invalid initial window size"); | ||||
|  | ||||
|         flow.assign_capacity(config.init_local_window_sz); | ||||
|         flow.assign_capacity(config.local_init_window_sz); | ||||
|  | ||||
|         trace!("Prioritize::new; flow={:?}", flow); | ||||
|  | ||||
|   | ||||
| @@ -38,6 +38,9 @@ where | ||||
|     /// Refused StreamId, this represents a frame that must be sent out. | ||||
|     refused: Option<StreamId>, | ||||
|  | ||||
|     /// If push promises are allowed to be recevied. | ||||
|     is_push_enabled: bool, | ||||
|  | ||||
|     _p: PhantomData<B>, | ||||
| } | ||||
|  | ||||
| @@ -71,7 +74,7 @@ where | ||||
|         flow.assign_capacity(DEFAULT_INITIAL_WINDOW_SIZE); | ||||
|  | ||||
|         Recv { | ||||
|             init_window_sz: config.init_local_window_sz, | ||||
|             init_window_sz: config.local_init_window_sz, | ||||
|             flow: flow, | ||||
|             next_stream_id: next_stream_id.into(), | ||||
|             pending_window_updates: store::Queue::new(), | ||||
| @@ -79,6 +82,7 @@ where | ||||
|             pending_accept: store::Queue::new(), | ||||
|             buffer: Buffer::new(), | ||||
|             refused: None, | ||||
|             is_push_enabled: config.local_push_enabled, | ||||
|             _p: PhantomData, | ||||
|         } | ||||
|     } | ||||
| @@ -429,10 +433,20 @@ where | ||||
|         // TODO: Are there other rules? | ||||
|         if P::is_server() { | ||||
|             // The remote is a client and cannot reserve | ||||
|             trace!("recv_push_promise; error remote is client"); | ||||
|             return Err(RecvError::Connection(ProtocolError)); | ||||
|         } | ||||
|  | ||||
|         if !promised_id.is_server_initiated() { | ||||
|             trace!( | ||||
|                 "recv_push_promise; error promised id is invalid {:?}", | ||||
|                 promised_id | ||||
|             ); | ||||
|             return Err(RecvError::Connection(ProtocolError)); | ||||
|         } | ||||
|  | ||||
|         if !self.is_push_enabled { | ||||
|             trace!("recv_push_promise; error push is disabled"); | ||||
|             return Err(RecvError::Connection(ProtocolError)); | ||||
|         } | ||||
|  | ||||
|   | ||||
| @@ -35,7 +35,7 @@ where | ||||
|  | ||||
|         Send { | ||||
|             next_stream_id: next_stream_id.into(), | ||||
|             init_window_sz: config.init_local_window_sz, | ||||
|             init_window_sz: config.local_init_window_sz, | ||||
|             prioritize: Prioritize::new(config), | ||||
|         } | ||||
|     } | ||||
|   | ||||
| @@ -285,6 +285,7 @@ impl State { | ||||
|                 .. | ||||
|             } => true, | ||||
|             HalfClosedLocal(AwaitingHeaders) => true, | ||||
|             ReservedRemote => true, | ||||
|             _ => false, | ||||
|         } | ||||
|     } | ||||
|   | ||||
							
								
								
									
										109
									
								
								tests/push_promise.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								tests/push_promise.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,109 @@ | ||||
| extern crate h2_test_support; | ||||
| use h2_test_support::prelude::*; | ||||
|  | ||||
| #[test] | ||||
| fn recv_push_works() { | ||||
|     // tests that by default, received push promises work | ||||
|     // TODO: once API exists, read the pushed response | ||||
|     let _ = ::env_logger::init(); | ||||
|  | ||||
|     let (io, srv) = mock::new(); | ||||
|     let mock = srv.assert_client_handshake() | ||||
|         .unwrap() | ||||
|         .recv_settings() | ||||
|         .recv_frame( | ||||
|             frames::headers(1) | ||||
|                 .request("GET", "https://http2.akamai.com/") | ||||
|                 .eos(), | ||||
|         ) | ||||
|         .send_frame( | ||||
|             frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css"), | ||||
|         ) | ||||
|         .send_frame(frames::headers(1).response(200).eos()) | ||||
|         .send_frame(frames::headers(2).response(200).eos()); | ||||
|  | ||||
|     let h2 = Client::handshake(io).unwrap().and_then(|mut h2| { | ||||
|         let request = Request::builder() | ||||
|             .method(Method::GET) | ||||
|             .uri("https://http2.akamai.com/") | ||||
|             .body(()) | ||||
|             .unwrap(); | ||||
|         let req = h2.request(request, true) | ||||
|             .unwrap() | ||||
|             .unwrap() | ||||
|             .and_then(|resp| { | ||||
|                 assert_eq!(resp.status(), StatusCode::OK); | ||||
|                 Ok(()) | ||||
|             }); | ||||
|  | ||||
|         h2.drive(req) | ||||
|     }); | ||||
|  | ||||
|     h2.join(mock).wait().unwrap(); | ||||
| } | ||||
|  | ||||
| #[test] | ||||
| fn recv_push_when_push_disabled_is_conn_error() { | ||||
|     let _ = ::env_logger::init(); | ||||
|  | ||||
|     let (io, srv) = mock::new(); | ||||
|     let mock = srv.assert_client_handshake() | ||||
|         .unwrap() | ||||
|         .ignore_settings() | ||||
|         .recv_frame( | ||||
|             frames::headers(1) | ||||
|                 .request("GET", "https://http2.akamai.com/") | ||||
|                 .eos(), | ||||
|         ) | ||||
|         .send_frame( | ||||
|             frames::push_promise(1, 3).request("GET", "https://http2.akamai.com/style.css"), | ||||
|         ) | ||||
|         .send_frame(frames::headers(1).response(200).eos()) | ||||
|         .recv_frame(frames::go_away(0).protocol_error()); | ||||
|  | ||||
|     let h2 = Client::builder() | ||||
|         .enable_push(false) | ||||
|         .handshake::<_, Bytes>(io) | ||||
|         .unwrap() | ||||
|         .and_then(|mut h2| { | ||||
|             let request = Request::builder() | ||||
|                 .method(Method::GET) | ||||
|                 .uri("https://http2.akamai.com/") | ||||
|                 .body(()) | ||||
|                 .unwrap(); | ||||
|             let req = h2.request(request, true).unwrap().then(|res| { | ||||
|                 let err = res.unwrap_err(); | ||||
|                 assert_eq!( | ||||
|                     err.to_string(), | ||||
|                     "protocol error: unspecific protocol error detected" | ||||
|                 ); | ||||
|                 Ok::<(), ()>(()) | ||||
|             }); | ||||
|  | ||||
|             // client should see a protocol error | ||||
|             let conn = h2.then(|res| { | ||||
|                 let err = res.unwrap_err(); | ||||
|                 assert_eq!( | ||||
|                     err.to_string(), | ||||
|                     "protocol error: unspecific protocol error detected" | ||||
|                 ); | ||||
|                 Ok::<(), ()>(()) | ||||
|             }); | ||||
|  | ||||
|             conn.unwrap().join(req) | ||||
|         }); | ||||
|  | ||||
|     h2.join(mock).wait().unwrap(); | ||||
| } | ||||
|  | ||||
| #[test] | ||||
| #[ignore] | ||||
| fn recv_push_promise_with_unsafe_method_is_stream_error() { | ||||
|     // for instance, when :method = POST | ||||
| } | ||||
|  | ||||
| #[test] | ||||
| #[ignore] | ||||
| fn recv_push_promise_with_wrong_authority_is_stream_error() { | ||||
|     // if server is foo.com, :authority = bar.com is stream error | ||||
| } | ||||
| @@ -28,6 +28,18 @@ pub fn data<T, B>(id: T, buf: B) -> Mock<frame::Data> | ||||
|     Mock(frame::Data::new(id.into(), buf.into())) | ||||
| } | ||||
|  | ||||
| pub fn push_promise<T1, T2>(id: T1, promised: T2) -> Mock<frame::PushPromise> | ||||
| where T1: Into<StreamId>, | ||||
|       T2: Into<StreamId>, | ||||
| { | ||||
|     Mock(frame::PushPromise::new( | ||||
|         id.into(), | ||||
|         promised.into(), | ||||
|         frame::Pseudo::default(), | ||||
|         HeaderMap::default(), | ||||
|     )) | ||||
| } | ||||
|  | ||||
| pub fn window_update<T>(id: T, sz: u32) -> frame::WindowUpdate | ||||
|     where T: Into<StreamId>, | ||||
| { | ||||
| @@ -140,9 +152,54 @@ impl From<Mock<frame::Data>> for SendFrame { | ||||
|     } | ||||
| } | ||||
|  | ||||
|  | ||||
| // PushPromise helpers | ||||
|  | ||||
| impl Mock<frame::PushPromise> { | ||||
|     pub fn request<M, U>(self, method: M, uri: U) -> Self | ||||
|     where M: HttpTryInto<http::Method>, | ||||
|           U: HttpTryInto<http::Uri>, | ||||
|     { | ||||
|         let method = method.try_into().unwrap(); | ||||
|         let uri = uri.try_into().unwrap(); | ||||
|         let (id, promised, _, fields) = self.into_parts(); | ||||
|         let frame = frame::PushPromise::new( | ||||
|             id, | ||||
|             promised, | ||||
|             frame::Pseudo::request(method, uri), | ||||
|             fields | ||||
|         ); | ||||
|         Mock(frame) | ||||
|     } | ||||
|  | ||||
|     pub fn fields(self, fields: HeaderMap) -> Self { | ||||
|         let (id, promised, pseudo, _) = self.into_parts(); | ||||
|         let frame = frame::PushPromise::new(id, promised, pseudo, fields); | ||||
|         Mock(frame) | ||||
|     } | ||||
|  | ||||
|     fn into_parts(self) -> (StreamId, StreamId, frame::Pseudo, HeaderMap) { | ||||
|         assert!(self.0.is_end_headers(), "unset eoh will be lost"); | ||||
|         let id = self.0.stream_id(); | ||||
|         let promised = self.0.promised_id(); | ||||
|         let parts = self.0.into_parts(); | ||||
|         (id, promised, parts.0, parts.1) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<Mock<frame::PushPromise>> for SendFrame { | ||||
|     fn from(src: Mock<frame::PushPromise>) -> Self { | ||||
|         Frame::PushPromise(src.0) | ||||
|     } | ||||
| } | ||||
|  | ||||
| // GoAway helpers | ||||
|  | ||||
| impl Mock<frame::GoAway> { | ||||
|     pub fn protocol_error(self) -> Self { | ||||
|         Mock(frame::GoAway::new(self.0.last_stream_id(), frame::Reason::ProtocolError)) | ||||
|     } | ||||
|  | ||||
|     pub fn flow_control(self) -> Self { | ||||
|         Mock(frame::GoAway::new(self.0.last_stream_id(), frame::Reason::FlowControlError)) | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user