From 2904668105649e674aab7d0945a8f117a0dba32b Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Fri, 8 Jul 2016 10:07:02 -0700 Subject: [PATCH] feat(client): implement connection pooling for Client Closes #830 Closes #848 --- src/client/mod.rs | 95 +++++++++++++++++++++++++++++----------- src/http/conn.rs | 76 ++++++++++++++++++++++++-------- src/http/h1/parse.rs | 8 ++++ src/http/mod.rs | 38 ++++++++++------ tests/client.rs | 25 +++++++++++ tests/server.rs | 100 +++++++++++++++++++++++++++++++++++-------- 6 files changed, 268 insertions(+), 74 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 32656c55..fed4505d 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -5,6 +5,7 @@ use std::collections::HashMap; use std::fmt; +use std::io; use std::marker::PhantomData; use std::sync::mpsc; use std::thread; @@ -24,7 +25,6 @@ pub use self::response::Response; mod connect; mod dns; -//mod pool; mod request; mod response; @@ -116,6 +116,7 @@ impl Client { loop_.run(Context { connect_timeout: connect_timeout, keep_alive: keep_alive, + idle_conns: HashMap::new(), queue: HashMap::new(), }).unwrap() })); @@ -332,7 +333,7 @@ impl, T: Transport> http::MessageHandler for Message { struct Context { connect_timeout: Duration, keep_alive: bool, - // idle: HashMap>, + idle_conns: HashMap>, queue: HashMap>>, } @@ -352,6 +353,27 @@ impl Context { } queued } + + fn conn_response(&mut self, conn: Option<(http::Conn>, Option)>, time: rotor::Time) + -> rotor::Response, (C::Key, C::Output)> + where C: Connect, H: Handler { + match conn { + Some((conn, timeout)) => { + //TODO: HTTP2: a connection doesn't need to be idle to be used for a second stream + if conn.is_idle() { + self.idle_conns.entry(conn.key().clone()).or_insert_with(Vec::new) + .push(conn.control()); + } + match timeout { + Some(dur) => rotor::Response::ok(ClientFsm::Socket(conn)) + .deadline(time + dur), + None => rotor::Response::ok(ClientFsm::Socket(conn)), + } + + } + None => rotor::Response::done() + } + } } impl, T: Transport> http::MessageHandlerFactory for Context { @@ -414,14 +436,9 @@ where C: Connect, unreachable!("Connector can never be ready") }, ClientFsm::Socket(conn) => { - match conn.ready(events, scope) { - Some((conn, None)) => rotor::Response::ok(ClientFsm::Socket(conn)), - Some((conn, Some(dur))) => { - rotor::Response::ok(ClientFsm::Socket(conn)) - .deadline(scope.now() + dur) - } - None => rotor::Response::done() - } + let res = conn.ready(events, scope); + let now = scope.now(); + scope.conn_response(res, now) } } } @@ -461,14 +478,9 @@ where C: Connect, } } ClientFsm::Socket(conn) => { - match conn.timeout(scope) { - Some((conn, None)) => rotor::Response::ok(ClientFsm::Socket(conn)), - Some((conn, Some(dur))) => { - rotor::Response::ok(ClientFsm::Socket(conn)) - .deadline(scope.now() + dur) - } - None => rotor::Response::done() - } + let res = conn.timeout(scope); + let now = scope.now(); + scope.conn_response(res, now) } } } @@ -478,13 +490,10 @@ where C: Connect, ClientFsm::Connector(..) => { self.connect(scope) }, - ClientFsm::Socket(conn) => match conn.wakeup(scope) { - Some((conn, None)) => rotor::Response::ok(ClientFsm::Socket(conn)), - Some((conn, Some(dur))) => { - rotor::Response::ok(ClientFsm::Socket(conn)) - .deadline(scope.now() + dur) - } - None => rotor::Response::done() + ClientFsm::Socket(conn) => { + let res = conn.wakeup(scope); + let now = scope.now(); + scope.conn_response(res, now) } } } @@ -513,7 +522,41 @@ where C: Connect, loop { match rx.try_recv() { Ok(Notify::Connect(url, mut handler)) => { - // TODO: check pool for sockets to this domain + // check pool for sockets to this domain + if let Some(key) = connector.key(&url) { + let mut remove_idle = false; + let mut woke_up = false; + if let Some(mut idle) = scope.idle_conns.get_mut(&key) { + while !idle.is_empty() { + let ctrl = idle.remove(0); + // err means the socket has since died + if ctrl.ready(Next::write()).is_ok() { + woke_up = true; + break; + } + } + remove_idle = idle.is_empty(); + } + if remove_idle { + scope.idle_conns.remove(&key); + } + + if woke_up { + trace!("woke up idle conn for '{}'", url); + let deadline = scope.now() + scope.connect_timeout; + scope.queue.entry(key).or_insert_with(Vec::new).push(Queued { + deadline: deadline, + handler: handler, + url: url + }); + continue; + } + } else { + // this connector cannot handle this url anyways + let _ = handler.on_error(io::Error::new(io::ErrorKind::InvalidInput, "invalid url for connector").into()); + continue; + } + // no exist connection, call connector match connector.connect(&url) { Ok(key) => { let deadline = scope.now() + scope.connect_timeout; diff --git a/src/http/conn.rs b/src/http/conn.rs index 8c6c23e1..9d727e72 100644 --- a/src/http/conn.rs +++ b/src/http/conn.rs @@ -63,8 +63,8 @@ impl> ConnInner { fn interest(&self) -> Reg { match self.state { State::Closed => Reg::Remove, - State::Init => { - ::Message::initial_interest().interest() + State::Init { interest, .. } => { + interest.register() } State::Http1(Http1 { reading: Reading::Closed, writing: Writing::Closed, .. }) => { Reg::Remove @@ -142,12 +142,12 @@ impl> ConnInner { fn read>(&mut self, scope: &mut Scope, state: State) -> State { match state { - State::Init => { + State::Init { interest: Next_::Read, .. } => { let head = match self.parse() { Ok(head) => head, Err(::Error::Io(e)) => match e.kind() { io::ErrorKind::WouldBlock | - io::ErrorKind::Interrupted => return State::Init, + io::ErrorKind::Interrupted => return state, _ => { debug!("io error trying to parse {:?}", e); return State::Closed; @@ -219,6 +219,10 @@ impl> ConnInner { } } }, + State::Init { .. } => { + trace!("on_readable State::{:?}", state); + state + }, State::Http1(mut http1) => { let next = match http1.reading { Reading::Init => None, @@ -274,7 +278,7 @@ impl> ConnInner { if let Some(next) = next { s.update(next); } - trace!("Conn.on_readable State::Http1 completed, new state = {:?}", s); + trace!("Conn.on_readable State::Http1 completed, new state = State::{:?}", s); let again = match s { State::Http1(Http1 { reading: Reading::Body(ref encoder), .. }) => encoder.is_eof(), @@ -296,7 +300,7 @@ impl> ConnInner { fn write>(&mut self, scope: &mut Scope, mut state: State) -> State { let next = match state { - State::Init => { + State::Init { interest: Next_::Write, .. } => { // this could be a Client request, which writes first, so pay // attention to the version written here, which will adjust // our internal state to Http1 or Http2 @@ -336,6 +340,10 @@ impl> ConnInner { } Some(interest) } + State::Init { .. } => { + trace!("Conn.on_writable State::{:?}", state); + None + } State::Http1(Http1 { ref mut handler, ref mut writing, ref mut keep_alive, .. }) => { match *writing { Writing::Init => { @@ -426,7 +434,7 @@ impl> ConnInner { fn can_read_more(&self) -> bool { match self.state { - State::Init => false, + State::Init { .. } => false, _ => !self.buf.is_empty() } } @@ -435,7 +443,7 @@ impl> ConnInner { debug!("on_error err = {:?}", err); trace!("on_error state = {:?}", self.state); let next = match self.state { - State::Init => Next::remove(), + State::Init { .. } => Next::remove(), State::Http1(ref mut http1) => http1.handler.on_error(err), State::Closed => Next::remove(), }; @@ -461,7 +469,7 @@ impl> ConnInner { fn on_remove(self) { debug!("on_remove"); match self.state { - State::Init | State::Closed => (), + State::Init { .. } | State::Closed => (), State::Http1(http1) => http1.handler.on_remove(self.transport), } } @@ -475,7 +483,10 @@ impl> Conn { ctrl: channel::new(notify), keep_alive_enabled: true, key: key, - state: State::Init, + state: State::Init { + interest: H::Message::initial_interest().interest, + timeout: None, + }, transport: transport, })) } @@ -585,10 +596,30 @@ impl> Conn { self.0.on_remove() } + pub fn key(&self) -> &K { + &self.0.key + } + + pub fn control(&self) -> Control { + Control { + tx: self.0.ctrl.0.clone(), + } + } + + pub fn is_idle(&self) -> bool { + if let State::Init { interest: Next_::Wait, .. } = self.0.state { + true + } else { + false + } + } } enum State, T: Transport> { - Init, + Init { + interest: Next_, + timeout: Option, + }, /// Http1 will only ever use a connection to send and receive a single /// message at a time. Once a H1 status has been determined, we will either /// be reading or writing an H1 message, and optionally multiple if @@ -606,7 +637,7 @@ enum State, T: Transport> { impl, T: Transport> State { fn timeout(&self) -> Option { match *self { - State::Init => None, + State::Init { timeout, .. } => timeout, State::Http1(ref http1) => http1.timeout, State::Closed => None, } @@ -616,7 +647,10 @@ impl, T: Transport> State { impl, T: Transport> fmt::Debug for State { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - State::Init => f.write_str("Init"), + State::Init { interest, timeout } => f.debug_struct("Init") + .field("interest", &interest) + .field("timeout", &timeout) + .finish(), State::Http1(ref h1) => f.debug_tuple("Http1") .field(h1) .finish(), @@ -632,10 +666,14 @@ impl, T: Transport> State { let new_state = match (state, next.interest) { (_, Next_::Remove) => State::Closed, (State::Closed, _) => State::Closed, - (State::Init, _) => State::Init, + (State::Init { timeout, .. }, e) => State::Init { + interest: e, + timeout: timeout, + }, (State::Http1(http1), Next_::End) => { let reading = match http1.reading { - Reading::Body(ref decoder) if decoder.is_eof() => { + Reading::Body(ref decoder) | + Reading::Wait(ref decoder) if decoder.is_eof() => { if http1.keep_alive { Reading::KeepAlive } else { @@ -646,6 +684,7 @@ impl, T: Transport> State { _ => Reading::Closed, }; let writing = match http1.writing { + Writing::Wait(encoder) | Writing::Ready(encoder) => { if encoder.is_eof() { if http1.keep_alive { @@ -691,8 +730,11 @@ impl, T: Transport> State { }; match (reading, writing) { (Reading::KeepAlive, Writing::KeepAlive) => { - //http1.handler.on_keep_alive(); - State::Init + //XXX keepalive + State::Init { + interest: H::Message::keep_alive_interest().interest, + timeout: None, + } }, (reading, Writing::Chunk(chunk)) => { State::Http1(Http1 { diff --git a/src/http/h1/parse.rs b/src/http/h1/parse.rs index 07867b3d..489342e5 100644 --- a/src/http/h1/parse.rs +++ b/src/http/h1/parse.rs @@ -31,6 +31,10 @@ impl Http1Message for ServerMessage { Next::new(Next_::Read) } + fn keep_alive_interest() -> Next { + Next::new(Next_::Read) + } + fn parse(buf: &[u8]) -> ParseResult { let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; trace!("Request.parse([Header; {}], [u8; {}])", headers.len(), buf.len()); @@ -114,6 +118,10 @@ impl Http1Message for ClientMessage { Next::new(Next_::Write) } + fn keep_alive_interest() -> Next { + Next::new(Next_::Wait) + } + fn parse(buf: &[u8]) -> ParseResult { let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; trace!("Response.parse([Header; {}], [u8; {}])", headers.len(), buf.len()); diff --git a/src/http/mod.rs b/src/http/mod.rs index 883cc7f0..b9fe7c5c 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -249,14 +249,16 @@ impl Deserialize for RawStatus { /// Checks if a connection should be kept alive. #[inline] pub fn should_keep_alive(version: HttpVersion, headers: &Headers) -> bool { - trace!("should_keep_alive( {:?}, {:?} )", version, headers.get::()); - match (version, headers.get::()) { + let ret = match (version, headers.get::()) { (Http10, None) => false, (Http10, Some(conn)) if !conn.contains(&KeepAlive) => false, (Http11, Some(conn)) if conn.contains(&Close) => false, _ => true - } + }; + trace!("should_keep_alive(version={:?}, header={:?}) = {:?}", version, headers.get::(), ret); + ret } + pub type ParseResult = ::Result, usize)>>; pub fn parse, I>(rdr: &[u8]) -> ParseResult { @@ -280,6 +282,7 @@ pub trait Http1Message { type Outgoing: Default; //TODO: replace with associated const when stable fn initial_interest() -> Next; + fn keep_alive_interest() -> Next; fn parse(bytes: &[u8]) -> ParseResult; fn decoder(head: &MessageHead) -> ::Result; fn encode(head: MessageHead, dst: &mut Vec) -> h1::Encoder; @@ -304,6 +307,7 @@ impl fmt::Debug for Next { } } +// Internal enum for `Next` #[derive(Debug, Clone, Copy)] enum Next_ { Read, @@ -314,6 +318,8 @@ enum Next_ { Remove, } +// An enum representing all the possible actions to taken when registering +// with the event loop. #[derive(Debug, Clone, Copy)] enum Reg { Read, @@ -361,16 +367,11 @@ impl Next { } } - fn interest(&self) -> Reg { - match self.interest { - Next_::Read => Reg::Read, - Next_::Write => Reg::Write, - Next_::ReadWrite => Reg::ReadWrite, - Next_::Wait => Reg::Wait, - Next_::End => Reg::Remove, - Next_::Remove => Reg::Remove, - } + /* + fn reg(&self) -> Reg { + self.interest.register() } + */ /// Signals the desire to read from the transport. pub fn read() -> Next { @@ -410,6 +411,19 @@ impl Next { } } +impl Next_ { + fn register(&self) -> Reg { + match *self { + Next_::Read => Reg::Read, + Next_::Write => Reg::Write, + Next_::ReadWrite => Reg::ReadWrite, + Next_::Wait => Reg::Wait, + Next_::End => Reg::Remove, + Next_::Remove => Reg::Remove, + } + } +} + #[test] fn test_should_keep_alive() { let mut headers = Headers::new(); diff --git a/tests/client.rs b/tests/client.rs index 563250be..ee0ccafa 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -363,3 +363,28 @@ fn client_read_timeout() { other => panic!("expected timeout, actual: {:?}", other) } } + +#[test] +fn client_keep_alive() { + let server = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = server.local_addr().unwrap(); + let client = client(); + let res = client.request(format!("http://{}/a", addr), opts()); + + let mut sock = server.accept().unwrap().0; + sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let mut buf = [0; 4096]; + sock.read(&mut buf).expect("read 1"); + sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n").expect("write 1"); + + while let Ok(_) = res.recv() {} + + let res = client.request(format!("http://{}/b", addr), opts()); + sock.read(&mut buf).expect("read 2"); + let second_get = b"GET /b HTTP/1.1\r\n"; + assert_eq!(&buf[..second_get.len()], second_get); + sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n").expect("write 2"); + + while let Ok(_) = res.recv() {} +} diff --git a/tests/server.rs b/tests/server.rs index a93ac3c1..d56493f5 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -72,7 +72,7 @@ impl Drop for Serve { struct TestHandler { tx: mpsc::Sender, - rx: mpsc::Receiver, + reply: Vec, peeked: Option>, timeout: Option, } @@ -123,28 +123,26 @@ impl Handler for TestHandler { } fn on_response(&mut self, res: &mut Response) -> Next { - loop { - match self.rx.try_recv() { - Ok(Reply::Status(s)) => { + for reply in self.reply.drain(..) { + match reply { + Reply::Status(s) => { res.set_status(s); }, - Ok(Reply::Headers(headers)) => { + Reply::Headers(headers) => { use std::iter::Extend; res.headers_mut().extend(headers.iter()); }, - Ok(Reply::Body(body)) => { + Reply::Body(body) => { self.peeked = Some(body); }, - Err(..) => { - return if self.peeked.is_some() { - self.next(Next::write()) - } else { - self.next(Next::end()) - }; - }, } } + if self.peeked.is_some() { + self.next(Next::write()) + } else { + self.next(Next::end()) + } } fn on_response_writable(&mut self, encoder: &mut Encoder) -> Next { @@ -167,13 +165,18 @@ fn serve_with_timeout(dur: Option) -> Serve { let (msg_tx, msg_rx) = mpsc::channel(); let (reply_tx, reply_rx) = mpsc::channel(); - let mut reply_rx = Some(reply_rx); let (listening, server) = Server::http(&"127.0.0.1:0".parse().unwrap()).unwrap() - .handle(move |_| TestHandler { - tx: msg_tx.clone(), - timeout: dur, - rx: reply_rx.take().unwrap(), - peeked: None, + .handle(move |_| { + let mut replies = Vec::new(); + while let Ok(reply) = reply_rx.try_recv() { + replies.push(reply); + } + TestHandler { + tx: msg_tx.clone(), + timeout: dur, + reply: replies, + peeked: None, + } }).unwrap(); @@ -377,3 +380,62 @@ fn server_empty_response_chunked_without_calling_write() { assert_eq!(lines.next(), Some("")); assert_eq!(lines.next(), None); } + +#[test] +fn server_keep_alive() { + extern crate env_logger; + env_logger::init().unwrap(); + + let foo_bar = b"foo bar baz"; + let server = serve(); + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(foo_bar.len() as u64)) + .body(foo_bar); + let mut req = TcpStream::connect(server.addr()).unwrap(); + req.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); + req.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + req.write_all(b"\ + GET / HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: keep-alive\r\n\ + \r\n\ + ").expect("writing 1"); + + let mut buf = [0; 1024 * 8]; + loop { + let n = req.read(&mut buf[..]).expect("reading 1"); + if n < buf.len() { + if &buf[n - foo_bar.len()..n] == foo_bar { + break; + } else { + println!("{:?}", ::std::str::from_utf8(&buf[..n])); + } + } + } + + // try again! + + let quux = b"zar quux"; + server.reply() + .status(hyper::Ok) + .header(hyper::header::ContentLength(quux.len() as u64)) + .body(quux); + req.write_all(b"\ + GET /quux HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Connection: close\r\n\ + \r\n\ + ").expect("writing 2"); + + let mut buf = [0; 1024 * 8]; + loop { + let n = req.read(&mut buf[..]).expect("reading 2"); + assert!(n > 0, "n = {}", n); + if n < buf.len() { + if &buf[n - quux.len()..n] == quux { + break; + } + } + } +}