From 87969c1f296173a7838956165014b4828dc5d5db Mon Sep 17 00:00:00 2001 From: Anthony Ramine <123095+nox@users.noreply.github.com> Date: Wed, 24 Nov 2021 10:05:10 +0100 Subject: [PATCH] Implement the extended CONNECT protocol from RFC 8441 (#565) --- src/client.rs | 20 ++- src/ext.rs | 55 ++++++++ src/frame/headers.rs | 20 ++- src/frame/settings.rs | 27 ++++ src/hpack/header.rs | 18 +++ src/hpack/table.rs | 1 + src/hpack/test/fixture.rs | 2 + src/lib.rs | 1 + src/proto/connection.rs | 11 ++ src/proto/settings.rs | 4 +- src/proto/streams/mod.rs | 3 + src/proto/streams/recv.rs | 116 ++++++++------- src/proto/streams/send.rs | 18 ++- src/proto/streams/store.rs | 16 ++- src/proto/streams/streams.rs | 66 ++++----- src/server.rs | 36 ++++- tests/h2-support/src/assert.rs | 11 ++ tests/h2-support/src/frames.rs | 28 +++- tests/h2-support/src/mock.rs | 21 +-- tests/h2-support/src/prelude.rs | 5 +- tests/h2-tests/tests/client_request.rs | 147 +++++++++++++++++++ tests/h2-tests/tests/server.rs | 188 +++++++++++++++++++++++++ 22 files changed, 694 insertions(+), 120 deletions(-) create mode 100644 src/ext.rs diff --git a/src/client.rs b/src/client.rs index 9cd0b8f..3a818a5 100644 --- a/src/client.rs +++ b/src/client.rs @@ -136,6 +136,7 @@ //! [`Error`]: ../struct.Error.html use crate::codec::{Codec, SendError, UserError}; +use crate::ext::Protocol; use crate::frame::{Headers, Pseudo, Reason, Settings, StreamId}; use crate::proto::{self, Error}; use crate::{FlowControl, PingPong, RecvStream, SendStream}; @@ -517,6 +518,19 @@ where (response, stream) }) } + + /// Returns whether the [extended CONNECT protocol][1] is enabled or not. + /// + /// This setting is configured by the server peer by sending the + /// [`SETTINGS_ENABLE_CONNECT_PROTOCOL` parameter][2] in a `SETTINGS` frame. + /// This method returns the currently acknowledged value recieved from the + /// remote. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + /// [2]: https://datatracker.ietf.org/doc/html/rfc8441#section-3 + pub fn is_extended_connect_protocol_enabled(&self) -> bool { + self.inner.is_extended_connect_protocol_enabled() + } } impl fmt::Debug for SendRequest @@ -1246,11 +1260,10 @@ where /// This method returns the currently acknowledged value recieved from the /// remote. /// - /// [settings]: https://tools.ietf.org/html/rfc7540#section-5.1.2 + /// [1]: https://tools.ietf.org/html/rfc7540#section-5.1.2 pub fn max_concurrent_send_streams(&self) -> usize { self.inner.max_send_streams() } - /// Returns the maximum number of concurrent streams that may be initiated /// by the server on this connection. /// @@ -1416,6 +1429,7 @@ impl Peer { pub fn convert_send_message( id: StreamId, request: Request<()>, + protocol: Option, end_of_stream: bool, ) -> Result { use http::request::Parts; @@ -1435,7 +1449,7 @@ impl Peer { // Build the set pseudo header set. All requests will include `method` // and `path`. - let mut pseudo = Pseudo::request(method, uri); + let mut pseudo = Pseudo::request(method, uri, protocol); if pseudo.scheme.is_none() { // If the scheme is not set, then there are a two options. diff --git a/src/ext.rs b/src/ext.rs new file mode 100644 index 0000000..cf383a4 --- /dev/null +++ b/src/ext.rs @@ -0,0 +1,55 @@ +//! Extensions specific to the HTTP/2 protocol. + +use crate::hpack::BytesStr; + +use bytes::Bytes; +use std::fmt; + +/// Represents the `:protocol` pseudo-header used by +/// the [Extended CONNECT Protocol]. +/// +/// [Extended CONNECT Protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 +#[derive(Clone, Eq, PartialEq)] +pub struct Protocol { + value: BytesStr, +} + +impl Protocol { + /// Converts a static string to a protocol name. + pub const fn from_static(value: &'static str) -> Self { + Self { + value: BytesStr::from_static(value), + } + } + + /// Returns a str representation of the header. + pub fn as_str(&self) -> &str { + self.value.as_str() + } + + pub(crate) fn try_from(bytes: Bytes) -> Result { + Ok(Self { + value: BytesStr::try_from(bytes)?, + }) + } +} + +impl<'a> From<&'a str> for Protocol { + fn from(value: &'a str) -> Self { + Self { + value: BytesStr::from(value), + } + } +} + +impl AsRef<[u8]> for Protocol { + fn as_ref(&self) -> &[u8] { + self.value.as_ref() + } +} + +impl fmt::Debug for Protocol { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.value.fmt(f) + } +} diff --git a/src/frame/headers.rs b/src/frame/headers.rs index 2fc9561..05d7723 100644 --- a/src/frame/headers.rs +++ b/src/frame/headers.rs @@ -1,4 +1,5 @@ use super::{util, StreamDependency, StreamId}; +use crate::ext::Protocol; use crate::frame::{Error, Frame, Head, Kind}; use crate::hpack::{self, BytesStr}; @@ -66,6 +67,7 @@ pub struct Pseudo { pub scheme: Option, pub authority: Option, pub path: Option, + pub protocol: Option, // Response pub status: Option, @@ -292,6 +294,10 @@ impl fmt::Debug for Headers { .field("stream_id", &self.stream_id) .field("flags", &self.flags); + if let Some(ref protocol) = self.header_block.pseudo.protocol { + builder.field("protocol", protocol); + } + if let Some(ref dep) = self.stream_dep { builder.field("stream_dep", dep); } @@ -529,7 +535,7 @@ impl Continuation { // ===== impl Pseudo ===== impl Pseudo { - pub fn request(method: Method, uri: Uri) -> Self { + pub fn request(method: Method, uri: Uri, protocol: Option) -> Self { let parts = uri::Parts::from(uri); let mut path = parts @@ -550,6 +556,7 @@ impl Pseudo { scheme: None, authority: None, path: Some(path).filter(|p| !p.is_empty()), + protocol, status: None, }; @@ -575,6 +582,7 @@ impl Pseudo { scheme: None, authority: None, path: None, + protocol: None, status: Some(status), } } @@ -593,6 +601,11 @@ impl Pseudo { self.scheme = Some(bytes_str); } + #[cfg(feature = "unstable")] + pub fn set_protocol(&mut self, protocol: Protocol) { + self.protocol = Some(protocol); + } + pub fn set_authority(&mut self, authority: BytesStr) { self.authority = Some(authority); } @@ -681,6 +694,10 @@ impl Iterator for Iter { return Some(Path(path)); } + if let Some(protocol) = pseudo.protocol.take() { + return Some(Protocol(protocol)); + } + if let Some(status) = pseudo.status.take() { return Some(Status(status)); } @@ -879,6 +896,7 @@ impl HeaderBlock { Method(v) => set_pseudo!(method, v), Scheme(v) => set_pseudo!(scheme, v), Path(v) => set_pseudo!(path, v), + Protocol(v) => set_pseudo!(protocol, v), Status(v) => set_pseudo!(status, v), } }); diff --git a/src/frame/settings.rs b/src/frame/settings.rs index 523f20b..080d0f4 100644 --- a/src/frame/settings.rs +++ b/src/frame/settings.rs @@ -13,6 +13,7 @@ pub struct Settings { initial_window_size: Option, max_frame_size: Option, max_header_list_size: Option, + enable_connect_protocol: Option, } /// An enum that lists all valid settings that can be sent in a SETTINGS @@ -27,6 +28,7 @@ pub enum Setting { InitialWindowSize(u32), MaxFrameSize(u32), MaxHeaderListSize(u32), + EnableConnectProtocol(u32), } #[derive(Copy, Clone, Eq, PartialEq, Default)] @@ -107,6 +109,14 @@ impl Settings { self.enable_push = Some(enable as u32); } + pub fn is_extended_connect_protocol_enabled(&self) -> Option { + self.enable_connect_protocol.map(|val| val != 0) + } + + pub fn set_enable_connect_protocol(&mut self, val: Option) { + self.enable_connect_protocol = val; + } + pub fn header_table_size(&self) -> Option { self.header_table_size } @@ -181,6 +191,14 @@ impl Settings { Some(MaxHeaderListSize(val)) => { settings.max_header_list_size = Some(val); } + Some(EnableConnectProtocol(val)) => match val { + 0 | 1 => { + settings.enable_connect_protocol = Some(val); + } + _ => { + return Err(Error::InvalidSettingValue); + } + }, None => {} } } @@ -236,6 +254,10 @@ impl Settings { if let Some(v) = self.max_header_list_size { f(MaxHeaderListSize(v)); } + + if let Some(v) = self.enable_connect_protocol { + f(EnableConnectProtocol(v)); + } } } @@ -269,6 +291,9 @@ impl fmt::Debug for Settings { Setting::MaxHeaderListSize(v) => { builder.field("max_header_list_size", &v); } + Setting::EnableConnectProtocol(v) => { + builder.field("enable_connect_protocol", &v); + } }); builder.finish() @@ -291,6 +316,7 @@ impl Setting { 4 => Some(InitialWindowSize(val)), 5 => Some(MaxFrameSize(val)), 6 => Some(MaxHeaderListSize(val)), + 8 => Some(EnableConnectProtocol(val)), _ => None, } } @@ -322,6 +348,7 @@ impl Setting { InitialWindowSize(v) => (4, v), MaxFrameSize(v) => (5, v), MaxHeaderListSize(v) => (6, v), + EnableConnectProtocol(v) => (8, v), }; dst.put_u16(kind); diff --git a/src/hpack/header.rs b/src/hpack/header.rs index 8d6136e..e6df555 100644 --- a/src/hpack/header.rs +++ b/src/hpack/header.rs @@ -1,4 +1,5 @@ use super::{DecoderError, NeedMore}; +use crate::ext::Protocol; use bytes::Bytes; use http::header::{HeaderName, HeaderValue}; @@ -14,6 +15,7 @@ pub enum Header { Method(Method), Scheme(BytesStr), Path(BytesStr), + Protocol(Protocol), Status(StatusCode), } @@ -25,6 +27,7 @@ pub enum Name<'a> { Method, Scheme, Path, + Protocol, Status, } @@ -51,6 +54,7 @@ impl Header> { Method(v) => Method(v), Scheme(v) => Scheme(v), Path(v) => Path(v), + Protocol(v) => Protocol(v), Status(v) => Status(v), }) } @@ -79,6 +83,10 @@ impl Header { let value = BytesStr::try_from(value)?; Ok(Header::Path(value)) } + b"protocol" => { + let value = Protocol::try_from(value)?; + Ok(Header::Protocol(value)) + } b"status" => { let status = StatusCode::from_bytes(&value)?; Ok(Header::Status(status)) @@ -104,6 +112,7 @@ impl Header { Header::Method(ref v) => 32 + 7 + v.as_ref().len(), Header::Scheme(ref v) => 32 + 7 + v.len(), Header::Path(ref v) => 32 + 5 + v.len(), + Header::Protocol(ref v) => 32 + 9 + v.as_str().len(), Header::Status(_) => 32 + 7 + 3, } } @@ -116,6 +125,7 @@ impl Header { Header::Method(..) => Name::Method, Header::Scheme(..) => Name::Scheme, Header::Path(..) => Name::Path, + Header::Protocol(..) => Name::Protocol, Header::Status(..) => Name::Status, } } @@ -127,6 +137,7 @@ impl Header { Header::Method(ref v) => v.as_ref().as_ref(), Header::Scheme(ref v) => v.as_ref(), Header::Path(ref v) => v.as_ref(), + Header::Protocol(ref v) => v.as_ref(), Header::Status(ref v) => v.as_str().as_ref(), } } @@ -156,6 +167,10 @@ impl Header { Header::Path(ref b) => a == b, _ => false, }, + Header::Protocol(ref a) => match *other { + Header::Protocol(ref b) => a == b, + _ => false, + }, Header::Status(ref a) => match *other { Header::Status(ref b) => a == b, _ => false, @@ -205,6 +220,7 @@ impl From
for Header> { Header::Method(v) => Header::Method(v), Header::Scheme(v) => Header::Scheme(v), Header::Path(v) => Header::Path(v), + Header::Protocol(v) => Header::Protocol(v), Header::Status(v) => Header::Status(v), } } @@ -221,6 +237,7 @@ impl<'a> Name<'a> { Name::Method => Ok(Header::Method(Method::from_bytes(&*value)?)), Name::Scheme => Ok(Header::Scheme(BytesStr::try_from(value)?)), Name::Path => Ok(Header::Path(BytesStr::try_from(value)?)), + Name::Protocol => Ok(Header::Protocol(Protocol::try_from(value)?)), Name::Status => { match StatusCode::from_bytes(&value) { Ok(status) => Ok(Header::Status(status)), @@ -238,6 +255,7 @@ impl<'a> Name<'a> { Name::Method => b":method", Name::Scheme => b":scheme", Name::Path => b":path", + Name::Protocol => b":protocol", Name::Status => b":status", } } diff --git a/src/hpack/table.rs b/src/hpack/table.rs index 2328743..0124f21 100644 --- a/src/hpack/table.rs +++ b/src/hpack/table.rs @@ -751,6 +751,7 @@ fn index_static(header: &Header) -> Option<(usize, bool)> { "/index.html" => Some((5, true)), _ => Some((4, false)), }, + Header::Protocol(..) => None, Header::Status(ref v) => match u16::from(*v) { 200 => Some((8, true)), 204 => Some((9, true)), diff --git a/src/hpack/test/fixture.rs b/src/hpack/test/fixture.rs index 6d04484..3428c39 100644 --- a/src/hpack/test/fixture.rs +++ b/src/hpack/test/fixture.rs @@ -134,6 +134,7 @@ fn key_str(e: &Header) -> &str { Header::Method(..) => ":method", Header::Scheme(..) => ":scheme", Header::Path(..) => ":path", + Header::Protocol(..) => ":protocol", Header::Status(..) => ":status", } } @@ -145,6 +146,7 @@ fn value_str(e: &Header) -> &str { Header::Method(ref m) => m.as_str(), Header::Scheme(ref v) => &**v, Header::Path(ref v) => &**v, + Header::Protocol(ref v) => v.as_str(), Header::Status(ref v) => v.as_str(), } } diff --git a/src/lib.rs b/src/lib.rs index cb02aca..db6b488 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -120,6 +120,7 @@ mod frame; pub mod frame; pub mod client; +pub mod ext; pub mod server; mod share; diff --git a/src/proto/connection.rs b/src/proto/connection.rs index a75df43..d1b8b51 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -110,6 +110,10 @@ where initial_max_send_streams: config.initial_max_send_streams, local_next_stream_id: config.next_stream_id, local_push_enabled: config.settings.is_push_enabled().unwrap_or(true), + extended_connect_protocol_enabled: config + .settings + .is_extended_connect_protocol_enabled() + .unwrap_or(false), local_reset_duration: config.reset_stream_duration, local_reset_max: config.reset_stream_max, remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, @@ -147,6 +151,13 @@ where self.inner.settings.send_settings(settings) } + /// Send a new SETTINGS frame with extended CONNECT protocol enabled. + pub(crate) fn set_enable_connect_protocol(&mut self) -> Result<(), UserError> { + let mut settings = frame::Settings::default(); + settings.set_enable_connect_protocol(Some(1)); + self.inner.settings.send_settings(settings) + } + /// Returns the maximum number of concurrent streams that may be initiated /// by this peer. pub(crate) fn max_send_streams(&self) -> usize { diff --git a/src/proto/settings.rs b/src/proto/settings.rs index 44f4c2d..6cc6172 100644 --- a/src/proto/settings.rs +++ b/src/proto/settings.rs @@ -117,6 +117,8 @@ impl Settings { tracing::trace!("ACK sent; applying settings"); + streams.apply_remote_settings(settings)?; + if let Some(val) = settings.header_table_size() { dst.set_send_header_table_size(val as usize); } @@ -124,8 +126,6 @@ impl Settings { if let Some(val) = settings.max_frame_size() { dst.set_max_send_frame_size(val as usize); } - - streams.apply_remote_settings(settings)?; } self.remote = None; diff --git a/src/proto/streams/mod.rs b/src/proto/streams/mod.rs index 608395c..0fd61a2 100644 --- a/src/proto/streams/mod.rs +++ b/src/proto/streams/mod.rs @@ -47,6 +47,9 @@ pub struct Config { /// If the local peer is willing to receive push promises pub local_push_enabled: bool, + /// If extended connect protocol is enabled. + pub extended_connect_protocol_enabled: bool, + /// How long a locally reset stream should ignore frames pub local_reset_duration: Duration, diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index be996b9..e613c26 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -56,6 +56,9 @@ pub(super) struct Recv { /// If push promises are allowed to be received. is_push_enabled: bool, + + /// If extended connect protocol is enabled. + is_extended_connect_protocol_enabled: bool, } #[derive(Debug)] @@ -103,6 +106,7 @@ impl Recv { buffer: Buffer::new(), refused: None, is_push_enabled: config.local_push_enabled, + is_extended_connect_protocol_enabled: config.extended_connect_protocol_enabled, } } @@ -216,6 +220,14 @@ impl Recv { let stream_id = frame.stream_id(); let (pseudo, fields) = frame.into_parts(); + + if pseudo.protocol.is_some() { + if counts.peer().is_server() && !self.is_extended_connect_protocol_enabled { + proto_err!(stream: "cannot use :protocol if extended connect protocol is disabled; stream={:?}", stream.id); + return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR).into()); + } + } + if !pseudo.is_informational() { let message = counts .peer() @@ -449,60 +461,58 @@ impl Recv { settings: &frame::Settings, store: &mut Store, ) -> Result<(), proto::Error> { - let target = if let Some(val) = settings.initial_window_size() { - val - } else { - return Ok(()); - }; - - let old_sz = self.init_window_sz; - self.init_window_sz = target; - - tracing::trace!("update_initial_window_size; new={}; old={}", target, old_sz,); - - // Per RFC 7540 §6.9.2: - // - // In addition to changing the flow-control window for streams that are - // not yet active, a SETTINGS frame can alter the initial flow-control - // window size for streams with active flow-control windows (that is, - // streams in the "open" or "half-closed (remote)" state). When the - // value of SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST adjust - // the size of all stream flow-control windows that it maintains by the - // difference between the new value and the old value. - // - // A change to `SETTINGS_INITIAL_WINDOW_SIZE` can cause the available - // space in a flow-control window to become negative. A sender MUST - // track the negative flow-control window and MUST NOT send new - // flow-controlled frames until it receives WINDOW_UPDATE frames that - // cause the flow-control window to become positive. - - if target < old_sz { - // We must decrease the (local) window on every open stream. - let dec = old_sz - target; - tracing::trace!("decrementing all windows; dec={}", dec); - - store.for_each(|mut stream| { - stream.recv_flow.dec_recv_window(dec); - Ok(()) - }) - } else if target > old_sz { - // We must increase the (local) window on every open stream. - let inc = target - old_sz; - tracing::trace!("incrementing all windows; inc={}", inc); - store.for_each(|mut stream| { - // XXX: Shouldn't the peer have already noticed our - // overflow and sent us a GOAWAY? - stream - .recv_flow - .inc_window(inc) - .map_err(proto::Error::library_go_away)?; - stream.recv_flow.assign_capacity(inc); - Ok(()) - }) - } else { - // size is the same... so do nothing - Ok(()) + if let Some(val) = settings.is_extended_connect_protocol_enabled() { + self.is_extended_connect_protocol_enabled = val; } + + if let Some(target) = settings.initial_window_size() { + let old_sz = self.init_window_sz; + self.init_window_sz = target; + + tracing::trace!("update_initial_window_size; new={}; old={}", target, old_sz,); + + // Per RFC 7540 §6.9.2: + // + // In addition to changing the flow-control window for streams that are + // not yet active, a SETTINGS frame can alter the initial flow-control + // window size for streams with active flow-control windows (that is, + // streams in the "open" or "half-closed (remote)" state). When the + // value of SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST adjust + // the size of all stream flow-control windows that it maintains by the + // difference between the new value and the old value. + // + // A change to `SETTINGS_INITIAL_WINDOW_SIZE` can cause the available + // space in a flow-control window to become negative. A sender MUST + // track the negative flow-control window and MUST NOT send new + // flow-controlled frames until it receives WINDOW_UPDATE frames that + // cause the flow-control window to become positive. + + if target < old_sz { + // We must decrease the (local) window on every open stream. + let dec = old_sz - target; + tracing::trace!("decrementing all windows; dec={}", dec); + + store.for_each(|mut stream| { + stream.recv_flow.dec_recv_window(dec); + }) + } else if target > old_sz { + // We must increase the (local) window on every open stream. + let inc = target - old_sz; + tracing::trace!("incrementing all windows; inc={}", inc); + store.try_for_each(|mut stream| { + // XXX: Shouldn't the peer have already noticed our + // overflow and sent us a GOAWAY? + stream + .recv_flow + .inc_window(inc) + .map_err(proto::Error::library_go_away)?; + stream.recv_flow.assign_capacity(inc); + Ok::<_, proto::Error>(()) + })?; + } + } + + Ok(()) } pub fn is_end_stream(&self, stream: &store::Ptr) -> bool { diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index 3735d13..e3fcf6d 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -35,6 +35,9 @@ pub(super) struct Send { prioritize: Prioritize, is_push_enabled: bool, + + /// If extended connect protocol is enabled. + is_extended_connect_protocol_enabled: bool, } /// A value to detect which public API has called `poll_reset`. @@ -53,6 +56,7 @@ impl Send { next_stream_id: Ok(config.local_next_stream_id), prioritize: Prioritize::new(config), is_push_enabled: true, + is_extended_connect_protocol_enabled: false, } } @@ -429,6 +433,10 @@ impl Send { counts: &mut Counts, task: &mut Option, ) -> Result<(), Error> { + if let Some(val) = settings.is_extended_connect_protocol_enabled() { + self.is_extended_connect_protocol_enabled = val; + } + // Applies an update to the remote endpoint's initial window size. // // Per RFC 7540 §6.9.2: @@ -490,16 +498,14 @@ impl Send { // TODO: Should this notify the producer when the capacity // of a stream is reduced? Maybe it should if the capacity // is reduced to zero, allowing the producer to stop work. - - Ok::<_, Error>(()) - })?; + }); self.prioritize .assign_connection_capacity(total_reclaimed, store, counts); } else if val > old_val { let inc = val - old_val; - store.for_each(|mut stream| { + store.try_for_each(|mut stream| { self.recv_stream_window_update(inc, buffer, &mut stream, counts, task) .map_err(Error::library_go_away) })?; @@ -554,4 +560,8 @@ impl Send { } } } + + pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { + self.is_extended_connect_protocol_enabled + } } diff --git a/src/proto/streams/store.rs b/src/proto/streams/store.rs index ac58f43..3e34b7c 100644 --- a/src/proto/streams/store.rs +++ b/src/proto/streams/store.rs @@ -4,6 +4,7 @@ use slab; use indexmap::{self, IndexMap}; +use std::convert::Infallible; use std::fmt; use std::marker::PhantomData; use std::ops; @@ -128,7 +129,20 @@ impl Store { } } - pub fn for_each(&mut self, mut f: F) -> Result<(), E> + pub(crate) fn for_each(&mut self, mut f: F) + where + F: FnMut(Ptr), + { + match self.try_for_each(|ptr| { + f(ptr); + Ok::<_, Infallible>(()) + }) { + Ok(()) => (), + Err(infallible) => match infallible {}, + } + } + + pub fn try_for_each(&mut self, mut f: F) -> Result<(), E> where F: FnMut(Ptr) -> Result<(), E>, { diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 4962db8..5c235c1 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -2,6 +2,7 @@ use super::recv::RecvHeaderBlockError; use super::store::{self, Entry, Resolve, Store}; use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; use crate::codec::{Codec, SendError, UserError}; +use crate::ext::Protocol; use crate::frame::{self, Frame, Reason}; use crate::proto::{peer, Error, Initiator, Open, Peer, WindowSize}; use crate::{client, proto, server}; @@ -214,6 +215,8 @@ where use super::stream::ContentLength; use http::Method; + let protocol = request.extensions_mut().remove::(); + // Clear before taking lock, incase extensions contain a StreamRef. request.extensions_mut().clear(); @@ -261,7 +264,8 @@ where } // Convert the message - let headers = client::Peer::convert_send_message(stream_id, request, end_of_stream)?; + let headers = + client::Peer::convert_send_message(stream_id, request, protocol, end_of_stream)?; let mut stream = me.store.insert(stream.id, stream); @@ -294,6 +298,15 @@ where send_buffer: self.send_buffer.clone(), }) } + + pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { + self.inner + .lock() + .unwrap() + .actions + .send + .is_extended_connect_protocol_enabled() + } } impl DynStreams<'_, B> { @@ -643,15 +656,12 @@ impl Inner { let last_processed_id = actions.recv.last_processed_id(); - self.store - .for_each(|stream| { - counts.transition(stream, |counts, stream| { - actions.recv.handle_error(&err, &mut *stream); - actions.send.handle_error(send_buffer, stream, counts); - Ok::<_, ()>(()) - }) + self.store.for_each(|stream| { + counts.transition(stream, |counts, stream| { + actions.recv.handle_error(&err, &mut *stream); + actions.send.handle_error(send_buffer, stream, counts); }) - .unwrap(); + }); actions.conn_error = Some(err); @@ -674,19 +684,14 @@ impl Inner { let err = Error::remote_go_away(frame.debug_data().clone(), frame.reason()); - self.store - .for_each(|stream| { - if stream.id > last_stream_id { - counts.transition(stream, |counts, stream| { - actions.recv.handle_error(&err, &mut *stream); - actions.send.handle_error(send_buffer, stream, counts); - Ok::<_, ()>(()) - }) - } else { - Ok::<_, ()>(()) - } - }) - .unwrap(); + self.store.for_each(|stream| { + if stream.id > last_stream_id { + counts.transition(stream, |counts, stream| { + actions.recv.handle_error(&err, &mut *stream); + actions.send.handle_error(send_buffer, stream, counts); + }) + } + }); actions.conn_error = Some(err); @@ -807,18 +812,15 @@ impl Inner { tracing::trace!("Streams::recv_eof"); - self.store - .for_each(|stream| { - counts.transition(stream, |counts, stream| { - actions.recv.recv_eof(stream); + self.store.for_each(|stream| { + counts.transition(stream, |counts, stream| { + actions.recv.recv_eof(stream); - // This handles resetting send state associated with the - // stream - actions.send.handle_error(send_buffer, stream, counts); - Ok::<_, ()>(()) - }) + // This handles resetting send state associated with the + // stream + actions.send.handle_error(send_buffer, stream, counts); }) - .expect("recv_eof"); + }); actions.clear_queues(clear_pending_accept, &mut self.store, counts); Ok(()) diff --git a/src/server.rs b/src/server.rs index 4914464..1eb4031 100644 --- a/src/server.rs +++ b/src/server.rs @@ -470,6 +470,19 @@ where Ok(()) } + /// Enables the [extended CONNECT protocol]. + /// + /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + /// + /// # Errors + /// + /// Returns an error if a previous call is still pending acknowledgement + /// from the remote endpoint. + pub fn enable_connect_protocol(&mut self) -> Result<(), crate::Error> { + self.connection.set_enable_connect_protocol()?; + Ok(()) + } + /// Returns `Ready` when the underlying connection has closed. /// /// If any new inbound streams are received during a call to `poll_closed`, @@ -904,6 +917,14 @@ impl Builder { self } + /// Enables the [extended CONNECT protocol]. + /// + /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + pub fn enable_connect_protocol(&mut self) -> &mut Self { + self.settings.set_enable_connect_protocol(Some(1)); + self + } + /// Creates a new configured HTTP/2 server backed by `io`. /// /// It is expected that `io` already be in an appropriate state to commence @@ -1360,7 +1381,7 @@ impl Peer { _, ) = request.into_parts(); - let pseudo = Pseudo::request(method, uri); + let pseudo = Pseudo::request(method, uri, None); Ok(frame::PushPromise::new( stream_id, @@ -1410,6 +1431,11 @@ impl proto::Peer for Peer { malformed!("malformed headers: missing method"); } + let has_protocol = pseudo.protocol.is_some(); + if !is_connect && has_protocol { + malformed!("malformed headers: :protocol on non-CONNECT request"); + } + if pseudo.status.is_some() { malformed!("malformed headers: :status field on request"); } @@ -1432,7 +1458,7 @@ impl proto::Peer for Peer { // A :scheme is required, except CONNECT. if let Some(scheme) = pseudo.scheme { - if is_connect { + if is_connect && !has_protocol { malformed!(":scheme in CONNECT"); } let maybe_scheme = scheme.parse(); @@ -1450,12 +1476,12 @@ impl proto::Peer for Peer { if parts.authority.is_some() { parts.scheme = Some(scheme); } - } else if !is_connect { + } else if !is_connect || has_protocol { malformed!("malformed headers: missing scheme"); } if let Some(path) = pseudo.path { - if is_connect { + if is_connect && !has_protocol { malformed!(":path in CONNECT"); } @@ -1468,6 +1494,8 @@ impl proto::Peer for Peer { parts.path_and_query = Some(maybe_path.or_else(|why| { malformed!("malformed headers: malformed path ({:?}): {}", path, why,) })?); + } else if is_connect && has_protocol { + malformed!("malformed headers: missing path in extended CONNECT"); } b = b.uri(parts); diff --git a/tests/h2-support/src/assert.rs b/tests/h2-support/src/assert.rs index 8bc6d25..88e3d4f 100644 --- a/tests/h2-support/src/assert.rs +++ b/tests/h2-support/src/assert.rs @@ -47,6 +47,16 @@ macro_rules! assert_settings { }}; } +#[macro_export] +macro_rules! assert_go_away { + ($frame:expr) => {{ + match $frame { + h2::frame::Frame::GoAway(v) => v, + f => panic!("expected GO_AWAY; actual={:?}", f), + } + }}; +} + #[macro_export] macro_rules! poll_err { ($transport:expr) => {{ @@ -80,6 +90,7 @@ macro_rules! assert_default_settings { use h2::frame::Frame; +#[track_caller] pub fn assert_frame_eq, U: Into>(t: T, u: U) { let actual: Frame = t.into(); let expected: Frame = u.into(); diff --git a/tests/h2-support/src/frames.rs b/tests/h2-support/src/frames.rs index 824bc5c..f2c07ba 100644 --- a/tests/h2-support/src/frames.rs +++ b/tests/h2-support/src/frames.rs @@ -4,7 +4,10 @@ use std::fmt; use bytes::Bytes; use http::{self, HeaderMap, StatusCode}; -use h2::frame::{self, Frame, StreamId}; +use h2::{ + ext::Protocol, + frame::{self, Frame, StreamId}, +}; pub const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; pub const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; @@ -109,7 +112,9 @@ impl Mock { let method = method.try_into().unwrap(); let uri = uri.try_into().unwrap(); let (id, _, fields) = self.into_parts(); - let frame = frame::Headers::new(id, frame::Pseudo::request(method, uri), fields); + let extensions = Default::default(); + let pseudo = frame::Pseudo::request(method, uri, extensions); + let frame = frame::Headers::new(id, pseudo, fields); Mock(frame) } @@ -179,6 +184,15 @@ impl Mock { Mock(frame::Headers::new(id, pseudo, fields)) } + pub fn protocol(self, value: &str) -> Self { + let (id, mut pseudo, fields) = self.into_parts(); + let value = Protocol::from(value); + + pseudo.set_protocol(value); + + Mock(frame::Headers::new(id, pseudo, fields)) + } + pub fn eos(mut self) -> Self { self.0.set_end_stream(); self @@ -230,8 +244,9 @@ impl Mock { let method = method.try_into().unwrap(); let uri = uri.try_into().unwrap(); let (id, promised, _, fields) = self.into_parts(); - let frame = - frame::PushPromise::new(id, promised, frame::Pseudo::request(method, uri), fields); + let extensions = Default::default(); + let pseudo = frame::Pseudo::request(method, uri, extensions); + let frame = frame::PushPromise::new(id, promised, pseudo, fields); Mock(frame) } @@ -352,6 +367,11 @@ impl Mock { self.0.set_enable_push(false); self } + + pub fn enable_connect_protocol(mut self, val: u32) -> Self { + self.0.set_enable_connect_protocol(Some(val)); + self + } } impl From> for frame::Settings { diff --git a/tests/h2-support/src/mock.rs b/tests/h2-support/src/mock.rs index b5df9ad..cc314cd 100644 --- a/tests/h2-support/src/mock.rs +++ b/tests/h2-support/src/mock.rs @@ -221,22 +221,15 @@ impl Handle { let settings = settings.into(); self.send(settings.into()).await.unwrap(); - let frame = self.next().await; - let settings = match frame { - Some(frame) => match frame.unwrap() { - Frame::Settings(settings) => { - // Send the ACK - let ack = frame::Settings::ack(); + let frame = self.next().await.expect("unexpected EOF").unwrap(); + let settings = assert_settings!(frame); - // TODO: Don't unwrap? - self.send(ack.into()).await.unwrap(); + // Send the ACK + let ack = frame::Settings::ack(); + + // TODO: Don't unwrap? + self.send(ack.into()).await.unwrap(); - settings - } - frame => panic!("unexpected frame; frame={:?}", frame), - }, - None => panic!("unexpected EOF"), - }; let frame = self.next().await; let f = assert_settings!(frame.unwrap().unwrap()); diff --git a/tests/h2-support/src/prelude.rs b/tests/h2-support/src/prelude.rs index 1fcb0dc..86ef324 100644 --- a/tests/h2-support/src/prelude.rs +++ b/tests/h2-support/src/prelude.rs @@ -2,6 +2,7 @@ pub use h2; pub use h2::client; +pub use h2::ext::Protocol; pub use h2::frame::StreamId; pub use h2::server; pub use h2::*; @@ -20,8 +21,8 @@ pub use super::{Codec, SendFrame}; // Re-export macros pub use super::{ - assert_closed, assert_data, assert_default_settings, assert_headers, assert_ping, poll_err, - poll_frame, raw_codec, + assert_closed, assert_data, assert_default_settings, assert_go_away, assert_headers, + assert_ping, assert_settings, poll_err, poll_frame, raw_codec, }; pub use super::assert::assert_frame_eq; diff --git a/tests/h2-tests/tests/client_request.rs b/tests/h2-tests/tests/client_request.rs index 2af0bde..9635bcc 100644 --- a/tests/h2-tests/tests/client_request.rs +++ b/tests/h2-tests/tests/client_request.rs @@ -1305,6 +1305,153 @@ async fn informational_while_local_streaming() { join(srv, h2).await; } +#[tokio::test] +async fn extended_connect_protocol_disabled_by_default() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + + srv.recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + // we send a simple req here just to drive the connection so we can + // receive the server settings. + let request = Request::get("https://example.com/").body(()).unwrap(); + // first request is allowed + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + + assert!(!client.is_extended_connect_protocol_enabled()); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn extended_connect_protocol_enabled_during_handshake() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1)) + .await; + assert_default_settings!(settings); + + srv.recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + // we send a simple req here just to drive the connection so we can + // receive the server settings. + let request = Request::get("https://example.com/").body(()).unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + + assert!(client.is_extended_connect_protocol_enabled()); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn invalid_connect_protocol_enabled_setting() { + h2_support::trace_init!(); + + let (io, mut srv) = mock::new(); + + let srv = async move { + // Send a settings frame + srv.send(frames::settings().enable_connect_protocol(2).into()) + .await + .unwrap(); + srv.read_preface().await.unwrap(); + + let settings = assert_settings!(srv.next().await.expect("unexpected EOF").unwrap()); + assert_default_settings!(settings); + + // Send the ACK + let ack = frame::Settings::ack(); + + // TODO: Don't unwrap? + srv.send(ack.into()).await.unwrap(); + + let frame = srv.next().await.unwrap().unwrap(); + let go_away = assert_go_away!(frame); + assert_eq!(go_away.reason(), Reason::PROTOCOL_ERROR); + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + // we send a simple req here just to drive the connection so we can + // receive the server settings. + let request = Request::get("https://example.com/").body(()).unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + + let error = h2.drive(response).await.unwrap_err(); + assert_eq!(error.reason(), Some(Reason::PROTOCOL_ERROR)); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn extended_connect_request() { + h2_support::trace_init!(); + + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1)) + .await; + assert_default_settings!(settings); + + srv.recv_frame( + frames::headers(1) + .request("CONNECT", "http://bread/baguette") + .protocol("the-bread-protocol") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + let request = Request::connect("http://bread/baguette") + .extension(Protocol::from("the-bread-protocol")) + .body(()) + .unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + }; + + join(srv, h2).await; +} + const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index e60483d..b3bf1a2 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -1149,3 +1149,191 @@ async fn send_reset_explicitly() { join(client, srv).await; } + +#[tokio::test] +async fn extended_connect_protocol_disabled_by_default() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), None); + + client + .send_frame( + frames::headers(1) + .request("CONNECT", "http://bread/baguette") + .protocol("the-bread-protocol"), + ) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn extended_connect_protocol_enabled_during_handshake() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); + + client + .send_frame( + frames::headers(1) + .request("CONNECT", "http://bread/baguette") + .protocol("the-bread-protocol"), + ) + .await; + + client.recv_frame(frames::headers(1).response(200)).await; + }; + + let srv = async move { + let mut builder = server::Builder::new(); + + builder.enable_connect_protocol(); + + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + let (_req, mut stream) = srv.next().await.unwrap().unwrap(); + + let rsp = Response::new(()); + stream.send_response(rsp, false).unwrap(); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn reject_pseudo_protocol_on_non_connect_request() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); + + client + .send_frame( + frames::headers(1) + .request("GET", "http://bread/baguette") + .protocol("the-bread-protocol"), + ) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let mut builder = server::Builder::new(); + + builder.enable_connect_protocol(); + + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + assert!(srv.next().await.is_none()); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn reject_authority_target_on_extended_connect_request() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); + + client + .send_frame( + frames::headers(1) + .request("CONNECT", "bread:80") + .protocol("the-bread-protocol"), + ) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let mut builder = server::Builder::new(); + + builder.enable_connect_protocol(); + + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + assert!(srv.next().await.is_none()); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn reject_non_authority_target_on_connect_request() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); + + client + .send_frame(frames::headers(1).request("CONNECT", "https://bread/baguette")) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let mut builder = server::Builder::new(); + + builder.enable_connect_protocol(); + + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + assert!(srv.next().await.is_none()); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +}