diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 3b3ef85..5230533 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -106,6 +106,7 @@ struct Config { cookie_store: Option, trust_dns: bool, error: Option, + https_only: bool, } impl Default for ClientBuilder { @@ -157,6 +158,7 @@ impl ClientBuilder { trust_dns: cfg!(feature = "trust-dns"), #[cfg(feature = "cookies")] cookie_store: None, + https_only: false, }, } } @@ -349,6 +351,7 @@ impl ClientBuilder { request_timeout: config.timeout, proxies, proxies_maybe_http_auth, + https_only: config.https_only, }), }) } @@ -917,6 +920,14 @@ impl ClientBuilder { self } } + + /// Restrict the Client to be used with HTTPS only requests. + /// + /// Defaults to false. + pub fn https_only(mut self, enabled: bool) -> ClientBuilder { + self.config.https_only = enabled; + self + } } type HyperClient = hyper::Client; @@ -1040,6 +1051,11 @@ impl Client { return Pending::new_err(error::url_bad_scheme(url)); } + // check if we're in https_only mode and check the scheme of the current URL + if self.inner.https_only && url.scheme() != "https" { + return Pending::new_err(error::url_bad_scheme(url)); + } + // insert default headers in the request headers // without overwriting already appended headers. for (key, value) in &self.inner.headers { @@ -1238,6 +1254,7 @@ struct ClientRef { request_timeout: Option, proxies: Arc>, proxies_maybe_http_auth: bool, + https_only: bool, } impl ClientRef { diff --git a/src/blocking/client.rs b/src/blocking/client.rs index 60f6992..d1fbe41 100644 --- a/src/blocking/client.rs +++ b/src/blocking/client.rs @@ -573,6 +573,13 @@ impl ClientBuilder { self.with_inner(|inner| inner.no_trust_dns()) } + /// Restrict the Client to be used with HTTPS only requests. + /// + /// Defaults to false. + pub fn https_only(self, enabled: bool) -> ClientBuilder { + self.with_inner(|inner| inner.https_only(enabled)) + } + // private fn with_inner(mut self, func: F) -> ClientBuilder diff --git a/tests/blocking.rs b/tests/blocking.rs index ccb5bdf..aa1c2a2 100644 --- a/tests/blocking.rs +++ b/tests/blocking.rs @@ -288,3 +288,25 @@ fn test_blocking_inside_a_runtime() { let _should_panic = reqwest::blocking::get(&url); }); } + +#[cfg(feature = "default-tls")] +#[test] +fn test_allowed_methods_blocking() { + let resp = reqwest::blocking::Client::builder() + .https_only(true) + .build() + .expect("client builder") + .get("https://google.com") + .send(); + + assert_eq!(resp.is_err(), false); + + let resp = reqwest::blocking::Client::builder() + .https_only(true) + .build() + .expect("client builder") + .get("http://google.com") + .send(); + + assert_eq!(resp.is_err(), true); +} diff --git a/tests/client.rs b/tests/client.rs index 40febef..19e90d6 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -199,3 +199,27 @@ fn use_preconfigured_rustls_default() { .build() .expect("preconfigured rustls tls"); } + +#[cfg(feature = "default-tls")] +#[tokio::test] +async fn test_allowed_methods() { + let resp = reqwest::Client::builder() + .https_only(true) + .build() + .expect("client builder") + .get("https://google.com") + .send() + .await; + + assert_eq!(resp.is_err(), false); + + let resp = reqwest::Client::builder() + .https_only(true) + .build() + .expect("client builder") + .get("http://google.com") + .send() + .await; + + assert_eq!(resp.is_err(), true); +}