diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index edb55ab7..a8fa7e0e 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -239,7 +239,9 @@ where I: AsyncRead + AsyncWrite + Unpin, pub fn poll_read_keep_alive(&mut self, cx: &mut task::Context<'_>) -> Poll> { debug_assert!(!self.can_read_head() && !self.can_read_body()); - if self.is_mid_message() { + if self.is_read_closed() { + Poll::Pending + } else if self.is_mid_message() { self.mid_message_detect_eof(cx) } else { self.require_empty_read(cx) @@ -258,7 +260,7 @@ where I: AsyncRead + AsyncWrite + Unpin, // This should only be called for Clients wanting to enter the idle // state. fn require_empty_read(&mut self, cx: &mut task::Context<'_>) -> Poll> { - debug_assert!(!self.can_read_head() && !self.can_read_body()); + debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed()); debug_assert!(!self.is_mid_message()); debug_assert!(T::is_client()); @@ -288,17 +290,13 @@ where I: AsyncRead + AsyncWrite + Unpin, } fn mid_message_detect_eof(&mut self, cx: &mut task::Context<'_>) -> Poll> { - debug_assert!(!self.can_read_head() && !self.can_read_body()); + debug_assert!(!self.can_read_head() && !self.can_read_body() && !self.is_read_closed()); debug_assert!(self.is_mid_message()); if self.state.allow_half_close || !self.io.read_buf().is_empty() { return Poll::Pending; } - if self.state.is_read_closed() { - return Poll::Ready(Err(crate::Error::new_incomplete())); - } - let num_read = ready!(self.force_io_read(cx)).map_err(crate::Error::new_io)?; if num_read == 0 { @@ -347,7 +345,17 @@ where I: AsyncRead + AsyncWrite + Unpin, if !self.io.is_read_blocked() { if self.io.read_buf().is_empty() { match self.io.poll_read_from_io(cx) { - Poll::Ready(Ok(_)) => (), + Poll::Ready(Ok(n)) => { + if n == 0 { + trace!("maybe_notify; read eof"); + if self.state.is_idle() { + self.state.close(); + } else { + self.close_read() + } + return; + } + }, Poll::Pending => { trace!("maybe_notify; read_from_io blocked"); return @@ -355,6 +363,7 @@ where I: AsyncRead + AsyncWrite + Unpin, Poll::Ready(Err(e)) => { trace!("maybe_notify; read_from_io error: {}", e); self.state.close(); + self.state.error = Some(crate::Error::new_io(e)); } } } @@ -716,6 +725,10 @@ impl fmt::Debug for State { builder.field("error", error); } + if self.allow_half_close { + builder.field("allow_half_close", &true); + } + // Purposefully leaving off other fields.. builder.finish() diff --git a/tests/client.rs b/tests/client.rs index f138512f..7b125843 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1790,28 +1790,24 @@ mod conn { use futures_util::try_future::TryFutureExt; use futures_util::try_stream::TryStreamExt; use tokio::runtime::current_thread::Runtime; - use tokio_io::{AsyncRead, AsyncWrite}; - use tokio_net::tcp::TcpStream; + use tokio_io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _}; + use tokio_net::tcp::{TcpListener as TkTcpListener, TcpStream}; use hyper::{self, Request, Body, Method}; use hyper::client::conn; use super::{s, tcp_connect, FutureHyperExt}; - #[test] - fn get() { - let server = TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = server.local_addr().unwrap(); - let mut rt = Runtime::new().unwrap(); + #[tokio::test] + async fn get() { + let _ = ::pretty_env_logger::try_init(); + let mut listener = TkTcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); - let (tx1, rx1) = oneshot::channel(); - - thread::spawn(move || { - let mut sock = server.accept().unwrap().0; - sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); - sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); + let server = async move { + let mut sock = listener.accept().await.unwrap().0; let mut buf = [0; 4096]; - let n = sock.read(&mut buf).expect("read 1"); + let n = sock.read(&mut buf).await.expect("read 1"); // Notably: // - Just a path, since just a path was set @@ -1819,27 +1815,27 @@ mod conn { let expected = "GET /a HTTP/1.1\r\n\r\n"; assert_eq!(s(&buf[..n]), expected); - sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n").unwrap(); - let _ = tx1.send(()); - }); + sock.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n").await.unwrap(); + }; - let tcp = rt.block_on(tcp_connect(&addr)).unwrap(); + let client = async move { + let tcp = tcp_connect(&addr).await.expect("connect"); + let (mut client, conn) = conn::handshake(tcp).await.expect("handshake"); - let (mut client, conn) = rt.block_on(conn::handshake(tcp)).unwrap(); + hyper::rt::spawn(async move { + conn.await.expect("http conn"); + }); - rt.spawn(conn.map_err(|e| panic!("conn error: {}", e)).map(|_| ())); - - let req = Request::builder() - .uri("/a") - .body(Default::default()) - .unwrap(); - let res = client.send_request(req).and_then(move |res| { + let req = Request::builder() + .uri("/a") + .body(Default::default()) + .unwrap(); + let mut res = client.send_request(req).await.expect("send_request"); assert_eq!(res.status(), hyper::StatusCode::OK); - res.into_body().try_concat() - }); - let rx = rx1.expect("thread panicked"); - let rx = rx.then(|_| tokio_timer::delay_for(Duration::from_millis(200))); - rt.block_on(future::join(res, rx).map(|r| r.0)).unwrap(); + assert!(res.body_mut().next().await.is_none()); + }; + + future::join(server, client).await; } #[test]