From de5dcd78655e5e313f269ca06c11a32c729a3c38 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Wed, 30 Oct 2019 22:03:29 +0100 Subject: [PATCH] refactor(lib): use dedicated enums for connection protocol versions This should make it easier to add H3 functionality. --- src/client/conn.rs | 54 +++++++++++++++++++++++--------- src/server/conn.rs | 78 ++++++++++++++++++++++++---------------------- 2 files changed, 81 insertions(+), 51 deletions(-) diff --git a/src/client/conn.rs b/src/client/conn.rs index f26791b3..c379b9a0 100644 --- a/src/client/conn.rs +++ b/src/client/conn.rs @@ -13,6 +13,7 @@ use std::sync::Arc; use bytes::Bytes; use futures_util::future::{self, Either, FutureExt as _}; +use pin_project::{pin_project, project}; use tokio_io::{AsyncRead, AsyncWrite}; use tower_service::Service; @@ -29,10 +30,15 @@ type Http1Dispatcher = proto::dispatch::Dispatcher< T, R, >; -type ConnEither = Either< - Http1Dispatcher, - proto::h2::ClientTask, ->; + +#[pin_project] +enum ProtoClient +where + B: Payload, +{ + H1(#[pin] Http1Dispatcher), + H2(#[pin] proto::h2::ClientTask), +} // Our defaults are chosen for the "majority" case, which usually are not // resource contrained, and so the spec default of 64kb can be too limiting @@ -70,7 +76,7 @@ where T: AsyncRead + AsyncWrite + Send + 'static, B: Payload + 'static, { - inner: Option>, + inner: Option>, } @@ -342,8 +348,8 @@ where /// Only works for HTTP/1 connections. HTTP/2 connections will panic. pub fn into_parts(self) -> Parts { let (io, read_buf, _) = match self.inner.expect("already upgraded") { - Either::Left(h1) => h1.into_inner(), - Either::Right(_h2) => { + ProtoClient::H1(h1) => h1.into_inner(), + ProtoClient::H2(_h2) => { panic!("http2 cannot into_inner"); } }; @@ -368,10 +374,10 @@ where /// to work with this function; or use the `without_shutdown` wrapper. pub fn poll_without_shutdown(&mut self, cx: &mut task::Context<'_>) -> Poll> { match self.inner.as_mut().expect("already upgraded") { - &mut Either::Left(ref mut h1) => { + &mut ProtoClient::H1(ref mut h1) => { h1.poll_without_shutdown(cx) }, - &mut Either::Right(ref mut h2) => { + &mut ProtoClient::H2(ref mut h2) => { Pin::new(h2).poll(cx).map_ok(|_| ()) } } @@ -403,7 +409,7 @@ where }, proto::Dispatched::Upgrade(pending) => { let h1 = match mem::replace(&mut self.inner, None) { - Some(Either::Left(h1)) => h1, + Some(ProtoClient::H1(h1)) => h1, _ => unreachable!("Upgrade expects h1"), }; @@ -534,7 +540,7 @@ impl Builder { trace!("client handshake HTTP/{}", if opts.http2 { 2 } else { 1 }); let (tx, rx) = dispatch::channel(); - let either = if !opts.http2 { + let proto = if !opts.http2 { let mut conn = proto::Conn::new(io); if !opts.h1_writev { conn.set_write_strategy_flatten(); @@ -550,11 +556,11 @@ impl Builder { } let cd = proto::h1::dispatch::Client::new(rx); let dispatch = proto::h1::Dispatcher::new(cd, conn); - Either::Left(dispatch) + ProtoClient::H1(dispatch) } else { let h2 = proto::h2::client::handshake(io, rx, &opts.h2_builder, opts.exec.clone()) .await?; - Either::Right(h2) + ProtoClient::H2(h2) }; Ok(( @@ -562,7 +568,7 @@ impl Builder { dispatch: tx, }, Connection { - inner: Some(either), + inner: Some(proto), }, )) } @@ -598,6 +604,26 @@ impl fmt::Debug for ResponseFuture { } } +// ===== impl ProtoClient + +impl Future for ProtoClient +where + T: AsyncRead + AsyncWrite + Send + Unpin + 'static, + B: Payload + Unpin + 'static, + B::Data: Unpin, +{ + type Output = crate::Result; + + #[project] + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + #[project] + match self.project() { + ProtoClient::H1(c) => c.poll(cx), + ProtoClient::H2(c) => c.poll(cx), + } + } +} + // assert trait markers trait AssertSend: Send {} diff --git a/src/server/conn.rs b/src/server/conn.rs index a8a9bcc4..dcb3e2e8 100644 --- a/src/server/conn.rs +++ b/src/server/conn.rs @@ -119,27 +119,26 @@ pub struct Connection where S: HttpService, { - pub(super) conn: Option, - S::ResBody, - T, - proto::ServerTransaction, - >, - proto::h2::Server< - Rewind, - S, - S::ResBody, - E, - >, - >>, + pub(super) conn: Option>, fallback: Fallback, } #[pin_project] -pub(super) enum Either { - A(#[pin] A), - B(#[pin] B), +pub(super) enum ProtoServer +where + S: HttpService, + B: Payload, +{ + H1( + #[pin] + proto::h1::Dispatcher< + proto::h1::dispatch::Server, + B, + T, + proto::ServerTransaction, + >, + ), + H2(#[pin] proto::h2::Server, S, B, E>), } #[derive(Clone, Debug)] @@ -384,7 +383,7 @@ impl Http { I: AsyncRead + AsyncWrite + Unpin, E: H2Exec, { - let either = match self.mode { + let proto = match self.mode { ConnectionMode::H1Only | ConnectionMode::Fallback => { let mut conn = proto::Conn::new(io); if !self.keep_alive { @@ -401,17 +400,17 @@ impl Http { conn.set_max_buf_size(max); } let sd = proto::h1::dispatch::Server::new(service); - Either::A(proto::h1::Dispatcher::new(sd, conn)) + ProtoServer::H1(proto::h1::Dispatcher::new(sd, conn)) } ConnectionMode::H2Only => { let rewind_io = Rewind::new(io); let h2 = proto::h2::Server::new(rewind_io, service, &self.h2_builder, self.exec.clone()); - Either::B(h2) + ProtoServer::H2(h2) } }; Connection { - conn: Some(either), + conn: Some(proto), fallback: if self.mode == ConnectionMode::Fallback { Fallback::ToHttp2(self.h2_builder.clone(), self.exec.clone()) } else { @@ -528,10 +527,10 @@ where /// can finish. pub fn graceful_shutdown(self: Pin<&mut Self>) { match self.project().conn.as_mut().unwrap() { - Either::A(ref mut h1) => { + ProtoServer::H1(ref mut h1) => { h1.disable_keep_alive(); }, - Either::B(ref mut h2) => { + ProtoServer::H2(ref mut h2) => { h2.graceful_shutdown(); } } @@ -555,7 +554,7 @@ where /// This method will return a `None` if this connection is using an h2 protocol. pub fn try_into_parts(self) -> Option> { match self.conn.unwrap() { - Either::A(h1) => { + ProtoServer::H1(h1) => { let (io, read_buf, dispatch) = h1.into_inner(); Some(Parts { io: io, @@ -564,7 +563,7 @@ where _inner: (), }) }, - Either::B(_h2) => None, + ProtoServer::H2(_h2) => None, } } @@ -587,8 +586,8 @@ where { loop { let polled = match *self.conn.as_mut().unwrap() { - Either::A(ref mut h1) => h1.poll_without_shutdown(cx), - Either::B(ref mut h2) => return Pin::new(h2).poll(cx).map_ok(|_| ()), + ProtoServer::H1(ref mut h1) => h1.poll_without_shutdown(cx), + ProtoServer::H2(ref mut h2) => return Pin::new(h2).poll(cx).map_ok(|_| ()), }; match ready!(polled) { Ok(x) => return Poll::Ready(Ok(x)), @@ -625,10 +624,10 @@ where let conn = self.conn.take(); let (io, read_buf, dispatch) = match conn.unwrap() { - Either::A(h1) => { + ProtoServer::H1(h1) => { h1.into_inner() }, - Either::B(_h2) => { + ProtoServer::H2(_h2) => { panic!("h2 cannot into_inner"); } }; @@ -646,7 +645,7 @@ where ); debug_assert!(self.conn.is_none()); - self.conn = Some(Either::B(h2)); + self.conn = Some(ProtoServer::H2(h2)); } /// Enable this connection to support higher-level HTTP upgrades. @@ -852,20 +851,25 @@ where } } +// ===== impl ProtoServer ===== -impl Future for Either +impl Future for ProtoServer where - A: Future, - B: Future, + T: AsyncRead + AsyncWrite + Unpin, + S: HttpService, + S::Error: Into>, + B: Payload, + B::Data: Unpin, + E: H2Exec, { - type Output = A::Output; + type Output = crate::Result; #[project] fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { #[project] match self.project() { - Either::A(a) => a.poll(cx), - Either::B(b) => b.poll(cx), + ProtoServer::H1(s) => s.poll(cx), + ProtoServer::H2(s) => s.poll(cx), } } } @@ -1050,7 +1054,7 @@ mod upgrades { Ok(proto::Dispatched::Shutdown) => return Poll::Ready(Ok(())), Ok(proto::Dispatched::Upgrade(pending)) => { let h1 = match mem::replace(&mut self.inner.conn, None) { - Some(Either::A(h1)) => h1, + Some(ProtoServer::H1(h1)) => h1, _ => unreachable!("Upgrade expects h1"), };