From 1c34a05a85078421078f2cb266dccc5dfce8a9f0 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 24 Apr 2017 11:39:02 -0700 Subject: [PATCH] feat(client): add `HttpConnector.enforce_http` This will make the `HttpConnector` require the `scheme` to be `http`, and return an error otherwise. This value is enabled by default, so any requests to URLs that aren't of scheme `http` will now see an error message stating the failure. When constructing a connector that wraps an `HttpConnector`, this enforcement can be disabled to allow connecting over TCP easily even when the scheme is not `http`. To do, call `connector.enforce_http(false)`. --- src/client/connect.rs | 77 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 5 deletions(-) diff --git a/src/client/connect.rs b/src/client/connect.rs index fb15e2e8..ec08d7df 100644 --- a/src/client/connect.rs +++ b/src/client/connect.rs @@ -1,3 +1,4 @@ +use std::error::Error as StdError; use std::fmt; use std::io; //use std::net::SocketAddr; @@ -42,6 +43,7 @@ where T: Service + 'static, #[derive(Clone)] pub struct HttpConnector { dns: dns::Dns, + enforce_http: bool, handle: Handle, } @@ -50,15 +52,26 @@ impl HttpConnector { /// Construct a new HttpConnector. /// /// Takes number of DNS worker threads. + #[inline] pub fn new(threads: usize, handle: &Handle) -> HttpConnector { HttpConnector { dns: dns::Dns::new(threads), + enforce_http: true, handle: handle.clone(), } } + + /// Option to enforce all `Uri`s have the `http` scheme. + /// + /// Enabled by default. + #[inline] + pub fn enforce_http(&mut self, is_enforced: bool) { + self.enforce_http = is_enforced; + } } impl fmt::Debug for HttpConnector { + #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("HttpConnector") .finish() @@ -73,12 +86,18 @@ impl Service for HttpConnector { fn call(&self, uri: Uri) -> Self::Future { debug!("Http::connect({:?})", uri); + + if self.enforce_http { + if uri.scheme() != Some("http") { + return invalid_url(InvalidUrl::NotHttp, &self.handle); + } + } else if uri.scheme().is_none() { + return invalid_url(InvalidUrl::MissingScheme, &self.handle); + } + let host = match uri.host() { Some(s) => s, - None => return HttpConnecting { - state: State::Error(Some(io::Error::new(io::ErrorKind::InvalidInput, "invalid url"))), - handle: self.handle.clone(), - }, + None => return invalid_url(InvalidUrl::MissingAuthority, &self.handle), }; let port = match uri.port() { Some(port) => port, @@ -94,7 +113,37 @@ impl Service for HttpConnector { handle: self.handle.clone(), } } +} +#[inline] +fn invalid_url(err: InvalidUrl, handle: &Handle) -> HttpConnecting { + HttpConnecting { + state: State::Error(Some(io::Error::new(io::ErrorKind::InvalidInput, err))), + handle: handle.clone(), + } +} + +#[derive(Debug, Clone, Copy)] +enum InvalidUrl { + MissingScheme, + NotHttp, + MissingAuthority, +} + +impl fmt::Display for InvalidUrl { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(self.description()) + } +} + +impl StdError for InvalidUrl { + fn description(&self) -> &str { + match *self { + InvalidUrl::MissingScheme => "invalid URL, missing scheme", + InvalidUrl::NotHttp => "invalid URL, scheme must be http", + InvalidUrl::MissingAuthority => "invalid URL, missing domain", + } + } } /// A Future representing work to connect to a URL. @@ -195,7 +244,7 @@ mod tests { use super::{Connect, HttpConnector}; #[test] - fn test_non_http_url() { + fn test_errors_missing_authority() { let mut core = Core::new().unwrap(); let url = "/foo/bar?baz".parse().unwrap(); let connector = HttpConnector::new(1, &core.handle()); @@ -203,4 +252,22 @@ mod tests { assert_eq!(core.run(connector.connect(url)).unwrap_err().kind(), io::ErrorKind::InvalidInput); } + #[test] + fn test_errors_enforce_http() { + let mut core = Core::new().unwrap(); + let url = "https://example.domain/foo/bar?baz".parse().unwrap(); + let connector = HttpConnector::new(1, &core.handle()); + + assert_eq!(core.run(connector.connect(url)).unwrap_err().kind(), io::ErrorKind::InvalidInput); + } + + + #[test] + fn test_errors_missing_scheme() { + let mut core = Core::new().unwrap(); + let url = "example.domain".parse().unwrap(); + let connector = HttpConnector::new(1, &core.handle()); + + assert_eq!(core.run(connector.connect(url)).unwrap_err().kind(), io::ErrorKind::InvalidInput); + } }