diff --git a/Cargo.toml b/Cargo.toml index a3cfec0..b8c3b18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ base64 = "0.9" bytes = "0.4" encoding_rs = "0.8" futures = "0.1.23" -http = "0.1.5" +http = "0.1.10" hyper = "0.12.7" hyper-old-types = { version = "0.11", optional = true, features = ["compat"] } hyper-tls = "0.3" diff --git a/src/body.rs b/src/body.rs index e7ca538..eee5cab 100644 --- a/src/body.rs +++ b/src/body.rs @@ -3,6 +3,7 @@ use std::fmt; use std::io::{self, Cursor, Read}; use bytes::Bytes; +use futures::Future; use hyper::{self}; use {async_impl}; @@ -77,6 +78,37 @@ impl Body { kind: Kind::Reader(Box::from(reader), Some(len)), } } + + pub(crate) fn len(&self) -> Option { + match self.kind { + Kind::Reader(_, len) => len, + Kind::Bytes(ref bytes) => Some(bytes.len() as u64), + } + } + + pub(crate) fn into_reader(self) -> Reader { + match self.kind { + Kind::Reader(r, _) => Reader::Reader(r), + Kind::Bytes(b) => Reader::Bytes(Cursor::new(b)), + } + } + + pub(crate) fn into_async(self) -> (Option, async_impl::Body, Option) { + match self.kind { + Kind::Reader(read, len) => { + let (tx, rx) = hyper::Body::channel(); + let tx = Sender { + body: (read, len), + tx: tx, + }; + (Some(tx), async_impl::body::wrap(rx), len) + }, + Kind::Bytes(chunk) => { + let len = chunk.len() as u64; + (None, async_impl::body::reusable(chunk), Some(len)) + } + } + } } @@ -150,29 +182,11 @@ impl<'a> fmt::Debug for DebugLength<'a> { } } - -// pub(crate) - -pub fn len(body: &Body) -> Option { - match body.kind { - Kind::Reader(_, len) => len, - Kind::Bytes(ref bytes) => Some(bytes.len() as u64), - } -} - -pub enum Reader { +pub(crate) enum Reader { Reader(Box), Bytes(Cursor), } -#[inline] -pub fn reader(body: Body) -> Reader { - match body.kind { - Kind::Reader(r, _) => Reader::Reader(r), - Kind::Bytes(b) => Reader::Bytes(Cursor::new(b)), - } -} - impl Read for Reader { fn read(&mut self, buf: &mut [u8]) -> io::Result { match *self { @@ -182,25 +196,37 @@ impl Read for Reader { } } -pub struct Sender { +pub(crate) struct Sender { body: (Box, Option), tx: hyper::body::Sender, } impl Sender { - pub fn send(self) -> ::Result<()> { + // A `Future` that may do blocking read calls. + // As a `Future`, this integrates easily with `wait::timeout`. + pub(crate) fn send(self) -> impl Future { use std::cmp; use bytes::{BufMut, BytesMut}; + use futures::future; let cap = cmp::min(self.body.1.unwrap_or(8192), 8192); let mut buf = BytesMut::with_capacity(cap as usize); let mut body = self.body.0; - let mut tx = self.tx; - loop { + // Put in an option so that it can be consumed on error to call abort() + let mut tx = Some(self.tx); + + future::poll_fn(move || loop { + try_ready!(tx + .as_mut() + .expect("tx only taken on error") + .poll_ready() + .map_err(::error::from)); + match body.read(unsafe { buf.bytes_mut() }) { - Ok(0) => return Ok(()), + Ok(0) => return Ok(().into()), Ok(n) => { unsafe { buf.advance_mut(n); } + let tx = tx.as_mut().expect("tx only taken on error"); if let Err(_) = tx.send_data(buf.take().freeze().into()) { return Err(::error::timedout(None)); } @@ -210,35 +236,20 @@ impl Sender { } Err(e) => { let ret = io::Error::new(e.kind(), e.to_string()); - tx.abort(); + tx + .take() + .expect("tx only taken on error") + .abort(); return Err(::error::from(ret)); } } - } - } -} - -#[inline] -pub fn async(body: Body) -> (Option, async_impl::Body, Option) { - match body.kind { - Kind::Reader(read, len) => { - let (tx, rx) = hyper::Body::channel(); - let tx = Sender { - body: (read, len), - tx: tx, - }; - (Some(tx), async_impl::body::wrap(rx), len) - }, - Kind::Bytes(chunk) => { - let len = chunk.len() as u64; - (None, async_impl::body::reusable(chunk), Some(len)) - } + }) } } // useful for tests, but not publicly exposed #[cfg(test)] -pub fn read_to_string(mut body: Body) -> io::Result { +pub(crate) fn read_to_string(mut body: Body) -> io::Result { let mut s = String::new(); match body.kind { Kind::Reader(ref mut reader, _) => reader.read_to_string(&mut s), diff --git a/src/client.rs b/src/client.rs index 7fd9c0e..3067c83 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,6 +4,7 @@ use std::time::Duration; use std::thread; use futures::{Future, Stream}; +use futures::future::{self, Either}; use futures::sync::{mpsc, oneshot}; use request::{Request, RequestBuilder}; @@ -466,20 +467,29 @@ impl ClientHandle { .unbounded_send((req, tx)) .expect("core thread panicked"); - if let Some(body) = body { - try_!(body.send(), &url); - } + let write = if let Some(body) = body { + Either::A(body.send()) + //try_!(body.send(self.timeout.0), &url); + } else { + Either::B(future::ok(())) + }; - let res = match wait::timeout(rx, self.timeout.0) { + let rx = rx.map_err(|_canceled| { + // The only possible reason there would be a Canceled error + // is if the thread running the event loop panicked. We could return + // an Err here, like a BrokenPipe, but the Client is not + // recoverable. Additionally, the panic in the other thread + // is not normal, and should likely be propagated. + panic!("event loop thread panicked"); + }); + + let fut = write.join(rx).map(|((), res)| res); + + let res = match wait::timeout(fut, self.timeout.0) { Ok(res) => res, Err(wait::Waited::TimedOut) => return Err(::error::timedout(Some(url))), - Err(wait::Waited::Err(_canceled)) => { - // The only possible reason there would be a Cancelled error - // is if the thread running the Core panicked. We could return - // an Err here, like a BrokenPipe, but the Client is not - // recoverable. Additionally, the panic in the other thread - // is not normal, and should likely be propagated. - panic!("core thread panicked"); + Err(wait::Waited::Err(err)) => { + return Err(err.with_url(url)); } }; res.map(|res| { diff --git a/src/error.rs b/src/error.rs index 27b53ea..5b49c84 100644 --- a/src/error.rs +++ b/src/error.rs @@ -82,6 +82,12 @@ impl Error { self.url.as_ref() } + pub(crate) fn with_url(mut self, url: Url) -> Error { + debug_assert_eq!(self.url, None, "with_url overriding existing url"); + self.url = Some(url); + self + } + /// Returns a reference to the internal error, if available. /// /// The `'static` bounds allows using `downcast_ref` to check the diff --git a/src/multipart.rs b/src/multipart.rs index 37679a2..1b1b7ff 100644 --- a/src/multipart.rs +++ b/src/multipart.rs @@ -93,7 +93,7 @@ impl Form { pub(crate) fn compute_length(&mut self) -> Option { let mut length = 0u64; for &(ref name, ref field) in self.fields.iter() { - match ::body::len(&field.value) { + match field.value.len() { Some(value_length) => { // We are constructing the header just to get its length. To not have to // construct it again when the request is sent we cache these headers. @@ -272,7 +272,7 @@ impl Reader { }); let reader = boundary .chain(header) - .chain(::body::reader(field.value)) + .chain(field.value.into_reader()) .chain(Cursor::new("\r\n")); // According to https://tools.ietf.org/html/rfc2046#section-5.1.1 // the very last field has a special boundary diff --git a/src/request.rs b/src/request.rs index cb11160..30af213 100644 --- a/src/request.rs +++ b/src/request.rs @@ -86,9 +86,9 @@ impl Request { let mut req_async = self.inner; let body = self.body.and_then(|body| { - let (tx, body, len) = body::async(body); + let (tx, body, len) = body.into_async(); if let Some(len) = len { - req_async.headers_mut().insert(CONTENT_LENGTH, HeaderValue::from_str(len.to_string().as_str()).expect("")); + req_async.headers_mut().insert(CONTENT_LENGTH, len.into()); } *req_async.body_mut() = Some(body); tx diff --git a/tests/multipart.rs b/tests/multipart.rs index bcfaa23..bf07613 100644 --- a/tests/multipart.rs +++ b/tests/multipart.rs @@ -49,3 +49,52 @@ fn test_multipart() { assert_eq!(res.url().as_str(), &url); assert_eq!(res.status(), reqwest::StatusCode::OK); } + +#[test] +fn file() { + let _ = env_logger::try_init(); + + let form = reqwest::multipart::Form::new() + .file("foo", "Cargo.lock").unwrap(); + + let fcontents = ::std::fs::read_to_string("Cargo.lock").unwrap(); + + let expected_body = format!("\ + --{0}\r\n\ + Content-Disposition: form-data; name=\"foo\"; filename=\"Cargo.lock\"\r\n\ + Content-Type: application/octet-stream\r\n\r\n\ + {1}\r\n\ + --{0}--\r\n\ + ", form.boundary(), fcontents); + + let server = server! { + request: format!("\ + POST /multipart/2 HTTP/1.1\r\n\ + user-agent: $USERAGENT\r\n\ + accept: */*\r\n\ + content-type: multipart/form-data; boundary={}\r\n\ + content-length: {}\r\n\ + accept-encoding: gzip\r\n\ + host: $HOST\r\n\ + \r\n\ + {}\ + ", form.boundary(), expected_body.len(), expected_body), + response: b"\ + HTTP/1.1 200 OK\r\n\ + Server: multipart\r\n\ + Content-Length: 0\r\n\ + \r\n\ + " + }; + + let url = format!("http://{}/multipart/2", server.addr()); + + let res = reqwest::Client::new() + .post(&url) + .multipart(form) + .send() + .unwrap(); + + assert_eq!(res.url().as_str(), &url); + assert_eq!(res.status(), reqwest::StatusCode::OK); +} diff --git a/tests/support/server.rs b/tests/support/server.rs index 145bbf0..03fecc5 100644 --- a/tests/support/server.rs +++ b/tests/support/server.rs @@ -32,7 +32,8 @@ static DEFAULT_USER_AGENT: &'static str = pub fn spawn(txns: Vec) -> Server { let listener = net::TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); - thread::spawn( + let tname = format!("test({})-support-server", thread::current().name().unwrap_or("")); + thread::Builder::new().name(tname).spawn( move || for txn in txns { let mut expected = txn.request; let reply = txn.response; @@ -46,20 +47,41 @@ pub fn spawn(txns: Vec) -> Server { thread::park_timeout(dur); } - let mut buf = [0; 4096]; - assert!(buf.len() >= expected.len()); + let mut buf = vec![0; expected.len() + 256]; let mut n = 0; while n < expected.len() { match socket.read(&mut buf[n..]) { - Ok(0) | Err(_) => break, + Ok(0) => break, Ok(nread) => n += nread, + Err(err) => { + println!("server read error: {}", err); + break; + } } } match (::std::str::from_utf8(&expected), ::std::str::from_utf8(&buf[..n])) { - (Ok(expected), Ok(received)) => assert_eq!(expected, received), - _ => assert_eq!(expected, &buf[..n]), + (Ok(expected), Ok(received)) => { + assert_eq!( + expected.len(), + received.len(), + "expected len = {}, received len = {}", + expected.len(), + received.len(), + ); + assert_eq!(expected, received) + }, + _ => { + assert_eq!( + expected.len(), + n, + "expected len = {}, received len = {}", + expected.len(), + n, + ); + assert_eq!(expected, &buf[..n]) + }, } if let Some(dur) = txn.response_timeout { @@ -86,7 +108,7 @@ pub fn spawn(txns: Vec) -> Server { socket.write_all(&reply).unwrap(); } } - ); + ).expect("server thread spawn"); Server { addr: addr,