fix(server): send 400 responses on parse errors before closing connection
This commit is contained in:
		| @@ -186,7 +186,8 @@ where I: AsyncRead + AsyncWrite, | |||||||
|                     let was_mid_parse = !self.io.read_buf().is_empty(); |                     let was_mid_parse = !self.io.read_buf().is_empty(); | ||||||
|                     return if was_mid_parse || must_error { |                     return if was_mid_parse || must_error { | ||||||
|                         debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len()); |                         debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len()); | ||||||
|                         Err(e) |                         self.on_parse_error(e) | ||||||
|  |                             .map(|()| Async::NotReady) | ||||||
|                     } else { |                     } else { | ||||||
|                         debug!("read eof"); |                         debug!("read eof"); | ||||||
|                         Ok(Async::Ready(None)) |                         Ok(Async::Ready(None)) | ||||||
| @@ -213,7 +214,8 @@ where I: AsyncRead + AsyncWrite, | |||||||
|                 Err(e) => { |                 Err(e) => { | ||||||
|                     debug!("decoder error = {:?}", e); |                     debug!("decoder error = {:?}", e); | ||||||
|                     self.state.close_read(); |                     self.state.close_read(); | ||||||
|                     return Err(e); |                     return self.on_parse_error(e) | ||||||
|  |                         .map(|()| Async::NotReady); | ||||||
|                 } |                 } | ||||||
|             }; |             }; | ||||||
|  |  | ||||||
| @@ -548,6 +550,27 @@ where I: AsyncRead + AsyncWrite, | |||||||
|         Ok(AsyncSink::Ready) |         Ok(AsyncSink::Ready) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // When we get a parse error, depending on what side we are, we might be able | ||||||
|  |     // to write a response before closing the connection. | ||||||
|  |     // | ||||||
|  |     // - Client: there is nothing we can do | ||||||
|  |     // - Server: if Response hasn't been written yet, we can send a 4xx response | ||||||
|  |     fn on_parse_error(&mut self, err: ::Error) -> ::Result<()> { | ||||||
|  |         match self.state.writing { | ||||||
|  |             Writing::Init => { | ||||||
|  |                 if let Some(msg) = T::on_error(&err) { | ||||||
|  |                     self.write_head(msg, false); | ||||||
|  |                     self.state.error = Some(err); | ||||||
|  |                     return Ok(()); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             _ => (), | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // fallback is pass the error back up | ||||||
|  |         Err(err) | ||||||
|  |     } | ||||||
|  |  | ||||||
|     fn write_queued(&mut self) -> Poll<(), io::Error> { |     fn write_queued(&mut self) -> Poll<(), io::Error> { | ||||||
|         trace!("Conn::write_queued()"); |         trace!("Conn::write_queued()"); | ||||||
|         let state = match self.state.writing { |         let state = match self.state.writing { | ||||||
|   | |||||||
| @@ -150,6 +150,26 @@ impl Http1Transaction for ServerTransaction { | |||||||
|         ret |         ret | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     fn on_error(err: &::Error) -> Option<MessageHead<Self::Outgoing>> { | ||||||
|  |         let status = match err { | ||||||
|  |             &::Error::Method | | ||||||
|  |             &::Error::Version | | ||||||
|  |             &::Error::Header | | ||||||
|  |             &::Error::Uri(_) => { | ||||||
|  |                 StatusCode::BadRequest | ||||||
|  |             }, | ||||||
|  |             &::Error::TooLarge => { | ||||||
|  |                 StatusCode::RequestHeaderFieldsTooLarge | ||||||
|  |             } | ||||||
|  |             _ => return None, | ||||||
|  |         }; | ||||||
|  |  | ||||||
|  |         debug!("sending automatic response ({}) for parse error", status); | ||||||
|  |         let mut msg = MessageHead::default(); | ||||||
|  |         msg.subject = status; | ||||||
|  |         Some(msg) | ||||||
|  |     } | ||||||
|  |  | ||||||
|     fn should_error_on_parse_eof() -> bool { |     fn should_error_on_parse_eof() -> bool { | ||||||
|         false |         false | ||||||
|     } |     } | ||||||
| @@ -317,6 +337,11 @@ impl Http1Transaction for ClientTransaction { | |||||||
|         Ok(body) |         Ok(body) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     fn on_error(_err: &::Error) -> Option<MessageHead<Self::Outgoing>> { | ||||||
|  |         // we can't tell the server about any errors it creates | ||||||
|  |         None | ||||||
|  |     } | ||||||
|  |  | ||||||
|     fn should_error_on_parse_eof() -> bool { |     fn should_error_on_parse_eof() -> bool { | ||||||
|         true |         true | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -149,6 +149,7 @@ pub trait Http1Transaction { | |||||||
|     fn parse(bytes: &mut BytesMut) -> ParseResult<Self::Incoming>; |     fn parse(bytes: &mut BytesMut) -> ParseResult<Self::Incoming>; | ||||||
|     fn decoder(head: &MessageHead<Self::Incoming>, method: &mut Option<::Method>) -> ::Result<Option<h1::Decoder>>; |     fn decoder(head: &MessageHead<Self::Incoming>, method: &mut Option<::Method>) -> ::Result<Option<h1::Decoder>>; | ||||||
|     fn encode(head: MessageHead<Self::Outgoing>, has_body: bool, method: &mut Option<Method>, dst: &mut Vec<u8>) -> ::Result<h1::Encoder>; |     fn encode(head: MessageHead<Self::Outgoing>, has_body: bool, method: &mut Option<Method>, dst: &mut Vec<u8>) -> ::Result<h1::Encoder>; | ||||||
|  |     fn on_error(err: &::Error) -> Option<MessageHead<Self::Outgoing>>; | ||||||
|  |  | ||||||
|     fn should_error_on_parse_eof() -> bool; |     fn should_error_on_parse_eof() -> bool; | ||||||
|     fn should_read_first() -> bool; |     fn should_read_first() -> bool; | ||||||
|   | |||||||
| @@ -900,6 +900,64 @@ fn returning_1xx_response_is_error() { | |||||||
|     core.run(fut).unwrap_err(); |     core.run(fut).unwrap_err(); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[test] | ||||||
|  | fn parse_errors_send_4xx_response() { | ||||||
|  |     let mut core = Core::new().unwrap(); | ||||||
|  |     let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); | ||||||
|  |     let addr = listener.local_addr().unwrap(); | ||||||
|  |  | ||||||
|  |     thread::spawn(move || { | ||||||
|  |         let mut tcp = connect(&addr); | ||||||
|  |         tcp.write_all(b"GE T / HTTP/1.1\r\n\r\n").unwrap(); | ||||||
|  |         let mut buf = [0; 256]; | ||||||
|  |         tcp.read(&mut buf).unwrap(); | ||||||
|  |  | ||||||
|  |         let expected = "HTTP/1.1 400 "; | ||||||
|  |         assert_eq!(s(&buf[..expected.len()]), expected); | ||||||
|  |     }); | ||||||
|  |  | ||||||
|  |     let fut = listener.incoming() | ||||||
|  |         .into_future() | ||||||
|  |         .map_err(|_| unreachable!()) | ||||||
|  |         .and_then(|(item, _incoming)| { | ||||||
|  |             let (socket, _) = item.unwrap(); | ||||||
|  |             Http::<hyper::Chunk>::new() | ||||||
|  |                 .serve_connection(socket, HelloWorld) | ||||||
|  |                 .map(|_| ()) | ||||||
|  |         }); | ||||||
|  |  | ||||||
|  |     core.run(fut).unwrap_err(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #[test] | ||||||
|  | fn illegal_request_length_returns_400_response() { | ||||||
|  |     let mut core = Core::new().unwrap(); | ||||||
|  |     let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap(); | ||||||
|  |     let addr = listener.local_addr().unwrap(); | ||||||
|  |  | ||||||
|  |     thread::spawn(move || { | ||||||
|  |         let mut tcp = connect(&addr); | ||||||
|  |         tcp.write_all(b"POST / HTTP/1.1\r\nContent-Length: foo\r\n\r\n").unwrap(); | ||||||
|  |         let mut buf = [0; 256]; | ||||||
|  |         tcp.read(&mut buf).unwrap(); | ||||||
|  |  | ||||||
|  |         let expected = "HTTP/1.1 400 "; | ||||||
|  |         assert_eq!(s(&buf[..expected.len()]), expected); | ||||||
|  |     }); | ||||||
|  |  | ||||||
|  |     let fut = listener.incoming() | ||||||
|  |         .into_future() | ||||||
|  |         .map_err(|_| unreachable!()) | ||||||
|  |         .and_then(|(item, _incoming)| { | ||||||
|  |             let (socket, _) = item.unwrap(); | ||||||
|  |             Http::<hyper::Chunk>::new() | ||||||
|  |                 .serve_connection(socket, HelloWorld) | ||||||
|  |                 .map(|_| ()) | ||||||
|  |         }); | ||||||
|  |  | ||||||
|  |     core.run(fut).unwrap_err(); | ||||||
|  | } | ||||||
|  |  | ||||||
| #[test] | #[test] | ||||||
| fn remote_addr() { | fn remote_addr() { | ||||||
|     let server = serve(); |     let server = serve(); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user