Implement the extended CONNECT protocol from RFC 8441 (#565)
This commit is contained in:
@@ -136,6 +136,7 @@
|
||||
//! [`Error`]: ../struct.Error.html
|
||||
|
||||
use crate::codec::{Codec, SendError, UserError};
|
||||
use crate::ext::Protocol;
|
||||
use crate::frame::{Headers, Pseudo, Reason, Settings, StreamId};
|
||||
use crate::proto::{self, Error};
|
||||
use crate::{FlowControl, PingPong, RecvStream, SendStream};
|
||||
@@ -517,6 +518,19 @@ where
|
||||
(response, stream)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns whether the [extended CONNECT protocol][1] is enabled or not.
|
||||
///
|
||||
/// This setting is configured by the server peer by sending the
|
||||
/// [`SETTINGS_ENABLE_CONNECT_PROTOCOL` parameter][2] in a `SETTINGS` frame.
|
||||
/// This method returns the currently acknowledged value recieved from the
|
||||
/// remote.
|
||||
///
|
||||
/// [1]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
|
||||
/// [2]: https://datatracker.ietf.org/doc/html/rfc8441#section-3
|
||||
pub fn is_extended_connect_protocol_enabled(&self) -> bool {
|
||||
self.inner.is_extended_connect_protocol_enabled()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> fmt::Debug for SendRequest<B>
|
||||
@@ -1246,11 +1260,10 @@ where
|
||||
/// This method returns the currently acknowledged value recieved from the
|
||||
/// remote.
|
||||
///
|
||||
/// [settings]: https://tools.ietf.org/html/rfc7540#section-5.1.2
|
||||
/// [1]: https://tools.ietf.org/html/rfc7540#section-5.1.2
|
||||
pub fn max_concurrent_send_streams(&self) -> usize {
|
||||
self.inner.max_send_streams()
|
||||
}
|
||||
|
||||
/// Returns the maximum number of concurrent streams that may be initiated
|
||||
/// by the server on this connection.
|
||||
///
|
||||
@@ -1416,6 +1429,7 @@ impl Peer {
|
||||
pub fn convert_send_message(
|
||||
id: StreamId,
|
||||
request: Request<()>,
|
||||
protocol: Option<Protocol>,
|
||||
end_of_stream: bool,
|
||||
) -> Result<Headers, SendError> {
|
||||
use http::request::Parts;
|
||||
@@ -1435,7 +1449,7 @@ impl Peer {
|
||||
|
||||
// Build the set pseudo header set. All requests will include `method`
|
||||
// and `path`.
|
||||
let mut pseudo = Pseudo::request(method, uri);
|
||||
let mut pseudo = Pseudo::request(method, uri, protocol);
|
||||
|
||||
if pseudo.scheme.is_none() {
|
||||
// If the scheme is not set, then there are a two options.
|
||||
|
||||
55
src/ext.rs
Normal file
55
src/ext.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
//! Extensions specific to the HTTP/2 protocol.
|
||||
|
||||
use crate::hpack::BytesStr;
|
||||
|
||||
use bytes::Bytes;
|
||||
use std::fmt;
|
||||
|
||||
/// Represents the `:protocol` pseudo-header used by
|
||||
/// the [Extended CONNECT Protocol].
|
||||
///
|
||||
/// [Extended CONNECT Protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
|
||||
#[derive(Clone, Eq, PartialEq)]
|
||||
pub struct Protocol {
|
||||
value: BytesStr,
|
||||
}
|
||||
|
||||
impl Protocol {
|
||||
/// Converts a static string to a protocol name.
|
||||
pub const fn from_static(value: &'static str) -> Self {
|
||||
Self {
|
||||
value: BytesStr::from_static(value),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a str representation of the header.
|
||||
pub fn as_str(&self) -> &str {
|
||||
self.value.as_str()
|
||||
}
|
||||
|
||||
pub(crate) fn try_from(bytes: Bytes) -> Result<Self, std::str::Utf8Error> {
|
||||
Ok(Self {
|
||||
value: BytesStr::try_from(bytes)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for Protocol {
|
||||
fn from(value: &'a str) -> Self {
|
||||
Self {
|
||||
value: BytesStr::from(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for Protocol {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
self.value.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Protocol {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.value.fmt(f)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
use super::{util, StreamDependency, StreamId};
|
||||
use crate::ext::Protocol;
|
||||
use crate::frame::{Error, Frame, Head, Kind};
|
||||
use crate::hpack::{self, BytesStr};
|
||||
|
||||
@@ -66,6 +67,7 @@ pub struct Pseudo {
|
||||
pub scheme: Option<BytesStr>,
|
||||
pub authority: Option<BytesStr>,
|
||||
pub path: Option<BytesStr>,
|
||||
pub protocol: Option<Protocol>,
|
||||
|
||||
// Response
|
||||
pub status: Option<StatusCode>,
|
||||
@@ -292,6 +294,10 @@ impl fmt::Debug for Headers {
|
||||
.field("stream_id", &self.stream_id)
|
||||
.field("flags", &self.flags);
|
||||
|
||||
if let Some(ref protocol) = self.header_block.pseudo.protocol {
|
||||
builder.field("protocol", protocol);
|
||||
}
|
||||
|
||||
if let Some(ref dep) = self.stream_dep {
|
||||
builder.field("stream_dep", dep);
|
||||
}
|
||||
@@ -529,7 +535,7 @@ impl Continuation {
|
||||
// ===== impl Pseudo =====
|
||||
|
||||
impl Pseudo {
|
||||
pub fn request(method: Method, uri: Uri) -> Self {
|
||||
pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
|
||||
let parts = uri::Parts::from(uri);
|
||||
|
||||
let mut path = parts
|
||||
@@ -550,6 +556,7 @@ impl Pseudo {
|
||||
scheme: None,
|
||||
authority: None,
|
||||
path: Some(path).filter(|p| !p.is_empty()),
|
||||
protocol,
|
||||
status: None,
|
||||
};
|
||||
|
||||
@@ -575,6 +582,7 @@ impl Pseudo {
|
||||
scheme: None,
|
||||
authority: None,
|
||||
path: None,
|
||||
protocol: None,
|
||||
status: Some(status),
|
||||
}
|
||||
}
|
||||
@@ -593,6 +601,11 @@ impl Pseudo {
|
||||
self.scheme = Some(bytes_str);
|
||||
}
|
||||
|
||||
#[cfg(feature = "unstable")]
|
||||
pub fn set_protocol(&mut self, protocol: Protocol) {
|
||||
self.protocol = Some(protocol);
|
||||
}
|
||||
|
||||
pub fn set_authority(&mut self, authority: BytesStr) {
|
||||
self.authority = Some(authority);
|
||||
}
|
||||
@@ -681,6 +694,10 @@ impl Iterator for Iter {
|
||||
return Some(Path(path));
|
||||
}
|
||||
|
||||
if let Some(protocol) = pseudo.protocol.take() {
|
||||
return Some(Protocol(protocol));
|
||||
}
|
||||
|
||||
if let Some(status) = pseudo.status.take() {
|
||||
return Some(Status(status));
|
||||
}
|
||||
@@ -879,6 +896,7 @@ impl HeaderBlock {
|
||||
Method(v) => set_pseudo!(method, v),
|
||||
Scheme(v) => set_pseudo!(scheme, v),
|
||||
Path(v) => set_pseudo!(path, v),
|
||||
Protocol(v) => set_pseudo!(protocol, v),
|
||||
Status(v) => set_pseudo!(status, v),
|
||||
}
|
||||
});
|
||||
|
||||
@@ -13,6 +13,7 @@ pub struct Settings {
|
||||
initial_window_size: Option<u32>,
|
||||
max_frame_size: Option<u32>,
|
||||
max_header_list_size: Option<u32>,
|
||||
enable_connect_protocol: Option<u32>,
|
||||
}
|
||||
|
||||
/// An enum that lists all valid settings that can be sent in a SETTINGS
|
||||
@@ -27,6 +28,7 @@ pub enum Setting {
|
||||
InitialWindowSize(u32),
|
||||
MaxFrameSize(u32),
|
||||
MaxHeaderListSize(u32),
|
||||
EnableConnectProtocol(u32),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Default)]
|
||||
@@ -107,6 +109,14 @@ impl Settings {
|
||||
self.enable_push = Some(enable as u32);
|
||||
}
|
||||
|
||||
pub fn is_extended_connect_protocol_enabled(&self) -> Option<bool> {
|
||||
self.enable_connect_protocol.map(|val| val != 0)
|
||||
}
|
||||
|
||||
pub fn set_enable_connect_protocol(&mut self, val: Option<u32>) {
|
||||
self.enable_connect_protocol = val;
|
||||
}
|
||||
|
||||
pub fn header_table_size(&self) -> Option<u32> {
|
||||
self.header_table_size
|
||||
}
|
||||
@@ -181,6 +191,14 @@ impl Settings {
|
||||
Some(MaxHeaderListSize(val)) => {
|
||||
settings.max_header_list_size = Some(val);
|
||||
}
|
||||
Some(EnableConnectProtocol(val)) => match val {
|
||||
0 | 1 => {
|
||||
settings.enable_connect_protocol = Some(val);
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::InvalidSettingValue);
|
||||
}
|
||||
},
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
@@ -236,6 +254,10 @@ impl Settings {
|
||||
if let Some(v) = self.max_header_list_size {
|
||||
f(MaxHeaderListSize(v));
|
||||
}
|
||||
|
||||
if let Some(v) = self.enable_connect_protocol {
|
||||
f(EnableConnectProtocol(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,6 +291,9 @@ impl fmt::Debug for Settings {
|
||||
Setting::MaxHeaderListSize(v) => {
|
||||
builder.field("max_header_list_size", &v);
|
||||
}
|
||||
Setting::EnableConnectProtocol(v) => {
|
||||
builder.field("enable_connect_protocol", &v);
|
||||
}
|
||||
});
|
||||
|
||||
builder.finish()
|
||||
@@ -291,6 +316,7 @@ impl Setting {
|
||||
4 => Some(InitialWindowSize(val)),
|
||||
5 => Some(MaxFrameSize(val)),
|
||||
6 => Some(MaxHeaderListSize(val)),
|
||||
8 => Some(EnableConnectProtocol(val)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -322,6 +348,7 @@ impl Setting {
|
||||
InitialWindowSize(v) => (4, v),
|
||||
MaxFrameSize(v) => (5, v),
|
||||
MaxHeaderListSize(v) => (6, v),
|
||||
EnableConnectProtocol(v) => (8, v),
|
||||
};
|
||||
|
||||
dst.put_u16(kind);
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use super::{DecoderError, NeedMore};
|
||||
use crate::ext::Protocol;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http::header::{HeaderName, HeaderValue};
|
||||
@@ -14,6 +15,7 @@ pub enum Header<T = HeaderName> {
|
||||
Method(Method),
|
||||
Scheme(BytesStr),
|
||||
Path(BytesStr),
|
||||
Protocol(Protocol),
|
||||
Status(StatusCode),
|
||||
}
|
||||
|
||||
@@ -25,6 +27,7 @@ pub enum Name<'a> {
|
||||
Method,
|
||||
Scheme,
|
||||
Path,
|
||||
Protocol,
|
||||
Status,
|
||||
}
|
||||
|
||||
@@ -51,6 +54,7 @@ impl Header<Option<HeaderName>> {
|
||||
Method(v) => Method(v),
|
||||
Scheme(v) => Scheme(v),
|
||||
Path(v) => Path(v),
|
||||
Protocol(v) => Protocol(v),
|
||||
Status(v) => Status(v),
|
||||
})
|
||||
}
|
||||
@@ -79,6 +83,10 @@ impl Header {
|
||||
let value = BytesStr::try_from(value)?;
|
||||
Ok(Header::Path(value))
|
||||
}
|
||||
b"protocol" => {
|
||||
let value = Protocol::try_from(value)?;
|
||||
Ok(Header::Protocol(value))
|
||||
}
|
||||
b"status" => {
|
||||
let status = StatusCode::from_bytes(&value)?;
|
||||
Ok(Header::Status(status))
|
||||
@@ -104,6 +112,7 @@ impl Header {
|
||||
Header::Method(ref v) => 32 + 7 + v.as_ref().len(),
|
||||
Header::Scheme(ref v) => 32 + 7 + v.len(),
|
||||
Header::Path(ref v) => 32 + 5 + v.len(),
|
||||
Header::Protocol(ref v) => 32 + 9 + v.as_str().len(),
|
||||
Header::Status(_) => 32 + 7 + 3,
|
||||
}
|
||||
}
|
||||
@@ -116,6 +125,7 @@ impl Header {
|
||||
Header::Method(..) => Name::Method,
|
||||
Header::Scheme(..) => Name::Scheme,
|
||||
Header::Path(..) => Name::Path,
|
||||
Header::Protocol(..) => Name::Protocol,
|
||||
Header::Status(..) => Name::Status,
|
||||
}
|
||||
}
|
||||
@@ -127,6 +137,7 @@ impl Header {
|
||||
Header::Method(ref v) => v.as_ref().as_ref(),
|
||||
Header::Scheme(ref v) => v.as_ref(),
|
||||
Header::Path(ref v) => v.as_ref(),
|
||||
Header::Protocol(ref v) => v.as_ref(),
|
||||
Header::Status(ref v) => v.as_str().as_ref(),
|
||||
}
|
||||
}
|
||||
@@ -156,6 +167,10 @@ impl Header {
|
||||
Header::Path(ref b) => a == b,
|
||||
_ => false,
|
||||
},
|
||||
Header::Protocol(ref a) => match *other {
|
||||
Header::Protocol(ref b) => a == b,
|
||||
_ => false,
|
||||
},
|
||||
Header::Status(ref a) => match *other {
|
||||
Header::Status(ref b) => a == b,
|
||||
_ => false,
|
||||
@@ -205,6 +220,7 @@ impl From<Header> for Header<Option<HeaderName>> {
|
||||
Header::Method(v) => Header::Method(v),
|
||||
Header::Scheme(v) => Header::Scheme(v),
|
||||
Header::Path(v) => Header::Path(v),
|
||||
Header::Protocol(v) => Header::Protocol(v),
|
||||
Header::Status(v) => Header::Status(v),
|
||||
}
|
||||
}
|
||||
@@ -221,6 +237,7 @@ impl<'a> Name<'a> {
|
||||
Name::Method => Ok(Header::Method(Method::from_bytes(&*value)?)),
|
||||
Name::Scheme => Ok(Header::Scheme(BytesStr::try_from(value)?)),
|
||||
Name::Path => Ok(Header::Path(BytesStr::try_from(value)?)),
|
||||
Name::Protocol => Ok(Header::Protocol(Protocol::try_from(value)?)),
|
||||
Name::Status => {
|
||||
match StatusCode::from_bytes(&value) {
|
||||
Ok(status) => Ok(Header::Status(status)),
|
||||
@@ -238,6 +255,7 @@ impl<'a> Name<'a> {
|
||||
Name::Method => b":method",
|
||||
Name::Scheme => b":scheme",
|
||||
Name::Path => b":path",
|
||||
Name::Protocol => b":protocol",
|
||||
Name::Status => b":status",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -751,6 +751,7 @@ fn index_static(header: &Header) -> Option<(usize, bool)> {
|
||||
"/index.html" => Some((5, true)),
|
||||
_ => Some((4, false)),
|
||||
},
|
||||
Header::Protocol(..) => None,
|
||||
Header::Status(ref v) => match u16::from(*v) {
|
||||
200 => Some((8, true)),
|
||||
204 => Some((9, true)),
|
||||
|
||||
@@ -134,6 +134,7 @@ fn key_str(e: &Header) -> &str {
|
||||
Header::Method(..) => ":method",
|
||||
Header::Scheme(..) => ":scheme",
|
||||
Header::Path(..) => ":path",
|
||||
Header::Protocol(..) => ":protocol",
|
||||
Header::Status(..) => ":status",
|
||||
}
|
||||
}
|
||||
@@ -145,6 +146,7 @@ fn value_str(e: &Header) -> &str {
|
||||
Header::Method(ref m) => m.as_str(),
|
||||
Header::Scheme(ref v) => &**v,
|
||||
Header::Path(ref v) => &**v,
|
||||
Header::Protocol(ref v) => v.as_str(),
|
||||
Header::Status(ref v) => v.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,6 +120,7 @@ mod frame;
|
||||
pub mod frame;
|
||||
|
||||
pub mod client;
|
||||
pub mod ext;
|
||||
pub mod server;
|
||||
mod share;
|
||||
|
||||
|
||||
@@ -110,6 +110,10 @@ where
|
||||
initial_max_send_streams: config.initial_max_send_streams,
|
||||
local_next_stream_id: config.next_stream_id,
|
||||
local_push_enabled: config.settings.is_push_enabled().unwrap_or(true),
|
||||
extended_connect_protocol_enabled: config
|
||||
.settings
|
||||
.is_extended_connect_protocol_enabled()
|
||||
.unwrap_or(false),
|
||||
local_reset_duration: config.reset_stream_duration,
|
||||
local_reset_max: config.reset_stream_max,
|
||||
remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE,
|
||||
@@ -147,6 +151,13 @@ where
|
||||
self.inner.settings.send_settings(settings)
|
||||
}
|
||||
|
||||
/// Send a new SETTINGS frame with extended CONNECT protocol enabled.
|
||||
pub(crate) fn set_enable_connect_protocol(&mut self) -> Result<(), UserError> {
|
||||
let mut settings = frame::Settings::default();
|
||||
settings.set_enable_connect_protocol(Some(1));
|
||||
self.inner.settings.send_settings(settings)
|
||||
}
|
||||
|
||||
/// Returns the maximum number of concurrent streams that may be initiated
|
||||
/// by this peer.
|
||||
pub(crate) fn max_send_streams(&self) -> usize {
|
||||
|
||||
@@ -117,6 +117,8 @@ impl Settings {
|
||||
|
||||
tracing::trace!("ACK sent; applying settings");
|
||||
|
||||
streams.apply_remote_settings(settings)?;
|
||||
|
||||
if let Some(val) = settings.header_table_size() {
|
||||
dst.set_send_header_table_size(val as usize);
|
||||
}
|
||||
@@ -124,8 +126,6 @@ impl Settings {
|
||||
if let Some(val) = settings.max_frame_size() {
|
||||
dst.set_max_send_frame_size(val as usize);
|
||||
}
|
||||
|
||||
streams.apply_remote_settings(settings)?;
|
||||
}
|
||||
|
||||
self.remote = None;
|
||||
|
||||
@@ -47,6 +47,9 @@ pub struct Config {
|
||||
/// If the local peer is willing to receive push promises
|
||||
pub local_push_enabled: bool,
|
||||
|
||||
/// If extended connect protocol is enabled.
|
||||
pub extended_connect_protocol_enabled: bool,
|
||||
|
||||
/// How long a locally reset stream should ignore frames
|
||||
pub local_reset_duration: Duration,
|
||||
|
||||
|
||||
@@ -56,6 +56,9 @@ pub(super) struct Recv {
|
||||
|
||||
/// If push promises are allowed to be received.
|
||||
is_push_enabled: bool,
|
||||
|
||||
/// If extended connect protocol is enabled.
|
||||
is_extended_connect_protocol_enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -103,6 +106,7 @@ impl Recv {
|
||||
buffer: Buffer::new(),
|
||||
refused: None,
|
||||
is_push_enabled: config.local_push_enabled,
|
||||
is_extended_connect_protocol_enabled: config.extended_connect_protocol_enabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,6 +220,14 @@ impl Recv {
|
||||
|
||||
let stream_id = frame.stream_id();
|
||||
let (pseudo, fields) = frame.into_parts();
|
||||
|
||||
if pseudo.protocol.is_some() {
|
||||
if counts.peer().is_server() && !self.is_extended_connect_protocol_enabled {
|
||||
proto_err!(stream: "cannot use :protocol if extended connect protocol is disabled; stream={:?}", stream.id);
|
||||
return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR).into());
|
||||
}
|
||||
}
|
||||
|
||||
if !pseudo.is_informational() {
|
||||
let message = counts
|
||||
.peer()
|
||||
@@ -449,60 +461,58 @@ impl Recv {
|
||||
settings: &frame::Settings,
|
||||
store: &mut Store,
|
||||
) -> Result<(), proto::Error> {
|
||||
let target = if let Some(val) = settings.initial_window_size() {
|
||||
val
|
||||
} else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let old_sz = self.init_window_sz;
|
||||
self.init_window_sz = target;
|
||||
|
||||
tracing::trace!("update_initial_window_size; new={}; old={}", target, old_sz,);
|
||||
|
||||
// Per RFC 7540 §6.9.2:
|
||||
//
|
||||
// In addition to changing the flow-control window for streams that are
|
||||
// not yet active, a SETTINGS frame can alter the initial flow-control
|
||||
// window size for streams with active flow-control windows (that is,
|
||||
// streams in the "open" or "half-closed (remote)" state). When the
|
||||
// value of SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST adjust
|
||||
// the size of all stream flow-control windows that it maintains by the
|
||||
// difference between the new value and the old value.
|
||||
//
|
||||
// A change to `SETTINGS_INITIAL_WINDOW_SIZE` can cause the available
|
||||
// space in a flow-control window to become negative. A sender MUST
|
||||
// track the negative flow-control window and MUST NOT send new
|
||||
// flow-controlled frames until it receives WINDOW_UPDATE frames that
|
||||
// cause the flow-control window to become positive.
|
||||
|
||||
if target < old_sz {
|
||||
// We must decrease the (local) window on every open stream.
|
||||
let dec = old_sz - target;
|
||||
tracing::trace!("decrementing all windows; dec={}", dec);
|
||||
|
||||
store.for_each(|mut stream| {
|
||||
stream.recv_flow.dec_recv_window(dec);
|
||||
Ok(())
|
||||
})
|
||||
} else if target > old_sz {
|
||||
// We must increase the (local) window on every open stream.
|
||||
let inc = target - old_sz;
|
||||
tracing::trace!("incrementing all windows; inc={}", inc);
|
||||
store.for_each(|mut stream| {
|
||||
// XXX: Shouldn't the peer have already noticed our
|
||||
// overflow and sent us a GOAWAY?
|
||||
stream
|
||||
.recv_flow
|
||||
.inc_window(inc)
|
||||
.map_err(proto::Error::library_go_away)?;
|
||||
stream.recv_flow.assign_capacity(inc);
|
||||
Ok(())
|
||||
})
|
||||
} else {
|
||||
// size is the same... so do nothing
|
||||
Ok(())
|
||||
if let Some(val) = settings.is_extended_connect_protocol_enabled() {
|
||||
self.is_extended_connect_protocol_enabled = val;
|
||||
}
|
||||
|
||||
if let Some(target) = settings.initial_window_size() {
|
||||
let old_sz = self.init_window_sz;
|
||||
self.init_window_sz = target;
|
||||
|
||||
tracing::trace!("update_initial_window_size; new={}; old={}", target, old_sz,);
|
||||
|
||||
// Per RFC 7540 §6.9.2:
|
||||
//
|
||||
// In addition to changing the flow-control window for streams that are
|
||||
// not yet active, a SETTINGS frame can alter the initial flow-control
|
||||
// window size for streams with active flow-control windows (that is,
|
||||
// streams in the "open" or "half-closed (remote)" state). When the
|
||||
// value of SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST adjust
|
||||
// the size of all stream flow-control windows that it maintains by the
|
||||
// difference between the new value and the old value.
|
||||
//
|
||||
// A change to `SETTINGS_INITIAL_WINDOW_SIZE` can cause the available
|
||||
// space in a flow-control window to become negative. A sender MUST
|
||||
// track the negative flow-control window and MUST NOT send new
|
||||
// flow-controlled frames until it receives WINDOW_UPDATE frames that
|
||||
// cause the flow-control window to become positive.
|
||||
|
||||
if target < old_sz {
|
||||
// We must decrease the (local) window on every open stream.
|
||||
let dec = old_sz - target;
|
||||
tracing::trace!("decrementing all windows; dec={}", dec);
|
||||
|
||||
store.for_each(|mut stream| {
|
||||
stream.recv_flow.dec_recv_window(dec);
|
||||
})
|
||||
} else if target > old_sz {
|
||||
// We must increase the (local) window on every open stream.
|
||||
let inc = target - old_sz;
|
||||
tracing::trace!("incrementing all windows; inc={}", inc);
|
||||
store.try_for_each(|mut stream| {
|
||||
// XXX: Shouldn't the peer have already noticed our
|
||||
// overflow and sent us a GOAWAY?
|
||||
stream
|
||||
.recv_flow
|
||||
.inc_window(inc)
|
||||
.map_err(proto::Error::library_go_away)?;
|
||||
stream.recv_flow.assign_capacity(inc);
|
||||
Ok::<_, proto::Error>(())
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn is_end_stream(&self, stream: &store::Ptr) -> bool {
|
||||
|
||||
@@ -35,6 +35,9 @@ pub(super) struct Send {
|
||||
prioritize: Prioritize,
|
||||
|
||||
is_push_enabled: bool,
|
||||
|
||||
/// If extended connect protocol is enabled.
|
||||
is_extended_connect_protocol_enabled: bool,
|
||||
}
|
||||
|
||||
/// A value to detect which public API has called `poll_reset`.
|
||||
@@ -53,6 +56,7 @@ impl Send {
|
||||
next_stream_id: Ok(config.local_next_stream_id),
|
||||
prioritize: Prioritize::new(config),
|
||||
is_push_enabled: true,
|
||||
is_extended_connect_protocol_enabled: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -429,6 +433,10 @@ impl Send {
|
||||
counts: &mut Counts,
|
||||
task: &mut Option<Waker>,
|
||||
) -> Result<(), Error> {
|
||||
if let Some(val) = settings.is_extended_connect_protocol_enabled() {
|
||||
self.is_extended_connect_protocol_enabled = val;
|
||||
}
|
||||
|
||||
// Applies an update to the remote endpoint's initial window size.
|
||||
//
|
||||
// Per RFC 7540 §6.9.2:
|
||||
@@ -490,16 +498,14 @@ impl Send {
|
||||
// TODO: Should this notify the producer when the capacity
|
||||
// of a stream is reduced? Maybe it should if the capacity
|
||||
// is reduced to zero, allowing the producer to stop work.
|
||||
|
||||
Ok::<_, Error>(())
|
||||
})?;
|
||||
});
|
||||
|
||||
self.prioritize
|
||||
.assign_connection_capacity(total_reclaimed, store, counts);
|
||||
} else if val > old_val {
|
||||
let inc = val - old_val;
|
||||
|
||||
store.for_each(|mut stream| {
|
||||
store.try_for_each(|mut stream| {
|
||||
self.recv_stream_window_update(inc, buffer, &mut stream, counts, task)
|
||||
.map_err(Error::library_go_away)
|
||||
})?;
|
||||
@@ -554,4 +560,8 @@ impl Send {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool {
|
||||
self.is_extended_connect_protocol_enabled
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use slab;
|
||||
|
||||
use indexmap::{self, IndexMap};
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::fmt;
|
||||
use std::marker::PhantomData;
|
||||
use std::ops;
|
||||
@@ -128,7 +129,20 @@ impl Store {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn for_each<F, E>(&mut self, mut f: F) -> Result<(), E>
|
||||
pub(crate) fn for_each<F>(&mut self, mut f: F)
|
||||
where
|
||||
F: FnMut(Ptr),
|
||||
{
|
||||
match self.try_for_each(|ptr| {
|
||||
f(ptr);
|
||||
Ok::<_, Infallible>(())
|
||||
}) {
|
||||
Ok(()) => (),
|
||||
Err(infallible) => match infallible {},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_for_each<F, E>(&mut self, mut f: F) -> Result<(), E>
|
||||
where
|
||||
F: FnMut(Ptr) -> Result<(), E>,
|
||||
{
|
||||
|
||||
@@ -2,6 +2,7 @@ use super::recv::RecvHeaderBlockError;
|
||||
use super::store::{self, Entry, Resolve, Store};
|
||||
use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId};
|
||||
use crate::codec::{Codec, SendError, UserError};
|
||||
use crate::ext::Protocol;
|
||||
use crate::frame::{self, Frame, Reason};
|
||||
use crate::proto::{peer, Error, Initiator, Open, Peer, WindowSize};
|
||||
use crate::{client, proto, server};
|
||||
@@ -214,6 +215,8 @@ where
|
||||
use super::stream::ContentLength;
|
||||
use http::Method;
|
||||
|
||||
let protocol = request.extensions_mut().remove::<Protocol>();
|
||||
|
||||
// Clear before taking lock, incase extensions contain a StreamRef.
|
||||
request.extensions_mut().clear();
|
||||
|
||||
@@ -261,7 +264,8 @@ where
|
||||
}
|
||||
|
||||
// Convert the message
|
||||
let headers = client::Peer::convert_send_message(stream_id, request, end_of_stream)?;
|
||||
let headers =
|
||||
client::Peer::convert_send_message(stream_id, request, protocol, end_of_stream)?;
|
||||
|
||||
let mut stream = me.store.insert(stream.id, stream);
|
||||
|
||||
@@ -294,6 +298,15 @@ where
|
||||
send_buffer: self.send_buffer.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool {
|
||||
self.inner
|
||||
.lock()
|
||||
.unwrap()
|
||||
.actions
|
||||
.send
|
||||
.is_extended_connect_protocol_enabled()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B> DynStreams<'_, B> {
|
||||
@@ -643,15 +656,12 @@ impl Inner {
|
||||
|
||||
let last_processed_id = actions.recv.last_processed_id();
|
||||
|
||||
self.store
|
||||
.for_each(|stream| {
|
||||
counts.transition(stream, |counts, stream| {
|
||||
actions.recv.handle_error(&err, &mut *stream);
|
||||
actions.send.handle_error(send_buffer, stream, counts);
|
||||
Ok::<_, ()>(())
|
||||
})
|
||||
self.store.for_each(|stream| {
|
||||
counts.transition(stream, |counts, stream| {
|
||||
actions.recv.handle_error(&err, &mut *stream);
|
||||
actions.send.handle_error(send_buffer, stream, counts);
|
||||
})
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
actions.conn_error = Some(err);
|
||||
|
||||
@@ -674,19 +684,14 @@ impl Inner {
|
||||
|
||||
let err = Error::remote_go_away(frame.debug_data().clone(), frame.reason());
|
||||
|
||||
self.store
|
||||
.for_each(|stream| {
|
||||
if stream.id > last_stream_id {
|
||||
counts.transition(stream, |counts, stream| {
|
||||
actions.recv.handle_error(&err, &mut *stream);
|
||||
actions.send.handle_error(send_buffer, stream, counts);
|
||||
Ok::<_, ()>(())
|
||||
})
|
||||
} else {
|
||||
Ok::<_, ()>(())
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
self.store.for_each(|stream| {
|
||||
if stream.id > last_stream_id {
|
||||
counts.transition(stream, |counts, stream| {
|
||||
actions.recv.handle_error(&err, &mut *stream);
|
||||
actions.send.handle_error(send_buffer, stream, counts);
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
actions.conn_error = Some(err);
|
||||
|
||||
@@ -807,18 +812,15 @@ impl Inner {
|
||||
|
||||
tracing::trace!("Streams::recv_eof");
|
||||
|
||||
self.store
|
||||
.for_each(|stream| {
|
||||
counts.transition(stream, |counts, stream| {
|
||||
actions.recv.recv_eof(stream);
|
||||
self.store.for_each(|stream| {
|
||||
counts.transition(stream, |counts, stream| {
|
||||
actions.recv.recv_eof(stream);
|
||||
|
||||
// This handles resetting send state associated with the
|
||||
// stream
|
||||
actions.send.handle_error(send_buffer, stream, counts);
|
||||
Ok::<_, ()>(())
|
||||
})
|
||||
// This handles resetting send state associated with the
|
||||
// stream
|
||||
actions.send.handle_error(send_buffer, stream, counts);
|
||||
})
|
||||
.expect("recv_eof");
|
||||
});
|
||||
|
||||
actions.clear_queues(clear_pending_accept, &mut self.store, counts);
|
||||
Ok(())
|
||||
|
||||
@@ -470,6 +470,19 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enables the [extended CONNECT protocol].
|
||||
///
|
||||
/// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if a previous call is still pending acknowledgement
|
||||
/// from the remote endpoint.
|
||||
pub fn enable_connect_protocol(&mut self) -> Result<(), crate::Error> {
|
||||
self.connection.set_enable_connect_protocol()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns `Ready` when the underlying connection has closed.
|
||||
///
|
||||
/// If any new inbound streams are received during a call to `poll_closed`,
|
||||
@@ -904,6 +917,14 @@ impl Builder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Enables the [extended CONNECT protocol].
|
||||
///
|
||||
/// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
|
||||
pub fn enable_connect_protocol(&mut self) -> &mut Self {
|
||||
self.settings.set_enable_connect_protocol(Some(1));
|
||||
self
|
||||
}
|
||||
|
||||
/// Creates a new configured HTTP/2 server backed by `io`.
|
||||
///
|
||||
/// It is expected that `io` already be in an appropriate state to commence
|
||||
@@ -1360,7 +1381,7 @@ impl Peer {
|
||||
_,
|
||||
) = request.into_parts();
|
||||
|
||||
let pseudo = Pseudo::request(method, uri);
|
||||
let pseudo = Pseudo::request(method, uri, None);
|
||||
|
||||
Ok(frame::PushPromise::new(
|
||||
stream_id,
|
||||
@@ -1410,6 +1431,11 @@ impl proto::Peer for Peer {
|
||||
malformed!("malformed headers: missing method");
|
||||
}
|
||||
|
||||
let has_protocol = pseudo.protocol.is_some();
|
||||
if !is_connect && has_protocol {
|
||||
malformed!("malformed headers: :protocol on non-CONNECT request");
|
||||
}
|
||||
|
||||
if pseudo.status.is_some() {
|
||||
malformed!("malformed headers: :status field on request");
|
||||
}
|
||||
@@ -1432,7 +1458,7 @@ impl proto::Peer for Peer {
|
||||
|
||||
// A :scheme is required, except CONNECT.
|
||||
if let Some(scheme) = pseudo.scheme {
|
||||
if is_connect {
|
||||
if is_connect && !has_protocol {
|
||||
malformed!(":scheme in CONNECT");
|
||||
}
|
||||
let maybe_scheme = scheme.parse();
|
||||
@@ -1450,12 +1476,12 @@ impl proto::Peer for Peer {
|
||||
if parts.authority.is_some() {
|
||||
parts.scheme = Some(scheme);
|
||||
}
|
||||
} else if !is_connect {
|
||||
} else if !is_connect || has_protocol {
|
||||
malformed!("malformed headers: missing scheme");
|
||||
}
|
||||
|
||||
if let Some(path) = pseudo.path {
|
||||
if is_connect {
|
||||
if is_connect && !has_protocol {
|
||||
malformed!(":path in CONNECT");
|
||||
}
|
||||
|
||||
@@ -1468,6 +1494,8 @@ impl proto::Peer for Peer {
|
||||
parts.path_and_query = Some(maybe_path.or_else(|why| {
|
||||
malformed!("malformed headers: malformed path ({:?}): {}", path, why,)
|
||||
})?);
|
||||
} else if is_connect && has_protocol {
|
||||
malformed!("malformed headers: missing path in extended CONNECT");
|
||||
}
|
||||
|
||||
b = b.uri(parts);
|
||||
|
||||
@@ -47,6 +47,16 @@ macro_rules! assert_settings {
|
||||
}};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! assert_go_away {
|
||||
($frame:expr) => {{
|
||||
match $frame {
|
||||
h2::frame::Frame::GoAway(v) => v,
|
||||
f => panic!("expected GO_AWAY; actual={:?}", f),
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! poll_err {
|
||||
($transport:expr) => {{
|
||||
@@ -80,6 +90,7 @@ macro_rules! assert_default_settings {
|
||||
|
||||
use h2::frame::Frame;
|
||||
|
||||
#[track_caller]
|
||||
pub fn assert_frame_eq<T: Into<Frame>, U: Into<Frame>>(t: T, u: U) {
|
||||
let actual: Frame = t.into();
|
||||
let expected: Frame = u.into();
|
||||
|
||||
@@ -4,7 +4,10 @@ use std::fmt;
|
||||
use bytes::Bytes;
|
||||
use http::{self, HeaderMap, StatusCode};
|
||||
|
||||
use h2::frame::{self, Frame, StreamId};
|
||||
use h2::{
|
||||
ext::Protocol,
|
||||
frame::{self, Frame, StreamId},
|
||||
};
|
||||
|
||||
pub const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0];
|
||||
pub const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0];
|
||||
@@ -109,7 +112,9 @@ impl Mock<frame::Headers> {
|
||||
let method = method.try_into().unwrap();
|
||||
let uri = uri.try_into().unwrap();
|
||||
let (id, _, fields) = self.into_parts();
|
||||
let frame = frame::Headers::new(id, frame::Pseudo::request(method, uri), fields);
|
||||
let extensions = Default::default();
|
||||
let pseudo = frame::Pseudo::request(method, uri, extensions);
|
||||
let frame = frame::Headers::new(id, pseudo, fields);
|
||||
Mock(frame)
|
||||
}
|
||||
|
||||
@@ -179,6 +184,15 @@ impl Mock<frame::Headers> {
|
||||
Mock(frame::Headers::new(id, pseudo, fields))
|
||||
}
|
||||
|
||||
pub fn protocol(self, value: &str) -> Self {
|
||||
let (id, mut pseudo, fields) = self.into_parts();
|
||||
let value = Protocol::from(value);
|
||||
|
||||
pseudo.set_protocol(value);
|
||||
|
||||
Mock(frame::Headers::new(id, pseudo, fields))
|
||||
}
|
||||
|
||||
pub fn eos(mut self) -> Self {
|
||||
self.0.set_end_stream();
|
||||
self
|
||||
@@ -230,8 +244,9 @@ impl Mock<frame::PushPromise> {
|
||||
let method = method.try_into().unwrap();
|
||||
let uri = uri.try_into().unwrap();
|
||||
let (id, promised, _, fields) = self.into_parts();
|
||||
let frame =
|
||||
frame::PushPromise::new(id, promised, frame::Pseudo::request(method, uri), fields);
|
||||
let extensions = Default::default();
|
||||
let pseudo = frame::Pseudo::request(method, uri, extensions);
|
||||
let frame = frame::PushPromise::new(id, promised, pseudo, fields);
|
||||
Mock(frame)
|
||||
}
|
||||
|
||||
@@ -352,6 +367,11 @@ impl Mock<frame::Settings> {
|
||||
self.0.set_enable_push(false);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn enable_connect_protocol(mut self, val: u32) -> Self {
|
||||
self.0.set_enable_connect_protocol(Some(val));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Mock<frame::Settings>> for frame::Settings {
|
||||
|
||||
@@ -221,22 +221,15 @@ impl Handle {
|
||||
let settings = settings.into();
|
||||
self.send(settings.into()).await.unwrap();
|
||||
|
||||
let frame = self.next().await;
|
||||
let settings = match frame {
|
||||
Some(frame) => match frame.unwrap() {
|
||||
Frame::Settings(settings) => {
|
||||
// Send the ACK
|
||||
let ack = frame::Settings::ack();
|
||||
let frame = self.next().await.expect("unexpected EOF").unwrap();
|
||||
let settings = assert_settings!(frame);
|
||||
|
||||
// TODO: Don't unwrap?
|
||||
self.send(ack.into()).await.unwrap();
|
||||
// Send the ACK
|
||||
let ack = frame::Settings::ack();
|
||||
|
||||
// TODO: Don't unwrap?
|
||||
self.send(ack.into()).await.unwrap();
|
||||
|
||||
settings
|
||||
}
|
||||
frame => panic!("unexpected frame; frame={:?}", frame),
|
||||
},
|
||||
None => panic!("unexpected EOF"),
|
||||
};
|
||||
let frame = self.next().await;
|
||||
let f = assert_settings!(frame.unwrap().unwrap());
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
pub use h2;
|
||||
|
||||
pub use h2::client;
|
||||
pub use h2::ext::Protocol;
|
||||
pub use h2::frame::StreamId;
|
||||
pub use h2::server;
|
||||
pub use h2::*;
|
||||
@@ -20,8 +21,8 @@ pub use super::{Codec, SendFrame};
|
||||
|
||||
// Re-export macros
|
||||
pub use super::{
|
||||
assert_closed, assert_data, assert_default_settings, assert_headers, assert_ping, poll_err,
|
||||
poll_frame, raw_codec,
|
||||
assert_closed, assert_data, assert_default_settings, assert_go_away, assert_headers,
|
||||
assert_ping, assert_settings, poll_err, poll_frame, raw_codec,
|
||||
};
|
||||
|
||||
pub use super::assert::assert_frame_eq;
|
||||
|
||||
@@ -1305,6 +1305,153 @@ async fn informational_while_local_streaming() {
|
||||
join(srv, h2).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extended_connect_protocol_disabled_by_default() {
|
||||
h2_support::trace_init!();
|
||||
let (io, mut srv) = mock::new();
|
||||
|
||||
let srv = async move {
|
||||
let settings = srv.assert_client_handshake().await;
|
||||
assert_default_settings!(settings);
|
||||
|
||||
srv.recv_frame(
|
||||
frames::headers(1)
|
||||
.request("GET", "https://example.com/")
|
||||
.eos(),
|
||||
)
|
||||
.await;
|
||||
srv.send_frame(frames::headers(1).response(200).eos()).await;
|
||||
};
|
||||
|
||||
let h2 = async move {
|
||||
let (mut client, mut h2) = client::handshake(io).await.unwrap();
|
||||
|
||||
// we send a simple req here just to drive the connection so we can
|
||||
// receive the server settings.
|
||||
let request = Request::get("https://example.com/").body(()).unwrap();
|
||||
// first request is allowed
|
||||
let (response, _) = client.send_request(request, true).unwrap();
|
||||
h2.drive(response).await.unwrap();
|
||||
|
||||
assert!(!client.is_extended_connect_protocol_enabled());
|
||||
};
|
||||
|
||||
join(srv, h2).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extended_connect_protocol_enabled_during_handshake() {
|
||||
h2_support::trace_init!();
|
||||
let (io, mut srv) = mock::new();
|
||||
|
||||
let srv = async move {
|
||||
let settings = srv
|
||||
.assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1))
|
||||
.await;
|
||||
assert_default_settings!(settings);
|
||||
|
||||
srv.recv_frame(
|
||||
frames::headers(1)
|
||||
.request("GET", "https://example.com/")
|
||||
.eos(),
|
||||
)
|
||||
.await;
|
||||
srv.send_frame(frames::headers(1).response(200).eos()).await;
|
||||
};
|
||||
|
||||
let h2 = async move {
|
||||
let (mut client, mut h2) = client::handshake(io).await.unwrap();
|
||||
|
||||
// we send a simple req here just to drive the connection so we can
|
||||
// receive the server settings.
|
||||
let request = Request::get("https://example.com/").body(()).unwrap();
|
||||
let (response, _) = client.send_request(request, true).unwrap();
|
||||
h2.drive(response).await.unwrap();
|
||||
|
||||
assert!(client.is_extended_connect_protocol_enabled());
|
||||
};
|
||||
|
||||
join(srv, h2).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_connect_protocol_enabled_setting() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut srv) = mock::new();
|
||||
|
||||
let srv = async move {
|
||||
// Send a settings frame
|
||||
srv.send(frames::settings().enable_connect_protocol(2).into())
|
||||
.await
|
||||
.unwrap();
|
||||
srv.read_preface().await.unwrap();
|
||||
|
||||
let settings = assert_settings!(srv.next().await.expect("unexpected EOF").unwrap());
|
||||
assert_default_settings!(settings);
|
||||
|
||||
// Send the ACK
|
||||
let ack = frame::Settings::ack();
|
||||
|
||||
// TODO: Don't unwrap?
|
||||
srv.send(ack.into()).await.unwrap();
|
||||
|
||||
let frame = srv.next().await.unwrap().unwrap();
|
||||
let go_away = assert_go_away!(frame);
|
||||
assert_eq!(go_away.reason(), Reason::PROTOCOL_ERROR);
|
||||
};
|
||||
|
||||
let h2 = async move {
|
||||
let (mut client, mut h2) = client::handshake(io).await.unwrap();
|
||||
|
||||
// we send a simple req here just to drive the connection so we can
|
||||
// receive the server settings.
|
||||
let request = Request::get("https://example.com/").body(()).unwrap();
|
||||
let (response, _) = client.send_request(request, true).unwrap();
|
||||
|
||||
let error = h2.drive(response).await.unwrap_err();
|
||||
assert_eq!(error.reason(), Some(Reason::PROTOCOL_ERROR));
|
||||
};
|
||||
|
||||
join(srv, h2).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extended_connect_request() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut srv) = mock::new();
|
||||
|
||||
let srv = async move {
|
||||
let settings = srv
|
||||
.assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1))
|
||||
.await;
|
||||
assert_default_settings!(settings);
|
||||
|
||||
srv.recv_frame(
|
||||
frames::headers(1)
|
||||
.request("CONNECT", "http://bread/baguette")
|
||||
.protocol("the-bread-protocol")
|
||||
.eos(),
|
||||
)
|
||||
.await;
|
||||
srv.send_frame(frames::headers(1).response(200).eos()).await;
|
||||
};
|
||||
|
||||
let h2 = async move {
|
||||
let (mut client, mut h2) = client::handshake(io).await.unwrap();
|
||||
|
||||
let request = Request::connect("http://bread/baguette")
|
||||
.extension(Protocol::from("the-bread-protocol"))
|
||||
.body(())
|
||||
.unwrap();
|
||||
let (response, _) = client.send_request(request, true).unwrap();
|
||||
h2.drive(response).await.unwrap();
|
||||
};
|
||||
|
||||
join(srv, h2).await;
|
||||
}
|
||||
|
||||
const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0];
|
||||
const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0];
|
||||
|
||||
|
||||
@@ -1149,3 +1149,191 @@ async fn send_reset_explicitly() {
|
||||
|
||||
join(client, srv).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extended_connect_protocol_disabled_by_default() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut client) = mock::new();
|
||||
|
||||
let client = async move {
|
||||
let settings = client.assert_server_handshake().await;
|
||||
|
||||
assert_eq!(settings.is_extended_connect_protocol_enabled(), None);
|
||||
|
||||
client
|
||||
.send_frame(
|
||||
frames::headers(1)
|
||||
.request("CONNECT", "http://bread/baguette")
|
||||
.protocol("the-bread-protocol"),
|
||||
)
|
||||
.await;
|
||||
|
||||
client.recv_frame(frames::reset(1).protocol_error()).await;
|
||||
};
|
||||
|
||||
let srv = async move {
|
||||
let mut srv = server::handshake(io).await.expect("handshake");
|
||||
|
||||
poll_fn(move |cx| srv.poll_closed(cx))
|
||||
.await
|
||||
.expect("server");
|
||||
};
|
||||
|
||||
join(client, srv).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extended_connect_protocol_enabled_during_handshake() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut client) = mock::new();
|
||||
|
||||
let client = async move {
|
||||
let settings = client.assert_server_handshake().await;
|
||||
|
||||
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
|
||||
|
||||
client
|
||||
.send_frame(
|
||||
frames::headers(1)
|
||||
.request("CONNECT", "http://bread/baguette")
|
||||
.protocol("the-bread-protocol"),
|
||||
)
|
||||
.await;
|
||||
|
||||
client.recv_frame(frames::headers(1).response(200)).await;
|
||||
};
|
||||
|
||||
let srv = async move {
|
||||
let mut builder = server::Builder::new();
|
||||
|
||||
builder.enable_connect_protocol();
|
||||
|
||||
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
|
||||
|
||||
let (_req, mut stream) = srv.next().await.unwrap().unwrap();
|
||||
|
||||
let rsp = Response::new(());
|
||||
stream.send_response(rsp, false).unwrap();
|
||||
|
||||
poll_fn(move |cx| srv.poll_closed(cx))
|
||||
.await
|
||||
.expect("server");
|
||||
};
|
||||
|
||||
join(client, srv).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_pseudo_protocol_on_non_connect_request() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut client) = mock::new();
|
||||
|
||||
let client = async move {
|
||||
let settings = client.assert_server_handshake().await;
|
||||
|
||||
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
|
||||
|
||||
client
|
||||
.send_frame(
|
||||
frames::headers(1)
|
||||
.request("GET", "http://bread/baguette")
|
||||
.protocol("the-bread-protocol"),
|
||||
)
|
||||
.await;
|
||||
|
||||
client.recv_frame(frames::reset(1).protocol_error()).await;
|
||||
};
|
||||
|
||||
let srv = async move {
|
||||
let mut builder = server::Builder::new();
|
||||
|
||||
builder.enable_connect_protocol();
|
||||
|
||||
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
|
||||
|
||||
assert!(srv.next().await.is_none());
|
||||
|
||||
poll_fn(move |cx| srv.poll_closed(cx))
|
||||
.await
|
||||
.expect("server");
|
||||
};
|
||||
|
||||
join(client, srv).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_authority_target_on_extended_connect_request() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut client) = mock::new();
|
||||
|
||||
let client = async move {
|
||||
let settings = client.assert_server_handshake().await;
|
||||
|
||||
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
|
||||
|
||||
client
|
||||
.send_frame(
|
||||
frames::headers(1)
|
||||
.request("CONNECT", "bread:80")
|
||||
.protocol("the-bread-protocol"),
|
||||
)
|
||||
.await;
|
||||
|
||||
client.recv_frame(frames::reset(1).protocol_error()).await;
|
||||
};
|
||||
|
||||
let srv = async move {
|
||||
let mut builder = server::Builder::new();
|
||||
|
||||
builder.enable_connect_protocol();
|
||||
|
||||
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
|
||||
|
||||
assert!(srv.next().await.is_none());
|
||||
|
||||
poll_fn(move |cx| srv.poll_closed(cx))
|
||||
.await
|
||||
.expect("server");
|
||||
};
|
||||
|
||||
join(client, srv).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_non_authority_target_on_connect_request() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut client) = mock::new();
|
||||
|
||||
let client = async move {
|
||||
let settings = client.assert_server_handshake().await;
|
||||
|
||||
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
|
||||
|
||||
client
|
||||
.send_frame(frames::headers(1).request("CONNECT", "https://bread/baguette"))
|
||||
.await;
|
||||
|
||||
client.recv_frame(frames::reset(1).protocol_error()).await;
|
||||
};
|
||||
|
||||
let srv = async move {
|
||||
let mut builder = server::Builder::new();
|
||||
|
||||
builder.enable_connect_protocol();
|
||||
|
||||
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
|
||||
|
||||
assert!(srv.next().await.is_none());
|
||||
|
||||
poll_fn(move |cx| srv.poll_closed(cx))
|
||||
.await
|
||||
.expect("server");
|
||||
};
|
||||
|
||||
join(client, srv).await;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user