diff --git a/src/client.rs b/src/client.rs index bea78b3..8574e3c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -57,7 +57,7 @@ impl Peer for Client { } fn is_valid_remote_stream_id(id: StreamId) -> bool { - id.is_server_initiated() + false } fn convert_send_message( diff --git a/src/error.rs b/src/error.rs index 606c3d9..4b82bb4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -56,6 +56,9 @@ pub enum User { /// The stream is not currently expecting a frame of this type. UnexpectedFrameType, + /// The connection state is corrupt and the connection should be dropped. + Corrupt, + // TODO: reserve additional variants } @@ -93,6 +96,7 @@ macro_rules! user_desc { InvalidStreamId => concat!($prefix, "invalid stream ID"), InactiveStreamId => concat!($prefix, "inactive stream ID"), UnexpectedFrameType => concat!($prefix, "unexpected frame type"), + Corrupt => concat!($prefix, "connection state corrupt"), } }); } diff --git a/src/proto/connection.rs b/src/proto/connection.rs index f480ebf..ca93c49 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -23,6 +23,8 @@ use std::hash::BuildHasherDefault; pub struct Connection { inner: proto::Inner, streams: StreamMap, + // Set to `true` as long as the connection is in a valid state. + active: bool, peer: PhantomData<(P, B)>, } @@ -36,6 +38,7 @@ pub fn new(transport: proto::Inner) -> Connection Connection { inner: transport, streams: StreamMap::default(), + active: true, peer: PhantomData, } } @@ -108,6 +111,10 @@ impl Stream for Connection trace!("Connection::poll"); + if !self.active { + return Err(error::User::Corrupt.into()); + } + let frame = match try!(self.inner.poll()) { Async::Ready(f) => f, Async::NotReady => { @@ -135,7 +142,8 @@ impl Stream for Connection // connections should not be factored. if !P::is_valid_remote_stream_id(stream_id) { - unimplemented!(); + self.active = false; + return Err(error::Reason::ProtocolError.into()); } } @@ -179,6 +187,10 @@ impl Sink for Connection fn start_send(&mut self, item: Self::SinkItem) -> StartSend { + if !self.active { + return Err(error::User::Corrupt.into()); + } + // First ensure that the upstream can process a new item if !try!(self.poll_ready()).is_ready() { return Ok(AsyncSink::NotReady(item)); diff --git a/src/server.rs b/src/server.rs index fe40040..ff8411d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -111,7 +111,7 @@ impl Peer for Server { type Poll = http::request::Head; fn is_valid_local_stream_id(id: StreamId) -> bool { - id.is_server_initiated() + false } fn is_valid_remote_stream_id(id: StreamId) -> bool { diff --git a/tests/client_request.rs b/tests/client_request.rs index 47f198a..6baea71 100644 --- a/tests/client_request.rs +++ b/tests/client_request.rs @@ -23,6 +23,17 @@ macro_rules! assert_user_err { }}; } +macro_rules! assert_proto_err { + ($actual:expr, $err:ident) => {{ + use h2::error::{ConnectionError, Reason}; + + match $actual { + ConnectionError::Proto(e) => assert_eq!(e, Reason::$err), + _ => panic!("unexpected connection error type"), + } + }}; +} + #[test] fn handshake() { let _ = ::env_logger::init(); @@ -341,8 +352,32 @@ fn invalid_client_stream_id() { } #[test] -#[ignore] fn invalid_server_stream_id() { + let _ = ::env_logger::init(); + + let mock = mock_io::Builder::new() + .handshake() + // Write GET / + .write(&[ + 0, 0, 0x10, 1, 5, 0, 0, 0, 1, 0x82, 0x87, 0x41, 0x8B, 0x9D, 0x29, + 0xAC, 0x4B, 0x8F, 0xA8, 0xE9, 0x19, 0x97, 0x21, 0xE9, 0x84, + ]) + .write(SETTINGS_ACK) + // Read response + .read(&[0, 0, 1, 1, 5, 0, 0, 0, 2, 137]) + .build(); + + let h2 = client::handshake(mock) + .wait().unwrap(); + + // Send the request + let mut request = request::Head::default(); + request.uri = "https://http2.akamai.com/".parse().unwrap(); + let h2 = h2.send_request(1.into(), request, true).wait().unwrap(); + + // Get the response + let (err, _) = h2.into_future().wait().unwrap_err(); + assert_proto_err!(err, ProtocolError); } #[test]