Skip to content

Commit

Permalink
Add retry logic (#33)
Browse files Browse the repository at this point in the history
* Add retry logic
* windows_sys only on windows
* Fix build
  • Loading branch information
dpaoliello authored Dec 30, 2023
1 parent c0a5f8c commit 8134f51
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 14 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ serde = "1.0"
serde_json = "1.0"
sys-info = "0.9"
tokio = { version = "1.35", features = ["rt", "net", "time", "rt-multi-thread"] }
windows-sys = {version = "0.52", features = ["Win32_Foundation", "Win32_Security_Credentials"] }

# Build openssl from source instead of linking it.
# Required for cross-compilation.
native-tls = { version = "0.2", features = ["vendored"] }

[target.'cfg(target_os = "windows")'.dependencies]
windows-sys = {version = "0.52", features = ["Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security_Credentials"] }
120 changes: 107 additions & 13 deletions src/http.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::{error::Error, time::Duration};

use anyhow::{Context, Result};
use bytes::Bytes;
use reqwest::{StatusCode, Url};
use reqwest::{RequestBuilder, Response, StatusCode, Url};

pub struct Client {
inner: reqwest::Client,
Expand All @@ -13,14 +15,65 @@ impl Client {
}
}

fn should_retry(response: &reqwest::Result<Response>) -> bool {
match response {
Ok(response) => {
// Retry on server error
response.status().is_server_error()
}
Err(err) => {
// Retry on timeout.
if err.is_timeout() {
return true;
}

if let Some(err) = err.source() {
if let Some(err) = err.downcast_ref::<std::io::Error>() {
match err.raw_os_error() {
// Retry on DNS lookup failure.
#[cfg(windows)]
Some(windows_sys::Win32::Networking::WinSock::WSAHOST_NOT_FOUND) => {
return true
}
_ => {}
}
}
}

false
}
}
}

async fn send_with_retry(
&self,
make_request: impl Fn(&reqwest::Client) -> RequestBuilder,
) -> reqwest::Result<Response> {
const MAX_RETRIES: u32 = 5;
const RETRY_DELAY: Duration = if cfg!(test) {
Duration::from_millis(5)
} else {
Duration::from_millis(500)
};
let mut retries = 0;

loop {
let response = make_request(&self.inner).send().await;

if retries < MAX_RETRIES && Client::should_retry(&response) {
tokio::time::sleep(RETRY_DELAY.saturating_mul(retries)).await;
retries += 1;
} else {
break response;
}
}
}

pub async fn get<T>(&self, token: &str, url: Url) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
self.inner
.get(url)
.bearer_auth(token)
.send()
self.send_with_retry(|client| client.get(url.clone()).bearer_auth(token))
.await
.with_context(|| "Sending request failed")?
.error_for_status()?
Expand All @@ -39,10 +92,7 @@ impl Client {
T: serde::de::DeserializeOwned,
{
let response = self
.inner
.post(url)
.form(parameters)
.send()
.send_with_retry(|client| client.post(url.clone()).form(parameters))
.await
.with_context(|| "Sending request failed")?;

Expand All @@ -59,17 +109,61 @@ impl Client {

pub async fn download(&self, token: &str, url: Url) -> Result<Bytes> {
Ok(self
.inner
.get(url)
.bearer_auth(token)
.send()
.send_with_retry(|client| client.get(url.clone()).bearer_auth(token))
.await
.with_context(|| "Sending request failed")?
.bytes()
.await?)
}
}

#[tokio::test]
async fn retry_after_server_error() {
let mut server = mockito::Server::new();
let url = server.url();

let fail_mock = server
.mock("GET", "/error")
.with_status(500)
.expect(1)
.create();
let success_mock = server
.mock("GET", "/success")
.with_status(200)
.expect(1)
.create();

let client = Client::new();
let response = client
.send_with_retry(|client| {
let url = if !fail_mock.matched() {
format!("{url}/error")
} else {
format!("{url}/success")
};
client.get(url)
})
.await;

fail_mock.assert();
success_mock.assert();
assert_eq!(response.unwrap().status(), 200);
}

#[tokio::test]
async fn retry_always_error() {
let mut server = mockito::Server::new();
let url = server.url();

let mock = server.mock("GET", "/").with_status(500).expect(6).create();

let client = Client::new();
let response = client.send_with_retry(|client| client.get(&url)).await;

assert_eq!(response.unwrap().status(), 500);
mock.assert();
}

pub trait AppendPaths {
fn append_path(&self, path: &str) -> Self;
fn append_paths(&self, paths: &[&str]) -> Self;
Expand Down

0 comments on commit 8134f51

Please sign in to comment.