Implement the extended CONNECT protocol from RFC 8441 (#565)

This commit is contained in:
Anthony Ramine
2021-11-24 10:05:10 +01:00
committed by GitHub
parent dbaa3a4285
commit 87969c1f29
22 changed files with 694 additions and 120 deletions

View File

@@ -1305,6 +1305,153 @@ async fn informational_while_local_streaming() {
join(srv, h2).await;
}
#[tokio::test]
async fn extended_connect_protocol_disabled_by_default() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();
let srv = async move {
let settings = srv.assert_client_handshake().await;
assert_default_settings!(settings);
srv.recv_frame(
frames::headers(1)
.request("GET", "https://example.com/")
.eos(),
)
.await;
srv.send_frame(frames::headers(1).response(200).eos()).await;
};
let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();
// we send a simple req here just to drive the connection so we can
// receive the server settings.
let request = Request::get("https://example.com/").body(()).unwrap();
// first request is allowed
let (response, _) = client.send_request(request, true).unwrap();
h2.drive(response).await.unwrap();
assert!(!client.is_extended_connect_protocol_enabled());
};
join(srv, h2).await;
}
#[tokio::test]
async fn extended_connect_protocol_enabled_during_handshake() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();
let srv = async move {
let settings = srv
.assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1))
.await;
assert_default_settings!(settings);
srv.recv_frame(
frames::headers(1)
.request("GET", "https://example.com/")
.eos(),
)
.await;
srv.send_frame(frames::headers(1).response(200).eos()).await;
};
let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();
// we send a simple req here just to drive the connection so we can
// receive the server settings.
let request = Request::get("https://example.com/").body(()).unwrap();
let (response, _) = client.send_request(request, true).unwrap();
h2.drive(response).await.unwrap();
assert!(client.is_extended_connect_protocol_enabled());
};
join(srv, h2).await;
}
#[tokio::test]
async fn invalid_connect_protocol_enabled_setting() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();
let srv = async move {
// Send a settings frame
srv.send(frames::settings().enable_connect_protocol(2).into())
.await
.unwrap();
srv.read_preface().await.unwrap();
let settings = assert_settings!(srv.next().await.expect("unexpected EOF").unwrap());
assert_default_settings!(settings);
// Send the ACK
let ack = frame::Settings::ack();
// TODO: Don't unwrap?
srv.send(ack.into()).await.unwrap();
let frame = srv.next().await.unwrap().unwrap();
let go_away = assert_go_away!(frame);
assert_eq!(go_away.reason(), Reason::PROTOCOL_ERROR);
};
let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();
// we send a simple req here just to drive the connection so we can
// receive the server settings.
let request = Request::get("https://example.com/").body(()).unwrap();
let (response, _) = client.send_request(request, true).unwrap();
let error = h2.drive(response).await.unwrap_err();
assert_eq!(error.reason(), Some(Reason::PROTOCOL_ERROR));
};
join(srv, h2).await;
}
#[tokio::test]
async fn extended_connect_request() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();
let srv = async move {
let settings = srv
.assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1))
.await;
assert_default_settings!(settings);
srv.recv_frame(
frames::headers(1)
.request("CONNECT", "http://bread/baguette")
.protocol("the-bread-protocol")
.eos(),
)
.await;
srv.send_frame(frames::headers(1).response(200).eos()).await;
};
let h2 = async move {
let (mut client, mut h2) = client::handshake(io).await.unwrap();
let request = Request::connect("http://bread/baguette")
.extension(Protocol::from("the-bread-protocol"))
.body(())
.unwrap();
let (response, _) = client.send_request(request, true).unwrap();
h2.drive(response).await.unwrap();
};
join(srv, h2).await;
}
const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0];
const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0];

View File

@@ -1149,3 +1149,191 @@ async fn send_reset_explicitly() {
join(client, srv).await;
}
#[tokio::test]
async fn extended_connect_protocol_disabled_by_default() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), None);
client
.send_frame(
frames::headers(1)
.request("CONNECT", "http://bread/baguette")
.protocol("the-bread-protocol"),
)
.await;
client.recv_frame(frames::reset(1).protocol_error()).await;
};
let srv = async move {
let mut srv = server::handshake(io).await.expect("handshake");
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}
#[tokio::test]
async fn extended_connect_protocol_enabled_during_handshake() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
client
.send_frame(
frames::headers(1)
.request("CONNECT", "http://bread/baguette")
.protocol("the-bread-protocol"),
)
.await;
client.recv_frame(frames::headers(1).response(200)).await;
};
let srv = async move {
let mut builder = server::Builder::new();
builder.enable_connect_protocol();
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
let (_req, mut stream) = srv.next().await.unwrap().unwrap();
let rsp = Response::new(());
stream.send_response(rsp, false).unwrap();
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}
#[tokio::test]
async fn reject_pseudo_protocol_on_non_connect_request() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
client
.send_frame(
frames::headers(1)
.request("GET", "http://bread/baguette")
.protocol("the-bread-protocol"),
)
.await;
client.recv_frame(frames::reset(1).protocol_error()).await;
};
let srv = async move {
let mut builder = server::Builder::new();
builder.enable_connect_protocol();
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
assert!(srv.next().await.is_none());
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}
#[tokio::test]
async fn reject_authority_target_on_extended_connect_request() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
client
.send_frame(
frames::headers(1)
.request("CONNECT", "bread:80")
.protocol("the-bread-protocol"),
)
.await;
client.recv_frame(frames::reset(1).protocol_error()).await;
};
let srv = async move {
let mut builder = server::Builder::new();
builder.enable_connect_protocol();
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
assert!(srv.next().await.is_none());
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}
#[tokio::test]
async fn reject_non_authority_target_on_connect_request() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let client = async move {
let settings = client.assert_server_handshake().await;
assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true));
client
.send_frame(frames::headers(1).request("CONNECT", "https://bread/baguette"))
.await;
client.recv_frame(frames::reset(1).protocol_error()).await;
};
let srv = async move {
let mut builder = server::Builder::new();
builder.enable_connect_protocol();
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
assert!(srv.next().await.is_none());
poll_fn(move |cx| srv.poll_closed(cx))
.await
.expect("server");
};
join(client, srv).await;
}