From 3b8c5cac1a5bea9b91bd4aa6eb9ad42d4d6568dd Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Sun, 9 Nov 2014 12:13:12 -0800 Subject: [PATCH] fix(client): GET and HEAD shouldn't add Transfer-Encoding Also adds an EmptyWriter, used for GET and HEAD requests, which will return an io::ShortWrite error if the user ever tries to write to a GET or HEAD request. Closes #77 --- src/client/request.rs | 110 ++++++++++++++++++++++++++++------------- src/client/response.rs | 4 +- src/http.rs | 18 ++++++- src/mock.rs | 51 ++++++++++++++++--- src/net.rs | 8 +-- 5 files changed, 142 insertions(+), 49 deletions(-) diff --git a/src/client/request.rs b/src/client/request.rs index 62b8b379..6f9bd66e 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -7,7 +7,7 @@ use method::{mod, Get, Post, Delete, Put, Patch, Head, Options}; use header::Headers; use header::common::{mod, Host}; use net::{NetworkStream, NetworkConnector, HttpStream, Fresh, Streaming}; -use http::{HttpWriter, ThroughWriter, ChunkedWriter, SizedWriter, LINE_ENDING}; +use http::{HttpWriter, ThroughWriter, ChunkedWriter, SizedWriter, EmptyWriter, LINE_ENDING}; use version; use {HttpResult, HttpUriError}; use client::Response; @@ -117,43 +117,50 @@ impl Request { try_io!(self.body.write(LINE_ENDING)); - let mut chunked = true; - let mut len = 0; - - match self.headers.get::() { - Some(cl) => { - chunked = false; - len = cl.len(); + let stream = match self.method { + Get | Head => { + EmptyWriter(self.body.unwrap()) }, - None => () - }; + _ => { + let mut chunked = true; + let mut len = 0; - // cant do in match above, thanks borrowck - if chunked { - let encodings = match self.headers.get_mut::() { - Some(&common::TransferEncoding(ref mut encodings)) => { - //TODO: check if chunked is already in encodings. use HashSet? - encodings.push(common::transfer_encoding::Chunked); - false - }, - None => true - }; + match self.headers.get::() { + Some(cl) => { + chunked = false; + len = cl.len(); + }, + None => () + }; - if encodings { - self.headers.set::( - common::TransferEncoding(vec![common::transfer_encoding::Chunked])) + // cant do in match above, thanks borrowck + if chunked { + let encodings = match self.headers.get_mut::() { + Some(&common::TransferEncoding(ref mut encodings)) => { + //TODO: check if chunked is already in encodings. use HashSet? + encodings.push(common::transfer_encoding::Chunked); + false + }, + None => true + }; + + if encodings { + self.headers.set::( + common::TransferEncoding(vec![common::transfer_encoding::Chunked])) + } + } + + debug!("headers [\n{}]", self.headers); + try_io!(write!(self.body, "{}", self.headers)); + + try_io!(self.body.write(LINE_ENDING)); + + if chunked { + ChunkedWriter(self.body.unwrap()) + } else { + SizedWriter(self.body.unwrap(), len) + } } - } - - debug!("headers [\n{}]", self.headers); - try_io!(write!(self.body, "{}", self.headers)); - - try_io!(self.body.write(LINE_ENDING)); - - let stream = if chunked { - ChunkedWriter(self.body.unwrap()) - } else { - SizedWriter(self.body.unwrap(), len) }; Ok(Request { @@ -192,3 +199,38 @@ impl Writer for Request { } } +#[cfg(test)] +mod tests { + use std::boxed::BoxAny; + use std::str::from_utf8; + use url::Url; + use method::{Get, Head}; + use mock::MockStream; + use super::Request; + + #[test] + fn test_get_empty_body() { + let req = Request::with_stream::( + Get, Url::parse("http://example.dom").unwrap() + ).unwrap(); + let req = req.start().unwrap(); + let stream = *req.body.end().unwrap().unwrap().downcast::().unwrap(); + let bytes = stream.write.unwrap(); + let s = from_utf8(bytes[]).unwrap(); + assert!(!s.contains("Content-Length:")); + assert!(!s.contains("Transfer-Encoding:")); + } + + #[test] + fn test_head_empty_body() { + let req = Request::with_stream::( + Head, Url::parse("http://example.dom").unwrap() + ).unwrap(); + let req = req.start().unwrap(); + let stream = *req.body.end().unwrap().unwrap().downcast::().unwrap(); + let bytes = stream.write.unwrap(); + let s = from_utf8(bytes[]).unwrap(); + assert!(!s.contains("Content-Length:")); + assert!(!s.contains("Transfer-Encoding:")); + } +} diff --git a/src/client/response.rs b/src/client/response.rs index 940169ad..b54fec49 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -100,11 +100,11 @@ mod tests { status: status::Ok, headers: Headers::new(), version: version::Http11, - body: EofReader(BufferedReader::new(box MockStream as Box)) + body: EofReader(BufferedReader::new(box MockStream::new() as Box)) }; let b = res.unwrap().downcast::().unwrap(); - assert_eq!(b, box MockStream); + assert_eq!(b, box MockStream::new()); } } diff --git a/src/http.rs b/src/http.rs index f17ab024..e00022a6 100644 --- a/src/http.rs +++ b/src/http.rs @@ -157,6 +157,8 @@ pub enum HttpWriter { /// /// Enforces that the body is not longer than the Content-Length header. SizedWriter(W, uint), + /// A writer that should not write any body. + EmptyWriter(W), } impl HttpWriter { @@ -166,7 +168,8 @@ impl HttpWriter { match self { ThroughWriter(w) => w, ChunkedWriter(w) => w, - SizedWriter(w, _) => w + SizedWriter(w, _) => w, + EmptyWriter(w) => w, } } @@ -204,6 +207,18 @@ impl Writer for HttpWriter { *remaining -= len; w.write(msg) } + }, + EmptyWriter(..) => { + let bytes = msg.len(); + if bytes == 0 { + Ok(()) + } else { + Err(io::IoError { + kind: io::ShortWrite(bytes), + desc: "EmptyWriter cannot write any bytes", + detail: Some("Cannot include a body with this kind of message".into_string()) + }) + } } } } @@ -214,6 +229,7 @@ impl Writer for HttpWriter { ThroughWriter(ref mut w) => w.flush(), ChunkedWriter(ref mut w) => w.flush(), SizedWriter(ref mut w, _) => w.flush(), + EmptyWriter(ref mut w) => w.flush(), } } } diff --git a/src/mock.rs b/src/mock.rs index 8c023ad8..5f7504e4 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -1,20 +1,55 @@ -use std::io::IoResult; +use std::fmt; +use std::io::{IoResult, MemReader, MemWriter}; use std::io::net::ip::{SocketAddr, ToSocketAddr}; use net::{NetworkStream, NetworkConnector}; -#[deriving(Clone, PartialEq, Show)] -pub struct MockStream; +pub struct MockStream { + pub read: MemReader, + pub write: MemWriter, +} +impl Clone for MockStream { + fn clone(&self) -> MockStream { + MockStream { + read: MemReader::new(self.read.get_ref().to_vec()), + write: MemWriter::from_vec(self.write.get_ref().to_vec()), + } + } +} + +impl PartialEq for MockStream { + fn eq(&self, other: &MockStream) -> bool { + self.read.get_ref() == other.read.get_ref() && + self.write.get_ref() == other.write.get_ref() + } +} + +impl fmt::Show for MockStream { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "MockStream {{ read: {}, write: {} }}", + self.read.get_ref(), self.write.get_ref()) + } + +} + +impl MockStream { + pub fn new() -> MockStream { + MockStream { + read: MemReader::new(vec![]), + write: MemWriter::new(), + } + } +} impl Reader for MockStream { - fn read(&mut self, _buf: &mut [u8]) -> IoResult { - unimplemented!() + fn read(&mut self, buf: &mut [u8]) -> IoResult { + self.read.read(buf) } } impl Writer for MockStream { - fn write(&mut self, _msg: &[u8]) -> IoResult<()> { - unimplemented!() + fn write(&mut self, msg: &[u8]) -> IoResult<()> { + self.write.write(msg) } } @@ -27,6 +62,6 @@ impl NetworkStream for MockStream { impl NetworkConnector for MockStream { fn connect(_addr: To, _scheme: &str) -> IoResult { - Ok(MockStream) + Ok(MockStream::new()) } } diff --git a/src/net.rs b/src/net.rs index 4313cfbc..68cb0ca3 100644 --- a/src/net.rs +++ b/src/net.rs @@ -274,19 +274,19 @@ mod tests { #[test] fn test_downcast_box_stream() { - let stream = box MockStream as Box; + let stream = box MockStream::new() as Box; let mock = stream.downcast::().unwrap(); - assert_eq!(mock, box MockStream); + assert_eq!(mock, box MockStream::new()); } #[test] fn test_downcast_unchecked_box_stream() { - let stream = box MockStream as Box; + let stream = box MockStream::new() as Box; let mock = unsafe { stream.downcast_unchecked::() }; - assert_eq!(mock, box MockStream); + assert_eq!(mock, box MockStream::new()); }