refactor(lib): use dedicated enums for connection protocol versions
This should make it easier to add H3 functionality.
This commit is contained in:
committed by
Sean McArthur
parent
e6027bc02d
commit
de5dcd7865
@@ -13,6 +13,7 @@ use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures_util::future::{self, Either, FutureExt as _};
|
||||
use pin_project::{pin_project, project};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use tower_service::Service;
|
||||
|
||||
@@ -29,10 +30,15 @@ type Http1Dispatcher<T, B, R> = proto::dispatch::Dispatcher<
|
||||
T,
|
||||
R,
|
||||
>;
|
||||
type ConnEither<T, B> = Either<
|
||||
Http1Dispatcher<T, B, proto::h1::ClientTransaction>,
|
||||
proto::h2::ClientTask<B>,
|
||||
>;
|
||||
|
||||
#[pin_project]
|
||||
enum ProtoClient<T, B>
|
||||
where
|
||||
B: Payload,
|
||||
{
|
||||
H1(#[pin] Http1Dispatcher<T, B, proto::h1::ClientTransaction>),
|
||||
H2(#[pin] proto::h2::ClientTask<B>),
|
||||
}
|
||||
|
||||
// Our defaults are chosen for the "majority" case, which usually are not
|
||||
// resource contrained, and so the spec default of 64kb can be too limiting
|
||||
@@ -70,7 +76,7 @@ where
|
||||
T: AsyncRead + AsyncWrite + Send + 'static,
|
||||
B: Payload + 'static,
|
||||
{
|
||||
inner: Option<ConnEither<T, B>>,
|
||||
inner: Option<ProtoClient<T, B>>,
|
||||
}
|
||||
|
||||
|
||||
@@ -342,8 +348,8 @@ where
|
||||
/// Only works for HTTP/1 connections. HTTP/2 connections will panic.
|
||||
pub fn into_parts(self) -> Parts<T> {
|
||||
let (io, read_buf, _) = match self.inner.expect("already upgraded") {
|
||||
Either::Left(h1) => h1.into_inner(),
|
||||
Either::Right(_h2) => {
|
||||
ProtoClient::H1(h1) => h1.into_inner(),
|
||||
ProtoClient::H2(_h2) => {
|
||||
panic!("http2 cannot into_inner");
|
||||
}
|
||||
};
|
||||
@@ -368,10 +374,10 @@ where
|
||||
/// to work with this function; or use the `without_shutdown` wrapper.
|
||||
pub fn poll_without_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
|
||||
match self.inner.as_mut().expect("already upgraded") {
|
||||
&mut Either::Left(ref mut h1) => {
|
||||
&mut ProtoClient::H1(ref mut h1) => {
|
||||
h1.poll_without_shutdown(cx)
|
||||
},
|
||||
&mut Either::Right(ref mut h2) => {
|
||||
&mut ProtoClient::H2(ref mut h2) => {
|
||||
Pin::new(h2).poll(cx).map_ok(|_| ())
|
||||
}
|
||||
}
|
||||
@@ -403,7 +409,7 @@ where
|
||||
},
|
||||
proto::Dispatched::Upgrade(pending) => {
|
||||
let h1 = match mem::replace(&mut self.inner, None) {
|
||||
Some(Either::Left(h1)) => h1,
|
||||
Some(ProtoClient::H1(h1)) => h1,
|
||||
_ => unreachable!("Upgrade expects h1"),
|
||||
};
|
||||
|
||||
@@ -534,7 +540,7 @@ impl Builder {
|
||||
trace!("client handshake HTTP/{}", if opts.http2 { 2 } else { 1 });
|
||||
|
||||
let (tx, rx) = dispatch::channel();
|
||||
let either = if !opts.http2 {
|
||||
let proto = if !opts.http2 {
|
||||
let mut conn = proto::Conn::new(io);
|
||||
if !opts.h1_writev {
|
||||
conn.set_write_strategy_flatten();
|
||||
@@ -550,11 +556,11 @@ impl Builder {
|
||||
}
|
||||
let cd = proto::h1::dispatch::Client::new(rx);
|
||||
let dispatch = proto::h1::Dispatcher::new(cd, conn);
|
||||
Either::Left(dispatch)
|
||||
ProtoClient::H1(dispatch)
|
||||
} else {
|
||||
let h2 = proto::h2::client::handshake(io, rx, &opts.h2_builder, opts.exec.clone())
|
||||
.await?;
|
||||
Either::Right(h2)
|
||||
ProtoClient::H2(h2)
|
||||
};
|
||||
|
||||
Ok((
|
||||
@@ -562,7 +568,7 @@ impl Builder {
|
||||
dispatch: tx,
|
||||
},
|
||||
Connection {
|
||||
inner: Some(either),
|
||||
inner: Some(proto),
|
||||
},
|
||||
))
|
||||
}
|
||||
@@ -598,6 +604,26 @@ impl fmt::Debug for ResponseFuture {
|
||||
}
|
||||
}
|
||||
|
||||
// ===== impl ProtoClient
|
||||
|
||||
impl<T, B> Future for ProtoClient<T, B>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
B: Payload + Unpin + 'static,
|
||||
B::Data: Unpin,
|
||||
{
|
||||
type Output = crate::Result<proto::Dispatched>;
|
||||
|
||||
#[project]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
|
||||
#[project]
|
||||
match self.project() {
|
||||
ProtoClient::H1(c) => c.poll(cx),
|
||||
ProtoClient::H2(c) => c.poll(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// assert trait markers
|
||||
|
||||
trait AssertSend: Send {}
|
||||
|
||||
@@ -119,27 +119,26 @@ pub struct Connection<T, S, E = Exec>
|
||||
where
|
||||
S: HttpService<Body>,
|
||||
{
|
||||
pub(super) conn: Option<Either<
|
||||
proto::h1::Dispatcher<
|
||||
proto::h1::dispatch::Server<S, Body>,
|
||||
S::ResBody,
|
||||
T,
|
||||
proto::ServerTransaction,
|
||||
>,
|
||||
proto::h2::Server<
|
||||
Rewind<T>,
|
||||
S,
|
||||
S::ResBody,
|
||||
E,
|
||||
>,
|
||||
>>,
|
||||
pub(super) conn: Option<ProtoServer<T, S::ResBody, S, E>>,
|
||||
fallback: Fallback<E>,
|
||||
}
|
||||
|
||||
#[pin_project]
|
||||
pub(super) enum Either<A, B> {
|
||||
A(#[pin] A),
|
||||
B(#[pin] B),
|
||||
pub(super) enum ProtoServer<T, B, S, E = Exec>
|
||||
where
|
||||
S: HttpService<Body>,
|
||||
B: Payload,
|
||||
{
|
||||
H1(
|
||||
#[pin]
|
||||
proto::h1::Dispatcher<
|
||||
proto::h1::dispatch::Server<S, Body>,
|
||||
B,
|
||||
T,
|
||||
proto::ServerTransaction,
|
||||
>,
|
||||
),
|
||||
H2(#[pin] proto::h2::Server<Rewind<T>, S, B, E>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -384,7 +383,7 @@ impl<E> Http<E> {
|
||||
I: AsyncRead + AsyncWrite + Unpin,
|
||||
E: H2Exec<S::Future, Bd>,
|
||||
{
|
||||
let either = match self.mode {
|
||||
let proto = match self.mode {
|
||||
ConnectionMode::H1Only | ConnectionMode::Fallback => {
|
||||
let mut conn = proto::Conn::new(io);
|
||||
if !self.keep_alive {
|
||||
@@ -401,17 +400,17 @@ impl<E> Http<E> {
|
||||
conn.set_max_buf_size(max);
|
||||
}
|
||||
let sd = proto::h1::dispatch::Server::new(service);
|
||||
Either::A(proto::h1::Dispatcher::new(sd, conn))
|
||||
ProtoServer::H1(proto::h1::Dispatcher::new(sd, conn))
|
||||
}
|
||||
ConnectionMode::H2Only => {
|
||||
let rewind_io = Rewind::new(io);
|
||||
let h2 = proto::h2::Server::new(rewind_io, service, &self.h2_builder, self.exec.clone());
|
||||
Either::B(h2)
|
||||
ProtoServer::H2(h2)
|
||||
}
|
||||
};
|
||||
|
||||
Connection {
|
||||
conn: Some(either),
|
||||
conn: Some(proto),
|
||||
fallback: if self.mode == ConnectionMode::Fallback {
|
||||
Fallback::ToHttp2(self.h2_builder.clone(), self.exec.clone())
|
||||
} else {
|
||||
@@ -528,10 +527,10 @@ where
|
||||
/// can finish.
|
||||
pub fn graceful_shutdown(self: Pin<&mut Self>) {
|
||||
match self.project().conn.as_mut().unwrap() {
|
||||
Either::A(ref mut h1) => {
|
||||
ProtoServer::H1(ref mut h1) => {
|
||||
h1.disable_keep_alive();
|
||||
},
|
||||
Either::B(ref mut h2) => {
|
||||
ProtoServer::H2(ref mut h2) => {
|
||||
h2.graceful_shutdown();
|
||||
}
|
||||
}
|
||||
@@ -555,7 +554,7 @@ where
|
||||
/// This method will return a `None` if this connection is using an h2 protocol.
|
||||
pub fn try_into_parts(self) -> Option<Parts<I, S>> {
|
||||
match self.conn.unwrap() {
|
||||
Either::A(h1) => {
|
||||
ProtoServer::H1(h1) => {
|
||||
let (io, read_buf, dispatch) = h1.into_inner();
|
||||
Some(Parts {
|
||||
io: io,
|
||||
@@ -564,7 +563,7 @@ where
|
||||
_inner: (),
|
||||
})
|
||||
},
|
||||
Either::B(_h2) => None,
|
||||
ProtoServer::H2(_h2) => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -587,8 +586,8 @@ where
|
||||
{
|
||||
loop {
|
||||
let polled = match *self.conn.as_mut().unwrap() {
|
||||
Either::A(ref mut h1) => h1.poll_without_shutdown(cx),
|
||||
Either::B(ref mut h2) => return Pin::new(h2).poll(cx).map_ok(|_| ()),
|
||||
ProtoServer::H1(ref mut h1) => h1.poll_without_shutdown(cx),
|
||||
ProtoServer::H2(ref mut h2) => return Pin::new(h2).poll(cx).map_ok(|_| ()),
|
||||
};
|
||||
match ready!(polled) {
|
||||
Ok(x) => return Poll::Ready(Ok(x)),
|
||||
@@ -625,10 +624,10 @@ where
|
||||
let conn = self.conn.take();
|
||||
|
||||
let (io, read_buf, dispatch) = match conn.unwrap() {
|
||||
Either::A(h1) => {
|
||||
ProtoServer::H1(h1) => {
|
||||
h1.into_inner()
|
||||
},
|
||||
Either::B(_h2) => {
|
||||
ProtoServer::H2(_h2) => {
|
||||
panic!("h2 cannot into_inner");
|
||||
}
|
||||
};
|
||||
@@ -646,7 +645,7 @@ where
|
||||
);
|
||||
|
||||
debug_assert!(self.conn.is_none());
|
||||
self.conn = Some(Either::B(h2));
|
||||
self.conn = Some(ProtoServer::H2(h2));
|
||||
}
|
||||
|
||||
/// Enable this connection to support higher-level HTTP upgrades.
|
||||
@@ -852,20 +851,25 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// ===== impl ProtoServer =====
|
||||
|
||||
impl<A, B> Future for Either<A, B>
|
||||
impl<T, B, S, E> Future for ProtoServer<T, B, S, E>
|
||||
where
|
||||
A: Future,
|
||||
B: Future<Output=A::Output>,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
S: HttpService<Body, ResBody = B>,
|
||||
S::Error: Into<Box<dyn StdError + Send + Sync>>,
|
||||
B: Payload,
|
||||
B::Data: Unpin,
|
||||
E: H2Exec<S::Future, B>,
|
||||
{
|
||||
type Output = A::Output;
|
||||
type Output = crate::Result<proto::Dispatched>;
|
||||
|
||||
#[project]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
|
||||
#[project]
|
||||
match self.project() {
|
||||
Either::A(a) => a.poll(cx),
|
||||
Either::B(b) => b.poll(cx),
|
||||
ProtoServer::H1(s) => s.poll(cx),
|
||||
ProtoServer::H2(s) => s.poll(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1050,7 +1054,7 @@ mod upgrades {
|
||||
Ok(proto::Dispatched::Shutdown) => return Poll::Ready(Ok(())),
|
||||
Ok(proto::Dispatched::Upgrade(pending)) => {
|
||||
let h1 = match mem::replace(&mut self.inner.conn, None) {
|
||||
Some(Either::A(h1)) => h1,
|
||||
Some(ProtoServer::H1(h1)) => h1,
|
||||
_ => unreachable!("Upgrade expects h1"),
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user