feat(client): add method to end a chunked body for a Request

Closes #831
This commit is contained in:
Sean McArthur
2016-06-17 03:53:30 -07:00
parent 1b4f857997
commit c856de0428
4 changed files with 172 additions and 36 deletions

View File

@@ -646,13 +646,6 @@ impl<H: MessageHandler<T>, T: Transport> State<H, T> {
_ => Reading::Closed, _ => Reading::Closed,
}; };
let writing = match http1.writing { let writing = match http1.writing {
Writing::Ready(ref encoder) if encoder.is_eof() => {
if http1.keep_alive {
Writing::KeepAlive
} else {
Writing::Closed
}
},
Writing::Ready(encoder) => { Writing::Ready(encoder) => {
if encoder.is_eof() { if encoder.is_eof() {
if http1.keep_alive { if http1.keep_alive {
@@ -660,7 +653,7 @@ impl<H: MessageHandler<T>, T: Transport> State<H, T> {
} else { } else {
Writing::Closed Writing::Closed
} }
} else if let Some(buf) = encoder.end() { } else if let Some(buf) = encoder.finish() {
Writing::Chunk(Chunk { Writing::Chunk(Chunk {
buf: buf.bytes, buf: buf.bytes,
pos: buf.pos, pos: buf.pos,
@@ -680,7 +673,7 @@ impl<H: MessageHandler<T>, T: Transport> State<H, T> {
} else { } else {
Writing::Closed Writing::Closed
} }
} else if let Some(buf) = encoder.end() { } else if let Some(buf) = encoder.finish() {
Writing::Chunk(Chunk { Writing::Chunk(Chunk {
buf: buf.bytes, buf: buf.bytes,
pos: buf.pos, pos: buf.pos,
@@ -719,14 +712,26 @@ impl<H: MessageHandler<T>, T: Transport> State<H, T> {
}; };
http1.writing = match http1.writing { http1.writing = match http1.writing {
Writing::Ready(encoder) => if encoder.is_eof() { Writing::Ready(encoder) => {
if http1.keep_alive { if encoder.is_eof() {
Writing::KeepAlive if http1.keep_alive {
Writing::KeepAlive
} else {
Writing::Closed
}
} else if encoder.is_closed() {
if let Some(buf) = encoder.finish() {
Writing::Chunk(Chunk {
buf: buf.bytes,
pos: buf.pos,
next: (h1::Encoder::length(0), Next::wait())
})
} else {
Writing::Closed
}
} else { } else {
Writing::Closed Writing::Wait(encoder)
} }
} else {
Writing::Wait(encoder)
}, },
Writing::Chunk(chunk) => if chunk.is_written() { Writing::Chunk(chunk) => if chunk.is_written() {
Writing::Wait(chunk.next.0) Writing::Wait(chunk.next.0)

View File

@@ -8,7 +8,8 @@ use http::internal::{AtomicWrite, WriteBuf};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Encoder { pub struct Encoder {
kind: Kind, kind: Kind,
prefix: Prefix, //Option<WriteBuf<Vec<u8>>> prefix: Prefix,
is_closed: bool,
} }
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone)]
@@ -25,14 +26,16 @@ impl Encoder {
pub fn chunked() -> Encoder { pub fn chunked() -> Encoder {
Encoder { Encoder {
kind: Kind::Chunked(Chunked::Init), kind: Kind::Chunked(Chunked::Init),
prefix: Prefix(None) prefix: Prefix(None),
is_closed: false,
} }
} }
pub fn length(len: u64) -> Encoder { pub fn length(len: u64) -> Encoder {
Encoder { Encoder {
kind: Kind::Length(len), kind: Kind::Length(len),
prefix: Prefix(None) prefix: Prefix(None),
is_closed: false,
} }
} }
@@ -51,7 +54,16 @@ impl Encoder {
} }
} }
pub fn end(self) -> Option<WriteBuf<Cow<'static, [u8]>>> { /// User has called `encoder.close()` in a `Handler`.
pub fn is_closed(&self) -> bool {
self.is_closed
}
pub fn close(&mut self) {
self.is_closed = true;
}
pub fn finish(self) -> Option<WriteBuf<Cow<'static, [u8]>>> {
let trailer = self.trailer(); let trailer = self.trailer();
let buf = self.prefix.0; let buf = self.prefix.0;
@@ -335,7 +347,7 @@ mod tests {
use mock::{Async, Buf}; use mock::{Async, Buf};
#[test] #[test]
fn test_write_chunked_sync() { fn test_chunked_encode_sync() {
let mut dst = Buf::new(); let mut dst = Buf::new();
let mut encoder = Encoder::chunked(); let mut encoder = Encoder::chunked();
@@ -346,7 +358,7 @@ mod tests {
} }
#[test] #[test]
fn test_write_chunked_async() { fn test_chunked_encode_async() {
let mut dst = Async::new(Buf::new(), 7); let mut dst = Async::new(Buf::new(), 7);
let mut encoder = Encoder::chunked(); let mut encoder = Encoder::chunked();
@@ -360,7 +372,7 @@ mod tests {
} }
#[test] #[test]
fn test_write_sized() { fn test_sized_encode() {
let mut dst = Buf::new(); let mut dst = Buf::new();
let mut encoder = Encoder::length(8); let mut encoder = Encoder::length(8);
encoder.encode(&mut dst, b"foo bar").unwrap(); encoder.encode(&mut dst, b"foo bar").unwrap();

View File

@@ -72,6 +72,7 @@ impl<'a, T: Read> Decoder<'a, T> {
Decoder(DecoderImpl::H1(decoder, transport)) Decoder(DecoderImpl::H1(decoder, transport))
} }
/// Get a reference to the transport. /// Get a reference to the transport.
pub fn get_ref(&self) -> &T { pub fn get_ref(&self) -> &T {
match self.0 { match self.0 {
@@ -85,6 +86,17 @@ impl<'a, T: Transport> Encoder<'a, T> {
Encoder(EncoderImpl::H1(encoder, transport)) Encoder(EncoderImpl::H1(encoder, transport))
} }
/// Closes an encoder, signaling that no more writing will occur.
///
/// This is needed for encodings that don't know length of the content
/// beforehand. Most common instance would be usage of
/// `Transfer-Enciding: chunked`. You would call `close()` to signal
/// the `Encoder` should write the end chunk, or `0\r\n\r\n`.
pub fn close(&mut self) {
match self.0 {
EncoderImpl::H1(ref mut encoder, _) => encoder.close()
}
}
/// Get a reference to the transport. /// Get a reference to the transport.
pub fn get_ref(&self) -> &T { pub fn get_ref(&self) -> &T {
@@ -113,7 +125,11 @@ impl<'a, T: Transport> Write for Encoder<'a, T> {
} }
match self.0 { match self.0 {
EncoderImpl::H1(ref mut encoder, ref mut transport) => { EncoderImpl::H1(ref mut encoder, ref mut transport) => {
encoder.encode(*transport, data) if encoder.is_closed() {
Ok(0)
} else {
encoder.encode(*transport, data)
}
} }
} }
} }

View File

@@ -8,6 +8,7 @@ use std::time::Duration;
use hyper::client::{Handler, Request, Response, HttpConnector}; use hyper::client::{Handler, Request, Response, HttpConnector};
use hyper::{Method, StatusCode, Next, Encoder, Decoder}; use hyper::{Method, StatusCode, Next, Encoder, Decoder};
use hyper::header::Headers;
use hyper::net::HttpStream; use hyper::net::HttpStream;
fn s(bytes: &[u8]) -> &str { fn s(bytes: &[u8]) -> &str {
@@ -48,10 +49,24 @@ fn read(opts: &Opts) -> Next {
impl Handler<HttpStream> for TestHandler { impl Handler<HttpStream> for TestHandler {
fn on_request(&mut self, req: &mut Request) -> Next { fn on_request(&mut self, req: &mut Request) -> Next {
req.set_method(self.opts.method.clone()); req.set_method(self.opts.method.clone());
read(&self.opts) req.headers_mut().extend(self.opts.headers.iter());
if self.opts.body.is_some() {
Next::write()
} else {
read(&self.opts)
}
} }
fn on_request_writable(&mut self, _encoder: &mut Encoder<HttpStream>) -> Next { fn on_request_writable(&mut self, encoder: &mut Encoder<HttpStream>) -> Next {
if let Some(ref mut body) = self.opts.body {
let n = encoder.write(body).unwrap();
*body = &body[n..];
if !body.is_empty() {
return Next::write()
}
}
encoder.close();
read(&self.opts) read(&self.opts)
} }
@@ -103,14 +118,18 @@ struct Client {
#[derive(Debug)] #[derive(Debug)]
struct Opts { struct Opts {
body: Option<&'static [u8]>,
method: Method, method: Method,
headers: Headers,
read_timeout: Option<Duration>, read_timeout: Option<Duration>,
} }
impl Default for Opts { impl Default for Opts {
fn default() -> Opts { fn default() -> Opts {
Opts { Opts {
body: None,
method: Method::Get, method: Method::Get,
headers: Headers::new(),
read_timeout: None, read_timeout: None,
} }
} }
@@ -126,6 +145,16 @@ impl Opts {
self self
} }
fn header<H: ::hyper::header::Header>(mut self, header: H) -> Opts {
self.headers.set(header);
self
}
fn body(mut self, body: Option<&'static [u8]>) -> Opts {
self.body = body;
self
}
fn read_timeout(mut self, timeout: Duration) -> Opts { fn read_timeout(mut self, timeout: Duration) -> Opts {
self.read_timeout = Some(timeout); self.read_timeout = Some(timeout);
self self
@@ -167,33 +196,46 @@ macro_rules! test {
request: request:
method: $client_method:ident, method: $client_method:ident,
url: $client_url:expr, url: $client_url:expr,
headers: [ $($request_headers:expr,)* ],
body: $request_body:expr,
response: response:
status: $client_status:ident, status: $client_status:ident,
headers: [ $($client_headers:expr,)* ], headers: [ $($response_headers:expr,)* ],
body: $client_body:expr body: $response_body:expr,
) => ( ) => (
#[test] #[test]
fn $name() { fn $name() {
#[allow(unused)]
use hyper::header::*;
let server = TcpListener::bind("127.0.0.1:0").unwrap(); let server = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = server.local_addr().unwrap(); let addr = server.local_addr().unwrap();
let client = client(); let client = client();
let res = client.request(format!($client_url, addr=addr), opts().method(Method::$client_method)); let opts = opts()
.method(Method::$client_method)
.body($request_body);
$(
let opts = opts.header($request_headers);
)*
let res = client.request(format!($client_url, addr=addr), opts);
let mut inc = server.accept().unwrap().0; let mut inc = server.accept().unwrap().0;
inc.set_read_timeout(Some(Duration::from_secs(5))).unwrap(); inc.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
inc.set_write_timeout(Some(Duration::from_secs(5))).unwrap(); inc.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
let mut buf = [0; 4096];
let n = inc.read(&mut buf).unwrap();
let expected = format!($server_expected, addr=addr); let expected = format!($server_expected, addr=addr);
let mut buf = [0; 4096];
let mut n = 0;
while n < buf.len() && n < expected.len() {
n += inc.read(&mut buf[n..]).unwrap();
}
assert_eq!(s(&buf[..n]), expected); assert_eq!(s(&buf[..n]), expected);
inc.write_all($server_reply.as_ref()).unwrap(); inc.write_all($server_reply.as_ref()).unwrap();
if let Msg::Head(head) = res.recv().unwrap() { if let Msg::Head(head) = res.recv().unwrap() {
use hyper::header::*;
assert_eq!(head.status(), &StatusCode::$client_status); assert_eq!(head.status(), &StatusCode::$client_status);
$( $(
assert_eq!(head.headers().get(), Some(&$client_headers)); assert_eq!(head.headers().get(), Some(&$response_headers));
)* )*
} else { } else {
panic!("we lost the head!"); panic!("we lost the head!");
@@ -205,23 +247,27 @@ macro_rules! test {
); );
} }
static REPLY_OK: &'static str = "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n";
test! { test! {
name: client_get, name: client_get,
server: server:
expected: "GET / HTTP/1.1\r\nHost: {addr}\r\n\r\n", expected: "GET / HTTP/1.1\r\nHost: {addr}\r\n\r\n",
reply: "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n", reply: REPLY_OK,
client: client:
request: request:
method: Get, method: Get,
url: "http://{addr}/", url: "http://{addr}/",
headers: [],
body: None,
response: response:
status: Ok, status: Ok,
headers: [ headers: [
ContentLength(0), ContentLength(0),
], ],
body: None body: None,
} }
test! { test! {
@@ -229,19 +275,76 @@ test! {
server: server:
expected: "GET /foo?key=val HTTP/1.1\r\nHost: {addr}\r\n\r\n", expected: "GET /foo?key=val HTTP/1.1\r\nHost: {addr}\r\n\r\n",
reply: "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n", reply: REPLY_OK,
client: client:
request: request:
method: Get, method: Get,
url: "http://{addr}/foo?key=val#dont_send_me", url: "http://{addr}/foo?key=val#dont_send_me",
headers: [],
body: None,
response: response:
status: Ok, status: Ok,
headers: [ headers: [
ContentLength(0), ContentLength(0),
], ],
body: None body: None,
}
test! {
name: client_post_sized,
server:
expected: "\
POST /length HTTP/1.1\r\n\
Host: {addr}\r\n\
Content-Length: 7\r\n\
\r\n\
foo bar\
",
reply: REPLY_OK,
client:
request:
method: Post,
url: "http://{addr}/length",
headers: [
ContentLength(7),
],
body: Some(b"foo bar"),
response:
status: Ok,
headers: [],
body: None,
}
test! {
name: client_post_chunked,
server:
expected: "\
POST /chunks HTTP/1.1\r\n\
Host: {addr}\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
B\r\n\
foo bar baz\r\n\
0\r\n\r\n\
",
reply: REPLY_OK,
client:
request:
method: Post,
url: "http://{addr}/chunks",
headers: [
TransferEncoding::chunked(),
],
body: Some(b"foo bar baz"),
response:
status: Ok,
headers: [],
body: None,
} }
#[test] #[test]