diff --git a/src/mock.rs b/src/mock.rs index e44351ee..30bd4a6e 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -1,7 +1,8 @@ use std::cmp; use std::io::{self, Read, Write}; -use futures::Poll; +use bytes::Buf as BufTrait; +use futures::{Async, Poll}; use tokio_io::{AsyncRead, AsyncWrite}; #[derive(Debug)] @@ -11,10 +12,6 @@ pub struct Buf { } impl Buf { - pub fn new() -> Buf { - Buf::wrap(vec![]) - } - pub fn wrap(vec: Vec) -> Buf { Buf { vec: vec, @@ -63,23 +60,27 @@ impl Read for Buf { } } +const READ_VECS_CNT: usize = 64; + #[derive(Debug)] pub struct AsyncIo { - inner: T, + blocked: bool, bytes_until_block: usize, error: Option, - blocked: bool, flushed: bool, + inner: T, + max_read_vecs: usize, } impl AsyncIo { pub fn new(inner: T, bytes: usize) -> AsyncIo { AsyncIo { - inner: inner, + blocked: false, bytes_until_block: bytes, error: None, flushed: false, - blocked: false, + inner: inner, + max_read_vecs: READ_VECS_CNT, } } @@ -90,6 +91,22 @@ impl AsyncIo { pub fn error(&mut self, err: io::Error) { self.error = Some(err); } + + pub fn max_read_vecs(&mut self, cnt: usize) { + assert!(cnt <= READ_VECS_CNT); + self.max_read_vecs = cnt; + } + + #[cfg(feature = "tokio-proto")] + //TODO: fix proto::conn::tests to not use tokio-proto API, + //and then this cfg flag go away + pub fn flushed(&self) -> bool { + self.flushed + } + + pub fn blocked(&self) -> bool { + self.blocked + } } impl AsyncIo { @@ -103,16 +120,17 @@ impl AsyncIo { pub fn new_eof() -> AsyncIo { AsyncIo::new(Buf::wrap(Vec::new().into()), 1) } +} - #[cfg(feature = "tokio-proto")] - //TODO: fix proto::conn::tests to not use tokio-proto API, - //and then this cfg flag go away - pub fn flushed(&self) -> bool { - self.flushed - } +impl AsyncIo { + fn write_no_vecs(&mut self, buf: &mut B) -> Poll { + if !buf.has_remaining() { + return Ok(Async::Ready(0)); + } - pub fn blocked(&self) -> bool { - self.blocked + let n = try_nb!(self.write(buf.bytes())); + buf.advance(n); + Ok(Async::Ready(n)) } } @@ -170,12 +188,14 @@ impl AsyncWrite for AsyncIo { Ok(().into()) } - fn write_buf(&mut self, buf: &mut B) -> Poll { - use futures::Async; + fn write_buf(&mut self, buf: &mut B) -> Poll { + if self.max_read_vecs == 0 { + return self.write_no_vecs(buf); + } let r = { static DUMMY: &[u8] = &[0]; - let mut bufs = [From::from(DUMMY); 64]; - let i = ::bytes::Buf::bytes_vec(&buf, &mut bufs); + let mut bufs = [From::from(DUMMY); READ_VECS_CNT]; + let i = ::bytes::Buf::bytes_vec(&buf, &mut bufs[..self.max_read_vecs]); let mut n = 0; let mut ret = Ok(0); for iovec in &bufs[..i] { diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index 55ef12ce..0b1b6dba 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -440,6 +440,9 @@ impl Buf for BufDeque { return buf.bytes(); } } + if let Some(ref buf) = self.bufs.front() { + return buf.bytes(); + } &[] } @@ -483,42 +486,54 @@ impl Buf for BufDeque { } } -// TODO: Move tests to their own mod #[cfg(test)] -use std::io::Read; +mod tests { + use super::*; + use std::io::Read; + use mock::AsyncIo; -#[cfg(test)] -impl MemRead for ::mock::AsyncIo { - fn read_mem(&mut self, len: usize) -> Poll { - let mut v = vec![0; len]; - let n = try_nb!(self.read(v.as_mut_slice())); - Ok(Async::Ready(BytesMut::from(&v[..n]).freeze())) + #[cfg(test)] + impl MemRead for ::mock::AsyncIo { + fn read_mem(&mut self, len: usize) -> Poll { + let mut v = vec![0; len]; + let n = try_nb!(self.read(v.as_mut_slice())); + Ok(Async::Ready(BytesMut::from(&v[..n]).freeze())) + } + } + + #[test] + fn iobuf_write_empty_slice() { + let mut mock = AsyncIo::new_buf(vec![], 256); + mock.error(io::Error::new(io::ErrorKind::Other, "logic error")); + + let mut io_buf = Buffered::<_, Cursor>>::new(mock); + + // underlying io will return the logic error upon write, + // so we are testing that the io_buf does not trigger a write + // when there is nothing to flush + io_buf.flush().expect("should short-circuit flush"); + } + + #[test] + fn parse_reads_until_blocked() { + // missing last line ending + let raw = "HTTP/1.1 200 OK\r\n"; + + let mock = AsyncIo::new_buf(raw, raw.len()); + let mut buffered = Buffered::<_, Cursor>>::new(mock); + assert_eq!(buffered.parse::<::proto::ClientTransaction>().unwrap(), Async::NotReady); + assert!(buffered.io.blocked()); + } + + #[test] + fn write_buf_skips_empty_bufs() { + let mut mock = AsyncIo::new_buf(vec![], 1024); + mock.max_read_vecs(0); // disable vectored IO + let mut buffered = Buffered::<_, Cursor>>::new(mock); + + buffered.buffer(Cursor::new(Vec::new())); + buffered.buffer(Cursor::new(b"hello".to_vec())); + buffered.flush().unwrap(); + assert_eq!(buffered.io, b"hello"); } } - -#[test] -fn test_iobuf_write_empty_slice() { - use mock::{AsyncIo, Buf as MockBuf}; - - let mut mock = AsyncIo::new(MockBuf::new(), 256); - mock.error(io::Error::new(io::ErrorKind::Other, "logic error")); - - let mut io_buf = Buffered::<_, Cursor>>::new(mock); - - // underlying io will return the logic error upon write, - // so we are testing that the io_buf does not trigger a write - // when there is nothing to flush - io_buf.flush().expect("should short-circuit flush"); -} - -#[test] -fn test_parse_reads_until_blocked() { - use mock::{AsyncIo, Buf as MockBuf}; - // missing last line ending - let raw = "HTTP/1.1 200 OK\r\n"; - - let mock = AsyncIo::new(MockBuf::wrap(raw.into()), raw.len()); - let mut buffered = Buffered::<_, Cursor>>::new(mock); - assert_eq!(buffered.parse::<::proto::ClientTransaction>().unwrap(), Async::NotReady); - assert!(buffered.io.blocked()); -} diff --git a/tests/client.rs b/tests/client.rs index 8183a1d2..523f1bcf 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1044,45 +1044,6 @@ mod dispatch_impl { assert_eq!(closes.load(Ordering::Relaxed), 1); } - #[test] - fn client_body_mpsc() { - use futures::Sink; - let _ = pretty_env_logger::try_init(); - let server = TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = server.local_addr().unwrap(); - let mut core = Core::new().unwrap(); - let handle = core.handle(); - let closes = Arc::new(AtomicUsize::new(0)); - - let (tx1, rx1) = oneshot::channel(); - - thread::spawn(move || { - let mut sock = server.accept().unwrap().0; - sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); - sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); - let mut buf = [0; 4096]; - sock.read(&mut buf).expect("read 1"); - sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n").unwrap(); - let _ = tx1.send(()); - }); - - let uri = format!("http://{}/a", addr).parse().unwrap(); - - let client = Client::configure() - .connector(DebugConnector(HttpConnector::new(1, &handle), closes.clone())) - .build(&handle); - let mut req = Request::new(Method::Post, uri); - let (tx, body) = hyper::Body::pair(); - req.set_body(body); - let res = client.request(req).and_then(move |res| { - assert_eq!(res.status(), hyper::StatusCode::Ok); - res.body().concat2() - }); - let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked"))); - let send = tx.send_all(::futures::stream::iter_ok(vec!["hello"; 2]).map(hyper::Chunk::from).map(Ok)).then(|_| Ok(())); - core.run(res.join(send).join(rx).map(|r| r.0)).unwrap(); - } - struct DebugConnector(HttpConnector, Arc); impl Service for DebugConnector {