Add get_mut for modifying the typed representation of Headers.

Also adds an associated test and updates code to use it instead
of cloning and setting when possible.
This commit is contained in:
Jonathan Reem
2014-09-20 05:52:41 -07:00
parent 858a09304a
commit d3a62fa0d5
3 changed files with 65 additions and 27 deletions

View File

@@ -121,16 +121,19 @@ impl Request<Fresh> {
// cant do in match above, thanks borrowck
if chunked {
//TODO: use CollectionViews (when implemented) to prevent double hash/lookup
let encodings = match self.headers.get::<common::TransferEncoding>().map(|h| h.clone()) {
Some(common::TransferEncoding(mut encodings)) => {
let encodings = match self.headers.get_mut::<common::TransferEncoding>() {
Some(&common::TransferEncoding(ref mut encodings)) => {
//TODO: check if chunked is already in encodings. use HashSet?
encodings.push(common::transfer_encoding::Chunked);
encodings
false
},
None => vec![common::transfer_encoding::Chunked]
None => true
};
self.headers.set(common::TransferEncoding(encodings));
if encodings {
self.headers.set::<common::TransferEncoding>(
common::TransferEncoding(vec![common::transfer_encoding::Chunked]))
}
}
for (name, header) in self.headers.iter() {

View File

@@ -15,7 +15,7 @@ use std::string::raw;
use std::collections::hashmap::{HashMap, Entries, Occupied, Vacant};
use std::sync::RWLock;
use uany::UncheckedAnyDowncast;
use uany::{UncheckedAnyDowncast, UncheckedAnyMutDowncast};
use typeable::Typeable;
use http::read_header;
@@ -62,6 +62,14 @@ impl<'a> UncheckedAnyDowncast<'a> for &'a Header {
}
}
impl<'a> UncheckedAnyMutDowncast<'a> for &'a mut Header {
#[inline]
unsafe fn downcast_mut_unchecked<T: 'static>(self) -> &'a mut T {
let to: TraitObject = transmute_copy(&self);
transmute(to.data)
}
}
fn header_name<T: Header>() -> &'static str {
let name = Header::header_name(None::<T>);
name
@@ -145,6 +153,31 @@ impl Headers {
/// Get a reference to the header field's value, if it exists.
pub fn get<H: Header>(&self) -> Option<&H> {
self.get_or_parse::<H>().map(|item| {
let read = item.read();
debug!("downcasting {}", *read);
let ret = match *read {
Typed(ref val) => unsafe { val.downcast_ref_unchecked() },
_ => unreachable!()
};
unsafe { transmute::<&H, &H>(ret) }
})
}
/// Get a mutable reference to the header field's value, if it exists.
pub fn get_mut<H: Header>(&mut self) -> Option<&mut H> {
self.get_or_parse::<H>().map(|item| {
let mut write = item.write();
debug!("downcasting {}", *write);
let ret = match *&mut *write {
Typed(ref mut val) => unsafe { val.downcast_mut_unchecked() },
_ => unreachable!()
};
unsafe { transmute::<&mut H, &mut H>(ret) }
})
}
fn get_or_parse<H: Header>(&self) -> Option<&RWLock<Item>> {
self.data.find(&CaseInsensitive(Slice(header_name::<H>()))).and_then(|item| {
let done = match *item.read() {
// Huge borrowck hack here, should be refactored to just return here.
@@ -185,14 +218,6 @@ impl Headers {
// Mutate in the raw case.
*write = Typed(box header as Box<Header + Send + Sync>);
Some(item)
}).map(|item| {
let read = item.read();
debug!("downcasting {}", *read);
let ret = match *read {
Typed(ref val) => unsafe { val.downcast_ref_unchecked() },
_ => unreachable!()
};
unsafe { transmute::<&H, &H>(ret) }
})
}
@@ -349,7 +374,7 @@ mod tests {
#[test]
fn test_from_raw() {
let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap();
assert_eq!(headers.get_ref(), Some(&ContentLength(10)));
assert_eq!(headers.get(), Some(&ContentLength(10)));
}
#[test]
@@ -388,23 +413,30 @@ mod tests {
#[test]
fn test_different_structs_for_same_header() {
let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap();
let ContentLength(_) = headers.get::<ContentLength>().unwrap();
let ContentLength(_) = *headers.get::<ContentLength>().unwrap();
assert!(headers.get::<CrazyLength>().is_none());
}
#[test]
fn test_multiple_reads() {
let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\n\r\n")).unwrap();
let ContentLength(one) = headers.get::<ContentLength>().unwrap();
let ContentLength(two) = headers.get::<ContentLength>().unwrap();
let ContentLength(one) = *headers.get::<ContentLength>().unwrap();
let ContentLength(two) = *headers.get::<ContentLength>().unwrap();
assert_eq!(one, two);
}
#[test]
fn test_different_reads() {
let headers = Headers::from_raw(&mut mem("Content-Length: 10\r\nContent-Type: text/plain\r\n\r\n")).unwrap();
let ContentLength(_) = headers.get::<ContentLength>().unwrap();
let ContentType(_) = headers.get::<ContentType>().unwrap();
let ContentLength(_) = *headers.get::<ContentLength>().unwrap();
let ContentType(_) = *headers.get::<ContentType>().unwrap();
}
#[test]
fn test_get_mutable() {
let mut headers = Headers::from_raw(&mut mem("Content-Length: 10\r\nContent-Type: text/plain\r\n\r\n")).unwrap();
*headers.get_mut::<ContentLength>().unwrap() = ContentLength(20);
assert_eq!(*headers.get::<ContentLength>().unwrap(), ContentLength(20));
}
}

View File

@@ -82,16 +82,19 @@ impl Response<Fresh> {
// cant do in match above, thanks borrowck
if chunked {
//TODO: use CollectionViews (when implemented) to prevent double hash/lookup
let encodings = match self.headers.get::<common::TransferEncoding>().map(|h| h.clone()) {
Some(common::TransferEncoding(mut encodings)) => {
let encodings = match self.headers.get_mut::<common::TransferEncoding>() {
Some(&common::TransferEncoding(ref mut encodings)) => {
//TODO: check if chunked is already in encodings. use HashSet?
encodings.push(common::transfer_encoding::Chunked);
encodings
false
},
None => vec![common::transfer_encoding::Chunked]
None => true
};
self.headers.set(common::TransferEncoding(encodings));
if encodings {
self.headers.set::<common::TransferEncoding>(
common::TransferEncoding(vec![common::transfer_encoding::Chunked]))
}
}
for (name, header) in self.headers.iter() {