Tokio 0.3 Upgrade (#2319)

Co-authored-by: Urhengulas <johann.hemmann@code.berlin>
Co-authored-by: Eliza Weisman <eliza@buoyant.io>
This commit is contained in:
Sean McArthur
2020-11-05 17:17:21 -08:00
committed by GitHub
parent cc7d3058e8
commit 1b9af22fa0
24 changed files with 467 additions and 472 deletions

View File

@@ -809,9 +809,9 @@ where
type Output = Result<Connection<I, S, E>, FE>;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let me = self.project();
let mut me = self.project();
let service = ready!(me.future.poll(cx))?;
let io = me.io.take().expect("polled after complete");
let io = Option::take(&mut me.io).expect("polled after complete");
Poll::Ready(Ok(me.protocol.serve_connection(io, service)))
}
}

View File

@@ -4,7 +4,7 @@ use std::net::{SocketAddr, TcpListener as StdTcpListener};
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::time::Delay;
use tokio::time::Sleep;
use crate::common::{task, Future, Pin, Poll};
@@ -19,7 +19,7 @@ pub struct AddrIncoming {
sleep_on_errors: bool,
tcp_keepalive_timeout: Option<Duration>,
tcp_nodelay: bool,
timeout: Option<Delay>,
timeout: Option<Sleep>,
}
impl AddrIncoming {
@@ -30,6 +30,10 @@ impl AddrIncoming {
}
pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> {
// TcpListener::from_std doesn't set O_NONBLOCK
std_listener
.set_nonblocking(true)
.map_err(crate::Error::new_listen)?;
let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?;
let addr = listener.local_addr().map_err(crate::Error::new_listen)?;
Ok(AddrIncoming {
@@ -98,9 +102,46 @@ impl AddrIncoming {
match ready!(self.listener.poll_accept(cx)) {
Ok((socket, addr)) => {
if let Some(dur) = self.tcp_keepalive_timeout {
// Convert the Tokio `TcpStream` into a `socket2` socket
// so we can call `set_keepalive`.
// TODO(eliza): if Tokio's `TcpSocket` API grows a few
// more methods in the future, hopefully we shouldn't
// have to do the `from_raw_fd` dance any longer...
#[cfg(unix)]
let socket = unsafe {
// Safety: `socket2`'s socket will try to close the
// underlying fd when it's dropped. However, we
// can't take ownership of the fd from the tokio
// TcpStream, so instead we will call `into_raw_fd`
// on the socket2 socket before dropping it. This
// prevents it from trying to close the fd.
use std::os::unix::io::{AsRawFd, FromRawFd};
socket2::Socket::from_raw_fd(socket.as_raw_fd())
};
#[cfg(windows)]
let socket = unsafe {
// Safety: `socket2`'s socket will try to close the
// underlying SOCKET when it's dropped. However, we
// can't take ownership of the SOCKET from the tokio
// TcpStream, so instead we will call `into_raw_socket`
// on the socket2 socket before dropping it. This
// prevents it from trying to close the SOCKET.
use std::os::windows::io::{AsRawSocket, FromRawSocket};
socket2::Socket::from_raw_socket(socket.as_raw_socket())
};
// Actually set the TCP keepalive timeout.
if let Err(e) = socket.set_keepalive(Some(dur)) {
trace!("error trying to set TCP keepalive: {}", e);
}
// Take ownershop of the fd/socket back from the socket2
// `Socket`, so that socket2 doesn't try to close it
// when it's dropped.
#[cfg(unix)]
drop(std::os::unix::io::IntoRawFd::into_raw_fd(socket));
#[cfg(windows)]
drop(std::os::windows::io::IntoRawSocket::into_raw_socket(socket));
}
if let Err(e) = socket.set_nodelay(self.tcp_nodelay) {
trace!("error trying to set TCP nodelay: {}", e);
@@ -119,7 +160,7 @@ impl AddrIncoming {
error!("accept error: {}", e);
// Sleep 1s.
let mut timeout = tokio::time::delay_for(Duration::from_secs(1));
let mut timeout = tokio::time::sleep(Duration::from_secs(1));
match Pin::new(&mut timeout).poll(cx) {
Poll::Ready(()) => {
@@ -181,19 +222,20 @@ impl fmt::Debug for AddrIncoming {
}
mod addr_stream {
use bytes::{Buf, BufMut};
use std::io;
use std::net::SocketAddr;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use crate::common::{task, Pin, Poll};
/// A transport returned yieled by `AddrIncoming`.
#[pin_project::pin_project]
#[derive(Debug)]
pub struct AddrStream {
#[pin]
inner: TcpStream,
pub(super) remote_addr: SocketAddr,
}
@@ -231,49 +273,24 @@ mod addr_stream {
}
impl AsyncRead for AddrStream {
unsafe fn prepare_uninitialized_buffer(
&self,
buf: &mut [std::mem::MaybeUninit<u8>],
) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
#[inline]
fn poll_read_buf<B: BufMut>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_read_buf(cx, buf)
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.project().inner.poll_read(cx, buf)
}
}
impl AsyncWrite for AddrStream {
#[inline]
fn poll_write(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
#[inline]
fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_buf(cx, buf)
self.project().inner.poll_write(cx, buf)
}
#[inline]
@@ -283,11 +300,8 @@ mod addr_stream {
}
#[inline]
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
self.project().inner.poll_shutdown(cx)
}
}