diff --git a/src/client.rs b/src/client.rs index fbebc76..1a185a9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -15,8 +15,8 @@ use std::marker::PhantomData; /// In progress H2 connection binding pub struct Handshake { + builder: Builder, inner: MapErr, fn(io::Error) -> ::Error>, - settings: Settings, _marker: PhantomData, } @@ -36,9 +36,10 @@ pub struct Body { } /// Build a Client. -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug)] pub struct Builder { settings: Settings, + stream_id: StreamId, } #[derive(Debug)] @@ -72,7 +73,7 @@ where T: AsyncRead + AsyncWrite, B: IntoBuf, { - fn handshake2(io: T, settings: Settings) -> Handshake { + fn handshake2(io: T, builder: Builder) -> Handshake { use tokio_io::io; debug!("binding client connection"); @@ -81,8 +82,8 @@ where let handshake = io::write_all(io, msg).map_err(::Error::from as _); Handshake { + builder, inner: handshake, - settings: settings, _marker: PhantomData, } } @@ -182,6 +183,14 @@ impl Builder { self } + /// Set the first stream ID to something other than 1. + #[cfg(feature = "unstable")] + pub fn initial_stream_id(&mut self, stream_id: u32) -> &mut Self { + self.stream_id = stream_id.into(); + assert!(self.stream_id.is_client_initiated(), "stream id must be odd"); + self + } + /// Bind an H2 client connection. /// /// Returns a future which resolves to the connection value once the H2 @@ -194,7 +203,16 @@ impl Builder { T: AsyncRead + AsyncWrite, B: IntoBuf, { - Client::handshake2(io, self.settings.clone()) + Client::handshake2(io, self.clone()) + } +} + +impl Default for Builder { + fn default() -> Builder { + Builder { + settings: Default::default(), + stream_id: 1.into(), + } } } @@ -215,16 +233,16 @@ where // Create the codec let mut codec = Codec::new(io); - if let Some(max) = self.settings.max_frame_size() { + if let Some(max) = self.builder.settings.max_frame_size() { codec.set_max_recv_frame_size(max as usize); } // Send initial settings frame codec - .buffer(self.settings.clone().into()) + .buffer(self.builder.settings.clone().into()) .expect("invalid SETTINGS frame"); - let connection = Connection::new(codec, &self.settings); + let connection = Connection::new(codec, &self.builder.settings, self.builder.stream_id); Ok(Async::Ready(Client { connection, })) diff --git a/src/codec/error.rs b/src/codec/error.rs index 33c753f..5a060f9 100644 --- a/src/codec/error.rs +++ b/src/codec/error.rs @@ -37,6 +37,10 @@ pub enum UserError { /// The released capacity is larger than claimed capacity. ReleaseCapacityTooBig, + /// The stream ID space is overflowed. + /// + /// A new connection is needed. + OverflowedStreamId, } // ===== impl RecvError ===== @@ -112,6 +116,7 @@ impl error::Error for UserError { PayloadTooBig => "payload too big", Rejected => "rejected", ReleaseCapacityTooBig => "release capacity too big", + OverflowedStreamId => "stream ID overflowed", } } } diff --git a/src/frame/mod.rs b/src/frame/mod.rs index cecb974..0799501 100644 --- a/src/frame/mod.rs +++ b/src/frame/mod.rs @@ -48,7 +48,7 @@ pub use self::priority::{Priority, StreamDependency}; pub use self::reason::Reason; pub use self::reset::Reset; pub use self::settings::Settings; -pub use self::stream_id::StreamId; +pub use self::stream_id::{StreamId, StreamIdOverflow}; pub use self::window_update::WindowUpdate; // Re-export some constants diff --git a/src/frame/stream_id.rs b/src/frame/stream_id.rs index 4f9b9fe..7406e70 100644 --- a/src/frame/stream_id.rs +++ b/src/frame/stream_id.rs @@ -4,9 +4,16 @@ use std::u32; #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct StreamId(u32); +#[derive(Debug, Copy, Clone)] +pub struct StreamIdOverflow; + const STREAM_ID_MASK: u32 = 1 << 31; impl StreamId { + pub const ZERO: StreamId = StreamId(0); + + pub const MAX: StreamId = StreamId(u32::MAX >> 1); + /// Parse the stream ID #[inline] pub fn parse(buf: &[u8]) -> (StreamId, bool) { @@ -30,20 +37,20 @@ impl StreamId { #[inline] pub fn zero() -> StreamId { - StreamId(0) - } - - #[inline] - pub fn max() -> StreamId { - StreamId(u32::MAX >> 1) + StreamId::ZERO } pub fn is_zero(&self) -> bool { self.0 == 0 } - pub fn increment(&mut self) { - self.0 += 2; + pub fn next_id(&self) -> Result { + let next = self.0 + 2; + if next > StreamId::MAX.0 { + Err(StreamIdOverflow) + } else { + Ok(StreamId(next)) + } } } diff --git a/src/proto/connection.rs b/src/proto/connection.rs index dec3e6c..f871f5e 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -61,13 +61,14 @@ where pub fn new( codec: Codec>, settings: &frame::Settings, + next_stream_id: frame::StreamId ) -> Connection { - // TODO: Actually configure let streams = Streams::new(streams::Config { local_init_window_sz: settings .initial_window_size() .unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE), local_max_initiated: None, + local_next_stream_id: next_stream_id, local_push_enabled: settings.is_push_enabled(), remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, remote_max_initiated: None, diff --git a/src/proto/streams/mod.rs b/src/proto/streams/mod.rs index 1bf15cf..f5b6f25 100644 --- a/src/proto/streams/mod.rs +++ b/src/proto/streams/mod.rs @@ -23,7 +23,7 @@ use self::store::{Entry, Store}; use self::stream::Stream; use error::Reason::*; -use frame::StreamId; +use frame::{StreamId, StreamIdOverflow}; use proto::*; use bytes::Bytes; @@ -37,6 +37,9 @@ pub struct Config { /// Maximum number of locally initiated streams pub local_max_initiated: Option, + /// The stream ID to start the next local stream with + pub local_next_stream_id: StreamId, + /// If the local peer is willing to receive push promises pub local_push_enabled: bool, diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index 8474cc7..985ac61 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -21,7 +21,7 @@ where flow: FlowControl, /// The lowest stream ID that is still idle - next_stream_id: StreamId, + next_stream_id: Result, /// The stream ID of the last processed stream last_processed_id: StreamId, @@ -76,7 +76,7 @@ where Recv { init_window_sz: config.local_init_window_sz, flow: flow, - next_stream_id: next_stream_id.into(), + next_stream_id: Ok(next_stream_id.into()), pending_window_updates: store::Queue::new(), last_processed_id: StreamId::zero(), pending_accept: store::Queue::new(), @@ -109,12 +109,12 @@ where self.ensure_can_open(id)?; - if id < self.next_stream_id { + let next_id = self.next_stream_id()?; + if id < next_id { return Err(RecvError::Connection(ProtocolError)); } - self.next_stream_id = id; - self.next_stream_id.increment(); + self.next_stream_id = id.next_id(); if !counts.can_inc_num_recv_streams() { self.refused = Some(id); @@ -137,6 +137,13 @@ where let is_initial = stream.state.recv_open(frame.is_end_stream())?; if is_initial { + let next_id = self.next_stream_id()?; + if frame.stream_id() >= next_id { + self.next_stream_id = frame.stream_id().next_id(); + } else { + return Err(RecvError::Connection(ProtocolError)); + } + // TODO: be smarter about this logic if frame.stream_id() > self.last_processed_id { self.last_processed_id = frame.stream_id(); @@ -383,9 +390,12 @@ where /// Ensures that `id` is not in the `Idle` state. pub fn ensure_not_idle(&self, id: StreamId) -> Result<(), Reason> { - if id >= self.next_stream_id { - return Err(ProtocolError); + if let Ok(next) = self.next_stream_id { + if id >= next { + return Err(ProtocolError); + } } + // if next_stream_id is overflowed, that's ok. Ok(()) } @@ -428,6 +438,14 @@ where Ok(()) } + fn next_stream_id(&self) -> Result { + if let Ok(id) = self.next_stream_id { + Ok(id) + } else { + Err(RecvError::Connection(ProtocolError)) + } + } + /// Returns true if the remote peer can reserve a stream with the given ID. fn ensure_can_reserve(&self, promised_id: StreamId) -> Result<(), RecvError> { // TODO: Are there other rules? diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index bae09bd..b8053bb 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -15,7 +15,7 @@ where P: Peer, { /// Stream identifier to use for next initialized stream. - next_stream_id: StreamId, + next_stream_id: Result, /// Initial window size of locally initiated streams init_window_sz: WindowSize, @@ -31,11 +31,9 @@ where { /// Create a new `Send` pub fn new(config: &Config) -> Self { - let next_stream_id = if P::is_server() { 2 } else { 1 }; - Send { - next_stream_id: next_stream_id.into(), init_window_sz: config.local_init_window_sz, + next_stream_id: Ok(config.local_next_stream_id), prioritize: Prioritize::new(config), } } @@ -49,19 +47,17 @@ where /// /// Returns the stream state if successful. `None` if refused pub fn open(&mut self, counts: &mut Counts

) -> Result { - self.ensure_can_open()?; - if !counts.can_inc_num_send_streams() { return Err(Rejected.into()); } - let ret = self.next_stream_id; - self.next_stream_id.increment(); + let stream_id = self.try_open()?; // Increment the number of locally initiated streams counts.inc_num_send_streams(); + self.next_stream_id = stream_id.next_id(); - Ok(ret) + Ok(stream_id) } pub fn send_headers( @@ -293,22 +289,23 @@ where } pub fn ensure_not_idle(&self, id: StreamId) -> Result<(), Reason> { - if id >= self.next_stream_id { - return Err(ProtocolError); + if let Ok(next) = self.next_stream_id { + if id >= next { + return Err(ProtocolError); + } } + // if next_stream_id is overflowed, that's ok. Ok(()) } - /// Returns true if the local actor can initiate a stream with the given ID. - fn ensure_can_open(&self) -> Result<(), UserError> { + /// Returns a new StreamId if the local actor can initiate a new stream. + fn try_open(&self) -> Result { if P::is_server() { // Servers cannot open streams. PushPromise must first be reserved. return Err(UnexpectedFrameType); } - // TODO: Handle StreamId overflow - - Ok(()) + self.next_stream_id.map_err(|_| OverflowedStreamId) } } diff --git a/src/server.rs b/src/server.rs index 1880c1e..df04226 100644 --- a/src/server.rs +++ b/src/server.rs @@ -106,7 +106,7 @@ where let handshake = Flush::new(codec) .and_then(ReadPreface::new) .map(move |codec| { - let connection = Connection::new(codec, &settings); + let connection = Connection::new(codec, &settings, 2.into()); Server { connection, } diff --git a/tests/client_request.rs b/tests/client_request.rs index ea68f0a..4f3b952 100644 --- a/tests/client_request.rs +++ b/tests/client_request.rs @@ -57,6 +57,55 @@ fn recv_invalid_server_stream_id() { assert!(stream.wait().is_err()); } +#[test] +fn request_stream_id_overflows() { + let _ = ::env_logger::init(); + let (io, srv) = mock::new(); + + + let h2 = Client::builder() + .initial_stream_id(::std::u32::MAX >> 1) + .handshake::<_, Bytes>(io) + .expect("handshake") + .and_then(|mut h2| { + let request = Request::builder() + .method(Method::GET) + .uri("https://example.com/") + .body(()) + .unwrap(); + + // first request is allowed + let req = h2.send_request(request, true) + .unwrap() + .unwrap(); + + let request = Request::builder() + .method(Method::GET) + .uri("https://example.com/") + .body(()) + .unwrap(); + + // second cant use the next stream id, it's over + let err = h2.send_request(request, true).unwrap_err(); + assert_eq!(err.to_string(), "user error: stream ID overflowed"); + + h2.expect("h2").join(req) + }); + + let srv = srv.assert_client_handshake() + .unwrap() + .recv_settings() + .recv_frame( + frames::headers(::std::u32::MAX >> 1) + .request("GET", "https://example.com/") + .eos(), + ) + .send_frame(frames::headers(::std::u32::MAX >> 1).response(200)) + .close(); + + h2.join(srv).wait().expect("wait"); +} + #[test] #[ignore] fn request_without_scheme() {}