From e92b3e862a1a94c0b4173a7d49a315bc121da31e Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Sat, 10 Dec 2016 11:36:22 -0800 Subject: [PATCH] add support for defining RedirectPolicy for a Client --- src/client.rs | 51 ++++++++++------ src/error.rs | 9 ++- src/lib.rs | 2 + src/redirect.rs | 159 +++++++++++++++++++++++++++++++++++++++++++++++- tests/client.rs | 53 ++++++++++++++++ 5 files changed, 250 insertions(+), 24 deletions(-) diff --git a/src/client.rs b/src/client.rs index 39113ef..ace7c89 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,6 +1,6 @@ use std::fmt; use std::io::{self, Read}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use hyper::client::IntoUrl; use hyper::header::{Headers, ContentType, Location, Referer, UserAgent}; @@ -14,6 +14,7 @@ use serde_json; use serde_urlencoded; use ::body::{self, Body}; +use ::redirect::{RedirectPolicy, check_redirect}; static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); @@ -24,8 +25,9 @@ static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", e /// /// The `Client` holds a connection pool internally, so it is advised that /// you create one and reuse it. +#[derive(Clone)] pub struct Client { - inner: ClientRef, //::hyper::Client, + inner: Arc, //::hyper::Client, } impl Client { @@ -34,12 +36,18 @@ impl Client { let mut client = try!(new_hyper_client()); client.set_redirect_policy(::hyper::client::RedirectPolicy::FollowNone); Ok(Client { - inner: ClientRef { - hyper: Arc::new(client), - } + inner: Arc::new(ClientRef { + hyper: client, + redirect_policy: Mutex::new(RedirectPolicy::default()), + }), }) } + /// Set a `RedirectPolicy` for this client. + pub fn redirect(&mut self, policy: RedirectPolicy) { + *self.inner.redirect_policy.lock().unwrap() = policy; + } + /// Convenience method to make a `GET` request to a URL. pub fn get(&self, url: U) -> RequestBuilder { self.request(Method::Get, url) @@ -75,13 +83,15 @@ impl Client { impl fmt::Debug for Client { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.pad("Client") + f.debug_struct("Client") + .field("redirect_policy", &self.inner.redirect_policy) + .finish() } } -#[derive(Clone)] struct ClientRef { - hyper: Arc<::hyper::Client>, + hyper: ::hyper::Client, + redirect_policy: Mutex, } fn new_hyper_client() -> ::Result<::hyper::Client> { @@ -97,7 +107,7 @@ fn new_hyper_client() -> ::Result<::hyper::Client> { /// A builder to construct the properties of a `Request`. pub struct RequestBuilder { - client: ClientRef, + client: Arc, method: Method, url: Result, @@ -196,7 +206,7 @@ impl RequestBuilder { None => None, }; - let mut redirect_count = 0; + let mut urls = Vec::new(); loop { let res = { @@ -237,14 +247,6 @@ impl RequestBuilder { }; if should_redirect { - //TODO: turn this into self.redirect_policy.check() - if redirect_count > 10 { - return Err(::Error::TooManyRedirects); - } - redirect_count += 1; - - headers.set(Referer(url.to_string())); - let loc = { let loc = res.headers.get::().map(|loc| url.join(loc)); if let Some(loc) = loc { @@ -257,7 +259,18 @@ impl RequestBuilder { }; url = match loc { - Ok(u) => u, + Ok(loc) => { + headers.set(Referer(url.to_string())); + urls.push(url); + if check_redirect(&client.redirect_policy.lock().unwrap(), &loc, &urls)? { + loc + } else { + debug!("redirect_policy disallowed redirection to '{}'", loc); + return Ok(Response { + inner: res + }) + } + }, Err(e) => { debug!("Location header had invalid URI: {:?}", e); return Ok(Response { diff --git a/src/error.rs b/src/error.rs index 2a7896b..b624bce 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,6 +13,8 @@ pub enum Error { Serialize(Box), /// A request tried to redirect too many times. TooManyRedirects, + /// An infinite redirect loop was detected. + RedirectLoop, #[doc(hidden)] __DontMatchMe, } @@ -22,9 +24,8 @@ impl fmt::Display for Error { match *self { Error::Http(ref e) => fmt::Display::fmt(e, f), Error::Serialize(ref e) => fmt::Display::fmt(e, f), - Error::TooManyRedirects => { - f.pad("Too many redirects") - }, + Error::TooManyRedirects => f.pad("Too many redirects"), + Error::RedirectLoop => f.pad("Infinite redirect loop"), Error::__DontMatchMe => unreachable!() } } @@ -36,6 +37,7 @@ impl StdError for Error { Error::Http(ref e) => e.description(), Error::Serialize(ref e) => e.description(), Error::TooManyRedirects => "Too many redirects", + Error::RedirectLoop => "Infinite redirect loop", Error::__DontMatchMe => unreachable!() } } @@ -45,6 +47,7 @@ impl StdError for Error { Error::Http(ref e) => Some(e), Error::Serialize(ref e) => Some(&**e), Error::TooManyRedirects => None, + Error::RedirectLoop => None, Error::__DontMatchMe => unreachable!() } } diff --git a/src/lib.rs b/src/lib.rs index 936c7cc..d9b2824 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -108,10 +108,12 @@ pub use url::ParseError as UrlError; pub use self::client::{Client, Response, RequestBuilder}; pub use self::error::{Error, Result}; pub use self::body::Body; +pub use self::redirect::RedirectPolicy; mod body; mod client; mod error; +mod redirect; mod tls; diff --git a/src/redirect.rs b/src/redirect.rs index 4fbd22a..9208c1b 100644 --- a/src/redirect.rs +++ b/src/redirect.rs @@ -1,8 +1,163 @@ +use std::fmt; + +use ::Url; + +/// A type that controls the policy on how to handle the following of redirects. +/// +/// The default value will catch redirect loops, and has a maximum of 10 +/// redirects it will follow in a chain before returning an error. #[derive(Debug)] pub struct RedirectPolicy { - inner: () + inner: Policy, } impl RedirectPolicy { - + /// Create a RedirectPolicy with a maximum number of redirects. + /// + /// A `Error::TooManyRedirects` will be returned if the max is reached. + pub fn limited(max: usize) -> RedirectPolicy { + RedirectPolicy { + inner: Policy::Limit(max), + } + } + + /// Create a RedirectPolicy that does not follow any redirect. + pub fn none() -> RedirectPolicy { + RedirectPolicy { + inner: Policy::None, + } + } + + /// Create a custom RedirectPolicy using the passed function. + /// + /// # Note + /// + /// The default RedirectPolicy handles redirect loops and a maximum loop + /// chain, but the custom variant does not do that for you automatically. + /// The custom policy should hanve some way of handling those. + /// + /// There are variants on `::Error` for both cases that can be used as + /// return values. + /// + /// # Example + /// + /// ```no_run + /// # use reqwest::RedirectPolicy; + /// # let mut client = reqwest::Client::new().unwrap(); + /// client.redirect(RedirectPolicy::custom(|next, previous| { + /// if previous.len() > 5 { + /// Err(reqwest::Error::TooManyRedirects) + /// } else if next.host_str() == Some("example.domain") { + /// // prevent redirects to 'example.domain' + /// Ok(false) + /// } else { + /// Ok(true) + /// } + /// })); + /// ``` + pub fn custom(policy: T) -> RedirectPolicy + where T: Fn(&Url, &[Url]) -> ::Result + Send + Sync + 'static { + RedirectPolicy { + inner: Policy::Custom(Box::new(policy)), + } + } + + fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result { + match self.inner { + Policy::Custom(ref custom) => custom(next, previous), + Policy::Limit(max) => { + if previous.len() == max { + Err(::Error::TooManyRedirects) + } else if previous.contains(next) { + Err(::Error::RedirectLoop) + } else { + Ok(true) + } + }, + Policy::None => Ok(false), + } + } +} + +impl Default for RedirectPolicy { + fn default() -> RedirectPolicy { + RedirectPolicy::limited(10) + } +} + +enum Policy { + Custom(Box ::Result + Send + Sync + 'static>), + Limit(usize), + None, +} + +impl fmt::Debug for Policy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Policy::Custom(..) => f.pad("Custom"), + Policy::Limit(max) => f.debug_tuple("Limit").field(&max).finish(), + Policy::None => f.pad("None"), + } + } +} + +pub fn check_redirect(policy: &RedirectPolicy, next: &Url, previous: &[Url]) -> ::Result { + policy.redirect(next, previous) +} + +/* +This was the desired way of doing it, but ran in to inference issues when +using closures, since the arguments received are references (&Url and &[Url]), +and the compiler could not infer the lifetimes of those references. That means +people would need to annotate the closure's argument types, which is garbase. + +pub trait Redirect { + fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result; +} + +impl Redirect for F +where F: Fn(&Url, &[Url]) -> ::Result { + fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result { + self(next, previous) + } +} +*/ + +#[test] +fn test_redirect_policy_limit() { + let policy = RedirectPolicy::default(); + let next = Url::parse("http://x.y/z").unwrap(); + let mut previous = (0..9) + .map(|i| Url::parse(&format!("http://a.b/c/{}", i)).unwrap()) + .collect::>(); + + + match policy.redirect(&next, &previous) { + Ok(true) => {}, + other => panic!("expected Ok(true), got: {:?}", other) + } + + previous.push(Url::parse("http://a.b.d/e/33").unwrap()); + + match policy.redirect(&next, &previous) { + Err(::Error::TooManyRedirects) => {}, + other => panic!("expected TooManyRedirects, got: {:?}", other) + } +} + +#[test] +fn test_redirect_policy_custom() { + let policy = RedirectPolicy::custom(|next, _previous| { + if next.host_str() == Some("foo") { + Ok(false) + } else { + Ok(true) + } + }); + + let next = Url::parse("http://bar/baz").unwrap(); + assert_eq!(policy.redirect(&next, &[]).unwrap(), true); + + let next = Url::parse("http://foo/baz").unwrap(); + assert_eq!(policy.redirect(&next, &[]).unwrap(), false); } diff --git a/tests/client.rs b/tests/client.rs index 39b8080..1cc710e 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -160,3 +160,56 @@ fn test_redirect_307_does_not_try_if_reader_cannot_reset() { assert_eq!(res.status(), &reqwest::StatusCode::from_u16(code)); } } + +#[test] +fn test_redirect_policy_can_return_errors() { + let server = server! { + request: b"\ + GET /loop HTTP/1.1\r\n\ + Host: $HOST\r\n\ + User-Agent: $USERAGENT\r\n\ + \r\n\ + ", + response: b"\ + HTTP/1.1 302 Found\r\n\ + Server: test\r\n\ + Location: /loop + Content-Length: 0\r\n\ + \r\n\ + " + }; + + let err = reqwest::get(&format!("http://{}/loop", server.addr())).unwrap_err(); + match err { + reqwest::Error::RedirectLoop => (), + e => panic!("wrong error received: {:?}", e), + } +} + +#[test] +fn test_redirect_policy_can_stop_redirects_without_an_error() { + let server = server! { + request: b"\ + GET /no-redirect HTTP/1.1\r\n\ + Host: $HOST\r\n\ + User-Agent: $USERAGENT\r\n\ + \r\n\ + ", + response: b"\ + HTTP/1.1 302 Found\r\n\ + Server: test-dont\r\n\ + Location: /dont + Content-Length: 0\r\n\ + \r\n\ + " + }; + let mut client = reqwest::Client::new().unwrap(); + client.redirect(reqwest::RedirectPolicy::none()); + + let res = client.get(&format!("http://{}/no-redirect", server.addr())) + .send() + .unwrap(); + + assert_eq!(res.status(), &reqwest::StatusCode::Found); + assert_eq!(res.headers().get(), Some(&reqwest::header::Server("test-dont".to_string()))); +}