diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index e92fea2..5f98d09 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -671,15 +671,9 @@ where }; let stream = me.store.resolve(key); - let actions = &mut me.actions; let mut send_buffer = self.send_buffer.inner.lock().unwrap(); let send_buffer = &mut *send_buffer; - - me.counts.transition(stream, |counts, stream| { - actions.send.send_reset( - reason, send_buffer, stream, counts, &mut actions.task); - actions.recv.enqueue_reset_expiration(stream, counts) - }) + me.actions.send_reset(stream, reason, &mut me.counts, send_buffer); } pub fn send_go_away(&mut self, last_processed_id: StreamId) { @@ -848,14 +842,10 @@ impl StreamRef { let me = &mut *me; let stream = me.store.resolve(self.opaque.key); - let actions = &mut me.actions; let mut send_buffer = self.send_buffer.inner.lock().unwrap(); let send_buffer = &mut *send_buffer; - me.counts.transition(stream, |counts, stream| { - actions.send.send_reset( - reason, send_buffer, stream, counts, &mut actions.task) - }) + me.actions.send_reset(stream, reason, &mut me.counts, send_buffer); } pub fn send_response( @@ -1178,6 +1168,22 @@ impl SendBuffer { // ===== impl Actions ===== impl Actions { + fn send_reset( + &mut self, + stream: store::Ptr, + reason: Reason, + counts: &mut Counts, + send_buffer: &mut Buffer>, + ) { + counts.transition(stream, |counts, stream| { + self.send.send_reset( + reason, send_buffer, stream, counts, &mut self.task); + self.recv.enqueue_reset_expiration(stream, counts); + // if a RecvStream is parked, ensure it's notified + stream.notify_recv(); + }); + } + fn reset_on_recv_stream_err( &mut self, buffer: &mut Buffer>, diff --git a/tests/h2-support/src/frames.rs b/tests/h2-support/src/frames.rs index 4a757f5..77cfa9b 100644 --- a/tests/h2-support/src/frames.rs +++ b/tests/h2-support/src/frames.rs @@ -321,6 +321,11 @@ impl Mock { let id = self.0.stream_id(); Mock(frame::Reset::new(id, frame::Reason::INTERNAL_ERROR)) } + + pub fn reason(self, reason: frame::Reason) -> Self { + let id = self.0.stream_id(); + Mock(frame::Reset::new(id, reason)) + } } impl From> for SendFrame { diff --git a/tests/h2-tests/tests/client_request.rs b/tests/h2-tests/tests/client_request.rs index 958fd6a..07627c6 100644 --- a/tests/h2-tests/tests/client_request.rs +++ b/tests/h2-tests/tests/client_request.rs @@ -360,6 +360,91 @@ fn send_request_poll_ready_when_connection_error() { h2.join(srv).wait().expect("wait"); } +#[test] +fn send_reset_notifies_recv_stream() { + let _ = ::env_logger::try_init(); + let (io, srv) = mock::new(); + + + let srv = srv.assert_client_handshake() + .unwrap() + .recv_settings() + .recv_frame( + frames::headers(1) + .request("POST", "https://example.com/") + ) + .send_frame(frames::headers(1).response(200)) + .recv_frame(frames::reset(1).refused()) + .recv_frame( + frames::headers(3) + .request("POST", "https://example.com/") + .eos() + ) + .send_frame( + frames::headers(3) + .response(200) + .eos() + ) + .close(); + + let client = client::handshake(io) + .expect("handshake") + .and_then(|(mut client, conn)| { + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + + // first request is allowed + let (resp1, tx) = client.send_request(request, false).unwrap(); + + conn.drive(resp1) + .map(move |(conn, res)| (client, conn, tx, res)) + }) + .and_then(|(client, conn, mut tx, res)| { + let tx = futures::future::poll_fn(move || { + tx.send_reset(h2::Reason::REFUSED_STREAM); + Ok(().into()) + }); + + + let rx = res + .into_body() + .for_each(|_| -> Result<(), _> { + unreachable!("no response body expected") + }); + // a FuturesUnordered is used on purpose! + // + // We don't want a join, since any of the other futures notifying + // will make the rx future polled again, but we are + // specifically testing that rx gets notified on its own. + let mut unordered = futures::stream::FuturesUnordered::>>::new(); + unordered.push(Box::new(rx.expect_err("RecvBody").then(|_| Ok(())))); + unordered.push(Box::new(tx)); + + conn.drive(unordered.for_each(|_| Ok(()))) + .map(move |(conn, _)| (client, conn)) + }) + .and_then(|(mut client, conn)| { + // send a second request just to keep the connection alive until + // we know the previous `RecvStream` was notified about the reset. + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + + let (resp2, _) = client.send_request(request, true).unwrap(); + let fut = resp2.map(|_res| ()); + + conn.drive(fut) + .and_then(|(conn, _)| conn.expect("client")) + }); + + client.join(srv).wait().expect("wait"); +} + #[test] fn http_11_request_without_scheme_or_authority() { let _ = ::env_logger::try_init();