Implement the extended CONNECT protocol from RFC 8441 (#565)
This commit is contained in:
@@ -47,6 +47,16 @@ macro_rules! assert_settings {
|
||||
}};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! assert_go_away {
|
||||
($frame:expr) => {{
|
||||
match $frame {
|
||||
h2::frame::Frame::GoAway(v) => v,
|
||||
f => panic!("expected GO_AWAY; actual={:?}", f),
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! poll_err {
|
||||
($transport:expr) => {{
|
||||
@@ -80,6 +90,7 @@ macro_rules! assert_default_settings {
|
||||
|
||||
use h2::frame::Frame;
|
||||
|
||||
#[track_caller]
|
||||
pub fn assert_frame_eq<T: Into<Frame>, U: Into<Frame>>(t: T, u: U) {
|
||||
let actual: Frame = t.into();
|
||||
let expected: Frame = u.into();
|
||||
|
||||
@@ -4,7 +4,10 @@ use std::fmt;
|
||||
use bytes::Bytes;
|
||||
use http::{self, HeaderMap, StatusCode};
|
||||
|
||||
use h2::frame::{self, Frame, StreamId};
|
||||
use h2::{
|
||||
ext::Protocol,
|
||||
frame::{self, Frame, StreamId},
|
||||
};
|
||||
|
||||
pub const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0];
|
||||
pub const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0];
|
||||
@@ -109,7 +112,9 @@ impl Mock<frame::Headers> {
|
||||
let method = method.try_into().unwrap();
|
||||
let uri = uri.try_into().unwrap();
|
||||
let (id, _, fields) = self.into_parts();
|
||||
let frame = frame::Headers::new(id, frame::Pseudo::request(method, uri), fields);
|
||||
let extensions = Default::default();
|
||||
let pseudo = frame::Pseudo::request(method, uri, extensions);
|
||||
let frame = frame::Headers::new(id, pseudo, fields);
|
||||
Mock(frame)
|
||||
}
|
||||
|
||||
@@ -179,6 +184,15 @@ impl Mock<frame::Headers> {
|
||||
Mock(frame::Headers::new(id, pseudo, fields))
|
||||
}
|
||||
|
||||
pub fn protocol(self, value: &str) -> Self {
|
||||
let (id, mut pseudo, fields) = self.into_parts();
|
||||
let value = Protocol::from(value);
|
||||
|
||||
pseudo.set_protocol(value);
|
||||
|
||||
Mock(frame::Headers::new(id, pseudo, fields))
|
||||
}
|
||||
|
||||
pub fn eos(mut self) -> Self {
|
||||
self.0.set_end_stream();
|
||||
self
|
||||
@@ -230,8 +244,9 @@ impl Mock<frame::PushPromise> {
|
||||
let method = method.try_into().unwrap();
|
||||
let uri = uri.try_into().unwrap();
|
||||
let (id, promised, _, fields) = self.into_parts();
|
||||
let frame =
|
||||
frame::PushPromise::new(id, promised, frame::Pseudo::request(method, uri), fields);
|
||||
let extensions = Default::default();
|
||||
let pseudo = frame::Pseudo::request(method, uri, extensions);
|
||||
let frame = frame::PushPromise::new(id, promised, pseudo, fields);
|
||||
Mock(frame)
|
||||
}
|
||||
|
||||
@@ -352,6 +367,11 @@ impl Mock<frame::Settings> {
|
||||
self.0.set_enable_push(false);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn enable_connect_protocol(mut self, val: u32) -> Self {
|
||||
self.0.set_enable_connect_protocol(Some(val));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Mock<frame::Settings>> for frame::Settings {
|
||||
|
||||
@@ -221,22 +221,15 @@ impl Handle {
|
||||
let settings = settings.into();
|
||||
self.send(settings.into()).await.unwrap();
|
||||
|
||||
let frame = self.next().await;
|
||||
let settings = match frame {
|
||||
Some(frame) => match frame.unwrap() {
|
||||
Frame::Settings(settings) => {
|
||||
// Send the ACK
|
||||
let ack = frame::Settings::ack();
|
||||
let frame = self.next().await.expect("unexpected EOF").unwrap();
|
||||
let settings = assert_settings!(frame);
|
||||
|
||||
// TODO: Don't unwrap?
|
||||
self.send(ack.into()).await.unwrap();
|
||||
// Send the ACK
|
||||
let ack = frame::Settings::ack();
|
||||
|
||||
// TODO: Don't unwrap?
|
||||
self.send(ack.into()).await.unwrap();
|
||||
|
||||
settings
|
||||
}
|
||||
frame => panic!("unexpected frame; frame={:?}", frame),
|
||||
},
|
||||
None => panic!("unexpected EOF"),
|
||||
};
|
||||
let frame = self.next().await;
|
||||
let f = assert_settings!(frame.unwrap().unwrap());
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
pub use h2;
|
||||
|
||||
pub use h2::client;
|
||||
pub use h2::ext::Protocol;
|
||||
pub use h2::frame::StreamId;
|
||||
pub use h2::server;
|
||||
pub use h2::*;
|
||||
@@ -20,8 +21,8 @@ pub use super::{Codec, SendFrame};
|
||||
|
||||
// Re-export macros
|
||||
pub use super::{
|
||||
assert_closed, assert_data, assert_default_settings, assert_headers, assert_ping, poll_err,
|
||||
poll_frame, raw_codec,
|
||||
assert_closed, assert_data, assert_default_settings, assert_go_away, assert_headers,
|
||||
assert_ping, assert_settings, poll_err, poll_frame, raw_codec,
|
||||
};
|
||||
|
||||
pub use super::assert::assert_frame_eq;
|
||||
|
||||
@@ -1305,6 +1305,153 @@ async fn informational_while_local_streaming() {
|
||||
join(srv, h2).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extended_connect_protocol_disabled_by_default() {
|
||||
h2_support::trace_init!();
|
||||
let (io, mut srv) = mock::new();
|
||||
|
||||
let srv = async move {
|
||||
let settings = srv.assert_client_handshake().await;
|
||||
assert_default_settings!(settings);
|
||||
|
||||
srv.recv_frame(
|
||||
frames::headers(1)
|
||||
.request("GET", "https://example.com/")
|
||||
.eos(),
|
||||
)
|
||||
.await;
|
||||
srv.send_frame(frames::headers(1).response(200).eos()).await;
|
||||
};
|
||||
|
||||
let h2 = async move {
|
||||
let (mut client, mut h2) = client::handshake(io).await.unwrap();
|
||||
|
||||
// we send a simple req here just to drive the connection so we can
|
||||
// receive the server settings.
|
||||
let request = Request::get("https://example.com/").body(()).unwrap();
|
||||
// first request is allowed
|
||||
let (response, _) = client.send_request(request, true).unwrap();
|
||||
h2.drive(response).await.unwrap();
|
||||
|
||||
assert!(!client.is_extended_connect_protocol_enabled());
|
||||
};
|
||||
|
||||
join(srv, h2).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extended_connect_protocol_enabled_during_handshake() {
|
||||
h2_support::trace_init!();
|
||||
let (io, mut srv) = mock::new();
|
||||
|
||||
let srv = async move {
|
||||
let settings = srv
|
||||
.assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1))
|
||||
.await;
|
||||
assert_default_settings!(settings);
|
||||
|
||||
srv.recv_frame(
|
||||
frames::headers(1)
|
||||
.request("GET", "https://example.com/")
|
||||
.eos(),
|
||||
)
|
||||
.await;
|
||||
srv.send_frame(frames::headers(1).response(200).eos()).await;
|
||||
};
|
||||
|
||||
let h2 = async move {
|
||||
let (mut client, mut h2) = client::handshake(io).await.unwrap();
|
||||
|
||||
// we send a simple req here just to drive the connection so we can
|
||||
// receive the server settings.
|
||||
let request = Request::get("https://example.com/").body(()).unwrap();
|
||||
let (response, _) = client.send_request(request, true).unwrap();
|
||||
h2.drive(response).await.unwrap();
|
||||
|
||||
assert!(client.is_extended_connect_protocol_enabled());
|
||||
};
|
||||
|
||||
join(srv, h2).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_connect_protocol_enabled_setting() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut srv) = mock::new();
|
||||
|
||||
let srv = async move {
|
||||
// Send a settings frame
|
||||
srv.send(frames::settings().enable_connect_protocol(2).into())
|
||||
.await
|
||||
.unwrap();
|
||||
srv.read_preface().await.unwrap();
|
||||
|
||||
let settings = assert_settings!(srv.next().await.expect("unexpected EOF").unwrap());
|
||||
assert_default_settings!(settings);
|
||||
|
||||
// Send the ACK
|
||||
let ack = frame::Settings::ack();
|
||||
|
||||
// TODO: Don't unwrap?
|
||||
srv.send(ack.into()).await.unwrap();
|
||||
|
||||
let frame = srv.next().await.unwrap().unwrap();
|
||||
let go_away = assert_go_away!(frame);
|
||||
assert_eq!(go_away.reason(), Reason::PROTOCOL_ERROR);
|
||||
};
|
||||
|
||||
let h2 = async move {
|
||||
let (mut client, mut h2) = client::handshake(io).await.unwrap();
|
||||
|
||||
// we send a simple req here just to drive the connection so we can
|
||||
// receive the server settings.
|
||||
let request = Request::get("https://example.com/").body(()).unwrap();
|
||||
let (response, _) = client.send_request(request, true).unwrap();
|
||||
|
||||
let error = h2.drive(response).await.unwrap_err();
|
||||
assert_eq!(error.reason(), Some(Reason::PROTOCOL_ERROR));
|
||||
};
|
||||
|
||||
join(srv, h2).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extended_connect_request() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut srv) = mock::new();
|
||||
|
||||
let srv = async move {
|
||||
let settings = srv
|
||||
.assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1))
|
||||
.await;
|
||||
assert_default_settings!(settings);
|
||||
|
||||
srv.recv_frame(
|
||||
frames::headers(1)
|
||||
.request("CONNECT", "http://bread/baguette")
|
||||
.protocol("the-bread-protocol")
|
||||
.eos(),
|
||||
)
|
||||
.await;
|
||||
srv.send_frame(frames::headers(1).response(200).eos()).await;
|
||||
};
|
||||
|
||||
let h2 = async move {
|
||||
let (mut client, mut h2) = client::handshake(io).await.unwrap();
|
||||
|
||||
let request = Request::connect("http://bread/baguette")
|
||||
.extension(Protocol::from("the-bread-protocol"))
|
||||
.body(())
|
||||
.unwrap();
|
||||
let (response, _) = client.send_request(request, true).unwrap();
|
||||
h2.drive(response).await.unwrap();
|
||||
};
|
||||
|
||||
join(srv, h2).await;
|
||||
}
|
||||
|
||||
const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0];
|
||||
const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0];
|
||||
|
||||
|
||||
@@ -1149,3 +1149,191 @@ async fn send_reset_explicitly() {
|
||||
|
||||
join(client, srv).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extended_connect_protocol_disabled_by_default() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut client) = mock::new();
|
||||
|
||||
let client = async move {
|
||||
let settings = client.assert_server_handshake().await;
|
||||
|
||||
assert_eq!(settings.is_extended_connect_protocol_enabled(), None);
|
||||
|
||||
client
|
||||
.send_frame(
|
||||
frames::headers(1)
|
||||
.request("CONNECT", "http://bread/baguette")
|
||||
.protocol("the-bread-protocol"),
|
||||
)
|
||||
.await;
|
||||
|
||||
client.recv_frame(frames::reset(1).protocol_error()).await;
|
||||
};
|
||||
|
||||
let srv = async move {
|
||||
let mut srv = server::handshake(io).await.expect("handshake");
|
||||
|
||||
poll_fn(move |cx| srv.poll_closed(cx))
|
||||
.await
|
||||
.expect("server");
|
||||
};
|
||||
|
||||
join(client, srv).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn extended_connect_protocol_enabled_during_handshake() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut client) = mock::new();
|
||||
|
||||
let client = async move {
|
||||
let settings = client.assert_server_handshake().await;
|
||||
|
||||
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
|
||||
|
||||
client
|
||||
.send_frame(
|
||||
frames::headers(1)
|
||||
.request("CONNECT", "http://bread/baguette")
|
||||
.protocol("the-bread-protocol"),
|
||||
)
|
||||
.await;
|
||||
|
||||
client.recv_frame(frames::headers(1).response(200)).await;
|
||||
};
|
||||
|
||||
let srv = async move {
|
||||
let mut builder = server::Builder::new();
|
||||
|
||||
builder.enable_connect_protocol();
|
||||
|
||||
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
|
||||
|
||||
let (_req, mut stream) = srv.next().await.unwrap().unwrap();
|
||||
|
||||
let rsp = Response::new(());
|
||||
stream.send_response(rsp, false).unwrap();
|
||||
|
||||
poll_fn(move |cx| srv.poll_closed(cx))
|
||||
.await
|
||||
.expect("server");
|
||||
};
|
||||
|
||||
join(client, srv).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_pseudo_protocol_on_non_connect_request() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut client) = mock::new();
|
||||
|
||||
let client = async move {
|
||||
let settings = client.assert_server_handshake().await;
|
||||
|
||||
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
|
||||
|
||||
client
|
||||
.send_frame(
|
||||
frames::headers(1)
|
||||
.request("GET", "http://bread/baguette")
|
||||
.protocol("the-bread-protocol"),
|
||||
)
|
||||
.await;
|
||||
|
||||
client.recv_frame(frames::reset(1).protocol_error()).await;
|
||||
};
|
||||
|
||||
let srv = async move {
|
||||
let mut builder = server::Builder::new();
|
||||
|
||||
builder.enable_connect_protocol();
|
||||
|
||||
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
|
||||
|
||||
assert!(srv.next().await.is_none());
|
||||
|
||||
poll_fn(move |cx| srv.poll_closed(cx))
|
||||
.await
|
||||
.expect("server");
|
||||
};
|
||||
|
||||
join(client, srv).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_authority_target_on_extended_connect_request() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut client) = mock::new();
|
||||
|
||||
let client = async move {
|
||||
let settings = client.assert_server_handshake().await;
|
||||
|
||||
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
|
||||
|
||||
client
|
||||
.send_frame(
|
||||
frames::headers(1)
|
||||
.request("CONNECT", "bread:80")
|
||||
.protocol("the-bread-protocol"),
|
||||
)
|
||||
.await;
|
||||
|
||||
client.recv_frame(frames::reset(1).protocol_error()).await;
|
||||
};
|
||||
|
||||
let srv = async move {
|
||||
let mut builder = server::Builder::new();
|
||||
|
||||
builder.enable_connect_protocol();
|
||||
|
||||
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
|
||||
|
||||
assert!(srv.next().await.is_none());
|
||||
|
||||
poll_fn(move |cx| srv.poll_closed(cx))
|
||||
.await
|
||||
.expect("server");
|
||||
};
|
||||
|
||||
join(client, srv).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_non_authority_target_on_connect_request() {
|
||||
h2_support::trace_init!();
|
||||
|
||||
let (io, mut client) = mock::new();
|
||||
|
||||
let client = async move {
|
||||
let settings = client.assert_server_handshake().await;
|
||||
|
||||
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
|
||||
|
||||
client
|
||||
.send_frame(frames::headers(1).request("CONNECT", "https://bread/baguette"))
|
||||
.await;
|
||||
|
||||
client.recv_frame(frames::reset(1).protocol_error()).await;
|
||||
};
|
||||
|
||||
let srv = async move {
|
||||
let mut builder = server::Builder::new();
|
||||
|
||||
builder.enable_connect_protocol();
|
||||
|
||||
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
|
||||
|
||||
assert!(srv.next().await.is_none());
|
||||
|
||||
poll_fn(move |cx| srv.poll_closed(cx))
|
||||
.await
|
||||
.expect("server");
|
||||
};
|
||||
|
||||
join(client, srv).await;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user