diff --git a/src/client/conn.rs b/src/client/conn.rs index 6072a866..c969d839 100644 --- a/src/client/conn.rs +++ b/src/client/conn.rs @@ -397,7 +397,7 @@ where }; let (io, buf, _) = h1.into_inner(); - pending.fulfill(Upgraded::new(Box::new(io), buf)); + pending.fulfill(Upgraded::new(io, buf)); Poll::Ready(Ok(())) } } diff --git a/src/common/io/rewind.rs b/src/common/io/rewind.rs index 9acbe325..322ee3f0 100644 --- a/src/common/io/rewind.rs +++ b/src/common/io/rewind.rs @@ -36,6 +36,10 @@ impl Rewind { pub(crate) fn into_inner(self) -> (T, Bytes) { (self.inner, self.pre.unwrap_or_else(Bytes::new)) } + + pub(crate) fn get_mut(&mut self) -> &mut T { + &mut self.inner + } } impl AsyncRead for Rewind diff --git a/src/server/conn.rs b/src/server/conn.rs index ad01fd22..a2c5c094 100644 --- a/src/server/conn.rs +++ b/src/server/conn.rs @@ -955,7 +955,7 @@ mod upgrades { }; let (io, buf, _) = h1.into_inner(); - pending.fulfill(Upgraded::new(Box::new(io), buf)); + pending.fulfill(Upgraded::new(io, buf)); return Poll::Ready(Ok(())); } Err(e) => match *e.kind() { diff --git a/src/upgrade.rs b/src/upgrade.rs index e8377322..3b47cf68 100644 --- a/src/upgrade.rs +++ b/src/upgrade.rs @@ -11,7 +11,7 @@ use std::fmt; use std::io; use std::marker::Unpin; -use bytes::Bytes; +use bytes::{Buf, Bytes}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::oneshot; @@ -73,39 +73,15 @@ pub(crate) fn pending() -> (Pending, OnUpgrade) { (Pending { tx }, OnUpgrade { rx: Some(rx) }) } -pub(crate) trait Io: AsyncRead + AsyncWrite + Unpin + 'static { - fn __hyper_type_id(&self) -> TypeId { - TypeId::of::() - } -} - -impl dyn Io + Send { - fn __hyper_is(&self) -> bool { - let t = TypeId::of::(); - self.__hyper_type_id() == t - } - - fn __hyper_downcast(self: Box) -> Result, Box> { - if self.__hyper_is::() { - // Taken from `std::error::Error::downcast()`. - unsafe { - let raw: *mut dyn Io = Box::into_raw(self); - Ok(Box::from_raw(raw as *mut T)) - } - } else { - Err(self) - } - } -} - -impl Io for T {} - // ===== impl Upgraded ===== impl Upgraded { - pub(crate) fn new(io: Box, read_buf: Bytes) -> Self { + pub(crate) fn new(io: T, read_buf: Bytes) -> Self + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { Upgraded { - io: Rewind::new_buffered(io, read_buf), + io: Rewind::new_buffered(Box::new(ForwardsWriteBuf(io)), read_buf), } } @@ -115,9 +91,9 @@ impl Upgraded { /// `Upgraded` back. pub fn downcast(self) -> Result, Self> { let (io, buf) = self.io.into_inner(); - match io.__hyper_downcast() { + match io.__hyper_downcast::>() { Ok(t) => Ok(Parts { - io: *t, + io: t.0, read_buf: buf, _inner: (), }), @@ -151,6 +127,14 @@ impl AsyncWrite for Upgraded { Pin::new(&mut self.io).poll_write(cx, buf) } + fn poll_write_buf( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut B, + ) -> Poll> { + Pin::new(self.io.get_mut()).poll_write_dyn_buf(cx, buf) + } + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { Pin::new(&mut self.io).poll_flush(cx) } @@ -230,3 +214,156 @@ impl StdError for UpgradeExpected { "upgrade expected but not completed" } } + +// ===== impl Io ===== + +struct ForwardsWriteBuf(T); + +pub(crate) trait Io: AsyncRead + AsyncWrite + Unpin + 'static { + fn poll_write_dyn_buf( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut dyn Buf, + ) -> Poll>; + + fn __hyper_type_id(&self) -> TypeId { + TypeId::of::() + } +} + +impl dyn Io + Send { + fn __hyper_is(&self) -> bool { + let t = TypeId::of::(); + self.__hyper_type_id() == t + } + + fn __hyper_downcast(self: Box) -> Result, Box> { + if self.__hyper_is::() { + // Taken from `std::error::Error::downcast()`. + unsafe { + let raw: *mut dyn Io = Box::into_raw(self); + Ok(Box::from_raw(raw as *mut T)) + } + } else { + Err(self) + } + } +} + +impl AsyncRead for ForwardsWriteBuf { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { + self.0.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl AsyncWrite for ForwardsWriteBuf { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_write_buf( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut B, + ) -> Poll> { + Pin::new(&mut self.0).poll_write_buf(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} + +impl Io for ForwardsWriteBuf { + fn poll_write_dyn_buf( + &mut self, + cx: &mut task::Context<'_>, + mut buf: &mut dyn Buf, + ) -> Poll> { + Pin::new(&mut self.0).poll_write_buf(cx, &mut buf) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::AsyncWriteExt; + + #[test] + fn upgraded_downcast() { + let upgraded = Upgraded::new(Mock, Bytes::new()); + + let upgraded = upgraded.downcast::>>().unwrap_err(); + + upgraded.downcast::().unwrap(); + } + + #[tokio::test] + async fn upgraded_forwards_write_buf() { + // sanity check that the underlying IO implements write_buf + Mock.write_buf(&mut "hello".as_bytes()).await.unwrap(); + + let mut upgraded = Upgraded::new(Mock, Bytes::new()); + upgraded.write_buf(&mut "hello".as_bytes()).await.unwrap(); + } + + // TODO: replace with tokio_test::io when it can test write_buf + struct Mock; + + impl AsyncRead for Mock { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + _buf: &mut [u8], + ) -> Poll> { + unreachable!("Mock::poll_read") + } + } + + impl AsyncWrite for Mock { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + _buf: &[u8], + ) -> Poll> { + panic!("poll_write shouldn't be called"); + } + + fn poll_write_buf( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + buf: &mut B, + ) -> Poll> { + let n = buf.remaining(); + buf.advance(n); + Poll::Ready(Ok(n)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll> { + unreachable!("Mock::poll_flush") + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut task::Context<'_>, + ) -> Poll> { + unreachable!("Mock::poll_shutdown") + } + } +}