diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index 86e9bc95..dd21b623 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -191,7 +191,9 @@ where // This is because Service only allows returning a single Response, and // so if you try to reply with a e.g. 100 Continue, you have no way of // replying with the latter status code response. - let (ret, mut is_last) = if StatusCode::SWITCHING_PROTOCOLS == msg.head.subject { + let is_upgrade = msg.head.subject == StatusCode::SWITCHING_PROTOCOLS + || (msg.req_method == &Some(Method::CONNECT) && msg.head.subject.is_success()); + let (ret, mut is_last) = if is_upgrade { (T::on_encode_upgrade(&mut msg), true) } else if msg.head.subject.is_informational() { error!("response with 1xx status code not supported"); @@ -851,12 +853,20 @@ impl OnUpgrade for YesUpgrades { impl OnUpgrade for NoUpgrades { fn on_encode_upgrade(msg: &mut Encode) -> ::Result<()> { - error!("response with 101 status code not supported"); *msg.head = MessageHead::default(); msg.head.subject = ::StatusCode::INTERNAL_SERVER_ERROR; msg.body = None; - //TODO: replace with more descriptive error - Err(::Error::new_status()) + + if msg.head.subject == StatusCode::SWITCHING_PROTOCOLS { + error!("response with 101 status code not supported"); + Err(Parse::UpgradeNotSupported.into()) + } else if msg.req_method == &Some(Method::CONNECT) { + error!("200 response to CONNECT request not supported"); + Err(::Error::new_user_unsupported_request_method()) + } else { + debug_assert!(false, "upgrade incorrectly detected"); + Err(::Error::new_status()) + } } fn on_decode_upgrade() -> Result { @@ -1309,6 +1319,39 @@ mod tests { assert_eq!(vec, b"GET / HTTP/1.1\r\nContent-Length: 10\r\nContent-Type: application/json\r\n\r\n".to_vec()); } + #[test] + fn test_server_no_upgrades_connect_method() { + let mut head = MessageHead::default(); + + let mut vec = Vec::new(); + let err = Server::encode(Encode { + head: &mut head, + body: None, + keep_alive: true, + req_method: &mut Some(Method::CONNECT), + title_case_headers: false, + }, &mut vec).unwrap_err(); + + assert!(err.is_user()); + assert_eq!(err.kind(), &::error::Kind::UnsupportedRequestMethod); + } + + #[test] + fn test_server_yes_upgrades_connect_method() { + let mut head = MessageHead::default(); + + let mut vec = Vec::new(); + let encoder = S::::encode(Encode { + head: &mut head, + body: None, + keep_alive: true, + req_method: &mut Some(Method::CONNECT), + title_case_headers: false, + }, &mut vec).unwrap(); + + assert!(encoder.is_last()); + } + #[cfg(feature = "nightly")] use test::Bencher; diff --git a/tests/server.rs b/tests/server.rs index 09c2e9f5..000c7098 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -1147,6 +1147,70 @@ fn upgrades() { assert_eq!(vec, b"bar=foo"); } +#[test] +fn http_connect() { + use tokio_io::io::{read_to_end, write_all}; + let _ = pretty_env_logger::try_init(); + let runtime = Runtime::new().unwrap(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap(), &runtime.reactor()).unwrap(); + let addr = listener.local_addr().unwrap(); + let (tx, rx) = oneshot::channel(); + + thread::spawn(move || { + let mut tcp = connect(&addr); + tcp.write_all(b"\ + CONNECT localhost:80 HTTP/1.1\r\n\ + \r\n\ + eagerly optimistic\ + ").expect("write 1"); + let mut buf = [0; 256]; + tcp.read(&mut buf).expect("read 1"); + + let expected = "HTTP/1.1 200 OK\r\n"; + assert_eq!(s(&buf[..expected.len()]), expected); + let _ = tx.send(()); + + let n = tcp.read(&mut buf).expect("read 2"); + assert_eq!(s(&buf[..n]), "foo=bar"); + tcp.write_all(b"bar=foo").expect("write 2"); + }); + + let fut = listener.incoming() + .into_future() + .map_err(|_| -> hyper::Error { unreachable!() }) + .and_then(|(item, _incoming)| { + let socket = item.unwrap(); + let conn = Http::new() + .serve_connection(socket, service_fn(|_| { + let res = Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(); + Ok::<_, hyper::Error>(res) + })); + + let mut conn_opt = Some(conn); + future::poll_fn(move || { + try_ready!(conn_opt.as_mut().unwrap().poll_without_shutdown()); + // conn is done with HTTP now + Ok(conn_opt.take().unwrap().into()) + }) + }); + + let conn = fut.wait().unwrap(); + + // wait so that we don't write until other side saw 101 response + rx.wait().unwrap(); + + let parts = conn.into_parts(); + let io = parts.io; + assert_eq!(parts.read_buf, "eagerly optimistic"); + + let io = write_all(io, b"foo=bar").wait().unwrap().0; + let vec = read_to_end(io, vec![]).wait().unwrap().1; + assert_eq!(vec, b"bar=foo"); +} + #[test] fn parse_errors_send_4xx_response() { let runtime = Runtime::new().unwrap();