Ref count stream state and release when final (#73)

Previously, stream state was never released so that long-lived connections
leaked memory.

Now, stream states are reference-counted and freed from the stream slab
when complete.  Locally reset streams are retained so that received frames
may be ignored.
This commit is contained in:
Carl Lerche
2017-09-10 16:01:19 -07:00
committed by Oliver Gould
parent daa54b9512
commit 5c0efcf8c4
15 changed files with 542 additions and 190 deletions

View File

@@ -24,6 +24,7 @@ log = "0.3.8"
fnv = "1.0.5"
slab = "0.4.0"
string = { git = "https://github.com/carllerche/string" }
ordermap = "0.2"
[dev-dependencies]

View File

@@ -142,6 +142,29 @@ impl<T, B> fmt::Debug for Client<T, B>
}
}
#[cfg(feature = "unstable")]
impl<T, B> Client<T, B>
where T: AsyncRead + AsyncWrite,
B: IntoBuf
{
/// Returns the number of active streams.
///
/// An active stream is a stream that has not yet transitioned to a closed
/// state.
pub fn num_active_streams(&self) -> usize {
self.connection.num_active_streams()
}
/// Returns the number of streams that are held in memory.
///
/// A wired stream is a stream that is either active or is closed but must
/// stay in memory for some reason. For example, there are still outstanding
/// userspace handles pointing to the slot.
pub fn num_wired_streams(&self) -> usize {
self.connection.num_wired_streams()
}
}
// ===== impl Handshake =====
impl<T, B: IntoBuf> Future for Handshake<T, B>

View File

@@ -24,6 +24,8 @@ extern crate log;
extern crate string;
extern crate ordermap;
mod error;
mod codec;
mod hpack;

View File

@@ -253,3 +253,18 @@ impl<T, B> Connection<T, server::Peer, B>
self.streams.next_incoming()
}
}
#[cfg(feature = "unstable")]
impl<T, P, B> Connection<T, P, B>
where T: AsyncRead + AsyncWrite,
P: Peer,
B: IntoBuf,
{
pub fn num_active_streams(&self) -> usize {
self.streams.num_active_streams()
}
pub fn num_wired_streams(&self) -> usize {
self.streams.num_wired_streams()
}
}

View File

@@ -19,4 +19,9 @@ pub trait Peer {
end_of_stream: bool) -> Headers;
fn convert_poll_message(headers: Headers) -> Result<Self::Poll, RecvError>;
fn is_local_init(id: StreamId) -> bool {
assert!(!id.is_zero());
Self::is_server() == id.is_server_initiated()
}
}

152
src/proto/streams/counts.rs Normal file
View File

@@ -0,0 +1,152 @@
use client;
use super::*;
use std::usize;
use std::marker::PhantomData;
#[derive(Debug)]
pub(super) struct Counts<P>
where P: Peer,
{
/// Maximum number of locally initiated streams
max_send_streams: Option<usize>,
/// Current number of remote initiated streams
num_send_streams: usize,
/// Maximum number of remote initiated streams
max_recv_streams: Option<usize>,
/// Current number of locally initiated streams
num_recv_streams: usize,
/// Task awaiting notification to open a new stream.
blocked_open: Option<task::Task>,
_p: PhantomData<P>,
}
impl<P> Counts<P>
where P: Peer,
{
/// Create a new `Counts` using the provided configuration values.
pub fn new(config: &Config) -> Self {
Counts {
max_send_streams: config.max_local_initiated,
num_send_streams: 0,
max_recv_streams: config.max_remote_initiated,
num_recv_streams: 0,
blocked_open: None,
_p: PhantomData,
}
}
/// Returns true if the receive stream concurrency can be incremented
pub fn can_inc_num_recv_streams(&self) -> bool {
if let Some(max) = self.max_recv_streams {
max > self.num_recv_streams
} else {
true
}
}
/// Increments the number of concurrent receive streams.
///
/// # Panics
///
/// Panics on failure as this should have been validated before hand.
pub fn inc_num_recv_streams(&mut self) {
assert!(self.can_inc_num_recv_streams());
// Increment the number of remote initiated streams
self.num_recv_streams += 1;
}
/// Returns true if the send stream concurrency can be incremented
pub fn can_inc_num_send_streams(&self) -> bool {
if let Some(max) = self.max_send_streams {
max > self.num_send_streams
} else {
true
}
}
/// Increments the number of concurrent send streams.
///
/// # Panics
///
/// Panics on failure as this should have been validated before hand.
pub fn inc_num_send_streams(&mut self) {
assert!(self.can_inc_num_send_streams());
// Increment the number of remote initiated streams
self.num_send_streams += 1;
}
pub fn apply_remote_settings(&mut self, settings: &frame::Settings) {
if let Some(val) = settings.max_concurrent_streams() {
self.max_send_streams = Some(val as usize);
}
}
/// Run a block of code that could potentially transition a stream's state.
///
/// If the stream state transitions to closed, this function will perform
/// all necessary cleanup.
pub fn transition<F, B, U>(&mut self, mut stream: store::Ptr<B, P>, f: F) -> U
where F: FnOnce(&mut Self, &mut store::Ptr<B, P>) -> U
{
let is_counted = stream.state.is_counted();
// Run the action
let ret = f(self, &mut stream);
self.transition_after(stream, is_counted);
ret
}
// TODO: move this to macro?
pub fn transition_after<B>(&mut self, mut stream: store::Ptr<B, P>, is_counted: bool) {
if stream.is_closed() {
stream.unlink();
if is_counted {
// Decrement the number of active streams.
self.dec_num_streams(stream.id);
}
}
// Release the stream if it requires releasing
if stream.is_released() {
stream.remove();
}
}
fn dec_num_streams(&mut self, id: StreamId) {
use std::usize;
if P::is_local_init(id) {
self.num_send_streams -= 1;
if self.num_send_streams < self.max_send_streams.unwrap_or(usize::MAX) {
if let Some(task) = self.blocked_open.take() {
task.notify();
}
}
} else {
self.num_recv_streams -= 1;
}
}
}
impl Counts<client::Peer> {
pub fn poll_open_ready(&mut self) -> Async<()> {
if !self.can_inc_num_send_streams() {
self.blocked_open = Some(task::current());
return Async::NotReady;
}
return Async::Ready(());
}
}

View File

@@ -1,4 +1,5 @@
mod buffer;
mod counts;
mod flow_control;
mod prioritize;
mod recv;
@@ -12,6 +13,7 @@ pub(crate) use self::streams::{Streams, StreamRef};
pub(crate) use self::prioritize::Prioritized;
use self::buffer::Buffer;
use self::counts::Counts;
use self::flow_control::FlowControl;
use self::prioritize::Prioritize;
use self::recv::Recv;

View File

@@ -325,6 +325,7 @@ impl<B, P> Prioritize<B, P>
pub fn poll_complete<T>(&mut self,
store: &mut Store<B, P>,
counts: &mut Counts<P>,
dst: &mut Codec<T, Prioritized<B>>)
-> Poll<(), io::Error>
where T: AsyncWrite,
@@ -341,7 +342,7 @@ impl<B, P> Prioritize<B, P>
trace!("poll_complete");
loop {
match self.pop_frame(store, max_frame_len) {
match self.pop_frame(store, max_frame_len, counts) {
Some(frame) => {
trace!("writing frame={:?}", frame);
@@ -433,7 +434,7 @@ impl<B, P> Prioritize<B, P>
}
}
fn pop_frame(&mut self, store: &mut Store<B, P>, max_len: usize)
fn pop_frame(&mut self, store: &mut Store<B, P>, max_len: usize, counts: &mut Counts<P>)
-> Option<Frame<Prioritized<B>>>
{
trace!("pop_frame");
@@ -444,6 +445,8 @@ impl<B, P> Prioritize<B, P>
trace!("pop_frame; stream={:?}", stream.id);
debug_assert!(!stream.pending_send.is_empty());
let is_counted = stream.state.is_counted();
let frame = match stream.pending_send.pop_front(&mut self.buffer).unwrap() {
Frame::Data(mut frame) => {
// Get the amount of capacity remaining for stream's
@@ -541,6 +544,8 @@ impl<B, P> Prioritize<B, P>
self.pending_send.push(&mut stream);
}
counts.transition_after(stream, is_counted);
return Some(frame);
}
None => return None,

View File

@@ -13,12 +13,6 @@ use std::marker::PhantomData;
pub(super) struct Recv<B, P>
where P: Peer,
{
/// Maximum number of remote initiated streams
max_streams: Option<usize>,
/// Current number of remote initiated streams
num_streams: usize,
/// Initial window size of remote initiated streams
init_window_sz: WindowSize,
@@ -77,8 +71,6 @@ impl<B, P> Recv<B, P>
flow.assign_capacity(config.init_remote_window_sz);
Recv {
max_streams: config.max_remote_initiated,
num_streams: 0,
init_window_sz: config.init_remote_window_sz,
flow: flow,
next_stream_id: next_stream_id.into(),
@@ -104,14 +96,21 @@ impl<B, P> Recv<B, P>
/// Update state reflecting a new, remotely opened stream
///
/// Returns the stream state if successful. `None` if refused
pub fn open(&mut self, id: StreamId)
pub fn open(&mut self, id: StreamId, counts: &mut Counts<P>)
-> Result<Option<StreamId>, RecvError>
{
assert!(self.refused.is_none());
try!(self.ensure_can_open(id));
if !self.can_inc_num_streams() {
if id < self.next_stream_id {
return Err(RecvError::Connection(ProtocolError));
}
self.next_stream_id = id;
self.next_stream_id.increment();
if !counts.can_inc_num_recv_streams() {
self.refused = Some(id);
return Ok(None);
}
@@ -124,31 +123,21 @@ impl<B, P> Recv<B, P>
/// The caller ensures that the frame represents headers and not trailers.
pub fn recv_headers(&mut self,
frame: frame::Headers,
stream: &mut store::Ptr<B, P>)
stream: &mut store::Ptr<B, P>,
counts: &mut Counts<P>)
-> Result<(), RecvError>
{
trace!("opening stream; init_window={}", self.init_window_sz);
let is_initial = stream.state.recv_open(frame.is_end_stream())?;
if is_initial {
if !self.can_inc_num_streams() {
unimplemented!();
}
if frame.stream_id() >= self.next_stream_id {
self.next_stream_id = frame.stream_id();
self.next_stream_id.increment();
} else {
return Err(RecvError::Connection(ProtocolError));
}
// TODO: be smarter about this logic
if frame.stream_id() > self.last_processed_id {
self.last_processed_id = frame.stream_id();
}
// Increment the number of concurrent streams
self.inc_num_streams();
counts.inc_num_recv_streams();
}
if !stream.content_length.is_head() {
@@ -397,30 +386,6 @@ impl<B, P> Recv<B, P>
stream.notify_recv();
}
/// Returns true if the current stream concurrency can be incremetned
fn can_inc_num_streams(&self) -> bool {
if let Some(max) = self.max_streams {
max > self.num_streams
} else {
true
}
}
/// Increments the number of concurrenty streams. Panics on failure as this
/// should have been validated before hand.
fn inc_num_streams(&mut self) {
if !self.can_inc_num_streams() {
panic!();
}
// Increment the number of remote initiated streams
self.num_streams += 1;
}
pub fn dec_num_streams(&mut self) {
self.num_streams -= 1;
}
/// Returns true if the remote peer can initiate a stream with the given ID.
fn ensure_can_open(&self, id: StreamId)
-> Result<(), RecvError>

View File

@@ -1,4 +1,3 @@
use client;
use frame::{self, Reason};
use codec::{RecvError, UserError};
use codec::UserError::*;
@@ -14,21 +13,12 @@ use std::io;
pub(super) struct Send<B, P>
where P: Peer,
{
/// Maximum number of locally initiated streams
max_streams: Option<usize>,
/// Current number of locally initiated streams
num_streams: usize,
/// Stream identifier to use for next initialized stream.
next_stream_id: StreamId,
/// Initial window size of locally initiated streams
init_window_sz: WindowSize,
/// Task awaiting notification to open a new stream.
blocked_open: Option<task::Task>,
/// Prioritization layer
prioritize: Prioritize<B, P>,
}
@@ -42,11 +32,8 @@ where B: Buf,
let next_stream_id = if P::is_server() { 2 } else { 1 };
Send {
max_streams: config.max_local_initiated,
num_streams: 0,
next_stream_id: next_stream_id.into(),
init_window_sz: config.init_local_window_sz,
blocked_open: None,
prioritize: Prioritize::new(config),
}
}
@@ -59,22 +46,20 @@ where B: Buf,
/// Update state reflecting a new, locally opened stream
///
/// Returns the stream state if successful. `None` if refused
pub fn open(&mut self)
pub fn open(&mut self, counts: &mut Counts<P>)
-> Result<StreamId, UserError>
{
try!(self.ensure_can_open());
if let Some(max) = self.max_streams {
if max <= self.num_streams {
return Err(Rejected.into());
}
if !counts.can_inc_num_send_streams() {
return Err(Rejected.into());
}
let ret = self.next_stream_id;
self.next_stream_id.increment();
// Increment the number of locally initiated streams
self.num_streams += 1;
self.next_stream_id.increment();
counts.inc_num_send_streams();
Ok(ret)
}
@@ -167,11 +152,12 @@ where B: Buf,
pub fn poll_complete<T>(&mut self,
store: &mut Store<B, P>,
counts: &mut Counts<P>,
dst: &mut Codec<T, Prioritized<B>>)
-> Poll<(), io::Error>
where T: AsyncWrite,
{
self.prioritize.poll_complete(store, dst)
self.prioritize.poll_complete(store, counts, dst)
}
/// Request capacity to send data
@@ -237,10 +223,6 @@ where B: Buf,
task: &mut Option<Task>)
-> Result<(), RecvError>
{
if let Some(val) = settings.max_concurrent_streams() {
self.max_streams = Some(val as usize);
}
// Applies an update to the remote endpoint's initial window size.
//
// Per RFC 7540 §6.9.2:
@@ -304,16 +286,6 @@ where B: Buf,
Ok(())
}
pub fn dec_num_streams(&mut self) {
self.num_streams -= 1;
if self.num_streams < self.max_streams.unwrap_or(::std::usize::MAX) {
if let Some(task) = self.blocked_open.take() {
task.notify();
}
}
}
/// Returns true if the local actor can initiate a stream with the given ID.
fn ensure_can_open(&self) -> Result<(), UserError> {
if P::is_server() {
@@ -326,18 +298,3 @@ where B: Buf,
Ok(())
}
}
impl<B> Send<B, client::Peer>
where B: Buf,
{
pub fn poll_open_ready(&mut self) -> Async<()> {
if let Some(max) = self.max_streams {
if max <= self.num_streams {
self.blocked_open = Some(task::current());
return Async::NotReady;
}
}
return Async::Ready(());
}
}

View File

@@ -2,8 +2,9 @@ use super::*;
use slab;
use ordermap::{self, OrderMap};
use std::ops;
use std::collections::{HashMap, hash_map};
use std::marker::PhantomData;
/// Storage for streams
@@ -12,7 +13,7 @@ pub(super) struct Store<B, P>
where P: Peer,
{
slab: slab::Slab<Stream<B, P>>,
ids: HashMap<StreamId, usize>,
ids: OrderMap<StreamId, usize>,
}
/// "Pointer" to an entry in the store
@@ -20,7 +21,7 @@ pub(super) struct Ptr<'a, B: 'a, P>
where P: Peer + 'a,
{
key: Key,
slab: &'a mut slab::Slab<Stream<B, P>>,
store: &'a mut Store<B, P>,
}
/// References an entry in the store.
@@ -60,13 +61,13 @@ pub(super) enum Entry<'a, B: 'a, P: Peer + 'a> {
}
pub(super) struct OccupiedEntry<'a> {
ids: hash_map::OccupiedEntry<'a, StreamId, usize>,
ids: ordermap::OccupiedEntry<'a, StreamId, usize>,
}
pub(super) struct VacantEntry<'a, B: 'a, P>
where P: Peer + 'a,
{
ids: hash_map::VacantEntry<'a, StreamId, usize>,
ids: ordermap::VacantEntry<'a, StreamId, usize>,
slab: &'a mut slab::Slab<Stream<B, P>>,
}
@@ -84,19 +85,24 @@ impl<B, P> Store<B, P>
pub fn new() -> Self {
Store {
slab: slab::Slab::new(),
ids: HashMap::new(),
ids: OrderMap::new(),
}
}
pub fn contains_id(&self, id: &StreamId) -> bool {
self.ids.contains_key(id)
}
pub fn find_mut(&mut self, id: &StreamId) -> Option<Ptr<B, P>> {
if let Some(&key) = self.ids.get(id) {
Some(Ptr {
key: Key(key),
slab: &mut self.slab,
})
} else {
None
}
let key = match self.ids.get(id) {
Some(key) => *key,
None => return None,
};
Some(Ptr {
key: Key(key),
store: self,
})
}
pub fn insert(&mut self, id: StreamId, val: Stream<B, P>) -> Ptr<B, P> {
@@ -105,12 +111,12 @@ impl<B, P> Store<B, P>
Ptr {
key: Key(key),
slab: &mut self.slab,
store: self,
}
}
pub fn find_entry(&mut self, id: StreamId) -> Entry<B, P> {
use self::hash_map::Entry::*;
use self::ordermap::Entry::*;
match self.ids.entry(id) {
Occupied(e) => {
@@ -130,11 +136,27 @@ impl<B, P> Store<B, P>
pub fn for_each<F, E>(&mut self, mut f: F) -> Result<(), E>
where F: FnMut(Ptr<B, P>) -> Result<(), E>,
{
for &key in self.ids.values() {
let mut len = self.ids.len();
let mut i = 0;
while i < len {
// Get the key by index, this makes the borrow checker happy
let key = *self.ids.get_index(i).unwrap().1;
f(Ptr {
key: Key(key),
slab: &mut self.slab,
store: self,
})?;
// TODO: This logic probably could be better...
let new_len = self.ids.len();
if new_len < len {
debug_assert!(new_len == len - 1);
len -= 1;
} else {
i += 1;
}
}
Ok(())
@@ -147,7 +169,7 @@ impl<B, P> Resolve<B, P> for Store<B, P>
fn resolve(&mut self, key: Key) -> Ptr<B, P> {
Ptr {
key: key,
slab: &mut self.slab,
store: self,
}
}
}
@@ -170,6 +192,19 @@ impl<B, P> ops::IndexMut<Key> for Store<B, P>
}
}
#[cfg(feature = "unstable")]
impl<B, P> Store<B, P>
where P: Peer,
{
pub fn num_active_streams(&self) -> usize {
self.ids.len()
}
pub fn num_wired_streams(&self) -> usize {
self.slab.len()
}
}
// ===== impl Queue =====
impl<B, N, P> Queue<B, N, P>
@@ -263,9 +298,28 @@ impl<B, N, P> Queue<B, N, P>
impl<'a, B: 'a, P> Ptr<'a, B, P>
where P: Peer,
{
/// Returns the Key associated with the stream
pub fn key(&self) -> Key {
self.key
}
/// Remove the stream from the store
pub fn remove(self) -> StreamId {
// The stream must have been unlinked before this point
debug_assert!(!self.store.ids.contains_key(&self.id));
// Remove the stream state
self.store.slab.remove(self.key.0).id
}
/// Remove the StreamId -> stream state association.
///
/// This will effectively remove the stream as far as the H2 protocol is
/// concerned.
pub fn unlink(&mut self) {
let id = self.id;
self.store.ids.remove(&id);
}
}
impl<'a, B: 'a, P> Resolve<B, P> for Ptr<'a, B, P>
@@ -274,7 +328,7 @@ impl<'a, B: 'a, P> Resolve<B, P> for Ptr<'a, B, P>
fn resolve(&mut self, key: Key) -> Ptr<B, P> {
Ptr {
key: key,
slab: &mut *self.slab,
store: &mut *self.store,
}
}
}
@@ -285,7 +339,7 @@ impl<'a, B: 'a, P> ops::Deref for Ptr<'a, B, P>
type Target = Stream<B, P>;
fn deref(&self) -> &Stream<B, P> {
&self.slab[self.key.0]
&self.store.slab[self.key.0]
}
}
@@ -293,7 +347,7 @@ impl<'a, B: 'a, P> ops::DerefMut for Ptr<'a, B, P>
where P: Peer,
{
fn deref_mut(&mut self) -> &mut Stream<B, P> {
&mut self.slab[self.key.0]
&mut self.store.slab[self.key.0]
}
}

View File

@@ -1,5 +1,18 @@
use super::*;
use std::usize;
/// Tracks Stream related state
///
/// # Reference counting
///
/// There can be a number of outstanding handles to a single Stream. These are
/// tracked using reference counting. The `ref_count` field represents the
/// number of outstanding userspace handles that can reach this stream.
///
/// It's important to note that when the stream is placed in an internal queue
/// (such as an accept queue), this is **not** tracked by a reference count.
/// Thus, `ref_count` can be zero and the stream still has to be kept around.
#[derive(Debug)]
pub(super) struct Stream<B, P>
where P: Peer,
@@ -10,6 +23,9 @@ pub(super) struct Stream<B, P>
/// Current state of the stream
pub state: State,
/// Number of outstanding handles pointing to this stream
pub ref_count: usize,
// ===== Fields related to sending =====
/// Next node in the accept linked list
@@ -117,6 +133,7 @@ impl<B, P> Stream<B, P>
Stream {
id,
state: State::default(),
ref_count: 0,
// ===== Fields related to sending =====
@@ -146,6 +163,46 @@ impl<B, P> Stream<B, P>
}
}
/// Increment the stream's ref count
pub fn ref_inc(&mut self) {
assert!(self.ref_count < usize::MAX);
self.ref_count += 1;
}
/// Decrements the stream's ref count
pub fn ref_dec(&mut self) {
assert!(self.ref_count > 0);
self.ref_count -= 1;
}
/// Returns true if the stream is closed
pub fn is_closed(&self) -> bool {
// The state has fully transitioned to closed.
self.state.is_closed() &&
// Because outbound frames transition the stream state before being
// buffered, we have to ensure that all frames have been flushed.
self.pending_send.is_empty() &&
// Sometimes large data frames are sent out in chunks. After a chunk
// of the frame is sent, the remainder is pushed back onto the send
// queue to be rescheduled.
//
// Checking for additional buffered data lets us catch this case.
self.buffered_send_data == 0
}
/// Returns true if the stream is no longer in use
pub fn is_released(&self) -> bool {
// The stream is closed and fully flushed
self.is_closed() &&
// There are no more outstanding references to the stream
self.ref_count == 0 &&
// The stream is not in any queue
!self.is_pending_send &&
!self.is_pending_send_capacity &&
!self.is_pending_accept &&
!self.is_pending_window_update
}
pub fn assign_capacity(&mut self, capacity: WindowSize) {
debug_assert!(capacity > 0);
self.send_capacity_inc = true;

View File

@@ -34,6 +34,8 @@ pub(crate) struct StreamRef<B, P>
struct Inner<B, P>
where P: Peer,
{
/// Tracks send & recv stream concurrency.
counts: Counts<P>,
actions: Actions<B, P>,
store: Store<B, P>,
}
@@ -59,6 +61,7 @@ impl<B, P> Streams<B, P>
pub fn new(config: Config) -> Self {
Streams {
inner: Arc::new(Mutex::new(Inner {
counts: Counts::new(&config),
actions: Actions {
recv: Recv::new(&config),
send: Send::new(&config),
@@ -80,7 +83,7 @@ impl<B, P> Streams<B, P>
let key = match me.store.find_entry(id) {
Entry::Occupied(e) => e.key(),
Entry::Vacant(e) => {
match try!(me.actions.recv.open(id)) {
match try!(me.actions.recv.open(id, &mut me.counts)) {
Some(stream_id) => {
let stream = Stream::new(
stream_id,
@@ -95,10 +98,13 @@ impl<B, P> Streams<B, P>
};
let stream = me.store.resolve(key);
let actions = &mut me.actions;
me.counts.transition(stream, |counts, stream| {
trace!("recv_headers; stream={:?}; state={:?}", stream.id, stream.state);
me.actions.transition(stream, |actions, stream| {
let res = if stream.state.is_recv_headers() {
actions.recv.recv_headers(frame, stream)
actions.recv.recv_headers(frame, stream, counts)
} else {
if !frame.is_end_stream() {
// TODO: Is this the right error
@@ -133,7 +139,9 @@ impl<B, P> Streams<B, P>
None => return Err(RecvError::Connection(ProtocolError)),
};
me.actions.transition(stream, |actions, stream| {
let actions = &mut me.actions;
me.counts.transition(stream, |_, stream| {
match actions.recv.recv_data(frame, stream) {
Err(RecvError::Stream { reason, .. }) => {
// Reset the stream.
@@ -168,7 +176,9 @@ impl<B, P> Streams<B, P>
}
};
me.actions.transition(stream, |actions, stream| {
let actions = &mut me.actions;
me.counts.transition(stream, |_, stream| {
actions.recv.recv_reset(frame, stream)?;
assert!(stream.state.is_closed());
Ok(())
@@ -181,12 +191,16 @@ impl<B, P> Streams<B, P>
let me = &mut *me;
let actions = &mut me.actions;
let counts = &mut me.counts;
let last_processed_id = actions.recv.last_processed_id();
me.store.for_each(|mut stream| {
actions.recv.recv_err(err, &mut *stream);
Ok::<_, ()>(())
}).ok().expect("unexpected error processing error");
me.store.for_each(|stream| {
counts.transition(stream, |_, stream| {
actions.recv.recv_err(err, &mut *stream);
Ok::<_, ()>(())
})
}).unwrap();
last_processed_id
}
@@ -244,7 +258,16 @@ impl<B, P> Streams<B, P>
let mut me = self.inner.lock().unwrap();
let me = &mut *me;
me.actions.recv.next_incoming(&mut me.store)
match me.actions.recv.next_incoming(&mut me.store) {
Some(key) => {
// Increment the ref count
me.store.resolve(key).ref_inc();
// Return the key
Some(key)
}
None => None,
}
};
key.map(|key| {
@@ -278,7 +301,7 @@ impl<B, P> Streams<B, P>
try_ready!(me.actions.recv.poll_complete(&mut me.store, dst));
// Send any other pending frames
try_ready!(me.actions.send.poll_complete(&mut me.store, dst));
try_ready!(me.actions.send.poll_complete(&mut me.store, &mut me.counts, dst));
// Nothing else to do, track the task
me.actions.task = Some(task::current());
@@ -292,6 +315,8 @@ impl<B, P> Streams<B, P>
let mut me = self.inner.lock().unwrap();
let me = &mut *me;
me.counts.apply_remote_settings(frame);
me.actions.send.apply_remote_settings(
frame, &mut me.store, &mut me.actions.task)
}
@@ -312,7 +337,7 @@ impl<B, P> Streams<B, P>
let me = &mut *me;
// Initialize a new stream. This fails if the connection is at capacity.
let stream_id = me.actions.send.open()?;
let stream_id = me.actions.send.open(&mut me.counts)?;
let mut stream = Stream::new(
stream_id,
@@ -336,6 +361,9 @@ impl<B, P> Streams<B, P>
// closed state.
debug_assert!(!stream.state.is_closed());
// Increment the stream ref count as we will be returning a handle.
stream.ref_inc();
stream.key()
};
@@ -352,7 +380,7 @@ impl<B, P> Streams<B, P>
let key = match me.store.find_entry(id) {
Entry::Occupied(e) => e.key(),
Entry::Vacant(e) => {
match me.actions.recv.open(id) {
match me.actions.recv.open(id, &mut me.counts) {
Ok(Some(stream_id)) => {
let stream = Stream::new(
stream_id, 0, 0);
@@ -364,10 +392,10 @@ impl<B, P> Streams<B, P>
}
};
let stream = me.store.resolve(key);
let actions = &mut me.actions;
me.actions.transition(stream, move |actions, stream| {
me.counts.transition(stream, |_, stream| {
actions.send.send_reset(reason, stream, &mut actions.task)
})
}
@@ -380,7 +408,23 @@ impl<B> Streams<B, client::Peer>
let mut me = self.inner.lock().unwrap();
let me = &mut *me;
me.actions.send.poll_open_ready()
me.counts.poll_open_ready()
}
}
#[cfg(feature = "unstable")]
impl<B, P> Streams<B, P>
where B: Buf,
P: Peer,
{
pub fn num_active_streams(&self) -> usize {
let me = self.inner.lock().unwrap();
me.store.num_active_streams()
}
pub fn num_wired_streams(&self) -> usize {
let me = self.inner.lock().unwrap();
me.store.num_wired_streams()
}
}
@@ -397,12 +441,13 @@ impl<B, P> StreamRef<B, P>
let me = &mut *me;
let stream = me.store.resolve(self.key);
let actions = &mut me.actions;
// Create the data frame
let mut frame = frame::Data::new(stream.id, data);
frame.set_end_stream(end_stream);
me.counts.transition(stream, |_, stream| {
// Create the data frame
let mut frame = frame::Data::new(stream.id, data);
frame.set_end_stream(end_stream);
me.actions.transition(stream, |actions, stream| {
// Send the data frame
actions.send.send_data(frame, stream, &mut actions.task)
})
@@ -415,11 +460,12 @@ impl<B, P> StreamRef<B, P>
let me = &mut *me;
let stream = me.store.resolve(self.key);
let actions = &mut me.actions;
// Create the trailers frame
let frame = frame::Headers::trailers(stream.id, trailers);
me.counts.transition(stream, |_, stream| {
// Create the trailers frame
let frame = frame::Headers::trailers(stream.id, trailers);
me.actions.transition(stream, |actions, stream| {
// Send the trailers frame
actions.send.send_trailers(frame, stream, &mut actions.task)
})
@@ -430,7 +476,9 @@ impl<B, P> StreamRef<B, P>
let me = &mut *me;
let stream = me.store.resolve(self.key);
me.actions.transition(stream, move |actions, stream| {
let actions = &mut me.actions;
me.counts.transition(stream, |_, stream| {
actions.send.send_reset(reason, stream, &mut actions.task)
})
}
@@ -442,11 +490,12 @@ impl<B, P> StreamRef<B, P>
let me = &mut *me;
let stream = me.store.resolve(self.key);
let actions = &mut me.actions;
let frame = server::Peer::convert_send_message(
stream.id, response, end_of_stream);
me.counts.transition(stream, |_, stream| {
let frame = server::Peer::convert_send_message(
stream.id, response, end_of_stream);
me.actions.transition(stream, |actions, stream| {
actions.send.send_headers(frame, stream, &mut actions.task)
})
}
@@ -559,6 +608,11 @@ impl<B, P> Clone for StreamRef<B, P>
where P: Peer,
{
fn clone(&self) -> Self {
// Increment the ref count
self.inner.lock().unwrap()
.store.resolve(self.key)
.ref_inc();
StreamRef {
inner: self.inner.clone(),
key: self.key.clone(),
@@ -566,6 +620,29 @@ impl<B, P> Clone for StreamRef<B, P>
}
}
impl<B, P> Drop for StreamRef<B, P>
where P: Peer,
{
fn drop(&mut self) {
let mut me = self.inner.lock().unwrap();
let me = &mut *me;
let id = {
let mut stream = me.store.resolve(self.key);
stream.ref_dec();
if !stream.is_released() {
return;
}
stream.remove()
};
debug_assert!(!me.store.contains_id(&id));
}
}
// ===== impl Actions =====
impl<B, P> Actions<B, P>
@@ -575,37 +652,10 @@ impl<B, P> Actions<B, P>
fn ensure_not_idle(&mut self, id: StreamId)
-> Result<(), Reason>
{
if self.is_local_init(id) {
if P::is_local_init(id) {
self.send.ensure_not_idle(id)
} else {
self.recv.ensure_not_idle(id)
}
}
fn dec_num_streams(&mut self, id: StreamId) {
if self.is_local_init(id) {
self.send.dec_num_streams();
} else {
self.recv.dec_num_streams();
}
}
fn is_local_init(&self, id: StreamId) -> bool {
assert!(!id.is_zero());
P::is_server() == id.is_server_initiated()
}
fn transition<F, U>(&mut self, mut stream: store::Ptr<B, P>, f: F) -> U
where F: FnOnce(&mut Self, &mut store::Ptr<B, P>) -> U,
{
let is_counted = stream.state.is_counted();
let ret = f(self, &mut stream);
if is_counted && stream.state.is_closed() {
self.dec_num_streams(stream.id);
}
ret
}
}

View File

@@ -149,6 +149,58 @@ fn send_headers_recv_data_single_frame() {
h2.wait().unwrap();
}
#[test]
fn closed_streams_are_released() {
let _ = ::env_logger::init();
let (io, srv) = mock::new();
let h2 = Client::handshake(io).unwrap()
.and_then(|mut h2| {
let request = Request::get("https://example.com/")
.body(()).unwrap();
// Send request
let stream = h2.request(request, true).unwrap();
h2.drive(stream)
})
.and_then(|(h2, response)| {
assert_eq!(response.status(), StatusCode::NO_CONTENT);
// There are no active streams
assert_eq!(0, h2.num_active_streams());
// The response contains a handle for the body. This keeps the
// stream wired.
assert_eq!(1, h2.num_wired_streams());
drop(response);
// The stream state is now free
assert_eq!(0, h2.num_wired_streams());
Ok(())
})
;
let srv = srv.assert_client_handshake().unwrap()
.recv_settings()
.recv_frame(
frames::headers(1)
.request("GET", "https://example.com/")
.eos()
)
.send_frame(
frames::headers(1)
.response(204)
.eos()
)
.close()
;
let _ = h2.join(srv)
.wait().unwrap();
}
/*
#[test]
fn send_data_after_headers_eos() {

View File

@@ -71,23 +71,35 @@ impl<T, U> Future for Drive<T, U>
type Error = ();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.future.poll() {
Ok(Async::Ready(val)) => {
// Get the driver
let driver = self.driver.take().unwrap();
let mut looped = false;
return Ok((driver, val).into())
loop {
match self.future.poll() {
Ok(Async::Ready(val)) => {
// Get the driver
let driver = self.driver.take().unwrap();
return Ok((driver, val).into())
}
Ok(_) => {}
Err(e) => panic!("unexpected error; {:?}", e),
}
Ok(_) => {}
Err(e) => panic!("unexpected error; {:?}", e),
}
match self.driver.as_mut().unwrap().poll() {
Ok(Async::Ready(_)) => panic!("driver resolved before future"),
Ok(Async::NotReady) => {}
Err(e) => panic!("unexpected error; {:?}", e),
}
match self.driver.as_mut().unwrap().poll() {
Ok(Async::Ready(_)) => {
if looped {
// Try polling the future one last time
panic!("driver resolved before future")
} else {
looped = true;
continue;
}
}
Ok(Async::NotReady) => {}
Err(e) => panic!("unexpected error; {:?}", e),
}
Ok(Async::NotReady)
return Ok(Async::NotReady);
}
}
}