Implement the extended CONNECT protocol from RFC 8441 (#565)

This commit is contained in:
Anthony Ramine
2021-11-24 10:05:10 +01:00
committed by GitHub
parent dbaa3a4285
commit 87969c1f29
22 changed files with 694 additions and 120 deletions

View File

@@ -136,6 +136,7 @@
//! [`Error`]: ../struct.Error.html //! [`Error`]: ../struct.Error.html
use crate::codec::{Codec, SendError, UserError}; use crate::codec::{Codec, SendError, UserError};
use crate::ext::Protocol;
use crate::frame::{Headers, Pseudo, Reason, Settings, StreamId}; use crate::frame::{Headers, Pseudo, Reason, Settings, StreamId};
use crate::proto::{self, Error}; use crate::proto::{self, Error};
use crate::{FlowControl, PingPong, RecvStream, SendStream}; use crate::{FlowControl, PingPong, RecvStream, SendStream};
@@ -517,6 +518,19 @@ where
(response, stream) (response, stream)
}) })
} }
/// Returns whether the [extended CONNECT protocol][1] is enabled or not.
///
/// This setting is configured by the server peer by sending the
/// [`SETTINGS_ENABLE_CONNECT_PROTOCOL` parameter][2] in a `SETTINGS` frame.
/// This method returns the currently acknowledged value recieved from the
/// remote.
///
/// [1]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
/// [2]: https://datatracker.ietf.org/doc/html/rfc8441#section-3
pub fn is_extended_connect_protocol_enabled(&self) -> bool {
self.inner.is_extended_connect_protocol_enabled()
}
} }
impl<B> fmt::Debug for SendRequest<B> impl<B> fmt::Debug for SendRequest<B>
@@ -1246,11 +1260,10 @@ where
/// This method returns the currently acknowledged value recieved from the /// This method returns the currently acknowledged value recieved from the
/// remote. /// remote.
/// ///
/// [settings]: https://tools.ietf.org/html/rfc7540#section-5.1.2 /// [1]: https://tools.ietf.org/html/rfc7540#section-5.1.2
pub fn max_concurrent_send_streams(&self) -> usize { pub fn max_concurrent_send_streams(&self) -> usize {
self.inner.max_send_streams() self.inner.max_send_streams()
} }
/// Returns the maximum number of concurrent streams that may be initiated /// Returns the maximum number of concurrent streams that may be initiated
/// by the server on this connection. /// by the server on this connection.
/// ///
@@ -1416,6 +1429,7 @@ impl Peer {
pub fn convert_send_message( pub fn convert_send_message(
id: StreamId, id: StreamId,
request: Request<()>, request: Request<()>,
protocol: Option<Protocol>,
end_of_stream: bool, end_of_stream: bool,
) -> Result<Headers, SendError> { ) -> Result<Headers, SendError> {
use http::request::Parts; use http::request::Parts;
@@ -1435,7 +1449,7 @@ impl Peer {
// Build the set pseudo header set. All requests will include `method` // Build the set pseudo header set. All requests will include `method`
// and `path`. // and `path`.
let mut pseudo = Pseudo::request(method, uri); let mut pseudo = Pseudo::request(method, uri, protocol);
if pseudo.scheme.is_none() { if pseudo.scheme.is_none() {
// If the scheme is not set, then there are a two options. // If the scheme is not set, then there are a two options.

55
src/ext.rs Normal file
View File

@@ -0,0 +1,55 @@
//! Extensions specific to the HTTP/2 protocol.
use crate::hpack::BytesStr;
use bytes::Bytes;
use std::fmt;
/// Represents the `:protocol` pseudo-header used by
/// the [Extended CONNECT Protocol].
///
/// [Extended CONNECT Protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
#[derive(Clone, Eq, PartialEq)]
pub struct Protocol {
value: BytesStr,
}
impl Protocol {
/// Converts a static string to a protocol name.
pub const fn from_static(value: &'static str) -> Self {
Self {
value: BytesStr::from_static(value),
}
}
/// Returns a str representation of the header.
pub fn as_str(&self) -> &str {
self.value.as_str()
}
pub(crate) fn try_from(bytes: Bytes) -> Result<Self, std::str::Utf8Error> {
Ok(Self {
value: BytesStr::try_from(bytes)?,
})
}
}
impl<'a> From<&'a str> for Protocol {
fn from(value: &'a str) -> Self {
Self {
value: BytesStr::from(value),
}
}
}
impl AsRef<[u8]> for Protocol {
fn as_ref(&self) -> &[u8] {
self.value.as_ref()
}
}
impl fmt::Debug for Protocol {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.value.fmt(f)
}
}

View File

@@ -1,4 +1,5 @@
use super::{util, StreamDependency, StreamId}; use super::{util, StreamDependency, StreamId};
use crate::ext::Protocol;
use crate::frame::{Error, Frame, Head, Kind}; use crate::frame::{Error, Frame, Head, Kind};
use crate::hpack::{self, BytesStr}; use crate::hpack::{self, BytesStr};
@@ -66,6 +67,7 @@ pub struct Pseudo {
pub scheme: Option<BytesStr>, pub scheme: Option<BytesStr>,
pub authority: Option<BytesStr>, pub authority: Option<BytesStr>,
pub path: Option<BytesStr>, pub path: Option<BytesStr>,
pub protocol: Option<Protocol>,
// Response // Response
pub status: Option<StatusCode>, pub status: Option<StatusCode>,
@@ -292,6 +294,10 @@ impl fmt::Debug for Headers {
.field("stream_id", &self.stream_id) .field("stream_id", &self.stream_id)
.field("flags", &self.flags); .field("flags", &self.flags);
if let Some(ref protocol) = self.header_block.pseudo.protocol {
builder.field("protocol", protocol);
}
if let Some(ref dep) = self.stream_dep { if let Some(ref dep) = self.stream_dep {
builder.field("stream_dep", dep); builder.field("stream_dep", dep);
} }
@@ -529,7 +535,7 @@ impl Continuation {
// ===== impl Pseudo ===== // ===== impl Pseudo =====
impl Pseudo { impl Pseudo {
pub fn request(method: Method, uri: Uri) -> Self { pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
let parts = uri::Parts::from(uri); let parts = uri::Parts::from(uri);
let mut path = parts let mut path = parts
@@ -550,6 +556,7 @@ impl Pseudo {
scheme: None, scheme: None,
authority: None, authority: None,
path: Some(path).filter(|p| !p.is_empty()), path: Some(path).filter(|p| !p.is_empty()),
protocol,
status: None, status: None,
}; };
@@ -575,6 +582,7 @@ impl Pseudo {
scheme: None, scheme: None,
authority: None, authority: None,
path: None, path: None,
protocol: None,
status: Some(status), status: Some(status),
} }
} }
@@ -593,6 +601,11 @@ impl Pseudo {
self.scheme = Some(bytes_str); self.scheme = Some(bytes_str);
} }
#[cfg(feature = "unstable")]
pub fn set_protocol(&mut self, protocol: Protocol) {
self.protocol = Some(protocol);
}
pub fn set_authority(&mut self, authority: BytesStr) { pub fn set_authority(&mut self, authority: BytesStr) {
self.authority = Some(authority); self.authority = Some(authority);
} }
@@ -681,6 +694,10 @@ impl Iterator for Iter {
return Some(Path(path)); return Some(Path(path));
} }
if let Some(protocol) = pseudo.protocol.take() {
return Some(Protocol(protocol));
}
if let Some(status) = pseudo.status.take() { if let Some(status) = pseudo.status.take() {
return Some(Status(status)); return Some(Status(status));
} }
@@ -879,6 +896,7 @@ impl HeaderBlock {
Method(v) => set_pseudo!(method, v), Method(v) => set_pseudo!(method, v),
Scheme(v) => set_pseudo!(scheme, v), Scheme(v) => set_pseudo!(scheme, v),
Path(v) => set_pseudo!(path, v), Path(v) => set_pseudo!(path, v),
Protocol(v) => set_pseudo!(protocol, v),
Status(v) => set_pseudo!(status, v), Status(v) => set_pseudo!(status, v),
} }
}); });

View File

@@ -13,6 +13,7 @@ pub struct Settings {
initial_window_size: Option<u32>, initial_window_size: Option<u32>,
max_frame_size: Option<u32>, max_frame_size: Option<u32>,
max_header_list_size: Option<u32>, max_header_list_size: Option<u32>,
enable_connect_protocol: Option<u32>,
} }
/// An enum that lists all valid settings that can be sent in a SETTINGS /// An enum that lists all valid settings that can be sent in a SETTINGS
@@ -27,6 +28,7 @@ pub enum Setting {
InitialWindowSize(u32), InitialWindowSize(u32),
MaxFrameSize(u32), MaxFrameSize(u32),
MaxHeaderListSize(u32), MaxHeaderListSize(u32),
EnableConnectProtocol(u32),
} }
#[derive(Copy, Clone, Eq, PartialEq, Default)] #[derive(Copy, Clone, Eq, PartialEq, Default)]
@@ -107,6 +109,14 @@ impl Settings {
self.enable_push = Some(enable as u32); self.enable_push = Some(enable as u32);
} }
pub fn is_extended_connect_protocol_enabled(&self) -> Option<bool> {
self.enable_connect_protocol.map(|val| val != 0)
}
pub fn set_enable_connect_protocol(&mut self, val: Option<u32>) {
self.enable_connect_protocol = val;
}
pub fn header_table_size(&self) -> Option<u32> { pub fn header_table_size(&self) -> Option<u32> {
self.header_table_size self.header_table_size
} }
@@ -181,6 +191,14 @@ impl Settings {
Some(MaxHeaderListSize(val)) => { Some(MaxHeaderListSize(val)) => {
settings.max_header_list_size = Some(val); settings.max_header_list_size = Some(val);
} }
Some(EnableConnectProtocol(val)) => match val {
0 | 1 => {
settings.enable_connect_protocol = Some(val);
}
_ => {
return Err(Error::InvalidSettingValue);
}
},
None => {} None => {}
} }
} }
@@ -236,6 +254,10 @@ impl Settings {
if let Some(v) = self.max_header_list_size { if let Some(v) = self.max_header_list_size {
f(MaxHeaderListSize(v)); f(MaxHeaderListSize(v));
} }
if let Some(v) = self.enable_connect_protocol {
f(EnableConnectProtocol(v));
}
} }
} }
@@ -269,6 +291,9 @@ impl fmt::Debug for Settings {
Setting::MaxHeaderListSize(v) => { Setting::MaxHeaderListSize(v) => {
builder.field("max_header_list_size", &v); builder.field("max_header_list_size", &v);
} }
Setting::EnableConnectProtocol(v) => {
builder.field("enable_connect_protocol", &v);
}
}); });
builder.finish() builder.finish()
@@ -291,6 +316,7 @@ impl Setting {
4 => Some(InitialWindowSize(val)), 4 => Some(InitialWindowSize(val)),
5 => Some(MaxFrameSize(val)), 5 => Some(MaxFrameSize(val)),
6 => Some(MaxHeaderListSize(val)), 6 => Some(MaxHeaderListSize(val)),
8 => Some(EnableConnectProtocol(val)),
_ => None, _ => None,
} }
} }
@@ -322,6 +348,7 @@ impl Setting {
InitialWindowSize(v) => (4, v), InitialWindowSize(v) => (4, v),
MaxFrameSize(v) => (5, v), MaxFrameSize(v) => (5, v),
MaxHeaderListSize(v) => (6, v), MaxHeaderListSize(v) => (6, v),
EnableConnectProtocol(v) => (8, v),
}; };
dst.put_u16(kind); dst.put_u16(kind);

View File

@@ -1,4 +1,5 @@
use super::{DecoderError, NeedMore}; use super::{DecoderError, NeedMore};
use crate::ext::Protocol;
use bytes::Bytes; use bytes::Bytes;
use http::header::{HeaderName, HeaderValue}; use http::header::{HeaderName, HeaderValue};
@@ -14,6 +15,7 @@ pub enum Header<T = HeaderName> {
Method(Method), Method(Method),
Scheme(BytesStr), Scheme(BytesStr),
Path(BytesStr), Path(BytesStr),
Protocol(Protocol),
Status(StatusCode), Status(StatusCode),
} }
@@ -25,6 +27,7 @@ pub enum Name<'a> {
Method, Method,
Scheme, Scheme,
Path, Path,
Protocol,
Status, Status,
} }
@@ -51,6 +54,7 @@ impl Header<Option<HeaderName>> {
Method(v) => Method(v), Method(v) => Method(v),
Scheme(v) => Scheme(v), Scheme(v) => Scheme(v),
Path(v) => Path(v), Path(v) => Path(v),
Protocol(v) => Protocol(v),
Status(v) => Status(v), Status(v) => Status(v),
}) })
} }
@@ -79,6 +83,10 @@ impl Header {
let value = BytesStr::try_from(value)?; let value = BytesStr::try_from(value)?;
Ok(Header::Path(value)) Ok(Header::Path(value))
} }
b"protocol" => {
let value = Protocol::try_from(value)?;
Ok(Header::Protocol(value))
}
b"status" => { b"status" => {
let status = StatusCode::from_bytes(&value)?; let status = StatusCode::from_bytes(&value)?;
Ok(Header::Status(status)) Ok(Header::Status(status))
@@ -104,6 +112,7 @@ impl Header {
Header::Method(ref v) => 32 + 7 + v.as_ref().len(), Header::Method(ref v) => 32 + 7 + v.as_ref().len(),
Header::Scheme(ref v) => 32 + 7 + v.len(), Header::Scheme(ref v) => 32 + 7 + v.len(),
Header::Path(ref v) => 32 + 5 + v.len(), Header::Path(ref v) => 32 + 5 + v.len(),
Header::Protocol(ref v) => 32 + 9 + v.as_str().len(),
Header::Status(_) => 32 + 7 + 3, Header::Status(_) => 32 + 7 + 3,
} }
} }
@@ -116,6 +125,7 @@ impl Header {
Header::Method(..) => Name::Method, Header::Method(..) => Name::Method,
Header::Scheme(..) => Name::Scheme, Header::Scheme(..) => Name::Scheme,
Header::Path(..) => Name::Path, Header::Path(..) => Name::Path,
Header::Protocol(..) => Name::Protocol,
Header::Status(..) => Name::Status, Header::Status(..) => Name::Status,
} }
} }
@@ -127,6 +137,7 @@ impl Header {
Header::Method(ref v) => v.as_ref().as_ref(), Header::Method(ref v) => v.as_ref().as_ref(),
Header::Scheme(ref v) => v.as_ref(), Header::Scheme(ref v) => v.as_ref(),
Header::Path(ref v) => v.as_ref(), Header::Path(ref v) => v.as_ref(),
Header::Protocol(ref v) => v.as_ref(),
Header::Status(ref v) => v.as_str().as_ref(), Header::Status(ref v) => v.as_str().as_ref(),
} }
} }
@@ -156,6 +167,10 @@ impl Header {
Header::Path(ref b) => a == b, Header::Path(ref b) => a == b,
_ => false, _ => false,
}, },
Header::Protocol(ref a) => match *other {
Header::Protocol(ref b) => a == b,
_ => false,
},
Header::Status(ref a) => match *other { Header::Status(ref a) => match *other {
Header::Status(ref b) => a == b, Header::Status(ref b) => a == b,
_ => false, _ => false,
@@ -205,6 +220,7 @@ impl From<Header> for Header<Option<HeaderName>> {
Header::Method(v) => Header::Method(v), Header::Method(v) => Header::Method(v),
Header::Scheme(v) => Header::Scheme(v), Header::Scheme(v) => Header::Scheme(v),
Header::Path(v) => Header::Path(v), Header::Path(v) => Header::Path(v),
Header::Protocol(v) => Header::Protocol(v),
Header::Status(v) => Header::Status(v), Header::Status(v) => Header::Status(v),
} }
} }
@@ -221,6 +237,7 @@ impl<'a> Name<'a> {
Name::Method => Ok(Header::Method(Method::from_bytes(&*value)?)), Name::Method => Ok(Header::Method(Method::from_bytes(&*value)?)),
Name::Scheme => Ok(Header::Scheme(BytesStr::try_from(value)?)), Name::Scheme => Ok(Header::Scheme(BytesStr::try_from(value)?)),
Name::Path => Ok(Header::Path(BytesStr::try_from(value)?)), Name::Path => Ok(Header::Path(BytesStr::try_from(value)?)),
Name::Protocol => Ok(Header::Protocol(Protocol::try_from(value)?)),
Name::Status => { Name::Status => {
match StatusCode::from_bytes(&value) { match StatusCode::from_bytes(&value) {
Ok(status) => Ok(Header::Status(status)), Ok(status) => Ok(Header::Status(status)),
@@ -238,6 +255,7 @@ impl<'a> Name<'a> {
Name::Method => b":method", Name::Method => b":method",
Name::Scheme => b":scheme", Name::Scheme => b":scheme",
Name::Path => b":path", Name::Path => b":path",
Name::Protocol => b":protocol",
Name::Status => b":status", Name::Status => b":status",
} }
} }

View File

@@ -751,6 +751,7 @@ fn index_static(header: &Header) -> Option<(usize, bool)> {
"/index.html" => Some((5, true)), "/index.html" => Some((5, true)),
_ => Some((4, false)), _ => Some((4, false)),
}, },
Header::Protocol(..) => None,
Header::Status(ref v) => match u16::from(*v) { Header::Status(ref v) => match u16::from(*v) {
200 => Some((8, true)), 200 => Some((8, true)),
204 => Some((9, true)), 204 => Some((9, true)),

View File

@@ -134,6 +134,7 @@ fn key_str(e: &Header) -> &str {
Header::Method(..) => ":method", Header::Method(..) => ":method",
Header::Scheme(..) => ":scheme", Header::Scheme(..) => ":scheme",
Header::Path(..) => ":path", Header::Path(..) => ":path",
Header::Protocol(..) => ":protocol",
Header::Status(..) => ":status", Header::Status(..) => ":status",
} }
} }
@@ -145,6 +146,7 @@ fn value_str(e: &Header) -> &str {
Header::Method(ref m) => m.as_str(), Header::Method(ref m) => m.as_str(),
Header::Scheme(ref v) => &**v, Header::Scheme(ref v) => &**v,
Header::Path(ref v) => &**v, Header::Path(ref v) => &**v,
Header::Protocol(ref v) => v.as_str(),
Header::Status(ref v) => v.as_str(), Header::Status(ref v) => v.as_str(),
} }
} }

View File

@@ -120,6 +120,7 @@ mod frame;
pub mod frame; pub mod frame;
pub mod client; pub mod client;
pub mod ext;
pub mod server; pub mod server;
mod share; mod share;

View File

@@ -110,6 +110,10 @@ where
initial_max_send_streams: config.initial_max_send_streams, initial_max_send_streams: config.initial_max_send_streams,
local_next_stream_id: config.next_stream_id, local_next_stream_id: config.next_stream_id,
local_push_enabled: config.settings.is_push_enabled().unwrap_or(true), local_push_enabled: config.settings.is_push_enabled().unwrap_or(true),
extended_connect_protocol_enabled: config
.settings
.is_extended_connect_protocol_enabled()
.unwrap_or(false),
local_reset_duration: config.reset_stream_duration, local_reset_duration: config.reset_stream_duration,
local_reset_max: config.reset_stream_max, local_reset_max: config.reset_stream_max,
remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE,
@@ -147,6 +151,13 @@ where
self.inner.settings.send_settings(settings) self.inner.settings.send_settings(settings)
} }
/// Send a new SETTINGS frame with extended CONNECT protocol enabled.
pub(crate) fn set_enable_connect_protocol(&mut self) -> Result<(), UserError> {
let mut settings = frame::Settings::default();
settings.set_enable_connect_protocol(Some(1));
self.inner.settings.send_settings(settings)
}
/// Returns the maximum number of concurrent streams that may be initiated /// Returns the maximum number of concurrent streams that may be initiated
/// by this peer. /// by this peer.
pub(crate) fn max_send_streams(&self) -> usize { pub(crate) fn max_send_streams(&self) -> usize {

View File

@@ -117,6 +117,8 @@ impl Settings {
tracing::trace!("ACK sent; applying settings"); tracing::trace!("ACK sent; applying settings");
streams.apply_remote_settings(settings)?;
if let Some(val) = settings.header_table_size() { if let Some(val) = settings.header_table_size() {
dst.set_send_header_table_size(val as usize); dst.set_send_header_table_size(val as usize);
} }
@@ -124,8 +126,6 @@ impl Settings {
if let Some(val) = settings.max_frame_size() { if let Some(val) = settings.max_frame_size() {
dst.set_max_send_frame_size(val as usize); dst.set_max_send_frame_size(val as usize);
} }
streams.apply_remote_settings(settings)?;
} }
self.remote = None; self.remote = None;

View File

@@ -47,6 +47,9 @@ pub struct Config {
/// If the local peer is willing to receive push promises /// If the local peer is willing to receive push promises
pub local_push_enabled: bool, pub local_push_enabled: bool,
/// If extended connect protocol is enabled.
pub extended_connect_protocol_enabled: bool,
/// How long a locally reset stream should ignore frames /// How long a locally reset stream should ignore frames
pub local_reset_duration: Duration, pub local_reset_duration: Duration,

View File

@@ -56,6 +56,9 @@ pub(super) struct Recv {
/// If push promises are allowed to be received. /// If push promises are allowed to be received.
is_push_enabled: bool, is_push_enabled: bool,
/// If extended connect protocol is enabled.
is_extended_connect_protocol_enabled: bool,
} }
#[derive(Debug)] #[derive(Debug)]
@@ -103,6 +106,7 @@ impl Recv {
buffer: Buffer::new(), buffer: Buffer::new(),
refused: None, refused: None,
is_push_enabled: config.local_push_enabled, is_push_enabled: config.local_push_enabled,
is_extended_connect_protocol_enabled: config.extended_connect_protocol_enabled,
} }
} }
@@ -216,6 +220,14 @@ impl Recv {
let stream_id = frame.stream_id(); let stream_id = frame.stream_id();
let (pseudo, fields) = frame.into_parts(); let (pseudo, fields) = frame.into_parts();
if pseudo.protocol.is_some() {
if counts.peer().is_server() && !self.is_extended_connect_protocol_enabled {
proto_err!(stream: "cannot use :protocol if extended connect protocol is disabled; stream={:?}", stream.id);
return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR).into());
}
}
if !pseudo.is_informational() { if !pseudo.is_informational() {
let message = counts let message = counts
.peer() .peer()
@@ -449,60 +461,58 @@ impl Recv {
settings: &frame::Settings, settings: &frame::Settings,
store: &mut Store, store: &mut Store,
) -> Result<(), proto::Error> { ) -> Result<(), proto::Error> {
let target = if let Some(val) = settings.initial_window_size() { if let Some(val) = settings.is_extended_connect_protocol_enabled() {
val self.is_extended_connect_protocol_enabled = val;
} else {
return Ok(());
};
let old_sz = self.init_window_sz;
self.init_window_sz = target;
tracing::trace!("update_initial_window_size; new={}; old={}", target, old_sz,);
// Per RFC 7540 §6.9.2:
//
// In addition to changing the flow-control window for streams that are
// not yet active, a SETTINGS frame can alter the initial flow-control
// window size for streams with active flow-control windows (that is,
// streams in the "open" or "half-closed (remote)" state). When the
// value of SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST adjust
// the size of all stream flow-control windows that it maintains by the
// difference between the new value and the old value.
//
// A change to `SETTINGS_INITIAL_WINDOW_SIZE` can cause the available
// space in a flow-control window to become negative. A sender MUST
// track the negative flow-control window and MUST NOT send new
// flow-controlled frames until it receives WINDOW_UPDATE frames that
// cause the flow-control window to become positive.
if target < old_sz {
// We must decrease the (local) window on every open stream.
let dec = old_sz - target;
tracing::trace!("decrementing all windows; dec={}", dec);
store.for_each(|mut stream| {
stream.recv_flow.dec_recv_window(dec);
Ok(())
})
} else if target > old_sz {
// We must increase the (local) window on every open stream.
let inc = target - old_sz;
tracing::trace!("incrementing all windows; inc={}", inc);
store.for_each(|mut stream| {
// XXX: Shouldn't the peer have already noticed our
// overflow and sent us a GOAWAY?
stream
.recv_flow
.inc_window(inc)
.map_err(proto::Error::library_go_away)?;
stream.recv_flow.assign_capacity(inc);
Ok(())
})
} else {
// size is the same... so do nothing
Ok(())
} }
if let Some(target) = settings.initial_window_size() {
let old_sz = self.init_window_sz;
self.init_window_sz = target;
tracing::trace!("update_initial_window_size; new={}; old={}", target, old_sz,);
// Per RFC 7540 §6.9.2:
//
// In addition to changing the flow-control window for streams that are
// not yet active, a SETTINGS frame can alter the initial flow-control
// window size for streams with active flow-control windows (that is,
// streams in the "open" or "half-closed (remote)" state). When the
// value of SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST adjust
// the size of all stream flow-control windows that it maintains by the
// difference between the new value and the old value.
//
// A change to `SETTINGS_INITIAL_WINDOW_SIZE` can cause the available
// space in a flow-control window to become negative. A sender MUST
// track the negative flow-control window and MUST NOT send new
// flow-controlled frames until it receives WINDOW_UPDATE frames that
// cause the flow-control window to become positive.
if target < old_sz {
// We must decrease the (local) window on every open stream.
let dec = old_sz - target;
tracing::trace!("decrementing all windows; dec={}", dec);
store.for_each(|mut stream| {
stream.recv_flow.dec_recv_window(dec);
})
} else if target > old_sz {
// We must increase the (local) window on every open stream.
let inc = target - old_sz;
tracing::trace!("incrementing all windows; inc={}", inc);
store.try_for_each(|mut stream| {
// XXX: Shouldn't the peer have already noticed our
// overflow and sent us a GOAWAY?
stream
.recv_flow
.inc_window(inc)
.map_err(proto::Error::library_go_away)?;
stream.recv_flow.assign_capacity(inc);
Ok::<_, proto::Error>(())
})?;
}
}
Ok(())
} }
pub fn is_end_stream(&self, stream: &store::Ptr) -> bool { pub fn is_end_stream(&self, stream: &store::Ptr) -> bool {

View File

@@ -35,6 +35,9 @@ pub(super) struct Send {
prioritize: Prioritize, prioritize: Prioritize,
is_push_enabled: bool, is_push_enabled: bool,
/// If extended connect protocol is enabled.
is_extended_connect_protocol_enabled: bool,
} }
/// A value to detect which public API has called `poll_reset`. /// A value to detect which public API has called `poll_reset`.
@@ -53,6 +56,7 @@ impl Send {
next_stream_id: Ok(config.local_next_stream_id), next_stream_id: Ok(config.local_next_stream_id),
prioritize: Prioritize::new(config), prioritize: Prioritize::new(config),
is_push_enabled: true, is_push_enabled: true,
is_extended_connect_protocol_enabled: false,
} }
} }
@@ -429,6 +433,10 @@ impl Send {
counts: &mut Counts, counts: &mut Counts,
task: &mut Option<Waker>, task: &mut Option<Waker>,
) -> Result<(), Error> { ) -> Result<(), Error> {
if let Some(val) = settings.is_extended_connect_protocol_enabled() {
self.is_extended_connect_protocol_enabled = val;
}
// Applies an update to the remote endpoint's initial window size. // Applies an update to the remote endpoint's initial window size.
// //
// Per RFC 7540 §6.9.2: // Per RFC 7540 §6.9.2:
@@ -490,16 +498,14 @@ impl Send {
// TODO: Should this notify the producer when the capacity // TODO: Should this notify the producer when the capacity
// of a stream is reduced? Maybe it should if the capacity // of a stream is reduced? Maybe it should if the capacity
// is reduced to zero, allowing the producer to stop work. // is reduced to zero, allowing the producer to stop work.
});
Ok::<_, Error>(())
})?;
self.prioritize self.prioritize
.assign_connection_capacity(total_reclaimed, store, counts); .assign_connection_capacity(total_reclaimed, store, counts);
} else if val > old_val { } else if val > old_val {
let inc = val - old_val; let inc = val - old_val;
store.for_each(|mut stream| { store.try_for_each(|mut stream| {
self.recv_stream_window_update(inc, buffer, &mut stream, counts, task) self.recv_stream_window_update(inc, buffer, &mut stream, counts, task)
.map_err(Error::library_go_away) .map_err(Error::library_go_away)
})?; })?;
@@ -554,4 +560,8 @@ impl Send {
} }
} }
} }
pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool {
self.is_extended_connect_protocol_enabled
}
} }

View File

@@ -4,6 +4,7 @@ use slab;
use indexmap::{self, IndexMap}; use indexmap::{self, IndexMap};
use std::convert::Infallible;
use std::fmt; use std::fmt;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ops; use std::ops;
@@ -128,7 +129,20 @@ impl Store {
} }
} }
pub fn for_each<F, E>(&mut self, mut f: F) -> Result<(), E> pub(crate) fn for_each<F>(&mut self, mut f: F)
where
F: FnMut(Ptr),
{
match self.try_for_each(|ptr| {
f(ptr);
Ok::<_, Infallible>(())
}) {
Ok(()) => (),
Err(infallible) => match infallible {},
}
}
pub fn try_for_each<F, E>(&mut self, mut f: F) -> Result<(), E>
where where
F: FnMut(Ptr) -> Result<(), E>, F: FnMut(Ptr) -> Result<(), E>,
{ {

View File

@@ -2,6 +2,7 @@ use super::recv::RecvHeaderBlockError;
use super::store::{self, Entry, Resolve, Store}; use super::store::{self, Entry, Resolve, Store};
use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId};
use crate::codec::{Codec, SendError, UserError}; use crate::codec::{Codec, SendError, UserError};
use crate::ext::Protocol;
use crate::frame::{self, Frame, Reason}; use crate::frame::{self, Frame, Reason};
use crate::proto::{peer, Error, Initiator, Open, Peer, WindowSize}; use crate::proto::{peer, Error, Initiator, Open, Peer, WindowSize};
use crate::{client, proto, server}; use crate::{client, proto, server};
@@ -214,6 +215,8 @@ where
use super::stream::ContentLength; use super::stream::ContentLength;
use http::Method; use http::Method;
let protocol = request.extensions_mut().remove::<Protocol>();
// Clear before taking lock, incase extensions contain a StreamRef. // Clear before taking lock, incase extensions contain a StreamRef.
request.extensions_mut().clear(); request.extensions_mut().clear();
@@ -261,7 +264,8 @@ where
} }
// Convert the message // Convert the message
let headers = client::Peer::convert_send_message(stream_id, request, end_of_stream)?; let headers =
client::Peer::convert_send_message(stream_id, request, protocol, end_of_stream)?;
let mut stream = me.store.insert(stream.id, stream); let mut stream = me.store.insert(stream.id, stream);
@@ -294,6 +298,15 @@ where
send_buffer: self.send_buffer.clone(), send_buffer: self.send_buffer.clone(),
}) })
} }
pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool {
self.inner
.lock()
.unwrap()
.actions
.send
.is_extended_connect_protocol_enabled()
}
} }
impl<B> DynStreams<'_, B> { impl<B> DynStreams<'_, B> {
@@ -643,15 +656,12 @@ impl Inner {
let last_processed_id = actions.recv.last_processed_id(); let last_processed_id = actions.recv.last_processed_id();
self.store self.store.for_each(|stream| {
.for_each(|stream| { counts.transition(stream, |counts, stream| {
counts.transition(stream, |counts, stream| { actions.recv.handle_error(&err, &mut *stream);
actions.recv.handle_error(&err, &mut *stream); actions.send.handle_error(send_buffer, stream, counts);
actions.send.handle_error(send_buffer, stream, counts);
Ok::<_, ()>(())
})
}) })
.unwrap(); });
actions.conn_error = Some(err); actions.conn_error = Some(err);
@@ -674,19 +684,14 @@ impl Inner {
let err = Error::remote_go_away(frame.debug_data().clone(), frame.reason()); let err = Error::remote_go_away(frame.debug_data().clone(), frame.reason());
self.store self.store.for_each(|stream| {
.for_each(|stream| { if stream.id > last_stream_id {
if stream.id > last_stream_id { counts.transition(stream, |counts, stream| {
counts.transition(stream, |counts, stream| { actions.recv.handle_error(&err, &mut *stream);
actions.recv.handle_error(&err, &mut *stream); actions.send.handle_error(send_buffer, stream, counts);
actions.send.handle_error(send_buffer, stream, counts); })
Ok::<_, ()>(()) }
}) });
} else {
Ok::<_, ()>(())
}
})
.unwrap();
actions.conn_error = Some(err); actions.conn_error = Some(err);
@@ -807,18 +812,15 @@ impl Inner {
tracing::trace!("Streams::recv_eof"); tracing::trace!("Streams::recv_eof");
self.store self.store.for_each(|stream| {
.for_each(|stream| { counts.transition(stream, |counts, stream| {
counts.transition(stream, |counts, stream| { actions.recv.recv_eof(stream);
actions.recv.recv_eof(stream);
// This handles resetting send state associated with the // This handles resetting send state associated with the
// stream // stream
actions.send.handle_error(send_buffer, stream, counts); actions.send.handle_error(send_buffer, stream, counts);
Ok::<_, ()>(())
})
}) })
.expect("recv_eof"); });
actions.clear_queues(clear_pending_accept, &mut self.store, counts); actions.clear_queues(clear_pending_accept, &mut self.store, counts);
Ok(()) Ok(())

View File

@@ -470,6 +470,19 @@ where
Ok(()) Ok(())
} }
/// Enables the [extended CONNECT protocol].
///
/// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
///
/// # Errors
///
/// Returns an error if a previous call is still pending acknowledgement
/// from the remote endpoint.
pub fn enable_connect_protocol(&mut self) -> Result<(), crate::Error> {
self.connection.set_enable_connect_protocol()?;
Ok(())
}
/// Returns `Ready` when the underlying connection has closed. /// Returns `Ready` when the underlying connection has closed.
/// ///
/// If any new inbound streams are received during a call to `poll_closed`, /// If any new inbound streams are received during a call to `poll_closed`,
@@ -904,6 +917,14 @@ impl Builder {
self self
} }
/// Enables the [extended CONNECT protocol].
///
/// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
pub fn enable_connect_protocol(&mut self) -> &mut Self {
self.settings.set_enable_connect_protocol(Some(1));
self
}
/// Creates a new configured HTTP/2 server backed by `io`. /// Creates a new configured HTTP/2 server backed by `io`.
/// ///
/// It is expected that `io` already be in an appropriate state to commence /// It is expected that `io` already be in an appropriate state to commence
@@ -1360,7 +1381,7 @@ impl Peer {
_, _,
) = request.into_parts(); ) = request.into_parts();
let pseudo = Pseudo::request(method, uri); let pseudo = Pseudo::request(method, uri, None);
Ok(frame::PushPromise::new( Ok(frame::PushPromise::new(
stream_id, stream_id,
@@ -1410,6 +1431,11 @@ impl proto::Peer for Peer {
malformed!("malformed headers: missing method"); malformed!("malformed headers: missing method");
} }
let has_protocol = pseudo.protocol.is_some();
if !is_connect && has_protocol {
malformed!("malformed headers: :protocol on non-CONNECT request");
}
if pseudo.status.is_some() { if pseudo.status.is_some() {
malformed!("malformed headers: :status field on request"); malformed!("malformed headers: :status field on request");
} }
@@ -1432,7 +1458,7 @@ impl proto::Peer for Peer {
// A :scheme is required, except CONNECT. // A :scheme is required, except CONNECT.
if let Some(scheme) = pseudo.scheme { if let Some(scheme) = pseudo.scheme {
if is_connect { if is_connect && !has_protocol {
malformed!(":scheme in CONNECT"); malformed!(":scheme in CONNECT");
} }
let maybe_scheme = scheme.parse(); let maybe_scheme = scheme.parse();
@@ -1450,12 +1476,12 @@ impl proto::Peer for Peer {
if parts.authority.is_some() { if parts.authority.is_some() {
parts.scheme = Some(scheme); parts.scheme = Some(scheme);
} }
} else if !is_connect { } else if !is_connect || has_protocol {
malformed!("malformed headers: missing scheme"); malformed!("malformed headers: missing scheme");
} }
if let Some(path) = pseudo.path { if let Some(path) = pseudo.path {
if is_connect { if is_connect && !has_protocol {
malformed!(":path in CONNECT"); malformed!(":path in CONNECT");
} }
@@ -1468,6 +1494,8 @@ impl proto::Peer for Peer {
parts.path_and_query = Some(maybe_path.or_else(|why| { parts.path_and_query = Some(maybe_path.or_else(|why| {
malformed!("malformed headers: malformed path ({:?}): {}", path, why,) malformed!("malformed headers: malformed path ({:?}): {}", path, why,)
})?); })?);
} else if is_connect && has_protocol {
malformed!("malformed headers: missing path in extended CONNECT");
} }
b = b.uri(parts); b = b.uri(parts);

View File

@@ -47,6 +47,16 @@ macro_rules! assert_settings {
}}; }};
} }
#[macro_export]
macro_rules! assert_go_away {
($frame:expr) => {{
match $frame {
h2::frame::Frame::GoAway(v) => v,
f => panic!("expected GO_AWAY; actual={:?}", f),
}
}};
}
#[macro_export] #[macro_export]
macro_rules! poll_err { macro_rules! poll_err {
($transport:expr) => {{ ($transport:expr) => {{
@@ -80,6 +90,7 @@ macro_rules! assert_default_settings {
use h2::frame::Frame; use h2::frame::Frame;
#[track_caller]
pub fn assert_frame_eq<T: Into<Frame>, U: Into<Frame>>(t: T, u: U) { pub fn assert_frame_eq<T: Into<Frame>, U: Into<Frame>>(t: T, u: U) {
let actual: Frame = t.into(); let actual: Frame = t.into();
let expected: Frame = u.into(); let expected: Frame = u.into();

View File

@@ -4,7 +4,10 @@ use std::fmt;
use bytes::Bytes; use bytes::Bytes;
use http::{self, HeaderMap, StatusCode}; use http::{self, HeaderMap, StatusCode};
use h2::frame::{self, Frame, StreamId}; use h2::{
ext::Protocol,
frame::{self, Frame, StreamId},
};
pub const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; pub const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0];
pub const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; pub const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0];
@@ -109,7 +112,9 @@ impl Mock<frame::Headers> {
let method = method.try_into().unwrap(); let method = method.try_into().unwrap();
let uri = uri.try_into().unwrap(); let uri = uri.try_into().unwrap();
let (id, _, fields) = self.into_parts(); let (id, _, fields) = self.into_parts();
let frame = frame::Headers::new(id, frame::Pseudo::request(method, uri), fields); let extensions = Default::default();
let pseudo = frame::Pseudo::request(method, uri, extensions);
let frame = frame::Headers::new(id, pseudo, fields);
Mock(frame) Mock(frame)
} }
@@ -179,6 +184,15 @@ impl Mock<frame::Headers> {
Mock(frame::Headers::new(id, pseudo, fields)) Mock(frame::Headers::new(id, pseudo, fields))
} }
pub fn protocol(self, value: &str) -> Self {
let (id, mut pseudo, fields) = self.into_parts();
let value = Protocol::from(value);
pseudo.set_protocol(value);
Mock(frame::Headers::new(id, pseudo, fields))
}
pub fn eos(mut self) -> Self { pub fn eos(mut self) -> Self {
self.0.set_end_stream(); self.0.set_end_stream();
self self
@@ -230,8 +244,9 @@ impl Mock<frame::PushPromise> {
let method = method.try_into().unwrap(); let method = method.try_into().unwrap();
let uri = uri.try_into().unwrap(); let uri = uri.try_into().unwrap();
let (id, promised, _, fields) = self.into_parts(); let (id, promised, _, fields) = self.into_parts();
let frame = let extensions = Default::default();
frame::PushPromise::new(id, promised, frame::Pseudo::request(method, uri), fields); let pseudo = frame::Pseudo::request(method, uri, extensions);
let frame = frame::PushPromise::new(id, promised, pseudo, fields);
Mock(frame) Mock(frame)
} }
@@ -352,6 +367,11 @@ impl Mock<frame::Settings> {
self.0.set_enable_push(false); self.0.set_enable_push(false);
self self
} }
pub fn enable_connect_protocol(mut self, val: u32) -> Self {
self.0.set_enable_connect_protocol(Some(val));
self
}
} }
impl From<Mock<frame::Settings>> for frame::Settings { impl From<Mock<frame::Settings>> for frame::Settings {

View File

@@ -221,22 +221,15 @@ impl Handle {
let settings = settings.into(); let settings = settings.into();
self.send(settings.into()).await.unwrap(); self.send(settings.into()).await.unwrap();
let frame = self.next().await; let frame = self.next().await.expect("unexpected EOF").unwrap();
let settings = match frame { let settings = assert_settings!(frame);
Some(frame) => match frame.unwrap() {
Frame::Settings(settings) => {
// Send the ACK
let ack = frame::Settings::ack();
// TODO: Don't unwrap? // Send the ACK
self.send(ack.into()).await.unwrap(); let ack = frame::Settings::ack();
// TODO: Don't unwrap?
self.send(ack.into()).await.unwrap();
settings
}
frame => panic!("unexpected frame; frame={:?}", frame),
},
None => panic!("unexpected EOF"),
};
let frame = self.next().await; let frame = self.next().await;
let f = assert_settings!(frame.unwrap().unwrap()); let f = assert_settings!(frame.unwrap().unwrap());

View File

@@ -2,6 +2,7 @@
pub use h2; pub use h2;
pub use h2::client; pub use h2::client;
pub use h2::ext::Protocol;
pub use h2::frame::StreamId; pub use h2::frame::StreamId;
pub use h2::server; pub use h2::server;
pub use h2::*; pub use h2::*;
@@ -20,8 +21,8 @@ pub use super::{Codec, SendFrame};
// Re-export macros // Re-export macros
pub use super::{ pub use super::{
assert_closed, assert_data, assert_default_settings, assert_headers, assert_ping, poll_err, assert_closed, assert_data, assert_default_settings, assert_go_away, assert_headers,
poll_frame, raw_codec, assert_ping, assert_settings, poll_err, poll_frame, raw_codec,
}; };
pub use super::assert::assert_frame_eq; pub use super::assert::assert_frame_eq;

View File

@@ -1305,6 +1305,153 @@ async fn informational_while_local_streaming() {
join(srv, h2).await; join(srv, h2).await;
} }
#[tokio::test]
async fn extended_connect_protocol_disabled_by_default() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();
let srv = async move {
let settings = srv.assert_client_handshake().await;
assert_default_settings!(settings);
srv.recv_frame(
frames::headers(1)
.request("GET", "https://example.com/")
.eos(),
)
.await;
srv.send_frame(frames::headers(1).response(200).eos()).await;
};
let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();
// we send a simple req here just to drive the connection so we can
// receive the server settings.
let request = Request::get("https://example.com/").body(()).unwrap();
// first request is allowed
let (response, _) = client.send_request(request, true).unwrap();
h2.drive(response).await.unwrap();
assert!(!client.is_extended_connect_protocol_enabled());
};
join(srv, h2).await;
}
#[tokio::test]
async fn extended_connect_protocol_enabled_during_handshake() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();
let srv = async move {
let settings = srv
.assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1))
.await;
assert_default_settings!(settings);
srv.recv_frame(
frames::headers(1)
.request("GET", "https://example.com/")
.eos(),
)
.await;
srv.send_frame(frames::headers(1).response(200).eos()).await;
};
let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();
// we send a simple req here just to drive the connection so we can
// receive the server settings.
let request = Request::get("https://example.com/").body(()).unwrap();
let (response, _) = client.send_request(request, true).unwrap();
h2.drive(response).await.unwrap();
assert!(client.is_extended_connect_protocol_enabled());
};
join(srv, h2).await;
}
#[tokio::test]
async fn invalid_connect_protocol_enabled_setting() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();
let srv = async move {
// Send a settings frame
srv.send(frames::settings().enable_connect_protocol(2).into())
.await
.unwrap();
srv.read_preface().await.unwrap();
let settings = assert_settings!(srv.next().await.expect("unexpected EOF").unwrap());
assert_default_settings!(settings);
// Send the ACK
let ack = frame::Settings::ack();
// TODO: Don't unwrap?
srv.send(ack.into()).await.unwrap();
let frame = srv.next().await.unwrap().unwrap();
let go_away = assert_go_away!(frame);
assert_eq!(go_away.reason(), Reason::PROTOCOL_ERROR);
};
let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();
// we send a simple req here just to drive the connection so we can
// receive the server settings.
let request = Request::get("https://example.com/").body(()).unwrap();
let (response, _) = client.send_request(request, true).unwrap();
let error = h2.drive(response).await.unwrap_err();
assert_eq!(error.reason(), Some(Reason::PROTOCOL_ERROR));
};
join(srv, h2).await;
}
#[tokio::test]
async fn extended_connect_request() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();
let srv = async move {
let settings = srv
.assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1))
.await;
assert_default_settings!(settings);
srv.recv_frame(
frames::headers(1)
.request("CONNECT", "http://bread/baguette")
.protocol("the-bread-protocol")
.eos(),
)
.await;
srv.send_frame(frames::headers(1).response(200).eos()).await;
};
let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();
let request = Request::connect("http://bread/baguette")
.extension(Protocol::from("the-bread-protocol"))
.body(())
.unwrap();
let (response, _) = client.send_request(request, true).unwrap();
h2.drive(response).await.unwrap();
};
join(srv, h2).await;
}
const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0];
const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0];

View File

@@ -1149,3 +1149,191 @@ async fn send_reset_explicitly() {
join(client, srv).await; join(client, srv).await;
} }
#[tokio::test]
async fn extended_connect_protocol_disabled_by_default() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), None);
client
.send_frame(
frames::headers(1)
.request("CONNECT", "http://bread/baguette")
.protocol("the-bread-protocol"),
)
.await;
client.recv_frame(frames::reset(1).protocol_error()).await;
};
let srv = async move {
let mut srv = server::handshake(io).await.expect("handshake");
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}
#[tokio::test]
async fn extended_connect_protocol_enabled_during_handshake() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
client
.send_frame(
frames::headers(1)
.request("CONNECT", "http://bread/baguette")
.protocol("the-bread-protocol"),
)
.await;
client.recv_frame(frames::headers(1).response(200)).await;
};
let srv = async move {
let mut builder = server::Builder::new();
builder.enable_connect_protocol();
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
let (_req, mut stream) = srv.next().await.unwrap().unwrap();
let rsp = Response::new(());
stream.send_response(rsp, false).unwrap();
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}
#[tokio::test]
async fn reject_pseudo_protocol_on_non_connect_request() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
client
.send_frame(
frames::headers(1)
.request("GET", "http://bread/baguette")
.protocol("the-bread-protocol"),
)
.await;
client.recv_frame(frames::reset(1).protocol_error()).await;
};
let srv = async move {
let mut builder = server::Builder::new();
builder.enable_connect_protocol();
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
assert!(srv.next().await.is_none());
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}
#[tokio::test]
async fn reject_authority_target_on_extended_connect_request() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
client
.send_frame(
frames::headers(1)
.request("CONNECT", "bread:80")
.protocol("the-bread-protocol"),
)
.await;
client.recv_frame(frames::reset(1).protocol_error()).await;
};
let srv = async move {
let mut builder = server::Builder::new();
builder.enable_connect_protocol();
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
assert!(srv.next().await.is_none());
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}
#[tokio::test]
async fn reject_non_authority_target_on_connect_request() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
client
.send_frame(frames::headers(1).request("CONNECT", "https://bread/baguette"))
.await;
client.recv_frame(frames::reset(1).protocol_error()).await;
};
let srv = async move {
let mut builder = server::Builder::new();
builder.enable_connect_protocol();
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
assert!(srv.next().await.is_none());
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}