Implement the extended CONNECT protocol from RFC 8441 (#565)
This commit is contained in:
		| @@ -136,6 +136,7 @@ | |||||||
| //! [`Error`]: ../struct.Error.html | //! [`Error`]: ../struct.Error.html | ||||||
|  |  | ||||||
| use crate::codec::{Codec, SendError, UserError}; | use crate::codec::{Codec, SendError, UserError}; | ||||||
|  | use crate::ext::Protocol; | ||||||
| use crate::frame::{Headers, Pseudo, Reason, Settings, StreamId}; | use crate::frame::{Headers, Pseudo, Reason, Settings, StreamId}; | ||||||
| use crate::proto::{self, Error}; | use crate::proto::{self, Error}; | ||||||
| use crate::{FlowControl, PingPong, RecvStream, SendStream}; | use crate::{FlowControl, PingPong, RecvStream, SendStream}; | ||||||
| @@ -517,6 +518,19 @@ where | |||||||
|                 (response, stream) |                 (response, stream) | ||||||
|             }) |             }) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     /// Returns whether the [extended CONNECT protocol][1] is enabled or not. | ||||||
|  |     /// | ||||||
|  |     /// This setting is configured by the server peer by sending the | ||||||
|  |     /// [`SETTINGS_ENABLE_CONNECT_PROTOCOL` parameter][2] in a `SETTINGS` frame. | ||||||
|  |     /// This method returns the currently acknowledged value recieved from the | ||||||
|  |     /// remote. | ||||||
|  |     /// | ||||||
|  |     /// [1]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 | ||||||
|  |     /// [2]: https://datatracker.ietf.org/doc/html/rfc8441#section-3 | ||||||
|  |     pub fn is_extended_connect_protocol_enabled(&self) -> bool { | ||||||
|  |         self.inner.is_extended_connect_protocol_enabled() | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| impl<B> fmt::Debug for SendRequest<B> | impl<B> fmt::Debug for SendRequest<B> | ||||||
| @@ -1246,11 +1260,10 @@ where | |||||||
|     /// This method returns the currently acknowledged value recieved from the |     /// This method returns the currently acknowledged value recieved from the | ||||||
|     /// remote. |     /// remote. | ||||||
|     /// |     /// | ||||||
|     /// [settings]: https://tools.ietf.org/html/rfc7540#section-5.1.2 |     /// [1]: https://tools.ietf.org/html/rfc7540#section-5.1.2 | ||||||
|     pub fn max_concurrent_send_streams(&self) -> usize { |     pub fn max_concurrent_send_streams(&self) -> usize { | ||||||
|         self.inner.max_send_streams() |         self.inner.max_send_streams() | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /// Returns the maximum number of concurrent streams that may be initiated |     /// Returns the maximum number of concurrent streams that may be initiated | ||||||
|     /// by the server on this connection. |     /// by the server on this connection. | ||||||
|     /// |     /// | ||||||
| @@ -1416,6 +1429,7 @@ impl Peer { | |||||||
|     pub fn convert_send_message( |     pub fn convert_send_message( | ||||||
|         id: StreamId, |         id: StreamId, | ||||||
|         request: Request<()>, |         request: Request<()>, | ||||||
|  |         protocol: Option<Protocol>, | ||||||
|         end_of_stream: bool, |         end_of_stream: bool, | ||||||
|     ) -> Result<Headers, SendError> { |     ) -> Result<Headers, SendError> { | ||||||
|         use http::request::Parts; |         use http::request::Parts; | ||||||
| @@ -1435,7 +1449,7 @@ impl Peer { | |||||||
|  |  | ||||||
|         // Build the set pseudo header set. All requests will include `method` |         // Build the set pseudo header set. All requests will include `method` | ||||||
|         // and `path`. |         // and `path`. | ||||||
|         let mut pseudo = Pseudo::request(method, uri); |         let mut pseudo = Pseudo::request(method, uri, protocol); | ||||||
|  |  | ||||||
|         if pseudo.scheme.is_none() { |         if pseudo.scheme.is_none() { | ||||||
|             // If the scheme is not set, then there are a two options. |             // If the scheme is not set, then there are a two options. | ||||||
|   | |||||||
							
								
								
									
										55
									
								
								src/ext.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								src/ext.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | |||||||
|  | //! Extensions specific to the HTTP/2 protocol. | ||||||
|  |  | ||||||
|  | use crate::hpack::BytesStr; | ||||||
|  |  | ||||||
|  | use bytes::Bytes; | ||||||
|  | use std::fmt; | ||||||
|  |  | ||||||
|  | /// Represents the `:protocol` pseudo-header used by | ||||||
|  | /// the [Extended CONNECT Protocol]. | ||||||
|  | /// | ||||||
|  | /// [Extended CONNECT Protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 | ||||||
|  | #[derive(Clone, Eq, PartialEq)] | ||||||
|  | pub struct Protocol { | ||||||
|  |     value: BytesStr, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Protocol { | ||||||
|  |     /// Converts a static string to a protocol name. | ||||||
|  |     pub const fn from_static(value: &'static str) -> Self { | ||||||
|  |         Self { | ||||||
|  |             value: BytesStr::from_static(value), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /// Returns a str representation of the header. | ||||||
|  |     pub fn as_str(&self) -> &str { | ||||||
|  |         self.value.as_str() | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub(crate) fn try_from(bytes: Bytes) -> Result<Self, std::str::Utf8Error> { | ||||||
|  |         Ok(Self { | ||||||
|  |             value: BytesStr::try_from(bytes)?, | ||||||
|  |         }) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<'a> From<&'a str> for Protocol { | ||||||
|  |     fn from(value: &'a str) -> Self { | ||||||
|  |         Self { | ||||||
|  |             value: BytesStr::from(value), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl AsRef<[u8]> for Protocol { | ||||||
|  |     fn as_ref(&self) -> &[u8] { | ||||||
|  |         self.value.as_ref() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl fmt::Debug for Protocol { | ||||||
|  |     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||||||
|  |         self.value.fmt(f) | ||||||
|  |     } | ||||||
|  | } | ||||||
| @@ -1,4 +1,5 @@ | |||||||
| use super::{util, StreamDependency, StreamId}; | use super::{util, StreamDependency, StreamId}; | ||||||
|  | use crate::ext::Protocol; | ||||||
| use crate::frame::{Error, Frame, Head, Kind}; | use crate::frame::{Error, Frame, Head, Kind}; | ||||||
| use crate::hpack::{self, BytesStr}; | use crate::hpack::{self, BytesStr}; | ||||||
|  |  | ||||||
| @@ -66,6 +67,7 @@ pub struct Pseudo { | |||||||
|     pub scheme: Option<BytesStr>, |     pub scheme: Option<BytesStr>, | ||||||
|     pub authority: Option<BytesStr>, |     pub authority: Option<BytesStr>, | ||||||
|     pub path: Option<BytesStr>, |     pub path: Option<BytesStr>, | ||||||
|  |     pub protocol: Option<Protocol>, | ||||||
|  |  | ||||||
|     // Response |     // Response | ||||||
|     pub status: Option<StatusCode>, |     pub status: Option<StatusCode>, | ||||||
| @@ -292,6 +294,10 @@ impl fmt::Debug for Headers { | |||||||
|             .field("stream_id", &self.stream_id) |             .field("stream_id", &self.stream_id) | ||||||
|             .field("flags", &self.flags); |             .field("flags", &self.flags); | ||||||
|  |  | ||||||
|  |         if let Some(ref protocol) = self.header_block.pseudo.protocol { | ||||||
|  |             builder.field("protocol", protocol); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         if let Some(ref dep) = self.stream_dep { |         if let Some(ref dep) = self.stream_dep { | ||||||
|             builder.field("stream_dep", dep); |             builder.field("stream_dep", dep); | ||||||
|         } |         } | ||||||
| @@ -529,7 +535,7 @@ impl Continuation { | |||||||
| // ===== impl Pseudo ===== | // ===== impl Pseudo ===== | ||||||
|  |  | ||||||
| impl Pseudo { | impl Pseudo { | ||||||
|     pub fn request(method: Method, uri: Uri) -> Self { |     pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self { | ||||||
|         let parts = uri::Parts::from(uri); |         let parts = uri::Parts::from(uri); | ||||||
|  |  | ||||||
|         let mut path = parts |         let mut path = parts | ||||||
| @@ -550,6 +556,7 @@ impl Pseudo { | |||||||
|             scheme: None, |             scheme: None, | ||||||
|             authority: None, |             authority: None, | ||||||
|             path: Some(path).filter(|p| !p.is_empty()), |             path: Some(path).filter(|p| !p.is_empty()), | ||||||
|  |             protocol, | ||||||
|             status: None, |             status: None, | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
| @@ -575,6 +582,7 @@ impl Pseudo { | |||||||
|             scheme: None, |             scheme: None, | ||||||
|             authority: None, |             authority: None, | ||||||
|             path: None, |             path: None, | ||||||
|  |             protocol: None, | ||||||
|             status: Some(status), |             status: Some(status), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -593,6 +601,11 @@ impl Pseudo { | |||||||
|         self.scheme = Some(bytes_str); |         self.scheme = Some(bytes_str); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     #[cfg(feature = "unstable")] | ||||||
|  |     pub fn set_protocol(&mut self, protocol: Protocol) { | ||||||
|  |         self.protocol = Some(protocol); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn set_authority(&mut self, authority: BytesStr) { |     pub fn set_authority(&mut self, authority: BytesStr) { | ||||||
|         self.authority = Some(authority); |         self.authority = Some(authority); | ||||||
|     } |     } | ||||||
| @@ -681,6 +694,10 @@ impl Iterator for Iter { | |||||||
|                 return Some(Path(path)); |                 return Some(Path(path)); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|  |             if let Some(protocol) = pseudo.protocol.take() { | ||||||
|  |                 return Some(Protocol(protocol)); | ||||||
|  |             } | ||||||
|  |  | ||||||
|             if let Some(status) = pseudo.status.take() { |             if let Some(status) = pseudo.status.take() { | ||||||
|                 return Some(Status(status)); |                 return Some(Status(status)); | ||||||
|             } |             } | ||||||
| @@ -879,6 +896,7 @@ impl HeaderBlock { | |||||||
|                 Method(v) => set_pseudo!(method, v), |                 Method(v) => set_pseudo!(method, v), | ||||||
|                 Scheme(v) => set_pseudo!(scheme, v), |                 Scheme(v) => set_pseudo!(scheme, v), | ||||||
|                 Path(v) => set_pseudo!(path, v), |                 Path(v) => set_pseudo!(path, v), | ||||||
|  |                 Protocol(v) => set_pseudo!(protocol, v), | ||||||
|                 Status(v) => set_pseudo!(status, v), |                 Status(v) => set_pseudo!(status, v), | ||||||
|             } |             } | ||||||
|         }); |         }); | ||||||
|   | |||||||
| @@ -13,6 +13,7 @@ pub struct Settings { | |||||||
|     initial_window_size: Option<u32>, |     initial_window_size: Option<u32>, | ||||||
|     max_frame_size: Option<u32>, |     max_frame_size: Option<u32>, | ||||||
|     max_header_list_size: Option<u32>, |     max_header_list_size: Option<u32>, | ||||||
|  |     enable_connect_protocol: Option<u32>, | ||||||
| } | } | ||||||
|  |  | ||||||
| /// An enum that lists all valid settings that can be sent in a SETTINGS | /// An enum that lists all valid settings that can be sent in a SETTINGS | ||||||
| @@ -27,6 +28,7 @@ pub enum Setting { | |||||||
|     InitialWindowSize(u32), |     InitialWindowSize(u32), | ||||||
|     MaxFrameSize(u32), |     MaxFrameSize(u32), | ||||||
|     MaxHeaderListSize(u32), |     MaxHeaderListSize(u32), | ||||||
|  |     EnableConnectProtocol(u32), | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Copy, Clone, Eq, PartialEq, Default)] | #[derive(Copy, Clone, Eq, PartialEq, Default)] | ||||||
| @@ -107,6 +109,14 @@ impl Settings { | |||||||
|         self.enable_push = Some(enable as u32); |         self.enable_push = Some(enable as u32); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn is_extended_connect_protocol_enabled(&self) -> Option<bool> { | ||||||
|  |         self.enable_connect_protocol.map(|val| val != 0) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn set_enable_connect_protocol(&mut self, val: Option<u32>) { | ||||||
|  |         self.enable_connect_protocol = val; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn header_table_size(&self) -> Option<u32> { |     pub fn header_table_size(&self) -> Option<u32> { | ||||||
|         self.header_table_size |         self.header_table_size | ||||||
|     } |     } | ||||||
| @@ -181,6 +191,14 @@ impl Settings { | |||||||
|                 Some(MaxHeaderListSize(val)) => { |                 Some(MaxHeaderListSize(val)) => { | ||||||
|                     settings.max_header_list_size = Some(val); |                     settings.max_header_list_size = Some(val); | ||||||
|                 } |                 } | ||||||
|  |                 Some(EnableConnectProtocol(val)) => match val { | ||||||
|  |                     0 | 1 => { | ||||||
|  |                         settings.enable_connect_protocol = Some(val); | ||||||
|  |                     } | ||||||
|  |                     _ => { | ||||||
|  |                         return Err(Error::InvalidSettingValue); | ||||||
|  |                     } | ||||||
|  |                 }, | ||||||
|                 None => {} |                 None => {} | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| @@ -236,6 +254,10 @@ impl Settings { | |||||||
|         if let Some(v) = self.max_header_list_size { |         if let Some(v) = self.max_header_list_size { | ||||||
|             f(MaxHeaderListSize(v)); |             f(MaxHeaderListSize(v)); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         if let Some(v) = self.enable_connect_protocol { | ||||||
|  |             f(EnableConnectProtocol(v)); | ||||||
|  |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -269,6 +291,9 @@ impl fmt::Debug for Settings { | |||||||
|             Setting::MaxHeaderListSize(v) => { |             Setting::MaxHeaderListSize(v) => { | ||||||
|                 builder.field("max_header_list_size", &v); |                 builder.field("max_header_list_size", &v); | ||||||
|             } |             } | ||||||
|  |             Setting::EnableConnectProtocol(v) => { | ||||||
|  |                 builder.field("enable_connect_protocol", &v); | ||||||
|  |             } | ||||||
|         }); |         }); | ||||||
|  |  | ||||||
|         builder.finish() |         builder.finish() | ||||||
| @@ -291,6 +316,7 @@ impl Setting { | |||||||
|             4 => Some(InitialWindowSize(val)), |             4 => Some(InitialWindowSize(val)), | ||||||
|             5 => Some(MaxFrameSize(val)), |             5 => Some(MaxFrameSize(val)), | ||||||
|             6 => Some(MaxHeaderListSize(val)), |             6 => Some(MaxHeaderListSize(val)), | ||||||
|  |             8 => Some(EnableConnectProtocol(val)), | ||||||
|             _ => None, |             _ => None, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -322,6 +348,7 @@ impl Setting { | |||||||
|             InitialWindowSize(v) => (4, v), |             InitialWindowSize(v) => (4, v), | ||||||
|             MaxFrameSize(v) => (5, v), |             MaxFrameSize(v) => (5, v), | ||||||
|             MaxHeaderListSize(v) => (6, v), |             MaxHeaderListSize(v) => (6, v), | ||||||
|  |             EnableConnectProtocol(v) => (8, v), | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         dst.put_u16(kind); |         dst.put_u16(kind); | ||||||
|   | |||||||
| @@ -1,4 +1,5 @@ | |||||||
| use super::{DecoderError, NeedMore}; | use super::{DecoderError, NeedMore}; | ||||||
|  | use crate::ext::Protocol; | ||||||
|  |  | ||||||
| use bytes::Bytes; | use bytes::Bytes; | ||||||
| use http::header::{HeaderName, HeaderValue}; | use http::header::{HeaderName, HeaderValue}; | ||||||
| @@ -14,6 +15,7 @@ pub enum Header<T = HeaderName> { | |||||||
|     Method(Method), |     Method(Method), | ||||||
|     Scheme(BytesStr), |     Scheme(BytesStr), | ||||||
|     Path(BytesStr), |     Path(BytesStr), | ||||||
|  |     Protocol(Protocol), | ||||||
|     Status(StatusCode), |     Status(StatusCode), | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -25,6 +27,7 @@ pub enum Name<'a> { | |||||||
|     Method, |     Method, | ||||||
|     Scheme, |     Scheme, | ||||||
|     Path, |     Path, | ||||||
|  |     Protocol, | ||||||
|     Status, |     Status, | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -51,6 +54,7 @@ impl Header<Option<HeaderName>> { | |||||||
|             Method(v) => Method(v), |             Method(v) => Method(v), | ||||||
|             Scheme(v) => Scheme(v), |             Scheme(v) => Scheme(v), | ||||||
|             Path(v) => Path(v), |             Path(v) => Path(v), | ||||||
|  |             Protocol(v) => Protocol(v), | ||||||
|             Status(v) => Status(v), |             Status(v) => Status(v), | ||||||
|         }) |         }) | ||||||
|     } |     } | ||||||
| @@ -79,6 +83,10 @@ impl Header { | |||||||
|                     let value = BytesStr::try_from(value)?; |                     let value = BytesStr::try_from(value)?; | ||||||
|                     Ok(Header::Path(value)) |                     Ok(Header::Path(value)) | ||||||
|                 } |                 } | ||||||
|  |                 b"protocol" => { | ||||||
|  |                     let value = Protocol::try_from(value)?; | ||||||
|  |                     Ok(Header::Protocol(value)) | ||||||
|  |                 } | ||||||
|                 b"status" => { |                 b"status" => { | ||||||
|                     let status = StatusCode::from_bytes(&value)?; |                     let status = StatusCode::from_bytes(&value)?; | ||||||
|                     Ok(Header::Status(status)) |                     Ok(Header::Status(status)) | ||||||
| @@ -104,6 +112,7 @@ impl Header { | |||||||
|             Header::Method(ref v) => 32 + 7 + v.as_ref().len(), |             Header::Method(ref v) => 32 + 7 + v.as_ref().len(), | ||||||
|             Header::Scheme(ref v) => 32 + 7 + v.len(), |             Header::Scheme(ref v) => 32 + 7 + v.len(), | ||||||
|             Header::Path(ref v) => 32 + 5 + v.len(), |             Header::Path(ref v) => 32 + 5 + v.len(), | ||||||
|  |             Header::Protocol(ref v) => 32 + 9 + v.as_str().len(), | ||||||
|             Header::Status(_) => 32 + 7 + 3, |             Header::Status(_) => 32 + 7 + 3, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -116,6 +125,7 @@ impl Header { | |||||||
|             Header::Method(..) => Name::Method, |             Header::Method(..) => Name::Method, | ||||||
|             Header::Scheme(..) => Name::Scheme, |             Header::Scheme(..) => Name::Scheme, | ||||||
|             Header::Path(..) => Name::Path, |             Header::Path(..) => Name::Path, | ||||||
|  |             Header::Protocol(..) => Name::Protocol, | ||||||
|             Header::Status(..) => Name::Status, |             Header::Status(..) => Name::Status, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -127,6 +137,7 @@ impl Header { | |||||||
|             Header::Method(ref v) => v.as_ref().as_ref(), |             Header::Method(ref v) => v.as_ref().as_ref(), | ||||||
|             Header::Scheme(ref v) => v.as_ref(), |             Header::Scheme(ref v) => v.as_ref(), | ||||||
|             Header::Path(ref v) => v.as_ref(), |             Header::Path(ref v) => v.as_ref(), | ||||||
|  |             Header::Protocol(ref v) => v.as_ref(), | ||||||
|             Header::Status(ref v) => v.as_str().as_ref(), |             Header::Status(ref v) => v.as_str().as_ref(), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -156,6 +167,10 @@ impl Header { | |||||||
|                 Header::Path(ref b) => a == b, |                 Header::Path(ref b) => a == b, | ||||||
|                 _ => false, |                 _ => false, | ||||||
|             }, |             }, | ||||||
|  |             Header::Protocol(ref a) => match *other { | ||||||
|  |                 Header::Protocol(ref b) => a == b, | ||||||
|  |                 _ => false, | ||||||
|  |             }, | ||||||
|             Header::Status(ref a) => match *other { |             Header::Status(ref a) => match *other { | ||||||
|                 Header::Status(ref b) => a == b, |                 Header::Status(ref b) => a == b, | ||||||
|                 _ => false, |                 _ => false, | ||||||
| @@ -205,6 +220,7 @@ impl From<Header> for Header<Option<HeaderName>> { | |||||||
|             Header::Method(v) => Header::Method(v), |             Header::Method(v) => Header::Method(v), | ||||||
|             Header::Scheme(v) => Header::Scheme(v), |             Header::Scheme(v) => Header::Scheme(v), | ||||||
|             Header::Path(v) => Header::Path(v), |             Header::Path(v) => Header::Path(v), | ||||||
|  |             Header::Protocol(v) => Header::Protocol(v), | ||||||
|             Header::Status(v) => Header::Status(v), |             Header::Status(v) => Header::Status(v), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -221,6 +237,7 @@ impl<'a> Name<'a> { | |||||||
|             Name::Method => Ok(Header::Method(Method::from_bytes(&*value)?)), |             Name::Method => Ok(Header::Method(Method::from_bytes(&*value)?)), | ||||||
|             Name::Scheme => Ok(Header::Scheme(BytesStr::try_from(value)?)), |             Name::Scheme => Ok(Header::Scheme(BytesStr::try_from(value)?)), | ||||||
|             Name::Path => Ok(Header::Path(BytesStr::try_from(value)?)), |             Name::Path => Ok(Header::Path(BytesStr::try_from(value)?)), | ||||||
|  |             Name::Protocol => Ok(Header::Protocol(Protocol::try_from(value)?)), | ||||||
|             Name::Status => { |             Name::Status => { | ||||||
|                 match StatusCode::from_bytes(&value) { |                 match StatusCode::from_bytes(&value) { | ||||||
|                     Ok(status) => Ok(Header::Status(status)), |                     Ok(status) => Ok(Header::Status(status)), | ||||||
| @@ -238,6 +255,7 @@ impl<'a> Name<'a> { | |||||||
|             Name::Method => b":method", |             Name::Method => b":method", | ||||||
|             Name::Scheme => b":scheme", |             Name::Scheme => b":scheme", | ||||||
|             Name::Path => b":path", |             Name::Path => b":path", | ||||||
|  |             Name::Protocol => b":protocol", | ||||||
|             Name::Status => b":status", |             Name::Status => b":status", | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -751,6 +751,7 @@ fn index_static(header: &Header) -> Option<(usize, bool)> { | |||||||
|             "/index.html" => Some((5, true)), |             "/index.html" => Some((5, true)), | ||||||
|             _ => Some((4, false)), |             _ => Some((4, false)), | ||||||
|         }, |         }, | ||||||
|  |         Header::Protocol(..) => None, | ||||||
|         Header::Status(ref v) => match u16::from(*v) { |         Header::Status(ref v) => match u16::from(*v) { | ||||||
|             200 => Some((8, true)), |             200 => Some((8, true)), | ||||||
|             204 => Some((9, true)), |             204 => Some((9, true)), | ||||||
|   | |||||||
| @@ -134,6 +134,7 @@ fn key_str(e: &Header) -> &str { | |||||||
|         Header::Method(..) => ":method", |         Header::Method(..) => ":method", | ||||||
|         Header::Scheme(..) => ":scheme", |         Header::Scheme(..) => ":scheme", | ||||||
|         Header::Path(..) => ":path", |         Header::Path(..) => ":path", | ||||||
|  |         Header::Protocol(..) => ":protocol", | ||||||
|         Header::Status(..) => ":status", |         Header::Status(..) => ":status", | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -145,6 +146,7 @@ fn value_str(e: &Header) -> &str { | |||||||
|         Header::Method(ref m) => m.as_str(), |         Header::Method(ref m) => m.as_str(), | ||||||
|         Header::Scheme(ref v) => &**v, |         Header::Scheme(ref v) => &**v, | ||||||
|         Header::Path(ref v) => &**v, |         Header::Path(ref v) => &**v, | ||||||
|  |         Header::Protocol(ref v) => v.as_str(), | ||||||
|         Header::Status(ref v) => v.as_str(), |         Header::Status(ref v) => v.as_str(), | ||||||
|     } |     } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -120,6 +120,7 @@ mod frame; | |||||||
| pub mod frame; | pub mod frame; | ||||||
|  |  | ||||||
| pub mod client; | pub mod client; | ||||||
|  | pub mod ext; | ||||||
| pub mod server; | pub mod server; | ||||||
| mod share; | mod share; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -110,6 +110,10 @@ where | |||||||
|                 initial_max_send_streams: config.initial_max_send_streams, |                 initial_max_send_streams: config.initial_max_send_streams, | ||||||
|                 local_next_stream_id: config.next_stream_id, |                 local_next_stream_id: config.next_stream_id, | ||||||
|                 local_push_enabled: config.settings.is_push_enabled().unwrap_or(true), |                 local_push_enabled: config.settings.is_push_enabled().unwrap_or(true), | ||||||
|  |                 extended_connect_protocol_enabled: config | ||||||
|  |                     .settings | ||||||
|  |                     .is_extended_connect_protocol_enabled() | ||||||
|  |                     .unwrap_or(false), | ||||||
|                 local_reset_duration: config.reset_stream_duration, |                 local_reset_duration: config.reset_stream_duration, | ||||||
|                 local_reset_max: config.reset_stream_max, |                 local_reset_max: config.reset_stream_max, | ||||||
|                 remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, |                 remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, | ||||||
| @@ -147,6 +151,13 @@ where | |||||||
|         self.inner.settings.send_settings(settings) |         self.inner.settings.send_settings(settings) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     /// Send a new SETTINGS frame with extended CONNECT protocol enabled. | ||||||
|  |     pub(crate) fn set_enable_connect_protocol(&mut self) -> Result<(), UserError> { | ||||||
|  |         let mut settings = frame::Settings::default(); | ||||||
|  |         settings.set_enable_connect_protocol(Some(1)); | ||||||
|  |         self.inner.settings.send_settings(settings) | ||||||
|  |     } | ||||||
|  |  | ||||||
|     /// Returns the maximum number of concurrent streams that may be initiated |     /// Returns the maximum number of concurrent streams that may be initiated | ||||||
|     /// by this peer. |     /// by this peer. | ||||||
|     pub(crate) fn max_send_streams(&self) -> usize { |     pub(crate) fn max_send_streams(&self) -> usize { | ||||||
|   | |||||||
| @@ -117,6 +117,8 @@ impl Settings { | |||||||
|  |  | ||||||
|             tracing::trace!("ACK sent; applying settings"); |             tracing::trace!("ACK sent; applying settings"); | ||||||
|  |  | ||||||
|  |             streams.apply_remote_settings(settings)?; | ||||||
|  |  | ||||||
|             if let Some(val) = settings.header_table_size() { |             if let Some(val) = settings.header_table_size() { | ||||||
|                 dst.set_send_header_table_size(val as usize); |                 dst.set_send_header_table_size(val as usize); | ||||||
|             } |             } | ||||||
| @@ -124,8 +126,6 @@ impl Settings { | |||||||
|             if let Some(val) = settings.max_frame_size() { |             if let Some(val) = settings.max_frame_size() { | ||||||
|                 dst.set_max_send_frame_size(val as usize); |                 dst.set_max_send_frame_size(val as usize); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             streams.apply_remote_settings(settings)?; |  | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         self.remote = None; |         self.remote = None; | ||||||
|   | |||||||
| @@ -47,6 +47,9 @@ pub struct Config { | |||||||
|     /// If the local peer is willing to receive push promises |     /// If the local peer is willing to receive push promises | ||||||
|     pub local_push_enabled: bool, |     pub local_push_enabled: bool, | ||||||
|  |  | ||||||
|  |     /// If extended connect protocol is enabled. | ||||||
|  |     pub extended_connect_protocol_enabled: bool, | ||||||
|  |  | ||||||
|     /// How long a locally reset stream should ignore frames |     /// How long a locally reset stream should ignore frames | ||||||
|     pub local_reset_duration: Duration, |     pub local_reset_duration: Duration, | ||||||
|  |  | ||||||
|   | |||||||
| @@ -56,6 +56,9 @@ pub(super) struct Recv { | |||||||
|  |  | ||||||
|     /// If push promises are allowed to be received. |     /// If push promises are allowed to be received. | ||||||
|     is_push_enabled: bool, |     is_push_enabled: bool, | ||||||
|  |  | ||||||
|  |     /// If extended connect protocol is enabled. | ||||||
|  |     is_extended_connect_protocol_enabled: bool, | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Debug)] | #[derive(Debug)] | ||||||
| @@ -103,6 +106,7 @@ impl Recv { | |||||||
|             buffer: Buffer::new(), |             buffer: Buffer::new(), | ||||||
|             refused: None, |             refused: None, | ||||||
|             is_push_enabled: config.local_push_enabled, |             is_push_enabled: config.local_push_enabled, | ||||||
|  |             is_extended_connect_protocol_enabled: config.extended_connect_protocol_enabled, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -216,6 +220,14 @@ impl Recv { | |||||||
|  |  | ||||||
|         let stream_id = frame.stream_id(); |         let stream_id = frame.stream_id(); | ||||||
|         let (pseudo, fields) = frame.into_parts(); |         let (pseudo, fields) = frame.into_parts(); | ||||||
|  |  | ||||||
|  |         if pseudo.protocol.is_some() { | ||||||
|  |             if counts.peer().is_server() && !self.is_extended_connect_protocol_enabled { | ||||||
|  |                 proto_err!(stream: "cannot use :protocol if extended connect protocol is disabled; stream={:?}", stream.id); | ||||||
|  |                 return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR).into()); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|         if !pseudo.is_informational() { |         if !pseudo.is_informational() { | ||||||
|             let message = counts |             let message = counts | ||||||
|                 .peer() |                 .peer() | ||||||
| @@ -449,12 +461,11 @@ impl Recv { | |||||||
|         settings: &frame::Settings, |         settings: &frame::Settings, | ||||||
|         store: &mut Store, |         store: &mut Store, | ||||||
|     ) -> Result<(), proto::Error> { |     ) -> Result<(), proto::Error> { | ||||||
|         let target = if let Some(val) = settings.initial_window_size() { |         if let Some(val) = settings.is_extended_connect_protocol_enabled() { | ||||||
|             val |             self.is_extended_connect_protocol_enabled = val; | ||||||
|         } else { |         } | ||||||
|             return Ok(()); |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|  |         if let Some(target) = settings.initial_window_size() { | ||||||
|             let old_sz = self.init_window_sz; |             let old_sz = self.init_window_sz; | ||||||
|             self.init_window_sz = target; |             self.init_window_sz = target; | ||||||
|  |  | ||||||
| @@ -483,13 +494,12 @@ impl Recv { | |||||||
|  |  | ||||||
|                 store.for_each(|mut stream| { |                 store.for_each(|mut stream| { | ||||||
|                     stream.recv_flow.dec_recv_window(dec); |                     stream.recv_flow.dec_recv_window(dec); | ||||||
|                 Ok(()) |  | ||||||
|                 }) |                 }) | ||||||
|             } else if target > old_sz { |             } else if target > old_sz { | ||||||
|                 // We must increase the (local) window on every open stream. |                 // We must increase the (local) window on every open stream. | ||||||
|                 let inc = target - old_sz; |                 let inc = target - old_sz; | ||||||
|                 tracing::trace!("incrementing all windows; inc={}", inc); |                 tracing::trace!("incrementing all windows; inc={}", inc); | ||||||
|             store.for_each(|mut stream| { |                 store.try_for_each(|mut stream| { | ||||||
|                     // XXX: Shouldn't the peer have already noticed our |                     // XXX: Shouldn't the peer have already noticed our | ||||||
|                     // overflow and sent us a GOAWAY? |                     // overflow and sent us a GOAWAY? | ||||||
|                     stream |                     stream | ||||||
| @@ -497,14 +507,14 @@ impl Recv { | |||||||
|                         .inc_window(inc) |                         .inc_window(inc) | ||||||
|                         .map_err(proto::Error::library_go_away)?; |                         .map_err(proto::Error::library_go_away)?; | ||||||
|                     stream.recv_flow.assign_capacity(inc); |                     stream.recv_flow.assign_capacity(inc); | ||||||
|                 Ok(()) |                     Ok::<_, proto::Error>(()) | ||||||
|             }) |                 })?; | ||||||
|         } else { |  | ||||||
|             // size is the same... so do nothing |  | ||||||
|             Ok(()) |  | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn is_end_stream(&self, stream: &store::Ptr) -> bool { |     pub fn is_end_stream(&self, stream: &store::Ptr) -> bool { | ||||||
|         if !stream.state.is_recv_closed() { |         if !stream.state.is_recv_closed() { | ||||||
|             return false; |             return false; | ||||||
|   | |||||||
| @@ -35,6 +35,9 @@ pub(super) struct Send { | |||||||
|     prioritize: Prioritize, |     prioritize: Prioritize, | ||||||
|  |  | ||||||
|     is_push_enabled: bool, |     is_push_enabled: bool, | ||||||
|  |  | ||||||
|  |     /// If extended connect protocol is enabled. | ||||||
|  |     is_extended_connect_protocol_enabled: bool, | ||||||
| } | } | ||||||
|  |  | ||||||
| /// A value to detect which public API has called `poll_reset`. | /// A value to detect which public API has called `poll_reset`. | ||||||
| @@ -53,6 +56,7 @@ impl Send { | |||||||
|             next_stream_id: Ok(config.local_next_stream_id), |             next_stream_id: Ok(config.local_next_stream_id), | ||||||
|             prioritize: Prioritize::new(config), |             prioritize: Prioritize::new(config), | ||||||
|             is_push_enabled: true, |             is_push_enabled: true, | ||||||
|  |             is_extended_connect_protocol_enabled: false, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -429,6 +433,10 @@ impl Send { | |||||||
|         counts: &mut Counts, |         counts: &mut Counts, | ||||||
|         task: &mut Option<Waker>, |         task: &mut Option<Waker>, | ||||||
|     ) -> Result<(), Error> { |     ) -> Result<(), Error> { | ||||||
|  |         if let Some(val) = settings.is_extended_connect_protocol_enabled() { | ||||||
|  |             self.is_extended_connect_protocol_enabled = val; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         // Applies an update to the remote endpoint's initial window size. |         // Applies an update to the remote endpoint's initial window size. | ||||||
|         // |         // | ||||||
|         // Per RFC 7540 §6.9.2: |         // Per RFC 7540 §6.9.2: | ||||||
| @@ -490,16 +498,14 @@ impl Send { | |||||||
|                     // TODO: Should this notify the producer when the capacity |                     // TODO: Should this notify the producer when the capacity | ||||||
|                     // of a stream is reduced? Maybe it should if the capacity |                     // of a stream is reduced? Maybe it should if the capacity | ||||||
|                     // is reduced to zero, allowing the producer to stop work. |                     // is reduced to zero, allowing the producer to stop work. | ||||||
|  |                 }); | ||||||
|                     Ok::<_, Error>(()) |  | ||||||
|                 })?; |  | ||||||
|  |  | ||||||
|                 self.prioritize |                 self.prioritize | ||||||
|                     .assign_connection_capacity(total_reclaimed, store, counts); |                     .assign_connection_capacity(total_reclaimed, store, counts); | ||||||
|             } else if val > old_val { |             } else if val > old_val { | ||||||
|                 let inc = val - old_val; |                 let inc = val - old_val; | ||||||
|  |  | ||||||
|                 store.for_each(|mut stream| { |                 store.try_for_each(|mut stream| { | ||||||
|                     self.recv_stream_window_update(inc, buffer, &mut stream, counts, task) |                     self.recv_stream_window_update(inc, buffer, &mut stream, counts, task) | ||||||
|                         .map_err(Error::library_go_away) |                         .map_err(Error::library_go_away) | ||||||
|                 })?; |                 })?; | ||||||
| @@ -554,4 +560,8 @@ impl Send { | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { | ||||||
|  |         self.is_extended_connect_protocol_enabled | ||||||
|  |     } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ use slab; | |||||||
|  |  | ||||||
| use indexmap::{self, IndexMap}; | use indexmap::{self, IndexMap}; | ||||||
|  |  | ||||||
|  | use std::convert::Infallible; | ||||||
| use std::fmt; | use std::fmt; | ||||||
| use std::marker::PhantomData; | use std::marker::PhantomData; | ||||||
| use std::ops; | use std::ops; | ||||||
| @@ -128,7 +129,20 @@ impl Store { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub fn for_each<F, E>(&mut self, mut f: F) -> Result<(), E> |     pub(crate) fn for_each<F>(&mut self, mut f: F) | ||||||
|  |     where | ||||||
|  |         F: FnMut(Ptr), | ||||||
|  |     { | ||||||
|  |         match self.try_for_each(|ptr| { | ||||||
|  |             f(ptr); | ||||||
|  |             Ok::<_, Infallible>(()) | ||||||
|  |         }) { | ||||||
|  |             Ok(()) => (), | ||||||
|  |             Err(infallible) => match infallible {}, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn try_for_each<F, E>(&mut self, mut f: F) -> Result<(), E> | ||||||
|     where |     where | ||||||
|         F: FnMut(Ptr) -> Result<(), E>, |         F: FnMut(Ptr) -> Result<(), E>, | ||||||
|     { |     { | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ use super::recv::RecvHeaderBlockError; | |||||||
| use super::store::{self, Entry, Resolve, Store}; | use super::store::{self, Entry, Resolve, Store}; | ||||||
| use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; | use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; | ||||||
| use crate::codec::{Codec, SendError, UserError}; | use crate::codec::{Codec, SendError, UserError}; | ||||||
|  | use crate::ext::Protocol; | ||||||
| use crate::frame::{self, Frame, Reason}; | use crate::frame::{self, Frame, Reason}; | ||||||
| use crate::proto::{peer, Error, Initiator, Open, Peer, WindowSize}; | use crate::proto::{peer, Error, Initiator, Open, Peer, WindowSize}; | ||||||
| use crate::{client, proto, server}; | use crate::{client, proto, server}; | ||||||
| @@ -214,6 +215,8 @@ where | |||||||
|         use super::stream::ContentLength; |         use super::stream::ContentLength; | ||||||
|         use http::Method; |         use http::Method; | ||||||
|  |  | ||||||
|  |         let protocol = request.extensions_mut().remove::<Protocol>(); | ||||||
|  |  | ||||||
|         // Clear before taking lock, incase extensions contain a StreamRef. |         // Clear before taking lock, incase extensions contain a StreamRef. | ||||||
|         request.extensions_mut().clear(); |         request.extensions_mut().clear(); | ||||||
|  |  | ||||||
| @@ -261,7 +264,8 @@ where | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         // Convert the message |         // Convert the message | ||||||
|         let headers = client::Peer::convert_send_message(stream_id, request, end_of_stream)?; |         let headers = | ||||||
|  |             client::Peer::convert_send_message(stream_id, request, protocol, end_of_stream)?; | ||||||
|  |  | ||||||
|         let mut stream = me.store.insert(stream.id, stream); |         let mut stream = me.store.insert(stream.id, stream); | ||||||
|  |  | ||||||
| @@ -294,6 +298,15 @@ where | |||||||
|             send_buffer: self.send_buffer.clone(), |             send_buffer: self.send_buffer.clone(), | ||||||
|         }) |         }) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { | ||||||
|  |         self.inner | ||||||
|  |             .lock() | ||||||
|  |             .unwrap() | ||||||
|  |             .actions | ||||||
|  |             .send | ||||||
|  |             .is_extended_connect_protocol_enabled() | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| impl<B> DynStreams<'_, B> { | impl<B> DynStreams<'_, B> { | ||||||
| @@ -643,15 +656,12 @@ impl Inner { | |||||||
|  |  | ||||||
|         let last_processed_id = actions.recv.last_processed_id(); |         let last_processed_id = actions.recv.last_processed_id(); | ||||||
|  |  | ||||||
|         self.store |         self.store.for_each(|stream| { | ||||||
|             .for_each(|stream| { |  | ||||||
|             counts.transition(stream, |counts, stream| { |             counts.transition(stream, |counts, stream| { | ||||||
|                 actions.recv.handle_error(&err, &mut *stream); |                 actions.recv.handle_error(&err, &mut *stream); | ||||||
|                 actions.send.handle_error(send_buffer, stream, counts); |                 actions.send.handle_error(send_buffer, stream, counts); | ||||||
|                     Ok::<_, ()>(()) |  | ||||||
|             }) |             }) | ||||||
|             }) |         }); | ||||||
|             .unwrap(); |  | ||||||
|  |  | ||||||
|         actions.conn_error = Some(err); |         actions.conn_error = Some(err); | ||||||
|  |  | ||||||
| @@ -674,19 +684,14 @@ impl Inner { | |||||||
|  |  | ||||||
|         let err = Error::remote_go_away(frame.debug_data().clone(), frame.reason()); |         let err = Error::remote_go_away(frame.debug_data().clone(), frame.reason()); | ||||||
|  |  | ||||||
|         self.store |         self.store.for_each(|stream| { | ||||||
|             .for_each(|stream| { |  | ||||||
|             if stream.id > last_stream_id { |             if stream.id > last_stream_id { | ||||||
|                 counts.transition(stream, |counts, stream| { |                 counts.transition(stream, |counts, stream| { | ||||||
|                     actions.recv.handle_error(&err, &mut *stream); |                     actions.recv.handle_error(&err, &mut *stream); | ||||||
|                     actions.send.handle_error(send_buffer, stream, counts); |                     actions.send.handle_error(send_buffer, stream, counts); | ||||||
|                         Ok::<_, ()>(()) |  | ||||||
|                 }) |                 }) | ||||||
|                 } else { |  | ||||||
|                     Ok::<_, ()>(()) |  | ||||||
|             } |             } | ||||||
|             }) |         }); | ||||||
|             .unwrap(); |  | ||||||
|  |  | ||||||
|         actions.conn_error = Some(err); |         actions.conn_error = Some(err); | ||||||
|  |  | ||||||
| @@ -807,18 +812,15 @@ impl Inner { | |||||||
|  |  | ||||||
|         tracing::trace!("Streams::recv_eof"); |         tracing::trace!("Streams::recv_eof"); | ||||||
|  |  | ||||||
|         self.store |         self.store.for_each(|stream| { | ||||||
|             .for_each(|stream| { |  | ||||||
|             counts.transition(stream, |counts, stream| { |             counts.transition(stream, |counts, stream| { | ||||||
|                 actions.recv.recv_eof(stream); |                 actions.recv.recv_eof(stream); | ||||||
|  |  | ||||||
|                 // This handles resetting send state associated with the |                 // This handles resetting send state associated with the | ||||||
|                 // stream |                 // stream | ||||||
|                 actions.send.handle_error(send_buffer, stream, counts); |                 actions.send.handle_error(send_buffer, stream, counts); | ||||||
|                     Ok::<_, ()>(()) |  | ||||||
|             }) |             }) | ||||||
|             }) |         }); | ||||||
|             .expect("recv_eof"); |  | ||||||
|  |  | ||||||
|         actions.clear_queues(clear_pending_accept, &mut self.store, counts); |         actions.clear_queues(clear_pending_accept, &mut self.store, counts); | ||||||
|         Ok(()) |         Ok(()) | ||||||
|   | |||||||
| @@ -470,6 +470,19 @@ where | |||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     /// Enables the [extended CONNECT protocol]. | ||||||
|  |     /// | ||||||
|  |     /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 | ||||||
|  |     /// | ||||||
|  |     /// # Errors | ||||||
|  |     /// | ||||||
|  |     /// Returns an error if a previous call is still pending acknowledgement | ||||||
|  |     /// from the remote endpoint. | ||||||
|  |     pub fn enable_connect_protocol(&mut self) -> Result<(), crate::Error> { | ||||||
|  |         self.connection.set_enable_connect_protocol()?; | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
|  |  | ||||||
|     /// Returns `Ready` when the underlying connection has closed. |     /// Returns `Ready` when the underlying connection has closed. | ||||||
|     /// |     /// | ||||||
|     /// If any new inbound streams are received during a call to `poll_closed`, |     /// If any new inbound streams are received during a call to `poll_closed`, | ||||||
| @@ -904,6 +917,14 @@ impl Builder { | |||||||
|         self |         self | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     /// Enables the [extended CONNECT protocol]. | ||||||
|  |     /// | ||||||
|  |     /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 | ||||||
|  |     pub fn enable_connect_protocol(&mut self) -> &mut Self { | ||||||
|  |         self.settings.set_enable_connect_protocol(Some(1)); | ||||||
|  |         self | ||||||
|  |     } | ||||||
|  |  | ||||||
|     /// Creates a new configured HTTP/2 server backed by `io`. |     /// Creates a new configured HTTP/2 server backed by `io`. | ||||||
|     /// |     /// | ||||||
|     /// It is expected that `io` already be in an appropriate state to commence |     /// It is expected that `io` already be in an appropriate state to commence | ||||||
| @@ -1360,7 +1381,7 @@ impl Peer { | |||||||
|             _, |             _, | ||||||
|         ) = request.into_parts(); |         ) = request.into_parts(); | ||||||
|  |  | ||||||
|         let pseudo = Pseudo::request(method, uri); |         let pseudo = Pseudo::request(method, uri, None); | ||||||
|  |  | ||||||
|         Ok(frame::PushPromise::new( |         Ok(frame::PushPromise::new( | ||||||
|             stream_id, |             stream_id, | ||||||
| @@ -1410,6 +1431,11 @@ impl proto::Peer for Peer { | |||||||
|             malformed!("malformed headers: missing method"); |             malformed!("malformed headers: missing method"); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         let has_protocol = pseudo.protocol.is_some(); | ||||||
|  |         if !is_connect && has_protocol { | ||||||
|  |             malformed!("malformed headers: :protocol on non-CONNECT request"); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         if pseudo.status.is_some() { |         if pseudo.status.is_some() { | ||||||
|             malformed!("malformed headers: :status field on request"); |             malformed!("malformed headers: :status field on request"); | ||||||
|         } |         } | ||||||
| @@ -1432,7 +1458,7 @@ impl proto::Peer for Peer { | |||||||
|  |  | ||||||
|         // A :scheme is required, except CONNECT. |         // A :scheme is required, except CONNECT. | ||||||
|         if let Some(scheme) = pseudo.scheme { |         if let Some(scheme) = pseudo.scheme { | ||||||
|             if is_connect { |             if is_connect && !has_protocol { | ||||||
|                 malformed!(":scheme in CONNECT"); |                 malformed!(":scheme in CONNECT"); | ||||||
|             } |             } | ||||||
|             let maybe_scheme = scheme.parse(); |             let maybe_scheme = scheme.parse(); | ||||||
| @@ -1450,12 +1476,12 @@ impl proto::Peer for Peer { | |||||||
|             if parts.authority.is_some() { |             if parts.authority.is_some() { | ||||||
|                 parts.scheme = Some(scheme); |                 parts.scheme = Some(scheme); | ||||||
|             } |             } | ||||||
|         } else if !is_connect { |         } else if !is_connect || has_protocol { | ||||||
|             malformed!("malformed headers: missing scheme"); |             malformed!("malformed headers: missing scheme"); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if let Some(path) = pseudo.path { |         if let Some(path) = pseudo.path { | ||||||
|             if is_connect { |             if is_connect && !has_protocol { | ||||||
|                 malformed!(":path in CONNECT"); |                 malformed!(":path in CONNECT"); | ||||||
|             } |             } | ||||||
|  |  | ||||||
| @@ -1468,6 +1494,8 @@ impl proto::Peer for Peer { | |||||||
|             parts.path_and_query = Some(maybe_path.or_else(|why| { |             parts.path_and_query = Some(maybe_path.or_else(|why| { | ||||||
|                 malformed!("malformed headers: malformed path ({:?}): {}", path, why,) |                 malformed!("malformed headers: malformed path ({:?}): {}", path, why,) | ||||||
|             })?); |             })?); | ||||||
|  |         } else if is_connect && has_protocol { | ||||||
|  |             malformed!("malformed headers: missing path in extended CONNECT"); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         b = b.uri(parts); |         b = b.uri(parts); | ||||||
|   | |||||||
| @@ -47,6 +47,16 @@ macro_rules! assert_settings { | |||||||
|     }}; |     }}; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[macro_export] | ||||||
|  | macro_rules! assert_go_away { | ||||||
|  |     ($frame:expr) => {{ | ||||||
|  |         match $frame { | ||||||
|  |             h2::frame::Frame::GoAway(v) => v, | ||||||
|  |             f => panic!("expected GO_AWAY; actual={:?}", f), | ||||||
|  |         } | ||||||
|  |     }}; | ||||||
|  | } | ||||||
|  |  | ||||||
| #[macro_export] | #[macro_export] | ||||||
| macro_rules! poll_err { | macro_rules! poll_err { | ||||||
|     ($transport:expr) => {{ |     ($transport:expr) => {{ | ||||||
| @@ -80,6 +90,7 @@ macro_rules! assert_default_settings { | |||||||
|  |  | ||||||
| use h2::frame::Frame; | use h2::frame::Frame; | ||||||
|  |  | ||||||
|  | #[track_caller] | ||||||
| pub fn assert_frame_eq<T: Into<Frame>, U: Into<Frame>>(t: T, u: U) { | pub fn assert_frame_eq<T: Into<Frame>, U: Into<Frame>>(t: T, u: U) { | ||||||
|     let actual: Frame = t.into(); |     let actual: Frame = t.into(); | ||||||
|     let expected: Frame = u.into(); |     let expected: Frame = u.into(); | ||||||
|   | |||||||
| @@ -4,7 +4,10 @@ use std::fmt; | |||||||
| use bytes::Bytes; | use bytes::Bytes; | ||||||
| use http::{self, HeaderMap, StatusCode}; | use http::{self, HeaderMap, StatusCode}; | ||||||
|  |  | ||||||
| use h2::frame::{self, Frame, StreamId}; | use h2::{ | ||||||
|  |     ext::Protocol, | ||||||
|  |     frame::{self, Frame, StreamId}, | ||||||
|  | }; | ||||||
|  |  | ||||||
| pub const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; | pub const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; | ||||||
| pub const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; | pub const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; | ||||||
| @@ -109,7 +112,9 @@ impl Mock<frame::Headers> { | |||||||
|         let method = method.try_into().unwrap(); |         let method = method.try_into().unwrap(); | ||||||
|         let uri = uri.try_into().unwrap(); |         let uri = uri.try_into().unwrap(); | ||||||
|         let (id, _, fields) = self.into_parts(); |         let (id, _, fields) = self.into_parts(); | ||||||
|         let frame = frame::Headers::new(id, frame::Pseudo::request(method, uri), fields); |         let extensions = Default::default(); | ||||||
|  |         let pseudo = frame::Pseudo::request(method, uri, extensions); | ||||||
|  |         let frame = frame::Headers::new(id, pseudo, fields); | ||||||
|         Mock(frame) |         Mock(frame) | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -179,6 +184,15 @@ impl Mock<frame::Headers> { | |||||||
|         Mock(frame::Headers::new(id, pseudo, fields)) |         Mock(frame::Headers::new(id, pseudo, fields)) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn protocol(self, value: &str) -> Self { | ||||||
|  |         let (id, mut pseudo, fields) = self.into_parts(); | ||||||
|  |         let value = Protocol::from(value); | ||||||
|  |  | ||||||
|  |         pseudo.set_protocol(value); | ||||||
|  |  | ||||||
|  |         Mock(frame::Headers::new(id, pseudo, fields)) | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn eos(mut self) -> Self { |     pub fn eos(mut self) -> Self { | ||||||
|         self.0.set_end_stream(); |         self.0.set_end_stream(); | ||||||
|         self |         self | ||||||
| @@ -230,8 +244,9 @@ impl Mock<frame::PushPromise> { | |||||||
|         let method = method.try_into().unwrap(); |         let method = method.try_into().unwrap(); | ||||||
|         let uri = uri.try_into().unwrap(); |         let uri = uri.try_into().unwrap(); | ||||||
|         let (id, promised, _, fields) = self.into_parts(); |         let (id, promised, _, fields) = self.into_parts(); | ||||||
|         let frame = |         let extensions = Default::default(); | ||||||
|             frame::PushPromise::new(id, promised, frame::Pseudo::request(method, uri), fields); |         let pseudo = frame::Pseudo::request(method, uri, extensions); | ||||||
|  |         let frame = frame::PushPromise::new(id, promised, pseudo, fields); | ||||||
|         Mock(frame) |         Mock(frame) | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -352,6 +367,11 @@ impl Mock<frame::Settings> { | |||||||
|         self.0.set_enable_push(false); |         self.0.set_enable_push(false); | ||||||
|         self |         self | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn enable_connect_protocol(mut self, val: u32) -> Self { | ||||||
|  |         self.0.set_enable_connect_protocol(Some(val)); | ||||||
|  |         self | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| impl From<Mock<frame::Settings>> for frame::Settings { | impl From<Mock<frame::Settings>> for frame::Settings { | ||||||
|   | |||||||
| @@ -221,22 +221,15 @@ impl Handle { | |||||||
|         let settings = settings.into(); |         let settings = settings.into(); | ||||||
|         self.send(settings.into()).await.unwrap(); |         self.send(settings.into()).await.unwrap(); | ||||||
|  |  | ||||||
|         let frame = self.next().await; |         let frame = self.next().await.expect("unexpected EOF").unwrap(); | ||||||
|         let settings = match frame { |         let settings = assert_settings!(frame); | ||||||
|             Some(frame) => match frame.unwrap() { |  | ||||||
|                 Frame::Settings(settings) => { |  | ||||||
|         // Send the ACK |         // Send the ACK | ||||||
|         let ack = frame::Settings::ack(); |         let ack = frame::Settings::ack(); | ||||||
|  |  | ||||||
|         // TODO: Don't unwrap? |         // TODO: Don't unwrap? | ||||||
|         self.send(ack.into()).await.unwrap(); |         self.send(ack.into()).await.unwrap(); | ||||||
|  |  | ||||||
|                     settings |  | ||||||
|                 } |  | ||||||
|                 frame => panic!("unexpected frame; frame={:?}", frame), |  | ||||||
|             }, |  | ||||||
|             None => panic!("unexpected EOF"), |  | ||||||
|         }; |  | ||||||
|         let frame = self.next().await; |         let frame = self.next().await; | ||||||
|         let f = assert_settings!(frame.unwrap().unwrap()); |         let f = assert_settings!(frame.unwrap().unwrap()); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,6 +2,7 @@ | |||||||
| pub use h2; | pub use h2; | ||||||
|  |  | ||||||
| pub use h2::client; | pub use h2::client; | ||||||
|  | pub use h2::ext::Protocol; | ||||||
| pub use h2::frame::StreamId; | pub use h2::frame::StreamId; | ||||||
| pub use h2::server; | pub use h2::server; | ||||||
| pub use h2::*; | pub use h2::*; | ||||||
| @@ -20,8 +21,8 @@ pub use super::{Codec, SendFrame}; | |||||||
|  |  | ||||||
| // Re-export macros | // Re-export macros | ||||||
| pub use super::{ | pub use super::{ | ||||||
|     assert_closed, assert_data, assert_default_settings, assert_headers, assert_ping, poll_err, |     assert_closed, assert_data, assert_default_settings, assert_go_away, assert_headers, | ||||||
|     poll_frame, raw_codec, |     assert_ping, assert_settings, poll_err, poll_frame, raw_codec, | ||||||
| }; | }; | ||||||
|  |  | ||||||
| pub use super::assert::assert_frame_eq; | pub use super::assert::assert_frame_eq; | ||||||
|   | |||||||
| @@ -1305,6 +1305,153 @@ async fn informational_while_local_streaming() { | |||||||
|     join(srv, h2).await; |     join(srv, h2).await; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[tokio::test] | ||||||
|  | async fn extended_connect_protocol_disabled_by_default() { | ||||||
|  |     h2_support::trace_init!(); | ||||||
|  |     let (io, mut srv) = mock::new(); | ||||||
|  |  | ||||||
|  |     let srv = async move { | ||||||
|  |         let settings = srv.assert_client_handshake().await; | ||||||
|  |         assert_default_settings!(settings); | ||||||
|  |  | ||||||
|  |         srv.recv_frame( | ||||||
|  |             frames::headers(1) | ||||||
|  |                 .request("GET", "https://example.com/") | ||||||
|  |                 .eos(), | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |         srv.send_frame(frames::headers(1).response(200).eos()).await; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let h2 = async move { | ||||||
|  |         let (mut client, mut h2) = client::handshake(io).await.unwrap(); | ||||||
|  |  | ||||||
|  |         // we send a simple req here just to drive the connection so we can | ||||||
|  |         // receive the server settings. | ||||||
|  |         let request = Request::get("https://example.com/").body(()).unwrap(); | ||||||
|  |         // first request is allowed | ||||||
|  |         let (response, _) = client.send_request(request, true).unwrap(); | ||||||
|  |         h2.drive(response).await.unwrap(); | ||||||
|  |  | ||||||
|  |         assert!(!client.is_extended_connect_protocol_enabled()); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     join(srv, h2).await; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[tokio::test] | ||||||
|  | async fn extended_connect_protocol_enabled_during_handshake() { | ||||||
|  |     h2_support::trace_init!(); | ||||||
|  |     let (io, mut srv) = mock::new(); | ||||||
|  |  | ||||||
|  |     let srv = async move { | ||||||
|  |         let settings = srv | ||||||
|  |             .assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1)) | ||||||
|  |             .await; | ||||||
|  |         assert_default_settings!(settings); | ||||||
|  |  | ||||||
|  |         srv.recv_frame( | ||||||
|  |             frames::headers(1) | ||||||
|  |                 .request("GET", "https://example.com/") | ||||||
|  |                 .eos(), | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |         srv.send_frame(frames::headers(1).response(200).eos()).await; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let h2 = async move { | ||||||
|  |         let (mut client, mut h2) = client::handshake(io).await.unwrap(); | ||||||
|  |  | ||||||
|  |         // we send a simple req here just to drive the connection so we can | ||||||
|  |         // receive the server settings. | ||||||
|  |         let request = Request::get("https://example.com/").body(()).unwrap(); | ||||||
|  |         let (response, _) = client.send_request(request, true).unwrap(); | ||||||
|  |         h2.drive(response).await.unwrap(); | ||||||
|  |  | ||||||
|  |         assert!(client.is_extended_connect_protocol_enabled()); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     join(srv, h2).await; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[tokio::test] | ||||||
|  | async fn invalid_connect_protocol_enabled_setting() { | ||||||
|  |     h2_support::trace_init!(); | ||||||
|  |  | ||||||
|  |     let (io, mut srv) = mock::new(); | ||||||
|  |  | ||||||
|  |     let srv = async move { | ||||||
|  |         // Send a settings frame | ||||||
|  |         srv.send(frames::settings().enable_connect_protocol(2).into()) | ||||||
|  |             .await | ||||||
|  |             .unwrap(); | ||||||
|  |         srv.read_preface().await.unwrap(); | ||||||
|  |  | ||||||
|  |         let settings = assert_settings!(srv.next().await.expect("unexpected EOF").unwrap()); | ||||||
|  |         assert_default_settings!(settings); | ||||||
|  |  | ||||||
|  |         // Send the ACK | ||||||
|  |         let ack = frame::Settings::ack(); | ||||||
|  |  | ||||||
|  |         // TODO: Don't unwrap? | ||||||
|  |         srv.send(ack.into()).await.unwrap(); | ||||||
|  |  | ||||||
|  |         let frame = srv.next().await.unwrap().unwrap(); | ||||||
|  |         let go_away = assert_go_away!(frame); | ||||||
|  |         assert_eq!(go_away.reason(), Reason::PROTOCOL_ERROR); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let h2 = async move { | ||||||
|  |         let (mut client, mut h2) = client::handshake(io).await.unwrap(); | ||||||
|  |  | ||||||
|  |         // we send a simple req here just to drive the connection so we can | ||||||
|  |         // receive the server settings. | ||||||
|  |         let request = Request::get("https://example.com/").body(()).unwrap(); | ||||||
|  |         let (response, _) = client.send_request(request, true).unwrap(); | ||||||
|  |  | ||||||
|  |         let error = h2.drive(response).await.unwrap_err(); | ||||||
|  |         assert_eq!(error.reason(), Some(Reason::PROTOCOL_ERROR)); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     join(srv, h2).await; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[tokio::test] | ||||||
|  | async fn extended_connect_request() { | ||||||
|  |     h2_support::trace_init!(); | ||||||
|  |  | ||||||
|  |     let (io, mut srv) = mock::new(); | ||||||
|  |  | ||||||
|  |     let srv = async move { | ||||||
|  |         let settings = srv | ||||||
|  |             .assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1)) | ||||||
|  |             .await; | ||||||
|  |         assert_default_settings!(settings); | ||||||
|  |  | ||||||
|  |         srv.recv_frame( | ||||||
|  |             frames::headers(1) | ||||||
|  |                 .request("CONNECT", "http://bread/baguette") | ||||||
|  |                 .protocol("the-bread-protocol") | ||||||
|  |                 .eos(), | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |         srv.send_frame(frames::headers(1).response(200).eos()).await; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let h2 = async move { | ||||||
|  |         let (mut client, mut h2) = client::handshake(io).await.unwrap(); | ||||||
|  |  | ||||||
|  |         let request = Request::connect("http://bread/baguette") | ||||||
|  |             .extension(Protocol::from("the-bread-protocol")) | ||||||
|  |             .body(()) | ||||||
|  |             .unwrap(); | ||||||
|  |         let (response, _) = client.send_request(request, true).unwrap(); | ||||||
|  |         h2.drive(response).await.unwrap(); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     join(srv, h2).await; | ||||||
|  | } | ||||||
|  |  | ||||||
| const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; | const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; | ||||||
| const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; | const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1149,3 +1149,191 @@ async fn send_reset_explicitly() { | |||||||
|  |  | ||||||
|     join(client, srv).await; |     join(client, srv).await; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[tokio::test] | ||||||
|  | async fn extended_connect_protocol_disabled_by_default() { | ||||||
|  |     h2_support::trace_init!(); | ||||||
|  |  | ||||||
|  |     let (io, mut client) = mock::new(); | ||||||
|  |  | ||||||
|  |     let client = async move { | ||||||
|  |         let settings = client.assert_server_handshake().await; | ||||||
|  |  | ||||||
|  |         assert_eq!(settings.is_extended_connect_protocol_enabled(), None); | ||||||
|  |  | ||||||
|  |         client | ||||||
|  |             .send_frame( | ||||||
|  |                 frames::headers(1) | ||||||
|  |                     .request("CONNECT", "http://bread/baguette") | ||||||
|  |                     .protocol("the-bread-protocol"), | ||||||
|  |             ) | ||||||
|  |             .await; | ||||||
|  |  | ||||||
|  |         client.recv_frame(frames::reset(1).protocol_error()).await; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let srv = async move { | ||||||
|  |         let mut srv = server::handshake(io).await.expect("handshake"); | ||||||
|  |  | ||||||
|  |         poll_fn(move |cx| srv.poll_closed(cx)) | ||||||
|  |             .await | ||||||
|  |             .expect("server"); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     join(client, srv).await; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[tokio::test] | ||||||
|  | async fn extended_connect_protocol_enabled_during_handshake() { | ||||||
|  |     h2_support::trace_init!(); | ||||||
|  |  | ||||||
|  |     let (io, mut client) = mock::new(); | ||||||
|  |  | ||||||
|  |     let client = async move { | ||||||
|  |         let settings = client.assert_server_handshake().await; | ||||||
|  |  | ||||||
|  |         assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); | ||||||
|  |  | ||||||
|  |         client | ||||||
|  |             .send_frame( | ||||||
|  |                 frames::headers(1) | ||||||
|  |                     .request("CONNECT", "http://bread/baguette") | ||||||
|  |                     .protocol("the-bread-protocol"), | ||||||
|  |             ) | ||||||
|  |             .await; | ||||||
|  |  | ||||||
|  |         client.recv_frame(frames::headers(1).response(200)).await; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let srv = async move { | ||||||
|  |         let mut builder = server::Builder::new(); | ||||||
|  |  | ||||||
|  |         builder.enable_connect_protocol(); | ||||||
|  |  | ||||||
|  |         let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); | ||||||
|  |  | ||||||
|  |         let (_req, mut stream) = srv.next().await.unwrap().unwrap(); | ||||||
|  |  | ||||||
|  |         let rsp = Response::new(()); | ||||||
|  |         stream.send_response(rsp, false).unwrap(); | ||||||
|  |  | ||||||
|  |         poll_fn(move |cx| srv.poll_closed(cx)) | ||||||
|  |             .await | ||||||
|  |             .expect("server"); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     join(client, srv).await; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[tokio::test] | ||||||
|  | async fn reject_pseudo_protocol_on_non_connect_request() { | ||||||
|  |     h2_support::trace_init!(); | ||||||
|  |  | ||||||
|  |     let (io, mut client) = mock::new(); | ||||||
|  |  | ||||||
|  |     let client = async move { | ||||||
|  |         let settings = client.assert_server_handshake().await; | ||||||
|  |  | ||||||
|  |         assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); | ||||||
|  |  | ||||||
|  |         client | ||||||
|  |             .send_frame( | ||||||
|  |                 frames::headers(1) | ||||||
|  |                     .request("GET", "http://bread/baguette") | ||||||
|  |                     .protocol("the-bread-protocol"), | ||||||
|  |             ) | ||||||
|  |             .await; | ||||||
|  |  | ||||||
|  |         client.recv_frame(frames::reset(1).protocol_error()).await; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let srv = async move { | ||||||
|  |         let mut builder = server::Builder::new(); | ||||||
|  |  | ||||||
|  |         builder.enable_connect_protocol(); | ||||||
|  |  | ||||||
|  |         let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); | ||||||
|  |  | ||||||
|  |         assert!(srv.next().await.is_none()); | ||||||
|  |  | ||||||
|  |         poll_fn(move |cx| srv.poll_closed(cx)) | ||||||
|  |             .await | ||||||
|  |             .expect("server"); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     join(client, srv).await; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[tokio::test] | ||||||
|  | async fn reject_authority_target_on_extended_connect_request() { | ||||||
|  |     h2_support::trace_init!(); | ||||||
|  |  | ||||||
|  |     let (io, mut client) = mock::new(); | ||||||
|  |  | ||||||
|  |     let client = async move { | ||||||
|  |         let settings = client.assert_server_handshake().await; | ||||||
|  |  | ||||||
|  |         assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); | ||||||
|  |  | ||||||
|  |         client | ||||||
|  |             .send_frame( | ||||||
|  |                 frames::headers(1) | ||||||
|  |                     .request("CONNECT", "bread:80") | ||||||
|  |                     .protocol("the-bread-protocol"), | ||||||
|  |             ) | ||||||
|  |             .await; | ||||||
|  |  | ||||||
|  |         client.recv_frame(frames::reset(1).protocol_error()).await; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let srv = async move { | ||||||
|  |         let mut builder = server::Builder::new(); | ||||||
|  |  | ||||||
|  |         builder.enable_connect_protocol(); | ||||||
|  |  | ||||||
|  |         let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); | ||||||
|  |  | ||||||
|  |         assert!(srv.next().await.is_none()); | ||||||
|  |  | ||||||
|  |         poll_fn(move |cx| srv.poll_closed(cx)) | ||||||
|  |             .await | ||||||
|  |             .expect("server"); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     join(client, srv).await; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[tokio::test] | ||||||
|  | async fn reject_non_authority_target_on_connect_request() { | ||||||
|  |     h2_support::trace_init!(); | ||||||
|  |  | ||||||
|  |     let (io, mut client) = mock::new(); | ||||||
|  |  | ||||||
|  |     let client = async move { | ||||||
|  |         let settings = client.assert_server_handshake().await; | ||||||
|  |  | ||||||
|  |         assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); | ||||||
|  |  | ||||||
|  |         client | ||||||
|  |             .send_frame(frames::headers(1).request("CONNECT", "https://bread/baguette")) | ||||||
|  |             .await; | ||||||
|  |  | ||||||
|  |         client.recv_frame(frames::reset(1).protocol_error()).await; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let srv = async move { | ||||||
|  |         let mut builder = server::Builder::new(); | ||||||
|  |  | ||||||
|  |         builder.enable_connect_protocol(); | ||||||
|  |  | ||||||
|  |         let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); | ||||||
|  |  | ||||||
|  |         assert!(srv.next().await.is_none()); | ||||||
|  |  | ||||||
|  |         poll_fn(move |cx| srv.poll_closed(cx)) | ||||||
|  |             .await | ||||||
|  |             .expect("server"); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     join(client, srv).await; | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user