| @@ -62,6 +62,13 @@ where I: AsyncRead + AsyncWrite, | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     fn can_write_continue(&self) -> bool { | ||||||
|  |         match self.state.writing { | ||||||
|  |             Writing::Continue(..) => true, | ||||||
|  |             _ => false, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     fn can_read_body(&self) -> bool { |     fn can_read_body(&self) -> bool { | ||||||
|         match self.state.reading { |         match self.state.reading { | ||||||
|             Reading::Body(..) => true, |             Reading::Body(..) => true, | ||||||
| @@ -105,6 +112,10 @@ where I: AsyncRead + AsyncWrite, | |||||||
|                     } |                     } | ||||||
|                 }; |                 }; | ||||||
|                 self.state.busy(); |                 self.state.busy(); | ||||||
|  |                 if head.expecting_continue() { | ||||||
|  |                     let msg = b"HTTP/1.1 100 Continue\r\n\r\n"; | ||||||
|  |                     self.state.writing = Writing::Continue(Cursor::new(msg)); | ||||||
|  |                 } | ||||||
|                 let wants_keep_alive = head.should_keep_alive(); |                 let wants_keep_alive = head.should_keep_alive(); | ||||||
|                 self.state.keep_alive &= wants_keep_alive; |                 self.state.keep_alive &= wants_keep_alive; | ||||||
|                 let (body, reading) = if decoder.is_eof() { |                 let (body, reading) = if decoder.is_eof() { | ||||||
| @@ -172,6 +183,7 @@ where I: AsyncRead + AsyncWrite, | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         match self.state.writing { |         match self.state.writing { | ||||||
|  |             Writing::Continue(..) | | ||||||
|             Writing::Body(..) | |             Writing::Body(..) | | ||||||
|             Writing::Ending(..) => return, |             Writing::Ending(..) => return, | ||||||
|             Writing::Init | |             Writing::Init | | ||||||
| @@ -191,7 +203,7 @@ where I: AsyncRead + AsyncWrite, | |||||||
|  |  | ||||||
|     fn can_write_head(&self) -> bool { |     fn can_write_head(&self) -> bool { | ||||||
|         match self.state.writing { |         match self.state.writing { | ||||||
|             Writing::Init => true, |             Writing::Continue(..) | Writing::Init => true, | ||||||
|             _ => false |             _ => false | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -199,6 +211,7 @@ where I: AsyncRead + AsyncWrite, | |||||||
|     fn can_write_body(&self) -> bool { |     fn can_write_body(&self) -> bool { | ||||||
|         match self.state.writing { |         match self.state.writing { | ||||||
|             Writing::Body(..) => true, |             Writing::Body(..) => true, | ||||||
|  |             Writing::Continue(..) | | ||||||
|             Writing::Init | |             Writing::Init | | ||||||
|             Writing::Ending(..) | |             Writing::Ending(..) | | ||||||
|             Writing::KeepAlive | |             Writing::KeepAlive | | ||||||
| @@ -227,6 +240,13 @@ where I: AsyncRead + AsyncWrite, | |||||||
|  |  | ||||||
|         let wants_keep_alive = head.should_keep_alive(); |         let wants_keep_alive = head.should_keep_alive(); | ||||||
|         self.state.keep_alive &= wants_keep_alive; |         self.state.keep_alive &= wants_keep_alive; | ||||||
|  |         // if a 100-continue has started but not finished sending, tack the | ||||||
|  |         // remainder on to the start of the buffer. | ||||||
|  |         if let Writing::Continue(ref pending) = self.state.writing { | ||||||
|  |             if pending.has_started() { | ||||||
|  |                 self.io.write_buf_mut().extend_from_slice(pending.buf()); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|         let encoder = T::encode(head, self.io.write_buf_mut()); |         let encoder = T::encode(head, self.io.write_buf_mut()); | ||||||
|         self.state.writing = if body { |         self.state.writing = if body { | ||||||
|             Writing::Body(encoder, None) |             Writing::Body(encoder, None) | ||||||
| @@ -290,6 +310,15 @@ where I: AsyncRead + AsyncWrite, | |||||||
|     fn write_queued(&mut self) -> Poll<(), io::Error> { |     fn write_queued(&mut self) -> Poll<(), io::Error> { | ||||||
|         trace!("Conn::write_queued()"); |         trace!("Conn::write_queued()"); | ||||||
|         let state = match self.state.writing { |         let state = match self.state.writing { | ||||||
|  |             Writing::Continue(ref mut queued) => { | ||||||
|  |                 let n = self.io.buffer(queued.buf()); | ||||||
|  |                 queued.consume(n); | ||||||
|  |                 if queued.is_written() { | ||||||
|  |                     Writing::Init | ||||||
|  |                 } else { | ||||||
|  |                     return Ok(Async::NotReady); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|             Writing::Body(ref mut encoder, ref mut queued) => { |             Writing::Body(ref mut encoder, ref mut queued) => { | ||||||
|                 let complete = if let Some(chunk) = queued.as_mut() { |                 let complete = if let Some(chunk) = queued.as_mut() { | ||||||
|                     let n = try_nb!(encoder.encode(&mut self.io, chunk.buf())); |                     let n = try_nb!(encoder.encode(&mut self.io, chunk.buf())); | ||||||
| @@ -349,24 +378,28 @@ where I: AsyncRead + AsyncWrite, | |||||||
|         trace!("Conn::poll()"); |         trace!("Conn::poll()"); | ||||||
|         self.state.read_task.take(); |         self.state.read_task.take(); | ||||||
|  |  | ||||||
|  |         loop { | ||||||
|             if self.is_read_closed() { |             if self.is_read_closed() { | ||||||
|                 trace!("Conn::poll when closed"); |                 trace!("Conn::poll when closed"); | ||||||
|             Ok(Async::Ready(None)) |                 return Ok(Async::Ready(None)); | ||||||
|             } else if self.can_read_head() { |             } else if self.can_read_head() { | ||||||
|             self.read_head() |                 return self.read_head(); | ||||||
|  |             } else if self.can_write_continue() { | ||||||
|  |                 try_nb!(self.flush()); | ||||||
|             } else if self.can_read_body() { |             } else if self.can_read_body() { | ||||||
|             self.read_body() |                 return self.read_body() | ||||||
|                     .map(|async| async.map(|chunk| Some(Frame::Body { |                     .map(|async| async.map(|chunk| Some(Frame::Body { | ||||||
|                         chunk: chunk |                         chunk: chunk | ||||||
|                     }))) |                     }))) | ||||||
|                     .or_else(|err| { |                     .or_else(|err| { | ||||||
|                         self.state.close_read(); |                         self.state.close_read(); | ||||||
|                         Ok(Async::Ready(Some(Frame::Error { error: err.into() }))) |                         Ok(Async::Ready(Some(Frame::Error { error: err.into() }))) | ||||||
|                 }) |                     }); | ||||||
|             } else { |             } else { | ||||||
|                 trace!("poll when on keep-alive"); |                 trace!("poll when on keep-alive"); | ||||||
|                 self.maybe_park_read(); |                 self.maybe_park_read(); | ||||||
|             Ok(Async::NotReady) |                 return Ok(Async::NotReady); | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -467,6 +500,7 @@ enum Reading { | |||||||
| } | } | ||||||
|  |  | ||||||
| enum Writing<B> { | enum Writing<B> { | ||||||
|  |     Continue(Cursor<&'static [u8]>), | ||||||
|     Init, |     Init, | ||||||
|     Body(Encoder, Option<Cursor<B>>), |     Body(Encoder, Option<Cursor<B>>), | ||||||
|     Ending(Cursor<&'static [u8]>), |     Ending(Cursor<&'static [u8]>), | ||||||
| @@ -488,6 +522,9 @@ impl<B: AsRef<[u8]>, K: fmt::Debug> fmt::Debug for State<B, K> { | |||||||
| impl<B: AsRef<[u8]>> fmt::Debug for Writing<B> { | impl<B: AsRef<[u8]>> fmt::Debug for Writing<B> { | ||||||
|     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||||||
|         match *self { |         match *self { | ||||||
|  |             Writing::Continue(ref buf) => f.debug_tuple("Continue") | ||||||
|  |                 .field(buf) | ||||||
|  |                 .finish(), | ||||||
|             Writing::Init => f.write_str("Init"), |             Writing::Init => f.write_str("Init"), | ||||||
|             Writing::Body(ref enc, ref queued) => f.debug_tuple("Body") |             Writing::Body(ref enc, ref queued) => f.debug_tuple("Body") | ||||||
|                 .field(enc) |                 .field(enc) | ||||||
|   | |||||||
| @@ -181,6 +181,10 @@ impl<T: AsRef<[u8]>> Cursor<T> { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn has_started(&self) -> bool { | ||||||
|  |         self.pos != 0 | ||||||
|  |     } | ||||||
|  |  | ||||||
|     pub fn is_written(&self) -> bool { |     pub fn is_written(&self) -> bool { | ||||||
|         trace!("Cursor::is_written pos = {}, len = {}", self.pos, self.bytes.as_ref().len()); |         trace!("Cursor::is_written pos = {}, len = {}", self.pos, self.bytes.as_ref().len()); | ||||||
|         self.pos >= self.bytes.as_ref().len() |         self.pos >= self.bytes.as_ref().len() | ||||||
|   | |||||||
| @@ -4,7 +4,7 @@ use std::fmt; | |||||||
|  |  | ||||||
| use bytes::BytesMut; | use bytes::BytesMut; | ||||||
|  |  | ||||||
| use header::{Connection, ConnectionOption}; | use header::{Connection, ConnectionOption, Expect}; | ||||||
| use header::Headers; | use header::Headers; | ||||||
| use method::Method; | use method::Method; | ||||||
| use status::StatusCode; | use status::StatusCode; | ||||||
| @@ -56,6 +56,10 @@ impl<S> MessageHead<S> { | |||||||
|     pub fn should_keep_alive(&self) -> bool { |     pub fn should_keep_alive(&self) -> bool { | ||||||
|         should_keep_alive(self.version, &self.headers) |         should_keep_alive(self.version, &self.headers) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn expecting_continue(&self) -> bool { | ||||||
|  |         expecting_continue(self.version, &self.headers) | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| impl ResponseHead { | impl ResponseHead { | ||||||
| @@ -119,6 +123,17 @@ pub fn should_keep_alive(version: HttpVersion, headers: &Headers) -> bool { | |||||||
|     ret |     ret | ||||||
| } | } | ||||||
|  |  | ||||||
|  | /// Checks if a connection is expecting a `100 Continue` before sending its body. | ||||||
|  | #[inline] | ||||||
|  | pub fn expecting_continue(version: HttpVersion, headers: &Headers) -> bool { | ||||||
|  |     let ret = match (version, headers.get::<Expect>()) { | ||||||
|  |         (Http11, Some(&Expect::Continue)) => true, | ||||||
|  |         _ => false | ||||||
|  |     }; | ||||||
|  |     trace!("expecting_continue(version={:?}, header={:?}) = {:?}", version, headers.get::<Expect>(), ret); | ||||||
|  |     ret | ||||||
|  | } | ||||||
|  |  | ||||||
| #[derive(Debug)] | #[derive(Debug)] | ||||||
| pub enum ServerTransaction {} | pub enum ServerTransaction {} | ||||||
|  |  | ||||||
| @@ -168,3 +183,15 @@ fn test_should_keep_alive() { | |||||||
|     assert!(should_keep_alive(Http10, &headers)); |     assert!(should_keep_alive(Http10, &headers)); | ||||||
|     assert!(should_keep_alive(Http11, &headers)); |     assert!(should_keep_alive(Http11, &headers)); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[test] | ||||||
|  | fn test_expecting_continue() { | ||||||
|  |     let mut headers = Headers::new(); | ||||||
|  |  | ||||||
|  |     assert!(!expecting_continue(Http10, &headers)); | ||||||
|  |     assert!(!expecting_continue(Http11, &headers)); | ||||||
|  |  | ||||||
|  |     headers.set(Expect::Continue); | ||||||
|  |     assert!(!expecting_continue(Http10, &headers)); | ||||||
|  |     assert!(expecting_continue(Http11, &headers)); | ||||||
|  | } | ||||||
|   | |||||||
| @@ -523,3 +523,33 @@ fn test_server_disable_keep_alive() { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | #[test] | ||||||
|  | fn expect_continue() { | ||||||
|  |     let server = serve(); | ||||||
|  |     let mut req = connect(server.addr()); | ||||||
|  |     server.reply().status(hyper::Ok); | ||||||
|  |  | ||||||
|  |     req.write_all(b"\ | ||||||
|  |         POST /foo HTTP/1.1\r\n\ | ||||||
|  |         Host: example.domain\r\n\ | ||||||
|  |         Expect: 100-continue\r\n\ | ||||||
|  |         Content-Length: 5\r\n\ | ||||||
|  |         Connection: Close\r\n\ | ||||||
|  |         \r\n\ | ||||||
|  |     ").expect("write 1"); | ||||||
|  |  | ||||||
|  |     let msg = b"HTTP/1.1 100 Continue\r\n\r\n"; | ||||||
|  |     let mut buf = vec![0; msg.len()]; | ||||||
|  |     req.read_exact(&mut buf).expect("read 1"); | ||||||
|  |     assert_eq!(buf, msg); | ||||||
|  |  | ||||||
|  |     let msg = b"hello"; | ||||||
|  |     req.write_all(msg).expect("write 2"); | ||||||
|  |  | ||||||
|  |     let mut body = String::new(); | ||||||
|  |     req.read_to_string(&mut body).expect("read 2"); | ||||||
|  |  | ||||||
|  |     let body = server.body(); | ||||||
|  |     assert_eq!(body, msg); | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user