Add get_mut for modifying the typed representation of Headers.

Also adds an associated test and updates code to use it instead
of cloning and setting when possible.
This commit is contained in:
Jonathan Reem
2014-09-20 05:52:41 -07:00
parent 858a09304a
commit d3a62fa0d5
3 changed files with 65 additions and 27 deletions

View File

@@ -121,16 +121,19 @@ impl Request<Fresh> {
// cant do in match above, thanks borrowck // cant do in match above, thanks borrowck
if chunked { if chunked {
//TODO: use CollectionViews (when implemented) to prevent double hash/lookup let encodings = match self.headers.get_mut::<common::TransferEncoding>() {
let encodings = match self.headers.get::<common::TransferEncoding>().map(|h| h.clone()) { Some(&common::TransferEncoding(ref mut encodings)) => {
Some(common::TransferEncoding(mut encodings)) => {
//TODO: check if chunked is already in encodings. use HashSet? //TODO: check if chunked is already in encodings. use HashSet?
encodings.push(common::transfer_encoding::Chunked); encodings.push(common::transfer_encoding::Chunked);
encodings false
}, },
None => vec![common::transfer_encoding::Chunked] None => true
}; };
self.headers.set(common::TransferEncoding(encodings));
if encodings {
self.headers.set::<common::TransferEncoding>(
common::TransferEncoding(vec![common::transfer_encoding::Chunked]))
}
} }
for (name, header) in self.headers.iter() { for (name, header) in self.headers.iter() {

View File

@@ -15,7 +15,7 @@ use std::string::raw;
use std::collections::hashmap::{HashMap, Entries, Occupied, Vacant}; use std::collections::hashmap::{HashMap, Entries, Occupied, Vacant};
use std::sync::RWLock; use std::sync::RWLock;
use uany::UncheckedAnyDowncast; use uany::{UncheckedAnyDowncast, UncheckedAnyMutDowncast};
use typeable::Typeable; use typeable::Typeable;
use http::read_header; use http::read_header;
@@ -62,6 +62,14 @@ impl<'a> UncheckedAnyDowncast<'a> for &'a Header {
} }
} }
impl<'a> UncheckedAnyMutDowncast<'a> for &'a mut Header {
#[inline]
unsafe fn downcast_mut_unchecked<T: 'static>(self) -> &'a mut T {
let to: TraitObject = transmute_copy(&self);
transmute(to.data)
}
}
fn header_name<T: Header>() -> &'static str { fn header_name<T: Header>() -> &'static str {
let name = Header::header_name(None::<T>); let name = Header::header_name(None::<T>);
name name
@@ -145,6 +153,31 @@ impl Headers {
/// Get a reference to the header field's value, if it exists. /// Get a reference to the header field's value, if it exists.
pub fn get<H: Header>(&self) -> Option<&H> { pub fn get<H: Header>(&self) -> Option<&H> {
self.get_or_parse::<H>().map(|item| {
let read = item.read();
debug!("downcasting {}", *read);
let ret = match *read {
Typed(ref val) => unsafe { val.downcast_ref_unchecked() },
_ => unreachable!()
};
unsafe { transmute::<&H, &H>(ret) }
})
}
/// Get a mutable reference to the header field's value, if it exists.
pub fn get_mut<H: Header>(&mut self) -> Option<&mut H> {
self.get_or_parse::<H>().map(|item| {
let mut write = item.write();
debug!("downcasting {}", *write);
let ret = match *&mut *write {
Typed(ref mut val) => unsafe { val.downcast_mut_unchecked() },
_ => unreachable!()
};
unsafe { transmute::<&mut H, &mut H>(ret) }
})
}
fn get_or_parse<H: Header>(&self) -> Option<&RWLock<Item>> {
self.data.find(&CaseInsensitive(Slice(header_name::<H>()))).and_then(|item| { self.data.find(&CaseInsensitive(Slice(header_name::<H>()))).and_then(|item| {
let done = match *item.read() { let done = match *item.read() {
// Huge borrowck hack here, should be refactored to just return here. // Huge borrowck hack here, should be refactored to just return here.
@@ -185,14 +218,6 @@ impl Headers {
// Mutate in the raw case. // Mutate in the raw case.
*write = Typed(box header as Box<Header + Send + Sync>); *write = Typed(box header as Box<Header + Send + Sync>);
Some(item) Some(item)
}).map(|item| {
let read = item.read();
debug!("downcasting {}", *read);
let ret = match *read {
Typed(ref val) => unsafe { val.downcast_ref_unchecked() },
_ => unreachable!()
};
unsafe { transmute::<&H, &H>(ret) }
}) })
} }
@@ -349,7 +374,7 @@ mod tests {
#[test] #[test]
fn test_from_raw() { fn test_from_raw() {
let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap(); let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap();
assert_eq!(headers.get_ref(), Some(&ContentLength(10))); assert_eq!(headers.get(), Some(&ContentLength(10)));
} }
#[test] #[test]
@@ -388,23 +413,30 @@ mod tests {
#[test] #[test]
fn test_different_structs_for_same_header() { fn test_different_structs_for_same_header() {
let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap(); let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap();
let ContentLength(_) = headers.get::<ContentLength>().unwrap(); let ContentLength(_) = *headers.get::<ContentLength>().unwrap();
assert!(headers.get::<CrazyLength>().is_none()); assert!(headers.get::<CrazyLength>().is_none());
} }
#[test] #[test]
fn test_multiple_reads() { fn test_multiple_reads() {
let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap(); let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap();
let ContentLength(one) = headers.get::<ContentLength>().unwrap(); let ContentLength(one) = *headers.get::<ContentLength>().unwrap();
let ContentLength(two) = headers.get::<ContentLength>().unwrap(); let ContentLength(two) = *headers.get::<ContentLength>().unwrap();
assert_eq!(one, two); assert_eq!(one, two);
} }
#[test] #[test]
fn test_different_reads() { fn test_different_reads() {
let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\nContent-Type: text/plain\r\n\r\n")).unwrap(); let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\nContent-Type: text/plain\r\n\r\n")).unwrap();
let ContentLength(_) = headers.get::<ContentLength>().unwrap(); let ContentLength(_) = *headers.get::<ContentLength>().unwrap();
let ContentType(_) = headers.get::<ContentType>().unwrap(); let ContentType(_) = *headers.get::<ContentType>().unwrap();
}
#[test]
fn test_get_mutable() {
let mut headers = Headers::from_raw(&mut mem("Content-Length: 10\r\nContent-Type: text/plain\r\n\r\n")).unwrap();
*headers.get_mut::<ContentLength>().unwrap() = ContentLength(20);
assert_eq!(*headers.get::<ContentLength>().unwrap(), ContentLength(20));
} }
} }

View File

@@ -82,16 +82,19 @@ impl Response<Fresh> {
// cant do in match above, thanks borrowck // cant do in match above, thanks borrowck
if chunked { if chunked {
//TODO: use CollectionViews (when implemented) to prevent double hash/lookup let encodings = match self.headers.get_mut::<common::TransferEncoding>() {
let encodings = match self.headers.get::<common::TransferEncoding>().map(|h| h.clone()) { Some(&common::TransferEncoding(ref mut encodings)) => {
Some(common::TransferEncoding(mut encodings)) => {
//TODO: check if chunked is already in encodings. use HashSet? //TODO: check if chunked is already in encodings. use HashSet?
encodings.push(common::transfer_encoding::Chunked); encodings.push(common::transfer_encoding::Chunked);
encodings false
}, },
None => vec![common::transfer_encoding::Chunked] None => true
}; };
self.headers.set(common::TransferEncoding(encodings));
if encodings {
self.headers.set::<common::TransferEncoding>(
common::TransferEncoding(vec![common::transfer_encoding::Chunked]))
}
} }
for (name, header) in self.headers.iter() { for (name, header) in self.headers.iter() {