From 4c552a4960c1df7a42eed21df5a055917532b012 Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Thu, 22 Aug 2019 23:06:12 -0400 Subject: [PATCH] refactor(lib): Use pin-project crate to perform pin projections Remove all pin-related `unsafe` code from Hyper, as well as the now-unused 'pin-utils' dependency. --- Cargo.toml | 3 +- src/common/drain.rs | 17 +++--- src/proto/h2/server.rs | 51 ++++++++++-------- src/server/conn.rs | 120 +++++++++++++++++++++-------------------- src/server/mod.rs | 13 ++--- src/server/shutdown.rs | 68 ++++++++++++----------- 6 files changed, 147 insertions(+), 125 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 82ee388e..b77c149c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,8 @@ iovec = "0.1" itoa = "0.4.1" log = "0.4" net2 = { version = "0.2.32", optional = true } -pin-utils = "=0.1.0-alpha.4" +pin-project = { version = "0.4.0-alpha.7", features = ["project_attr"] } + time = "0.1" tokio = { version = "=0.2.0-alpha.4", optional = true, default-features = false, features = ["rt-full"] } tower-service = "=0.3.0-alpha.1" diff --git a/src/common/drain.rs b/src/common/drain.rs index 95eaaaeb..f5735ba5 100644 --- a/src/common/drain.rs +++ b/src/common/drain.rs @@ -2,6 +2,7 @@ use std::mem; use futures_util::FutureExt as _; use tokio_sync::{mpsc, watch}; +use pin_project::pin_project; use super::{Future, Never, Poll, Pin, task}; @@ -43,7 +44,9 @@ pub struct Watch { } #[allow(missing_debug_implementations)] +#[pin_project] pub struct Watching { + #[pin] future: F, state: State, watch: Watch, @@ -95,10 +98,10 @@ where { type Output = F::Output; - fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - let me = unsafe { self.get_unchecked_mut() }; + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { loop { - match mem::replace(&mut me.state, State::Draining) { + let me = self.project(); + match mem::replace(me.state, State::Draining) { State::Watch(on_drain) => { let recv = me.watch.rx.recv_ref(); futures_util::pin_mut!(recv); @@ -106,17 +109,17 @@ where match recv.poll_unpin(cx) { Poll::Ready(None) => { // Drain has been triggered! - on_drain(unsafe { Pin::new_unchecked(&mut me.future) }); + on_drain(me.future); }, Poll::Ready(Some(_/*State::Open*/)) | Poll::Pending => { - me.state = State::Watch(on_drain); - return unsafe { Pin::new_unchecked(&mut me.future) }.poll(cx); + *me.state = State::Watch(on_drain); + return me.future.poll(cx); }, } }, State::Draining => { - return unsafe { Pin::new_unchecked(&mut me.future) }.poll(cx); + return me.future.poll(cx) }, } } diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index 39e1f514..2e658d48 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -1,6 +1,7 @@ use std::error::Error as StdError; use std::marker::Unpin; +use pin_project::{pin_project, project}; use h2::Reason; use h2::server::{Builder, Connection, Handshake, SendResponse}; use tokio_io::{AsyncRead, AsyncWrite}; @@ -199,19 +200,22 @@ where } #[allow(missing_debug_implementations)] +#[pin_project] pub struct H2Stream where B: Payload, { reply: SendResponse>, + #[pin] state: H2StreamState, } +#[pin_project] enum H2StreamState where B: Payload, { - Service(F), + Service(#[pin] F), Body(PipeToSendStream), } @@ -229,6 +233,19 @@ where } } +macro_rules! reply { + ($me:expr, $res:expr, $eos:expr) => ({ + match $me.reply.send_response($res, $eos) { + Ok(tx) => tx, + Err(e) => { + debug!("send response error: {}", e); + $me.reply.send_reset(Reason::INTERNAL_ERROR); + return Poll::Ready(Err(crate::Error::new_h2(e))); + } + } + }) +} + impl H2Stream where F: Future, E>>, @@ -236,13 +253,14 @@ where B::Data: Unpin, E: Into>, { - fn poll2(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - // Safety: State::{Service, Body} futures are never moved - let me = unsafe { self.get_unchecked_mut() }; + #[project] + fn poll2(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { loop { - let next = match me.state { - H2StreamState::Service(ref mut h) => { - let res = match unsafe { Pin::new_unchecked(h) }.poll(cx) { + let mut me = self.project(); + #[project] + let next = match me.state.project() { + H2StreamState::Service(h) => { + let res = match h.poll(cx) { Poll::Ready(Ok(r)) => r, Poll::Pending => { // Response is not yet ready, so we want to check if the client has sent a @@ -274,18 +292,7 @@ where .expect("DATE is a valid HeaderName") .or_insert_with(crate::proto::h1::date::update_and_header_value); - macro_rules! reply { - ($eos:expr) => ({ - match me.reply.send_response(res, $eos) { - Ok(tx) => tx, - Err(e) => { - debug!("send response error: {}", e); - me.reply.send_reset(Reason::INTERNAL_ERROR); - return Poll::Ready(Err(crate::Error::new_h2(e))); - } - } - }) - } + // automatically set Content-Length from body... if let Some(len) = body.size_hint().exact() { @@ -293,10 +300,10 @@ where } if !body.is_end_stream() { - let body_tx = reply!(false); + let body_tx = reply!(me, res, false); H2StreamState::Body(PipeToSendStream::new(body, body_tx)) } else { - reply!(true); + reply!(me, res, true); return Poll::Ready(Ok(())); } }, @@ -304,7 +311,7 @@ where return Pin::new(pipe).poll(cx); } }; - me.state = next; + me.state.set(next); } } } diff --git a/src/server/conn.rs b/src/server/conn.rs index d1458f1c..d0789d0e 100644 --- a/src/server/conn.rs +++ b/src/server/conn.rs @@ -16,8 +16,8 @@ use std::mem; use bytes::Bytes; use futures_core::Stream; -use pin_utils::{unsafe_pinned, unsafe_unpinned}; use tokio_io::{AsyncRead, AsyncWrite}; +use pin_project::{pin_project, project}; #[cfg(feature = "runtime")] use tokio_net::driver::Handle; use crate::body::{Body, Payload}; @@ -69,8 +69,10 @@ enum ConnectionMode { /// /// Yields `Connecting`s that are futures that should be put on a reactor. #[must_use = "streams do nothing unless polled"] +#[pin_project] #[derive(Debug)] pub struct Serve { + #[pin] incoming: I, make_service: S, protocol: Http, @@ -81,16 +83,20 @@ pub struct Serve { /// Wraps the future returned from `MakeService` into one that returns /// a `Connection`. #[must_use = "futures do nothing unless polled"] +#[pin_project] #[derive(Debug)] pub struct Connecting { + #[pin] future: F, io: Option, protocol: Http, } #[must_use = "futures do nothing unless polled"] +#[pin_project] #[derive(Debug)] pub(super) struct SpawnAll { + #[pin] pub(super) serve: Serve, } @@ -98,6 +104,7 @@ pub(super) struct SpawnAll { /// /// Polling this future will drive HTTP forward. #[must_use = "futures do nothing unless polled"] +#[pin_project] pub struct Connection where S: Service, @@ -119,9 +126,10 @@ where fallback: Fallback, } +#[pin_project] pub(super) enum Either { - A(A), - B(B), + A(#[pin] A), + B(#[pin] B), } #[derive(Clone, Debug)] @@ -484,10 +492,8 @@ where /// /// This `Connection` should continue to be polled until shutdown /// can finish. - pub fn graceful_shutdown(self: Pin<&mut Self>) { - // Safety: neither h1 nor h2 poll any of the generic futures - // in these methods. - match unsafe { self.get_unchecked_mut() }.conn.as_mut().unwrap() { + pub fn graceful_shutdown(mut self: Pin<&mut Self>) { + match self.project().conn.as_mut().unwrap() { Either::A(ref mut h1) => { h1.disable_keep_alive(); }, @@ -672,9 +678,6 @@ where // ===== impl Serve ===== impl Serve { - unsafe_pinned!(incoming: I); - unsafe_unpinned!(make_service: S); - /// Spawn all incoming connections onto the executor in `Http`. pub(super) fn spawn_all(self) -> SpawnAll { SpawnAll { @@ -709,7 +712,7 @@ where type Item = crate::Result>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - match ready!(self.as_mut().make_service().poll_ready_ref(cx)) { + match ready!(self.project().make_service.poll_ready_ref(cx)) { Ok(()) => (), Err(e) => { trace!("make_service closed"); @@ -717,9 +720,9 @@ where } } - if let Some(item) = ready!(self.as_mut().incoming().poll_next(cx)) { + if let Some(item) = ready!(self.project().incoming.poll_next(cx)) { let io = item.map_err(crate::Error::new_accept)?; - let new_fut = self.as_mut().make_service().make_service_ref(&io); + let new_fut = self.project().make_service.make_service_ref(&io); Poll::Ready(Some(Ok(Connecting { future: new_fut, io: Some(io), @@ -733,10 +736,6 @@ where // ===== impl Connecting ===== -impl Connecting { - unsafe_pinned!(future: F); - unsafe_unpinned!(io: Option); -} impl Future for Connecting where @@ -750,8 +749,8 @@ where type Output = Result, FE>; fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - let service = ready!(self.as_mut().future().poll(cx))?; - let io = self.as_mut().io().take().expect("polled after complete"); + let service = ready!(self.project().future.poll(cx))?; + let io = self.project().io.take().expect("polled after complete"); Poll::Ready(Ok(self.protocol.serve_connection(io, service))) } } @@ -784,17 +783,15 @@ where B: Payload, E: H2Exec<>::Future, B>, { - pub(super) fn poll_watch(self: Pin<&mut Self>, cx: &mut task::Context<'_>, watcher: &W) -> Poll> + pub(super) fn poll_watch(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, watcher: &W) -> Poll> where E: NewSvcExec, W: Watcher, { - // Safety: futures are never moved... lolwtf - let me = unsafe { self.get_unchecked_mut() }; loop { - if let Some(connecting) = ready!(unsafe { Pin::new_unchecked(&mut me.serve) }.poll_next(cx)?) { + if let Some(connecting) = ready!(self.project().serve.poll_next(cx)?) { let fut = NewSvcTask::new(connecting, watcher.clone()); - me.serve.protocol.exec.execute_new_svc(fut)?; + self.project().serve.project().protocol.exec.execute_new_svc(fut)?; } else { return Poll::Ready(Ok(())); } @@ -810,13 +807,13 @@ where { type Output = A::Output; - fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + #[project] + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { // Just simple pin projection to the inner variants - unsafe { - match self.get_unchecked_mut() { - Either::A(a) => Pin::new_unchecked(a).poll(cx), - Either::B(b) => Pin::new_unchecked(b).poll(cx), - } + #[project] + match self.project() { + Either::A(a) => a.poll(cx), + Either::B(b) => b.poll(cx), } } } @@ -830,6 +827,7 @@ pub(crate) mod spawn_all { use crate::common::{Future, Pin, Poll, Unpin, task}; use crate::service::Service; use super::{Connecting, UpgradeableConnection}; + use pin_project::{pin_project, project}; // Used by `SpawnAll` to optionally watch a `Connection` future. // @@ -872,14 +870,18 @@ pub(crate) mod spawn_all { // // Users cannot import this type, nor the associated `NewSvcExec`. Instead, // a blanket implementation for `Executor` is sufficient. + + #[pin_project] #[allow(missing_debug_implementations)] pub struct NewSvcTask, E, W: Watcher> { + #[pin] state: State, } - enum State, E, W: Watcher> { - Connecting(Connecting, W), - Connected(W::Future), + #[pin_project] + pub enum State, E, W: Watcher> { + Connecting(#[pin] Connecting, W), + Connected(#[pin] W::Future), } impl, E, W: Watcher> NewSvcTask { @@ -903,39 +905,43 @@ pub(crate) mod spawn_all { { type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + #[project] + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { // If it weren't for needing to name this type so the `Send` bounds // could be projected to the `Serve` executor, this could just be // an `async fn`, and much safer. Woe is me. - let me = unsafe { self.get_unchecked_mut() }; loop { - let next = match me.state { - State::Connecting(ref mut connecting, ref watcher) => { - let res = ready!(unsafe { Pin::new_unchecked(connecting).poll(cx) }); - let conn = match res { - Ok(conn) => conn, - Err(err) => { - let err = crate::Error::new_user_make_service(err); - debug!("connecting error: {}", err); - return Poll::Ready(()); - } - }; - let connected = watcher.watch(conn.with_upgrades()); - State::Connected(connected) - }, - State::Connected(ref mut future) => { - return unsafe { Pin::new_unchecked(future) } - .poll(cx) - .map(|res| { - if let Err(err) = res { - debug!("connection error: {}", err); + let mut me = self.project(); + let next = { + #[project] + match me.state.project() { + State::Connecting(connecting, watcher) => { + let res = ready!(connecting.poll(cx)); + let conn = match res { + Ok(conn) => conn, + Err(err) => { + let err = crate::Error::new_user_make_service(err); + debug!("connecting error: {}", err); + return Poll::Ready(()); } - }); + }; + let connected = watcher.watch(conn.with_upgrades()); + State::Connected(connected) + }, + State::Connected(future) => { + return future + .poll(cx) + .map(|res| { + if let Err(err) = res { + debug!("connection error: {}", err); + } + }); + } } }; - me.state = next; + me.state.set(next); } } } diff --git a/src/server/mod.rs b/src/server/mod.rs index bcd307ca..ef306afa 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -59,8 +59,8 @@ use std::fmt; #[cfg(feature = "runtime")] use std::time::Duration; use futures_core::Stream; -use pin_utils::unsafe_pinned; use tokio_io::{AsyncRead, AsyncWrite}; +use pin_project::pin_project; use crate::body::{Body, Payload}; use crate::common::exec::{Exec, H2Exec, NewSvcExec}; @@ -78,7 +78,9 @@ use self::shutdown::{Graceful, GracefulWatcher}; /// handlers. It is built using the [`Builder`](Builder), and the future /// completes when the server has been shutdown. It should be run by an /// `Executor`. +#[pin_project] pub struct Server { + #[pin] spawn_all: SpawnAll, } @@ -101,11 +103,6 @@ impl Server { } } -impl Server { - // Never moved, just projected - unsafe_pinned!(spawn_all: SpawnAll); -} - #[cfg(feature = "runtime")] impl Server { /// Binds to the provided address, and returns a [`Builder`](Builder). @@ -216,8 +213,8 @@ where { type Output = crate::Result<()>; - fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - self.spawn_all().poll_watch(cx, &NoopWatcher) + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + self.project().spawn_all.poll_watch(cx, &NoopWatcher) } } diff --git a/src/server/shutdown.rs b/src/server/shutdown.rs index 462fd1e1..3080ae9b 100644 --- a/src/server/shutdown.rs +++ b/src/server/shutdown.rs @@ -2,6 +2,7 @@ use std::error::Error as StdError; use futures_core::Stream; use tokio_io::{AsyncRead, AsyncWrite}; +use pin_project::{pin_project, project}; use crate::body::{Body, Payload}; use crate::common::drain::{self, Draining, Signal, Watch, Watching}; @@ -11,14 +12,19 @@ use crate::service::{MakeServiceRef, Service}; use super::conn::{SpawnAll, UpgradeableConnection, Watcher}; #[allow(missing_debug_implementations)] +#[pin_project] pub struct Graceful { + #[pin] state: State, } -enum State { +#[pin_project] +pub(super) enum State { Running { drain: Option<(Signal, Watch)>, + #[pin] spawn_all: SpawnAll, + #[pin] signal: F, }, Draining(Draining), @@ -54,39 +60,41 @@ where { type Output = crate::Result<()>; - fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - // Safety: the futures are NEVER moved, self.state is overwritten instead. - let me = unsafe { self.get_unchecked_mut() }; + #[project] + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + let mut me = self.project(); loop { - let next = match me.state { - State::Running { - ref mut drain, - ref mut spawn_all, - ref mut signal, - } => match unsafe { Pin::new_unchecked(signal) }.poll(cx) { - Poll::Ready(()) => { - debug!("signal received, starting graceful shutdown"); - let sig = drain - .take() - .expect("drain channel") - .0; - State::Draining(sig.drain()) + let next = { + #[project] + match me.state.project() { + State::Running { + drain, + spawn_all, + signal, + } => match signal.poll(cx) { + Poll::Ready(()) => { + debug!("signal received, starting graceful shutdown"); + let sig = drain + .take() + .expect("drain channel") + .0; + State::Draining(sig.drain()) + }, + Poll::Pending => { + let watch = drain + .as_ref() + .expect("drain channel") + .1 + .clone(); + return spawn_all.poll_watch(cx, &GracefulWatcher(watch)); + }, }, - Poll::Pending => { - let watch = drain - .as_ref() - .expect("drain channel") - .1 - .clone(); - return unsafe { Pin::new_unchecked(spawn_all) }.poll_watch(cx, &GracefulWatcher(watch)); - }, - }, - State::Draining(ref mut draining) => { - return Pin::new(draining).poll(cx).map(Ok); + State::Draining(ref mut draining) => { + return Pin::new(draining).poll(cx).map(Ok); + } } }; - // It's important to just assign, not mem::replace or anything. - me.state = next; + me.state.set(next); } } }