Skip to content

Commit

Permalink
Add test for backoff, Update test_retry_saving to actually test asser…
Browse files Browse the repository at this point in the history
…tions and requeue of failed items. Remove backoff from test_retry_reading
  • Loading branch information
KaspariK committed Jan 22, 2025
1 parent b0e890b commit de2e2f9
Showing 1 changed file with 60 additions and 51 deletions.
111 changes: 60 additions & 51 deletions tests/serialize/runstate/dynamodb_state_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from moto.dynamodb2.responses import dynamo_json_dump

from testifycompat import assert_equal
from testifycompat.assertions import assert_in
from tron.serialize.runstate.dynamodb_state_store import DynamoDBStateStore
from tron.serialize.runstate.dynamodb_state_store import MAX_UNPROCESSED_KEYS_RETRIES

Expand Down Expand Up @@ -296,70 +295,80 @@ def test_delete_item_with_json_partitions(self, store, small_object, large_objec
vals = store.restore([key])
assert key not in vals

@mock.patch("time.sleep", return_value=None)
def test_retry_saving(self, mock_sleep, store, small_object, large_object):
with mock.patch(
"moto.dynamodb2.responses.DynamoHandler.transact_write_items",
side_effect=KeyError("foo"),
) as mock_failed_write:
keys = [store.build_key("job_state", i) for i in range(1)]
value = small_object
pairs = zip(keys, (value for i in range(len(keys))))
try:
store.save(pairs)
except Exception:
assert_equal(mock_failed_write.call_count, 3)

@mock.patch("time.sleep")
@mock.patch("random.uniform")
def test_retry_reading(self, mock_random_uniform, mock_sleep, store, small_object, large_object):
@pytest.mark.parametrize(
"test_object, side_effects, expected_save_errors, expected_queue_length",
[
# All attempts fail
("small_object", [KeyError("foo")] * 3, 3, 1),
("large_object", [KeyError("foo")] * 3, 3, 1),
# Failure followed by success
("small_object", [KeyError("foo"), {}], 0, 0),
("large_object", [KeyError("foo"), {}], 0, 0),
],
)
def test_retry_saving(
self, test_object, side_effects, expected_save_errors, expected_queue_length, store, small_object, large_object
):
object_mapping = {
"small_object": small_object,
"large_object": large_object,
}
value = object_mapping[test_object]

with mock.patch.object(
store.client,
"transact_write_items",
side_effect=side_effects,
) as mock_transact_write:
keys = [store.build_key("job_state", 0)]
pairs = zip(keys, [value])
store.save(pairs)

for _ in side_effects:
store._consume_save_queue()

assert mock_transact_write.call_count == len(side_effects)
assert store.save_errors == expected_save_errors
assert len(store.save_queue) == expected_queue_length

@pytest.mark.parametrize(
"attempt, expected_delay",
[
(1, 0.5),
(2, 1.0),
(3, 2.0),
(4, 4.0),
(5, 8.0),
(6, 10.0),
(7, 10.0),
],
)
def test_calculate_backoff_delay(self, store, attempt, expected_delay):
delay = store._calculate_backoff_delay(attempt)
assert_equal(delay, expected_delay)

def test_retry_reading(self, store):
unprocessed_value = {
"Responses": {},
"UnprocessedKeys": {
store.name: {
"Keys": [
{
"key": {"S": "job_state 0"},
"index": {"N": "0"},
}
],
"Keys": [{"key": {"S": store.build_key("job_state", 0)}, "index": {"N": "0"}}],
"ConsistentRead": True,
}
},
}
keys = [store.build_key("job_state", i) for i in range(1)]
value = small_object
pairs = zip(keys, [value] * len(keys))
store.save(pairs)
store._consume_save_queue()

# Mock random.uniform to return the upper limit of the range so that we are simulating max jitter
def side_effect_random_uniform(a, b):
return b

mock_random_uniform.side_effect = side_effect_random_uniform
keys = [store.build_key("job_state", 0)]

with mock.patch.object(
store.client,
"batch_get_item",
return_value=unprocessed_value,
) as mock_failed_read:
with pytest.raises(Exception) as exec_info, mock.patch(
"tron.config.static_config.load_yaml_file", autospec=True
), mock.patch("tron.config.static_config.build_configuration_watcher", autospec=True):
store.restore(keys)
assert_in("failed to retrieve items with keys", str(exec_info.value))
assert_equal(mock_failed_read.call_count, MAX_UNPROCESSED_KEYS_RETRIES)

# We also need to verify that sleep was called with expected delays
expected_delays = []
base_delay_seconds = 0.5
max_delay_seconds = 10
for attempt in range(1, MAX_UNPROCESSED_KEYS_RETRIES + 1):
expected_delay = min(base_delay_seconds * (2 ** (attempt - 1)), max_delay_seconds)
expected_delays.append(expected_delay)
actual_delays = [call.args[0] for call in mock_sleep.call_args_list]
assert_equal(actual_delays, expected_delays)
) as mock_batch_get_item, mock.patch("time.sleep") as mock_sleep, pytest.raises(Exception) as exec_info:
store.restore(keys)
assert "failed to retrieve items with keys" in str(exec_info.value)
assert mock_batch_get_item.call_count == MAX_UNPROCESSED_KEYS_RETRIES
assert mock_sleep.call_count == MAX_UNPROCESSED_KEYS_RETRIES

def test_restore_exception_propagation(self, store, small_object):
# This test is to ensure that restore propagates exceptions upwards: see DAR-2328
Expand Down

0 comments on commit de2e2f9

Please sign in to comment.