From 1c472a220ac4db27c126ef9f1cdea393186d971a Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Wed, 17 Sep 2014 17:19:07 -0700 Subject: [PATCH] adds HttpWriters --- src/client/request.rs | 48 ++++++++++--- src/header/common/content_length.rs | 11 ++- src/header/common/transfer_encoding.rs | 1 - src/http.rs | 98 +++++++++++++++++++++++++- src/lib.rs | 2 +- src/server/response.rs | 51 +++++++++++--- 6 files changed, 190 insertions(+), 21 deletions(-) diff --git a/src/client/request.rs b/src/client/request.rs index 231240b9..cad798e5 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -5,9 +5,9 @@ use url::Url; use method::{mod, Get, Post, Delete, Put, Patch, Head, Options}; use header::Headers; -use header::common::Host; +use header::common::{mod, Host}; use net::{NetworkStream, HttpStream, WriteStatus, Fresh, Streaming}; -use http::LINE_ENDING; +use http::{HttpWriter, ThroughWriter, ChunkedWriter, SizedWriter, LINE_ENDING}; use version; use {HttpResult, HttpUriError}; use client::Response; @@ -20,7 +20,7 @@ pub struct Request { /// The HTTP version of this request. pub version: version::HttpVersion, - body: BufferedWriter>, + body: HttpWriter>>, headers: Headers, method: method::Method, } @@ -56,7 +56,7 @@ impl Request { debug!("port={}", port); let stream: S = try_io!(NetworkStream::connect(host.as_slice(), port, url.scheme.as_slice())); - let stream = BufferedWriter::new(stream.abstract()); + let stream = ThroughWriter(BufferedWriter::new(stream.abstract())); let mut headers = Headers::new(); headers.set(Host(host)); @@ -107,6 +107,31 @@ impl Request { debug!("{}", self.headers); + let mut chunked = true; + let mut len = 0; + + match self.headers.get_ref::() { + Some(cl) => { + chunked = false; + len = cl.len(); + }, + None => () + }; + + // cant do in match above, thanks borrowck + if chunked { + //TODO: use CollectionViews (when implemented) to prevent double hash/lookup + let encodings = match self.headers.get::() { + Some(common::TransferEncoding(mut encodings)) => { + //TODO: check if chunked is already in encodings. use HashSet? + encodings.push(common::transfer_encoding::Chunked); + encodings + }, + None => vec![common::transfer_encoding::Chunked] + }; + self.headers.set(common::TransferEncoding(encodings)); + } + for (name, header) in self.headers.iter() { try_io!(write!(self.body, "{}: {}", name, header)); try_io!(self.body.write(LINE_ENDING)); @@ -114,12 +139,18 @@ impl Request { try_io!(self.body.write(LINE_ENDING)); + let stream = if chunked { + ChunkedWriter(self.body.unwrap()) + } else { + SizedWriter(self.body.unwrap(), len) + }; + Ok(Request { method: self.method, headers: self.headers, url: self.url, version: self.version, - body: self.body + body: stream }) } @@ -132,18 +163,19 @@ impl Request { /// Completes writing the request, and returns a response to read from. /// /// Consumes the Request. - pub fn send(mut self) -> HttpResult { - try_io!(self.flush()); - let raw = self.body.unwrap(); + pub fn send(self) -> HttpResult { + let raw = try_io!(self.body.end()).unwrap(); Response::new(raw) } } impl Writer for Request { + #[inline] fn write(&mut self, msg: &[u8]) -> IoResult<()> { self.body.write(msg) } + #[inline] fn flush(&mut self) -> IoResult<()> { self.body.flush() } diff --git a/src/header/common/content_length.rs b/src/header/common/content_length.rs index 3f5ba5b4..d76636cc 100644 --- a/src/header/common/content_length.rs +++ b/src/header/common/content_length.rs @@ -1,5 +1,6 @@ -use header::Header; use std::fmt::{mod, Show}; + +use header::Header; use super::from_one_raw_str; /// The `Content-Length` header. @@ -23,3 +24,11 @@ impl Header for ContentLength { } } +impl ContentLength { + /// Returns the wrapped length. + #[inline] + pub fn len(&self) -> uint { + let ContentLength(len) = *self; + len + } +} diff --git a/src/header/common/transfer_encoding.rs b/src/header/common/transfer_encoding.rs index 311dd598..2fbc5660 100644 --- a/src/header/common/transfer_encoding.rs +++ b/src/header/common/transfer_encoding.rs @@ -33,7 +33,6 @@ pub enum Encoding { /// The `chunked` encoding. Chunked, - // TODO: #2 implement this in `HttpReader`. /// The `gzip` encoding. Gzip, /// The `deflate` encoding. diff --git a/src/http.rs b/src/http.rs index 7ea5367b..560edb19 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,3 +1,4 @@ +//! Pieces pertaining to the HTTP message protocol. use std::cmp::min; use std::io::{mod, Reader, IoResult}; use std::u16; @@ -133,6 +134,77 @@ fn read_chunk_size(rdr: &mut R) -> IoResult { Ok(size) } +/// Writers to handle different Transfer-Encodings. +pub enum HttpWriter { + /// A no-op Writer, used initially before Transfer-Encoding is determined. + ThroughWriter(W), + /// A Writer for when Transfer-Encoding includes `chunked`. + ChunkedWriter(W), + /// A Writer for when Content-Length is set. + /// + /// Enforces that the body is not longer than the Content-Length header. + SizedWriter(W, uint), +} + +impl HttpWriter { + /// Unwraps the HttpWriter and returns the underlying Writer. + #[inline] + pub fn unwrap(self) -> W { + match self { + ThroughWriter(w) => w, + ChunkedWriter(w) => w, + SizedWriter(w, _) => w + } + } + + /// Ends the HttpWriter, and returns the underlying Writer. + /// + /// A final `write()` is called with an empty message, and then flushed. + /// The ChunkedWriter variant will use this to write the 0-sized last-chunk. + #[inline] + pub fn end(mut self) -> IoResult { + try!(self.write(&[])); + try!(self.flush()); + Ok(self.unwrap()) + } +} + +impl Writer for HttpWriter { + #[inline] + fn write(&mut self, msg: &[u8]) -> IoResult<()> { + match *self { + ThroughWriter(ref mut w) => w.write(msg), + ChunkedWriter(ref mut w) => { + let chunk_size = msg.len(); + try!(write!(w, "{:X}{}{}", chunk_size, CR as char, LF as char)); + try!(w.write(msg)); + w.write(LINE_ENDING) + }, + SizedWriter(ref mut w, ref mut remaining) => { + let len = msg.len(); + if len > *remaining { + let len = *remaining; + *remaining = 0; + try!(w.write(msg.slice_to(len))); // msg[..len] + Err(io::standard_error(io::ShortWrite(len))) + } else { + *remaining -= len; + w.write(msg) + } + } + } + } + + #[inline] + fn flush(&mut self) -> IoResult<()> { + match *self { + ThroughWriter(ref mut w) => w.flush(), + ChunkedWriter(ref mut w) => w.flush(), + SizedWriter(ref mut w, _) => w.flush(), + } + } +} + pub static SP: u8 = b' '; pub static CR: u8 = b'\r'; pub static LF: u8 = b'\n'; @@ -551,6 +623,7 @@ pub fn read_status_line(stream: &mut R) -> HttpResult { Ok((version, code)) } +/// Read the StatusCode from a stream. pub fn read_status(stream: &mut R) -> HttpResult { let code = [ try_io!(stream.read_byte()), @@ -591,7 +664,7 @@ fn expect(r: IoResult, expected: u8) -> HttpResult<()> { #[cfg(test)] mod tests { - use std::io::MemReader; + use std::io::{mod, MemReader, MemWriter}; use test::Bencher; use uri::{RequestUri, Star, AbsoluteUri, AbsolutePath, Authority}; use method; @@ -669,6 +742,29 @@ mod tests { "rust-lang.org".as_bytes().to_vec())))); } + #[test] + fn test_write_chunked() { + use std::str::from_utf8; + let mut w = super::ChunkedWriter(MemWriter::new()); + w.write(b"foo bar").unwrap(); + w.write(b"baz quux herp").unwrap(); + let buf = w.end().unwrap().unwrap(); + let s = from_utf8(buf.as_slice()).unwrap(); + assert_eq!(s, "7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n"); + } + + #[test] + fn test_write_sized() { + use std::str::from_utf8; + let mut w = super::SizedWriter(MemWriter::new(), 8); + w.write(b"foo bar").unwrap(); + assert_eq!(w.write(b"baz"), Err(io::standard_error(io::ShortWrite(1)))); + + let buf = w.end().unwrap().unwrap(); + let s = from_utf8(buf.as_slice()).unwrap(); + assert_eq!(s, "foo barb"); + } + #[bench] fn bench_read_method(b: &mut Bencher) { b.bytes = b"CONNECT ".len() as u64; diff --git a/src/lib.rs b/src/lib.rs index 9414d648..4dfda02f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -56,13 +56,13 @@ macro_rules! trace( pub mod client; pub mod method; pub mod header; +pub mod http; pub mod net; pub mod server; pub mod status; pub mod uri; pub mod version; -mod http; mod mimewrapper { /// Re-exporting the mime crate, for convenience. diff --git a/src/server/response.rs b/src/server/response.rs index dcda84e0..76c650ed 100644 --- a/src/server/response.rs +++ b/src/server/response.rs @@ -8,7 +8,7 @@ use time::now_utc; use header; use header::common; -use http::{CR, LF, LINE_ENDING}; +use http::{CR, LF, LINE_ENDING, HttpWriter, ThroughWriter, ChunkedWriter, SizedWriter}; use status; use net::{NetworkStream, WriteStatus, Fresh, Streaming}; use version; @@ -18,7 +18,7 @@ pub struct Response { /// The HTTP version of this response. pub version: version::HttpVersion, // Stream the Response is writing to, not accessible through UnwrittenResponse - body: BufferedWriter>, // TODO: use a HttpWriter from http + body: HttpWriter>>, // The status code for the request. status: status::StatusCode, // The outgoing headers on this response. @@ -35,7 +35,7 @@ impl Response { /// Construct a Response from its constituent parts. pub fn construct(version: version::HttpVersion, - body: BufferedWriter>, + body: HttpWriter>>, status: status::StatusCode, headers: header::Headers) -> Response { Response { @@ -54,7 +54,7 @@ impl Response { status: status::Ok, version: version::Http11, headers: header::Headers::new(), - body: BufferedWriter::new(stream.abstract()) + body: ThroughWriter(BufferedWriter::new(stream.abstract())) } } @@ -67,18 +67,50 @@ impl Response { self.headers.set(common::Date(now_utc())); } + + let mut chunked = true; + let mut len = 0; + + match self.headers.get_ref::() { + Some(cl) => { + chunked = false; + len = cl.len(); + }, + None => () + }; + + // cant do in match above, thanks borrowck + if chunked { + //TODO: use CollectionViews (when implemented) to prevent double hash/lookup + let encodings = match self.headers.get::() { + Some(common::TransferEncoding(mut encodings)) => { + //TODO: check if chunked is already in encodings. use HashSet? + encodings.push(common::transfer_encoding::Chunked); + encodings + }, + None => vec![common::transfer_encoding::Chunked] + }; + self.headers.set(common::TransferEncoding(encodings)); + } + for (name, header) in self.headers.iter() { - debug!("headers {}: {}", name, header); + debug!("header {}: {}", name, header); try!(write!(self.body, "{}: {}", name, header)); try!(self.body.write(LINE_ENDING)); } try!(self.body.write(LINE_ENDING)); + let stream = if chunked { + ChunkedWriter(self.body.unwrap()) + } else { + SizedWriter(self.body.unwrap(), len) + }; + // "copy" to change the phantom type Ok(Response { version: self.version, - body: self.body, + body: stream, status: self.status, headers: self.headers }) @@ -92,7 +124,7 @@ impl Response { pub fn headers_mut(&mut self) -> &mut header::Headers { &mut self.headers } /// Deconstruct this Response into its constituent parts. - pub fn deconstruct(self) -> (version::HttpVersion, BufferedWriter>, + pub fn deconstruct(self) -> (version::HttpVersion, HttpWriter>>, status::StatusCode, header::Headers) { (self.version, self.body, self.status, self.headers) } @@ -100,9 +132,10 @@ impl Response { impl Response { /// Flushes all writing of a response to the client. - pub fn end(mut self) -> IoResult<()> { + pub fn end(self) -> IoResult<()> { debug!("ending"); - self.flush() + try!(self.body.end()); + Ok(()) } }