diff --git a/src/client.rs b/src/client.rs index 434e20b..f063953 100644 --- a/src/client.rs +++ b/src/client.rs @@ -187,6 +187,18 @@ impl Builder { self } + /// Set the maximum number of concurrent streams. + /// + /// Clients can only limit the maximum number of streams that that the + /// server can initiate. See [Section 5.1.2] in the HTTP/2 spec for more + /// details. + /// + /// [Section 5.1.2]: https://http2.github.io/http2-spec/#rfc.section.5.1.2 + pub fn max_concurrent_streams(&mut self, max: u32) -> &mut Self { + self.settings.set_max_concurrent_streams(Some(max)); + self + } + /// Enable or disable the server to send push promises. pub fn enable_push(&mut self, enabled: bool) -> &mut Self { self.settings.set_enable_push(enabled); diff --git a/src/frame/settings.rs b/src/frame/settings.rs index 9ac0e33..b130f43 100644 --- a/src/frame/settings.rs +++ b/src/frame/settings.rs @@ -74,7 +74,6 @@ impl Settings { self.max_concurrent_streams } - #[cfg(feature = "unstable")] pub fn set_max_concurrent_streams(&mut self, max: Option) { self.max_concurrent_streams = max; } diff --git a/src/proto/connection.rs b/src/proto/connection.rs index b3e399d..da3e125 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -76,7 +76,9 @@ where local_next_stream_id: next_stream_id, local_push_enabled: settings.is_push_enabled(), remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, - remote_max_initiated: None, + remote_max_initiated: settings + .max_concurrent_streams() + .map(|max| max as usize), }); Connection { state: State::Open, diff --git a/src/server.rs b/src/server.rs index 329a57f..29c3359 100644 --- a/src/server.rs +++ b/src/server.rs @@ -194,6 +194,18 @@ impl Builder { self } + /// Set the maximum number of concurrent streams. + /// + /// Servers can only limit the maximum number of streams that that the + /// client can initiate. See [Section 5.1.2] in the HTTP/2 spec for more + /// details. + /// + /// [Section 5.1.2]: https://http2.github.io/http2-spec/#rfc.section.5.1.2 + pub fn max_concurrent_streams(&mut self, max: u32) -> &mut Self { + self.settings.set_max_concurrent_streams(Some(max)); + self + } + /// Bind an H2 server connection. /// /// Returns a future which resolves to the connection value once the H2 diff --git a/tests/client_request.rs b/tests/client_request.rs index a1524f2..c6a927d 100644 --- a/tests/client_request.rs +++ b/tests/client_request.rs @@ -148,6 +148,45 @@ fn request_stream_id_overflows() { h2.join(srv).wait().expect("wait"); } +#[test] +fn client_builder_max_concurrent_streams() { + let _ = ::env_logger::init(); + let (io, srv) = mock::new(); + + let mut settings = frame::Settings::default(); + settings.set_max_concurrent_streams(Some(1)); + + let srv = srv + .assert_client_handshake() + .unwrap() + .recv_custom_settings(settings) + .recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos() + ) + .send_frame(frames::headers(1).response(200).eos()) + .close(); + + let mut builder = Client::builder(); + builder.max_concurrent_streams(1); + + let h2 = builder + .handshake::<_, Bytes>(io) + .expect("handshake") + .and_then(|(mut client, h2)| { + let request = Request::builder() + .method(Method::GET) + .uri("https://example.com/") + .body(()) + .unwrap(); + let req = client.send_request(request, true).unwrap().unwrap(); + h2.drive(req).map(move |(h2, _)| (client, h2)) + }); + + h2.join(srv).wait().expect("wait"); +} + #[test] fn request_over_max_concurrent_streams_errors() { let _ = ::env_logger::init(); diff --git a/tests/server.rs b/tests/server.rs index f579484..736e31e 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -22,6 +22,56 @@ fn read_preface_in_multiple_frames() { assert!(Stream::wait(h2).next().is_none()); } +#[test] +fn server_builder_set_max_concurrent_streams() { + let _ = ::env_logger::init(); + let (io, client) = mock::new(); + + let mut settings = frame::Settings::default(); + settings.set_max_concurrent_streams(Some(1)); + + let client = client + .assert_server_handshake() + .unwrap() + .recv_custom_settings(settings) + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/"), + ) + .send_frame( + frames::headers(3) + .request("GET", "https://example.com/"), + ) + .send_frame(frames::data(1, &b"hello"[..]).eos(),) + .recv_frame(frames::reset(3).refused()) + .recv_frame(frames::headers(1).response(200).eos()) + .close(); + + let mut builder = Server::builder(); + builder.max_concurrent_streams(1); + + let h2 = builder + .handshake::<_, Bytes>(io) + .expect("handshake") + .and_then(|srv| { + srv.into_future().unwrap().and_then(|(reqstream, srv)| { + let (req, mut stream) = reqstream.unwrap(); + + assert_eq!(req.method(), &http::Method::GET); + + let rsp = + http::Response::builder() + .status(200).body(()) + .unwrap(); + stream.send_response(rsp, true).unwrap(); + + srv.into_future().unwrap() + }) + }); + + h2.join(client).wait().expect("wait"); +} + #[test] fn serve_request() { let _ = ::env_logger::init(); diff --git a/tests/support/frames.rs b/tests/support/frames.rs index c1e39d2..4c2ae8b 100644 --- a/tests/support/frames.rs +++ b/tests/support/frames.rs @@ -249,6 +249,11 @@ impl Mock { let id = self.0.stream_id(); Mock(frame::Reset::new(id, frame::Reason::FLOW_CONTROL_ERROR)) } + + pub fn refused(self) -> Self { + let id = self.0.stream_id(); + Mock(frame::Reset::new(id, frame::Reason::REFUSED_STREAM)) + } } impl From> for SendFrame { diff --git a/tests/support/mock.rs b/tests/support/mock.rs index 0d54a87..4fc3171 100644 --- a/tests/support/mock.rs +++ b/tests/support/mock.rs @@ -384,19 +384,32 @@ impl AsyncWrite for Pipe { } pub trait HandleFutureExt { - fn recv_settings(self) -> RecvFrame, Handle), Error = ()>>> + fn recv_settings(self) + -> RecvFrame, Handle), Error = ()>>> where Self: Sized + 'static, Self: Future, Self::Error: fmt::Debug, { - let map = self.map(|(settings, handle)| (Some(settings.into()), handle)) + self.recv_custom_settings(frame::Settings::default()) + } + + fn recv_custom_settings(self, settings: frame::Settings) + -> RecvFrame, Handle), Error = ()>>> + where + Self: Sized + 'static, + Self: Future, + Self::Error: fmt::Debug, + { + let map = self + .map(|(settings, handle)| (Some(settings.into()), handle)) .unwrap(); - let boxed: Box, Handle), Error = ()>> = Box::new(map); + let boxed: Box, Handle), Error = ()>> = + Box::new(map); RecvFrame { inner: boxed, - frame: frame::Settings::default().into(), + frame: settings.into(), } }