diff --git a/src/header/mod.rs b/src/header/mod.rs index 5c56b123..764f9ec3 100644 --- a/src/header/mod.rs +++ b/src/header/mod.rs @@ -76,7 +76,7 @@ fn header_name() -> &'static str { /// A map of header fields on requests and responses. pub struct Headers { - data: HashMap> + data: HashMap, RWLock> } impl Headers { @@ -96,12 +96,12 @@ impl Headers { Some((name, value)) => { let name = CaseInsensitive(Owned(name)); let item = match headers.data.entry(name) { - Vacant(entry) => entry.set(RWLock::new(Raw(vec![]))), + Vacant(entry) => entry.set(RWLock::new(Item::raw(vec![]))), Occupied(entry) => entry.into_mut() }; - match &mut *item.write() { - &Raw(ref mut raw) => raw.push(value), + match &mut item.write().raw { + &Some(ref mut raw) => raw.push(value), // Unreachable _ => {} }; @@ -116,41 +116,55 @@ impl Headers { /// /// The field is determined by the type of the value being set. pub fn set(&mut self, value: H) { - self.data.insert(CaseInsensitive(Slice(header_name::())), RWLock::new(Typed(box value as Box
))); + self.data.insert(CaseInsensitive(Slice(header_name::())), + RWLock::new(Item::typed(box value as Box
))); } - /// Access the raw value of a header, if it exists and has not - /// been already parsed. + /// Access the raw value of a header. /// - /// If the header field has already been parsed into a typed header, - /// then you *must* access it through that representation. - /// - /// This operation is unsafe because the raw representation can be - /// invalidated by lasting too long or by the header being parsed - /// while you still have a reference to the data. + /// Prefer to use the typed getters instead. /// /// Example: + /// /// ``` /// # use hyper::header::Headers; /// # let mut headers = Headers::new(); - /// let raw_content_type = unsafe { headers.get_raw("content-type") }; + /// let raw_content_type = headers.get_raw("content-type"); /// ``` - pub unsafe fn get_raw(&self, name: &'static str) -> Option<*const [Vec]> { - self.data.find(&CaseInsensitive(Slice(name))).and_then(|item| { - match *item.read() { - Raw(ref raw) => Some(raw.as_slice() as *const [Vec]), - _ => None + pub fn get_raw(&self, name: &str) -> Option<&[Vec]> { + self.data.find_equiv(&CaseInsensitive(name)).and_then(|item| { + let lock = item.read(); + if let Some(ref raw) = lock.raw { + return unsafe { transmute(Some(raw[])) }; } + + let mut lock = item.write(); + let raw = vec![lock.typed.as_ref().unwrap().to_string().into_bytes()]; + lock.raw = Some(raw); + unsafe { transmute(Some(lock.raw.as_ref().unwrap()[])) } }) } + /// Set the raw value of a header, bypassing any typed headers. + /// + /// Example: + /// + /// ``` + /// # use hyper::header::Headers; + /// # let mut headers = Headers::new(); + /// headers.set_raw("content-length", vec!["5".as_bytes().to_vec()]); + /// ``` + pub fn set_raw>(&mut self, name: K, value: Vec>) { + self.data.insert(CaseInsensitive(name.into_maybe_owned()), RWLock::new(Item::raw(value))); + } + /// Get a reference to the header field's value, if it exists. pub fn get(&self) -> Option<&H> { self.get_or_parse::().map(|item| { let read = item.read(); debug!("downcasting {}", *read); - let ret = match *read { - Typed(ref val) => unsafe { val.downcast_ref_unchecked() }, + let ret = match read.typed { + Some(ref val) => unsafe { val.downcast_ref_unchecked() }, _ => unreachable!() }; unsafe { transmute::<&H, &H>(ret) } @@ -162,8 +176,8 @@ impl Headers { self.get_or_parse::().map(|item| { let mut write = item.write(); debug!("downcasting {}", *write); - let ret = match *&mut *write { - Typed(ref mut val) => unsafe { val.downcast_mut_unchecked() }, + let ret = match *&mut write.typed { + Some(ref mut val) => unsafe { val.downcast_mut_unchecked() }, _ => unreachable!() }; unsafe { transmute::<&mut H, &mut H>(ret) } @@ -172,44 +186,47 @@ impl Headers { fn get_or_parse(&self) -> Option<&RWLock> { self.data.find(&CaseInsensitive(Slice(header_name::()))).and_then(|item| { - let done = match *item.read() { - // Huge borrowck hack here, should be refactored to just return here. - Typed(ref typed) if typed.is::() => true, - - // Typed, wrong type. - Typed(_) => return None, - - // Raw, work to do. - Raw(_) => false, - }; - - // borrowck hack continued - if done { return Some(item); } + match item.read().typed { + Some(ref typed) if typed.is::() => return Some(item), + Some(ref typed) => { + warn!("attempted to access {} as wrong type", typed); + return None; + } + _ => () + } // Take out a write lock to do the parsing and mutation. let mut write = item.write(); - let header = match *write { - // Since this lock can queue, it's possible another thread just - // did the work for us. - // + // Since this lock can queue, it's possible another thread just + // did the work for us. + match write.typed { // Check they inserted the correct type and move on. - Typed(ref typed) if typed.is::() => return Some(item), + Some(ref typed) if typed.is::() => return Some(item), // Wrong type, another thread got here before us and parsed // as a different representation. - Typed(_) => return None, + Some(ref typed) => { + debug!("other thread was here first?") + warn!("attempted to access {} as wrong type", typed); + return None; + }, // We are first in the queue or the only ones, so do the actual // work of parsing and mutation. - Raw(ref raw) => match Header::parse_header(raw.as_slice()) { + _ => () + } + + let header = match write.raw { + Some(ref raw) => match Header::parse_header(raw[]) { Some::(h) => h, None => return None - } + }, + None => unreachable!() }; - // Mutate in the raw case. - *write = Typed(box header as Box
); + // Mutate! + write.typed = Some(box header as Box
); Some(item) }) } @@ -253,7 +270,7 @@ impl fmt::Show for Headers { /// An `Iterator` over the fields in a `Headers` map. pub struct HeadersItems<'a> { - inner: Entries<'a, CaseInsensitive, RWLock> + inner: Entries<'a, CaseInsensitive, RWLock> } impl<'a> Iterator<(&'a str, HeaderView<'a>)> for HeadersItems<'a> { @@ -287,28 +304,53 @@ impl Mutable for Headers { } } -enum Item { - Raw(Vec>), - Typed(Box
) +struct Item { + raw: Option>>, + typed: Option> } -impl fmt::Show for Item { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - match *self { - Typed(ref h) => h.fmt_header(fmt), - Raw(ref raw) => { - for part in raw.iter() { - try!(fmt.write(part.as_slice())); - } - Ok(()) - }, +impl Item { + fn raw(data: Vec>) -> Item { + Item { + raw: Some(data), + typed: None, + } + } + + fn typed(ty: Box
) -> Item { + Item { + raw: None, + typed: Some(ty), } } } -struct CaseInsensitive(SendStr); +impl fmt::Show for Item { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self.typed { + Some(ref h) => h.fmt_header(fmt), + None => match self.raw { + Some(ref raw) => { + for part in raw.iter() { + try!(fmt.write(part.as_slice())); + } + Ok(()) + }, + None => unreachable!() + } + } + } +} -impl Str for CaseInsensitive { +impl fmt::Show for Box
{ + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + (**self).fmt_header(fmt) + } +} + +struct CaseInsensitive(S); + +impl Str for CaseInsensitive { fn as_slice(&self) -> &str { let CaseInsensitive(ref s) = *self; s.as_slice() @@ -316,21 +358,29 @@ impl Str for CaseInsensitive { } -impl fmt::Show for CaseInsensitive { +impl fmt::Show for CaseInsensitive { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { self.as_slice().fmt(fmt) } } -impl PartialEq for CaseInsensitive { - fn eq(&self, other: &CaseInsensitive) -> bool { +impl PartialEq for CaseInsensitive { + fn eq(&self, other: &CaseInsensitive) -> bool { self.as_slice().eq_ignore_ascii_case(other.as_slice()) } } -impl Eq for CaseInsensitive {} +impl Eq for CaseInsensitive {} -impl hash::Hash for CaseInsensitive { +impl Equiv> for CaseInsensitive { + fn equiv(&self, other: &CaseInsensitive) -> bool { + let left = CaseInsensitive(self.as_slice()); + let right = CaseInsensitive(other.as_slice()); + left == right + } +} + +impl hash::Hash for CaseInsensitive { #[inline] fn hash(&self, hasher: &mut H) { for byte in self.as_slice().bytes() { @@ -387,7 +437,7 @@ mod tests { assert_eq!(accept, Some(Accept(vec![application_vendor, text_plain]))); } - #[deriving(Clone)] + #[deriving(Clone, Show)] struct CrazyLength(Option, uint); impl Header for CrazyLength { @@ -408,9 +458,8 @@ mod tests { }.map(|u| CrazyLength(Some(false), u)) } fn fmt_header(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - use std::fmt::Show; - let CrazyLength(_, ref value) = *self; - value.fmt(fmt) + let CrazyLength(ref opt, ref value) = *self; + write!(fmt, "{}, {}", opt, value) } } @@ -455,7 +504,15 @@ mod tests { pieces.sort(); let s = pieces.into_iter().rev().collect::>().connect("\r\n"); assert_eq!(s[], "Host: foo.bar\r\nContent-Length: 15\r\n"); + } + #[test] + fn test_set_raw() { + let mut headers = Headers::new(); + headers.set(ContentLength(10)); + headers.set_raw("content-LENGTH", vec![b"20".to_vec()]); + assert_eq!(headers.get_raw("Content-length").unwrap(), [b"20".to_vec()][]); + assert_eq!(headers.get(), Some(&ContentLength(20))); } }