fix(server): send 400 responses on parse errors before closing connection
This commit is contained in:
@@ -186,7 +186,8 @@ where I: AsyncRead + AsyncWrite,
|
|||||||
let was_mid_parse = !self.io.read_buf().is_empty();
|
let was_mid_parse = !self.io.read_buf().is_empty();
|
||||||
return if was_mid_parse || must_error {
|
return if was_mid_parse || must_error {
|
||||||
debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len());
|
debug!("parse error ({}) with {} bytes", e, self.io.read_buf().len());
|
||||||
Err(e)
|
self.on_parse_error(e)
|
||||||
|
.map(|()| Async::NotReady)
|
||||||
} else {
|
} else {
|
||||||
debug!("read eof");
|
debug!("read eof");
|
||||||
Ok(Async::Ready(None))
|
Ok(Async::Ready(None))
|
||||||
@@ -213,7 +214,8 @@ where I: AsyncRead + AsyncWrite,
|
|||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("decoder error = {:?}", e);
|
debug!("decoder error = {:?}", e);
|
||||||
self.state.close_read();
|
self.state.close_read();
|
||||||
return Err(e);
|
return self.on_parse_error(e)
|
||||||
|
.map(|()| Async::NotReady);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -548,6 +550,27 @@ where I: AsyncRead + AsyncWrite,
|
|||||||
Ok(AsyncSink::Ready)
|
Ok(AsyncSink::Ready)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When we get a parse error, depending on what side we are, we might be able
|
||||||
|
// to write a response before closing the connection.
|
||||||
|
//
|
||||||
|
// - Client: there is nothing we can do
|
||||||
|
// - Server: if Response hasn't been written yet, we can send a 4xx response
|
||||||
|
fn on_parse_error(&mut self, err: ::Error) -> ::Result<()> {
|
||||||
|
match self.state.writing {
|
||||||
|
Writing::Init => {
|
||||||
|
if let Some(msg) = T::on_error(&err) {
|
||||||
|
self.write_head(msg, false);
|
||||||
|
self.state.error = Some(err);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
|
||||||
|
// fallback is pass the error back up
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
|
||||||
fn write_queued(&mut self) -> Poll<(), io::Error> {
|
fn write_queued(&mut self) -> Poll<(), io::Error> {
|
||||||
trace!("Conn::write_queued()");
|
trace!("Conn::write_queued()");
|
||||||
let state = match self.state.writing {
|
let state = match self.state.writing {
|
||||||
|
|||||||
@@ -150,6 +150,26 @@ impl Http1Transaction for ServerTransaction {
|
|||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn on_error(err: &::Error) -> Option<MessageHead<Self::Outgoing>> {
|
||||||
|
let status = match err {
|
||||||
|
&::Error::Method |
|
||||||
|
&::Error::Version |
|
||||||
|
&::Error::Header |
|
||||||
|
&::Error::Uri(_) => {
|
||||||
|
StatusCode::BadRequest
|
||||||
|
},
|
||||||
|
&::Error::TooLarge => {
|
||||||
|
StatusCode::RequestHeaderFieldsTooLarge
|
||||||
|
}
|
||||||
|
_ => return None,
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!("sending automatic response ({}) for parse error", status);
|
||||||
|
let mut msg = MessageHead::default();
|
||||||
|
msg.subject = status;
|
||||||
|
Some(msg)
|
||||||
|
}
|
||||||
|
|
||||||
fn should_error_on_parse_eof() -> bool {
|
fn should_error_on_parse_eof() -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
@@ -317,6 +337,11 @@ impl Http1Transaction for ClientTransaction {
|
|||||||
Ok(body)
|
Ok(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn on_error(_err: &::Error) -> Option<MessageHead<Self::Outgoing>> {
|
||||||
|
// we can't tell the server about any errors it creates
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
fn should_error_on_parse_eof() -> bool {
|
fn should_error_on_parse_eof() -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -149,6 +149,7 @@ pub trait Http1Transaction {
|
|||||||
fn parse(bytes: &mut BytesMut) -> ParseResult<Self::Incoming>;
|
fn parse(bytes: &mut BytesMut) -> ParseResult<Self::Incoming>;
|
||||||
fn decoder(head: &MessageHead<Self::Incoming>, method: &mut Option<::Method>) -> ::Result<Option<h1::Decoder>>;
|
fn decoder(head: &MessageHead<Self::Incoming>, method: &mut Option<::Method>) -> ::Result<Option<h1::Decoder>>;
|
||||||
fn encode(head: MessageHead<Self::Outgoing>, has_body: bool, method: &mut Option<Method>, dst: &mut Vec<u8>) -> ::Result<h1::Encoder>;
|
fn encode(head: MessageHead<Self::Outgoing>, has_body: bool, method: &mut Option<Method>, dst: &mut Vec<u8>) -> ::Result<h1::Encoder>;
|
||||||
|
fn on_error(err: &::Error) -> Option<MessageHead<Self::Outgoing>>;
|
||||||
|
|
||||||
fn should_error_on_parse_eof() -> bool;
|
fn should_error_on_parse_eof() -> bool;
|
||||||
fn should_read_first() -> bool;
|
fn should_read_first() -> bool;
|
||||||
|
|||||||
@@ -900,6 +900,64 @@ fn returning_1xx_response_is_error() {
|
|||||||
core.run(fut).unwrap_err();
|
core.run(fut).unwrap_err();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_errors_send_4xx_response() {
|
||||||
|
let mut core = Core::new().unwrap();
|
||||||
|
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
thread::spawn(move || {
|
||||||
|
let mut tcp = connect(&addr);
|
||||||
|
tcp.write_all(b"GE T / HTTP/1.1\r\n\r\n").unwrap();
|
||||||
|
let mut buf = [0; 256];
|
||||||
|
tcp.read(&mut buf).unwrap();
|
||||||
|
|
||||||
|
let expected = "HTTP/1.1 400 ";
|
||||||
|
assert_eq!(s(&buf[..expected.len()]), expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
let fut = listener.incoming()
|
||||||
|
.into_future()
|
||||||
|
.map_err(|_| unreachable!())
|
||||||
|
.and_then(|(item, _incoming)| {
|
||||||
|
let (socket, _) = item.unwrap();
|
||||||
|
Http::<hyper::Chunk>::new()
|
||||||
|
.serve_connection(socket, HelloWorld)
|
||||||
|
.map(|_| ())
|
||||||
|
});
|
||||||
|
|
||||||
|
core.run(fut).unwrap_err();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn illegal_request_length_returns_400_response() {
|
||||||
|
let mut core = Core::new().unwrap();
|
||||||
|
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
thread::spawn(move || {
|
||||||
|
let mut tcp = connect(&addr);
|
||||||
|
tcp.write_all(b"POST / HTTP/1.1\r\nContent-Length: foo\r\n\r\n").unwrap();
|
||||||
|
let mut buf = [0; 256];
|
||||||
|
tcp.read(&mut buf).unwrap();
|
||||||
|
|
||||||
|
let expected = "HTTP/1.1 400 ";
|
||||||
|
assert_eq!(s(&buf[..expected.len()]), expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
let fut = listener.incoming()
|
||||||
|
.into_future()
|
||||||
|
.map_err(|_| unreachable!())
|
||||||
|
.and_then(|(item, _incoming)| {
|
||||||
|
let (socket, _) = item.unwrap();
|
||||||
|
Http::<hyper::Chunk>::new()
|
||||||
|
.serve_connection(socket, HelloWorld)
|
||||||
|
.map(|_| ())
|
||||||
|
});
|
||||||
|
|
||||||
|
core.run(fut).unwrap_err();
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn remote_addr() {
|
fn remote_addr() {
|
||||||
let server = serve();
|
let server = serve();
|
||||||
|
|||||||
Reference in New Issue
Block a user