diff --git a/Cargo.toml b/Cargo.toml index 83ba178..d5f82e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -96,6 +96,7 @@ futures-core = { version = "0.3.0", default-features = false } futures-util = { version = "0.3.0", default-features = false } http-body = "0.4.0" hyper = { version = "0.14", default-features = false, features = ["tcp", "http1", "http2", "client", "runtime"] } +h2 = "0.3.10" lazy_static = "1.4" log = "0.4" mime = "0.3.16" diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index d1526d9..43772ae 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -22,7 +22,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; use tokio::time::Sleep; -use log::debug; +use log::{debug, trace}; use super::decoder::Accepts; use super::request::{Request, RequestBuilder}; @@ -1418,6 +1418,8 @@ impl Client { urls: Vec::new(), + retry_count: 0, + client: self.inner.clone(), in_flight, @@ -1628,6 +1630,8 @@ pin_project! { urls: Vec, + retry_count: usize, + client: Arc, #[pin] @@ -1653,6 +1657,54 @@ impl PendingRequest { fn headers(self: Pin<&mut Self>) -> &mut HeaderMap { self.project().headers } + + fn retry_error(mut self: Pin<&mut Self>, err: &(dyn std::error::Error + 'static)) -> bool { + if !is_retryable_error(err) { + return false; + } + + trace!("can retry {:?}", err); + + let body = match self.body { + Some(Some(ref body)) => Body::reusable(body.clone()), + Some(None) => { + debug!("error was retryable, but body not reusable"); + return false; + } + None => Body::empty(), + }; + + if self.retry_count >= 2 { + trace!("retry count too high"); + return false; + } + self.retry_count += 1; + + let uri = expect_uri(&self.url); + let mut req = hyper::Request::builder() + .method(self.method.clone()) + .uri(uri) + .body(body.into_stream()) + .expect("valid request parts"); + + *req.headers_mut() = self.headers.clone(); + + *self.as_mut().in_flight().get_mut() = self.client.hyper.request(req); + + true + } +} + +fn is_retryable_error(err: &(dyn std::error::Error + 'static)) -> bool { + if let Some(cause) = err.source() { + if let Some(err) = cause.downcast_ref::() { + // They sent us a graceful shutdown, try with a new connection! + return err.is_go_away() + && err.is_remote() + && err.reason() == Some(h2::Reason::NO_ERROR); + } + } + false } impl Pending { @@ -1696,6 +1748,9 @@ impl Future for PendingRequest { loop { let res = match self.as_mut().in_flight().as_mut().poll(cx) { Poll::Ready(Err(e)) => { + if self.as_mut().retry_error(&e) { + continue; + } return Poll::Ready(Err(crate::error::request(e).with_url(self.url.clone()))); } Poll::Ready(Ok(res)) => res,