From 0c59957d8842cef1a5db4cd4e142ca6a7bd881f1 Mon Sep 17 00:00:00 2001 From: Darren Tsung Date: Thu, 15 Feb 2018 13:14:18 -0800 Subject: [PATCH] When Streams are dropped, close Connection (#221) (#222) When all Streams are dropped / finished, the Connection was held open until the peer hangs up. Instead, the Connection should hang up once it knows that nothing more will be sent. To fix this, we notify the Connection when a stream is no longer referenced. On the Connection poll(), we check that there are no active, held, reset streams or any references to the Streams and transition to sending a GOAWAY if that is case. The specific behavior depends on if running as a client or server. --- src/client.rs | 1 + src/proto/connection.rs | 33 ++++++++++++----- src/proto/peer.rs | 2 + src/proto/streams/counts.rs | 4 ++ src/proto/streams/recv.rs | 1 + src/proto/streams/streams.rs | 31 ++++++++++++++-- src/server.rs | 10 ++++- tests/client_request.rs | 16 ++++++-- tests/codec_read.rs | 2 +- tests/codec_write.rs | 2 + tests/flow_control.rs | 58 ++++++++++++++++------------- tests/ping_pong.rs | 5 ++- tests/server.rs | 72 ++++++++++++++++++++++++++++++++++++ tests/stream_states.rs | 1 + 14 files changed, 193 insertions(+), 45 deletions(-) diff --git a/src/client.rs b/src/client.rs index 97539fa..8c64fbf 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1162,6 +1162,7 @@ where type Error = ::Error; fn poll(&mut self) -> Poll<(), ::Error> { + self.inner.maybe_close_connection_if_no_streams(); self.inner.poll().map_err(Into::into) } } diff --git a/src/proto/connection.rs b/src/proto/connection.rs index 08a58f3..7f23794 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -122,6 +122,27 @@ where Ok(().into()) } + fn transition_to_go_away(&mut self, id: StreamId, e: Reason) { + let goaway = frame::GoAway::new(id, e); + self.state = State::GoAway(goaway); + } + + /// Closes the connection by transitioning to a GOAWAY state + /// iff there are no streams or references + pub fn maybe_close_connection_if_no_streams(&mut self) { + // If we poll() and realize that there are no streams or references + // then we can close the connection by transitioning to GOAWAY + if self.streams.num_active_streams() == 0 && !self.streams.has_streams_or_other_references() { + self.close_connection(); + } + } + + /// Closes the connection by transitioning to a GOAWAY state + pub fn close_connection(&mut self) { + let last_processed_id = self.streams.last_processed_id(); + self.transition_to_go_away(last_processed_id, Reason::NO_ERROR); + } + /// Advances the internal state of the connection. pub fn poll(&mut self) -> Poll<(), proto::Error> { use codec::RecvError::*; @@ -143,9 +164,8 @@ where if self.error.is_some() { if self.streams.num_active_streams() == 0 { - let id = self.streams.last_processed_id(); - let goaway = frame::GoAway::new(id, Reason::NO_ERROR); - self.state = State::GoAway(goaway); + let last_processed_id = self.streams.last_processed_id(); + self.transition_to_go_away(last_processed_id, Reason::NO_ERROR); continue; } } @@ -160,12 +180,7 @@ where // Reset all active streams let last_processed_id = self.streams.recv_err(&e.into()); - - // Create the GO_AWAY frame with the last_processed_id - let frame = frame::GoAway::new(last_processed_id, e); - - // Transition to the going away state. - self.state = State::GoAway(frame); + self.transition_to_go_away(last_processed_id, e); }, // Attempting to read a frame resulted in a stream level error. // This is handled by resetting the frame then trying to read diff --git a/src/proto/peer.rs b/src/proto/peer.rs index 1603e52..8325b88 100644 --- a/src/proto/peer.rs +++ b/src/proto/peer.rs @@ -63,6 +63,7 @@ impl Dyn { /// Returns true if the remote peer can initiate a stream with the given ID. pub fn ensure_can_open(&self, id: StreamId) -> Result<(), RecvError> { if !self.is_server() { + trace!("Cannot open stream {:?} - not server, PROTOCOL_ERROR", id); // Remote is a server and cannot open streams. PushPromise is // registered by reserving, so does not go through this path. return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); @@ -70,6 +71,7 @@ impl Dyn { // Ensure that the ID is a valid server initiated ID if !id.is_client_initiated() { + trace!("Cannot open stream {:?} - not client initiated, PROTOCOL_ERROR", id); return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); } diff --git a/src/proto/streams/counts.rs b/src/proto/streams/counts.rs index 6aed6d2..0e82f15 100644 --- a/src/proto/streams/counts.rs +++ b/src/proto/streams/counts.rs @@ -46,6 +46,10 @@ impl Counts { self.peer } + pub fn has_streams(&self) -> bool { + self.num_send_streams != 0 || self.num_recv_streams != 0 + } + /// Returns true if the receive stream concurrency can be incremented pub fn can_inc_num_recv_streams(&self) -> bool { self.max_recv_streams > self.num_recv_streams diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index 28a263e..775598f 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -119,6 +119,7 @@ impl Recv { let next_id = self.next_stream_id()?; if id < next_id { + trace!("id ({:?}) < next_id ({:?}), PROTOCOL_ERROR", id, next_id); return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); } diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index eb97f63..eeac656 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -1,10 +1,10 @@ -use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; -use super::recv::RecvHeaderBlockError; -use super::store::{self, Entry, Resolve, Store}; use {client, proto, server}; use codec::{Codec, RecvError, SendError, UserError}; use frame::{self, Frame, Reason}; use proto::{peer, Peer, WindowSize}; +use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; +use super::recv::RecvHeaderBlockError; +use super::store::{self, Entry, Resolve, Store}; use bytes::{Buf, Bytes}; use futures::{task, Async, Poll}; @@ -520,8 +520,8 @@ where end_of_stream: bool, pending: Option<&store::Key>, ) -> Result, SendError> { - use super::stream::ContentLength; use http::Method; + use super::stream::ContentLength; // TODO: There is a hazard with assigning a stream ID before the // prioritize layer. If prioritization reorders new streams, this @@ -661,6 +661,19 @@ where me.store.num_active_streams() } + pub fn has_streams_or_other_references(&self) -> bool { + if Arc::strong_count(&self.inner) > 1 { + return true; + } + + if Arc::strong_count(&self.send_buffer) > 1 { + return true; + } + + let me = self.inner.lock().unwrap(); + me.counts.has_streams() + } + #[cfg(feature = "unstable")] pub fn num_wired_streams(&self) -> usize { let me = self.inner.lock().unwrap(); @@ -951,6 +964,16 @@ fn drop_stream_ref(inner: &Mutex, key: store::Key) { let actions = &mut me.actions; + // If the stream is not referenced and it is already + // closed (does not have to go through logic below + // of canceling the stream), we should notify the task + // (connection) so that it can close properly + if stream.ref_count == 0 && stream.is_closed() { + if let Some(task) = actions.task.take() { + task.notify(); + } + } + me.counts.transition(stream, |counts, stream| { maybe_cancel(stream, actions, counts); diff --git a/src/server.rs b/src/server.rs index 7a7752c..3a4c716 100644 --- a/src/server.rs +++ b/src/server.rs @@ -140,9 +140,9 @@ use proto::{self, Config, Prioritized}; use bytes::{Buf, Bytes, IntoBuf}; use futures::{self, Async, Future, Poll}; use http::{Request, Response}; -use tokio_io::{AsyncRead, AsyncWrite}; use std::{convert, fmt, mem}; use std::time::Duration; +use tokio_io::{AsyncRead, AsyncWrite}; /// In progress HTTP/2.0 connection handshake future. /// @@ -424,6 +424,14 @@ where pub fn poll_close(&mut self) -> Poll<(), ::Error> { self.connection.poll().map_err(Into::into) } + + /// Sets the connection to a GOAWAY state. Does not close connection immediately. + /// + /// This closes the stream after sending a GOAWAY frame + /// and flushing the codec. Must continue being polled to close connection. + pub fn close_connection(&mut self) { + self.connection.close_connection(); + } } impl futures::Stream for Connection diff --git a/tests/client_request.rs b/tests/client_request.rs index 1226b47..749d1cb 100644 --- a/tests/client_request.rs +++ b/tests/client_request.rs @@ -13,7 +13,7 @@ fn handshake() { .write(SETTINGS_ACK) .build(); - let (_, h2) = client::handshake(mock).wait().unwrap(); + let (_client, h2) = client::handshake(mock).wait().unwrap(); trace!("hands have been shook"); @@ -129,7 +129,12 @@ fn request_stream_id_overflows() { let err = client.send_request(request, true).unwrap_err(); assert_eq!(err.to_string(), "user error: stream ID overflowed"); - h2.expect("h2") + h2.expect("h2").map(|ret| { + // Hold on to the `client` handle to avoid sending a GO_AWAY + // frame. + drop(client); + ret + }) }) }); @@ -338,7 +343,11 @@ fn http_2_request_without_scheme_or_authority() { // first request is allowed assert!(client.send_request(request, true).is_err()); - h2.expect("h2") + h2.expect("h2").map(|ret| { + // Hold on to the `client` handle to avoid sending a GO_AWAY frame. + drop(client); + ret + }) }); h2.join(srv).wait().expect("wait"); @@ -614,6 +623,7 @@ fn recv_too_big_headers() { conn.drive(req1.join(req2)) .and_then(|(conn, _)| conn.expect("client")) + .map(|c| (c, client)) }); client.join(srv).wait().expect("wait"); diff --git a/tests/codec_read.rs b/tests/codec_read.rs index 4fee476..72e97dc 100644 --- a/tests/codec_read.rs +++ b/tests/codec_read.rs @@ -161,7 +161,7 @@ fn read_continuation_frames() { conn.drive(req) .and_then(move |(h2, _)| { h2.expect("client") - }) + }).map(|c| (client, c)) }); client.join(srv).wait().expect("wait"); diff --git a/tests/codec_write.rs b/tests/codec_write.rs index a0a3da5..507d91a 100644 --- a/tests/codec_write.rs +++ b/tests/codec_write.rs @@ -54,6 +54,8 @@ fn write_continuation_frames() { conn.drive(req) .and_then(move |(h2, _)| { h2.unwrap() + }).map(|c| { + (c, client) }) }); diff --git a/tests/flow_control.rs b/tests/flow_control.rs index 673e2da..fe975fb 100644 --- a/tests/flow_control.rs +++ b/tests/flow_control.rs @@ -307,7 +307,9 @@ fn recv_data_overflows_stream_window() { }) }); - conn.unwrap().join(req) + conn.unwrap() + .join(req) + .map(|c| (c, client)) }); h2.join(mock).wait().unwrap(); } @@ -385,6 +387,7 @@ fn stream_error_release_connection_capacity() { }); conn.drive(req.expect("response")) .and_then(|(conn, _)| conn.expect("client")) + .map(|c| (c, client)) }); srv.join(client).wait().unwrap(); @@ -620,19 +623,19 @@ fn reserved_capacity_assigned_in_multi_window_updates() { h2.drive( util::wait_for_capacity(stream, 5) - .map(|stream| (response, stream))) + .map(|stream| (response, client, stream))) }) - .and_then(|(h2, (response, mut stream))| { + .and_then(|(h2, (response, client, mut stream))| { stream.send_data("hello".into(), false).unwrap(); stream.send_data("world".into(), true).unwrap(); - h2.drive(response) + h2.drive(response).map(|c| (c, client)) }) - .and_then(|(h2, response)| { + .and_then(|((h2, response), client)| { assert_eq!(response.status(), StatusCode::NO_CONTENT); // Wait for the connection to close - h2.unwrap() + h2.unwrap().map(|c| (c, client)) }); let srv = srv.assert_client_handshake().unwrap() @@ -820,20 +823,21 @@ fn recv_settings_removes_available_capacity() { stream.reserve_capacity(11); - h2.drive(util::wait_for_capacity(stream, 11).map(|s| (response, s))) + h2.drive(util::wait_for_capacity(stream, 11).map(|s| (response, client, s))) }) - .and_then(|(h2, (response, mut stream))| { + .and_then(|(h2, (response, client, mut stream))| { assert_eq!(stream.capacity(), 11); stream.send_data("hello world".into(), true).unwrap(); - h2.drive(response) + h2.drive(response).map(|c| (c, client)) }) - .and_then(|(h2, response)| { + .and_then(|((h2, response), client)| { assert_eq!(response.status(), StatusCode::NO_CONTENT); // Wait for the connection to close - h2.unwrap() + // Hold on to the `client` handle to avoid sending a GO_AWAY frame. + h2.unwrap().map(|c| (c, client)) }); let _ = h2.join(srv) @@ -881,20 +885,21 @@ fn recv_no_init_window_then_receive_some_init_window() { stream.reserve_capacity(11); - h2.drive(util::wait_for_capacity(stream, 11).map(|s| (response, s))) + h2.drive(util::wait_for_capacity(stream, 11).map(|s| (response, client, s))) }) - .and_then(|(h2, (response, mut stream))| { + .and_then(|(h2, (response, client, mut stream))| { assert_eq!(stream.capacity(), 11); stream.send_data("hello world".into(), true).unwrap(); - h2.drive(response) + h2.drive(response).map(|c| (c, client)) }) - .and_then(|(h2, response)| { + .and_then(|((h2, response), client)| { assert_eq!(response.status(), StatusCode::NO_CONTENT); // Wait for the connection to close - h2.unwrap() + // Hold on to the `client` handle to avoid sending a GO_AWAY frame. + h2.unwrap().map(|c| (c, client)) }); let _ = h2.join(srv) @@ -903,8 +908,8 @@ fn recv_no_init_window_then_receive_some_init_window() { #[test] fn settings_lowered_capacity_returns_capacity_to_connection() { - use std::thread; use std::sync::mpsc; + use std::thread; let _ = ::env_logger::init(); let (io, srv) = mock::new(); @@ -1038,7 +1043,7 @@ fn client_increase_target_window_size() { .and_then(|(_client, mut conn)| { conn.set_target_window_size(2 << 20); - conn.unwrap() + conn.unwrap().map(|c| (c, _client)) }); srv.join(client).wait().unwrap(); @@ -1078,7 +1083,7 @@ fn increase_target_window_size_after_using_some() { .and_then(|(mut conn, _bytes)| { conn.set_target_window_size(2 << 20); conn.unwrap() - }) + }).map(|c| (c, client)) }); srv.join(client).wait().unwrap(); @@ -1113,18 +1118,19 @@ fn decrease_target_window_size() { .uri("https://http2.akamai.com/") .body(()).unwrap(); let (resp, _) = client.send_request(request, true).unwrap(); - conn.drive(resp.expect("response")) + conn.drive(resp.expect("response")).map(|c| (c, client)) }) - .and_then(|(mut conn, res)| { + .and_then(|((mut conn, res), client)| { conn.set_target_window_size(16_384); let mut body = res.into_parts().1; let mut cap = body.release_capacity().clone(); conn.drive(body.concat2().expect("concat")) - .and_then(move |(conn, bytes)| { + .map(|c| (c, client)) + .and_then(move |((conn, bytes), client)| { assert_eq!(bytes.len(), 65_535); cap.release_capacity(bytes.len()).unwrap(); - conn.expect("conn") + conn.expect("conn").map(|c| (c, client)) }) }); @@ -1188,10 +1194,10 @@ fn recv_settings_increase_window_size_after_using_some() { .body(()).unwrap(); let (resp, mut req_body) = client.send_request(request, false).unwrap(); req_body.send_data(vec![0; new_win_size].into(), true).unwrap(); - conn.drive(resp.expect("response")) + conn.drive(resp.expect("response")).map(|c| (c, client)) }) - .and_then(|(conn, _res)| { - conn.expect("client") + .and_then(|((conn, _res), client)| { + conn.expect("client").map(|c| (c, client)) }); srv.join(client).wait().unwrap(); diff --git a/tests/ping_pong.rs b/tests/ping_pong.rs index a7becd3..fe4e3c4 100644 --- a/tests/ping_pong.rs +++ b/tests/ping_pong.rs @@ -10,7 +10,10 @@ fn recv_single_ping() { // Create the handshake let h2 = client::handshake(m) .unwrap() - .and_then(|(_, conn)| conn.unwrap()); + .and_then(|(client, conn)| { + conn.unwrap() + .map(|c| (client, c)) + }); let mock = mock.assert_client_handshake() .unwrap() diff --git a/tests/server.rs b/tests/server.rs index 290fedf..36144b5 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -204,6 +204,78 @@ fn sends_reset_cancel_when_req_body_is_dropped() { srv.join(client).wait().expect("wait"); } +#[test] +fn sends_goaway_when_serv_closes_connection() { + let _ = ::env_logger::init(); + let (io, client) = mock::new(); + + let client = client + .assert_server_handshake() + .unwrap() + .recv_settings() + .send_frame( + frames::headers(1) + .request("POST", "https://example.com/") + ) + .recv_frame(frames::go_away(1)) + .close(); + + let srv = server::handshake(io).expect("handshake").and_then(|srv| { + srv.into_future().unwrap().and_then(|(_, mut srv)| { + srv.close_connection(); + srv.into_future().unwrap() + }) + }); + + srv.join(client).wait().expect("wait"); +} + +#[test] +fn serve_request_then_serv_closes_connection() { + let _ = ::env_logger::init(); + let (io, client) = mock::new(); + + let client = client + .assert_server_handshake() + .unwrap() + .recv_settings() + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/"), + ) + .recv_frame(frames::headers(1).response(200).eos()) + .recv_frame(frames::reset(1).cancel()) + .send_frame( + frames::headers(3) + .request("GET", "https://example.com/"), + ) + .recv_frame(frames::go_away(3)) + // streams sent after GOAWAY receive no response + .send_frame( + frames::headers(5) + .request("GET", "https://example.com/"), + ) + .close(); + + let srv = server::handshake(io).expect("handshake").and_then(|srv| { + srv.into_future().unwrap().and_then(|(reqstream, srv)| { + let (req, mut stream) = reqstream.unwrap(); + + assert_eq!(req.method(), &http::Method::GET); + + let rsp = http::Response::builder().status(200).body(()).unwrap(); + stream.send_response(rsp, true).unwrap(); + + srv.into_future().unwrap().and_then(|(_reqstream, mut srv)| { + srv.close_connection(); + srv.into_future().unwrap() + }) + }) + }); + + srv.join(client).wait().expect("wait"); +} + #[test] fn sends_reset_cancel_when_res_body_is_dropped() { let _ = ::env_logger::init(); diff --git a/tests/stream_states.rs b/tests/stream_states.rs index 0238a57..23f2333 100644 --- a/tests/stream_states.rs +++ b/tests/stream_states.rs @@ -621,6 +621,7 @@ fn rst_stream_expires() { err.to_string(), "protocol error: unspecific protocol error detected" ); + drop(client); }) });