diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 5259149..a57a7aa 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -66,6 +66,9 @@ struct Config { hostname_verification: bool, #[cfg(feature = "tls")] certs_verification: bool, + connect_timeout: Option, + #[cfg(feature = "tls")] + identity: Option, proxies: Vec, redirect_policy: RedirectPolicy, referer: bool, @@ -73,8 +76,6 @@ struct Config { #[cfg(feature = "tls")] root_certs: Vec, #[cfg(feature = "tls")] - identity: Option, - #[cfg(feature = "tls")] tls: TlsBackend, http2_only: bool, local_address: Option, @@ -97,6 +98,7 @@ impl ClientBuilder { hostname_verification: true, #[cfg(feature = "tls")] certs_verification: true, + connect_timeout: None, proxies: Vec::new(), redirect_policy: RedirectPolicy::default(), referer: true, @@ -123,7 +125,7 @@ impl ClientBuilder { let config = self.config; let proxies = Arc::new(config.proxies); - let connector = { + let mut connector = { #[cfg(feature = "tls")] match config.tls { #[cfg(feature = "default-tls")] @@ -177,6 +179,8 @@ impl ClientBuilder { Connector::new(proxies.clone(), config.local_address)? }; + connector.set_timeout(config.connect_timeout); + let mut builder = ::hyper::Client::builder(); if config.http2_only { builder.http2_only(true); @@ -312,7 +316,8 @@ impl ClientBuilder { self } - /// Set a timeout for both the read and write operations of a client. + // Currently not used, so hide from docs. + #[doc(hidden)] pub fn timeout(mut self, timeout: Duration) -> ClientBuilder { self.config.timeout = Some(timeout); self @@ -324,6 +329,19 @@ impl ClientBuilder { self } + /// Set a timeout for only the connect phase of a `Client`. + /// + /// Default is `None`. + /// + /// # Note + /// + /// This **requires** the futures be executed in a tokio runtime with + /// a tokio timer enabled. + pub fn connect_timeout(mut self, timeout: Duration) -> ClientBuilder { + self.config.connect_timeout = Some(timeout); + self + } + #[doc(hidden)] #[deprecated(note = "DNS no longer uses blocking threads")] pub fn dns_threads(self, _threads: usize) -> ClientBuilder { diff --git a/src/client.rs b/src/client.rs index 455ba39..16a60b7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -281,12 +281,28 @@ impl ClientBuilder { /// /// Pass `None` to disable timeout. pub fn timeout(mut self, timeout: T) -> ClientBuilder - where T: Into>, + where + T: Into>, { self.timeout = Timeout(timeout.into()); self } + /// Set a timeout for only the connect phase of a `Client`. + /// + /// Default is `None`. + pub fn connect_timeout(self, timeout: T) -> ClientBuilder + where + T: Into>, + { + let timeout = timeout.into(); + if let Some(dur) = timeout { + self.with_inner(|inner| inner.connect_timeout(dur)) + } else { + self + } + } + fn with_inner(mut self, func: F) -> ClientBuilder where F: FnOnce(async_impl::ClientBuilder) -> async_impl::ClientBuilder, diff --git a/src/connect.rs b/src/connect.rs index c4e7eb7..e6ce468 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -2,6 +2,7 @@ use futures::Future; use http::uri::Scheme; use hyper::client::connect::{Connect, Connected, Destination}; use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_timer::Timeout; #[cfg(feature = "default-tls")] @@ -14,6 +15,7 @@ use bytes::BufMut; use std::io; use std::sync::Arc; use std::net::IpAddr; +use std::time::Duration; #[cfg(feature = "trust-dns")] use dns::TrustDnsResolver; @@ -26,8 +28,9 @@ type HttpConnector = ::hyper::client::HttpConnector; pub(crate) struct Connector { + inner: Inner, proxies: Arc>, - inner: Inner + timeout: Option, } enum Inner { @@ -49,8 +52,9 @@ impl Connector { let mut http = http_connector()?; http.set_local_address(local_addr.into()); Ok(Connector { + inner: Inner::Http(http), proxies, - inner: Inner::Http(http) + timeout: None, }) } @@ -70,8 +74,9 @@ impl Connector { let http = ::hyper_tls::HttpsConnector::from((http, tls.clone())); Ok(Connector { + inner: Inner::DefaultTls(http, tls), proxies, - inner: Inner::DefaultTls(http, tls) + timeout: None, }) } @@ -89,10 +94,15 @@ impl Connector { let http = ::hyper_rustls::HttpsConnector::from((http, tls.clone())); Ok(Connector { + inner: Inner::RustlsTls(http, Arc::new(tls)), proxies, - inner: Inner::RustlsTls(http, Arc::new(tls)) + timeout: None, }) } + + pub(crate) fn set_timeout(&mut self, timeout: Option) { + self.timeout = timeout; + } } #[cfg(feature = "trust-dns")] @@ -113,9 +123,27 @@ impl Connect for Connector { type Future = Connecting; fn connect(&self, dst: Destination) -> Self::Future { + macro_rules! timeout { + ($future:expr) => { + if let Some(dur) = self.timeout { + Box::new(Timeout::new($future, dur).map_err(|err| { + if err.is_inner() { + err.into_inner().expect("is_inner") + } else if err.is_elapsed() { + io::Error::new(io::ErrorKind::TimedOut, "connect timed out") + } else { + io::Error::new(io::ErrorKind::Other, err) + } + })) + } else { + Box::new($future) + } + } + } + macro_rules! connect { ( $http:expr, $dst:expr, $proxy:expr ) => { - Box::new($http.connect($dst) + timeout!($http.connect($dst) .map(|(io, connected)| (Box::new(io) as Conn, connected.proxy($proxy)))) }; ( $dst:expr, $proxy:expr ) => { @@ -158,7 +186,7 @@ impl Connect for Connector { let host = dst.host().to_owned(); let port = dst.port().unwrap_or(443); let tls = tls.clone(); - return Box::new(http.connect(ndst).and_then(move |(conn, connected)| { + return timeout!(http.connect(ndst).and_then(move |(conn, connected)| { trace!("tunneling HTTPS over proxy"); tunnel(conn, host.clone(), port, auth) .and_then(move |tunneled| { @@ -178,7 +206,7 @@ impl Connect for Connector { let host = dst.host().to_owned(); let port = dst.port().unwrap_or(443); let tls = tls.clone(); - return Box::new(http.connect(ndst).and_then(move |(conn, connected)| { + return timeout!(http.connect(ndst).and_then(move |(conn, connected)| { trace!("tunneling HTTPS over proxy"); let maybe_dnsname = DNSNameRef::try_from_ascii_str(&host) .map(|dnsname| dnsname.to_owned()) diff --git a/src/lib.rs b/src/lib.rs index 0c9c2fc..c8bd38d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -198,6 +198,7 @@ extern crate serde_urlencoded; extern crate tokio; #[cfg_attr(feature = "default-tls", macro_use)] extern crate tokio_io; +extern crate tokio_timer; #[cfg(feature = "trust-dns")] extern crate trust_dns_resolver; extern crate url;