feat(http2): Implement Client-side CONNECT support over HTTP/2 (#2523)

Closes #2508
This commit is contained in:
Anthony Ramine
2021-05-24 20:20:44 +02:00
committed by GitHub
parent be9677a1e7
commit 5442b6fadd
10 changed files with 833 additions and 78 deletions

View File

@@ -3,19 +3,24 @@ use std::marker::Unpin;
#[cfg(feature = "runtime")]
use std::time::Duration;
use bytes::Bytes;
use h2::server::{Connection, Handshake, SendResponse};
use h2::Reason;
use h2::{Reason, RecvStream};
use http::{Method, Request};
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use super::{decode_content_length, ping, PipeToSendStream, SendBuf};
use super::{ping, PipeToSendStream, SendBuf};
use crate::body::HttpBody;
use crate::common::exec::ConnStreamExec;
use crate::common::{date, task, Future, Pin, Poll};
use crate::headers;
use crate::proto::h2::ping::Recorder;
use crate::proto::h2::{H2Upgraded, UpgradedSendStream};
use crate::proto::Dispatched;
use crate::service::HttpService;
use crate::upgrade::{OnUpgrade, Pending, Upgraded};
use crate::{Body, Response};
// Our defaults are chosen for the "majority" case, which usually are not
@@ -255,9 +260,9 @@ where
// When the service is ready, accepts an incoming request.
match ready!(self.conn.poll_accept(cx)) {
Some(Ok((req, respond))) => {
Some(Ok((req, mut respond))) => {
trace!("incoming request");
let content_length = decode_content_length(req.headers());
let content_length = headers::content_length_parse_all(req.headers());
let ping = self
.ping
.as_ref()
@@ -267,8 +272,36 @@ where
// Record the headers received
ping.record_non_data();
let req = req.map(|stream| crate::Body::h2(stream, content_length, ping));
let fut = H2Stream::new(service.call(req), respond);
let is_connect = req.method() == Method::CONNECT;
let (mut parts, stream) = req.into_parts();
let (req, connect_parts) = if !is_connect {
(
Request::from_parts(
parts,
crate::Body::h2(stream, content_length.into(), ping),
),
None,
)
} else {
if content_length.map_or(false, |len| len != 0) {
warn!("h2 connect request with non-zero body not supported");
respond.send_reset(h2::Reason::INTERNAL_ERROR);
return Poll::Ready(Ok(()));
}
let (pending, upgrade) = crate::upgrade::pending();
debug_assert!(parts.extensions.get::<OnUpgrade>().is_none());
parts.extensions.insert(upgrade);
(
Request::from_parts(parts, crate::Body::empty()),
Some(ConnectParts {
pending,
ping,
recv_stream: stream,
}),
)
};
let fut = H2Stream::new(service.call(req), connect_parts, respond);
exec.execute_h2stream(fut);
}
Some(Err(e)) => {
@@ -331,18 +364,28 @@ enum H2StreamState<F, B>
where
B: HttpBody,
{
Service(#[pin] F),
Service(#[pin] F, Option<ConnectParts>),
Body(#[pin] PipeToSendStream<B>),
}
struct ConnectParts {
pending: Pending,
ping: Recorder,
recv_stream: RecvStream,
}
impl<F, B> H2Stream<F, B>
where
B: HttpBody,
{
fn new(fut: F, respond: SendResponse<SendBuf<B::Data>>) -> H2Stream<F, B> {
fn new(
fut: F,
connect_parts: Option<ConnectParts>,
respond: SendResponse<SendBuf<B::Data>>,
) -> H2Stream<F, B> {
H2Stream {
reply: respond,
state: H2StreamState::Service(fut),
state: H2StreamState::Service(fut, connect_parts),
}
}
}
@@ -364,6 +407,7 @@ impl<F, B, E> H2Stream<F, B>
where
F: Future<Output = Result<Response<B>, E>>,
B: HttpBody,
B::Data: 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
E: Into<Box<dyn StdError + Send + Sync>>,
{
@@ -371,7 +415,7 @@ where
let mut me = self.project();
loop {
let next = match me.state.as_mut().project() {
H2StreamStateProj::Service(h) => {
H2StreamStateProj::Service(h, connect_parts) => {
let res = match h.poll(cx) {
Poll::Ready(Ok(r)) => r,
Poll::Pending => {
@@ -402,6 +446,29 @@ where
.entry(::http::header::DATE)
.or_insert_with(date::update_and_header_value);
if let Some(connect_parts) = connect_parts.take() {
if res.status().is_success() {
if headers::content_length_parse_all(res.headers())
.map_or(false, |len| len != 0)
{
warn!("h2 successful response to CONNECT request with body not supported");
me.reply.send_reset(h2::Reason::INTERNAL_ERROR);
return Poll::Ready(Err(crate::Error::new_user_header()));
}
let send_stream = reply!(me, res, false);
connect_parts.pending.fulfill(Upgraded::new(
H2Upgraded {
ping: connect_parts.ping,
recv_stream: connect_parts.recv_stream,
send_stream: unsafe { UpgradedSendStream::new(send_stream) },
buf: Bytes::new(),
},
Bytes::new(),
));
return Poll::Ready(Ok(()));
}
}
// automatically set Content-Length from body...
if let Some(len) = body.size_hint().exact() {
headers::set_content_length_if_missing(res.headers_mut(), len);
@@ -428,6 +495,7 @@ impl<F, B, E> Future for H2Stream<F, B>
where
F: Future<Output = Result<Response<B>, E>>,
B: HttpBody,
B::Data: 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
E: Into<Box<dyn StdError + Send + Sync>>,
{