From cb1927553ee1db4aaf09f0b77f0d9aae174114e8 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Wed, 5 Apr 2017 11:40:57 -0700 Subject: [PATCH] fix(uri): fix Uri to_origin_form to always include '/' Closes #1112 --- src/http/h1/parse.rs | 2 +- src/http/str.rs | 15 ++++++- src/uri.rs | 94 ++++++++++++++++++++++++++++++++++++++------ 3 files changed, 96 insertions(+), 15 deletions(-) diff --git a/src/http/h1/parse.rs b/src/http/h1/parse.rs index b6e85cb1..dc7e6a00 100644 --- a/src/http/h1/parse.rs +++ b/src/http/h1/parse.rs @@ -344,7 +344,7 @@ mod tests { let (req, len) = parse::(&mut raw).unwrap().unwrap(); assert_eq!(len, expected_len); assert_eq!(req.subject.0, ::Method::Get); - assert_eq!(req.subject.1, "/echo".parse().unwrap()); + assert_eq!(req.subject.1, "/echo"); assert_eq!(req.version, ::HttpVersion::Http11); assert_eq!(req.headers.len(), 1); assert_eq!(req.headers.get_raw("Host").map(|raw| &raw[0]), Some(b"hyper.rs".as_ref())); diff --git a/src/http/str.rs b/src/http/str.rs index 4479a01b..6fd22364 100644 --- a/src/http/str.rs +++ b/src/http/str.rs @@ -1,7 +1,7 @@ use std::ops::Deref; use std::str; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ByteStr(Bytes); @@ -38,6 +38,19 @@ impl Deref for ByteStr { } } + +impl From for Bytes { + fn from(s: ByteStr) -> Bytes { + s.0 + } +} + +impl From for BytesMut { + fn from(s: ByteStr) -> BytesMut { + s.0.into() + } +} + impl<'a> From<&'a str> for ByteStr { fn from(s: &'a str) -> ByteStr { ByteStr(Bytes::from(s)) diff --git a/src/uri.rs b/src/uri.rs index c6821a0a..e2b57c9d 100644 --- a/src/uri.rs +++ b/src/uri.rs @@ -3,6 +3,7 @@ use std::fmt::{Display, self}; use std::str::{self, FromStr}; use http::ByteStr; +use bytes::{BufMut, BytesMut}; /// The Request-URI of a Request's StartLine. /// @@ -101,14 +102,8 @@ impl Uri { /// Get the path of this `Uri`. pub fn path(&self) -> &str { - let index = self.authority_end.unwrap_or(self.scheme_end.unwrap_or(0)); - let end = if let Some(query) = self.query_start { - query - } else if let Some(fragment) = self.fragment_start { - fragment - } else { - self.source.len() - }; + let index = self.path_start(); + let end = self.path_end(); if index >= end { if self.scheme().is_some() { "/" // absolute-form MUST have path @@ -120,6 +115,31 @@ impl Uri { } } + #[inline] + fn path_start(&self) -> usize { + self.authority_end.unwrap_or(self.scheme_end.unwrap_or(0)) + } + + #[inline] + fn path_end(&self) -> usize { + if let Some(query) = self.query_start { + query + } else if let Some(fragment) = self.fragment_start { + fragment + } else { + self.source.len() + } + } + + #[inline] + fn origin_form_end(&self) -> usize { + if let Some(fragment) = self.fragment_start { + fragment + } else { + self.source.len() + } + } + /// Get the scheme of this `Uri`. pub fn scheme(&self) -> Option<&str> { if let Some(end) = self.scheme_end { @@ -226,6 +246,18 @@ impl PartialEq for Uri { } } +impl<'a> PartialEq<&'a str> for Uri { + fn eq(&self, other: & &'a str) -> bool { + self.source.as_str() == *other + } +} + +impl<'a> PartialEq for &'a str{ + fn eq(&self, other: &Uri) -> bool { + *self == other.source.as_str() + } +} + impl Eq for Uri {} impl AsRef for Uri { @@ -277,14 +309,24 @@ pub fn scheme_and_authority(uri: &Uri) -> Option { } pub fn origin_form(uri: &Uri) -> Uri { - let start = uri.authority_end.unwrap_or(uri.scheme_end.unwrap_or(0)); - let end = if let Some(f) = uri.fragment_start { - f + let range = Range(uri.path_start(), uri.origin_form_end()); + + let clone = if range.len() == 0 { + ByteStr::from_static("/") + } else if uri.source.as_bytes()[range.0] != b'/' { + let mut new = BytesMut::with_capacity(range.1 - range.0 + 1); + new.put_u8(b'/'); + new.put_slice(&uri.source.as_bytes()[range.0..range.1]); + // safety: the bytes are '/' + previous utf8 str + unsafe { ByteStr::from_utf8_unchecked(new.freeze()) } + } else if range.0 == 0 && range.1 == uri.source.len() { + uri.source.clone() } else { - uri.source.len() + uri.source.slice(range.0, range.1) }; + Uri { - source: uri.source.slice(start, end), + source: clone, scheme_end: None, authority_end: None, query_start: uri.query_start, @@ -292,6 +334,14 @@ pub fn origin_form(uri: &Uri) -> Uri { } } +struct Range(usize, usize); + +impl Range { + fn len(&self) -> usize { + self.1 - self.0 + } +} + /// An error parsing a `Uri`. #[derive(Clone, Debug)] pub struct UriError(ErrorKind); @@ -480,3 +530,21 @@ fn test_uri_parse_error() { err("localhost/"); err("localhost?key=val"); } + +#[test] +fn test_uri_to_origin_form() { + let cases = vec![ + ("/", "/"), + ("/foo?bar", "/foo?bar"), + ("/foo?bar#nope", "/foo?bar"), + ("http://hyper.rs", "/"), + ("http://hyper.rs/", "/"), + ("http://hyper.rs/path", "/path"), + ("http://hyper.rs?query", "/?query"), + ]; + + for case in cases { + let uri = Uri::from_str(case.0).unwrap(); + assert_eq!(origin_form(&uri), case.1); + } +}