diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 8606654..e3e02c2 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -207,13 +207,16 @@ where pub fn send_request( &mut self, - request: Request<()>, + mut request: Request<()>, end_of_stream: bool, pending: Option<&OpaqueStreamRef>, ) -> Result, SendError> { use super::stream::ContentLength; use http::Method; + // Clear before taking lock, incase extensions contain a StreamRef. + request.extensions_mut().clear(); + // TODO: There is a hazard with assigning a stream ID before the // prioritize layer. If prioritization reorders new streams, this // implicitly closes the earlier stream IDs. @@ -1062,9 +1065,11 @@ impl StreamRef { pub fn send_response( &mut self, - response: Response<()>, + mut response: Response<()>, end_of_stream: bool, ) -> Result<(), UserError> { + // Clear before taking lock, incase extensions contain a StreamRef. + response.extensions_mut().clear(); let mut me = self.opaque.inner.lock().unwrap(); let me = &mut *me; @@ -1082,7 +1087,12 @@ impl StreamRef { }) } - pub fn send_push_promise(&mut self, request: Request<()>) -> Result, UserError> { + pub fn send_push_promise( + &mut self, + mut request: Request<()>, + ) -> Result, UserError> { + // Clear before taking lock, incase extensions contain a StreamRef. + request.extensions_mut().clear(); let mut me = self.opaque.inner.lock().unwrap(); let me = &mut *me; diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index 03ce43f..556b53c 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -1059,3 +1059,37 @@ async fn request_without_authority() { join(client, srv).await; } + +#[tokio::test] +async fn serve_when_request_in_response_extensions() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + client + .recv_frame(frames::headers(1).response(200).eos()) + .await; + }; + + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + let (req, mut stream) = srv.next().await.unwrap().unwrap(); + + let mut rsp = http::Response::new(()); + rsp.extensions_mut().insert(req); + stream.send_response(rsp, true).unwrap(); + + assert!(srv.next().await.is_none()); + }; + + join(client, srv).await; +}