refactor(http): Decoder::decode and MemRead in io now return Poll

This commit is contained in:
Yazad Daruvala
2017-06-18 20:17:01 -07:00
parent ca1fa81ce0
commit 80f16f1917
3 changed files with 100 additions and 84 deletions

View File

@@ -136,7 +136,7 @@ where I: AsyncRead + AsyncWrite,
let (reading, ret) = match self.state.reading { let (reading, ret) = match self.state.reading {
Reading::Body(ref mut decoder) => { Reading::Body(ref mut decoder) => {
let slice = try_nb!(decoder.decode(&mut self.io)); let slice = try_ready!(decoder.decode(&mut self.io));
if !slice.is_empty() { if !slice.is_empty() {
return Ok(Async::Ready(Some(http::Chunk::from(slice)))); return Ok(Async::Ready(Some(http::Chunk::from(slice))));
} else if decoder.is_eof() { } else if decoder.is_eof() {

View File

@@ -1,6 +1,7 @@
use std::usize; use std::usize;
use std::io; use std::io;
use futures::{Async, Poll};
use bytes::Bytes; use bytes::Bytes;
use http::io::MemRead; use http::io::MemRead;
@@ -79,15 +80,15 @@ impl Decoder {
} }
impl Decoder { impl Decoder {
pub fn decode<R: MemRead>(&mut self, body: &mut R) -> io::Result<Bytes> { pub fn decode<R: MemRead>(&mut self, body: &mut R) -> Poll<Bytes, io::Error> {
match self.kind { match self.kind {
Length(ref mut remaining) => { Length(ref mut remaining) => {
trace!("Sized read, remaining={:?}", remaining); trace!("Sized read, remaining={:?}", remaining);
if *remaining == 0 { if *remaining == 0 {
Ok(Bytes::new()) Ok(Async::Ready(Bytes::new()))
} else { } else {
let to_read = *remaining as usize; let to_read = *remaining as usize;
let buf = try!(body.read_mem(to_read)); let buf = try_ready!(body.read_mem(to_read));
let num = buf.as_ref().len() as u64; let num = buf.as_ref().len() as u64;
trace!("Length read: {}", num); trace!("Length read: {}", num);
if num > *remaining { if num > *remaining {
@@ -97,37 +98,33 @@ impl Decoder {
} else { } else {
*remaining -= num; *remaining -= num;
} }
Ok(buf) Ok(Async::Ready(buf))
} }
} }
Chunked(ref mut state, ref mut size) => { Chunked(ref mut state, ref mut size) => {
loop { loop {
let mut buf = None; let mut buf = None;
// advances the chunked state // advances the chunked state
*state = try!(state.step(body, size, &mut buf)); *state = try_ready!(state.step(body, size, &mut buf));
if *state == ChunkedState::End { if *state == ChunkedState::End {
trace!("end of chunked"); trace!("end of chunked");
return Ok(Bytes::new()); return Ok(Async::Ready(Bytes::new()));
} }
if let Some(buf) = buf { if let Some(buf) = buf {
return Ok(buf); return Ok(Async::Ready(buf));
} }
} }
} }
Eof(ref mut is_eof) => { Eof(ref mut is_eof) => {
if *is_eof { if *is_eof {
Ok(Bytes::new()) Ok(Async::Ready(Bytes::new()))
} else { } else {
// 8192 chosen because its about 2 packets, there probably // 8192 chosen because its about 2 packets, there probably
// won't be that much available, so don't have MemReaders // won't be that much available, so don't have MemReaders
// allocate buffers to big // allocate buffers to big
match body.read_mem(8192) { let slice = try_ready!(body.read_mem(8192));
Ok(slice) => { *is_eof = slice.is_empty();
*is_eof = slice.is_empty(); Ok(Async::Ready(slice))
Ok(slice)
}
other => other,
}
} }
} }
} }
@@ -136,7 +133,7 @@ impl Decoder {
macro_rules! byte ( macro_rules! byte (
($rdr:ident) => ({ ($rdr:ident) => ({
let buf = try!($rdr.read_mem(1)); let buf = try_ready!($rdr.read_mem(1));
if !buf.is_empty() { if !buf.is_empty() {
buf[0] buf[0]
} else { } else {
@@ -151,22 +148,22 @@ impl ChunkedState {
body: &mut R, body: &mut R,
size: &mut u64, size: &mut u64,
buf: &mut Option<Bytes>) buf: &mut Option<Bytes>)
-> io::Result<ChunkedState> { -> Poll<ChunkedState, io::Error> {
use self::ChunkedState::*; use self::ChunkedState::*;
Ok(match *self { match *self {
Size => try!(ChunkedState::read_size(body, size)), Size => ChunkedState::read_size(body, size),
SizeLws => try!(ChunkedState::read_size_lws(body)), SizeLws => ChunkedState::read_size_lws(body),
Extension => try!(ChunkedState::read_extension(body)), Extension => ChunkedState::read_extension(body),
SizeLf => try!(ChunkedState::read_size_lf(body, size)), SizeLf => ChunkedState::read_size_lf(body, size),
Body => try!(ChunkedState::read_body(body, size, buf)), Body => ChunkedState::read_body(body, size, buf),
BodyCr => try!(ChunkedState::read_body_cr(body)), BodyCr => ChunkedState::read_body_cr(body),
BodyLf => try!(ChunkedState::read_body_lf(body)), BodyLf => ChunkedState::read_body_lf(body),
EndCr => try!(ChunkedState::read_end_cr(body)), EndCr => ChunkedState::read_end_cr(body),
EndLf => try!(ChunkedState::read_end_lf(body)), EndLf => ChunkedState::read_end_lf(body),
End => ChunkedState::End, End => Ok(Async::Ready(ChunkedState::End)),
}) }
} }
fn read_size<R: MemRead>(rdr: &mut R, size: &mut u64) -> io::Result<ChunkedState> { fn read_size<R: MemRead>(rdr: &mut R, size: &mut u64) -> Poll<ChunkedState, io::Error> {
trace!("Read chunk hex size"); trace!("Read chunk hex size");
let radix = 16; let radix = 16;
match byte!(rdr) { match byte!(rdr) {
@@ -182,41 +179,41 @@ impl ChunkedState {
*size *= radix; *size *= radix;
*size += (b + 10 - b'A') as u64; *size += (b + 10 - b'A') as u64;
} }
b'\t' | b' ' => return Ok(ChunkedState::SizeLws), b'\t' | b' ' => return Ok(Async::Ready(ChunkedState::SizeLws)),
b';' => return Ok(ChunkedState::Extension), b';' => return Ok(Async::Ready(ChunkedState::Extension)),
b'\r' => return Ok(ChunkedState::SizeLf), b'\r' => return Ok(Async::Ready(ChunkedState::SizeLf)),
_ => { _ => {
return Err(io::Error::new(io::ErrorKind::InvalidInput, return Err(io::Error::new(io::ErrorKind::InvalidInput,
"Invalid chunk size line: Invalid Size")); "Invalid chunk size line: Invalid Size"));
} }
} }
Ok(ChunkedState::Size) Ok(Async::Ready(ChunkedState::Size))
} }
fn read_size_lws<R: MemRead>(rdr: &mut R) -> io::Result<ChunkedState> { fn read_size_lws<R: MemRead>(rdr: &mut R) -> Poll<ChunkedState, io::Error> {
trace!("read_size_lws"); trace!("read_size_lws");
match byte!(rdr) { match byte!(rdr) {
// LWS can follow the chunk size, but no more digits can come // LWS can follow the chunk size, but no more digits can come
b'\t' | b' ' => Ok(ChunkedState::SizeLws), b'\t' | b' ' => Ok(Async::Ready(ChunkedState::SizeLws)),
b';' => Ok(ChunkedState::Extension), b';' => Ok(Async::Ready(ChunkedState::Extension)),
b'\r' => Ok(ChunkedState::SizeLf), b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)),
_ => { _ => {
Err(io::Error::new(io::ErrorKind::InvalidInput, Err(io::Error::new(io::ErrorKind::InvalidInput,
"Invalid chunk size linear white space")) "Invalid chunk size linear white space"))
} }
} }
} }
fn read_extension<R: MemRead>(rdr: &mut R) -> io::Result<ChunkedState> { fn read_extension<R: MemRead>(rdr: &mut R) -> Poll<ChunkedState, io::Error> {
trace!("read_extension"); trace!("read_extension");
match byte!(rdr) { match byte!(rdr) {
b'\r' => Ok(ChunkedState::SizeLf), b'\r' => Ok(Async::Ready(ChunkedState::SizeLf)),
_ => Ok(ChunkedState::Extension), // no supported extensions _ => Ok(Async::Ready(ChunkedState::Extension)), // no supported extensions
} }
} }
fn read_size_lf<R: MemRead>(rdr: &mut R, size: &mut u64) -> io::Result<ChunkedState> { fn read_size_lf<R: MemRead>(rdr: &mut R, size: &mut u64) -> Poll<ChunkedState, io::Error> {
trace!("Chunk size is {:?}", size); trace!("Chunk size is {:?}", size);
match byte!(rdr) { match byte!(rdr) {
b'\n' if *size > 0 => Ok(ChunkedState::Body), b'\n' if *size > 0 => Ok(Async::Ready(ChunkedState::Body)),
b'\n' if *size == 0 => Ok(ChunkedState::EndCr), b'\n' if *size == 0 => Ok(Async::Ready(ChunkedState::EndCr)),
_ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk size LF")), _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk size LF")),
} }
} }
@@ -224,7 +221,7 @@ impl ChunkedState {
fn read_body<R: MemRead>(rdr: &mut R, fn read_body<R: MemRead>(rdr: &mut R,
rem: &mut u64, rem: &mut u64,
buf: &mut Option<Bytes>) buf: &mut Option<Bytes>)
-> io::Result<ChunkedState> { -> Poll<ChunkedState, io::Error> {
trace!("Chunked read, remaining={:?}", rem); trace!("Chunked read, remaining={:?}", rem);
// cap remaining bytes at the max capacity of usize // cap remaining bytes at the max capacity of usize
@@ -234,7 +231,7 @@ impl ChunkedState {
}; };
let to_read = rem_cap; let to_read = rem_cap;
let slice = try!(rdr.read_mem(to_read)); let slice = try_ready!(rdr.read_mem(to_read));
let count = slice.len(); let count = slice.len();
if count == 0 { if count == 0 {
@@ -245,33 +242,33 @@ impl ChunkedState {
*rem -= count as u64; *rem -= count as u64;
if *rem > 0 { if *rem > 0 {
Ok(ChunkedState::Body) Ok(Async::Ready(ChunkedState::Body))
} else { } else {
Ok(ChunkedState::BodyCr) Ok(Async::Ready(ChunkedState::BodyCr))
} }
} }
fn read_body_cr<R: MemRead>(rdr: &mut R) -> io::Result<ChunkedState> { fn read_body_cr<R: MemRead>(rdr: &mut R) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\r' => Ok(ChunkedState::BodyLf), b'\r' => Ok(Async::Ready(ChunkedState::BodyLf)),
_ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body CR")), _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body CR")),
} }
} }
fn read_body_lf<R: MemRead>(rdr: &mut R) -> io::Result<ChunkedState> { fn read_body_lf<R: MemRead>(rdr: &mut R) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\n' => Ok(ChunkedState::Size), b'\n' => Ok(Async::Ready(ChunkedState::Size)),
_ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body LF")), _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk body LF")),
} }
} }
fn read_end_cr<R: MemRead>(rdr: &mut R) -> io::Result<ChunkedState> { fn read_end_cr<R: MemRead>(rdr: &mut R) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\r' => Ok(ChunkedState::EndLf), b'\r' => Ok(Async::Ready(ChunkedState::EndLf)),
_ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk end CR")), _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk end CR")),
} }
} }
fn read_end_lf<R: MemRead>(rdr: &mut R) -> io::Result<ChunkedState> { fn read_end_lf<R: MemRead>(rdr: &mut R) -> Poll<ChunkedState, io::Error> {
match byte!(rdr) { match byte!(rdr) {
b'\n' => Ok(ChunkedState::End), b'\n' => Ok(Async::Ready(ChunkedState::End)),
_ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk end LF")), _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid chunk end LF")),
} }
} }
@@ -285,19 +282,40 @@ mod tests {
use super::Decoder; use super::Decoder;
use super::ChunkedState; use super::ChunkedState;
use http::io::MemRead; use http::io::MemRead;
use futures::{Async, Poll};
use bytes::{BytesMut, Bytes}; use bytes::{BytesMut, Bytes};
use mock::AsyncIo; use mock::AsyncIo;
impl<'a> MemRead for &'a [u8] { impl<'a> MemRead for &'a [u8] {
fn read_mem(&mut self, len: usize) -> io::Result<Bytes> { fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error> {
let n = ::std::cmp::min(len, self.len()); let n = ::std::cmp::min(len, self.len());
if n > 0 { if n > 0 {
let (a, b) = self.split_at(n); let (a, b) = self.split_at(n);
let mut buf = BytesMut::from(a); let mut buf = BytesMut::from(a);
*self = b; *self = b;
Ok(buf.split_to(n).freeze()) Ok(Async::Ready(buf.split_to(n).freeze()))
} else { } else {
Ok(Bytes::new()) Ok(Async::Ready(Bytes::new()))
}
}
}
trait HelpUnwrap<T> {
fn unwrap(self) -> T;
}
impl HelpUnwrap<Bytes> for Async<Bytes> {
fn unwrap(self) -> Bytes {
match self {
Async::Ready(bytes) => bytes,
Async::NotReady => panic!(),
}
}
}
impl HelpUnwrap<ChunkedState> for Async<ChunkedState> {
fn unwrap(self) -> ChunkedState {
match self {
Async::Ready(state) => state,
Async::NotReady => panic!(),
} }
} }
} }
@@ -313,7 +331,7 @@ mod tests {
loop { loop {
let result = state.step(rdr, &mut size, &mut None); let result = state.step(rdr, &mut size, &mut None);
let desc = format!("read_size failed for {:?}", s); let desc = format!("read_size failed for {:?}", s);
state = result.expect(desc.as_str()); state = result.expect(desc.as_str()).unwrap();
if state == ChunkedState::Body || state == ChunkedState::EndCr { if state == ChunkedState::Body || state == ChunkedState::EndCr {
break; break;
} }
@@ -328,7 +346,7 @@ mod tests {
loop { loop {
let result = state.step(rdr, &mut size, &mut None); let result = state.step(rdr, &mut size, &mut None);
state = match result { state = match result {
Ok(s) => s, Ok(s) => s.unwrap(),
Err(e) => { Err(e) => {
assert!(expected_err == e.kind(), "Reading {:?}, expected {:?}, but got {:?}", assert!(expected_err == e.kind(), "Reading {:?}, expected {:?}, but got {:?}",
s, expected_err, e.kind()); s, expected_err, e.kind());
@@ -376,7 +394,7 @@ mod tests {
fn test_read_sized_early_eof() { fn test_read_sized_early_eof() {
let mut bytes = &b"foo bar"[..]; let mut bytes = &b"foo bar"[..];
let mut decoder = Decoder::length(10); let mut decoder = Decoder::length(10);
assert_eq!(decoder.decode(&mut bytes).unwrap().len(), 7); assert_eq!(decoder.decode(&mut bytes).unwrap().unwrap().len(), 7);
let e = decoder.decode(&mut bytes).unwrap_err(); let e = decoder.decode(&mut bytes).unwrap_err();
assert_eq!(e.kind(), io::ErrorKind::Other); assert_eq!(e.kind(), io::ErrorKind::Other);
assert_eq!(e.description(), "early eof"); assert_eq!(e.description(), "early eof");
@@ -389,7 +407,7 @@ mod tests {
foo bar\ foo bar\
"[..]; "[..];
let mut decoder = Decoder::chunked(); let mut decoder = Decoder::chunked();
assert_eq!(decoder.decode(&mut bytes).unwrap().len(), 7); assert_eq!(decoder.decode(&mut bytes).unwrap().unwrap().len(), 7);
let e = decoder.decode(&mut bytes).unwrap_err(); let e = decoder.decode(&mut bytes).unwrap_err();
assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof);
assert_eq!(e.description(), "early eof"); assert_eq!(e.description(), "early eof");
@@ -398,7 +416,7 @@ mod tests {
#[test] #[test]
fn test_read_chunked_single_read() { fn test_read_chunked_single_read() {
let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n"[..]; let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n"[..];
let buf = Decoder::chunked().decode(&mut mock_buf).expect("decode"); let buf = Decoder::chunked().decode(&mut mock_buf).expect("decode").unwrap();
assert_eq!(16, buf.len()); assert_eq!(16, buf.len());
let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
assert_eq!("1234567890abcdef", &result); assert_eq!("1234567890abcdef", &result);
@@ -410,17 +428,17 @@ mod tests {
let mut decoder = Decoder::chunked(); let mut decoder = Decoder::chunked();
// normal read // normal read
let buf = decoder.decode(&mut mock_buf).expect("decode"); let buf = decoder.decode(&mut mock_buf).expect("decode").unwrap();
assert_eq!(16, buf.len()); assert_eq!(16, buf.len());
let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
assert_eq!("1234567890abcdef", &result); assert_eq!("1234567890abcdef", &result);
// eof read // eof read
let buf = decoder.decode(&mut mock_buf).expect("decode"); let buf = decoder.decode(&mut mock_buf).expect("decode").unwrap();
assert_eq!(0, buf.len()); assert_eq!(0, buf.len());
// ensure read after eof also returns eof // ensure read after eof also returns eof
let buf = decoder.decode(&mut mock_buf).expect("decode"); let buf = decoder.decode(&mut mock_buf).expect("decode").unwrap();
assert_eq!(0, buf.len()); assert_eq!(0, buf.len());
} }
@@ -434,18 +452,15 @@ mod tests {
let mut ins = AsyncIo::new(content, block_at); let mut ins = AsyncIo::new(content, block_at);
let mut outs = Vec::new(); let mut outs = Vec::new();
loop { loop {
match decoder.decode(&mut ins) { match decoder.decode(&mut ins).expect("unexpected decode error: {}") {
Ok(buf) => { Async::Ready(buf) => {
if buf.is_empty() { if buf.is_empty() {
break; // eof break; // eof
} }
outs.write(buf.as_ref()).expect("write buffer"); outs.write(buf.as_ref()).expect("write buffer");
} },
Err(e) => match e.kind() { Async::NotReady => {
io::ErrorKind::WouldBlock => { ins.block_in(content_len); // we only block once
ins.block_in(content_len); // we only block once
},
_ => panic!("unexpected decode error: {}", e),
} }
}; };
} }

View File

@@ -3,6 +3,7 @@ use std::fmt;
use std::io::{self, Write}; use std::io::{self, Write};
use std::ptr; use std::ptr;
use futures::{Async, Poll};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use http::{Http1Transaction, MessageHead, DebugTruncate}; use http::{Http1Transaction, MessageHead, DebugTruncate};
@@ -147,19 +148,19 @@ impl<T: Write> Write for Buffered<T> {
} }
pub trait MemRead { pub trait MemRead {
fn read_mem(&mut self, len: usize) -> io::Result<Bytes>; fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error>;
} }
impl<T: AsyncRead + AsyncWrite> MemRead for Buffered<T> { impl<T: AsyncRead + AsyncWrite> MemRead for Buffered<T> {
fn read_mem(&mut self, len: usize) -> io::Result<Bytes> { fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error> {
trace!("Buffered.read_mem read_buf={}, wanted={}", self.read_buf.len(), len); trace!("Buffered.read_mem read_buf={}, wanted={}", self.read_buf.len(), len);
if !self.read_buf.is_empty() { if !self.read_buf.is_empty() {
let n = ::std::cmp::min(len, self.read_buf.len()); let n = ::std::cmp::min(len, self.read_buf.len());
trace!("Buffered.read_mem read_buf is not empty, slicing {}", n); trace!("Buffered.read_mem read_buf is not empty, slicing {}", n);
Ok(self.read_buf.split_to(n).freeze()) Ok(Async::Ready(self.read_buf.split_to(n).freeze()))
} else { } else {
let n = try!(self.read_from_io()); let n = try_nb!(self.read_from_io());
Ok(self.read_buf.split_to(::std::cmp::min(len, n)).freeze()) Ok(Async::Ready(self.read_buf.split_to(::std::cmp::min(len, n)).freeze()))
} }
} }
} }
@@ -327,10 +328,10 @@ use std::io::Read;
#[cfg(test)] #[cfg(test)]
impl<T: Read> MemRead for ::mock::AsyncIo<T> { impl<T: Read> MemRead for ::mock::AsyncIo<T> {
fn read_mem(&mut self, len: usize) -> io::Result<Bytes> { fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error> {
let mut v = vec![0; len]; let mut v = vec![0; len];
let n = try!(self.read(v.as_mut_slice())); let n = try_nb!(self.read(v.as_mut_slice()));
Ok(BytesMut::from(&v[..n]).freeze()) Ok(Async::Ready(BytesMut::from(&v[..n]).freeze()))
} }
} }