diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index e2e5a4b8..a4124fd6 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -451,43 +451,23 @@ where I: AsyncRead + AsyncWrite, } } - pub fn write_body(&mut self, chunk: Option) { + pub fn write_body(&mut self, chunk: B) { debug_assert!(self.can_write_body() && self.can_buffer_body()); + // empty chunks should be discarded at Dispatcher level + debug_assert!(chunk.remaining() != 0); let state = match self.state.writing { Writing::Body(ref mut encoder) => { - if let Some(chunk) = chunk { - if chunk.remaining() == 0 { - return; - } + self.io.buffer(encoder.encode(chunk)); - let encoded = encoder.encode(chunk); - self.io.buffer(encoded); - - if encoder.is_eof() { - if encoder.is_last() { - Writing::Closed - } else { - Writing::KeepAlive - } + if encoder.is_eof() { + if encoder.is_last() { + Writing::Closed } else { - return; + Writing::KeepAlive } } else { - // end of stream, that means we should try to eof - match encoder.end() { - Ok(end) => { - if let Some(end) = end { - self.io.buffer(end); - } - if encoder.is_last() { - Writing::Closed - } else { - Writing::KeepAlive - } - }, - Err(_not_eof) => Writing::Closed, - } + return; } }, _ => unreachable!("write_body invalid state: {:?}", self.state.writing), @@ -496,6 +476,61 @@ where I: AsyncRead + AsyncWrite, self.state.writing = state; } + pub fn write_body_and_end(&mut self, chunk: B) { + debug_assert!(self.can_write_body() && self.can_buffer_body()); + // empty chunks should be discarded at Dispatcher level + debug_assert!(chunk.remaining() != 0); + + let state = match self.state.writing { + Writing::Body(ref mut encoder) => { + self.io.buffer(encoder.encode(chunk)); + match encoder.end() { + Ok(end) => { + if let Some(end) = end { + self.io.buffer(end); + } + if encoder.is_last() { + Writing::Closed + } else { + Writing::KeepAlive + } + }, + Err(_not_eof) => Writing::Closed, + } + }, + _ => unreachable!("write_body invalid state: {:?}", self.state.writing), + }; + + self.state.writing = state; + } + + pub fn end_body(&mut self) { + debug_assert!(self.can_write_body()); + + let state = match self.state.writing { + Writing::Body(ref mut encoder) => { + // end of stream, that means we should try to eof + match encoder.end() { + Ok(end) => { + if let Some(end) = end { + self.io.buffer(end); + } + if encoder.is_last() { + Writing::Closed + } else { + Writing::KeepAlive + } + }, + Err(_not_eof) => Writing::Closed, + } + }, + _ => return, + }; + + self.state.writing = state; + + } + // When we get a parse error, depending on what side we are, we might be able // to write a response before closing the connection. // diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 8a7421b3..1f0ceef7 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -235,30 +235,39 @@ where } else if !self.conn.can_buffer_body() { try_ready!(self.poll_flush()); } else if let Some(mut body) = self.body_rx.take() { - let chunk = match body.poll_data().map_err(::Error::new_user_body)? { + if !self.conn.can_write_body() { + trace!( + "no more write body allowed, user body is_end_stream = {}", + body.is_end_stream(), + ); + continue; + } + match body.poll_data().map_err(::Error::new_user_body)? { Async::Ready(Some(chunk)) => { - self.body_rx = Some(body); - chunk + let eos = body.is_end_stream(); + if eos { + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + self.conn.end_body(); + } else { + self.conn.write_body_and_end(chunk); + } + } else { + self.body_rx = Some(body); + if chunk.remaining() == 0 { + trace!("discarding empty chunk"); + continue; + } + self.conn.write_body(chunk); + } }, Async::Ready(None) => { - if self.conn.can_write_body() { - self.conn.write_body(None); - } - continue; + self.conn.end_body(); }, Async::NotReady => { self.body_rx = Some(body); return Ok(Async::NotReady); } - }; - - if self.conn.can_write_body() { - self.conn.write_body(Some(chunk)); - // This allows when chunk is `None`, or `Some([])`. - } else if chunk.remaining() == 0 { - // ok - } else { - warn!("unexpected chunk when body cannot write"); } } else { return Ok(Async::NotReady); @@ -531,4 +540,25 @@ mod tests { Ok::<(), ()>(()) }).wait().unwrap(); } + + #[test] + fn body_empty_chunks_ignored() { + let _ = pretty_env_logger::try_init(); + ::futures::lazy(|| { + let io = AsyncIo::new_buf(vec![], 0); + let (mut tx, rx) = ::client::dispatch::channel(); + let conn = Conn::<_, ::Chunk, ClientTransaction>::new(io); + let mut dispatcher = Dispatcher::new(Client::new(rx), conn); + + // First poll is needed to allow tx to send... + assert!(dispatcher.poll().expect("nothing is ready").is_not_ready()); + + let body = ::Body::wrap_stream(::futures::stream::once(Ok::<_, ::Error>(""))); + + let _res_rx = tx.try_send(::Request::new(body)).unwrap(); + + dispatcher.poll().expect("empty body shouldn't panic"); + Ok::<(), ()>(()) + }).wait().unwrap(); + } } diff --git a/src/proto/h1/encode.rs b/src/proto/h1/encode.rs index c78bd255..2072f3a5 100644 --- a/src/proto/h1/encode.rs +++ b/src/proto/h1/encode.rs @@ -32,7 +32,7 @@ enum Kind { /// /// This is mostly only used with HTTP/1.0 with a length. This kind requires /// the connection to be closed when the body is finished. - Eof + CloseDelimited, } #[derive(Debug)] @@ -58,8 +58,8 @@ impl Encoder { Encoder::new(Kind::Length(len)) } - pub fn eof() -> Encoder { - Encoder::new(Kind::Eof) + pub fn close_delimited() -> Encoder { + Encoder::new(Kind::CloseDelimited) } pub fn is_eof(&self) -> bool { @@ -112,8 +112,8 @@ impl Encoder { BufKind::Exact(msg) } }, - Kind::Eof => { - trace!("eof write {}B", len); + Kind::CloseDelimited => { + trace!("close delimited write {}B", len); BufKind::Exact(msg) } }; @@ -323,7 +323,7 @@ mod tests { #[test] fn eof() { - let mut encoder = Encoder::eof(); + let mut encoder = Encoder::close_delimited(); let mut dst = Vec::new(); diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 3772bafb..cff7de82 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -514,13 +514,13 @@ fn set_length(headers: &mut HeaderMap, body: BodyLength, can_chunked: bool) -> E let encoder = if headers.remove(TRANSFER_ENCODING).is_some() { trace!("removing illegal transfer-encoding header"); should_remove_con_len = true; - Encoder::eof() + Encoder::close_delimited() } else if let Some(len) = existing_con_len { Encoder::length(len) } else if let BodyLength::Known(len) = body { set_content_length(headers, len) } else { - Encoder::eof() + Encoder::close_delimited() }; if should_remove_con_len && existing_con_len.is_some() {