From e2871d92faeccf7bd4064c549b99e27fdc58bc1f Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Fri, 10 Mar 2017 13:02:04 -0800 Subject: [PATCH] Work --- Cargo.toml | 3 + src/error.rs | 152 ++++++++++++++++++++++++++++++++++++++ src/frame/head.rs | 29 +++++++- src/frame/mod.rs | 88 +++++++++++++++++----- src/frame/reader.rs | 40 ++++++++++ src/frame/settings.rs | 87 ++++++++++++++++++++-- src/frame/unknown.rs | 14 +++- src/frame/writer.rs | 1 + src/lib.rs | 47 +++++++++--- src/proto/framed_read.rs | 43 +++++++++++ src/proto/framed_write.rs | 89 ++++++++++++++++++++++ src/proto/mod.rs | 4 + 12 files changed, 559 insertions(+), 38 deletions(-) create mode 100644 src/error.rs create mode 100644 src/frame/reader.rs create mode 100644 src/frame/writer.rs create mode 100644 src/proto/framed_read.rs create mode 100644 src/proto/framed_write.rs create mode 100644 src/proto/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 6a7dddd..26b5721 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,3 +8,6 @@ futures = "0.1" tokio-io = { git = "https://github.com/alexcrichton/tokio-io" } tokio-timer = { git = "https://github.com/tokio-rs/tokio-timer" } bytes = { git = "https://github.com/carllerche/bytes" } + +[replace] +"futures:0.1.10" = { git = "https://github.com/alexcrichton/futures-rs", branch = "close" } diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..b6de333 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,152 @@ +use std::{error, fmt, io}; + +/// The error type for HTTP/2 operations +#[derive(Debug)] +pub enum ConnectionError { + /// The HTTP/2 stream was reset + Proto(Reason), + /// An `io::Error` occurred while trying to read or write. + Io(io::Error), +} + +pub struct StreamError(Reason); + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum Reason { + NoError, + ProtocolError, + InternalError, + FlowControlError, + SettingsTimeout, + StreamClosed, + FrameSizeError, + RefusedStream, + Cancel, + CompressionError, + ConnectError, + EnhanceYourCalm, + InadequateSecurity, + Http11Required, + Other(u32), +} + +macro_rules! reason_desc { + ($reason:expr) => (reason_desc!($reason, "")); + ($reason:expr, $prefix:expr) => ({ + match $reason { + Reason::NoError => concat!($prefix, "not a result of an error"), + Reason::ProtocolError => concat!($prefix, "unspecific protocol error detected"), + Reason::InternalError => concat!($prefix, "unexpected internal error encountered"), + Reason::FlowControlError => concat!($prefix, "flow-control protocol violated"), + Reason::SettingsTimeout => concat!($prefix, "settings ACK not received in timely manner"), + Reason::StreamClosed => concat!($prefix, "received frame when stream half-closed"), + Reason::FrameSizeError => concat!($prefix, "frame sent with invalid size"), + Reason::RefusedStream => concat!($prefix, "refused stream before processing any application logic"), + Reason::Cancel => concat!($prefix, "stream no longer needed"), + Reason::CompressionError => concat!($prefix, "unable to maintain the header compression context"), + Reason::ConnectError => concat!($prefix, "connection established in response to a CONNECT request was reset or abnormally closed"), + Reason::EnhanceYourCalm => concat!($prefix, "detected excessive load generating behavior"), + Reason::InadequateSecurity => concat!($prefix, "security properties do not meet minimum requirements"), + Reason::Http11Required => concat!($prefix, "endpoint requires HTTP/1.1"), + Reason::Other(_) => concat!($prefix, "other reason"), + } + }); +} + +// ===== impl ConnectionError ===== + +impl From for ConnectionError { + fn from(src: io::Error) -> ConnectionError { + ConnectionError::Io(src) + } +} + +impl From for io::Error { + fn from(src: ConnectionError) -> io::Error { + io::Error::new(io::ErrorKind::Other, src) + } +} + +impl fmt::Display for ConnectionError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + use self::ConnectionError::*; + + match *self { + Proto(reason) => write!(fmt, "protocol error: {}", reason), + Io(ref e) => fmt::Display::fmt(e, fmt), + } + } +} + +impl error::Error for ConnectionError { + fn description(&self) -> &str { + use self::ConnectionError::*; + + match *self { + Io(ref e) => error::Error::description(e), + Proto(reason) => reason_desc!(reason, "protocol error: "), + } + } +} + +// ===== impl Reason ===== + +impl Reason { + pub fn description(&self) -> &str { + reason_desc!(*self) + } +} + +impl From for Reason { + fn from(src: u32) -> Reason { + use self::Reason::*; + + match src { + 0x0 => NoError, + 0x1 => ProtocolError, + 0x2 => InternalError, + 0x3 => FlowControlError, + 0x4 => SettingsTimeout, + 0x5 => StreamClosed, + 0x6 => FrameSizeError, + 0x7 => RefusedStream, + 0x8 => Cancel, + 0x9 => CompressionError, + 0xa => ConnectError, + 0xb => EnhanceYourCalm, + 0xc => InadequateSecurity, + 0xd => Http11Required, + _ => Other(src), + } + } +} + +impl From for u32 { + fn from(src: Reason) -> u32 { + use self::Reason::*; + + match src { + NoError => 0x0, + ProtocolError => 0x1, + InternalError => 0x2, + FlowControlError => 0x3, + SettingsTimeout => 0x4, + StreamClosed => 0x5, + FrameSizeError => 0x6, + RefusedStream => 0x7, + Cancel => 0x8, + CompressionError => 0x9, + ConnectError => 0xa, + EnhanceYourCalm => 0xb, + InadequateSecurity => 0xc, + Http11Required => 0xd, + Other(v) => v, + } + } +} + +impl fmt::Display for Reason { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "{}", self.description()) + } +} diff --git a/src/frame/head.rs b/src/frame/head.rs index 49e8b42..202640b 100644 --- a/src/frame/head.rs +++ b/src/frame/head.rs @@ -1,5 +1,7 @@ use super::Error; +use bytes::{BufMut, BigEndian}; + #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct Head { kind: Kind, @@ -25,9 +27,19 @@ pub enum Kind { pub type StreamId = u32; +const STREAM_ID_MASK: StreamId = 0x80000000; + // ===== impl Head ===== impl Head { + pub fn new(kind: Kind, flag: u8, stream_id: StreamId) -> Head { + Head { + kind: kind, + flag: flag, + stream_id: stream_id, + } + } + /// Parse an HTTP/2.0 frame header pub fn parse(header: &[u8]) -> Head { Head { @@ -48,6 +60,21 @@ impl Head { pub fn flag(&self) -> u8 { self.flag } + + pub fn encode_len(&self) -> usize { + super::FRAME_HEADER_LEN + } + + pub fn encode(&self, payload_len: usize, dst: &mut T) -> Result<(), Error> { + debug_assert_eq!(self.encode_len(), dst.remaining_mut()); + debug_assert!(self.stream_id & STREAM_ID_MASK == 0); + + dst.put_uint::(payload_len as u64, 3); + dst.put_u8(self.kind as u8); + dst.put_u8(self.flag); + dst.put_u32::(self.stream_id); + Ok(()) + } } /// Parse the next 4 octets in the given buffer, assuming they represent an @@ -58,7 +85,7 @@ impl Head { fn parse_stream_id(buf: &[u8]) -> StreamId { let unpacked = unpack_octets_4!(buf, 0, u32); // Now clear the most significant bit, as that is reserved and MUST be ignored when received. - unpacked & !0x80000000 + unpacked & !STREAM_ID_MASK } // ===== impl Kind ===== diff --git a/src/frame/mod.rs b/src/frame/mod.rs index 971b934..3769269 100644 --- a/src/frame/mod.rs +++ b/src/frame/mod.rs @@ -1,4 +1,7 @@ -use bytes::Bytes; +use error::{ConnectionError, Reason}; +use bytes::{Bytes, BytesMut, BufMut}; + +use std::io; /// A helper macro that unpacks a sequence of 4 bytes found in the buffer with /// the given identifier, starting at the given offset, into the given integer @@ -29,10 +32,27 @@ mod util; pub use self::data::Data; pub use self::head::{Head, Kind, StreamId}; +pub use self::settings::Settings; pub use self::unknown::Unknown; const FRAME_HEADER_LEN: usize = 9; +#[derive(Debug, Clone, PartialEq)] +pub enum Frame { + /* + Data(DataFrame<'a>), + HeadersFrame(HeadersFrame<'a>), + RstStreamFrame(RstStreamFrame), + SettingsFrame(SettingsFrame), + PingFrame(PingFrame), + GoawayFrame(GoawayFrame<'a>), + WindowUpdateFrame(WindowUpdateFrame), + UnknownFrame(RawFrame<'a>), + */ + Settings(Settings), + Unknown(Unknown), +} + /// Errors that can occur during parsing an HTTP/2 frame. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Error { @@ -64,26 +84,22 @@ pub enum Error { /// The payload length specified by the frame header was not the /// value necessary for the specific frame type. - InvalidPayloadLength + InvalidPayloadLength, + + /// Received a payload with an ACK settings frame + InvalidPayloadAckSettings, + + /// An invalid stream identifier was provided. + /// + /// This is returned if a settings frame is received with a stream + /// identifier other than zero. + InvalidStreamId, } -#[derive(Debug, Clone, PartialEq)] -pub enum Frame { - /* - Data(DataFrame<'a>), - HeadersFrame(HeadersFrame<'a>), - RstStreamFrame(RstStreamFrame), - SettingsFrame(SettingsFrame), - PingFrame(PingFrame), - GoawayFrame(GoawayFrame<'a>), - WindowUpdateFrame(WindowUpdateFrame), - UnknownFrame(RawFrame<'a>), - */ - Unknown(Unknown), -} +// ===== impl Frame ====== impl Frame { - pub fn load(mut frame: Bytes) -> Frame { + pub fn load(mut frame: Bytes) -> Result { let head = Head::parse(&frame); // Extract the payload from the frame @@ -91,8 +107,44 @@ impl Frame { match head.kind() { - Kind::Unknown => Frame::Unknown(Unknown::new(head, frame)), + Kind::Unknown => { + let unknown = Unknown::new(head, frame); + Ok(Frame::Unknown(unknown)) + } _ => unimplemented!(), } } + + pub fn encode_len(&self) -> usize { + use self::Frame::*; + + match *self { + Settings(ref frame) => frame.encode_len(), + Unknown(ref frame) => frame.encode_len(), + } + } + + pub fn encode(&self, dst: &mut BytesMut) -> Result<(), Error> { + use self::Frame::*; + + debug_assert!(dst.remaining_mut() >= self.encode_len()); + + match *self { + Settings(ref frame) => frame.encode(dst), + Unknown(ref frame) => frame.encode(dst), + } + } +} + +// ===== impl Error ===== + +impl From for ConnectionError { + fn from(src: Error) -> ConnectionError { + use self::Error::*; + + match src { + // TODO: implement + _ => ConnectionError::Proto(Reason::ProtocolError), + } + } } diff --git a/src/frame/reader.rs b/src/frame/reader.rs new file mode 100644 index 0000000..6d5e11b --- /dev/null +++ b/src/frame/reader.rs @@ -0,0 +1,40 @@ +use ConnectionError; +use super::Frame; +use futures::*; +use bytes::BytesMut; +use std::io; + +pub struct Reader { + inner: T, +} + +impl Stream for Reader + where T: Stream, +{ + type Item = Frame; + type Error = ConnectionError; + + fn poll(&mut self) -> Poll, ConnectionError> { + match try_ready!(self.inner.poll()) { + Some(bytes) => { + Frame::load(bytes.freeze()) + .map(|frame| Async::Ready(Some(frame))) + .map_err(ConnectionError::from) + } + None => Ok(Async::Ready(None)), + } + } +} + +impl Sink for Reader { + type SinkItem = T::SinkItem; + type SinkError = T::SinkError; + + fn start_send(&mut self, item: T::SinkItem) -> StartSend { + self.inner.start_send(item) + } + + fn poll_complete(&mut self) -> Poll<(), T::SinkError> { + self.inner.poll_complete() + } +} diff --git a/src/frame/settings.rs b/src/frame/settings.rs index 257df5a..73b7f94 100644 --- a/src/frame/settings.rs +++ b/src/frame/settings.rs @@ -1,7 +1,7 @@ -use frame::{Error, Head}; -use bytes::Bytes; +use frame::{Error, Head, Kind}; +use bytes::{Bytes, BytesMut, BufMut, BigEndian}; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, Eq, PartialEq)] pub struct Settings { flag: SettingsFlag, // Fields @@ -41,8 +41,7 @@ impl Settings { debug_assert_eq!(head.kind(), ::frame::Kind::Settings); if head.stream_id() != 0 { - // TODO: raise ProtocolError - unimplemented!(); + return Err(Error::InvalidStreamId); } // Load the flag @@ -51,8 +50,7 @@ impl Settings { if flag.is_ack() { // Ensure that the payload is empty if payload.len() > 0 { - // TODO: raise a FRAME_SIZE_ERROR - unimplemented!(); + return Err(Error::InvalidPayloadLength); } // Return the ACK frame @@ -64,7 +62,7 @@ impl Settings { // Ensure the payload length is correct, each setting is 6 bytes long. if payload.len() % 6 != 0 { - return Err(Error::PartialSettingLength); + return Err(Error::InvalidPayloadAckSettings); } let mut settings = Settings::default(); @@ -96,6 +94,57 @@ impl Settings { Ok(settings) } + + pub fn encode_len(&self) -> usize { + super::FRAME_HEADER_LEN + self.payload_len() + } + + fn payload_len(&self) -> usize { + let mut len = 0; + self.for_each(|_| len += 6); + len + } + + pub fn encode(&self, dst: &mut BytesMut) -> Result<(), Error> { + // Create & encode an appropriate frame head + let head = Head::new(Kind::Settings, self.flag.into(), 0); + let payload_len = self.payload_len(); + + try!(head.encode(payload_len, dst)); + + // Encode the settings + self.for_each(|setting| setting.encode(dst)); + + Ok(()) + } + + fn for_each(&self, mut f: F) { + use self::Setting::*; + + if let Some(v) = self.header_table_size { + f(HeaderTableSize(v)); + } + + if let Some(v) = self.enable_push { + f(EnablePush(if v { 1 } else { 0 })); + } + + if let Some(v) = self.max_concurrent_streams { + f(MaxConcurrentStreams(v)); + } + + if let Some(v) = self.initial_window_size { + f(InitialWindowSize(v)); + } + + if let Some(v) = self.max_frame_size { + f(MaxFrameSize(v)); + } + + if let Some(v) = self.max_header_list_size { + f(MaxHeaderListSize(v)); + } + } } // ===== impl Setting ===== @@ -134,6 +183,22 @@ impl Setting { Setting::from_id(id, val) } + + fn encode(&self, dst: &mut BytesMut) { + use self::Setting::*; + + let (kind, val) = match *self { + HeaderTableSize(v) => (1, v), + EnablePush(v) => (2, v), + MaxConcurrentStreams(v) => (3, v), + InitialWindowSize(v) => (4, v), + MaxFrameSize(v) => (5, v), + MaxHeaderListSize(v) => (6, v), + }; + + dst.put_u16::(kind); + dst.put_u32::(val); + } } // ===== impl SettingsFlag ===== @@ -151,3 +216,9 @@ impl SettingsFlag { self.0 & ACK == ACK } } + +impl From for u8 { + fn from(src: SettingsFlag) -> u8 { + src.0 + } +} diff --git a/src/frame/unknown.rs b/src/frame/unknown.rs index 200d2be..0531999 100644 --- a/src/frame/unknown.rs +++ b/src/frame/unknown.rs @@ -1,5 +1,5 @@ -use frame::Head; -use bytes::Bytes; +use frame::{Head, Error}; +use bytes::{Bytes, BytesMut, BufMut}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Unknown { @@ -14,4 +14,14 @@ impl Unknown { payload: payload, } } + + pub fn encode_len(&self) -> usize { + self.head.encode_len() + self.payload.len() + } + + pub fn encode(&self, dst: &mut BytesMut) -> Result<(), Error> { + try!(self.head.encode(self.payload.len(), dst)); + dst.put(&self.payload); + Ok(()) + } } diff --git a/src/frame/writer.rs b/src/frame/writer.rs new file mode 100644 index 0000000..5a13dca --- /dev/null +++ b/src/frame/writer.rs @@ -0,0 +1 @@ +pub struct Writer; diff --git a/src/lib.rs b/src/lib.rs index 66001b0..96a0ce1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,17 @@ #![allow(warnings)] extern crate futures; +#[macro_use] extern crate tokio_io; extern crate tokio_timer; extern crate bytes; +pub mod error; +pub mod proto; pub mod frame; +pub use error::{ConnectionError, StreamError, Reason}; + use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::codec::length_delimited; @@ -16,15 +21,39 @@ pub struct Transport { inner: length_delimited::FramedRead, } -pub fn bind(io: T) -> Transport { - let framed = length_delimited::Builder::new() - .big_endian() - .length_field_length(3) - .length_adjustment(6) - .num_skip(0) // Don't skip the header - .new_read(io); +impl Transport { + pub fn bind(io: T) -> Transport { + let framed = length_delimited::Builder::new() + .big_endian() + .length_field_length(3) + .length_adjustment(6) + .num_skip(0) // Don't skip the header + .new_read(io); - Transport { - inner: framed, + Transport { + inner: framed, + } + } +} + +impl Stream for Transport { + type Item = frame::Frame; + type Error = ConnectionError; + + fn poll(&mut self) -> Poll, ConnectionError> { + unimplemented!(); + } +} + +impl Sink for Transport { + type SinkItem = frame::Frame; + type SinkError = ConnectionError; + + fn start_send(&mut self, item: frame::Frame) -> StartSend { + unimplemented!(); + } + + fn poll_complete(&mut self) -> Poll<(), ConnectionError> { + unimplemented!(); } } diff --git a/src/proto/framed_read.rs b/src/proto/framed_read.rs new file mode 100644 index 0000000..cf67b89 --- /dev/null +++ b/src/proto/framed_read.rs @@ -0,0 +1,43 @@ +use ConnectionError; +use frame::Frame; + +use futures::*; +use bytes::BytesMut; + +use std::io; + +pub struct FramedRead { + inner: T, +} + +impl Stream for FramedRead + where T: Stream, +{ + type Item = Frame; + type Error = ConnectionError; + + fn poll(&mut self) -> Poll, ConnectionError> { + match try_ready!(self.inner.poll()) { + Some(bytes) => { + Frame::load(bytes.freeze()) + .map(|frame| Async::Ready(Some(frame))) + .map_err(ConnectionError::from) + } + None => Ok(Async::Ready(None)), + } + } +} + +impl Sink for FramedRead { + type SinkItem = T::SinkItem; + type SinkError = T::SinkError; + + fn start_send(&mut self, item: T::SinkItem) -> StartSend { + self.inner.start_send(item) + } + + fn poll_complete(&mut self) -> Poll<(), T::SinkError> { + self.inner.poll_complete() + } +} + diff --git a/src/proto/framed_write.rs b/src/proto/framed_write.rs new file mode 100644 index 0000000..4b95f2a --- /dev/null +++ b/src/proto/framed_write.rs @@ -0,0 +1,89 @@ +use {ConnectionError, Reason}; +use frame::{Frame, Error}; + +use tokio_io::AsyncWrite; +use futures::*; +use bytes::{BytesMut, Buf, BufMut}; + +use std::io::{self, Cursor}; + +#[derive(Debug)] +pub struct FramedWrite { + inner: T, + buf: Cursor, +} + +const DEFAULT_BUFFER_CAPACITY: usize = 8 * 1_024; +const MAX_BUFFER_CAPACITY: usize = 16 * 1_024; + +impl FramedWrite { + pub fn new(inner: T) -> FramedWrite { + let buf = BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY); + + FramedWrite { + inner: inner, + buf: Cursor::new(buf), + } + } + + fn write_buf(&mut self) -> &mut BytesMut { + self.buf.get_mut() + } +} + +impl Sink for FramedWrite { + type SinkItem = Frame; + type SinkError = ConnectionError; + + fn start_send(&mut self, item: Frame) -> StartSend { + let len = item.encode_len(); + + if len > MAX_BUFFER_CAPACITY { + // This case should never happen. Large frames should be chunked at + // a higher level, so this is an internal error. + return Err(ConnectionError::Proto(Reason::InternalError)); + } + + if self.write_buf().remaining_mut() <= len { + // Try flushing the buffer + try!(self.poll_complete()); + + let rem = self.write_buf().remaining_mut(); + let additional = len - rem; + + if self.write_buf().capacity() + additional > MAX_BUFFER_CAPACITY { + return Ok(AsyncSink::NotReady(item)); + } + + // Grow the buffer + self.write_buf().reserve(additional); + } + + // At this point, the buffer contains enough space + item.encode(self.write_buf()); + + Ok(AsyncSink::Ready) + } + + fn poll_complete(&mut self) -> Poll<(), ConnectionError> { + while self.buf.has_remaining() { + try_ready!(self.inner.write_buf(&mut self.buf)); + + if !self.buf.has_remaining() { + // Reset the buffer + self.write_buf().clear(); + self.buf.set_position(0); + } + } + + // Try flushing the underlying IO + try_nb!(self.inner.flush()); + + return Ok(Async::Ready(())); + } + + fn close(&mut self) -> Poll<(), ConnectionError> { + try_ready!(self.poll_complete()); + self.inner.shutdown().map_err(Into::into) + } +} diff --git a/src/proto/mod.rs b/src/proto/mod.rs new file mode 100644 index 0000000..daa3d73 --- /dev/null +++ b/src/proto/mod.rs @@ -0,0 +1,4 @@ +mod framed_read; +mod framed_write; + +pub use self::framed_read::FramedRead;