Make use of NLL to clean up handshaking logic (#576)

This commit is contained in:
Anthony Ramine
2022-01-26 11:18:28 +01:00
committed by GitHub
parent 7de2ccc1a3
commit 556447c130
2 changed files with 85 additions and 93 deletions

View File

@@ -126,7 +126,7 @@ use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{convert, fmt, io, mem};
use std::{fmt, io};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::instrument::{Instrument, Instrumented};
@@ -301,8 +301,8 @@ enum Handshaking<T, B: Buf> {
Flushing(Instrumented<Flush<T, Prioritized<B>>>),
/// State 2. Connection is waiting for the client preface.
ReadingPreface(Instrumented<ReadPreface<T, Prioritized<B>>>),
/// Dummy state for `mem::replace`.
Empty,
/// State 3. Handshake is done, polling again would panic.
Done,
}
/// Flush a Sink
@@ -387,7 +387,8 @@ where
.expect("invalid SETTINGS frame");
// Create the handshake future.
let state = Handshaking::from(codec);
let state =
Handshaking::Flushing(Flush::new(codec).instrument(tracing::trace_span!("flush")));
drop(entered);
@@ -1269,63 +1270,58 @@ where
let span = self.span.clone(); // XXX(eliza): T_T
let _e = span.enter();
tracing::trace!(state = ?self.state);
use crate::server::Handshaking::*;
self.state = if let Flushing(ref mut flush) = self.state {
// We're currently flushing a pending SETTINGS frame. Poll the
// flush future, and, if it's completed, advance our state to wait
// for the client preface.
let codec = match Pin::new(flush).poll(cx)? {
Poll::Pending => {
tracing::trace!(flush.poll = %"Pending");
return Poll::Pending;
loop {
match &mut self.state {
Handshaking::Flushing(flush) => {
// We're currently flushing a pending SETTINGS frame. Poll the
// flush future, and, if it's completed, advance our state to wait
// for the client preface.
let codec = match Pin::new(flush).poll(cx)? {
Poll::Pending => {
tracing::trace!(flush.poll = %"Pending");
return Poll::Pending;
}
Poll::Ready(flushed) => {
tracing::trace!(flush.poll = %"Ready");
flushed
}
};
self.state = Handshaking::ReadingPreface(
ReadPreface::new(codec).instrument(tracing::trace_span!("read_preface")),
);
}
Poll::Ready(flushed) => {
tracing::trace!(flush.poll = %"Ready");
flushed
}
};
Handshaking::from(ReadPreface::new(codec))
} else {
// Otherwise, we haven't actually advanced the state, but we have
// to replace it with itself, because we have to return a value.
// (note that the assignment to `self.state` has to be outside of
// the `if let` block above in order to placate the borrow checker).
mem::replace(&mut self.state, Handshaking::Empty)
};
let poll = if let ReadingPreface(ref mut read) = self.state {
// We're now waiting for the client preface. Poll the `ReadPreface`
// future. If it has completed, we will create a `Connection` handle
// for the connection.
Pin::new(read).poll(cx)
// Actually creating the `Connection` has to occur outside of this
// `if let` block, because we've borrowed `self` mutably in order
// to poll the state and won't be able to borrow the SETTINGS frame
// as well until we release the borrow for `poll()`.
} else {
unreachable!("Handshake::poll() state was not advanced completely!")
};
poll?.map(|codec| {
let connection = proto::Connection::new(
codec,
Config {
next_stream_id: 2.into(),
// Server does not need to locally initiate any streams
initial_max_send_streams: 0,
max_send_buffer_size: self.builder.max_send_buffer_size,
reset_stream_duration: self.builder.reset_stream_duration,
reset_stream_max: self.builder.reset_stream_max,
settings: self.builder.settings.clone(),
},
);
Handshaking::ReadingPreface(read) => {
let codec = ready!(Pin::new(read).poll(cx)?);
tracing::trace!("connection established!");
let mut c = Connection { connection };
if let Some(sz) = self.builder.initial_target_connection_window_size {
c.set_target_window_size(sz);
self.state = Handshaking::Done;
let connection = proto::Connection::new(
codec,
Config {
next_stream_id: 2.into(),
// Server does not need to locally initiate any streams
initial_max_send_streams: 0,
max_send_buffer_size: self.builder.max_send_buffer_size,
reset_stream_duration: self.builder.reset_stream_duration,
reset_stream_max: self.builder.reset_stream_max,
settings: self.builder.settings.clone(),
},
);
tracing::trace!("connection established!");
let mut c = Connection { connection };
if let Some(sz) = self.builder.initial_target_connection_window_size {
c.set_target_window_size(sz);
}
return Poll::Ready(Ok(c));
}
Handshaking::Done => {
panic!("Handshaking::poll() called again after handshaking was complete")
}
}
Ok(c)
})
}
}
}
@@ -1548,42 +1544,9 @@ where
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
match *self {
Handshaking::Flushing(_) => write!(f, "Handshaking::Flushing(_)"),
Handshaking::ReadingPreface(_) => write!(f, "Handshaking::ReadingPreface(_)"),
Handshaking::Empty => write!(f, "Handshaking::Empty"),
Handshaking::Flushing(_) => f.write_str("Flushing(_)"),
Handshaking::ReadingPreface(_) => f.write_str("ReadingPreface(_)"),
Handshaking::Done => f.write_str("Done"),
}
}
}
impl<T, B> convert::From<Flush<T, Prioritized<B>>> for Handshaking<T, B>
where
T: AsyncRead + AsyncWrite,
B: Buf,
{
#[inline]
fn from(flush: Flush<T, Prioritized<B>>) -> Self {
Handshaking::Flushing(flush.instrument(tracing::trace_span!("flush")))
}
}
impl<T, B> convert::From<ReadPreface<T, Prioritized<B>>> for Handshaking<T, B>
where
T: AsyncRead + AsyncWrite,
B: Buf,
{
#[inline]
fn from(read: ReadPreface<T, Prioritized<B>>) -> Self {
Handshaking::ReadingPreface(read.instrument(tracing::trace_span!("read_preface")))
}
}
impl<T, B> convert::From<Codec<T, Prioritized<B>>> for Handshaking<T, B>
where
T: AsyncRead + AsyncWrite,
B: Buf,
{
#[inline]
fn from(codec: Codec<T, Prioritized<B>>) -> Self {
Handshaking::from(Flush::new(codec))
}
}