@@ -11,7 +11,7 @@ use std::fmt;
|
||||
use std::io;
|
||||
use std::marker::Unpin;
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use bytes::Bytes;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
@@ -82,7 +82,7 @@ impl Upgraded {
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
Upgraded {
|
||||
io: Rewind::new_buffered(Box::new(ForwardsWriteBuf(io)), read_buf),
|
||||
io: Rewind::new_buffered(Box::new(io), read_buf),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,9 +92,9 @@ impl Upgraded {
|
||||
/// `Upgraded` back.
|
||||
pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
|
||||
let (io, buf) = self.io.into_inner();
|
||||
match io.__hyper_downcast::<ForwardsWriteBuf<T>>() {
|
||||
match io.__hyper_downcast() {
|
||||
Ok(t) => Ok(Parts {
|
||||
io: t.0,
|
||||
io: *t,
|
||||
read_buf: buf,
|
||||
_inner: (),
|
||||
}),
|
||||
@@ -221,20 +221,14 @@ impl StdError for UpgradeExpected {}
|
||||
|
||||
// ===== impl Io =====
|
||||
|
||||
struct ForwardsWriteBuf<T>(T);
|
||||
|
||||
pub(crate) trait Io: AsyncRead + AsyncWrite + Unpin + 'static {
|
||||
fn poll_write_dyn_buf(
|
||||
&mut self,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut dyn Buf,
|
||||
) -> Poll<io::Result<usize>>;
|
||||
|
||||
fn __hyper_type_id(&self) -> TypeId {
|
||||
TypeId::of::<Self>()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {}
|
||||
|
||||
impl dyn Io + Send {
|
||||
fn __hyper_is<T: Io>(&self) -> bool {
|
||||
let t = TypeId::of::<T>();
|
||||
@@ -254,61 +248,6 @@ impl dyn Io + Send {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + Unpin> AsyncRead for ForwardsWriteBuf<T> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncWrite + Unpin> AsyncWrite for ForwardsWriteBuf<T> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.0).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_shutdown(cx)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.0.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for ForwardsWriteBuf<T> {
|
||||
fn poll_write_dyn_buf(
|
||||
&mut self,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut dyn Buf,
|
||||
) -> Poll<io::Result<usize>> {
|
||||
if self.0.is_write_vectored() {
|
||||
let mut bufs = [io::IoSlice::new(&[]); crate::common::io::MAX_WRITEV_BUFS];
|
||||
let cnt = buf.bytes_vectored(&mut bufs);
|
||||
return Pin::new(&mut self.0).poll_write_vectored(cx, &bufs[..cnt]);
|
||||
}
|
||||
Pin::new(&mut self.0).poll_write(cx, buf.bytes())
|
||||
}
|
||||
}
|
||||
|
||||
mod sealed {
|
||||
use super::OnUpgrade;
|
||||
|
||||
@@ -352,7 +291,6 @@ mod sealed {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
#[test]
|
||||
fn upgraded_downcast() {
|
||||
@@ -363,15 +301,6 @@ mod tests {
|
||||
upgraded.downcast::<Mock>().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn upgraded_forwards_write_buf() {
|
||||
// sanity check that the underlying IO implements write_buf
|
||||
Mock.write_buf(&mut "hello".as_bytes()).await.unwrap();
|
||||
|
||||
let mut upgraded = Upgraded::new(Mock, Bytes::new());
|
||||
upgraded.write_buf(&mut "hello".as_bytes()).await.unwrap();
|
||||
}
|
||||
|
||||
// TODO: replace with tokio_test::io when it can test write_buf
|
||||
struct Mock;
|
||||
|
||||
@@ -395,17 +324,6 @@ mod tests {
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
|
||||
// TODO(eliza): :(
|
||||
// fn poll_write_buf<B: Buf>(
|
||||
// self: Pin<&mut Self>,
|
||||
// _cx: &mut task::Context<'_>,
|
||||
// buf: &mut B,
|
||||
// ) -> Poll<io::Result<usize>> {
|
||||
// let n = buf.remaining();
|
||||
// buf.advance(n);
|
||||
// Poll::Ready(Ok(n))
|
||||
// }
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
unreachable!("Mock::poll_flush")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user