Tokio 0.3 Upgrade (#2319)
Co-authored-by: Urhengulas <johann.hemmann@code.berlin> Co-authored-by: Eliza Weisman <eliza@buoyant.io>
This commit is contained in:
@@ -12,8 +12,8 @@ use std::time::Duration;
|
||||
use futures_util::future::Either;
|
||||
use http::uri::{Scheme, Uri};
|
||||
use pin_project::pin_project;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::Delay;
|
||||
use tokio::net::{TcpSocket, TcpStream};
|
||||
use tokio::time::Sleep;
|
||||
|
||||
use super::dns::{self, resolve, GaiResolver, Resolve};
|
||||
use super::{Connected, Connection};
|
||||
@@ -331,34 +331,9 @@ where
|
||||
dns::IpAddrs::new(addrs)
|
||||
};
|
||||
|
||||
let c = ConnectingTcp::new(
|
||||
config.local_address_ipv4,
|
||||
config.local_address_ipv6,
|
||||
addrs,
|
||||
config.connect_timeout,
|
||||
config.happy_eyeballs_timeout,
|
||||
config.reuse_address,
|
||||
);
|
||||
let c = ConnectingTcp::new(addrs, config);
|
||||
|
||||
let sock = c
|
||||
.connect()
|
||||
.await
|
||||
.map_err(ConnectError::m("tcp connect error"))?;
|
||||
|
||||
if let Some(dur) = config.keep_alive_timeout {
|
||||
sock.set_keepalive(Some(dur))
|
||||
.map_err(ConnectError::m("tcp set_keepalive error"))?;
|
||||
}
|
||||
|
||||
if let Some(size) = config.send_buffer_size {
|
||||
sock.set_send_buffer_size(size)
|
||||
.map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
|
||||
}
|
||||
|
||||
if let Some(size) = config.recv_buffer_size {
|
||||
sock.set_recv_buffer_size(size)
|
||||
.map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
|
||||
}
|
||||
let sock = c.connect().await?;
|
||||
|
||||
sock.set_nodelay(config.nodelay)
|
||||
.map_err(ConnectError::m("tcp set_nodelay error"))?;
|
||||
@@ -475,60 +450,45 @@ impl StdError for ConnectError {
|
||||
}
|
||||
}
|
||||
|
||||
struct ConnectingTcp {
|
||||
local_addr_ipv4: Option<Ipv4Addr>,
|
||||
local_addr_ipv6: Option<Ipv6Addr>,
|
||||
struct ConnectingTcp<'a> {
|
||||
preferred: ConnectingTcpRemote,
|
||||
fallback: Option<ConnectingTcpFallback>,
|
||||
reuse_address: bool,
|
||||
config: &'a Config,
|
||||
}
|
||||
|
||||
impl ConnectingTcp {
|
||||
fn new(
|
||||
local_addr_ipv4: Option<Ipv4Addr>,
|
||||
local_addr_ipv6: Option<Ipv6Addr>,
|
||||
remote_addrs: dns::IpAddrs,
|
||||
connect_timeout: Option<Duration>,
|
||||
fallback_timeout: Option<Duration>,
|
||||
reuse_address: bool,
|
||||
) -> ConnectingTcp {
|
||||
if let Some(fallback_timeout) = fallback_timeout {
|
||||
let (preferred_addrs, fallback_addrs) =
|
||||
remote_addrs.split_by_preference(local_addr_ipv4, local_addr_ipv6);
|
||||
impl<'a> ConnectingTcp<'a> {
|
||||
fn new(remote_addrs: dns::IpAddrs, config: &'a Config) -> Self {
|
||||
if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
|
||||
let (preferred_addrs, fallback_addrs) = remote_addrs
|
||||
.split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
|
||||
if fallback_addrs.is_empty() {
|
||||
return ConnectingTcp {
|
||||
local_addr_ipv4,
|
||||
local_addr_ipv6,
|
||||
preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout),
|
||||
preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
|
||||
fallback: None,
|
||||
reuse_address,
|
||||
config,
|
||||
};
|
||||
}
|
||||
|
||||
ConnectingTcp {
|
||||
local_addr_ipv4,
|
||||
local_addr_ipv6,
|
||||
preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout),
|
||||
preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
|
||||
fallback: Some(ConnectingTcpFallback {
|
||||
delay: tokio::time::delay_for(fallback_timeout),
|
||||
remote: ConnectingTcpRemote::new(fallback_addrs, connect_timeout),
|
||||
delay: tokio::time::sleep(fallback_timeout),
|
||||
remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
|
||||
}),
|
||||
reuse_address,
|
||||
config,
|
||||
}
|
||||
} else {
|
||||
ConnectingTcp {
|
||||
local_addr_ipv4,
|
||||
local_addr_ipv6,
|
||||
preferred: ConnectingTcpRemote::new(remote_addrs, connect_timeout),
|
||||
preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
|
||||
fallback: None,
|
||||
reuse_address,
|
||||
config,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ConnectingTcpFallback {
|
||||
delay: Delay,
|
||||
delay: Sleep,
|
||||
remote: ConnectingTcpRemote,
|
||||
}
|
||||
|
||||
@@ -549,24 +509,11 @@ impl ConnectingTcpRemote {
|
||||
}
|
||||
|
||||
impl ConnectingTcpRemote {
|
||||
async fn connect(
|
||||
&mut self,
|
||||
local_addr_ipv4: &Option<Ipv4Addr>,
|
||||
local_addr_ipv6: &Option<Ipv6Addr>,
|
||||
reuse_address: bool,
|
||||
) -> io::Result<TcpStream> {
|
||||
async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
|
||||
let mut err = None;
|
||||
for addr in &mut self.addrs {
|
||||
debug!("connecting to {}", addr);
|
||||
match connect(
|
||||
&addr,
|
||||
local_addr_ipv4,
|
||||
local_addr_ipv6,
|
||||
reuse_address,
|
||||
self.connect_timeout,
|
||||
)?
|
||||
.await
|
||||
{
|
||||
match connect(&addr, config, self.connect_timeout)?.await {
|
||||
Ok(tcp) => {
|
||||
debug!("connected to {}", addr);
|
||||
return Ok(tcp);
|
||||
@@ -580,9 +527,9 @@ impl ConnectingTcpRemote {
|
||||
|
||||
match err {
|
||||
Some(e) => Err(e),
|
||||
None => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::NotConnected,
|
||||
"Network unreachable",
|
||||
None => Err(ConnectError::new(
|
||||
"tcp connect error",
|
||||
std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
|
||||
)),
|
||||
}
|
||||
}
|
||||
@@ -618,30 +565,79 @@ fn bind_local_address(
|
||||
|
||||
fn connect(
|
||||
addr: &SocketAddr,
|
||||
local_addr_ipv4: &Option<Ipv4Addr>,
|
||||
local_addr_ipv6: &Option<Ipv6Addr>,
|
||||
reuse_address: bool,
|
||||
config: &Config,
|
||||
connect_timeout: Option<Duration>,
|
||||
) -> io::Result<impl Future<Output = io::Result<TcpStream>>> {
|
||||
) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
|
||||
// TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
|
||||
// keepalive timeout and send/recv buffer size, it would be nice to use that
|
||||
// instead of socket2, and avoid the unsafe `into_raw_fd`/`from_raw_fd`
|
||||
// dance...
|
||||
use socket2::{Domain, Protocol, Socket, Type};
|
||||
let domain = match *addr {
|
||||
SocketAddr::V4(_) => Domain::ipv4(),
|
||||
SocketAddr::V6(_) => Domain::ipv6(),
|
||||
};
|
||||
let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))?;
|
||||
let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))
|
||||
.map_err(ConnectError::m("tcp open error"))?;
|
||||
|
||||
if reuse_address {
|
||||
socket.set_reuse_address(true)?;
|
||||
if config.reuse_address {
|
||||
socket
|
||||
.set_reuse_address(true)
|
||||
.map_err(ConnectError::m("tcp set_reuse_address error"))?;
|
||||
}
|
||||
|
||||
bind_local_address(&socket, addr, local_addr_ipv4, local_addr_ipv6)?;
|
||||
// When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
|
||||
// responsible for ensuring O_NONBLOCK is set.
|
||||
socket
|
||||
.set_nonblocking(true)
|
||||
.map_err(ConnectError::m("tcp set_nonblocking error"))?;
|
||||
|
||||
let addr = *addr;
|
||||
bind_local_address(
|
||||
&socket,
|
||||
addr,
|
||||
&config.local_address_ipv4,
|
||||
&config.local_address_ipv6,
|
||||
)
|
||||
.map_err(ConnectError::m("tcp bind local error"))?;
|
||||
|
||||
let std_tcp = socket.into_tcp_stream();
|
||||
if let Some(dur) = config.keep_alive_timeout {
|
||||
socket
|
||||
.set_keepalive(Some(dur))
|
||||
.map_err(ConnectError::m("tcp set_keepalive error"))?;
|
||||
}
|
||||
|
||||
if let Some(size) = config.send_buffer_size {
|
||||
socket
|
||||
.set_send_buffer_size(size)
|
||||
.map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
|
||||
}
|
||||
|
||||
if let Some(size) = config.recv_buffer_size {
|
||||
socket
|
||||
.set_recv_buffer_size(size)
|
||||
.map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
let socket = unsafe {
|
||||
// Safety: `from_raw_fd` is only safe to call if ownership of the raw
|
||||
// file descriptor is transferred. Since we call `into_raw_fd` on the
|
||||
// socket2 socket, it gives up ownership of the fd and will not close
|
||||
// it, so this is safe.
|
||||
use std::os::unix::io::{FromRawFd, IntoRawFd};
|
||||
TcpSocket::from_raw_fd(socket.into_raw_fd())
|
||||
};
|
||||
#[cfg(windows)]
|
||||
let socket = unsafe {
|
||||
// Safety: `from_raw_socket` is only safe to call if ownership of the raw
|
||||
// Windows SOCKET is transferred. Since we call `into_raw_socket` on the
|
||||
// socket2 socket, it gives up ownership of the SOCKET and will not close
|
||||
// it, so this is safe.
|
||||
use std::os::windows::io::{FromRawSocket, IntoRawSocket};
|
||||
TcpSocket::from_raw_socket(socket.into_raw_socket())
|
||||
};
|
||||
let connect = socket.connect(*addr);
|
||||
Ok(async move {
|
||||
let connect = TcpStream::connect_std(std_tcp, &addr);
|
||||
match connect_timeout {
|
||||
Some(dur) => match tokio::time::timeout(dur, connect).await {
|
||||
Ok(Ok(s)) => Ok(s),
|
||||
@@ -650,33 +646,19 @@ fn connect(
|
||||
},
|
||||
None => connect.await,
|
||||
}
|
||||
.map_err(ConnectError::m("tcp connect error"))
|
||||
})
|
||||
}
|
||||
|
||||
impl ConnectingTcp {
|
||||
async fn connect(mut self) -> io::Result<TcpStream> {
|
||||
let Self {
|
||||
ref local_addr_ipv4,
|
||||
ref local_addr_ipv6,
|
||||
reuse_address,
|
||||
..
|
||||
} = self;
|
||||
impl ConnectingTcp<'_> {
|
||||
async fn connect(mut self) -> Result<TcpStream, ConnectError> {
|
||||
match self.fallback {
|
||||
None => {
|
||||
self.preferred
|
||||
.connect(local_addr_ipv4, local_addr_ipv6, reuse_address)
|
||||
.await
|
||||
}
|
||||
None => self.preferred.connect(self.config).await,
|
||||
Some(mut fallback) => {
|
||||
let preferred_fut =
|
||||
self.preferred
|
||||
.connect(local_addr_ipv4, local_addr_ipv6, reuse_address);
|
||||
let preferred_fut = self.preferred.connect(self.config);
|
||||
futures_util::pin_mut!(preferred_fut);
|
||||
|
||||
let fallback_fut =
|
||||
fallback
|
||||
.remote
|
||||
.connect(local_addr_ipv4, local_addr_ipv6, reuse_address);
|
||||
let fallback_fut = fallback.remote.connect(self.config);
|
||||
futures_util::pin_mut!(fallback_fut);
|
||||
|
||||
let (result, future) =
|
||||
@@ -711,7 +693,7 @@ mod tests {
|
||||
use ::http::Uri;
|
||||
|
||||
use super::super::sealed::{Connect, ConnectSvc};
|
||||
use super::HttpConnector;
|
||||
use super::{Config, ConnectError, HttpConnector};
|
||||
|
||||
async fn connect<C>(
|
||||
connector: C,
|
||||
@@ -773,6 +755,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn local_address() {
|
||||
use std::net::{IpAddr, TcpListener};
|
||||
let _ = pretty_env_logger::try_init();
|
||||
|
||||
let (bind_ip_v4, bind_ip_v6) = get_local_ips();
|
||||
let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
@@ -818,10 +801,8 @@ mod tests {
|
||||
let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
let addr = server4.local_addr().unwrap();
|
||||
let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
|
||||
let mut rt = tokio::runtime::Builder::new()
|
||||
.enable_io()
|
||||
.enable_time()
|
||||
.basic_scheduler()
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
@@ -925,16 +906,21 @@ mod tests {
|
||||
.iter()
|
||||
.map(|host| (host.clone(), addr.port()).into())
|
||||
.collect();
|
||||
let connecting_tcp = ConnectingTcp::new(
|
||||
None,
|
||||
None,
|
||||
dns::IpAddrs::new(addrs),
|
||||
None,
|
||||
Some(fallback_timeout),
|
||||
false,
|
||||
);
|
||||
let cfg = Config {
|
||||
local_address_ipv4: None,
|
||||
local_address_ipv6: None,
|
||||
connect_timeout: None,
|
||||
keep_alive_timeout: None,
|
||||
happy_eyeballs_timeout: Some(fallback_timeout),
|
||||
nodelay: false,
|
||||
reuse_address: false,
|
||||
enforce_http: false,
|
||||
send_buffer_size: None,
|
||||
recv_buffer_size: None,
|
||||
};
|
||||
let connecting_tcp = ConnectingTcp::new(dns::IpAddrs::new(addrs), &cfg);
|
||||
let start = Instant::now();
|
||||
Ok::<_, io::Error>((start, connecting_tcp.connect().await?))
|
||||
Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
|
||||
})
|
||||
.unwrap();
|
||||
let res = if stream.peer_addr().unwrap().is_ipv4() {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use futures_util::future;
|
||||
use tokio::stream::Stream;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use crate::common::{task, Future, Pin, Poll};
|
||||
@@ -131,22 +132,25 @@ impl<T, U> Clone for UnboundedSender<T, U> {
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project::pin_project(PinnedDrop)]
|
||||
pub struct Receiver<T, U> {
|
||||
#[pin]
|
||||
inner: mpsc::UnboundedReceiver<Envelope<T, U>>,
|
||||
taker: want::Taker,
|
||||
}
|
||||
|
||||
impl<T, U> Receiver<T, U> {
|
||||
pub(crate) fn poll_next(
|
||||
&mut self,
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> Poll<Option<(T, Callback<T, U>)>> {
|
||||
match self.inner.poll_recv(cx) {
|
||||
let this = self.project();
|
||||
match this.inner.poll_next(cx) {
|
||||
Poll::Ready(item) => {
|
||||
Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped")))
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.taker.want();
|
||||
this.taker.want();
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
@@ -165,11 +169,12 @@ impl<T, U> Receiver<T, U> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> Drop for Receiver<T, U> {
|
||||
fn drop(&mut self) {
|
||||
#[pin_project::pinned_drop]
|
||||
impl<T, U> PinnedDrop for Receiver<T, U> {
|
||||
fn drop(mut self: Pin<&mut Self>) {
|
||||
// Notify the giver about the closure first, before dropping
|
||||
// the mpsc::Receiver.
|
||||
self.taker.cancel();
|
||||
self.as_mut().taker.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -262,7 +267,7 @@ mod tests {
|
||||
impl<T, U> Future for Receiver<T, U> {
|
||||
type Output = Option<(T, Callback<T, U>)>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.poll_next(cx)
|
||||
}
|
||||
}
|
||||
@@ -344,9 +349,8 @@ mod tests {
|
||||
fn giver_queue_throughput(b: &mut test::Bencher) {
|
||||
use crate::{Body, Request, Response};
|
||||
|
||||
let mut rt = tokio::runtime::Builder::new()
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.basic_scheduler()
|
||||
.build()
|
||||
.unwrap();
|
||||
let (mut tx, mut rx) = channel::<Request<Body>, Response<Body>>();
|
||||
@@ -368,9 +372,8 @@ mod tests {
|
||||
#[cfg(feature = "nightly")]
|
||||
#[bench]
|
||||
fn giver_queue_not_ready(b: &mut test::Bencher) {
|
||||
let mut rt = tokio::runtime::Builder::new()
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.basic_scheduler()
|
||||
.build()
|
||||
.unwrap();
|
||||
let (_tx, mut rx) = channel::<i32, ()>();
|
||||
|
||||
@@ -706,12 +706,15 @@ impl Expiration {
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
#[pin_project::pin_project]
|
||||
struct IdleTask<T> {
|
||||
#[pin]
|
||||
interval: Interval,
|
||||
pool: WeakOpt<Mutex<PoolInner<T>>>,
|
||||
// This allows the IdleTask to be notified as soon as the entire
|
||||
// Pool is fully dropped, and shutdown. This channel is never sent on,
|
||||
// but Err(Canceled) will be received when the Pool is dropped.
|
||||
#[pin]
|
||||
pool_drop_notifier: oneshot::Receiver<crate::common::Never>,
|
||||
}
|
||||
|
||||
@@ -719,9 +722,11 @@ struct IdleTask<T> {
|
||||
impl<T: Poolable + 'static> Future for IdleTask<T> {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
|
||||
use tokio::stream::Stream;
|
||||
let mut this = self.project();
|
||||
loop {
|
||||
match Pin::new(&mut self.pool_drop_notifier).poll(cx) {
|
||||
match this.pool_drop_notifier.as_mut().poll(cx) {
|
||||
Poll::Ready(Ok(n)) => match n {},
|
||||
Poll::Pending => (),
|
||||
Poll::Ready(Err(_canceled)) => {
|
||||
@@ -730,9 +735,9 @@ impl<T: Poolable + 'static> Future for IdleTask<T> {
|
||||
}
|
||||
}
|
||||
|
||||
ready!(self.interval.poll_tick(cx));
|
||||
ready!(this.interval.as_mut().poll_next(cx));
|
||||
|
||||
if let Some(inner) = self.pool.upgrade() {
|
||||
if let Some(inner) = this.pool.upgrade() {
|
||||
if let Ok(mut inner) = inner.lock() {
|
||||
trace!("idle interval checking for expired");
|
||||
inner.clear_expired();
|
||||
@@ -850,7 +855,7 @@ mod tests {
|
||||
let pooled = pool.pooled(c(key.clone()), Uniq(41));
|
||||
|
||||
drop(pooled);
|
||||
tokio::time::delay_for(pool.locked().timeout.unwrap()).await;
|
||||
tokio::time::sleep(pool.locked().timeout.unwrap()).await;
|
||||
let mut checkout = pool.checkout(key);
|
||||
let poll_once = PollOnce(&mut checkout);
|
||||
let is_not_ready = poll_once.await.is_none();
|
||||
@@ -871,7 +876,7 @@ mod tests {
|
||||
pool.locked().idle.get(&key).map(|entries| entries.len()),
|
||||
Some(3)
|
||||
);
|
||||
tokio::time::delay_for(pool.locked().timeout.unwrap()).await;
|
||||
tokio::time::sleep(pool.locked().timeout.unwrap()).await;
|
||||
|
||||
let mut checkout = pool.checkout(key.clone());
|
||||
let poll_once = PollOnce(&mut checkout);
|
||||
|
||||
@@ -1,20 +1,13 @@
|
||||
use std::mem;
|
||||
|
||||
use pin_project::pin_project;
|
||||
use tokio::stream::Stream;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
|
||||
use super::{task, Future, Never, Pin, Poll};
|
||||
|
||||
// Sentinel value signaling that the watch is still open
|
||||
#[derive(Clone, Copy)]
|
||||
enum Action {
|
||||
Open,
|
||||
// Closed isn't sent via the `Action` type, but rather once
|
||||
// the watch::Sender is dropped.
|
||||
}
|
||||
|
||||
pub fn channel() -> (Signal, Watch) {
|
||||
let (tx, rx) = watch::channel(Action::Open);
|
||||
let (tx, rx) = watch::channel(());
|
||||
let (drained_tx, drained_rx) = mpsc::channel(1);
|
||||
(
|
||||
Signal {
|
||||
@@ -27,17 +20,19 @@ pub fn channel() -> (Signal, Watch) {
|
||||
|
||||
pub struct Signal {
|
||||
drained_rx: mpsc::Receiver<Never>,
|
||||
_tx: watch::Sender<Action>,
|
||||
_tx: watch::Sender<()>,
|
||||
}
|
||||
|
||||
#[pin_project::pin_project]
|
||||
pub struct Draining {
|
||||
#[pin]
|
||||
drained_rx: mpsc::Receiver<Never>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Watch {
|
||||
drained_tx: mpsc::Sender<Never>,
|
||||
rx: watch::Receiver<Action>,
|
||||
rx: watch::Receiver<()>,
|
||||
}
|
||||
|
||||
#[allow(missing_debug_implementations)]
|
||||
@@ -46,7 +41,8 @@ pub struct Watching<F, FN> {
|
||||
#[pin]
|
||||
future: F,
|
||||
state: State<FN>,
|
||||
watch: Watch,
|
||||
watch: Pin<Box<dyn Future<Output = ()> + Send + Sync>>,
|
||||
_drained_tx: mpsc::Sender<Never>,
|
||||
}
|
||||
|
||||
enum State<F> {
|
||||
@@ -66,8 +62,8 @@ impl Signal {
|
||||
impl Future for Draining {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
|
||||
match ready!(self.drained_rx.poll_recv(cx)) {
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
|
||||
match ready!(self.project().drained_rx.poll_next(cx)) {
|
||||
Some(never) => match never {},
|
||||
None => Poll::Ready(()),
|
||||
}
|
||||
@@ -80,10 +76,14 @@ impl Watch {
|
||||
F: Future,
|
||||
FN: FnOnce(Pin<&mut F>),
|
||||
{
|
||||
let Self { drained_tx, mut rx } = self;
|
||||
Watching {
|
||||
future,
|
||||
state: State::Watch(on_drain),
|
||||
watch: self,
|
||||
watch: Box::pin(async move {
|
||||
let _ = rx.changed().await;
|
||||
}),
|
||||
_drained_tx: drained_tx,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -100,12 +100,12 @@ where
|
||||
loop {
|
||||
match mem::replace(me.state, State::Draining) {
|
||||
State::Watch(on_drain) => {
|
||||
match me.watch.rx.poll_recv_ref(cx) {
|
||||
Poll::Ready(None) => {
|
||||
match Pin::new(&mut me.watch).poll(cx) {
|
||||
Poll::Ready(()) => {
|
||||
// Drain has been triggered!
|
||||
on_drain(me.future.as_mut());
|
||||
}
|
||||
Poll::Ready(Some(_ /*State::Open*/)) | Poll::Pending => {
|
||||
Poll::Pending => {
|
||||
*me.state = State::Watch(on_drain);
|
||||
return me.future.poll(cx);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::marker::Unpin;
|
||||
use std::{cmp, io};
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use crate::common::{task, Pin, Poll};
|
||||
|
||||
@@ -37,36 +37,33 @@ impl<T> Rewind<T> {
|
||||
(self.inner, self.pre.unwrap_or_else(Bytes::new))
|
||||
}
|
||||
|
||||
pub(crate) fn get_mut(&mut self) -> &mut T {
|
||||
&mut self.inner
|
||||
}
|
||||
// pub(crate) fn get_mut(&mut self) -> &mut T {
|
||||
// &mut self.inner
|
||||
// }
|
||||
}
|
||||
|
||||
impl<T> AsyncRead for Rewind<T>
|
||||
where
|
||||
T: AsyncRead + Unpin,
|
||||
{
|
||||
#[inline]
|
||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
|
||||
self.inner.prepare_uninitialized_buffer(buf)
|
||||
}
|
||||
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if let Some(mut prefix) = self.pre.take() {
|
||||
// If there are no remaining bytes, let the bytes get dropped.
|
||||
if !prefix.is_empty() {
|
||||
let copy_len = cmp::min(prefix.len(), buf.len());
|
||||
prefix.copy_to_slice(&mut buf[..copy_len]);
|
||||
let copy_len = cmp::min(prefix.len(), buf.remaining());
|
||||
// TODO: There should be a way to do following two lines cleaner...
|
||||
buf.put_slice(&prefix[..copy_len]);
|
||||
prefix.advance(copy_len);
|
||||
// Put back whats left
|
||||
if !prefix.is_empty() {
|
||||
self.pre = Some(prefix);
|
||||
}
|
||||
|
||||
return Poll::Ready(Ok(copy_len));
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||
@@ -92,15 +89,6 @@ where
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_write_buf<B: Buf>(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut B,
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write_buf(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -967,9 +967,8 @@ mod tests {
|
||||
*conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]);
|
||||
conn.state.cached_headers = Some(HeaderMap::with_capacity(2));
|
||||
|
||||
let mut rt = tokio::runtime::Builder::new()
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.basic_scheduler()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -382,7 +382,7 @@ mod tests {
|
||||
use super::*;
|
||||
use std::pin::Pin;
|
||||
use std::time::Duration;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::{AsyncRead, ReadBuf};
|
||||
|
||||
impl<'a> MemRead for &'a [u8] {
|
||||
fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
|
||||
@@ -401,8 +401,9 @@ mod tests {
|
||||
impl<'a> MemRead for &'a mut (dyn AsyncRead + Unpin) {
|
||||
fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
|
||||
let mut v = vec![0; len];
|
||||
let n = ready!(Pin::new(self).poll_read(cx, &mut v)?);
|
||||
Poll::Ready(Ok(Bytes::copy_from_slice(&v[..n])))
|
||||
let mut buf = ReadBuf::new(&mut v);
|
||||
ready!(Pin::new(self).poll_read(cx, &mut buf)?);
|
||||
Poll::Ready(Ok(Bytes::copy_from_slice(&buf.filled())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -623,7 +624,7 @@ mod tests {
|
||||
#[cfg(feature = "nightly")]
|
||||
#[bench]
|
||||
fn bench_decode_chunked_1kb(b: &mut test::Bencher) {
|
||||
let mut rt = new_runtime();
|
||||
let rt = new_runtime();
|
||||
|
||||
const LEN: usize = 1024;
|
||||
let mut vec = Vec::new();
|
||||
@@ -647,7 +648,7 @@ mod tests {
|
||||
#[cfg(feature = "nightly")]
|
||||
#[bench]
|
||||
fn bench_decode_length_1kb(b: &mut test::Bencher) {
|
||||
let mut rt = new_runtime();
|
||||
let rt = new_runtime();
|
||||
|
||||
const LEN: usize = 1024;
|
||||
let content = Bytes::from(&[0; LEN][..]);
|
||||
@@ -665,9 +666,8 @@ mod tests {
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
fn new_runtime() -> tokio::runtime::Runtime {
|
||||
tokio::runtime::Builder::new()
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.basic_scheduler()
|
||||
.build()
|
||||
.expect("rt build")
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ pub(crate) trait Dispatch {
|
||||
type PollError;
|
||||
type RecvItem;
|
||||
fn poll_msg(
|
||||
&mut self,
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>>;
|
||||
fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()>;
|
||||
@@ -40,8 +40,10 @@ pub struct Server<S: HttpService<B>, B> {
|
||||
pub(crate) service: S,
|
||||
}
|
||||
|
||||
#[pin_project::pin_project]
|
||||
pub struct Client<B> {
|
||||
callback: Option<crate::client::dispatch::Callback<Request<B>, Response<Body>>>,
|
||||
#[pin]
|
||||
rx: ClientRx<B>,
|
||||
rx_closed: bool,
|
||||
}
|
||||
@@ -281,7 +283,7 @@ where
|
||||
&& self.conn.can_write_head()
|
||||
&& self.dispatch.should_poll()
|
||||
{
|
||||
if let Some(msg) = ready!(self.dispatch.poll_msg(cx)) {
|
||||
if let Some(msg) = ready!(Pin::new(&mut self.dispatch).poll_msg(cx)) {
|
||||
let (head, mut body) = msg.map_err(crate::Error::new_user_service)?;
|
||||
|
||||
// Check if the body knows its full data immediately.
|
||||
@@ -469,10 +471,11 @@ where
|
||||
type RecvItem = RequestHead;
|
||||
|
||||
fn poll_msg(
|
||||
&mut self,
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>> {
|
||||
let ret = if let Some(ref mut fut) = self.in_flight.as_mut().as_pin_mut() {
|
||||
let mut this = self.as_mut();
|
||||
let ret = if let Some(ref mut fut) = this.in_flight.as_mut().as_pin_mut() {
|
||||
let resp = ready!(fut.as_mut().poll(cx)?);
|
||||
let (parts, body) = resp.into_parts();
|
||||
let head = MessageHead {
|
||||
@@ -486,7 +489,7 @@ where
|
||||
};
|
||||
|
||||
// Since in_flight finished, remove it
|
||||
self.in_flight.set(None);
|
||||
this.in_flight.set(None);
|
||||
ret
|
||||
}
|
||||
|
||||
@@ -540,11 +543,12 @@ where
|
||||
type RecvItem = ResponseHead;
|
||||
|
||||
fn poll_msg(
|
||||
&mut self,
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Never>>> {
|
||||
debug_assert!(!self.rx_closed);
|
||||
match self.rx.poll_next(cx) {
|
||||
let this = self.project();
|
||||
debug_assert!(!*this.rx_closed);
|
||||
match this.rx.poll_next(cx) {
|
||||
Poll::Ready(Some((req, mut cb))) => {
|
||||
// check that future hasn't been canceled already
|
||||
match cb.poll_canceled(cx) {
|
||||
@@ -559,7 +563,7 @@ where
|
||||
subject: RequestLine(parts.method, parts.uri),
|
||||
headers: parts.headers,
|
||||
};
|
||||
self.callback = Some(cb);
|
||||
*this.callback = Some(cb);
|
||||
Poll::Ready(Some(Ok((head, body))))
|
||||
}
|
||||
}
|
||||
@@ -567,7 +571,7 @@ where
|
||||
Poll::Ready(None) => {
|
||||
// user has dropped sender handle
|
||||
trace!("client tx closed");
|
||||
self.rx_closed = true;
|
||||
*this.rx_closed = true;
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::fmt;
|
||||
use std::io::{self, IoSlice};
|
||||
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use super::{Http1Transaction, ParseContext, ParsedMessage};
|
||||
use crate::common::buf::BufList;
|
||||
@@ -188,9 +188,16 @@ where
|
||||
if self.read_buf_remaining_mut() < next {
|
||||
self.read_buf.reserve(next);
|
||||
}
|
||||
match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf) {
|
||||
Poll::Ready(Ok(n)) => {
|
||||
debug!("read {} bytes", n);
|
||||
let mut buf = ReadBuf::uninit(&mut self.read_buf.bytes_mut()[..]);
|
||||
match Pin::new(&mut self.io).poll_read(cx, &mut buf) {
|
||||
Poll::Ready(Ok(_)) => {
|
||||
let n = buf.filled().len();
|
||||
unsafe {
|
||||
// Safety: we just read that many bytes into the
|
||||
// uninitialized part of the buffer, so this is okay.
|
||||
// @tokio pls give me back `poll_read_buf` thanks
|
||||
self.read_buf.advance_mut(n);
|
||||
}
|
||||
self.read_buf_strategy.record(n);
|
||||
Poll::Ready(Ok(n))
|
||||
}
|
||||
@@ -224,8 +231,16 @@ where
|
||||
return self.poll_flush_flattened(cx);
|
||||
}
|
||||
loop {
|
||||
let n =
|
||||
ready!(Pin::new(&mut self.io).poll_write_buf(cx, &mut self.write_buf.auto()))?;
|
||||
// TODO(eliza): this basically ignores all of `WriteBuf`...put
|
||||
// back vectored IO and `poll_write_buf` when the appropriate Tokio
|
||||
// changes land...
|
||||
let n = ready!(Pin::new(&mut self.io)
|
||||
// .poll_write_buf(cx, &mut self.write_buf.auto()))?;
|
||||
.poll_write(cx, self.write_buf.auto().bytes()))?;
|
||||
// TODO(eliza): we have to do this manually because
|
||||
// `poll_write_buf` doesn't exist in Tokio 0.3 yet...when
|
||||
// `poll_write_buf` comes back, the manual advance will need to leave!
|
||||
self.write_buf.advance(n);
|
||||
debug!("flushed {} bytes", n);
|
||||
if self.write_buf.remaining() == 0 {
|
||||
break;
|
||||
@@ -452,6 +467,7 @@ where
|
||||
self.strategy = strategy;
|
||||
}
|
||||
|
||||
// TODO(eliza): put back writev!
|
||||
#[inline]
|
||||
fn auto(&mut self) -> WriteBufAuto<'_, B> {
|
||||
WriteBufAuto::new(self)
|
||||
@@ -628,28 +644,31 @@ mod tests {
|
||||
*/
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn iobuf_write_empty_slice() {
|
||||
// First, let's just check that the Mock would normally return an
|
||||
// error on an unexpected write, even if the buffer is empty...
|
||||
let mut mock = Mock::new().build();
|
||||
futures_util::future::poll_fn(|cx| {
|
||||
Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[]))
|
||||
})
|
||||
.await
|
||||
.expect_err("should be a broken pipe");
|
||||
// TODO(eliza): can i have writev back pls T_T
|
||||
// // First, let's just check that the Mock would normally return an
|
||||
// // error on an unexpected write, even if the buffer is empty...
|
||||
// let mut mock = Mock::new().build();
|
||||
// futures_util::future::poll_fn(|cx| {
|
||||
// Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[]))
|
||||
// })
|
||||
// .await
|
||||
// .expect_err("should be a broken pipe");
|
||||
|
||||
// underlying io will return the logic error upon write,
|
||||
// so we are testing that the io_buf does not trigger a write
|
||||
// when there is nothing to flush
|
||||
let mock = Mock::new().build();
|
||||
let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
|
||||
io_buf.flush().await.expect("should short-circuit flush");
|
||||
// // underlying io will return the logic error upon write,
|
||||
// // so we are testing that the io_buf does not trigger a write
|
||||
// // when there is nothing to flush
|
||||
// let mock = Mock::new().build();
|
||||
// let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
|
||||
// io_buf.flush().await.expect("should short-circuit flush");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_reads_until_blocked() {
|
||||
use crate::proto::h1::ClientTransaction;
|
||||
|
||||
let _ = pretty_env_logger::try_init();
|
||||
let mock = Mock::new()
|
||||
// Split over multiple reads will read all of it
|
||||
.read(b"HTTP/1.1 200 OK\r\n")
|
||||
|
||||
@@ -33,7 +33,7 @@ use std::time::Instant;
|
||||
|
||||
use h2::{Ping, PingPong};
|
||||
#[cfg(feature = "runtime")]
|
||||
use tokio::time::{Delay, Instant};
|
||||
use tokio::time::{Instant, Sleep};
|
||||
|
||||
type WindowSize = u32;
|
||||
|
||||
@@ -60,7 +60,7 @@ pub(super) fn channel(ping_pong: PingPong, config: Config) -> (Recorder, Ponger)
|
||||
interval,
|
||||
timeout: config.keep_alive_timeout,
|
||||
while_idle: config.keep_alive_while_idle,
|
||||
timer: tokio::time::delay_for(interval),
|
||||
timer: tokio::time::sleep(interval),
|
||||
state: KeepAliveState::Init,
|
||||
});
|
||||
|
||||
@@ -156,7 +156,7 @@ struct KeepAlive {
|
||||
while_idle: bool,
|
||||
|
||||
state: KeepAliveState,
|
||||
timer: Delay,
|
||||
timer: Sleep,
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
|
||||
@@ -809,9 +809,9 @@ where
|
||||
type Output = Result<Connection<I, S, E>, FE>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
|
||||
let me = self.project();
|
||||
let mut me = self.project();
|
||||
let service = ready!(me.future.poll(cx))?;
|
||||
let io = me.io.take().expect("polled after complete");
|
||||
let io = Option::take(&mut me.io).expect("polled after complete");
|
||||
Poll::Ready(Ok(me.protocol.serve_connection(io, service)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::net::{SocketAddr, TcpListener as StdTcpListener};
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::Delay;
|
||||
use tokio::time::Sleep;
|
||||
|
||||
use crate::common::{task, Future, Pin, Poll};
|
||||
|
||||
@@ -19,7 +19,7 @@ pub struct AddrIncoming {
|
||||
sleep_on_errors: bool,
|
||||
tcp_keepalive_timeout: Option<Duration>,
|
||||
tcp_nodelay: bool,
|
||||
timeout: Option<Delay>,
|
||||
timeout: Option<Sleep>,
|
||||
}
|
||||
|
||||
impl AddrIncoming {
|
||||
@@ -30,6 +30,10 @@ impl AddrIncoming {
|
||||
}
|
||||
|
||||
pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> {
|
||||
// TcpListener::from_std doesn't set O_NONBLOCK
|
||||
std_listener
|
||||
.set_nonblocking(true)
|
||||
.map_err(crate::Error::new_listen)?;
|
||||
let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?;
|
||||
let addr = listener.local_addr().map_err(crate::Error::new_listen)?;
|
||||
Ok(AddrIncoming {
|
||||
@@ -98,9 +102,46 @@ impl AddrIncoming {
|
||||
match ready!(self.listener.poll_accept(cx)) {
|
||||
Ok((socket, addr)) => {
|
||||
if let Some(dur) = self.tcp_keepalive_timeout {
|
||||
// Convert the Tokio `TcpStream` into a `socket2` socket
|
||||
// so we can call `set_keepalive`.
|
||||
// TODO(eliza): if Tokio's `TcpSocket` API grows a few
|
||||
// more methods in the future, hopefully we shouldn't
|
||||
// have to do the `from_raw_fd` dance any longer...
|
||||
#[cfg(unix)]
|
||||
let socket = unsafe {
|
||||
// Safety: `socket2`'s socket will try to close the
|
||||
// underlying fd when it's dropped. However, we
|
||||
// can't take ownership of the fd from the tokio
|
||||
// TcpStream, so instead we will call `into_raw_fd`
|
||||
// on the socket2 socket before dropping it. This
|
||||
// prevents it from trying to close the fd.
|
||||
use std::os::unix::io::{AsRawFd, FromRawFd};
|
||||
socket2::Socket::from_raw_fd(socket.as_raw_fd())
|
||||
};
|
||||
#[cfg(windows)]
|
||||
let socket = unsafe {
|
||||
// Safety: `socket2`'s socket will try to close the
|
||||
// underlying SOCKET when it's dropped. However, we
|
||||
// can't take ownership of the SOCKET from the tokio
|
||||
// TcpStream, so instead we will call `into_raw_socket`
|
||||
// on the socket2 socket before dropping it. This
|
||||
// prevents it from trying to close the SOCKET.
|
||||
use std::os::windows::io::{AsRawSocket, FromRawSocket};
|
||||
socket2::Socket::from_raw_socket(socket.as_raw_socket())
|
||||
};
|
||||
|
||||
// Actually set the TCP keepalive timeout.
|
||||
if let Err(e) = socket.set_keepalive(Some(dur)) {
|
||||
trace!("error trying to set TCP keepalive: {}", e);
|
||||
}
|
||||
|
||||
// Take ownershop of the fd/socket back from the socket2
|
||||
// `Socket`, so that socket2 doesn't try to close it
|
||||
// when it's dropped.
|
||||
#[cfg(unix)]
|
||||
drop(std::os::unix::io::IntoRawFd::into_raw_fd(socket));
|
||||
#[cfg(windows)]
|
||||
drop(std::os::windows::io::IntoRawSocket::into_raw_socket(socket));
|
||||
}
|
||||
if let Err(e) = socket.set_nodelay(self.tcp_nodelay) {
|
||||
trace!("error trying to set TCP nodelay: {}", e);
|
||||
@@ -119,7 +160,7 @@ impl AddrIncoming {
|
||||
error!("accept error: {}", e);
|
||||
|
||||
// Sleep 1s.
|
||||
let mut timeout = tokio::time::delay_for(Duration::from_secs(1));
|
||||
let mut timeout = tokio::time::sleep(Duration::from_secs(1));
|
||||
|
||||
match Pin::new(&mut timeout).poll(cx) {
|
||||
Poll::Ready(()) => {
|
||||
@@ -181,19 +222,20 @@ impl fmt::Debug for AddrIncoming {
|
||||
}
|
||||
|
||||
mod addr_stream {
|
||||
use bytes::{Buf, BufMut};
|
||||
use std::io;
|
||||
use std::net::SocketAddr;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::io::{AsRawFd, RawFd};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::common::{task, Pin, Poll};
|
||||
|
||||
/// A transport returned yieled by `AddrIncoming`.
|
||||
#[pin_project::pin_project]
|
||||
#[derive(Debug)]
|
||||
pub struct AddrStream {
|
||||
#[pin]
|
||||
inner: TcpStream,
|
||||
pub(super) remote_addr: SocketAddr,
|
||||
}
|
||||
@@ -231,49 +273,24 @@ mod addr_stream {
|
||||
}
|
||||
|
||||
impl AsyncRead for AddrStream {
|
||||
unsafe fn prepare_uninitialized_buffer(
|
||||
&self,
|
||||
buf: &mut [std::mem::MaybeUninit<u8>],
|
||||
) -> bool {
|
||||
self.inner.prepare_uninitialized_buffer(buf)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_read_buf<B: BufMut>(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut B,
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_read_buf(cx, buf)
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.project().inner.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for AddrStream {
|
||||
#[inline]
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_write_buf<B: Buf>(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut B,
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write_buf(cx, buf)
|
||||
self.project().inner.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@@ -283,11 +300,8 @@ mod addr_stream {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.project().inner.poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ use std::io;
|
||||
use std::marker::Unpin;
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::common::io::Rewind;
|
||||
@@ -105,15 +105,11 @@ impl Upgraded {
|
||||
}
|
||||
|
||||
impl AsyncRead for Upgraded {
|
||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
|
||||
self.io.prepare_uninitialized_buffer(buf)
|
||||
}
|
||||
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.io).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
@@ -127,14 +123,6 @@ impl AsyncWrite for Upgraded {
|
||||
Pin::new(&mut self.io).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_write_buf<B: Buf>(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut B,
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(self.io.get_mut()).poll_write_dyn_buf(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.io).poll_flush(cx)
|
||||
}
|
||||
@@ -247,15 +235,11 @@ impl dyn Io + Send {
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + Unpin> AsyncRead for ForwardsWriteBuf<T> {
|
||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
|
||||
self.0.prepare_uninitialized_buffer(buf)
|
||||
}
|
||||
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
@@ -269,14 +253,6 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for ForwardsWriteBuf<T> {
|
||||
Pin::new(&mut self.0).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_write_buf<B: Buf>(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut B,
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.0).poll_write_buf(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.0).poll_flush(cx)
|
||||
}
|
||||
@@ -290,9 +266,9 @@ impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for ForwardsWriteBuf<T> {
|
||||
fn poll_write_dyn_buf(
|
||||
&mut self,
|
||||
cx: &mut task::Context<'_>,
|
||||
mut buf: &mut dyn Buf,
|
||||
buf: &mut dyn Buf,
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.0).poll_write_buf(cx, &mut buf)
|
||||
Pin::new(&mut self.0).poll_write(cx, buf.bytes())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,8 +302,8 @@ mod tests {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut task::Context<'_>,
|
||||
_buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
_buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
unreachable!("Mock::poll_read")
|
||||
}
|
||||
}
|
||||
@@ -335,21 +311,23 @@ mod tests {
|
||||
impl AsyncWrite for Mock {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut task::Context<'_>,
|
||||
_buf: &[u8],
|
||||
_: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
panic!("poll_write shouldn't be called");
|
||||
// panic!("poll_write shouldn't be called");
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
|
||||
fn poll_write_buf<B: Buf>(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut task::Context<'_>,
|
||||
buf: &mut B,
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let n = buf.remaining();
|
||||
buf.advance(n);
|
||||
Poll::Ready(Ok(n))
|
||||
}
|
||||
// TODO(eliza): :(
|
||||
// fn poll_write_buf<B: Buf>(
|
||||
// self: Pin<&mut Self>,
|
||||
// _cx: &mut task::Context<'_>,
|
||||
// buf: &mut B,
|
||||
// ) -> Poll<io::Result<usize>> {
|
||||
// let n = buf.remaining();
|
||||
// buf.advance(n);
|
||||
// Poll::Ready(Ok(n))
|
||||
// }
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
unreachable!("Mock::poll_flush")
|
||||
|
||||
Reference in New Issue
Block a user