diff --git a/README.md b/README.md index 970cdb03..fef3a77c 100644 --- a/README.md +++ b/README.md @@ -42,18 +42,18 @@ Client: ```rust fn main() { + // Create a client. + let mut client = Client::new(); + // Creating an outgoing request. - let mut req = Request::get(Url::parse("http://www.gooogle.com/").unwrap()).unwrap(); + let mut res = client.get("http://www.gooogle.com/") + // set a header + .header(Connection(vec![Close])) + // let 'er go! + .send(); - // Setting a header. - req.headers_mut().set(Connection(vec![Close])); - - // Start the Request, writing headers and starting streaming. - let res = req.start().unwrap() - // Send the Request. - .send().unwrap() - // Read the Response. - .read_to_string().unwrap(); + // Read the Response. + let body = res.read_to_string().unwrap(); println!("Response: {}", res); } @@ -64,22 +64,20 @@ fn main() { [Client Bench:](./benches/client.rs) ``` - running 3 tests -test bench_curl ... bench: 298416 ns/iter (+/- 132455) -test bench_http ... bench: 292725 ns/iter (+/- 167575) -test bench_hyper ... bench: 222819 ns/iter (+/- 86615) +test bench_curl ... bench: 400253 ns/iter (+/- 143539) +test bench_hyper ... bench: 181703 ns/iter (+/- 46529) -test result: ok. 0 passed; 0 failed; 0 ignored; 3 measured +test result: ok. 0 passed; 0 failed; 0 ignored; 2 measured ``` [Mock Client Bench:](./benches/client_mock_tcp.rs) ``` running 3 tests -test bench_mock_curl ... bench: 25254 ns/iter (+/- 2113) -test bench_mock_http ... bench: 43585 ns/iter (+/- 1206) -test bench_mock_hyper ... bench: 27153 ns/iter (+/- 2227) +test bench_mock_curl ... bench: 53987 ns/iter (+/- 1735) +test bench_mock_http ... bench: 43569 ns/iter (+/- 1409) +test bench_mock_hyper ... bench: 20996 ns/iter (+/- 1742) test result: ok. 0 passed; 0 failed; 0 ignored; 3 measured ``` diff --git a/benches/client.rs b/benches/client.rs index a3938924..f6e9ef04 100644 --- a/benches/client.rs +++ b/benches/client.rs @@ -8,6 +8,10 @@ extern crate test; use std::fmt::{mod, Show}; use std::io::net::ip::Ipv4Addr; use hyper::server::{Request, Response, Server}; +use hyper::method::Method::Get; +use hyper::header::Headers; +use hyper::Client; +use hyper::client::RequestBuilder; fn listen() -> hyper::server::Listening { let server = Server::http(Ipv4Addr(127, 0, 0, 1), 0); @@ -22,9 +26,10 @@ macro_rules! try_return( } }}) -fn handle(_: Request, res: Response) { +fn handle(_r: Request, res: Response) { + static BODY: &'static [u8] = b"Benchmarking hyper vs others!"; let mut res = try_return!(res.start()); - try_return!(res.write(b"Benchmarking hyper vs others!")) + try_return!(res.write(BODY)) try_return!(res.end()); } @@ -41,7 +46,7 @@ fn bench_curl(b: &mut test::Bencher) { .exec() .unwrap() }); - listening.close().unwrap() + listening.close().unwrap(); } #[deriving(Clone)] @@ -67,17 +72,17 @@ fn bench_hyper(b: &mut test::Bencher) { let mut listening = listen(); let s = format!("http://{}/", listening.socket); let url = s.as_slice(); + let mut client = Client::new(); + let mut headers = Headers::new(); + headers.set(Foo); b.iter(|| { - let mut req = hyper::client::Request::get(hyper::Url::parse(url).unwrap()).unwrap(); - req.headers_mut().set(Foo); - - req.start().unwrap() - .send().unwrap() - .read_to_string().unwrap() + client.get(url).header(Foo).send().unwrap().read_to_string().unwrap(); }); listening.close().unwrap() } +/* +doesn't handle keep-alive properly... #[bench] fn bench_http(b: &mut test::Bencher) { let mut listening = listen(); @@ -92,9 +97,10 @@ fn bench_http(b: &mut test::Bencher) { // cant unwrap because Err contains RequestWriter, which does not implement Show let mut res = match req.read_response() { Ok(res) => res, - Err(..) => panic!("http response failed") + Err((_, ioe)) => panic!("http response failed = {}", ioe) }; res.read_to_string().unwrap(); }); listening.close().unwrap() } +*/ diff --git a/benches/server.rs b/benches/server.rs index 378ba4ec..d18bf596 100644 --- a/benches/server.rs +++ b/benches/server.rs @@ -9,12 +9,13 @@ use test::Bencher; use std::io::net::ip::{SocketAddr, Ipv4Addr}; use http::server::Server; +use hyper::method::Method::Get; use hyper::server::{Request, Response}; static PHRASE: &'static [u8] = b"Benchmarking hyper vs others!"; fn request(url: hyper::Url) { - let req = hyper::client::Request::get(url).unwrap(); + let req = hyper::client::Request::new(Get, url).unwrap(); req.start().unwrap().send().unwrap().read_to_string().unwrap(); } diff --git a/examples/client.rs b/examples/client.rs index 8bf8926b..02961ea4 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -4,8 +4,7 @@ use std::os; use std::io::stdout; use std::io::util::copy; -use hyper::Url; -use hyper::client::Request; +use hyper::Client; fn main() { let args = os::args(); @@ -17,26 +16,17 @@ fn main() { } }; - let url = match Url::parse(args[1].as_slice()) { - Ok(url) => { - println!("GET {}...", url) - url - }, - Err(e) => panic!("Invalid URL: {}", e) - }; + let url = &*args[1]; + let mut client = Client::new(); - let req = match Request::get(url) { - Ok(req) => req, + let mut res = match client.get(url).send() { + Ok(res) => res, Err(err) => panic!("Failed to connect: {}", err) }; - let mut res = req - .start().unwrap() // failure: Error writing Headers - .send().unwrap(); // failure: Error reading Response head. - println!("Response: {}", res.status); - println!("{}", res.headers); + println!("Headers:\n{}", res.headers); match copy(&mut res, &mut stdout()) { Ok(..) => (), Err(e) => panic!("Stream failure: {}", e) diff --git a/src/client/mod.rs b/src/client/mod.rs index 1b24c7b7..36fb6284 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,7 +1,399 @@ //! HTTP Client +//! +//! # Usage +//! +//! The `Client` API is designed for most people to make HTTP requests. +//! It utilizes the lower level `Request` API. +//! +//! ```no_run +//! use hyper::Client; +//! +//! let mut client = Client::new(); +//! +//! let mut res = client.get("http://example.domain").send().unwrap(); +//! assert_eq!(res.status, hyper::Ok); +//! ``` +//! +//! The returned value from is a `Response`, which provides easy access +//! to the `status`, the `headers`, and the response body via the `Writer` +//! trait. +use std::default::Default; +use std::io::IoResult; +use std::io::util::copy; +use std::iter::Extend; + +use url::UrlParser; +use url::ParseError as UrlError; + +use openssl::ssl::VerifyCallback; + +use header::{Headers, Header, HeaderFormat}; +use header::common::{ContentLength, Location}; +use method::Method; +use net::{NetworkConnector, NetworkStream, HttpConnector}; +use status::StatusClass::Redirection; +use {Url, Port, HttpResult}; +use HttpError::HttpUriError; + pub use self::request::Request; pub use self::response::Response; pub mod request; pub mod response; +/// A Client to use additional features with Requests. +/// +/// Clients can handle things such as: redirect policy. +pub struct Client { + connector: C, + redirect_policy: RedirectPolicy, +} + +impl Client { + + /// Create a new Client. + pub fn new() -> Client { + Client::with_connector(HttpConnector(None)) + } + + /// Set the SSL verifier callback for use with OpenSSL. + pub fn set_ssl_verifier(&mut self, verifier: VerifyCallback) { + self.connector = HttpConnector(Some(verifier)); + } + +} + +impl, S: NetworkStream> Client { + + /// Create a new client with a specific connector. + pub fn with_connector(connector: C) -> Client { + Client { + connector: connector, + redirect_policy: Default::default() + } + } + + /// Set the RedirectPolicy. + pub fn set_redirect_policy(&mut self, policy: RedirectPolicy) { + self.redirect_policy = policy; + } + + /// Execute a Get request. + pub fn get(&mut self, url: U) -> RequestBuilder { + self.request(Method::Get, url) + } + + /// Execute a Head request. + pub fn head(&mut self, url: U) -> RequestBuilder { + self.request(Method::Head, url) + } + + /// Execute a Post request. + pub fn post(&mut self, url: U) -> RequestBuilder { + self.request(Method::Post, url) + } + + /// Execute a Put request. + pub fn put(&mut self, url: U) -> RequestBuilder { + self.request(Method::Put, url) + } + + /// Execute a Delete request. + pub fn delete(&mut self, url: U) -> RequestBuilder { + self.request(Method::Delete, url) + } + + + /// Build a new request using this Client. + pub fn request(&mut self, method: Method, url: U) -> RequestBuilder { + RequestBuilder { + client: self, + method: method, + url: url, + body: None, + headers: None, + } + } +} + +/// Options for an individual Request. +/// +/// One of these will be built for you if you use one of the convenience +/// methods, such as `get()`, `post()`, etc. +pub struct RequestBuilder<'a, U: IntoUrl, C: NetworkConnector + 'a, S: NetworkStream> { + client: &'a mut Client, + url: U, + headers: Option, + method: Method, + body: Option>, +} + +impl<'a, U: IntoUrl, C: NetworkConnector, S: NetworkStream> RequestBuilder<'a, U, C, S> { + + /// Set a request body to be sent. + pub fn body>(mut self, body: B) -> RequestBuilder<'a, U, C, S> { + self.body = Some(body.into_body()); + self + } + + /// Add additional headers to the request. + pub fn headers(mut self, headers: Headers) -> RequestBuilder<'a, U, C, S> { + self.headers = Some(headers); + self + } + + /// Add an individual new header to the request. + pub fn header(mut self, header: H) -> RequestBuilder<'a, U, C, S> { + { + let mut headers = match self.headers { + Some(ref mut h) => h, + None => { + self.headers = Some(Headers::new()); + self.headers.as_mut().unwrap() + } + }; + + headers.set(header); + } + self + } + + /// Execute this request and receive a Response back. + pub fn send(self) -> HttpResult { + let RequestBuilder { client, method, url, headers, body } = self; + let mut url = try!(url.into_url()); + debug!("client.request {} {}", method, url); + + let can_have_body = match &method { + &Method::Get | &Method::Head => false, + _ => true + }; + + let mut body = if can_have_body { + body.map(|b| b.into_body()) + } else { + None + }; + + loop { + let mut req = try!(Request::with_connector(method.clone(), url.clone(), &mut client.connector)); + headers.as_ref().map(|headers| req.headers_mut().extend(headers.iter())); + + match (can_have_body, body.as_ref()) { + (true, Some(ref body)) => match body.size() { + Some(size) => req.headers_mut().set(ContentLength(size)), + None => (), // chunked, Request will add it automatically + }, + (true, None) => req.headers_mut().set(ContentLength(0)), + _ => () // neither + } + let mut streaming = try!(req.start()); + body.take().map(|mut rdr| copy(&mut rdr, &mut streaming)); + let res = try!(streaming.send()); + if res.status.class() != Redirection { + return Ok(res) + } + debug!("redirect code {} for {}", res.status, url); + + let loc = { + // punching borrowck here + let loc = match res.headers.get::() { + Some(&Location(ref loc)) => { + Some(UrlParser::new().base_url(&url).parse(loc[])) + } + None => { + debug!("no Location header"); + // could be 304 Not Modified? + None + } + }; + match loc { + Some(r) => r, + None => return Ok(res) + } + }; + url = match loc { + Ok(u) => { + inspect!("Location", u) + }, + Err(e) => { + debug!("Location header had invalid URI: {}", e); + return Ok(res); + } + }; + match client.redirect_policy { + // separate branches because they cant be one + RedirectPolicy::FollowAll => (), //continue + RedirectPolicy::FollowIf(cond) if cond(&url) => (), //continue + _ => return Ok(res), + } + } + } +} + +/// A helper trait to allow overloading of the body parameter. +pub trait IntoBody<'a> { + /// Consumes self into an instance of `Body`. + fn into_body(self) -> Body<'a>; +} + +/// The target enum for the IntoBody trait. +pub enum Body<'a> { + /// A Reader does not necessarily know it's size, so it is chunked. + ChunkedBody(&'a mut (Reader + 'a)), + /// For Readers that can know their size, like a `File`. + SizedBody(&'a mut (Reader + 'a), uint), + /// A String has a size, and uses Content-Length. + BufBody(&'a [u8] , uint), +} + +impl<'a> Body<'a> { + fn size(&self) -> Option { + match *self { + Body::SizedBody(_, len) | Body::BufBody(_, len) => Some(len), + _ => None + } + } +} + +impl<'a> Reader for Body<'a> { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> IoResult { + match *self { + Body::ChunkedBody(ref mut r) => r.read(buf), + Body::SizedBody(ref mut r, _) => r.read(buf), + Body::BufBody(ref mut r, _) => r.read(buf), + } + } +} + +// To allow someone to pass a `Body::SizedBody()` themselves. +impl<'a> IntoBody<'a> for Body<'a> { + #[inline] + fn into_body(self) -> Body<'a> { + self + } +} + +impl<'a> IntoBody<'a> for &'a [u8] { + #[inline] + fn into_body(self) -> Body<'a> { + Body::BufBody(self, self.len()) + } +} + +impl<'a> IntoBody<'a> for &'a str { + #[inline] + fn into_body(self) -> Body<'a> { + self.as_bytes().into_body() + } +} + +impl<'a, R: Reader> IntoBody<'a> for &'a mut R { + #[inline] + fn into_body(self) -> Body<'a> { + Body::ChunkedBody(self) + } +} + +/// A helper trait to convert common objects into a Url. +pub trait IntoUrl { + /// Consumes the object, trying to return a Url. + fn into_url(self) -> Result; +} + +impl IntoUrl for Url { + fn into_url(self) -> Result { + Ok(self) + } +} + +impl<'a> IntoUrl for &'a str { + fn into_url(self) -> Result { + Url::parse(self) + } +} + +/// Behavior regarding how to handle redirects within a Client. +#[deriving(Copy, Clone)] +pub enum RedirectPolicy { + /// Don't follow any redirects. + FollowNone, + /// Follow all redirects. + FollowAll, + /// Follow a redirect if the contained function returns true. + FollowIf(fn(&Url) -> bool), +} + +impl Default for RedirectPolicy { + fn default() -> RedirectPolicy { + RedirectPolicy::FollowAll + } +} + +fn get_host_and_port(url: &Url) -> HttpResult<(String, Port)> { + let host = match url.serialize_host() { + Some(host) => host, + None => return Err(HttpUriError(UrlError::EmptyHost)) + }; + debug!("host={}", host); + let port = match url.port_or_default() { + Some(port) => port, + None => return Err(HttpUriError(UrlError::InvalidPort)) + }; + debug!("port={}", port); + Ok((host, port)) +} + +#[cfg(test)] +mod tests { + use header::common::Server; + use super::{Client, RedirectPolicy}; + use url::Url; + + mock_connector!(MockRedirectPolicy { + "http://127.0.0.1" => "HTTP/1.1 301 Redirect\r\n\ + Location: http://127.0.0.2\r\n\ + Server: mock1\r\n\ + \r\n\ + " + "http://127.0.0.2" => "HTTP/1.1 302 Found\r\n\ + Location: https://127.0.0.3\r\n\ + Server: mock2\r\n\ + \r\n\ + " + "https://127.0.0.3" => "HTTP/1.1 200 OK\r\n\ + Server: mock3\r\n\ + \r\n\ + " + }) + + #[test] + fn test_redirect_followall() { + let mut client = Client::with_connector(MockRedirectPolicy); + client.set_redirect_policy(RedirectPolicy::FollowAll); + + let res = client.get("http://127.0.0.1").send().unwrap(); + assert_eq!(res.headers.get(), Some(&Server("mock3".into_string()))); + } + + #[test] + fn test_redirect_dontfollow() { + let mut client = Client::with_connector(MockRedirectPolicy); + client.set_redirect_policy(RedirectPolicy::FollowNone); + let res = client.get("http://127.0.0.1").send().unwrap(); + assert_eq!(res.headers.get(), Some(&Server("mock1".into_string()))); + } + + #[test] + fn test_redirect_followif() { + fn follow_if(url: &Url) -> bool { + !url.serialize()[].contains("127.0.0.3") + } + let mut client = Client::with_connector(MockRedirectPolicy); + client.set_redirect_policy(RedirectPolicy::FollowIf(follow_if)); + let res = client.get("http://127.0.0.1").send().unwrap(); + assert_eq!(res.headers.get(), Some(&Server("mock2".into_string()))); + } + +} diff --git a/src/client/request.rs b/src/client/request.rs index 28125174..ff1fa9c1 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -8,12 +8,11 @@ use method::Method::{Get, Post, Delete, Put, Patch, Head, Options}; use header::Headers; use header::common::{mod, Host}; use net::{NetworkStream, NetworkConnector, HttpConnector, Fresh, Streaming}; -use HttpError::HttpUriError; use http::{HttpWriter, LINE_ENDING}; use http::HttpWriter::{ThroughWriter, ChunkedWriter, SizedWriter, EmptyWriter}; use version; use HttpResult; -use client::Response; +use client::{Response, get_host_and_port}; /// A client request to a remote server. @@ -42,23 +41,14 @@ impl Request { impl Request { /// Create a new client request. pub fn new(method: method::Method, url: Url) -> HttpResult> { - let mut conn = HttpConnector; + let mut conn = HttpConnector(None); Request::with_connector(method, url, &mut conn) } /// Create a new client request with a specific underlying NetworkStream. pub fn with_connector, S: NetworkStream>(method: method::Method, url: Url, connector: &mut C) -> HttpResult> { debug!("{} {}", method, url); - let host = match url.serialize_host() { - Some(host) => host, - None => return Err(HttpUriError) - }; - debug!("host={}", host); - let port = match url.port_or_default() { - Some(port) => port, - None => return Err(HttpUriError) - }; - debug!("port={}", port); + let (host, port) = try!(get_host_and_port(&url)); let stream: S = try!(connector.connect(host[], port, &*url.scheme)); let stream = ThroughWriter(BufferedWriter::new(box stream as Box)); @@ -80,30 +70,37 @@ impl Request { /// Create a new GET request. #[inline] + #[deprecated = "use hyper::Client"] pub fn get(url: Url) -> HttpResult> { Request::new(Get, url) } /// Create a new POST request. #[inline] + #[deprecated = "use hyper::Client"] pub fn post(url: Url) -> HttpResult> { Request::new(Post, url) } /// Create a new DELETE request. #[inline] + #[deprecated = "use hyper::Client"] pub fn delete(url: Url) -> HttpResult> { Request::new(Delete, url) } /// Create a new PUT request. #[inline] + #[deprecated = "use hyper::Client"] pub fn put(url: Url) -> HttpResult> { Request::new(Put, url) } /// Create a new PATCH request. #[inline] + #[deprecated = "use hyper::Client"] pub fn patch(url: Url) -> HttpResult> { Request::new(Patch, url) } /// Create a new HEAD request. #[inline] + #[deprecated = "use hyper::Client"] pub fn head(url: Url) -> HttpResult> { Request::new(Head, url) } /// Create a new OPTIONS request. #[inline] + #[deprecated = "use hyper::Client"] pub fn options(url: Url) -> HttpResult> { Request::new(Options, url) } /// Consume a Fresh Request, writing the headers and method, diff --git a/src/client/response.rs b/src/client/response.rs index f0be6559..7bffe22d 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -38,7 +38,7 @@ impl Response { debug!("{} {}", version, status); let headers = try!(header::Headers::from_raw(&mut stream)); - debug!("{}", headers); + debug!("Headers: [\n{}]", headers); let body = if headers.has::() { match headers.get::() { diff --git a/src/header/mod.rs b/src/header/mod.rs index 440d55b2..53da306f 100644 --- a/src/header/mod.rs +++ b/src/header/mod.rs @@ -30,7 +30,7 @@ pub mod common; /// /// This trait represents the construction and identification of headers, /// and contains trait-object unsafe methods. -pub trait Header: Any + Send + Sync { +pub trait Header: Clone + Any + Send + Sync { /// Returns the name of the header field this belongs to. /// /// The market `Option` is to hint to the type system which implementation diff --git a/src/http.rs b/src/http.rs index 90fbc40b..12657e6e 100644 --- a/src/http.rs +++ b/src/http.rs @@ -7,6 +7,7 @@ use std::num::from_u16; use std::str::{mod, SendStr}; use url::Url; +use url::ParseError as UrlError; use method; use status::StatusCode; @@ -234,6 +235,7 @@ impl Writer for HttpWriter { ThroughWriter(ref mut w) => w.write(msg), ChunkedWriter(ref mut w) => { let chunk_size = msg.len(); + debug!("chunked write, size = {}", chunk_size); try!(write!(w, "{:X}{}{}", chunk_size, CR as char, LF as char)); try!(w.write(msg)); w.write(LINE_ENDING) @@ -419,7 +421,7 @@ pub fn read_uri(stream: &mut R) -> HttpResult { break; }, CR | LF => { - return Err(HttpUriError) + return Err(HttpUriError(UrlError::InvalidCharacter)) }, b => s.push(b as char) } @@ -431,26 +433,13 @@ pub fn read_uri(stream: &mut R) -> HttpResult { if s.as_slice().starts_with("/") { Ok(AbsolutePath(s)) } else if s.as_slice().contains("/") { - match Url::parse(s.as_slice()) { - Ok(u) => Ok(AbsoluteUri(u)), - Err(_e) => { - debug!("URL err {}", _e); - Err(HttpUriError) - } - } + Ok(AbsoluteUri(try!(Url::parse(s.as_slice())))) } else { let mut temp = "http://".to_string(); temp.push_str(s.as_slice()); - match Url::parse(temp.as_slice()) { - Ok(_u) => { - todo!("compare vs u.authority()"); - Ok(Authority(s)) - } - Err(_e) => { - debug!("URL err {}", _e); - Err(HttpUriError) - } - } + try!(Url::parse(temp.as_slice())); + todo!("compare vs u.authority()"); + Ok(Authority(s)) } diff --git a/src/lib.rs b/src/lib.rs index 2c80be30..e993737c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -138,6 +138,7 @@ extern crate mucell; pub use std::io::net::ip::{SocketAddr, IpAddr, Ipv4Addr, Ipv6Addr, Port}; pub use mimewrapper::mime; pub use url::Url; +pub use client::Client; pub use method::Method::{Get, Head, Post, Delete}; pub use status::StatusCode::{Ok, BadRequest, NotFound}; pub use server::Server; @@ -181,6 +182,10 @@ macro_rules! inspect( }) ) +#[cfg(test)] +#[macro_escape] +mod mock; + pub mod client; pub mod method; pub mod header; @@ -191,7 +196,6 @@ pub mod status; pub mod uri; pub mod version; -#[cfg(test)] mod mock; mod mimewrapper { /// Re-exporting the mime crate, for convenience. @@ -208,7 +212,7 @@ pub enum HttpError { /// An invalid `Method`, such as `GE,T`. HttpMethodError, /// An invalid `RequestUri`, such as `exam ple.domain`. - HttpUriError, + HttpUriError(url::ParseError), /// An invalid `HttpVersion`, such as `HTP/1.1` HttpVersionError, /// An invalid `Header`. @@ -223,7 +227,7 @@ impl Error for HttpError { fn description(&self) -> &str { match *self { HttpMethodError => "Invalid Method specified", - HttpUriError => "Invalid Request URI specified", + HttpUriError(_) => "Invalid Request URI specified", HttpVersionError => "Invalid HTTP version specified", HttpHeaderError => "Invalid Header provided", HttpStatusError => "Invalid Status provided", @@ -234,6 +238,7 @@ impl Error for HttpError { fn cause(&self) -> Option<&Error> { match *self { HttpIoError(ref error) => Some(error as &Error), + HttpUriError(ref error) => Some(error as &Error), _ => None, } } @@ -245,6 +250,12 @@ impl FromError for HttpError { } } +impl FromError for HttpError { + fn from_error(err: url::ParseError) -> HttpError { + HttpUriError(err) + } +} + //FIXME: when Opt-in Built-in Types becomes a thing, we can force these structs //to be Send. For now, this has the compiler do a static check. fn _assert_send() { diff --git a/src/mock.rs b/src/mock.rs index cc22a8a7..fdecd5e7 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -73,3 +73,35 @@ impl NetworkConnector for MockConnector { Ok(MockStream::new()) } } + +/// new connectors must be created if you wish to intercept requests. +macro_rules! mock_connector ( + ($name:ident { + $($url:expr => $res:expr)* + }) => ( + + struct $name; + + impl ::net::NetworkConnector<::mock::MockStream> for $name { + fn connect(&mut self, host: &str, port: u16, scheme: &str) -> ::std::io::IoResult<::mock::MockStream> { + use std::collections::HashMap; + debug!("MockStream::connect({}, {}, {})", host, port, scheme); + let mut map = HashMap::new(); + $(map.insert($url, $res);)* + + + let key = format!("{}://{}", scheme, host); + // ignore port for now + match map.find(&&*key) { + Some(res) => Ok(::mock::MockStream { + write: ::std::io::MemWriter::new(), + read: ::std::io::MemReader::new(res.to_string().into_bytes()) + }), + None => panic!("{} doesn't know url {}", stringify!($name), key) + } + } + + } + + ) +) diff --git a/src/net.rs b/src/net.rs index 5a9f3f1a..7e9e2a91 100644 --- a/src/net.rs +++ b/src/net.rs @@ -11,7 +11,8 @@ use std::mem::{mod, transmute, transmute_copy}; use std::raw::{mod, TraitObject}; use uany::UncheckedBoxAnyDowncast; -use openssl::ssl::{SslStream, SslContext, Ssl}; +use openssl::ssl::{Ssl, SslStream, SslContext, VerifyCallback}; +use openssl::ssl::SslVerifyMode::SslVerifyPeer; use openssl::ssl::SslMethod::Sslv23; use openssl::ssl::error::{SslError, StreamError, OpenSslErrors, SslSessionClosed}; @@ -239,7 +240,7 @@ impl NetworkStream for HttpStream { /// A connector that will produce HttpStreams. #[allow(missing_copy_implementations)] -pub struct HttpConnector; +pub struct HttpConnector(pub Option); impl NetworkConnector for HttpConnector { fn connect(&mut self, host: &str, port: Port, scheme: &str) -> IoResult { @@ -252,12 +253,11 @@ impl NetworkConnector for HttpConnector { "https" => { debug!("https scheme"); let stream = try!(TcpStream::connect(addr)); - let context = try!(SslContext::new(Sslv23).map_err(lift_ssl_error)); + let mut context = try!(SslContext::new(Sslv23).map_err(lift_ssl_error)); + self.0.as_ref().map(|cb| context.set_verify(SslVerifyPeer, Some(*cb))); let ssl = try!(Ssl::new(&context).map_err(lift_ssl_error)); - debug!("ssl set_hostname = {}", host); try!(ssl.set_hostname(host).map_err(lift_ssl_error)); - debug!("ssl set_hostname done"); - let stream = try!(SslStream::new_from(ssl, stream).map_err(lift_ssl_error)); + let stream = try!(SslStream::new(&context, stream).map_err(lift_ssl_error)); Ok(Https(stream)) }, _ => { diff --git a/src/server/mod.rs b/src/server/mod.rs index 10ae4660..c43c4460 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -105,8 +105,8 @@ impl, S: NetworkStream, A: NetworkAcceptor> Server()) { - (Http10, Some(conn)) if !conn.0.contains(&KeepAlive) => false, - (Http11, Some(conn)) if conn.0.contains(&Close) => false, + (Http10, Some(conn)) if !conn.contains(&KeepAlive) => false, + (Http11, Some(conn)) if conn.contains(&Close) => false, _ => true }; res.version = req.version;