diff --git a/src/client.rs b/src/client.rs index 099e1f9..62aea85 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1124,6 +1124,20 @@ where // ===== impl Connection ===== +async fn bind_connection(io: &mut T) -> Result<(), crate::Error> +where + T: AsyncRead + AsyncWrite + Unpin, +{ + tracing::debug!("binding client connection"); + + let msg: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + io.write_all(msg).await.map_err(crate::Error::from_io)?; + + tracing::debug!("client connection bound"); + + Ok(()) +} + impl Connection where T: AsyncRead + AsyncWrite + Unpin, @@ -1133,12 +1147,7 @@ where mut io: T, builder: Builder, ) -> Result<(SendRequest, Connection), crate::Error> { - tracing::debug!("binding client connection"); - - let msg: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; - io.write_all(msg).await.map_err(crate::Error::from_io)?; - - tracing::debug!("client connection bound"); + bind_connection(&mut io).await?; // Create the codec let mut codec = Codec::new(io); diff --git a/src/codec/framed_write.rs b/src/codec/framed_write.rs index 8ec2045..b69979a 100644 --- a/src/codec/framed_write.rs +++ b/src/codec/framed_write.rs @@ -23,6 +23,11 @@ pub struct FramedWrite { /// Upstream `AsyncWrite` inner: T, + encoder: Encoder, +} + +#[derive(Debug)] +struct Encoder { /// HPACK encoder hpack: hpack::Encoder, @@ -74,12 +79,14 @@ where let is_write_vectored = inner.is_write_vectored(); FramedWrite { inner, - hpack: hpack::Encoder::default(), - buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), - next: None, - last_data_frame: None, - max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, - is_write_vectored, + encoder: Encoder { + hpack: hpack::Encoder::default(), + buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), + next: None, + last_data_frame: None, + max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, + is_write_vectored, + }, } } @@ -88,11 +95,11 @@ where /// Calling this function may result in the current contents of the buffer /// to be flushed to `T`. pub fn poll_ready(&mut self, cx: &mut Context) -> Poll> { - if !self.has_capacity() { + if !self.encoder.has_capacity() { // Try flushing ready!(self.flush(cx))?; - if !self.has_capacity() { + if !self.encoder.has_capacity() { return Poll::Pending; } } @@ -105,6 +112,128 @@ where /// `poll_ready` must be called first to ensure that a frame may be /// accepted. pub fn buffer(&mut self, item: Frame) -> Result<(), UserError> { + self.encoder.buffer(item) + } + + /// Flush buffered data to the wire + pub fn flush(&mut self, cx: &mut Context) -> Poll> { + let span = tracing::trace_span!("FramedWrite::flush"); + let _e = span.enter(); + + loop { + while !self.encoder.is_empty() { + match self.encoder.next { + Some(Next::Data(ref mut frame)) => { + tracing::trace!(queued_data_frame = true); + let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut()); + ready!(write( + &mut self.inner, + self.encoder.is_write_vectored, + &mut buf, + cx, + ))? + } + _ => { + tracing::trace!(queued_data_frame = false); + ready!(write( + &mut self.inner, + self.encoder.is_write_vectored, + &mut self.encoder.buf, + cx, + ))? + } + } + } + + match self.encoder.unset_frame() { + ControlFlow::Continue => (), + ControlFlow::Break => break, + } + } + + tracing::trace!("flushing buffer"); + // Flush the upstream + ready!(Pin::new(&mut self.inner).poll_flush(cx))?; + + Poll::Ready(Ok(())) + } + + /// Close the codec + pub fn shutdown(&mut self, cx: &mut Context) -> Poll> { + ready!(self.flush(cx))?; + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +fn write( + writer: &mut T, + is_write_vectored: bool, + buf: &mut B, + cx: &mut Context<'_>, +) -> Poll> +where + T: AsyncWrite + Unpin, + B: Buf, +{ + // TODO(eliza): when tokio-util 0.5.1 is released, this + // could just use `poll_write_buf`... + const MAX_IOVS: usize = 64; + let n = if is_write_vectored { + let mut bufs = [IoSlice::new(&[]); MAX_IOVS]; + let cnt = buf.chunks_vectored(&mut bufs); + ready!(Pin::new(writer).poll_write_vectored(cx, &bufs[..cnt]))? + } else { + ready!(Pin::new(writer).poll_write(cx, buf.chunk()))? + }; + buf.advance(n); + Ok(()).into() +} + +#[must_use] +enum ControlFlow { + Continue, + Break, +} + +impl Encoder +where + B: Buf, +{ + fn unset_frame(&mut self) -> ControlFlow { + // Clear internal buffer + self.buf.set_position(0); + self.buf.get_mut().clear(); + + // The data frame has been written, so unset it + match self.next.take() { + Some(Next::Data(frame)) => { + self.last_data_frame = Some(frame); + debug_assert!(self.is_empty()); + ControlFlow::Break + } + Some(Next::Continuation(frame)) => { + // Buffer the continuation frame, then try to write again + let mut buf = limited_write_buf!(self); + if let Some(continuation) = frame.encode(&mut self.hpack, &mut buf) { + // We previously had a CONTINUATION, and after encoding + // it, we got *another* one? Let's just double check + // that at least some progress is being made... + if self.buf.get_ref().len() == frame::HEADER_LEN { + // If *only* the CONTINUATION frame header was + // written, and *no* header fields, we're stuck + // in a loop... + panic!("CONTINUATION frame write loop; header value too big to encode"); + } + + self.next = Some(Next::Continuation(continuation)); + } + ControlFlow::Continue + } + None => ControlFlow::Break, + } + } + + fn buffer(&mut self, item: Frame) -> Result<(), UserError> { // Ensure that we have enough capacity to accept the write. assert!(self.has_capacity()); let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item); @@ -185,93 +314,6 @@ where Ok(()) } - /// Flush buffered data to the wire - pub fn flush(&mut self, cx: &mut Context) -> Poll> { - const MAX_IOVS: usize = 64; - - let span = tracing::trace_span!("FramedWrite::flush"); - let _e = span.enter(); - - loop { - while !self.is_empty() { - match self.next { - Some(Next::Data(ref mut frame)) => { - tracing::trace!(queued_data_frame = true); - let mut buf = (&mut self.buf).chain(frame.payload_mut()); - // TODO(eliza): when tokio-util 0.5.1 is released, this - // could just use `poll_write_buf`... - let n = if self.is_write_vectored { - let mut bufs = [IoSlice::new(&[]); MAX_IOVS]; - let cnt = buf.chunks_vectored(&mut bufs); - ready!(Pin::new(&mut self.inner).poll_write_vectored(cx, &bufs[..cnt]))? - } else { - ready!(Pin::new(&mut self.inner).poll_write(cx, buf.chunk()))? - }; - buf.advance(n); - } - _ => { - tracing::trace!(queued_data_frame = false); - let n = if self.is_write_vectored { - let mut iovs = [IoSlice::new(&[]); MAX_IOVS]; - let cnt = self.buf.chunks_vectored(&mut iovs); - ready!( - Pin::new(&mut self.inner).poll_write_vectored(cx, &mut iovs[..cnt]) - )? - } else { - ready!(Pin::new(&mut self.inner).poll_write(cx, &mut self.buf.chunk()))? - }; - self.buf.advance(n); - } - } - } - - // Clear internal buffer - self.buf.set_position(0); - self.buf.get_mut().clear(); - - // The data frame has been written, so unset it - match self.next.take() { - Some(Next::Data(frame)) => { - self.last_data_frame = Some(frame); - debug_assert!(self.is_empty()); - break; - } - Some(Next::Continuation(frame)) => { - // Buffer the continuation frame, then try to write again - let mut buf = limited_write_buf!(self); - if let Some(continuation) = frame.encode(&mut self.hpack, &mut buf) { - // We previously had a CONTINUATION, and after encoding - // it, we got *another* one? Let's just double check - // that at least some progress is being made... - if self.buf.get_ref().len() == frame::HEADER_LEN { - // If *only* the CONTINUATION frame header was - // written, and *no* header fields, we're stuck - // in a loop... - panic!("CONTINUATION frame write loop; header value too big to encode"); - } - - self.next = Some(Next::Continuation(continuation)); - } - } - None => { - break; - } - } - } - - tracing::trace!("flushing buffer"); - // Flush the upstream - ready!(Pin::new(&mut self.inner).poll_flush(cx))?; - - Poll::Ready(Ok(())) - } - - /// Close the codec - pub fn shutdown(&mut self, cx: &mut Context) -> Poll> { - ready!(self.flush(cx))?; - Pin::new(&mut self.inner).poll_shutdown(cx) - } - fn has_capacity(&self) -> bool { self.next.is_none() && self.buf.get_ref().remaining_mut() >= MIN_BUFFER_CAPACITY } @@ -284,26 +326,32 @@ where } } +impl Encoder { + fn max_frame_size(&self) -> usize { + self.max_frame_size as usize + } +} + impl FramedWrite { /// Returns the max frame size that can be sent pub fn max_frame_size(&self) -> usize { - self.max_frame_size as usize + self.encoder.max_frame_size() } /// Set the peer's max frame size. pub fn set_max_frame_size(&mut self, val: usize) { assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize); - self.max_frame_size = val as FrameSize; + self.encoder.max_frame_size = val as FrameSize; } /// Set the peer's header table size. pub fn set_header_table_size(&mut self, val: usize) { - self.hpack.update_max_size(val); + self.encoder.hpack.update_max_size(val); } /// Retrieve the last data frame that has been sent pub fn take_last_data_frame(&mut self) -> Option> { - self.last_data_frame.take() + self.encoder.last_data_frame.take() } pub fn get_mut(&mut self) -> &mut T { diff --git a/src/proto/connection.rs b/src/proto/connection.rs index 887c8f0..d408f7c 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -17,6 +17,19 @@ use tokio::io::{AsyncRead, AsyncWrite}; /// An H2 connection #[derive(Debug)] pub(crate) struct Connection +where + P: Peer, +{ + /// Read / write frame values + codec: Codec>, + + inner: ConnectionInner, +} + +// Extracted part of `Connection` which does not depend on `T`. Reduces the amount of duplicated +// method instantiations. +#[derive(Debug)] +struct ConnectionInner where P: Peer, { @@ -29,9 +42,6 @@ where /// graceful shutdown. error: Option, - /// Read / write frame values - codec: Codec>, - /// Pending GOAWAY frames to write. go_away: GoAway, @@ -51,6 +61,18 @@ where _phantom: PhantomData

, } +struct DynConnection<'a, B: Buf = Bytes> { + state: &'a mut State, + + go_away: &'a mut GoAway, + + streams: DynStreams<'a, B>, + + error: &'a mut Option, + + ping_pong: &'a mut PingPong, +} + #[derive(Debug, Clone)] pub(crate) struct Config { pub next_stream_id: StreamId, @@ -79,51 +101,56 @@ where B: Buf, { pub fn new(codec: Codec>, config: Config) -> Connection { - let streams = Streams::new(streams::Config { - local_init_window_sz: config - .settings - .initial_window_size() - .unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE), - initial_max_send_streams: config.initial_max_send_streams, - local_next_stream_id: config.next_stream_id, - local_push_enabled: config.settings.is_push_enabled().unwrap_or(true), - local_reset_duration: config.reset_stream_duration, - local_reset_max: config.reset_stream_max, - remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, - remote_max_initiated: config - .settings - .max_concurrent_streams() - .map(|max| max as usize), - }); + fn streams_config(config: &Config) -> streams::Config { + streams::Config { + local_init_window_sz: config + .settings + .initial_window_size() + .unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE), + initial_max_send_streams: config.initial_max_send_streams, + local_next_stream_id: config.next_stream_id, + local_push_enabled: config.settings.is_push_enabled().unwrap_or(true), + local_reset_duration: config.reset_stream_duration, + local_reset_max: config.reset_stream_max, + remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, + remote_max_initiated: config + .settings + .max_concurrent_streams() + .map(|max| max as usize), + } + } + let streams = Streams::new(streams_config(&config)); Connection { - state: State::Open, - error: None, codec, - go_away: GoAway::new(), - ping_pong: PingPong::new(), - settings: Settings::new(config.settings), - streams, - span: tracing::debug_span!("Connection", peer = %P::NAME), - _phantom: PhantomData, + inner: ConnectionInner { + state: State::Open, + error: None, + go_away: GoAway::new(), + ping_pong: PingPong::new(), + settings: Settings::new(config.settings), + streams, + span: tracing::debug_span!("Connection", peer = %P::NAME), + _phantom: PhantomData, + }, } } /// connection flow control pub(crate) fn set_target_window_size(&mut self, size: WindowSize) { - self.streams.set_target_connection_window_size(size); + self.inner.streams.set_target_connection_window_size(size); } /// Send a new SETTINGS frame with an updated initial window size. pub(crate) fn set_initial_window_size(&mut self, size: WindowSize) -> Result<(), UserError> { let mut settings = frame::Settings::default(); settings.set_initial_window_size(Some(size)); - self.settings.send_settings(settings) + self.inner.settings.send_settings(settings) } /// Returns the maximum number of concurrent streams that may be initiated /// by this peer. pub(crate) fn max_send_streams(&self) -> usize { - self.streams.max_send_streams() + self.inner.streams.max_send_streams() } /// Returns `Ready` when the connection is ready to receive a frame. @@ -131,16 +158,17 @@ where /// Returns `RecvError` as this may raise errors that are caused by delayed /// processing of received frames. fn poll_ready(&mut self, cx: &mut Context) -> Poll> { - let _e = self.span.enter(); + let _e = self.inner.span.enter(); let span = tracing::trace_span!("poll_ready"); let _e = span.enter(); // The order of these calls don't really matter too much - ready!(self.ping_pong.send_pending_pong(cx, &mut self.codec))?; - ready!(self.ping_pong.send_pending_ping(cx, &mut self.codec))?; + ready!(self.inner.ping_pong.send_pending_pong(cx, &mut self.codec))?; + ready!(self.inner.ping_pong.send_pending_ping(cx, &mut self.codec))?; ready!(self + .inner .settings - .poll_send(cx, &mut self.codec, &mut self.streams))?; - ready!(self.streams.send_pending_refusal(cx, &mut self.codec))?; + .poll_send(cx, &mut self.codec, &mut self.inner.streams))?; + ready!(self.inner.streams.send_pending_refusal(cx, &mut self.codec))?; Poll::Ready(Ok(())) } @@ -150,32 +178,15 @@ where /// This will return `Some(reason)` if the connection should be closed /// afterwards. If this is a graceful shutdown, this returns `None`. fn poll_go_away(&mut self, cx: &mut Context) -> Poll>> { - self.go_away.send_pending_go_away(cx, &mut self.codec) - } - - fn go_away(&mut self, id: StreamId, e: Reason) { - let frame = frame::GoAway::new(id, e); - self.streams.send_go_away(id); - self.go_away.go_away(frame); - } - - fn go_away_now(&mut self, e: Reason) { - let last_processed_id = self.streams.last_processed_id(); - let frame = frame::GoAway::new(last_processed_id, e); - self.go_away.go_away_now(frame); + self.inner.go_away.send_pending_go_away(cx, &mut self.codec) } pub fn go_away_from_user(&mut self, e: Reason) { - let last_processed_id = self.streams.last_processed_id(); - let frame = frame::GoAway::new(last_processed_id, e); - self.go_away.go_away_from_user(frame); - - // Notify all streams of reason we're abruptly closing. - self.streams.recv_err(&proto::Error::Proto(e)); + self.inner.as_dyn().go_away_from_user(e) } fn take_error(&mut self, ours: Reason) -> Poll> { - let reason = if let Some(theirs) = self.error.take() { + let reason = if let Some(theirs) = self.inner.error.take() { match (ours, theirs) { // If either side reported an error, return that // to the user. @@ -202,13 +213,13 @@ where pub fn maybe_close_connection_if_no_streams(&mut self) { // If we poll() and realize that there are no streams or references // then we can close the connection by transitioning to GOAWAY - if !self.streams.has_streams_or_other_references() { - self.go_away_now(Reason::NO_ERROR); + if !self.inner.streams.has_streams_or_other_references() { + self.inner.as_dyn().go_away_now(Reason::NO_ERROR); } } pub(crate) fn take_user_pings(&mut self) -> Option { - self.ping_pong.take_user_pings() + self.inner.ping_pong.take_user_pings() } /// Advances the internal state of the connection. @@ -217,79 +228,39 @@ where // order to placate the borrow checker — `self` is mutably borrowed by // `poll2`, which means that we can't borrow `self.span` to enter it. // The clone is just an atomic ref bump. - let span = self.span.clone(); + let span = self.inner.span.clone(); let _e = span.enter(); let span = tracing::trace_span!("poll"); let _e = span.enter(); - use crate::codec::RecvError::*; loop { - tracing::trace!(connection.state = ?self.state); + tracing::trace!(connection.state = ?self.inner.state); // TODO: probably clean up this glob of code - match self.state { + match self.inner.state { // When open, continue to poll a frame State::Open => { - match self.poll2(cx) { - // The connection has shutdown normally - Poll::Ready(Ok(())) => self.state = State::Closing(Reason::NO_ERROR), + let result = match self.poll2(cx) { + Poll::Ready(result) => result, // The connection is not ready to make progress Poll::Pending => { // Ensure all window updates have been sent. // // This will also handle flushing `self.codec` - ready!(self.streams.poll_complete(cx, &mut self.codec))?; + ready!(self.inner.streams.poll_complete(cx, &mut self.codec))?; - if (self.error.is_some() || self.go_away.should_close_on_idle()) - && !self.streams.has_streams() + if (self.inner.error.is_some() + || self.inner.go_away.should_close_on_idle()) + && !self.inner.streams.has_streams() { - self.go_away_now(Reason::NO_ERROR); + self.inner.as_dyn().go_away_now(Reason::NO_ERROR); continue; } return Poll::Pending; } - // Attempting to read a frame resulted in a connection level - // error. This is handled by setting a GOAWAY frame followed by - // terminating the connection. - Poll::Ready(Err(Connection(e))) => { - tracing::debug!(error = ?e, "Connection::poll; connection error"); + }; - // We may have already sent a GOAWAY for this error, - // if so, don't send another, just flush and close up. - if let Some(reason) = self.go_away.going_away_reason() { - if reason == e { - tracing::trace!(" -> already going away"); - self.state = State::Closing(e); - continue; - } - } - - // Reset all active streams - self.streams.recv_err(&e.into()); - self.go_away_now(e); - } - // Attempting to read a frame resulted in a stream level error. - // This is handled by resetting the frame then trying to read - // another frame. - Poll::Ready(Err(Stream { id, reason })) => { - tracing::trace!(?id, ?reason, "stream error"); - self.streams.send_reset(id, reason); - } - // Attempting to read a frame resulted in an I/O error. All - // active streams must be reset. - // - // TODO: Are I/O errors recoverable? - Poll::Ready(Err(Io(e))) => { - tracing::debug!(error = ?e, "Connection::poll; IO error"); - let e = e.into(); - - // Reset all active streams - self.streams.recv_err(&e); - - // Return the error - return Poll::Ready(Err(e)); - } - } + self.inner.as_dyn().handle_poll2_result(result)? } State::Closing(reason) => { tracing::trace!("connection closing after flush"); @@ -297,7 +268,7 @@ where ready!(self.codec.shutdown(cx))?; // Transition the state to error - self.state = State::Closed(reason); + self.inner.state = State::Closed(reason); } State::Closed(reason) => return self.take_error(reason), } @@ -305,8 +276,6 @@ where } fn poll2(&mut self, cx: &mut Context) -> Poll> { - use crate::frame::Frame::*; - // This happens outside of the loop to prevent needing to do a clock // check and then comparison of the queue possibly multiple times a // second (and thus, the clock wouldn't have changed enough to matter). @@ -319,8 +288,8 @@ where // - poll_go_away may buffer a graceful shutdown GOAWAY frame // - If it has, we've also added a PING to be sent in poll_ready if let Some(reason) = ready!(self.poll_go_away(cx)?) { - if self.go_away.should_close_now() { - if self.go_away.is_user_initiated() { + if self.inner.go_away.should_close_now() { + if self.inner.go_away.is_user_initiated() { // A user initiated abrupt shutdown shouldn't return // the same error back to the user. return Poll::Ready(Ok(())); @@ -337,61 +306,20 @@ where } ready!(self.poll_ready(cx))?; - match ready!(Pin::new(&mut self.codec).poll_next(cx)?) { - Some(Headers(frame)) => { - tracing::trace!(?frame, "recv HEADERS"); - self.streams.recv_headers(frame)?; + match self + .inner + .as_dyn() + .recv_frame(ready!(Pin::new(&mut self.codec).poll_next(cx)?))? + { + ReceivedFrame::Settings(frame) => { + self.inner.settings.recv_settings( + frame, + &mut self.codec, + &mut self.inner.streams, + )?; } - Some(Data(frame)) => { - tracing::trace!(?frame, "recv DATA"); - self.streams.recv_data(frame)?; - } - Some(Reset(frame)) => { - tracing::trace!(?frame, "recv RST_STREAM"); - self.streams.recv_reset(frame)?; - } - Some(PushPromise(frame)) => { - tracing::trace!(?frame, "recv PUSH_PROMISE"); - self.streams.recv_push_promise(frame)?; - } - Some(Settings(frame)) => { - tracing::trace!(?frame, "recv SETTINGS"); - self.settings - .recv_settings(frame, &mut self.codec, &mut self.streams)?; - } - Some(GoAway(frame)) => { - tracing::trace!(?frame, "recv GOAWAY"); - // This should prevent starting new streams, - // but should allow continuing to process current streams - // until they are all EOS. Once they are, State should - // transition to GoAway. - self.streams.recv_go_away(&frame)?; - self.error = Some(frame.reason()); - } - Some(Ping(frame)) => { - tracing::trace!(?frame, "recv PING"); - let status = self.ping_pong.recv_ping(frame); - if status.is_shutdown() { - assert!( - self.go_away.is_going_away(), - "received unexpected shutdown ping" - ); - - let last_processed_id = self.streams.last_processed_id(); - self.go_away(last_processed_id, Reason::NO_ERROR); - } - } - Some(WindowUpdate(frame)) => { - tracing::trace!(?frame, "recv WINDOW_UPDATE"); - self.streams.recv_window_update(frame)?; - } - Some(Priority(frame)) => { - tracing::trace!(?frame, "recv PRIORITY"); - // TODO: handle - } - None => { - tracing::trace!("codec closed"); - self.streams.recv_eof(false).expect("mutex poisoned"); + ReceivedFrame::Continue => (), + ReceivedFrame::Done => { return Poll::Ready(Ok(())); } } @@ -399,17 +327,190 @@ where } fn clear_expired_reset_streams(&mut self) { - self.streams.clear_expired_reset_streams(); + self.inner.streams.clear_expired_reset_streams(); } } +impl ConnectionInner +where + P: Peer, + B: Buf, +{ + fn as_dyn(&mut self) -> DynConnection<'_, B> { + let ConnectionInner { + state, + go_away, + streams, + error, + ping_pong, + .. + } = self; + let streams = streams.as_dyn(); + DynConnection { + state, + go_away, + streams, + error, + ping_pong, + } + } +} + +impl DynConnection<'_, B> +where + B: Buf, +{ + fn go_away(&mut self, id: StreamId, e: Reason) { + let frame = frame::GoAway::new(id, e); + self.streams.send_go_away(id); + self.go_away.go_away(frame); + } + + fn go_away_now(&mut self, e: Reason) { + let last_processed_id = self.streams.last_processed_id(); + let frame = frame::GoAway::new(last_processed_id, e); + self.go_away.go_away_now(frame); + } + + fn go_away_from_user(&mut self, e: Reason) { + let last_processed_id = self.streams.last_processed_id(); + let frame = frame::GoAway::new(last_processed_id, e); + self.go_away.go_away_from_user(frame); + + // Notify all streams of reason we're abruptly closing. + self.streams.recv_err(&proto::Error::Proto(e)); + } + + fn handle_poll2_result(&mut self, result: Result<(), RecvError>) -> Result<(), Error> { + use crate::codec::RecvError::*; + match result { + // The connection has shutdown normally + Ok(()) => { + *self.state = State::Closing(Reason::NO_ERROR); + Ok(()) + } + // Attempting to read a frame resulted in a connection level + // error. This is handled by setting a GOAWAY frame followed by + // terminating the connection. + Err(Connection(e)) => { + tracing::debug!(error = ?e, "Connection::poll; connection error"); + + // We may have already sent a GOAWAY for this error, + // if so, don't send another, just flush and close up. + if let Some(reason) = self.go_away.going_away_reason() { + if reason == e { + tracing::trace!(" -> already going away"); + *self.state = State::Closing(e); + return Ok(()); + } + } + + // Reset all active streams + self.streams.recv_err(&e.into()); + self.go_away_now(e); + Ok(()) + } + // Attempting to read a frame resulted in a stream level error. + // This is handled by resetting the frame then trying to read + // another frame. + Err(Stream { id, reason }) => { + tracing::trace!(?id, ?reason, "stream error"); + self.streams.send_reset(id, reason); + Ok(()) + } + // Attempting to read a frame resulted in an I/O error. All + // active streams must be reset. + // + // TODO: Are I/O errors recoverable? + Err(Io(e)) => { + tracing::debug!(error = ?e, "Connection::poll; IO error"); + let e = e.into(); + + // Reset all active streams + self.streams.recv_err(&e); + + // Return the error + Err(e) + } + } + } + + fn recv_frame(&mut self, frame: Option) -> Result { + use crate::frame::Frame::*; + match frame { + Some(Headers(frame)) => { + tracing::trace!(?frame, "recv HEADERS"); + self.streams.recv_headers(frame)?; + } + Some(Data(frame)) => { + tracing::trace!(?frame, "recv DATA"); + self.streams.recv_data(frame)?; + } + Some(Reset(frame)) => { + tracing::trace!(?frame, "recv RST_STREAM"); + self.streams.recv_reset(frame)?; + } + Some(PushPromise(frame)) => { + tracing::trace!(?frame, "recv PUSH_PROMISE"); + self.streams.recv_push_promise(frame)?; + } + Some(Settings(frame)) => { + tracing::trace!(?frame, "recv SETTINGS"); + return Ok(ReceivedFrame::Settings(frame)); + } + Some(GoAway(frame)) => { + tracing::trace!(?frame, "recv GOAWAY"); + // This should prevent starting new streams, + // but should allow continuing to process current streams + // until they are all EOS. Once they are, State should + // transition to GoAway. + self.streams.recv_go_away(&frame)?; + *self.error = Some(frame.reason()); + } + Some(Ping(frame)) => { + tracing::trace!(?frame, "recv PING"); + let status = self.ping_pong.recv_ping(frame); + if status.is_shutdown() { + assert!( + self.go_away.is_going_away(), + "received unexpected shutdown ping" + ); + + let last_processed_id = self.streams.last_processed_id(); + self.go_away(last_processed_id, Reason::NO_ERROR); + } + } + Some(WindowUpdate(frame)) => { + tracing::trace!(?frame, "recv WINDOW_UPDATE"); + self.streams.recv_window_update(frame)?; + } + Some(Priority(frame)) => { + tracing::trace!(?frame, "recv PRIORITY"); + // TODO: handle + } + None => { + tracing::trace!("codec closed"); + self.streams.recv_eof(false).expect("mutex poisoned"); + return Ok(ReceivedFrame::Done); + } + } + Ok(ReceivedFrame::Continue) + } +} + +enum ReceivedFrame { + Settings(frame::Settings), + Continue, + Done, +} + impl Connection where T: AsyncRead + AsyncWrite, B: Buf, { pub(crate) fn streams(&self) -> &Streams { - &self.streams + &self.inner.streams } } @@ -419,12 +520,12 @@ where B: Buf, { pub fn next_incoming(&mut self) -> Option> { - self.streams.next_incoming() + self.inner.streams.next_incoming() } // Graceful shutdown only makes sense for server peers. pub fn go_away_gracefully(&mut self) { - if self.go_away.is_going_away() { + if self.inner.go_away.is_going_away() { // No reason to start a new one. return; } @@ -440,11 +541,11 @@ where // > send another GOAWAY frame with an updated last stream identifier. // > This ensures that a connection can be cleanly shut down without // > losing requests. - self.go_away(StreamId::MAX, Reason::NO_ERROR); + self.inner.as_dyn().go_away(StreamId::MAX, Reason::NO_ERROR); // We take the advice of waiting 1 RTT literally, and wait // for a pong before proceeding. - self.ping_pong.ping_shutdown(); + self.inner.ping_pong.ping_shutdown(); } } @@ -455,6 +556,6 @@ where { fn drop(&mut self) { // Ignore errors as this indicates that the mutex is poisoned. - let _ = self.streams.recv_eof(true); + let _ = self.inner.streams.recv_eof(true); } } diff --git a/src/proto/mod.rs b/src/proto/mod.rs index f9e068b..84fd854 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -10,7 +10,7 @@ pub(crate) use self::connection::{Config, Connection}; pub(crate) use self::error::Error; pub(crate) use self::peer::{Dyn as DynPeer, Peer}; pub(crate) use self::ping_pong::UserPings; -pub(crate) use self::streams::{OpaqueStreamRef, StreamRef, Streams}; +pub(crate) use self::streams::{DynStreams, OpaqueStreamRef, StreamRef, Streams}; pub(crate) use self::streams::{Open, PollReset, Prioritized}; use crate::codec::Codec; diff --git a/src/proto/streams/mod.rs b/src/proto/streams/mod.rs index 508d9a1..608395c 100644 --- a/src/proto/streams/mod.rs +++ b/src/proto/streams/mod.rs @@ -12,7 +12,7 @@ mod streams; pub(crate) use self::prioritize::Prioritized; pub(crate) use self::recv::Open; pub(crate) use self::send::PollReset; -pub(crate) use self::streams::{OpaqueStreamRef, StreamRef, Streams}; +pub(crate) use self::streams::{DynStreams, OpaqueStreamRef, StreamRef, Streams}; use self::buffer::Buffer; use self::counts::Counts; diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index b7b616f..701b8f4 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -545,43 +545,57 @@ impl Prioritize { // First check if there are any data chunks to take back if let Some(frame) = dst.take_last_data_frame() { - tracing::trace!( - ?frame, - sz = frame.payload().inner.get_ref().remaining(), - "reclaimed" - ); + self.reclaim_frame_inner(buffer, store, frame) + } else { + false + } + } - let mut eos = false; - let key = frame.payload().stream; + fn reclaim_frame_inner( + &mut self, + buffer: &mut Buffer>, + store: &mut Store, + frame: frame::Data>, + ) -> bool + where + B: Buf, + { + tracing::trace!( + ?frame, + sz = frame.payload().inner.get_ref().remaining(), + "reclaimed" + ); - match mem::replace(&mut self.in_flight_data_frame, InFlightData::Nothing) { - InFlightData::Nothing => panic!("wasn't expecting a frame to reclaim"), - InFlightData::Drop => { - tracing::trace!("not reclaiming frame for cancelled stream"); - return false; - } - InFlightData::DataFrame(k) => { - debug_assert_eq!(k, key); - } + let mut eos = false; + let key = frame.payload().stream; + + match mem::replace(&mut self.in_flight_data_frame, InFlightData::Nothing) { + InFlightData::Nothing => panic!("wasn't expecting a frame to reclaim"), + InFlightData::Drop => { + tracing::trace!("not reclaiming frame for cancelled stream"); + return false; + } + InFlightData::DataFrame(k) => { + debug_assert_eq!(k, key); + } + } + + let mut frame = frame.map(|prioritized| { + // TODO: Ensure fully written + eos = prioritized.end_of_stream; + prioritized.inner.into_inner() + }); + + if frame.payload().has_remaining() { + let mut stream = store.resolve(key); + + if eos { + frame.set_end_stream(true); } - let mut frame = frame.map(|prioritized| { - // TODO: Ensure fully written - eos = prioritized.end_of_stream; - prioritized.inner.into_inner() - }); + self.push_back_frame(frame.into(), buffer, &mut stream); - if frame.payload().has_remaining() { - let mut stream = store.resolve(key); - - if eos { - frame.set_end_stream(true); - } - - self.push_back_frame(frame.into(), buffer, &mut stream); - - return true; - } + return true; } false diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 7e9b403..7ba87eb 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -37,6 +37,17 @@ where _p: ::std::marker::PhantomData

, } +// Like `Streams` but with a `peer::Dyn` field instead of a static `P: Peer` type parameter. +// Ensures that the methods only get one instantiation, instead of two (client and server) +#[derive(Debug)] +pub(crate) struct DynStreams<'a, B> { + inner: &'a Mutex, + + send_buffer: &'a SendBuffer, + + peer: peer::Dyn, +} + /// Reference to the stream state #[derive(Debug)] pub(crate) struct StreamRef { @@ -101,17 +112,7 @@ where let peer = P::r#dyn(); Streams { - inner: Arc::new(Mutex::new(Inner { - counts: Counts::new(peer, &config), - actions: Actions { - recv: Recv::new(peer, &config), - send: Send::new(&config), - task: None, - conn_error: None, - }, - store: Store::new(), - refs: 1, - })), + inner: Inner::new(peer, config), send_buffer: Arc::new(SendBuffer::new()), _p: ::std::marker::PhantomData, } @@ -126,434 +127,6 @@ where .set_target_connection_window(size, &mut me.actions.task) } - /// Process inbound headers - pub fn recv_headers(&mut self, frame: frame::Headers) -> Result<(), RecvError> { - let id = frame.stream_id(); - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - // The GOAWAY process has begun. All streams with a greater ID than - // specified as part of GOAWAY should be ignored. - if id > me.actions.recv.max_stream_id() { - tracing::trace!( - "id ({:?}) > max_stream_id ({:?}), ignoring HEADERS", - id, - me.actions.recv.max_stream_id() - ); - return Ok(()); - } - - let key = match me.store.find_entry(id) { - Entry::Occupied(e) => e.key(), - Entry::Vacant(e) => { - // Client: it's possible to send a request, and then send - // a RST_STREAM while the response HEADERS were in transit. - // - // Server: we can't reset a stream before having received - // the request headers, so don't allow. - if !P::is_server() { - // This may be response headers for a stream we've already - // forgotten about... - if me.actions.may_have_forgotten_stream::

(id) { - tracing::debug!( - "recv_headers for old stream={:?}, sending STREAM_CLOSED", - id, - ); - return Err(RecvError::Stream { - id, - reason: Reason::STREAM_CLOSED, - }); - } - } - - match me.actions.recv.open(id, Open::Headers, &mut me.counts)? { - Some(stream_id) => { - let stream = Stream::new( - stream_id, - me.actions.send.init_window_sz(), - me.actions.recv.init_window_sz(), - ); - - e.insert(stream) - } - None => return Ok(()), - } - } - }; - - let stream = me.store.resolve(key); - - if stream.state.is_local_reset() { - // Locally reset streams must ignore frames "for some time". - // This is because the remote may have sent trailers before - // receiving the RST_STREAM frame. - tracing::trace!("recv_headers; ignoring trailers on {:?}", stream.id); - return Ok(()); - } - - let actions = &mut me.actions; - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); - let send_buffer = &mut *send_buffer; - - me.counts.transition(stream, |counts, stream| { - tracing::trace!( - "recv_headers; stream={:?}; state={:?}", - stream.id, - stream.state - ); - - let res = if stream.state.is_recv_headers() { - match actions.recv.recv_headers(frame, stream, counts) { - Ok(()) => Ok(()), - Err(RecvHeaderBlockError::Oversize(resp)) => { - if let Some(resp) = resp { - let sent = actions.send.send_headers( - resp, send_buffer, stream, counts, &mut actions.task); - debug_assert!(sent.is_ok(), "oversize response should not fail"); - - actions.send.schedule_implicit_reset( - stream, - Reason::REFUSED_STREAM, - counts, - &mut actions.task); - - actions.recv.enqueue_reset_expiration(stream, counts); - - Ok(()) - } else { - Err(RecvError::Stream { - id: stream.id, - reason: Reason::REFUSED_STREAM, - }) - } - }, - Err(RecvHeaderBlockError::State(err)) => Err(err), - } - } else { - if !frame.is_end_stream() { - // Receiving trailers that don't set EOS is a "malformed" - // message. Malformed messages are a stream error. - proto_err!(stream: "recv_headers: trailers frame was not EOS; stream={:?}", stream.id); - return Err(RecvError::Stream { - id: stream.id, - reason: Reason::PROTOCOL_ERROR, - }); - } - - actions.recv.recv_trailers(frame, stream) - }; - - actions.reset_on_recv_stream_err(send_buffer, stream, counts, res) - }) - } - - pub fn recv_data(&mut self, frame: frame::Data) -> Result<(), RecvError> { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let id = frame.stream_id(); - - let stream = match me.store.find_mut(&id) { - Some(stream) => stream, - None => { - // The GOAWAY process has begun. All streams with a greater ID - // than specified as part of GOAWAY should be ignored. - if id > me.actions.recv.max_stream_id() { - tracing::trace!( - "id ({:?}) > max_stream_id ({:?}), ignoring DATA", - id, - me.actions.recv.max_stream_id() - ); - return Ok(()); - } - - if me.actions.may_have_forgotten_stream::

(id) { - tracing::debug!("recv_data for old stream={:?}, sending STREAM_CLOSED", id,); - - let sz = frame.payload().len(); - // This should have been enforced at the codec::FramedRead layer, so - // this is just a sanity check. - assert!(sz <= super::MAX_WINDOW_SIZE as usize); - let sz = sz as WindowSize; - - me.actions.recv.ignore_data(sz)?; - return Err(RecvError::Stream { - id, - reason: Reason::STREAM_CLOSED, - }); - } - - proto_err!(conn: "recv_data: stream not found; id={:?}", id); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); - } - }; - - let actions = &mut me.actions; - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); - let send_buffer = &mut *send_buffer; - - me.counts.transition(stream, |counts, stream| { - let sz = frame.payload().len(); - let res = actions.recv.recv_data(frame, stream); - - // Any stream error after receiving a DATA frame means - // we won't give the data to the user, and so they can't - // release the capacity. We do it automatically. - if let Err(RecvError::Stream { .. }) = res { - actions - .recv - .release_connection_capacity(sz as WindowSize, &mut None); - } - actions.reset_on_recv_stream_err(send_buffer, stream, counts, res) - }) - } - - pub fn recv_reset(&mut self, frame: frame::Reset) -> Result<(), RecvError> { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let id = frame.stream_id(); - - if id.is_zero() { - proto_err!(conn: "recv_reset: invalid stream ID 0"); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); - } - - // The GOAWAY process has begun. All streams with a greater ID than - // specified as part of GOAWAY should be ignored. - if id > me.actions.recv.max_stream_id() { - tracing::trace!( - "id ({:?}) > max_stream_id ({:?}), ignoring RST_STREAM", - id, - me.actions.recv.max_stream_id() - ); - return Ok(()); - } - - let stream = match me.store.find_mut(&id) { - Some(stream) => stream, - None => { - // TODO: Are there other error cases? - me.actions - .ensure_not_idle(me.counts.peer(), id) - .map_err(RecvError::Connection)?; - - return Ok(()); - } - }; - - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); - let send_buffer = &mut *send_buffer; - - let actions = &mut me.actions; - - me.counts.transition(stream, |counts, stream| { - actions.recv.recv_reset(frame, stream); - actions.send.recv_err(send_buffer, stream, counts); - assert!(stream.state.is_closed()); - Ok(()) - }) - } - - /// Handle a received error and return the ID of the last processed stream. - pub fn recv_err(&mut self, err: &proto::Error) -> StreamId { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let actions = &mut me.actions; - let counts = &mut me.counts; - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); - let send_buffer = &mut *send_buffer; - - let last_processed_id = actions.recv.last_processed_id(); - - me.store - .for_each(|stream| { - counts.transition(stream, |counts, stream| { - actions.recv.recv_err(err, &mut *stream); - actions.send.recv_err(send_buffer, stream, counts); - Ok::<_, ()>(()) - }) - }) - .unwrap(); - - actions.conn_error = Some(err.shallow_clone()); - - last_processed_id - } - - pub fn recv_go_away(&mut self, frame: &frame::GoAway) -> Result<(), RecvError> { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let actions = &mut me.actions; - let counts = &mut me.counts; - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); - let send_buffer = &mut *send_buffer; - - let last_stream_id = frame.last_stream_id(); - - actions.send.recv_go_away(last_stream_id)?; - - let err = frame.reason().into(); - - me.store - .for_each(|stream| { - if stream.id > last_stream_id { - counts.transition(stream, |counts, stream| { - actions.recv.recv_err(&err, &mut *stream); - actions.send.recv_err(send_buffer, stream, counts); - Ok::<_, ()>(()) - }) - } else { - Ok::<_, ()>(()) - } - }) - .unwrap(); - - actions.conn_error = Some(err); - - Ok(()) - } - - pub fn last_processed_id(&self) -> StreamId { - self.inner.lock().unwrap().actions.recv.last_processed_id() - } - - pub fn recv_window_update(&mut self, frame: frame::WindowUpdate) -> Result<(), RecvError> { - let id = frame.stream_id(); - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); - let send_buffer = &mut *send_buffer; - - if id.is_zero() { - me.actions - .send - .recv_connection_window_update(frame, &mut me.store, &mut me.counts) - .map_err(RecvError::Connection)?; - } else { - // The remote may send window updates for streams that the local now - // considers closed. It's ok... - if let Some(mut stream) = me.store.find_mut(&id) { - // This result is ignored as there is nothing to do when there - // is an error. The stream is reset by the function on error and - // the error is informational. - let _ = me.actions.send.recv_stream_window_update( - frame.size_increment(), - send_buffer, - &mut stream, - &mut me.counts, - &mut me.actions.task, - ); - } else { - me.actions - .ensure_not_idle(me.counts.peer(), id) - .map_err(RecvError::Connection)?; - } - } - - Ok(()) - } - - pub fn recv_push_promise(&mut self, frame: frame::PushPromise) -> Result<(), RecvError> { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let id = frame.stream_id(); - let promised_id = frame.promised_id(); - - // First, ensure that the initiating stream is still in a valid state. - let parent_key = match me.store.find_mut(&id) { - Some(stream) => { - // The GOAWAY process has begun. All streams with a greater ID - // than specified as part of GOAWAY should be ignored. - if id > me.actions.recv.max_stream_id() { - tracing::trace!( - "id ({:?}) > max_stream_id ({:?}), ignoring PUSH_PROMISE", - id, - me.actions.recv.max_stream_id() - ); - return Ok(()); - } - - // The stream must be receive open - stream.state.ensure_recv_open()?; - stream.key() - } - None => { - proto_err!(conn: "recv_push_promise: initiating stream is in an invalid state"); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); - } - }; - - // TODO: Streams in the reserved states do not count towards the concurrency - // limit. However, it seems like there should be a cap otherwise this - // could grow in memory indefinitely. - - // Ensure that we can reserve streams - me.actions.recv.ensure_can_reserve()?; - - // Next, open the stream. - // - // If `None` is returned, then the stream is being refused. There is no - // further work to be done. - if me - .actions - .recv - .open(promised_id, Open::PushPromise, &mut me.counts)? - .is_none() - { - return Ok(()); - } - - // Try to handle the frame and create a corresponding key for the pushed stream - // this requires a bit of indirection to make the borrow checker happy. - let child_key: Option = { - // Create state for the stream - let stream = me.store.insert(promised_id, { - Stream::new( - promised_id, - me.actions.send.init_window_sz(), - me.actions.recv.init_window_sz(), - ) - }); - - let actions = &mut me.actions; - - me.counts.transition(stream, |counts, stream| { - let stream_valid = actions.recv.recv_push_promise(frame, stream); - - match stream_valid { - Ok(()) => Ok(Some(stream.key())), - _ => { - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); - actions - .reset_on_recv_stream_err( - &mut *send_buffer, - stream, - counts, - stream_valid, - ) - .map(|()| None) - } - } - })? - }; - // If we're successful, push the headers and stream... - if let Some(child) = child_key { - let mut ppp = me.store[parent_key].pending_push_promises.take(); - ppp.push(&mut me.store.resolve(child)); - - let parent = &mut me.store.resolve(parent_key); - parent.pending_push_promises = ppp; - parent.notify_recv(); - }; - - Ok(()) - } - pub fn next_incoming(&mut self) -> Option> { let mut me = self.inner.lock().unwrap(); let me = &mut *me; @@ -604,30 +177,7 @@ where T: AsyncWrite + Unpin, { let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); - let send_buffer = &mut *send_buffer; - - // Send WINDOW_UPDATE frames first - // - // TODO: It would probably be better to interleave updates w/ data - // frames. - ready!(me - .actions - .recv - .poll_complete(cx, &mut me.store, &mut me.counts, dst))?; - - // Send any other pending frames - ready!(me - .actions - .send - .poll_complete(cx, send_buffer, &mut me.store, &mut me.counts, dst))?; - - // Nothing else to do, track the task - me.actions.task = Some(cx.waker().clone()); - - Poll::Ready(Ok(())) + me.poll_complete(&self.send_buffer, cx, dst) } pub fn apply_remote_settings(&mut self, frame: &frame::Settings) -> Result<(), RecvError> { @@ -741,12 +291,587 @@ where send_buffer: self.send_buffer.clone(), }) } +} + +impl DynStreams<'_, B> { + pub fn recv_headers(&mut self, frame: frame::Headers) -> Result<(), RecvError> { + let mut me = self.inner.lock().unwrap(); + + me.recv_headers(self.peer, &self.send_buffer, frame) + } + + pub fn recv_data(&mut self, frame: frame::Data) -> Result<(), RecvError> { + let mut me = self.inner.lock().unwrap(); + me.recv_data(self.peer, &self.send_buffer, frame) + } + + pub fn recv_reset(&mut self, frame: frame::Reset) -> Result<(), RecvError> { + let mut me = self.inner.lock().unwrap(); + + me.recv_reset(&self.send_buffer, frame) + } + + /// Handle a received error and return the ID of the last processed stream. + pub fn recv_err(&mut self, err: &proto::Error) -> StreamId { + let mut me = self.inner.lock().unwrap(); + me.recv_err(&self.send_buffer, err) + } + + pub fn recv_go_away(&mut self, frame: &frame::GoAway) -> Result<(), RecvError> { + let mut me = self.inner.lock().unwrap(); + me.recv_go_away(&self.send_buffer, frame) + } + + pub fn last_processed_id(&self) -> StreamId { + self.inner.lock().unwrap().actions.recv.last_processed_id() + } + + pub fn recv_window_update(&mut self, frame: frame::WindowUpdate) -> Result<(), RecvError> { + let mut me = self.inner.lock().unwrap(); + me.recv_window_update(&self.send_buffer, frame) + } + + pub fn recv_push_promise(&mut self, frame: frame::PushPromise) -> Result<(), RecvError> { + let mut me = self.inner.lock().unwrap(); + me.recv_push_promise(&self.send_buffer, frame) + } + + pub fn recv_eof(&mut self, clear_pending_accept: bool) -> Result<(), ()> { + let mut me = self.inner.lock().map_err(|_| ())?; + me.recv_eof(&self.send_buffer, clear_pending_accept) + } pub fn send_reset(&mut self, id: StreamId, reason: Reason) { let mut me = self.inner.lock().unwrap(); - let me = &mut *me; + me.send_reset(&self.send_buffer, id, reason) + } - let key = match me.store.find_entry(id) { + pub fn send_go_away(&mut self, last_processed_id: StreamId) { + let mut me = self.inner.lock().unwrap(); + me.actions.recv.go_away(last_processed_id); + } +} + +impl Inner { + fn new(peer: peer::Dyn, config: Config) -> Arc> { + Arc::new(Mutex::new(Inner { + counts: Counts::new(peer, &config), + actions: Actions { + recv: Recv::new(peer, &config), + send: Send::new(&config), + task: None, + conn_error: None, + }, + store: Store::new(), + refs: 1, + })) + } + + fn recv_headers( + &mut self, + peer: peer::Dyn, + send_buffer: &SendBuffer, + frame: frame::Headers, + ) -> Result<(), RecvError> { + let id = frame.stream_id(); + + // The GOAWAY process has begun. All streams with a greater ID than + // specified as part of GOAWAY should be ignored. + if id > self.actions.recv.max_stream_id() { + tracing::trace!( + "id ({:?}) > max_stream_id ({:?}), ignoring HEADERS", + id, + self.actions.recv.max_stream_id() + ); + return Ok(()); + } + + let key = match self.store.find_entry(id) { + Entry::Occupied(e) => e.key(), + Entry::Vacant(e) => { + // Client: it's possible to send a request, and then send + // a RST_STREAM while the response HEADERS were in transit. + // + // Server: we can't reset a stream before having received + // the request headers, so don't allow. + if !peer.is_server() { + // This may be response headers for a stream we've already + // forgotten about... + if self.actions.may_have_forgotten_stream(peer, id) { + tracing::debug!( + "recv_headers for old stream={:?}, sending STREAM_CLOSED", + id, + ); + return Err(RecvError::Stream { + id, + reason: Reason::STREAM_CLOSED, + }); + } + } + + match self + .actions + .recv + .open(id, Open::Headers, &mut self.counts)? + { + Some(stream_id) => { + let stream = Stream::new( + stream_id, + self.actions.send.init_window_sz(), + self.actions.recv.init_window_sz(), + ); + + e.insert(stream) + } + None => return Ok(()), + } + } + }; + + let stream = self.store.resolve(key); + + if stream.state.is_local_reset() { + // Locally reset streams must ignore frames "for some time". + // This is because the remote may have sent trailers before + // receiving the RST_STREAM frame. + tracing::trace!("recv_headers; ignoring trailers on {:?}", stream.id); + return Ok(()); + } + + let actions = &mut self.actions; + let mut send_buffer = send_buffer.inner.lock().unwrap(); + let send_buffer = &mut *send_buffer; + + self.counts.transition(stream, |counts, stream| { + tracing::trace!( + "recv_headers; stream={:?}; state={:?}", + stream.id, + stream.state + ); + + let res = if stream.state.is_recv_headers() { + match actions.recv.recv_headers(frame, stream, counts) { + Ok(()) => Ok(()), + Err(RecvHeaderBlockError::Oversize(resp)) => { + if let Some(resp) = resp { + let sent = actions.send.send_headers( + resp, send_buffer, stream, counts, &mut actions.task); + debug_assert!(sent.is_ok(), "oversize response should not fail"); + + actions.send.schedule_implicit_reset( + stream, + Reason::REFUSED_STREAM, + counts, + &mut actions.task); + + actions.recv.enqueue_reset_expiration(stream, counts); + + Ok(()) + } else { + Err(RecvError::Stream { + id: stream.id, + reason: Reason::REFUSED_STREAM, + }) + } + }, + Err(RecvHeaderBlockError::State(err)) => Err(err), + } + } else { + if !frame.is_end_stream() { + // Receiving trailers that don't set EOS is a "malformed" + // message. Malformed messages are a stream error. + proto_err!(stream: "recv_headers: trailers frame was not EOS; stream={:?}", stream.id); + return Err(RecvError::Stream { + id: stream.id, + reason: Reason::PROTOCOL_ERROR, + }); + } + + actions.recv.recv_trailers(frame, stream) + }; + + actions.reset_on_recv_stream_err(send_buffer, stream, counts, res) + }) + } + + fn recv_data( + &mut self, + peer: peer::Dyn, + send_buffer: &SendBuffer, + frame: frame::Data, + ) -> Result<(), RecvError> { + let id = frame.stream_id(); + + let stream = match self.store.find_mut(&id) { + Some(stream) => stream, + None => { + // The GOAWAY process has begun. All streams with a greater ID + // than specified as part of GOAWAY should be ignored. + if id > self.actions.recv.max_stream_id() { + tracing::trace!( + "id ({:?}) > max_stream_id ({:?}), ignoring DATA", + id, + self.actions.recv.max_stream_id() + ); + return Ok(()); + } + + if self.actions.may_have_forgotten_stream(peer, id) { + tracing::debug!("recv_data for old stream={:?}, sending STREAM_CLOSED", id,); + + let sz = frame.payload().len(); + // This should have been enforced at the codec::FramedRead layer, so + // this is just a sanity check. + assert!(sz <= super::MAX_WINDOW_SIZE as usize); + let sz = sz as WindowSize; + + self.actions.recv.ignore_data(sz)?; + return Err(RecvError::Stream { + id, + reason: Reason::STREAM_CLOSED, + }); + } + + proto_err!(conn: "recv_data: stream not found; id={:?}", id); + return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + } + }; + + let actions = &mut self.actions; + let mut send_buffer = send_buffer.inner.lock().unwrap(); + let send_buffer = &mut *send_buffer; + + self.counts.transition(stream, |counts, stream| { + let sz = frame.payload().len(); + let res = actions.recv.recv_data(frame, stream); + + // Any stream error after receiving a DATA frame means + // we won't give the data to the user, and so they can't + // release the capacity. We do it automatically. + if let Err(RecvError::Stream { .. }) = res { + actions + .recv + .release_connection_capacity(sz as WindowSize, &mut None); + } + actions.reset_on_recv_stream_err(send_buffer, stream, counts, res) + }) + } + + fn recv_reset( + &mut self, + send_buffer: &SendBuffer, + frame: frame::Reset, + ) -> Result<(), RecvError> { + let id = frame.stream_id(); + + if id.is_zero() { + proto_err!(conn: "recv_reset: invalid stream ID 0"); + return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + } + + // The GOAWAY process has begun. All streams with a greater ID than + // specified as part of GOAWAY should be ignored. + if id > self.actions.recv.max_stream_id() { + tracing::trace!( + "id ({:?}) > max_stream_id ({:?}), ignoring RST_STREAM", + id, + self.actions.recv.max_stream_id() + ); + return Ok(()); + } + + let stream = match self.store.find_mut(&id) { + Some(stream) => stream, + None => { + // TODO: Are there other error cases? + self.actions + .ensure_not_idle(self.counts.peer(), id) + .map_err(RecvError::Connection)?; + + return Ok(()); + } + }; + + let mut send_buffer = send_buffer.inner.lock().unwrap(); + let send_buffer = &mut *send_buffer; + + let actions = &mut self.actions; + + self.counts.transition(stream, |counts, stream| { + actions.recv.recv_reset(frame, stream); + actions.send.recv_err(send_buffer, stream, counts); + assert!(stream.state.is_closed()); + Ok(()) + }) + } + + fn recv_window_update( + &mut self, + send_buffer: &SendBuffer, + frame: frame::WindowUpdate, + ) -> Result<(), RecvError> { + let id = frame.stream_id(); + + let mut send_buffer = send_buffer.inner.lock().unwrap(); + let send_buffer = &mut *send_buffer; + + if id.is_zero() { + self.actions + .send + .recv_connection_window_update(frame, &mut self.store, &mut self.counts) + .map_err(RecvError::Connection)?; + } else { + // The remote may send window updates for streams that the local now + // considers closed. It's ok... + if let Some(mut stream) = self.store.find_mut(&id) { + // This result is ignored as there is nothing to do when there + // is an error. The stream is reset by the function on error and + // the error is informational. + let _ = self.actions.send.recv_stream_window_update( + frame.size_increment(), + send_buffer, + &mut stream, + &mut self.counts, + &mut self.actions.task, + ); + } else { + self.actions + .ensure_not_idle(self.counts.peer(), id) + .map_err(RecvError::Connection)?; + } + } + + Ok(()) + } + + fn recv_err(&mut self, send_buffer: &SendBuffer, err: &proto::Error) -> StreamId { + let actions = &mut self.actions; + let counts = &mut self.counts; + let mut send_buffer = send_buffer.inner.lock().unwrap(); + let send_buffer = &mut *send_buffer; + + let last_processed_id = actions.recv.last_processed_id(); + + self.store + .for_each(|stream| { + counts.transition(stream, |counts, stream| { + actions.recv.recv_err(err, &mut *stream); + actions.send.recv_err(send_buffer, stream, counts); + Ok::<_, ()>(()) + }) + }) + .unwrap(); + + actions.conn_error = Some(err.shallow_clone()); + + last_processed_id + } + + fn recv_go_away( + &mut self, + send_buffer: &SendBuffer, + frame: &frame::GoAway, + ) -> Result<(), RecvError> { + let actions = &mut self.actions; + let counts = &mut self.counts; + let mut send_buffer = send_buffer.inner.lock().unwrap(); + let send_buffer = &mut *send_buffer; + + let last_stream_id = frame.last_stream_id(); + + actions.send.recv_go_away(last_stream_id)?; + + let err = frame.reason().into(); + + self.store + .for_each(|stream| { + if stream.id > last_stream_id { + counts.transition(stream, |counts, stream| { + actions.recv.recv_err(&err, &mut *stream); + actions.send.recv_err(send_buffer, stream, counts); + Ok::<_, ()>(()) + }) + } else { + Ok::<_, ()>(()) + } + }) + .unwrap(); + + actions.conn_error = Some(err); + + Ok(()) + } + + fn recv_push_promise( + &mut self, + send_buffer: &SendBuffer, + frame: frame::PushPromise, + ) -> Result<(), RecvError> { + let id = frame.stream_id(); + let promised_id = frame.promised_id(); + + // First, ensure that the initiating stream is still in a valid state. + let parent_key = match self.store.find_mut(&id) { + Some(stream) => { + // The GOAWAY process has begun. All streams with a greater ID + // than specified as part of GOAWAY should be ignored. + if id > self.actions.recv.max_stream_id() { + tracing::trace!( + "id ({:?}) > max_stream_id ({:?}), ignoring PUSH_PROMISE", + id, + self.actions.recv.max_stream_id() + ); + return Ok(()); + } + + // The stream must be receive open + stream.state.ensure_recv_open()?; + stream.key() + } + None => { + proto_err!(conn: "recv_push_promise: initiating stream is in an invalid state"); + return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + } + }; + + // TODO: Streams in the reserved states do not count towards the concurrency + // limit. However, it seems like there should be a cap otherwise this + // could grow in memory indefinitely. + + // Ensure that we can reserve streams + self.actions.recv.ensure_can_reserve()?; + + // Next, open the stream. + // + // If `None` is returned, then the stream is being refused. There is no + // further work to be done. + if self + .actions + .recv + .open(promised_id, Open::PushPromise, &mut self.counts)? + .is_none() + { + return Ok(()); + } + + // Try to handle the frame and create a corresponding key for the pushed stream + // this requires a bit of indirection to make the borrow checker happy. + let child_key: Option = { + // Create state for the stream + let stream = self.store.insert(promised_id, { + Stream::new( + promised_id, + self.actions.send.init_window_sz(), + self.actions.recv.init_window_sz(), + ) + }); + + let actions = &mut self.actions; + + self.counts.transition(stream, |counts, stream| { + let stream_valid = actions.recv.recv_push_promise(frame, stream); + + match stream_valid { + Ok(()) => Ok(Some(stream.key())), + _ => { + let mut send_buffer = send_buffer.inner.lock().unwrap(); + actions + .reset_on_recv_stream_err( + &mut *send_buffer, + stream, + counts, + stream_valid, + ) + .map(|()| None) + } + } + })? + }; + // If we're successful, push the headers and stream... + if let Some(child) = child_key { + let mut ppp = self.store[parent_key].pending_push_promises.take(); + ppp.push(&mut self.store.resolve(child)); + + let parent = &mut self.store.resolve(parent_key); + parent.pending_push_promises = ppp; + parent.notify_recv(); + }; + + Ok(()) + } + + fn recv_eof( + &mut self, + send_buffer: &SendBuffer, + clear_pending_accept: bool, + ) -> Result<(), ()> { + let actions = &mut self.actions; + let counts = &mut self.counts; + let mut send_buffer = send_buffer.inner.lock().unwrap(); + let send_buffer = &mut *send_buffer; + + if actions.conn_error.is_none() { + actions.conn_error = Some(io::Error::from(io::ErrorKind::BrokenPipe).into()); + } + + tracing::trace!("Streams::recv_eof"); + + self.store + .for_each(|stream| { + counts.transition(stream, |counts, stream| { + actions.recv.recv_eof(stream); + + // This handles resetting send state associated with the + // stream + actions.send.recv_err(send_buffer, stream, counts); + Ok::<_, ()>(()) + }) + }) + .expect("recv_eof"); + + actions.clear_queues(clear_pending_accept, &mut self.store, counts); + Ok(()) + } + + fn poll_complete( + &mut self, + send_buffer: &SendBuffer, + cx: &mut Context, + dst: &mut Codec>, + ) -> Poll> + where + T: AsyncWrite + Unpin, + B: Buf, + { + let mut send_buffer = send_buffer.inner.lock().unwrap(); + let send_buffer = &mut *send_buffer; + + // Send WINDOW_UPDATE frames first + // + // TODO: It would probably be better to interleave updates w/ data + // frames. + ready!(self + .actions + .recv + .poll_complete(cx, &mut self.store, &mut self.counts, dst))?; + + // Send any other pending frames + ready!(self.actions.send.poll_complete( + cx, + send_buffer, + &mut self.store, + &mut self.counts, + dst + ))?; + + // Nothing else to do, track the task + self.actions.task = Some(cx.waker().clone()); + + Poll::Ready(Ok(())) + } + + fn send_reset(&mut self, send_buffer: &SendBuffer, id: StreamId, reason: Reason) { + let key = match self.store.find_entry(id) { Entry::Occupied(e) => e.key(), Entry::Vacant(e) => { let stream = Stream::new(id, 0, 0); @@ -755,18 +880,11 @@ where } }; - let stream = me.store.resolve(key); - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); + let stream = self.store.resolve(key); + let mut send_buffer = send_buffer.inner.lock().unwrap(); let send_buffer = &mut *send_buffer; - me.actions - .send_reset(stream, reason, &mut me.counts, send_buffer); - } - - pub fn send_go_away(&mut self, last_processed_id: StreamId) { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - let actions = &mut me.actions; - actions.recv.go_away(last_processed_id); + self.actions + .send_reset(stream, reason, &mut self.counts, send_buffer); } } @@ -801,39 +919,24 @@ impl Streams where P: Peer, { + pub fn as_dyn(&self) -> DynStreams { + let Self { + inner, + send_buffer, + _p, + } = self; + DynStreams { + inner, + send_buffer, + peer: P::r#dyn(), + } + } + /// This function is safe to call multiple times. /// /// A `Result` is returned to avoid panicking if the mutex is poisoned. pub fn recv_eof(&mut self, clear_pending_accept: bool) -> Result<(), ()> { - let mut me = self.inner.lock().map_err(|_| ())?; - let me = &mut *me; - - let actions = &mut me.actions; - let counts = &mut me.counts; - let mut send_buffer = self.send_buffer.inner.lock().unwrap(); - let send_buffer = &mut *send_buffer; - - if actions.conn_error.is_none() { - actions.conn_error = Some(io::Error::from(io::ErrorKind::BrokenPipe).into()); - } - - tracing::trace!("Streams::recv_eof"); - - me.store - .for_each(|stream| { - counts.transition(stream, |counts, stream| { - actions.recv.recv_eof(stream); - - // This handles resetting send state associated with the - // stream - actions.send.recv_err(send_buffer, stream, counts); - Ok::<_, ()>(()) - }) - }) - .expect("recv_eof"); - - actions.clear_queues(clear_pending_accept, &mut me.store, counts); - Ok(()) + self.as_dyn().recv_eof(clear_pending_accept) } pub(crate) fn max_send_streams(&self) -> usize { @@ -1398,11 +1501,11 @@ impl Actions { /// is more likely to be latency/memory constraints that caused this, /// and not a bad actor. So be less catastrophic, the spec allows /// us to send another RST_STREAM of STREAM_CLOSED. - fn may_have_forgotten_stream(&self, id: StreamId) -> bool { + fn may_have_forgotten_stream(&self, peer: peer::Dyn, id: StreamId) -> bool { if id.is_zero() { return false; } - if P::is_local_init(id) { + if peer.is_local_init(id) { self.send.may_have_created_stream(id) } else { self.recv.may_have_created_stream(id)