diff --git a/src/codec/error.rs b/src/codec/error.rs index 5a060f9..e2cc705 100644 --- a/src/codec/error.rs +++ b/src/codec/error.rs @@ -16,6 +16,9 @@ pub enum SendError { /// User error User(UserError), + /// Connection error prevents sending. + Connection(Reason), + /// I/O error Io(io::Error), } @@ -80,6 +83,7 @@ impl error::Error for SendError { match *self { User(ref e) => e.description(), + Connection(ref reason) => reason.description(), Io(ref e) => e.description(), } } diff --git a/src/error.rs b/src/error.rs index c004dcd..f9dac4b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -62,6 +62,7 @@ impl From for Error { fn from(src: SendError) -> Error { match src { SendError::User(e) => e.into(), + SendError::Connection(reason) => reason.into(), SendError::Io(e) => e.into(), } } diff --git a/src/proto/error.rs b/src/proto/error.rs index f26fbcb..2ab9a37 100644 --- a/src/proto/error.rs +++ b/src/proto/error.rs @@ -1,4 +1,4 @@ -use codec::RecvError; +use codec::{RecvError, SendError}; use frame::Reason; use std::io; @@ -11,12 +11,13 @@ pub enum Error { } impl Error { - pub fn into_connection_recv_error(self) -> RecvError { - use self::Error::*; - - match self { - Proto(reason) => RecvError::Connection(reason), - Io(e) => RecvError::Io(e), + /// Clone the error for internal purposes. + /// + /// `io::Error` is not `Clone`, so we only copy the `ErrorKind`. + pub(super) fn shallow_clone(&self) -> Error { + match *self { + Error::Proto(reason) => Error::Proto(reason), + Error::Io(ref io) => Error::Io(io::Error::from(io.kind())), } } } @@ -32,3 +33,21 @@ impl From for Error { Error::Io(src) } } + +impl From for RecvError { + fn from(src: Error) -> RecvError { + match src { + Error::Proto(reason) => RecvError::Connection(reason), + Error::Io(e) => RecvError::Io(e), + } + } +} + +impl From for SendError { + fn from(src: Error) -> SendError { + match src { + Error::Proto(reason) => SendError::Connection(reason), + Error::Io(e) => SendError::Io(e), + } + } +} diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index 7b0ebdd..0f560f8 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -337,10 +337,7 @@ where self.ensure_can_reserve(frame.promised_id())?; // Make sure that the stream state is valid - store[stream] - .state - .ensure_recv_open() - .map_err(|e| e.into_connection_recv_error())?; + store[stream].state.ensure_recv_open()?; // TODO: Streams in the reserved states do not count towards the concurrency // limit. However, it seems like there should be a cap otherwise this diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index aa62927..8d6c3c4 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -56,6 +56,9 @@ where /// Task that calls `poll_complete`. task: Option, + + /// If the connection errors, a copy is kept for any StreamRefs. + conn_error: Option, } impl Streams @@ -71,6 +74,7 @@ where recv: Recv::new(&config), send: Send::new(&config), task: None, + conn_error: None, }, store: Store::new(), })), @@ -193,6 +197,8 @@ where }) .unwrap(); + actions.conn_error = Some(err.shallow_clone()); + last_processed_id } @@ -337,6 +343,7 @@ where let mut me = self.inner.lock().unwrap(); let me = &mut *me; + me.actions.ensure_no_conn_error()?; me.actions.send.ensure_next_stream_id()?; // The `pending` argument is provided by the `Client`, and holds @@ -424,6 +431,7 @@ where let mut me = self.inner.lock().unwrap(); let me = &mut *me; + me.actions.ensure_no_conn_error()?; me.actions.send.ensure_next_stream_id()?; if let Some(key) = key { @@ -733,4 +741,12 @@ where self.recv.ensure_not_idle(id) } } + + fn ensure_no_conn_error(&self) -> Result<(), proto::Error> { + if let Some(ref err) = self.conn_error { + Err(err.shallow_clone()) + } else { + Ok(()) + } + } } diff --git a/tests/client_request.rs b/tests/client_request.rs index 8552175..a1524f2 100644 --- a/tests/client_request.rs +++ b/tests/client_request.rs @@ -242,8 +242,66 @@ fn request_with_h1_version() {} #[test] -#[ignore] -fn sending_request_on_closed_soket() {} +fn sending_request_on_closed_connection() { + let _ = ::env_logger::init(); + let (io, srv) = mock::new(); + + let srv = srv.assert_client_handshake() + .unwrap() + .recv_settings() + .recv_frame( + frames::headers(1) + .request("GET", "https://http2.akamai.com/") + .eos(), + ) + .send_frame(frames::headers(1).response(200).eos()) + // a bad frame! + .send_frame(frames::headers(0).response(200).eos()) + .close(); + + let h2 = Client::handshake(io) + .expect("handshake") + .and_then(|(mut client, h2)| { + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + // first request works + let req = client + .send_request(request, true) + .expect("send_request1") + .expect("response1") + .map(|_| ()); + // after finish request1, there should be a conn error + let h2 = h2.then(|res| { + res.expect_err("h2 error"); + Ok::<(), ()>(()) + }); + + h2.select(req) + .then(|res| match res { + Ok((_, next)) => next, + Err(_) => unreachable!("both selected futures cannot error"), + }) + .map(move |_| client) + }) + .and_then(|mut client| { + let poll_err = client.poll_ready().unwrap_err(); + let msg = "protocol error: unspecific protocol error detected"; + assert_eq!(poll_err.to_string(), msg); + + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + let send_err = client.send_request(request, true).unwrap_err(); + assert_eq!(send_err.to_string(), msg); + + Ok(()) + }); + + h2.join(srv).wait().expect("wait"); +} const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0];