Merge pull request #827 from hyperium/enc-dec-get-ref

add get_ref methods to Encoder and Decoder
This commit is contained in:
Sean McArthur
2016-06-14 20:18:28 +02:00
committed by GitHub
5 changed files with 36 additions and 298 deletions

View File

@@ -1,291 +0,0 @@
//! Client Connection Pooling
use std::borrow::ToOwned;
use std::collections::HashMap;
use std::io::{self, Read, Write};
use std::net::{SocketAddr, Shutdown};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use net::{NetworkConnector, NetworkStream, DefaultConnector};
/// The `NetworkConnector` that behaves as a connection pool used by hyper's `Client`.
pub struct Pool<C: NetworkConnector> {
connector: C,
inner: Arc<Mutex<PoolImpl<<C as NetworkConnector>::Stream>>>
}
/// Config options for the `Pool`.
#[derive(Debug)]
pub struct Config {
/// The maximum idle connections *per host*.
pub max_idle: usize,
}
impl Default for Config {
#[inline]
fn default() -> Config {
Config {
max_idle: 5,
}
}
}
#[derive(Debug)]
struct PoolImpl<S> {
conns: HashMap<Key, Vec<PooledStreamInner<S>>>,
config: Config,
}
type Key = (String, u16, Scheme);
fn key<T: Into<Scheme>>(host: &str, port: u16, scheme: T) -> Key {
(host.to_owned(), port, scheme.into())
}
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
enum Scheme {
Http,
Https,
Other(String)
}
impl<'a> From<&'a str> for Scheme {
fn from(s: &'a str) -> Scheme {
match s {
"http" => Scheme::Http,
"https" => Scheme::Https,
s => Scheme::Other(String::from(s))
}
}
}
impl Pool<DefaultConnector> {
/// Creates a `Pool` with a `DefaultConnector`.
#[inline]
pub fn new(config: Config) -> Pool<DefaultConnector> {
Pool::with_connector(config, DefaultConnector::default())
}
}
impl<C: NetworkConnector> Pool<C> {
/// Creates a `Pool` with a specified `NetworkConnector`.
#[inline]
pub fn with_connector(config: Config, connector: C) -> Pool<C> {
Pool {
connector: connector,
inner: Arc::new(Mutex::new(PoolImpl {
conns: HashMap::new(),
config: config,
}))
}
}
/// Clear all idle connections from the Pool, closing them.
#[inline]
pub fn clear_idle(&mut self) {
self.inner.lock().unwrap().conns.clear();
}
}
impl<S> PoolImpl<S> {
fn reuse(&mut self, key: Key, conn: PooledStreamInner<S>) {
trace!("reuse {:?}", key);
let conns = self.conns.entry(key).or_insert(vec![]);
if conns.len() < self.config.max_idle {
conns.push(conn);
}
}
}
impl<C: NetworkConnector<Stream=S>, S: NetworkStream + Send> NetworkConnector for Pool<C> {
type Stream = PooledStream<S>;
fn connect(&self, host: &str, port: u16, scheme: &str) -> ::Result<PooledStream<S>> {
let key = key(host, port, scheme);
let inner = {
// keep the mutex locked only in this block
let mut locked = self.inner.lock().unwrap();
let mut should_remove = false;
let inner = locked.conns.get_mut(&key).map(|vec| {
trace!("Pool had connection, using");
should_remove = vec.len() == 1;
vec.pop().unwrap()
});
if should_remove {
locked.conns.remove(&key);
}
inner
};
let inner = match inner {
Some(inner) => inner,
None => PooledStreamInner {
key: key.clone(),
stream: try!(self.connector.connect(host, port, scheme)),
previous_response_expected_no_content: false,
}
};
Ok(PooledStream {
inner: Some(inner),
is_closed: false,
pool: self.inner.clone(),
})
}
}
/// A Stream that will try to be returned to the Pool when dropped.
#[derive(Debug)]
pub struct PooledStream<S> {
inner: Option<PooledStreamInner<S>>,
is_closed: bool,
pool: Arc<Mutex<PoolImpl<S>>>,
}
impl<S: NetworkStream> PooledStream<S> {
/// Take the wrapped stream out of the pool completely.
pub fn into_inner(mut self) -> S {
self.inner.take().expect("PooledStream lost its inner stream").stream
}
}
#[derive(Debug)]
struct PooledStreamInner<S> {
key: Key,
stream: S,
previous_response_expected_no_content: bool,
}
impl<S: NetworkStream> Read for PooledStream<S> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.inner.as_mut().unwrap().stream.read(buf) {
Ok(0) => {
// if the wrapped stream returns EOF (Ok(0)), that means the
// server has closed the stream. we must be sure this stream
// is dropped and not put back into the pool.
self.is_closed = true;
Ok(0)
},
r => r
}
}
}
impl<S: NetworkStream> Write for PooledStream<S> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.as_mut().unwrap().stream.write(buf)
}
#[inline]
fn flush(&mut self) -> io::Result<()> {
self.inner.as_mut().unwrap().stream.flush()
}
}
impl<S: NetworkStream> NetworkStream for PooledStream<S> {
#[inline]
fn peer_addr(&mut self) -> io::Result<SocketAddr> {
self.inner.as_mut().unwrap().stream.peer_addr()
}
#[inline]
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.as_ref().unwrap().stream.set_read_timeout(dur)
}
#[inline]
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.as_ref().unwrap().stream.set_write_timeout(dur)
}
#[inline]
fn close(&mut self, how: Shutdown) -> io::Result<()> {
self.is_closed = true;
self.inner.as_mut().unwrap().stream.close(how)
}
#[inline]
fn set_previous_response_expected_no_content(&mut self, expected: bool) {
trace!("set_previous_response_expected_no_content {}", expected);
self.inner.as_mut().unwrap().previous_response_expected_no_content = expected;
}
#[inline]
fn previous_response_expected_no_content(&self) -> bool {
let answer = self.inner.as_ref().unwrap().previous_response_expected_no_content;
trace!("previous_response_expected_no_content {}", answer);
answer
}
}
impl<S> Drop for PooledStream<S> {
fn drop(&mut self) {
trace!("PooledStream.drop, is_closed={}", self.is_closed);
if !self.is_closed {
self.inner.take().map(|inner| {
if let Ok(mut pool) = self.pool.lock() {
pool.reuse(inner.key.clone(), inner);
}
// else poisoned, give up
});
}
}
}
#[cfg(test)]
mod tests {
use std::net::Shutdown;
use std::io::Read;
use mock::{MockConnector};
use net::{NetworkConnector, NetworkStream};
use super::{Pool, key};
macro_rules! mocked {
() => ({
Pool::with_connector(Default::default(), MockConnector)
})
}
#[test]
fn test_connect_and_drop() {
let pool = mocked!();
let key = key("127.0.0.1", 3000, "http");
pool.connect("127.0.0.1", 3000, "http").unwrap();
{
let locked = pool.inner.lock().unwrap();
assert_eq!(locked.conns.len(), 1);
assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
}
pool.connect("127.0.0.1", 3000, "http").unwrap(); //reused
{
let locked = pool.inner.lock().unwrap();
assert_eq!(locked.conns.len(), 1);
assert_eq!(locked.conns.get(&key).unwrap().len(), 1);
}
}
#[test]
fn test_closed() {
let pool = mocked!();
let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
stream.close(Shutdown::Both).unwrap();
drop(stream);
let locked = pool.inner.lock().unwrap();
assert_eq!(locked.conns.len(), 0);
}
#[test]
fn test_eof_closes() {
let pool = mocked!();
let mut stream = pool.connect("127.0.0.1", 3000, "http").unwrap();
assert_eq!(stream.read(&mut [0]).unwrap(), 0);
drop(stream);
let locked = pool.inner.lock().unwrap();
assert_eq!(locked.conns.len(), 0);
}
}

View File

@@ -94,6 +94,12 @@ pub struct BufReader<'a, R: io::Read + 'a> {
reader: &'a mut R
}
impl<'a, R: io::Read + 'a> BufReader<'a, R> {
pub fn get_ref(&self) -> &R {
self.reader
}
}
impl<'a, R: io::Read> Read for BufReader<'a, R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
trace!("BufReader.read self={}, buf={}", self.buf.len(), buf.len());

View File

@@ -44,6 +44,15 @@ enum Trans<'a, T: Read + 'a> {
Buf(self::buffer::BufReader<'a, T>)
}
impl<'a, T: Read + 'a> Trans<'a, T> {
fn get_ref(&self) -> &T {
match *self {
Trans::Port(ref t) => &*t,
Trans::Buf(ref buf) => buf.get_ref()
}
}
}
impl<'a, T: Read + 'a> Read for Trans<'a, T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match *self {
@@ -62,12 +71,27 @@ impl<'a, T: Read> Decoder<'a, T> {
fn h1(decoder: &'a mut h1::Decoder, transport: Trans<'a, T>) -> Decoder<'a, T> {
Decoder(DecoderImpl::H1(decoder, transport))
}
/// Get a reference to the transport.
pub fn get_ref(&self) -> &T {
match self.0 {
DecoderImpl::H1(_, ref transport) => transport.get_ref()
}
}
}
impl<'a, T: Transport> Encoder<'a, T> {
fn h1(encoder: &'a mut h1::Encoder, transport: &'a mut T) -> Encoder<'a, T> {
Encoder(EncoderImpl::H1(encoder, transport))
}
/// Get a reference to the transport.
pub fn get_ref(&self) -> &T {
match self.0 {
EncoderImpl::H1(_, ref transport) => &*transport
}
}
}
impl<'a, T: Read> Read for Decoder<'a, T> {

View File

@@ -53,11 +53,6 @@ pub enum Blocked {
Write,
}
impl Transport for HttpStream {
fn take_socket_error(&mut self) -> io::Result<()> {
self.0.take_socket_error()
}
}
/// Accepts sockets asynchronously.
pub trait Accept: Evented {
@@ -73,6 +68,12 @@ pub trait Accept: Evented {
#[derive(Debug)]
pub struct HttpStream(pub TcpStream);
impl Transport for HttpStream {
fn take_socket_error(&mut self) -> io::Result<()> {
self.0.take_socket_error()
}
}
impl Read for HttpStream {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {

View File

@@ -2,7 +2,6 @@
//!
//! These are requests that a `hyper::Server` receives, and include its method,
//! target URI, headers, and message body.
//use std::net::SocketAddr;
use version::HttpVersion;
use method::Method;
@@ -16,7 +15,6 @@ pub fn new(incoming: RequestHead) -> Request {
debug!("{:#?}", headers);
Request {
//remote_addr: addr,
method: method,
uri: uri,
headers: headers,