From 5479d9e8b852f9988f4d4d0938ea1cdb827ec348 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Tue, 25 Oct 2016 12:26:01 -0700 Subject: [PATCH] handle redirects --- Cargo.toml | 4 +- examples/post.rs | 40 ------------- src/client.rs | 148 ++++++++++++++++++++++++++++++++++++++--------- src/error.rs | 43 ++++++++++++++ src/lib.rs | 10 +++- 5 files changed, 174 insertions(+), 71 deletions(-) delete mode 100644 examples/post.rs diff --git a/Cargo.toml b/Cargo.toml index 4239d8f..ecf8c12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,10 @@ license = "MIT/Apache-2.0" [dependencies] hyper = { version = "0.9" , default-features = false } +log = "0.3" serde = "0.8" serde_json = "0.8" -log = "0.3" +url = "1.0" [dependencies.native-tls] git = "https://github.com/sfackler/rust-native-tls" @@ -25,4 +26,3 @@ tls = ["native-tls"] [dev-dependencies] env_logger = "0.3" -serde_derive = "0.8" diff --git a/examples/post.rs b/examples/post.rs deleted file mode 100644 index 7c6379c..0000000 --- a/examples/post.rs +++ /dev/null @@ -1,40 +0,0 @@ -//#![feature(proc_macro)] - -extern crate reqwest; -extern crate env_logger; -//#[macro_use] extern crate serde_derive; - -/* -#[derive(Serialize)] -struct Thingy { - a: i32, - b: bool, - c: String, -} -*/ - -fn main() { - env_logger::init().unwrap(); - - println!("POST https://httpbin.org/post"); - - /* - let thingy = Thingy { - a: 5, - b: true, - c: String::from("reqwest") - }; - */ - - let client = reqwest::Client::new(); - let mut res = client.post("https://httpbin.org/post") - .body("foo=bar") - .send().unwrap(); - - println!("Status: {}", res.status()); - println!("Headers:\n{}", res.headers()); - - ::std::io::copy(&mut res, &mut ::std::io::stdout()).unwrap(); - - println!("\n\nDone."); -} diff --git a/src/client.rs b/src/client.rs index 2e36612..6952367 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,6 +1,7 @@ use std::io::{self, Read}; -use hyper::header::{Headers, ContentType, UserAgent}; +use hyper::client::IntoUrl; +use hyper::header::{Headers, ContentType, Location, Referer, UserAgent}; use hyper::method::Method; use hyper::status::StatusCode; use hyper::version::HttpVersion; @@ -26,28 +27,35 @@ pub struct Client { impl Client { /// Constructs a new `Client`. - pub fn new() -> Client { - Client { - inner: new_hyper_client() - } + pub fn new() -> ::Result { + let mut client = try!(new_hyper_client()); + client.set_redirect_policy(::hyper::client::RedirectPolicy::FollowNone); + Ok(Client { + inner: client + }) } /// Convenience method to make a `GET` request to a URL. - pub fn get(&self, url: &str) -> RequestBuilder { - self.request(Method::Get, Url::parse(url).unwrap()) + pub fn get(&self, url: U) -> RequestBuilder { + self.request(Method::Get, url) } /// Convenience method to make a `POST` request to a URL. - pub fn post(&self, url: &str) -> RequestBuilder { - self.request(Method::Post, Url::parse(url).unwrap()) + pub fn post(&self, url: U) -> RequestBuilder { + self.request(Method::Post, url) + } + + /// Convenience method to make a `HEAD` request to a URL. + pub fn head(&self, url: U) -> RequestBuilder { + self.request(Method::Head, url) } /// Start building a `Request` with the `Method` and `Url`. /// /// Returns a `RequestBuilder`, which will allow setting headers and /// request body before sending. - pub fn request(&self, method: Method, url: Url) -> RequestBuilder { - debug!("request {:?} \"{}\"", method, url); + pub fn request(&self, method: Method, url: U) -> RequestBuilder { + let url = url.into_url(); RequestBuilder { client: self, method: method, @@ -61,19 +69,19 @@ impl Client { } #[cfg(not(feature = "tls"))] -fn new_hyper_client() -> ::hyper::Client { - ::hyper::Client::new() +fn new_hyper_client() -> ::Result<::hyper::Client> { + Ok(::hyper::Client::new()) } #[cfg(feature = "tls")] -fn new_hyper_client() -> ::hyper::Client { +fn new_hyper_client() -> ::Result<::hyper::Client> { use tls::TlsClient; - ::hyper::Client::with_connector( + Ok(::hyper::Client::with_connector( ::hyper::client::Pool::with_connector( Default::default(), - ::hyper::net::HttpsConnector::new(TlsClient::new().unwrap()) + ::hyper::net::HttpsConnector::new(try!(TlsClient::new())) ) - ) + )) } @@ -82,7 +90,7 @@ pub struct RequestBuilder<'a> { client: &'a Client, method: Method, - url: Url, + url: Result, _version: HttpVersion, headers: Headers, @@ -91,6 +99,15 @@ pub struct RequestBuilder<'a> { impl<'a> RequestBuilder<'a> { /// Add a `Header` to this Request. + /// + /// ```no_run + /// use reqwest::header::UserAgent; + /// let client = reqwest::Client::new().expect("client failed to construct"); + /// + /// let res = client.get("https://www.rust-lang.org") + /// .header(UserAgent("foo".to_string())) + /// .send(); + /// ``` pub fn header(mut self, header: H) -> RequestBuilder<'a> { self.headers.set(header); self @@ -109,6 +126,20 @@ impl<'a> RequestBuilder<'a> { self } + /// Send a JSON body. + /// + /// Sets the body to the JSON serialization of the passed value, and + /// also sets the `Content-Type: application/json` header. + /// + /// ```no_run + /// # use std::collections::HashMap; + /// let mut map = HashMap::new(); + /// map.insert("lang", "rust"); + /// + /// let res = reqwest::post("http://www.rust-lang.org") + /// .json(map) + /// .send(); + /// ``` pub fn json(mut self, json: T) -> RequestBuilder<'a> { let body = serde_json::to_vec(&json).expect("serde to_vec cannot fail"); self.headers.set(ContentType::json()); @@ -122,18 +153,79 @@ impl<'a> RequestBuilder<'a> { self.headers.set(UserAgent(DEFAULT_USER_AGENT.to_owned())); } - let mut req = self.client.inner.request(self.method, self.url) - .headers(self.headers); + let client = self.client; + let mut method = self.method; + let mut url = try!(self.url); + let mut headers = self.headers; + let mut body = self.body; - if let Some(ref b) = self.body { - let body = body::as_hyper_body(b); - req = req.body(body); + let mut redirect_count = 0; + + loop { + let res = { + debug!("request {:?} \"{}\"", method, url); + let mut req = client.inner.request(method.clone(), url.clone()) + .headers(headers.clone()); + + if let Some(ref b) = body { + let body = body::as_hyper_body(&b); + req = req.body(body); + } + + try!(req.send()) + }; + body.take(); + + match res.status { + StatusCode::MovedPermanently | + StatusCode::Found => { + + //TODO: turn this into self.redirect_policy.check() + if redirect_count > 10 { + return Err(::Error::TooManyRedirects); + } + redirect_count += 1; + + method = match method { + Method::Post | Method::Put => Method::Get, + m => m + }; + + headers.set(Referer(url.to_string())); + + let loc = { + let loc = res.headers.get::().map(|loc| url.join(loc)); + if let Some(loc) = loc { + loc + } else { + return Ok(Response { + inner: res + }); + } + }; + + url = match loc { + Ok(u) => u, + Err(e) => { + debug!("Location header had invalid URI: {:?}", e); + return Ok(Response { + inner: res + }) + } + }; + + debug!("redirecting to '{}'", url); + + //TODO: removeSensitiveHeaders(&mut headers, &url); + + }, + _ => { + return Ok(Response { + inner: res + }); + } + } } - - let res = try!(req.send()); - Ok(Response { - inner: res - }) } } diff --git a/src/error.rs b/src/error.rs index 6bb8e3b..00167b3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,17 +1,60 @@ +use std::error::Error as StdError; +use std::fmt; + /// The Errors that may occur when processing a `Request`. #[derive(Debug)] pub enum Error { /// An HTTP error from the `hyper` crate. Http(::hyper::Error), + /// A request tried to redirect too many times. + TooManyRedirects, #[doc(hidden)] __DontMatchMe, } +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Error::Http(ref e) => fmt::Display::fmt(e, f), + Error::TooManyRedirects => { + f.pad("Too many redirects") + }, + Error::__DontMatchMe => unreachable!() + } + } +} + +impl StdError for Error { + fn description(&self) -> &str { + match *self { + Error::Http(ref e) => e.description(), + Error::TooManyRedirects => "Too many redirects", + Error::__DontMatchMe => unreachable!() + } + } + + fn cause(&self) -> Option<&StdError> { + match *self { + Error::Http(ref e) => Some(e), + Error::TooManyRedirects => None, + Error::__DontMatchMe => unreachable!() + } + } +} + impl From<::hyper::Error> for Error { fn from(err: ::hyper::Error) -> Error { Error::Http(err) } } +impl From<::url::ParseError> for Error { + fn from(err: ::url::ParseError) -> Error { + Error::Http(::hyper::Error::Uri(err)) + } +} + + + /// A `Result` alias where the `Err` case is `reqwest::Error`. pub type Result = ::std::result::Result; diff --git a/src/lib.rs b/src/lib.rs index 7559eed..7799a37 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,22 +22,30 @@ //! //! ## Making a GET request //! +//! For a single request, you can use the `get` shortcut method. +//! +//! //! ```no_run //! let resp = reqwest::get("https://www.rust-lang.org").unwrap(); //! assert!(resp.status().is_success()); //! ``` +//! +//! If you plan to perform multiple requests, it is best to create a [`Client`][client] +//! and reuse it, taking advantage of keep-alive connection pooling. extern crate hyper; #[macro_use] extern crate log; #[cfg(feature = "tls")] extern crate native_tls; extern crate serde; extern crate serde_json; +extern crate url; pub use hyper::header; pub use hyper::method::Method; pub use hyper::status::StatusCode; pub use hyper::version::HttpVersion; pub use hyper::Url; +pub use url::ParseError as UrlError; pub use self::client::{Client, Response}; pub use self::error::{Error, Result}; @@ -50,6 +58,6 @@ mod error; /// Shortcut method to quickly make a `GET` request. pub fn get(url: &str) -> ::Result { - let client = Client::new(); + let client = try!(Client::new()); client.get(url).send() }