refactor(lib): use dedicated enums for connection protocol versions

This should make it easier to add H3 functionality.
This commit is contained in:
Dirkjan Ochtman
2019-10-30 22:03:29 +01:00
committed by Sean McArthur
parent e6027bc02d
commit de5dcd7865
2 changed files with 81 additions and 51 deletions

View File

@@ -13,6 +13,7 @@ use std::sync::Arc;
use bytes::Bytes;
use futures_util::future::{self, Either, FutureExt as _};
use pin_project::{pin_project, project};
use tokio_io::{AsyncRead, AsyncWrite};
use tower_service::Service;
@@ -29,10 +30,15 @@ type Http1Dispatcher<T, B, R> = proto::dispatch::Dispatcher<
T,
R,
>;
type ConnEither<T, B> = Either<
Http1Dispatcher<T, B, proto::h1::ClientTransaction>,
proto::h2::ClientTask<B>,
>;
#[pin_project]
enum ProtoClient<T, B>
where
B: Payload,
{
H1(#[pin] Http1Dispatcher<T, B, proto::h1::ClientTransaction>),
H2(#[pin] proto::h2::ClientTask<B>),
}
// Our defaults are chosen for the "majority" case, which usually are not
// resource contrained, and so the spec default of 64kb can be too limiting
@@ -70,7 +76,7 @@ where
T: AsyncRead + AsyncWrite + Send + 'static,
B: Payload + 'static,
{
inner: Option<ConnEither<T, B>>,
inner: Option<ProtoClient<T, B>>,
}
@@ -342,8 +348,8 @@ where
/// Only works for HTTP/1 connections. HTTP/2 connections will panic.
pub fn into_parts(self) -> Parts<T> {
let (io, read_buf, _) = match self.inner.expect("already upgraded") {
Either::Left(h1) => h1.into_inner(),
Either::Right(_h2) => {
ProtoClient::H1(h1) => h1.into_inner(),
ProtoClient::H2(_h2) => {
panic!("http2 cannot into_inner");
}
};
@@ -368,10 +374,10 @@ where
/// to work with this function; or use the `without_shutdown` wrapper.
pub fn poll_without_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
match self.inner.as_mut().expect("already upgraded") {
&mut Either::Left(ref mut h1) => {
&mut ProtoClient::H1(ref mut h1) => {
h1.poll_without_shutdown(cx)
},
&mut Either::Right(ref mut h2) => {
&mut ProtoClient::H2(ref mut h2) => {
Pin::new(h2).poll(cx).map_ok(|_| ())
}
}
@@ -403,7 +409,7 @@ where
},
proto::Dispatched::Upgrade(pending) => {
let h1 = match mem::replace(&mut self.inner, None) {
Some(Either::Left(h1)) => h1,
Some(ProtoClient::H1(h1)) => h1,
_ => unreachable!("Upgrade expects h1"),
};
@@ -534,7 +540,7 @@ impl Builder {
trace!("client handshake HTTP/{}", if opts.http2 { 2 } else { 1 });
let (tx, rx) = dispatch::channel();
let either = if !opts.http2 {
let proto = if !opts.http2 {
let mut conn = proto::Conn::new(io);
if !opts.h1_writev {
conn.set_write_strategy_flatten();
@@ -550,11 +556,11 @@ impl Builder {
}
let cd = proto::h1::dispatch::Client::new(rx);
let dispatch = proto::h1::Dispatcher::new(cd, conn);
Either::Left(dispatch)
ProtoClient::H1(dispatch)
} else {
let h2 = proto::h2::client::handshake(io, rx, &opts.h2_builder, opts.exec.clone())
.await?;
Either::Right(h2)
ProtoClient::H2(h2)
};
Ok((
@@ -562,7 +568,7 @@ impl Builder {
dispatch: tx,
},
Connection {
inner: Some(either),
inner: Some(proto),
},
))
}
@@ -598,6 +604,26 @@ impl fmt::Debug for ResponseFuture {
}
}
// ===== impl ProtoClient
impl<T, B> Future for ProtoClient<T, B>
where
T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
B: Payload + Unpin + 'static,
B::Data: Unpin,
{
type Output = crate::Result<proto::Dispatched>;
#[project]
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
#[project]
match self.project() {
ProtoClient::H1(c) => c.poll(cx),
ProtoClient::H2(c) => c.poll(cx),
}
}
}
// assert trait markers
trait AssertSend: Send {}