Merge pull request #621 from hyperium/timeout
feat(net): add socket timeouts to Server and Client
This commit is contained in:
		| @@ -59,13 +59,16 @@ use std::default::Default; | ||||
| use std::io::{self, copy, Read}; | ||||
| use std::iter::Extend; | ||||
|  | ||||
| #[cfg(feature = "timeouts")] | ||||
| use std::time::Duration; | ||||
|  | ||||
| use url::UrlParser; | ||||
| use url::ParseError as UrlError; | ||||
|  | ||||
| use header::{Headers, Header, HeaderFormat}; | ||||
| use header::{ContentLength, Location}; | ||||
| use method::Method; | ||||
| use net::{NetworkConnector, NetworkStream}; | ||||
| use net::{NetworkConnector, NetworkStream, Fresh}; | ||||
| use {Url}; | ||||
| use Error; | ||||
|  | ||||
| @@ -87,7 +90,9 @@ pub struct Client { | ||||
|     protocol: Box<Protocol + Send + Sync>, | ||||
|     redirect_policy: RedirectPolicy, | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     read_timeout: Option<Duration> | ||||
|     read_timeout: Option<Duration>, | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     write_timeout: Option<Duration>, | ||||
| } | ||||
|  | ||||
| impl Client { | ||||
| @@ -108,11 +113,23 @@ impl Client { | ||||
|         Client::with_protocol(Http11Protocol::with_connector(connector)) | ||||
|     } | ||||
|  | ||||
|     #[cfg(not(feature = "timeouts"))] | ||||
|     /// Create a new client with a specific `Protocol`. | ||||
|     pub fn with_protocol<P: Protocol + Send + Sync + 'static>(protocol: P) -> Client { | ||||
|         Client { | ||||
|             protocol: Box::new(protocol), | ||||
|             redirect_policy: Default::default() | ||||
|             redirect_policy: Default::default(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     /// Create a new client with a specific `Protocol`. | ||||
|     pub fn with_protocol<P: Protocol + Send + Sync + 'static>(protocol: P) -> Client { | ||||
|         Client { | ||||
|             protocol: Box::new(protocol), | ||||
|             redirect_policy: Default::default(), | ||||
|             read_timeout: None, | ||||
|             write_timeout: None, | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -127,6 +144,12 @@ impl Client { | ||||
|         self.read_timeout = dur; | ||||
|     } | ||||
|  | ||||
|     /// Set the write timeout value for all requests. | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     pub fn set_write_timeout(&mut self, dur: Option<Duration>) { | ||||
|         self.write_timeout = dur; | ||||
|     } | ||||
|  | ||||
|     /// Build a Get request. | ||||
|     pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder<U> { | ||||
|         self.request(Method::Get, url) | ||||
| @@ -236,6 +259,20 @@ impl<'a, U: IntoUrl> RequestBuilder<'a, U> { | ||||
|             let mut req = try!(Request::with_message(method.clone(), url.clone(), message)); | ||||
|             headers.as_ref().map(|headers| req.headers_mut().extend(headers.iter())); | ||||
|  | ||||
|             #[cfg(not(feature = "timeouts"))] | ||||
|             fn set_timeouts(_req: &mut Request<Fresh>, _client: &Client) -> ::Result<()> { | ||||
|                 Ok(()) | ||||
|             } | ||||
|  | ||||
|             #[cfg(feature = "timeouts")] | ||||
|             fn set_timeouts(req: &mut Request<Fresh>, client: &Client) -> ::Result<()> { | ||||
|                 try!(req.set_write_timeout(client.write_timeout)); | ||||
|                 try!(req.set_read_timeout(client.read_timeout)); | ||||
|                 Ok(()) | ||||
|             } | ||||
|  | ||||
|             try!(set_timeouts(&mut req, &client)); | ||||
|  | ||||
|             match (can_have_body, body.as_ref()) { | ||||
|                 (true, Some(body)) => match body.size() { | ||||
|                     Some(size) => req.headers_mut().set(ContentLength(size)), | ||||
|   | ||||
| @@ -5,6 +5,9 @@ use std::io::{self, Read, Write}; | ||||
| use std::net::{SocketAddr, Shutdown}; | ||||
| use std::sync::{Arc, Mutex}; | ||||
|  | ||||
| #[cfg(feature = "timeouts")] | ||||
| use std::time::Duration; | ||||
|  | ||||
| use net::{NetworkConnector, NetworkStream, DefaultConnector}; | ||||
|  | ||||
| /// The `NetworkConnector` that behaves as a connection pool used by hyper's `Client`. | ||||
| @@ -153,6 +156,18 @@ impl<S: NetworkStream> NetworkStream for PooledStream<S> { | ||||
|         self.inner.as_mut().unwrap().1.peer_addr() | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     #[inline] | ||||
|     fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         self.inner.as_ref().unwrap().1.set_read_timeout(dur) | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     #[inline] | ||||
|     fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         self.inner.as_ref().unwrap().1.set_write_timeout(dur) | ||||
|     } | ||||
|  | ||||
|     #[inline] | ||||
|     fn close(&mut self, how: Shutdown) -> io::Result<()> { | ||||
|         self.is_closed = true; | ||||
|   | ||||
| @@ -2,6 +2,9 @@ | ||||
| use std::marker::PhantomData; | ||||
| use std::io::{self, Write}; | ||||
|  | ||||
| #[cfg(feature = "timeouts")] | ||||
| use std::time::Duration; | ||||
|  | ||||
| use url::Url; | ||||
|  | ||||
| use method::{self, Method}; | ||||
| @@ -39,6 +42,20 @@ impl<W> Request<W> { | ||||
|     /// Read the Request method. | ||||
|     #[inline] | ||||
|     pub fn method(&self) -> method::Method { self.method.clone() } | ||||
|  | ||||
|     /// Set the write timeout. | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     #[inline] | ||||
|     pub fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         self.message.set_write_timeout(dur) | ||||
|     } | ||||
|  | ||||
|     /// Set the read timeout. | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     #[inline] | ||||
|     pub fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         self.message.set_read_timeout(dur) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Request<Fresh> { | ||||
|   | ||||
| @@ -4,6 +4,8 @@ use std::cmp::min; | ||||
| use std::fmt; | ||||
| use std::io::{self, Write, BufWriter, BufRead, Read}; | ||||
| use std::net::Shutdown; | ||||
| #[cfg(feature = "timeouts")] | ||||
| use std::time::Duration; | ||||
|  | ||||
| use httparse; | ||||
|  | ||||
| @@ -192,6 +194,19 @@ impl HttpMessage for Http11Message { | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     #[inline] | ||||
|     fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         self.get_ref().set_read_timeout(dur) | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     #[inline] | ||||
|     fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         self.get_ref().set_write_timeout(dur) | ||||
|     } | ||||
|  | ||||
|     #[inline] | ||||
|     fn close_connection(&mut self) -> ::Result<()> { | ||||
|         try!(self.get_mut().close(Shutdown::Both)); | ||||
|         Ok(()) | ||||
| @@ -214,13 +229,27 @@ impl Http11Message { | ||||
|  | ||||
|     /// Gets a mutable reference to the underlying `NetworkStream`, regardless of the state of the | ||||
|     /// `Http11Message`. | ||||
|     pub fn get_mut(&mut self) -> &mut Box<NetworkStream + Send> { | ||||
|     pub fn get_ref(&self) -> &(NetworkStream + Send) { | ||||
|         if self.stream.is_some() { | ||||
|             self.stream.as_mut().unwrap() | ||||
|             &**self.stream.as_ref().unwrap() | ||||
|         } else if self.writer.is_some() { | ||||
|             self.writer.as_mut().unwrap().get_mut().get_mut() | ||||
|             &**self.writer.as_ref().unwrap().get_ref().get_ref() | ||||
|         } else if self.reader.is_some() { | ||||
|             self.reader.as_mut().unwrap().get_mut().get_mut() | ||||
|             &**self.reader.as_ref().unwrap().get_ref().get_ref() | ||||
|         } else { | ||||
|             panic!("Http11Message lost its underlying stream somehow"); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /// Gets a mutable reference to the underlying `NetworkStream`, regardless of the state of the | ||||
|     /// `Http11Message`. | ||||
|     pub fn get_mut(&mut self) -> &mut (NetworkStream + Send) { | ||||
|         if self.stream.is_some() { | ||||
|             &mut **self.stream.as_mut().unwrap() | ||||
|         } else if self.writer.is_some() { | ||||
|             &mut **self.writer.as_mut().unwrap().get_mut().get_mut() | ||||
|         } else if self.reader.is_some() { | ||||
|             &mut **self.reader.as_mut().unwrap().get_mut().get_mut() | ||||
|         } else { | ||||
|             panic!("Http11Message lost its underlying stream somehow"); | ||||
|         } | ||||
| @@ -344,6 +373,16 @@ impl<R: Read> HttpReader<R> { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /// Gets a borrowed reference to the underlying Reader. | ||||
|     pub fn get_ref(&self) -> &R { | ||||
|         match *self { | ||||
|             SizedReader(ref r, _) => r, | ||||
|             ChunkedReader(ref r, _) => r, | ||||
|             EofReader(ref r) => r, | ||||
|             EmptyReader(ref r) => r, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /// Gets a mutable reference to the underlying Reader. | ||||
|     pub fn get_mut(&mut self) -> &mut R { | ||||
|         match *self { | ||||
|   | ||||
| @@ -4,6 +4,8 @@ use std::io::{self, Write, Read, Cursor}; | ||||
| use std::net::Shutdown; | ||||
| use std::ascii::AsciiExt; | ||||
| use std::mem; | ||||
| #[cfg(feature = "timeouts")] | ||||
| use std::time::Duration; | ||||
|  | ||||
| use http::{ | ||||
|     Protocol, | ||||
| @@ -398,6 +400,19 @@ impl<S> HttpMessage for Http2Message<S> where S: CloneableStream { | ||||
|         Ok(head) | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     #[inline] | ||||
|     fn set_read_timeout(&self, _dur: Option<Duration>) -> io::Result<()> { | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     #[inline] | ||||
|     fn set_write_timeout(&self, _dur: Option<Duration>) -> io::Result<()> { | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     #[inline] | ||||
|     fn close_connection(&mut self) -> ::Result<()> { | ||||
|         Ok(()) | ||||
|     } | ||||
|   | ||||
| @@ -1,12 +1,16 @@ | ||||
| //! Defines the `HttpMessage` trait that serves to encapsulate the operations of a single | ||||
| //! request-response cycle on any HTTP connection. | ||||
|  | ||||
| use std::fmt::Debug; | ||||
| use std::any::{Any, TypeId}; | ||||
| use std::fmt::Debug; | ||||
| use std::io::{Read, Write}; | ||||
|  | ||||
| use std::mem; | ||||
|  | ||||
| #[cfg(feature = "timeouts")] | ||||
| use std::io; | ||||
| #[cfg(feature = "timeouts")] | ||||
| use std::time::Duration; | ||||
|  | ||||
| use typeable::Typeable; | ||||
|  | ||||
| use header::Headers; | ||||
| @@ -62,7 +66,10 @@ pub trait HttpMessage: Write + Read + Send + Any + Typeable + Debug { | ||||
|     fn get_incoming(&mut self) -> ::Result<ResponseHead>; | ||||
|     /// Set the read timeout duration for this message. | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     fn set_read_timeout(&self, dur: Option<Duration>) -> ::Result<()>; | ||||
|     fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()>; | ||||
|     /// Set the write timeout duration for this message. | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()>; | ||||
|     /// Closes the underlying HTTP connection. | ||||
|     fn close_connection(&mut self) -> ::Result<()>; | ||||
| } | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
| #![cfg_attr(test, deny(missing_docs))] | ||||
| #![cfg_attr(test, deny(warnings))] | ||||
| #![cfg_attr(all(test, feature = "nightly"), feature(test))] | ||||
| #![cfg_attr(feature = "timeouts", feature(duration, socket_timeout))] | ||||
|  | ||||
| //! # Hyper | ||||
| //! | ||||
|   | ||||
							
								
								
									
										62
									
								
								src/mock.rs
									
									
									
									
									
								
							
							
						
						
									
										62
									
								
								src/mock.rs
									
									
									
									
									
								
							| @@ -4,6 +4,10 @@ use std::io::{self, Read, Write, Cursor}; | ||||
| use std::cell::RefCell; | ||||
| use std::net::SocketAddr; | ||||
| use std::sync::{Arc, Mutex}; | ||||
| #[cfg(feature = "timeouts")] | ||||
| use std::time::Duration; | ||||
| #[cfg(feature = "timeouts")] | ||||
| use std::cell::Cell; | ||||
|  | ||||
| use solicit::http::HttpScheme; | ||||
| use solicit::http::transport::TransportStream; | ||||
| @@ -13,18 +17,14 @@ use solicit::http::connection::{HttpConnection, EndStream, DataChunk}; | ||||
| use header::Headers; | ||||
| use net::{NetworkStream, NetworkConnector}; | ||||
|  | ||||
| #[derive(Clone)] | ||||
| pub struct MockStream { | ||||
|     pub read: Cursor<Vec<u8>>, | ||||
|     pub write: Vec<u8>, | ||||
| } | ||||
|  | ||||
| impl Clone for MockStream { | ||||
|     fn clone(&self) -> MockStream { | ||||
|         MockStream { | ||||
|             read: Cursor::new(self.read.get_ref().clone()), | ||||
|             write: self.write.clone() | ||||
|         } | ||||
|     } | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     pub read_timeout: Cell<Option<Duration>>, | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     pub write_timeout: Cell<Option<Duration>> | ||||
| } | ||||
|  | ||||
| impl fmt::Debug for MockStream { | ||||
| @@ -41,18 +41,26 @@ impl PartialEq for MockStream { | ||||
|  | ||||
| impl MockStream { | ||||
|     pub fn new() -> MockStream { | ||||
|         MockStream { | ||||
|             read: Cursor::new(vec![]), | ||||
|             write: vec![], | ||||
|         } | ||||
|         MockStream::with_input(b"") | ||||
|     } | ||||
|  | ||||
|     #[cfg(not(feature = "timeouts"))] | ||||
|     pub fn with_input(input: &[u8]) -> MockStream { | ||||
|         MockStream { | ||||
|             read: Cursor::new(input.to_vec()), | ||||
|             write: vec![] | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     pub fn with_input(input: &[u8]) -> MockStream { | ||||
|         MockStream { | ||||
|             read: Cursor::new(input.to_vec()), | ||||
|             write: vec![], | ||||
|             read_timeout: Cell::new(None), | ||||
|             write_timeout: Cell::new(None), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Read for MockStream { | ||||
| @@ -75,6 +83,18 @@ impl NetworkStream for MockStream { | ||||
|     fn peer_addr(&mut self) -> io::Result<SocketAddr> { | ||||
|         Ok("127.0.0.1:1337".parse().unwrap()) | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         self.read_timeout.set(dur); | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         self.write_timeout.set(dur); | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// A wrapper around a `MockStream` that allows one to clone it and keep an independent copy to the | ||||
| @@ -114,6 +134,16 @@ impl NetworkStream for CloneableMockStream { | ||||
|     fn peer_addr(&mut self) -> io::Result<SocketAddr> { | ||||
|         self.inner.lock().unwrap().peer_addr() | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         self.inner.lock().unwrap().set_read_timeout(dur) | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         self.inner.lock().unwrap().set_write_timeout(dur) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl CloneableMockStream { | ||||
| @@ -147,7 +177,6 @@ macro_rules! mock_connector ( | ||||
|             fn connect(&self, host: &str, port: u16, scheme: &str) | ||||
|                     -> $crate::Result<::mock::MockStream> { | ||||
|                 use std::collections::HashMap; | ||||
|                 use std::io::Cursor; | ||||
|                 debug!("MockStream::connect({:?}, {:?}, {:?})", host, port, scheme); | ||||
|                 let mut map = HashMap::new(); | ||||
|                 $(map.insert($url, $res);)* | ||||
| @@ -156,10 +185,7 @@ macro_rules! mock_connector ( | ||||
|                 let key = format!("{}://{}", scheme, host); | ||||
|                 // ignore port for now | ||||
|                 match map.get(&*key) { | ||||
|                     Some(&res) => Ok($crate::mock::MockStream { | ||||
|                         write: vec![], | ||||
|                         read: Cursor::new(res.to_owned().into_bytes()), | ||||
|                     }), | ||||
|                     Some(&res) => Ok($crate::mock::MockStream::with_input(res.as_bytes())), | ||||
|                     None => panic!("{:?} doesn't know url {}", stringify!($name), key) | ||||
|                 } | ||||
|             } | ||||
|   | ||||
							
								
								
									
										59
									
								
								src/net.rs
									
									
									
									
									
								
							
							
						
						
									
										59
									
								
								src/net.rs
									
									
									
									
									
								
							| @@ -8,6 +8,9 @@ use std::mem; | ||||
| #[cfg(feature = "openssl")] | ||||
| pub use self::openssl::Openssl; | ||||
|  | ||||
| #[cfg(feature = "timeouts")] | ||||
| use std::time::Duration; | ||||
|  | ||||
| use typeable::Typeable; | ||||
| use traitobject; | ||||
|  | ||||
| @@ -21,8 +24,6 @@ pub enum Streaming {} | ||||
| pub trait NetworkListener: Clone { | ||||
|     /// The stream produced for each connection. | ||||
|     type Stream: NetworkStream + Send + Clone; | ||||
|     /// Listens on a socket. | ||||
|     //fn listen<To: ToSocketAddrs>(&mut self, addr: To) -> io::Result<Self::Acceptor>; | ||||
|  | ||||
|     /// Returns an iterator of streams. | ||||
|     fn accept(&mut self) -> ::Result<Self::Stream>; | ||||
| @@ -30,9 +31,6 @@ pub trait NetworkListener: Clone { | ||||
|     /// Get the address this Listener ended up listening on. | ||||
|     fn local_addr(&mut self) -> io::Result<SocketAddr>; | ||||
|  | ||||
|     /// Closes the Acceptor, so no more incoming connections will be handled. | ||||
| //    fn close(&mut self) -> io::Result<()>; | ||||
|  | ||||
|     /// Returns an iterator over incoming connections. | ||||
|     fn incoming(&mut self) -> NetworkConnections<Self> { | ||||
|         NetworkConnections(self) | ||||
| @@ -53,6 +51,12 @@ impl<'a, N: NetworkListener + 'a> Iterator for NetworkConnections<'a, N> { | ||||
| pub trait NetworkStream: Read + Write + Any + Send + Typeable { | ||||
|     /// Get the remote address of the underlying connection. | ||||
|     fn peer_addr(&mut self) -> io::Result<SocketAddr>; | ||||
|     /// Set the maximum time to wait for a read to complete. | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()>; | ||||
|     /// Set the maximum time to wait for a write to complete. | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()>; | ||||
|     /// This will be called when Stream should no longer be kept alive. | ||||
|     #[inline] | ||||
|     fn close(&mut self, _how: Shutdown) -> io::Result<()> { | ||||
| @@ -222,6 +226,18 @@ impl NetworkStream for HttpStream { | ||||
|             self.0.peer_addr() | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     #[inline] | ||||
|     fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         self.0.set_read_timeout(dur) | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     #[inline] | ||||
|     fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         self.0.set_write_timeout(dur) | ||||
|     } | ||||
|  | ||||
|     #[inline] | ||||
|     fn close(&mut self, how: Shutdown) -> io::Result<()> { | ||||
|         match self.0.shutdown(how) { | ||||
| @@ -340,6 +356,24 @@ impl<S: NetworkStream> NetworkStream for HttpsStream<S> { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     #[inline] | ||||
|     fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         match *self { | ||||
|             HttpsStream::Http(ref inner) => inner.0.set_read_timeout(dur), | ||||
|             HttpsStream::Https(ref inner) => inner.set_read_timeout(dur) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     #[inline] | ||||
|     fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|         match *self { | ||||
|             HttpsStream::Http(ref inner) => inner.0.set_read_timeout(dur), | ||||
|             HttpsStream::Https(ref inner) => inner.set_read_timeout(dur) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     #[inline] | ||||
|     fn close(&mut self, how: Shutdown) -> io::Result<()> { | ||||
|         match *self { | ||||
| @@ -425,6 +459,9 @@ mod openssl { | ||||
|     use std::net::{SocketAddr, Shutdown}; | ||||
|     use std::path::Path; | ||||
|     use std::sync::Arc; | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     use std::time::Duration; | ||||
|  | ||||
|     use openssl::ssl::{Ssl, SslContext, SslStream, SslMethod, SSL_VERIFY_NONE}; | ||||
|     use openssl::ssl::error::StreamError as SslIoError; | ||||
|     use openssl::ssl::error::SslError; | ||||
| @@ -503,6 +540,18 @@ mod openssl { | ||||
|             self.get_mut().peer_addr() | ||||
|         } | ||||
|  | ||||
|         #[cfg(feature = "timeouts")] | ||||
|         #[inline] | ||||
|         fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|             self.get_ref().set_read_timeout(dur) | ||||
|         } | ||||
|  | ||||
|         #[cfg(feature = "timeouts")] | ||||
|         #[inline] | ||||
|         fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||
|             self.get_ref().set_write_timeout(dur) | ||||
|         } | ||||
|  | ||||
|         fn close(&mut self, how: Shutdown) -> io::Result<()> { | ||||
|             self.get_mut().close(how) | ||||
|         } | ||||
|   | ||||
| @@ -108,10 +108,13 @@ | ||||
| //! `Request<Streaming>` object, that no longer has `headers_mut()`, but does | ||||
| //! implement `Write`. | ||||
| use std::fmt; | ||||
| use std::io::{ErrorKind, BufWriter, Write}; | ||||
| use std::io::{self, ErrorKind, BufWriter, Write}; | ||||
| use std::net::{SocketAddr, ToSocketAddrs}; | ||||
| use std::thread::{self, JoinHandle}; | ||||
|  | ||||
| #[cfg(feature = "timeouts")] | ||||
| use std::time::Duration; | ||||
|  | ||||
| use num_cpus; | ||||
|  | ||||
| pub use self::request::Request; | ||||
| @@ -143,8 +146,20 @@ mod listener; | ||||
| #[derive(Debug)] | ||||
| pub struct Server<L = HttpListener> { | ||||
|     listener: L, | ||||
|     _timeouts: Timeouts, | ||||
| } | ||||
|  | ||||
| #[cfg(feature = "timeouts")] | ||||
| #[derive(Clone, Copy, Default, Debug)] | ||||
| struct Timeouts { | ||||
|     read: Option<Duration>, | ||||
|     write: Option<Duration>, | ||||
| } | ||||
|  | ||||
| #[cfg(not(feature = "timeouts"))] | ||||
| #[derive(Clone, Copy, Default, Debug)] | ||||
| struct Timeouts; | ||||
|  | ||||
| macro_rules! try_option( | ||||
|     ($e:expr) => {{ | ||||
|         match $e { | ||||
| @@ -159,9 +174,22 @@ impl<L: NetworkListener> Server<L> { | ||||
|     #[inline] | ||||
|     pub fn new(listener: L) -> Server<L> { | ||||
|         Server { | ||||
|             listener: listener | ||||
|             listener: listener, | ||||
|             _timeouts: Timeouts::default(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     pub fn set_read_timeout(&mut self, dur: Option<Duration>) { | ||||
|         self._timeouts.read = dur; | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     pub fn set_write_timeout(&mut self, dur: Option<Duration>) { | ||||
|         self._timeouts.write = dur; | ||||
|     } | ||||
|  | ||||
|  | ||||
| } | ||||
|  | ||||
| impl Server<HttpListener> { | ||||
| @@ -183,24 +211,25 @@ impl<S: Ssl + Clone + Send> Server<HttpsListener<S>> { | ||||
| impl<L: NetworkListener + Send + 'static> Server<L> { | ||||
|     /// Binds to a socket and starts handling connections. | ||||
|     pub fn handle<H: Handler + 'static>(self, handler: H) -> ::Result<Listening> { | ||||
|         with_listener(handler, self.listener, num_cpus::get() * 5 / 4) | ||||
|         self.handle_threads(handler, num_cpus::get() * 5 / 4) | ||||
|     } | ||||
|     /// Binds to a socket and starts handling connections with the provided | ||||
|     /// number of threads. | ||||
|     pub fn handle_threads<H: Handler + 'static>(self, handler: H, | ||||
|             threads: usize) -> ::Result<Listening> { | ||||
|         with_listener(handler, self.listener, threads) | ||||
|         handle(self, handler, threads) | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn with_listener<H, L>(handler: H, mut listener: L, threads: usize) -> ::Result<Listening> | ||||
| fn handle<H, L>(mut server: Server<L>, handler: H, threads: usize) -> ::Result<Listening> | ||||
| where H: Handler + 'static, | ||||
| L: NetworkListener + Send + 'static { | ||||
|     let socket = try!(listener.local_addr()); | ||||
|     let socket = try!(server.listener.local_addr()); | ||||
|  | ||||
|     debug!("threads = {:?}", threads); | ||||
|     let pool = ListenerPool::new(listener); | ||||
|     let work = move |mut stream| Worker(&handler).handle_connection(&mut stream); | ||||
|     let pool = ListenerPool::new(server.listener); | ||||
|     let worker = Worker::new(handler, server._timeouts); | ||||
|     let work = move |mut stream| worker.handle_connection(&mut stream); | ||||
|  | ||||
|     let guard = thread::spawn(move || pool.accept(work, threads)); | ||||
|  | ||||
| @@ -210,12 +239,28 @@ L: NetworkListener + Send + 'static { | ||||
|     }) | ||||
| } | ||||
|  | ||||
| struct Worker<'a, H: Handler + 'static>(&'a H); | ||||
| struct Worker<H: Handler + 'static> { | ||||
|     handler: H, | ||||
|     _timeouts: Timeouts, | ||||
| } | ||||
|  | ||||
| impl<'a, H: Handler + 'static> Worker<'a, H> { | ||||
| impl<H: Handler + 'static> Worker<H> { | ||||
|  | ||||
|     fn new(handler: H, timeouts: Timeouts) -> Worker<H> { | ||||
|         Worker { | ||||
|             handler: handler, | ||||
|             _timeouts: timeouts, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn handle_connection<S>(&self, mut stream: &mut S) where S: NetworkStream + Clone { | ||||
|         debug!("Incoming stream"); | ||||
|  | ||||
|         if let Err(e) = self.set_timeouts(stream) { | ||||
|             error!("set_timeouts error: {:?}", e); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         let addr = match stream.peer_addr() { | ||||
|             Ok(addr) => addr, | ||||
|             Err(e) => { | ||||
| @@ -233,6 +278,17 @@ impl<'a, H: Handler + 'static> Worker<'a, H> { | ||||
|         debug!("keep_alive loop ending for {}", addr); | ||||
|     } | ||||
|  | ||||
|     #[cfg(not(feature = "timeouts"))] | ||||
|     fn set_timeouts<S>(&self, _: &mut S) -> io::Result<()> where S: NetworkStream { | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     #[cfg(feature = "timeouts")] | ||||
|     fn set_timeouts<S>(&self, s: &mut S) -> io::Result<()> where S: NetworkStream { | ||||
|         try!(s.set_read_timeout(self._timeouts.read)); | ||||
|         s.set_write_timeout(self._timeouts.write) | ||||
|     } | ||||
|  | ||||
|     fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>, | ||||
|             mut wrt: W, addr: SocketAddr) { | ||||
|         let mut keep_alive = true; | ||||
| @@ -268,7 +324,7 @@ impl<'a, H: Handler + 'static> Worker<'a, H> { | ||||
|             { | ||||
|                 let mut res = Response::new(&mut wrt, &mut res_headers); | ||||
|                 res.version = version; | ||||
|                 self.0.handle(req, res); | ||||
|                 self.handler.handle(req, res); | ||||
|             } | ||||
|  | ||||
|             // if the request was keep-alive, we need to check that the server agrees | ||||
| @@ -284,7 +340,7 @@ impl<'a, H: Handler + 'static> Worker<'a, H> { | ||||
|  | ||||
|     fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool { | ||||
|          if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) { | ||||
|             let status = self.0.check_continue((&req.method, &req.uri, &req.headers)); | ||||
|             let status = self.handler.check_continue((&req.method, &req.uri, &req.headers)); | ||||
|             match write!(wrt, "{} {}\r\n\r\n", Http11, status) { | ||||
|                 Ok(..) => (), | ||||
|                 Err(e) => { | ||||
| @@ -327,7 +383,6 @@ impl Listening { | ||||
|     pub fn close(&mut self) -> ::Result<()> { | ||||
|         let _ = self._guard.take(); | ||||
|         debug!("closing server"); | ||||
|         //try!(self.acceptor.close()); | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
| @@ -379,7 +434,7 @@ mod tests { | ||||
|             res.start().unwrap().end().unwrap(); | ||||
|         } | ||||
|  | ||||
|         Worker(&handle).handle_connection(&mut mock); | ||||
|         Worker::new(handle, Default::default()).handle_connection(&mut mock); | ||||
|         let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; | ||||
|         assert_eq!(&mock.write[..cont.len()], cont); | ||||
|         let res = b"HTTP/1.1 200 OK\r\n"; | ||||
| @@ -408,7 +463,7 @@ mod tests { | ||||
|             1234567890\ | ||||
|         "); | ||||
|  | ||||
|         Worker(&Reject).handle_connection(&mut mock); | ||||
|         Worker::new(Reject, Default::default()).handle_connection(&mut mock); | ||||
|         assert_eq!(mock.write, &b"HTTP/1.1 417 Expectation Failed\r\n\r\n"[..]); | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user