Fix poll_capacity to wake in combination with max_send_buffer_size

This commit is contained in:
Sean McArthur
2021-12-08 17:38:41 -08:00
parent 88037ae0ab
commit a5c60b24de
3 changed files with 90 additions and 0 deletions

View File

@@ -741,6 +741,11 @@ impl Prioritize {
stream.buffered_send_data -= len as usize;
stream.requested_send_capacity -= len;
// If the capacity was limited because of the
// max_send_buffer_size, then consider waking
// the send task again...
stream.notify_if_can_buffer_more();
// Assign the capacity back to the connection that
// was just consumed from the stream in the previous
// line.

View File

@@ -279,6 +279,17 @@ impl Stream {
}
}
/// If the capacity was limited because of the max_send_buffer_size,
/// then consider waking the send task again...
pub fn notify_if_can_buffer_more(&mut self) {
// Only notify if the capacity exceeds the amount of buffered data
if self.send_flow.available() > self.buffered_send_data {
self.send_capacity_inc = true;
tracing::trace!(" notifying task");
self.notify_send();
}
}
/// Returns `Err` when the decrement cannot be completed due to overflow.
pub fn dec_content_length(&mut self, len: usize) -> Result<(), ()> {
match self.content_length {

View File

@@ -1668,3 +1668,77 @@ async fn max_send_buffer_size_overflow() {
join(srv, client).await;
}
#[tokio::test]
async fn max_send_buffer_size_poll_capacity_wakes_task() {
h2_support::trace_init!();
let (io, mut srv) = mock::new();
let srv = async move {
let settings = srv.assert_client_handshake().await;
assert_default_settings!(settings);
srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/"))
.await;
srv.send_frame(frames::headers(1).response(200).eos()).await;
srv.recv_frame(frames::data(1, &[0; 5][..])).await;
srv.recv_frame(frames::data(1, &[0; 5][..])).await;
srv.recv_frame(frames::data(1, &[0; 5][..])).await;
srv.recv_frame(frames::data(1, &[0; 5][..])).await;
srv.recv_frame(frames::data(1, &[][..]).eos()).await;
};
let client = async move {
let (mut client, mut conn) = client::Builder::new()
.max_send_buffer_size(5)
.handshake::<_, Bytes>(io)
.await
.unwrap();
let request = Request::builder()
.method(Method::POST)
.uri("https://www.example.com/")
.body(())
.unwrap();
let (response, mut stream) = client.send_request(request, false).unwrap();
let response = conn.drive(response).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(stream.capacity(), 0);
const TO_SEND: usize = 20;
stream.reserve_capacity(TO_SEND);
assert_eq!(
stream.capacity(),
5,
"polled capacity not over max buffer size"
);
let t1 = tokio::spawn(async move {
let mut sent = 0;
let buf = [0; TO_SEND];
loop {
match poll_fn(|cx| stream.poll_capacity(cx)).await {
None => panic!("no cap"),
Some(Err(e)) => panic!("cap error: {:?}", e),
Some(Ok(cap)) => {
stream
.send_data(buf[sent..(sent + cap)].to_vec().into(), false)
.unwrap();
sent += cap;
if sent >= TO_SEND {
break;
}
}
}
}
stream.send_data(Bytes::new(), true).unwrap();
});
// Wait for the connection to close
conn.await.unwrap();
t1.await.unwrap();
};
join(srv, client).await;
}