perf(upgrade): forward vectored writes on the Upgraded type
This commit is contained in:
@@ -397,7 +397,7 @@ where
|
|||||||
};
|
};
|
||||||
|
|
||||||
let (io, buf, _) = h1.into_inner();
|
let (io, buf, _) = h1.into_inner();
|
||||||
pending.fulfill(Upgraded::new(Box::new(io), buf));
|
pending.fulfill(Upgraded::new(io, buf));
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,6 +36,10 @@ impl<T> Rewind<T> {
|
|||||||
pub(crate) fn into_inner(self) -> (T, Bytes) {
|
pub(crate) fn into_inner(self) -> (T, Bytes) {
|
||||||
(self.inner, self.pre.unwrap_or_else(Bytes::new))
|
(self.inner, self.pre.unwrap_or_else(Bytes::new))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_mut(&mut self) -> &mut T {
|
||||||
|
&mut self.inner
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> AsyncRead for Rewind<T>
|
impl<T> AsyncRead for Rewind<T>
|
||||||
|
|||||||
@@ -955,7 +955,7 @@ mod upgrades {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let (io, buf, _) = h1.into_inner();
|
let (io, buf, _) = h1.into_inner();
|
||||||
pending.fulfill(Upgraded::new(Box::new(io), buf));
|
pending.fulfill(Upgraded::new(io, buf));
|
||||||
return Poll::Ready(Ok(()));
|
return Poll::Ready(Ok(()));
|
||||||
}
|
}
|
||||||
Err(e) => match *e.kind() {
|
Err(e) => match *e.kind() {
|
||||||
|
|||||||
201
src/upgrade.rs
201
src/upgrade.rs
@@ -11,7 +11,7 @@ use std::fmt;
|
|||||||
use std::io;
|
use std::io;
|
||||||
use std::marker::Unpin;
|
use std::marker::Unpin;
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::{Buf, Bytes};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
|
|
||||||
@@ -73,39 +73,15 @@ pub(crate) fn pending() -> (Pending, OnUpgrade) {
|
|||||||
(Pending { tx }, OnUpgrade { rx: Some(rx) })
|
(Pending { tx }, OnUpgrade { rx: Some(rx) })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) trait Io: AsyncRead + AsyncWrite + Unpin + 'static {
|
|
||||||
fn __hyper_type_id(&self) -> TypeId {
|
|
||||||
TypeId::of::<Self>()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl dyn Io + Send {
|
|
||||||
fn __hyper_is<T: Io>(&self) -> bool {
|
|
||||||
let t = TypeId::of::<T>();
|
|
||||||
self.__hyper_type_id() == t
|
|
||||||
}
|
|
||||||
|
|
||||||
fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
|
|
||||||
if self.__hyper_is::<T>() {
|
|
||||||
// Taken from `std::error::Error::downcast()`.
|
|
||||||
unsafe {
|
|
||||||
let raw: *mut dyn Io = Box::into_raw(self);
|
|
||||||
Ok(Box::from_raw(raw as *mut T))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err(self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {}
|
|
||||||
|
|
||||||
// ===== impl Upgraded =====
|
// ===== impl Upgraded =====
|
||||||
|
|
||||||
impl Upgraded {
|
impl Upgraded {
|
||||||
pub(crate) fn new(io: Box<dyn Io + Send>, read_buf: Bytes) -> Self {
|
pub(crate) fn new<T>(io: T, read_buf: Bytes) -> Self
|
||||||
|
where
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
Upgraded {
|
Upgraded {
|
||||||
io: Rewind::new_buffered(io, read_buf),
|
io: Rewind::new_buffered(Box::new(ForwardsWriteBuf(io)), read_buf),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,9 +91,9 @@ impl Upgraded {
|
|||||||
/// `Upgraded` back.
|
/// `Upgraded` back.
|
||||||
pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
|
pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
|
||||||
let (io, buf) = self.io.into_inner();
|
let (io, buf) = self.io.into_inner();
|
||||||
match io.__hyper_downcast() {
|
match io.__hyper_downcast::<ForwardsWriteBuf<T>>() {
|
||||||
Ok(t) => Ok(Parts {
|
Ok(t) => Ok(Parts {
|
||||||
io: *t,
|
io: t.0,
|
||||||
read_buf: buf,
|
read_buf: buf,
|
||||||
_inner: (),
|
_inner: (),
|
||||||
}),
|
}),
|
||||||
@@ -151,6 +127,14 @@ impl AsyncWrite for Upgraded {
|
|||||||
Pin::new(&mut self.io).poll_write(cx, buf)
|
Pin::new(&mut self.io).poll_write(cx, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn poll_write_buf<B: Buf>(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
buf: &mut B,
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(self.io.get_mut()).poll_write_dyn_buf(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||||
Pin::new(&mut self.io).poll_flush(cx)
|
Pin::new(&mut self.io).poll_flush(cx)
|
||||||
}
|
}
|
||||||
@@ -230,3 +214,156 @@ impl StdError for UpgradeExpected {
|
|||||||
"upgrade expected but not completed"
|
"upgrade expected but not completed"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ===== impl Io =====
|
||||||
|
|
||||||
|
struct ForwardsWriteBuf<T>(T);
|
||||||
|
|
||||||
|
pub(crate) trait Io: AsyncRead + AsyncWrite + Unpin + 'static {
|
||||||
|
fn poll_write_dyn_buf(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
buf: &mut dyn Buf,
|
||||||
|
) -> Poll<io::Result<usize>>;
|
||||||
|
|
||||||
|
fn __hyper_type_id(&self) -> TypeId {
|
||||||
|
TypeId::of::<Self>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl dyn Io + Send {
|
||||||
|
fn __hyper_is<T: Io>(&self) -> bool {
|
||||||
|
let t = TypeId::of::<T>();
|
||||||
|
self.__hyper_type_id() == t
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
|
||||||
|
if self.__hyper_is::<T>() {
|
||||||
|
// Taken from `std::error::Error::downcast()`.
|
||||||
|
unsafe {
|
||||||
|
let raw: *mut dyn Io = Box::into_raw(self);
|
||||||
|
Ok(Box::from_raw(raw as *mut T))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Err(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: AsyncRead + Unpin> AsyncRead for ForwardsWriteBuf<T> {
|
||||||
|
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
|
||||||
|
self.0.prepare_uninitialized_buffer(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.0).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: AsyncWrite + Unpin> AsyncWrite for ForwardsWriteBuf<T> {
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.0).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_buf<B: Buf>(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
buf: &mut B,
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.0).poll_write_buf(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.0).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.0).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for ForwardsWriteBuf<T> {
|
||||||
|
fn poll_write_dyn_buf(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
mut buf: &mut dyn Buf,
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.0).poll_write_buf(cx, &mut buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn upgraded_downcast() {
|
||||||
|
let upgraded = Upgraded::new(Mock, Bytes::new());
|
||||||
|
|
||||||
|
let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err();
|
||||||
|
|
||||||
|
upgraded.downcast::<Mock>().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn upgraded_forwards_write_buf() {
|
||||||
|
// sanity check that the underlying IO implements write_buf
|
||||||
|
Mock.write_buf(&mut "hello".as_bytes()).await.unwrap();
|
||||||
|
|
||||||
|
let mut upgraded = Upgraded::new(Mock, Bytes::new());
|
||||||
|
upgraded.write_buf(&mut "hello".as_bytes()).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: replace with tokio_test::io when it can test write_buf
|
||||||
|
struct Mock;
|
||||||
|
|
||||||
|
impl AsyncRead for Mock {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut task::Context<'_>,
|
||||||
|
_buf: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
unreachable!("Mock::poll_read")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for Mock {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut task::Context<'_>,
|
||||||
|
_buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
panic!("poll_write shouldn't be called");
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_buf<B: Buf>(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut task::Context<'_>,
|
||||||
|
buf: &mut B,
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
let n = buf.remaining();
|
||||||
|
buf.advance(n);
|
||||||
|
Poll::Ready(Ok(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
unreachable!("Mock::poll_flush")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut task::Context<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
unreachable!("Mock::poll_shutdown")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user