From 5c0efcf8c441fbdb1955ec2ac40809011f56281f Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Sun, 10 Sep 2017 16:01:19 -0700 Subject: [PATCH] Ref count stream state and release when final (#73) Previously, stream state was never released so that long-lived connections leaked memory. Now, stream states are reference-counted and freed from the stream slab when complete. Locally reset streams are retained so that received frames may be ignored. --- Cargo.toml | 1 + src/client.rs | 23 +++++ src/lib.rs | 2 + src/proto/connection.rs | 15 +++ src/proto/peer.rs | 5 + src/proto/streams/counts.rs | 152 ++++++++++++++++++++++++++++++ src/proto/streams/mod.rs | 2 + src/proto/streams/prioritize.rs | 9 +- src/proto/streams/recv.rs | 59 +++--------- src/proto/streams/send.rs | 57 ++---------- src/proto/streams/store.rs | 98 ++++++++++++++----- src/proto/streams/stream.rs | 57 ++++++++++++ src/proto/streams/streams.rs | 160 +++++++++++++++++++++----------- tests/stream_states.rs | 52 +++++++++++ tests/support/src/future_ext.rs | 40 +++++--- 15 files changed, 542 insertions(+), 190 deletions(-) create mode 100644 src/proto/streams/counts.rs diff --git a/Cargo.toml b/Cargo.toml index b8e7e05..6ee36b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ log = "0.3.8" fnv = "1.0.5" slab = "0.4.0" string = { git = "https://github.com/carllerche/string" } +ordermap = "0.2" [dev-dependencies] diff --git a/src/client.rs b/src/client.rs index a7d8721..652cb9e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -142,6 +142,29 @@ impl fmt::Debug for Client } } +#[cfg(feature = "unstable")] +impl Client + where T: AsyncRead + AsyncWrite, + B: IntoBuf +{ + /// Returns the number of active streams. + /// + /// An active stream is a stream that has not yet transitioned to a closed + /// state. + pub fn num_active_streams(&self) -> usize { + self.connection.num_active_streams() + } + + /// Returns the number of streams that are held in memory. + /// + /// A wired stream is a stream that is either active or is closed but must + /// stay in memory for some reason. For example, there are still outstanding + /// userspace handles pointing to the slot. + pub fn num_wired_streams(&self) -> usize { + self.connection.num_wired_streams() + } +} + // ===== impl Handshake ===== impl Future for Handshake diff --git a/src/lib.rs b/src/lib.rs index d5030fa..e8e6e91 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,8 @@ extern crate log; extern crate string; +extern crate ordermap; + mod error; mod codec; mod hpack; diff --git a/src/proto/connection.rs b/src/proto/connection.rs index 7617890..0334b70 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -253,3 +253,18 @@ impl Connection self.streams.next_incoming() } } + +#[cfg(feature = "unstable")] +impl Connection + where T: AsyncRead + AsyncWrite, + P: Peer, + B: IntoBuf, +{ + pub fn num_active_streams(&self) -> usize { + self.streams.num_active_streams() + } + + pub fn num_wired_streams(&self) -> usize { + self.streams.num_wired_streams() + } +} diff --git a/src/proto/peer.rs b/src/proto/peer.rs index 2ab847b..2925bd3 100644 --- a/src/proto/peer.rs +++ b/src/proto/peer.rs @@ -19,4 +19,9 @@ pub trait Peer { end_of_stream: bool) -> Headers; fn convert_poll_message(headers: Headers) -> Result; + + fn is_local_init(id: StreamId) -> bool { + assert!(!id.is_zero()); + Self::is_server() == id.is_server_initiated() + } } diff --git a/src/proto/streams/counts.rs b/src/proto/streams/counts.rs new file mode 100644 index 0000000..7e69969 --- /dev/null +++ b/src/proto/streams/counts.rs @@ -0,0 +1,152 @@ +use client; +use super::*; + +use std::usize; +use std::marker::PhantomData; + +#[derive(Debug)] +pub(super) struct Counts

+ where P: Peer, +{ + /// Maximum number of locally initiated streams + max_send_streams: Option, + + /// Current number of remote initiated streams + num_send_streams: usize, + + /// Maximum number of remote initiated streams + max_recv_streams: Option, + + /// Current number of locally initiated streams + num_recv_streams: usize, + + /// Task awaiting notification to open a new stream. + blocked_open: Option, + + _p: PhantomData

, +} + +impl

Counts

+ where P: Peer, +{ + /// Create a new `Counts` using the provided configuration values. + pub fn new(config: &Config) -> Self { + Counts { + max_send_streams: config.max_local_initiated, + num_send_streams: 0, + max_recv_streams: config.max_remote_initiated, + num_recv_streams: 0, + blocked_open: None, + _p: PhantomData, + } + } + + /// Returns true if the receive stream concurrency can be incremented + pub fn can_inc_num_recv_streams(&self) -> bool { + if let Some(max) = self.max_recv_streams { + max > self.num_recv_streams + } else { + true + } + } + + /// Increments the number of concurrent receive streams. + /// + /// # Panics + /// + /// Panics on failure as this should have been validated before hand. + pub fn inc_num_recv_streams(&mut self) { + assert!(self.can_inc_num_recv_streams()); + + // Increment the number of remote initiated streams + self.num_recv_streams += 1; + } + + /// Returns true if the send stream concurrency can be incremented + pub fn can_inc_num_send_streams(&self) -> bool { + if let Some(max) = self.max_send_streams { + max > self.num_send_streams + } else { + true + } + } + + /// Increments the number of concurrent send streams. + /// + /// # Panics + /// + /// Panics on failure as this should have been validated before hand. + pub fn inc_num_send_streams(&mut self) { + assert!(self.can_inc_num_send_streams()); + + // Increment the number of remote initiated streams + self.num_send_streams += 1; + } + + pub fn apply_remote_settings(&mut self, settings: &frame::Settings) { + if let Some(val) = settings.max_concurrent_streams() { + self.max_send_streams = Some(val as usize); + } + } + + /// Run a block of code that could potentially transition a stream's state. + /// + /// If the stream state transitions to closed, this function will perform + /// all necessary cleanup. + pub fn transition(&mut self, mut stream: store::Ptr, f: F) -> U + where F: FnOnce(&mut Self, &mut store::Ptr) -> U + { + let is_counted = stream.state.is_counted(); + + // Run the action + let ret = f(self, &mut stream); + + self.transition_after(stream, is_counted); + + ret + } + + // TODO: move this to macro? + pub fn transition_after(&mut self, mut stream: store::Ptr, is_counted: bool) { + if stream.is_closed() { + stream.unlink(); + + if is_counted { + // Decrement the number of active streams. + self.dec_num_streams(stream.id); + } + } + + // Release the stream if it requires releasing + if stream.is_released() { + stream.remove(); + } + } + + fn dec_num_streams(&mut self, id: StreamId) { + use std::usize; + + if P::is_local_init(id) { + self.num_send_streams -= 1; + + if self.num_send_streams < self.max_send_streams.unwrap_or(usize::MAX) { + if let Some(task) = self.blocked_open.take() { + task.notify(); + } + } + } else { + self.num_recv_streams -= 1; + } + } +} + +impl Counts { + pub fn poll_open_ready(&mut self) -> Async<()> { + if !self.can_inc_num_send_streams() { + self.blocked_open = Some(task::current()); + return Async::NotReady; + } + + return Async::Ready(()); + } +} diff --git a/src/proto/streams/mod.rs b/src/proto/streams/mod.rs index c875c0f..681a85f 100644 --- a/src/proto/streams/mod.rs +++ b/src/proto/streams/mod.rs @@ -1,4 +1,5 @@ mod buffer; +mod counts; mod flow_control; mod prioritize; mod recv; @@ -12,6 +13,7 @@ pub(crate) use self::streams::{Streams, StreamRef}; pub(crate) use self::prioritize::Prioritized; use self::buffer::Buffer; +use self::counts::Counts; use self::flow_control::FlowControl; use self::prioritize::Prioritize; use self::recv::Recv; diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index c92d19c..915b021 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -325,6 +325,7 @@ impl Prioritize pub fn poll_complete(&mut self, store: &mut Store, + counts: &mut Counts

, dst: &mut Codec>) -> Poll<(), io::Error> where T: AsyncWrite, @@ -341,7 +342,7 @@ impl Prioritize trace!("poll_complete"); loop { - match self.pop_frame(store, max_frame_len) { + match self.pop_frame(store, max_frame_len, counts) { Some(frame) => { trace!("writing frame={:?}", frame); @@ -433,7 +434,7 @@ impl Prioritize } } - fn pop_frame(&mut self, store: &mut Store, max_len: usize) + fn pop_frame(&mut self, store: &mut Store, max_len: usize, counts: &mut Counts

) -> Option>> { trace!("pop_frame"); @@ -444,6 +445,8 @@ impl Prioritize trace!("pop_frame; stream={:?}", stream.id); debug_assert!(!stream.pending_send.is_empty()); + let is_counted = stream.state.is_counted(); + let frame = match stream.pending_send.pop_front(&mut self.buffer).unwrap() { Frame::Data(mut frame) => { // Get the amount of capacity remaining for stream's @@ -541,6 +544,8 @@ impl Prioritize self.pending_send.push(&mut stream); } + counts.transition_after(stream, is_counted); + return Some(frame); } None => return None, diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index d814d22..551845a 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -13,12 +13,6 @@ use std::marker::PhantomData; pub(super) struct Recv where P: Peer, { - /// Maximum number of remote initiated streams - max_streams: Option, - - /// Current number of remote initiated streams - num_streams: usize, - /// Initial window size of remote initiated streams init_window_sz: WindowSize, @@ -77,8 +71,6 @@ impl Recv flow.assign_capacity(config.init_remote_window_sz); Recv { - max_streams: config.max_remote_initiated, - num_streams: 0, init_window_sz: config.init_remote_window_sz, flow: flow, next_stream_id: next_stream_id.into(), @@ -104,14 +96,21 @@ impl Recv /// Update state reflecting a new, remotely opened stream /// /// Returns the stream state if successful. `None` if refused - pub fn open(&mut self, id: StreamId) + pub fn open(&mut self, id: StreamId, counts: &mut Counts

) -> Result, RecvError> { assert!(self.refused.is_none()); try!(self.ensure_can_open(id)); - if !self.can_inc_num_streams() { + if id < self.next_stream_id { + return Err(RecvError::Connection(ProtocolError)); + } + + self.next_stream_id = id; + self.next_stream_id.increment(); + + if !counts.can_inc_num_recv_streams() { self.refused = Some(id); return Ok(None); } @@ -124,31 +123,21 @@ impl Recv /// The caller ensures that the frame represents headers and not trailers. pub fn recv_headers(&mut self, frame: frame::Headers, - stream: &mut store::Ptr) + stream: &mut store::Ptr, + counts: &mut Counts

) -> Result<(), RecvError> { trace!("opening stream; init_window={}", self.init_window_sz); let is_initial = stream.state.recv_open(frame.is_end_stream())?; if is_initial { - if !self.can_inc_num_streams() { - unimplemented!(); - } - - if frame.stream_id() >= self.next_stream_id { - self.next_stream_id = frame.stream_id(); - self.next_stream_id.increment(); - } else { - return Err(RecvError::Connection(ProtocolError)); - } - // TODO: be smarter about this logic if frame.stream_id() > self.last_processed_id { self.last_processed_id = frame.stream_id(); } // Increment the number of concurrent streams - self.inc_num_streams(); + counts.inc_num_recv_streams(); } if !stream.content_length.is_head() { @@ -397,30 +386,6 @@ impl Recv stream.notify_recv(); } - /// Returns true if the current stream concurrency can be incremetned - fn can_inc_num_streams(&self) -> bool { - if let Some(max) = self.max_streams { - max > self.num_streams - } else { - true - } - } - - /// Increments the number of concurrenty streams. Panics on failure as this - /// should have been validated before hand. - fn inc_num_streams(&mut self) { - if !self.can_inc_num_streams() { - panic!(); - } - - // Increment the number of remote initiated streams - self.num_streams += 1; - } - - pub fn dec_num_streams(&mut self) { - self.num_streams -= 1; - } - /// Returns true if the remote peer can initiate a stream with the given ID. fn ensure_can_open(&self, id: StreamId) -> Result<(), RecvError> diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index 431b946..f61f61a 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -1,4 +1,3 @@ -use client; use frame::{self, Reason}; use codec::{RecvError, UserError}; use codec::UserError::*; @@ -14,21 +13,12 @@ use std::io; pub(super) struct Send where P: Peer, { - /// Maximum number of locally initiated streams - max_streams: Option, - - /// Current number of locally initiated streams - num_streams: usize, - /// Stream identifier to use for next initialized stream. next_stream_id: StreamId, /// Initial window size of locally initiated streams init_window_sz: WindowSize, - /// Task awaiting notification to open a new stream. - blocked_open: Option, - /// Prioritization layer prioritize: Prioritize, } @@ -42,11 +32,8 @@ where B: Buf, let next_stream_id = if P::is_server() { 2 } else { 1 }; Send { - max_streams: config.max_local_initiated, - num_streams: 0, next_stream_id: next_stream_id.into(), init_window_sz: config.init_local_window_sz, - blocked_open: None, prioritize: Prioritize::new(config), } } @@ -59,22 +46,20 @@ where B: Buf, /// Update state reflecting a new, locally opened stream /// /// Returns the stream state if successful. `None` if refused - pub fn open(&mut self) + pub fn open(&mut self, counts: &mut Counts

) -> Result { try!(self.ensure_can_open()); - if let Some(max) = self.max_streams { - if max <= self.num_streams { - return Err(Rejected.into()); - } + if !counts.can_inc_num_send_streams() { + return Err(Rejected.into()); } let ret = self.next_stream_id; + self.next_stream_id.increment(); // Increment the number of locally initiated streams - self.num_streams += 1; - self.next_stream_id.increment(); + counts.inc_num_send_streams(); Ok(ret) } @@ -167,11 +152,12 @@ where B: Buf, pub fn poll_complete(&mut self, store: &mut Store, + counts: &mut Counts

, dst: &mut Codec>) -> Poll<(), io::Error> where T: AsyncWrite, { - self.prioritize.poll_complete(store, dst) + self.prioritize.poll_complete(store, counts, dst) } /// Request capacity to send data @@ -237,10 +223,6 @@ where B: Buf, task: &mut Option) -> Result<(), RecvError> { - if let Some(val) = settings.max_concurrent_streams() { - self.max_streams = Some(val as usize); - } - // Applies an update to the remote endpoint's initial window size. // // Per RFC 7540 ยง6.9.2: @@ -304,16 +286,6 @@ where B: Buf, Ok(()) } - pub fn dec_num_streams(&mut self) { - self.num_streams -= 1; - - if self.num_streams < self.max_streams.unwrap_or(::std::usize::MAX) { - if let Some(task) = self.blocked_open.take() { - task.notify(); - } - } - } - /// Returns true if the local actor can initiate a stream with the given ID. fn ensure_can_open(&self) -> Result<(), UserError> { if P::is_server() { @@ -326,18 +298,3 @@ where B: Buf, Ok(()) } } - -impl Send -where B: Buf, -{ - pub fn poll_open_ready(&mut self) -> Async<()> { - if let Some(max) = self.max_streams { - if max <= self.num_streams { - self.blocked_open = Some(task::current()); - return Async::NotReady; - } - } - - return Async::Ready(()); - } -} diff --git a/src/proto/streams/store.rs b/src/proto/streams/store.rs index 1eb3e56..1446fdc 100644 --- a/src/proto/streams/store.rs +++ b/src/proto/streams/store.rs @@ -2,8 +2,9 @@ use super::*; use slab; +use ordermap::{self, OrderMap}; + use std::ops; -use std::collections::{HashMap, hash_map}; use std::marker::PhantomData; /// Storage for streams @@ -12,7 +13,7 @@ pub(super) struct Store where P: Peer, { slab: slab::Slab>, - ids: HashMap, + ids: OrderMap, } /// "Pointer" to an entry in the store @@ -20,7 +21,7 @@ pub(super) struct Ptr<'a, B: 'a, P> where P: Peer + 'a, { key: Key, - slab: &'a mut slab::Slab>, + store: &'a mut Store, } /// References an entry in the store. @@ -60,13 +61,13 @@ pub(super) enum Entry<'a, B: 'a, P: Peer + 'a> { } pub(super) struct OccupiedEntry<'a> { - ids: hash_map::OccupiedEntry<'a, StreamId, usize>, + ids: ordermap::OccupiedEntry<'a, StreamId, usize>, } pub(super) struct VacantEntry<'a, B: 'a, P> where P: Peer + 'a, { - ids: hash_map::VacantEntry<'a, StreamId, usize>, + ids: ordermap::VacantEntry<'a, StreamId, usize>, slab: &'a mut slab::Slab>, } @@ -84,19 +85,24 @@ impl Store pub fn new() -> Self { Store { slab: slab::Slab::new(), - ids: HashMap::new(), + ids: OrderMap::new(), } } + pub fn contains_id(&self, id: &StreamId) -> bool { + self.ids.contains_key(id) + } + pub fn find_mut(&mut self, id: &StreamId) -> Option> { - if let Some(&key) = self.ids.get(id) { - Some(Ptr { - key: Key(key), - slab: &mut self.slab, - }) - } else { - None - } + let key = match self.ids.get(id) { + Some(key) => *key, + None => return None, + }; + + Some(Ptr { + key: Key(key), + store: self, + }) } pub fn insert(&mut self, id: StreamId, val: Stream) -> Ptr { @@ -105,12 +111,12 @@ impl Store Ptr { key: Key(key), - slab: &mut self.slab, + store: self, } } pub fn find_entry(&mut self, id: StreamId) -> Entry { - use self::hash_map::Entry::*; + use self::ordermap::Entry::*; match self.ids.entry(id) { Occupied(e) => { @@ -130,11 +136,27 @@ impl Store pub fn for_each(&mut self, mut f: F) -> Result<(), E> where F: FnMut(Ptr) -> Result<(), E>, { - for &key in self.ids.values() { + let mut len = self.ids.len(); + let mut i = 0; + + while i < len { + // Get the key by index, this makes the borrow checker happy + let key = *self.ids.get_index(i).unwrap().1; + f(Ptr { key: Key(key), - slab: &mut self.slab, + store: self, })?; + + // TODO: This logic probably could be better... + let new_len = self.ids.len(); + + if new_len < len { + debug_assert!(new_len == len - 1); + len -= 1; + } else { + i += 1; + } } Ok(()) @@ -147,7 +169,7 @@ impl Resolve for Store fn resolve(&mut self, key: Key) -> Ptr { Ptr { key: key, - slab: &mut self.slab, + store: self, } } } @@ -170,6 +192,19 @@ impl ops::IndexMut for Store } } +#[cfg(feature = "unstable")] +impl Store + where P: Peer, +{ + pub fn num_active_streams(&self) -> usize { + self.ids.len() + } + + pub fn num_wired_streams(&self) -> usize { + self.slab.len() + } +} + // ===== impl Queue ===== impl Queue @@ -263,9 +298,28 @@ impl Queue impl<'a, B: 'a, P> Ptr<'a, B, P> where P: Peer, { + /// Returns the Key associated with the stream pub fn key(&self) -> Key { self.key } + + /// Remove the stream from the store + pub fn remove(self) -> StreamId { + // The stream must have been unlinked before this point + debug_assert!(!self.store.ids.contains_key(&self.id)); + + // Remove the stream state + self.store.slab.remove(self.key.0).id + } + + /// Remove the StreamId -> stream state association. + /// + /// This will effectively remove the stream as far as the H2 protocol is + /// concerned. + pub fn unlink(&mut self) { + let id = self.id; + self.store.ids.remove(&id); + } } impl<'a, B: 'a, P> Resolve for Ptr<'a, B, P> @@ -274,7 +328,7 @@ impl<'a, B: 'a, P> Resolve for Ptr<'a, B, P> fn resolve(&mut self, key: Key) -> Ptr { Ptr { key: key, - slab: &mut *self.slab, + store: &mut *self.store, } } } @@ -285,7 +339,7 @@ impl<'a, B: 'a, P> ops::Deref for Ptr<'a, B, P> type Target = Stream; fn deref(&self) -> &Stream { - &self.slab[self.key.0] + &self.store.slab[self.key.0] } } @@ -293,7 +347,7 @@ impl<'a, B: 'a, P> ops::DerefMut for Ptr<'a, B, P> where P: Peer, { fn deref_mut(&mut self) -> &mut Stream { - &mut self.slab[self.key.0] + &mut self.store.slab[self.key.0] } } diff --git a/src/proto/streams/stream.rs b/src/proto/streams/stream.rs index 57e0d8c..b341a62 100644 --- a/src/proto/streams/stream.rs +++ b/src/proto/streams/stream.rs @@ -1,5 +1,18 @@ use super::*; +use std::usize; + +/// Tracks Stream related state +/// +/// # Reference counting +/// +/// There can be a number of outstanding handles to a single Stream. These are +/// tracked using reference counting. The `ref_count` field represents the +/// number of outstanding userspace handles that can reach this stream. +/// +/// It's important to note that when the stream is placed in an internal queue +/// (such as an accept queue), this is **not** tracked by a reference count. +/// Thus, `ref_count` can be zero and the stream still has to be kept around. #[derive(Debug)] pub(super) struct Stream where P: Peer, @@ -10,6 +23,9 @@ pub(super) struct Stream /// Current state of the stream pub state: State, + /// Number of outstanding handles pointing to this stream + pub ref_count: usize, + // ===== Fields related to sending ===== /// Next node in the accept linked list @@ -117,6 +133,7 @@ impl Stream Stream { id, state: State::default(), + ref_count: 0, // ===== Fields related to sending ===== @@ -146,6 +163,46 @@ impl Stream } } + /// Increment the stream's ref count + pub fn ref_inc(&mut self) { + assert!(self.ref_count < usize::MAX); + self.ref_count += 1; + } + + /// Decrements the stream's ref count + pub fn ref_dec(&mut self) { + assert!(self.ref_count > 0); + self.ref_count -= 1; + } + + /// Returns true if the stream is closed + pub fn is_closed(&self) -> bool { + // The state has fully transitioned to closed. + self.state.is_closed() && + // Because outbound frames transition the stream state before being + // buffered, we have to ensure that all frames have been flushed. + self.pending_send.is_empty() && + // Sometimes large data frames are sent out in chunks. After a chunk + // of the frame is sent, the remainder is pushed back onto the send + // queue to be rescheduled. + // + // Checking for additional buffered data lets us catch this case. + self.buffered_send_data == 0 + } + + /// Returns true if the stream is no longer in use + pub fn is_released(&self) -> bool { + // The stream is closed and fully flushed + self.is_closed() && + // There are no more outstanding references to the stream + self.ref_count == 0 && + // The stream is not in any queue + !self.is_pending_send && + !self.is_pending_send_capacity && + !self.is_pending_accept && + !self.is_pending_window_update + } + pub fn assign_capacity(&mut self, capacity: WindowSize) { debug_assert!(capacity > 0); self.send_capacity_inc = true; diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 3aac220..4a00a9c 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -34,6 +34,8 @@ pub(crate) struct StreamRef struct Inner where P: Peer, { + /// Tracks send & recv stream concurrency. + counts: Counts

, actions: Actions, store: Store, } @@ -59,6 +61,7 @@ impl Streams pub fn new(config: Config) -> Self { Streams { inner: Arc::new(Mutex::new(Inner { + counts: Counts::new(&config), actions: Actions { recv: Recv::new(&config), send: Send::new(&config), @@ -80,7 +83,7 @@ impl Streams let key = match me.store.find_entry(id) { Entry::Occupied(e) => e.key(), Entry::Vacant(e) => { - match try!(me.actions.recv.open(id)) { + match try!(me.actions.recv.open(id, &mut me.counts)) { Some(stream_id) => { let stream = Stream::new( stream_id, @@ -95,10 +98,13 @@ impl Streams }; let stream = me.store.resolve(key); + let actions = &mut me.actions; + + me.counts.transition(stream, |counts, stream| { + trace!("recv_headers; stream={:?}; state={:?}", stream.id, stream.state); - me.actions.transition(stream, |actions, stream| { let res = if stream.state.is_recv_headers() { - actions.recv.recv_headers(frame, stream) + actions.recv.recv_headers(frame, stream, counts) } else { if !frame.is_end_stream() { // TODO: Is this the right error @@ -133,7 +139,9 @@ impl Streams None => return Err(RecvError::Connection(ProtocolError)), }; - me.actions.transition(stream, |actions, stream| { + let actions = &mut me.actions; + + me.counts.transition(stream, |_, stream| { match actions.recv.recv_data(frame, stream) { Err(RecvError::Stream { reason, .. }) => { // Reset the stream. @@ -168,7 +176,9 @@ impl Streams } }; - me.actions.transition(stream, |actions, stream| { + let actions = &mut me.actions; + + me.counts.transition(stream, |_, stream| { actions.recv.recv_reset(frame, stream)?; assert!(stream.state.is_closed()); Ok(()) @@ -181,12 +191,16 @@ impl Streams let me = &mut *me; let actions = &mut me.actions; + let counts = &mut me.counts; + let last_processed_id = actions.recv.last_processed_id(); - me.store.for_each(|mut stream| { - actions.recv.recv_err(err, &mut *stream); - Ok::<_, ()>(()) - }).ok().expect("unexpected error processing error"); + me.store.for_each(|stream| { + counts.transition(stream, |_, stream| { + actions.recv.recv_err(err, &mut *stream); + Ok::<_, ()>(()) + }) + }).unwrap(); last_processed_id } @@ -244,7 +258,16 @@ impl Streams let mut me = self.inner.lock().unwrap(); let me = &mut *me; - me.actions.recv.next_incoming(&mut me.store) + match me.actions.recv.next_incoming(&mut me.store) { + Some(key) => { + // Increment the ref count + me.store.resolve(key).ref_inc(); + + // Return the key + Some(key) + } + None => None, + } }; key.map(|key| { @@ -278,7 +301,7 @@ impl Streams try_ready!(me.actions.recv.poll_complete(&mut me.store, dst)); // Send any other pending frames - try_ready!(me.actions.send.poll_complete(&mut me.store, dst)); + try_ready!(me.actions.send.poll_complete(&mut me.store, &mut me.counts, dst)); // Nothing else to do, track the task me.actions.task = Some(task::current()); @@ -292,6 +315,8 @@ impl Streams let mut me = self.inner.lock().unwrap(); let me = &mut *me; + me.counts.apply_remote_settings(frame); + me.actions.send.apply_remote_settings( frame, &mut me.store, &mut me.actions.task) } @@ -312,7 +337,7 @@ impl Streams let me = &mut *me; // Initialize a new stream. This fails if the connection is at capacity. - let stream_id = me.actions.send.open()?; + let stream_id = me.actions.send.open(&mut me.counts)?; let mut stream = Stream::new( stream_id, @@ -336,6 +361,9 @@ impl Streams // closed state. debug_assert!(!stream.state.is_closed()); + // Increment the stream ref count as we will be returning a handle. + stream.ref_inc(); + stream.key() }; @@ -352,7 +380,7 @@ impl Streams let key = match me.store.find_entry(id) { Entry::Occupied(e) => e.key(), Entry::Vacant(e) => { - match me.actions.recv.open(id) { + match me.actions.recv.open(id, &mut me.counts) { Ok(Some(stream_id)) => { let stream = Stream::new( stream_id, 0, 0); @@ -364,10 +392,10 @@ impl Streams } }; - let stream = me.store.resolve(key); + let actions = &mut me.actions; - me.actions.transition(stream, move |actions, stream| { + me.counts.transition(stream, |_, stream| { actions.send.send_reset(reason, stream, &mut actions.task) }) } @@ -380,7 +408,23 @@ impl Streams let mut me = self.inner.lock().unwrap(); let me = &mut *me; - me.actions.send.poll_open_ready() + me.counts.poll_open_ready() + } +} + +#[cfg(feature = "unstable")] +impl Streams + where B: Buf, + P: Peer, +{ + pub fn num_active_streams(&self) -> usize { + let me = self.inner.lock().unwrap(); + me.store.num_active_streams() + } + + pub fn num_wired_streams(&self) -> usize { + let me = self.inner.lock().unwrap(); + me.store.num_wired_streams() } } @@ -397,12 +441,13 @@ impl StreamRef let me = &mut *me; let stream = me.store.resolve(self.key); + let actions = &mut me.actions; - // Create the data frame - let mut frame = frame::Data::new(stream.id, data); - frame.set_end_stream(end_stream); + me.counts.transition(stream, |_, stream| { + // Create the data frame + let mut frame = frame::Data::new(stream.id, data); + frame.set_end_stream(end_stream); - me.actions.transition(stream, |actions, stream| { // Send the data frame actions.send.send_data(frame, stream, &mut actions.task) }) @@ -415,11 +460,12 @@ impl StreamRef let me = &mut *me; let stream = me.store.resolve(self.key); + let actions = &mut me.actions; - // Create the trailers frame - let frame = frame::Headers::trailers(stream.id, trailers); + me.counts.transition(stream, |_, stream| { + // Create the trailers frame + let frame = frame::Headers::trailers(stream.id, trailers); - me.actions.transition(stream, |actions, stream| { // Send the trailers frame actions.send.send_trailers(frame, stream, &mut actions.task) }) @@ -430,7 +476,9 @@ impl StreamRef let me = &mut *me; let stream = me.store.resolve(self.key); - me.actions.transition(stream, move |actions, stream| { + let actions = &mut me.actions; + + me.counts.transition(stream, |_, stream| { actions.send.send_reset(reason, stream, &mut actions.task) }) } @@ -442,11 +490,12 @@ impl StreamRef let me = &mut *me; let stream = me.store.resolve(self.key); + let actions = &mut me.actions; - let frame = server::Peer::convert_send_message( - stream.id, response, end_of_stream); + me.counts.transition(stream, |_, stream| { + let frame = server::Peer::convert_send_message( + stream.id, response, end_of_stream); - me.actions.transition(stream, |actions, stream| { actions.send.send_headers(frame, stream, &mut actions.task) }) } @@ -559,6 +608,11 @@ impl Clone for StreamRef where P: Peer, { fn clone(&self) -> Self { + // Increment the ref count + self.inner.lock().unwrap() + .store.resolve(self.key) + .ref_inc(); + StreamRef { inner: self.inner.clone(), key: self.key.clone(), @@ -566,6 +620,29 @@ impl Clone for StreamRef } } +impl Drop for StreamRef + where P: Peer, +{ + fn drop(&mut self) { + let mut me = self.inner.lock().unwrap(); + + let me = &mut *me; + + let id = { + let mut stream = me.store.resolve(self.key); + stream.ref_dec(); + + if !stream.is_released() { + return; + } + + stream.remove() + }; + + debug_assert!(!me.store.contains_id(&id)); + } +} + // ===== impl Actions ===== impl Actions @@ -575,37 +652,10 @@ impl Actions fn ensure_not_idle(&mut self, id: StreamId) -> Result<(), Reason> { - if self.is_local_init(id) { + if P::is_local_init(id) { self.send.ensure_not_idle(id) } else { self.recv.ensure_not_idle(id) } } - - fn dec_num_streams(&mut self, id: StreamId) { - if self.is_local_init(id) { - self.send.dec_num_streams(); - } else { - self.recv.dec_num_streams(); - } - } - - fn is_local_init(&self, id: StreamId) -> bool { - assert!(!id.is_zero()); - P::is_server() == id.is_server_initiated() - } - - fn transition(&mut self, mut stream: store::Ptr, f: F) -> U - where F: FnOnce(&mut Self, &mut store::Ptr) -> U, - { - let is_counted = stream.state.is_counted(); - - let ret = f(self, &mut stream); - - if is_counted && stream.state.is_closed() { - self.dec_num_streams(stream.id); - } - - ret - } } diff --git a/tests/stream_states.rs b/tests/stream_states.rs index 2a2eee9..6e7cb7e 100644 --- a/tests/stream_states.rs +++ b/tests/stream_states.rs @@ -149,6 +149,58 @@ fn send_headers_recv_data_single_frame() { h2.wait().unwrap(); } +#[test] +fn closed_streams_are_released() { + let _ = ::env_logger::init(); + let (io, srv) = mock::new(); + + let h2 = Client::handshake(io).unwrap() + .and_then(|mut h2| { + let request = Request::get("https://example.com/") + .body(()).unwrap(); + + // Send request + let stream = h2.request(request, true).unwrap(); + h2.drive(stream) + }) + .and_then(|(h2, response)| { + assert_eq!(response.status(), StatusCode::NO_CONTENT); + + // There are no active streams + assert_eq!(0, h2.num_active_streams()); + + // The response contains a handle for the body. This keeps the + // stream wired. + assert_eq!(1, h2.num_wired_streams()); + + drop(response); + + // The stream state is now free + assert_eq!(0, h2.num_wired_streams()); + + Ok(()) + }) + ; + + let srv = srv.assert_client_handshake().unwrap() + .recv_settings() + .recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos() + ) + .send_frame( + frames::headers(1) + .response(204) + .eos() + ) + .close() + ; + + let _ = h2.join(srv) + .wait().unwrap(); +} + /* #[test] fn send_data_after_headers_eos() { diff --git a/tests/support/src/future_ext.rs b/tests/support/src/future_ext.rs index 2794e51..f4397c7 100644 --- a/tests/support/src/future_ext.rs +++ b/tests/support/src/future_ext.rs @@ -71,23 +71,35 @@ impl Future for Drive type Error = (); fn poll(&mut self) -> Poll { - match self.future.poll() { - Ok(Async::Ready(val)) => { - // Get the driver - let driver = self.driver.take().unwrap(); + let mut looped = false; - return Ok((driver, val).into()) + loop { + match self.future.poll() { + Ok(Async::Ready(val)) => { + // Get the driver + let driver = self.driver.take().unwrap(); + + return Ok((driver, val).into()) + } + Ok(_) => {} + Err(e) => panic!("unexpected error; {:?}", e), } - Ok(_) => {} - Err(e) => panic!("unexpected error; {:?}", e), - } - match self.driver.as_mut().unwrap().poll() { - Ok(Async::Ready(_)) => panic!("driver resolved before future"), - Ok(Async::NotReady) => {} - Err(e) => panic!("unexpected error; {:?}", e), - } + match self.driver.as_mut().unwrap().poll() { + Ok(Async::Ready(_)) => { + if looped { + // Try polling the future one last time + panic!("driver resolved before future") + } else { + looped = true; + continue; + } + } + Ok(Async::NotReady) => {} + Err(e) => panic!("unexpected error; {:?}", e), + } - Ok(Async::NotReady) + return Ok(Async::NotReady); + } } }