Include new cookie header after a redirect (#514)

Closes #510
This commit is contained in:
WindSoilder
2019-05-01 06:15:41 +08:00
committed by Sean McArthur
parent 66a88d946b
commit e0a52dcf5d
2 changed files with 78 additions and 12 deletions

View File

@@ -403,10 +403,10 @@ impl ClientBuilder {
} }
/// Enable a persistent cookie store for the client. /// Enable a persistent cookie store for the client.
/// ///
/// Cookies received in responses will be preserved and included in /// Cookies received in responses will be preserved and included in
/// additional requests. /// additional requests.
/// ///
/// By default, no cookie store is used. /// By default, no cookie store is used.
pub fn cookie_store(mut self, enable: bool) -> ClientBuilder { pub fn cookie_store(mut self, enable: bool) -> ClientBuilder {
self.config.cookie_store = if enable { self.config.cookie_store = if enable {
@@ -549,15 +549,7 @@ impl Client {
if let Some(cookie_store_wrapper) = self.inner.cookie_store.as_ref() { if let Some(cookie_store_wrapper) = self.inner.cookie_store.as_ref() {
if headers.get(::header::COOKIE).is_none() { if headers.get(::header::COOKIE).is_none() {
let cookie_store = cookie_store_wrapper.read().unwrap(); let cookie_store = cookie_store_wrapper.read().unwrap();
let header = cookie_store add_cookie_header(&mut headers, &cookie_store, &url);
.0
.get_request_cookies(&url)
.map(|c| c.encoded().to_string())
.collect::<Vec<_>>()
.join("; ");
if !header.is_empty() {
headers.insert(::header::COOKIE, HeaderValue::from_bytes(header.as_bytes()).unwrap());
}
} }
} }
@@ -823,6 +815,12 @@ impl Future for PendingRequest {
.body(body) .body(body)
.expect("valid request parts"); .expect("valid request parts");
// Add cookies from the cookie store.
if let Some(cookie_store_wrapper) = self.client.cookie_store.as_ref() {
let cookie_store = cookie_store_wrapper.read().unwrap();
add_cookie_header(&mut self.headers, &cookie_store, &self.url);
}
*req.headers_mut() = self.headers.clone(); *req.headers_mut() = self.headers.clone();
self.in_flight = self.client.hyper.request(req); self.in_flight = self.client.hyper.request(req);
continue; continue;
@@ -874,3 +872,18 @@ fn make_referer(next: &Url, previous: &Url) -> Option<HeaderValue> {
referer.set_fragment(None); referer.set_fragment(None);
referer.as_str().parse().ok() referer.as_str().parse().ok()
} }
fn add_cookie_header(headers: &mut HeaderMap, cookie_store: &cookie::CookieStore, url: &Url) {
let header = cookie_store
.0
.get_request_cookies(url)
.map(|c| c.encoded().to_string())
.collect::<Vec<_>>()
.join("; ");
if !header.is_empty() {
headers.insert(
::header::COOKIE,
HeaderValue::from_bytes(header.as_bytes()).unwrap()
);
}
}

View File

@@ -387,3 +387,56 @@ fn test_invalid_location_stops_redirect_gh484() {
assert_eq!(res.status(), reqwest::StatusCode::FOUND); assert_eq!(res.status(), reqwest::StatusCode::FOUND);
assert_eq!(res.headers().get(reqwest::header::SERVER).unwrap(), &"test-yikes"); assert_eq!(res.headers().get(reqwest::header::SERVER).unwrap(), &"test-yikes");
} }
#[test]
fn test_redirect_302_with_set_cookies() {
let code = 302;
let client = reqwest::ClientBuilder::new().cookie_store(true).build().unwrap();
let server = server! {
request: format!("\
GET /{} HTTP/1.1\r\n\
user-agent: $USERAGENT\r\n\
accept: */*\r\n\
accept-encoding: gzip\r\n\
host: $HOST\r\n\
\r\n\
", code),
response: format!("\
HTTP/1.1 {} reason\r\n\
Server: test-redirect\r\n\
Content-Length: 0\r\n\
Location: /dst\r\n\
Connection: close\r\n\
Set-Cookie: key=value\r\n\
\r\n\
", code)
;
request: format!("\
GET /dst HTTP/1.1\r\n\
user-agent: $USERAGENT\r\n\
accept: */*\r\n\
accept-encoding: gzip\r\n\
referer: http://$HOST/{}\r\n\
cookie: key=value\r\n\
host: $HOST\r\n\
\r\n\
", code),
response: b"\
HTTP/1.1 200 OK\r\n\
Server: test-dst\r\n\
Content-Length: 0\r\n\
\r\n\
"
};
let url = format!("http://{}/{}", server.addr(), code);
let dst = format!("http://{}/{}", server.addr(), "dst");
let res = client.get(&url)
.send()
.unwrap();
assert_eq!(res.url().as_str(), dst);
assert_eq!(res.status(), reqwest::StatusCode::OK);
assert_eq!(res.headers().get(reqwest::header::SERVER).unwrap(), &"test-dst");
}