Start state transition verification + refactors
This commit is contained in:
		| @@ -45,15 +45,12 @@ impl Peer for Client { | ||||
|     type Send = http::request::Head; | ||||
|     type Poll = http::response::Head; | ||||
|  | ||||
|     fn check_initiating_id(id: StreamId) -> Result<(), ConnectionError> { | ||||
|         if id % 2 == 0 { | ||||
|             // Client stream identifiers must be odd | ||||
|             unimplemented!(); | ||||
|         } | ||||
|     fn is_valid_local_stream_id(id: StreamId) -> bool { | ||||
|         id.is_client_initiated() | ||||
|     } | ||||
|  | ||||
|         // TODO: Ensure the `id` doesn't overflow u31 | ||||
|  | ||||
|         Ok(()) | ||||
|     fn is_valid_remote_stream_id(id: StreamId) -> bool { | ||||
|         id.is_server_initiated() | ||||
|     } | ||||
|  | ||||
|     fn convert_send_message( | ||||
|   | ||||
| @@ -62,6 +62,12 @@ impl From<io::Error> for ConnectionError { | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<Reason> for ConnectionError { | ||||
|     fn from(src: Reason) -> ConnectionError { | ||||
|         ConnectionError::Proto(src) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<ConnectionError> for io::Error { | ||||
|     fn from(src: ConnectionError) -> io::Error { | ||||
|         io::Error::new(io::ErrorKind::Other, src) | ||||
|   | ||||
| @@ -1,5 +1,4 @@ | ||||
| use frame::Error; | ||||
| use super::{head, StreamId}; | ||||
| use frame::{Error, StreamId}; | ||||
|  | ||||
| #[derive(Debug)] | ||||
| pub struct GoAway { | ||||
| @@ -15,7 +14,7 @@ impl GoAway { | ||||
|             unimplemented!(); | ||||
|         } | ||||
|  | ||||
|         let last_stream_id = head::parse_stream_id(&payload[..4]); | ||||
|         let last_stream_id = StreamId::parse(&payload[..4]); | ||||
|         let error_code = unpack_octets_4!(payload, 4, u32); | ||||
|  | ||||
|         Ok(GoAway { | ||||
|   | ||||
| @@ -1,3 +1,5 @@ | ||||
| use super::StreamId; | ||||
|  | ||||
| use bytes::{BufMut, BigEndian}; | ||||
|  | ||||
| #[derive(Debug, Copy, Clone, PartialEq, Eq)] | ||||
| @@ -23,10 +25,6 @@ pub enum Kind { | ||||
|     Unknown, | ||||
| } | ||||
|  | ||||
| pub type StreamId = u32; | ||||
|  | ||||
| const STREAM_ID_MASK: StreamId = 0x80000000; | ||||
|  | ||||
| // ===== impl Head ===== | ||||
|  | ||||
| impl Head { | ||||
| @@ -43,11 +41,11 @@ impl Head { | ||||
|         Head { | ||||
|             kind: Kind::new(header[3]), | ||||
|             flag: header[4], | ||||
|             stream_id: parse_stream_id(&header[5..]), | ||||
|             stream_id: StreamId::parse(&header[5..]), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn stream_id(&self) -> u32 { | ||||
|     pub fn stream_id(&self) -> StreamId { | ||||
|         self.stream_id | ||||
|     } | ||||
|  | ||||
| @@ -65,27 +63,14 @@ impl Head { | ||||
|  | ||||
|     pub fn encode<T: BufMut>(&self, payload_len: usize, dst: &mut T) { | ||||
|         debug_assert!(self.encode_len() <= dst.remaining_mut()); | ||||
|         debug_assert!(self.stream_id & STREAM_ID_MASK == 0); | ||||
|  | ||||
|         dst.put_uint::<BigEndian>(payload_len as u64, 3); | ||||
|         dst.put_u8(self.kind as u8); | ||||
|         dst.put_u8(self.flag); | ||||
|         dst.put_u32::<BigEndian>(self.stream_id); | ||||
|         dst.put_u32::<BigEndian>(self.stream_id.into()); | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Parse the next 4 octets in the given buffer, assuming they represent an | ||||
| /// HTTP/2 stream ID.  This means that the most significant bit of the first | ||||
| /// octet is ignored and the rest interpreted as a network-endian 31-bit | ||||
| /// integer. | ||||
| #[inline] | ||||
| pub fn parse_stream_id(buf: &[u8]) -> StreamId { | ||||
|     /// TODO: Move this onto the StreamId type? | ||||
|     let unpacked = unpack_octets_4!(buf, 0, u32); | ||||
|     // Now clear the most significant bit, as that is reserved and MUST be ignored when received. | ||||
|     unpacked & !STREAM_ID_MASK | ||||
| } | ||||
|  | ||||
| // ===== impl Kind ===== | ||||
|  | ||||
| impl Kind { | ||||
|   | ||||
| @@ -30,15 +30,17 @@ mod headers; | ||||
| mod ping; | ||||
| mod reset; | ||||
| mod settings; | ||||
| mod stream_id; | ||||
| mod util; | ||||
|  | ||||
| pub use self::data::Data; | ||||
| pub use self::go_away::GoAway; | ||||
| pub use self::head::{Head, Kind, StreamId}; | ||||
| pub use self::head::{Head, Kind}; | ||||
| pub use self::headers::{Headers, PushPromise, Continuation, Pseudo}; | ||||
| pub use self::ping::Ping; | ||||
| pub use self::reset::Reset; | ||||
| pub use self::settings::{Settings, SettingSet}; | ||||
| pub use self::stream_id::StreamId; | ||||
|  | ||||
| // Re-export some constants | ||||
| pub use self::settings::{ | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| use bytes::{Buf, BufMut, IntoBuf}; | ||||
| use frame::{Frame, Head, Kind, Error}; | ||||
| use frame::{Frame, Head, Kind, Error, StreamId}; | ||||
|  | ||||
| const ACK_FLAG: u8 = 0x1; | ||||
|  | ||||
| @@ -36,7 +36,7 @@ impl Ping { | ||||
|         // frame is received with a stream identifier field value other than | ||||
|         // 0x0, the recipient MUST respond with a connection error | ||||
|         // (Section 5.4.1) of type PROTOCOL_ERROR. | ||||
|         if head.stream_id() != 0 { | ||||
|         if !head.stream_id().is_zero() { | ||||
|             return Err(Error::InvalidStreamId); | ||||
|         } | ||||
|  | ||||
| @@ -45,6 +45,7 @@ impl Ping { | ||||
|         if bytes.len() != 8 { | ||||
|             return Err(Error::BadFrameSize); | ||||
|         } | ||||
|  | ||||
|         let mut payload = [0; 8]; | ||||
|         bytes.into_buf().copy_to_slice(&mut payload); | ||||
|  | ||||
| @@ -63,7 +64,7 @@ impl Ping { | ||||
|         trace!("encoding PING; ack={} len={}", self.ack, sz); | ||||
|  | ||||
|         let flags = if self.ack { ACK_FLAG } else { 0 }; | ||||
|         let head = Head::new(Kind::Ping, flags, 0); | ||||
|         let head = Head::new(Kind::Ping, flags, StreamId::zero()); | ||||
|  | ||||
|         head.encode(sz, dst); | ||||
|         dst.put_slice(&self.payload); | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| use frame::{Frame, Error, Head, Kind}; | ||||
| use frame::{Frame, Error, Head, Kind, StreamId}; | ||||
| use bytes::{BytesMut, BufMut, BigEndian}; | ||||
|  | ||||
| #[derive(Debug, Clone, Default, Eq, PartialEq)] | ||||
| @@ -71,7 +71,7 @@ impl Settings { | ||||
|  | ||||
|         debug_assert_eq!(head.kind(), ::frame::Kind::Settings); | ||||
|  | ||||
|         if head.stream_id() != 0 { | ||||
|         if !head.stream_id().is_zero() { | ||||
|             return Err(Error::InvalidStreamId); | ||||
|         } | ||||
|  | ||||
| @@ -132,7 +132,7 @@ impl Settings { | ||||
|  | ||||
|     pub fn encode(&self, dst: &mut BytesMut) { | ||||
|         // Create & encode an appropriate frame head | ||||
|         let head = Head::new(Kind::Settings, self.flags.into(), 0); | ||||
|         let head = Head::new(Kind::Settings, self.flags.into(), StreamId::zero()); | ||||
|         let payload_len = self.payload_len(); | ||||
|  | ||||
|         trace!("encoding SETTINGS; len={}", payload_len); | ||||
|   | ||||
							
								
								
									
										55
									
								
								src/frame/stream_id.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								src/frame/stream_id.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| use byteorder::{BigEndian, ByteOrder}; | ||||
| use std::u32; | ||||
|  | ||||
| #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] | ||||
| pub struct StreamId(u32); | ||||
|  | ||||
| const STREAM_ID_MASK: u32 = 1 << 31; | ||||
|  | ||||
| impl StreamId { | ||||
|     /// Parse the stream ID | ||||
|     #[inline] | ||||
|     pub fn parse(buf: &[u8]) -> StreamId { | ||||
|         let unpacked = BigEndian::read_u32(buf); | ||||
|         // Now clear the most significant bit, as that is reserved and MUST be | ||||
|         // ignored when received. | ||||
|         StreamId(unpacked & !STREAM_ID_MASK) | ||||
|     } | ||||
|  | ||||
|     pub fn is_client_initiated(&self) -> bool { | ||||
|         let id = self.0; | ||||
|         id != 0 && id % 2 == 1 | ||||
|     } | ||||
|  | ||||
|     pub fn is_server_initiated(&self) -> bool { | ||||
|         let id = self.0; | ||||
|         id != 0 && id % 2 == 0 | ||||
|     } | ||||
|  | ||||
|     #[inline] | ||||
|     pub fn zero() -> StreamId { | ||||
|         StreamId(0) | ||||
|     } | ||||
|  | ||||
|     #[inline] | ||||
|     pub fn max() -> StreamId { | ||||
|         StreamId(u32::MAX >> 1) | ||||
|     } | ||||
|  | ||||
|     pub fn is_zero(&self) -> bool { | ||||
|         self.0 == 0 | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<u32> for StreamId { | ||||
|     fn from(src: u32) -> Self { | ||||
|         assert_eq!(src & STREAM_ID_MASK, 0, "invalid stream ID -- MSB is set"); | ||||
|         StreamId(src) | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl From<StreamId> for u32 { | ||||
|     fn from(src: StreamId) -> Self { | ||||
|         src.0 | ||||
|     } | ||||
| } | ||||
| @@ -75,7 +75,13 @@ pub trait Peer { | ||||
|     /// Message type polled from the transport | ||||
|     type Poll; | ||||
|  | ||||
|     fn check_initiating_id(id: StreamId) -> Result<(), ConnectionError>; | ||||
|     /// Returns `true` if `id` is a valid StreamId for a stream initiated by the | ||||
|     /// local node. | ||||
|     fn is_valid_local_stream_id(id: StreamId) -> bool; | ||||
|  | ||||
|     /// Returns `true` if `id` is a valid StreamId for a stream initiated by the | ||||
|     /// remote node. | ||||
|     fn is_valid_remote_stream_id(id: StreamId) -> bool; | ||||
|  | ||||
|     #[doc(hidden)] | ||||
|     fn convert_send_message( | ||||
|   | ||||
| @@ -100,6 +100,20 @@ impl<T, P> Stream for Connection<T, P> | ||||
|                 let stream_id = v.stream_id(); | ||||
|                 let end_of_stream = v.is_end_stream(); | ||||
|  | ||||
|                 let stream_initialized = try!(self.streams.entry(stream_id) | ||||
|                      .or_insert(State::default()) | ||||
|                      .recv_headers::<P>(end_of_stream)); | ||||
|  | ||||
|                 if stream_initialized { | ||||
|                     // TODO: Ensure available capacity for a new stream | ||||
|                     // This won't be as simple as self.streams.len() as closed | ||||
|                     // connections should not be factored. | ||||
|  | ||||
|                     if !P::is_valid_remote_stream_id(stream_id) { | ||||
|                         unimplemented!(); | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 Frame::Headers { | ||||
|                     id: stream_id, | ||||
|                     headers: P::convert_poll_message(v), | ||||
| @@ -143,22 +157,23 @@ impl<T, P> Sink for Connection<T, P> | ||||
|  | ||||
|         match item { | ||||
|             Frame::Headers { id, headers, end_of_stream } => { | ||||
|                 // Ensure ID is valid | ||||
|                 // TODO: This check should only be done **if** this is a new | ||||
|                 // stream ID | ||||
|                 // try!(P::check_initiating_id(id)); | ||||
|  | ||||
|                 // TODO: Ensure available capacity for a new stream | ||||
|                 // This won't be as simple as self.streams.len() as closed | ||||
|                 // connections should not be factored. | ||||
|  | ||||
|                 // Transition the stream state, creating a new entry if needed | ||||
|                 // | ||||
|                 // TODO: Response can send multiple headers frames before body | ||||
|                 // (1xx responses). | ||||
|                 try!(self.streams.entry(id) | ||||
|                 let stream_initialized = try!(self.streams.entry(id) | ||||
|                      .or_insert(State::default()) | ||||
|                      .send_headers()); | ||||
|                      .send_headers::<P>(end_of_stream)); | ||||
|  | ||||
|                 if stream_initialized { | ||||
|                     // TODO: Ensure available capacity for a new stream | ||||
|                     // This won't be as simple as self.streams.len() as closed | ||||
|                     // connections should not be factored. | ||||
|                     // | ||||
|                     if !P::is_valid_local_stream_id(id) { | ||||
|                         unimplemented!(); | ||||
|                     } | ||||
|                 } | ||||
|  | ||||
|                 let frame = P::convert_send_message(id, headers, end_of_stream); | ||||
|  | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| use ConnectionError; | ||||
| use {ConnectionError, Reason, Peer}; | ||||
|  | ||||
| /// Represents the state of an H2 stream | ||||
| /// | ||||
| @@ -45,23 +45,134 @@ pub enum State { | ||||
|     Idle, | ||||
|     ReservedLocal, | ||||
|     ReservedRemote, | ||||
|     Open, | ||||
|     HalfClosedLocal, | ||||
|     HalfClosedRemote, | ||||
|     Open { | ||||
|         local: PeerState, | ||||
|         remote: PeerState, | ||||
|     }, | ||||
|     HalfClosedLocal(PeerState), | ||||
|     HalfClosedRemote(PeerState), | ||||
|     Closed, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Copy, Clone, Eq, PartialEq)] | ||||
| pub enum PeerState { | ||||
|     Headers, | ||||
|     Data, | ||||
| } | ||||
|  | ||||
| impl State { | ||||
|     /// Transition the state to represent headers being received. | ||||
|     /// | ||||
|     /// Returns true if this state transition results in iniitializing the | ||||
|     /// stream id. `Err` is returned if this is an invalid state transition. | ||||
|     pub fn recv_headers<P: Peer>(&mut self, eos: bool) -> Result<bool, ConnectionError> { | ||||
|         use self::State::*; | ||||
|         use self::PeerState::*; | ||||
|  | ||||
|         match *self { | ||||
|             Idle => { | ||||
|                 *self = if eos { | ||||
|                     HalfClosedRemote(Headers) | ||||
|                 } else { | ||||
|                     Open { | ||||
|                         local: Headers, | ||||
|                         remote: Data, | ||||
|                     } | ||||
|                 }; | ||||
|  | ||||
|                 Ok(true) | ||||
|             } | ||||
|             Open { local, remote } => { | ||||
|                 try!(remote.check_is_headers(Reason::ProtocolError)); | ||||
|  | ||||
|                 *self = if eos { | ||||
|                     HalfClosedRemote(local) | ||||
|                 } else { | ||||
|                     let remote = Data; | ||||
|                     Open { local, remote } | ||||
|                 }; | ||||
|  | ||||
|                 Ok(false) | ||||
|             } | ||||
|             HalfClosedLocal(remote) => { | ||||
|                 try!(remote.check_is_headers(Reason::ProtocolError)); | ||||
|  | ||||
|                 *self = if eos { | ||||
|                     Closed | ||||
|                 } else { | ||||
|                     HalfClosedLocal(Data) | ||||
|                 }; | ||||
|  | ||||
|                 Ok(false) | ||||
|             } | ||||
|             Closed | HalfClosedRemote(..) => { | ||||
|                 Err(Reason::ProtocolError.into()) | ||||
|             } | ||||
|             _ => unimplemented!(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /// Transition the state to represent headers being sent. | ||||
|     /// | ||||
|     /// Returns an error if this is an invalid state transition. | ||||
|     pub fn send_headers(&mut self) -> Result<(), ConnectionError> { | ||||
|         if *self != State::Idle { | ||||
|             unimplemented!(); | ||||
|         } | ||||
|     /// Returns true if this state transition results in initializing the stream | ||||
|     /// id. `Err` is returned if this is an invalid state transition. | ||||
|     pub fn send_headers<P: Peer>(&mut self, eos: bool) -> Result<bool, ConnectionError> { | ||||
|         use self::State::*; | ||||
|         use self::PeerState::*; | ||||
|  | ||||
|         *self = State::Open; | ||||
|         Ok(()) | ||||
|         match *self { | ||||
|             Idle => { | ||||
|                 *self = if eos { | ||||
|                     HalfClosedLocal(Headers) | ||||
|                 } else { | ||||
|                     Open { | ||||
|                         local: Data, | ||||
|                         remote: Headers, | ||||
|                     } | ||||
|                 }; | ||||
|  | ||||
|                 Ok(true) | ||||
|             } | ||||
|             Open { local, remote } => { | ||||
|                 try!(local.check_is_headers(Reason::InternalError)); | ||||
|  | ||||
|                 *self = if eos { | ||||
|                     HalfClosedLocal(remote) | ||||
|                 } else { | ||||
|                     let local = Data; | ||||
|                     Open { local, remote } | ||||
|                 }; | ||||
|  | ||||
|                 Ok(false) | ||||
|             } | ||||
|             HalfClosedRemote(local) => { | ||||
|                 try!(local.check_is_headers(Reason::InternalError)); | ||||
|  | ||||
|                 *self = if eos { | ||||
|                     Closed | ||||
|                 } else { | ||||
|                     HalfClosedRemote(Data) | ||||
|                 }; | ||||
|  | ||||
|                 Ok(false) | ||||
|             } | ||||
|             Closed | HalfClosedLocal(..) => { | ||||
|                 Err(Reason::InternalError.into()) | ||||
|             } | ||||
|             _ => unimplemented!(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl PeerState { | ||||
|     #[inline] | ||||
|     fn check_is_headers(&self, err: Reason) -> Result<(), ConnectionError> { | ||||
|         use self::PeerState::*; | ||||
|  | ||||
|         match *self { | ||||
|             Headers => Ok(()), | ||||
|             _ => Err(err.into()), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -102,15 +102,12 @@ impl Peer for Server { | ||||
|     type Send = http::response::Head; | ||||
|     type Poll = http::request::Head; | ||||
|  | ||||
|     fn check_initiating_id(id: StreamId) -> Result<(), ConnectionError> { | ||||
|         if id % 2 == 1 { | ||||
|             // Server stream identifiers must be even | ||||
|             unimplemented!(); | ||||
|         } | ||||
|     fn is_valid_local_stream_id(id: StreamId) -> bool { | ||||
|         id.is_server_initiated() | ||||
|     } | ||||
|  | ||||
|         // TODO: Ensure the `id` doesn't overflow u31 | ||||
|  | ||||
|         Ok(()) | ||||
|     fn is_valid_remote_stream_id(id: StreamId) -> bool { | ||||
|         id.is_client_initiated() | ||||
|     } | ||||
|  | ||||
|     fn convert_send_message( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user