diff --git a/src/async_impl/multipart.rs b/src/async_impl/multipart.rs index 45b8269..526c68b 100644 --- a/src/async_impl/multipart.rs +++ b/src/async_impl/multipart.rs @@ -3,7 +3,7 @@ use std::borrow::Cow; use std::fmt; use std::pin::Pin; -use bytes::{Bytes}; +use bytes::Bytes; use http::HeaderMap; use mime_guess::Mime; use percent_encoding::{self, AsciiSet, NON_ALPHANUMERIC}; @@ -22,6 +22,7 @@ pub struct Form { pub struct Part { meta: PartMetadata, value: Body, + body_length: Option, } pub(crate) struct FormParts

{ @@ -190,7 +191,7 @@ impl Part { Cow::Borrowed(slice) => Body::from(slice), Cow::Owned(string) => Body::from(string), }; - Part::new(body) + Part::new(body, None) } /// Makes a new parameter from arbitrary bytes. @@ -202,18 +203,26 @@ impl Part { Cow::Borrowed(slice) => Body::from(slice), Cow::Owned(vec) => Body::from(vec), }; - Part::new(body) + Part::new(body, None) } /// Makes a new parameter from an arbitrary stream. pub fn stream>(value: T) -> Part { - Part::new(value.into()) + Part::new(value.into(), None) } - fn new(value: Body) -> Part { + /// Makes a new parameter from an arbitrary stream with a known length. This is particularly + /// useful when adding something like file contents as a stream, where you can know the content + /// length beforehand. + pub fn stream_with_length>(value: T, length: u64) -> Part { + Part::new(value.into(), Some(length)) + } + + fn new(value: Body, body_length: Option) -> Part { Part { meta: PartMetadata::new(), value, + body_length, } } @@ -241,7 +250,7 @@ impl Part { { Part { meta: func(self.meta), - value: self.value, + ..self } } } @@ -257,7 +266,11 @@ impl fmt::Debug for Part { impl PartProps for Part { fn value_len(&self) -> Option { - self.value.content_length() + if self.body_length.is_some() { + self.body_length + } else { + self.value.content_length() + } } fn metadata(&self) -> &PartMetadata { @@ -508,7 +521,11 @@ mod tests { fn form_empty() { let form = Form::new(); - let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt"); + let mut rt = runtime::Builder::new() + .basic_scheduler() + .enable_all() + .build() + .expect("new rt"); let body = form.stream().into_stream(); let s = body.map_ok(|try_c| try_c.to_vec()).try_concat(); @@ -524,7 +541,7 @@ mod tests { Part::stream(Body::stream(stream::once(future::ready::< Result, >(Ok( - "part1".to_owned(), + "part1".to_owned() ))))), ) .part("key1", Part::text("value1")) @@ -534,13 +551,12 @@ mod tests { Part::stream(Body::stream(stream::once(future::ready::< Result, >(Ok( - "part2".to_owned(), + "part2".to_owned() ))))), ) .part("key3", Part::text("value3").file_name("filename")); form.inner.boundary = "boundary".to_string(); - let expected = - "--boundary\r\n\ + let expected = "--boundary\r\n\ Content-Disposition: form-data; name=\"reader1\"\r\n\r\n\ part1\r\n\ --boundary\r\n\ @@ -556,7 +572,11 @@ mod tests { --boundary\r\n\ Content-Disposition: form-data; name=\"key3\"; filename=\"filename\"\r\n\r\n\ value3\r\n--boundary--\r\n"; - let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt"); + let mut rt = runtime::Builder::new() + .basic_scheduler() + .enable_all() + .build() + .expect("new rt"); let body = form.stream().into_stream(); let s = body.map(|try_c| try_c.map(|r| r.to_vec())).try_concat(); @@ -583,7 +603,11 @@ mod tests { \r\n\ value2\r\n\ --boundary--\r\n"; - let mut rt = runtime::Builder::new().basic_scheduler().enable_all().build().expect("new rt"); + let mut rt = runtime::Builder::new() + .basic_scheduler() + .enable_all() + .build() + .expect("new rt"); let body = form.stream().into_stream(); let s = body.map(|try_c| try_c.map(|r| r.to_vec())).try_concat(); @@ -597,6 +621,29 @@ mod tests { assert_eq!(std::str::from_utf8(&out).unwrap(), expected); } + #[test] + fn correct_content_length() { + // Setup an arbitrary data stream + let stream_data = b"just some stream data"; + let stream_len = stream_data.len(); + let stream_data = stream_data + .chunks(3) + .map(|c| Ok::<_, std::io::Error>(Bytes::from(c))); + let the_stream = futures_util::stream::iter(stream_data); + + let bytes_data = b"some bytes data".to_vec(); + let bytes_len = bytes_data.len(); + + let stream_part = Part::stream_with_length(Body::stream(the_stream), stream_len as u64); + let body_part = Part::bytes(bytes_data); + + // A simple check to make sure we get the configured body length + assert_eq!(stream_part.value_len().unwrap(), stream_len as u64); + + // Make sure it delegates to the underlying body if length is not specified + assert_eq!(body_part.value_len().unwrap(), bytes_len as u64); + } + #[test] fn header_percent_encoding() { let name = "start%'\"\r\nßend";