diff --git a/src/proto/connection.rs b/src/proto/connection.rs index dc67c4f..0ea3e2a 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -139,12 +139,21 @@ where self.go_away.go_away(frame); } - pub fn go_away_now(&mut self, e: Reason) { + fn go_away_now(&mut self, e: Reason) { let last_processed_id = self.streams.last_processed_id(); let frame = frame::GoAway::new(last_processed_id, e); self.go_away.go_away_now(frame); } + pub fn go_away_from_user(&mut self, e: Reason) { + let last_processed_id = self.streams.last_processed_id(); + let frame = frame::GoAway::new(last_processed_id, e); + self.go_away.go_away_from_user(frame); + + // Notify all streams of reason we're abruptly closing. + self.streams.recv_err(&proto::Error::Proto(e)); + } + fn take_error(&mut self, ours: Reason) -> Poll<(), proto::Error> { let reason = if let Some(theirs) = self.error.take() { match (ours, theirs) { @@ -214,7 +223,7 @@ where // error. This is handled by setting a GOAWAY frame followed by // terminating the connection. Err(Connection(e)) => { - debug!("Connection::poll; err={:?}", e); + debug!("Connection::poll; connection error={:?}", e); // We may have already sent a GOAWAY for this error, // if so, don't send another, just flush and close up. @@ -237,7 +246,7 @@ where id, reason, }) => { - trace!("stream level error; id={:?}; reason={:?}", id, reason); + trace!("stream error; id={:?}; reason={:?}", id, reason); self.streams.send_reset(id, reason); }, // Attempting to read a frame resulted in an I/O error. All @@ -245,6 +254,7 @@ where // // TODO: Are I/O errors recoverable? Err(Io(e)) => { + debug!("Connection::poll; IO error={:?}", e); let e = e.into(); // Reset all active streams @@ -256,7 +266,7 @@ where } } State::Closing(reason) => { - trace!("connection closing after flush, reason={:?}", reason); + trace!("connection closing after flush"); // Flush/shutdown the codec try_ready!(self.codec.shutdown()); @@ -284,7 +294,13 @@ where // - If it has, we've also added a PING to be sent in poll_ready if let Some(reason) = try_ready!(self.poll_go_away()) { if self.go_away.should_close_now() { - return Err(RecvError::Connection(reason)); + if self.go_away.is_user_initiated() { + // A user initiated abrupt shutdown shouldn't return + // the same error back to the user. + return Ok(Async::Ready(())); + } else { + return Err(RecvError::Connection(reason)); + } } // Only NO_ERROR should be waiting for idle debug_assert_eq!(reason, Reason::NO_ERROR, "graceful GOAWAY should be NO_ERROR"); diff --git a/src/proto/go_away.rs b/src/proto/go_away.rs index 5b4f856..bb9f541 100644 --- a/src/proto/go_away.rs +++ b/src/proto/go_away.rs @@ -13,7 +13,8 @@ pub(super) struct GoAway { close_now: bool, /// Records if we've sent any GOAWAY before. going_away: Option, - + /// Whether the user started the GOAWAY by calling `abrupt_shutdown`. + is_user_initiated: bool, /// A GOAWAY frame that must be buffered in the Codec immediately. pending: Option, } @@ -45,6 +46,7 @@ impl GoAway { GoAway { close_now: false, going_away: None, + is_user_initiated: false, pending: None, } } @@ -82,11 +84,20 @@ impl GoAway { self.go_away(f); } + pub fn go_away_from_user(&mut self, f: frame::GoAway) { + self.is_user_initiated = true; + self.go_away_now(f); + } + /// Return if a GOAWAY has ever been scheduled. pub fn is_going_away(&self) -> bool { self.going_away.is_some() } + pub fn is_user_initiated(&self) -> bool { + self.is_user_initiated + } + /// Return the last Reason we've sent. pub fn going_away_reason(&self) -> Option { self.going_away diff --git a/src/server.rs b/src/server.rs index 32eb4d1..1b30d94 100644 --- a/src/server.rs +++ b/src/server.rs @@ -443,7 +443,7 @@ where /// /// For graceful shutdowns, see [`graceful_shutdown`](Connection::graceful_shutdown). pub fn abrupt_shutdown(&mut self, reason: Reason) { - self.connection.go_away_now(reason); + self.connection.go_away_from_user(reason); } /// Starts a [graceful shutdown][1] process. diff --git a/tests/h2-support/src/frames.rs b/tests/h2-support/src/frames.rs index 77cfa9b..6fa4c7f 100644 --- a/tests/h2-support/src/frames.rs +++ b/tests/h2-support/src/frames.rs @@ -260,30 +260,29 @@ impl From> for SendFrame { impl Mock { pub fn protocol_error(self) -> Self { - Mock(frame::GoAway::new( - self.0.last_stream_id(), - frame::Reason::PROTOCOL_ERROR, - )) + self.reason(frame::Reason::PROTOCOL_ERROR) + } + + pub fn internal_error(self) -> Self { + self.reason(frame::Reason::INTERNAL_ERROR) } pub fn flow_control(self) -> Self { - Mock(frame::GoAway::new( - self.0.last_stream_id(), - frame::Reason::FLOW_CONTROL_ERROR, - )) + self.reason(frame::Reason::FLOW_CONTROL_ERROR) } pub fn frame_size(self) -> Self { - Mock(frame::GoAway::new( - self.0.last_stream_id(), - frame::Reason::FRAME_SIZE_ERROR, - )) + self.reason(frame::Reason::FRAME_SIZE_ERROR) } pub fn no_error(self) -> Self { + self.reason(frame::Reason::NO_ERROR) + } + + pub fn reason(self, reason: frame::Reason) -> Self { Mock(frame::GoAway::new( self.0.last_stream_id(), - frame::Reason::NO_ERROR, + reason, )) } } diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index d6f276c..15359ca 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -220,13 +220,33 @@ fn abrupt_shutdown() { frames::headers(1) .request("POST", "https://example.com/") ) - .recv_frame(frames::go_away(1)) - .close(); + .recv_frame(frames::go_away(1).internal_error()) + .recv_eof(); let srv = server::handshake(io).expect("handshake").and_then(|srv| { - srv.into_future().unwrap().and_then(|(_, mut srv)| { - srv.abrupt_shutdown(Reason::NO_ERROR); - srv.into_future().unwrap() + srv.into_future().unwrap().and_then(|(item, mut srv)| { + let (req, tx) = item.expect("server receives request"); + + let req_fut = req + .into_body() + .concat2() + .map(|_| drop(tx)) + .expect_err("request body should error") + .map(|err| { + assert_eq!( + err.reason(), + Some(Reason::INTERNAL_ERROR), + "streams should be also error with user's reason", + ); + }); + + srv.abrupt_shutdown(Reason::INTERNAL_ERROR); + + let srv_fut = futures::future::poll_fn(move || { + srv.poll_close() + }).expect("server"); + + req_fut.join(srv_fut) }) });