diff --git a/src/client/request.rs b/src/client/request.rs index 9c9afc57..dc48aeb4 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -111,7 +111,7 @@ impl Request { let mut chunked = true; let mut len = 0; - match self.headers.get_ref::() { + match self.headers.get::() { Some(cl) => { chunked = false; len = cl.len(); @@ -121,16 +121,19 @@ impl Request { // cant do in match above, thanks borrowck if chunked { - //TODO: use CollectionViews (when implemented) to prevent double hash/lookup - let encodings = match self.headers.get::() { - Some(common::TransferEncoding(mut encodings)) => { + let encodings = match self.headers.get_mut::() { + Some(&common::TransferEncoding(ref mut encodings)) => { //TODO: check if chunked is already in encodings. use HashSet? encodings.push(common::transfer_encoding::Chunked); - encodings + false }, - None => vec![common::transfer_encoding::Chunked] + None => true }; - self.headers.set(common::TransferEncoding(encodings)); + + if encodings { + self.headers.set::( + common::TransferEncoding(vec![common::transfer_encoding::Chunked])) + } } for (name, header) in self.headers.iter() { diff --git a/src/client/response.rs b/src/client/response.rs index 3474ac06..7124cf6c 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -27,13 +27,13 @@ impl Response { pub fn new(stream: Box) -> HttpResult { let mut stream = BufferedReader::new(stream.abstract()); let (version, status) = try!(read_status_line(&mut stream)); - let mut headers = try!(header::Headers::from_raw(&mut stream)); + let headers = try!(header::Headers::from_raw(&mut stream)); debug!("{} {}", version, status); debug!("{}", headers); let body = if headers.has::() { - match headers.get_ref::() { + match headers.get::() { Some(&TransferEncoding(ref codings)) => { if codings.len() > 1 { debug!("TODO: #2 handle other codings: {}", codings); @@ -49,7 +49,7 @@ impl Response { None => unreachable!() } } else if headers.has::() { - match headers.get_ref::() { + match headers.get::() { Some(&ContentLength(len)) => SizedReader(stream, len), None => unreachable!() } diff --git a/src/header/mod.rs b/src/header/mod.rs index c4ce71fc..7dc8e9a5 100644 --- a/src/header/mod.rs +++ b/src/header/mod.rs @@ -13,8 +13,9 @@ use std::raw::TraitObject; use std::str::{from_utf8, SendStr, Slice, Owned}; use std::string::raw; use std::collections::hashmap::{HashMap, Entries, Occupied, Vacant}; +use std::sync::RWLock; -use uany::UncheckedAnyDowncast; +use uany::{UncheckedAnyDowncast, UncheckedAnyMutDowncast}; use typeable::Typeable; use http::read_header; @@ -24,7 +25,7 @@ use {HttpResult}; pub mod common; /// A trait for any object that will represent a header field and value. -pub trait Header: Typeable { +pub trait Header: Typeable + Send + Sync { /// Returns the name of the header field this belongs to. /// /// The market `Option` is to hint to the type system which implementation @@ -61,6 +62,14 @@ impl<'a> UncheckedAnyDowncast<'a> for &'a Header { } } +impl<'a> UncheckedAnyMutDowncast<'a> for &'a mut Header { + #[inline] + unsafe fn downcast_mut_unchecked(self) -> &'a mut T { + let to: TraitObject = transmute_copy(&self); + transmute(to.data) + } +} + fn header_name() -> &'static str { let name = Header::header_name(None::); name @@ -68,7 +77,7 @@ fn header_name() -> &'static str { /// A map of header fields on requests and responses. pub struct Headers { - data: HashMap + data: HashMap> } impl Headers { @@ -92,13 +101,14 @@ impl Headers { raw::from_utf8(name) }; - let item = match headers.data.entry(CaseInsensitive(Owned(name))) { - Vacant(entry) => entry.set(Raw(vec![])), + let name = CaseInsensitive(Owned(name)); + let item = match headers.data.entry(name) { + Vacant(entry) => entry.set(RWLock::new(Raw(vec![]))), Occupied(entry) => entry.into_mut() }; - match *item { - Raw(ref mut raw) => raw.push(value), + match &mut *item.write() { + &Raw(ref mut raw) => raw.push(value), // Unreachable _ => {} }; @@ -113,21 +123,7 @@ impl Headers { /// /// The field is determined by the type of the value being set. pub fn set(&mut self, value: H) { - self.data.insert(CaseInsensitive(Slice(header_name::())), Typed(box value as Box
)); - } - - /// Get a clone of the header field's value, if it exists. - /// - /// Example: - /// - /// ``` - /// # use hyper::header::Headers; - /// # use hyper::header::common::ContentType; - /// # let mut headers = Headers::new(); - /// let content_type = headers.get::(); - /// ``` - pub fn get(&mut self) -> Option { - self.get_ref().map(|v: &H| v.clone()) + self.data.insert(CaseInsensitive(Slice(header_name::())), RWLock::new(Typed(box value as Box
))); } /// Access the raw value of a header, if it exists and has not @@ -136,56 +132,92 @@ impl Headers { /// If the header field has already been parsed into a typed header, /// then you *must* access it through that representation. /// + /// This operation is unsafe because the raw representation can be + /// invalidated by lasting too long or by the header being parsed + /// while you still have a reference to the data. + /// /// Example: /// ``` /// # use hyper::header::Headers; /// # let mut headers = Headers::new(); /// let raw_content_type = unsafe { headers.get_raw("content-type") }; /// ``` - pub unsafe fn get_raw(&self, name: &'static str) -> Option<&[Vec]> { + pub unsafe fn get_raw(&self, name: &'static str) -> Option<*const [Vec]> { self.data.find(&CaseInsensitive(Slice(name))).and_then(|item| { - match *item { - Raw(ref raw) => Some(raw.as_slice()), + match *item.read() { + Raw(ref raw) => Some(raw.as_slice() as *const [Vec]), _ => None } }) } /// Get a reference to the header field's value, if it exists. - pub fn get_ref(&mut self) -> Option<&H> { - self.data.find_mut(&CaseInsensitive(Slice(header_name::()))).and_then(|item| { - debug!("get_ref, name={}, val={}", header_name::(), item); - let header = match *item { - // Huge borrowck hack here, should be refactored to just return here. - Typed(ref typed) if typed.is::() => None, - // Typed, wrong type - Typed(_) => return None, - Raw(ref raw) => match Header::parse_header(raw.as_slice()) { - Some::(h) => { - Some(h) - }, - None => return None - }, - }; - - match header { - Some(header) => { - *item = Typed(box header as Box
); - Some(item) - }, - None => { - Some(item) - } - } - }).and_then(|item| { - debug!("downcasting {}", item); - let ret = match *item { - Typed(ref val) => { - unsafe { Some(val.downcast_ref_unchecked()) } - }, + pub fn get(&self) -> Option<&H> { + self.get_or_parse::().map(|item| { + let read = item.read(); + debug!("downcasting {}", *read); + let ret = match *read { + Typed(ref val) => unsafe { val.downcast_ref_unchecked() }, _ => unreachable!() }; - ret + unsafe { transmute::<&H, &H>(ret) } + }) + } + + /// Get a mutable reference to the header field's value, if it exists. + pub fn get_mut(&mut self) -> Option<&mut H> { + self.get_or_parse::().map(|item| { + let mut write = item.write(); + debug!("downcasting {}", *write); + let ret = match *&mut *write { + Typed(ref mut val) => unsafe { val.downcast_mut_unchecked() }, + _ => unreachable!() + }; + unsafe { transmute::<&mut H, &mut H>(ret) } + }) + } + + fn get_or_parse(&self) -> Option<&RWLock> { + self.data.find(&CaseInsensitive(Slice(header_name::()))).and_then(|item| { + let done = match *item.read() { + // Huge borrowck hack here, should be refactored to just return here. + Typed(ref typed) if typed.is::() => true, + + // Typed, wrong type. + Typed(_) => return None, + + // Raw, work to do. + Raw(_) => false, + }; + + // borrowck hack continued + if done { return Some(item); } + + // Take out a write lock to do the parsing and mutation. + let mut write = item.write(); + + let header = match *write { + // Since this lock can queue, it's possible another thread just + // did the work for us. + // + // Check they inserted the correct type and move on. + Typed(ref typed) if typed.is::() => return Some(item), + + // Wrong type, another thread got here before us and parsed + // as a different representation. + Typed(_) => return None, + + // We are first in the queue or the only ones, so do the actual + // work of parsing and mutation. + Raw(ref raw) => match Header::parse_header(raw.as_slice()) { + Some::(h) => h, + None => return None + } + }; + + // Mutate in the raw case. + *write = Typed(box header as Box
); + Some(item) }) } @@ -229,7 +261,7 @@ impl fmt::Show for Headers { /// An `Iterator` over the fields in a `Headers` map. pub struct HeadersItems<'a> { - inner: Entries<'a, CaseInsensitive, Item> + inner: Entries<'a, CaseInsensitive, RWLock> } impl<'a> Iterator<(&'a str, HeaderView<'a>)> for HeadersItems<'a> { @@ -242,12 +274,12 @@ impl<'a> Iterator<(&'a str, HeaderView<'a>)> for HeadersItems<'a> { } /// Returned with the `HeadersItems` iterator. -pub struct HeaderView<'a>(&'a Item); +pub struct HeaderView<'a>(&'a RWLock); impl<'a> fmt::Show for HeaderView<'a> { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { let HeaderView(item) = *self; - item.fmt(fmt) + item.read().fmt(fmt) } } @@ -265,7 +297,7 @@ impl Mutable for Headers { enum Item { Raw(Vec>), - Typed(Box
) + Typed(Box
) } impl fmt::Show for Item { @@ -341,8 +373,8 @@ mod tests { #[test] fn test_from_raw() { - let mut headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap(); - assert_eq!(headers.get_ref(), Some(&ContentLength(10))); + let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap(); + assert_eq!(headers.get(), Some(&ContentLength(10))); } #[test] @@ -380,8 +412,31 @@ mod tests { #[test] fn test_different_structs_for_same_header() { - let mut headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap(); - let ContentLength(_) = headers.get::().unwrap(); + let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap(); + let ContentLength(_) = *headers.get::().unwrap(); assert!(headers.get::().is_none()); } + + #[test] + fn test_multiple_reads() { + let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap(); + let ContentLength(one) = *headers.get::().unwrap(); + let ContentLength(two) = *headers.get::().unwrap(); + assert_eq!(one, two); + } + + #[test] + fn test_different_reads() { + let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\nContent-Type: text/plain\r\n\r\n")).unwrap(); + let ContentLength(_) = *headers.get::().unwrap(); + let ContentType(_) = *headers.get::().unwrap(); + } + + #[test] + fn test_get_mutable() { + let mut headers = Headers::from_raw(&mut mem("Content-Length: 10\r\nContent-Type: text/plain\r\n\r\n")).unwrap(); + *headers.get_mut::().unwrap() = ContentLength(20); + assert_eq!(*headers.get::().unwrap(), ContentLength(20)); + } } + diff --git a/src/server/request.rs b/src/server/request.rs index 39f81d58..91989dcd 100644 --- a/src/server/request.rs +++ b/src/server/request.rs @@ -39,14 +39,14 @@ impl Request { let remote_addr = try_io!(stream.peer_name()); let mut stream = BufferedReader::new(stream.abstract()); let (method, uri, version) = try!(read_request_line(&mut stream)); - let mut headers = try!(Headers::from_raw(&mut stream)); + let headers = try!(Headers::from_raw(&mut stream)); debug!("{} {} {}", method, uri, version); debug!("{}", headers); let body = if headers.has::() { - match headers.get_ref::() { + match headers.get::() { Some(&ContentLength(len)) => SizedReader(stream, len), None => unreachable!() } diff --git a/src/server/response.rs b/src/server/response.rs index aa779a24..040d1796 100644 --- a/src/server/response.rs +++ b/src/server/response.rs @@ -72,7 +72,7 @@ impl Response { let mut chunked = true; let mut len = 0; - match self.headers.get_ref::() { + match self.headers.get::() { Some(cl) => { chunked = false; len = cl.len(); @@ -82,16 +82,19 @@ impl Response { // cant do in match above, thanks borrowck if chunked { - //TODO: use CollectionViews (when implemented) to prevent double hash/lookup - let encodings = match self.headers.get::() { - Some(common::TransferEncoding(mut encodings)) => { + let encodings = match self.headers.get_mut::() { + Some(&common::TransferEncoding(ref mut encodings)) => { //TODO: check if chunked is already in encodings. use HashSet? encodings.push(common::transfer_encoding::Chunked); - encodings + false }, - None => vec![common::transfer_encoding::Chunked] + None => true }; - self.headers.set(common::TransferEncoding(encodings)); + + if encodings { + self.headers.set::( + common::TransferEncoding(vec![common::transfer_encoding::Chunked])) + } } for (name, header) in self.headers.iter() {