Tokio 0.3 Upgrade (#2319)

Co-authored-by: Urhengulas <johann.hemmann@code.berlin>
Co-authored-by: Eliza Weisman <eliza@buoyant.io>
This commit is contained in:
Sean McArthur
2020-11-05 17:17:21 -08:00
committed by GitHub
parent cc7d3058e8
commit 1b9af22fa0
24 changed files with 467 additions and 472 deletions

View File

@@ -12,8 +12,8 @@ use std::time::Duration;
use futures_util::future::Either;
use http::uri::{Scheme, Uri};
use pin_project::pin_project;
use tokio::net::TcpStream;
use tokio::time::Delay;
use tokio::net::{TcpSocket, TcpStream};
use tokio::time::Sleep;
use super::dns::{self, resolve, GaiResolver, Resolve};
use super::{Connected, Connection};
@@ -331,34 +331,9 @@ where
dns::IpAddrs::new(addrs)
};
let c = ConnectingTcp::new(
config.local_address_ipv4,
config.local_address_ipv6,
addrs,
config.connect_timeout,
config.happy_eyeballs_timeout,
config.reuse_address,
);
let c = ConnectingTcp::new(addrs, config);
let sock = c
.connect()
.await
.map_err(ConnectError::m("tcp connect error"))?;
if let Some(dur) = config.keep_alive_timeout {
sock.set_keepalive(Some(dur))
.map_err(ConnectError::m("tcp set_keepalive error"))?;
}
if let Some(size) = config.send_buffer_size {
sock.set_send_buffer_size(size)
.map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
}
if let Some(size) = config.recv_buffer_size {
sock.set_recv_buffer_size(size)
.map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
}
let sock = c.connect().await?;
sock.set_nodelay(config.nodelay)
.map_err(ConnectError::m("tcp set_nodelay error"))?;
@@ -475,60 +450,45 @@ impl StdError for ConnectError {
}
}
struct ConnectingTcp {
local_addr_ipv4: Option<Ipv4Addr>,
local_addr_ipv6: Option<Ipv6Addr>,
struct ConnectingTcp<'a> {
preferred: ConnectingTcpRemote,
fallback: Option<ConnectingTcpFallback>,
reuse_address: bool,
config: &'a Config,
}
impl ConnectingTcp {
fn new(
local_addr_ipv4: Option<Ipv4Addr>,
local_addr_ipv6: Option<Ipv6Addr>,
remote_addrs: dns::IpAddrs,
connect_timeout: Option<Duration>,
fallback_timeout: Option<Duration>,
reuse_address: bool,
) -> ConnectingTcp {
if let Some(fallback_timeout) = fallback_timeout {
let (preferred_addrs, fallback_addrs) =
remote_addrs.split_by_preference(local_addr_ipv4, local_addr_ipv6);
impl<'a> ConnectingTcp<'a> {
fn new(remote_addrs: dns::IpAddrs, config: &'a Config) -> Self {
if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
let (preferred_addrs, fallback_addrs) = remote_addrs
.split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
if fallback_addrs.is_empty() {
return ConnectingTcp {
local_addr_ipv4,
local_addr_ipv6,
preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout),
preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
fallback: None,
reuse_address,
config,
};
}
ConnectingTcp {
local_addr_ipv4,
local_addr_ipv6,
preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout),
preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
fallback: Some(ConnectingTcpFallback {
delay: tokio::time::delay_for(fallback_timeout),
remote: ConnectingTcpRemote::new(fallback_addrs, connect_timeout),
delay: tokio::time::sleep(fallback_timeout),
remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
}),
reuse_address,
config,
}
} else {
ConnectingTcp {
local_addr_ipv4,
local_addr_ipv6,
preferred: ConnectingTcpRemote::new(remote_addrs, connect_timeout),
preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
fallback: None,
reuse_address,
config,
}
}
}
}
struct ConnectingTcpFallback {
delay: Delay,
delay: Sleep,
remote: ConnectingTcpRemote,
}
@@ -549,24 +509,11 @@ impl ConnectingTcpRemote {
}
impl ConnectingTcpRemote {
async fn connect(
&mut self,
local_addr_ipv4: &Option<Ipv4Addr>,
local_addr_ipv6: &Option<Ipv6Addr>,
reuse_address: bool,
) -> io::Result<TcpStream> {
async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
let mut err = None;
for addr in &mut self.addrs {
debug!("connecting to {}", addr);
match connect(
&addr,
local_addr_ipv4,
local_addr_ipv6,
reuse_address,
self.connect_timeout,
)?
.await
{
match connect(&addr, config, self.connect_timeout)?.await {
Ok(tcp) => {
debug!("connected to {}", addr);
return Ok(tcp);
@@ -580,9 +527,9 @@ impl ConnectingTcpRemote {
match err {
Some(e) => Err(e),
None => Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"Network unreachable",
None => Err(ConnectError::new(
"tcp connect error",
std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
)),
}
}
@@ -618,30 +565,79 @@ fn bind_local_address(
fn connect(
addr: &SocketAddr,
local_addr_ipv4: &Option<Ipv4Addr>,
local_addr_ipv6: &Option<Ipv6Addr>,
reuse_address: bool,
config: &Config,
connect_timeout: Option<Duration>,
) -> io::Result<impl Future<Output = io::Result<TcpStream>>> {
) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
// TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
// keepalive timeout and send/recv buffer size, it would be nice to use that
// instead of socket2, and avoid the unsafe `into_raw_fd`/`from_raw_fd`
// dance...
use socket2::{Domain, Protocol, Socket, Type};
let domain = match *addr {
SocketAddr::V4(_) => Domain::ipv4(),
SocketAddr::V6(_) => Domain::ipv6(),
};
let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))?;
let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))
.map_err(ConnectError::m("tcp open error"))?;
if reuse_address {
socket.set_reuse_address(true)?;
if config.reuse_address {
socket
.set_reuse_address(true)
.map_err(ConnectError::m("tcp set_reuse_address error"))?;
}
bind_local_address(&socket, addr, local_addr_ipv4, local_addr_ipv6)?;
// When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
// responsible for ensuring O_NONBLOCK is set.
socket
.set_nonblocking(true)
.map_err(ConnectError::m("tcp set_nonblocking error"))?;
let addr = *addr;
bind_local_address(
&socket,
addr,
&config.local_address_ipv4,
&config.local_address_ipv6,
)
.map_err(ConnectError::m("tcp bind local error"))?;
let std_tcp = socket.into_tcp_stream();
if let Some(dur) = config.keep_alive_timeout {
socket
.set_keepalive(Some(dur))
.map_err(ConnectError::m("tcp set_keepalive error"))?;
}
if let Some(size) = config.send_buffer_size {
socket
.set_send_buffer_size(size)
.map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
}
if let Some(size) = config.recv_buffer_size {
socket
.set_recv_buffer_size(size)
.map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
}
#[cfg(unix)]
let socket = unsafe {
// Safety: `from_raw_fd` is only safe to call if ownership of the raw
// file descriptor is transferred. Since we call `into_raw_fd` on the
// socket2 socket, it gives up ownership of the fd and will not close
// it, so this is safe.
use std::os::unix::io::{FromRawFd, IntoRawFd};
TcpSocket::from_raw_fd(socket.into_raw_fd())
};
#[cfg(windows)]
let socket = unsafe {
// Safety: `from_raw_socket` is only safe to call if ownership of the raw
// Windows SOCKET is transferred. Since we call `into_raw_socket` on the
// socket2 socket, it gives up ownership of the SOCKET and will not close
// it, so this is safe.
use std::os::windows::io::{FromRawSocket, IntoRawSocket};
TcpSocket::from_raw_socket(socket.into_raw_socket())
};
let connect = socket.connect(*addr);
Ok(async move {
let connect = TcpStream::connect_std(std_tcp, &addr);
match connect_timeout {
Some(dur) => match tokio::time::timeout(dur, connect).await {
Ok(Ok(s)) => Ok(s),
@@ -650,33 +646,19 @@ fn connect(
},
None => connect.await,
}
.map_err(ConnectError::m("tcp connect error"))
})
}
impl ConnectingTcp {
async fn connect(mut self) -> io::Result<TcpStream> {
let Self {
ref local_addr_ipv4,
ref local_addr_ipv6,
reuse_address,
..
} = self;
impl ConnectingTcp<'_> {
async fn connect(mut self) -> Result<TcpStream, ConnectError> {
match self.fallback {
None => {
self.preferred
.connect(local_addr_ipv4, local_addr_ipv6, reuse_address)
.await
}
None => self.preferred.connect(self.config).await,
Some(mut fallback) => {
let preferred_fut =
self.preferred
.connect(local_addr_ipv4, local_addr_ipv6, reuse_address);
let preferred_fut = self.preferred.connect(self.config);
futures_util::pin_mut!(preferred_fut);
let fallback_fut =
fallback
.remote
.connect(local_addr_ipv4, local_addr_ipv6, reuse_address);
let fallback_fut = fallback.remote.connect(self.config);
futures_util::pin_mut!(fallback_fut);
let (result, future) =
@@ -711,7 +693,7 @@ mod tests {
use ::http::Uri;
use super::super::sealed::{Connect, ConnectSvc};
use super::HttpConnector;
use super::{Config, ConnectError, HttpConnector};
async fn connect<C>(
connector: C,
@@ -773,6 +755,7 @@ mod tests {
#[tokio::test]
async fn local_address() {
use std::net::{IpAddr, TcpListener};
let _ = pretty_env_logger::try_init();
let (bind_ip_v4, bind_ip_v6) = get_local_ips();
let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
@@ -818,10 +801,8 @@ mod tests {
let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = server4.local_addr().unwrap();
let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
let mut rt = tokio::runtime::Builder::new()
.enable_io()
.enable_time()
.basic_scheduler()
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
@@ -925,16 +906,21 @@ mod tests {
.iter()
.map(|host| (host.clone(), addr.port()).into())
.collect();
let connecting_tcp = ConnectingTcp::new(
None,
None,
dns::IpAddrs::new(addrs),
None,
Some(fallback_timeout),
false,
);
let cfg = Config {
local_address_ipv4: None,
local_address_ipv6: None,
connect_timeout: None,
keep_alive_timeout: None,
happy_eyeballs_timeout: Some(fallback_timeout),
nodelay: false,
reuse_address: false,
enforce_http: false,
send_buffer_size: None,
recv_buffer_size: None,
};
let connecting_tcp = ConnectingTcp::new(dns::IpAddrs::new(addrs), &cfg);
let start = Instant::now();
Ok::<_, io::Error>((start, connecting_tcp.connect().await?))
Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
})
.unwrap();
let res = if stream.peer_addr().unwrap().is_ipv4() {

View File

@@ -1,4 +1,5 @@
use futures_util::future;
use tokio::stream::Stream;
use tokio::sync::{mpsc, oneshot};
use crate::common::{task, Future, Pin, Poll};
@@ -131,22 +132,25 @@ impl<T, U> Clone for UnboundedSender<T, U> {
}
}
#[pin_project::pin_project(PinnedDrop)]
pub struct Receiver<T, U> {
#[pin]
inner: mpsc::UnboundedReceiver<Envelope<T, U>>,
taker: want::Taker,
}
impl<T, U> Receiver<T, U> {
pub(crate) fn poll_next(
&mut self,
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Option<(T, Callback<T, U>)>> {
match self.inner.poll_recv(cx) {
let this = self.project();
match this.inner.poll_next(cx) {
Poll::Ready(item) => {
Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped")))
}
Poll::Pending => {
self.taker.want();
this.taker.want();
Poll::Pending
}
}
@@ -165,11 +169,12 @@ impl<T, U> Receiver<T, U> {
}
}
impl<T, U> Drop for Receiver<T, U> {
fn drop(&mut self) {
#[pin_project::pinned_drop]
impl<T, U> PinnedDrop for Receiver<T, U> {
fn drop(mut self: Pin<&mut Self>) {
// Notify the giver about the closure first, before dropping
// the mpsc::Receiver.
self.taker.cancel();
self.as_mut().taker.cancel();
}
}
@@ -262,7 +267,7 @@ mod tests {
impl<T, U> Future for Receiver<T, U> {
type Output = Option<(T, Callback<T, U>)>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.poll_next(cx)
}
}
@@ -344,9 +349,8 @@ mod tests {
fn giver_queue_throughput(b: &mut test::Bencher) {
use crate::{Body, Request, Response};
let mut rt = tokio::runtime::Builder::new()
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.basic_scheduler()
.build()
.unwrap();
let (mut tx, mut rx) = channel::<Request<Body>, Response<Body>>();
@@ -368,9 +372,8 @@ mod tests {
#[cfg(feature = "nightly")]
#[bench]
fn giver_queue_not_ready(b: &mut test::Bencher) {
let mut rt = tokio::runtime::Builder::new()
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.basic_scheduler()
.build()
.unwrap();
let (_tx, mut rx) = channel::<i32, ()>();

View File

@@ -706,12 +706,15 @@ impl Expiration {
}
#[cfg(feature = "runtime")]
#[pin_project::pin_project]
struct IdleTask<T> {
#[pin]
interval: Interval,
pool: WeakOpt<Mutex<PoolInner<T>>>,
// This allows the IdleTask to be notified as soon as the entire
// Pool is fully dropped, and shutdown. This channel is never sent on,
// but Err(Canceled) will be received when the Pool is dropped.
#[pin]
pool_drop_notifier: oneshot::Receiver<crate::common::Never>,
}
@@ -719,9 +722,11 @@ struct IdleTask<T> {
impl<T: Poolable + 'static> Future for IdleTask<T> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
use tokio::stream::Stream;
let mut this = self.project();
loop {
match Pin::new(&mut self.pool_drop_notifier).poll(cx) {
match this.pool_drop_notifier.as_mut().poll(cx) {
Poll::Ready(Ok(n)) => match n {},
Poll::Pending => (),
Poll::Ready(Err(_canceled)) => {
@@ -730,9 +735,9 @@ impl<T: Poolable + 'static> Future for IdleTask<T> {
}
}
ready!(self.interval.poll_tick(cx));
ready!(this.interval.as_mut().poll_next(cx));
if let Some(inner) = self.pool.upgrade() {
if let Some(inner) = this.pool.upgrade() {
if let Ok(mut inner) = inner.lock() {
trace!("idle interval checking for expired");
inner.clear_expired();
@@ -850,7 +855,7 @@ mod tests {
let pooled = pool.pooled(c(key.clone()), Uniq(41));
drop(pooled);
tokio::time::delay_for(pool.locked().timeout.unwrap()).await;
tokio::time::sleep(pool.locked().timeout.unwrap()).await;
let mut checkout = pool.checkout(key);
let poll_once = PollOnce(&mut checkout);
let is_not_ready = poll_once.await.is_none();
@@ -871,7 +876,7 @@ mod tests {
pool.locked().idle.get(&key).map(|entries| entries.len()),
Some(3)
);
tokio::time::delay_for(pool.locked().timeout.unwrap()).await;
tokio::time::sleep(pool.locked().timeout.unwrap()).await;
let mut checkout = pool.checkout(key.clone());
let poll_once = PollOnce(&mut checkout);

View File

@@ -1,20 +1,13 @@
use std::mem;
use pin_project::pin_project;
use tokio::stream::Stream;
use tokio::sync::{mpsc, watch};
use super::{task, Future, Never, Pin, Poll};
// Sentinel value signaling that the watch is still open
#[derive(Clone, Copy)]
enum Action {
Open,
// Closed isn't sent via the `Action` type, but rather once
// the watch::Sender is dropped.
}
pub fn channel() -> (Signal, Watch) {
let (tx, rx) = watch::channel(Action::Open);
let (tx, rx) = watch::channel(());
let (drained_tx, drained_rx) = mpsc::channel(1);
(
Signal {
@@ -27,17 +20,19 @@ pub fn channel() -> (Signal, Watch) {
pub struct Signal {
drained_rx: mpsc::Receiver<Never>,
_tx: watch::Sender<Action>,
_tx: watch::Sender<()>,
}
#[pin_project::pin_project]
pub struct Draining {
#[pin]
drained_rx: mpsc::Receiver<Never>,
}
#[derive(Clone)]
pub struct Watch {
drained_tx: mpsc::Sender<Never>,
rx: watch::Receiver<Action>,
rx: watch::Receiver<()>,
}
#[allow(missing_debug_implementations)]
@@ -46,7 +41,8 @@ pub struct Watching<F, FN> {
#[pin]
future: F,
state: State<FN>,
watch: Watch,
watch: Pin<Box<dyn Future<Output = ()> + Send + Sync>>,
_drained_tx: mpsc::Sender<Never>,
}
enum State<F> {
@@ -66,8 +62,8 @@ impl Signal {
impl Future for Draining {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
match ready!(self.drained_rx.poll_recv(cx)) {
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
match ready!(self.project().drained_rx.poll_next(cx)) {
Some(never) => match never {},
None => Poll::Ready(()),
}
@@ -80,10 +76,14 @@ impl Watch {
F: Future,
FN: FnOnce(Pin<&mut F>),
{
let Self { drained_tx, mut rx } = self;
Watching {
future,
state: State::Watch(on_drain),
watch: self,
watch: Box::pin(async move {
let _ = rx.changed().await;
}),
_drained_tx: drained_tx,
}
}
}
@@ -100,12 +100,12 @@ where
loop {
match mem::replace(me.state, State::Draining) {
State::Watch(on_drain) => {
match me.watch.rx.poll_recv_ref(cx) {
Poll::Ready(None) => {
match Pin::new(&mut me.watch).poll(cx) {
Poll::Ready(()) => {
// Drain has been triggered!
on_drain(me.future.as_mut());
}
Poll::Ready(Some(_ /*State::Open*/)) | Poll::Pending => {
Poll::Pending => {
*me.state = State::Watch(on_drain);
return me.future.poll(cx);
}

View File

@@ -2,7 +2,7 @@ use std::marker::Unpin;
use std::{cmp, io};
use bytes::{Buf, Bytes};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::common::{task, Pin, Poll};
@@ -37,36 +37,33 @@ impl<T> Rewind<T> {
(self.inner, self.pre.unwrap_or_else(Bytes::new))
}
pub(crate) fn get_mut(&mut self) -> &mut T {
&mut self.inner
}
// pub(crate) fn get_mut(&mut self) -> &mut T {
// &mut self.inner
// }
}
impl<T> AsyncRead for Rewind<T>
where
T: AsyncRead + Unpin,
{
#[inline]
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some(mut prefix) = self.pre.take() {
// If there are no remaining bytes, let the bytes get dropped.
if !prefix.is_empty() {
let copy_len = cmp::min(prefix.len(), buf.len());
prefix.copy_to_slice(&mut buf[..copy_len]);
let copy_len = cmp::min(prefix.len(), buf.remaining());
// TODO: There should be a way to do following two lines cleaner...
buf.put_slice(&prefix[..copy_len]);
prefix.advance(copy_len);
// Put back whats left
if !prefix.is_empty() {
self.pre = Some(prefix);
}
return Poll::Ready(Ok(copy_len));
return Poll::Ready(Ok(()));
}
}
Pin::new(&mut self.inner).poll_read(cx, buf)
@@ -92,15 +89,6 @@ where
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
#[inline]
fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_buf(cx, buf)
}
}
#[cfg(test)]

View File

@@ -967,9 +967,8 @@ mod tests {
*conn.io.read_buf_mut() = ::bytes::BytesMut::from(&s[..]);
conn.state.cached_headers = Some(HeaderMap::with_capacity(2));
let mut rt = tokio::runtime::Builder::new()
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.basic_scheduler()
.build()
.unwrap();

View File

@@ -382,7 +382,7 @@ mod tests {
use super::*;
use std::pin::Pin;
use std::time::Duration;
use tokio::io::AsyncRead;
use tokio::io::{AsyncRead, ReadBuf};
impl<'a> MemRead for &'a [u8] {
fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
@@ -401,8 +401,9 @@ mod tests {
impl<'a> MemRead for &'a mut (dyn AsyncRead + Unpin) {
fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
let mut v = vec![0; len];
let n = ready!(Pin::new(self).poll_read(cx, &mut v)?);
Poll::Ready(Ok(Bytes::copy_from_slice(&v[..n])))
let mut buf = ReadBuf::new(&mut v);
ready!(Pin::new(self).poll_read(cx, &mut buf)?);
Poll::Ready(Ok(Bytes::copy_from_slice(&buf.filled())))
}
}
@@ -623,7 +624,7 @@ mod tests {
#[cfg(feature = "nightly")]
#[bench]
fn bench_decode_chunked_1kb(b: &mut test::Bencher) {
let mut rt = new_runtime();
let rt = new_runtime();
const LEN: usize = 1024;
let mut vec = Vec::new();
@@ -647,7 +648,7 @@ mod tests {
#[cfg(feature = "nightly")]
#[bench]
fn bench_decode_length_1kb(b: &mut test::Bencher) {
let mut rt = new_runtime();
let rt = new_runtime();
const LEN: usize = 1024;
let content = Bytes::from(&[0; LEN][..]);
@@ -665,9 +666,8 @@ mod tests {
#[cfg(feature = "nightly")]
fn new_runtime() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new()
tokio::runtime::Builder::new_current_thread()
.enable_all()
.basic_scheduler()
.build()
.expect("rt build")
}

View File

@@ -27,7 +27,7 @@ pub(crate) trait Dispatch {
type PollError;
type RecvItem;
fn poll_msg(
&mut self,
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>>;
fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()>;
@@ -40,8 +40,10 @@ pub struct Server<S: HttpService<B>, B> {
pub(crate) service: S,
}
#[pin_project::pin_project]
pub struct Client<B> {
callback: Option<crate::client::dispatch::Callback<Request<B>, Response<Body>>>,
#[pin]
rx: ClientRx<B>,
rx_closed: bool,
}
@@ -281,7 +283,7 @@ where
&& self.conn.can_write_head()
&& self.dispatch.should_poll()
{
if let Some(msg) = ready!(self.dispatch.poll_msg(cx)) {
if let Some(msg) = ready!(Pin::new(&mut self.dispatch).poll_msg(cx)) {
let (head, mut body) = msg.map_err(crate::Error::new_user_service)?;
// Check if the body knows its full data immediately.
@@ -469,10 +471,11 @@ where
type RecvItem = RequestHead;
fn poll_msg(
&mut self,
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>> {
let ret = if let Some(ref mut fut) = self.in_flight.as_mut().as_pin_mut() {
let mut this = self.as_mut();
let ret = if let Some(ref mut fut) = this.in_flight.as_mut().as_pin_mut() {
let resp = ready!(fut.as_mut().poll(cx)?);
let (parts, body) = resp.into_parts();
let head = MessageHead {
@@ -486,7 +489,7 @@ where
};
// Since in_flight finished, remove it
self.in_flight.set(None);
this.in_flight.set(None);
ret
}
@@ -540,11 +543,12 @@ where
type RecvItem = ResponseHead;
fn poll_msg(
&mut self,
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Never>>> {
debug_assert!(!self.rx_closed);
match self.rx.poll_next(cx) {
let this = self.project();
debug_assert!(!*this.rx_closed);
match this.rx.poll_next(cx) {
Poll::Ready(Some((req, mut cb))) => {
// check that future hasn't been canceled already
match cb.poll_canceled(cx) {
@@ -559,7 +563,7 @@ where
subject: RequestLine(parts.method, parts.uri),
headers: parts.headers,
};
self.callback = Some(cb);
*this.callback = Some(cb);
Poll::Ready(Some(Ok((head, body))))
}
}
@@ -567,7 +571,7 @@ where
Poll::Ready(None) => {
// user has dropped sender handle
trace!("client tx closed");
self.rx_closed = true;
*this.rx_closed = true;
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,

View File

@@ -4,7 +4,7 @@ use std::fmt;
use std::io::{self, IoSlice};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use super::{Http1Transaction, ParseContext, ParsedMessage};
use crate::common::buf::BufList;
@@ -188,9 +188,16 @@ where
if self.read_buf_remaining_mut() < next {
self.read_buf.reserve(next);
}
match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf) {
Poll::Ready(Ok(n)) => {
debug!("read {} bytes", n);
let mut buf = ReadBuf::uninit(&mut self.read_buf.bytes_mut()[..]);
match Pin::new(&mut self.io).poll_read(cx, &mut buf) {
Poll::Ready(Ok(_)) => {
let n = buf.filled().len();
unsafe {
// Safety: we just read that many bytes into the
// uninitialized part of the buffer, so this is okay.
// @tokio pls give me back `poll_read_buf` thanks
self.read_buf.advance_mut(n);
}
self.read_buf_strategy.record(n);
Poll::Ready(Ok(n))
}
@@ -224,8 +231,16 @@ where
return self.poll_flush_flattened(cx);
}
loop {
let n =
ready!(Pin::new(&mut self.io).poll_write_buf(cx, &mut self.write_buf.auto()))?;
// TODO(eliza): this basically ignores all of `WriteBuf`...put
// back vectored IO and `poll_write_buf` when the appropriate Tokio
// changes land...
let n = ready!(Pin::new(&mut self.io)
// .poll_write_buf(cx, &mut self.write_buf.auto()))?;
.poll_write(cx, self.write_buf.auto().bytes()))?;
// TODO(eliza): we have to do this manually because
// `poll_write_buf` doesn't exist in Tokio 0.3 yet...when
// `poll_write_buf` comes back, the manual advance will need to leave!
self.write_buf.advance(n);
debug!("flushed {} bytes", n);
if self.write_buf.remaining() == 0 {
break;
@@ -452,6 +467,7 @@ where
self.strategy = strategy;
}
// TODO(eliza): put back writev!
#[inline]
fn auto(&mut self) -> WriteBufAuto<'_, B> {
WriteBufAuto::new(self)
@@ -628,28 +644,31 @@ mod tests {
*/
#[tokio::test]
#[ignore]
async fn iobuf_write_empty_slice() {
// First, let's just check that the Mock would normally return an
// error on an unexpected write, even if the buffer is empty...
let mut mock = Mock::new().build();
futures_util::future::poll_fn(|cx| {
Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[]))
})
.await
.expect_err("should be a broken pipe");
// TODO(eliza): can i have writev back pls T_T
// // First, let's just check that the Mock would normally return an
// // error on an unexpected write, even if the buffer is empty...
// let mut mock = Mock::new().build();
// futures_util::future::poll_fn(|cx| {
// Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[]))
// })
// .await
// .expect_err("should be a broken pipe");
// underlying io will return the logic error upon write,
// so we are testing that the io_buf does not trigger a write
// when there is nothing to flush
let mock = Mock::new().build();
let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
io_buf.flush().await.expect("should short-circuit flush");
// // underlying io will return the logic error upon write,
// // so we are testing that the io_buf does not trigger a write
// // when there is nothing to flush
// let mock = Mock::new().build();
// let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
// io_buf.flush().await.expect("should short-circuit flush");
}
#[tokio::test]
async fn parse_reads_until_blocked() {
use crate::proto::h1::ClientTransaction;
let _ = pretty_env_logger::try_init();
let mock = Mock::new()
// Split over multiple reads will read all of it
.read(b"HTTP/1.1 200 OK\r\n")

View File

@@ -33,7 +33,7 @@ use std::time::Instant;
use h2::{Ping, PingPong};
#[cfg(feature = "runtime")]
use tokio::time::{Delay, Instant};
use tokio::time::{Instant, Sleep};
type WindowSize = u32;
@@ -60,7 +60,7 @@ pub(super) fn channel(ping_pong: PingPong, config: Config) -> (Recorder, Ponger)
interval,
timeout: config.keep_alive_timeout,
while_idle: config.keep_alive_while_idle,
timer: tokio::time::delay_for(interval),
timer: tokio::time::sleep(interval),
state: KeepAliveState::Init,
});
@@ -156,7 +156,7 @@ struct KeepAlive {
while_idle: bool,
state: KeepAliveState,
timer: Delay,
timer: Sleep,
}
#[cfg(feature = "runtime")]

View File

@@ -809,9 +809,9 @@ where
type Output = Result<Connection<I, S, E>, FE>;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let me = self.project();
let mut me = self.project();
let service = ready!(me.future.poll(cx))?;
let io = me.io.take().expect("polled after complete");
let io = Option::take(&mut me.io).expect("polled after complete");
Poll::Ready(Ok(me.protocol.serve_connection(io, service)))
}
}

View File

@@ -4,7 +4,7 @@ use std::net::{SocketAddr, TcpListener as StdTcpListener};
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::time::Delay;
use tokio::time::Sleep;
use crate::common::{task, Future, Pin, Poll};
@@ -19,7 +19,7 @@ pub struct AddrIncoming {
sleep_on_errors: bool,
tcp_keepalive_timeout: Option<Duration>,
tcp_nodelay: bool,
timeout: Option<Delay>,
timeout: Option<Sleep>,
}
impl AddrIncoming {
@@ -30,6 +30,10 @@ impl AddrIncoming {
}
pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> {
// TcpListener::from_std doesn't set O_NONBLOCK
std_listener
.set_nonblocking(true)
.map_err(crate::Error::new_listen)?;
let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?;
let addr = listener.local_addr().map_err(crate::Error::new_listen)?;
Ok(AddrIncoming {
@@ -98,9 +102,46 @@ impl AddrIncoming {
match ready!(self.listener.poll_accept(cx)) {
Ok((socket, addr)) => {
if let Some(dur) = self.tcp_keepalive_timeout {
// Convert the Tokio `TcpStream` into a `socket2` socket
// so we can call `set_keepalive`.
// TODO(eliza): if Tokio's `TcpSocket` API grows a few
// more methods in the future, hopefully we shouldn't
// have to do the `from_raw_fd` dance any longer...
#[cfg(unix)]
let socket = unsafe {
// Safety: `socket2`'s socket will try to close the
// underlying fd when it's dropped. However, we
// can't take ownership of the fd from the tokio
// TcpStream, so instead we will call `into_raw_fd`
// on the socket2 socket before dropping it. This
// prevents it from trying to close the fd.
use std::os::unix::io::{AsRawFd, FromRawFd};
socket2::Socket::from_raw_fd(socket.as_raw_fd())
};
#[cfg(windows)]
let socket = unsafe {
// Safety: `socket2`'s socket will try to close the
// underlying SOCKET when it's dropped. However, we
// can't take ownership of the SOCKET from the tokio
// TcpStream, so instead we will call `into_raw_socket`
// on the socket2 socket before dropping it. This
// prevents it from trying to close the SOCKET.
use std::os::windows::io::{AsRawSocket, FromRawSocket};
socket2::Socket::from_raw_socket(socket.as_raw_socket())
};
// Actually set the TCP keepalive timeout.
if let Err(e) = socket.set_keepalive(Some(dur)) {
trace!("error trying to set TCP keepalive: {}", e);
}
// Take ownershop of the fd/socket back from the socket2
// `Socket`, so that socket2 doesn't try to close it
// when it's dropped.
#[cfg(unix)]
drop(std::os::unix::io::IntoRawFd::into_raw_fd(socket));
#[cfg(windows)]
drop(std::os::windows::io::IntoRawSocket::into_raw_socket(socket));
}
if let Err(e) = socket.set_nodelay(self.tcp_nodelay) {
trace!("error trying to set TCP nodelay: {}", e);
@@ -119,7 +160,7 @@ impl AddrIncoming {
error!("accept error: {}", e);
// Sleep 1s.
let mut timeout = tokio::time::delay_for(Duration::from_secs(1));
let mut timeout = tokio::time::sleep(Duration::from_secs(1));
match Pin::new(&mut timeout).poll(cx) {
Poll::Ready(()) => {
@@ -181,19 +222,20 @@ impl fmt::Debug for AddrIncoming {
}
mod addr_stream {
use bytes::{Buf, BufMut};
use std::io;
use std::net::SocketAddr;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use crate::common::{task, Pin, Poll};
/// A transport returned yieled by `AddrIncoming`.
#[pin_project::pin_project]
#[derive(Debug)]
pub struct AddrStream {
#[pin]
inner: TcpStream,
pub(super) remote_addr: SocketAddr,
}
@@ -231,49 +273,24 @@ mod addr_stream {
}
impl AsyncRead for AddrStream {
unsafe fn prepare_uninitialized_buffer(
&self,
buf: &mut [std::mem::MaybeUninit<u8>],
) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
#[inline]
fn poll_read_buf<B: BufMut>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_read_buf(cx, buf)
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.project().inner.poll_read(cx, buf)
}
}
impl AsyncWrite for AddrStream {
#[inline]
fn poll_write(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
#[inline]
fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_buf(cx, buf)
self.project().inner.poll_write(cx, buf)
}
#[inline]
@@ -283,11 +300,8 @@ mod addr_stream {
}
#[inline]
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_shutdown(cx)
}
}

View File

@@ -12,7 +12,7 @@ use std::io;
use std::marker::Unpin;
use bytes::{Buf, Bytes};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::oneshot;
use crate::common::io::Rewind;
@@ -105,15 +105,11 @@ impl Upgraded {
}
impl AsyncRead for Upgraded {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
self.io.prepare_uninitialized_buffer(buf)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_read(cx, buf)
}
}
@@ -127,14 +123,6 @@ impl AsyncWrite for Upgraded {
Pin::new(&mut self.io).poll_write(cx, buf)
}
fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Pin::new(self.io.get_mut()).poll_write_dyn_buf(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.io).poll_flush(cx)
}
@@ -247,15 +235,11 @@ impl dyn Io + Send {
}
impl<T: AsyncRead + Unpin> AsyncRead for ForwardsWriteBuf<T> {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
self.0.prepare_uninitialized_buffer(buf)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
@@ -269,14 +253,6 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for ForwardsWriteBuf<T> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write_buf(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}
@@ -290,9 +266,9 @@ impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for ForwardsWriteBuf<T> {
fn poll_write_dyn_buf(
&mut self,
cx: &mut task::Context<'_>,
mut buf: &mut dyn Buf,
buf: &mut dyn Buf,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write_buf(cx, &mut buf)
Pin::new(&mut self.0).poll_write(cx, buf.bytes())
}
}
@@ -326,8 +302,8 @@ mod tests {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
_buf: &mut [u8],
) -> Poll<io::Result<usize>> {
_buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
unreachable!("Mock::poll_read")
}
}
@@ -335,21 +311,23 @@ mod tests {
impl AsyncWrite for Mock {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
_buf: &[u8],
_: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
panic!("poll_write shouldn't be called");
// panic!("poll_write shouldn't be called");
Poll::Ready(Ok(buf.len()))
}
fn poll_write_buf<B: Buf>(
self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
let n = buf.remaining();
buf.advance(n);
Poll::Ready(Ok(n))
}
// TODO(eliza): :(
// fn poll_write_buf<B: Buf>(
// self: Pin<&mut Self>,
// _cx: &mut task::Context<'_>,
// buf: &mut B,
// ) -> Poll<io::Result<usize>> {
// let n = buf.remaining();
// buf.advance(n);
// Poll::Ready(Ok(n))
// }
fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
unreachable!("Mock::poll_flush")