diff --git a/src/dsdk/utils.py b/src/dsdk/utils.py index a9d68af..9debc0d 100644 --- a/src/dsdk/utils.py +++ b/src/dsdk/utils.py @@ -6,6 +6,10 @@ import pickle from collections import OrderedDict from datetime import datetime +from functools import wraps +from logging import NullHandler, getLogger +from time import sleep as default_sleep +from typing import Callable, Sequence from warnings import warn from configargparse import ArgParser @@ -29,6 +33,10 @@ MongoClient = None +logger = getLogger(__name__) +logger.addHandler(NullHandler()) + + def get_base_config() -> ArgParser: """Get the base configuration parser.""" config_parser = ArgParser( @@ -124,3 +132,48 @@ def __setitem__(self, key, value): if key in self: raise KeyError("{} has already been set".format(key)) super(WriteOnceDict, self).__setitem__(key, value) + + +def retry( + exceptions: Sequence[Exception], + retries: int = 5, + delay: float = 1.0, + backoff: float = 1.5, + sleep: Callable = default_sleep, +): + """ + Retry calling the decorated function using an exponential backoff. + + Args: + exceptions: The exception to check. may be a tuple of + exceptions to check. + retries: Number of times to retry before giving up. + delay: Initial delay between retries in seconds. + backoff: Backoff multiplier (e.g. value of 2 will double the delay + each retry). + """ + delay = float(delay) + backoff = float(backoff) + + def wrapper(func): + @wraps(func) + def wrapped(*args, **kwargs): + try: + return func(*args, **kwargs) + except exceptions as exception: + logger.exception(exception) + wait = delay + for _ in range(retries): + message = f"Retrying in {wait:.2f} seconds..." + logger.warning(message) + sleep(wait) + wait *= backoff + try: + return func(*args, **kwargs) + except exceptions as exception: + logger.exception(exception) + raise + + return wrapped + + return wrapper diff --git a/tests/test_dsdk.py b/tests/test_dsdk.py index a8399e2..c3b79a6 100644 --- a/tests/test_dsdk.py +++ b/tests/test_dsdk.py @@ -6,6 +6,7 @@ import configargparse from dsdk import BaseBatchJob, Block +from dsdk.utils import retry def test_batch(monkeypatch): @@ -23,3 +24,67 @@ def run(self): batch.run() assert len(batch.evidence) == 1 assert batch.evidence["test"] == 42 + + +def test_retry_other_exception(): + """Test retry other exception.""" + + exceptions_in = [ + RuntimeError("what?"), + NotImplementedError("how?"), + RuntimeError("no!"), + ] + actual = [] + expected = [1.0, 1.5, 2.25] + + def sleep(wait: float): + actual.append(wait) + + @retry( + (NotImplementedError, RuntimeError), + retries=4, + delay=1.0, + backoff=1.5, + sleep=sleep, + ) + def explode(): + raise exceptions_in.pop() + + try: + explode() + raise AssertionError("IndexError expected") + except IndexError: + assert actual == expected + + +def test_retry_exhausted(): + """Test retry.""" + + exceptions_in = [ + RuntimeError("what?"), + NotImplementedError("how?"), + RuntimeError("no!"), + NotImplementedError("when?"), + ] + actual = [] + expected = [1.0, 1.5] + + def sleep(wait: float): + actual.append(wait) + + @retry( + (NotImplementedError, RuntimeError), + retries=2, + delay=1.0, + backoff=1.5, + sleep=sleep, + ) + def explode(): + raise exceptions_in.pop() + + try: + explode() + raise AssertionError("NotImplementedError expected") + except NotImplementedError as exception: + assert actual == expected + assert str(exception) == "when?"