diff --git a/src/client.rs b/src/client.rs index 400fafb..9d4b073 100644 --- a/src/client.rs +++ b/src/client.rs @@ -176,6 +176,12 @@ impl Builder { self } + /// Enable or disable the server to send push promises. + pub fn enable_push(&mut self, enabled: bool) -> &mut Self { + self.settings.set_enable_push(enabled); + self + } + /// Bind an H2 client connection. /// /// Returns a future which resolves to the connection value once the H2 diff --git a/src/codec/framed_write.rs b/src/codec/framed_write.rs index 473b1ce..dbd867f 100644 --- a/src/codec/framed_write.rs +++ b/src/codec/framed_write.rs @@ -127,8 +127,9 @@ where } }, Frame::PushPromise(v) => { - debug!("unimplemented PUSH_PROMISE write; frame={:?}", v); - unimplemented!(); + if let Some(continuation) = v.encode(&mut self.hpack, self.buf.get_mut()) { + self.next = Some(Next::Continuation(continuation)); + } }, Frame::Settings(v) => { v.encode(self.buf.get_mut()); diff --git a/src/frame/headers.rs b/src/frame/headers.rs index 8d9b064..c261f16 100644 --- a/src/frame/headers.rs +++ b/src/frame/headers.rs @@ -1,12 +1,12 @@ use super::{StreamDependency, StreamId}; -use frame::{self, Error, Frame, Head, Kind}; +use frame::{Error, Frame, Head, Kind}; use hpack; use http::{uri, HeaderMap, Method, StatusCode, Uri}; use http::header::{self, HeaderName, HeaderValue}; use byteorder::{BigEndian, ByteOrder}; -use bytes::{Bytes, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; use string::String; use std::fmt; @@ -23,12 +23,8 @@ pub struct Headers { /// The stream dependency information, if any. stream_dep: Option, - /// The decoded header fields - fields: HeaderMap, - - /// Pseudo headers, these are broken out as they must be sent as part of the - /// headers frame. - pseudo: Pseudo, + /// The header block fragment + header_block: HeaderBlock, /// The associated flags flags: HeadersFlag, @@ -37,7 +33,7 @@ pub struct Headers { #[derive(Copy, Clone, Eq, PartialEq)] pub struct HeadersFlag(u8); -#[derive(Debug, Eq, PartialEq)] +#[derive(Eq, PartialEq)] pub struct PushPromise { /// The ID of the stream with which this frame is associated. stream_id: StreamId, @@ -45,11 +41,14 @@ pub struct PushPromise { /// The ID of the stream being reserved by this PushPromise. promised_id: StreamId, + /// The header block fragment + header_block: HeaderBlock, + /// The associated flags flags: PushPromiseFlag, } -#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[derive(Copy, Clone, Eq, PartialEq)] pub struct PushPromiseFlag(u8); #[derive(Debug)] @@ -85,6 +84,16 @@ pub struct Iter { fields: header::IntoIter, } +#[derive(PartialEq, Eq)] +struct HeaderBlock { + /// The decoded header fields + fields: HeaderMap, + + /// Pseudo headers, these are broken out as they must be sent as part of the + /// headers frame. + pseudo: Pseudo, +} + const END_STREAM: u8 = 0x1; const END_HEADERS: u8 = 0x4; const PADDED: u8 = 0x8; @@ -99,8 +108,10 @@ impl Headers { Headers { stream_id: stream_id, stream_dep: None, - fields: fields, - pseudo: pseudo, + header_block: HeaderBlock { + fields: fields, + pseudo: pseudo, + }, flags: HeadersFlag::default(), } } @@ -112,8 +123,10 @@ impl Headers { Headers { stream_id, stream_dep: None, - fields: fields, - pseudo: Pseudo::default(), + header_block: HeaderBlock { + fields: fields, + pseudo: Pseudo::default(), + }, flags: flags, } } @@ -164,8 +177,10 @@ impl Headers { let headers = Headers { stream_id: head.stream_id(), stream_dep: stream_dep, - fields: HeaderMap::new(), - pseudo: Pseudo::default(), + header_block: HeaderBlock { + fields: HeaderMap::new(), + pseudo: Pseudo::default(), + }, flags: flags, }; @@ -181,11 +196,11 @@ impl Headers { if reg { trace!("load_hpack; header malformed -- pseudo not at head of block"); malformed = true; - } else if self.pseudo.$field.is_some() { + } else if self.header_block.pseudo.$field.is_some() { trace!("load_hpack; header malformed -- repeated pseudo"); malformed = true; } else { - self.pseudo.$field = Some($val); + self.header_block.pseudo.$field = Some($val); } }} } @@ -216,7 +231,7 @@ impl Headers { malformed = true; } else { reg = true; - self.fields.append(name, value); + self.header_block.fields.append(name, value); } }, Authority(v) => set_pseudo!(authority, v), @@ -257,15 +272,15 @@ impl Headers { } pub fn into_parts(self) -> (Pseudo, HeaderMap) { - (self.pseudo, self.fields) + (self.header_block.pseudo, self.header_block.fields) } pub fn fields(&self) -> &HeaderMap { - &self.fields + &self.header_block.fields } pub fn into_fields(self) -> HeaderMap { - self.fields + self.header_block.fields } pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option { @@ -278,27 +293,12 @@ impl Headers { head.encode(0, dst); // Encode the frame - let mut headers = Iter { - pseudo: Some(self.pseudo), - fields: self.fields.into_iter(), - }; - - let ret = match encoder.encode(None, &mut headers, dst) { - hpack::Encode::Full => None, - hpack::Encode::Partial(state) => Some(Continuation { - stream_id: self.stream_id, - hpack: state, - headers: headers, - }), - }; - - // Compute the frame length - let len = (dst.len() - pos) - frame::HEADER_LEN; + let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); // Write the frame length - BigEndian::write_uint(&mut dst[pos..pos + 3], len as u64, 3); + BigEndian::write_uint(&mut dst[pos..pos + 3], len, 3); - ret + cont } fn head(&self) -> Head { @@ -326,6 +326,23 @@ impl fmt::Debug for Headers { // ===== impl PushPromise ===== impl PushPromise { + pub fn new( + stream_id: StreamId, + promised_id: StreamId, + pseudo: Pseudo, + fields: HeaderMap, + ) -> Self { + PushPromise { + flags: PushPromiseFlag::default(), + header_block: HeaderBlock { + fields, + pseudo, + }, + promised_id, + stream_id, + } + } + pub fn load(head: Head, payload: &[u8]) -> Result { let flags = PushPromiseFlag(head.flag()); @@ -334,9 +351,13 @@ impl PushPromise { let (promised_id, _) = StreamId::parse(&payload[..4]); Ok(PushPromise { - stream_id: head.stream_id(), - promised_id: promised_id, flags: flags, + header_block: HeaderBlock { + fields: HeaderMap::new(), + pseudo: Pseudo::default(), + }, + promised_id: promised_id, + stream_id: head.stream_id(), }) } @@ -347,6 +368,45 @@ impl PushPromise { pub fn promised_id(&self) -> StreamId { self.promised_id } + + pub fn is_end_headers(&self) -> bool { + self.flags.is_end_headers() + } + + pub fn into_parts(self) -> (Pseudo, HeaderMap) { + (self.header_block.pseudo, self.header_block.fields) + } + + pub fn fields(&self) -> &HeaderMap { + &self.header_block.fields + } + + pub fn into_fields(self) -> HeaderMap { + self.header_block.fields + } + + pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option { + let head = self.head(); + let pos = dst.len(); + + // At this point, we don't know how big the h2 frame will be. + // So, we write the head with length 0, then write the body, and + // finally write the length once we know the size. + head.encode(0, dst); + + // Encode the frame + dst.put_u32::(self.promised_id.into()); + let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst); + + // Write the frame length + BigEndian::write_uint(&mut dst[pos..pos + 3], len + 4, 3); + + cont + } + + fn head(&self) -> Head { + Head::new(Kind::PushPromise, self.flags.into(), self.stream_id) + } } impl From for Frame { @@ -355,6 +415,17 @@ impl From for Frame { } } +impl fmt::Debug for PushPromise { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("PushPromise") + .field("stream_id", &self.stream_id) + .field("promised_id", &self.promised_id) + .field("flags", &self.flags) + // `fields` and `pseudo` purposefully not included + .finish() + } +} + // ===== impl Pseudo ===== impl Pseudo { @@ -509,3 +580,76 @@ impl fmt::Debug for HeadersFlag { .finish() } } + +// ===== impl PushPromiseFlag ===== + +impl PushPromiseFlag { + pub fn empty() -> PushPromiseFlag { + PushPromiseFlag(0) + } + + pub fn load(bits: u8) -> PushPromiseFlag { + PushPromiseFlag(bits & ALL) + } + + pub fn is_end_headers(&self) -> bool { + self.0 & END_HEADERS == END_HEADERS + } + + pub fn is_padded(&self) -> bool { + self.0 & PADDED == PADDED + } +} + +impl Default for PushPromiseFlag { + /// Returns a `PushPromiseFlag` value with `END_HEADERS` set. + fn default() -> Self { + PushPromiseFlag(END_HEADERS) + } +} + +impl From for u8 { + fn from(src: PushPromiseFlag) -> u8 { + src.0 + } +} + +impl fmt::Debug for PushPromiseFlag { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("PushPromiseFlag") + .field("end_headers", &self.is_end_headers()) + .field("padded", &self.is_padded()) + .finish() + } +} + +// ===== HeaderBlock ===== + +impl HeaderBlock { + fn encode( + self, + stream_id: StreamId, + encoder: &mut hpack::Encoder, + dst: &mut BytesMut, + ) -> (u64, Option) { + let pos = dst.len(); + let mut headers = Iter { + pseudo: Some(self.pseudo), + fields: self.fields.into_iter(), + }; + + let cont = match encoder.encode(None, &mut headers, dst) { + hpack::Encode::Full => None, + hpack::Encode::Partial(state) => Some(Continuation { + stream_id: stream_id, + hpack: state, + headers: headers, + }), + }; + + // Compute the header block length + let len = (dst.len() - pos) as u64; + + (len, cont) + } +} diff --git a/src/frame/settings.rs b/src/frame/settings.rs index a2e9747..f655a2f 100644 --- a/src/frame/settings.rs +++ b/src/frame/settings.rs @@ -85,6 +85,14 @@ impl Settings { self.max_frame_size = size; } + pub fn is_push_enabled(&self) -> bool { + self.enable_push.unwrap_or(1) != 0 + } + + pub fn set_enable_push(&mut self, enable: bool) { + self.enable_push = Some(enable as u32); + } + pub fn load(head: Head, payload: &[u8]) -> Result { use self::Setting::*; diff --git a/src/proto/connection.rs b/src/proto/connection.rs index bbf74f3..dec3e6c 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -64,12 +64,13 @@ where ) -> Connection { // TODO: Actually configure let streams = Streams::new(streams::Config { - max_remote_initiated: None, - init_remote_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, - max_local_initiated: None, - init_local_window_sz: settings + local_init_window_sz: settings .initial_window_size() .unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE), + local_max_initiated: None, + local_push_enabled: settings.is_push_enabled(), + remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, + remote_max_initiated: None, }); Connection { state: State::Open, diff --git a/src/proto/streams/counts.rs b/src/proto/streams/counts.rs index ffd1e3c..168e4f2 100644 --- a/src/proto/streams/counts.rs +++ b/src/proto/streams/counts.rs @@ -34,9 +34,9 @@ where /// Create a new `Counts` using the provided configuration values. pub fn new(config: &Config) -> Self { Counts { - max_send_streams: config.max_local_initiated, + max_send_streams: config.local_max_initiated, num_send_streams: 0, - max_recv_streams: config.max_remote_initiated, + max_recv_streams: config.remote_max_initiated, num_recv_streams: 0, blocked_open: None, _p: PhantomData, diff --git a/src/proto/streams/mod.rs b/src/proto/streams/mod.rs index 9783b23..1bf15cf 100644 --- a/src/proto/streams/mod.rs +++ b/src/proto/streams/mod.rs @@ -31,15 +31,18 @@ use http::{Request, Response}; #[derive(Debug)] pub struct Config { - /// Maximum number of remote initiated streams - pub max_remote_initiated: Option, - - /// Initial window size of remote initiated streams - pub init_remote_window_sz: WindowSize, + /// Initial window size of locally initiated streams + pub local_init_window_sz: WindowSize, /// Maximum number of locally initiated streams - pub max_local_initiated: Option, + pub local_max_initiated: Option, - /// Initial window size of locally initiated streams - pub init_local_window_sz: WindowSize, + /// If the local peer is willing to receive push promises + pub local_push_enabled: bool, + + /// Initial window size of remote initiated streams + pub remote_init_window_sz: WindowSize, + + /// Maximum number of remote initiated streams + pub remote_max_initiated: Option, } diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index 1488555..b0f2f1d 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -49,11 +49,11 @@ where pub fn new(config: &Config) -> Prioritize { let mut flow = FlowControl::new(); - flow.inc_window(config.init_local_window_sz) + flow.inc_window(config.local_init_window_sz) .ok() .expect("invalid initial window size"); - flow.assign_capacity(config.init_local_window_sz); + flow.assign_capacity(config.local_init_window_sz); trace!("Prioritize::new; flow={:?}", flow); diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index 5deaba2..8474cc7 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -38,6 +38,9 @@ where /// Refused StreamId, this represents a frame that must be sent out. refused: Option, + /// If push promises are allowed to be recevied. + is_push_enabled: bool, + _p: PhantomData, } @@ -71,7 +74,7 @@ where flow.assign_capacity(DEFAULT_INITIAL_WINDOW_SIZE); Recv { - init_window_sz: config.init_local_window_sz, + init_window_sz: config.local_init_window_sz, flow: flow, next_stream_id: next_stream_id.into(), pending_window_updates: store::Queue::new(), @@ -79,6 +82,7 @@ where pending_accept: store::Queue::new(), buffer: Buffer::new(), refused: None, + is_push_enabled: config.local_push_enabled, _p: PhantomData, } } @@ -429,10 +433,20 @@ where // TODO: Are there other rules? if P::is_server() { // The remote is a client and cannot reserve + trace!("recv_push_promise; error remote is client"); return Err(RecvError::Connection(ProtocolError)); } if !promised_id.is_server_initiated() { + trace!( + "recv_push_promise; error promised id is invalid {:?}", + promised_id + ); + return Err(RecvError::Connection(ProtocolError)); + } + + if !self.is_push_enabled { + trace!("recv_push_promise; error push is disabled"); return Err(RecvError::Connection(ProtocolError)); } diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index 5901fa7..bae09bd 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -35,7 +35,7 @@ where Send { next_stream_id: next_stream_id.into(), - init_window_sz: config.init_local_window_sz, + init_window_sz: config.local_init_window_sz, prioritize: Prioritize::new(config), } } diff --git a/src/proto/streams/state.rs b/src/proto/streams/state.rs index c6efdd5..b0dde1c 100644 --- a/src/proto/streams/state.rs +++ b/src/proto/streams/state.rs @@ -285,6 +285,7 @@ impl State { .. } => true, HalfClosedLocal(AwaitingHeaders) => true, + ReservedRemote => true, _ => false, } } diff --git a/tests/push_promise.rs b/tests/push_promise.rs new file mode 100644 index 0000000..59b1bad --- /dev/null +++ b/tests/push_promise.rs @@ -0,0 +1,109 @@ +extern crate h2_test_support; +use h2_test_support::prelude::*; + +#[test] +fn recv_push_works() { + // tests that by default, received push promises work + // TODO: once API exists, read the pushed response + let _ = ::env_logger::init(); + + let (io, srv) = mock::new(); + let mock = srv.assert_client_handshake() + .unwrap() + .recv_settings() + .recv_frame( + frames::headers(1) + .request("GET", "https://http2.akamai.com/") + .eos(), + ) + .send_frame( + frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css"), + ) + .send_frame(frames::headers(1).response(200).eos()) + .send_frame(frames::headers(2).response(200).eos()); + + let h2 = Client::handshake(io).unwrap().and_then(|mut h2| { + let request = Request::builder() + .method(Method::GET) + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + let req = h2.request(request, true) + .unwrap() + .unwrap() + .and_then(|resp| { + assert_eq!(resp.status(), StatusCode::OK); + Ok(()) + }); + + h2.drive(req) + }); + + h2.join(mock).wait().unwrap(); +} + +#[test] +fn recv_push_when_push_disabled_is_conn_error() { + let _ = ::env_logger::init(); + + let (io, srv) = mock::new(); + let mock = srv.assert_client_handshake() + .unwrap() + .ignore_settings() + .recv_frame( + frames::headers(1) + .request("GET", "https://http2.akamai.com/") + .eos(), + ) + .send_frame( + frames::push_promise(1, 3).request("GET", "https://http2.akamai.com/style.css"), + ) + .send_frame(frames::headers(1).response(200).eos()) + .recv_frame(frames::go_away(0).protocol_error()); + + let h2 = Client::builder() + .enable_push(false) + .handshake::<_, Bytes>(io) + .unwrap() + .and_then(|mut h2| { + let request = Request::builder() + .method(Method::GET) + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + let req = h2.request(request, true).unwrap().then(|res| { + let err = res.unwrap_err(); + assert_eq!( + err.to_string(), + "protocol error: unspecific protocol error detected" + ); + Ok::<(), ()>(()) + }); + + // client should see a protocol error + let conn = h2.then(|res| { + let err = res.unwrap_err(); + assert_eq!( + err.to_string(), + "protocol error: unspecific protocol error detected" + ); + Ok::<(), ()>(()) + }); + + conn.unwrap().join(req) + }); + + h2.join(mock).wait().unwrap(); +} + +#[test] +#[ignore] +fn recv_push_promise_with_unsafe_method_is_stream_error() { + // for instance, when :method = POST +} + +#[test] +#[ignore] +fn recv_push_promise_with_wrong_authority_is_stream_error() { + // if server is foo.com, :authority = bar.com is stream error +} diff --git a/tests/support/src/frames.rs b/tests/support/src/frames.rs index f9c5c15..ba43e40 100644 --- a/tests/support/src/frames.rs +++ b/tests/support/src/frames.rs @@ -28,6 +28,18 @@ pub fn data(id: T, buf: B) -> Mock Mock(frame::Data::new(id.into(), buf.into())) } +pub fn push_promise(id: T1, promised: T2) -> Mock +where T1: Into, + T2: Into, +{ + Mock(frame::PushPromise::new( + id.into(), + promised.into(), + frame::Pseudo::default(), + HeaderMap::default(), + )) +} + pub fn window_update(id: T, sz: u32) -> frame::WindowUpdate where T: Into, { @@ -140,9 +152,54 @@ impl From> for SendFrame { } } + +// PushPromise helpers + +impl Mock { + pub fn request(self, method: M, uri: U) -> Self + where M: HttpTryInto, + U: HttpTryInto, + { + let method = method.try_into().unwrap(); + let uri = uri.try_into().unwrap(); + let (id, promised, _, fields) = self.into_parts(); + let frame = frame::PushPromise::new( + id, + promised, + frame::Pseudo::request(method, uri), + fields + ); + Mock(frame) + } + + pub fn fields(self, fields: HeaderMap) -> Self { + let (id, promised, pseudo, _) = self.into_parts(); + let frame = frame::PushPromise::new(id, promised, pseudo, fields); + Mock(frame) + } + + fn into_parts(self) -> (StreamId, StreamId, frame::Pseudo, HeaderMap) { + assert!(self.0.is_end_headers(), "unset eoh will be lost"); + let id = self.0.stream_id(); + let promised = self.0.promised_id(); + let parts = self.0.into_parts(); + (id, promised, parts.0, parts.1) + } +} + +impl From> for SendFrame { + fn from(src: Mock) -> Self { + Frame::PushPromise(src.0) + } +} + // GoAway helpers impl Mock { + pub fn protocol_error(self) -> Self { + Mock(frame::GoAway::new(self.0.last_stream_id(), frame::Reason::ProtocolError)) + } + pub fn flow_control(self) -> Self { Mock(frame::GoAway::new(self.0.last_stream_id(), frame::Reason::FlowControlError)) }