fix(client): close connection when there is an Error

This commit is contained in:
Sean McArthur
2015-08-19 14:37:30 -07:00
parent e305a2e9dc
commit d32d35bbea
3 changed files with 86 additions and 10 deletions

View File

@@ -101,11 +101,17 @@ impl Request<Fresh> {
/// Consume a Fresh Request, writing the headers and method, /// Consume a Fresh Request, writing the headers and method,
/// returning a Streaming Request. /// returning a Streaming Request.
pub fn start(mut self) -> ::Result<Request<Streaming>> { pub fn start(mut self) -> ::Result<Request<Streaming>> {
let head = try!(self.message.set_outgoing(RequestHead { let head = match self.message.set_outgoing(RequestHead {
headers: self.headers, headers: self.headers,
method: self.method, method: self.method,
url: self.url, url: self.url,
})); }) {
Ok(head) => head,
Err(e) => {
let _ = self.message.close_connection();
return Err(From::from(e));
}
};
Ok(Request { Ok(Request {
method: head.method, method: head.method,
@@ -134,17 +140,30 @@ impl Request<Streaming> {
impl Write for Request<Streaming> { impl Write for Request<Streaming> {
#[inline] #[inline]
fn write(&mut self, msg: &[u8]) -> io::Result<usize> { fn write(&mut self, msg: &[u8]) -> io::Result<usize> {
self.message.write(msg) match self.message.write(msg) {
Ok(n) => Ok(n),
Err(e) => {
let _ = self.message.close_connection();
Err(e)
}
}
} }
#[inline] #[inline]
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
self.message.flush() match self.message.flush() {
Ok(r) => Ok(r),
Err(e) => {
let _ = self.message.close_connection();
Err(e)
}
}
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::io::Write;
use std::str::from_utf8; use std::str::from_utf8;
use url::Url; use url::Url;
use method::Method::{Get, Head, Post}; use method::Method::{Get, Head, Post};
@@ -237,4 +256,24 @@ mod tests {
assert!(!s.contains("Content-Length:")); assert!(!s.contains("Content-Length:"));
assert!(s.contains("Transfer-Encoding:")); assert!(s.contains("Transfer-Encoding:"));
} }
#[test]
fn test_write_error_closes() {
let url = Url::parse("http://hyper.rs").unwrap();
let req = Request::with_connector(
Get, url, &mut MockConnector
).unwrap();
let mut req = req.start().unwrap();
req.message.downcast_mut::<Http11Message>().unwrap()
.get_mut().downcast_mut::<MockStream>().unwrap()
.error_on_write = true;
req.write(b"foo").unwrap();
assert!(req.flush().is_err());
assert!(req.message.downcast_ref::<Http11Message>().unwrap()
.get_ref().downcast_ref::<MockStream>().unwrap()
.is_closed);
}
} }

View File

@@ -37,7 +37,13 @@ impl Response {
/// Creates a new response received from the server on the given `HttpMessage`. /// Creates a new response received from the server on the given `HttpMessage`.
pub fn with_message(url: Url, mut message: Box<HttpMessage>) -> ::Result<Response> { pub fn with_message(url: Url, mut message: Box<HttpMessage>) -> ::Result<Response> {
trace!("Response::with_message"); trace!("Response::with_message");
let ResponseHead { headers, raw_status, version } = try!(message.get_incoming()); let ResponseHead { headers, raw_status, version } = match message.get_incoming() {
Ok(head) => head,
Err(e) => {
let _ = message.close_connection();
return Err(From::from(e));
}
};
let status = status::StatusCode::from_u16(raw_status.0); let status = status::StatusCode::from_u16(raw_status.0);
debug!("version={:?}, status={:?}", version, status); debug!("version={:?}, status={:?}", version, status);
debug!("headers={:?}", headers); debug!("headers={:?}", headers);
@@ -54,6 +60,7 @@ impl Response {
} }
/// Get the raw status code and reason. /// Get the raw status code and reason.
#[inline]
pub fn status_raw(&self) -> &RawStatus { pub fn status_raw(&self) -> &RawStatus {
&self.status_raw &self.status_raw
} }
@@ -68,6 +75,10 @@ impl Read for Response {
self.is_drained = true; self.is_drained = true;
Ok(0) Ok(0)
}, },
Err(e) => {
let _ = self.message.close_connection();
Err(e)
}
r => r r => r
} }
} }

View File

@@ -2,7 +2,7 @@ use std::fmt;
use std::ascii::AsciiExt; use std::ascii::AsciiExt;
use std::io::{self, Read, Write, Cursor}; use std::io::{self, Read, Write, Cursor};
use std::cell::RefCell; use std::cell::RefCell;
use std::net::SocketAddr; use std::net::{SocketAddr, Shutdown};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
#[cfg(feature = "timeouts")] #[cfg(feature = "timeouts")]
use std::time::Duration; use std::time::Duration;
@@ -21,10 +21,13 @@ use net::{NetworkStream, NetworkConnector};
pub struct MockStream { pub struct MockStream {
pub read: Cursor<Vec<u8>>, pub read: Cursor<Vec<u8>>,
pub write: Vec<u8>, pub write: Vec<u8>,
pub is_closed: bool,
pub error_on_write: bool,
pub error_on_read: bool,
#[cfg(feature = "timeouts")] #[cfg(feature = "timeouts")]
pub read_timeout: Cell<Option<Duration>>, pub read_timeout: Cell<Option<Duration>>,
#[cfg(feature = "timeouts")] #[cfg(feature = "timeouts")]
pub write_timeout: Cell<Option<Duration>> pub write_timeout: Cell<Option<Duration>>,
} }
impl fmt::Debug for MockStream { impl fmt::Debug for MockStream {
@@ -48,7 +51,10 @@ impl MockStream {
pub fn with_input(input: &[u8]) -> MockStream { pub fn with_input(input: &[u8]) -> MockStream {
MockStream { MockStream {
read: Cursor::new(input.to_vec()), read: Cursor::new(input.to_vec()),
write: vec![] write: vec![],
is_closed: false,
error_on_write: false,
error_on_read: false,
} }
} }
@@ -57,6 +63,9 @@ impl MockStream {
MockStream { MockStream {
read: Cursor::new(input.to_vec()), read: Cursor::new(input.to_vec()),
write: vec![], write: vec![],
is_closed: false,
error_on_write: false,
error_on_read: false,
read_timeout: Cell::new(None), read_timeout: Cell::new(None),
write_timeout: Cell::new(None), write_timeout: Cell::new(None),
} }
@@ -65,14 +74,22 @@ impl MockStream {
impl Read for MockStream { impl Read for MockStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.error_on_read {
Err(io::Error::new(io::ErrorKind::Other, "mock error"))
} else {
self.read.read(buf) self.read.read(buf)
} }
}
} }
impl Write for MockStream { impl Write for MockStream {
fn write(&mut self, msg: &[u8]) -> io::Result<usize> { fn write(&mut self, msg: &[u8]) -> io::Result<usize> {
if self.error_on_write {
Err(io::Error::new(io::ErrorKind::Other, "mock error"))
} else {
Write::write(&mut self.write, msg) Write::write(&mut self.write, msg)
} }
}
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
Ok(()) Ok(())
@@ -95,6 +112,11 @@ impl NetworkStream for MockStream {
self.write_timeout.set(dur); self.write_timeout.set(dur);
Ok(()) Ok(())
} }
fn close(&mut self, _how: Shutdown) -> io::Result<()> {
self.is_closed = true;
Ok(())
}
} }
/// A wrapper around a `MockStream` that allows one to clone it and keep an independent copy to the /// A wrapper around a `MockStream` that allows one to clone it and keep an independent copy to the
@@ -144,6 +166,10 @@ impl NetworkStream for CloneableMockStream {
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.inner.lock().unwrap().set_write_timeout(dur) self.inner.lock().unwrap().set_write_timeout(dur)
} }
fn close(&mut self, how: Shutdown) -> io::Result<()> {
NetworkStream::close(&mut *self.inner.lock().unwrap(), how)
}
} }
impl CloneableMockStream { impl CloneableMockStream {