Validate & convert messages before buffering

Malformed requests and responses should immediately result in a
RST_STREAM. To support this, received header frames are validated and
converted to Request / Response values immediately on receipt and before
buffering.
This commit is contained in:
Carl Lerche
2017-08-30 18:00:32 -04:00
parent 9bb34d907a
commit 2452cc4423
10 changed files with 246 additions and 157 deletions

View File

@@ -1,6 +1,7 @@
use {frame, HeaderMap, ConnectionError}; use {frame, HeaderMap, ConnectionError};
use frame::StreamId; use frame::StreamId;
use proto::{self, Connection, WindowSize}; use proto::{self, Connection, WindowSize, ProtoError};
use error::Reason::*;
use http::{Request, Response}; use http::{Request, Response};
use futures::{Future, Poll, Sink, Async, AsyncSink}; use futures::{Future, Poll, Sink, Async, AsyncSink};
@@ -254,7 +255,30 @@ impl proto::Peer for Peer {
frame frame
} }
fn convert_poll_message(headers: frame::Headers) -> Result<Self::Poll, ConnectionError> { fn convert_poll_message(headers: frame::Headers) -> Result<Self::Poll, ProtoError> {
headers.into_response() let mut b = Response::builder();
let stream_id = headers.stream_id();
let (pseudo, fields) = headers.into_parts();
if let Some(status) = pseudo.status {
b.status(status);
}
let mut response = match b.body(()) {
Ok(response) => response,
Err(_) => {
// TODO: Should there be more specialized handling for different
// kinds of errors
return Err(ProtoError::Stream {
id: stream_id,
reason: ProtocolError,
});
}
};
*response.headers_mut() = fields;
Ok(response)
} }
} }

View File

@@ -1,11 +1,9 @@
use super::{StreamId, StreamDependency}; use super::{StreamId, StreamDependency};
use ConnectionError;
use hpack; use hpack;
use frame::{self, Frame, Head, Kind, Error}; use frame::{self, Frame, Head, Kind, Error};
use HeaderMap; use HeaderMap;
use error::Reason::*;
use http::{version, uri, Request, Response, Method, StatusCode, Uri}; use http::{uri, Method, StatusCode, Uri};
use http::header::{self, HeaderName, HeaderValue}; use http::header::{self, HeaderName, HeaderValue};
use bytes::{BytesMut, Bytes}; use bytes::{BytesMut, Bytes};
@@ -70,13 +68,13 @@ pub struct Continuation {
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Pseudo { pub struct Pseudo {
// Request // Request
method: Option<Method>, pub method: Option<Method>,
scheme: Option<String<Bytes>>, pub scheme: Option<String<Bytes>>,
authority: Option<String<Bytes>>, pub authority: Option<String<Bytes>>,
path: Option<String<Bytes>>, pub path: Option<String<Bytes>>,
// Response // Response
status: Option<StatusCode>, pub status: Option<StatusCode>,
} }
#[derive(Debug)] #[derive(Debug)]
@@ -265,57 +263,8 @@ impl Headers {
self.flags.set_end_stream() self.flags.set_end_stream()
} }
pub fn into_response(self) -> Result<Response<()>, ConnectionError> { pub fn into_parts(self) -> (Pseudo, HeaderMap) {
let mut b = Response::builder(); (self.pseudo, self.fields)
if let Some(status) = self.pseudo.status {
b.status(status);
}
let mut response = try!(b.body(()));
*response.headers_mut() = self.fields;
Ok(response)
}
pub fn into_request(self) -> Result<Request<()>, ConnectionError> {
let mut b = Request::builder();
b.version(version::HTTP_2);
if let Some(method) = self.pseudo.method {
b.method(method);
}
// Specifying :status for a request is a protocol error
if self.pseudo.status.is_some() {
return Err(ProtocolError.into());
}
// Convert the URI
let mut parts = uri::Parts::default();
if let Some(scheme) = self.pseudo.scheme {
// TODO: Don't unwrap
parts.scheme = Some(uri::Scheme::from_shared(scheme.into_inner()).unwrap());
}
if let Some(authority) = self.pseudo.authority {
// TODO: Don't unwrap
parts.authority = Some(uri::Authority::from_shared(authority.into_inner()).unwrap());
}
if let Some(path) = self.pseudo.path {
// TODO: Don't unwrap
parts.path_and_query = Some(uri::PathAndQuery::from_shared(path.into_inner()).unwrap());
}
b.uri(parts);
let mut request = try!(b.body(()));
*request.headers_mut() = self.fields;
Ok(request)
} }
pub fn into_fields(self) -> HeaderMap { pub fn into_fields(self) -> HeaderMap {

View File

@@ -73,16 +73,6 @@ pub enum Frame<T = Bytes> {
} }
impl<T> Frame<T> { impl<T> Frame<T> {
/// Returns true if the frame is a DATA frame.
pub fn is_data(&self) -> bool {
use self::Frame::*;
match *self {
Data(..) => true,
_ => false,
}
}
pub fn map<F, U>(self, f: F) -> Frame<U> pub fn map<F, U>(self, f: F) -> Frame<U>
where F: FnOnce(T) -> U where F: FnOnce(T) -> U
{ {

View File

@@ -1,4 +1,4 @@
#![deny(warnings, missing_debug_implementations)] // #![deny(warnings, missing_debug_implementations)]
#[macro_use] #[macro_use]
extern crate futures; extern crate futures;

View File

@@ -26,7 +26,7 @@ use bytes::{Buf, IntoBuf};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use tokio_io::codec::length_delimited; use tokio_io::codec::length_delimited;
use std::io; use std::{fmt, io};
/// Either a Client or a Server /// Either a Client or a Server
pub trait Peer { pub trait Peer {
@@ -34,7 +34,7 @@ pub trait Peer {
type Send; type Send;
/// Message type polled from the transport /// Message type polled from the transport
type Poll; type Poll: fmt::Debug;
fn is_server() -> bool; fn is_server() -> bool;
@@ -43,7 +43,7 @@ pub trait Peer {
headers: Self::Send, headers: Self::Send,
end_of_stream: bool) -> frame::Headers; end_of_stream: bool) -> frame::Headers;
fn convert_poll_message(headers: frame::Headers) -> Result<Self::Poll, ConnectionError>; fn convert_poll_message(headers: frame::Headers) -> Result<Self::Poll, ProtoError>;
} }
pub type PingPayload = [u8; 8]; pub type PingPayload = [u8; 8];

View File

@@ -36,7 +36,7 @@ pub(super) struct Recv<B, P>
pending_accept: store::Queue<B, stream::NextAccept, P>, pending_accept: store::Queue<B, stream::NextAccept, P>,
/// Holds frames that are waiting to be read /// Holds frames that are waiting to be read
buffer: Buffer<Frame<Bytes>>, buffer: Buffer<Event<P::Poll>>,
/// Refused StreamId, this represents a frame that must be sent out. /// Refused StreamId, this represents a frame that must be sent out.
refused: Option<StreamId>, refused: Option<StreamId>,
@@ -44,6 +44,13 @@ pub(super) struct Recv<B, P>
_p: PhantomData<(B)>, _p: PhantomData<(B)>,
} }
#[derive(Debug)]
pub(super) enum Event<T> {
Headers(T),
Data(Bytes),
Trailers(::HeaderMap),
}
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
struct Indices { struct Indices {
head: store::Key, head: store::Key,
@@ -110,44 +117,13 @@ impl<B, P> Recv<B, P>
Ok(Some(id)) Ok(Some(id))
} }
pub fn take_request(&mut self, stream: &mut store::Ptr<B, P>)
-> Result<Request<()>, ConnectionError>
{
match stream.pending_recv.pop_front(&mut self.buffer) {
Some(Frame::Headers(frame)) => {
// TODO: This error should probably be caught on receipt of the
// frame vs. now.
Ok(server::Peer::convert_poll_message(frame)?)
}
_ => panic!(),
}
}
pub fn poll_response(&mut self, stream: &mut store::Ptr<B, P>)
-> Poll<Response<()>, ConnectionError> {
// If the buffer is not empty, then the first frame must be a HEADERS
// frame or the user violated the contract.
match stream.pending_recv.pop_front(&mut self.buffer) {
Some(Frame::Headers(v)) => {
// TODO: This error should probably be caught on receipt of the
// frame vs. now.
Ok(client::Peer::convert_poll_message(v)?.into())
}
Some(_) => unimplemented!(),
None => {
stream.state.ensure_recv_open()?;
stream.recv_task = Some(task::current());
Ok(Async::NotReady)
}
}
}
/// Transition the stream state based on receiving headers /// Transition the stream state based on receiving headers
///
/// The caller ensures that the frame represents headers and not trailers.
pub fn recv_headers(&mut self, pub fn recv_headers(&mut self,
frame: frame::Headers, frame: frame::Headers,
stream: &mut store::Ptr<B, P>) stream: &mut store::Ptr<B, P>)
-> Result<(), ConnectionError> -> Result<(), ProtoError>
{ {
trace!("opening stream; init_window={}", self.init_window_sz); trace!("opening stream; init_window={}", self.init_window_sz);
let is_initial = stream.state.recv_open(frame.is_end_stream())?; let is_initial = stream.state.recv_open(frame.is_end_stream())?;
@@ -161,7 +137,7 @@ impl<B, P> Recv<B, P>
self.next_stream_id = frame.stream_id(); self.next_stream_id = frame.stream_id();
self.next_stream_id.increment(); self.next_stream_id.increment();
} else { } else {
return Err(ProtocolError.into()); return Err(ProtoError::Connection(ProtocolError));
} }
// TODO: be smarter about this logic // TODO: be smarter about this logic
@@ -173,8 +149,10 @@ impl<B, P> Recv<B, P>
self.inc_num_streams(); self.inc_num_streams();
} }
let message = P::convert_poll_message(frame)?;
// Push the frame onto the stream's recv buffer // Push the frame onto the stream's recv buffer
stream.pending_recv.push_back(&mut self.buffer, frame.into()); stream.pending_recv.push_back(&mut self.buffer, Event::Headers(message));
stream.notify_recv(); stream.notify_recv();
// Only servers can receive a headers frame that initiates the stream. // Only servers can receive a headers frame that initiates the stream.
@@ -190,13 +168,15 @@ impl<B, P> Recv<B, P>
pub fn recv_trailers(&mut self, pub fn recv_trailers(&mut self,
frame: frame::Headers, frame: frame::Headers,
stream: &mut store::Ptr<B, P>) stream: &mut store::Ptr<B, P>)
-> Result<(), ConnectionError> -> Result<(), ProtoError>
{ {
// Transition the state // Transition the state
stream.state.recv_close()?; stream.state.recv_close()?;
let trailers = frame.into_fields();
// Push the frame onto the stream's recv buffer // Push the frame onto the stream's recv buffer
stream.pending_recv.push_back(&mut self.buffer, frame.into()); stream.pending_recv.push_back(&mut self.buffer, Event::Trailers(trailers));
stream.notify_recv(); stream.notify_recv();
Ok(()) Ok(())
@@ -236,7 +216,7 @@ impl<B, P> Recv<B, P>
} }
stream.pending_recv.peek_front(&self.buffer) stream.pending_recv.peek_front(&self.buffer)
.map(|frame| !frame.is_data()) .map(|event| !event.is_data())
.unwrap_or(true) .unwrap_or(true)
} }
@@ -278,11 +258,15 @@ impl<B, P> Recv<B, P>
stream.in_flight_recv_data += sz; stream.in_flight_recv_data += sz;
if frame.is_end_stream() { if frame.is_end_stream() {
try!(stream.state.recv_close()); if stream.state.recv_close().is_err() {
return Err(ProtocolError.into());
}
} }
let event = Event::Data(frame.into_payload());
// Push the frame onto the recv buffer // Push the frame onto the recv buffer
stream.pending_recv.push_back(&mut self.buffer, frame.into()); stream.pending_recv.push_back(&mut self.buffer, event);
stream.notify_recv(); stream.notify_recv();
Ok(()) Ok(())
@@ -530,12 +514,12 @@ impl<B, P> Recv<B, P>
-> Poll<Option<Bytes>, ConnectionError> -> Poll<Option<Bytes>, ConnectionError>
{ {
match stream.pending_recv.pop_front(&mut self.buffer) { match stream.pending_recv.pop_front(&mut self.buffer) {
Some(Frame::Data(frame)) => { Some(Event::Data(payload)) => {
Ok(Some(frame.into_payload()).into()) Ok(Some(payload).into())
} }
Some(frame) => { Some(event) => {
// Frame is trailer // Frame is trailer
stream.pending_recv.push_front(&mut self.buffer, frame); stream.pending_recv.push_front(&mut self.buffer, event);
// No more data frames // No more data frames
Ok(None.into()) Ok(None.into())
@@ -557,8 +541,8 @@ impl<B, P> Recv<B, P>
-> Poll<Option<HeaderMap>, ConnectionError> -> Poll<Option<HeaderMap>, ConnectionError>
{ {
match stream.pending_recv.pop_front(&mut self.buffer) { match stream.pending_recv.pop_front(&mut self.buffer) {
Some(Frame::Headers(frame)) => { Some(Event::Trailers(trailers)) => {
Ok(Some(frame.into_fields()).into()) Ok(Some(trailers).into())
} }
Some(_) => { Some(_) => {
// TODO: This is a user error. `poll_trailers` was called before // TODO: This is a user error. `poll_trailers` was called before
@@ -583,3 +567,55 @@ impl<B, P> Recv<B, P>
unimplemented!(); unimplemented!();
} }
} }
impl<B> Recv<B, server::Peer>
where B: Buf,
{
/// TODO: Should this fn return `Result`?
pub fn take_request(&mut self, stream: &mut store::Ptr<B, server::Peer>)
-> Result<Request<()>, ConnectionError>
{
match stream.pending_recv.pop_front(&mut self.buffer) {
Some(Event::Headers(request)) => Ok(request),
/*
// TODO: This error should probably be caught on receipt of the
// frame vs. now.
Ok(server::Peer::convert_poll_message(frame)?)
*/
_ => panic!(),
}
}
}
impl<B> Recv<B, client::Peer>
where B: Buf,
{
pub fn poll_response(&mut self, stream: &mut store::Ptr<B, client::Peer>)
-> Poll<Response<()>, ConnectionError> {
// If the buffer is not empty, then the first frame must be a HEADERS
// frame or the user violated the contract.
match stream.pending_recv.pop_front(&mut self.buffer) {
Some(Event::Headers(response)) => {
Ok(response.into())
}
Some(_) => unimplemented!(),
None => {
stream.state.ensure_recv_open()?;
stream.recv_task = Some(task::current());
Ok(Async::NotReady)
}
}
}
}
// ===== impl Event =====
impl<T> Event<T> {
fn is_data(&self) -> bool {
match *self {
Event::Data(..) => true,
_ => false,
}
}
}

View File

@@ -1,4 +1,5 @@
use ConnectionError; use ConnectionError;
use proto::ProtoError;
use error::Reason; use error::Reason;
use error::Reason::*; use error::Reason::*;
use error::User::*; use error::User::*;
@@ -125,7 +126,7 @@ impl State {
/// frame is received. /// frame is received.
/// ///
/// Returns true if this transitions the state to Open /// Returns true if this transitions the state to Open
pub fn recv_open(&mut self, eos: bool) -> Result<bool, ConnectionError> { pub fn recv_open(&mut self, eos: bool) -> Result<bool, ProtoError> {
let remote = Peer::Streaming; let remote = Peer::Streaming;
let mut initial = false; let mut initial = false;
@@ -173,7 +174,7 @@ impl State {
} }
_ => { _ => {
// All other transitions result in a protocol error // All other transitions result in a protocol error
return Err(ProtocolError.into()); return Err(ProtoError::Connection(ProtocolError));
} }
}; };
@@ -192,7 +193,7 @@ impl State {
} }
/// Indicates that the remote side will not send more data to the local. /// Indicates that the remote side will not send more data to the local.
pub fn recv_close(&mut self) -> Result<(), ConnectionError> { pub fn recv_close(&mut self) -> Result<(), ProtoError> {
match self.inner { match self.inner {
Open { local, .. } => { Open { local, .. } => {
// The remote side will continue to receive data. // The remote side will continue to receive data.
@@ -205,7 +206,7 @@ impl State {
self.inner = Closed(None); self.inner = Closed(None);
Ok(()) Ok(())
} }
_ => Err(ProtocolError.into()), _ => Err(ProtoError::Connection(ProtocolError)),
} }
} }

View File

@@ -64,7 +64,7 @@ pub(super) struct Stream<B, P>
pub is_pending_window_update: bool, pub is_pending_window_update: bool,
/// Frames pending for this stream to read /// Frames pending for this stream to read
pub pending_recv: buffer::Deque<Frame<Bytes>>, pub pending_recv: buffer::Deque<recv::Event<P::Poll>>,
/// Task tracking receiving frames /// Task tracking receiving frames
pub recv_task: Option<task::Task>, pub recv_task: Option<task::Task>,

View File

@@ -92,7 +92,7 @@ impl<B, P> Streams<B, P>
let stream = me.store.resolve(key); let stream = me.store.resolve(key);
me.actions.transition(stream, |actions, stream| { me.actions.transition(stream, |actions, stream| {
if stream.state.is_recv_headers() { let res = if stream.state.is_recv_headers() {
actions.recv.recv_headers(frame, stream) actions.recv.recv_headers(frame, stream)
} else { } else {
if !frame.is_end_stream() { if !frame.is_end_stream() {
@@ -101,6 +101,17 @@ impl<B, P> Streams<B, P>
} }
actions.recv.recv_trailers(frame, stream) actions.recv.recv_trailers(frame, stream)
};
match res {
Ok(()) => Ok(()),
Err(ProtoError::Connection(reason)) => Err(reason.into()),
Err(ProtoError::Stream { reason, .. }) => {
// Reset the stream.
actions.send.send_reset(reason, stream, &mut actions.task);
Ok(())
}
Err(ProtoError::Io(_)) => unreachable!(),
} }
}) })
} }
@@ -381,21 +392,6 @@ impl<B, P> StreamRef<B, P>
}) })
} }
/// Called by the server after the stream is accepted. Given that clients
/// initialize streams by sending HEADERS, the request will always be
/// available.
///
/// # Panics
///
/// This function panics if the request isn't present.
pub fn take_request(&self) -> Result<Request<()>, ConnectionError> {
let mut me = self.inner.lock().unwrap();
let me = &mut *me;
let mut stream = me.store.resolve(self.key);
me.actions.recv.take_request(&mut stream)
}
pub fn send_reset(&mut self, reason: Reason) { pub fn send_reset(&mut self, reason: Reason) {
let mut me = self.inner.lock().unwrap(); let mut me = self.inner.lock().unwrap();
let me = &mut *me; let me = &mut *me;
@@ -431,15 +427,6 @@ impl<B, P> StreamRef<B, P>
me.actions.recv.body_is_empty(&stream) me.actions.recv.body_is_empty(&stream)
} }
pub fn poll_response(&mut self) -> Poll<Response<()>, ConnectionError> {
let mut me = self.inner.lock().unwrap();
let me = &mut *me;
let mut stream = me.store.resolve(self.key);
me.actions.recv.poll_response(&mut stream)
}
pub fn poll_data(&mut self) -> Poll<Option<Bytes>, ConnectionError> { pub fn poll_data(&mut self) -> Poll<Option<Bytes>, ConnectionError> {
let mut me = self.inner.lock().unwrap(); let mut me = self.inner.lock().unwrap();
let me = &mut *me; let me = &mut *me;
@@ -503,6 +490,38 @@ impl<B, P> StreamRef<B, P>
} }
} }
impl<B> StreamRef<B, server::Peer>
where B: Buf,
{
/// Called by the server after the stream is accepted. Given that clients
/// initialize streams by sending HEADERS, the request will always be
/// available.
///
/// # Panics
///
/// This function panics if the request isn't present.
pub fn take_request(&self) -> Result<Request<()>, ConnectionError> {
let mut me = self.inner.lock().unwrap();
let me = &mut *me;
let mut stream = me.store.resolve(self.key);
me.actions.recv.take_request(&mut stream)
}
}
impl<B> StreamRef<B, client::Peer>
where B: Buf,
{
pub fn poll_response(&mut self) -> Poll<Response<()>, ConnectionError> {
let mut me = self.inner.lock().unwrap();
let me = &mut *me;
let mut stream = me.store.resolve(self.key);
me.actions.recv.poll_response(&mut stream)
}
}
impl<B, P> Clone for StreamRef<B, P> impl<B, P> Clone for StreamRef<B, P>
where P: Peer, where P: Peer,
{ {

View File

@@ -1,6 +1,6 @@
use {HeaderMap, ConnectionError}; use {HeaderMap, ConnectionError};
use frame::{self, StreamId}; use frame::{self, StreamId};
use proto::{self, Connection, WindowSize}; use proto::{self, Connection, WindowSize, ProtoError};
use error::Reason; use error::Reason;
use error::Reason::*; use error::Reason::*;
@@ -401,8 +401,78 @@ impl proto::Peer for Peer {
} }
fn convert_poll_message(headers: frame::Headers) fn convert_poll_message(headers: frame::Headers)
-> Result<Self::Poll, ConnectionError> -> Result<Self::Poll, ProtoError>
{ {
headers.into_request() use http::{version, uri};
let mut b = Request::builder();
let stream_id = headers.stream_id();
let (pseudo, fields) = headers.into_parts();
macro_rules! malformed {
() => {
return Err(ProtoError::Stream {
id: stream_id,
reason: ProtocolError,
});
}
};
b.version(version::HTTP_2);
if let Some(method) = pseudo.method {
b.method(method);
} else {
malformed!();
}
// Specifying :status for a request is a protocol error
if pseudo.status.is_some() {
return Err(ProtoError::Connection(ProtocolError));
}
// Convert the URI
let mut parts = uri::Parts::default();
if let Some(scheme) = pseudo.scheme {
// TODO: Don't unwrap
parts.scheme = Some(uri::Scheme::from_shared(scheme.into_inner()).unwrap());
} else {
malformed!();
}
if let Some(authority) = pseudo.authority {
// TODO: Don't unwrap
parts.authority = Some(uri::Authority::from_shared(authority.into_inner()).unwrap());
}
if let Some(path) = pseudo.path {
// This cannot be empty
if path.is_empty() {
malformed!();
}
// TODO: Don't unwrap
parts.path_and_query = Some(uri::PathAndQuery::from_shared(path.into_inner()).unwrap());
}
b.uri(parts);
let mut request = match b.body(()) {
Ok(request) => request,
Err(_) => {
// TODO: Should there be more specialized handling for different
// kinds of errors
return Err(ProtoError::Stream {
id: stream_id,
reason: ProtocolError,
});
}
};
*request.headers_mut() = fields;
Ok(request)
} }
} }