Implement the extended CONNECT protocol from RFC 8441 (#565)
This commit is contained in:
		| @@ -47,6 +47,16 @@ macro_rules! assert_settings { | ||||
|     }}; | ||||
| } | ||||
|  | ||||
| #[macro_export] | ||||
| macro_rules! assert_go_away { | ||||
|     ($frame:expr) => {{ | ||||
|         match $frame { | ||||
|             h2::frame::Frame::GoAway(v) => v, | ||||
|             f => panic!("expected GO_AWAY; actual={:?}", f), | ||||
|         } | ||||
|     }}; | ||||
| } | ||||
|  | ||||
| #[macro_export] | ||||
| macro_rules! poll_err { | ||||
|     ($transport:expr) => {{ | ||||
| @@ -80,6 +90,7 @@ macro_rules! assert_default_settings { | ||||
|  | ||||
| use h2::frame::Frame; | ||||
|  | ||||
| #[track_caller] | ||||
| pub fn assert_frame_eq<T: Into<Frame>, U: Into<Frame>>(t: T, u: U) { | ||||
|     let actual: Frame = t.into(); | ||||
|     let expected: Frame = u.into(); | ||||
|   | ||||
| @@ -4,7 +4,10 @@ use std::fmt; | ||||
| use bytes::Bytes; | ||||
| use http::{self, HeaderMap, StatusCode}; | ||||
|  | ||||
| use h2::frame::{self, Frame, StreamId}; | ||||
| use h2::{ | ||||
|     ext::Protocol, | ||||
|     frame::{self, Frame, StreamId}, | ||||
| }; | ||||
|  | ||||
| pub const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; | ||||
| pub const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; | ||||
| @@ -109,7 +112,9 @@ impl Mock<frame::Headers> { | ||||
|         let method = method.try_into().unwrap(); | ||||
|         let uri = uri.try_into().unwrap(); | ||||
|         let (id, _, fields) = self.into_parts(); | ||||
|         let frame = frame::Headers::new(id, frame::Pseudo::request(method, uri), fields); | ||||
|         let extensions = Default::default(); | ||||
|         let pseudo = frame::Pseudo::request(method, uri, extensions); | ||||
|         let frame = frame::Headers::new(id, pseudo, fields); | ||||
|         Mock(frame) | ||||
|     } | ||||
|  | ||||
| @@ -179,6 +184,15 @@ impl Mock<frame::Headers> { | ||||
|         Mock(frame::Headers::new(id, pseudo, fields)) | ||||
|     } | ||||
|  | ||||
|     pub fn protocol(self, value: &str) -> Self { | ||||
|         let (id, mut pseudo, fields) = self.into_parts(); | ||||
|         let value = Protocol::from(value); | ||||
|  | ||||
|         pseudo.set_protocol(value); | ||||
|  | ||||
|         Mock(frame::Headers::new(id, pseudo, fields)) | ||||
|     } | ||||
|  | ||||
|     pub fn eos(mut self) -> Self { | ||||
|         self.0.set_end_stream(); | ||||
|         self | ||||
| @@ -230,8 +244,9 @@ impl Mock<frame::PushPromise> { | ||||
|         let method = method.try_into().unwrap(); | ||||
|         let uri = uri.try_into().unwrap(); | ||||
|         let (id, promised, _, fields) = self.into_parts(); | ||||
|         let frame = | ||||
|             frame::PushPromise::new(id, promised, frame::Pseudo::request(method, uri), fields); | ||||
|         let extensions = Default::default(); | ||||
|         let pseudo = frame::Pseudo::request(method, uri, extensions); | ||||
|         let frame = frame::PushPromise::new(id, promised, pseudo, fields); | ||||
|         Mock(frame) | ||||
|     } | ||||
|  | ||||
| @@ -352,6 +367,11 @@ impl Mock<frame::Settings> { | ||||
|         self.0.set_enable_push(false); | ||||
|         self | ||||
|     } | ||||
|  | ||||
|     pub fn enable_connect_protocol(mut self, val: u32) -> Self { | ||||
|         self.0.set_enable_connect_protocol(Some(val)); | ||||
|         self | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<Mock<frame::Settings>> for frame::Settings { | ||||
|   | ||||
| @@ -221,22 +221,15 @@ impl Handle { | ||||
|         let settings = settings.into(); | ||||
|         self.send(settings.into()).await.unwrap(); | ||||
|  | ||||
|         let frame = self.next().await; | ||||
|         let settings = match frame { | ||||
|             Some(frame) => match frame.unwrap() { | ||||
|                 Frame::Settings(settings) => { | ||||
|                     // Send the ACK | ||||
|                     let ack = frame::Settings::ack(); | ||||
|         let frame = self.next().await.expect("unexpected EOF").unwrap(); | ||||
|         let settings = assert_settings!(frame); | ||||
|  | ||||
|                     // TODO: Don't unwrap? | ||||
|                     self.send(ack.into()).await.unwrap(); | ||||
|         // Send the ACK | ||||
|         let ack = frame::Settings::ack(); | ||||
|  | ||||
|         // TODO: Don't unwrap? | ||||
|         self.send(ack.into()).await.unwrap(); | ||||
|  | ||||
|                     settings | ||||
|                 } | ||||
|                 frame => panic!("unexpected frame; frame={:?}", frame), | ||||
|             }, | ||||
|             None => panic!("unexpected EOF"), | ||||
|         }; | ||||
|         let frame = self.next().await; | ||||
|         let f = assert_settings!(frame.unwrap().unwrap()); | ||||
|  | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
| pub use h2; | ||||
|  | ||||
| pub use h2::client; | ||||
| pub use h2::ext::Protocol; | ||||
| pub use h2::frame::StreamId; | ||||
| pub use h2::server; | ||||
| pub use h2::*; | ||||
| @@ -20,8 +21,8 @@ pub use super::{Codec, SendFrame}; | ||||
|  | ||||
| // Re-export macros | ||||
| pub use super::{ | ||||
|     assert_closed, assert_data, assert_default_settings, assert_headers, assert_ping, poll_err, | ||||
|     poll_frame, raw_codec, | ||||
|     assert_closed, assert_data, assert_default_settings, assert_go_away, assert_headers, | ||||
|     assert_ping, assert_settings, poll_err, poll_frame, raw_codec, | ||||
| }; | ||||
|  | ||||
| pub use super::assert::assert_frame_eq; | ||||
|   | ||||
| @@ -1305,6 +1305,153 @@ async fn informational_while_local_streaming() { | ||||
|     join(srv, h2).await; | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn extended_connect_protocol_disabled_by_default() { | ||||
|     h2_support::trace_init!(); | ||||
|     let (io, mut srv) = mock::new(); | ||||
|  | ||||
|     let srv = async move { | ||||
|         let settings = srv.assert_client_handshake().await; | ||||
|         assert_default_settings!(settings); | ||||
|  | ||||
|         srv.recv_frame( | ||||
|             frames::headers(1) | ||||
|                 .request("GET", "https://example.com/") | ||||
|                 .eos(), | ||||
|         ) | ||||
|         .await; | ||||
|         srv.send_frame(frames::headers(1).response(200).eos()).await; | ||||
|     }; | ||||
|  | ||||
|     let h2 = async move { | ||||
|         let (mut client, mut h2) = client::handshake(io).await.unwrap(); | ||||
|  | ||||
|         // we send a simple req here just to drive the connection so we can | ||||
|         // receive the server settings. | ||||
|         let request = Request::get("https://example.com/").body(()).unwrap(); | ||||
|         // first request is allowed | ||||
|         let (response, _) = client.send_request(request, true).unwrap(); | ||||
|         h2.drive(response).await.unwrap(); | ||||
|  | ||||
|         assert!(!client.is_extended_connect_protocol_enabled()); | ||||
|     }; | ||||
|  | ||||
|     join(srv, h2).await; | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn extended_connect_protocol_enabled_during_handshake() { | ||||
|     h2_support::trace_init!(); | ||||
|     let (io, mut srv) = mock::new(); | ||||
|  | ||||
|     let srv = async move { | ||||
|         let settings = srv | ||||
|             .assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1)) | ||||
|             .await; | ||||
|         assert_default_settings!(settings); | ||||
|  | ||||
|         srv.recv_frame( | ||||
|             frames::headers(1) | ||||
|                 .request("GET", "https://example.com/") | ||||
|                 .eos(), | ||||
|         ) | ||||
|         .await; | ||||
|         srv.send_frame(frames::headers(1).response(200).eos()).await; | ||||
|     }; | ||||
|  | ||||
|     let h2 = async move { | ||||
|         let (mut client, mut h2) = client::handshake(io).await.unwrap(); | ||||
|  | ||||
|         // we send a simple req here just to drive the connection so we can | ||||
|         // receive the server settings. | ||||
|         let request = Request::get("https://example.com/").body(()).unwrap(); | ||||
|         let (response, _) = client.send_request(request, true).unwrap(); | ||||
|         h2.drive(response).await.unwrap(); | ||||
|  | ||||
|         assert!(client.is_extended_connect_protocol_enabled()); | ||||
|     }; | ||||
|  | ||||
|     join(srv, h2).await; | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn invalid_connect_protocol_enabled_setting() { | ||||
|     h2_support::trace_init!(); | ||||
|  | ||||
|     let (io, mut srv) = mock::new(); | ||||
|  | ||||
|     let srv = async move { | ||||
|         // Send a settings frame | ||||
|         srv.send(frames::settings().enable_connect_protocol(2).into()) | ||||
|             .await | ||||
|             .unwrap(); | ||||
|         srv.read_preface().await.unwrap(); | ||||
|  | ||||
|         let settings = assert_settings!(srv.next().await.expect("unexpected EOF").unwrap()); | ||||
|         assert_default_settings!(settings); | ||||
|  | ||||
|         // Send the ACK | ||||
|         let ack = frame::Settings::ack(); | ||||
|  | ||||
|         // TODO: Don't unwrap? | ||||
|         srv.send(ack.into()).await.unwrap(); | ||||
|  | ||||
|         let frame = srv.next().await.unwrap().unwrap(); | ||||
|         let go_away = assert_go_away!(frame); | ||||
|         assert_eq!(go_away.reason(), Reason::PROTOCOL_ERROR); | ||||
|     }; | ||||
|  | ||||
|     let h2 = async move { | ||||
|         let (mut client, mut h2) = client::handshake(io).await.unwrap(); | ||||
|  | ||||
|         // we send a simple req here just to drive the connection so we can | ||||
|         // receive the server settings. | ||||
|         let request = Request::get("https://example.com/").body(()).unwrap(); | ||||
|         let (response, _) = client.send_request(request, true).unwrap(); | ||||
|  | ||||
|         let error = h2.drive(response).await.unwrap_err(); | ||||
|         assert_eq!(error.reason(), Some(Reason::PROTOCOL_ERROR)); | ||||
|     }; | ||||
|  | ||||
|     join(srv, h2).await; | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn extended_connect_request() { | ||||
|     h2_support::trace_init!(); | ||||
|  | ||||
|     let (io, mut srv) = mock::new(); | ||||
|  | ||||
|     let srv = async move { | ||||
|         let settings = srv | ||||
|             .assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1)) | ||||
|             .await; | ||||
|         assert_default_settings!(settings); | ||||
|  | ||||
|         srv.recv_frame( | ||||
|             frames::headers(1) | ||||
|                 .request("CONNECT", "http://bread/baguette") | ||||
|                 .protocol("the-bread-protocol") | ||||
|                 .eos(), | ||||
|         ) | ||||
|         .await; | ||||
|         srv.send_frame(frames::headers(1).response(200).eos()).await; | ||||
|     }; | ||||
|  | ||||
|     let h2 = async move { | ||||
|         let (mut client, mut h2) = client::handshake(io).await.unwrap(); | ||||
|  | ||||
|         let request = Request::connect("http://bread/baguette") | ||||
|             .extension(Protocol::from("the-bread-protocol")) | ||||
|             .body(()) | ||||
|             .unwrap(); | ||||
|         let (response, _) = client.send_request(request, true).unwrap(); | ||||
|         h2.drive(response).await.unwrap(); | ||||
|     }; | ||||
|  | ||||
|     join(srv, h2).await; | ||||
| } | ||||
|  | ||||
| const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; | ||||
| const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; | ||||
|  | ||||
|   | ||||
| @@ -1149,3 +1149,191 @@ async fn send_reset_explicitly() { | ||||
|  | ||||
|     join(client, srv).await; | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn extended_connect_protocol_disabled_by_default() { | ||||
|     h2_support::trace_init!(); | ||||
|  | ||||
|     let (io, mut client) = mock::new(); | ||||
|  | ||||
|     let client = async move { | ||||
|         let settings = client.assert_server_handshake().await; | ||||
|  | ||||
|         assert_eq!(settings.is_extended_connect_protocol_enabled(), None); | ||||
|  | ||||
|         client | ||||
|             .send_frame( | ||||
|                 frames::headers(1) | ||||
|                     .request("CONNECT", "http://bread/baguette") | ||||
|                     .protocol("the-bread-protocol"), | ||||
|             ) | ||||
|             .await; | ||||
|  | ||||
|         client.recv_frame(frames::reset(1).protocol_error()).await; | ||||
|     }; | ||||
|  | ||||
|     let srv = async move { | ||||
|         let mut srv = server::handshake(io).await.expect("handshake"); | ||||
|  | ||||
|         poll_fn(move |cx| srv.poll_closed(cx)) | ||||
|             .await | ||||
|             .expect("server"); | ||||
|     }; | ||||
|  | ||||
|     join(client, srv).await; | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn extended_connect_protocol_enabled_during_handshake() { | ||||
|     h2_support::trace_init!(); | ||||
|  | ||||
|     let (io, mut client) = mock::new(); | ||||
|  | ||||
|     let client = async move { | ||||
|         let settings = client.assert_server_handshake().await; | ||||
|  | ||||
|         assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); | ||||
|  | ||||
|         client | ||||
|             .send_frame( | ||||
|                 frames::headers(1) | ||||
|                     .request("CONNECT", "http://bread/baguette") | ||||
|                     .protocol("the-bread-protocol"), | ||||
|             ) | ||||
|             .await; | ||||
|  | ||||
|         client.recv_frame(frames::headers(1).response(200)).await; | ||||
|     }; | ||||
|  | ||||
|     let srv = async move { | ||||
|         let mut builder = server::Builder::new(); | ||||
|  | ||||
|         builder.enable_connect_protocol(); | ||||
|  | ||||
|         let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); | ||||
|  | ||||
|         let (_req, mut stream) = srv.next().await.unwrap().unwrap(); | ||||
|  | ||||
|         let rsp = Response::new(()); | ||||
|         stream.send_response(rsp, false).unwrap(); | ||||
|  | ||||
|         poll_fn(move |cx| srv.poll_closed(cx)) | ||||
|             .await | ||||
|             .expect("server"); | ||||
|     }; | ||||
|  | ||||
|     join(client, srv).await; | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn reject_pseudo_protocol_on_non_connect_request() { | ||||
|     h2_support::trace_init!(); | ||||
|  | ||||
|     let (io, mut client) = mock::new(); | ||||
|  | ||||
|     let client = async move { | ||||
|         let settings = client.assert_server_handshake().await; | ||||
|  | ||||
|         assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); | ||||
|  | ||||
|         client | ||||
|             .send_frame( | ||||
|                 frames::headers(1) | ||||
|                     .request("GET", "http://bread/baguette") | ||||
|                     .protocol("the-bread-protocol"), | ||||
|             ) | ||||
|             .await; | ||||
|  | ||||
|         client.recv_frame(frames::reset(1).protocol_error()).await; | ||||
|     }; | ||||
|  | ||||
|     let srv = async move { | ||||
|         let mut builder = server::Builder::new(); | ||||
|  | ||||
|         builder.enable_connect_protocol(); | ||||
|  | ||||
|         let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); | ||||
|  | ||||
|         assert!(srv.next().await.is_none()); | ||||
|  | ||||
|         poll_fn(move |cx| srv.poll_closed(cx)) | ||||
|             .await | ||||
|             .expect("server"); | ||||
|     }; | ||||
|  | ||||
|     join(client, srv).await; | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn reject_authority_target_on_extended_connect_request() { | ||||
|     h2_support::trace_init!(); | ||||
|  | ||||
|     let (io, mut client) = mock::new(); | ||||
|  | ||||
|     let client = async move { | ||||
|         let settings = client.assert_server_handshake().await; | ||||
|  | ||||
|         assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); | ||||
|  | ||||
|         client | ||||
|             .send_frame( | ||||
|                 frames::headers(1) | ||||
|                     .request("CONNECT", "bread:80") | ||||
|                     .protocol("the-bread-protocol"), | ||||
|             ) | ||||
|             .await; | ||||
|  | ||||
|         client.recv_frame(frames::reset(1).protocol_error()).await; | ||||
|     }; | ||||
|  | ||||
|     let srv = async move { | ||||
|         let mut builder = server::Builder::new(); | ||||
|  | ||||
|         builder.enable_connect_protocol(); | ||||
|  | ||||
|         let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); | ||||
|  | ||||
|         assert!(srv.next().await.is_none()); | ||||
|  | ||||
|         poll_fn(move |cx| srv.poll_closed(cx)) | ||||
|             .await | ||||
|             .expect("server"); | ||||
|     }; | ||||
|  | ||||
|     join(client, srv).await; | ||||
| } | ||||
|  | ||||
| #[tokio::test] | ||||
| async fn reject_non_authority_target_on_connect_request() { | ||||
|     h2_support::trace_init!(); | ||||
|  | ||||
|     let (io, mut client) = mock::new(); | ||||
|  | ||||
|     let client = async move { | ||||
|         let settings = client.assert_server_handshake().await; | ||||
|  | ||||
|         assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); | ||||
|  | ||||
|         client | ||||
|             .send_frame(frames::headers(1).request("CONNECT", "https://bread/baguette")) | ||||
|             .await; | ||||
|  | ||||
|         client.recv_frame(frames::reset(1).protocol_error()).await; | ||||
|     }; | ||||
|  | ||||
|     let srv = async move { | ||||
|         let mut builder = server::Builder::new(); | ||||
|  | ||||
|         builder.enable_connect_protocol(); | ||||
|  | ||||
|         let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); | ||||
|  | ||||
|         assert!(srv.next().await.is_none()); | ||||
|  | ||||
|         poll_fn(move |cx| srv.poll_closed(cx)) | ||||
|             .await | ||||
|             .expect("server"); | ||||
|     }; | ||||
|  | ||||
|     join(client, srv).await; | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user