Skip to content

Commit

Permalink
Add unit of measurement to base_delay and max_delay. Expand the retry…
Browse files Browse the repository at this point in the history
…_reading test
  • Loading branch information
KaspariK committed Jan 20, 2025
1 parent 6902ab0 commit b4e423d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 25 deletions.
58 changes: 36 additions & 22 deletions tests/serialize/runstate/dynamodb_state_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from moto.dynamodb2.responses import dynamo_json_dump

from testifycompat import assert_equal
from testifycompat.assertions import assert_in
from tron.serialize.runstate.dynamodb_state_store import DynamoDBStateStore
from tron.serialize.runstate.dynamodb_state_store import MAX_UNPROCESSED_KEYS_RETRIES


def mock_transact_write_items(self):
Expand Down Expand Up @@ -294,7 +296,8 @@ def test_delete_item_with_json_partitions(self, store, small_object, large_objec
vals = store.restore([key])
assert key not in vals

def test_retry_saving(self, store, small_object, large_object):
@mock.patch("time.sleep", return_value=None)
def test_retry_saving(self, mock_sleep, store, small_object, large_object):
with mock.patch(
"moto.dynamodb2.responses.DynamoHandler.transact_write_items",
side_effect=KeyError("foo"),
Expand All @@ -307,45 +310,56 @@ def test_retry_saving(self, store, small_object, large_object):
except Exception:
assert_equal(mock_failed_write.call_count, 3)

def test_retry_reading(self, store, small_object, large_object):
@mock.patch("time.sleep")
@mock.patch("random.uniform")
def test_retry_reading(self, mock_random_uniform, mock_sleep, store, small_object, large_object):
unprocessed_value = {
"Responses": {
store.name: [
{
"index": {"N": "0"},
"key": {"S": "job_state 0"},
},
],
},
"Responses": {},
"UnprocessedKeys": {
store.name: {
"ConsistentRead": True,
"Keys": [
{
"index": {"N": "0"},
"key": {"S": "job_state 0"},
"index": {"N": "0"},
}
],
},
"ConsistentRead": True,
}
},
"ResponseMetadata": {},
}
keys = [store.build_key("job_state", i) for i in range(1)]
value = small_object
pairs = zip(keys, (value for i in range(len(keys))))
pairs = zip(keys, [value] * len(keys))
store.save(pairs)
store._consume_save_queue()

# Mock random.uniform to return the upper limit of the range so that we are simulating max jitter
def side_effect_random_uniform(a, b):
return b

mock_random_uniform.side_effect = side_effect_random_uniform

with mock.patch.object(
store.client,
"batch_get_item",
return_value=unprocessed_value,
) as mock_failed_read:
try:
with mock.patch("tron.config.static_config.load_yaml_file", autospec=True), mock.patch(
"tron.config.static_config.build_configuration_watcher", autospec=True
):
store.restore(keys)
except Exception:
assert_equal(mock_failed_read.call_count, 10)
with pytest.raises(Exception) as exec_info, mock.patch(
"tron.config.static_config.load_yaml_file", autospec=True
), mock.patch("tron.config.static_config.build_configuration_watcher", autospec=True):
store.restore(keys)
assert_in("failed to retrieve items with keys", str(exec_info.value))
assert_equal(mock_failed_read.call_count, MAX_UNPROCESSED_KEYS_RETRIES)

# We also need to verify that sleep was called with expected delays
expected_delays = []
base_delay_seconds = 0.5
max_delay_seconds = 10
for attempt in range(1, MAX_UNPROCESSED_KEYS_RETRIES + 1):
expected_delay = min(base_delay_seconds * (2 ** (attempt - 1)), max_delay_seconds)
expected_delays.append(expected_delay)
actual_delays = [call.args[0] for call in mock_sleep.call_args_list]
assert_equal(actual_delays, expected_delays)

def test_restore_exception_propagation(self, store, small_object):
# This test is to ensure that restore propagates exceptions upwards: see DAR-2328
Expand Down
6 changes: 3 additions & 3 deletions tron/serialize/runstate/dynamodb_state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def _get_items(self, table_keys: list) -> object:
# let's avoid potentially mutating our input :)
cand_keys_list = copy.copy(table_keys)
attempts = 0
base_delay = 0.5
max_delay = 10
base_delay_seconds = 0.5
max_delay_seconds = 10

while len(cand_keys_list) != 0 and attempts < MAX_UNPROCESSED_KEYS_RETRIES:
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
Expand Down Expand Up @@ -141,7 +141,7 @@ def _get_items(self, table_keys: list) -> object:
if cand_keys_list:
attempts += 1
# Exponential backoff for retrying unprocessed keys
exponential_delay = min(base_delay * (2 ** (attempts - 1)), max_delay)
exponential_delay = min(base_delay_seconds * (2 ** (attempts - 1)), max_delay_seconds)
# Full jitter (i.e. from 0 to exponential_delay) will help minimize the number and length of calls
jitter = random.uniform(0, exponential_delay)
delay = jitter
Expand Down

0 comments on commit b4e423d

Please sign in to comment.