diff --git a/invertedai/utils.py b/invertedai/utils.py index 48af183c..fb862826 100644 --- a/invertedai/utils.py +++ b/invertedai/utils.py @@ -6,6 +6,7 @@ import csv import math import logging +import random import time from typing import Dict, Optional, List, Tuple @@ -76,6 +77,7 @@ def __init__(self): self._status_force_list = [408, 429, 500, 502, 503, 504] self._base_backoff = 1 # Base backoff time in seconds self._backoff_factor = 2 + self._jitter_factor = 0.5 self._current_backoff = self._base_backoff self._max_backoff = None @@ -135,6 +137,15 @@ def max_backoff(self): def max_backoff(self, value): self._max_backoff = value + @property + def jitter_factor(self): + return self._jitter_factor + + @jitter_factor.setter + def jitter_factor(self, value): + self._jitter_factor = value + + def should_log(self, retry_count): return retry_count == 0 or math.log2(retry_count).is_integer() @@ -245,26 +256,37 @@ def _request( try: retries = 0 while retries < self.max_retries: - response = self.session.request( - method=method, - params=params, - url=self.base_url + relative_path, - headers=headers, - data=data, - json=json_body, - ) - if response.status_code not in self.status_force_list: + try: + response = self.session.request( + method=method, + params=params, + url=self.base_url + relative_path, + headers=headers, + data=data, + json=json_body, + ) + except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e: + logger.warning("Error communicating with IAI, will retry.") + response = None + if response is not None and response.status_code not in self.status_force_list: self.current_backoff = max( self.base_backoff, self.current_backoff / self.backoff_factor ) response.raise_for_status() break else: + if self.jitter_factor is not None: + jitter = random.uniform(-self.jitter_factor, self.jitter_factor) + else: + jitter = 0 if self.should_log(retries): - logger.warning( - f"Retrying {relative_path}: Status {response.status_code}, Message {STATUS_MESSAGE.get(response.status_code, response.text)} Retry #{retries + 1}, Backoff {self.current_backoff} seconds" - ) - time.sleep(self.current_backoff) + if response is not None: + logger.warning( + f"Retrying {relative_path}: Status {response.status_code}, Message {STATUS_MESSAGE.get(response.status_code, response.text)} Retry #{retries + 1}, Backoff {self.current_backoff} seconds" + ) + else: + logger.warning(f"Retrying {relative_path}: No response received, Retry #{retries + 1}, Backoff {self.current_backoff} seconds") + time.sleep(min(self.current_backoff * (1 + jitter), self.max_backoff if self.max_backoff is not None else float("inf"))) self.current_backoff *= self.backoff_factor if self.max_backoff is not None: self.current_backoff = min( @@ -272,7 +294,11 @@ def _request( ) retries += 1 else: - response.raise_for_status() + if response is not None: + response.raise_for_status() + else: + error.APIConnectionError( + "Error communicating with IAI", should_retry=True) except requests.exceptions.ConnectionError as e: diff --git a/invertedai_cpp/invertedai/session.cc b/invertedai_cpp/invertedai/session.cc index 165e150e..2af8fd9c 100644 --- a/invertedai_cpp/invertedai/session.cc +++ b/invertedai_cpp/invertedai/session.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -97,9 +98,10 @@ const std::string Session::request( const std::string &url_query_string, double max_retries, const std::vector& status_force_list, - int base_backoff, - int backoff_factor, - int max_backoff + double base_backoff, + double backoff_factor, + double max_backoff, + double jitter_factor ) { std::string target = subdomain + mode + url_query_string; @@ -113,6 +115,7 @@ const std::string Session::request( req.set("accept", "application/json"); req.set("x-api-key", this->api_key_); req.set("x-client-version", INVERTEDAI_VERSION); + req.set("Connection","keep-alive"); if (debug_mode) { std::cout << "req body content:\n"; std::cout << body_str << std::endl; @@ -138,7 +141,11 @@ const std::string Session::request( else{ http::read(this->ssl_stream_, buffer, res, ec); } - if (!(res.result() == http::status::ok)) { + std::cout << mode << " " << res.result() << " "<< ec << " " << res.result_int() << std::endl; + if (!(res.result() == http::status::ok) || ec) { + if (res.result_int() == 500) { + this->connect(); + } if (std::find(status_force_list.begin(), status_force_list.end(), res.result_int()) != status_force_list.end() || ec) { int delay_seconds = base_backoff * std::pow(backoff_factor, retry_count); if (max_backoff > 0 && delay_seconds > max_backoff) { diff --git a/invertedai_cpp/invertedai/session.h b/invertedai_cpp/invertedai/session.h index 60b88053..aa287c97 100644 --- a/invertedai_cpp/invertedai/session.h +++ b/invertedai_cpp/invertedai/session.h @@ -30,10 +30,11 @@ class Session { const bool local_mode = iai_dev && (std::string(iai_dev) == "1" || std::string(iai_dev) == "True"); double max_retries = std::numeric_limits::infinity(); // Allows for infinite retries by default std::vector status_force_list = {408, 429, 500, 502, 503, 504}; - int base_backoff = 1; // Base backoff time in seconds - int backoff_factor = 2; - int current_backoff = base_backoff; - int max_backoff = 0; // No max backoff by default, 0 signifies no limit + double base_backoff = 1; // Base backoff time in seconds + double backoff_factor = 2; + double current_backoff = base_backoff; + double max_backoff = 0; // No max backoff by default, 0 signifies no limit + double jitter_factor = 0.5; public: const char* host_ = local_mode ? "localhost" : "api.inverted.ai"; @@ -42,7 +43,9 @@ class Session { const int version_ = 11; explicit Session(net::io_context &ioc, ssl::context &ctx) - : resolver_(ioc), ssl_stream_(ioc, ctx), tcp_stream_(ioc){}; + : resolver_(ioc), ssl_stream_(ioc, ctx), tcp_stream_(ioc){ + tcp_stream_.expires_never(); + }; /** * Set your own api key here. @@ -78,9 +81,10 @@ class Session { const std::string &url_params, double max_retries = std::numeric_limits::infinity(), const std::vector& status_force_list = {408, 429, 500, 502, 503, 504}, - int base_backoff = 1, - int backoff_factor = 2, - int max_backoff = 0 // No max by default + double base_backoff = 1, + double backoff_factor = 2, + double max_backoff = 0, // No max by default + double jitter_factor = 0.5 ); };