Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions indico/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class IndicoConfig:
api_token= (str, optional): The actual text of the API Token. Takes precedence over api_token_path
verify_ssl= (bool, optional): Whether to verify the host's SSL certificate. Default=True
requests_params= (dict, optional): Dictionary of requests. Session parameters to set
retry_count= (int, optional): Retry API calls this many times.
retry_wait= (float, optional): Wait this many seconds after the first error before retrying.
retry_backoff= (float, optional): Multiply the wait time by this amount for each additional error.
retry_jitter= (float, optional): Add a random amount of time (up to this percent as a decimal) to the wait time to prevent simultaneous retries.

Returns:
IndicoConfig object
Expand All @@ -42,6 +46,11 @@ def __init__(self, **kwargs: "Any"):
self.requests_params: "Optional[AnyDict]" = None
self._disable_cookie_domain: bool = False

self.retry_count: int = int(os.getenv("INDICO_RETRY_COUNT", "4"))
self.retry_wait: float = float(os.getenv("INDICO_RETRY_WAIT", "1"))
self.retry_backoff: float = float(os.getenv("INDICO_RETRY_BACKOFF", "4"))
self.retry_jitter: float = float(os.getenv("INDICO_RETRY_JITTER", "1"))

for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
Expand Down
21 changes: 19 additions & 2 deletions indico/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from indico.http.serialization import aio_deserialize, deserialize

from .retry import aioretry
from .retry import retry

if TYPE_CHECKING: # pragma: no cover
from http.cookiejar import Cookie
Expand Down Expand Up @@ -50,6 +50,14 @@ class HTTPClient:
def __init__(self, config: "Optional[IndicoConfig]" = None):
self.config = config or IndicoConfig()
self.base_url = f"{self.config.protocol}://{self.config.host}"
self._decorate_with_retry = retry(
requests.RequestException,
count=self.config.retry_count,
wait=self.config.retry_wait,
backoff=self.config.retry_backoff,
jitter=self.config.retry_jitter,
)
self._make_request = self._decorate_with_retry(self._make_request) # type: ignore[method-assign]

self.request_session = requests.Session()
if isinstance(self.config.requests_params, dict):
Expand Down Expand Up @@ -169,6 +177,7 @@ def _make_request(
f"{self.base_url}{path}",
headers=headers,
stream=True,
timeout=(4, 64),
verify=False
if not self.config.verify_ssl or not self.request_session.verify
else True,
Expand Down Expand Up @@ -232,6 +241,14 @@ def __init__(self, config: "Optional[IndicoConfig]" = None):
"""
self.config = config or IndicoConfig()
self.base_url = f"{self.config.protocol}://{self.config.host}"
self._decorate_with_retry = retry(
aiohttp.ClientConnectionError,
count=self.config.retry_count,
wait=self.config.retry_wait,
backoff=self.config.retry_backoff,
jitter=self.config.retry_jitter,
)
self._make_request = self._decorate_with_retry(self._make_request) # type: ignore[method-assign]

self.request_session = aiohttp.ClientSession()
if isinstance(self.config.requests_params, dict):
Expand Down Expand Up @@ -316,7 +333,6 @@ def _handle_files(
for f in files:
f.close()

@aioretry(aiohttp.ClientConnectionError, aiohttp.ServerDisconnectedError)
async def _make_request(
self,
method: str,
Expand Down Expand Up @@ -346,6 +362,7 @@ async def _make_request(
async with getattr(self.request_session, method)(
f"{self.base_url}{path}",
headers=headers,
timeout=aiohttp.ClientTimeout(sock_connect=4, sock_read=64),
verify_ssl=self.config.verify_ssl,
**request_kwargs,
) as response:
Expand Down
147 changes: 81 additions & 66 deletions indico/http/retry.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,100 @@
import asyncio
import time
from functools import wraps
from random import randint
from typing import TYPE_CHECKING
from inspect import iscoroutinefunction
from random import random
from typing import TYPE_CHECKING, overload

if TYPE_CHECKING: # pragma: no cover
from typing import Awaitable, Callable, Optional, Tuple, Type, TypeVar, Union
if TYPE_CHECKING:
import sys
from collections.abc import Awaitable, Callable
from typing import Type

from typing_extensions import ParamSpec
if sys.version_info >= (3, 10):
from typing import ParamSpec, TypeVar
else:
from typing_extensions import ParamSpec, TypeVar

P = ParamSpec("P")
T = TypeVar("T")
ArgumentsType = ParamSpec("ArgumentsType")
OuterReturnType = TypeVar("OuterReturnType")
InnerReturnType = TypeVar("InnerReturnType")


def retry(
*ExceptionTypes: "Type[Exception]", tries: int = 3, delay: int = 1, backoff: int = 2
) -> "Callable[[Callable[P, T]], Callable[P, T]]":
*errors: "Type[Exception]",
count: int,
wait: float,
backoff: float,
jitter: float,
) -> "Callable[[Callable[ArgumentsType, OuterReturnType]], Callable[ArgumentsType, OuterReturnType]]": # noqa: E501
"""
Retry with exponential backoff
Decorate a function or coroutine to retry when it raises specified errors,
apply exponential backoff and jitter to the wait time,
and raise the last error if it retries too many times.

Original from: http://wiki.python.org/moin/PythonDecoratorLibrary#Retry
Arguments:
errors: Retry the function when it raises one of these errors.
count: Retry the function this many times.
wait: Wait this many seconds after the first error before retrying.
backoff: Multiply the wait time by this amount for each additional error.
jitter: Add a random amount of time (up to this percent as a decimal)
to the wait time to prevent simultaneous retries.
"""

def retry_decorator(f: "Callable[P, T]") -> "Callable[P, T]":
@wraps(f)
def retry_fn(*args: "P.args", **kwargs: "P.kwargs") -> "T":
n_tries, n_delay = tries, delay
while n_tries > 1:
try:
return f(*args, **kwargs)
except ExceptionTypes:
time.sleep(n_delay)
n_tries -= 1
n_delay *= backoff
return f(*args, **kwargs)

return retry_fn
def wait_time(times_retried: int) -> float:
"""
Calculate the sleep time based on number of times retried.
"""
return wait * backoff**times_retried * (1 + jitter * random())

return retry_decorator
@overload
def retry_decorator(
decorated: "Callable[ArgumentsType, Awaitable[InnerReturnType]]",
) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]]": ...

@overload
def retry_decorator(
decorated: "Callable[ArgumentsType, InnerReturnType]",
) -> "Callable[ArgumentsType, InnerReturnType]": ...

def aioretry(
*ExceptionTypes: "Type[Exception]",
tries: int = 3,
delay: "Union[int, Tuple[int, int]]" = 1,
backoff: int = 2,
condition: "Optional[Callable[[Exception], bool]]" = None,
) -> "Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]":
"""
Retry with exponential backoff

Original from: http://wiki.python.org/moin/PythonDecoratorLibrary#Retry
Options:
condition: Callable to evaluate if an exception of a given type
is retryable for additional handling
delay: an initial time to wait (seconds). If a tuple, choose a random number
in that range to start. This can helps prevent retries at the exact
same time across multiple concurrent function calls
"""
def retry_decorator(
decorated: "Callable[ArgumentsType, InnerReturnType]",
) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]] | Callable[ArgumentsType, InnerReturnType]": # noqa: E501
"""
Decorate either a function or coroutine as appropriate.
"""
if iscoroutinefunction(decorated):

@wraps(decorated)
async def retrying_coroutine( # type: ignore[return]
*args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs"
) -> "InnerReturnType":
for times_retried in range(count + 1):
try:
return await decorated(*args, **kwargs) # type: ignore[no-any-return]
except errors:
if times_retried >= count:
raise

await asyncio.sleep(wait_time(times_retried))

return retrying_coroutine

else:

@wraps(decorated)
def retrying_function( # type: ignore[return]
*args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs"
) -> "InnerReturnType":
for times_retried in range(count + 1):
try:
return decorated(*args, **kwargs)
except errors:
if times_retried >= count:
raise

time.sleep(wait_time(times_retried))

def retry_decorator(f: "Callable[P, Awaitable[T]]") -> "Callable[P, Awaitable[T]]":
@wraps(f)
async def retry_fn(*args: "P.args", **kwargs: "P.kwargs") -> "T":
n_tries = tries
if isinstance(delay, tuple):
# pick a random number to sleep
n_delay = randint(*delay)
else:
n_delay = delay
while True:
try:
return await f(*args, **kwargs)
except ExceptionTypes as e:
if condition and not condition(e):
raise
await asyncio.sleep(n_delay)
n_tries -= 1
n_delay *= backoff
if n_tries <= 0:
raise

return retry_fn
return retrying_function

return retry_decorator
73 changes: 73 additions & 0 deletions tests/unit/http/test_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest

from indico.http.retry import retry


def test_no_errors() -> None:
@retry(Exception, count=0, wait=0, backoff=0, jitter=0)
def no_errors() -> bool:
return True

assert no_errors()


def test_raises_errors() -> None:
calls = 0

@retry(RuntimeError, count=4, wait=0, backoff=0, jitter=0)
def raises_errors() -> None:
nonlocal calls
calls += 1
raise RuntimeError()

with pytest.raises(RuntimeError):
raises_errors()

assert calls == 5


def test_raises_other_errors() -> None:
calls = 0

@retry(RuntimeError, count=4, wait=0, backoff=0, jitter=0)
def raises_errors() -> None:
nonlocal calls
calls += 1
raise ValueError()

with pytest.raises(ValueError):
raises_errors()

assert calls == 1


@pytest.mark.asyncio
async def test_raises_errors_async() -> None:
calls = 0

@retry(RuntimeError, count=4, wait=0, backoff=0, jitter=0)
async def raises_errors() -> None:
nonlocal calls
calls += 1
raise RuntimeError()

with pytest.raises(RuntimeError):
await raises_errors()

assert calls == 5


@pytest.mark.asyncio
async def test_raises_other_errors_async() -> None:
calls = 0

@retry(RuntimeError, count=4, wait=0, backoff=0, jitter=0)
async def raises_errors() -> None:
nonlocal calls
calls += 1
raise ValueError()

with pytest.raises(ValueError):
await raises_errors()

assert calls == 1