diff --git a/src/frame/headers.rs b/src/frame/headers.rs index 96f551c..dfacd0f 100644 --- a/src/frame/headers.rs +++ b/src/frame/headers.rs @@ -186,14 +186,16 @@ impl Headers { -> Result<(), Error> { let mut reg = false; - let mut err = false; + let mut malformed = false; macro_rules! set_pseudo { ($field:ident, $val:expr) => {{ if reg { - err = true; + trace!("load_hpack; header malformed -- pseudo not at head of block"); + malformed = true; } else if self.pseudo.$field.is_some() { - err = true; + trace!("load_hpack; header malformed -- repeated pseudo"); + malformed = true; } else { self.pseudo.$field = Some($val); } @@ -212,8 +214,19 @@ impl Headers { match header { Field { name, value } => { - reg = true; - self.fields.append(name, value); + // Connection level header fields are not supported and must + // result in a protocol error. + + if name == header::CONNECTION { + trace!("load_hpack; connection level header"); + malformed = true; + } else if name == header::TE && value != "trailers" { + trace!("load_hpack; TE header not set to trailers; val={:?}", value); + malformed = true; + } else { + reg = true; + self.fields.append(name, value); + } } Authority(v) => set_pseudo!(authority, v), Method(v) => set_pseudo!(method, v), @@ -228,9 +241,9 @@ impl Headers { return Err(e.into()); } - if err { - trace!("repeated pseudo"); - return Err(hpack::DecoderError::RepeatedPseudo.into()); + if malformed { + trace!("malformed message"); + return Err(Error::MalformedMessage.into()); } Ok(()) diff --git a/src/frame/mod.rs b/src/frame/mod.rs index 21f723e..9cc4953 100644 --- a/src/frame/mod.rs +++ b/src/frame/mod.rs @@ -171,6 +171,9 @@ pub enum Error { /// identifier other than zero. InvalidStreamId, + /// A request or response is malformed. + MalformedMessage, + /// An invalid stream dependency ID was provided /// /// This is returend if a HEADERS or PRIORITY frame is received with an diff --git a/src/hpack/decoder.rs b/src/hpack/decoder.rs index 257a263..44e1b0e 100644 --- a/src/hpack/decoder.rs +++ b/src/hpack/decoder.rs @@ -35,7 +35,6 @@ pub enum DecoderError { IntegerUnderflow, IntegerOverflow, StringUnderflow, - RepeatedPseudo, UnexpectedEndOfStream, } diff --git a/src/proto/codec.rs b/src/proto/codec.rs index daf0afb..699f170 100644 --- a/src/proto/codec.rs +++ b/src/proto/codec.rs @@ -53,9 +53,9 @@ impl futures::Stream for Codec where T: AsyncRead, { type Item = Frame; - type Error = ConnectionError; + type Error = ProtoError; - fn poll(&mut self) -> Poll, ConnectionError> { + fn poll(&mut self) -> Poll, Self::Error> { self.inner.poll() } } diff --git a/src/proto/connection.rs b/src/proto/connection.rs index 55616c0..1ee47af 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -123,6 +123,7 @@ impl Connection fn poll2(&mut self) -> Poll<(), ConnectionError> { use frame::Frame::*; + use proto::ProtoError::*; loop { // First, ensure that the `Connection` is able to receive a frame @@ -130,13 +131,29 @@ impl Connection trace!("polling codec"); - let frame = match try!(self.codec.poll()) { - Async::Ready(frame) => frame, - Async::NotReady => { + let frame = match self.codec.poll() { + // Receive a frame + Ok(Async::Ready(frame)) => frame, + // Socket not ready, try to flush any pending data + Ok(Async::NotReady) => { // Flush any pending writes let _ = try!(self.poll_complete()); return Ok(Async::NotReady); } + // Connection level error, set GO_AWAY and close connection + Err(Connection(reason)) => { + return Err(ConnectionError::Proto(reason)); + } + // Stream level error, reset the stream + Err(Stream { id, reason }) => { + trace!("stream level error; id={:?}; reason={:?}", id, reason); + self.streams.send_reset::

(id, reason); + continue; + } + // I/O error, nothing more can be done + Err(Io(err)) => { + return Err(err.into()); + } }; debug!("recv; frame={:?}", frame); diff --git a/src/proto/framed_read.rs b/src/proto/framed_read.rs index 8eef567..7146a76 100644 --- a/src/proto/framed_read.rs +++ b/src/proto/framed_read.rs @@ -36,6 +36,8 @@ struct Partial { #[derive(Debug)] enum Continuable { Headers(frame::Headers), + // Decode the Continuation frame but ignore it... + // Ignore(StreamId), // PushPromise(frame::PushPromise), } @@ -52,14 +54,16 @@ impl FramedRead { // TODO: Is this needed? } - fn decode_frame(&mut self, mut bytes: BytesMut) -> Result, ConnectionError> { + fn decode_frame(&mut self, mut bytes: BytesMut) -> Result, ProtoError> { + use self::ProtoError::*; + trace!("decoding frame from {}B", bytes.len()); // Parse the head let head = frame::Head::parse(&bytes); if self.partial.is_some() && head.kind() != Kind::Continuation { - return Err(ProtocolError.into()); + return Err(Connection(ProtocolError)); } let kind = head.kind(); @@ -68,17 +72,26 @@ impl FramedRead { let frame = match kind { Kind::Settings => { - frame::Settings::load(head, &bytes[frame::HEADER_LEN..])?.into() + let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); + + res.map_err(|_| Connection(ProtocolError))?.into() } Kind::Ping => { - frame::Ping::load(head, &bytes[frame::HEADER_LEN..])?.into() + let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]); + + res.map_err(|_| Connection(ProtocolError))?.into() } Kind::WindowUpdate => { - frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..])?.into() + let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]); + + res.map_err(|_| Connection(ProtocolError))?.into() } Kind::Data => { let _ = bytes.split_to(frame::HEADER_LEN); - frame::Data::load(head, bytes.freeze())?.into() + let res = frame::Data::load(head, bytes.freeze()); + + // TODO: Should this always be connection level? Probably not... + res.map_err(|_| Connection(ProtocolError))?.into() } Kind::Headers => { // Drop the frame header @@ -86,11 +99,24 @@ impl FramedRead { let _ = bytes.split_to(frame::HEADER_LEN); // Parse the header frame w/o parsing the payload - let (mut headers, payload) = frame::Headers::load(head, bytes)?; + let (mut headers, payload) = match frame::Headers::load(head, bytes) { + Ok(res) => res, + Err(_) => unimplemented!(), + }; if headers.is_end_headers() { // Load the HPACK encoded headers & return the frame - headers.load_hpack(payload, &mut self.hpack)?; + match headers.load_hpack(payload, &mut self.hpack) { + Ok(_) => {} + Err(frame::Error::MalformedMessage) => { + return Err(Stream { + id: head.stream_id(), + reason: ProtocolError, + }); + } + Err(_) => return Err(Connection(ProtocolError)), + } + headers.into() } else { // Defer loading the frame @@ -103,16 +129,20 @@ impl FramedRead { } } Kind::Reset => { - frame::Reset::load(head, &bytes[frame::HEADER_LEN..])?.into() + let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]); + res.map_err(|_| Connection(ProtocolError))?.into() } Kind::GoAway => { - frame::GoAway::load(&bytes[frame::HEADER_LEN..])?.into() + let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]); + res.map_err(|_| Connection(ProtocolError))?.into() } Kind::PushPromise => { - frame::PushPromise::load(head, &bytes[frame::HEADER_LEN..])?.into() + let res = frame::PushPromise::load(head, &bytes[frame::HEADER_LEN..]); + res.map_err(|_| Connection(ProtocolError))?.into() } Kind::Priority => { - frame::Priority::load(head, &bytes[frame::HEADER_LEN..])?.into() + let res = frame::Priority::load(head, &bytes[frame::HEADER_LEN..]); + res.map_err(|_| Connection(ProtocolError))?.into() } Kind::Continuation => { // TODO: Un-hack this @@ -120,7 +150,7 @@ impl FramedRead { let mut partial = match self.partial.take() { Some(partial) => partial, - None => return Err(ProtocolError.into()), + None => return Err(Connection(ProtocolError)), }; // Extend the buf @@ -135,10 +165,20 @@ impl FramedRead { Continuable::Headers(mut frame) => { // The stream identifiers must match if frame.stream_id() != head.stream_id() { - return Err(ProtocolError.into()); + return Err(Connection(ProtocolError)); + } + + match frame.load_hpack(partial.buf, &mut self.hpack) { + Ok(_) => {} + Err(frame::Error::MalformedMessage) => { + return Err(Stream { + id: head.stream_id(), + reason: ProtocolError, + }); + } + Err(_) => return Err(Connection(ProtocolError)), } - frame.load_hpack(partial.buf, &mut self.hpack)?; frame.into() } } @@ -165,9 +205,9 @@ impl futures::Stream for FramedRead where T: AsyncRead, { type Item = Frame; - type Error = ConnectionError; + type Error = ProtoError; - fn poll(&mut self) -> Poll, ConnectionError> { + fn poll(&mut self) -> Poll, Self::Error> { loop { trace!("poll"); let bytes = match try_ready!(self.inner.poll()) { diff --git a/src/proto/mod.rs b/src/proto/mod.rs index dff1321..8d27d9e 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -26,6 +26,8 @@ use bytes::{Buf, IntoBuf}; use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::codec::length_delimited; +use std::io; + /// Either a Client or a Server pub trait Peer { /// Message type sent into the transport @@ -50,6 +52,17 @@ pub type PingPayload = [u8; 8]; pub type WindowSize = u32; +/// Errors that are received +#[derive(Debug)] +pub enum ProtoError { + Connection(Reason), + Stream { + id: StreamId, + reason: Reason, + }, + Io(io::Error), +} + // Constants pub const DEFAULT_INITIAL_WINDOW_SIZE: WindowSize = 65_535; pub const MAX_WINDOW_SIZE: WindowSize = (1 << 31) - 1; @@ -88,3 +101,11 @@ pub(crate) fn from_framed_write(framed_write: FramedWrite for ProtoError { + fn from(src: io::Error) -> Self { + ProtoError::Io(src) + } +} diff --git a/src/proto/ping_pong.rs b/src/proto/ping_pong.rs index 88f296a..4baa1b8 100644 --- a/src/proto/ping_pong.rs +++ b/src/proto/ping_pong.rs @@ -1,4 +1,3 @@ -use ConnectionError; use frame::Ping; use proto::*; diff --git a/src/proto/settings.rs b/src/proto/settings.rs index b18d095..ee405f9 100644 --- a/src/proto/settings.rs +++ b/src/proto/settings.rs @@ -1,4 +1,4 @@ -use {frame, ConnectionError}; +use frame; use proto::*; use futures::Sink; diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index d893bb8..1390310 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -1,5 +1,4 @@ use {frame, ConnectionError}; -use error::User::InactiveStreamId; use proto::*; use super::*; @@ -101,27 +100,10 @@ impl Send where B: Buf { Ok(()) } - /// This is called by the user to send a reset and should not be called - /// by internal state transitions. Use `reset_stream` for that. pub fn send_reset(&mut self, reason: Reason, stream: &mut store::Ptr, task: &mut Option) - -> Result<(), ConnectionError> - { - if stream.state.is_closed() { - debug!("send_reset; invalid stream ID"); - return Err(InactiveStreamId.into()) - } - - self.reset_stream(reason, stream, task); - Ok(()) - } - - fn reset_stream(&mut self, - reason: Reason, - stream: &mut store::Ptr, - task: &mut Option) { if stream.state.is_reset() { // Don't double reset @@ -240,7 +222,7 @@ impl Send where B: Buf { { if let Err(e) = self.prioritize.recv_stream_window_update(sz, stream) { debug!("recv_stream_window_update !!; err={:?}", e); - self.reset_stream(FlowControlError.into(), stream, task); + self.send_reset(FlowControlError.into(), stream, task); } Ok(()) diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index a2d3cf3..bea5348 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -312,6 +312,33 @@ impl Streams key: key, }) } + + pub fn send_reset(&mut self, id: StreamId, reason: Reason) { + let mut me = self.inner.lock().unwrap(); + let me = &mut *me; + + let key = match me.store.find_entry(id) { + Entry::Occupied(e) => e.key(), + Entry::Vacant(e) => { + match me.actions.recv.open::

(id) { + Ok(Some(stream_id)) => { + let stream = Stream::new( + stream_id, 0, 0); + + e.insert(stream) + } + _ => return, + } + } + }; + + + let stream = me.store.resolve(key); + + me.actions.transition::(stream, move |actions, stream| { + actions.send.send_reset(reason, stream, &mut actions.task) + }) + } } // ===== impl StreamRef ===== @@ -367,7 +394,7 @@ impl StreamRef me.actions.recv.take_request(&mut stream) } - pub fn send_reset(&mut self, reason: Reason) -> Result<(), ConnectionError> { + pub fn send_reset(&mut self, reason: Reason) { let mut me = self.inner.lock().unwrap(); let me = &mut *me; diff --git a/src/server.rs b/src/server.rs index 36b49c9..de5aafa 100644 --- a/src/server.rs +++ b/src/server.rs @@ -191,7 +191,7 @@ impl Stream { self.inner.send_trailers::(trailers) } - pub fn send_reset(mut self, reason: Reason) -> Result<(), ConnectionError> { + pub fn send_reset(mut self, reason: Reason) { self.inner.send_reset::(reason) } }