fix(client): close connection when there is an Error
This commit is contained in:
		| @@ -101,11 +101,17 @@ impl Request<Fresh> { | |||||||
|     /// Consume a Fresh Request, writing the headers and method, |     /// Consume a Fresh Request, writing the headers and method, | ||||||
|     /// returning a Streaming Request. |     /// returning a Streaming Request. | ||||||
|     pub fn start(mut self) -> ::Result<Request<Streaming>> { |     pub fn start(mut self) -> ::Result<Request<Streaming>> { | ||||||
|         let head = try!(self.message.set_outgoing(RequestHead { |         let head = match self.message.set_outgoing(RequestHead { | ||||||
|             headers: self.headers, |             headers: self.headers, | ||||||
|             method: self.method, |             method: self.method, | ||||||
|             url: self.url, |             url: self.url, | ||||||
|         })); |         }) { | ||||||
|  |             Ok(head) => head, | ||||||
|  |             Err(e) => { | ||||||
|  |                 let _ = self.message.close_connection(); | ||||||
|  |                 return Err(From::from(e)); | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  |  | ||||||
|         Ok(Request { |         Ok(Request { | ||||||
|             method: head.method, |             method: head.method, | ||||||
| @@ -134,17 +140,30 @@ impl Request<Streaming> { | |||||||
| impl Write for Request<Streaming> { | impl Write for Request<Streaming> { | ||||||
|     #[inline] |     #[inline] | ||||||
|     fn write(&mut self, msg: &[u8]) -> io::Result<usize> { |     fn write(&mut self, msg: &[u8]) -> io::Result<usize> { | ||||||
|         self.message.write(msg) |         match self.message.write(msg) { | ||||||
|  |             Ok(n) => Ok(n), | ||||||
|  |             Err(e) => { | ||||||
|  |                 let _ = self.message.close_connection(); | ||||||
|  |                 Err(e) | ||||||
|  |             } | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     #[inline] |     #[inline] | ||||||
|     fn flush(&mut self) -> io::Result<()> { |     fn flush(&mut self) -> io::Result<()> { | ||||||
|         self.message.flush() |         match self.message.flush() { | ||||||
|  |             Ok(r) => Ok(r), | ||||||
|  |             Err(e) => { | ||||||
|  |                 let _ = self.message.close_connection(); | ||||||
|  |                 Err(e) | ||||||
|  |             } | ||||||
|  |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| #[cfg(test)] | #[cfg(test)] | ||||||
| mod tests { | mod tests { | ||||||
|  |     use std::io::Write; | ||||||
|     use std::str::from_utf8; |     use std::str::from_utf8; | ||||||
|     use url::Url; |     use url::Url; | ||||||
|     use method::Method::{Get, Head, Post}; |     use method::Method::{Get, Head, Post}; | ||||||
| @@ -237,4 +256,24 @@ mod tests { | |||||||
|         assert!(!s.contains("Content-Length:")); |         assert!(!s.contains("Content-Length:")); | ||||||
|         assert!(s.contains("Transfer-Encoding:")); |         assert!(s.contains("Transfer-Encoding:")); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     #[test] | ||||||
|  |     fn test_write_error_closes() { | ||||||
|  |         let url = Url::parse("http://hyper.rs").unwrap(); | ||||||
|  |         let req = Request::with_connector( | ||||||
|  |             Get, url, &mut MockConnector | ||||||
|  |         ).unwrap(); | ||||||
|  |         let mut req = req.start().unwrap(); | ||||||
|  |  | ||||||
|  |         req.message.downcast_mut::<Http11Message>().unwrap() | ||||||
|  |             .get_mut().downcast_mut::<MockStream>().unwrap() | ||||||
|  |             .error_on_write = true; | ||||||
|  |  | ||||||
|  |         req.write(b"foo").unwrap(); | ||||||
|  |         assert!(req.flush().is_err()); | ||||||
|  |  | ||||||
|  |         assert!(req.message.downcast_ref::<Http11Message>().unwrap() | ||||||
|  |             .get_ref().downcast_ref::<MockStream>().unwrap() | ||||||
|  |             .is_closed); | ||||||
|  |     } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -37,7 +37,13 @@ impl Response { | |||||||
|     /// Creates a new response received from the server on the given `HttpMessage`. |     /// Creates a new response received from the server on the given `HttpMessage`. | ||||||
|     pub fn with_message(url: Url, mut message: Box<HttpMessage>) -> ::Result<Response> { |     pub fn with_message(url: Url, mut message: Box<HttpMessage>) -> ::Result<Response> { | ||||||
|         trace!("Response::with_message"); |         trace!("Response::with_message"); | ||||||
|         let ResponseHead { headers, raw_status, version } = try!(message.get_incoming()); |         let ResponseHead { headers, raw_status, version } = match message.get_incoming() { | ||||||
|  |             Ok(head) => head, | ||||||
|  |             Err(e) => { | ||||||
|  |                 let _ = message.close_connection(); | ||||||
|  |                 return Err(From::from(e)); | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|         let status = status::StatusCode::from_u16(raw_status.0); |         let status = status::StatusCode::from_u16(raw_status.0); | ||||||
|         debug!("version={:?}, status={:?}", version, status); |         debug!("version={:?}, status={:?}", version, status); | ||||||
|         debug!("headers={:?}", headers); |         debug!("headers={:?}", headers); | ||||||
| @@ -54,6 +60,7 @@ impl Response { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     /// Get the raw status code and reason. |     /// Get the raw status code and reason. | ||||||
|  |     #[inline] | ||||||
|     pub fn status_raw(&self) -> &RawStatus { |     pub fn status_raw(&self) -> &RawStatus { | ||||||
|         &self.status_raw |         &self.status_raw | ||||||
|     } |     } | ||||||
| @@ -68,6 +75,10 @@ impl Read for Response { | |||||||
|                 self.is_drained = true; |                 self.is_drained = true; | ||||||
|                 Ok(0) |                 Ok(0) | ||||||
|             }, |             }, | ||||||
|  |             Err(e) => { | ||||||
|  |                 let _ = self.message.close_connection(); | ||||||
|  |                 Err(e) | ||||||
|  |             } | ||||||
|             r => r |             r => r | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|   | |||||||
							
								
								
									
										32
									
								
								src/mock.rs
									
									
									
									
									
								
							
							
						
						
									
										32
									
								
								src/mock.rs
									
									
									
									
									
								
							| @@ -2,7 +2,7 @@ use std::fmt; | |||||||
| use std::ascii::AsciiExt; | use std::ascii::AsciiExt; | ||||||
| use std::io::{self, Read, Write, Cursor}; | use std::io::{self, Read, Write, Cursor}; | ||||||
| use std::cell::RefCell; | use std::cell::RefCell; | ||||||
| use std::net::SocketAddr; | use std::net::{SocketAddr, Shutdown}; | ||||||
| use std::sync::{Arc, Mutex}; | use std::sync::{Arc, Mutex}; | ||||||
| #[cfg(feature = "timeouts")] | #[cfg(feature = "timeouts")] | ||||||
| use std::time::Duration; | use std::time::Duration; | ||||||
| @@ -21,10 +21,13 @@ use net::{NetworkStream, NetworkConnector}; | |||||||
| pub struct MockStream { | pub struct MockStream { | ||||||
|     pub read: Cursor<Vec<u8>>, |     pub read: Cursor<Vec<u8>>, | ||||||
|     pub write: Vec<u8>, |     pub write: Vec<u8>, | ||||||
|  |     pub is_closed: bool, | ||||||
|  |     pub error_on_write: bool, | ||||||
|  |     pub error_on_read: bool, | ||||||
|     #[cfg(feature = "timeouts")] |     #[cfg(feature = "timeouts")] | ||||||
|     pub read_timeout: Cell<Option<Duration>>, |     pub read_timeout: Cell<Option<Duration>>, | ||||||
|     #[cfg(feature = "timeouts")] |     #[cfg(feature = "timeouts")] | ||||||
|     pub write_timeout: Cell<Option<Duration>> |     pub write_timeout: Cell<Option<Duration>>, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl fmt::Debug for MockStream { | impl fmt::Debug for MockStream { | ||||||
| @@ -48,7 +51,10 @@ impl MockStream { | |||||||
|     pub fn with_input(input: &[u8]) -> MockStream { |     pub fn with_input(input: &[u8]) -> MockStream { | ||||||
|         MockStream { |         MockStream { | ||||||
|             read: Cursor::new(input.to_vec()), |             read: Cursor::new(input.to_vec()), | ||||||
|             write: vec![] |             write: vec![], | ||||||
|  |             is_closed: false, | ||||||
|  |             error_on_write: false, | ||||||
|  |             error_on_read: false, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -57,6 +63,9 @@ impl MockStream { | |||||||
|         MockStream { |         MockStream { | ||||||
|             read: Cursor::new(input.to_vec()), |             read: Cursor::new(input.to_vec()), | ||||||
|             write: vec![], |             write: vec![], | ||||||
|  |             is_closed: false, | ||||||
|  |             error_on_write: false, | ||||||
|  |             error_on_read: false, | ||||||
|             read_timeout: Cell::new(None), |             read_timeout: Cell::new(None), | ||||||
|             write_timeout: Cell::new(None), |             write_timeout: Cell::new(None), | ||||||
|         } |         } | ||||||
| @@ -65,14 +74,22 @@ impl MockStream { | |||||||
|  |  | ||||||
| impl Read for MockStream { | impl Read for MockStream { | ||||||
|     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { | ||||||
|  |         if self.error_on_read { | ||||||
|  |             Err(io::Error::new(io::ErrorKind::Other, "mock error")) | ||||||
|  |         } else { | ||||||
|             self.read.read(buf) |             self.read.read(buf) | ||||||
|         } |         } | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| impl Write for MockStream { | impl Write for MockStream { | ||||||
|     fn write(&mut self, msg: &[u8]) -> io::Result<usize> { |     fn write(&mut self, msg: &[u8]) -> io::Result<usize> { | ||||||
|  |         if self.error_on_write { | ||||||
|  |             Err(io::Error::new(io::ErrorKind::Other, "mock error")) | ||||||
|  |         } else { | ||||||
|             Write::write(&mut self.write, msg) |             Write::write(&mut self.write, msg) | ||||||
|         } |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     fn flush(&mut self) -> io::Result<()> { |     fn flush(&mut self) -> io::Result<()> { | ||||||
|         Ok(()) |         Ok(()) | ||||||
| @@ -95,6 +112,11 @@ impl NetworkStream for MockStream { | |||||||
|         self.write_timeout.set(dur); |         self.write_timeout.set(dur); | ||||||
|         Ok(()) |         Ok(()) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     fn close(&mut self, _how: Shutdown) -> io::Result<()> { | ||||||
|  |         self.is_closed = true; | ||||||
|  |         Ok(()) | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| /// A wrapper around a `MockStream` that allows one to clone it and keep an independent copy to the | /// A wrapper around a `MockStream` that allows one to clone it and keep an independent copy to the | ||||||
| @@ -144,6 +166,10 @@ impl NetworkStream for CloneableMockStream { | |||||||
|     fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { |     fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> { | ||||||
|         self.inner.lock().unwrap().set_write_timeout(dur) |         self.inner.lock().unwrap().set_write_timeout(dur) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     fn close(&mut self, how: Shutdown) -> io::Result<()> { | ||||||
|  |         NetworkStream::close(&mut *self.inner.lock().unwrap(), how) | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| impl CloneableMockStream { | impl CloneableMockStream { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user