Handle malformed HEADERS
This commit is contained in:
		| @@ -186,14 +186,16 @@ impl Headers { | ||||
|         -> Result<(), Error> | ||||
|     { | ||||
|         let mut reg = false; | ||||
|         let mut err = false; | ||||
|         let mut malformed = false; | ||||
|  | ||||
|         macro_rules! set_pseudo { | ||||
|             ($field:ident, $val:expr) => {{ | ||||
|                 if reg { | ||||
|                     err = true; | ||||
|                     trace!("load_hpack; header malformed -- pseudo not at head of block"); | ||||
|                     malformed = true; | ||||
|                 } else if self.pseudo.$field.is_some() { | ||||
|                     err = true; | ||||
|                     trace!("load_hpack; header malformed -- repeated pseudo"); | ||||
|                     malformed = true; | ||||
|                 } else { | ||||
|                     self.pseudo.$field = Some($val); | ||||
|                 } | ||||
| @@ -212,8 +214,19 @@ impl Headers { | ||||
|  | ||||
|             match header { | ||||
|                 Field { name, value } => { | ||||
|                     reg = true; | ||||
|                     self.fields.append(name, value); | ||||
|                     // Connection level header fields are not supported and must | ||||
|                     // result in a protocol error. | ||||
|  | ||||
|                     if name == header::CONNECTION { | ||||
|                         trace!("load_hpack; connection level header"); | ||||
|                         malformed = true; | ||||
|                     } else if name == header::TE && value != "trailers" { | ||||
|                         trace!("load_hpack; TE header not set to trailers; val={:?}", value); | ||||
|                         malformed = true; | ||||
|                     } else { | ||||
|                         reg = true; | ||||
|                         self.fields.append(name, value); | ||||
|                     } | ||||
|                 } | ||||
|                 Authority(v) => set_pseudo!(authority, v), | ||||
|                 Method(v) => set_pseudo!(method, v), | ||||
| @@ -228,9 +241,9 @@ impl Headers { | ||||
|             return Err(e.into()); | ||||
|         } | ||||
|  | ||||
|         if err { | ||||
|             trace!("repeated pseudo"); | ||||
|             return Err(hpack::DecoderError::RepeatedPseudo.into()); | ||||
|         if malformed { | ||||
|             trace!("malformed message"); | ||||
|             return Err(Error::MalformedMessage.into()); | ||||
|         } | ||||
|  | ||||
|         Ok(()) | ||||
|   | ||||
| @@ -171,6 +171,9 @@ pub enum Error { | ||||
|     /// identifier other than zero. | ||||
|     InvalidStreamId, | ||||
|  | ||||
|     /// A request or response is malformed. | ||||
|     MalformedMessage, | ||||
|  | ||||
|     /// An invalid stream dependency ID was provided | ||||
|     /// | ||||
|     /// This is returend if a HEADERS or PRIORITY frame is received with an | ||||
|   | ||||
| @@ -35,7 +35,6 @@ pub enum DecoderError { | ||||
|     IntegerUnderflow, | ||||
|     IntegerOverflow, | ||||
|     StringUnderflow, | ||||
|     RepeatedPseudo, | ||||
|     UnexpectedEndOfStream, | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -53,9 +53,9 @@ impl<T, B> futures::Stream for Codec<T, B> | ||||
|     where T: AsyncRead, | ||||
| { | ||||
|     type Item = Frame; | ||||
|     type Error = ConnectionError; | ||||
|     type Error = ProtoError; | ||||
|  | ||||
|     fn poll(&mut self) -> Poll<Option<Frame>, ConnectionError> { | ||||
|     fn poll(&mut self) -> Poll<Option<Frame>, Self::Error> { | ||||
|         self.inner.poll() | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -123,6 +123,7 @@ impl<T, P, B> Connection<T, P, B> | ||||
|  | ||||
|     fn poll2(&mut self) -> Poll<(), ConnectionError> { | ||||
|         use frame::Frame::*; | ||||
|         use proto::ProtoError::*; | ||||
|  | ||||
|         loop { | ||||
|             // First, ensure that the `Connection` is able to receive a frame | ||||
| @@ -130,13 +131,29 @@ impl<T, P, B> Connection<T, P, B> | ||||
|  | ||||
|             trace!("polling codec"); | ||||
|  | ||||
|             let frame = match try!(self.codec.poll()) { | ||||
|                 Async::Ready(frame) => frame, | ||||
|                 Async::NotReady => { | ||||
|             let frame = match self.codec.poll() { | ||||
|                 // Receive a frame | ||||
|                 Ok(Async::Ready(frame)) => frame, | ||||
|                 // Socket not ready, try to flush any pending data | ||||
|                 Ok(Async::NotReady) => { | ||||
|                     // Flush any pending writes | ||||
|                     let _ = try!(self.poll_complete()); | ||||
|                     return Ok(Async::NotReady); | ||||
|                 } | ||||
|                 // Connection level error, set GO_AWAY and close connection | ||||
|                 Err(Connection(reason)) => { | ||||
|                     return Err(ConnectionError::Proto(reason)); | ||||
|                 } | ||||
|                 // Stream level error, reset the stream | ||||
|                 Err(Stream { id, reason }) => { | ||||
|                     trace!("stream level error; id={:?}; reason={:?}", id, reason); | ||||
|                     self.streams.send_reset::<P>(id, reason); | ||||
|                     continue; | ||||
|                 } | ||||
|                 // I/O error, nothing more can be done | ||||
|                 Err(Io(err)) => { | ||||
|                     return Err(err.into()); | ||||
|                 } | ||||
|             }; | ||||
|  | ||||
|             debug!("recv; frame={:?}", frame); | ||||
|   | ||||
| @@ -36,6 +36,8 @@ struct Partial { | ||||
| #[derive(Debug)] | ||||
| enum Continuable { | ||||
|     Headers(frame::Headers), | ||||
|     // Decode the Continuation frame but ignore it... | ||||
|     // Ignore(StreamId), | ||||
|     // PushPromise(frame::PushPromise), | ||||
| } | ||||
|  | ||||
| @@ -52,14 +54,16 @@ impl<T> FramedRead<T> { | ||||
|         // TODO: Is this needed? | ||||
|     } | ||||
|  | ||||
|     fn decode_frame(&mut self, mut bytes: BytesMut) -> Result<Option<Frame>, ConnectionError> { | ||||
|     fn decode_frame(&mut self, mut bytes: BytesMut) -> Result<Option<Frame>, ProtoError> { | ||||
|         use self::ProtoError::*; | ||||
|  | ||||
|         trace!("decoding frame from {}B", bytes.len()); | ||||
|  | ||||
|         // Parse the head | ||||
|         let head = frame::Head::parse(&bytes); | ||||
|  | ||||
|         if self.partial.is_some() && head.kind() != Kind::Continuation { | ||||
|             return Err(ProtocolError.into()); | ||||
|             return Err(Connection(ProtocolError)); | ||||
|         } | ||||
|  | ||||
|         let kind = head.kind(); | ||||
| @@ -68,17 +72,26 @@ impl<T> FramedRead<T> { | ||||
|  | ||||
|         let frame = match kind { | ||||
|             Kind::Settings => { | ||||
|                 frame::Settings::load(head, &bytes[frame::HEADER_LEN..])?.into() | ||||
|                 let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); | ||||
|  | ||||
|                 res.map_err(|_| Connection(ProtocolError))?.into() | ||||
|             } | ||||
|             Kind::Ping => { | ||||
|                 frame::Ping::load(head, &bytes[frame::HEADER_LEN..])?.into() | ||||
|                 let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]); | ||||
|  | ||||
|                 res.map_err(|_| Connection(ProtocolError))?.into() | ||||
|             } | ||||
|             Kind::WindowUpdate => { | ||||
|                 frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..])?.into() | ||||
|                 let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]); | ||||
|  | ||||
|                 res.map_err(|_| Connection(ProtocolError))?.into() | ||||
|             } | ||||
|             Kind::Data => { | ||||
|                 let _ = bytes.split_to(frame::HEADER_LEN); | ||||
|                 frame::Data::load(head, bytes.freeze())?.into() | ||||
|                 let res = frame::Data::load(head, bytes.freeze()); | ||||
|  | ||||
|                 // TODO: Should this always be connection level? Probably not... | ||||
|                 res.map_err(|_| Connection(ProtocolError))?.into() | ||||
|             } | ||||
|             Kind::Headers => { | ||||
|                 // Drop the frame header | ||||
| @@ -86,11 +99,24 @@ impl<T> FramedRead<T> { | ||||
|                 let _ = bytes.split_to(frame::HEADER_LEN); | ||||
|  | ||||
|                 // Parse the header frame w/o parsing the payload | ||||
|                 let (mut headers, payload) = frame::Headers::load(head, bytes)?; | ||||
|                 let (mut headers, payload) = match frame::Headers::load(head, bytes) { | ||||
|                     Ok(res) => res, | ||||
|                     Err(_) => unimplemented!(), | ||||
|                 }; | ||||
|  | ||||
|                 if headers.is_end_headers() { | ||||
|                     // Load the HPACK encoded headers & return the frame | ||||
|                     headers.load_hpack(payload, &mut self.hpack)?; | ||||
|                     match headers.load_hpack(payload, &mut self.hpack) { | ||||
|                         Ok(_) => {} | ||||
|                         Err(frame::Error::MalformedMessage) => { | ||||
|                             return Err(Stream { | ||||
|                                 id: head.stream_id(), | ||||
|                                 reason: ProtocolError, | ||||
|                             }); | ||||
|                         } | ||||
|                         Err(_) => return Err(Connection(ProtocolError)), | ||||
|                     } | ||||
|  | ||||
|                     headers.into() | ||||
|                 } else { | ||||
|                     // Defer loading the frame | ||||
| @@ -103,16 +129,20 @@ impl<T> FramedRead<T> { | ||||
|                 } | ||||
|             } | ||||
|             Kind::Reset => { | ||||
|                 frame::Reset::load(head, &bytes[frame::HEADER_LEN..])?.into() | ||||
|                 let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]); | ||||
|                 res.map_err(|_| Connection(ProtocolError))?.into() | ||||
|             } | ||||
|             Kind::GoAway => { | ||||
|                 frame::GoAway::load(&bytes[frame::HEADER_LEN..])?.into() | ||||
|                 let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]); | ||||
|                 res.map_err(|_| Connection(ProtocolError))?.into() | ||||
|             } | ||||
|             Kind::PushPromise => { | ||||
|                 frame::PushPromise::load(head, &bytes[frame::HEADER_LEN..])?.into() | ||||
|                 let res = frame::PushPromise::load(head, &bytes[frame::HEADER_LEN..]); | ||||
|                 res.map_err(|_| Connection(ProtocolError))?.into() | ||||
|             } | ||||
|             Kind::Priority => { | ||||
|                 frame::Priority::load(head, &bytes[frame::HEADER_LEN..])?.into() | ||||
|                 let res = frame::Priority::load(head, &bytes[frame::HEADER_LEN..]); | ||||
|                 res.map_err(|_| Connection(ProtocolError))?.into() | ||||
|             } | ||||
|             Kind::Continuation => { | ||||
|                 // TODO: Un-hack this | ||||
| @@ -120,7 +150,7 @@ impl<T> FramedRead<T> { | ||||
|  | ||||
|                 let mut partial = match self.partial.take() { | ||||
|                     Some(partial) => partial, | ||||
|                     None => return Err(ProtocolError.into()), | ||||
|                     None => return Err(Connection(ProtocolError)), | ||||
|                 }; | ||||
|  | ||||
|                 // Extend the buf | ||||
| @@ -135,10 +165,20 @@ impl<T> FramedRead<T> { | ||||
|                     Continuable::Headers(mut frame) => { | ||||
|                         // The stream identifiers must match | ||||
|                         if frame.stream_id() != head.stream_id() { | ||||
|                             return Err(ProtocolError.into()); | ||||
|                             return Err(Connection(ProtocolError)); | ||||
|                         } | ||||
|  | ||||
|                         match frame.load_hpack(partial.buf, &mut self.hpack) { | ||||
|                             Ok(_) => {} | ||||
|                             Err(frame::Error::MalformedMessage) => { | ||||
|                                 return Err(Stream { | ||||
|                                     id: head.stream_id(), | ||||
|                                     reason: ProtocolError, | ||||
|                                 }); | ||||
|                             } | ||||
|                             Err(_) => return Err(Connection(ProtocolError)), | ||||
|                         } | ||||
|  | ||||
|                         frame.load_hpack(partial.buf, &mut self.hpack)?; | ||||
|                         frame.into() | ||||
|                     } | ||||
|                 } | ||||
| @@ -165,9 +205,9 @@ impl<T> futures::Stream for FramedRead<T> | ||||
|     where T: AsyncRead, | ||||
| { | ||||
|     type Item = Frame; | ||||
|     type Error = ConnectionError; | ||||
|     type Error = ProtoError; | ||||
|  | ||||
|     fn poll(&mut self) -> Poll<Option<Frame>, ConnectionError> { | ||||
|     fn poll(&mut self) -> Poll<Option<Frame>, Self::Error> { | ||||
|         loop { | ||||
|             trace!("poll"); | ||||
|             let bytes = match try_ready!(self.inner.poll()) { | ||||
|   | ||||
| @@ -26,6 +26,8 @@ use bytes::{Buf, IntoBuf}; | ||||
| use tokio_io::{AsyncRead, AsyncWrite}; | ||||
| use tokio_io::codec::length_delimited; | ||||
|  | ||||
| use std::io; | ||||
|  | ||||
| /// Either a Client or a Server | ||||
| pub trait Peer { | ||||
|     /// Message type sent into the transport | ||||
| @@ -50,6 +52,17 @@ pub type PingPayload = [u8; 8]; | ||||
|  | ||||
| pub type WindowSize = u32; | ||||
|  | ||||
| /// Errors that are received | ||||
| #[derive(Debug)] | ||||
| pub enum ProtoError { | ||||
|     Connection(Reason), | ||||
|     Stream { | ||||
|         id: StreamId, | ||||
|         reason: Reason, | ||||
|     }, | ||||
|     Io(io::Error), | ||||
| } | ||||
|  | ||||
| // Constants | ||||
| pub const DEFAULT_INITIAL_WINDOW_SIZE: WindowSize = 65_535; | ||||
| pub const MAX_WINDOW_SIZE: WindowSize = (1 << 31) - 1; | ||||
| @@ -88,3 +101,11 @@ pub(crate) fn from_framed_write<T, P, B>(framed_write: FramedWrite<T, Prioritize | ||||
|  | ||||
|     Connection::new(codec) | ||||
| } | ||||
|  | ||||
| // ===== impl ProtoError ===== | ||||
|  | ||||
| impl From<io::Error> for ProtoError { | ||||
|     fn from(src: io::Error) -> Self { | ||||
|         ProtoError::Io(src) | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
| use ConnectionError; | ||||
| use frame::Ping; | ||||
| use proto::*; | ||||
|  | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| use {frame, ConnectionError}; | ||||
| use frame; | ||||
| use proto::*; | ||||
|  | ||||
| use futures::Sink; | ||||
|   | ||||
| @@ -1,5 +1,4 @@ | ||||
| use {frame, ConnectionError}; | ||||
| use error::User::InactiveStreamId; | ||||
| use proto::*; | ||||
| use super::*; | ||||
|  | ||||
| @@ -101,27 +100,10 @@ impl<B> Send<B> where B: Buf { | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     /// This is called by the user to send a reset and should not be called | ||||
|     /// by internal state transitions. Use `reset_stream` for that. | ||||
|     pub fn send_reset(&mut self, | ||||
|                       reason: Reason, | ||||
|                       stream: &mut store::Ptr<B>, | ||||
|                       task: &mut Option<Task>) | ||||
|         -> Result<(), ConnectionError> | ||||
|     { | ||||
|         if stream.state.is_closed() { | ||||
|             debug!("send_reset; invalid stream ID"); | ||||
|             return Err(InactiveStreamId.into()) | ||||
|         } | ||||
|  | ||||
|         self.reset_stream(reason, stream, task); | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     fn reset_stream(&mut self, | ||||
|                     reason: Reason, | ||||
|                     stream: &mut store::Ptr<B>, | ||||
|                     task: &mut Option<Task>) | ||||
|     { | ||||
|         if stream.state.is_reset() { | ||||
|             // Don't double reset | ||||
| @@ -240,7 +222,7 @@ impl<B> Send<B> where B: Buf { | ||||
|     { | ||||
|         if let Err(e) = self.prioritize.recv_stream_window_update(sz, stream) { | ||||
|             debug!("recv_stream_window_update !!; err={:?}", e); | ||||
|             self.reset_stream(FlowControlError.into(), stream, task); | ||||
|             self.send_reset(FlowControlError.into(), stream, task); | ||||
|         } | ||||
|  | ||||
|         Ok(()) | ||||
|   | ||||
| @@ -312,6 +312,33 @@ impl<B> Streams<B> | ||||
|             key: key, | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     pub fn send_reset<P: Peer>(&mut self, id: StreamId, reason: Reason) { | ||||
|         let mut me = self.inner.lock().unwrap(); | ||||
|         let me = &mut *me; | ||||
|  | ||||
|         let key = match me.store.find_entry(id) { | ||||
|             Entry::Occupied(e) => e.key(), | ||||
|             Entry::Vacant(e) => { | ||||
|                 match me.actions.recv.open::<P>(id) { | ||||
|                     Ok(Some(stream_id)) => { | ||||
|                         let stream = Stream::new( | ||||
|                             stream_id, 0, 0); | ||||
|  | ||||
|                         e.insert(stream) | ||||
|                     } | ||||
|                     _ => return, | ||||
|                 } | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|  | ||||
|         let stream = me.store.resolve(key); | ||||
|  | ||||
|         me.actions.transition::<P, _, _>(stream, move |actions, stream| { | ||||
|             actions.send.send_reset(reason, stream, &mut actions.task) | ||||
|         }) | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ===== impl StreamRef ===== | ||||
| @@ -367,7 +394,7 @@ impl<B> StreamRef<B> | ||||
|         me.actions.recv.take_request(&mut stream) | ||||
|     } | ||||
|  | ||||
|     pub fn send_reset<P: Peer>(&mut self, reason: Reason) -> Result<(), ConnectionError> { | ||||
|     pub fn send_reset<P: Peer>(&mut self, reason: Reason) { | ||||
|         let mut me = self.inner.lock().unwrap(); | ||||
|         let me = &mut *me; | ||||
|  | ||||
|   | ||||
| @@ -191,7 +191,7 @@ impl<B: IntoBuf> Stream<B> { | ||||
|         self.inner.send_trailers::<Peer>(trailers) | ||||
|     } | ||||
|  | ||||
|     pub fn send_reset(mut self, reason: Reason) -> Result<(), ConnectionError> { | ||||
|     pub fn send_reset(mut self, reason: Reason) { | ||||
|         self.inner.send_reset::<Peer>(reason) | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user