Skip to content

Commit

Permalink
Merge pull request #37 from pennsignals/add/retry
Browse files Browse the repository at this point in the history
Add retry decorator
  • Loading branch information
mdbecker authored Jan 22, 2020
2 parents 5cf83cf + 50c0a2e commit acf5fa1
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/dsdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import pickle
from collections import OrderedDict
from datetime import datetime
from functools import wraps
from logging import NullHandler, getLogger
from time import sleep as default_sleep
from typing import Callable, Sequence
from warnings import warn

from configargparse import ArgParser
Expand All @@ -29,6 +33,10 @@
MongoClient = None


logger = getLogger(__name__)
logger.addHandler(NullHandler())


def get_base_config() -> ArgParser:
"""Get the base configuration parser."""
config_parser = ArgParser(
Expand Down Expand Up @@ -124,3 +132,48 @@ def __setitem__(self, key, value):
if key in self:
raise KeyError("{} has already been set".format(key))
super(WriteOnceDict, self).__setitem__(key, value)


def retry(
exceptions: Sequence[Exception],
retries: int = 5,
delay: float = 1.0,
backoff: float = 1.5,
sleep: Callable = default_sleep,
):
"""
Retry calling the decorated function using an exponential backoff.
Args:
exceptions: The exception to check. may be a tuple of
exceptions to check.
retries: Number of times to retry before giving up.
delay: Initial delay between retries in seconds.
backoff: Backoff multiplier (e.g. value of 2 will double the delay
each retry).
"""
delay = float(delay)
backoff = float(backoff)

def wrapper(func):
@wraps(func)
def wrapped(*args, **kwargs):
try:
return func(*args, **kwargs)
except exceptions as exception:
logger.exception(exception)
wait = delay
for _ in range(retries):
message = f"Retrying in {wait:.2f} seconds..."
logger.warning(message)
sleep(wait)
wait *= backoff
try:
return func(*args, **kwargs)
except exceptions as exception:
logger.exception(exception)
raise

return wrapped

return wrapper
65 changes: 65 additions & 0 deletions tests/test_dsdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import configargparse

from dsdk import BaseBatchJob, Block
from dsdk.utils import retry


def test_batch(monkeypatch):
Expand All @@ -23,3 +24,67 @@ def run(self):
batch.run()
assert len(batch.evidence) == 1
assert batch.evidence["test"] == 42


def test_retry_other_exception():
"""Test retry other exception."""

exceptions_in = [
RuntimeError("what?"),
NotImplementedError("how?"),
RuntimeError("no!"),
]
actual = []
expected = [1.0, 1.5, 2.25]

def sleep(wait: float):
actual.append(wait)

@retry(
(NotImplementedError, RuntimeError),
retries=4,
delay=1.0,
backoff=1.5,
sleep=sleep,
)
def explode():
raise exceptions_in.pop()

try:
explode()
raise AssertionError("IndexError expected")
except IndexError:
assert actual == expected


def test_retry_exhausted():
"""Test retry."""

exceptions_in = [
RuntimeError("what?"),
NotImplementedError("how?"),
RuntimeError("no!"),
NotImplementedError("when?"),
]
actual = []
expected = [1.0, 1.5]

def sleep(wait: float):
actual.append(wait)

@retry(
(NotImplementedError, RuntimeError),
retries=2,
delay=1.0,
backoff=1.5,
sleep=sleep,
)
def explode():
raise exceptions_in.pop()

try:
explode()
raise AssertionError("NotImplementedError expected")
except NotImplementedError as exception:
assert actual == expected
assert str(exception) == "when?"

0 comments on commit acf5fa1

Please sign in to comment.