diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 4ace6726..94ce2254 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -637,12 +637,6 @@ where I: AsyncRead + AsyncWrite + Unpin, trace!("{}: prepare possible HTTP upgrade", T::LOG); self.state.prepare_upgrade() } - - // Used in h1::dispatch tests - #[cfg(test)] - pub(super) fn io_mut(&mut self) -> &mut I { - self.io.io_mut() - } } impl fmt::Debug for Conn { diff --git a/src/proto/h1/decode.rs b/src/proto/h1/decode.rs index 3c1247e2..1215e8e6 100644 --- a/src/proto/h1/decode.rs +++ b/src/proto/h1/decode.rs @@ -143,6 +143,13 @@ impl Decoder { } } } + + #[cfg(test)] + async fn decode_fut(&mut self, body: &mut R) -> Result { + futures_util::future::poll_fn(move |cx| { + self.decode(cx, body) + }).await + } } @@ -319,65 +326,58 @@ impl StdError for IncompleteBody { #[cfg(test)] mod tests { - // FIXME: re-implement tests with `async/await`, this import should - // trigger a warning to remind us - use crate::Error; - /* - use std::io; - use std::io::Write; - use super::Decoder; - use super::ChunkedState; - use super::super::io::MemRead; - use futures::{Async, Poll}; - use bytes::{BytesMut, Bytes}; - use crate::mock::AsyncIo; + use std::time::Duration; + use std::pin::Pin; + use tokio_io::AsyncRead; + use super::*; impl<'a> MemRead for &'a [u8] { - fn read_mem(&mut self, len: usize) -> Poll> { + fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll> { let n = ::std::cmp::min(len, self.len()); if n > 0 { let (a, b) = self.split_at(n); - let mut buf = BytesMut::from(a); + let buf = Bytes::from(a); *self = b; - Poll::Ready(Ok(buf.split_to(n).freeze())) + Poll::Ready(Ok(buf)) } else { Poll::Ready(Ok(Bytes::new())) } } } - trait HelpUnwrap { - fn unwrap(self) -> T; - } - impl HelpUnwrap for Async { - fn unwrap(self) -> Bytes { - match self { - Async::Ready(bytes) => bytes, - Async::NotReady => panic!(), - } - } - } - impl HelpUnwrap for Async { - fn unwrap(self) -> ChunkedState { - match self { - Async::Ready(state) => state, - Async::NotReady => panic!(), - } + impl<'a> MemRead for &'a mut (dyn AsyncRead + Unpin) { + fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll> { + let mut v = vec![0; len]; + let n = ready!(Pin::new(self).poll_read(cx, &mut v)?); + Poll::Ready(Ok(Bytes::from(&v[..n]))) } } - #[test] - fn test_read_chunk_size() { + /* + use std::io; + use std::io::Write; + use super::Decoder; + use super::ChunkedState; + use futures::{Async, Poll}; + use bytes::{BytesMut, Bytes}; + use crate::mock::AsyncIo; + */ + + + #[tokio::test] + async fn test_read_chunk_size() { use std::io::ErrorKind::{UnexpectedEof, InvalidInput}; - fn read(s: &str) -> u64 { + async fn read(s: &str) -> u64 { let mut state = ChunkedState::Size; let rdr = &mut s.as_bytes(); let mut size = 0; loop { - let result = state.step(rdr, &mut size, &mut None); + let result = futures_util::future::poll_fn(|cx| { + state.step(cx, rdr, &mut size, &mut None) + }).await; let desc = format!("read_size failed for {:?}", s); - state = result.expect(desc.as_str()).unwrap(); + state = result.expect(desc.as_str()); if state == ChunkedState::Body || state == ChunkedState::EndCr { break; } @@ -385,14 +385,16 @@ mod tests { size } - fn read_err(s: &str, expected_err: io::ErrorKind) { + async fn read_err(s: &str, expected_err: io::ErrorKind) { let mut state = ChunkedState::Size; let rdr = &mut s.as_bytes(); let mut size = 0; loop { - let result = state.step(rdr, &mut size, &mut None); + let result = futures_util::future::poll_fn(|cx| { + state.step(cx, rdr, &mut size, &mut None) + }).await; state = match result { - Ok(s) => s.unwrap(), + Ok(s) => s, Err(e) => { assert!(expected_err == e.kind(), "Reading {:?}, expected {:?}, but got {:?}", s, expected_err, e.kind()); @@ -405,139 +407,150 @@ mod tests { } } - assert_eq!(1, read("1\r\n")); - assert_eq!(1, read("01\r\n")); - assert_eq!(0, read("0\r\n")); - assert_eq!(0, read("00\r\n")); - assert_eq!(10, read("A\r\n")); - assert_eq!(10, read("a\r\n")); - assert_eq!(255, read("Ff\r\n")); - assert_eq!(255, read("Ff \r\n")); + assert_eq!(1, read("1\r\n").await); + assert_eq!(1, read("01\r\n").await); + assert_eq!(0, read("0\r\n").await); + assert_eq!(0, read("00\r\n").await); + assert_eq!(10, read("A\r\n").await); + assert_eq!(10, read("a\r\n").await); + assert_eq!(255, read("Ff\r\n").await); + assert_eq!(255, read("Ff \r\n").await); // Missing LF or CRLF - read_err("F\rF", InvalidInput); - read_err("F", UnexpectedEof); + read_err("F\rF", InvalidInput).await; + read_err("F", UnexpectedEof).await; // Invalid hex digit - read_err("X\r\n", InvalidInput); - read_err("1X\r\n", InvalidInput); - read_err("-\r\n", InvalidInput); - read_err("-1\r\n", InvalidInput); + read_err("X\r\n", InvalidInput).await; + read_err("1X\r\n", InvalidInput).await; + read_err("-\r\n", InvalidInput).await; + read_err("-1\r\n", InvalidInput).await; // Acceptable (if not fully valid) extensions do not influence the size - assert_eq!(1, read("1;extension\r\n")); - assert_eq!(10, read("a;ext name=value\r\n")); - assert_eq!(1, read("1;extension;extension2\r\n")); - assert_eq!(1, read("1;;; ;\r\n")); - assert_eq!(2, read("2; extension...\r\n")); - assert_eq!(3, read("3 ; extension=123\r\n")); - assert_eq!(3, read("3 ;\r\n")); - assert_eq!(3, read("3 ; \r\n")); + assert_eq!(1, read("1;extension\r\n").await); + assert_eq!(10, read("a;ext name=value\r\n").await); + assert_eq!(1, read("1;extension;extension2\r\n").await); + assert_eq!(1, read("1;;; ;\r\n").await); + assert_eq!(2, read("2; extension...\r\n").await); + assert_eq!(3, read("3 ; extension=123\r\n").await); + assert_eq!(3, read("3 ;\r\n").await); + assert_eq!(3, read("3 ; \r\n").await); // Invalid extensions cause an error - read_err("1 invalid extension\r\n", InvalidInput); - read_err("1 A\r\n", InvalidInput); - read_err("1;no CRLF", UnexpectedEof); + read_err("1 invalid extension\r\n", InvalidInput).await; + read_err("1 A\r\n", InvalidInput).await; + read_err("1;no CRLF", UnexpectedEof).await; } - #[test] - fn test_read_sized_early_eof() { + #[tokio::test] + async fn test_read_sized_early_eof() { let mut bytes = &b"foo bar"[..]; let mut decoder = Decoder::length(10); - assert_eq!(decoder.decode(&mut bytes).unwrap().unwrap().len(), 7); - let e = decoder.decode(&mut bytes).unwrap_err(); + assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7); + let e = decoder.decode_fut(&mut bytes).await.unwrap_err(); assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); } - #[test] - fn test_read_chunked_early_eof() { + #[tokio::test] + async fn test_read_chunked_early_eof() { let mut bytes = &b"\ 9\r\n\ foo bar\ "[..]; let mut decoder = Decoder::chunked(); - assert_eq!(decoder.decode(&mut bytes).unwrap().unwrap().len(), 7); - let e = decoder.decode(&mut bytes).unwrap_err(); + assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7); + let e = decoder.decode_fut(&mut bytes).await.unwrap_err(); assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof); } - #[test] - fn test_read_chunked_single_read() { + #[tokio::test] + async fn test_read_chunked_single_read() { let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n"[..]; - let buf = Decoder::chunked().decode(&mut mock_buf).expect("decode").unwrap(); + let buf = Decoder::chunked().decode_fut(&mut mock_buf).await.expect("decode"); assert_eq!(16, buf.len()); let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); assert_eq!("1234567890abcdef", &result); } - #[test] - fn test_read_chunked_after_eof() { + #[tokio::test] + async fn test_read_chunked_after_eof() { let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n\r\n"[..]; let mut decoder = Decoder::chunked(); // normal read - let buf = decoder.decode(&mut mock_buf).expect("decode").unwrap(); + let buf = decoder.decode_fut(&mut mock_buf).await.unwrap(); assert_eq!(16, buf.len()); let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String"); assert_eq!("1234567890abcdef", &result); // eof read - let buf = decoder.decode(&mut mock_buf).expect("decode").unwrap(); + let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode"); assert_eq!(0, buf.len()); // ensure read after eof also returns eof - let buf = decoder.decode(&mut mock_buf).expect("decode").unwrap(); + let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode"); assert_eq!(0, buf.len()); } // perform an async read using a custom buffer size and causing a blocking // read at the specified byte - fn read_async(mut decoder: Decoder, + async fn read_async(mut decoder: Decoder, content: &[u8], block_at: usize) -> String { - let content_len = content.len(); - let mut ins = AsyncIo::new(content, block_at); let mut outs = Vec::new(); + + let mut ins = if block_at == 0 { + tokio_test::io::Builder::new() + .wait(Duration::from_millis(10)) + .read(content) + .build() + } else { + tokio_test::io::Builder::new() + .read(&content[..block_at]) + .wait(Duration::from_millis(10)) + .read(&content[block_at..]) + .build() + }; + + let mut ins = &mut ins as &mut (dyn AsyncRead + Unpin); + loop { - match decoder.decode(&mut ins).expect("unexpected decode error: {}") { - Async::Ready(buf) => { - if buf.is_empty() { - break; // eof - } - outs.write(buf.as_ref()).expect("write buffer"); - }, - Async::NotReady => { - ins.block_in(content_len); // we only block once - } - }; + let buf = decoder + .decode_fut(&mut ins) + .await + .expect("unexpected decode error"); + if buf.is_empty() { + break; // eof + } + outs.extend(buf.as_ref()); } + String::from_utf8(outs).expect("decode String") } // iterate over the different ways that this async read could go. // tests blocking a read at each byte along the content - The shotgun approach - fn all_async_cases(content: &str, expected: &str, decoder: Decoder) { + async fn all_async_cases(content: &str, expected: &str, decoder: Decoder) { let content_len = content.len(); for block_at in 0..content_len { - let actual = read_async(decoder.clone(), content.as_bytes(), block_at); + let actual = read_async(decoder.clone(), content.as_bytes(), block_at).await; assert_eq!(expected, &actual) //, "Failed async. Blocking at {}", block_at); } } - #[test] - fn test_read_length_async() { + #[tokio::test] + async fn test_read_length_async() { let content = "foobar"; - all_async_cases(content, content, Decoder::length(content.len() as u64)); + all_async_cases(content, content, Decoder::length(content.len() as u64)).await; } - #[test] - fn test_read_chunked_async() { + #[tokio::test] + async fn test_read_chunked_async() { let content = "3\r\nfoo\r\n3\r\nbar\r\n0\r\n\r\n"; let expected = "foobar"; - all_async_cases(content, expected, Decoder::chunked()); + all_async_cases(content, expected, Decoder::chunked()).await; } - #[test] - fn test_read_eof_async() { + #[tokio::test] + async fn test_read_eof_async() { let content = "foobar"; - all_async_cases(content, content, Decoder::eof()); + all_async_cases(content, content, Decoder::eof()).await; } - */ }