check for StreamId overflow (#68)
This commit is contained in:
@@ -15,8 +15,8 @@ use std::marker::PhantomData;
|
|||||||
|
|
||||||
/// In progress H2 connection binding
|
/// In progress H2 connection binding
|
||||||
pub struct Handshake<T: AsyncRead + AsyncWrite, B: IntoBuf = Bytes> {
|
pub struct Handshake<T: AsyncRead + AsyncWrite, B: IntoBuf = Bytes> {
|
||||||
|
builder: Builder,
|
||||||
inner: MapErr<WriteAll<T, &'static [u8]>, fn(io::Error) -> ::Error>,
|
inner: MapErr<WriteAll<T, &'static [u8]>, fn(io::Error) -> ::Error>,
|
||||||
settings: Settings,
|
|
||||||
_marker: PhantomData<B>,
|
_marker: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,9 +36,10 @@ pub struct Body<B: IntoBuf> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Build a Client.
|
/// Build a Client.
|
||||||
#[derive(Clone, Debug, Default)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Builder {
|
pub struct Builder {
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
|
stream_id: StreamId,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -72,7 +73,7 @@ where
|
|||||||
T: AsyncRead + AsyncWrite,
|
T: AsyncRead + AsyncWrite,
|
||||||
B: IntoBuf,
|
B: IntoBuf,
|
||||||
{
|
{
|
||||||
fn handshake2(io: T, settings: Settings) -> Handshake<T, B> {
|
fn handshake2(io: T, builder: Builder) -> Handshake<T, B> {
|
||||||
use tokio_io::io;
|
use tokio_io::io;
|
||||||
|
|
||||||
debug!("binding client connection");
|
debug!("binding client connection");
|
||||||
@@ -81,8 +82,8 @@ where
|
|||||||
let handshake = io::write_all(io, msg).map_err(::Error::from as _);
|
let handshake = io::write_all(io, msg).map_err(::Error::from as _);
|
||||||
|
|
||||||
Handshake {
|
Handshake {
|
||||||
|
builder,
|
||||||
inner: handshake,
|
inner: handshake,
|
||||||
settings: settings,
|
|
||||||
_marker: PhantomData,
|
_marker: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -182,6 +183,14 @@ impl Builder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the first stream ID to something other than 1.
|
||||||
|
#[cfg(feature = "unstable")]
|
||||||
|
pub fn initial_stream_id(&mut self, stream_id: u32) -> &mut Self {
|
||||||
|
self.stream_id = stream_id.into();
|
||||||
|
assert!(self.stream_id.is_client_initiated(), "stream id must be odd");
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Bind an H2 client connection.
|
/// Bind an H2 client connection.
|
||||||
///
|
///
|
||||||
/// Returns a future which resolves to the connection value once the H2
|
/// Returns a future which resolves to the connection value once the H2
|
||||||
@@ -194,7 +203,16 @@ impl Builder {
|
|||||||
T: AsyncRead + AsyncWrite,
|
T: AsyncRead + AsyncWrite,
|
||||||
B: IntoBuf,
|
B: IntoBuf,
|
||||||
{
|
{
|
||||||
Client::handshake2(io, self.settings.clone())
|
Client::handshake2(io, self.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Builder {
|
||||||
|
fn default() -> Builder {
|
||||||
|
Builder {
|
||||||
|
settings: Default::default(),
|
||||||
|
stream_id: 1.into(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -215,16 +233,16 @@ where
|
|||||||
// Create the codec
|
// Create the codec
|
||||||
let mut codec = Codec::new(io);
|
let mut codec = Codec::new(io);
|
||||||
|
|
||||||
if let Some(max) = self.settings.max_frame_size() {
|
if let Some(max) = self.builder.settings.max_frame_size() {
|
||||||
codec.set_max_recv_frame_size(max as usize);
|
codec.set_max_recv_frame_size(max as usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send initial settings frame
|
// Send initial settings frame
|
||||||
codec
|
codec
|
||||||
.buffer(self.settings.clone().into())
|
.buffer(self.builder.settings.clone().into())
|
||||||
.expect("invalid SETTINGS frame");
|
.expect("invalid SETTINGS frame");
|
||||||
|
|
||||||
let connection = Connection::new(codec, &self.settings);
|
let connection = Connection::new(codec, &self.builder.settings, self.builder.stream_id);
|
||||||
Ok(Async::Ready(Client {
|
Ok(Async::Ready(Client {
|
||||||
connection,
|
connection,
|
||||||
}))
|
}))
|
||||||
|
|||||||
@@ -37,6 +37,10 @@ pub enum UserError {
|
|||||||
|
|
||||||
/// The released capacity is larger than claimed capacity.
|
/// The released capacity is larger than claimed capacity.
|
||||||
ReleaseCapacityTooBig,
|
ReleaseCapacityTooBig,
|
||||||
|
/// The stream ID space is overflowed.
|
||||||
|
///
|
||||||
|
/// A new connection is needed.
|
||||||
|
OverflowedStreamId,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ===== impl RecvError =====
|
// ===== impl RecvError =====
|
||||||
@@ -112,6 +116,7 @@ impl error::Error for UserError {
|
|||||||
PayloadTooBig => "payload too big",
|
PayloadTooBig => "payload too big",
|
||||||
Rejected => "rejected",
|
Rejected => "rejected",
|
||||||
ReleaseCapacityTooBig => "release capacity too big",
|
ReleaseCapacityTooBig => "release capacity too big",
|
||||||
|
OverflowedStreamId => "stream ID overflowed",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ pub use self::priority::{Priority, StreamDependency};
|
|||||||
pub use self::reason::Reason;
|
pub use self::reason::Reason;
|
||||||
pub use self::reset::Reset;
|
pub use self::reset::Reset;
|
||||||
pub use self::settings::Settings;
|
pub use self::settings::Settings;
|
||||||
pub use self::stream_id::StreamId;
|
pub use self::stream_id::{StreamId, StreamIdOverflow};
|
||||||
pub use self::window_update::WindowUpdate;
|
pub use self::window_update::WindowUpdate;
|
||||||
|
|
||||||
// Re-export some constants
|
// Re-export some constants
|
||||||
|
|||||||
@@ -4,9 +4,16 @@ use std::u32;
|
|||||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
|
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
|
||||||
pub struct StreamId(u32);
|
pub struct StreamId(u32);
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
pub struct StreamIdOverflow;
|
||||||
|
|
||||||
const STREAM_ID_MASK: u32 = 1 << 31;
|
const STREAM_ID_MASK: u32 = 1 << 31;
|
||||||
|
|
||||||
impl StreamId {
|
impl StreamId {
|
||||||
|
pub const ZERO: StreamId = StreamId(0);
|
||||||
|
|
||||||
|
pub const MAX: StreamId = StreamId(u32::MAX >> 1);
|
||||||
|
|
||||||
/// Parse the stream ID
|
/// Parse the stream ID
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn parse(buf: &[u8]) -> (StreamId, bool) {
|
pub fn parse(buf: &[u8]) -> (StreamId, bool) {
|
||||||
@@ -30,20 +37,20 @@ impl StreamId {
|
|||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn zero() -> StreamId {
|
pub fn zero() -> StreamId {
|
||||||
StreamId(0)
|
StreamId::ZERO
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn max() -> StreamId {
|
|
||||||
StreamId(u32::MAX >> 1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_zero(&self) -> bool {
|
pub fn is_zero(&self) -> bool {
|
||||||
self.0 == 0
|
self.0 == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn increment(&mut self) {
|
pub fn next_id(&self) -> Result<StreamId, StreamIdOverflow> {
|
||||||
self.0 += 2;
|
let next = self.0 + 2;
|
||||||
|
if next > StreamId::MAX.0 {
|
||||||
|
Err(StreamIdOverflow)
|
||||||
|
} else {
|
||||||
|
Ok(StreamId(next))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -61,13 +61,14 @@ where
|
|||||||
pub fn new(
|
pub fn new(
|
||||||
codec: Codec<T, Prioritized<B::Buf>>,
|
codec: Codec<T, Prioritized<B::Buf>>,
|
||||||
settings: &frame::Settings,
|
settings: &frame::Settings,
|
||||||
|
next_stream_id: frame::StreamId
|
||||||
) -> Connection<T, P, B> {
|
) -> Connection<T, P, B> {
|
||||||
// TODO: Actually configure
|
|
||||||
let streams = Streams::new(streams::Config {
|
let streams = Streams::new(streams::Config {
|
||||||
local_init_window_sz: settings
|
local_init_window_sz: settings
|
||||||
.initial_window_size()
|
.initial_window_size()
|
||||||
.unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE),
|
.unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE),
|
||||||
local_max_initiated: None,
|
local_max_initiated: None,
|
||||||
|
local_next_stream_id: next_stream_id,
|
||||||
local_push_enabled: settings.is_push_enabled(),
|
local_push_enabled: settings.is_push_enabled(),
|
||||||
remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE,
|
remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE,
|
||||||
remote_max_initiated: None,
|
remote_max_initiated: None,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ use self::store::{Entry, Store};
|
|||||||
use self::stream::Stream;
|
use self::stream::Stream;
|
||||||
|
|
||||||
use error::Reason::*;
|
use error::Reason::*;
|
||||||
use frame::StreamId;
|
use frame::{StreamId, StreamIdOverflow};
|
||||||
use proto::*;
|
use proto::*;
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
@@ -37,6 +37,9 @@ pub struct Config {
|
|||||||
/// Maximum number of locally initiated streams
|
/// Maximum number of locally initiated streams
|
||||||
pub local_max_initiated: Option<usize>,
|
pub local_max_initiated: Option<usize>,
|
||||||
|
|
||||||
|
/// The stream ID to start the next local stream with
|
||||||
|
pub local_next_stream_id: StreamId,
|
||||||
|
|
||||||
/// If the local peer is willing to receive push promises
|
/// If the local peer is willing to receive push promises
|
||||||
pub local_push_enabled: bool,
|
pub local_push_enabled: bool,
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ where
|
|||||||
flow: FlowControl,
|
flow: FlowControl,
|
||||||
|
|
||||||
/// The lowest stream ID that is still idle
|
/// The lowest stream ID that is still idle
|
||||||
next_stream_id: StreamId,
|
next_stream_id: Result<StreamId, StreamIdOverflow>,
|
||||||
|
|
||||||
/// The stream ID of the last processed stream
|
/// The stream ID of the last processed stream
|
||||||
last_processed_id: StreamId,
|
last_processed_id: StreamId,
|
||||||
@@ -76,7 +76,7 @@ where
|
|||||||
Recv {
|
Recv {
|
||||||
init_window_sz: config.local_init_window_sz,
|
init_window_sz: config.local_init_window_sz,
|
||||||
flow: flow,
|
flow: flow,
|
||||||
next_stream_id: next_stream_id.into(),
|
next_stream_id: Ok(next_stream_id.into()),
|
||||||
pending_window_updates: store::Queue::new(),
|
pending_window_updates: store::Queue::new(),
|
||||||
last_processed_id: StreamId::zero(),
|
last_processed_id: StreamId::zero(),
|
||||||
pending_accept: store::Queue::new(),
|
pending_accept: store::Queue::new(),
|
||||||
@@ -109,12 +109,12 @@ where
|
|||||||
|
|
||||||
self.ensure_can_open(id)?;
|
self.ensure_can_open(id)?;
|
||||||
|
|
||||||
if id < self.next_stream_id {
|
let next_id = self.next_stream_id()?;
|
||||||
|
if id < next_id {
|
||||||
return Err(RecvError::Connection(ProtocolError));
|
return Err(RecvError::Connection(ProtocolError));
|
||||||
}
|
}
|
||||||
|
|
||||||
self.next_stream_id = id;
|
self.next_stream_id = id.next_id();
|
||||||
self.next_stream_id.increment();
|
|
||||||
|
|
||||||
if !counts.can_inc_num_recv_streams() {
|
if !counts.can_inc_num_recv_streams() {
|
||||||
self.refused = Some(id);
|
self.refused = Some(id);
|
||||||
@@ -137,6 +137,13 @@ where
|
|||||||
let is_initial = stream.state.recv_open(frame.is_end_stream())?;
|
let is_initial = stream.state.recv_open(frame.is_end_stream())?;
|
||||||
|
|
||||||
if is_initial {
|
if is_initial {
|
||||||
|
let next_id = self.next_stream_id()?;
|
||||||
|
if frame.stream_id() >= next_id {
|
||||||
|
self.next_stream_id = frame.stream_id().next_id();
|
||||||
|
} else {
|
||||||
|
return Err(RecvError::Connection(ProtocolError));
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: be smarter about this logic
|
// TODO: be smarter about this logic
|
||||||
if frame.stream_id() > self.last_processed_id {
|
if frame.stream_id() > self.last_processed_id {
|
||||||
self.last_processed_id = frame.stream_id();
|
self.last_processed_id = frame.stream_id();
|
||||||
@@ -383,9 +390,12 @@ where
|
|||||||
|
|
||||||
/// Ensures that `id` is not in the `Idle` state.
|
/// Ensures that `id` is not in the `Idle` state.
|
||||||
pub fn ensure_not_idle(&self, id: StreamId) -> Result<(), Reason> {
|
pub fn ensure_not_idle(&self, id: StreamId) -> Result<(), Reason> {
|
||||||
if id >= self.next_stream_id {
|
if let Ok(next) = self.next_stream_id {
|
||||||
return Err(ProtocolError);
|
if id >= next {
|
||||||
|
return Err(ProtocolError);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
// if next_stream_id is overflowed, that's ok.
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -428,6 +438,14 @@ where
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn next_stream_id(&self) -> Result<StreamId, RecvError> {
|
||||||
|
if let Ok(id) = self.next_stream_id {
|
||||||
|
Ok(id)
|
||||||
|
} else {
|
||||||
|
Err(RecvError::Connection(ProtocolError))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns true if the remote peer can reserve a stream with the given ID.
|
/// Returns true if the remote peer can reserve a stream with the given ID.
|
||||||
fn ensure_can_reserve(&self, promised_id: StreamId) -> Result<(), RecvError> {
|
fn ensure_can_reserve(&self, promised_id: StreamId) -> Result<(), RecvError> {
|
||||||
// TODO: Are there other rules?
|
// TODO: Are there other rules?
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ where
|
|||||||
P: Peer,
|
P: Peer,
|
||||||
{
|
{
|
||||||
/// Stream identifier to use for next initialized stream.
|
/// Stream identifier to use for next initialized stream.
|
||||||
next_stream_id: StreamId,
|
next_stream_id: Result<StreamId, StreamIdOverflow>,
|
||||||
|
|
||||||
/// Initial window size of locally initiated streams
|
/// Initial window size of locally initiated streams
|
||||||
init_window_sz: WindowSize,
|
init_window_sz: WindowSize,
|
||||||
@@ -31,11 +31,9 @@ where
|
|||||||
{
|
{
|
||||||
/// Create a new `Send`
|
/// Create a new `Send`
|
||||||
pub fn new(config: &Config) -> Self {
|
pub fn new(config: &Config) -> Self {
|
||||||
let next_stream_id = if P::is_server() { 2 } else { 1 };
|
|
||||||
|
|
||||||
Send {
|
Send {
|
||||||
next_stream_id: next_stream_id.into(),
|
|
||||||
init_window_sz: config.local_init_window_sz,
|
init_window_sz: config.local_init_window_sz,
|
||||||
|
next_stream_id: Ok(config.local_next_stream_id),
|
||||||
prioritize: Prioritize::new(config),
|
prioritize: Prioritize::new(config),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -49,19 +47,17 @@ where
|
|||||||
///
|
///
|
||||||
/// Returns the stream state if successful. `None` if refused
|
/// Returns the stream state if successful. `None` if refused
|
||||||
pub fn open(&mut self, counts: &mut Counts<P>) -> Result<StreamId, UserError> {
|
pub fn open(&mut self, counts: &mut Counts<P>) -> Result<StreamId, UserError> {
|
||||||
self.ensure_can_open()?;
|
|
||||||
|
|
||||||
if !counts.can_inc_num_send_streams() {
|
if !counts.can_inc_num_send_streams() {
|
||||||
return Err(Rejected.into());
|
return Err(Rejected.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let ret = self.next_stream_id;
|
let stream_id = self.try_open()?;
|
||||||
self.next_stream_id.increment();
|
|
||||||
|
|
||||||
// Increment the number of locally initiated streams
|
// Increment the number of locally initiated streams
|
||||||
counts.inc_num_send_streams();
|
counts.inc_num_send_streams();
|
||||||
|
self.next_stream_id = stream_id.next_id();
|
||||||
|
|
||||||
Ok(ret)
|
Ok(stream_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send_headers(
|
pub fn send_headers(
|
||||||
@@ -293,22 +289,23 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn ensure_not_idle(&self, id: StreamId) -> Result<(), Reason> {
|
pub fn ensure_not_idle(&self, id: StreamId) -> Result<(), Reason> {
|
||||||
if id >= self.next_stream_id {
|
if let Ok(next) = self.next_stream_id {
|
||||||
return Err(ProtocolError);
|
if id >= next {
|
||||||
|
return Err(ProtocolError);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
// if next_stream_id is overflowed, that's ok.
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns true if the local actor can initiate a stream with the given ID.
|
/// Returns a new StreamId if the local actor can initiate a new stream.
|
||||||
fn ensure_can_open(&self) -> Result<(), UserError> {
|
fn try_open(&self) -> Result<StreamId, UserError> {
|
||||||
if P::is_server() {
|
if P::is_server() {
|
||||||
// Servers cannot open streams. PushPromise must first be reserved.
|
// Servers cannot open streams. PushPromise must first be reserved.
|
||||||
return Err(UnexpectedFrameType);
|
return Err(UnexpectedFrameType);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Handle StreamId overflow
|
self.next_stream_id.map_err(|_| OverflowedStreamId)
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ where
|
|||||||
let handshake = Flush::new(codec)
|
let handshake = Flush::new(codec)
|
||||||
.and_then(ReadPreface::new)
|
.and_then(ReadPreface::new)
|
||||||
.map(move |codec| {
|
.map(move |codec| {
|
||||||
let connection = Connection::new(codec, &settings);
|
let connection = Connection::new(codec, &settings, 2.into());
|
||||||
Server {
|
Server {
|
||||||
connection,
|
connection,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,6 +57,55 @@ fn recv_invalid_server_stream_id() {
|
|||||||
assert!(stream.wait().is_err());
|
assert!(stream.wait().is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_stream_id_overflows() {
|
||||||
|
let _ = ::env_logger::init();
|
||||||
|
let (io, srv) = mock::new();
|
||||||
|
|
||||||
|
|
||||||
|
let h2 = Client::builder()
|
||||||
|
.initial_stream_id(::std::u32::MAX >> 1)
|
||||||
|
.handshake::<_, Bytes>(io)
|
||||||
|
.expect("handshake")
|
||||||
|
.and_then(|mut h2| {
|
||||||
|
let request = Request::builder()
|
||||||
|
.method(Method::GET)
|
||||||
|
.uri("https://example.com/")
|
||||||
|
.body(())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// first request is allowed
|
||||||
|
let req = h2.send_request(request, true)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let request = Request::builder()
|
||||||
|
.method(Method::GET)
|
||||||
|
.uri("https://example.com/")
|
||||||
|
.body(())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// second cant use the next stream id, it's over
|
||||||
|
let err = h2.send_request(request, true).unwrap_err();
|
||||||
|
assert_eq!(err.to_string(), "user error: stream ID overflowed");
|
||||||
|
|
||||||
|
h2.expect("h2").join(req)
|
||||||
|
});
|
||||||
|
|
||||||
|
let srv = srv.assert_client_handshake()
|
||||||
|
.unwrap()
|
||||||
|
.recv_settings()
|
||||||
|
.recv_frame(
|
||||||
|
frames::headers(::std::u32::MAX >> 1)
|
||||||
|
.request("GET", "https://example.com/")
|
||||||
|
.eos(),
|
||||||
|
)
|
||||||
|
.send_frame(frames::headers(::std::u32::MAX >> 1).response(200))
|
||||||
|
.close();
|
||||||
|
|
||||||
|
h2.join(srv).wait().expect("wait");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
fn request_without_scheme() {}
|
fn request_without_scheme() {}
|
||||||
|
|||||||
Reference in New Issue
Block a user