refactor(lib): use dedicated enums for connection protocol versions

This should make it easier to add H3 functionality.
This commit is contained in:
Dirkjan Ochtman
2019-10-30 22:03:29 +01:00
committed by Sean McArthur
parent e6027bc02d
commit de5dcd7865
2 changed files with 81 additions and 51 deletions

View File

@@ -13,6 +13,7 @@ use std::sync::Arc;
use bytes::Bytes;
use futures_util::future::{self, Either, FutureExt as _};
use pin_project::{pin_project, project};
use tokio_io::{AsyncRead, AsyncWrite};
use tower_service::Service;
@@ -29,10 +30,15 @@ type Http1Dispatcher<T, B, R> = proto::dispatch::Dispatcher<
T,
R,
>;
type ConnEither<T, B> = Either<
Http1Dispatcher<T, B, proto::h1::ClientTransaction>,
proto::h2::ClientTask<B>,
>;
#[pin_project]
enum ProtoClient<T, B>
where
B: Payload,
{
H1(#[pin] Http1Dispatcher<T, B, proto::h1::ClientTransaction>),
H2(#[pin] proto::h2::ClientTask<B>),
}
// Our defaults are chosen for the "majority" case, which usually are not
// resource contrained, and so the spec default of 64kb can be too limiting
@@ -70,7 +76,7 @@ where
T: AsyncRead + AsyncWrite + Send + 'static,
B: Payload + 'static,
{
inner: Option<ConnEither<T, B>>,
inner: Option<ProtoClient<T, B>>,
}
@@ -342,8 +348,8 @@ where
/// Only works for HTTP/1 connections. HTTP/2 connections will panic.
pub fn into_parts(self) -> Parts<T> {
let (io, read_buf, _) = match self.inner.expect("already upgraded") {
Either::Left(h1) => h1.into_inner(),
Either::Right(_h2) => {
ProtoClient::H1(h1) => h1.into_inner(),
ProtoClient::H2(_h2) => {
panic!("http2 cannot into_inner");
}
};
@@ -368,10 +374,10 @@ where
/// to work with this function; or use the `without_shutdown` wrapper.
pub fn poll_without_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
match self.inner.as_mut().expect("already upgraded") {
&mut Either::Left(ref mut h1) => {
&mut ProtoClient::H1(ref mut h1) => {
h1.poll_without_shutdown(cx)
},
&mut Either::Right(ref mut h2) => {
&mut ProtoClient::H2(ref mut h2) => {
Pin::new(h2).poll(cx).map_ok(|_| ())
}
}
@@ -403,7 +409,7 @@ where
},
proto::Dispatched::Upgrade(pending) => {
let h1 = match mem::replace(&mut self.inner, None) {
Some(Either::Left(h1)) => h1,
Some(ProtoClient::H1(h1)) => h1,
_ => unreachable!("Upgrade expects h1"),
};
@@ -534,7 +540,7 @@ impl Builder {
trace!("client handshake HTTP/{}", if opts.http2 { 2 } else { 1 });
let (tx, rx) = dispatch::channel();
let either = if !opts.http2 {
let proto = if !opts.http2 {
let mut conn = proto::Conn::new(io);
if !opts.h1_writev {
conn.set_write_strategy_flatten();
@@ -550,11 +556,11 @@ impl Builder {
}
let cd = proto::h1::dispatch::Client::new(rx);
let dispatch = proto::h1::Dispatcher::new(cd, conn);
Either::Left(dispatch)
ProtoClient::H1(dispatch)
} else {
let h2 = proto::h2::client::handshake(io, rx, &opts.h2_builder, opts.exec.clone())
.await?;
Either::Right(h2)
ProtoClient::H2(h2)
};
Ok((
@@ -562,7 +568,7 @@ impl Builder {
dispatch: tx,
},
Connection {
inner: Some(either),
inner: Some(proto),
},
))
}
@@ -598,6 +604,26 @@ impl fmt::Debug for ResponseFuture {
}
}
// ===== impl ProtoClient
impl<T, B> Future for ProtoClient<T, B>
where
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
B: Payload + Unpin + 'static,
B::Data: Unpin,
{
type Output = crate::Result<proto::Dispatched>;
#[project]
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
#[project]
match self.project() {
ProtoClient::H1(c) => c.poll(cx),
ProtoClient::H2(c) => c.poll(cx),
}
}
}
// assert trait markers
trait AssertSend: Send {}

View File

@@ -119,27 +119,26 @@ pub struct Connection<T, S, E = Exec>
where
S: HttpService<Body>,
{
pub(super) conn: Option<Either<
proto::h1::Dispatcher<
proto::h1::dispatch::Server<S, Body>,
S::ResBody,
T,
proto::ServerTransaction,
>,
proto::h2::Server<
Rewind<T>,
S,
S::ResBody,
E,
>,
>>,
pub(super) conn: Option<ProtoServer<T, S::ResBody, S, E>>,
fallback: Fallback<E>,
}
#[pin_project]
pub(super) enum Either<A, B> {
A(#[pin] A),
B(#[pin] B),
pub(super) enum ProtoServer<T, B, S, E = Exec>
where
S: HttpService<Body>,
B: Payload,
{
H1(
#[pin]
proto::h1::Dispatcher<
proto::h1::dispatch::Server<S, Body>,
B,
T,
proto::ServerTransaction,
>,
),
H2(#[pin] proto::h2::Server<Rewind<T>, S, B, E>),
}
#[derive(Clone, Debug)]
@@ -384,7 +383,7 @@ impl<E> Http<E> {
I: AsyncRead + AsyncWrite + Unpin,
E: H2Exec<S::Future, Bd>,
{
let either = match self.mode {
let proto = match self.mode {
ConnectionMode::H1Only | ConnectionMode::Fallback => {
let mut conn = proto::Conn::new(io);
if !self.keep_alive {
@@ -401,17 +400,17 @@ impl<E> Http<E> {
conn.set_max_buf_size(max);
}
let sd = proto::h1::dispatch::Server::new(service);
Either::A(proto::h1::Dispatcher::new(sd, conn))
ProtoServer::H1(proto::h1::Dispatcher::new(sd, conn))
}
ConnectionMode::H2Only => {
let rewind_io = Rewind::new(io);
let h2 = proto::h2::Server::new(rewind_io, service, &self.h2_builder, self.exec.clone());
Either::B(h2)
ProtoServer::H2(h2)
}
};
Connection {
conn: Some(either),
conn: Some(proto),
fallback: if self.mode == ConnectionMode::Fallback {
Fallback::ToHttp2(self.h2_builder.clone(), self.exec.clone())
} else {
@@ -528,10 +527,10 @@ where
/// can finish.
pub fn graceful_shutdown(self: Pin<&mut Self>) {
match self.project().conn.as_mut().unwrap() {
Either::A(ref mut h1) => {
ProtoServer::H1(ref mut h1) => {
h1.disable_keep_alive();
},
Either::B(ref mut h2) => {
ProtoServer::H2(ref mut h2) => {
h2.graceful_shutdown();
}
}
@@ -555,7 +554,7 @@ where
/// This method will return a `None` if this connection is using an h2 protocol.
pub fn try_into_parts(self) -> Option<Parts<I, S>> {
match self.conn.unwrap() {
Either::A(h1) => {
ProtoServer::H1(h1) => {
let (io, read_buf, dispatch) = h1.into_inner();
Some(Parts {
io: io,
@@ -564,7 +563,7 @@ where
_inner: (),
})
},
Either::B(_h2) => None,
ProtoServer::H2(_h2) => None,
}
}
@@ -587,8 +586,8 @@ where
{
loop {
let polled = match *self.conn.as_mut().unwrap() {
Either::A(ref mut h1) => h1.poll_without_shutdown(cx),
Either::B(ref mut h2) => return Pin::new(h2).poll(cx).map_ok(|_| ()),
ProtoServer::H1(ref mut h1) => h1.poll_without_shutdown(cx),
ProtoServer::H2(ref mut h2) => return Pin::new(h2).poll(cx).map_ok(|_| ()),
};
match ready!(polled) {
Ok(x) => return Poll::Ready(Ok(x)),
@@ -625,10 +624,10 @@ where
let conn = self.conn.take();
let (io, read_buf, dispatch) = match conn.unwrap() {
Either::A(h1) => {
ProtoServer::H1(h1) => {
h1.into_inner()
},
Either::B(_h2) => {
ProtoServer::H2(_h2) => {
panic!("h2 cannot into_inner");
}
};
@@ -646,7 +645,7 @@ where
);
debug_assert!(self.conn.is_none());
self.conn = Some(Either::B(h2));
self.conn = Some(ProtoServer::H2(h2));
}
/// Enable this connection to support higher-level HTTP upgrades.
@@ -852,20 +851,25 @@ where
}
}
// ===== impl ProtoServer =====
impl<A, B> Future for Either<A, B>
impl<T, B, S, E> Future for ProtoServer<T, B, S, E>
where
A: Future,
B: Future<Output=A::Output>,
T: AsyncRead + AsyncWrite + Unpin,
S: HttpService<Body, ResBody = B>,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Payload,
B::Data: Unpin,
E: H2Exec<S::Future, B>,
{
type Output = A::Output;
type Output = crate::Result<proto::Dispatched>;
#[project]
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
#[project]
match self.project() {
Either::A(a) => a.poll(cx),
Either::B(b) => b.poll(cx),
ProtoServer::H1(s) => s.poll(cx),
ProtoServer::H2(s) => s.poll(cx),
}
}
}
@@ -1050,7 +1054,7 @@ mod upgrades {
Ok(proto::Dispatched::Shutdown) => return Poll::Ready(Ok(())),
Ok(proto::Dispatched::Upgrade(pending)) => {
let h1 = match mem::replace(&mut self.inner.conn, None) {
Some(Either::A(h1)) => h1,
Some(ProtoServer::H1(h1)) => h1,
_ => unreachable!("Upgrade expects h1"),
};