Add HTTP Upgrade support to Response. (#1376)
This commit is contained in:
@@ -13,3 +13,4 @@ pub mod decoder;
|
||||
pub mod multipart;
|
||||
pub(crate) mod request;
|
||||
mod response;
|
||||
mod upgrade;
|
||||
|
||||
@@ -24,14 +24,10 @@ use crate::response::ResponseUrl;
|
||||
|
||||
/// A Response to a submitted `Request`.
|
||||
pub struct Response {
|
||||
status: StatusCode,
|
||||
headers: HeaderMap,
|
||||
pub(super) res: hyper::Response<Decoder>,
|
||||
// Boxed to save space (11 words to 1 word), and it's not accessed
|
||||
// frequently internally.
|
||||
url: Box<Url>,
|
||||
body: Decoder,
|
||||
version: Version,
|
||||
extensions: http::Extensions,
|
||||
}
|
||||
|
||||
impl Response {
|
||||
@@ -41,46 +37,38 @@ impl Response {
|
||||
accepts: Accepts,
|
||||
timeout: Option<Pin<Box<Sleep>>>,
|
||||
) -> Response {
|
||||
let (parts, body) = res.into_parts();
|
||||
let status = parts.status;
|
||||
let version = parts.version;
|
||||
let extensions = parts.extensions;
|
||||
|
||||
let mut headers = parts.headers;
|
||||
let decoder = Decoder::detect(&mut headers, Body::response(body, timeout), accepts);
|
||||
let (mut parts, body) = res.into_parts();
|
||||
let decoder = Decoder::detect(&mut parts.headers, Body::response(body, timeout), accepts);
|
||||
let res = hyper::Response::from_parts(parts, decoder);
|
||||
|
||||
Response {
|
||||
status,
|
||||
headers,
|
||||
res,
|
||||
url: Box::new(url),
|
||||
body: decoder,
|
||||
version,
|
||||
extensions,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the `StatusCode` of this `Response`.
|
||||
#[inline]
|
||||
pub fn status(&self) -> StatusCode {
|
||||
self.status
|
||||
self.res.status()
|
||||
}
|
||||
|
||||
/// Get the HTTP `Version` of this `Response`.
|
||||
#[inline]
|
||||
pub fn version(&self) -> Version {
|
||||
self.version
|
||||
self.res.version()
|
||||
}
|
||||
|
||||
/// Get the `Headers` of this `Response`.
|
||||
#[inline]
|
||||
pub fn headers(&self) -> &HeaderMap {
|
||||
&self.headers
|
||||
self.res.headers()
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the `Headers` of this `Response`.
|
||||
#[inline]
|
||||
pub fn headers_mut(&mut self) -> &mut HeaderMap {
|
||||
&mut self.headers
|
||||
self.res.headers_mut()
|
||||
}
|
||||
|
||||
/// Get the content-length of this response, if known.
|
||||
@@ -93,7 +81,7 @@ impl Response {
|
||||
pub fn content_length(&self) -> Option<u64> {
|
||||
use hyper::body::HttpBody;
|
||||
|
||||
HttpBody::size_hint(&self.body).exact()
|
||||
HttpBody::size_hint(self.res.body()).exact()
|
||||
}
|
||||
|
||||
/// Retrieve the cookies contained in the response.
|
||||
@@ -106,7 +94,7 @@ impl Response {
|
||||
#[cfg(feature = "cookies")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "cookies")))]
|
||||
pub fn cookies<'a>(&'a self) -> impl Iterator<Item = cookie::Cookie<'a>> + 'a {
|
||||
cookie::extract_response_cookies(&self.headers).filter_map(Result::ok)
|
||||
cookie::extract_response_cookies(self.res.headers()).filter_map(Result::ok)
|
||||
}
|
||||
|
||||
/// Get the final `Url` of this `Response`.
|
||||
@@ -117,19 +105,20 @@ impl Response {
|
||||
|
||||
/// Get the remote address used to get this `Response`.
|
||||
pub fn remote_addr(&self) -> Option<SocketAddr> {
|
||||
self.extensions
|
||||
self.res
|
||||
.extensions()
|
||||
.get::<HttpInfo>()
|
||||
.map(|info| info.remote_addr())
|
||||
}
|
||||
|
||||
/// Returns a reference to the associated extensions.
|
||||
pub fn extensions(&self) -> &http::Extensions {
|
||||
&self.extensions
|
||||
self.res.extensions()
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the associated extensions.
|
||||
pub fn extensions_mut(&mut self) -> &mut http::Extensions {
|
||||
&mut self.extensions
|
||||
self.res.extensions_mut()
|
||||
}
|
||||
|
||||
// body methods
|
||||
@@ -183,7 +172,7 @@ impl Response {
|
||||
/// ```
|
||||
pub async fn text_with_charset(self, default_encoding: &str) -> crate::Result<String> {
|
||||
let content_type = self
|
||||
.headers
|
||||
.headers()
|
||||
.get(crate::header::CONTENT_TYPE)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.and_then(|value| value.parse::<Mime>().ok());
|
||||
@@ -271,7 +260,7 @@ impl Response {
|
||||
/// # }
|
||||
/// ```
|
||||
pub async fn bytes(self) -> crate::Result<Bytes> {
|
||||
hyper::body::to_bytes(self.body).await
|
||||
hyper::body::to_bytes(self.res.into_body()).await
|
||||
}
|
||||
|
||||
/// Stream a chunk of the response body.
|
||||
@@ -291,7 +280,7 @@ impl Response {
|
||||
/// # }
|
||||
/// ```
|
||||
pub async fn chunk(&mut self) -> crate::Result<Option<Bytes>> {
|
||||
if let Some(item) = self.body.next().await {
|
||||
if let Some(item) = self.res.body_mut().next().await {
|
||||
Ok(Some(item?))
|
||||
} else {
|
||||
Ok(None)
|
||||
@@ -323,7 +312,7 @@ impl Response {
|
||||
#[cfg(feature = "stream")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
|
||||
pub fn bytes_stream(self) -> impl futures_core::Stream<Item = crate::Result<Bytes>> {
|
||||
self.body
|
||||
self.res.into_body()
|
||||
}
|
||||
|
||||
// util methods
|
||||
@@ -350,8 +339,9 @@ impl Response {
|
||||
/// # fn main() {}
|
||||
/// ```
|
||||
pub fn error_for_status(self) -> crate::Result<Self> {
|
||||
if self.status.is_client_error() || self.status.is_server_error() {
|
||||
Err(crate::error::status_code(*self.url, self.status))
|
||||
let status = self.status();
|
||||
if status.is_client_error() || status.is_server_error() {
|
||||
Err(crate::error::status_code(*self.url, status))
|
||||
} else {
|
||||
Ok(self)
|
||||
}
|
||||
@@ -379,8 +369,9 @@ impl Response {
|
||||
/// # fn main() {}
|
||||
/// ```
|
||||
pub fn error_for_status_ref(&self) -> crate::Result<&Self> {
|
||||
if self.status.is_client_error() || self.status.is_server_error() {
|
||||
Err(crate::error::status_code(*self.url.clone(), self.status))
|
||||
let status = self.status();
|
||||
if status.is_client_error() || status.is_server_error() {
|
||||
Err(crate::error::status_code(*self.url.clone(), status))
|
||||
} else {
|
||||
Ok(self)
|
||||
}
|
||||
@@ -395,7 +386,7 @@ impl Response {
|
||||
// This method is just used by the blocking API.
|
||||
#[cfg(feature = "blocking")]
|
||||
pub(crate) fn body_mut(&mut self) -> &mut Decoder {
|
||||
&mut self.body
|
||||
self.res.body_mut()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -413,19 +404,16 @@ impl<T: Into<Body>> From<http::Response<T>> for Response {
|
||||
fn from(r: http::Response<T>) -> Response {
|
||||
let (mut parts, body) = r.into_parts();
|
||||
let body = body.into();
|
||||
let body = Decoder::detect(&mut parts.headers, body, Accepts::none());
|
||||
let decoder = Decoder::detect(&mut parts.headers, body, Accepts::none());
|
||||
let url = parts
|
||||
.extensions
|
||||
.remove::<ResponseUrl>()
|
||||
.unwrap_or_else(|| ResponseUrl(Url::parse("http://no.url.provided.local").unwrap()));
|
||||
let url = url.0;
|
||||
let res = hyper::Response::from_parts(parts, decoder);
|
||||
Response {
|
||||
status: parts.status,
|
||||
headers: parts.headers,
|
||||
res,
|
||||
url: Box::new(url),
|
||||
body,
|
||||
version: parts.version,
|
||||
extensions: parts.extensions,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -433,7 +421,7 @@ impl<T: Into<Body>> From<http::Response<T>> for Response {
|
||||
/// A `Response` can be piped as the `Body` of another request.
|
||||
impl From<Response> for Body {
|
||||
fn from(r: Response) -> Body {
|
||||
Body::stream(r.body)
|
||||
Body::stream(r.res.into_body())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
73
src/async_impl/upgrade.rs
Normal file
73
src/async_impl/upgrade.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
use std::pin::Pin;
|
||||
use std::task::{self, Poll};
|
||||
use std::{fmt, io};
|
||||
|
||||
use futures_util::TryFutureExt;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
/// An upgraded HTTP connection.
|
||||
pub struct Upgraded {
|
||||
inner: hyper::upgrade::Upgraded,
|
||||
}
|
||||
|
||||
impl AsyncRead for Upgraded {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Upgraded {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.inner.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Upgraded {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Upgraded").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<hyper::upgrade::Upgraded> for Upgraded {
|
||||
fn from(inner: hyper::upgrade::Upgraded) -> Self {
|
||||
Upgraded { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl super::response::Response {
|
||||
/// Consumes the response and returns a future for a possible HTTP upgrade.
|
||||
pub async fn upgrade(self) -> crate::Result<Upgraded> {
|
||||
hyper::upgrade::on(self.res)
|
||||
.map_ok(Upgraded::from)
|
||||
.map_err(crate::error::upgrade)
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -185,6 +185,7 @@ impl fmt::Display for Error {
|
||||
Kind::Body => f.write_str("request or response body error")?,
|
||||
Kind::Decode => f.write_str("error decoding response body")?,
|
||||
Kind::Redirect => f.write_str("error following redirect")?,
|
||||
Kind::Upgrade => f.write_str("error upgrading connection")?,
|
||||
Kind::Status(ref code) => {
|
||||
let prefix = if code.is_client_error() {
|
||||
"HTTP status client error"
|
||||
@@ -236,6 +237,7 @@ pub(crate) enum Kind {
|
||||
Status(StatusCode),
|
||||
Body,
|
||||
Decode,
|
||||
Upgrade,
|
||||
}
|
||||
|
||||
// constructors
|
||||
@@ -274,6 +276,10 @@ if_wasm! {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn upgrade<E: Into<BoxError>>(e: E) -> Error {
|
||||
Error::new(Kind::Upgrade, Some(e))
|
||||
}
|
||||
|
||||
// io::Error helpers
|
||||
|
||||
#[allow(unused)]
|
||||
|
||||
51
tests/upgrade.rs
Normal file
51
tests/upgrade.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
#![cfg(not(target_arch = "wasm32"))]
|
||||
mod support;
|
||||
use support::*;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_upgrade() {
|
||||
let server = server::http(move |req| {
|
||||
assert_eq!(req.method(), "GET");
|
||||
assert_eq!(req.headers()["connection"], "upgrade");
|
||||
assert_eq!(req.headers()["upgrade"], "foobar");
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut upgraded = hyper::upgrade::on(req).await.unwrap();
|
||||
|
||||
let mut buf = vec![0; 7];
|
||||
upgraded.read_exact(&mut buf).await.unwrap();
|
||||
assert_eq!(buf, b"foo=bar");
|
||||
|
||||
upgraded.write_all(b"bar=foo").await.unwrap();
|
||||
});
|
||||
|
||||
async {
|
||||
http::Response::builder()
|
||||
.status(http::StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(http::header::CONNECTION, "upgrade")
|
||||
.header(http::header::UPGRADE, "foobar")
|
||||
.body(hyper::Body::empty())
|
||||
.unwrap()
|
||||
}
|
||||
});
|
||||
|
||||
let res = reqwest::Client::builder()
|
||||
.build()
|
||||
.unwrap()
|
||||
.get(format!("http://{}", server.addr()))
|
||||
.header(http::header::CONNECTION, "upgrade")
|
||||
.header(http::header::UPGRADE, "foobar")
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), http::StatusCode::SWITCHING_PROTOCOLS);
|
||||
let mut upgraded = res.upgrade().await.unwrap();
|
||||
|
||||
upgraded.write_all(b"foo=bar").await.unwrap();
|
||||
|
||||
let mut buf = vec![];
|
||||
upgraded.read_to_end(&mut buf).await.unwrap();
|
||||
assert_eq!(buf, b"bar=foo");
|
||||
}
|
||||
Reference in New Issue
Block a user