diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index d91921f..e92fea2 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -63,6 +63,9 @@ struct Inner { /// Stores stream state store: Store, + + /// The number of stream refs to this shared state. + refs: usize, } #[derive(Debug)] @@ -106,6 +109,7 @@ where conn_error: None, }, store: Store::new(), + refs: 1, })), send_buffer: Arc::new(SendBuffer::new()), _p: ::std::marker::PhantomData, @@ -485,7 +489,9 @@ where let mut me = self.inner.lock().unwrap(); let me = &mut *me; let key = me.actions.recv.next_incoming(&mut me.store); - + // TODO: ideally, OpaqueStreamRefs::new would do this, but we're holding + // the lock, so it can't. + me.refs += 1; key.map(|key| { let stream = &mut me.store.resolve(key); trace!("next_incoming; id={:?}, state={:?}", stream.id, stream.state); @@ -635,6 +641,10 @@ where // closed state. debug_assert!(!stream.state.is_closed()); + // TODO: ideally, OpaqueStreamRefs::new would do this, but we're holding + // the lock, so it can't. + me.refs += 1; + Ok(StreamRef { opaque: OpaqueStreamRef::new( self.inner.clone(), @@ -748,16 +758,8 @@ where } pub fn has_streams_or_other_references(&self) -> bool { - if Arc::strong_count(&self.inner) > 1 { - return true; - } - - if Arc::strong_count(&self.send_buffer) > 1 { - return true; - } - let me = self.inner.lock().unwrap(); - me.counts.has_streams() + me.counts.has_streams() || me.refs > 1 } #[cfg(feature = "unstable")] @@ -773,6 +775,7 @@ where P: Peer, { fn clone(&self) -> Self { + self.inner.lock().unwrap().refs += 1; Streams { inner: self.inner.clone(), send_buffer: self.send_buffer.clone(), @@ -781,6 +784,16 @@ where } } +impl Drop for Streams +where + P: Peer, +{ + fn drop(&mut self) { + let _ = self.inner.lock().map(|mut inner| inner.refs -= 1); + } +} + + // ===== impl StreamRef ===== impl StreamRef { @@ -978,6 +991,7 @@ impl OpaqueStreamRef { try_ready!(me.actions.recv.poll_pushed(&mut stream)) }; Ok(Async::Ready(res.map(|(h, key)| { + me.refs += 1; let opaque_ref = OpaqueStreamRef::new(self.inner.clone(), &mut me.store.resolve(key)); (h, opaque_ref) @@ -1070,7 +1084,9 @@ impl fmt::Debug for OpaqueStreamRef { impl Clone for OpaqueStreamRef { fn clone(&self) -> Self { // Increment the ref count - self.inner.lock().unwrap().store.resolve(self.key).ref_inc(); + let mut inner = self.inner.lock().unwrap(); + inner.store.resolve(self.key).ref_inc(); + inner.refs += 1; OpaqueStreamRef { inner: self.inner.clone(), @@ -1098,7 +1114,7 @@ fn drop_stream_ref(inner: &Mutex, key: store::Key) { }; let me = &mut *me; - + me.refs -= 1; let mut stream = me.store.resolve(key); trace!("drop_stream_ref; stream={:?}", stream); diff --git a/tests/h2-tests/Cargo.toml b/tests/h2-tests/Cargo.toml index 41642e7..7649d91 100644 --- a/tests/h2-tests/Cargo.toml +++ b/tests/h2-tests/Cargo.toml @@ -9,3 +9,4 @@ publish = false [dev-dependencies] h2-support = { path = "../h2-support" } log = "0.4.1" +tokio = "0.1.8" diff --git a/tests/h2-tests/tests/hammer.rs b/tests/h2-tests/tests/hammer.rs new file mode 100644 index 0000000..f83b8b9 --- /dev/null +++ b/tests/h2-tests/tests/hammer.rs @@ -0,0 +1,154 @@ +extern crate tokio; +#[macro_use] +extern crate h2_support; + +use h2_support::prelude::*; +use h2_support::futures::{Async, Poll}; + +use tokio::net::{TcpListener, TcpStream}; +use std::{net::SocketAddr, thread, sync::{atomic::{AtomicUsize, Ordering}, Arc}}; + +struct Server { + addr: SocketAddr, + reqs: Arc, + join: Option>, +} + +impl Server { + fn serve(mk_data: F) -> Self + where + F: Fn() -> Bytes, + F: Send + Sync + 'static, + { + let mk_data = Arc::new(mk_data); + + let listener = TcpListener::bind(&SocketAddr::from(([127, 0, 0, 1], 0))).unwrap(); + let addr = listener.local_addr().unwrap(); + let reqs = Arc::new(AtomicUsize::new(0)); + let reqs2 = reqs.clone(); + let join = thread::spawn(move || { + let server = listener.incoming().for_each(move |socket| { + let reqs = reqs2.clone(); + let mk_data = mk_data.clone(); + let connection = server::handshake(socket) + .and_then(move |conn| { + conn.for_each(move |(_, mut respond)| { + reqs.fetch_add(1, Ordering::Release); + let response = Response::builder().status(StatusCode::OK).body(()).unwrap(); + let mut send = respond.send_response(response, false)?; + send.send_data(mk_data(), true).map(|_|()) + }) + }) + .map_err(|e| eprintln!("serve conn error: {:?}", e)); + + tokio::spawn(Box::new(connection)); + Ok(()) + }) + .map_err(|e| eprintln!("serve error: {:?}", e)); + + tokio::run(server); + }); + + Self { + addr, + join: Some(join), + reqs + } + } + + fn addr(&self) -> SocketAddr { + self.addr.clone() + } + + fn request_count(&self) -> usize { + self.reqs.load(Ordering::Acquire) + } +} + + +struct Process { + body: RecvStream, + trailers: bool, +} + +impl Future for Process { + type Item = (); + type Error = h2::Error; + + fn poll(&mut self) -> Poll<(), h2::Error> { + loop { + if self.trailers { + return match self.body.poll_trailers()? { + Async::NotReady => Ok(Async::NotReady), + Async::Ready(_) => Ok(().into()), + }; + } else { + match self.body.poll()? { + Async::NotReady => return Ok(Async::NotReady), + Async::Ready(None) => { + self.trailers = true; + }, + _ => {}, + } + } + } + } +} + + +#[test] +fn hammer_client_concurrency() { + // This reproduces issue #326. + const N: usize = 5000; + + let server = Server::serve(|| Bytes::from_static(b"hello world!")); + + let addr = server.addr(); + let rsps = Arc::new(AtomicUsize::new(0)); + + for i in 0..N { + print!("sending {}", i); + let rsps = rsps.clone(); + let tcp = TcpStream::connect(&addr); + let tcp = tcp.then(|res| { + let tcp = res.unwrap(); + client::handshake(tcp) + }).then(move |res| { + let rsps = rsps; + let (mut client, h2) = res.unwrap(); + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + + let (response, mut stream) = client.send_request(request, false).unwrap(); + stream.send_trailers(HeaderMap::new()).unwrap(); + + tokio::spawn(h2.map_err(|e| panic!("client conn error: {:?}", e))); + + response + .and_then(|response| { + let (_, body) = response.into_parts(); + + Process { + body, + trailers: false, + } + }) + .map_err(|e| { + panic!("client error: {:?}", e); + }) + .map(move |_| { + rsps.fetch_add(1, Ordering::Release); + }) + }); + + tokio::run(tcp); + println!("...done"); + } + + println!("all done"); + + assert_eq!(N, rsps.load(Ordering::Acquire)); + assert_eq!(N, server.request_count()); +}