diff --git a/src/http/conn.rs b/src/http/conn.rs index 17d09b1a..1e335b39 100644 --- a/src/http/conn.rs +++ b/src/http/conn.rs @@ -646,13 +646,6 @@ impl, T: Transport> State { _ => Reading::Closed, }; let writing = match http1.writing { - Writing::Ready(ref encoder) if encoder.is_eof() => { - if http1.keep_alive { - Writing::KeepAlive - } else { - Writing::Closed - } - }, Writing::Ready(encoder) => { if encoder.is_eof() { if http1.keep_alive { @@ -660,7 +653,7 @@ impl, T: Transport> State { } else { Writing::Closed } - } else if let Some(buf) = encoder.end() { + } else if let Some(buf) = encoder.finish() { Writing::Chunk(Chunk { buf: buf.bytes, pos: buf.pos, @@ -680,7 +673,7 @@ impl, T: Transport> State { } else { Writing::Closed } - } else if let Some(buf) = encoder.end() { + } else if let Some(buf) = encoder.finish() { Writing::Chunk(Chunk { buf: buf.bytes, pos: buf.pos, @@ -719,14 +712,26 @@ impl, T: Transport> State { }; http1.writing = match http1.writing { - Writing::Ready(encoder) => if encoder.is_eof() { - if http1.keep_alive { - Writing::KeepAlive + Writing::Ready(encoder) => { + if encoder.is_eof() { + if http1.keep_alive { + Writing::KeepAlive + } else { + Writing::Closed + } + } else if encoder.is_closed() { + if let Some(buf) = encoder.finish() { + Writing::Chunk(Chunk { + buf: buf.bytes, + pos: buf.pos, + next: (h1::Encoder::length(0), Next::wait()) + }) + } else { + Writing::Closed + } } else { - Writing::Closed + Writing::Wait(encoder) } - } else { - Writing::Wait(encoder) }, Writing::Chunk(chunk) => if chunk.is_written() { Writing::Wait(chunk.next.0) diff --git a/src/http/h1/encode.rs b/src/http/h1/encode.rs index 6cc7772a..4fd4e299 100644 --- a/src/http/h1/encode.rs +++ b/src/http/h1/encode.rs @@ -8,7 +8,8 @@ use http::internal::{AtomicWrite, WriteBuf}; #[derive(Debug, Clone)] pub struct Encoder { kind: Kind, - prefix: Prefix, //Option>> + prefix: Prefix, + is_closed: bool, } #[derive(Debug, PartialEq, Clone)] @@ -25,14 +26,16 @@ impl Encoder { pub fn chunked() -> Encoder { Encoder { kind: Kind::Chunked(Chunked::Init), - prefix: Prefix(None) + prefix: Prefix(None), + is_closed: false, } } pub fn length(len: u64) -> Encoder { Encoder { kind: Kind::Length(len), - prefix: Prefix(None) + prefix: Prefix(None), + is_closed: false, } } @@ -51,7 +54,16 @@ impl Encoder { } } - pub fn end(self) -> Option>> { + /// User has called `encoder.close()` in a `Handler`. + pub fn is_closed(&self) -> bool { + self.is_closed + } + + pub fn close(&mut self) { + self.is_closed = true; + } + + pub fn finish(self) -> Option>> { let trailer = self.trailer(); let buf = self.prefix.0; @@ -335,7 +347,7 @@ mod tests { use mock::{Async, Buf}; #[test] - fn test_write_chunked_sync() { + fn test_chunked_encode_sync() { let mut dst = Buf::new(); let mut encoder = Encoder::chunked(); @@ -346,7 +358,7 @@ mod tests { } #[test] - fn test_write_chunked_async() { + fn test_chunked_encode_async() { let mut dst = Async::new(Buf::new(), 7); let mut encoder = Encoder::chunked(); @@ -360,7 +372,7 @@ mod tests { } #[test] - fn test_write_sized() { + fn test_sized_encode() { let mut dst = Buf::new(); let mut encoder = Encoder::length(8); encoder.encode(&mut dst, b"foo bar").unwrap(); diff --git a/src/http/mod.rs b/src/http/mod.rs index 030b17b4..883cc7f0 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -72,6 +72,7 @@ impl<'a, T: Read> Decoder<'a, T> { Decoder(DecoderImpl::H1(decoder, transport)) } + /// Get a reference to the transport. pub fn get_ref(&self) -> &T { match self.0 { @@ -85,6 +86,17 @@ impl<'a, T: Transport> Encoder<'a, T> { Encoder(EncoderImpl::H1(encoder, transport)) } + /// Closes an encoder, signaling that no more writing will occur. + /// + /// This is needed for encodings that don't know length of the content + /// beforehand. Most common instance would be usage of + /// `Transfer-Enciding: chunked`. You would call `close()` to signal + /// the `Encoder` should write the end chunk, or `0\r\n\r\n`. + pub fn close(&mut self) { + match self.0 { + EncoderImpl::H1(ref mut encoder, _) => encoder.close() + } + } /// Get a reference to the transport. pub fn get_ref(&self) -> &T { @@ -113,7 +125,11 @@ impl<'a, T: Transport> Write for Encoder<'a, T> { } match self.0 { EncoderImpl::H1(ref mut encoder, ref mut transport) => { - encoder.encode(*transport, data) + if encoder.is_closed() { + Ok(0) + } else { + encoder.encode(*transport, data) + } } } } diff --git a/tests/client.rs b/tests/client.rs index 0c34ee07..563250be 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -8,6 +8,7 @@ use std::time::Duration; use hyper::client::{Handler, Request, Response, HttpConnector}; use hyper::{Method, StatusCode, Next, Encoder, Decoder}; +use hyper::header::Headers; use hyper::net::HttpStream; fn s(bytes: &[u8]) -> &str { @@ -48,10 +49,24 @@ fn read(opts: &Opts) -> Next { impl Handler for TestHandler { fn on_request(&mut self, req: &mut Request) -> Next { req.set_method(self.opts.method.clone()); - read(&self.opts) + req.headers_mut().extend(self.opts.headers.iter()); + if self.opts.body.is_some() { + Next::write() + } else { + read(&self.opts) + } } - fn on_request_writable(&mut self, _encoder: &mut Encoder) -> Next { + fn on_request_writable(&mut self, encoder: &mut Encoder) -> Next { + if let Some(ref mut body) = self.opts.body { + let n = encoder.write(body).unwrap(); + *body = &body[n..]; + + if !body.is_empty() { + return Next::write() + } + } + encoder.close(); read(&self.opts) } @@ -103,14 +118,18 @@ struct Client { #[derive(Debug)] struct Opts { + body: Option<&'static [u8]>, method: Method, + headers: Headers, read_timeout: Option, } impl Default for Opts { fn default() -> Opts { Opts { + body: None, method: Method::Get, + headers: Headers::new(), read_timeout: None, } } @@ -126,6 +145,16 @@ impl Opts { self } + fn header(mut self, header: H) -> Opts { + self.headers.set(header); + self + } + + fn body(mut self, body: Option<&'static [u8]>) -> Opts { + self.body = body; + self + } + fn read_timeout(mut self, timeout: Duration) -> Opts { self.read_timeout = Some(timeout); self @@ -167,33 +196,46 @@ macro_rules! test { request: method: $client_method:ident, url: $client_url:expr, + headers: [ $($request_headers:expr,)* ], + body: $request_body:expr, + response: status: $client_status:ident, - headers: [ $($client_headers:expr,)* ], - body: $client_body:expr + headers: [ $($response_headers:expr,)* ], + body: $response_body:expr, ) => ( #[test] fn $name() { + #[allow(unused)] + use hyper::header::*; let server = TcpListener::bind("127.0.0.1:0").unwrap(); let addr = server.local_addr().unwrap(); let client = client(); - let res = client.request(format!($client_url, addr=addr), opts().method(Method::$client_method)); + let opts = opts() + .method(Method::$client_method) + .body($request_body); + $( + let opts = opts.header($request_headers); + )* + let res = client.request(format!($client_url, addr=addr), opts); let mut inc = server.accept().unwrap().0; inc.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); inc.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); - let mut buf = [0; 4096]; - let n = inc.read(&mut buf).unwrap(); let expected = format!($server_expected, addr=addr); + let mut buf = [0; 4096]; + let mut n = 0; + while n < buf.len() && n < expected.len() { + n += inc.read(&mut buf[n..]).unwrap(); + } assert_eq!(s(&buf[..n]), expected); inc.write_all($server_reply.as_ref()).unwrap(); if let Msg::Head(head) = res.recv().unwrap() { - use hyper::header::*; assert_eq!(head.status(), &StatusCode::$client_status); $( - assert_eq!(head.headers().get(), Some(&$client_headers)); + assert_eq!(head.headers().get(), Some(&$response_headers)); )* } else { panic!("we lost the head!"); @@ -205,23 +247,27 @@ macro_rules! test { ); } +static REPLY_OK: &'static str = "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"; + test! { name: client_get, server: expected: "GET / HTTP/1.1\r\nHost: {addr}\r\n\r\n", - reply: "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n", + reply: REPLY_OK, client: request: method: Get, url: "http://{addr}/", + headers: [], + body: None, response: status: Ok, headers: [ ContentLength(0), ], - body: None + body: None, } test! { @@ -229,19 +275,76 @@ test! { server: expected: "GET /foo?key=val HTTP/1.1\r\nHost: {addr}\r\n\r\n", - reply: "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n", + reply: REPLY_OK, client: request: method: Get, url: "http://{addr}/foo?key=val#dont_send_me", + headers: [], + body: None, response: status: Ok, headers: [ ContentLength(0), ], - body: None + body: None, +} +test! { + name: client_post_sized, + + server: + expected: "\ + POST /length HTTP/1.1\r\n\ + Host: {addr}\r\n\ + Content-Length: 7\r\n\ + \r\n\ + foo bar\ + ", + reply: REPLY_OK, + + client: + request: + method: Post, + url: "http://{addr}/length", + headers: [ + ContentLength(7), + ], + body: Some(b"foo bar"), + response: + status: Ok, + headers: [], + body: None, +} + +test! { + name: client_post_chunked, + + server: + expected: "\ + POST /chunks HTTP/1.1\r\n\ + Host: {addr}\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + B\r\n\ + foo bar baz\r\n\ + 0\r\n\r\n\ + ", + reply: REPLY_OK, + + client: + request: + method: Post, + url: "http://{addr}/chunks", + headers: [ + TransferEncoding::chunked(), + ], + body: Some(b"foo bar baz"), + response: + status: Ok, + headers: [], + body: None, } #[test]