diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 6252207b..983ef76e 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -990,14 +990,11 @@ impl Http1Transaction for Client { .h1_parser_config .obsolete_multiline_headers_in_responses_are_allowed() { - for header in &headers_indices[..headers_len] { + for header in &mut headers_indices[..headers_len] { // SAFETY: array is valid up to `headers_len` - let header = unsafe { &*header.as_ptr() }; - for b in &mut slice[header.value.0..header.value.1] { - if *b == b'\r' || *b == b'\n' { - *b = b' '; - } - } + let header = unsafe { &mut *header.as_mut_ptr() }; + Client::obs_fold_line(&mut slice, header); + } } @@ -1344,6 +1341,65 @@ impl Client { set_content_length(headers, len) } + + fn obs_fold_line(all: &mut [u8], idx: &mut HeaderIndices) { + // If the value has obs-folded text, then in-place shift the bytes out + // of here. + // + // https://httpwg.org/specs/rfc9112.html#line.folding + // + // > A user agent that receives an obs-fold MUST replace each received + // > obs-fold with one or more SP octets prior to interpreting the + // > field value. + // + // This means strings like "\r\n\t foo" must replace the "\r\n\t " with + // a single space. + + let buf = &mut all[idx.value.0..idx.value.1]; + + // look for a newline, otherwise bail out + let first_nl = match buf.iter().position(|b| *b == b'\n') { + Some(i) => i, + None => return, + }; + + // not on standard slices because whatever, sigh + fn trim_start(mut s: &[u8]) -> &[u8] { + while let [first, rest @ ..] = s { + if first.is_ascii_whitespace() { + s = rest; + } else { + break; + } + } + s + } + + fn trim_end(mut s: &[u8]) -> &[u8] { + while let [rest @ .., last] = s { + if last.is_ascii_whitespace() { + s = rest; + } else { + break; + } + } + s + } + + fn trim(s: &[u8]) -> &[u8] { + trim_start(trim_end(s)) + } + + // TODO(perf): we could do the moves in-place, but this is so uncommon + // that it shouldn't matter. + let mut unfolded = trim_end(&buf[..first_nl]).to_vec(); + for line in buf[first_nl + 1..].split(|b| *b == b'\n') { + unfolded.push(b' '); + unfolded.extend_from_slice(trim(line)); + } + buf[..unfolded.len()].copy_from_slice(&unfolded); + idx.value.1 = idx.value.0 + unfolded.len(); + } } fn set_content_length(headers: &mut HeaderMap, len: u64) -> Encoder { @@ -2384,6 +2440,30 @@ mod tests { ); } + #[cfg(feature = "client")] + #[test] + fn test_client_obs_fold_line() { + fn unfold(src: &str) -> String { + let mut buf = src.as_bytes().to_vec(); + let mut idx = HeaderIndices { + name: (0, 0), + value: (0, buf.len()), + }; + Client::obs_fold_line(&mut buf, &mut idx); + String::from_utf8(buf[idx.value.0 .. idx.value.1].to_vec()).unwrap() + } + + assert_eq!( + unfold("a normal line"), + "a normal line", + ); + + assert_eq!( + unfold("obs\r\n fold\r\n\t line"), + "obs fold line", + ); + } + #[test] fn test_client_request_encode_title_case() { use crate::proto::BodyLength; diff --git a/tests/client.rs b/tests/client.rs index 34b34907..2f9b3ab8 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1124,6 +1124,38 @@ test! { body: &b"Mmmmh, baguettes."[..], } +test! { + name: client_obs_fold_headers, + + server: + expected: "\ + GET / HTTP/1.1\r\n\ + host: {addr}\r\n\ + \r\n\ + ", + reply: "\ + HTTP/1.1 200 OK\r\n\ + Content-Length: 0\r\n\ + Fold: just\r\n some\r\n\t folding\r\n\ + \r\n\ + ", + + client: + options: { + http1_allow_obsolete_multiline_headers_in_responses: true, + }, + request: { + method: GET, + url: "http://{addr}/", + }, + response: + status: OK, + headers: { + "fold" => "just some folding", + }, + body: None, +} + mod dispatch_impl { use super::*; use std::io::{self, Read, Write}; @@ -2232,63 +2264,6 @@ mod conn { future::join(server, client).await; } - #[tokio::test] - async fn get_obsolete_line_folding() { - let _ = ::pretty_env_logger::try_init(); - let listener = TkTcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))) - .await - .unwrap(); - let addr = listener.local_addr().unwrap(); - - let server = async move { - let mut sock = listener.accept().await.unwrap().0; - let mut buf = [0; 4096]; - let n = sock.read(&mut buf).await.expect("read 1"); - - // Notably: - // - Just a path, since just a path was set - // - No host, since no host was set - let expected = "GET /a HTTP/1.1\r\n\r\n"; - assert_eq!(s(&buf[..n]), expected); - - sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: \r\n 0\r\nLine-Folded-Header: hello\r\n world \r\n \r\n\r\n") - .await - .unwrap(); - }; - - let client = async move { - let tcp = tcp_connect(&addr).await.expect("connect"); - let (mut client, conn) = conn::Builder::new() - .http1_allow_obsolete_multiline_headers_in_responses(true) - .handshake::<_, Body>(tcp) - .await - .expect("handshake"); - - tokio::task::spawn(async move { - conn.await.expect("http conn"); - }); - - let req = Request::builder() - .uri("/a") - .body(Default::default()) - .unwrap(); - let mut res = client.send_request(req).await.expect("send_request"); - assert_eq!(res.status(), hyper::StatusCode::OK); - assert_eq!(res.headers().len(), 2); - assert_eq!( - res.headers().get(http::header::CONTENT_LENGTH).unwrap(), - "0" - ); - assert_eq!( - res.headers().get("line-folded-header").unwrap(), - "hello world" - ); - assert!(res.body_mut().data().await.is_none()); - }; - - future::join(server, client).await; - } - #[tokio::test] async fn get_custom_reason_phrase() { let _ = ::pretty_env_logger::try_init();