diff --git a/src/client.rs b/src/client.rs index 97539fa..8c64fbf 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1162,6 +1162,7 @@ where type Error = ::Error; fn poll(&mut self) -> Poll<(), ::Error> { + self.inner.maybe_close_connection_if_no_streams(); self.inner.poll().map_err(Into::into) } } diff --git a/src/proto/connection.rs b/src/proto/connection.rs index 08a58f3..7f23794 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -122,6 +122,27 @@ where Ok(().into()) } + fn transition_to_go_away(&mut self, id: StreamId, e: Reason) { + let goaway = frame::GoAway::new(id, e); + self.state = State::GoAway(goaway); + } + + /// Closes the connection by transitioning to a GOAWAY state + /// iff there are no streams or references + pub fn maybe_close_connection_if_no_streams(&mut self) { + // If we poll() and realize that there are no streams or references + // then we can close the connection by transitioning to GOAWAY + if self.streams.num_active_streams() == 0 && !self.streams.has_streams_or_other_references() { + self.close_connection(); + } + } + + /// Closes the connection by transitioning to a GOAWAY state + pub fn close_connection(&mut self) { + let last_processed_id = self.streams.last_processed_id(); + self.transition_to_go_away(last_processed_id, Reason::NO_ERROR); + } + /// Advances the internal state of the connection. pub fn poll(&mut self) -> Poll<(), proto::Error> { use codec::RecvError::*; @@ -143,9 +164,8 @@ where if self.error.is_some() { if self.streams.num_active_streams() == 0 { - let id = self.streams.last_processed_id(); - let goaway = frame::GoAway::new(id, Reason::NO_ERROR); - self.state = State::GoAway(goaway); + let last_processed_id = self.streams.last_processed_id(); + self.transition_to_go_away(last_processed_id, Reason::NO_ERROR); continue; } } @@ -160,12 +180,7 @@ where // Reset all active streams let last_processed_id = self.streams.recv_err(&e.into()); - - // Create the GO_AWAY frame with the last_processed_id - let frame = frame::GoAway::new(last_processed_id, e); - - // Transition to the going away state. - self.state = State::GoAway(frame); + self.transition_to_go_away(last_processed_id, e); }, // Attempting to read a frame resulted in a stream level error. // This is handled by resetting the frame then trying to read diff --git a/src/proto/peer.rs b/src/proto/peer.rs index 1603e52..8325b88 100644 --- a/src/proto/peer.rs +++ b/src/proto/peer.rs @@ -63,6 +63,7 @@ impl Dyn { /// Returns true if the remote peer can initiate a stream with the given ID. pub fn ensure_can_open(&self, id: StreamId) -> Result<(), RecvError> { if !self.is_server() { + trace!("Cannot open stream {:?} - not server, PROTOCOL_ERROR", id); // Remote is a server and cannot open streams. PushPromise is // registered by reserving, so does not go through this path. return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); @@ -70,6 +71,7 @@ impl Dyn { // Ensure that the ID is a valid server initiated ID if !id.is_client_initiated() { + trace!("Cannot open stream {:?} - not client initiated, PROTOCOL_ERROR", id); return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); } diff --git a/src/proto/streams/counts.rs b/src/proto/streams/counts.rs index 6aed6d2..0e82f15 100644 --- a/src/proto/streams/counts.rs +++ b/src/proto/streams/counts.rs @@ -46,6 +46,10 @@ impl Counts { self.peer } + pub fn has_streams(&self) -> bool { + self.num_send_streams != 0 || self.num_recv_streams != 0 + } + /// Returns true if the receive stream concurrency can be incremented pub fn can_inc_num_recv_streams(&self) -> bool { self.max_recv_streams > self.num_recv_streams diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index 28a263e..775598f 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -119,6 +119,7 @@ impl Recv { let next_id = self.next_stream_id()?; if id < next_id { + trace!("id ({:?}) < next_id ({:?}), PROTOCOL_ERROR", id, next_id); return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); } diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index eb97f63..eeac656 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -1,10 +1,10 @@ -use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; -use super::recv::RecvHeaderBlockError; -use super::store::{self, Entry, Resolve, Store}; use {client, proto, server}; use codec::{Codec, RecvError, SendError, UserError}; use frame::{self, Frame, Reason}; use proto::{peer, Peer, WindowSize}; +use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; +use super::recv::RecvHeaderBlockError; +use super::store::{self, Entry, Resolve, Store}; use bytes::{Buf, Bytes}; use futures::{task, Async, Poll}; @@ -520,8 +520,8 @@ where end_of_stream: bool, pending: Option<&store::Key>, ) -> Result, SendError> { - use super::stream::ContentLength; use http::Method; + use super::stream::ContentLength; // TODO: There is a hazard with assigning a stream ID before the // prioritize layer. If prioritization reorders new streams, this @@ -661,6 +661,19 @@ where me.store.num_active_streams() } + pub fn has_streams_or_other_references(&self) -> bool { + if Arc::strong_count(&self.inner) > 1 { + return true; + } + + if Arc::strong_count(&self.send_buffer) > 1 { + return true; + } + + let me = self.inner.lock().unwrap(); + me.counts.has_streams() + } + #[cfg(feature = "unstable")] pub fn num_wired_streams(&self) -> usize { let me = self.inner.lock().unwrap(); @@ -951,6 +964,16 @@ fn drop_stream_ref(inner: &Mutex, key: store::Key) { let actions = &mut me.actions; + // If the stream is not referenced and it is already + // closed (does not have to go through logic below + // of canceling the stream), we should notify the task + // (connection) so that it can close properly + if stream.ref_count == 0 && stream.is_closed() { + if let Some(task) = actions.task.take() { + task.notify(); + } + } + me.counts.transition(stream, |counts, stream| { maybe_cancel(stream, actions, counts); diff --git a/src/server.rs b/src/server.rs index 7a7752c..3a4c716 100644 --- a/src/server.rs +++ b/src/server.rs @@ -140,9 +140,9 @@ use proto::{self, Config, Prioritized}; use bytes::{Buf, Bytes, IntoBuf}; use futures::{self, Async, Future, Poll}; use http::{Request, Response}; -use tokio_io::{AsyncRead, AsyncWrite}; use std::{convert, fmt, mem}; use std::time::Duration; +use tokio_io::{AsyncRead, AsyncWrite}; /// In progress HTTP/2.0 connection handshake future. /// @@ -424,6 +424,14 @@ where pub fn poll_close(&mut self) -> Poll<(), ::Error> { self.connection.poll().map_err(Into::into) } + + /// Sets the connection to a GOAWAY state. Does not close connection immediately. + /// + /// This closes the stream after sending a GOAWAY frame + /// and flushing the codec. Must continue being polled to close connection. + pub fn close_connection(&mut self) { + self.connection.close_connection(); + } } impl futures::Stream for Connection diff --git a/tests/client_request.rs b/tests/client_request.rs index 1226b47..749d1cb 100644 --- a/tests/client_request.rs +++ b/tests/client_request.rs @@ -13,7 +13,7 @@ fn handshake() { .write(SETTINGS_ACK) .build(); - let (_, h2) = client::handshake(mock).wait().unwrap(); + let (_client, h2) = client::handshake(mock).wait().unwrap(); trace!("hands have been shook"); @@ -129,7 +129,12 @@ fn request_stream_id_overflows() { let err = client.send_request(request, true).unwrap_err(); assert_eq!(err.to_string(), "user error: stream ID overflowed"); - h2.expect("h2") + h2.expect("h2").map(|ret| { + // Hold on to the `client` handle to avoid sending a GO_AWAY + // frame. + drop(client); + ret + }) }) }); @@ -338,7 +343,11 @@ fn http_2_request_without_scheme_or_authority() { // first request is allowed assert!(client.send_request(request, true).is_err()); - h2.expect("h2") + h2.expect("h2").map(|ret| { + // Hold on to the `client` handle to avoid sending a GO_AWAY frame. + drop(client); + ret + }) }); h2.join(srv).wait().expect("wait"); @@ -614,6 +623,7 @@ fn recv_too_big_headers() { conn.drive(req1.join(req2)) .and_then(|(conn, _)| conn.expect("client")) + .map(|c| (c, client)) }); client.join(srv).wait().expect("wait"); diff --git a/tests/codec_read.rs b/tests/codec_read.rs index 4fee476..72e97dc 100644 --- a/tests/codec_read.rs +++ b/tests/codec_read.rs @@ -161,7 +161,7 @@ fn read_continuation_frames() { conn.drive(req) .and_then(move |(h2, _)| { h2.expect("client") - }) + }).map(|c| (client, c)) }); client.join(srv).wait().expect("wait"); diff --git a/tests/codec_write.rs b/tests/codec_write.rs index a0a3da5..507d91a 100644 --- a/tests/codec_write.rs +++ b/tests/codec_write.rs @@ -54,6 +54,8 @@ fn write_continuation_frames() { conn.drive(req) .and_then(move |(h2, _)| { h2.unwrap() + }).map(|c| { + (c, client) }) }); diff --git a/tests/flow_control.rs b/tests/flow_control.rs index 673e2da..fe975fb 100644 --- a/tests/flow_control.rs +++ b/tests/flow_control.rs @@ -307,7 +307,9 @@ fn recv_data_overflows_stream_window() { }) }); - conn.unwrap().join(req) + conn.unwrap() + .join(req) + .map(|c| (c, client)) }); h2.join(mock).wait().unwrap(); } @@ -385,6 +387,7 @@ fn stream_error_release_connection_capacity() { }); conn.drive(req.expect("response")) .and_then(|(conn, _)| conn.expect("client")) + .map(|c| (c, client)) }); srv.join(client).wait().unwrap(); @@ -620,19 +623,19 @@ fn reserved_capacity_assigned_in_multi_window_updates() { h2.drive( util::wait_for_capacity(stream, 5) - .map(|stream| (response, stream))) + .map(|stream| (response, client, stream))) }) - .and_then(|(h2, (response, mut stream))| { + .and_then(|(h2, (response, client, mut stream))| { stream.send_data("hello".into(), false).unwrap(); stream.send_data("world".into(), true).unwrap(); - h2.drive(response) + h2.drive(response).map(|c| (c, client)) }) - .and_then(|(h2, response)| { + .and_then(|((h2, response), client)| { assert_eq!(response.status(), StatusCode::NO_CONTENT); // Wait for the connection to close - h2.unwrap() + h2.unwrap().map(|c| (c, client)) }); let srv = srv.assert_client_handshake().unwrap() @@ -820,20 +823,21 @@ fn recv_settings_removes_available_capacity() { stream.reserve_capacity(11); - h2.drive(util::wait_for_capacity(stream, 11).map(|s| (response, s))) + h2.drive(util::wait_for_capacity(stream, 11).map(|s| (response, client, s))) }) - .and_then(|(h2, (response, mut stream))| { + .and_then(|(h2, (response, client, mut stream))| { assert_eq!(stream.capacity(), 11); stream.send_data("hello world".into(), true).unwrap(); - h2.drive(response) + h2.drive(response).map(|c| (c, client)) }) - .and_then(|(h2, response)| { + .and_then(|((h2, response), client)| { assert_eq!(response.status(), StatusCode::NO_CONTENT); // Wait for the connection to close - h2.unwrap() + // Hold on to the `client` handle to avoid sending a GO_AWAY frame. + h2.unwrap().map(|c| (c, client)) }); let _ = h2.join(srv) @@ -881,20 +885,21 @@ fn recv_no_init_window_then_receive_some_init_window() { stream.reserve_capacity(11); - h2.drive(util::wait_for_capacity(stream, 11).map(|s| (response, s))) + h2.drive(util::wait_for_capacity(stream, 11).map(|s| (response, client, s))) }) - .and_then(|(h2, (response, mut stream))| { + .and_then(|(h2, (response, client, mut stream))| { assert_eq!(stream.capacity(), 11); stream.send_data("hello world".into(), true).unwrap(); - h2.drive(response) + h2.drive(response).map(|c| (c, client)) }) - .and_then(|(h2, response)| { + .and_then(|((h2, response), client)| { assert_eq!(response.status(), StatusCode::NO_CONTENT); // Wait for the connection to close - h2.unwrap() + // Hold on to the `client` handle to avoid sending a GO_AWAY frame. + h2.unwrap().map(|c| (c, client)) }); let _ = h2.join(srv) @@ -903,8 +908,8 @@ fn recv_no_init_window_then_receive_some_init_window() { #[test] fn settings_lowered_capacity_returns_capacity_to_connection() { - use std::thread; use std::sync::mpsc; + use std::thread; let _ = ::env_logger::init(); let (io, srv) = mock::new(); @@ -1038,7 +1043,7 @@ fn client_increase_target_window_size() { .and_then(|(_client, mut conn)| { conn.set_target_window_size(2 << 20); - conn.unwrap() + conn.unwrap().map(|c| (c, _client)) }); srv.join(client).wait().unwrap(); @@ -1078,7 +1083,7 @@ fn increase_target_window_size_after_using_some() { .and_then(|(mut conn, _bytes)| { conn.set_target_window_size(2 << 20); conn.unwrap() - }) + }).map(|c| (c, client)) }); srv.join(client).wait().unwrap(); @@ -1113,18 +1118,19 @@ fn decrease_target_window_size() { .uri("https://http2.akamai.com/") .body(()).unwrap(); let (resp, _) = client.send_request(request, true).unwrap(); - conn.drive(resp.expect("response")) + conn.drive(resp.expect("response")).map(|c| (c, client)) }) - .and_then(|(mut conn, res)| { + .and_then(|((mut conn, res), client)| { conn.set_target_window_size(16_384); let mut body = res.into_parts().1; let mut cap = body.release_capacity().clone(); conn.drive(body.concat2().expect("concat")) - .and_then(move |(conn, bytes)| { + .map(|c| (c, client)) + .and_then(move |((conn, bytes), client)| { assert_eq!(bytes.len(), 65_535); cap.release_capacity(bytes.len()).unwrap(); - conn.expect("conn") + conn.expect("conn").map(|c| (c, client)) }) }); @@ -1188,10 +1194,10 @@ fn recv_settings_increase_window_size_after_using_some() { .body(()).unwrap(); let (resp, mut req_body) = client.send_request(request, false).unwrap(); req_body.send_data(vec![0; new_win_size].into(), true).unwrap(); - conn.drive(resp.expect("response")) + conn.drive(resp.expect("response")).map(|c| (c, client)) }) - .and_then(|(conn, _res)| { - conn.expect("client") + .and_then(|((conn, _res), client)| { + conn.expect("client").map(|c| (c, client)) }); srv.join(client).wait().unwrap(); diff --git a/tests/ping_pong.rs b/tests/ping_pong.rs index a7becd3..fe4e3c4 100644 --- a/tests/ping_pong.rs +++ b/tests/ping_pong.rs @@ -10,7 +10,10 @@ fn recv_single_ping() { // Create the handshake let h2 = client::handshake(m) .unwrap() - .and_then(|(_, conn)| conn.unwrap()); + .and_then(|(client, conn)| { + conn.unwrap() + .map(|c| (client, c)) + }); let mock = mock.assert_client_handshake() .unwrap() diff --git a/tests/server.rs b/tests/server.rs index 290fedf..36144b5 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -204,6 +204,78 @@ fn sends_reset_cancel_when_req_body_is_dropped() { srv.join(client).wait().expect("wait"); } +#[test] +fn sends_goaway_when_serv_closes_connection() { + let _ = ::env_logger::init(); + let (io, client) = mock::new(); + + let client = client + .assert_server_handshake() + .unwrap() + .recv_settings() + .send_frame( + frames::headers(1) + .request("POST", "https://example.com/") + ) + .recv_frame(frames::go_away(1)) + .close(); + + let srv = server::handshake(io).expect("handshake").and_then(|srv| { + srv.into_future().unwrap().and_then(|(_, mut srv)| { + srv.close_connection(); + srv.into_future().unwrap() + }) + }); + + srv.join(client).wait().expect("wait"); +} + +#[test] +fn serve_request_then_serv_closes_connection() { + let _ = ::env_logger::init(); + let (io, client) = mock::new(); + + let client = client + .assert_server_handshake() + .unwrap() + .recv_settings() + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/"), + ) + .recv_frame(frames::headers(1).response(200).eos()) + .recv_frame(frames::reset(1).cancel()) + .send_frame( + frames::headers(3) + .request("GET", "https://example.com/"), + ) + .recv_frame(frames::go_away(3)) + // streams sent after GOAWAY receive no response + .send_frame( + frames::headers(5) + .request("GET", "https://example.com/"), + ) + .close(); + + let srv = server::handshake(io).expect("handshake").and_then(|srv| { + srv.into_future().unwrap().and_then(|(reqstream, srv)| { + let (req, mut stream) = reqstream.unwrap(); + + assert_eq!(req.method(), &http::Method::GET); + + let rsp = http::Response::builder().status(200).body(()).unwrap(); + stream.send_response(rsp, true).unwrap(); + + srv.into_future().unwrap().and_then(|(_reqstream, mut srv)| { + srv.close_connection(); + srv.into_future().unwrap() + }) + }) + }); + + srv.join(client).wait().expect("wait"); +} + #[test] fn sends_reset_cancel_when_res_body_is_dropped() { let _ = ::env_logger::init(); diff --git a/tests/stream_states.rs b/tests/stream_states.rs index 0238a57..23f2333 100644 --- a/tests/stream_states.rs +++ b/tests/stream_states.rs @@ -621,6 +621,7 @@ fn rst_stream_expires() { err.to_string(), "protocol error: unspecific protocol error detected" ); + drop(client); }) });