diff --git a/src/proto/streams/state.rs b/src/proto/streams/state.rs index 3c0d903..7b3bdc3 100644 --- a/src/proto/streams/state.rs +++ b/src/proto/streams/state.rs @@ -243,6 +243,17 @@ impl State { } } + /// Returns true if a stream with the current state counts against the + /// concurrency limit. + pub fn is_counted(&self) -> bool { + match self.inner { + Open { .. } => true, + HalfClosedLocal(..) => true, + HalfClosedRemote(..) => true, + _ => false, + } + } + pub fn is_closed(&self) -> bool { match self.inner { Closed(_) => true, diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index bbfee6d..fd3a47c 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -88,28 +88,23 @@ impl Streams } }; - let mut stream = me.store.resolve(key); + let stream = me.store.resolve(key); - let ret = if frame.is_trailers() { - unimplemented!(); - /* - if !frame.is_end_stream() { - // TODO: What error should this return? + me.actions.transition(stream, |actions, stream| { + if frame.is_trailers() { unimplemented!(); + /* + if !frame.is_end_stream() { + // TODO: What error should this return? + unimplemented!(); + } + + try!(me.actions.recv.recv_eos(stream)); + */ + } else { + actions.recv.recv_headers(frame, stream) } - - try!(me.actions.recv.recv_eos(stream)); - */ - } else { - try!(me.actions.recv.recv_headers(frame, &mut stream)) - }; - - // TODO: move this into a fn - if stream.state.is_closed() { - me.actions.dec_num_streams(id); - } - - Ok(ret) + }) } pub fn recv_data(&mut self, frame: frame::Data) @@ -120,20 +115,14 @@ impl Streams let id = frame.stream_id(); - let mut stream = match me.store.find_mut(&id) { + let stream = match me.store.find_mut(&id) { Some(stream) => stream, None => return Err(ProtocolError.into()), }; - // Ensure there's enough capacity on the connection before acting on the - // stream. - try!(me.actions.recv.recv_data(frame, &mut stream)); - - if stream.state.is_closed() { - me.actions.dec_num_streams(id); - } - - Ok(()) + me.actions.transition(stream, |actions, stream| { + actions.recv.recv_data(frame, stream) + }) } pub fn recv_reset(&mut self, frame: frame::Reset) @@ -150,12 +139,11 @@ impl Streams None => return Ok(()), }; - me.actions.recv.recv_reset(frame, &mut stream)?; - - assert!(stream.state.is_closed()); - me.actions.dec_num_streams(id); - - Ok(()) + me.actions.transition(stream, |actions, stream| { + actions.recv.recv_reset(frame, stream)?; + assert!(stream.state.is_closed()); + Ok(()) + }) } pub fn recv_err(&mut self, err: &ConnectionError) { @@ -338,19 +326,15 @@ impl StreamRef let mut me = self.inner.lock().unwrap(); let me = &mut *me; - let mut stream = me.store.resolve(self.key); + let stream = me.store.resolve(self.key); // Create the data frame let frame = frame::Data::from_buf(stream.id, data, end_of_stream); - // Send the data frame - me.actions.send.send_data(frame, &mut stream)?; - - if stream.state.is_closed() { - me.actions.dec_num_streams(stream.id); - } - - Ok(()) + me.actions.transition(stream, |actions, stream| { + // Send the data frame + actions.send.send_data(frame, stream) + }) } pub fn poll_data(&mut self) -> Poll>, ConnectionError> { @@ -445,4 +429,18 @@ impl Actions assert!(!id.is_zero()); P::is_server() == id.is_server_initiated() } + + fn transition(&mut self, mut stream: store::Ptr, f: F) -> U + where F: FnOnce(&mut Self, &mut store::Ptr) -> U, + { + let is_counted = stream.state.is_counted(); + + let ret = f(self, &mut stream); + + if is_counted && stream.state.is_closed() { + self.dec_num_streams(stream.id); + } + + ret + } }