add support for defining RedirectPolicy for a Client
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::io::{self, Read};
|
use std::io::{self, Read};
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use hyper::client::IntoUrl;
|
use hyper::client::IntoUrl;
|
||||||
use hyper::header::{Headers, ContentType, Location, Referer, UserAgent};
|
use hyper::header::{Headers, ContentType, Location, Referer, UserAgent};
|
||||||
@@ -14,6 +14,7 @@ use serde_json;
|
|||||||
use serde_urlencoded;
|
use serde_urlencoded;
|
||||||
|
|
||||||
use ::body::{self, Body};
|
use ::body::{self, Body};
|
||||||
|
use ::redirect::{RedirectPolicy, check_redirect};
|
||||||
|
|
||||||
static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
|
static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
|
||||||
|
|
||||||
@@ -24,8 +25,9 @@ static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", e
|
|||||||
///
|
///
|
||||||
/// The `Client` holds a connection pool internally, so it is advised that
|
/// The `Client` holds a connection pool internally, so it is advised that
|
||||||
/// you create one and reuse it.
|
/// you create one and reuse it.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
inner: ClientRef, //::hyper::Client,
|
inner: Arc<ClientRef>, //::hyper::Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Client {
|
impl Client {
|
||||||
@@ -34,12 +36,18 @@ impl Client {
|
|||||||
let mut client = try!(new_hyper_client());
|
let mut client = try!(new_hyper_client());
|
||||||
client.set_redirect_policy(::hyper::client::RedirectPolicy::FollowNone);
|
client.set_redirect_policy(::hyper::client::RedirectPolicy::FollowNone);
|
||||||
Ok(Client {
|
Ok(Client {
|
||||||
inner: ClientRef {
|
inner: Arc::new(ClientRef {
|
||||||
hyper: Arc::new(client),
|
hyper: client,
|
||||||
}
|
redirect_policy: Mutex::new(RedirectPolicy::default()),
|
||||||
|
}),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set a `RedirectPolicy` for this client.
|
||||||
|
pub fn redirect(&mut self, policy: RedirectPolicy) {
|
||||||
|
*self.inner.redirect_policy.lock().unwrap() = policy;
|
||||||
|
}
|
||||||
|
|
||||||
/// Convenience method to make a `GET` request to a URL.
|
/// Convenience method to make a `GET` request to a URL.
|
||||||
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
|
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
|
||||||
self.request(Method::Get, url)
|
self.request(Method::Get, url)
|
||||||
@@ -75,13 +83,15 @@ impl Client {
|
|||||||
|
|
||||||
impl fmt::Debug for Client {
|
impl fmt::Debug for Client {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
f.pad("Client")
|
f.debug_struct("Client")
|
||||||
|
.field("redirect_policy", &self.inner.redirect_policy)
|
||||||
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct ClientRef {
|
struct ClientRef {
|
||||||
hyper: Arc<::hyper::Client>,
|
hyper: ::hyper::Client,
|
||||||
|
redirect_policy: Mutex<RedirectPolicy>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_hyper_client() -> ::Result<::hyper::Client> {
|
fn new_hyper_client() -> ::Result<::hyper::Client> {
|
||||||
@@ -97,7 +107,7 @@ fn new_hyper_client() -> ::Result<::hyper::Client> {
|
|||||||
|
|
||||||
/// A builder to construct the properties of a `Request`.
|
/// A builder to construct the properties of a `Request`.
|
||||||
pub struct RequestBuilder {
|
pub struct RequestBuilder {
|
||||||
client: ClientRef,
|
client: Arc<ClientRef>,
|
||||||
|
|
||||||
method: Method,
|
method: Method,
|
||||||
url: Result<Url, ::UrlError>,
|
url: Result<Url, ::UrlError>,
|
||||||
@@ -196,7 +206,7 @@ impl RequestBuilder {
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut redirect_count = 0;
|
let mut urls = Vec::new();
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let res = {
|
let res = {
|
||||||
@@ -237,14 +247,6 @@ impl RequestBuilder {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if should_redirect {
|
if should_redirect {
|
||||||
//TODO: turn this into self.redirect_policy.check()
|
|
||||||
if redirect_count > 10 {
|
|
||||||
return Err(::Error::TooManyRedirects);
|
|
||||||
}
|
|
||||||
redirect_count += 1;
|
|
||||||
|
|
||||||
headers.set(Referer(url.to_string()));
|
|
||||||
|
|
||||||
let loc = {
|
let loc = {
|
||||||
let loc = res.headers.get::<Location>().map(|loc| url.join(loc));
|
let loc = res.headers.get::<Location>().map(|loc| url.join(loc));
|
||||||
if let Some(loc) = loc {
|
if let Some(loc) = loc {
|
||||||
@@ -257,7 +259,18 @@ impl RequestBuilder {
|
|||||||
};
|
};
|
||||||
|
|
||||||
url = match loc {
|
url = match loc {
|
||||||
Ok(u) => u,
|
Ok(loc) => {
|
||||||
|
headers.set(Referer(url.to_string()));
|
||||||
|
urls.push(url);
|
||||||
|
if check_redirect(&client.redirect_policy.lock().unwrap(), &loc, &urls)? {
|
||||||
|
loc
|
||||||
|
} else {
|
||||||
|
debug!("redirect_policy disallowed redirection to '{}'", loc);
|
||||||
|
return Ok(Response {
|
||||||
|
inner: res
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("Location header had invalid URI: {:?}", e);
|
debug!("Location header had invalid URI: {:?}", e);
|
||||||
return Ok(Response {
|
return Ok(Response {
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ pub enum Error {
|
|||||||
Serialize(Box<StdError + Send + Sync>),
|
Serialize(Box<StdError + Send + Sync>),
|
||||||
/// A request tried to redirect too many times.
|
/// A request tried to redirect too many times.
|
||||||
TooManyRedirects,
|
TooManyRedirects,
|
||||||
|
/// An infinite redirect loop was detected.
|
||||||
|
RedirectLoop,
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
__DontMatchMe,
|
__DontMatchMe,
|
||||||
}
|
}
|
||||||
@@ -22,9 +24,8 @@ impl fmt::Display for Error {
|
|||||||
match *self {
|
match *self {
|
||||||
Error::Http(ref e) => fmt::Display::fmt(e, f),
|
Error::Http(ref e) => fmt::Display::fmt(e, f),
|
||||||
Error::Serialize(ref e) => fmt::Display::fmt(e, f),
|
Error::Serialize(ref e) => fmt::Display::fmt(e, f),
|
||||||
Error::TooManyRedirects => {
|
Error::TooManyRedirects => f.pad("Too many redirects"),
|
||||||
f.pad("Too many redirects")
|
Error::RedirectLoop => f.pad("Infinite redirect loop"),
|
||||||
},
|
|
||||||
Error::__DontMatchMe => unreachable!()
|
Error::__DontMatchMe => unreachable!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -36,6 +37,7 @@ impl StdError for Error {
|
|||||||
Error::Http(ref e) => e.description(),
|
Error::Http(ref e) => e.description(),
|
||||||
Error::Serialize(ref e) => e.description(),
|
Error::Serialize(ref e) => e.description(),
|
||||||
Error::TooManyRedirects => "Too many redirects",
|
Error::TooManyRedirects => "Too many redirects",
|
||||||
|
Error::RedirectLoop => "Infinite redirect loop",
|
||||||
Error::__DontMatchMe => unreachable!()
|
Error::__DontMatchMe => unreachable!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -45,6 +47,7 @@ impl StdError for Error {
|
|||||||
Error::Http(ref e) => Some(e),
|
Error::Http(ref e) => Some(e),
|
||||||
Error::Serialize(ref e) => Some(&**e),
|
Error::Serialize(ref e) => Some(&**e),
|
||||||
Error::TooManyRedirects => None,
|
Error::TooManyRedirects => None,
|
||||||
|
Error::RedirectLoop => None,
|
||||||
Error::__DontMatchMe => unreachable!()
|
Error::__DontMatchMe => unreachable!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -108,10 +108,12 @@ pub use url::ParseError as UrlError;
|
|||||||
pub use self::client::{Client, Response, RequestBuilder};
|
pub use self::client::{Client, Response, RequestBuilder};
|
||||||
pub use self::error::{Error, Result};
|
pub use self::error::{Error, Result};
|
||||||
pub use self::body::Body;
|
pub use self::body::Body;
|
||||||
|
pub use self::redirect::RedirectPolicy;
|
||||||
|
|
||||||
mod body;
|
mod body;
|
||||||
mod client;
|
mod client;
|
||||||
mod error;
|
mod error;
|
||||||
|
mod redirect;
|
||||||
mod tls;
|
mod tls;
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
157
src/redirect.rs
157
src/redirect.rs
@@ -1,8 +1,163 @@
|
|||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
use ::Url;
|
||||||
|
|
||||||
|
/// A type that controls the policy on how to handle the following of redirects.
|
||||||
|
///
|
||||||
|
/// The default value will catch redirect loops, and has a maximum of 10
|
||||||
|
/// redirects it will follow in a chain before returning an error.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct RedirectPolicy {
|
pub struct RedirectPolicy {
|
||||||
inner: ()
|
inner: Policy,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RedirectPolicy {
|
impl RedirectPolicy {
|
||||||
|
/// Create a RedirectPolicy with a maximum number of redirects.
|
||||||
|
///
|
||||||
|
/// A `Error::TooManyRedirects` will be returned if the max is reached.
|
||||||
|
pub fn limited(max: usize) -> RedirectPolicy {
|
||||||
|
RedirectPolicy {
|
||||||
|
inner: Policy::Limit(max),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a RedirectPolicy that does not follow any redirect.
|
||||||
|
pub fn none() -> RedirectPolicy {
|
||||||
|
RedirectPolicy {
|
||||||
|
inner: Policy::None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a custom RedirectPolicy using the passed function.
|
||||||
|
///
|
||||||
|
/// # Note
|
||||||
|
///
|
||||||
|
/// The default RedirectPolicy handles redirect loops and a maximum loop
|
||||||
|
/// chain, but the custom variant does not do that for you automatically.
|
||||||
|
/// The custom policy should hanve some way of handling those.
|
||||||
|
///
|
||||||
|
/// There are variants on `::Error` for both cases that can be used as
|
||||||
|
/// return values.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// # use reqwest::RedirectPolicy;
|
||||||
|
/// # let mut client = reqwest::Client::new().unwrap();
|
||||||
|
/// client.redirect(RedirectPolicy::custom(|next, previous| {
|
||||||
|
/// if previous.len() > 5 {
|
||||||
|
/// Err(reqwest::Error::TooManyRedirects)
|
||||||
|
/// } else if next.host_str() == Some("example.domain") {
|
||||||
|
/// // prevent redirects to 'example.domain'
|
||||||
|
/// Ok(false)
|
||||||
|
/// } else {
|
||||||
|
/// Ok(true)
|
||||||
|
/// }
|
||||||
|
/// }));
|
||||||
|
/// ```
|
||||||
|
pub fn custom<T>(policy: T) -> RedirectPolicy
|
||||||
|
where T: Fn(&Url, &[Url]) -> ::Result<bool> + Send + Sync + 'static {
|
||||||
|
RedirectPolicy {
|
||||||
|
inner: Policy::Custom(Box::new(policy)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result<bool> {
|
||||||
|
match self.inner {
|
||||||
|
Policy::Custom(ref custom) => custom(next, previous),
|
||||||
|
Policy::Limit(max) => {
|
||||||
|
if previous.len() == max {
|
||||||
|
Err(::Error::TooManyRedirects)
|
||||||
|
} else if previous.contains(next) {
|
||||||
|
Err(::Error::RedirectLoop)
|
||||||
|
} else {
|
||||||
|
Ok(true)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Policy::None => Ok(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RedirectPolicy {
|
||||||
|
fn default() -> RedirectPolicy {
|
||||||
|
RedirectPolicy::limited(10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum Policy {
|
||||||
|
Custom(Box<Fn(&Url, &[Url]) -> ::Result<bool> + Send + Sync + 'static>),
|
||||||
|
Limit(usize),
|
||||||
|
None,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for Policy {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
match *self {
|
||||||
|
Policy::Custom(..) => f.pad("Custom"),
|
||||||
|
Policy::Limit(max) => f.debug_tuple("Limit").field(&max).finish(),
|
||||||
|
Policy::None => f.pad("None"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_redirect(policy: &RedirectPolicy, next: &Url, previous: &[Url]) -> ::Result<bool> {
|
||||||
|
policy.redirect(next, previous)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
This was the desired way of doing it, but ran in to inference issues when
|
||||||
|
using closures, since the arguments received are references (&Url and &[Url]),
|
||||||
|
and the compiler could not infer the lifetimes of those references. That means
|
||||||
|
people would need to annotate the closure's argument types, which is garbase.
|
||||||
|
|
||||||
|
pub trait Redirect {
|
||||||
|
fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result<bool>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> Redirect for F
|
||||||
|
where F: Fn(&Url, &[Url]) -> ::Result<bool> {
|
||||||
|
fn redirect(&self, next: &Url, previous: &[Url]) -> ::Result<bool> {
|
||||||
|
self(next, previous)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_redirect_policy_limit() {
|
||||||
|
let policy = RedirectPolicy::default();
|
||||||
|
let next = Url::parse("http://x.y/z").unwrap();
|
||||||
|
let mut previous = (0..9)
|
||||||
|
.map(|i| Url::parse(&format!("http://a.b/c/{}", i)).unwrap())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
|
||||||
|
match policy.redirect(&next, &previous) {
|
||||||
|
Ok(true) => {},
|
||||||
|
other => panic!("expected Ok(true), got: {:?}", other)
|
||||||
|
}
|
||||||
|
|
||||||
|
previous.push(Url::parse("http://a.b.d/e/33").unwrap());
|
||||||
|
|
||||||
|
match policy.redirect(&next, &previous) {
|
||||||
|
Err(::Error::TooManyRedirects) => {},
|
||||||
|
other => panic!("expected TooManyRedirects, got: {:?}", other)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_redirect_policy_custom() {
|
||||||
|
let policy = RedirectPolicy::custom(|next, _previous| {
|
||||||
|
if next.host_str() == Some("foo") {
|
||||||
|
Ok(false)
|
||||||
|
} else {
|
||||||
|
Ok(true)
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let next = Url::parse("http://bar/baz").unwrap();
|
||||||
|
assert_eq!(policy.redirect(&next, &[]).unwrap(), true);
|
||||||
|
|
||||||
|
let next = Url::parse("http://foo/baz").unwrap();
|
||||||
|
assert_eq!(policy.redirect(&next, &[]).unwrap(), false);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -160,3 +160,56 @@ fn test_redirect_307_does_not_try_if_reader_cannot_reset() {
|
|||||||
assert_eq!(res.status(), &reqwest::StatusCode::from_u16(code));
|
assert_eq!(res.status(), &reqwest::StatusCode::from_u16(code));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_redirect_policy_can_return_errors() {
|
||||||
|
let server = server! {
|
||||||
|
request: b"\
|
||||||
|
GET /loop HTTP/1.1\r\n\
|
||||||
|
Host: $HOST\r\n\
|
||||||
|
User-Agent: $USERAGENT\r\n\
|
||||||
|
\r\n\
|
||||||
|
",
|
||||||
|
response: b"\
|
||||||
|
HTTP/1.1 302 Found\r\n\
|
||||||
|
Server: test\r\n\
|
||||||
|
Location: /loop
|
||||||
|
Content-Length: 0\r\n\
|
||||||
|
\r\n\
|
||||||
|
"
|
||||||
|
};
|
||||||
|
|
||||||
|
let err = reqwest::get(&format!("http://{}/loop", server.addr())).unwrap_err();
|
||||||
|
match err {
|
||||||
|
reqwest::Error::RedirectLoop => (),
|
||||||
|
e => panic!("wrong error received: {:?}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_redirect_policy_can_stop_redirects_without_an_error() {
|
||||||
|
let server = server! {
|
||||||
|
request: b"\
|
||||||
|
GET /no-redirect HTTP/1.1\r\n\
|
||||||
|
Host: $HOST\r\n\
|
||||||
|
User-Agent: $USERAGENT\r\n\
|
||||||
|
\r\n\
|
||||||
|
",
|
||||||
|
response: b"\
|
||||||
|
HTTP/1.1 302 Found\r\n\
|
||||||
|
Server: test-dont\r\n\
|
||||||
|
Location: /dont
|
||||||
|
Content-Length: 0\r\n\
|
||||||
|
\r\n\
|
||||||
|
"
|
||||||
|
};
|
||||||
|
let mut client = reqwest::Client::new().unwrap();
|
||||||
|
client.redirect(reqwest::RedirectPolicy::none());
|
||||||
|
|
||||||
|
let res = client.get(&format!("http://{}/no-redirect", server.addr()))
|
||||||
|
.send()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(res.status(), &reqwest::StatusCode::Found);
|
||||||
|
assert_eq!(res.headers().get(), Some(&reqwest::header::Server("test-dont".to_string())));
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user