diff --git a/indico/config/config.py b/indico/config/config.py index 04713351..f346c7f3 100644 --- a/indico/config/config.py +++ b/indico/config/config.py @@ -22,6 +22,10 @@ class IndicoConfig: api_token= (str, optional): The actual text of the API Token. Takes precedence over api_token_path verify_ssl= (bool, optional): Whether to verify the host's SSL certificate. Default=True requests_params= (dict, optional): Dictionary of requests. Session parameters to set + retry_count= (int, optional): Retry API calls this many times. + retry_wait= (float, optional): Wait this many seconds after the first error before retrying. + retry_backoff= (float, optional): Multiply the wait time by this amount for each additional error. + retry_jitter= (float, optional): Add a random amount of time (up to this percent as a decimal) to the wait time to prevent simultaneous retries. Returns: IndicoConfig object @@ -42,6 +46,11 @@ def __init__(self, **kwargs: "Any"): self.requests_params: "Optional[AnyDict]" = None self._disable_cookie_domain: bool = False + self.retry_count: int = int(os.getenv("INDICO_RETRY_COUNT", "4")) + self.retry_wait: float = float(os.getenv("INDICO_RETRY_WAIT", "1")) + self.retry_backoff: float = float(os.getenv("INDICO_RETRY_BACKOFF", "4")) + self.retry_jitter: float = float(os.getenv("INDICO_RETRY_JITTER", "1")) + for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) diff --git a/indico/http/client.py b/indico/http/client.py index a4ab3513..63ac820f 100644 --- a/indico/http/client.py +++ b/indico/http/client.py @@ -17,7 +17,7 @@ ) from indico.http.serialization import aio_deserialize, deserialize -from .retry import aioretry +from .retry import retry if TYPE_CHECKING: # pragma: no cover from http.cookiejar import Cookie @@ -50,6 +50,14 @@ class HTTPClient: def __init__(self, config: "Optional[IndicoConfig]" = None): self.config = config or IndicoConfig() self.base_url = f"{self.config.protocol}://{self.config.host}" + self._decorate_with_retry = retry( + requests.RequestException, + count=self.config.retry_count, + wait=self.config.retry_wait, + backoff=self.config.retry_backoff, + jitter=self.config.retry_jitter, + ) + self._make_request = self._decorate_with_retry(self._make_request) # type: ignore[method-assign] self.request_session = requests.Session() if isinstance(self.config.requests_params, dict): @@ -169,6 +177,7 @@ def _make_request( f"{self.base_url}{path}", headers=headers, stream=True, + timeout=(4, 64), verify=False if not self.config.verify_ssl or not self.request_session.verify else True, @@ -232,6 +241,14 @@ def __init__(self, config: "Optional[IndicoConfig]" = None): """ self.config = config or IndicoConfig() self.base_url = f"{self.config.protocol}://{self.config.host}" + self._decorate_with_retry = retry( + aiohttp.ClientConnectionError, + count=self.config.retry_count, + wait=self.config.retry_wait, + backoff=self.config.retry_backoff, + jitter=self.config.retry_jitter, + ) + self._make_request = self._decorate_with_retry(self._make_request) # type: ignore[method-assign] self.request_session = aiohttp.ClientSession() if isinstance(self.config.requests_params, dict): @@ -316,7 +333,6 @@ def _handle_files( for f in files: f.close() - @aioretry(aiohttp.ClientConnectionError, aiohttp.ServerDisconnectedError) async def _make_request( self, method: str, @@ -346,6 +362,7 @@ async def _make_request( async with getattr(self.request_session, method)( f"{self.base_url}{path}", headers=headers, + timeout=aiohttp.ClientTimeout(sock_connect=4, sock_read=64), verify_ssl=self.config.verify_ssl, **request_kwargs, ) as response: diff --git a/indico/http/retry.py b/indico/http/retry.py index 85ed1850..11c29055 100644 --- a/indico/http/retry.py +++ b/indico/http/retry.py @@ -1,85 +1,100 @@ import asyncio import time from functools import wraps -from random import randint -from typing import TYPE_CHECKING +from inspect import iscoroutinefunction +from random import random +from typing import TYPE_CHECKING, overload -if TYPE_CHECKING: # pragma: no cover - from typing import Awaitable, Callable, Optional, Tuple, Type, TypeVar, Union +if TYPE_CHECKING: + import sys + from collections.abc import Awaitable, Callable + from typing import Type - from typing_extensions import ParamSpec + if sys.version_info >= (3, 10): + from typing import ParamSpec, TypeVar + else: + from typing_extensions import ParamSpec, TypeVar - P = ParamSpec("P") - T = TypeVar("T") + ArgumentsType = ParamSpec("ArgumentsType") + OuterReturnType = TypeVar("OuterReturnType") + InnerReturnType = TypeVar("InnerReturnType") def retry( - *ExceptionTypes: "Type[Exception]", tries: int = 3, delay: int = 1, backoff: int = 2 -) -> "Callable[[Callable[P, T]], Callable[P, T]]": + *errors: "Type[Exception]", + count: int, + wait: float, + backoff: float, + jitter: float, +) -> "Callable[[Callable[ArgumentsType, OuterReturnType]], Callable[ArgumentsType, OuterReturnType]]": # noqa: E501 """ - Retry with exponential backoff + Decorate a function or coroutine to retry when it raises specified errors, + apply exponential backoff and jitter to the wait time, + and raise the last error if it retries too many times. - Original from: http://wiki.python.org/moin/PythonDecoratorLibrary#Retry + Arguments: + errors: Retry the function when it raises one of these errors. + count: Retry the function this many times. + wait: Wait this many seconds after the first error before retrying. + backoff: Multiply the wait time by this amount for each additional error. + jitter: Add a random amount of time (up to this percent as a decimal) + to the wait time to prevent simultaneous retries. """ - def retry_decorator(f: "Callable[P, T]") -> "Callable[P, T]": - @wraps(f) - def retry_fn(*args: "P.args", **kwargs: "P.kwargs") -> "T": - n_tries, n_delay = tries, delay - while n_tries > 1: - try: - return f(*args, **kwargs) - except ExceptionTypes: - time.sleep(n_delay) - n_tries -= 1 - n_delay *= backoff - return f(*args, **kwargs) - - return retry_fn + def wait_time(times_retried: int) -> float: + """ + Calculate the sleep time based on number of times retried. + """ + return wait * backoff**times_retried * (1 + jitter * random()) - return retry_decorator + @overload + def retry_decorator( + decorated: "Callable[ArgumentsType, Awaitable[InnerReturnType]]", + ) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]]": ... + @overload + def retry_decorator( + decorated: "Callable[ArgumentsType, InnerReturnType]", + ) -> "Callable[ArgumentsType, InnerReturnType]": ... -def aioretry( - *ExceptionTypes: "Type[Exception]", - tries: int = 3, - delay: "Union[int, Tuple[int, int]]" = 1, - backoff: int = 2, - condition: "Optional[Callable[[Exception], bool]]" = None, -) -> "Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]": - """ - Retry with exponential backoff - - Original from: http://wiki.python.org/moin/PythonDecoratorLibrary#Retry - Options: - condition: Callable to evaluate if an exception of a given type - is retryable for additional handling - delay: an initial time to wait (seconds). If a tuple, choose a random number - in that range to start. This can helps prevent retries at the exact - same time across multiple concurrent function calls - """ + def retry_decorator( + decorated: "Callable[ArgumentsType, InnerReturnType]", + ) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]] | Callable[ArgumentsType, InnerReturnType]": # noqa: E501 + """ + Decorate either a function or coroutine as appropriate. + """ + if iscoroutinefunction(decorated): + + @wraps(decorated) + async def retrying_coroutine( # type: ignore[return] + *args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs" + ) -> "InnerReturnType": + for times_retried in range(count + 1): + try: + return await decorated(*args, **kwargs) # type: ignore[no-any-return] + except errors: + if times_retried >= count: + raise + + await asyncio.sleep(wait_time(times_retried)) + + return retrying_coroutine + + else: + + @wraps(decorated) + def retrying_function( # type: ignore[return] + *args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs" + ) -> "InnerReturnType": + for times_retried in range(count + 1): + try: + return decorated(*args, **kwargs) + except errors: + if times_retried >= count: + raise + + time.sleep(wait_time(times_retried)) - def retry_decorator(f: "Callable[P, Awaitable[T]]") -> "Callable[P, Awaitable[T]]": - @wraps(f) - async def retry_fn(*args: "P.args", **kwargs: "P.kwargs") -> "T": - n_tries = tries - if isinstance(delay, tuple): - # pick a random number to sleep - n_delay = randint(*delay) - else: - n_delay = delay - while True: - try: - return await f(*args, **kwargs) - except ExceptionTypes as e: - if condition and not condition(e): - raise - await asyncio.sleep(n_delay) - n_tries -= 1 - n_delay *= backoff - if n_tries <= 0: - raise - - return retry_fn + return retrying_function return retry_decorator diff --git a/tests/unit/http/test_retry.py b/tests/unit/http/test_retry.py new file mode 100644 index 00000000..618e9ebd --- /dev/null +++ b/tests/unit/http/test_retry.py @@ -0,0 +1,73 @@ +import pytest + +from indico.http.retry import retry + + +def test_no_errors() -> None: + @retry(Exception, count=0, wait=0, backoff=0, jitter=0) + def no_errors() -> bool: + return True + + assert no_errors() + + +def test_raises_errors() -> None: + calls = 0 + + @retry(RuntimeError, count=4, wait=0, backoff=0, jitter=0) + def raises_errors() -> None: + nonlocal calls + calls += 1 + raise RuntimeError() + + with pytest.raises(RuntimeError): + raises_errors() + + assert calls == 5 + + +def test_raises_other_errors() -> None: + calls = 0 + + @retry(RuntimeError, count=4, wait=0, backoff=0, jitter=0) + def raises_errors() -> None: + nonlocal calls + calls += 1 + raise ValueError() + + with pytest.raises(ValueError): + raises_errors() + + assert calls == 1 + + +@pytest.mark.asyncio +async def test_raises_errors_async() -> None: + calls = 0 + + @retry(RuntimeError, count=4, wait=0, backoff=0, jitter=0) + async def raises_errors() -> None: + nonlocal calls + calls += 1 + raise RuntimeError() + + with pytest.raises(RuntimeError): + await raises_errors() + + assert calls == 5 + + +@pytest.mark.asyncio +async def test_raises_other_errors_async() -> None: + calls = 0 + + @retry(RuntimeError, count=4, wait=0, backoff=0, jitter=0) + async def raises_errors() -> None: + nonlocal calls + calls += 1 + raise ValueError() + + with pytest.raises(ValueError): + await raises_errors() + + assert calls == 1