refactor(ffi): Add HeaderCaseMap preserving http1 header casing
This commit is contained in:
		| @@ -44,6 +44,8 @@ where | ||||
|                 error: None, | ||||
|                 keep_alive: KA::Busy, | ||||
|                 method: None, | ||||
|                 #[cfg(feature = "ffi")] | ||||
|                 preserve_header_case: false, | ||||
|                 title_case_headers: false, | ||||
|                 notify_read: false, | ||||
|                 reading: Reading::Init, | ||||
| @@ -142,6 +144,8 @@ where | ||||
|             ParseContext { | ||||
|                 cached_headers: &mut self.state.cached_headers, | ||||
|                 req_method: &mut self.state.method, | ||||
|                 #[cfg(feature = "ffi")] | ||||
|                 preserve_header_case: self.state.preserve_header_case, | ||||
|             } | ||||
|         )) { | ||||
|             Ok(msg) => msg, | ||||
| @@ -474,6 +478,16 @@ where | ||||
|  | ||||
|         self.enforce_version(&mut head); | ||||
|  | ||||
|         // Maybe check if we should preserve header casing on received | ||||
|         // message headers... | ||||
|         #[cfg(feature = "ffi")] | ||||
|         { | ||||
|             if T::is_client() && !self.state.preserve_header_case { | ||||
|                 self.state.preserve_header_case = | ||||
|                     head.extensions.get::<crate::ffi::HeaderCaseMap>().is_some(); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         let buf = self.io.headers_buf(); | ||||
|         match super::role::encode_headers::<T>( | ||||
|             Encode { | ||||
| @@ -736,6 +750,8 @@ struct State { | ||||
|     /// This is used to know things such as if the message can include | ||||
|     /// a body or not. | ||||
|     method: Option<Method>, | ||||
|     #[cfg(feature = "ffi")] | ||||
|     preserve_header_case: bool, | ||||
|     title_case_headers: bool, | ||||
|     /// Set to true when the Dispatcher should poll read operations | ||||
|     /// again. See the `maybe_notify` method for more. | ||||
|   | ||||
| @@ -492,7 +492,7 @@ cfg_server! { | ||||
|                     version: parts.version, | ||||
|                     subject: parts.status, | ||||
|                     headers: parts.headers, | ||||
|                     extensions: http::Extensions::default(), | ||||
|                     extensions: parts.extensions, | ||||
|                 }; | ||||
|                 Poll::Ready(Some(Ok((head, body)))) | ||||
|             } else { | ||||
| @@ -576,7 +576,7 @@ cfg_client! { | ||||
|                                 version: parts.version, | ||||
|                                 subject: crate::proto::RequestLine(parts.method, parts.uri), | ||||
|                                 headers: parts.headers, | ||||
|                                 extensions: http::Extensions::default(), | ||||
|                                 extensions: parts.extensions, | ||||
|                             }; | ||||
|                             *this.callback = Some(cb); | ||||
|                             Poll::Ready(Some(Ok((head, body)))) | ||||
|   | ||||
| @@ -159,6 +159,8 @@ where | ||||
|                 ParseContext { | ||||
|                     cached_headers: parse_ctx.cached_headers, | ||||
|                     req_method: parse_ctx.req_method, | ||||
|                     #[cfg(feature = "ffi")] | ||||
|                     preserve_header_case: parse_ctx.preserve_header_case, | ||||
|                 }, | ||||
|             )? { | ||||
|                 Some(msg) => { | ||||
| @@ -636,6 +638,8 @@ mod tests { | ||||
|             let parse_ctx = ParseContext { | ||||
|                 cached_headers: &mut None, | ||||
|                 req_method: &mut None, | ||||
|                 #[cfg(feature = "ffi")] | ||||
|                 preserve_header_case: false, | ||||
|             }; | ||||
|             assert!(buffered | ||||
|                 .parse::<ClientTransaction>(cx, parse_ctx) | ||||
|   | ||||
| @@ -70,6 +70,8 @@ pub(crate) struct ParsedMessage<T> { | ||||
| pub(crate) struct ParseContext<'a> { | ||||
|     cached_headers: &'a mut Option<HeaderMap>, | ||||
|     req_method: &'a mut Option<Method>, | ||||
|     #[cfg(feature = "ffi")] | ||||
|     preserve_header_case: bool, | ||||
| } | ||||
|  | ||||
| /// Passed to Http1Transaction::encode | ||||
|   | ||||
| @@ -148,6 +148,7 @@ impl Http1Transaction for Server { | ||||
|                         is_http_11 = false; | ||||
|                         Version::HTTP_10 | ||||
|                     }; | ||||
|                     trace!("headers: {:?}", &req.headers); | ||||
|  | ||||
|                     record_header_indices(bytes, &req.headers, &mut headers_indices)?; | ||||
|                     headers_len = req.headers.len(); | ||||
| @@ -692,6 +693,9 @@ impl Http1Transaction for Client { | ||||
|  | ||||
|             let mut keep_alive = version == Version::HTTP_11; | ||||
|  | ||||
|             #[cfg(feature = "ffi")] | ||||
|             let mut header_case_map = crate::ffi::HeaderCaseMap::default(); | ||||
|  | ||||
|             headers.reserve(headers_len); | ||||
|             for header in &headers_indices[..headers_len] { | ||||
|                 let name = header_name!(&slice[header.name.0..header.name.1]); | ||||
| @@ -707,14 +711,28 @@ impl Http1Transaction for Client { | ||||
|                         keep_alive = headers::connection_keep_alive(&value); | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 #[cfg(feature = "ffi")] | ||||
|                 if ctx.preserve_header_case { | ||||
|                     header_case_map.append(&name, slice.slice(header.name.0..header.name.1)); | ||||
|                 } | ||||
|  | ||||
|                 headers.append(name, value); | ||||
|             } | ||||
|  | ||||
|             #[allow(unused_mut)] | ||||
|             let mut extensions = http::Extensions::default(); | ||||
|  | ||||
|             #[cfg(feature = "ffi")] | ||||
|             if ctx.preserve_header_case { | ||||
|                 extensions.insert(header_case_map); | ||||
|             } | ||||
|  | ||||
|             let head = MessageHead { | ||||
|                 version, | ||||
|                 subject: status, | ||||
|                 headers, | ||||
|                 extensions: http::Extensions::default(), | ||||
|                 extensions, | ||||
|             }; | ||||
|             if let Some((decode, is_upgrade)) = Client::decoder(&head, ctx.req_method)? { | ||||
|                 return Ok(Some(ParsedMessage { | ||||
| @@ -766,11 +784,28 @@ impl Http1Transaction for Client { | ||||
|         } | ||||
|         extend(dst, b"\r\n"); | ||||
|  | ||||
|         if msg.title_case_headers { | ||||
|             write_headers_title_case(&msg.head.headers, dst); | ||||
|         } else { | ||||
|             write_headers(&msg.head.headers, dst); | ||||
|         #[cfg(feature = "ffi")] | ||||
|         { | ||||
|             if msg.title_case_headers { | ||||
|                 write_headers_title_case(&msg.head.headers, dst); | ||||
|             } else if let Some(orig_headers) = | ||||
|                 msg.head.extensions.get::<crate::ffi::HeaderCaseMap>() | ||||
|             { | ||||
|                 write_headers_original_case(&msg.head.headers, orig_headers, dst); | ||||
|             } else { | ||||
|                 write_headers(&msg.head.headers, dst); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         #[cfg(not(feature = "ffi"))] | ||||
|         { | ||||
|             if msg.title_case_headers { | ||||
|                 write_headers_title_case(&msg.head.headers, dst); | ||||
|             } else { | ||||
|                 write_headers(&msg.head.headers, dst); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         extend(dst, b"\r\n"); | ||||
|         msg.head.headers.clear(); //TODO: remove when switching to drain() | ||||
|  | ||||
| @@ -1081,6 +1116,40 @@ fn write_headers(headers: &HeaderMap, dst: &mut Vec<u8>) { | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[cfg(feature = "ffi")] | ||||
| #[cold] | ||||
| fn write_headers_original_case( | ||||
|     headers: &HeaderMap, | ||||
|     orig_case: &crate::ffi::HeaderCaseMap, | ||||
|     dst: &mut Vec<u8>, | ||||
| ) { | ||||
|     // For each header name/value pair, there may be a value in the casemap | ||||
|     // that corresponds to the HeaderValue. So, we iterator all the keys, | ||||
|     // and for each one, try to pair the originally cased name with the value. | ||||
|     // | ||||
|     // TODO: consider adding http::HeaderMap::entries() iterator | ||||
|     for name in headers.keys() { | ||||
|         let mut names = orig_case.get_all(name).iter(); | ||||
|  | ||||
|         for value in headers.get_all(name) { | ||||
|             if let Some(orig_name) = names.next() { | ||||
|                 extend(dst, orig_name); | ||||
|             } else { | ||||
|                 extend(dst, name.as_str().as_bytes()); | ||||
|             } | ||||
|  | ||||
|             // Wanted for curl test cases that send `X-Custom-Header:\r\n` | ||||
|             if value.is_empty() { | ||||
|                 extend(dst, b":\r\n"); | ||||
|             } else { | ||||
|                 extend(dst, b": "); | ||||
|                 extend(dst, value.as_bytes()); | ||||
|                 extend(dst, b"\r\n"); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| struct FastWrite<'a>(&'a mut Vec<u8>); | ||||
|  | ||||
| impl<'a> fmt::Write for FastWrite<'a> { | ||||
| @@ -1117,6 +1186,8 @@ mod tests { | ||||
|             ParseContext { | ||||
|                 cached_headers: &mut None, | ||||
|                 req_method: &mut method, | ||||
|                 #[cfg(feature = "ffi")] | ||||
|                 preserve_header_case: false, | ||||
|             }, | ||||
|         ) | ||||
|         .unwrap() | ||||
| @@ -1137,6 +1208,8 @@ mod tests { | ||||
|         let ctx = ParseContext { | ||||
|             cached_headers: &mut None, | ||||
|             req_method: &mut Some(crate::Method::GET), | ||||
|             #[cfg(feature = "ffi")] | ||||
|             preserve_header_case: false, | ||||
|         }; | ||||
|         let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); | ||||
|         assert_eq!(raw.len(), 0); | ||||
| @@ -1152,6 +1225,8 @@ mod tests { | ||||
|         let ctx = ParseContext { | ||||
|             cached_headers: &mut None, | ||||
|             req_method: &mut None, | ||||
|             #[cfg(feature = "ffi")] | ||||
|             preserve_header_case: false, | ||||
|         }; | ||||
|         Server::parse(&mut raw, ctx).unwrap_err(); | ||||
|     } | ||||
| @@ -1165,6 +1240,8 @@ mod tests { | ||||
|                 ParseContext { | ||||
|                     cached_headers: &mut None, | ||||
|                     req_method: &mut None, | ||||
|                     #[cfg(feature = "ffi")] | ||||
|                     preserve_header_case: false, | ||||
|                 }, | ||||
|             ) | ||||
|             .expect("parse ok") | ||||
| @@ -1178,6 +1255,8 @@ mod tests { | ||||
|                 ParseContext { | ||||
|                     cached_headers: &mut None, | ||||
|                     req_method: &mut None, | ||||
|                     #[cfg(feature = "ffi")] | ||||
|                     preserve_header_case: false, | ||||
|                 }, | ||||
|             ) | ||||
|             .expect_err(comment) | ||||
| @@ -1380,6 +1459,8 @@ mod tests { | ||||
|                 ParseContext { | ||||
|                     cached_headers: &mut None, | ||||
|                     req_method: &mut Some(Method::GET), | ||||
|                     #[cfg(feature = "ffi")] | ||||
|                     preserve_header_case: false, | ||||
|                 } | ||||
|             ) | ||||
|             .expect("parse ok") | ||||
| @@ -1393,6 +1474,8 @@ mod tests { | ||||
|                 ParseContext { | ||||
|                     cached_headers: &mut None, | ||||
|                     req_method: &mut Some(m), | ||||
|                     #[cfg(feature = "ffi")] | ||||
|                     preserve_header_case: false, | ||||
|                 }, | ||||
|             ) | ||||
|             .expect("parse ok") | ||||
| @@ -1406,6 +1489,8 @@ mod tests { | ||||
|                 ParseContext { | ||||
|                     cached_headers: &mut None, | ||||
|                     req_method: &mut Some(Method::GET), | ||||
|                     #[cfg(feature = "ffi")] | ||||
|                     preserve_header_case: false, | ||||
|                 }, | ||||
|             ) | ||||
|             .expect_err("parse should err") | ||||
| @@ -1719,6 +1804,8 @@ mod tests { | ||||
|             ParseContext { | ||||
|                 cached_headers: &mut None, | ||||
|                 req_method: &mut Some(Method::GET), | ||||
|                 #[cfg(feature = "ffi")] | ||||
|                 preserve_header_case: false, | ||||
|             }, | ||||
|         ) | ||||
|         .expect("parse ok") | ||||
| @@ -1727,6 +1814,42 @@ mod tests { | ||||
|         assert_eq!(parsed.head.headers["server"], "hello\tworld"); | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "ffi")] | ||||
|     #[test] | ||||
|     fn test_write_headers_orig_case_empty_value() { | ||||
|         let mut headers = HeaderMap::new(); | ||||
|         let name = http::header::HeaderName::from_static("x-empty"); | ||||
|         headers.insert(&name, "".parse().expect("parse empty")); | ||||
|         let mut orig_cases = crate::ffi::HeaderCaseMap::default(); | ||||
|         orig_cases.insert(name, Bytes::from_static(b"X-EmptY")); | ||||
|  | ||||
|         let mut dst = Vec::new(); | ||||
|         super::write_headers_original_case(&headers, &orig_cases, &mut dst); | ||||
|  | ||||
|         assert_eq!( | ||||
|             dst, b"X-EmptY:\r\n", | ||||
|             "there should be no space between the colon and CRLF" | ||||
|         ); | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "ffi")] | ||||
|     #[test] | ||||
|     fn test_write_headers_orig_case_multiple_entries() { | ||||
|         let mut headers = HeaderMap::new(); | ||||
|         let name = http::header::HeaderName::from_static("x-empty"); | ||||
|         headers.insert(&name, "a".parse().unwrap()); | ||||
|         headers.append(&name, "b".parse().unwrap()); | ||||
|  | ||||
|         let mut orig_cases = crate::ffi::HeaderCaseMap::default(); | ||||
|         orig_cases.insert(name.clone(), Bytes::from_static(b"X-Empty")); | ||||
|         orig_cases.append(name, Bytes::from_static(b"X-EMPTY")); | ||||
|  | ||||
|         let mut dst = Vec::new(); | ||||
|         super::write_headers_original_case(&headers, &orig_cases, &mut dst); | ||||
|  | ||||
|         assert_eq!(dst, b"X-Empty: a\r\nX-EMPTY: b\r\n"); | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "nightly")] | ||||
|     use test::Bencher; | ||||
|  | ||||
| @@ -1762,6 +1885,8 @@ mod tests { | ||||
|                 ParseContext { | ||||
|                     cached_headers: &mut headers, | ||||
|                     req_method: &mut None, | ||||
|                     #[cfg(feature = "ffi")] | ||||
|                     preserve_header_case: false, | ||||
|                 }, | ||||
|             ) | ||||
|             .unwrap() | ||||
| @@ -1795,6 +1920,8 @@ mod tests { | ||||
|                 ParseContext { | ||||
|                     cached_headers: &mut headers, | ||||
|                     req_method: &mut None, | ||||
|                     #[cfg(feature = "ffi")] | ||||
|                     preserve_header_case: false, | ||||
|                 }, | ||||
|             ) | ||||
|             .unwrap() | ||||
|   | ||||
| @@ -24,7 +24,6 @@ pub struct MessageHead<S> { | ||||
|     pub subject: S, | ||||
|     /// Headers of the Incoming message. | ||||
|     pub headers: http::HeaderMap, | ||||
|  | ||||
|     /// Extensions. | ||||
|     extensions: http::Extensions, | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user