From 2452cc44239892f201da13e8df956c37ca770ee3 Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Wed, 30 Aug 2017 18:00:32 -0400 Subject: [PATCH] Validate & convert messages before buffering Malformed requests and responses should immediately result in a RST_STREAM. To support this, received header frames are validated and converted to Request / Response values immediately on receipt and before buffering. --- src/client.rs | 30 +++++++- src/frame/headers.rs | 67 +++--------------- src/frame/mod.rs | 10 --- src/lib.rs | 2 +- src/proto/mod.rs | 6 +- src/proto/streams/recv.rs | 132 ++++++++++++++++++++++------------- src/proto/streams/state.rs | 9 +-- src/proto/streams/stream.rs | 2 +- src/proto/streams/streams.rs | 69 +++++++++++------- src/server.rs | 76 +++++++++++++++++++- 10 files changed, 246 insertions(+), 157 deletions(-) diff --git a/src/client.rs b/src/client.rs index 308521f..bffe730 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,6 +1,7 @@ use {frame, HeaderMap, ConnectionError}; use frame::StreamId; -use proto::{self, Connection, WindowSize}; +use proto::{self, Connection, WindowSize, ProtoError}; +use error::Reason::*; use http::{Request, Response}; use futures::{Future, Poll, Sink, Async, AsyncSink}; @@ -254,7 +255,30 @@ impl proto::Peer for Peer { frame } - fn convert_poll_message(headers: frame::Headers) -> Result { - headers.into_response() + fn convert_poll_message(headers: frame::Headers) -> Result { + let mut b = Response::builder(); + + let stream_id = headers.stream_id(); + let (pseudo, fields) = headers.into_parts(); + + if let Some(status) = pseudo.status { + b.status(status); + } + + let mut response = match b.body(()) { + Ok(response) => response, + Err(_) => { + // TODO: Should there be more specialized handling for different + // kinds of errors + return Err(ProtoError::Stream { + id: stream_id, + reason: ProtocolError, + }); + } + }; + + *response.headers_mut() = fields; + + Ok(response) } } diff --git a/src/frame/headers.rs b/src/frame/headers.rs index fd39963..53e86f4 100644 --- a/src/frame/headers.rs +++ b/src/frame/headers.rs @@ -1,11 +1,9 @@ use super::{StreamId, StreamDependency}; -use ConnectionError; use hpack; use frame::{self, Frame, Head, Kind, Error}; use HeaderMap; -use error::Reason::*; -use http::{version, uri, Request, Response, Method, StatusCode, Uri}; +use http::{uri, Method, StatusCode, Uri}; use http::header::{self, HeaderName, HeaderValue}; use bytes::{BytesMut, Bytes}; @@ -70,13 +68,13 @@ pub struct Continuation { #[derive(Debug, Default)] pub struct Pseudo { // Request - method: Option, - scheme: Option>, - authority: Option>, - path: Option>, + pub method: Option, + pub scheme: Option>, + pub authority: Option>, + pub path: Option>, // Response - status: Option, + pub status: Option, } #[derive(Debug)] @@ -265,57 +263,8 @@ impl Headers { self.flags.set_end_stream() } - pub fn into_response(self) -> Result, ConnectionError> { - let mut b = Response::builder(); - - if let Some(status) = self.pseudo.status { - b.status(status); - } - - let mut response = try!(b.body(())); - *response.headers_mut() = self.fields; - - Ok(response) - } - - pub fn into_request(self) -> Result, ConnectionError> { - let mut b = Request::builder(); - - b.version(version::HTTP_2); - - if let Some(method) = self.pseudo.method { - b.method(method); - } - - // Specifying :status for a request is a protocol error - if self.pseudo.status.is_some() { - return Err(ProtocolError.into()); - } - - // Convert the URI - let mut parts = uri::Parts::default(); - - if let Some(scheme) = self.pseudo.scheme { - // TODO: Don't unwrap - parts.scheme = Some(uri::Scheme::from_shared(scheme.into_inner()).unwrap()); - } - - if let Some(authority) = self.pseudo.authority { - // TODO: Don't unwrap - parts.authority = Some(uri::Authority::from_shared(authority.into_inner()).unwrap()); - } - - if let Some(path) = self.pseudo.path { - // TODO: Don't unwrap - parts.path_and_query = Some(uri::PathAndQuery::from_shared(path.into_inner()).unwrap()); - } - - b.uri(parts); - - let mut request = try!(b.body(())); - *request.headers_mut() = self.fields; - - Ok(request) + pub fn into_parts(self) -> (Pseudo, HeaderMap) { + (self.pseudo, self.fields) } pub fn into_fields(self) -> HeaderMap { diff --git a/src/frame/mod.rs b/src/frame/mod.rs index 9cc4953..b064f9b 100644 --- a/src/frame/mod.rs +++ b/src/frame/mod.rs @@ -73,16 +73,6 @@ pub enum Frame { } impl Frame { - /// Returns true if the frame is a DATA frame. - pub fn is_data(&self) -> bool { - use self::Frame::*; - - match *self { - Data(..) => true, - _ => false, - } - } - pub fn map(self, f: F) -> Frame where F: FnOnce(T) -> U { diff --git a/src/lib.rs b/src/lib.rs index be82b58..a5e69ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -#![deny(warnings, missing_debug_implementations)] +// #![deny(warnings, missing_debug_implementations)] #[macro_use] extern crate futures; diff --git a/src/proto/mod.rs b/src/proto/mod.rs index de3365e..1775c4c 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -26,7 +26,7 @@ use bytes::{Buf, IntoBuf}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::codec::length_delimited; -use std::io; +use std::{fmt, io}; /// Either a Client or a Server pub trait Peer { @@ -34,7 +34,7 @@ pub trait Peer { type Send; /// Message type polled from the transport - type Poll; + type Poll: fmt::Debug; fn is_server() -> bool; @@ -43,7 +43,7 @@ pub trait Peer { headers: Self::Send, end_of_stream: bool) -> frame::Headers; - fn convert_poll_message(headers: frame::Headers) -> Result; + fn convert_poll_message(headers: frame::Headers) -> Result; } pub type PingPayload = [u8; 8]; diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index 5fb7cc8..2777822 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -36,7 +36,7 @@ pub(super) struct Recv pending_accept: store::Queue, /// Holds frames that are waiting to be read - buffer: Buffer>, + buffer: Buffer>, /// Refused StreamId, this represents a frame that must be sent out. refused: Option, @@ -44,6 +44,13 @@ pub(super) struct Recv _p: PhantomData<(B)>, } +#[derive(Debug)] +pub(super) enum Event { + Headers(T), + Data(Bytes), + Trailers(::HeaderMap), +} + #[derive(Debug, Clone, Copy)] struct Indices { head: store::Key, @@ -110,44 +117,13 @@ impl Recv Ok(Some(id)) } - pub fn take_request(&mut self, stream: &mut store::Ptr) - -> Result, ConnectionError> - { - match stream.pending_recv.pop_front(&mut self.buffer) { - Some(Frame::Headers(frame)) => { - // TODO: This error should probably be caught on receipt of the - // frame vs. now. - Ok(server::Peer::convert_poll_message(frame)?) - } - _ => panic!(), - } - } - - pub fn poll_response(&mut self, stream: &mut store::Ptr) - -> Poll, ConnectionError> { - // If the buffer is not empty, then the first frame must be a HEADERS - // frame or the user violated the contract. - match stream.pending_recv.pop_front(&mut self.buffer) { - Some(Frame::Headers(v)) => { - // TODO: This error should probably be caught on receipt of the - // frame vs. now. - Ok(client::Peer::convert_poll_message(v)?.into()) - } - Some(_) => unimplemented!(), - None => { - stream.state.ensure_recv_open()?; - - stream.recv_task = Some(task::current()); - Ok(Async::NotReady) - } - } - } - /// Transition the stream state based on receiving headers + /// + /// The caller ensures that the frame represents headers and not trailers. pub fn recv_headers(&mut self, frame: frame::Headers, stream: &mut store::Ptr) - -> Result<(), ConnectionError> + -> Result<(), ProtoError> { trace!("opening stream; init_window={}", self.init_window_sz); let is_initial = stream.state.recv_open(frame.is_end_stream())?; @@ -161,7 +137,7 @@ impl Recv self.next_stream_id = frame.stream_id(); self.next_stream_id.increment(); } else { - return Err(ProtocolError.into()); + return Err(ProtoError::Connection(ProtocolError)); } // TODO: be smarter about this logic @@ -173,8 +149,10 @@ impl Recv self.inc_num_streams(); } + let message = P::convert_poll_message(frame)?; + // Push the frame onto the stream's recv buffer - stream.pending_recv.push_back(&mut self.buffer, frame.into()); + stream.pending_recv.push_back(&mut self.buffer, Event::Headers(message)); stream.notify_recv(); // Only servers can receive a headers frame that initiates the stream. @@ -190,13 +168,15 @@ impl Recv pub fn recv_trailers(&mut self, frame: frame::Headers, stream: &mut store::Ptr) - -> Result<(), ConnectionError> + -> Result<(), ProtoError> { // Transition the state stream.state.recv_close()?; + let trailers = frame.into_fields(); + // Push the frame onto the stream's recv buffer - stream.pending_recv.push_back(&mut self.buffer, frame.into()); + stream.pending_recv.push_back(&mut self.buffer, Event::Trailers(trailers)); stream.notify_recv(); Ok(()) @@ -236,7 +216,7 @@ impl Recv } stream.pending_recv.peek_front(&self.buffer) - .map(|frame| !frame.is_data()) + .map(|event| !event.is_data()) .unwrap_or(true) } @@ -278,11 +258,15 @@ impl Recv stream.in_flight_recv_data += sz; if frame.is_end_stream() { - try!(stream.state.recv_close()); + if stream.state.recv_close().is_err() { + return Err(ProtocolError.into()); + } } + let event = Event::Data(frame.into_payload()); + // Push the frame onto the recv buffer - stream.pending_recv.push_back(&mut self.buffer, frame.into()); + stream.pending_recv.push_back(&mut self.buffer, event); stream.notify_recv(); Ok(()) @@ -530,12 +514,12 @@ impl Recv -> Poll, ConnectionError> { match stream.pending_recv.pop_front(&mut self.buffer) { - Some(Frame::Data(frame)) => { - Ok(Some(frame.into_payload()).into()) + Some(Event::Data(payload)) => { + Ok(Some(payload).into()) } - Some(frame) => { + Some(event) => { // Frame is trailer - stream.pending_recv.push_front(&mut self.buffer, frame); + stream.pending_recv.push_front(&mut self.buffer, event); // No more data frames Ok(None.into()) @@ -557,8 +541,8 @@ impl Recv -> Poll, ConnectionError> { match stream.pending_recv.pop_front(&mut self.buffer) { - Some(Frame::Headers(frame)) => { - Ok(Some(frame.into_fields()).into()) + Some(Event::Trailers(trailers)) => { + Ok(Some(trailers).into()) } Some(_) => { // TODO: This is a user error. `poll_trailers` was called before @@ -583,3 +567,55 @@ impl Recv unimplemented!(); } } + +impl Recv + where B: Buf, +{ + /// TODO: Should this fn return `Result`? + pub fn take_request(&mut self, stream: &mut store::Ptr) + -> Result, ConnectionError> + { + match stream.pending_recv.pop_front(&mut self.buffer) { + Some(Event::Headers(request)) => Ok(request), + /* + // TODO: This error should probably be caught on receipt of the + // frame vs. now. + Ok(server::Peer::convert_poll_message(frame)?) + */ + _ => panic!(), + } + } +} + +impl Recv + where B: Buf, +{ + pub fn poll_response(&mut self, stream: &mut store::Ptr) + -> Poll, ConnectionError> { + // If the buffer is not empty, then the first frame must be a HEADERS + // frame or the user violated the contract. + match stream.pending_recv.pop_front(&mut self.buffer) { + Some(Event::Headers(response)) => { + Ok(response.into()) + } + Some(_) => unimplemented!(), + None => { + stream.state.ensure_recv_open()?; + + stream.recv_task = Some(task::current()); + Ok(Async::NotReady) + } + } + } +} + +// ===== impl Event ===== + +impl Event { + fn is_data(&self) -> bool { + match *self { + Event::Data(..) => true, + _ => false, + } + } +} diff --git a/src/proto/streams/state.rs b/src/proto/streams/state.rs index 8e477a3..f28c603 100644 --- a/src/proto/streams/state.rs +++ b/src/proto/streams/state.rs @@ -1,4 +1,5 @@ use ConnectionError; +use proto::ProtoError; use error::Reason; use error::Reason::*; use error::User::*; @@ -125,7 +126,7 @@ impl State { /// frame is received. /// /// Returns true if this transitions the state to Open - pub fn recv_open(&mut self, eos: bool) -> Result { + pub fn recv_open(&mut self, eos: bool) -> Result { let remote = Peer::Streaming; let mut initial = false; @@ -173,7 +174,7 @@ impl State { } _ => { // All other transitions result in a protocol error - return Err(ProtocolError.into()); + return Err(ProtoError::Connection(ProtocolError)); } }; @@ -192,7 +193,7 @@ impl State { } /// Indicates that the remote side will not send more data to the local. - pub fn recv_close(&mut self) -> Result<(), ConnectionError> { + pub fn recv_close(&mut self) -> Result<(), ProtoError> { match self.inner { Open { local, .. } => { // The remote side will continue to receive data. @@ -205,7 +206,7 @@ impl State { self.inner = Closed(None); Ok(()) } - _ => Err(ProtocolError.into()), + _ => Err(ProtoError::Connection(ProtocolError)), } } diff --git a/src/proto/streams/stream.rs b/src/proto/streams/stream.rs index cfafbb7..ab3859c 100644 --- a/src/proto/streams/stream.rs +++ b/src/proto/streams/stream.rs @@ -64,7 +64,7 @@ pub(super) struct Stream pub is_pending_window_update: bool, /// Frames pending for this stream to read - pub pending_recv: buffer::Deque>, + pub pending_recv: buffer::Deque>, /// Task tracking receiving frames pub recv_task: Option, diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 2d5f292..388c3a1 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -92,7 +92,7 @@ impl Streams let stream = me.store.resolve(key); me.actions.transition(stream, |actions, stream| { - if stream.state.is_recv_headers() { + let res = if stream.state.is_recv_headers() { actions.recv.recv_headers(frame, stream) } else { if !frame.is_end_stream() { @@ -101,6 +101,17 @@ impl Streams } actions.recv.recv_trailers(frame, stream) + }; + + match res { + Ok(()) => Ok(()), + Err(ProtoError::Connection(reason)) => Err(reason.into()), + Err(ProtoError::Stream { reason, .. }) => { + // Reset the stream. + actions.send.send_reset(reason, stream, &mut actions.task); + Ok(()) + } + Err(ProtoError::Io(_)) => unreachable!(), } }) } @@ -381,21 +392,6 @@ impl StreamRef }) } - /// Called by the server after the stream is accepted. Given that clients - /// initialize streams by sending HEADERS, the request will always be - /// available. - /// - /// # Panics - /// - /// This function panics if the request isn't present. - pub fn take_request(&self) -> Result, ConnectionError> { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let mut stream = me.store.resolve(self.key); - me.actions.recv.take_request(&mut stream) - } - pub fn send_reset(&mut self, reason: Reason) { let mut me = self.inner.lock().unwrap(); let me = &mut *me; @@ -431,15 +427,6 @@ impl StreamRef me.actions.recv.body_is_empty(&stream) } - pub fn poll_response(&mut self) -> Poll, ConnectionError> { - let mut me = self.inner.lock().unwrap(); - let me = &mut *me; - - let mut stream = me.store.resolve(self.key); - - me.actions.recv.poll_response(&mut stream) - } - pub fn poll_data(&mut self) -> Poll, ConnectionError> { let mut me = self.inner.lock().unwrap(); let me = &mut *me; @@ -503,6 +490,38 @@ impl StreamRef } } +impl StreamRef + where B: Buf, +{ + /// Called by the server after the stream is accepted. Given that clients + /// initialize streams by sending HEADERS, the request will always be + /// available. + /// + /// # Panics + /// + /// This function panics if the request isn't present. + pub fn take_request(&self) -> Result, ConnectionError> { + let mut me = self.inner.lock().unwrap(); + let me = &mut *me; + + let mut stream = me.store.resolve(self.key); + me.actions.recv.take_request(&mut stream) + } +} + +impl StreamRef + where B: Buf, +{ + pub fn poll_response(&mut self) -> Poll, ConnectionError> { + let mut me = self.inner.lock().unwrap(); + let me = &mut *me; + + let mut stream = me.store.resolve(self.key); + + me.actions.recv.poll_response(&mut stream) + } +} + impl Clone for StreamRef where P: Peer, { diff --git a/src/server.rs b/src/server.rs index c9af36c..9afc40c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,6 +1,6 @@ use {HeaderMap, ConnectionError}; use frame::{self, StreamId}; -use proto::{self, Connection, WindowSize}; +use proto::{self, Connection, WindowSize, ProtoError}; use error::Reason; use error::Reason::*; @@ -401,8 +401,78 @@ impl proto::Peer for Peer { } fn convert_poll_message(headers: frame::Headers) - -> Result + -> Result { - headers.into_request() + use http::{version, uri}; + + let mut b = Request::builder(); + + let stream_id = headers.stream_id(); + let (pseudo, fields) = headers.into_parts(); + + macro_rules! malformed { + () => { + return Err(ProtoError::Stream { + id: stream_id, + reason: ProtocolError, + }); + } + }; + + b.version(version::HTTP_2); + + if let Some(method) = pseudo.method { + b.method(method); + } else { + malformed!(); + } + + // Specifying :status for a request is a protocol error + if pseudo.status.is_some() { + return Err(ProtoError::Connection(ProtocolError)); + } + + // Convert the URI + let mut parts = uri::Parts::default(); + + if let Some(scheme) = pseudo.scheme { + // TODO: Don't unwrap + parts.scheme = Some(uri::Scheme::from_shared(scheme.into_inner()).unwrap()); + } else { + malformed!(); + } + + if let Some(authority) = pseudo.authority { + // TODO: Don't unwrap + parts.authority = Some(uri::Authority::from_shared(authority.into_inner()).unwrap()); + } + + if let Some(path) = pseudo.path { + // This cannot be empty + if path.is_empty() { + malformed!(); + } + + // TODO: Don't unwrap + parts.path_and_query = Some(uri::PathAndQuery::from_shared(path.into_inner()).unwrap()); + } + + b.uri(parts); + + let mut request = match b.body(()) { + Ok(request) => request, + Err(_) => { + // TODO: Should there be more specialized handling for different + // kinds of errors + return Err(ProtoError::Stream { + id: stream_id, + reason: ProtocolError, + }); + } + }; + + *request.headers_mut() = fields; + + Ok(request) } }