Add ClientBuilder::local_address option to bind to a local IP address (#451)

Closes #414
This commit is contained in:
Michael Habib
2019-02-11 10:40:16 -08:00
committed by Sean McArthur
parent 8ed9e60351
commit 4dc679d535
3 changed files with 60 additions and 7 deletions

View File

@@ -1,6 +1,7 @@
use std::{fmt, str}; use std::{fmt, str};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::net::IpAddr;
use bytes::Bytes; use bytes::Bytes;
use futures::{Async, Future, Poll}; use futures::{Async, Future, Poll};
@@ -76,6 +77,7 @@ struct Config {
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
tls: TlsBackend, tls: TlsBackend,
http2_only: bool, http2_only: bool,
local_address: Option<IpAddr>,
} }
impl ClientBuilder { impl ClientBuilder {
@@ -106,6 +108,7 @@ impl ClientBuilder {
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
tls: TlsBackend::default(), tls: TlsBackend::default(),
http2_only: false, http2_only: false,
local_address: None,
}, },
} }
} }
@@ -137,7 +140,7 @@ impl ClientBuilder {
id.add_to_native_tls(&mut tls)?; id.add_to_native_tls(&mut tls)?;
} }
Connector::new_default_tls(tls, proxies.clone())? Connector::new_default_tls(tls, proxies.clone(), config.local_address)?
}, },
#[cfg(feature = "rustls-tls")] #[cfg(feature = "rustls-tls")]
TlsBackend::Rustls => { TlsBackend::Rustls => {
@@ -166,18 +169,19 @@ impl ClientBuilder {
id.add_to_rustls(&mut tls)?; id.add_to_rustls(&mut tls)?;
} }
Connector::new_rustls_tls(tls, proxies.clone())? Connector::new_rustls_tls(tls, proxies.clone(), config.local_address)?
} }
} }
#[cfg(not(feature = "tls"))] #[cfg(not(feature = "tls"))]
Connector::new(proxies.clone())? Connector::new(proxies.clone(), config.local_address)?
}; };
let mut builder = ::hyper::Client::builder(); let mut builder = ::hyper::Client::builder();
if config.http2_only { if config.http2_only {
builder.http2_only(true); builder.http2_only(true);
} }
let hyper_client = builder.build(connector); let hyper_client = builder.build(connector);
let proxies_maybe_http_auth = proxies let proxies_maybe_http_auth = proxies
@@ -325,6 +329,15 @@ impl ClientBuilder {
pub fn dns_threads(self, _threads: usize) -> ClientBuilder { pub fn dns_threads(self, _threads: usize) -> ClientBuilder {
self self
} }
/// Bind to a local IP Address
pub fn local_address<T>(mut self, addr: T) -> ClientBuilder
where
T: Into<Option<IpAddr>>,
{
self.config.local_address = addr.into();
self
}
} }
type HyperClient = ::hyper::Client<Connector>; type HyperClient = ::hyper::Client<Connector>;

View File

@@ -2,6 +2,7 @@ use std::fmt;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::thread; use std::thread;
use std::net::IpAddr;
use futures::{Async, Future, Stream}; use futures::{Async, Future, Stream};
use futures::future::{self, Either}; use futures::future::{self, Either};
@@ -306,6 +307,24 @@ impl ClientBuilder {
pub fn h2_prior_knowledge(self) -> ClientBuilder { pub fn h2_prior_knowledge(self) -> ClientBuilder {
self.with_inner(|inner| inner.h2_prior_knowledge()) self.with_inner(|inner| inner.h2_prior_knowledge())
} }
/// Bind to a local IP Address
///
/// # Example
///
/// ```
/// use std::net::IpAddr;
/// let local_addr = IpAddr::from([12, 4, 1, 8]);
/// let client = reqwest::Client::builder()
/// .local_address(local_addr)
/// .build().unwrap();
/// ```
pub fn local_address<T>(self, addr: T) -> ClientBuilder
where
T: Into<Option<IpAddr>>,
{
self.with_inner(move |inner| inner.local_address(addr))
}
} }

View File

@@ -3,6 +3,7 @@ use http::uri::Scheme;
use hyper::client::connect::{Connect, Connected, Destination}; use hyper::client::connect::{Connect, Connected, Destination};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
#[cfg(feature = "default-tls")] #[cfg(feature = "default-tls")]
use native_tls::{TlsConnector, TlsConnectorBuilder}; use native_tls::{TlsConnector, TlsConnectorBuilder};
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
@@ -12,6 +13,7 @@ use bytes::BufMut;
use std::io; use std::io;
use std::sync::Arc; use std::sync::Arc;
use std::net::IpAddr;
#[cfg(feature = "trust-dns")] #[cfg(feature = "trust-dns")]
use dns::TrustDnsResolver; use dns::TrustDnsResolver;
@@ -39,8 +41,13 @@ enum Inner {
impl Connector { impl Connector {
#[cfg(not(feature = "tls"))] #[cfg(not(feature = "tls"))]
pub(crate) fn new(proxies: Arc<Vec<Proxy>>) -> ::Result<Connector> { pub(crate) fn new<T>(proxies: Arc<Vec<Proxy>>, local_addr: T) -> ::Result<Connector>
let http = http_connector()?; where
T: Into<Option<IpAddr>>
{
let mut http = http_connector()?;
http.set_local_address(local_addr.into());
Ok(Connector { Ok(Connector {
proxies, proxies,
inner: Inner::Http(http) inner: Inner::Http(http)
@@ -48,10 +55,17 @@ impl Connector {
} }
#[cfg(feature = "default-tls")] #[cfg(feature = "default-tls")]
pub(crate) fn new_default_tls(tls: TlsConnectorBuilder, proxies: Arc<Vec<Proxy>>) -> ::Result<Connector> { pub(crate) fn new_default_tls<T>(
tls: TlsConnectorBuilder,
proxies: Arc<Vec<Proxy>>,
local_addr: T) -> ::Result<Connector>
where
T: Into<Option<IpAddr>>,
{
let tls = try_!(tls.build()); let tls = try_!(tls.build());
let mut http = http_connector()?; let mut http = http_connector()?;
http.set_local_address(local_addr.into());
http.enforce_http(false); http.enforce_http(false);
let http = ::hyper_tls::HttpsConnector::from((http, tls.clone())); let http = ::hyper_tls::HttpsConnector::from((http, tls.clone()));
@@ -62,8 +76,15 @@ impl Connector {
} }
#[cfg(feature = "rustls-tls")] #[cfg(feature = "rustls-tls")]
pub(crate) fn new_rustls_tls(tls: rustls::ClientConfig, proxies: Arc<Vec<Proxy>>) -> ::Result<Connector> { pub(crate) fn new_rustls_tls<T>(
tls: rustls::ClientConfig,
proxies: Arc<Vec<Proxy>>,
local_addr: T) -> ::Result<Connector>
where
T: Into<Option<IpAddr>>,
{
let mut http = http_connector()?; let mut http = http_connector()?;
http.set_local_address(local_addr.into());
http.enforce_http(false); http.enforce_http(false);
let http = ::hyper_rustls::HttpsConnector::from((http, tls.clone())); let http = ::hyper_rustls::HttpsConnector::from((http, tls.clone()));