From d67dbc602833570ac3dfcda9f2becb5daa11d785 Mon Sep 17 00:00:00 2001 From: Ed Barnard Date: Mon, 18 Jul 2016 23:32:38 +0100 Subject: [PATCH] feat(server): Server::new can take one or more listeners Closes #859 --- src/net.rs | 19 ++++++++++++ src/server/mod.rs | 79 ++++++++++++++++++++++++++++++----------------- tests/server.rs | 48 +++++++++++++++++++++++++--- 3 files changed, 114 insertions(+), 32 deletions(-) diff --git a/src/net.rs b/src/net.rs index 2a4c6a3c..d123a70a 100644 --- a/src/net.rs +++ b/src/net.rs @@ -1,6 +1,7 @@ //! A collection of traits abstracting over Listeners and Streams. use std::io::{self, Read, Write}; use std::net::{SocketAddr}; +use std::option; use rotor::mio::tcp::{TcpStream, TcpListener}; use rotor::mio::{Selector, Token, Evented, EventSet, PollOpt, TryAccept}; @@ -168,6 +169,15 @@ impl Evented for HttpListener { } } +impl IntoIterator for HttpListener { + type Item = Self; + type IntoIter = option::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + Some(self).into_iter() + } +} + /// Deprecated /// /// Use `SslClient` and `SslServer` instead. @@ -390,6 +400,15 @@ impl Evented for HttpsListener { } } +impl IntoIterator for HttpsListener { + type Item = Self; + type IntoIter = option::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + Some(self).into_iter() + } +} + fn _assert_transport() { fn _assert() {} _assert::>(); diff --git a/src/server/mod.rs b/src/server/mod.rs index c3a5dbd1..fbb53770 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -37,19 +37,26 @@ impl> fmt::Debug for ServerLoop { /// A Server that can accept incoming network requests. #[derive(Debug)] -pub struct Server { - listener: T, +pub struct Server { + lead_listener: A, + other_listeners: Vec, keep_alive: bool, idle_timeout: Option, max_sockets: usize, } -impl Server where T: Accept, T::Output: Transport { - /// Creates a new server with the provided Listener. - #[inline] - pub fn new(listener: T) -> Server { +impl Server { + /// Creates a new Server from one or more Listeners. + /// + /// Panics if listeners is an empty iterator. + pub fn new>(listeners: I) -> Server { + let mut listeners = listeners.into_iter(); + let lead_listener = listeners.next().expect("Server::new requires at least 1 listener"); + let other_listeners = listeners.collect::>(); + Server { - listener: listener, + lead_listener: lead_listener, + other_listeners: other_listeners, keep_alive: true, idle_timeout: Some(Duration::from_secs(10)), max_sockets: 4096, @@ -59,7 +66,7 @@ impl Server where T: Accept, T::Output: Transport { /// Enables or disables HTTP keep-alive. /// /// Default is true. - pub fn keep_alive(mut self, val: bool) -> Server { + pub fn keep_alive(mut self, val: bool) -> Server { self.keep_alive = val; self } @@ -67,7 +74,7 @@ impl Server where T: Accept, T::Output: Transport { /// Sets how long an idle connection will be kept before closing. /// /// Default is 10 seconds. - pub fn idle_timeout(mut self, val: Option) -> Server { + pub fn idle_timeout(mut self, val: Option) -> Server { self.idle_timeout = val; self } @@ -75,7 +82,7 @@ impl Server where T: Accept, T::Output: Transport { /// Sets the maximum open sockets for this Server. /// /// Default is 4096, but most servers can handle much more than this. - pub fn max_sockets(mut self, val: usize) -> Server { + pub fn max_sockets(mut self, val: usize) -> Server { self.max_sockets = val; self } @@ -105,33 +112,48 @@ impl Server> { } -impl Server where A::Output: Transport { +impl Server { /// Binds to a socket and starts handling connections. pub fn handle(self, factory: H) -> ::Result<(Listening, ServerLoop)> where H: HandlerFactory { - let addr = try!(self.listener.local_addr()); let shutdown = Arc::new(AtomicBool::new(false)); - let shutdown_rx = shutdown.clone(); - + let mut config = rotor::Config::new(); config.slab_capacity(self.max_sockets); config.mio().notify_capacity(self.max_sockets); let keep_alive = self.keep_alive; let idle_timeout = self.idle_timeout; let mut loop_ = rotor::Loop::new(&config).unwrap(); + + let mut addrs = Vec::with_capacity(1 + self.other_listeners.len()); + + // Add the lead listener. This one handles shutdown messages. let mut notifier = None; { let notifier = &mut notifier; + let listener = self.lead_listener; + addrs.push(try!(listener.local_addr())); + let shutdown_rx = shutdown.clone(); loop_.add_machine_with(move |scope| { *notifier = Some(scope.notifier()); - rotor_try!(scope.register(&self.listener, EventSet::readable(), PollOpt::level())); - rotor::Response::ok(ServerFsm::Listener::(self.listener, shutdown_rx)) + rotor_try!(scope.register(&listener, EventSet::readable(), PollOpt::level())); + rotor::Response::ok(ServerFsm::Listener(listener, shutdown_rx)) }).unwrap(); } let notifier = notifier.expect("loop.add_machine failed"); + // Add the other listeners. + for listener in self.other_listeners { + addrs.push(try!(listener.local_addr())); + let shutdown_rx = shutdown.clone(); + loop_.add_machine_with(move |scope| { + rotor_try!(scope.register(&listener, EventSet::readable(), PollOpt::level())); + rotor::Response::ok(ServerFsm::Listener(listener, shutdown_rx)) + }).unwrap(); + } + let listening = Listening { - addr: addr, + addrs: addrs, shutdown: (shutdown, notifier), }; let server = ServerLoop { @@ -299,14 +321,14 @@ where A: Accept, /// A handle of the running server. pub struct Listening { - addr: SocketAddr, + addrs: Vec, shutdown: (Arc, rotor::Notifier), } impl fmt::Debug for Listening { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Listening") - .field("addr", &self.addr) + .field("addrs", &self.addrs) .field("closed", &self.shutdown.0.load(Ordering::Relaxed)) .finish() } @@ -314,14 +336,20 @@ impl fmt::Debug for Listening { impl fmt::Display for Listening { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Display::fmt(&self.addr, f) + for (i, addr) in self.addrs().iter().enumerate() { + if i > 1 { + try!(f.write_str(", ")); + } + try!(fmt::Display::fmt(addr, f)); + } + Ok(()) } } impl Listening { - /// The address this server is listening on. - pub fn addr(&self) -> &SocketAddr { - &self.addr + /// The addresses this server is listening on. + pub fn addrs(&self) -> &[SocketAddr] { + &self.addrs } /// Stop the server from listening to its socket address. @@ -375,8 +403,3 @@ where F: FnMut(http::Control) -> H, H: Handler, T: Transport { self(ctrl) } } - -#[cfg(test)] -mod tests { - -} diff --git a/tests/server.rs b/tests/server.rs index d56493f5..974af21e 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -7,7 +7,7 @@ use std::sync::mpsc; use std::time::Duration; use hyper::{Next, Encoder, Decoder}; -use hyper::net::HttpStream; +use hyper::net::{HttpListener, HttpStream}; use hyper::server::{Server, Handler, Request, Response}; struct Serve { @@ -17,8 +17,14 @@ struct Serve { } impl Serve { + fn addrs(&self) -> &[SocketAddr] { + self.listening.as_ref().unwrap().addrs() + } + fn addr(&self) -> &SocketAddr { - self.listening.as_ref().unwrap().addr() + let addrs = self.addrs(); + assert!(addrs.len() == 1); + &addrs[0] } /* @@ -161,11 +167,22 @@ fn serve() -> Serve { } fn serve_with_timeout(dur: Option) -> Serve { + serve_n_with_timeout(1, dur) +} + +fn serve_n(n: u32) -> Serve { + serve_n_with_timeout(n, None) +} + +fn serve_n_with_timeout(n: u32, dur: Option) -> Serve { use std::thread; let (msg_tx, msg_rx) = mpsc::channel(); let (reply_tx, reply_rx) = mpsc::channel(); - let (listening, server) = Server::http(&"127.0.0.1:0".parse().unwrap()).unwrap() + + let addr = "127.0.0.1:0".parse().unwrap(); + let listeners = (0..n).map(|_| HttpListener::bind(&addr).unwrap()); + let (listening, server) = Server::new(listeners) .handle(move |_| { let mut replies = Vec::new(); while let Ok(reply) = reply_rx.try_recv() { @@ -180,7 +197,7 @@ fn serve_with_timeout(dur: Option) -> Serve { }).unwrap(); - let thread_name = format!("test-server-{}: {:?}", listening.addr(), dur); + let thread_name = format!("test-server-{}: {:?}", listening, dur); thread::Builder::new().name(thread_name).spawn(move || { server.run(); }).unwrap(); @@ -439,3 +456,26 @@ fn server_keep_alive() { } } } + +#[test] +fn server_get_with_body_three_listeners() { + let server = serve_n(3); + let addrs = server.addrs(); + assert_eq!(addrs.len(), 3); + + for (i, addr) in addrs.iter().enumerate() { + let mut req = TcpStream::connect(addr).unwrap(); + write!(req, "\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Content-Length: 17\r\n\ + \r\n\ + I'm sending to {}.\r\n\ + ", i).unwrap(); + req.read(&mut [0; 256]).unwrap(); + + // note: doesnt include trailing \r\n, cause Content-Length wasn't 19 + let comparison = format!("I'm sending to {}.", i).into_bytes(); + assert_eq!(server.body(), comparison); + } +}