feat(server): add Expect 100-continue support
Adds a new method to `Handler`, with a default implementation of always responding with a `100 Continue` when sent an expectation. Closes #369
This commit is contained in:
		
							
								
								
									
										37
									
								
								src/header/common/expect.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								src/header/common/expect.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,37 @@ | ||||
| use std::fmt; | ||||
|  | ||||
| use header::{Header, HeaderFormat}; | ||||
|  | ||||
| /// The `Expect` header. | ||||
| /// | ||||
| /// > The "Expect" header field in a request indicates a certain set of | ||||
| /// > behaviors (expectations) that need to be supported by the server in | ||||
| /// > order to properly handle this request.  The only such expectation | ||||
| /// > defined by this specification is 100-continue. | ||||
| /// > | ||||
| /// >    Expect  = "100-continue" | ||||
| #[derive(Copy, Clone, PartialEq, Debug)] | ||||
| pub enum Expect { | ||||
|     /// The value `100-continue`. | ||||
|     Continue | ||||
| } | ||||
|  | ||||
| impl Header for Expect { | ||||
|     fn header_name() -> &'static str { | ||||
|         "Expect" | ||||
|     } | ||||
|  | ||||
|     fn parse_header(raw: &[Vec<u8>]) -> Option<Expect> { | ||||
|         if &[b"100-continue"] == raw { | ||||
|             Some(Expect::Continue) | ||||
|         } else { | ||||
|             None | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl HeaderFormat for Expect { | ||||
|     fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||||
|         f.write_str("100-continue") | ||||
|     } | ||||
| } | ||||
| @@ -20,6 +20,7 @@ pub use self::content_type::ContentType; | ||||
| pub use self::cookie::Cookie; | ||||
| pub use self::date::Date; | ||||
| pub use self::etag::Etag; | ||||
| pub use self::expect::Expect; | ||||
| pub use self::expires::Expires; | ||||
| pub use self::host::Host; | ||||
| pub use self::if_match::IfMatch; | ||||
| @@ -160,6 +161,7 @@ mod content_length; | ||||
| mod content_type; | ||||
| mod date; | ||||
| mod etag; | ||||
| mod expect; | ||||
| mod expires; | ||||
| mod host; | ||||
| mod if_match; | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| //! HTTP Server | ||||
| use std::io::{BufReader, BufWriter}; | ||||
| use std::io::{BufReader, BufWriter, Write}; | ||||
| use std::marker::PhantomData; | ||||
| use std::net::{IpAddr, SocketAddr}; | ||||
| use std::path::Path; | ||||
| @@ -14,9 +14,12 @@ pub use net::{Fresh, Streaming}; | ||||
|  | ||||
| use HttpError::HttpIoError; | ||||
| use {HttpResult}; | ||||
| use header::Connection; | ||||
| use header::{Headers, Connection, Expect}; | ||||
| use header::ConnectionOption::{Close, KeepAlive}; | ||||
| use method::Method; | ||||
| use net::{NetworkListener, NetworkStream, HttpListener}; | ||||
| use status::StatusCode; | ||||
| use uri::RequestUri; | ||||
| use version::HttpVersion::{Http10, Http11}; | ||||
|  | ||||
| use self::listener::ListenerPool; | ||||
| @@ -99,7 +102,7 @@ S: NetworkStream + Clone + Send> Server<'a, H, L> { | ||||
|  | ||||
|         debug!("threads = {:?}", threads); | ||||
|         let pool = ListenerPool::new(listener.clone()); | ||||
|         let work = move |stream| keep_alive_loop(stream, &handler); | ||||
|         let work = move |mut stream| handle_connection(&mut stream, &handler); | ||||
|  | ||||
|         let guard = thread::scoped(move || pool.accept(work, threads)); | ||||
|  | ||||
| @@ -111,7 +114,7 @@ S: NetworkStream + Clone + Send> Server<'a, H, L> { | ||||
| } | ||||
|  | ||||
|  | ||||
| fn keep_alive_loop<'h, S, H>(mut stream: S, handler: &'h H) | ||||
| fn handle_connection<'h, S, H>(mut stream: &mut S, handler: &'h H) | ||||
| where S: NetworkStream + Clone, H: Handler { | ||||
|     debug!("Incoming stream"); | ||||
|     let addr = match stream.peer_addr() { | ||||
| @@ -128,41 +131,47 @@ where S: NetworkStream + Clone, H: Handler { | ||||
|  | ||||
|     let mut keep_alive = true; | ||||
|     while keep_alive { | ||||
|         keep_alive = handle_connection(addr, &mut rdr, &mut wrt, handler); | ||||
|         let req = match Request::new(&mut rdr, addr) { | ||||
|             Ok(req) => req, | ||||
|             Err(e@HttpIoError(_)) => { | ||||
|                 debug!("ioerror in keepalive loop = {:?}", e); | ||||
|                 break; | ||||
|             } | ||||
|             Err(e) => { | ||||
|                 //TODO: send a 400 response | ||||
|                 error!("request error = {:?}", e); | ||||
|                 break; | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|         if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) { | ||||
|             let status = handler.check_continue((&req.method, &req.uri, &req.headers)); | ||||
|             match write!(&mut wrt, "{} {}\r\n\r\n", Http11, status) { | ||||
|                 Ok(..) => (), | ||||
|                 Err(e) => { | ||||
|                     error!("error writing 100-continue: {:?}", e); | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             if status != StatusCode::Continue { | ||||
|                 debug!("non-100 status ({}) for Expect 100 request", status); | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         keep_alive = match (req.version, req.headers.get::<Connection>()) { | ||||
|             (Http10, Some(conn)) if !conn.contains(&KeepAlive) => false, | ||||
|             (Http11, Some(conn)) if conn.contains(&Close)  => false, | ||||
|             _ => true | ||||
|         }; | ||||
|         let mut res = Response::new(&mut wrt); | ||||
|         res.version = req.version; | ||||
|         handler.handle(req, res); | ||||
|         debug!("keep_alive = {:?}", keep_alive); | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn handle_connection<'a, 'aa, 'h, S, H>( | ||||
|     addr: SocketAddr, | ||||
|     rdr: &'a mut BufReader<&'aa mut NetworkStream>, | ||||
|     wrt: &mut BufWriter<S>, | ||||
|     handler: &'h H | ||||
| ) -> bool where 'aa: 'a, S: NetworkStream, H: Handler { | ||||
|     let mut res = Response::new(wrt); | ||||
|     let req = match Request::<'a, 'aa>::new(rdr, addr) { | ||||
|         Ok(req) => req, | ||||
|         Err(e@HttpIoError(_)) => { | ||||
|             debug!("ioerror in keepalive loop = {:?}", e); | ||||
|             return false; | ||||
|         } | ||||
|         Err(e) => { | ||||
|             //TODO: send a 400 response | ||||
|             error!("request error = {:?}", e); | ||||
|             return false; | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     let keep_alive = match (req.version, req.headers.get::<Connection>()) { | ||||
|         (Http10, Some(conn)) if !conn.contains(&KeepAlive) => false, | ||||
|         (Http11, Some(conn)) if conn.contains(&Close)  => false, | ||||
|         _ => true | ||||
|     }; | ||||
|     res.version = req.version; | ||||
|     handler.handle(req, res); | ||||
|     keep_alive | ||||
| } | ||||
|  | ||||
| /// A listening server, which can later be closed. | ||||
| pub struct Listening { | ||||
|     _guard: JoinGuard<'static, ()>, | ||||
| @@ -184,11 +193,78 @@ pub trait Handler: Sync + Send { | ||||
|     /// Receives a `Request`/`Response` pair, and should perform some action on them. | ||||
|     /// | ||||
|     /// This could reading from the request, and writing to the response. | ||||
|     fn handle<'a, 'aa, 'b, 's>(&'s self, Request<'aa, 'a>, Response<'b, Fresh>); | ||||
|     fn handle<'a, 'k>(&'a self, Request<'a, 'k>, Response<'a, Fresh>); | ||||
|  | ||||
|     /// Called when a Request includes a `Expect: 100-continue` header. | ||||
|     /// | ||||
|     /// By default, this will always immediately response with a `StatusCode::Continue`, | ||||
|     /// but can be overridden with custom behavior. | ||||
|     fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode { | ||||
|         StatusCode::Continue | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<F> Handler for F where F: Fn(Request, Response<Fresh>), F: Sync + Send { | ||||
|     fn handle<'a, 'aa, 'b, 's>(&'s self, req: Request<'a, 'aa>, res: Response<'b, Fresh>) { | ||||
|     fn handle<'a, 'k>(&'a self, req: Request<'a, 'k>, res: Response<'a, Fresh>) { | ||||
|         self(req, res) | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|     use header::Headers; | ||||
|     use method::Method; | ||||
|     use mock::MockStream; | ||||
|     use status::StatusCode; | ||||
|     use uri::RequestUri; | ||||
|  | ||||
|     use super::{Request, Response, Fresh, Handler, handle_connection}; | ||||
|  | ||||
|     #[test] | ||||
|     fn test_check_continue_default() { | ||||
|         let mut mock = MockStream::with_input(b"\ | ||||
|             POST /upload HTTP/1.1\r\n\ | ||||
|             Host: example.domain\r\n\ | ||||
|             Expect: 100-continue\r\n\ | ||||
|             Content-Length: 10\r\n\ | ||||
|             \r\n\ | ||||
|             1234567890\ | ||||
|         "); | ||||
|  | ||||
|         fn handle(_: Request, res: Response<Fresh>) { | ||||
|             res.start().unwrap().end().unwrap(); | ||||
|         } | ||||
|  | ||||
|         handle_connection(&mut mock, &handle); | ||||
|         let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; | ||||
|         assert_eq!(&mock.write[..cont.len()], cont); | ||||
|         let res = b"HTTP/1.1 200 OK\r\n"; | ||||
|         assert_eq!(&mock.write[cont.len()..cont.len() + res.len()], res); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_check_continue_reject() { | ||||
|         struct Reject; | ||||
|         impl Handler for Reject { | ||||
|             fn handle<'a, 'k>(&'a self, _: Request<'a, 'k>, res: Response<'a, Fresh>) { | ||||
|                 res.start().unwrap().end().unwrap(); | ||||
|             } | ||||
|  | ||||
|             fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode { | ||||
|                 StatusCode::ExpectationFailed | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         let mut mock = MockStream::with_input(b"\ | ||||
|             POST /upload HTTP/1.1\r\n\ | ||||
|             Host: example.domain\r\n\ | ||||
|             Expect: 100-continue\r\n\ | ||||
|             Content-Length: 10\r\n\ | ||||
|             \r\n\ | ||||
|             1234567890\ | ||||
|         "); | ||||
|  | ||||
|         handle_connection(&mut mock, &Reject); | ||||
|         assert_eq!(mock.write, b"HTTP/1.1 417 Expectation Failed\r\n\r\n"); | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user