Merge pull request #29 from seanmonstar/redirect-policy
add support for defining RedirectPolicy for a Client
This commit is contained in:
		| @@ -1,6 +1,6 @@ | |||||||
| use std::fmt; | use std::fmt; | ||||||
| use std::io::{self, Read}; | use std::io::{self, Read}; | ||||||
| use std::sync::Arc; | use std::sync::{Arc, Mutex}; | ||||||
|  |  | ||||||
| use hyper::client::IntoUrl; | use hyper::client::IntoUrl; | ||||||
| use hyper::header::{Headers, ContentType, Location, Referer, UserAgent}; | use hyper::header::{Headers, ContentType, Location, Referer, UserAgent}; | ||||||
| @@ -14,6 +14,7 @@ use serde_json; | |||||||
| use serde_urlencoded; | use serde_urlencoded; | ||||||
|  |  | ||||||
| use ::body::{self, Body}; | use ::body::{self, Body}; | ||||||
|  | use ::redirect::{RedirectPolicy, check_redirect}; | ||||||
|  |  | ||||||
| static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); | static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); | ||||||
|  |  | ||||||
| @@ -24,8 +25,9 @@ static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", e | |||||||
| /// | /// | ||||||
| /// The `Client` holds a connection pool internally, so it is advised that | /// The `Client` holds a connection pool internally, so it is advised that | ||||||
| /// you create one and reuse it. | /// you create one and reuse it. | ||||||
|  | #[derive(Clone)] | ||||||
| pub struct Client { | pub struct Client { | ||||||
|     inner: ClientRef,  //::hyper::Client, |     inner: Arc<ClientRef>,  //::hyper::Client, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl Client { | impl Client { | ||||||
| @@ -34,12 +36,18 @@ impl Client { | |||||||
|         let mut client = try!(new_hyper_client()); |         let mut client = try!(new_hyper_client()); | ||||||
|         client.set_redirect_policy(::hyper::client::RedirectPolicy::FollowNone); |         client.set_redirect_policy(::hyper::client::RedirectPolicy::FollowNone); | ||||||
|         Ok(Client { |         Ok(Client { | ||||||
|             inner: ClientRef { |             inner: Arc::new(ClientRef { | ||||||
|                 hyper: Arc::new(client), |                 hyper: client, | ||||||
|             } |                 redirect_policy: Mutex::new(RedirectPolicy::default()), | ||||||
|  |             }), | ||||||
|         }) |         }) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     /// Set a `RedirectPolicy` for this client. | ||||||
|  |     pub fn redirect(&mut self, policy: RedirectPolicy) { | ||||||
|  |         *self.inner.redirect_policy.lock().unwrap() = policy; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     /// Convenience method to make a `GET` request to a URL. |     /// Convenience method to make a `GET` request to a URL. | ||||||
|     pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder { |     pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder { | ||||||
|         self.request(Method::Get, url) |         self.request(Method::Get, url) | ||||||
| @@ -75,13 +83,15 @@ impl Client { | |||||||
|  |  | ||||||
| impl fmt::Debug for Client { | impl fmt::Debug for Client { | ||||||
|     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||||||
|         f.pad("Client") |         f.debug_struct("Client") | ||||||
|  |             .field("redirect_policy", &self.inner.redirect_policy) | ||||||
|  |             .finish() | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Clone)] |  | ||||||
| struct ClientRef { | struct ClientRef { | ||||||
|     hyper: Arc<::hyper::Client>, |     hyper: ::hyper::Client, | ||||||
|  |     redirect_policy: Mutex<RedirectPolicy>, | ||||||
| } | } | ||||||
|  |  | ||||||
| fn new_hyper_client() -> ::Result<::hyper::Client> { | fn new_hyper_client() -> ::Result<::hyper::Client> { | ||||||
| @@ -97,7 +107,7 @@ fn new_hyper_client() -> ::Result<::hyper::Client> { | |||||||
|  |  | ||||||
| /// A builder to construct the properties of a `Request`. | /// A builder to construct the properties of a `Request`. | ||||||
| pub struct RequestBuilder { | pub struct RequestBuilder { | ||||||
|     client: ClientRef, |     client: Arc<ClientRef>, | ||||||
|  |  | ||||||
|     method: Method, |     method: Method, | ||||||
|     url: Result<Url, ::UrlError>, |     url: Result<Url, ::UrlError>, | ||||||
| @@ -196,7 +206,7 @@ impl RequestBuilder { | |||||||
|             None => None, |             None => None, | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         let mut redirect_count = 0; |         let mut urls = Vec::new(); | ||||||
|  |  | ||||||
|         loop { |         loop { | ||||||
|             let res = { |             let res = { | ||||||
| @@ -237,14 +247,6 @@ impl RequestBuilder { | |||||||
|             }; |             }; | ||||||
|  |  | ||||||
|             if should_redirect { |             if should_redirect { | ||||||
|                 //TODO: turn this into self.redirect_policy.check() |  | ||||||
|                 if redirect_count > 10 { |  | ||||||
|                     return Err(::Error::TooManyRedirects); |  | ||||||
|                 } |  | ||||||
|                 redirect_count += 1; |  | ||||||
|  |  | ||||||
|                 headers.set(Referer(url.to_string())); |  | ||||||
|  |  | ||||||
|                 let loc = { |                 let loc = { | ||||||
|                     let loc = res.headers.get::<Location>().map(|loc| url.join(loc)); |                     let loc = res.headers.get::<Location>().map(|loc| url.join(loc)); | ||||||
|                     if let Some(loc) = loc { |                     if let Some(loc) = loc { | ||||||
| @@ -257,7 +259,18 @@ impl RequestBuilder { | |||||||
|                 }; |                 }; | ||||||
|  |  | ||||||
|                 url = match loc { |                 url = match loc { | ||||||
|                     Ok(u) => u, |                     Ok(loc) => { | ||||||
|  |                         headers.set(Referer(url.to_string())); | ||||||
|  |                         urls.push(url); | ||||||
|  |                         if check_redirect(&client.redirect_policy.lock().unwrap(), &loc, &urls)? { | ||||||
|  |                             loc | ||||||
|  |                         } else { | ||||||
|  |                             debug!("redirect_policy disallowed redirection to '{}'", loc); | ||||||
|  |                             return Ok(Response { | ||||||
|  |                                 inner: res | ||||||
|  |                             }) | ||||||
|  |                         } | ||||||
|  |                     }, | ||||||
|                     Err(e) => { |                     Err(e) => { | ||||||
|                         debug!("Location header had invalid URI: {:?}", e); |                         debug!("Location header had invalid URI: {:?}", e); | ||||||
|                         return Ok(Response { |                         return Ok(Response { | ||||||
|   | |||||||
| @@ -13,6 +13,8 @@ pub enum Error { | |||||||
|     Serialize(Box<StdError + Send + Sync>), |     Serialize(Box<StdError + Send + Sync>), | ||||||
|     /// A request tried to redirect too many times. |     /// A request tried to redirect too many times. | ||||||
|     TooManyRedirects, |     TooManyRedirects, | ||||||
|  |     /// An infinite redirect loop was detected. | ||||||
|  |     RedirectLoop, | ||||||
|     #[doc(hidden)] |     #[doc(hidden)] | ||||||
|     __DontMatchMe, |     __DontMatchMe, | ||||||
| } | } | ||||||
| @@ -22,9 +24,8 @@ impl fmt::Display for Error { | |||||||
|         match *self { |         match *self { | ||||||
|             Error::Http(ref e) => fmt::Display::fmt(e, f), |             Error::Http(ref e) => fmt::Display::fmt(e, f), | ||||||
|             Error::Serialize(ref e) => fmt::Display::fmt(e, f), |             Error::Serialize(ref e) => fmt::Display::fmt(e, f), | ||||||
|             Error::TooManyRedirects => { |             Error::TooManyRedirects => f.pad("Too many redirects"), | ||||||
|                 f.pad("Too many redirects") |             Error::RedirectLoop => f.pad("Infinite redirect loop"), | ||||||
|             }, |  | ||||||
|             Error::__DontMatchMe => unreachable!() |             Error::__DontMatchMe => unreachable!() | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -36,6 +37,7 @@ impl StdError for Error { | |||||||
|             Error::Http(ref e) => e.description(), |             Error::Http(ref e) => e.description(), | ||||||
|             Error::Serialize(ref e) => e.description(), |             Error::Serialize(ref e) => e.description(), | ||||||
|             Error::TooManyRedirects => "Too many redirects", |             Error::TooManyRedirects => "Too many redirects", | ||||||
|  |             Error::RedirectLoop => "Infinite redirect loop", | ||||||
|             Error::__DontMatchMe => unreachable!() |             Error::__DontMatchMe => unreachable!() | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -45,6 +47,7 @@ impl StdError for Error { | |||||||
|             Error::Http(ref e) => Some(e), |             Error::Http(ref e) => Some(e), | ||||||
|             Error::Serialize(ref e) => Some(&**e), |             Error::Serialize(ref e) => Some(&**e), | ||||||
|             Error::TooManyRedirects => None, |             Error::TooManyRedirects => None, | ||||||
|  |             Error::RedirectLoop => None, | ||||||
|             Error::__DontMatchMe => unreachable!() |             Error::__DontMatchMe => unreachable!() | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -108,10 +108,12 @@ pub use url::ParseError as UrlError; | |||||||
| pub use self::client::{Client, Response, RequestBuilder}; | pub use self::client::{Client, Response, RequestBuilder}; | ||||||
| pub use self::error::{Error, Result}; | pub use self::error::{Error, Result}; | ||||||
| pub use self::body::Body; | pub use self::body::Body; | ||||||
|  | pub use self::redirect::RedirectPolicy; | ||||||
|  |  | ||||||
| mod body; | mod body; | ||||||
| mod client; | mod client; | ||||||
| mod error; | mod error; | ||||||
|  | mod redirect; | ||||||
| mod tls; | mod tls; | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										159
									
								
								src/redirect.rs
									
									
									
									
									
								
							
							
						
						
									
										159
									
								
								src/redirect.rs
									
									
									
									
									
								
							| @@ -1,8 +1,163 @@ | |||||||
|  | use std::fmt; | ||||||
|  |  | ||||||
|  | use ::Url; | ||||||
|  |  | ||||||
|  | /// A type that controls the policy on how to handle the following of redirects. | ||||||
|  | /// | ||||||
|  | /// The default value will catch redirect loops, and has a maximum of 10 | ||||||
|  | /// redirects it will follow in a chain before returning an error. | ||||||
| #[derive(Debug)] | #[derive(Debug)] | ||||||
| pub struct RedirectPolicy { | pub struct RedirectPolicy { | ||||||
|     inner: () |     inner: Policy, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl RedirectPolicy { | impl RedirectPolicy { | ||||||
|      |     /// Create a RedirectPolicy with a maximum number of redirects. | ||||||
|  |     /// | ||||||
|  |     /// A `Error::TooManyRedirects` will be returned if the max is reached. | ||||||
|  |     pub fn limited(max: usize) -> RedirectPolicy { | ||||||
|  |         RedirectPolicy { | ||||||
|  |             inner: Policy::Limit(max), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /// Create a RedirectPolicy that does not follow any redirect. | ||||||
|  |     pub fn none() -> RedirectPolicy { | ||||||
|  |         RedirectPolicy { | ||||||
|  |             inner: Policy::None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /// Create a custom RedirectPolicy using the passed function. | ||||||
|  |     /// | ||||||
|  |     /// # Note | ||||||
|  |     /// | ||||||
|  |     /// The default RedirectPolicy handles redirect loops and a maximum loop | ||||||
|  |     /// chain, but the custom variant does not do that for you automatically. | ||||||
|  |     /// The custom policy should hanve some way of handling those. | ||||||
|  |     /// | ||||||
|  |     /// There are variants on `::Error` for both cases that can be used as | ||||||
|  |     /// return values. | ||||||
|  |     /// | ||||||
|  |     /// # Example | ||||||
|  |     /// | ||||||
|  |     /// ```no_run | ||||||
|  |     /// # use reqwest::RedirectPolicy; | ||||||
|  |     /// # let mut client = reqwest::Client::new().unwrap(); | ||||||
|  |     /// client.redirect(RedirectPolicy::custom(|next, previous| { | ||||||
|  |     ///     if previous.len() > 5 { | ||||||
|  |     ///         Err(reqwest::Error::TooManyRedirects) | ||||||
|  |     ///     } else if next.host_str() == Some("example.domain") { | ||||||
|  |     ///         // prevent redirects to 'example.domain' | ||||||
|  |     ///         Ok(false) | ||||||
|  |     ///     } else { | ||||||
|  |     ///         Ok(true) | ||||||
|  |     ///     } | ||||||
|  |     /// })); | ||||||
|  |     /// ``` | ||||||
|  |     pub fn custom<T>(policy: T) -> RedirectPolicy | ||||||
|  |     where T: Fn(&Url, &[Url]) -> ::Result<bool> + Send + Sync + 'static { | ||||||
|  |         RedirectPolicy { | ||||||
|  |             inner: Policy::Custom(Box::new(policy)), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result<bool> { | ||||||
|  |         match self.inner { | ||||||
|  |             Policy::Custom(ref custom) => custom(next, previous), | ||||||
|  |             Policy::Limit(max) => { | ||||||
|  |                 if previous.len() == max { | ||||||
|  |                     Err(::Error::TooManyRedirects) | ||||||
|  |                 } else if previous.contains(next) { | ||||||
|  |                     Err(::Error::RedirectLoop) | ||||||
|  |                 } else { | ||||||
|  |                     Ok(true) | ||||||
|  |                 } | ||||||
|  |             }, | ||||||
|  |             Policy::None => Ok(false), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Default for RedirectPolicy { | ||||||
|  |     fn default() -> RedirectPolicy { | ||||||
|  |         RedirectPolicy::limited(10) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | enum Policy { | ||||||
|  |     Custom(Box<Fn(&Url, &[Url]) -> ::Result<bool> + Send + Sync + 'static>), | ||||||
|  |     Limit(usize), | ||||||
|  |     None, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl fmt::Debug for Policy { | ||||||
|  |     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||||||
|  |         match *self { | ||||||
|  |             Policy::Custom(..) => f.pad("Custom"), | ||||||
|  |             Policy::Limit(max) => f.debug_tuple("Limit").field(&max).finish(), | ||||||
|  |             Policy::None => f.pad("None"), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | pub fn check_redirect(policy: &RedirectPolicy, next: &Url, previous: &[Url]) -> ::Result<bool> { | ||||||
|  |     policy.redirect(next, previous) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /* | ||||||
|  | This was the desired way of doing it, but ran in to inference issues when | ||||||
|  | using closures, since the arguments received are references (&Url and &[Url]), | ||||||
|  | and the compiler could not infer the lifetimes of those references. That means | ||||||
|  | people would need to annotate the closure's argument types, which is garbase. | ||||||
|  |  | ||||||
|  | pub trait Redirect { | ||||||
|  |     fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result<bool>; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<F> Redirect for F | ||||||
|  | where F: Fn(&Url, &[Url]) -> ::Result<bool> { | ||||||
|  |     fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result<bool> { | ||||||
|  |         self(next, previous) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | */ | ||||||
|  |  | ||||||
|  | #[test] | ||||||
|  | fn test_redirect_policy_limit() { | ||||||
|  |     let policy = RedirectPolicy::default(); | ||||||
|  |     let next = Url::parse("http://x.y/z").unwrap(); | ||||||
|  |     let mut previous = (0..9) | ||||||
|  |         .map(|i| Url::parse(&format!("http://a.b/c/{}", i)).unwrap()) | ||||||
|  |         .collect::<Vec<_>>(); | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     match policy.redirect(&next, &previous) { | ||||||
|  |         Ok(true) => {}, | ||||||
|  |         other => panic!("expected Ok(true), got: {:?}", other) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     previous.push(Url::parse("http://a.b.d/e/33").unwrap()); | ||||||
|  |  | ||||||
|  |     match policy.redirect(&next, &previous) { | ||||||
|  |         Err(::Error::TooManyRedirects) => {}, | ||||||
|  |         other => panic!("expected TooManyRedirects, got: {:?}", other) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[test] | ||||||
|  | fn test_redirect_policy_custom() { | ||||||
|  |     let policy = RedirectPolicy::custom(|next, _previous| { | ||||||
|  |         if next.host_str() == Some("foo") { | ||||||
|  |             Ok(false) | ||||||
|  |         } else { | ||||||
|  |             Ok(true) | ||||||
|  |         } | ||||||
|  |     }); | ||||||
|  |  | ||||||
|  |     let next = Url::parse("http://bar/baz").unwrap(); | ||||||
|  |     assert_eq!(policy.redirect(&next, &[]).unwrap(), true); | ||||||
|  |  | ||||||
|  |     let next = Url::parse("http://foo/baz").unwrap(); | ||||||
|  |     assert_eq!(policy.redirect(&next, &[]).unwrap(), false); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -160,3 +160,56 @@ fn test_redirect_307_does_not_try_if_reader_cannot_reset() { | |||||||
|         assert_eq!(res.status(), &reqwest::StatusCode::from_u16(code)); |         assert_eq!(res.status(), &reqwest::StatusCode::from_u16(code)); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[test] | ||||||
|  | fn test_redirect_policy_can_return_errors() { | ||||||
|  |     let server = server! { | ||||||
|  |         request: b"\ | ||||||
|  |             GET /loop HTTP/1.1\r\n\ | ||||||
|  |             Host: $HOST\r\n\ | ||||||
|  |             User-Agent: $USERAGENT\r\n\ | ||||||
|  |             \r\n\ | ||||||
|  |             ", | ||||||
|  |         response: b"\ | ||||||
|  |             HTTP/1.1 302 Found\r\n\ | ||||||
|  |             Server: test\r\n\ | ||||||
|  |             Location: /loop | ||||||
|  |             Content-Length: 0\r\n\ | ||||||
|  |             \r\n\ | ||||||
|  |             " | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     let err = reqwest::get(&format!("http://{}/loop", server.addr())).unwrap_err(); | ||||||
|  |     match err { | ||||||
|  |         reqwest::Error::RedirectLoop => (), | ||||||
|  |         e => panic!("wrong error received: {:?}", e), | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[test] | ||||||
|  | fn test_redirect_policy_can_stop_redirects_without_an_error() { | ||||||
|  |     let server = server! { | ||||||
|  |         request: b"\ | ||||||
|  |             GET /no-redirect HTTP/1.1\r\n\ | ||||||
|  |             Host: $HOST\r\n\ | ||||||
|  |             User-Agent: $USERAGENT\r\n\ | ||||||
|  |             \r\n\ | ||||||
|  |             ", | ||||||
|  |         response: b"\ | ||||||
|  |             HTTP/1.1 302 Found\r\n\ | ||||||
|  |             Server: test-dont\r\n\ | ||||||
|  |             Location: /dont | ||||||
|  |             Content-Length: 0\r\n\ | ||||||
|  |             \r\n\ | ||||||
|  |             " | ||||||
|  |     }; | ||||||
|  |     let mut client = reqwest::Client::new().unwrap(); | ||||||
|  |     client.redirect(reqwest::RedirectPolicy::none()); | ||||||
|  |  | ||||||
|  |     let res = client.get(&format!("http://{}/no-redirect", server.addr())) | ||||||
|  |         .send() | ||||||
|  |         .unwrap(); | ||||||
|  |  | ||||||
|  |     assert_eq!(res.status(), &reqwest::StatusCode::Found); | ||||||
|  |     assert_eq!(res.headers().get(), Some(&reqwest::header::Server("test-dont".to_string()))); | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user