diff --git a/src/connect.rs b/src/connect.rs index 351bd73..c4e7eb7 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -153,7 +153,7 @@ impl Connect for Connector { #[cfg(feature = "default-tls")] Inner::DefaultTls(http, tls) => if dst.scheme() == "https" { #[cfg(feature = "default-tls")] - use connect_async::TlsConnectorExt; + use self::native_tls_async::TlsConnectorExt; let host = dst.host().to_owned(); let port = dst.port().unwrap_or(443); @@ -297,6 +297,140 @@ fn tunnel_eof() -> io::Error { ) } +#[cfg(feature = "default-tls")] +mod native_tls_async { + use std::io::{self, Read, Write}; + + use futures::{Poll, Future, Async}; + use native_tls::{self, HandshakeError, Error, TlsConnector}; + use tokio_io::{AsyncRead, AsyncWrite}; + + /// A wrapper around an underlying raw stream which implements the TLS or SSL + /// protocol. + /// + /// A `TlsStream` represents a handshake that has been completed successfully + /// and both the server and the client are ready for receiving and sending + /// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written + /// to a `TlsStream` are encrypted when passing through to `S`. + #[derive(Debug)] + pub struct TlsStream { + inner: native_tls::TlsStream, + } + + /// Future returned from `TlsConnectorExt::connect_async` which will resolve + /// once the connection handshake has finished. + pub struct ConnectAsync { + inner: MidHandshake, + } + + struct MidHandshake { + inner: Option, HandshakeError>>, + } + + /// Extension trait for the `TlsConnector` type in the `native_tls` crate. + pub trait TlsConnectorExt: sealed::Sealed { + /// Connects the provided stream with this connector, assuming the provided + /// domain. + /// + /// This function will internally call `TlsConnector::connect` to connect + /// the stream and returns a future representing the resolution of the + /// connection operation. The returned future will resolve to either + /// `TlsStream` or `Error` depending if it's successful or not. + /// + /// This is typically used for clients who have already established, for + /// example, a TCP connection to a remote server. That stream is then + /// provided here to perform the client half of a connection to a + /// TLS-powered server. + /// + /// # Compatibility notes + /// + /// Note that this method currently requires `S: Read + Write` but it's + /// highly recommended to ensure that the object implements the `AsyncRead` + /// and `AsyncWrite` traits as well, otherwise this function will not work + /// properly. + fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync + where S: Read + Write; // TODO: change to AsyncRead + AsyncWrite + } + + mod sealed { + pub trait Sealed {} + } + + impl Read for TlsStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.inner.read(buf) + } + } + + impl Write for TlsStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.inner.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.inner.flush() + } + } + + + impl AsyncRead for TlsStream { + } + + impl AsyncWrite for TlsStream { + fn shutdown(&mut self) -> Poll<(), io::Error> { + try_nb!(self.inner.shutdown()); + self.inner.get_mut().shutdown() + } + } + + impl TlsConnectorExt for TlsConnector { + fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync + where S: Read + Write, + { + ConnectAsync { + inner: MidHandshake { + inner: Some(self.connect(domain, stream)), + }, + } + } + } + + impl sealed::Sealed for TlsConnector {} + + // TODO: change this to AsyncRead/AsyncWrite on next major version + impl Future for ConnectAsync { + type Item = TlsStream; + type Error = Error; + + fn poll(&mut self) -> Poll, Error> { + self.inner.poll() + } + } + + // TODO: change this to AsyncRead/AsyncWrite on next major version + impl Future for MidHandshake { + type Item = TlsStream; + type Error = Error; + + fn poll(&mut self) -> Poll, Error> { + match self.inner.take().expect("cannot poll MidHandshake twice") { + Ok(stream) => Ok(TlsStream { inner: stream }.into()), + Err(HandshakeError::Failure(e)) => Err(e), + Err(HandshakeError::WouldBlock(s)) => { + match s.handshake() { + Ok(stream) => Ok(TlsStream { inner: stream }.into()), + Err(HandshakeError::Failure(e)) => Err(e), + Err(HandshakeError::WouldBlock(s)) => { + self.inner = Some(Err(HandshakeError::WouldBlock(s))); + Ok(Async::NotReady) + } + } + } + } + } + } +} + #[cfg(feature = "tls")] #[cfg(test)] mod tests { diff --git a/src/connect_async.rs b/src/connect_async.rs deleted file mode 100644 index ee9c769..0000000 --- a/src/connect_async.rs +++ /dev/null @@ -1,130 +0,0 @@ -use std::io::{self, Read, Write}; - -use futures::{Poll, Future, Async}; -use native_tls::{self, HandshakeError, Error, TlsConnector}; -use tokio_io::{AsyncRead, AsyncWrite}; - -/// A wrapper around an underlying raw stream which implements the TLS or SSL -/// protocol. -/// -/// A `TlsStream` represents a handshake that has been completed successfully -/// and both the server and the client are ready for receiving and sending -/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written -/// to a `TlsStream` are encrypted when passing through to `S`. -#[derive(Debug)] -pub struct TlsStream { - inner: native_tls::TlsStream, -} - -/// Future returned from `TlsConnectorExt::connect_async` which will resolve -/// once the connection handshake has finished. -pub struct ConnectAsync { - inner: MidHandshake, -} - -struct MidHandshake { - inner: Option, HandshakeError>>, -} - -/// Extension trait for the `TlsConnector` type in the `native_tls` crate. -pub trait TlsConnectorExt: sealed::Sealed { - /// Connects the provided stream with this connector, assuming the provided - /// domain. - /// - /// This function will internally call `TlsConnector::connect` to connect - /// the stream and returns a future representing the resolution of the - /// connection operation. The returned future will resolve to either - /// `TlsStream` or `Error` depending if it's successful or not. - /// - /// This is typically used for clients who have already established, for - /// example, a TCP connection to a remote server. That stream is then - /// provided here to perform the client half of a connection to a - /// TLS-powered server. - /// - /// # Compatibility notes - /// - /// Note that this method currently requires `S: Read + Write` but it's - /// highly recommended to ensure that the object implements the `AsyncRead` - /// and `AsyncWrite` traits as well, otherwise this function will not work - /// properly. - fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync - where S: Read + Write; // TODO: change to AsyncRead + AsyncWrite -} - -mod sealed { - pub trait Sealed {} -} - -impl Read for TlsStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.read(buf) - } -} - -impl Write for TlsStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.inner.flush() - } -} - - -impl AsyncRead for TlsStream { -} - -impl AsyncWrite for TlsStream { - fn shutdown(&mut self) -> Poll<(), io::Error> { - try_nb!(self.inner.shutdown()); - self.inner.get_mut().shutdown() - } -} - -impl TlsConnectorExt for TlsConnector { - fn connect_async(&self, domain: &str, stream: S) -> ConnectAsync - where S: Read + Write, - { - ConnectAsync { - inner: MidHandshake { - inner: Some(self.connect(domain, stream)), - }, - } - } -} - -impl sealed::Sealed for TlsConnector {} - -// TODO: change this to AsyncRead/AsyncWrite on next major version -impl Future for ConnectAsync { - type Item = TlsStream; - type Error = Error; - - fn poll(&mut self) -> Poll, Error> { - self.inner.poll() - } -} - -// TODO: change this to AsyncRead/AsyncWrite on next major version -impl Future for MidHandshake { - type Item = TlsStream; - type Error = Error; - - fn poll(&mut self) -> Poll, Error> { - match self.inner.take().expect("cannot poll MidHandshake twice") { - Ok(stream) => Ok(TlsStream { inner: stream }.into()), - Err(HandshakeError::Failure(e)) => Err(e), - Err(HandshakeError::WouldBlock(s)) => { - match s.handshake() { - Ok(stream) => Ok(TlsStream { inner: stream }.into()), - Err(HandshakeError::Failure(e)) => Err(e), - Err(HandshakeError::WouldBlock(s)) => { - self.inner = Some(Err(HandshakeError::WouldBlock(s))); - Ok(Async::NotReady) - } - } - } - } - } -} diff --git a/src/lib.rs b/src/lib.rs index 788430c..0c9c2fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -236,8 +236,6 @@ mod error; mod async_impl; mod connect; -#[cfg(feature = "default-tls")] -mod connect_async; mod body; mod client; #[cfg(feature = "trust-dns")]