add support for defining RedirectPolicy for a Client

This commit is contained in:
Sean McArthur
2016-12-10 11:36:22 -08:00
parent 6ef73ae206
commit e92b3e862a
5 changed files with 250 additions and 24 deletions

View File

@@ -1,6 +1,6 @@
use std::fmt;
use std::io::{self, Read};
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use hyper::client::IntoUrl;
use hyper::header::{Headers, ContentType, Location, Referer, UserAgent};
@@ -14,6 +14,7 @@ use serde_json;
use serde_urlencoded;
use ::body::{self, Body};
use ::redirect::{RedirectPolicy, check_redirect};
static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
@@ -24,8 +25,9 @@ static DEFAULT_USER_AGENT: &'static str = concat!(env!("CARGO_PKG_NAME"), "/", e
///
/// The `Client` holds a connection pool internally, so it is advised that
/// you create one and reuse it.
#[derive(Clone)]
pub struct Client {
inner: ClientRef, //::hyper::Client,
inner: Arc<ClientRef>, //::hyper::Client,
}
impl Client {
@@ -34,12 +36,18 @@ impl Client {
let mut client = try!(new_hyper_client());
client.set_redirect_policy(::hyper::client::RedirectPolicy::FollowNone);
Ok(Client {
inner: ClientRef {
hyper: Arc::new(client),
}
inner: Arc::new(ClientRef {
hyper: client,
redirect_policy: Mutex::new(RedirectPolicy::default()),
}),
})
}
/// Set a `RedirectPolicy` for this client.
pub fn redirect(&mut self, policy: RedirectPolicy) {
*self.inner.redirect_policy.lock().unwrap() = policy;
}
/// Convenience method to make a `GET` request to a URL.
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
self.request(Method::Get, url)
@@ -75,13 +83,15 @@ impl Client {
impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.pad("Client")
f.debug_struct("Client")
.field("redirect_policy", &self.inner.redirect_policy)
.finish()
}
}
#[derive(Clone)]
struct ClientRef {
hyper: Arc<::hyper::Client>,
hyper: ::hyper::Client,
redirect_policy: Mutex<RedirectPolicy>,
}
fn new_hyper_client() -> ::Result<::hyper::Client> {
@@ -97,7 +107,7 @@ fn new_hyper_client() -> ::Result<::hyper::Client> {
/// A builder to construct the properties of a `Request`.
pub struct RequestBuilder {
client: ClientRef,
client: Arc<ClientRef>,
method: Method,
url: Result<Url, ::UrlError>,
@@ -196,7 +206,7 @@ impl RequestBuilder {
None => None,
};
let mut redirect_count = 0;
let mut urls = Vec::new();
loop {
let res = {
@@ -237,14 +247,6 @@ impl RequestBuilder {
};
if should_redirect {
//TODO: turn this into self.redirect_policy.check()
if redirect_count > 10 {
return Err(::Error::TooManyRedirects);
}
redirect_count += 1;
headers.set(Referer(url.to_string()));
let loc = {
let loc = res.headers.get::<Location>().map(|loc| url.join(loc));
if let Some(loc) = loc {
@@ -257,7 +259,18 @@ impl RequestBuilder {
};
url = match loc {
Ok(u) => u,
Ok(loc) => {
headers.set(Referer(url.to_string()));
urls.push(url);
if check_redirect(&client.redirect_policy.lock().unwrap(), &loc, &urls)? {
loc
} else {
debug!("redirect_policy disallowed redirection to '{}'", loc);
return Ok(Response {
inner: res
})
}
},
Err(e) => {
debug!("Location header had invalid URI: {:?}", e);
return Ok(Response {