Merge pull request #377 from hyperium/check-continue
feat(server): add Expect 100-continue support
This commit is contained in:
37
src/header/common/expect.rs
Normal file
37
src/header/common/expect.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use std::fmt;
|
||||
|
||||
use header::{Header, HeaderFormat};
|
||||
|
||||
/// The `Expect` header.
|
||||
///
|
||||
/// > The "Expect" header field in a request indicates a certain set of
|
||||
/// > behaviors (expectations) that need to be supported by the server in
|
||||
/// > order to properly handle this request. The only such expectation
|
||||
/// > defined by this specification is 100-continue.
|
||||
/// >
|
||||
/// > Expect = "100-continue"
|
||||
#[derive(Copy, Clone, PartialEq, Debug)]
|
||||
pub enum Expect {
|
||||
/// The value `100-continue`.
|
||||
Continue
|
||||
}
|
||||
|
||||
impl Header for Expect {
|
||||
fn header_name() -> &'static str {
|
||||
"Expect"
|
||||
}
|
||||
|
||||
fn parse_header(raw: &[Vec<u8>]) -> Option<Expect> {
|
||||
if &[b"100-continue"] == raw {
|
||||
Some(Expect::Continue)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HeaderFormat for Expect {
|
||||
fn fmt_header(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str("100-continue")
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,7 @@ pub use self::content_type::ContentType;
|
||||
pub use self::cookie::Cookie;
|
||||
pub use self::date::Date;
|
||||
pub use self::etag::Etag;
|
||||
pub use self::expect::Expect;
|
||||
pub use self::expires::Expires;
|
||||
pub use self::host::Host;
|
||||
pub use self::if_match::IfMatch;
|
||||
@@ -160,6 +161,7 @@ mod content_length;
|
||||
mod content_type;
|
||||
mod date;
|
||||
mod etag;
|
||||
mod expect;
|
||||
mod expires;
|
||||
mod host;
|
||||
mod if_match;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
//! HTTP Server
|
||||
use std::io::{BufReader, BufWriter};
|
||||
use std::io::{BufReader, BufWriter, Write};
|
||||
use std::marker::PhantomData;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::path::Path;
|
||||
@@ -14,9 +14,12 @@ pub use net::{Fresh, Streaming};
|
||||
|
||||
use HttpError::HttpIoError;
|
||||
use {HttpResult};
|
||||
use header::Connection;
|
||||
use header::{Headers, Connection, Expect};
|
||||
use header::ConnectionOption::{Close, KeepAlive};
|
||||
use method::Method;
|
||||
use net::{NetworkListener, NetworkStream, HttpListener};
|
||||
use status::StatusCode;
|
||||
use uri::RequestUri;
|
||||
use version::HttpVersion::{Http10, Http11};
|
||||
|
||||
use self::listener::ListenerPool;
|
||||
@@ -99,7 +102,7 @@ S: NetworkStream + Clone + Send> Server<'a, H, L> {
|
||||
|
||||
debug!("threads = {:?}", threads);
|
||||
let pool = ListenerPool::new(listener.clone());
|
||||
let work = move |stream| keep_alive_loop(stream, &handler);
|
||||
let work = move |mut stream| handle_connection(&mut stream, &handler);
|
||||
|
||||
let guard = thread::scoped(move || pool.accept(work, threads));
|
||||
|
||||
@@ -111,7 +114,7 @@ S: NetworkStream + Clone + Send> Server<'a, H, L> {
|
||||
}
|
||||
|
||||
|
||||
fn keep_alive_loop<'h, S, H>(mut stream: S, handler: &'h H)
|
||||
fn handle_connection<'h, S, H>(mut stream: &mut S, handler: &'h H)
|
||||
where S: NetworkStream + Clone, H: Handler {
|
||||
debug!("Incoming stream");
|
||||
let addr = match stream.peer_addr() {
|
||||
@@ -128,41 +131,47 @@ where S: NetworkStream + Clone, H: Handler {
|
||||
|
||||
let mut keep_alive = true;
|
||||
while keep_alive {
|
||||
keep_alive = handle_connection(addr, &mut rdr, &mut wrt, handler);
|
||||
let req = match Request::new(&mut rdr, addr) {
|
||||
Ok(req) => req,
|
||||
Err(e@HttpIoError(_)) => {
|
||||
debug!("ioerror in keepalive loop = {:?}", e);
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
//TODO: send a 400 response
|
||||
error!("request error = {:?}", e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if req.version == Http11 && req.headers.get() == Some(&Expect::Continue) {
|
||||
let status = handler.check_continue((&req.method, &req.uri, &req.headers));
|
||||
match write!(&mut wrt, "{} {}\r\n\r\n", Http11, status) {
|
||||
Ok(..) => (),
|
||||
Err(e) => {
|
||||
error!("error writing 100-continue: {:?}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if status != StatusCode::Continue {
|
||||
debug!("non-100 status ({}) for Expect 100 request", status);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
keep_alive = match (req.version, req.headers.get::<Connection>()) {
|
||||
(Http10, Some(conn)) if !conn.contains(&KeepAlive) => false,
|
||||
(Http11, Some(conn)) if conn.contains(&Close) => false,
|
||||
_ => true
|
||||
};
|
||||
let mut res = Response::new(&mut wrt);
|
||||
res.version = req.version;
|
||||
handler.handle(req, res);
|
||||
debug!("keep_alive = {:?}", keep_alive);
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_connection<'a, 'aa, 'h, S, H>(
|
||||
addr: SocketAddr,
|
||||
rdr: &'a mut BufReader<&'aa mut NetworkStream>,
|
||||
wrt: &mut BufWriter<S>,
|
||||
handler: &'h H
|
||||
) -> bool where 'aa: 'a, S: NetworkStream, H: Handler {
|
||||
let mut res = Response::new(wrt);
|
||||
let req = match Request::<'a, 'aa>::new(rdr, addr) {
|
||||
Ok(req) => req,
|
||||
Err(e@HttpIoError(_)) => {
|
||||
debug!("ioerror in keepalive loop = {:?}", e);
|
||||
return false;
|
||||
}
|
||||
Err(e) => {
|
||||
//TODO: send a 400 response
|
||||
error!("request error = {:?}", e);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
let keep_alive = match (req.version, req.headers.get::<Connection>()) {
|
||||
(Http10, Some(conn)) if !conn.contains(&KeepAlive) => false,
|
||||
(Http11, Some(conn)) if conn.contains(&Close) => false,
|
||||
_ => true
|
||||
};
|
||||
res.version = req.version;
|
||||
handler.handle(req, res);
|
||||
keep_alive
|
||||
}
|
||||
|
||||
/// A listening server, which can later be closed.
|
||||
pub struct Listening {
|
||||
_guard: JoinGuard<'static, ()>,
|
||||
@@ -184,11 +193,78 @@ pub trait Handler: Sync + Send {
|
||||
/// Receives a `Request`/`Response` pair, and should perform some action on them.
|
||||
///
|
||||
/// This could reading from the request, and writing to the response.
|
||||
fn handle<'a, 'aa, 'b, 's>(&'s self, Request<'aa, 'a>, Response<'b, Fresh>);
|
||||
fn handle<'a, 'k>(&'a self, Request<'a, 'k>, Response<'a, Fresh>);
|
||||
|
||||
/// Called when a Request includes a `Expect: 100-continue` header.
|
||||
///
|
||||
/// By default, this will always immediately response with a `StatusCode::Continue`,
|
||||
/// but can be overridden with custom behavior.
|
||||
fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode {
|
||||
StatusCode::Continue
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Handler for F where F: Fn(Request, Response<Fresh>), F: Sync + Send {
|
||||
fn handle<'a, 'aa, 'b, 's>(&'s self, req: Request<'a, 'aa>, res: Response<'b, Fresh>) {
|
||||
fn handle<'a, 'k>(&'a self, req: Request<'a, 'k>, res: Response<'a, Fresh>) {
|
||||
self(req, res)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use header::Headers;
|
||||
use method::Method;
|
||||
use mock::MockStream;
|
||||
use status::StatusCode;
|
||||
use uri::RequestUri;
|
||||
|
||||
use super::{Request, Response, Fresh, Handler, handle_connection};
|
||||
|
||||
#[test]
|
||||
fn test_check_continue_default() {
|
||||
let mut mock = MockStream::with_input(b"\
|
||||
POST /upload HTTP/1.1\r\n\
|
||||
Host: example.domain\r\n\
|
||||
Expect: 100-continue\r\n\
|
||||
Content-Length: 10\r\n\
|
||||
\r\n\
|
||||
1234567890\
|
||||
");
|
||||
|
||||
fn handle(_: Request, res: Response<Fresh>) {
|
||||
res.start().unwrap().end().unwrap();
|
||||
}
|
||||
|
||||
handle_connection(&mut mock, &handle);
|
||||
let cont = b"HTTP/1.1 100 Continue\r\n\r\n";
|
||||
assert_eq!(&mock.write[..cont.len()], cont);
|
||||
let res = b"HTTP/1.1 200 OK\r\n";
|
||||
assert_eq!(&mock.write[cont.len()..cont.len() + res.len()], res);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_continue_reject() {
|
||||
struct Reject;
|
||||
impl Handler for Reject {
|
||||
fn handle<'a, 'k>(&'a self, _: Request<'a, 'k>, res: Response<'a, Fresh>) {
|
||||
res.start().unwrap().end().unwrap();
|
||||
}
|
||||
|
||||
fn check_continue(&self, _: (&Method, &RequestUri, &Headers)) -> StatusCode {
|
||||
StatusCode::ExpectationFailed
|
||||
}
|
||||
}
|
||||
|
||||
let mut mock = MockStream::with_input(b"\
|
||||
POST /upload HTTP/1.1\r\n\
|
||||
Host: example.domain\r\n\
|
||||
Expect: 100-continue\r\n\
|
||||
Content-Length: 10\r\n\
|
||||
\r\n\
|
||||
1234567890\
|
||||
");
|
||||
|
||||
handle_connection(&mut mock, &Reject);
|
||||
assert_eq!(mock.write, b"HTTP/1.1 417 Expectation Failed\r\n\r\n");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user