use std::error::Error as StdError; use futures_core::Stream; use tokio_io::{AsyncRead, AsyncWrite}; use crate::body::{Body, Payload}; use crate::common::drain::{self, Draining, Signal, Watch, Watching}; use crate::common::exec::{H2Exec, NewSvcExec}; use crate::common::{Future, Pin, Poll, Unpin, task}; use crate::service::{MakeServiceRef, Service}; use super::conn::{SpawnAll, UpgradeableConnection, Watcher}; #[allow(missing_debug_implementations)] pub struct Graceful { state: State, } enum State { Running { drain: Option<(Signal, Watch)>, spawn_all: SpawnAll, signal: F, }, Draining(Draining), } impl Graceful { pub(super) fn new(spawn_all: SpawnAll, signal: F) -> Self { let drain = Some(drain::channel()); Graceful { state: State::Running { drain, spawn_all, signal, }, } } } impl Future for Graceful where I: Stream>, IE: Into>, IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: MakeServiceRef, S::Service: 'static, S::Error: Into>, B: Payload, B::Data: Unpin, F: Future, E: H2Exec<::Future, B>, E: NewSvcExec, { type Output = crate::Result<()>; fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { // Safety: the futures are NEVER moved, self.state is overwritten instead. let me = unsafe { self.get_unchecked_mut() }; loop { let next = match me.state { State::Running { ref mut drain, ref mut spawn_all, ref mut signal, } => match unsafe { Pin::new_unchecked(signal) }.poll(cx) { Poll::Ready(()) => { debug!("signal received, starting graceful shutdown"); let sig = drain .take() .expect("drain channel") .0; State::Draining(sig.drain()) }, Poll::Pending => { let watch = drain .as_ref() .expect("drain channel") .1 .clone(); return unsafe { Pin::new_unchecked(spawn_all) }.poll_watch(cx, &GracefulWatcher(watch)); }, }, State::Draining(ref mut draining) => { return Pin::new(draining).poll(cx).map(Ok); } }; // It's important to just assign, not mem::replace or anything. me.state = next; } } } #[allow(missing_debug_implementations)] #[derive(Clone)] pub struct GracefulWatcher(Watch); impl Watcher for GracefulWatcher where I: AsyncRead + AsyncWrite + Unpin + Send + 'static, S: Service + 'static, ::Data: Unpin, E: H2Exec, { type Future = Watching, fn(Pin<&mut UpgradeableConnection>)>; fn watch(&self, conn: UpgradeableConnection) -> Self::Future { self .0 .clone() .watch(conn, on_drain) } } fn on_drain(conn: Pin<&mut UpgradeableConnection>) where S: Service, S::Error: Into>, I: AsyncRead + AsyncWrite + Unpin, S::ResBody: Payload + 'static, ::Data: Unpin, E: H2Exec, { conn.graceful_shutdown() }