feat(client): add HttpConnector::set_local_addresses to set both IPv6 and IPv4 local addrs (#2172)

Currently HttpConnector::set_local_address method accepts a single
argument. Server might not support IPv6 or IPv4. Therefore, the only
solution at the moment is to manually perform DNS resolution and pick
appropriate local address family. This is inefficient, as leads to
2 DNS lookups per request. This commit allows specifying both IPv4
and IPv6, so connector can decide which one to use based on DNS
resolution results.
This commit is contained in:
Ivan Nikulin
2020-10-14 00:02:16 +01:00
committed by GitHub
parent 02732bef0c
commit fb19f3a869
3 changed files with 208 additions and 53 deletions

View File

@@ -55,6 +55,9 @@ tokio-util = { version = "0.3", features = ["codec"] }
tower-util = "0.3"
url = "1.0"
[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies]
pnet = "0.25.0"
[features]
default = [
"runtime",

View File

@@ -200,27 +200,33 @@ impl IpAddrs {
None
}
pub(super) fn split_by_preference(self, local_addr: Option<IpAddr>) -> (IpAddrs, IpAddrs) {
if let Some(local_addr) = local_addr {
let preferred = self
.iter
.filter(|addr| addr.is_ipv6() == local_addr.is_ipv6())
.collect();
#[inline]
fn filter(self, predicate: impl FnMut(&SocketAddr) -> bool) -> IpAddrs {
IpAddrs::new(self.iter.filter(predicate).collect())
}
(IpAddrs::new(preferred), IpAddrs::new(vec![]))
} else {
let preferring_v6 = self
.iter
.as_slice()
.first()
.map(SocketAddr::is_ipv6)
.unwrap_or(false);
pub(super) fn split_by_preference(
self,
local_addr_ipv4: Option<Ipv4Addr>,
local_addr_ipv6: Option<Ipv6Addr>,
) -> (IpAddrs, IpAddrs) {
match (local_addr_ipv4, local_addr_ipv6) {
(Some(_), None) => (self.filter(SocketAddr::is_ipv4), IpAddrs::new(vec![])),
(None, Some(_)) => (self.filter(SocketAddr::is_ipv6), IpAddrs::new(vec![])),
_ => {
let preferring_v6 = self
.iter
.as_slice()
.first()
.map(SocketAddr::is_ipv6)
.unwrap_or(false);
let (preferred, fallback) = self
.iter
.partition::<Vec<_>, _>(|addr| addr.is_ipv6() == preferring_v6);
let (preferred, fallback) = self
.iter
.partition::<Vec<_>, _>(|addr| addr.is_ipv6() == preferring_v6);
(IpAddrs::new(preferred), IpAddrs::new(fallback))
(IpAddrs::new(preferred), IpAddrs::new(fallback))
}
}
}
@@ -355,34 +361,50 @@ mod tests {
#[test]
fn test_ip_addrs_split_by_preference() {
let v4_addr = (Ipv4Addr::new(127, 0, 0, 1), 80).into();
let v6_addr = (Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80).into();
let ip_v4 = Ipv4Addr::new(127, 0, 0, 1);
let ip_v6 = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1);
let v4_addr = (ip_v4, 80).into();
let v6_addr = (ip_v6, 80).into();
let (mut preferred, mut fallback) = IpAddrs {
iter: vec![v4_addr, v6_addr].into_iter(),
}
.split_by_preference(None);
.split_by_preference(None, None);
assert!(preferred.next().unwrap().is_ipv4());
assert!(fallback.next().unwrap().is_ipv6());
let (mut preferred, mut fallback) = IpAddrs {
iter: vec![v6_addr, v4_addr].into_iter(),
}
.split_by_preference(None);
.split_by_preference(None, None);
assert!(preferred.next().unwrap().is_ipv6());
assert!(fallback.next().unwrap().is_ipv4());
let (mut preferred, mut fallback) = IpAddrs {
iter: vec![v4_addr, v6_addr].into_iter(),
}
.split_by_preference(Some(ip_v4), Some(ip_v6));
assert!(preferred.next().unwrap().is_ipv4());
assert!(fallback.next().unwrap().is_ipv6());
let (mut preferred, mut fallback) = IpAddrs {
iter: vec![v6_addr, v4_addr].into_iter(),
}
.split_by_preference(Some(ip_v4), Some(ip_v6));
assert!(preferred.next().unwrap().is_ipv6());
assert!(fallback.next().unwrap().is_ipv4());
let (mut preferred, fallback) = IpAddrs {
iter: vec![v4_addr, v6_addr].into_iter(),
}
.split_by_preference(Some(v4_addr.ip()));
.split_by_preference(Some(ip_v4), None);
assert!(preferred.next().unwrap().is_ipv4());
assert!(fallback.is_empty());
let (mut preferred, fallback) = IpAddrs {
iter: vec![v4_addr, v6_addr].into_iter(),
}
.split_by_preference(Some(v6_addr.ip()));
.split_by_preference(None, Some(ip_v6));
assert!(preferred.next().unwrap().is_ipv6());
assert!(fallback.is_empty());
}

View File

@@ -3,7 +3,7 @@ use std::fmt;
use std::future::Future;
use std::io;
use std::marker::PhantomData;
use std::net::{IpAddr, SocketAddr};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{self, Poll};
@@ -72,7 +72,8 @@ struct Config {
enforce_http: bool,
happy_eyeballs_timeout: Option<Duration>,
keep_alive_timeout: Option<Duration>,
local_address: Option<IpAddr>,
local_address_ipv4: Option<Ipv4Addr>,
local_address_ipv6: Option<Ipv6Addr>,
nodelay: bool,
reuse_address: bool,
send_buffer_size: Option<usize>,
@@ -111,7 +112,8 @@ impl<R> HttpConnector<R> {
enforce_http: true,
happy_eyeballs_timeout: Some(Duration::from_millis(300)),
keep_alive_timeout: None,
local_address: None,
local_address_ipv4: None,
local_address_ipv6: None,
nodelay: false,
reuse_address: false,
send_buffer_size: None,
@@ -166,7 +168,26 @@ impl<R> HttpConnector<R> {
/// Default is `None`.
#[inline]
pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
self.config_mut().local_address = addr;
let (v4, v6) = match addr {
Some(IpAddr::V4(a)) => (Some(a), None),
Some(IpAddr::V6(a)) => (None, Some(a)),
_ => (None, None),
};
let cfg = self.config_mut();
cfg.local_address_ipv4 = v4;
cfg.local_address_ipv6 = v6;
}
/// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
/// preferences) before connection.
#[inline]
pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
let cfg = self.config_mut();
cfg.local_address_ipv4 = Some(addr_ipv4);
cfg.local_address_ipv6 = Some(addr_ipv6);
}
/// Set the connect timeout.
@@ -311,7 +332,8 @@ where
};
let c = ConnectingTcp::new(
config.local_address,
config.local_address_ipv4,
config.local_address_ipv6,
addrs,
config.connect_timeout,
config.happy_eyeballs_timeout,
@@ -454,7 +476,8 @@ impl StdError for ConnectError {
}
struct ConnectingTcp {
local_addr: Option<IpAddr>,
local_addr_ipv4: Option<Ipv4Addr>,
local_addr_ipv6: Option<Ipv6Addr>,
preferred: ConnectingTcpRemote,
fallback: Option<ConnectingTcpFallback>,
reuse_address: bool,
@@ -462,17 +485,20 @@ struct ConnectingTcp {
impl ConnectingTcp {
fn new(
local_addr: Option<IpAddr>,
local_addr_ipv4: Option<Ipv4Addr>,
local_addr_ipv6: Option<Ipv6Addr>,
remote_addrs: dns::IpAddrs,
connect_timeout: Option<Duration>,
fallback_timeout: Option<Duration>,
reuse_address: bool,
) -> ConnectingTcp {
if let Some(fallback_timeout) = fallback_timeout {
let (preferred_addrs, fallback_addrs) = remote_addrs.split_by_preference(local_addr);
let (preferred_addrs, fallback_addrs) =
remote_addrs.split_by_preference(local_addr_ipv4, local_addr_ipv6);
if fallback_addrs.is_empty() {
return ConnectingTcp {
local_addr,
local_addr_ipv4,
local_addr_ipv6,
preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout),
fallback: None,
reuse_address,
@@ -480,7 +506,8 @@ impl ConnectingTcp {
}
ConnectingTcp {
local_addr,
local_addr_ipv4,
local_addr_ipv6,
preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout),
fallback: Some(ConnectingTcpFallback {
delay: tokio::time::delay_for(fallback_timeout),
@@ -490,7 +517,8 @@ impl ConnectingTcp {
}
} else {
ConnectingTcp {
local_addr,
local_addr_ipv4,
local_addr_ipv6,
preferred: ConnectingTcpRemote::new(remote_addrs, connect_timeout),
fallback: None,
reuse_address,
@@ -523,13 +551,22 @@ impl ConnectingTcpRemote {
impl ConnectingTcpRemote {
async fn connect(
&mut self,
local_addr: &Option<IpAddr>,
local_addr_ipv4: &Option<Ipv4Addr>,
local_addr_ipv6: &Option<Ipv6Addr>,
reuse_address: bool,
) -> io::Result<TcpStream> {
let mut err = None;
for addr in &mut self.addrs {
debug!("connecting to {}", addr);
match connect(&addr, local_addr, reuse_address, self.connect_timeout)?.await {
match connect(
&addr,
local_addr_ipv4,
local_addr_ipv6,
reuse_address,
self.connect_timeout,
)?
.await
{
Ok(tcp) => {
debug!("connected to {}", addr);
return Ok(tcp);
@@ -551,9 +588,38 @@ impl ConnectingTcpRemote {
}
}
fn bind_local_address(
socket: &socket2::Socket,
dst_addr: &SocketAddr,
local_addr_ipv4: &Option<Ipv4Addr>,
local_addr_ipv6: &Option<Ipv6Addr>,
) -> io::Result<()> {
match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
(SocketAddr::V4(_), Some(addr), _) => {
socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
}
(SocketAddr::V6(_), _, Some(addr)) => {
socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
}
_ => {
if cfg!(windows) {
// Windows requires a socket be bound before calling connect
let any: SocketAddr = match *dst_addr {
SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
};
socket.bind(&any.into())?;
}
}
}
Ok(())
}
fn connect(
addr: &SocketAddr,
local_addr: &Option<IpAddr>,
local_addr_ipv4: &Option<Ipv4Addr>,
local_addr_ipv6: &Option<Ipv6Addr>,
reuse_address: bool,
connect_timeout: Option<Duration>,
) -> io::Result<impl Future<Output = io::Result<TcpStream>>> {
@@ -568,17 +634,7 @@ fn connect(
socket.set_reuse_address(true)?;
}
if let Some(ref local_addr) = *local_addr {
// Caller has requested this socket be bound before calling connect
socket.bind(&SocketAddr::new(local_addr.clone(), 0).into())?;
} else if cfg!(windows) {
// Windows requires a socket be bound before calling connect
let any: SocketAddr = match *addr {
SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
};
socket.bind(&any.into())?;
}
bind_local_address(&socket, addr, local_addr_ipv4, local_addr_ipv6)?;
let addr = *addr;
@@ -600,17 +656,27 @@ fn connect(
impl ConnectingTcp {
async fn connect(mut self) -> io::Result<TcpStream> {
let Self {
ref local_addr,
ref local_addr_ipv4,
ref local_addr_ipv6,
reuse_address,
..
} = self;
match self.fallback {
None => self.preferred.connect(local_addr, reuse_address).await,
None => {
self.preferred
.connect(local_addr_ipv4, local_addr_ipv6, reuse_address)
.await
}
Some(mut fallback) => {
let preferred_fut = self.preferred.connect(local_addr, reuse_address);
let preferred_fut =
self.preferred
.connect(local_addr_ipv4, local_addr_ipv6, reuse_address);
futures_util::pin_mut!(preferred_fut);
let fallback_fut = fallback.remote.connect(local_addr, reuse_address);
let fallback_fut =
fallback
.remote
.connect(local_addr_ipv4, local_addr_ipv6, reuse_address);
futures_util::pin_mut!(fallback_fut);
let (result, future) =
@@ -666,6 +732,32 @@ mod tests {
assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
}
#[cfg(any(target_os = "linux", target_os = "macos"))]
fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
use std::net::{IpAddr, TcpListener};
let mut ip_v4 = None;
let mut ip_v6 = None;
let ips = pnet::datalink::interfaces()
.into_iter()
.flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
for ip in ips {
match ip {
IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
_ => (),
}
if ip_v4.is_some() && ip_v6.is_some() {
break;
}
}
(ip_v4, ip_v6)
}
#[tokio::test]
async fn test_errors_missing_scheme() {
let dst = "example.domain".parse().unwrap();
@@ -676,6 +768,43 @@ mod tests {
assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
}
// NOTE: pnet crate that we use in this test doesn't compile on Windows
#[cfg(any(target_os = "linux", target_os = "macos"))]
#[tokio::test]
async fn local_address() {
use std::net::{IpAddr, TcpListener};
let (bind_ip_v4, bind_ip_v6) = get_local_ips();
let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
let port = server4.local_addr().unwrap().port();
let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
let mut connector = HttpConnector::new();
match (bind_ip_v4, bind_ip_v6) {
(Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
(Some(v4), None) => connector.set_local_address(Some(v4.into())),
(None, Some(v6)) => connector.set_local_address(Some(v6.into())),
_ => unreachable!(),
}
connect(connector, dst.parse().unwrap()).await.unwrap();
let (_, client_addr) = server.accept().unwrap();
assert_eq!(client_addr.ip(), expected_ip);
};
if let Some(ip) = bind_ip_v4 {
assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
}
if let Some(ip) = bind_ip_v6 {
assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
}
}
#[test]
#[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
fn client_happy_eyeballs() {
@@ -797,6 +926,7 @@ mod tests {
.map(|host| (host.clone(), addr.port()).into())
.collect();
let connecting_tcp = ConnectingTcp::new(
None,
None,
dns::IpAddrs::new(addrs),
None,