fix(uri): fix Uri to_origin_form to always include '/'

Closes #1112
This commit is contained in:
Sean McArthur
2017-04-05 11:40:57 -07:00
parent 1b49237e27
commit cb1927553e
3 changed files with 96 additions and 15 deletions

View File

@@ -344,7 +344,7 @@ mod tests {
let (req, len) = parse::<http::ServerTransaction, _>(&mut raw).unwrap().unwrap(); let (req, len) = parse::<http::ServerTransaction, _>(&mut raw).unwrap().unwrap();
assert_eq!(len, expected_len); assert_eq!(len, expected_len);
assert_eq!(req.subject.0, ::Method::Get); assert_eq!(req.subject.0, ::Method::Get);
assert_eq!(req.subject.1, "/echo".parse().unwrap()); assert_eq!(req.subject.1, "/echo");
assert_eq!(req.version, ::HttpVersion::Http11); assert_eq!(req.version, ::HttpVersion::Http11);
assert_eq!(req.headers.len(), 1); assert_eq!(req.headers.len(), 1);
assert_eq!(req.headers.get_raw("Host").map(|raw| &raw[0]), Some(b"hyper.rs".as_ref())); assert_eq!(req.headers.get_raw("Host").map(|raw| &raw[0]), Some(b"hyper.rs".as_ref()));

View File

@@ -1,7 +1,7 @@
use std::ops::Deref; use std::ops::Deref;
use std::str; use std::str;
use bytes::Bytes; use bytes::{Bytes, BytesMut};
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ByteStr(Bytes); pub struct ByteStr(Bytes);
@@ -38,6 +38,19 @@ impl Deref for ByteStr {
} }
} }
impl From<ByteStr> for Bytes {
fn from(s: ByteStr) -> Bytes {
s.0
}
}
impl From<ByteStr> for BytesMut {
fn from(s: ByteStr) -> BytesMut {
s.0.into()
}
}
impl<'a> From<&'a str> for ByteStr { impl<'a> From<&'a str> for ByteStr {
fn from(s: &'a str) -> ByteStr { fn from(s: &'a str) -> ByteStr {
ByteStr(Bytes::from(s)) ByteStr(Bytes::from(s))

View File

@@ -3,6 +3,7 @@ use std::fmt::{Display, self};
use std::str::{self, FromStr}; use std::str::{self, FromStr};
use http::ByteStr; use http::ByteStr;
use bytes::{BufMut, BytesMut};
/// The Request-URI of a Request's StartLine. /// The Request-URI of a Request's StartLine.
/// ///
@@ -101,14 +102,8 @@ impl Uri {
/// Get the path of this `Uri`. /// Get the path of this `Uri`.
pub fn path(&self) -> &str { pub fn path(&self) -> &str {
let index = self.authority_end.unwrap_or(self.scheme_end.unwrap_or(0)); let index = self.path_start();
let end = if let Some(query) = self.query_start { let end = self.path_end();
query
} else if let Some(fragment) = self.fragment_start {
fragment
} else {
self.source.len()
};
if index >= end { if index >= end {
if self.scheme().is_some() { if self.scheme().is_some() {
"/" // absolute-form MUST have path "/" // absolute-form MUST have path
@@ -120,6 +115,31 @@ impl Uri {
} }
} }
#[inline]
fn path_start(&self) -> usize {
self.authority_end.unwrap_or(self.scheme_end.unwrap_or(0))
}
#[inline]
fn path_end(&self) -> usize {
if let Some(query) = self.query_start {
query
} else if let Some(fragment) = self.fragment_start {
fragment
} else {
self.source.len()
}
}
#[inline]
fn origin_form_end(&self) -> usize {
if let Some(fragment) = self.fragment_start {
fragment
} else {
self.source.len()
}
}
/// Get the scheme of this `Uri`. /// Get the scheme of this `Uri`.
pub fn scheme(&self) -> Option<&str> { pub fn scheme(&self) -> Option<&str> {
if let Some(end) = self.scheme_end { if let Some(end) = self.scheme_end {
@@ -226,6 +246,18 @@ impl PartialEq for Uri {
} }
} }
impl<'a> PartialEq<&'a str> for Uri {
fn eq(&self, other: & &'a str) -> bool {
self.source.as_str() == *other
}
}
impl<'a> PartialEq<Uri> for &'a str{
fn eq(&self, other: &Uri) -> bool {
*self == other.source.as_str()
}
}
impl Eq for Uri {} impl Eq for Uri {}
impl AsRef<str> for Uri { impl AsRef<str> for Uri {
@@ -277,14 +309,24 @@ pub fn scheme_and_authority(uri: &Uri) -> Option<Uri> {
} }
pub fn origin_form(uri: &Uri) -> Uri { pub fn origin_form(uri: &Uri) -> Uri {
let start = uri.authority_end.unwrap_or(uri.scheme_end.unwrap_or(0)); let range = Range(uri.path_start(), uri.origin_form_end());
let end = if let Some(f) = uri.fragment_start {
f let clone = if range.len() == 0 {
ByteStr::from_static("/")
} else if uri.source.as_bytes()[range.0] != b'/' {
let mut new = BytesMut::with_capacity(range.1 - range.0 + 1);
new.put_u8(b'/');
new.put_slice(&uri.source.as_bytes()[range.0..range.1]);
// safety: the bytes are '/' + previous utf8 str
unsafe { ByteStr::from_utf8_unchecked(new.freeze()) }
} else if range.0 == 0 && range.1 == uri.source.len() {
uri.source.clone()
} else { } else {
uri.source.len() uri.source.slice(range.0, range.1)
}; };
Uri { Uri {
source: uri.source.slice(start, end), source: clone,
scheme_end: None, scheme_end: None,
authority_end: None, authority_end: None,
query_start: uri.query_start, query_start: uri.query_start,
@@ -292,6 +334,14 @@ pub fn origin_form(uri: &Uri) -> Uri {
} }
} }
struct Range(usize, usize);
impl Range {
fn len(&self) -> usize {
self.1 - self.0
}
}
/// An error parsing a `Uri`. /// An error parsing a `Uri`.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct UriError(ErrorKind); pub struct UriError(ErrorKind);
@@ -480,3 +530,21 @@ fn test_uri_parse_error() {
err("localhost/"); err("localhost/");
err("localhost?key=val"); err("localhost?key=val");
} }
#[test]
fn test_uri_to_origin_form() {
let cases = vec![
("/", "/"),
("/foo?bar", "/foo?bar"),
("/foo?bar#nope", "/foo?bar"),
("http://hyper.rs", "/"),
("http://hyper.rs/", "/"),
("http://hyper.rs/path", "/path"),
("http://hyper.rs?query", "/?query"),
];
for case in cases {
let uri = Uri::from_str(case.0).unwrap();
assert_eq!(origin_form(&uri), case.1);
}
}