diff --git a/.gitignore b/.gitignore index a9d37c5..d4f917d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ target Cargo.lock +*.swp diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 32ab23d..9d1d57c 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -63,6 +63,7 @@ pub struct ClientBuilder { struct Config { gzip: bool, + headers: Headers, hostname_verification: bool, proxies: Vec, redirect_policy: RedirectPolicy, @@ -76,18 +77,25 @@ impl ClientBuilder { /// Constructs a new `ClientBuilder` pub fn new() -> ClientBuilder { match TlsConnector::builder() { - Ok(tls_connector_builder) => ClientBuilder { - config: Some(Config { - gzip: true, - hostname_verification: true, - proxies: Vec::new(), - redirect_policy: RedirectPolicy::default(), - referer: true, - timeout: None, - tls: tls_connector_builder, - dns_threads: 4, - }), - err: None, + Ok(tls_connector_builder) => { + let mut headers = Headers::with_capacity(2); + headers.set(UserAgent::new(DEFAULT_USER_AGENT)); + headers.set(Accept::star()); + + ClientBuilder { + config: Some(Config { + gzip: true, + headers: headers, + hostname_verification: true, + proxies: Vec::new(), + redirect_policy: RedirectPolicy::default(), + referer: true, + timeout: None, + tls: tls_connector_builder, + dns_threads: 4, + }), + err: None, + } }, Err(e) => ClientBuilder { config: None, @@ -131,6 +139,7 @@ impl ClientBuilder { inner: Arc::new(ClientRef { gzip: config.gzip, hyper: hyper_client, + headers: config.headers, proxies: proxies, redirect_policy: config.redirect_policy, referer: config.referer, @@ -189,6 +198,15 @@ impl ClientBuilder { self } + /// Sets the default headers for every request. + #[inline] + pub fn default_headers(&mut self, headers: Headers) -> &mut ClientBuilder { + if let Some(config) = config_mut(&mut self.config, &self.err) { + config.headers.extend(headers.iter()); + } + self + } + /// Enable auto gzip decompression by checking the ContentEncoding response header. /// /// Default is enabled. @@ -372,17 +390,13 @@ impl Client { let ( method, url, - mut headers, + user_headers, body ) = request::pieces(req); - if !headers.has::() { - headers.set(UserAgent::new(DEFAULT_USER_AGENT)); - } + let mut headers = self.inner.headers.clone(); // default headers + headers.extend(user_headers.iter()); - if !headers.has::() { - headers.set(Accept::star()); - } if self.inner.gzip && !headers.has::() && !headers.has::() { @@ -442,6 +456,7 @@ impl fmt::Debug for ClientBuilder { struct ClientRef { gzip: bool, + headers: Headers, hyper: HyperClient, proxies: Arc>, redirect_policy: RedirectPolicy, diff --git a/src/client.rs b/src/client.rs index 480c8d7..11f154d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,7 +8,7 @@ use futures::sync::{mpsc, oneshot}; use request::{self, Request, RequestBuilder}; use response::{self, Response}; -use {async_impl, Certificate, Identity, Method, IntoUrl, Proxy, RedirectPolicy, wait}; +use {async_impl, header, Certificate, Identity, Method, IntoUrl, Proxy, RedirectPolicy, wait}; /// A `Client` to make Requests with. /// @@ -167,6 +167,50 @@ impl ClientBuilder { self } + /// Sets the default headers for every request. + /// + /// # Example + /// + /// ```rust + /// use reqwest::header; + /// # fn build_client() -> Result<(), Box> { + /// let mut headers = header::Headers::new(); + /// headers.set(header::Authorization("secret".to_string())); + /// + /// // get a client builder + /// let client = reqwest::Client::builder() + /// .default_headers(headers) + /// .build()?; + /// let res = client.get("https://www.rust-lang.org").send()?; + /// # Ok(()) + /// # } + /// ``` + /// + /// Override the default headers: + /// + /// ```rust + /// use reqwest::header; + /// # fn build_client() -> Result<(), Box> { + /// let mut headers = header::Headers::new(); + /// headers.set(header::Authorization("secret".to_string())); + /// + /// // get a client builder + /// let client = reqwest::Client::builder() + /// .default_headers(headers) + /// .build()?; + /// let res = client + /// .get("https://www.rust-lang.org") + /// .header(header::Authorization("token".to_string())) + /// .send()?; + /// # Ok(()) + /// # } + /// ``` + #[inline] + pub fn default_headers(&mut self, headers: header::Headers) -> &mut ClientBuilder { + self.inner.default_headers(headers); + self + } + /// Enable auto gzip decompression by checking the ContentEncoding response header. /// /// Default is enabled. diff --git a/tests/client.rs b/tests/client.rs index 541f1ba..a302a2a 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -79,9 +79,9 @@ fn test_post() { request: b"\ POST /2 HTTP/1.1\r\n\ Host: $HOST\r\n\ - Content-Length: 5\r\n\ User-Agent: $USERAGENT\r\n\ Accept: */*\r\n\ + Content-Length: 5\r\n\ Accept-Encoding: gzip\r\n\ \r\n\ Hello\ @@ -170,3 +170,111 @@ fn test_error_for_status_5xx() { assert!(err.is_server_error()); assert_eq!(err.status(), Some(reqwest::StatusCode::InternalServerError)); } + +#[test] +fn test_default_headers() { + use reqwest::header; + let mut headers = header::Headers::with_capacity(1); + let mut cookies = header::Cookie::new(); + cookies.set("a", "b"); + cookies.set("c", "d"); + headers.set(cookies); + let client = reqwest::Client::builder() + .default_headers(headers) + .build().unwrap(); + + let server = server! { + request: b"\ + GET /1 HTTP/1.1\r\n\ + Host: $HOST\r\n\ + User-Agent: $USERAGENT\r\n\ + Accept: */*\r\n\ + Cookie: a=b; c=d\r\n\ + Accept-Encoding: gzip\r\n\ + \r\n\ + ", + response: b"\ + HTTP/1.1 200 OK\r\n\ + Server: test\r\n\ + Content-Length: 0\r\n\ + \r\n\ + " + }; + + let url = format!("http://{}/1", server.addr()); + let res = client.get(&url).send().unwrap(); + + assert_eq!(res.url().as_str(), &url); + assert_eq!(res.status(), reqwest::StatusCode::Ok); + assert_eq!(res.headers().get(), + Some(&reqwest::header::Server::new("test"))); + assert_eq!(res.headers().get(), + Some(&reqwest::header::ContentLength(0))); + + let server = server! { + request: b"\ + GET /2 HTTP/1.1\r\n\ + Host: $HOST\r\n\ + User-Agent: $USERAGENT\r\n\ + Accept: */*\r\n\ + Cookie: a=b; c=d\r\n\ + Accept-Encoding: gzip\r\n\ + \r\n\ + ", + response: b"\ + HTTP/1.1 200 OK\r\n\ + Server: test\r\n\ + Content-Length: 0\r\n\ + \r\n\ + " + }; + + let url = format!("http://{}/2", server.addr()); + let res = client.get(&url).send().unwrap(); + + assert_eq!(res.url().as_str(), &url); + assert_eq!(res.status(), reqwest::StatusCode::Ok); + assert_eq!(res.headers().get(), + Some(&reqwest::header::Server::new("test"))); + assert_eq!(res.headers().get(), + Some(&reqwest::header::ContentLength(0))); +} + +#[test] +fn test_override_default_headers() { + use reqwest::header; + let mut headers = header::Headers::with_capacity(1); + headers.set(header::Authorization("iamatoken".to_string())); + let client = reqwest::Client::builder() + .default_headers(headers) + .build().unwrap(); + + let server = server! { + request: b"\ + GET /3 HTTP/1.1\r\n\ + Host: $HOST\r\n\ + User-Agent: $USERAGENT\r\n\ + Accept: */*\r\n\ + Authorization: secret\r\n\ + Accept-Encoding: gzip\r\n\ + \r\n\ + ", + response: b"\ + HTTP/1.1 200 OK\r\n\ + Server: test\r\n\ + Content-Length: 0\r\n\ + \r\n\ + " + }; + + let url = format!("http://{}/3", server.addr()); + let res = client.get(&url).header(header::Authorization("secret".to_string())).send().unwrap(); + + assert_eq!(res.url().as_str(), &url); + assert_eq!(res.status(), reqwest::StatusCode::Ok); + assert_eq!(res.headers().get(), + Some(&reqwest::header::Server::new("test"))); + assert_eq!(res.headers().get(), + Some(&reqwest::header::ContentLength(0))); + +} diff --git a/tests/gzip.rs b/tests/gzip.rs index 15715c3..02a3083 100644 --- a/tests/gzip.rs +++ b/tests/gzip.rs @@ -114,8 +114,8 @@ fn test_accept_header_is_not_changed_if_set() { request: b"\ GET /accept HTTP/1.1\r\n\ Host: $HOST\r\n\ - Accept: application/json\r\n\ User-Agent: $USERAGENT\r\n\ + Accept: application/json\r\n\ Accept-Encoding: gzip\r\n\ \r\n\ ", @@ -143,9 +143,9 @@ fn test_accept_encoding_header_is_not_changed_if_set() { request: b"\ GET /accept-encoding HTTP/1.1\r\n\ Host: $HOST\r\n\ - Accept-Encoding: identity\r\n\ User-Agent: $USERAGENT\r\n\ Accept: */*\r\n\ + Accept-Encoding: identity\r\n\ \r\n\ ", response: b"\ diff --git a/tests/multipart.rs b/tests/multipart.rs index 247ec7d..16c6207 100644 --- a/tests/multipart.rs +++ b/tests/multipart.rs @@ -22,10 +22,10 @@ fn test_multipart() { request: format!("\ POST /multipart/1 HTTP/1.1\r\n\ Host: $HOST\r\n\ - Content-Type: multipart/form-data; boundary={}\r\n\ - Content-Length: 123\r\n\ User-Agent: $USERAGENT\r\n\ Accept: */*\r\n\ + Content-Type: multipart/form-data; boundary={}\r\n\ + Content-Length: 123\r\n\ Accept-Encoding: gzip\r\n\ \r\n\ {}\ diff --git a/tests/redirect.rs b/tests/redirect.rs index ff46d5a..577f3e4 100644 --- a/tests/redirect.rs +++ b/tests/redirect.rs @@ -117,9 +117,9 @@ fn test_redirect_307_and_308_tries_to_post_again() { request: format!("\ POST /{} HTTP/1.1\r\n\ Host: $HOST\r\n\ - Content-Length: 5\r\n\ User-Agent: $USERAGENT\r\n\ Accept: */*\r\n\ + Content-Length: 5\r\n\ Accept-Encoding: gzip\r\n\ \r\n\ Hello\ @@ -136,9 +136,9 @@ fn test_redirect_307_and_308_tries_to_post_again() { request: format!("\ POST /dst HTTP/1.1\r\n\ Host: $HOST\r\n\ - Content-Length: 5\r\n\ User-Agent: $USERAGENT\r\n\ Accept: */*\r\n\ + Content-Length: 5\r\n\ Accept-Encoding: gzip\r\n\ Referer: http://$HOST/{}\r\n\ \r\n\ @@ -229,9 +229,9 @@ fn test_redirect_removes_sensitive_headers() { request: b"\ GET /sensitive HTTP/1.1\r\n\ Host: $HOST\r\n\ - Cookie: foo=bar\r\n\ User-Agent: $USERAGENT\r\n\ Accept: */*\r\n\ + Cookie: foo=bar\r\n\ Accept-Encoding: gzip\r\n\ \r\n\ ", diff --git a/tests/support/server.rs b/tests/support/server.rs index 9e1a7f3..145bbf0 100644 --- a/tests/support/server.rs +++ b/tests/support/server.rs @@ -51,7 +51,10 @@ pub fn spawn(txns: Vec) -> Server { let mut n = 0; while n < expected.len() { - n += socket.read(&mut buf).unwrap(); + match socket.read(&mut buf[n..]) { + Ok(0) | Err(_) => break, + Ok(nread) => n += nread, + } } match (::std::str::from_utf8(&expected), ::std::str::from_utf8(&buf[..n])) {