diff --git a/src/client/request.rs b/src/client/request.rs index 4130e986..dc48aeb4 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -121,16 +121,19 @@ impl Request { // cant do in match above, thanks borrowck if chunked { - //TODO: use CollectionViews (when implemented) to prevent double hash/lookup - let encodings = match self.headers.get::().map(|h| h.clone()) { - Some(common::TransferEncoding(mut encodings)) => { + let encodings = match self.headers.get_mut::() { + Some(&common::TransferEncoding(ref mut encodings)) => { //TODO: check if chunked is already in encodings. use HashSet? encodings.push(common::transfer_encoding::Chunked); - encodings + false }, - None => vec![common::transfer_encoding::Chunked] + None => true }; - self.headers.set(common::TransferEncoding(encodings)); + + if encodings { + self.headers.set::( + common::TransferEncoding(vec![common::transfer_encoding::Chunked])) + } } for (name, header) in self.headers.iter() { diff --git a/src/header/mod.rs b/src/header/mod.rs index 6a9c7e98..7dc8e9a5 100644 --- a/src/header/mod.rs +++ b/src/header/mod.rs @@ -15,7 +15,7 @@ use std::string::raw; use std::collections::hashmap::{HashMap, Entries, Occupied, Vacant}; use std::sync::RWLock; -use uany::UncheckedAnyDowncast; +use uany::{UncheckedAnyDowncast, UncheckedAnyMutDowncast}; use typeable::Typeable; use http::read_header; @@ -62,6 +62,14 @@ impl<'a> UncheckedAnyDowncast<'a> for &'a Header { } } +impl<'a> UncheckedAnyMutDowncast<'a> for &'a mut Header { + #[inline] + unsafe fn downcast_mut_unchecked(self) -> &'a mut T { + let to: TraitObject = transmute_copy(&self); + transmute(to.data) + } +} + fn header_name() -> &'static str { let name = Header::header_name(None::); name @@ -145,6 +153,31 @@ impl Headers { /// Get a reference to the header field's value, if it exists. pub fn get(&self) -> Option<&H> { + self.get_or_parse::().map(|item| { + let read = item.read(); + debug!("downcasting {}", *read); + let ret = match *read { + Typed(ref val) => unsafe { val.downcast_ref_unchecked() }, + _ => unreachable!() + }; + unsafe { transmute::<&H, &H>(ret) } + }) + } + + /// Get a mutable reference to the header field's value, if it exists. + pub fn get_mut(&mut self) -> Option<&mut H> { + self.get_or_parse::().map(|item| { + let mut write = item.write(); + debug!("downcasting {}", *write); + let ret = match *&mut *write { + Typed(ref mut val) => unsafe { val.downcast_mut_unchecked() }, + _ => unreachable!() + }; + unsafe { transmute::<&mut H, &mut H>(ret) } + }) + } + + fn get_or_parse(&self) -> Option<&RWLock> { self.data.find(&CaseInsensitive(Slice(header_name::()))).and_then(|item| { let done = match *item.read() { // Huge borrowck hack here, should be refactored to just return here. @@ -185,14 +218,6 @@ impl Headers { // Mutate in the raw case. *write = Typed(box header as Box
); Some(item) - }).map(|item| { - let read = item.read(); - debug!("downcasting {}", *read); - let ret = match *read { - Typed(ref val) => unsafe { val.downcast_ref_unchecked() }, - _ => unreachable!() - }; - unsafe { transmute::<&H, &H>(ret) } }) } @@ -349,7 +374,7 @@ mod tests { #[test] fn test_from_raw() { let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap(); - assert_eq!(headers.get_ref(), Some(&ContentLength(10))); + assert_eq!(headers.get(), Some(&ContentLength(10))); } #[test] @@ -388,23 +413,30 @@ mod tests { #[test] fn test_different_structs_for_same_header() { let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap(); - let ContentLength(_) = headers.get::().unwrap(); + let ContentLength(_) = *headers.get::().unwrap(); assert!(headers.get::().is_none()); } #[test] fn test_multiple_reads() { let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap(); - let ContentLength(one) = headers.get::().unwrap(); - let ContentLength(two) = headers.get::().unwrap(); + let ContentLength(one) = *headers.get::().unwrap(); + let ContentLength(two) = *headers.get::().unwrap(); assert_eq!(one, two); } #[test] fn test_different_reads() { let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\nContent-Type: text/plain\r\n\r\n")).unwrap(); - let ContentLength(_) = headers.get::().unwrap(); - let ContentType(_) = headers.get::().unwrap(); + let ContentLength(_) = *headers.get::().unwrap(); + let ContentType(_) = *headers.get::().unwrap(); + } + + #[test] + fn test_get_mutable() { + let mut headers = Headers::from_raw(&mut mem("Content-Length: 10\r\nContent-Type: text/plain\r\n\r\n")).unwrap(); + *headers.get_mut::().unwrap() = ContentLength(20); + assert_eq!(*headers.get::().unwrap(), ContentLength(20)); } } diff --git a/src/server/response.rs b/src/server/response.rs index 800ffc87..040d1796 100644 --- a/src/server/response.rs +++ b/src/server/response.rs @@ -82,16 +82,19 @@ impl Response { // cant do in match above, thanks borrowck if chunked { - //TODO: use CollectionViews (when implemented) to prevent double hash/lookup - let encodings = match self.headers.get::().map(|h| h.clone()) { - Some(common::TransferEncoding(mut encodings)) => { + let encodings = match self.headers.get_mut::() { + Some(&common::TransferEncoding(ref mut encodings)) => { //TODO: check if chunked is already in encodings. use HashSet? encodings.push(common::transfer_encoding::Chunked); - encodings + false }, - None => vec![common::transfer_encoding::Chunked] + None => true }; - self.headers.set(common::TransferEncoding(encodings)); + + if encodings { + self.headers.set::( + common::TransferEncoding(vec![common::transfer_encoding::Chunked])) + } } for (name, header) in self.headers.iter() {