From 0ab52c900975bbb26358dfd3e758f27cf1b6e957 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Wed, 24 Sep 2014 17:31:56 -0700 Subject: [PATCH] add bits to deal with Upgrade requests --- src/client/response.rs | 35 +++++++++++ src/header/common/connection.rs | 47 +++++++++----- src/header/common/mod.rs | 73 ++++++++++++++++------ src/header/common/transfer_encoding.rs | 41 ++++++------- src/header/common/upgrade.rs | 51 ++++++++++++++++ src/http.rs | 12 ++++ src/lib.rs | 1 + src/mock.rs | 29 +++++++++ src/net.rs | 85 +++++++++++++++++++++++++- 9 files changed, 314 insertions(+), 60 deletions(-) create mode 100644 src/header/common/upgrade.rs create mode 100644 src/mock.rs diff --git a/src/client/response.rs b/src/client/response.rs index adc2f631..27a3fea9 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -65,6 +65,11 @@ impl Response { body: body, }) } + + /// Unwraps the Request to return the NetworkStream underneath. + pub fn unwrap(self) -> Box { + self.body.unwrap().unwrap() + } } impl Reader for Response { @@ -73,3 +78,33 @@ impl Reader for Response { self.body.read(buf) } } + +#[cfg(test)] +mod tests { + use std::boxed::BoxAny; + use std::io::BufferedReader; + + use header::Headers; + use http::EofReader; + use mock::MockStream; + use net::NetworkStream; + use status; + use version; + + use super::Response; + + + #[test] + fn test_unwrap() { + let res = Response { + status: status::Ok, + headers: Headers::new(), + version: version::Http11, + body: EofReader(BufferedReader::new(box MockStream as Box)) + }; + + let b = res.unwrap().downcast::().unwrap(); + assert_eq!(b, box MockStream); + + } +} diff --git a/src/header/common/connection.rs b/src/header/common/connection.rs index 63af0541..cb20fe4b 100644 --- a/src/header/common/connection.rs +++ b/src/header/common/connection.rs @@ -1,45 +1,62 @@ use header::Header; use std::fmt::{mod, Show}; -use super::util::from_one_raw_str; +use super::{from_comma_delimited, fmt_comma_delimited}; use std::from_str::FromStr; /// The `Connection` header. -/// -/// Describes whether the socket connection should be closed or reused after -/// this request/response is completed. #[deriving(Clone, PartialEq, Show)] -pub enum Connection { +pub struct Connection(Vec); + +/// Values that can be in the `Connection` header. +#[deriving(Clone, PartialEq)] +pub enum ConnectionOption { /// The `keep-alive` connection value. KeepAlive, /// The `close` connection value. - Close + Close, + /// Values in the Connection header that are supposed to be names of other Headers. + /// + /// > When a header field aside from Connection is used to supply control + /// > information for or about the current connection, the sender MUST list + /// > the corresponding field-name within the Connection header field. + // TODO: it would be nice if these "Strings" could be stronger types, since + // they are supposed to relate to other Header fields (which we have strong + // types for). + ConnectionHeader(String), } -impl FromStr for Connection { - fn from_str(s: &str) -> Option { - debug!("Connection::from_str =? {}", s); +impl FromStr for ConnectionOption { + fn from_str(s: &str) -> Option { match s { "keep-alive" => Some(KeepAlive), "close" => Some(Close), - _ => None + s => Some(ConnectionHeader(s.to_string())) } } } +impl fmt::Show for ConnectionOption { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match *self { + KeepAlive => "keep-alive", + Close => "close", + ConnectionHeader(ref s) => s.as_slice() + }.fmt(fmt) + } +} + impl Header for Connection { fn header_name(_: Option) -> &'static str { "Connection" } fn parse_header(raw: &[Vec]) -> Option { - from_one_raw_str(raw) + from_comma_delimited(raw).map(|vec| Connection(vec)) } fn fmt_header(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - match *self { - KeepAlive => "keep-alive", - Close => "close", - }.fmt(fmt) + let Connection(ref parts) = *self; + fmt_comma_delimited(fmt, parts[]) } } diff --git a/src/header/common/mod.rs b/src/header/common/mod.rs index ecb9a3fa..e7fad24f 100644 --- a/src/header/common/mod.rs +++ b/src/header/common/mod.rs @@ -6,25 +6,21 @@ //! strongly-typed theme, the [mime](http://seanmonstar.github.io/mime.rs) crate //! is used, such as `ContentType(pub Mime)`. -pub use self::host::Host; -pub use self::content_length::ContentLength; -pub use self::content_type::ContentType; pub use self::accept::Accept; pub use self::connection::Connection; +pub use self::content_length::ContentLength; +pub use self::content_type::ContentType; +pub use self::date::Date; +pub use self::host::Host; +pub use self::location::Location; pub use self::transfer_encoding::TransferEncoding; +pub use self::upgrade::Upgrade; pub use self::user_agent::UserAgent; pub use self::server::Server; -pub use self::date::Date; -pub use self::location::Location; -/// Exposes the Host header. -pub mod host; - -/// Exposes the ContentLength header. -pub mod content_length; - -/// Exposes the ContentType header. -pub mod content_type; +use std::fmt::{mod, Show}; +use std::from_str::FromStr; +use std::str::from_utf8; /// Exposes the Accept header. pub mod accept; @@ -32,17 +28,30 @@ pub mod accept; /// Exposes the Connection header. pub mod connection; -/// Exposes the TransferEncoding header. -pub mod transfer_encoding; +/// Exposes the ContentLength header. +pub mod content_length; -/// Exposes the UserAgent header. -pub mod user_agent; +/// Exposes the ContentType header. +pub mod content_type; + +/// Exposes the Date header. +pub mod date; + +/// Exposes the Host header. +pub mod host; /// Exposes the Server header. pub mod server; -/// Exposes the Date header. -pub mod date; +/// Exposes the TransferEncoding header. +pub mod transfer_encoding; + +/// Exposes the Upgrade header. +pub mod upgrade; + +/// Exposes the UserAgent header. +pub mod user_agent; + /// Exposes the Location header. pub mod location; @@ -50,3 +59,29 @@ pub mod location; pub mod util; +fn from_comma_delimited(raw: &[Vec]) -> Option> { + if raw.len() != 1 { + return None; + } + // we JUST checked that raw.len() == 1, so raw[0] WILL exist. + match from_utf8(unsafe { raw.as_slice().unsafe_get(0).as_slice() }) { + Some(s) => { + Some(s.as_slice() + .split([',', ' '].as_slice()) + .filter_map(from_str) + .collect()) + } + None => None + } +} + +fn fmt_comma_delimited(fmt: &mut fmt::Formatter, parts: &[T]) -> fmt::Result { + let last = parts.len() - 1; + for (i, part) in parts.iter().enumerate() { + try!(part.fmt(fmt)); + if i < last { + try!(", ".fmt(fmt)); + } + } + Ok(()) +} diff --git a/src/header/common/transfer_encoding.rs b/src/header/common/transfer_encoding.rs index 7731557d..703eb1fc 100644 --- a/src/header/common/transfer_encoding.rs +++ b/src/header/common/transfer_encoding.rs @@ -1,7 +1,7 @@ use header::Header; -use std::fmt::{mod, Show}; +use std::fmt; use std::from_str::FromStr; -use std::str::from_utf8; +use super::{from_comma_delimited, fmt_comma_delimited}; /// The `Transfer-Encoding` header. /// @@ -28,7 +28,7 @@ pub struct TransferEncoding(pub Vec); /// # use hyper::header::Headers; /// # let mut headers = Headers::new(); /// headers.set(TransferEncoding(vec![Gzip, Chunked])); -#[deriving(Clone, PartialEq, Show)] +#[deriving(Clone, PartialEq)] pub enum Encoding { /// The `chunked` encoding. Chunked, @@ -43,6 +43,18 @@ pub enum Encoding { EncodingExt(String) } +impl fmt::Show for Encoding { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match *self { + Chunked => "chunked", + Gzip => "gzip", + Deflate => "deflate", + Compress => "compress", + EncodingExt(ref s) => s.as_slice() + }.fmt(fmt) + } +} + impl FromStr for Encoding { fn from_str(s: &str) -> Option { match s { @@ -61,31 +73,12 @@ impl Header for TransferEncoding { } fn parse_header(raw: &[Vec]) -> Option { - if raw.len() != 1 { - return None; - } - // we JUST checked that raw.len() == 1, so raw[0] WILL exist. - match from_utf8(unsafe { raw.as_slice().unsafe_get(0).as_slice() }) { - Some(s) => { - Some(TransferEncoding(s.as_slice() - .split([',', ' '].as_slice()) - .filter_map(from_str) - .collect())) - } - None => None - } + from_comma_delimited(raw).map(|vec| TransferEncoding(vec)) } fn fmt_header(&self, fmt: &mut fmt::Formatter) -> fmt::Result { let TransferEncoding(ref parts) = *self; - let last = parts.len() - 1; - for (i, part) in parts.iter().enumerate() { - try!(part.fmt(fmt)); - if i < last { - try!(", ".fmt(fmt)); - } - } - Ok(()) + fmt_comma_delimited(fmt, parts[]) } } diff --git a/src/header/common/upgrade.rs b/src/header/common/upgrade.rs new file mode 100644 index 00000000..fab3bc2c --- /dev/null +++ b/src/header/common/upgrade.rs @@ -0,0 +1,51 @@ +use header::Header; +use std::fmt::{mod, Show}; +use super::{from_comma_delimited, fmt_comma_delimited}; +use std::from_str::FromStr; + +/// The `Upgrade` header. +#[deriving(Clone, PartialEq, Show)] +pub struct Upgrade(Vec); + +/// Protocol values that can appear in the Upgrade header. +#[deriving(Clone, PartialEq)] +pub enum Protocol { + /// The websocket protocol. + WebSocket, + /// Some other less common protocol. + ProtocolExt(String), +} + +impl FromStr for Protocol { + fn from_str(s: &str) -> Option { + match s { + "websocket" => Some(WebSocket), + s => Some(ProtocolExt(s.to_string())) + } + } +} + +impl fmt::Show for Protocol { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match *self { + WebSocket => "websocket", + ProtocolExt(ref s) => s.as_slice() + }.fmt(fmt) + } +} + +impl Header for Upgrade { + fn header_name(_: Option) -> &'static str { + "Upgrade" + } + + fn parse_header(raw: &[Vec]) -> Option { + from_comma_delimited(raw).map(|vec| Upgrade(vec)) + } + + fn fmt_header(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + let Upgrade(ref parts) = *self; + fmt_comma_delimited(fmt, parts[]) + } +} + diff --git a/src/http.rs b/src/http.rs index 515483ac..b87b56ae 100644 --- a/src/http.rs +++ b/src/http.rs @@ -38,6 +38,18 @@ pub enum HttpReader { EofReader(R), } +impl HttpReader { + + /// Unwraps this HttpReader and returns the underlying Reader. + pub fn unwrap(self) -> R { + match self { + SizedReader(r, _) => r, + ChunkedReader(r, _) => r, + EofReader(r) => r, + } + } +} + impl Reader for HttpReader { fn read(&mut self, buf: &mut [u8]) -> IoResult { match *self { diff --git a/src/lib.rs b/src/lib.rs index 485c671f..0c3db3a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -184,6 +184,7 @@ pub mod status; pub mod uri; pub mod version; +#[cfg(test)] mod mock; mod mimewrapper { /// Re-exporting the mime crate, for convenience. diff --git a/src/mock.rs b/src/mock.rs new file mode 100644 index 00000000..12851991 --- /dev/null +++ b/src/mock.rs @@ -0,0 +1,29 @@ +use std::io::IoResult; +use std::io::net::ip::SocketAddr; + +use net::NetworkStream; + +#[deriving(Clone, PartialEq, Show)] +pub struct MockStream; + +impl Reader for MockStream { + fn read(&mut self, _buf: &mut [u8]) -> IoResult { + unimplemented!() + } +} + +impl Writer for MockStream { + fn write(&mut self, _msg: &[u8]) -> IoResult<()> { + unimplemented!() + } +} + +impl NetworkStream for MockStream { + fn connect(_host: &str, _port: u16, _scheme: &str) -> IoResult { + Ok(MockStream) + } + + fn peer_name(&mut self) -> IoResult { + Ok(from_str("127.0.0.1:1337").unwrap()) + } +} diff --git a/src/net.rs b/src/net.rs index 9465bd53..6a599f4f 100644 --- a/src/net.rs +++ b/src/net.rs @@ -1,10 +1,18 @@ //! A collection of traits abstracting over Listeners and Streams. +use std::any::{Any, AnyRefExt}; +use std::boxed::BoxAny; +use std::fmt; +use std::intrinsics::TypeId; use std::io::{IoResult, IoError, ConnectionAborted, InvalidInput, OtherIoError, Stream, Listener, Acceptor}; use std::io::net::ip::{SocketAddr, Port}; use std::io::net::tcp::{TcpStream, TcpListener, TcpAcceptor}; +use std::mem::{mod, transmute, transmute_copy}; +use std::raw::{mod, TraitObject}; use std::sync::{Arc, Mutex}; +use uany::UncheckedBoxAnyDowncast; +use typeable::Typeable; use openssl::ssl::{SslStream, SslContext, Sslv23}; use openssl::ssl::error::{SslError, StreamError, OpenSslErrors, SslSessionClosed}; @@ -15,7 +23,7 @@ pub struct Fresh; pub struct Streaming; /// An abstraction to listen for connections on a certain port. -pub trait NetworkListener>: Listener { +pub trait NetworkListener>: Listener + Typeable { /// Bind to a socket. /// /// Note: This does not start listening for connections. You must call @@ -33,7 +41,7 @@ pub trait NetworkAcceptor: Acceptor + Clone + Send { } /// An abstraction over streams that a Server can utilize. -pub trait NetworkStream: Stream + Clone + Send { +pub trait NetworkStream: Stream + Any + Clone + Send { /// Get the remote address of the underlying connection. fn peer_name(&mut self) -> IoResult; @@ -52,6 +60,12 @@ pub trait NetworkStream: Stream + Clone + Send { fn clone_box(&self) -> Box { self.clone().dynamic() } } +impl fmt::Show for Box { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.pad("Box") + } +} + impl Clone for Box { #[inline] fn clone(&self) -> Box { self.clone_box() } @@ -70,6 +84,46 @@ impl Writer for Box { fn flush(&mut self) -> IoResult<()> { (**self).flush() } } +impl UncheckedBoxAnyDowncast for Box { + unsafe fn downcast_unchecked(self) -> Box { + let to = *mem::transmute::<&Box, &raw::TraitObject>(&self); + // Prevent double-free. + mem::forget(self); + mem::transmute(to.data) + } +} + +impl<'a> AnyRefExt<'a> for &'a NetworkStream + 'a { + #[inline] + fn is(self) -> bool { + self.get_type_id() == TypeId::of::() + } + + #[inline] + fn downcast_ref(self) -> Option<&'a T> { + if self.is::() { + unsafe { + // Get the raw representation of the trait object + let to: TraitObject = transmute_copy(&self); + // Extract the data pointer + Some(transmute(to.data)) + } + } else { + None + } + } +} + +impl BoxAny for Box { + fn downcast(self) -> Result, Box> { + if self.is::() { + Ok(unsafe { self.downcast_unchecked() }) + } else { + Err(self) + } + } +} + /// A `NetworkListener` for `HttpStream`s. pub struct HttpListener { inner: TcpListener @@ -212,3 +266,30 @@ fn lift_ssl_error(ssl: SslError) -> IoError { } } +#[cfg(test)] +mod tests { + use std::boxed::BoxAny; + use uany::UncheckedBoxAnyDowncast; + + use mock::MockStream; + use super::NetworkStream; + + #[test] + fn test_downcast_box_stream() { + let stream = MockStream.dynamic(); + + let mock = stream.downcast::().unwrap(); + assert_eq!(mock, box MockStream); + + } + + #[test] + fn test_downcast_unchecked_box_stream() { + let stream = MockStream.dynamic(); + + let mock = unsafe { stream.downcast_unchecked::() }; + assert_eq!(mock, box MockStream); + + } + +}