Skip to content

Commit e3d385f

Browse files
committed
Add retry tests
1 parent 74aa7e2 commit e3d385f

File tree

2 files changed

+68
-1
lines changed

2 files changed

+68
-1
lines changed

src/dsdk/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __setitem__(self, key, value):
137137
def retry(
138138
exceptions: Sequence[Exception],
139139
retries: int = 5,
140-
delay: int = 1,
140+
delay: float = 1.0,
141141
backoff: float = 1.5,
142142
sleep: Callable = default_sleep,
143143
):
@@ -152,6 +152,8 @@ def retry(
152152
backoff: Backoff multiplier (e.g. value of 2 will double the delay
153153
each retry).
154154
"""
155+
delay = float(delay)
156+
backoff = float(backoff)
155157

156158
def wrapper(func):
157159
@wraps(func)

tests/test_dsdk.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import configargparse
77

88
from dsdk import BaseBatchJob, Block
9+
from dsdk.utils import retry
910

1011

1112
def test_batch(monkeypatch):
@@ -23,3 +24,67 @@ def run(self):
2324
batch.run()
2425
assert len(batch.evidence) == 1
2526
assert batch.evidence["test"] == 42
27+
28+
29+
def test_retry_other_exception():
30+
"""Test retry other exception."""
31+
32+
exceptions_in = [
33+
RuntimeError("what?"),
34+
NotImplementedError("how?"),
35+
RuntimeError("no!"),
36+
]
37+
actual = []
38+
expected = [1.0, 1.5, 2.25]
39+
40+
def sleep(wait: float):
41+
actual.append(wait)
42+
43+
@retry(
44+
(NotImplementedError, RuntimeError),
45+
retries=4,
46+
delay=1.0,
47+
backoff=1.5,
48+
sleep=sleep,
49+
)
50+
def explode():
51+
raise exceptions_in.pop()
52+
53+
try:
54+
explode()
55+
raise AssertionError("IndexError expected")
56+
except IndexError:
57+
assert actual == expected
58+
59+
60+
def test_retry_exhausted():
61+
"""Test retry."""
62+
63+
exceptions_in = [
64+
RuntimeError("what?"),
65+
NotImplementedError("how?"),
66+
RuntimeError("no!"),
67+
NotImplementedError("when?"),
68+
]
69+
actual = []
70+
expected = [1.0, 1.5]
71+
72+
def sleep(wait: float):
73+
actual.append(wait)
74+
75+
@retry(
76+
(NotImplementedError, RuntimeError),
77+
retries=2,
78+
delay=1.0,
79+
backoff=1.5,
80+
sleep=sleep,
81+
)
82+
def explode():
83+
raise exceptions_in.pop()
84+
85+
try:
86+
explode()
87+
raise AssertionError("NotImplementedError expected")
88+
except NotImplementedError as exception:
89+
assert actual == expected
90+
assert str(exception) == "when?"

0 commit comments

Comments
 (0)