diff --git a/src/multipart.rs b/src/multipart.rs index 6e445c6..8dc58dd 100644 --- a/src/multipart.rs +++ b/src/multipart.rs @@ -6,8 +6,7 @@ use std::io::{self, Cursor, Read}; use std::path::Path; use mime_guess::{self, Mime}; -use url::percent_encoding; -use url::percent_encoding::EncodeSet; +use url::percent_encoding::{self, EncodeSet, PATH_SEGMENT_ENCODE_SET}; use uuid::Uuid; use http::HeaderMap; @@ -18,6 +17,12 @@ pub struct Form { boundary: String, fields: Vec<(Cow<'static, str>, Part)>, headers: Vec>, + percent_encoding: PercentEncoding, +} + +enum PercentEncoding { + PathSegment, + AttrChar, } impl Form { @@ -27,6 +32,7 @@ impl Form { boundary: format!("{}", Uuid::new_v4().to_simple()), fields: Vec::new(), headers: Vec::new(), + percent_encoding: PercentEncoding::PathSegment, } } @@ -84,6 +90,18 @@ impl Form { self } + /// Configure this `Form` to percent-encode using the `path-segment` rules. + pub fn percent_encode_path_segment(mut self) -> Form { + self.percent_encoding = PercentEncoding::PathSegment; + self + } + + /// Configure this `Form` to percent-encode using the `attr-char` rules. + pub fn percent_encode_attr_chars(mut self) -> Form { + self.percent_encoding = PercentEncoding::AttrChar; + self + } + pub(crate) fn reader(self) -> Reader { Reader::new(self) } @@ -98,7 +116,7 @@ impl Form { Some(value_length) => { // We are constructing the header just to get its length. To not have to // construct it again when the request is sent we cache these headers. - let header = header(name, field); + let header = self.percent_encoding.encode_headers(name, field); let header_length = header.len(); self.headers.push(header); // The additions mimick the format string out of which the field is constructed @@ -277,7 +295,7 @@ impl Reader { let mut h = if self.form.headers.len() > 0 { self.form.headers.remove(0) } else { - header(&name, &field) + self.form.percent_encoding.encode_headers(&name, &field) }; h.extend_from_slice(b"\r\n\r\n"); h @@ -350,41 +368,51 @@ impl EncodeSet for AttrCharEncodeSet { } -fn header(name: &str, field: &Part) -> Vec { - let s = format!( - "Content-Disposition: form-data; {}{}{}", - format_parameter("name", name), - match field.file_name { - Some(ref file_name) => format!("; {}", format_parameter("filename", file_name)), - None => String::new(), - }, - match field.mime { - Some(ref mime) => format!("\r\nContent-Type: {}", mime), - None => "".to_string(), - }, - ); - field.headers.iter().fold(s.into_bytes(), |mut header, (k,v)| { - header.extend_from_slice(b"\r\n"); - header.extend_from_slice(k.as_str().as_bytes()); - header.extend_from_slice(b": "); - header.extend_from_slice(v.as_bytes()); - header - }) -} +impl PercentEncoding { + fn encode_headers(&self, name: &str, field: &Part) -> Vec { + let s = format!( + "Content-Disposition: form-data; {}{}{}", + self.format_parameter("name", name), + match field.file_name { + Some(ref file_name) => format!("; {}", self.format_parameter("filename", file_name)), + None => String::new(), + }, + match field.mime { + Some(ref mime) => format!("\r\nContent-Type: {}", mime), + None => "".to_string(), + }, + ); + field.headers.iter().fold(s.into_bytes(), |mut header, (k,v)| { + header.extend_from_slice(b"\r\n"); + header.extend_from_slice(k.as_str().as_bytes()); + header.extend_from_slice(b": "); + header.extend_from_slice(v.as_bytes()); + header + }) + } -fn format_parameter(name: &str, value: &str) -> String { - let legal_value = - percent_encoding::utf8_percent_encode(value, AttrCharEncodeSet) - .to_string(); - if value.len() == legal_value.len() { - // nothing has been percent encoded - format!("{}=\"{}\"", name, value) - } else { - // something has been percent encoded - format!("{}*=utf-8''{}", name, legal_value) + fn format_parameter(&self, name: &str, value: &str) -> String { + let legal_value = match *self { + PercentEncoding::PathSegment => { + percent_encoding::utf8_percent_encode(value, PATH_SEGMENT_ENCODE_SET) + .to_string() + }, + PercentEncoding::AttrChar => { + percent_encoding::utf8_percent_encode(value, AttrCharEncodeSet) + .to_string() + }, + }; + if value.len() == legal_value.len() { + // nothing has been percent encoded + format!("{}=\"{}\"", name, value) + } else { + // something has been percent encoded + format!("{}*=utf-8''{}", name, legal_value) + } } } + #[cfg(test)] mod tests { use super::*; @@ -507,8 +535,15 @@ mod tests { fn header_percent_encoding() { let name = "start%'\"\r\nßend"; let field = Part::text(""); - let expected = "Content-Disposition: form-data; name*=utf-8''start%25%27%22%0D%0A%C3%9Fend"; - assert_eq!(header(name, &field), expected.as_bytes()); + assert_eq!( + PercentEncoding::PathSegment.encode_headers(name, &field), + &b"Content-Disposition: form-data; name*=utf-8''start%25'%22%0D%0A%C3%9Fend"[..] + ); + + assert_eq!( + PercentEncoding::AttrChar.encode_headers(name, &field), + &b"Content-Disposition: form-data; name*=utf-8''start%25%27%22%0D%0A%C3%9Fend"[..] + ); } }