fix(server): send 400 responses on parse errors before closing connection

This commit is contained in:
Sean McArthur
2018-01-23 15:31:26 -08:00
parent 44c34ce9ad
commit 7cb72d2019
4 changed files with 109 additions and 2 deletions

View File

@@ -186,7 +186,8 @@ where I: AsyncRead + AsyncWrite,
let was_mid_parse = !self.io.read_buf().is_empty();
return if was_mid_parse || must_error {
debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len());
Err(e)
self.on_parse_error(e)
.map(|()| Async::NotReady)
} else {
debug!("read eof");
Ok(Async::Ready(None))
@@ -213,7 +214,8 @@ where I: AsyncRead + AsyncWrite,
Err(e) => {
debug!("decoder error = {:?}", e);
self.state.close_read();
return Err(e);
return self.on_parse_error(e)
.map(|()| Async::NotReady);
}
};
@@ -548,6 +550,27 @@ where I: AsyncRead + AsyncWrite,
Ok(AsyncSink::Ready)
}
// When we get a parse error, depending on what side we are, we might be able
// to write a response before closing the connection.
//
// - Client: there is nothing we can do
// - Server: if Response hasn't been written yet, we can send a 4xx response
fn on_parse_error(&mut self, err: ::Error) -> ::Result<()> {
match self.state.writing {
Writing::Init => {
if let Some(msg) = T::on_error(&err) {
self.write_head(msg, false);
self.state.error = Some(err);
return Ok(());
}
}
_ => (),
}
// fallback is pass the error back up
Err(err)
}
fn write_queued(&mut self) -> Poll<(), io::Error> {
trace!("Conn::write_queued()");
let state = match self.state.writing {

View File

@@ -150,6 +150,26 @@ impl Http1Transaction for ServerTransaction {
ret
}
fn on_error(err: &::Error) -> Option<MessageHead<Self::Outgoing>> {
let status = match err {
&::Error::Method |
&::Error::Version |
&::Error::Header |
&::Error::Uri(_) => {
StatusCode::BadRequest
},
&::Error::TooLarge => {
StatusCode::RequestHeaderFieldsTooLarge
}
_ => return None,
};
debug!("sending automatic response ({}) for parse error", status);
let mut msg = MessageHead::default();
msg.subject = status;
Some(msg)
}
fn should_error_on_parse_eof() -> bool {
false
}
@@ -317,6 +337,11 @@ impl Http1Transaction for ClientTransaction {
Ok(body)
}
fn on_error(_err: &::Error) -> Option<MessageHead<Self::Outgoing>> {
// we can't tell the server about any errors it creates
None
}
fn should_error_on_parse_eof() -> bool {
true
}

View File

@@ -149,6 +149,7 @@ pub trait Http1Transaction {
fn parse(bytes: &mut BytesMut) -> ParseResult<Self::Incoming>;
fn decoder(head: &MessageHead<Self::Incoming>, method: &mut Option<::Method>) -> ::Result<Option<h1::Decoder>>;
fn encode(head: MessageHead<Self::Outgoing>, has_body: bool, method: &mut Option<Method>, dst: &mut Vec<u8>) -> ::Result<h1::Encoder>;
fn on_error(err: &::Error) -> Option<MessageHead<Self::Outgoing>>;
fn should_error_on_parse_eof() -> bool;
fn should_read_first() -> bool;

View File

@@ -900,6 +900,64 @@ fn returning_1xx_response_is_error() {
core.run(fut).unwrap_err();
}
#[test]
fn parse_errors_send_4xx_response() {
let mut core = Core::new().unwrap();
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
let addr = listener.local_addr().unwrap();
thread::spawn(move || {
let mut tcp = connect(&addr);
tcp.write_all(b"GE T / HTTP/1.1\r\n\r\n").unwrap();
let mut buf = [0; 256];
tcp.read(&mut buf).unwrap();
let expected = "HTTP/1.1 400 ";
assert_eq!(s(&buf[..expected.len()]), expected);
});
let fut = listener.incoming()
.into_future()
.map_err(|_| unreachable!())
.and_then(|(item, _incoming)| {
let (socket, _) = item.unwrap();
Http::<hyper::Chunk>::new()
.serve_connection(socket, HelloWorld)
.map(|_| ())
});
core.run(fut).unwrap_err();
}
#[test]
fn illegal_request_length_returns_400_response() {
let mut core = Core::new().unwrap();
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
let addr = listener.local_addr().unwrap();
thread::spawn(move || {
let mut tcp = connect(&addr);
tcp.write_all(b"POST / HTTP/1.1\r\nContent-Length: foo\r\n\r\n").unwrap();
let mut buf = [0; 256];
tcp.read(&mut buf).unwrap();
let expected = "HTTP/1.1 400 ";
assert_eq!(s(&buf[..expected.len()]), expected);
});
let fut = listener.incoming()
.into_future()
.map_err(|_| unreachable!())
.and_then(|(item, _incoming)| {
let (socket, _) = item.unwrap();
Http::<hyper::Chunk>::new()
.serve_connection(socket, HelloWorld)
.map(|_| ())
});
core.run(fut).unwrap_err();
}
#[test]
fn remote_addr() {
let server = serve();