diff --git a/Cargo.toml b/Cargo.toml index 2f7d41f6..f2ff919f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ itoa = "0.4.1" tracing = { version = "0.1", default-features = false, features = ["log", "std"] } pin-project = "1.0" tower-service = "0.3" -tokio = { version = "0.3", features = ["sync", "stream"] } +tokio = { version = "0.3.4", features = ["sync", "stream"] } want = "0.3" # Optional diff --git a/src/client/connect/http.rs b/src/client/connect/http.rs index 639a21ac..7c17ace1 100644 --- a/src/client/connect/http.rs +++ b/src/client/connect/http.rs @@ -569,10 +569,11 @@ fn connect( connect_timeout: Option, ) -> Result>, ConnectError> { // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the - // keepalive timeout and send/recv buffer size, it would be nice to use that - // instead of socket2, and avoid the unsafe `into_raw_fd`/`from_raw_fd` - // dance... + // keepalive timeout, it would be nice to use that instead of socket2, + // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance... use socket2::{Domain, Protocol, Socket, Type}; + use std::convert::TryInto; + let domain = match *addr { SocketAddr::V4(_) => Domain::ipv4(), SocketAddr::V6(_) => Domain::ipv6(), @@ -580,18 +581,18 @@ fn connect( let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp())) .map_err(ConnectError::m("tcp open error"))?; - if config.reuse_address { - socket - .set_reuse_address(true) - .map_err(ConnectError::m("tcp set_reuse_address error"))?; - } - // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is // responsible for ensuring O_NONBLOCK is set. socket .set_nonblocking(true) .map_err(ConnectError::m("tcp set_nonblocking error"))?; + if let Some(dur) = config.keep_alive_timeout { + socket + .set_keepalive(Some(dur)) + .map_err(ConnectError::m("tcp set_keepalive error"))?; + } + bind_local_address( &socket, addr, @@ -600,24 +601,6 @@ fn connect( ) .map_err(ConnectError::m("tcp bind local error"))?; - if let Some(dur) = config.keep_alive_timeout { - socket - .set_keepalive(Some(dur)) - .map_err(ConnectError::m("tcp set_keepalive error"))?; - } - - if let Some(size) = config.send_buffer_size { - socket - .set_send_buffer_size(size) - .map_err(ConnectError::m("tcp set_send_buffer_size error"))?; - } - - if let Some(size) = config.recv_buffer_size { - socket - .set_recv_buffer_size(size) - .map_err(ConnectError::m("tcp set_recv_buffer_size error"))?; - } - #[cfg(unix)] let socket = unsafe { // Safety: `from_raw_fd` is only safe to call if ownership of the raw @@ -636,6 +619,25 @@ fn connect( use std::os::windows::io::{FromRawSocket, IntoRawSocket}; TcpSocket::from_raw_socket(socket.into_raw_socket()) }; + + if config.reuse_address { + socket + .set_reuseaddr(true) + .map_err(ConnectError::m("tcp set_reuse_address error"))?; + } + + if let Some(size) = config.send_buffer_size { + socket + .set_send_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) + .map_err(ConnectError::m("tcp set_send_buffer_size error"))?; + } + + if let Some(size) = config.recv_buffer_size { + socket + .set_recv_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) + .map_err(ConnectError::m("tcp set_recv_buffer_size error"))?; + } + let connect = socket.connect(*addr); Ok(async move { match connect_timeout {