From 4c32daeea00b5ba6621a2ab9142c08f6ac9fe7ae Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Thu, 7 Jan 2021 17:22:12 -0800 Subject: [PATCH] refactor(ffi): Add HeaderCaseMap preserving http1 header casing --- capi/examples/client.c | 2 +- src/ffi/body.rs | 2 +- src/ffi/client.rs | 8 +- src/ffi/error.rs | 2 +- src/ffi/http_types.rs | 177 ++++++++++++++++++++++++++++++++++----- src/ffi/mod.rs | 1 + src/proto/h1/conn.rs | 16 ++++ src/proto/h1/dispatch.rs | 4 +- src/proto/h1/io.rs | 4 + src/proto/h1/mod.rs | 2 + src/proto/h1/role.rs | 137 ++++++++++++++++++++++++++++-- src/proto/mod.rs | 1 - 12 files changed, 321 insertions(+), 35 deletions(-) diff --git a/capi/examples/client.c b/capi/examples/client.c index 6ed66a46..f8f1805f 100644 --- a/capi/examples/client.c +++ b/capi/examples/client.c @@ -228,7 +228,7 @@ int main(int argc, char *argv[]) { } hyper_headers *req_headers = hyper_request_headers(req); - hyper_headers_set(req_headers, STR_ARG("host"), STR_ARG(host)); + hyper_headers_set(req_headers, STR_ARG("Host"), STR_ARG(host)); // Send it! hyper_task *send = hyper_clientconn_send(client, req); diff --git a/src/ffi/body.rs b/src/ffi/body.rs index 1c8f1a48..14013fc3 100644 --- a/src/ffi/body.rs +++ b/src/ffi/body.rs @@ -24,7 +24,7 @@ pub(crate) struct UserBody { type hyper_body_foreach_callback = extern "C" fn(*mut c_void, *const hyper_buf) -> c_int; type hyper_body_data_callback = - extern "C" fn(*mut c_void, *mut hyper_context, *mut *mut hyper_buf) -> c_int; + extern "C" fn(*mut c_void, *mut hyper_context<'_>, *mut *mut hyper_buf) -> c_int; ffi_fn! { /// Create a new "empty" body. diff --git a/src/ffi/client.rs b/src/ffi/client.rs index 2c2ef6b2..def46441 100644 --- a/src/ffi/client.rs +++ b/src/ffi/client.rs @@ -67,11 +67,15 @@ ffi_fn! { return std::ptr::null_mut(); } - let req = unsafe { Box::from_raw(req) }; + let mut req = unsafe { Box::from_raw(req) }; + + // Update request with original-case map of headers + req.finalize_request(); + let fut = unsafe { &mut *conn }.tx.send_request(req.0); let fut = async move { - fut.await.map(hyper_response) + fut.await.map(hyper_response::wrap) }; Box::into_raw(Task::boxed(fut)) diff --git a/src/ffi/error.rs b/src/ffi/error.rs index 8cd672fe..5dfca54e 100644 --- a/src/ffi/error.rs +++ b/src/ffi/error.rs @@ -33,7 +33,7 @@ impl hyper_error { ErrorKind::IncompleteMessage => hyper_code::HYPERE_UNEXPECTED_EOF, ErrorKind::User(User::AbortedByCallback) => hyper_code::HYPERE_ABORTED_BY_CALLBACK, // TODO: add more variants - _ => hyper_code::HYPERE_ERROR + _ => hyper_code::HYPERE_ERROR, } } diff --git a/src/ffi/http_types.rs b/src/ffi/http_types.rs index 49e2027c..fdf645ca 100644 --- a/src/ffi/http_types.rs +++ b/src/ffi/http_types.rs @@ -1,3 +1,4 @@ +use bytes::Bytes; use libc::{c_int, size_t}; use std::ffi::c_void; @@ -8,13 +9,21 @@ use super::HYPER_ITER_CONTINUE; use crate::header::{HeaderName, HeaderValue}; use crate::{Body, HeaderMap, Method, Request, Response, Uri}; -// ===== impl Request ===== - pub struct hyper_request(pub(super) Request); pub struct hyper_response(pub(super) Response); -pub struct hyper_headers(pub(super) HeaderMap); +#[derive(Default)] +pub struct hyper_headers { + pub(super) headers: HeaderMap, + orig_casing: HeaderCaseMap, +} + +// Will probably be moved to `hyper::ext::http1` +#[derive(Debug, Default)] +pub(crate) struct HeaderCaseMap(HeaderMap); + +// ===== impl hyper_request ===== ffi_fn! { /// Construct a new HTTP request. @@ -96,7 +105,7 @@ ffi_fn! { /// This is not an owned reference, so it should not be accessed after the /// `hyper_request` has been consumed. fn hyper_request_headers(req: *mut hyper_request) -> *mut hyper_headers { - hyper_headers::wrap(unsafe { &mut *req }.0.headers_mut()) + hyper_headers::get_or_default(unsafe { &mut *req }.0.extensions_mut()) } } @@ -114,7 +123,16 @@ ffi_fn! { } } -// ===== impl Response ===== +impl hyper_request { + pub(super) fn finalize_request(&mut self) { + if let Some(headers) = self.0.extensions_mut().remove::() { + *self.0.headers_mut() = headers.headers; + self.0.extensions_mut().insert(headers.orig_casing); + } + } +} + +// ===== impl hyper_response ===== ffi_fn! { /// Free an HTTP response after using it. @@ -159,7 +177,7 @@ ffi_fn! { /// This is not an owned reference, so it should not be accessed after the /// `hyper_response` has been freed. fn hyper_response_headers(resp: *mut hyper_response) -> *mut hyper_headers { - hyper_headers::wrap(unsafe { &mut *resp }.0.headers_mut()) + hyper_headers::get_or_default(unsafe { &mut *resp }.0.extensions_mut()) } } @@ -173,6 +191,22 @@ ffi_fn! { } } +impl hyper_response { + pub(super) fn wrap(mut resp: Response) -> hyper_response { + let headers = std::mem::take(resp.headers_mut()); + let orig_casing = resp + .extensions_mut() + .remove::() + .unwrap_or_default(); + resp.extensions_mut().insert(hyper_headers { + headers, + orig_casing, + }); + + hyper_response(resp) + } +} + unsafe impl AsTaskType for hyper_response { fn as_task_type(&self) -> hyper_task_return_type { hyper_task_return_type::HYPER_TASK_RESPONSE @@ -185,9 +219,15 @@ type hyper_headers_foreach_callback = extern "C" fn(*mut c_void, *const u8, size_t, *const u8, size_t) -> c_int; impl hyper_headers { - pub(crate) fn wrap(cx: &mut HeaderMap) -> &mut hyper_headers { - // A struct with only one field has the same layout as that field. - unsafe { std::mem::transmute::<&mut HeaderMap, &mut hyper_headers>(cx) } + pub(super) fn get_or_default(ext: &mut http::Extensions) -> &mut hyper_headers { + if let None = ext.get_mut::() { + ext.insert(hyper_headers { + headers: Default::default(), + orig_casing: Default::default(), + }); + } + + ext.get_mut::().unwrap() } } @@ -199,14 +239,31 @@ ffi_fn! { /// The callback should return `HYPER_ITER_CONTINUE` to keep iterating, or /// `HYPER_ITER_BREAK` to stop. fn hyper_headers_foreach(headers: *const hyper_headers, func: hyper_headers_foreach_callback, userdata: *mut c_void) { - for (name, value) in unsafe { &*headers }.0.iter() { - let name_ptr = name.as_str().as_bytes().as_ptr(); - let name_len = name.as_str().as_bytes().len(); - let val_ptr = value.as_bytes().as_ptr(); - let val_len = value.as_bytes().len(); + let headers = unsafe { &*headers }; + // For each header name/value pair, there may be a value in the casemap + // that corresponds to the HeaderValue. So, we iterator all the keys, + // and for each one, try to pair the originally cased name with the value. + // + // TODO: consider adding http::HeaderMap::entries() iterator + for name in headers.headers.keys() { + let mut names = headers.orig_casing.get_all(name).iter(); - if HYPER_ITER_CONTINUE != func(userdata, name_ptr, name_len, val_ptr, val_len) { - break; + for value in headers.headers.get_all(name) { + let (name_ptr, name_len) = if let Some(orig_name) = names.next() { + (orig_name.as_ptr(), orig_name.len()) + } else { + ( + name.as_str().as_bytes().as_ptr(), + name.as_str().as_bytes().len(), + ) + }; + + let val_ptr = value.as_bytes().as_ptr(); + let val_len = value.as_bytes().len(); + + if HYPER_ITER_CONTINUE != func(userdata, name_ptr, name_len, val_ptr, val_len) { + return; + } } } } @@ -219,8 +276,9 @@ ffi_fn! { fn hyper_headers_set(headers: *mut hyper_headers, name: *const u8, name_len: size_t, value: *const u8, value_len: size_t) -> hyper_code { let headers = unsafe { &mut *headers }; match unsafe { raw_name_value(name, name_len, value, value_len) } { - Ok((name, value)) => { - headers.0.insert(name, value); + Ok((name, value, orig_name)) => { + headers.headers.insert(&name, value); + headers.orig_casing.insert(name, orig_name); hyper_code::HYPERE_OK } Err(code) => code, @@ -237,8 +295,9 @@ ffi_fn! { let headers = unsafe { &mut *headers }; match unsafe { raw_name_value(name, name_len, value, value_len) } { - Ok((name, value)) => { - headers.0.append(name, value); + Ok((name, value, orig_name)) => { + headers.headers.append(&name, value); + headers.orig_casing.append(name, orig_name); hyper_code::HYPERE_OK } Err(code) => code, @@ -251,8 +310,9 @@ unsafe fn raw_name_value( name_len: size_t, value: *const u8, value_len: size_t, -) -> Result<(HeaderName, HeaderValue), hyper_code> { +) -> Result<(HeaderName, HeaderValue, Bytes), hyper_code> { let name = std::slice::from_raw_parts(name, name_len); + let orig_name = Bytes::copy_from_slice(name); let name = match HeaderName::from_bytes(name) { Ok(name) => name, Err(_) => return Err(hyper_code::HYPERE_INVALID_ARG), @@ -263,5 +323,78 @@ unsafe fn raw_name_value( Err(_) => return Err(hyper_code::HYPERE_INVALID_ARG), }; - Ok((name, value)) + Ok((name, value, orig_name)) +} + +// ===== impl HeaderCaseMap ===== + +impl HeaderCaseMap { + pub(crate) fn get_all(&self, name: &HeaderName) -> http::header::GetAll<'_, Bytes> { + self.0.get_all(name) + } + + pub(crate) fn insert(&mut self, name: HeaderName, orig: Bytes) { + self.0.insert(name, orig); + } + + pub(crate) fn append(&mut self, name: N, orig: Bytes) + where + N: http::header::IntoHeaderName, + { + self.0.append(name, orig); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_headers_foreach_cases_preserved() { + let mut headers = hyper_headers::default(); + + let name1 = b"Set-CookiE"; + let value1 = b"a=b"; + hyper_headers_add( + &mut headers, + name1.as_ptr(), + name1.len(), + value1.as_ptr(), + value1.len(), + ); + + let name2 = b"SET-COOKIE"; + let value2 = b"c=d"; + hyper_headers_add( + &mut headers, + name2.as_ptr(), + name2.len(), + value2.as_ptr(), + value2.len(), + ); + + let mut vec = Vec::::new(); + hyper_headers_foreach(&headers, concat, &mut vec as *mut _ as *mut c_void); + + assert_eq!(vec, b"Set-CookiE: a=b\r\nSET-COOKIE: c=d\r\n"); + + extern "C" fn concat( + vec: *mut c_void, + name: *const u8, + name_len: usize, + value: *const u8, + value_len: usize, + ) -> c_int { + unsafe { + let vec = &mut *(vec as *mut Vec); + let name = std::slice::from_raw_parts(name, name_len); + let value = std::slice::from_raw_parts(value, value_len); + vec.extend(name); + vec.extend(b": "); + vec.extend(value); + vec.extend(b"\r\n"); + } + HYPER_ITER_CONTINUE + } + } } diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs index cee653d7..ffa9d6b1 100644 --- a/src/ffi/mod.rs +++ b/src/ffi/mod.rs @@ -28,6 +28,7 @@ mod io; mod task; pub(crate) use self::body::UserBody; +pub(crate) use self::http_types::HeaderCaseMap; pub const HYPER_ITER_CONTINUE: libc::c_int = 0; #[allow(unused)] diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 3226aaf8..9866e133 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -44,6 +44,8 @@ where error: None, keep_alive: KA::Busy, method: None, + #[cfg(feature = "ffi")] + preserve_header_case: false, title_case_headers: false, notify_read: false, reading: Reading::Init, @@ -142,6 +144,8 @@ where ParseContext { cached_headers: &mut self.state.cached_headers, req_method: &mut self.state.method, + #[cfg(feature = "ffi")] + preserve_header_case: self.state.preserve_header_case, } )) { Ok(msg) => msg, @@ -474,6 +478,16 @@ where self.enforce_version(&mut head); + // Maybe check if we should preserve header casing on received + // message headers... + #[cfg(feature = "ffi")] + { + if T::is_client() && !self.state.preserve_header_case { + self.state.preserve_header_case = + head.extensions.get::().is_some(); + } + } + let buf = self.io.headers_buf(); match super::role::encode_headers::( Encode { @@ -736,6 +750,8 @@ struct State { /// This is used to know things such as if the message can include /// a body or not. method: Option, + #[cfg(feature = "ffi")] + preserve_header_case: bool, title_case_headers: bool, /// Set to true when the Dispatcher should poll read operations /// again. See the `maybe_notify` method for more. diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index ab8616fe..8bbb0333 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -492,7 +492,7 @@ cfg_server! { version: parts.version, subject: parts.status, headers: parts.headers, - extensions: http::Extensions::default(), + extensions: parts.extensions, }; Poll::Ready(Some(Ok((head, body)))) } else { @@ -576,7 +576,7 @@ cfg_client! { version: parts.version, subject: crate::proto::RequestLine(parts.method, parts.uri), headers: parts.headers, - extensions: http::Extensions::default(), + extensions: parts.extensions, }; *this.callback = Some(cb); Poll::Ready(Some(Ok((head, body)))) diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index 85e4c016..da0ff820 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -159,6 +159,8 @@ where ParseContext { cached_headers: parse_ctx.cached_headers, req_method: parse_ctx.req_method, + #[cfg(feature = "ffi")] + preserve_header_case: parse_ctx.preserve_header_case, }, )? { Some(msg) => { @@ -636,6 +638,8 @@ mod tests { let parse_ctx = ParseContext { cached_headers: &mut None, req_method: &mut None, + #[cfg(feature = "ffi")] + preserve_header_case: false, }; assert!(buffered .parse::(cx, parse_ctx) diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 4e1b1685..10aa0962 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -70,6 +70,8 @@ pub(crate) struct ParsedMessage { pub(crate) struct ParseContext<'a> { cached_headers: &'a mut Option, req_method: &'a mut Option, + #[cfg(feature = "ffi")] + preserve_header_case: bool, } /// Passed to Http1Transaction::encode diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 2a3b1fdd..95015bff 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -148,6 +148,7 @@ impl Http1Transaction for Server { is_http_11 = false; Version::HTTP_10 }; + trace!("headers: {:?}", &req.headers); record_header_indices(bytes, &req.headers, &mut headers_indices)?; headers_len = req.headers.len(); @@ -692,6 +693,9 @@ impl Http1Transaction for Client { let mut keep_alive = version == Version::HTTP_11; + #[cfg(feature = "ffi")] + let mut header_case_map = crate::ffi::HeaderCaseMap::default(); + headers.reserve(headers_len); for header in &headers_indices[..headers_len] { let name = header_name!(&slice[header.name.0..header.name.1]); @@ -707,14 +711,28 @@ impl Http1Transaction for Client { keep_alive = headers::connection_keep_alive(&value); } } + + #[cfg(feature = "ffi")] + if ctx.preserve_header_case { + header_case_map.append(&name, slice.slice(header.name.0..header.name.1)); + } + headers.append(name, value); } + #[allow(unused_mut)] + let mut extensions = http::Extensions::default(); + + #[cfg(feature = "ffi")] + if ctx.preserve_header_case { + extensions.insert(header_case_map); + } + let head = MessageHead { version, subject: status, headers, - extensions: http::Extensions::default(), + extensions, }; if let Some((decode, is_upgrade)) = Client::decoder(&head, ctx.req_method)? { return Ok(Some(ParsedMessage { @@ -766,11 +784,28 @@ impl Http1Transaction for Client { } extend(dst, b"\r\n"); - if msg.title_case_headers { - write_headers_title_case(&msg.head.headers, dst); - } else { - write_headers(&msg.head.headers, dst); + #[cfg(feature = "ffi")] + { + if msg.title_case_headers { + write_headers_title_case(&msg.head.headers, dst); + } else if let Some(orig_headers) = + msg.head.extensions.get::() + { + write_headers_original_case(&msg.head.headers, orig_headers, dst); + } else { + write_headers(&msg.head.headers, dst); + } } + + #[cfg(not(feature = "ffi"))] + { + if msg.title_case_headers { + write_headers_title_case(&msg.head.headers, dst); + } else { + write_headers(&msg.head.headers, dst); + } + } + extend(dst, b"\r\n"); msg.head.headers.clear(); //TODO: remove when switching to drain() @@ -1081,6 +1116,40 @@ fn write_headers(headers: &HeaderMap, dst: &mut Vec) { } } +#[cfg(feature = "ffi")] +#[cold] +fn write_headers_original_case( + headers: &HeaderMap, + orig_case: &crate::ffi::HeaderCaseMap, + dst: &mut Vec, +) { + // For each header name/value pair, there may be a value in the casemap + // that corresponds to the HeaderValue. So, we iterator all the keys, + // and for each one, try to pair the originally cased name with the value. + // + // TODO: consider adding http::HeaderMap::entries() iterator + for name in headers.keys() { + let mut names = orig_case.get_all(name).iter(); + + for value in headers.get_all(name) { + if let Some(orig_name) = names.next() { + extend(dst, orig_name); + } else { + extend(dst, name.as_str().as_bytes()); + } + + // Wanted for curl test cases that send `X-Custom-Header:\r\n` + if value.is_empty() { + extend(dst, b":\r\n"); + } else { + extend(dst, b": "); + extend(dst, value.as_bytes()); + extend(dst, b"\r\n"); + } + } + } +} + struct FastWrite<'a>(&'a mut Vec); impl<'a> fmt::Write for FastWrite<'a> { @@ -1117,6 +1186,8 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut method, + #[cfg(feature = "ffi")] + preserve_header_case: false, }, ) .unwrap() @@ -1137,6 +1208,8 @@ mod tests { let ctx = ParseContext { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), + #[cfg(feature = "ffi")] + preserve_header_case: false, }; let msg = Client::parse(&mut raw, ctx).unwrap().unwrap(); assert_eq!(raw.len(), 0); @@ -1152,6 +1225,8 @@ mod tests { let ctx = ParseContext { cached_headers: &mut None, req_method: &mut None, + #[cfg(feature = "ffi")] + preserve_header_case: false, }; Server::parse(&mut raw, ctx).unwrap_err(); } @@ -1165,6 +1240,8 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut None, + #[cfg(feature = "ffi")] + preserve_header_case: false, }, ) .expect("parse ok") @@ -1178,6 +1255,8 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut None, + #[cfg(feature = "ffi")] + preserve_header_case: false, }, ) .expect_err(comment) @@ -1380,6 +1459,8 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut Some(Method::GET), + #[cfg(feature = "ffi")] + preserve_header_case: false, } ) .expect("parse ok") @@ -1393,6 +1474,8 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut Some(m), + #[cfg(feature = "ffi")] + preserve_header_case: false, }, ) .expect("parse ok") @@ -1406,6 +1489,8 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut Some(Method::GET), + #[cfg(feature = "ffi")] + preserve_header_case: false, }, ) .expect_err("parse should err") @@ -1719,6 +1804,8 @@ mod tests { ParseContext { cached_headers: &mut None, req_method: &mut Some(Method::GET), + #[cfg(feature = "ffi")] + preserve_header_case: false, }, ) .expect("parse ok") @@ -1727,6 +1814,42 @@ mod tests { assert_eq!(parsed.head.headers["server"], "hello\tworld"); } + #[cfg(feature = "ffi")] + #[test] + fn test_write_headers_orig_case_empty_value() { + let mut headers = HeaderMap::new(); + let name = http::header::HeaderName::from_static("x-empty"); + headers.insert(&name, "".parse().expect("parse empty")); + let mut orig_cases = crate::ffi::HeaderCaseMap::default(); + orig_cases.insert(name, Bytes::from_static(b"X-EmptY")); + + let mut dst = Vec::new(); + super::write_headers_original_case(&headers, &orig_cases, &mut dst); + + assert_eq!( + dst, b"X-EmptY:\r\n", + "there should be no space between the colon and CRLF" + ); + } + + #[cfg(feature = "ffi")] + #[test] + fn test_write_headers_orig_case_multiple_entries() { + let mut headers = HeaderMap::new(); + let name = http::header::HeaderName::from_static("x-empty"); + headers.insert(&name, "a".parse().unwrap()); + headers.append(&name, "b".parse().unwrap()); + + let mut orig_cases = crate::ffi::HeaderCaseMap::default(); + orig_cases.insert(name.clone(), Bytes::from_static(b"X-Empty")); + orig_cases.append(name, Bytes::from_static(b"X-EMPTY")); + + let mut dst = Vec::new(); + super::write_headers_original_case(&headers, &orig_cases, &mut dst); + + assert_eq!(dst, b"X-Empty: a\r\nX-EMPTY: b\r\n"); + } + #[cfg(feature = "nightly")] use test::Bencher; @@ -1762,6 +1885,8 @@ mod tests { ParseContext { cached_headers: &mut headers, req_method: &mut None, + #[cfg(feature = "ffi")] + preserve_header_case: false, }, ) .unwrap() @@ -1795,6 +1920,8 @@ mod tests { ParseContext { cached_headers: &mut headers, req_method: &mut None, + #[cfg(feature = "ffi")] + preserve_header_case: false, }, ) .unwrap() diff --git a/src/proto/mod.rs b/src/proto/mod.rs index 27b3ef6f..fe2e2e92 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -24,7 +24,6 @@ pub struct MessageHead { pub subject: S, /// Headers of the Incoming message. pub headers: http::HeaderMap, - /// Extensions. extensions: http::Extensions, }