Merge branch 'http-connector-enforce-scheme'
This commit is contained in:
		| @@ -1,3 +1,4 @@ | ||||
| use std::error::Error as StdError; | ||||
| use std::fmt; | ||||
| use std::io; | ||||
| //use std::net::SocketAddr; | ||||
| @@ -42,6 +43,7 @@ where T: Service<Request=Uri, Error=io::Error> + 'static, | ||||
| #[derive(Clone)] | ||||
| pub struct HttpConnector { | ||||
|     dns: dns::Dns, | ||||
|     enforce_http: bool, | ||||
|     handle: Handle, | ||||
| } | ||||
|  | ||||
| @@ -50,15 +52,26 @@ impl HttpConnector { | ||||
|     /// Construct a new HttpConnector. | ||||
|     /// | ||||
|     /// Takes number of DNS worker threads. | ||||
|     #[inline] | ||||
|     pub fn new(threads: usize, handle: &Handle) -> HttpConnector { | ||||
|         HttpConnector { | ||||
|             dns: dns::Dns::new(threads), | ||||
|             enforce_http: true, | ||||
|             handle: handle.clone(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /// Option to enforce all `Uri`s have the `http` scheme. | ||||
|     /// | ||||
|     /// Enabled by default. | ||||
|     #[inline] | ||||
|     pub fn enforce_http(&mut self, is_enforced: bool) { | ||||
|         self.enforce_http = is_enforced; | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl fmt::Debug for HttpConnector { | ||||
|     #[inline] | ||||
|     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||||
|         f.debug_struct("HttpConnector") | ||||
|             .finish() | ||||
| @@ -73,12 +86,18 @@ impl Service for HttpConnector { | ||||
|  | ||||
|     fn call(&self, uri: Uri) -> Self::Future { | ||||
|         debug!("Http::connect({:?})", uri); | ||||
|  | ||||
|         if self.enforce_http { | ||||
|             if uri.scheme() != Some("http") { | ||||
|                 return invalid_url(InvalidUrl::NotHttp, &self.handle); | ||||
|             } | ||||
|         } else if uri.scheme().is_none() { | ||||
|             return invalid_url(InvalidUrl::MissingScheme, &self.handle); | ||||
|         } | ||||
|  | ||||
|         let host = match uri.host() { | ||||
|             Some(s) => s, | ||||
|             None => return HttpConnecting { | ||||
|                 state: State::Error(Some(io::Error::new(io::ErrorKind::InvalidInput, "invalid url"))), | ||||
|                 handle: self.handle.clone(), | ||||
|             }, | ||||
|             None => return invalid_url(InvalidUrl::MissingAuthority, &self.handle), | ||||
|         }; | ||||
|         let port = match uri.port() { | ||||
|             Some(port) => port, | ||||
| @@ -94,7 +113,37 @@ impl Service for HttpConnector { | ||||
|             handle: self.handle.clone(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[inline] | ||||
| fn invalid_url(err: InvalidUrl, handle: &Handle) -> HttpConnecting { | ||||
|     HttpConnecting { | ||||
|         state: State::Error(Some(io::Error::new(io::ErrorKind::InvalidInput, err))), | ||||
|         handle: handle.clone(), | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone, Copy)] | ||||
| enum InvalidUrl { | ||||
|     MissingScheme, | ||||
|     NotHttp, | ||||
|     MissingAuthority, | ||||
| } | ||||
|  | ||||
| impl fmt::Display for InvalidUrl { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||||
|         f.write_str(self.description()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl StdError for InvalidUrl { | ||||
|     fn description(&self) -> &str { | ||||
|         match *self { | ||||
|             InvalidUrl::MissingScheme => "invalid URL, missing scheme", | ||||
|             InvalidUrl::NotHttp => "invalid URL, scheme must be http", | ||||
|             InvalidUrl::MissingAuthority => "invalid URL, missing domain", | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// A Future representing work to connect to a URL. | ||||
| @@ -195,7 +244,7 @@ mod tests { | ||||
|     use super::{Connect, HttpConnector}; | ||||
|  | ||||
|     #[test] | ||||
|     fn test_non_http_url() { | ||||
|     fn test_errors_missing_authority() { | ||||
|         let mut core = Core::new().unwrap(); | ||||
|         let url = "/foo/bar?baz".parse().unwrap(); | ||||
|         let connector = HttpConnector::new(1, &core.handle()); | ||||
| @@ -203,4 +252,22 @@ mod tests { | ||||
|         assert_eq!(core.run(connector.connect(url)).unwrap_err().kind(), io::ErrorKind::InvalidInput); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_errors_enforce_http() { | ||||
|         let mut core = Core::new().unwrap(); | ||||
|         let url = "https://example.domain/foo/bar?baz".parse().unwrap(); | ||||
|         let connector = HttpConnector::new(1, &core.handle()); | ||||
|  | ||||
|         assert_eq!(core.run(connector.connect(url)).unwrap_err().kind(), io::ErrorKind::InvalidInput); | ||||
|     } | ||||
|  | ||||
|  | ||||
|     #[test] | ||||
|     fn test_errors_missing_scheme() { | ||||
|         let mut core = Core::new().unwrap(); | ||||
|         let url = "example.domain".parse().unwrap(); | ||||
|         let connector = HttpConnector::new(1, &core.handle()); | ||||
|  | ||||
|         assert_eq!(core.run(connector.connect(url)).unwrap_err().kind(), io::ErrorKind::InvalidInput); | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -95,6 +95,7 @@ impl Uri { | ||||
|     } | ||||
|  | ||||
|     /// Get the path of this `Uri`. | ||||
|     #[inline] | ||||
|     pub fn path(&self) -> &str { | ||||
|         let index = self.path_start(); | ||||
|         let end = self.path_end(); | ||||
| @@ -135,6 +136,7 @@ impl Uri { | ||||
|     } | ||||
|  | ||||
|     /// Get the scheme of this `Uri`. | ||||
|     #[inline] | ||||
|     pub fn scheme(&self) -> Option<&str> { | ||||
|         if let Some(end) = self.scheme_end { | ||||
|             Some(&self.source[..end]) | ||||
| @@ -144,6 +146,7 @@ impl Uri { | ||||
|     } | ||||
|  | ||||
|     /// Get the authority of this `Uri`. | ||||
|     #[inline] | ||||
|     pub fn authority(&self) -> Option<&str> { | ||||
|         if let Some(end) = self.authority_end { | ||||
|             let index = self.scheme_end.map(|i| i + 3).unwrap_or(0); | ||||
| @@ -155,6 +158,7 @@ impl Uri { | ||||
|     } | ||||
|  | ||||
|     /// Get the host of this `Uri`. | ||||
|     #[inline] | ||||
|     pub fn host(&self) -> Option<&str> { | ||||
|         if let Some(auth) = self.authority() { | ||||
|             auth.split(":").next() | ||||
| @@ -164,6 +168,7 @@ impl Uri { | ||||
|     } | ||||
|  | ||||
|     /// Get the port of this `Uri`. | ||||
|     #[inline] | ||||
|     pub fn port(&self) -> Option<u16> { | ||||
|         match self.authority() { | ||||
|             Some(auth) => auth.find(":").and_then(|i| u16::from_str(&auth[i+1..]).ok()), | ||||
| @@ -172,6 +177,7 @@ impl Uri { | ||||
|     } | ||||
|  | ||||
|     /// Get the query string of this `Uri`, starting after the `?`. | ||||
|     #[inline] | ||||
|     pub fn query(&self) -> Option<&str> { | ||||
|         self.query_start.map(|start| { | ||||
|             // +1 to remove '?' | ||||
|   | ||||
		Reference in New Issue
	
	Block a user