Implement the extended CONNECT protocol from RFC 8441 (#565)

This commit is contained in:
Anthony Ramine
2021-11-24 10:05:10 +01:00
committed by GitHub
parent dbaa3a4285
commit 87969c1f29
22 changed files with 694 additions and 120 deletions

View File

@@ -136,6 +136,7 @@
//! [`Error`]: ../struct.Error.html
use crate::codec::{Codec, SendError, UserError};
use crate::ext::Protocol;
use crate::frame::{Headers, Pseudo, Reason, Settings, StreamId};
use crate::proto::{self, Error};
use crate::{FlowControl, PingPong, RecvStream, SendStream};
@@ -517,6 +518,19 @@ where
(response, stream)
})
}
/// Returns whether the [extended CONNECT protocol][1] is enabled or not.
///
/// This setting is configured by the server peer by sending the
/// [`SETTINGS_ENABLE_CONNECT_PROTOCOL` parameter][2] in a `SETTINGS` frame.
/// This method returns the currently acknowledged value recieved from the
/// remote.
///
/// [1]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
/// [2]: https://datatracker.ietf.org/doc/html/rfc8441#section-3
pub fn is_extended_connect_protocol_enabled(&self) -> bool {
self.inner.is_extended_connect_protocol_enabled()
}
}
impl<B> fmt::Debug for SendRequest<B>
@@ -1246,11 +1260,10 @@ where
/// This method returns the currently acknowledged value recieved from the
/// remote.
///
/// [settings]: https://tools.ietf.org/html/rfc7540#section-5.1.2
/// [1]: https://tools.ietf.org/html/rfc7540#section-5.1.2
pub fn max_concurrent_send_streams(&self) -> usize {
self.inner.max_send_streams()
}
/// Returns the maximum number of concurrent streams that may be initiated
/// by the server on this connection.
///
@@ -1416,6 +1429,7 @@ impl Peer {
pub fn convert_send_message(
id: StreamId,
request: Request<()>,
protocol: Option<Protocol>,
end_of_stream: bool,
) -> Result<Headers, SendError> {
use http::request::Parts;
@@ -1435,7 +1449,7 @@ impl Peer {
// Build the set pseudo header set. All requests will include `method`
// and `path`.
let mut pseudo = Pseudo::request(method, uri);
let mut pseudo = Pseudo::request(method, uri, protocol);
if pseudo.scheme.is_none() {
// If the scheme is not set, then there are a two options.

55
src/ext.rs Normal file
View File

@@ -0,0 +1,55 @@
//! Extensions specific to the HTTP/2 protocol.
use crate::hpack::BytesStr;
use bytes::Bytes;
use std::fmt;
/// Represents the `:protocol` pseudo-header used by
/// the [Extended CONNECT Protocol].
///
/// [Extended CONNECT Protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
#[derive(Clone, Eq, PartialEq)]
pub struct Protocol {
value: BytesStr,
}
impl Protocol {
/// Converts a static string to a protocol name.
pub const fn from_static(value: &'static str) -> Self {
Self {
value: BytesStr::from_static(value),
}
}
/// Returns a str representation of the header.
pub fn as_str(&self) -> &str {
self.value.as_str()
}
pub(crate) fn try_from(bytes: Bytes) -> Result<Self, std::str::Utf8Error> {
Ok(Self {
value: BytesStr::try_from(bytes)?,
})
}
}
impl<'a> From<&'a str> for Protocol {
fn from(value: &'a str) -> Self {
Self {
value: BytesStr::from(value),
}
}
}
impl AsRef<[u8]> for Protocol {
fn as_ref(&self) -> &[u8] {
self.value.as_ref()
}
}
impl fmt::Debug for Protocol {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.value.fmt(f)
}
}

View File

@@ -1,4 +1,5 @@
use super::{util, StreamDependency, StreamId};
use crate::ext::Protocol;
use crate::frame::{Error, Frame, Head, Kind};
use crate::hpack::{self, BytesStr};
@@ -66,6 +67,7 @@ pub struct Pseudo {
pub scheme: Option<BytesStr>,
pub authority: Option<BytesStr>,
pub path: Option<BytesStr>,
pub protocol: Option<Protocol>,
// Response
pub status: Option<StatusCode>,
@@ -292,6 +294,10 @@ impl fmt::Debug for Headers {
.field("stream_id", &self.stream_id)
.field("flags", &self.flags);
if let Some(ref protocol) = self.header_block.pseudo.protocol {
builder.field("protocol", protocol);
}
if let Some(ref dep) = self.stream_dep {
builder.field("stream_dep", dep);
}
@@ -529,7 +535,7 @@ impl Continuation {
// ===== impl Pseudo =====
impl Pseudo {
pub fn request(method: Method, uri: Uri) -> Self {
pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
let parts = uri::Parts::from(uri);
let mut path = parts
@@ -550,6 +556,7 @@ impl Pseudo {
scheme: None,
authority: None,
path: Some(path).filter(|p| !p.is_empty()),
protocol,
status: None,
};
@@ -575,6 +582,7 @@ impl Pseudo {
scheme: None,
authority: None,
path: None,
protocol: None,
status: Some(status),
}
}
@@ -593,6 +601,11 @@ impl Pseudo {
self.scheme = Some(bytes_str);
}
#[cfg(feature = "unstable")]
pub fn set_protocol(&mut self, protocol: Protocol) {
self.protocol = Some(protocol);
}
pub fn set_authority(&mut self, authority: BytesStr) {
self.authority = Some(authority);
}
@@ -681,6 +694,10 @@ impl Iterator for Iter {
return Some(Path(path));
}
if let Some(protocol) = pseudo.protocol.take() {
return Some(Protocol(protocol));
}
if let Some(status) = pseudo.status.take() {
return Some(Status(status));
}
@@ -879,6 +896,7 @@ impl HeaderBlock {
Method(v) => set_pseudo!(method, v),
Scheme(v) => set_pseudo!(scheme, v),
Path(v) => set_pseudo!(path, v),
Protocol(v) => set_pseudo!(protocol, v),
Status(v) => set_pseudo!(status, v),
}
});

View File

@@ -13,6 +13,7 @@ pub struct Settings {
initial_window_size: Option<u32>,
max_frame_size: Option<u32>,
max_header_list_size: Option<u32>,
enable_connect_protocol: Option<u32>,
}
/// An enum that lists all valid settings that can be sent in a SETTINGS
@@ -27,6 +28,7 @@ pub enum Setting {
InitialWindowSize(u32),
MaxFrameSize(u32),
MaxHeaderListSize(u32),
EnableConnectProtocol(u32),
}
#[derive(Copy, Clone, Eq, PartialEq, Default)]
@@ -107,6 +109,14 @@ impl Settings {
self.enable_push = Some(enable as u32);
}
pub fn is_extended_connect_protocol_enabled(&self) -> Option<bool> {
self.enable_connect_protocol.map(|val| val != 0)
}
pub fn set_enable_connect_protocol(&mut self, val: Option<u32>) {
self.enable_connect_protocol = val;
}
pub fn header_table_size(&self) -> Option<u32> {
self.header_table_size
}
@@ -181,6 +191,14 @@ impl Settings {
Some(MaxHeaderListSize(val)) => {
settings.max_header_list_size = Some(val);
}
Some(EnableConnectProtocol(val)) => match val {
0 | 1 => {
settings.enable_connect_protocol = Some(val);
}
_ => {
return Err(Error::InvalidSettingValue);
}
},
None => {}
}
}
@@ -236,6 +254,10 @@ impl Settings {
if let Some(v) = self.max_header_list_size {
f(MaxHeaderListSize(v));
}
if let Some(v) = self.enable_connect_protocol {
f(EnableConnectProtocol(v));
}
}
}
@@ -269,6 +291,9 @@ impl fmt::Debug for Settings {
Setting::MaxHeaderListSize(v) => {
builder.field("max_header_list_size", &v);
}
Setting::EnableConnectProtocol(v) => {
builder.field("enable_connect_protocol", &v);
}
});
builder.finish()
@@ -291,6 +316,7 @@ impl Setting {
4 => Some(InitialWindowSize(val)),
5 => Some(MaxFrameSize(val)),
6 => Some(MaxHeaderListSize(val)),
8 => Some(EnableConnectProtocol(val)),
_ => None,
}
}
@@ -322,6 +348,7 @@ impl Setting {
InitialWindowSize(v) => (4, v),
MaxFrameSize(v) => (5, v),
MaxHeaderListSize(v) => (6, v),
EnableConnectProtocol(v) => (8, v),
};
dst.put_u16(kind);

View File

@@ -1,4 +1,5 @@
use super::{DecoderError, NeedMore};
use crate::ext::Protocol;
use bytes::Bytes;
use http::header::{HeaderName, HeaderValue};
@@ -14,6 +15,7 @@ pub enum Header<T = HeaderName> {
Method(Method),
Scheme(BytesStr),
Path(BytesStr),
Protocol(Protocol),
Status(StatusCode),
}
@@ -25,6 +27,7 @@ pub enum Name<'a> {
Method,
Scheme,
Path,
Protocol,
Status,
}
@@ -51,6 +54,7 @@ impl Header<Option<HeaderName>> {
Method(v) => Method(v),
Scheme(v) => Scheme(v),
Path(v) => Path(v),
Protocol(v) => Protocol(v),
Status(v) => Status(v),
})
}
@@ -79,6 +83,10 @@ impl Header {
let value = BytesStr::try_from(value)?;
Ok(Header::Path(value))
}
b"protocol" => {
let value = Protocol::try_from(value)?;
Ok(Header::Protocol(value))
}
b"status" => {
let status = StatusCode::from_bytes(&value)?;
Ok(Header::Status(status))
@@ -104,6 +112,7 @@ impl Header {
Header::Method(ref v) => 32 + 7 + v.as_ref().len(),
Header::Scheme(ref v) => 32 + 7 + v.len(),
Header::Path(ref v) => 32 + 5 + v.len(),
Header::Protocol(ref v) => 32 + 9 + v.as_str().len(),
Header::Status(_) => 32 + 7 + 3,
}
}
@@ -116,6 +125,7 @@ impl Header {
Header::Method(..) => Name::Method,
Header::Scheme(..) => Name::Scheme,
Header::Path(..) => Name::Path,
Header::Protocol(..) => Name::Protocol,
Header::Status(..) => Name::Status,
}
}
@@ -127,6 +137,7 @@ impl Header {
Header::Method(ref v) => v.as_ref().as_ref(),
Header::Scheme(ref v) => v.as_ref(),
Header::Path(ref v) => v.as_ref(),
Header::Protocol(ref v) => v.as_ref(),
Header::Status(ref v) => v.as_str().as_ref(),
}
}
@@ -156,6 +167,10 @@ impl Header {
Header::Path(ref b) => a == b,
_ => false,
},
Header::Protocol(ref a) => match *other {
Header::Protocol(ref b) => a == b,
_ => false,
},
Header::Status(ref a) => match *other {
Header::Status(ref b) => a == b,
_ => false,
@@ -205,6 +220,7 @@ impl From<Header> for Header<Option<HeaderName>> {
Header::Method(v) => Header::Method(v),
Header::Scheme(v) => Header::Scheme(v),
Header::Path(v) => Header::Path(v),
Header::Protocol(v) => Header::Protocol(v),
Header::Status(v) => Header::Status(v),
}
}
@@ -221,6 +237,7 @@ impl<'a> Name<'a> {
Name::Method => Ok(Header::Method(Method::from_bytes(&*value)?)),
Name::Scheme => Ok(Header::Scheme(BytesStr::try_from(value)?)),
Name::Path => Ok(Header::Path(BytesStr::try_from(value)?)),
Name::Protocol => Ok(Header::Protocol(Protocol::try_from(value)?)),
Name::Status => {
match StatusCode::from_bytes(&value) {
Ok(status) => Ok(Header::Status(status)),
@@ -238,6 +255,7 @@ impl<'a> Name<'a> {
Name::Method => b":method",
Name::Scheme => b":scheme",
Name::Path => b":path",
Name::Protocol => b":protocol",
Name::Status => b":status",
}
}

View File

@@ -751,6 +751,7 @@ fn index_static(header: &Header) -> Option<(usize, bool)> {
"/index.html" => Some((5, true)),
_ => Some((4, false)),
},
Header::Protocol(..) => None,
Header::Status(ref v) => match u16::from(*v) {
200 => Some((8, true)),
204 => Some((9, true)),

View File

@@ -134,6 +134,7 @@ fn key_str(e: &Header) -> &str {
Header::Method(..) => ":method",
Header::Scheme(..) => ":scheme",
Header::Path(..) => ":path",
Header::Protocol(..) => ":protocol",
Header::Status(..) => ":status",
}
}
@@ -145,6 +146,7 @@ fn value_str(e: &Header) -> &str {
Header::Method(ref m) => m.as_str(),
Header::Scheme(ref v) => &**v,
Header::Path(ref v) => &**v,
Header::Protocol(ref v) => v.as_str(),
Header::Status(ref v) => v.as_str(),
}
}

View File

@@ -120,6 +120,7 @@ mod frame;
pub mod frame;
pub mod client;
pub mod ext;
pub mod server;
mod share;

View File

@@ -110,6 +110,10 @@ where
initial_max_send_streams: config.initial_max_send_streams,
local_next_stream_id: config.next_stream_id,
local_push_enabled: config.settings.is_push_enabled().unwrap_or(true),
extended_connect_protocol_enabled: config
.settings
.is_extended_connect_protocol_enabled()
.unwrap_or(false),
local_reset_duration: config.reset_stream_duration,
local_reset_max: config.reset_stream_max,
remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE,
@@ -147,6 +151,13 @@ where
self.inner.settings.send_settings(settings)
}
/// Send a new SETTINGS frame with extended CONNECT protocol enabled.
pub(crate) fn set_enable_connect_protocol(&mut self) -> Result<(), UserError> {
let mut settings = frame::Settings::default();
settings.set_enable_connect_protocol(Some(1));
self.inner.settings.send_settings(settings)
}
/// Returns the maximum number of concurrent streams that may be initiated
/// by this peer.
pub(crate) fn max_send_streams(&self) -> usize {

View File

@@ -117,6 +117,8 @@ impl Settings {
tracing::trace!("ACK sent; applying settings");
streams.apply_remote_settings(settings)?;
if let Some(val) = settings.header_table_size() {
dst.set_send_header_table_size(val as usize);
}
@@ -124,8 +126,6 @@ impl Settings {
if let Some(val) = settings.max_frame_size() {
dst.set_max_send_frame_size(val as usize);
}
streams.apply_remote_settings(settings)?;
}
self.remote = None;

View File

@@ -47,6 +47,9 @@ pub struct Config {
/// If the local peer is willing to receive push promises
pub local_push_enabled: bool,
/// If extended connect protocol is enabled.
pub extended_connect_protocol_enabled: bool,
/// How long a locally reset stream should ignore frames
pub local_reset_duration: Duration,

View File

@@ -56,6 +56,9 @@ pub(super) struct Recv {
/// If push promises are allowed to be received.
is_push_enabled: bool,
/// If extended connect protocol is enabled.
is_extended_connect_protocol_enabled: bool,
}
#[derive(Debug)]
@@ -103,6 +106,7 @@ impl Recv {
buffer: Buffer::new(),
refused: None,
is_push_enabled: config.local_push_enabled,
is_extended_connect_protocol_enabled: config.extended_connect_protocol_enabled,
}
}
@@ -216,6 +220,14 @@ impl Recv {
let stream_id = frame.stream_id();
let (pseudo, fields) = frame.into_parts();
if pseudo.protocol.is_some() {
if counts.peer().is_server() && !self.is_extended_connect_protocol_enabled {
proto_err!(stream: "cannot use :protocol if extended connect protocol is disabled; stream={:?}", stream.id);
return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR).into());
}
}
if !pseudo.is_informational() {
let message = counts
.peer()
@@ -449,12 +461,11 @@ impl Recv {
settings: &frame::Settings,
store: &mut Store,
) -> Result<(), proto::Error> {
let target = if let Some(val) = settings.initial_window_size() {
val
} else {
return Ok(());
};
if let Some(val) = settings.is_extended_connect_protocol_enabled() {
self.is_extended_connect_protocol_enabled = val;
}
if let Some(target) = settings.initial_window_size() {
let old_sz = self.init_window_sz;
self.init_window_sz = target;
@@ -483,13 +494,12 @@ impl Recv {
store.for_each(|mut stream| {
stream.recv_flow.dec_recv_window(dec);
Ok(())
})
} else if target > old_sz {
// We must increase the (local) window on every open stream.
let inc = target - old_sz;
tracing::trace!("incrementing all windows; inc={}", inc);
store.for_each(|mut stream| {
store.try_for_each(|mut stream| {
// XXX: Shouldn't the peer have already noticed our
// overflow and sent us a GOAWAY?
stream
@@ -497,14 +507,14 @@ impl Recv {
.inc_window(inc)
.map_err(proto::Error::library_go_away)?;
stream.recv_flow.assign_capacity(inc);
Ok(())
})
} else {
// size is the same... so do nothing
Ok(())
Ok::<_, proto::Error>(())
})?;
}
}
Ok(())
}
pub fn is_end_stream(&self, stream: &store::Ptr) -> bool {
if !stream.state.is_recv_closed() {
return false;

View File

@@ -35,6 +35,9 @@ pub(super) struct Send {
prioritize: Prioritize,
is_push_enabled: bool,
/// If extended connect protocol is enabled.
is_extended_connect_protocol_enabled: bool,
}
/// A value to detect which public API has called `poll_reset`.
@@ -53,6 +56,7 @@ impl Send {
next_stream_id: Ok(config.local_next_stream_id),
prioritize: Prioritize::new(config),
is_push_enabled: true,
is_extended_connect_protocol_enabled: false,
}
}
@@ -429,6 +433,10 @@ impl Send {
counts: &mut Counts,
task: &mut Option<Waker>,
) -> Result<(), Error> {
if let Some(val) = settings.is_extended_connect_protocol_enabled() {
self.is_extended_connect_protocol_enabled = val;
}
// Applies an update to the remote endpoint's initial window size.
//
// Per RFC 7540 §6.9.2:
@@ -490,16 +498,14 @@ impl Send {
// TODO: Should this notify the producer when the capacity
// of a stream is reduced? Maybe it should if the capacity
// is reduced to zero, allowing the producer to stop work.
Ok::<_, Error>(())
})?;
});
self.prioritize
.assign_connection_capacity(total_reclaimed, store, counts);
} else if val > old_val {
let inc = val - old_val;
store.for_each(|mut stream| {
store.try_for_each(|mut stream| {
self.recv_stream_window_update(inc, buffer, &mut stream, counts, task)
.map_err(Error::library_go_away)
})?;
@@ -554,4 +560,8 @@ impl Send {
}
}
}
pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool {
self.is_extended_connect_protocol_enabled
}
}

View File

@@ -4,6 +4,7 @@ use slab;
use indexmap::{self, IndexMap};
use std::convert::Infallible;
use std::fmt;
use std::marker::PhantomData;
use std::ops;
@@ -128,7 +129,20 @@ impl Store {
}
}
pub fn for_each<F, E>(&mut self, mut f: F) -> Result<(), E>
pub(crate) fn for_each<F>(&mut self, mut f: F)
where
F: FnMut(Ptr),
{
match self.try_for_each(|ptr| {
f(ptr);
Ok::<_, Infallible>(())
}) {
Ok(()) => (),
Err(infallible) => match infallible {},
}
}
pub fn try_for_each<F, E>(&mut self, mut f: F) -> Result<(), E>
where
F: FnMut(Ptr) -> Result<(), E>,
{

View File

@@ -2,6 +2,7 @@ use super::recv::RecvHeaderBlockError;
use super::store::{self, Entry, Resolve, Store};
use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId};
use crate::codec::{Codec, SendError, UserError};
use crate::ext::Protocol;
use crate::frame::{self, Frame, Reason};
use crate::proto::{peer, Error, Initiator, Open, Peer, WindowSize};
use crate::{client, proto, server};
@@ -214,6 +215,8 @@ where
use super::stream::ContentLength;
use http::Method;
let protocol = request.extensions_mut().remove::<Protocol>();
// Clear before taking lock, incase extensions contain a StreamRef.
request.extensions_mut().clear();
@@ -261,7 +264,8 @@ where
}
// Convert the message
let headers = client::Peer::convert_send_message(stream_id, request, end_of_stream)?;
let headers =
client::Peer::convert_send_message(stream_id, request, protocol, end_of_stream)?;
let mut stream = me.store.insert(stream.id, stream);
@@ -294,6 +298,15 @@ where
send_buffer: self.send_buffer.clone(),
})
}
pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool {
self.inner
.lock()
.unwrap()
.actions
.send
.is_extended_connect_protocol_enabled()
}
}
impl<B> DynStreams<'_, B> {
@@ -643,15 +656,12 @@ impl Inner {
let last_processed_id = actions.recv.last_processed_id();
self.store
.for_each(|stream| {
self.store.for_each(|stream| {
counts.transition(stream, |counts, stream| {
actions.recv.handle_error(&err, &mut *stream);
actions.send.handle_error(send_buffer, stream, counts);
Ok::<_, ()>(())
})
})
.unwrap();
});
actions.conn_error = Some(err);
@@ -674,19 +684,14 @@ impl Inner {
let err = Error::remote_go_away(frame.debug_data().clone(), frame.reason());
self.store
.for_each(|stream| {
self.store.for_each(|stream| {
if stream.id > last_stream_id {
counts.transition(stream, |counts, stream| {
actions.recv.handle_error(&err, &mut *stream);
actions.send.handle_error(send_buffer, stream, counts);
Ok::<_, ()>(())
})
} else {
Ok::<_, ()>(())
}
})
.unwrap();
});
actions.conn_error = Some(err);
@@ -807,18 +812,15 @@ impl Inner {
tracing::trace!("Streams::recv_eof");
self.store
.for_each(|stream| {
self.store.for_each(|stream| {
counts.transition(stream, |counts, stream| {
actions.recv.recv_eof(stream);
// This handles resetting send state associated with the
// stream
actions.send.handle_error(send_buffer, stream, counts);
Ok::<_, ()>(())
})
})
.expect("recv_eof");
});
actions.clear_queues(clear_pending_accept, &mut self.store, counts);
Ok(())

View File

@@ -470,6 +470,19 @@ where
Ok(())
}
/// Enables the [extended CONNECT protocol].
///
/// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
///
/// # Errors
///
/// Returns an error if a previous call is still pending acknowledgement
/// from the remote endpoint.
pub fn enable_connect_protocol(&mut self) -> Result<(), crate::Error> {
self.connection.set_enable_connect_protocol()?;
Ok(())
}
/// Returns `Ready` when the underlying connection has closed.
///
/// If any new inbound streams are received during a call to `poll_closed`,
@@ -904,6 +917,14 @@ impl Builder {
self
}
/// Enables the [extended CONNECT protocol].
///
/// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
pub fn enable_connect_protocol(&mut self) -> &mut Self {
self.settings.set_enable_connect_protocol(Some(1));
self
}
/// Creates a new configured HTTP/2 server backed by `io`.
///
/// It is expected that `io` already be in an appropriate state to commence
@@ -1360,7 +1381,7 @@ impl Peer {
_,
) = request.into_parts();
let pseudo = Pseudo::request(method, uri);
let pseudo = Pseudo::request(method, uri, None);
Ok(frame::PushPromise::new(
stream_id,
@@ -1410,6 +1431,11 @@ impl proto::Peer for Peer {
malformed!("malformed headers: missing method");
}
let has_protocol = pseudo.protocol.is_some();
if !is_connect && has_protocol {
malformed!("malformed headers: :protocol on non-CONNECT request");
}
if pseudo.status.is_some() {
malformed!("malformed headers: :status field on request");
}
@@ -1432,7 +1458,7 @@ impl proto::Peer for Peer {
// A :scheme is required, except CONNECT.
if let Some(scheme) = pseudo.scheme {
if is_connect {
if is_connect && !has_protocol {
malformed!(":scheme in CONNECT");
}
let maybe_scheme = scheme.parse();
@@ -1450,12 +1476,12 @@ impl proto::Peer for Peer {
if parts.authority.is_some() {
parts.scheme = Some(scheme);
}
} else if !is_connect {
} else if !is_connect || has_protocol {
malformed!("malformed headers: missing scheme");
}
if let Some(path) = pseudo.path {
if is_connect {
if is_connect && !has_protocol {
malformed!(":path in CONNECT");
}
@@ -1468,6 +1494,8 @@ impl proto::Peer for Peer {
parts.path_and_query = Some(maybe_path.or_else(|why| {
malformed!("malformed headers: malformed path ({:?}): {}", path, why,)
})?);
} else if is_connect && has_protocol {
malformed!("malformed headers: missing path in extended CONNECT");
}
b = b.uri(parts);

View File

@@ -47,6 +47,16 @@ macro_rules! assert_settings {
}};
}
#[macro_export]
macro_rules! assert_go_away {
($frame:expr) => {{
match $frame {
h2::frame::Frame::GoAway(v) => v,
f => panic!("expected GO_AWAY; actual={:?}", f),
}
}};
}
#[macro_export]
macro_rules! poll_err {
($transport:expr) => {{
@@ -80,6 +90,7 @@ macro_rules! assert_default_settings {
use h2::frame::Frame;
#[track_caller]
pub fn assert_frame_eq<T: Into<Frame>, U: Into<Frame>>(t: T, u: U) {
let actual: Frame = t.into();
let expected: Frame = u.into();

View File

@@ -4,7 +4,10 @@ use std::fmt;
use bytes::Bytes;
use http::{self, HeaderMap, StatusCode};
use h2::frame::{self, Frame, StreamId};
use h2::{
ext::Protocol,
frame::{self, Frame, StreamId},
};
pub const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0];
pub const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0];
@@ -109,7 +112,9 @@ impl Mock<frame::Headers> {
let method = method.try_into().unwrap();
let uri = uri.try_into().unwrap();
let (id, _, fields) = self.into_parts();
let frame = frame::Headers::new(id, frame::Pseudo::request(method, uri), fields);
let extensions = Default::default();
let pseudo = frame::Pseudo::request(method, uri, extensions);
let frame = frame::Headers::new(id, pseudo, fields);
Mock(frame)
}
@@ -179,6 +184,15 @@ impl Mock<frame::Headers> {
Mock(frame::Headers::new(id, pseudo, fields))
}
pub fn protocol(self, value: &str) -> Self {
let (id, mut pseudo, fields) = self.into_parts();
let value = Protocol::from(value);
pseudo.set_protocol(value);
Mock(frame::Headers::new(id, pseudo, fields))
}
pub fn eos(mut self) -> Self {
self.0.set_end_stream();
self
@@ -230,8 +244,9 @@ impl Mock<frame::PushPromise> {
let method = method.try_into().unwrap();
let uri = uri.try_into().unwrap();
let (id, promised, _, fields) = self.into_parts();
let frame =
frame::PushPromise::new(id, promised, frame::Pseudo::request(method, uri), fields);
let extensions = Default::default();
let pseudo = frame::Pseudo::request(method, uri, extensions);
let frame = frame::PushPromise::new(id, promised, pseudo, fields);
Mock(frame)
}
@@ -352,6 +367,11 @@ impl Mock<frame::Settings> {
self.0.set_enable_push(false);
self
}
pub fn enable_connect_protocol(mut self, val: u32) -> Self {
self.0.set_enable_connect_protocol(Some(val));
self
}
}
impl From<Mock<frame::Settings>> for frame::Settings {

View File

@@ -221,22 +221,15 @@ impl Handle {
let settings = settings.into();
self.send(settings.into()).await.unwrap();
let frame = self.next().await;
let settings = match frame {
Some(frame) => match frame.unwrap() {
Frame::Settings(settings) => {
let frame = self.next().await.expect("unexpected EOF").unwrap();
let settings = assert_settings!(frame);
// Send the ACK
let ack = frame::Settings::ack();
// TODO: Don't unwrap?
self.send(ack.into()).await.unwrap();
settings
}
frame => panic!("unexpected frame; frame={:?}", frame),
},
None => panic!("unexpected EOF"),
};
let frame = self.next().await;
let f = assert_settings!(frame.unwrap().unwrap());

View File

@@ -2,6 +2,7 @@
pub use h2;
pub use h2::client;
pub use h2::ext::Protocol;
pub use h2::frame::StreamId;
pub use h2::server;
pub use h2::*;
@@ -20,8 +21,8 @@ pub use super::{Codec, SendFrame};
// Re-export macros
pub use super::{
assert_closed, assert_data, assert_default_settings, assert_headers, assert_ping, poll_err,
poll_frame, raw_codec,
assert_closed, assert_data, assert_default_settings, assert_go_away, assert_headers,
assert_ping, assert_settings, poll_err, poll_frame, raw_codec,
};
pub use super::assert::assert_frame_eq;

View File

@@ -1305,6 +1305,153 @@ async fn informational_while_local_streaming() {
join(srv, h2).await;
}
#[tokio::test]
async fn extended_connect_protocol_disabled_by_default() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();
let srv = async move {
let settings = srv.assert_client_handshake().await;
assert_default_settings!(settings);
srv.recv_frame(
frames::headers(1)
.request("GET", "https://example.com/")
.eos(),
)
.await;
srv.send_frame(frames::headers(1).response(200).eos()).await;
};
let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();
// we send a simple req here just to drive the connection so we can
// receive the server settings.
let request = Request::get("https://example.com/").body(()).unwrap();
// first request is allowed
let (response, _) = client.send_request(request, true).unwrap();
h2.drive(response).await.unwrap();
assert!(!client.is_extended_connect_protocol_enabled());
};
join(srv, h2).await;
}
#[tokio::test]
async fn extended_connect_protocol_enabled_during_handshake() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();
let srv = async move {
let settings = srv
.assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1))
.await;
assert_default_settings!(settings);
srv.recv_frame(
frames::headers(1)
.request("GET", "https://example.com/")
.eos(),
)
.await;
srv.send_frame(frames::headers(1).response(200).eos()).await;
};
let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();
// we send a simple req here just to drive the connection so we can
// receive the server settings.
let request = Request::get("https://example.com/").body(()).unwrap();
let (response, _) = client.send_request(request, true).unwrap();
h2.drive(response).await.unwrap();
assert!(client.is_extended_connect_protocol_enabled());
};
join(srv, h2).await;
}
#[tokio::test]
async fn invalid_connect_protocol_enabled_setting() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();
let srv = async move {
// Send a settings frame
srv.send(frames::settings().enable_connect_protocol(2).into())
.await
.unwrap();
srv.read_preface().await.unwrap();
let settings = assert_settings!(srv.next().await.expect("unexpected EOF").unwrap());
assert_default_settings!(settings);
// Send the ACK
let ack = frame::Settings::ack();
// TODO: Don't unwrap?
srv.send(ack.into()).await.unwrap();
let frame = srv.next().await.unwrap().unwrap();
let go_away = assert_go_away!(frame);
assert_eq!(go_away.reason(), Reason::PROTOCOL_ERROR);
};
let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();
// we send a simple req here just to drive the connection so we can
// receive the server settings.
let request = Request::get("https://example.com/").body(()).unwrap();
let (response, _) = client.send_request(request, true).unwrap();
let error = h2.drive(response).await.unwrap_err();
assert_eq!(error.reason(), Some(Reason::PROTOCOL_ERROR));
};
join(srv, h2).await;
}
#[tokio::test]
async fn extended_connect_request() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();
let srv = async move {
let settings = srv
.assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1))
.await;
assert_default_settings!(settings);
srv.recv_frame(
frames::headers(1)
.request("CONNECT", "http://bread/baguette")
.protocol("the-bread-protocol")
.eos(),
)
.await;
srv.send_frame(frames::headers(1).response(200).eos()).await;
};
let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();
let request = Request::connect("http://bread/baguette")
.extension(Protocol::from("the-bread-protocol"))
.body(())
.unwrap();
let (response, _) = client.send_request(request, true).unwrap();
h2.drive(response).await.unwrap();
};
join(srv, h2).await;
}
const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0];
const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0];

View File

@@ -1149,3 +1149,191 @@ async fn send_reset_explicitly() {
join(client, srv).await;
}
#[tokio::test]
async fn extended_connect_protocol_disabled_by_default() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), None);
client
.send_frame(
frames::headers(1)
.request("CONNECT", "http://bread/baguette")
.protocol("the-bread-protocol"),
)
.await;
client.recv_frame(frames::reset(1).protocol_error()).await;
};
let srv = async move {
let mut srv = server::handshake(io).await.expect("handshake");
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}
#[tokio::test]
async fn extended_connect_protocol_enabled_during_handshake() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
client
.send_frame(
frames::headers(1)
.request("CONNECT", "http://bread/baguette")
.protocol("the-bread-protocol"),
)
.await;
client.recv_frame(frames::headers(1).response(200)).await;
};
let srv = async move {
let mut builder = server::Builder::new();
builder.enable_connect_protocol();
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
let (_req, mut stream) = srv.next().await.unwrap().unwrap();
let rsp = Response::new(());
stream.send_response(rsp, false).unwrap();
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}
#[tokio::test]
async fn reject_pseudo_protocol_on_non_connect_request() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
client
.send_frame(
frames::headers(1)
.request("GET", "http://bread/baguette")
.protocol("the-bread-protocol"),
)
.await;
client.recv_frame(frames::reset(1).protocol_error()).await;
};
let srv = async move {
let mut builder = server::Builder::new();
builder.enable_connect_protocol();
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
assert!(srv.next().await.is_none());
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}
#[tokio::test]
async fn reject_authority_target_on_extended_connect_request() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
client
.send_frame(
frames::headers(1)
.request("CONNECT", "bread:80")
.protocol("the-bread-protocol"),
)
.await;
client.recv_frame(frames::reset(1).protocol_error()).await;
};
let srv = async move {
let mut builder = server::Builder::new();
builder.enable_connect_protocol();
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
assert!(srv.next().await.is_none());
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}
#[tokio::test]
async fn reject_non_authority_target_on_connect_request() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
client
.send_frame(frames::headers(1).request("CONNECT", "https://bread/baguette"))
.await;
client.recv_frame(frames::reset(1).protocol_error()).await;
};
let srv = async move {
let mut builder = server::Builder::new();
builder.enable_connect_protocol();
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
assert!(srv.next().await.is_none());
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}