Merge pull request #91 from seanmonstar/config-push-promise

add Client config to disable server push
This commit is contained in:
Sean McArthur
2017-09-18 11:16:49 -07:00
committed by GitHub
14 changed files with 580 additions and 152 deletions

View File

@@ -176,6 +176,12 @@ impl Builder {
self
}
/// Enable or disable the server to send push promises.
pub fn enable_push(&mut self, enabled: bool) -> &mut Self {
self.settings.set_enable_push(enabled);
self
}
/// Bind an H2 client connection.
///
/// Returns a future which resolves to the connection value once the H2

View File

@@ -39,7 +39,7 @@ enum Continuable {
Headers(frame::Headers),
// Decode the Continuation frame but ignore it...
// Ignore(StreamId),
// PushPromise(frame::PushPromise),
PushPromise(frame::PushPromise),
}
impl<T> FramedRead<T> {
@@ -143,8 +143,38 @@ impl<T> FramedRead<T> {
res.map_err(|_| Connection(ProtocolError))?.into()
},
Kind::PushPromise => {
let res = frame::PushPromise::load(head, &bytes[frame::HEADER_LEN..]);
res.map_err(|_| Connection(ProtocolError))?.into()
// Drop the frame header
// TODO: Change to drain: carllerche/bytes#130
let _ = bytes.split_to(frame::HEADER_LEN);
// Parse the frame w/o parsing the payload
let (mut push, payload) = frame::PushPromise::load(head, bytes)
.map_err(|_| Connection(ProtocolError))?;
if push.is_end_headers() {
// Load the HPACK encoded headers & return the frame
match push.load_hpack(payload, &mut self.hpack) {
Ok(_) => {},
Err(frame::Error::MalformedMessage) => {
return Err(Stream {
id: head.stream_id(),
reason: ProtocolError,
});
},
Err(_) => return Err(Connection(ProtocolError)),
}
push.into()
} else {
// Defer loading the frame
self.partial = Some(Partial {
frame: Continuable::PushPromise(push),
buf: payload,
});
return Ok(None);
}
},
Kind::Priority => {
if head.stream_id() == 0 {
@@ -183,27 +213,23 @@ impl<T> FramedRead<T> {
return Ok(None);
}
match partial.frame {
Continuable::Headers(mut frame) => {
// The stream identifiers must match
if frame.stream_id() != head.stream_id() {
return Err(Connection(ProtocolError));
}
match frame.load_hpack(partial.buf, &mut self.hpack) {
Ok(_) => {},
Err(frame::Error::MalformedMessage) => {
return Err(Stream {
id: head.stream_id(),
reason: ProtocolError,
});
},
Err(_) => return Err(Connection(ProtocolError)),
}
frame.into()
},
// The stream identifiers must match
if partial.frame.stream_id() != head.stream_id() {
return Err(Connection(ProtocolError));
}
match partial.frame.load_hpack(partial.buf, &mut self.hpack) {
Ok(_) => {},
Err(frame::Error::MalformedMessage) => {
return Err(Stream {
id: head.stream_id(),
reason: ProtocolError,
});
},
Err(_) => return Err(Connection(ProtocolError)),
}
partial.frame.into()
},
Kind::Unknown => {
// Unknown frames are ignored
@@ -276,3 +302,30 @@ fn map_err(err: io::Error) -> RecvError {
}
err.into()
}
// ===== impl Continuable =====
impl Continuable {
fn stream_id(&self) -> frame::StreamId {
match *self {
Continuable::Headers(ref h) => h.stream_id(),
Continuable::PushPromise(ref p) => p.stream_id(),
}
}
fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), frame::Error> {
match *self {
Continuable::Headers(ref mut h) => h.load_hpack(src, decoder),
Continuable::PushPromise(ref mut p) => p.load_hpack(src, decoder),
}
}
}
impl<T> From<Continuable> for Frame<T> {
fn from(cont: Continuable) -> Self {
match cont {
Continuable::Headers(headers) => headers.into(),
Continuable::PushPromise(push) => push.into(),
}
}
}

View File

@@ -127,8 +127,9 @@ where
}
},
Frame::PushPromise(v) => {
debug!("unimplemented PUSH_PROMISE write; frame={:?}", v);
unimplemented!();
if let Some(continuation) = v.encode(&mut self.hpack, self.buf.get_mut()) {
self.next = Some(Next::Continuation(continuation));
}
},
Frame::Settings(v) => {
v.encode(self.buf.get_mut());

View File

@@ -1,12 +1,12 @@
use super::{StreamDependency, StreamId};
use frame::{self, Error, Frame, Head, Kind};
use frame::{Error, Frame, Head, Kind};
use hpack;
use http::{uri, HeaderMap, Method, StatusCode, Uri};
use http::header::{self, HeaderName, HeaderValue};
use byteorder::{BigEndian, ByteOrder};
use bytes::{Bytes, BytesMut};
use bytes::{BufMut, Bytes, BytesMut};
use string::String;
use std::fmt;
@@ -23,12 +23,8 @@ pub struct Headers {
/// The stream dependency information, if any.
stream_dep: Option<StreamDependency>,
/// The decoded header fields
fields: HeaderMap,
/// Pseudo headers, these are broken out as they must be sent as part of the
/// headers frame.
pseudo: Pseudo,
/// The header block fragment
header_block: HeaderBlock,
/// The associated flags
flags: HeadersFlag,
@@ -37,7 +33,7 @@ pub struct Headers {
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct HeadersFlag(u8);
#[derive(Debug, Eq, PartialEq)]
#[derive(Eq, PartialEq)]
pub struct PushPromise {
/// The ID of the stream with which this frame is associated.
stream_id: StreamId,
@@ -45,11 +41,14 @@ pub struct PushPromise {
/// The ID of the stream being reserved by this PushPromise.
promised_id: StreamId,
/// The header block fragment
header_block: HeaderBlock,
/// The associated flags
flags: PushPromiseFlag,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct PushPromiseFlag(u8);
#[derive(Debug)]
@@ -85,6 +84,16 @@ pub struct Iter {
fields: header::IntoIter<HeaderValue>,
}
#[derive(PartialEq, Eq)]
struct HeaderBlock {
/// The decoded header fields
fields: HeaderMap,
/// Pseudo headers, these are broken out as they must be sent as part of the
/// headers frame.
pseudo: Pseudo,
}
const END_STREAM: u8 = 0x1;
const END_HEADERS: u8 = 0x4;
const PADDED: u8 = 0x8;
@@ -99,8 +108,10 @@ impl Headers {
Headers {
stream_id: stream_id,
stream_dep: None,
fields: fields,
pseudo: pseudo,
header_block: HeaderBlock {
fields: fields,
pseudo: pseudo,
},
flags: HeadersFlag::default(),
}
}
@@ -112,8 +123,10 @@ impl Headers {
Headers {
stream_id,
stream_dep: None,
fields: fields,
pseudo: Pseudo::default(),
header_block: HeaderBlock {
fields: fields,
pseudo: Pseudo::default(),
},
flags: flags,
}
}
@@ -164,8 +177,10 @@ impl Headers {
let headers = Headers {
stream_id: head.stream_id(),
stream_dep: stream_dep,
fields: HeaderMap::new(),
pseudo: Pseudo::default(),
header_block: HeaderBlock {
fields: HeaderMap::new(),
pseudo: Pseudo::default(),
},
flags: flags,
};
@@ -173,71 +188,7 @@ impl Headers {
}
pub fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> {
let mut reg = false;
let mut malformed = false;
macro_rules! set_pseudo {
($field:ident, $val:expr) => {{
if reg {
trace!("load_hpack; header malformed -- pseudo not at head of block");
malformed = true;
} else if self.pseudo.$field.is_some() {
trace!("load_hpack; header malformed -- repeated pseudo");
malformed = true;
} else {
self.pseudo.$field = Some($val);
}
}}
}
let mut src = Cursor::new(src.freeze());
// At this point, we're going to assume that the hpack encoded headers
// contain the entire payload. Later, we need to check for stream
// priority.
//
// TODO: Provide a way to abort decoding if an error is hit.
let res = decoder.decode(&mut src, |header| {
use hpack::Header::*;
match header {
Field {
name,
value,
} => {
// Connection level header fields are not supported and must
// result in a protocol error.
if name == header::CONNECTION {
trace!("load_hpack; connection level header");
malformed = true;
} else if name == header::TE && value != "trailers" {
trace!("load_hpack; TE header not set to trailers; val={:?}", value);
malformed = true;
} else {
reg = true;
self.fields.append(name, value);
}
},
Authority(v) => set_pseudo!(authority, v),
Method(v) => set_pseudo!(method, v),
Scheme(v) => set_pseudo!(scheme, v),
Path(v) => set_pseudo!(path, v),
Status(v) => set_pseudo!(status, v),
}
});
if let Err(e) = res {
trace!("hpack decoding error; err={:?}", e);
return Err(e.into());
}
if malformed {
trace!("malformed message");
return Err(Error::MalformedMessage.into());
}
Ok(())
self.header_block.load(src, decoder)
}
pub fn stream_id(&self) -> StreamId {
@@ -257,15 +208,15 @@ impl Headers {
}
pub fn into_parts(self) -> (Pseudo, HeaderMap) {
(self.pseudo, self.fields)
(self.header_block.pseudo, self.header_block.fields)
}
pub fn fields(&self) -> &HeaderMap {
&self.fields
&self.header_block.fields
}
pub fn into_fields(self) -> HeaderMap {
self.fields
self.header_block.fields
}
pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> {
@@ -278,27 +229,12 @@ impl Headers {
head.encode(0, dst);
// Encode the frame
let mut headers = Iter {
pseudo: Some(self.pseudo),
fields: self.fields.into_iter(),
};
let ret = match encoder.encode(None, &mut headers, dst) {
hpack::Encode::Full => None,
hpack::Encode::Partial(state) => Some(Continuation {
stream_id: self.stream_id,
hpack: state,
headers: headers,
}),
};
// Compute the frame length
let len = (dst.len() - pos) - frame::HEADER_LEN;
let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst);
// Write the frame length
BigEndian::write_uint(&mut dst[pos..pos + 3], len as u64, 3);
BigEndian::write_uint(&mut dst[pos..pos + 3], len, 3);
ret
cont
}
fn head(&self) -> Head {
@@ -326,18 +262,66 @@ impl fmt::Debug for Headers {
// ===== impl PushPromise =====
impl PushPromise {
pub fn load(head: Head, payload: &[u8]) -> Result<Self, Error> {
pub fn new(
stream_id: StreamId,
promised_id: StreamId,
pseudo: Pseudo,
fields: HeaderMap,
) -> Self {
PushPromise {
flags: PushPromiseFlag::default(),
header_block: HeaderBlock {
fields,
pseudo,
},
promised_id,
stream_id,
}
}
/// Loads the push promise frame but doesn't actually do HPACK decoding.
///
/// HPACK decoding is done in the `load_hpack` step.
pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
let flags = PushPromiseFlag(head.flag());
let mut pad = 0;
// TODO: Handle padding
// Read the padding length
if flags.is_padded() {
// TODO: Ensure payload is sized correctly
pad = src[0] as usize;
let (promised_id, _) = StreamId::parse(&payload[..4]);
// Drop the padding
let _ = src.split_to(1);
}
Ok(PushPromise {
stream_id: head.stream_id(),
promised_id: promised_id,
let (promised_id, _) = StreamId::parse(&src[..4]);
// Drop promised_id bytes
let _ = src.split_to(5);
if pad > 0 {
if pad > src.len() {
return Err(Error::TooMuchPadding);
}
let len = src.len() - pad;
src.truncate(len);
}
let frame = PushPromise {
flags: flags,
})
header_block: HeaderBlock {
fields: HeaderMap::new(),
pseudo: Pseudo::default(),
},
promised_id: promised_id,
stream_id: head.stream_id(),
};
Ok((frame, src))
}
pub fn load_hpack(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> {
self.header_block.load(src, decoder)
}
pub fn stream_id(&self) -> StreamId {
@@ -347,6 +331,45 @@ impl PushPromise {
pub fn promised_id(&self) -> StreamId {
self.promised_id
}
pub fn is_end_headers(&self) -> bool {
self.flags.is_end_headers()
}
pub fn into_parts(self) -> (Pseudo, HeaderMap) {
(self.header_block.pseudo, self.header_block.fields)
}
pub fn fields(&self) -> &HeaderMap {
&self.header_block.fields
}
pub fn into_fields(self) -> HeaderMap {
self.header_block.fields
}
pub fn encode(self, encoder: &mut hpack::Encoder, dst: &mut BytesMut) -> Option<Continuation> {
let head = self.head();
let pos = dst.len();
// At this point, we don't know how big the h2 frame will be.
// So, we write the head with length 0, then write the body, and
// finally write the length once we know the size.
head.encode(0, dst);
// Encode the frame
dst.put_u32::<BigEndian>(self.promised_id.into());
let (len, cont) = self.header_block.encode(self.stream_id, encoder, dst);
// Write the frame length
BigEndian::write_uint(&mut dst[pos..pos + 3], len + 4, 3);
cont
}
fn head(&self) -> Head {
Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
}
}
impl<T> From<PushPromise> for Frame<T> {
@@ -355,6 +378,17 @@ impl<T> From<PushPromise> for Frame<T> {
}
}
impl fmt::Debug for PushPromise {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("PushPromise")
.field("stream_id", &self.stream_id)
.field("promised_id", &self.promised_id)
.field("flags", &self.flags)
// `fields` and `pseudo` purposefully not included
.finish()
}
}
// ===== impl Pseudo =====
impl Pseudo {
@@ -509,3 +543,144 @@ impl fmt::Debug for HeadersFlag {
.finish()
}
}
// ===== impl PushPromiseFlag =====
impl PushPromiseFlag {
pub fn empty() -> PushPromiseFlag {
PushPromiseFlag(0)
}
pub fn load(bits: u8) -> PushPromiseFlag {
PushPromiseFlag(bits & ALL)
}
pub fn is_end_headers(&self) -> bool {
self.0 & END_HEADERS == END_HEADERS
}
pub fn is_padded(&self) -> bool {
self.0 & PADDED == PADDED
}
}
impl Default for PushPromiseFlag {
/// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
fn default() -> Self {
PushPromiseFlag(END_HEADERS)
}
}
impl From<PushPromiseFlag> for u8 {
fn from(src: PushPromiseFlag) -> u8 {
src.0
}
}
impl fmt::Debug for PushPromiseFlag {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("PushPromiseFlag")
.field("end_headers", &self.is_end_headers())
.field("padded", &self.is_padded())
.finish()
}
}
// ===== HeaderBlock =====
impl HeaderBlock {
fn load(&mut self, src: BytesMut, decoder: &mut hpack::Decoder) -> Result<(), Error> {
let mut reg = false;
let mut malformed = false;
macro_rules! set_pseudo {
($field:ident, $val:expr) => {{
if reg {
trace!("load_hpack; header malformed -- pseudo not at head of block");
malformed = true;
} else if self.pseudo.$field.is_some() {
trace!("load_hpack; header malformed -- repeated pseudo");
malformed = true;
} else {
self.pseudo.$field = Some($val);
}
}}
}
let mut src = Cursor::new(src.freeze());
// At this point, we're going to assume that the hpack encoded headers
// contain the entire payload. Later, we need to check for stream
// priority.
//
// TODO: Provide a way to abort decoding if an error is hit.
let res = decoder.decode(&mut src, |header| {
use hpack::Header::*;
match header {
Field {
name,
value,
} => {
// Connection level header fields are not supported and must
// result in a protocol error.
if name == header::CONNECTION {
trace!("load_hpack; connection level header");
malformed = true;
} else if name == header::TE && value != "trailers" {
trace!("load_hpack; TE header not set to trailers; val={:?}", value);
malformed = true;
} else {
reg = true;
self.fields.append(name, value);
}
},
Authority(v) => set_pseudo!(authority, v),
Method(v) => set_pseudo!(method, v),
Scheme(v) => set_pseudo!(scheme, v),
Path(v) => set_pseudo!(path, v),
Status(v) => set_pseudo!(status, v),
}
});
if let Err(e) = res {
trace!("hpack decoding error; err={:?}", e);
return Err(e.into());
}
if malformed {
trace!("malformed message");
return Err(Error::MalformedMessage.into());
}
Ok(())
}
fn encode(
self,
stream_id: StreamId,
encoder: &mut hpack::Encoder,
dst: &mut BytesMut,
) -> (u64, Option<Continuation>) {
let pos = dst.len();
let mut headers = Iter {
pseudo: Some(self.pseudo),
fields: self.fields.into_iter(),
};
let cont = match encoder.encode(None, &mut headers, dst) {
hpack::Encode::Full => None,
hpack::Encode::Partial(state) => Some(Continuation {
stream_id: stream_id,
hpack: state,
headers: headers,
}),
};
// Compute the header block length
let len = (dst.len() - pos) as u64;
(len, cont)
}
}

View File

@@ -85,6 +85,14 @@ impl Settings {
self.max_frame_size = size;
}
pub fn is_push_enabled(&self) -> bool {
self.enable_push.unwrap_or(1) != 0
}
pub fn set_enable_push(&mut self, enable: bool) {
self.enable_push = Some(enable as u32);
}
pub fn load(head: Head, payload: &[u8]) -> Result<Settings, Error> {
use self::Setting::*;

View File

@@ -64,12 +64,13 @@ where
) -> Connection<T, P, B> {
// TODO: Actually configure
let streams = Streams::new(streams::Config {
max_remote_initiated: None,
init_remote_window_sz: DEFAULT_INITIAL_WINDOW_SIZE,
max_local_initiated: None,
init_local_window_sz: settings
local_init_window_sz: settings
.initial_window_size()
.unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE),
local_max_initiated: None,
local_push_enabled: settings.is_push_enabled(),
remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE,
remote_max_initiated: None,
});
Connection {
state: State::Open,

View File

@@ -34,9 +34,9 @@ where
/// Create a new `Counts` using the provided configuration values.
pub fn new(config: &Config) -> Self {
Counts {
max_send_streams: config.max_local_initiated,
max_send_streams: config.local_max_initiated,
num_send_streams: 0,
max_recv_streams: config.max_remote_initiated,
max_recv_streams: config.remote_max_initiated,
num_recv_streams: 0,
blocked_open: None,
_p: PhantomData,

View File

@@ -31,15 +31,18 @@ use http::{Request, Response};
#[derive(Debug)]
pub struct Config {
/// Maximum number of remote initiated streams
pub max_remote_initiated: Option<usize>,
/// Initial window size of remote initiated streams
pub init_remote_window_sz: WindowSize,
/// Initial window size of locally initiated streams
pub local_init_window_sz: WindowSize,
/// Maximum number of locally initiated streams
pub max_local_initiated: Option<usize>,
pub local_max_initiated: Option<usize>,
/// Initial window size of locally initiated streams
pub init_local_window_sz: WindowSize,
/// If the local peer is willing to receive push promises
pub local_push_enabled: bool,
/// Initial window size of remote initiated streams
pub remote_init_window_sz: WindowSize,
/// Maximum number of remote initiated streams
pub remote_max_initiated: Option<usize>,
}

View File

@@ -49,11 +49,11 @@ where
pub fn new(config: &Config) -> Prioritize<B, P> {
let mut flow = FlowControl::new();
flow.inc_window(config.init_local_window_sz)
flow.inc_window(config.local_init_window_sz)
.ok()
.expect("invalid initial window size");
flow.assign_capacity(config.init_local_window_sz);
flow.assign_capacity(config.local_init_window_sz);
trace!("Prioritize::new; flow={:?}", flow);

View File

@@ -38,6 +38,9 @@ where
/// Refused StreamId, this represents a frame that must be sent out.
refused: Option<StreamId>,
/// If push promises are allowed to be recevied.
is_push_enabled: bool,
_p: PhantomData<B>,
}
@@ -71,7 +74,7 @@ where
flow.assign_capacity(DEFAULT_INITIAL_WINDOW_SIZE);
Recv {
init_window_sz: config.init_local_window_sz,
init_window_sz: config.local_init_window_sz,
flow: flow,
next_stream_id: next_stream_id.into(),
pending_window_updates: store::Queue::new(),
@@ -79,6 +82,7 @@ where
pending_accept: store::Queue::new(),
buffer: Buffer::new(),
refused: None,
is_push_enabled: config.local_push_enabled,
_p: PhantomData,
}
}
@@ -429,10 +433,20 @@ where
// TODO: Are there other rules?
if P::is_server() {
// The remote is a client and cannot reserve
trace!("recv_push_promise; error remote is client");
return Err(RecvError::Connection(ProtocolError));
}
if !promised_id.is_server_initiated() {
trace!(
"recv_push_promise; error promised id is invalid {:?}",
promised_id
);
return Err(RecvError::Connection(ProtocolError));
}
if !self.is_push_enabled {
trace!("recv_push_promise; error push is disabled");
return Err(RecvError::Connection(ProtocolError));
}

View File

@@ -35,7 +35,7 @@ where
Send {
next_stream_id: next_stream_id.into(),
init_window_sz: config.init_local_window_sz,
init_window_sz: config.local_init_window_sz,
prioritize: Prioritize::new(config),
}
}

View File

@@ -285,6 +285,7 @@ impl State {
..
} => true,
HalfClosedLocal(AwaitingHeaders) => true,
ReservedRemote => true,
_ => false,
}
}

109
tests/push_promise.rs Normal file
View File

@@ -0,0 +1,109 @@
extern crate h2_test_support;
use h2_test_support::prelude::*;
#[test]
fn recv_push_works() {
// tests that by default, received push promises work
// TODO: once API exists, read the pushed response
let _ = ::env_logger::init();
let (io, srv) = mock::new();
let mock = srv.assert_client_handshake()
.unwrap()
.recv_settings()
.recv_frame(
frames::headers(1)
.request("GET", "https://http2.akamai.com/")
.eos(),
)
.send_frame(
frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css"),
)
.send_frame(frames::headers(1).response(200).eos())
.send_frame(frames::headers(2).response(200).eos());
let h2 = Client::handshake(io).unwrap().and_then(|mut h2| {
let request = Request::builder()
.method(Method::GET)
.uri("https://http2.akamai.com/")
.body(())
.unwrap();
let req = h2.request(request, true)
.unwrap()
.unwrap()
.and_then(|resp| {
assert_eq!(resp.status(), StatusCode::OK);
Ok(())
});
h2.drive(req)
});
h2.join(mock).wait().unwrap();
}
#[test]
fn recv_push_when_push_disabled_is_conn_error() {
let _ = ::env_logger::init();
let (io, srv) = mock::new();
let mock = srv.assert_client_handshake()
.unwrap()
.ignore_settings()
.recv_frame(
frames::headers(1)
.request("GET", "https://http2.akamai.com/")
.eos(),
)
.send_frame(
frames::push_promise(1, 3).request("GET", "https://http2.akamai.com/style.css"),
)
.send_frame(frames::headers(1).response(200).eos())
.recv_frame(frames::go_away(0).protocol_error());
let h2 = Client::builder()
.enable_push(false)
.handshake::<_, Bytes>(io)
.unwrap()
.and_then(|mut h2| {
let request = Request::builder()
.method(Method::GET)
.uri("https://http2.akamai.com/")
.body(())
.unwrap();
let req = h2.request(request, true).unwrap().then(|res| {
let err = res.unwrap_err();
assert_eq!(
err.to_string(),
"protocol error: unspecific protocol error detected"
);
Ok::<(), ()>(())
});
// client should see a protocol error
let conn = h2.then(|res| {
let err = res.unwrap_err();
assert_eq!(
err.to_string(),
"protocol error: unspecific protocol error detected"
);
Ok::<(), ()>(())
});
conn.unwrap().join(req)
});
h2.join(mock).wait().unwrap();
}
#[test]
#[ignore]
fn recv_push_promise_with_unsafe_method_is_stream_error() {
// for instance, when :method = POST
}
#[test]
#[ignore]
fn recv_push_promise_with_wrong_authority_is_stream_error() {
// if server is foo.com, :authority = bar.com is stream error
}

View File

@@ -28,6 +28,18 @@ pub fn data<T, B>(id: T, buf: B) -> Mock<frame::Data>
Mock(frame::Data::new(id.into(), buf.into()))
}
pub fn push_promise<T1, T2>(id: T1, promised: T2) -> Mock<frame::PushPromise>
where T1: Into<StreamId>,
T2: Into<StreamId>,
{
Mock(frame::PushPromise::new(
id.into(),
promised.into(),
frame::Pseudo::default(),
HeaderMap::default(),
))
}
pub fn window_update<T>(id: T, sz: u32) -> frame::WindowUpdate
where T: Into<StreamId>,
{
@@ -140,9 +152,54 @@ impl From<Mock<frame::Data>> for SendFrame {
}
}
// PushPromise helpers
impl Mock<frame::PushPromise> {
pub fn request<M, U>(self, method: M, uri: U) -> Self
where M: HttpTryInto<http::Method>,
U: HttpTryInto<http::Uri>,
{
let method = method.try_into().unwrap();
let uri = uri.try_into().unwrap();
let (id, promised, _, fields) = self.into_parts();
let frame = frame::PushPromise::new(
id,
promised,
frame::Pseudo::request(method, uri),
fields
);
Mock(frame)
}
pub fn fields(self, fields: HeaderMap) -> Self {
let (id, promised, pseudo, _) = self.into_parts();
let frame = frame::PushPromise::new(id, promised, pseudo, fields);
Mock(frame)
}
fn into_parts(self) -> (StreamId, StreamId, frame::Pseudo, HeaderMap) {
assert!(self.0.is_end_headers(), "unset eoh will be lost");
let id = self.0.stream_id();
let promised = self.0.promised_id();
let parts = self.0.into_parts();
(id, promised, parts.0, parts.1)
}
}
impl From<Mock<frame::PushPromise>> for SendFrame {
fn from(src: Mock<frame::PushPromise>) -> Self {
Frame::PushPromise(src.0)
}
}
// GoAway helpers
impl Mock<frame::GoAway> {
pub fn protocol_error(self) -> Self {
Mock(frame::GoAway::new(self.0.last_stream_id(), frame::Reason::ProtocolError))
}
pub fn flow_control(self) -> Self {
Mock(frame::GoAway::new(self.0.last_stream_id(), frame::Reason::FlowControlError))
}