feat(http2): Implement Client-side CONNECT support over HTTP/2 (#2523)

Closes #2508
This commit is contained in:
Anthony Ramine
2021-05-24 20:20:44 +02:00
committed by GitHub
parent be9677a1e7
commit 5442b6fadd
10 changed files with 833 additions and 78 deletions

View File

@@ -3,6 +3,17 @@ use std::fmt;
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) struct DecodedLength(u64);
#[cfg(any(feature = "http1", feature = "http2"))]
impl From<Option<u64>> for DecodedLength {
fn from(len: Option<u64>) -> Self {
len.and_then(|len| {
// If the length is u64::MAX, oh well, just reported chunked.
Self::checked_new(len).ok()
})
.unwrap_or(DecodedLength::CHUNKED)
}
}
#[cfg(any(feature = "http1", feature = "http2", test))]
const MAX_LEN: u64 = std::u64::MAX - 2;

View File

@@ -254,12 +254,9 @@ where
absolute_form(req.uri_mut());
} else {
origin_form(req.uri_mut());
};
}
} else if req.method() == Method::CONNECT {
debug!("client does not support CONNECT requests over HTTP2");
return Err(ClientError::Normal(
crate::Error::new_user_unsupported_request_method(),
));
authority_form(req.uri_mut());
}
let fut = pooled

View File

@@ -90,7 +90,7 @@ pub(super) enum User {
/// User tried to send a certain header in an unexpected context.
///
/// For example, sending both `content-length` and `transfer-encoding`.
#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
#[cfg(feature = "server")]
UnexpectedHeader,
/// User tried to create a Request with bad version.
@@ -290,7 +290,7 @@ impl Error {
Error::new(Kind::User(user))
}
#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
#[cfg(feature = "server")]
pub(super) fn new_user_header() -> Error {
Error::new_user(User::UnexpectedHeader)
@@ -405,7 +405,7 @@ impl Error {
Kind::User(User::MakeService) => "error from user's MakeService",
#[cfg(any(feature = "http1", feature = "http2"))]
Kind::User(User::Service) => "error from user's Service",
#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
#[cfg(feature = "server")]
Kind::User(User::UnexpectedHeader) => "user sent unexpected header",
#[cfg(any(feature = "http1", feature = "http2"))]

View File

@@ -2,17 +2,21 @@ use std::error::Error as StdError;
#[cfg(feature = "runtime")]
use std::time::Duration;
use bytes::Bytes;
use futures_channel::{mpsc, oneshot};
use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _};
use futures_util::stream::StreamExt as _;
use h2::client::{Builder, SendRequest};
use http::{Method, StatusCode};
use tokio::io::{AsyncRead, AsyncWrite};
use super::{decode_content_length, ping, PipeToSendStream, SendBuf};
use super::{ping, H2Upgraded, PipeToSendStream, SendBuf};
use crate::body::HttpBody;
use crate::common::{exec::Exec, task, Future, Never, Pin, Poll};
use crate::headers;
use crate::proto::h2::UpgradedSendStream;
use crate::proto::Dispatched;
use crate::upgrade::Upgraded;
use crate::{Body, Request, Response};
type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, Response<Body>>;
@@ -233,8 +237,25 @@ where
headers::set_content_length_if_missing(req.headers_mut(), len);
}
}
let is_connect = req.method() == Method::CONNECT;
let eos = body.is_end_stream();
let (fut, body_tx) = match self.h2_tx.send_request(req, eos) {
let ping = self.ping.clone();
if is_connect {
if headers::content_length_parse_all(req.headers())
.map_or(false, |len| len != 0)
{
warn!("h2 connect request with non-zero body not supported");
cb.send(Err((
crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()),
None,
)));
continue;
}
}
let (fut, body_tx) = match self.h2_tx.send_request(req, !is_connect && eos) {
Ok(ok) => ok,
Err(err) => {
debug!("client send request error: {}", err);
@@ -243,45 +264,81 @@ where
}
};
let ping = self.ping.clone();
if !eos {
let mut pipe = Box::pin(PipeToSendStream::new(body, body_tx)).map(|res| {
if let Err(e) = res {
debug!("client request body error: {}", e);
}
});
// eagerly see if the body pipe is ready and
// can thus skip allocating in the executor
match Pin::new(&mut pipe).poll(cx) {
Poll::Ready(_) => (),
Poll::Pending => {
let conn_drop_ref = self.conn_drop_ref.clone();
// keep the ping recorder's knowledge of an
// "open stream" alive while this body is
// still sending...
let ping = ping.clone();
let pipe = pipe.map(move |x| {
drop(conn_drop_ref);
drop(ping);
x
let send_stream = if !is_connect {
if !eos {
let mut pipe =
Box::pin(PipeToSendStream::new(body, body_tx)).map(|res| {
if let Err(e) = res {
debug!("client request body error: {}", e);
}
});
self.executor.execute(pipe);
// eagerly see if the body pipe is ready and
// can thus skip allocating in the executor
match Pin::new(&mut pipe).poll(cx) {
Poll::Ready(_) => (),
Poll::Pending => {
let conn_drop_ref = self.conn_drop_ref.clone();
// keep the ping recorder's knowledge of an
// "open stream" alive while this body is
// still sending...
let ping = ping.clone();
let pipe = pipe.map(move |x| {
drop(conn_drop_ref);
drop(ping);
x
});
self.executor.execute(pipe);
}
}
}
}
None
} else {
Some(body_tx)
};
let fut = fut.map(move |result| match result {
Ok(res) => {
// record that we got the response headers
ping.record_non_data();
let content_length = decode_content_length(res.headers());
let res = res.map(|stream| {
let ping = ping.for_stream(&stream);
crate::Body::h2(stream, content_length, ping)
});
Ok(res)
let content_length = headers::content_length_parse_all(res.headers());
if let (Some(mut send_stream), StatusCode::OK) =
(send_stream, res.status())
{
if content_length.map_or(false, |len| len != 0) {
warn!("h2 connect response with non-zero body not supported");
send_stream.send_reset(h2::Reason::INTERNAL_ERROR);
return Err((
crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()),
None,
));
}
let (parts, recv_stream) = res.into_parts();
let mut res = Response::from_parts(parts, Body::empty());
let (pending, on_upgrade) = crate::upgrade::pending();
let io = H2Upgraded {
ping,
send_stream: unsafe { UpgradedSendStream::new(send_stream) },
recv_stream,
buf: Bytes::new(),
};
let upgraded = Upgraded::new(io, Bytes::new());
pending.fulfill(upgraded);
res.extensions_mut().insert(on_upgrade);
Ok(res)
} else {
let res = res.map(|stream| {
let ping = ping.for_stream(&stream);
crate::Body::h2(stream, content_length.into(), ping)
});
Ok(res)
}
}
Err(err) => {
ping.ensure_not_timed_out().map_err(|e| (e, None))?;

View File

@@ -1,5 +1,5 @@
use bytes::Buf;
use h2::SendStream;
use bytes::{Buf, Bytes};
use h2::{RecvStream, SendStream};
use http::header::{
HeaderName, CONNECTION, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER,
TRANSFER_ENCODING, UPGRADE,
@@ -7,11 +7,14 @@ use http::header::{
use http::HeaderMap;
use pin_project::pin_project;
use std::error::Error as StdError;
use std::io::IoSlice;
use std::io::{self, Cursor, IoSlice};
use std::mem;
use std::task::Context;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::body::{DecodedLength, HttpBody};
use crate::body::HttpBody;
use crate::common::{task, Future, Pin, Poll};
use crate::headers::content_length_parse_all;
use crate::proto::h2::ping::Recorder;
pub(crate) mod ping;
@@ -83,15 +86,6 @@ fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) {
}
}
fn decode_content_length(headers: &HeaderMap) -> DecodedLength {
if let Some(len) = content_length_parse_all(headers) {
// If the length is u64::MAX, oh well, just reported chunked.
DecodedLength::checked_new(len).unwrap_or_else(|_| DecodedLength::CHUNKED)
} else {
DecodedLength::CHUNKED
}
}
// body adapters used by both Client and Server
#[pin_project]
@@ -172,7 +166,7 @@ where
is_eos,
);
let buf = SendBuf(Some(chunk));
let buf = SendBuf::Buf(chunk);
me.body_tx
.send_data(buf, is_eos)
.map_err(crate::Error::new_body_write)?;
@@ -243,32 +237,202 @@ impl<B: Buf> SendStreamExt for SendStream<SendBuf<B>> {
fn send_eos_frame(&mut self) -> crate::Result<()> {
trace!("send body eos");
self.send_data(SendBuf(None), true)
self.send_data(SendBuf::None, true)
.map_err(crate::Error::new_body_write)
}
}
struct SendBuf<B>(Option<B>);
#[repr(usize)]
enum SendBuf<B> {
Buf(B),
Cursor(Cursor<Box<[u8]>>),
None,
}
impl<B: Buf> Buf for SendBuf<B> {
#[inline]
fn remaining(&self) -> usize {
self.0.as_ref().map(|b| b.remaining()).unwrap_or(0)
match *self {
Self::Buf(ref b) => b.remaining(),
Self::Cursor(ref c) => c.remaining(),
Self::None => 0,
}
}
#[inline]
fn chunk(&self) -> &[u8] {
self.0.as_ref().map(|b| b.chunk()).unwrap_or(&[])
match *self {
Self::Buf(ref b) => b.chunk(),
Self::Cursor(ref c) => c.chunk(),
Self::None => &[],
}
}
#[inline]
fn advance(&mut self, cnt: usize) {
if let Some(b) = self.0.as_mut() {
b.advance(cnt)
match *self {
Self::Buf(ref mut b) => b.advance(cnt),
Self::Cursor(ref mut c) => c.advance(cnt),
Self::None => {}
}
}
fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
self.0.as_ref().map(|b| b.chunks_vectored(dst)).unwrap_or(0)
match *self {
Self::Buf(ref b) => b.chunks_vectored(dst),
Self::Cursor(ref c) => c.chunks_vectored(dst),
Self::None => 0,
}
}
}
struct H2Upgraded<B>
where
B: Buf,
{
ping: Recorder,
send_stream: UpgradedSendStream<B>,
recv_stream: RecvStream,
buf: Bytes,
}
impl<B> AsyncRead for H2Upgraded<B>
where
B: Buf,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
read_buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), io::Error>> {
if self.buf.is_empty() {
self.buf = loop {
match ready!(self.recv_stream.poll_data(cx)) {
None => return Poll::Ready(Ok(())),
Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => {
continue
}
Some(Ok(buf)) => {
self.ping.record_data(buf.len());
break buf;
}
Some(Err(e)) => {
return Poll::Ready(Err(h2_to_io_error(e)));
}
}
};
}
let cnt = std::cmp::min(self.buf.len(), read_buf.remaining());
read_buf.put_slice(&self.buf[..cnt]);
self.buf.advance(cnt);
let _ = self.recv_stream.flow_control().release_capacity(cnt);
Poll::Ready(Ok(()))
}
}
impl<B> AsyncWrite for H2Upgraded<B>
where
B: Buf,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
if let Poll::Ready(reset) = self.send_stream.poll_reset(cx) {
return Poll::Ready(Err(h2_to_io_error(match reset {
Ok(reason) => reason.into(),
Err(e) => e,
})));
}
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
self.send_stream.reserve_capacity(buf.len());
Poll::Ready(match ready!(self.send_stream.poll_capacity(cx)) {
None => Ok(0),
Some(Ok(cnt)) => self.send_stream.write(&buf[..cnt], false).map(|()| cnt),
Some(Err(e)) => Err(h2_to_io_error(e)),
})
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Ready(self.send_stream.write(&[], true))
}
}
fn h2_to_io_error(e: h2::Error) -> io::Error {
if e.is_io() {
e.into_io().unwrap()
} else {
io::Error::new(io::ErrorKind::Other, e)
}
}
struct UpgradedSendStream<B>(SendStream<SendBuf<Neutered<B>>>);
impl<B> UpgradedSendStream<B>
where
B: Buf,
{
unsafe fn new(inner: SendStream<SendBuf<B>>) -> Self {
assert_eq!(mem::size_of::<B>(), mem::size_of::<Neutered<B>>());
Self(mem::transmute(inner))
}
fn reserve_capacity(&mut self, cnt: usize) {
unsafe { self.as_inner_unchecked().reserve_capacity(cnt) }
}
fn poll_capacity(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<usize, h2::Error>>> {
unsafe { self.as_inner_unchecked().poll_capacity(cx) }
}
fn poll_reset(&mut self, cx: &mut Context<'_>) -> Poll<Result<h2::Reason, h2::Error>> {
unsafe { self.as_inner_unchecked().poll_reset(cx) }
}
fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> {
let send_buf = SendBuf::Cursor(Cursor::new(buf.into()));
unsafe {
self.as_inner_unchecked()
.send_data(send_buf, end_of_stream)
.map_err(h2_to_io_error)
}
}
unsafe fn as_inner_unchecked(&mut self) -> &mut SendStream<SendBuf<B>> {
&mut *(&mut self.0 as *mut _ as *mut _)
}
}
#[repr(transparent)]
struct Neutered<B> {
_inner: B,
impossible: Impossible,
}
enum Impossible {}
unsafe impl<B> Send for Neutered<B> {}
impl<B> Buf for Neutered<B> {
fn remaining(&self) -> usize {
match self.impossible {}
}
fn chunk(&self) -> &[u8] {
match self.impossible {}
}
fn advance(&mut self, _cnt: usize) {
match self.impossible {}
}
}

View File

@@ -3,19 +3,24 @@ use std::marker::Unpin;
#[cfg(feature = "runtime")]
use std::time::Duration;
use bytes::Bytes;
use h2::server::{Connection, Handshake, SendResponse};
use h2::Reason;
use h2::{Reason, RecvStream};
use http::{Method, Request};
use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite};
use super::{decode_content_length, ping, PipeToSendStream, SendBuf};
use super::{ping, PipeToSendStream, SendBuf};
use crate::body::HttpBody;
use crate::common::exec::ConnStreamExec;
use crate::common::{date, task, Future, Pin, Poll};
use crate::headers;
use crate::proto::h2::ping::Recorder;
use crate::proto::h2::{H2Upgraded, UpgradedSendStream};
use crate::proto::Dispatched;
use crate::service::HttpService;
use crate::upgrade::{OnUpgrade, Pending, Upgraded};
use crate::{Body, Response};
// Our defaults are chosen for the "majority" case, which usually are not
@@ -255,9 +260,9 @@ where
// When the service is ready, accepts an incoming request.
match ready!(self.conn.poll_accept(cx)) {
Some(Ok((req, respond))) => {
Some(Ok((req, mut respond))) => {
trace!("incoming request");
let content_length = decode_content_length(req.headers());
let content_length = headers::content_length_parse_all(req.headers());
let ping = self
.ping
.as_ref()
@@ -267,8 +272,36 @@ where
// Record the headers received
ping.record_non_data();
let req = req.map(|stream| crate::Body::h2(stream, content_length, ping));
let fut = H2Stream::new(service.call(req), respond);
let is_connect = req.method() == Method::CONNECT;
let (mut parts, stream) = req.into_parts();
let (req, connect_parts) = if !is_connect {
(
Request::from_parts(
parts,
crate::Body::h2(stream, content_length.into(), ping),
),
None,
)
} else {
if content_length.map_or(false, |len| len != 0) {
warn!("h2 connect request with non-zero body not supported");
respond.send_reset(h2::Reason::INTERNAL_ERROR);
return Poll::Ready(Ok(()));
}
let (pending, upgrade) = crate::upgrade::pending();
debug_assert!(parts.extensions.get::<OnUpgrade>().is_none());
parts.extensions.insert(upgrade);
(
Request::from_parts(parts, crate::Body::empty()),
Some(ConnectParts {
pending,
ping,
recv_stream: stream,
}),
)
};
let fut = H2Stream::new(service.call(req), connect_parts, respond);
exec.execute_h2stream(fut);
}
Some(Err(e)) => {
@@ -331,18 +364,28 @@ enum H2StreamState<F, B>
where
B: HttpBody,
{
Service(#[pin] F),
Service(#[pin] F, Option<ConnectParts>),
Body(#[pin] PipeToSendStream<B>),
}
struct ConnectParts {
pending: Pending,
ping: Recorder,
recv_stream: RecvStream,
}
impl<F, B> H2Stream<F, B>
where
B: HttpBody,
{
fn new(fut: F, respond: SendResponse<SendBuf<B::Data>>) -> H2Stream<F, B> {
fn new(
fut: F,
connect_parts: Option<ConnectParts>,
respond: SendResponse<SendBuf<B::Data>>,
) -> H2Stream<F, B> {
H2Stream {
reply: respond,
state: H2StreamState::Service(fut),
state: H2StreamState::Service(fut, connect_parts),
}
}
}
@@ -364,6 +407,7 @@ impl<F, B, E> H2Stream<F, B>
where
F: Future<Output = Result<Response<B>, E>>,
B: HttpBody,
B::Data: 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
E: Into<Box<dyn StdError + Send + Sync>>,
{
@@ -371,7 +415,7 @@ where
let mut me = self.project();
loop {
let next = match me.state.as_mut().project() {
H2StreamStateProj::Service(h) => {
H2StreamStateProj::Service(h, connect_parts) => {
let res = match h.poll(cx) {
Poll::Ready(Ok(r)) => r,
Poll::Pending => {
@@ -402,6 +446,29 @@ where
.entry(::http::header::DATE)
.or_insert_with(date::update_and_header_value);
if let Some(connect_parts) = connect_parts.take() {
if res.status().is_success() {
if headers::content_length_parse_all(res.headers())
.map_or(false, |len| len != 0)
{
warn!("h2 successful response to CONNECT request with body not supported");
me.reply.send_reset(h2::Reason::INTERNAL_ERROR);
return Poll::Ready(Err(crate::Error::new_user_header()));
}
let send_stream = reply!(me, res, false);
connect_parts.pending.fulfill(Upgraded::new(
H2Upgraded {
ping: connect_parts.ping,
recv_stream: connect_parts.recv_stream,
send_stream: unsafe { UpgradedSendStream::new(send_stream) },
buf: Bytes::new(),
},
Bytes::new(),
));
return Poll::Ready(Ok(()));
}
}
// automatically set Content-Length from body...
if let Some(len) = body.size_hint().exact() {
headers::set_content_length_if_missing(res.headers_mut(), len);
@@ -428,6 +495,7 @@ impl<F, B, E> Future for H2Stream<F, B>
where
F: Future<Output = Result<Response<B>, E>>,
B: HttpBody,
B::Data: 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
E: Into<Box<dyn StdError + Send + Sync>>,
{

View File

@@ -62,12 +62,12 @@ pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
msg.on_upgrade()
}
#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
pub(super) struct Pending {
tx: oneshot::Sender<crate::Result<Upgraded>>,
}
#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
pub(super) fn pending() -> (Pending, OnUpgrade) {
let (tx, rx) = oneshot::channel();
(Pending { tx }, OnUpgrade { rx: Some(rx) })
@@ -76,7 +76,7 @@ pub(super) fn pending() -> (Pending, OnUpgrade) {
// ===== impl Upgraded =====
impl Upgraded {
#[cfg(any(feature = "http1", test))]
#[cfg(any(feature = "http1", feature = "http2", test))]
pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
@@ -187,13 +187,14 @@ impl fmt::Debug for OnUpgrade {
// ===== impl Pending =====
#[cfg(feature = "http1")]
#[cfg(any(feature = "http1", feature = "http2"))]
impl Pending {
pub(super) fn fulfill(self, upgraded: Upgraded) {
trace!("pending upgrade fulfill");
let _ = self.tx.send(Ok(upgraded));
}
#[cfg(feature = "http1")]
/// Don't fulfill the pending Upgrade, but instead signal that
/// upgrades are handled manually.
pub(super) fn manual(self) {