diff --git a/tests/serialize/runstate/dynamodb_state_store_test.py b/tests/serialize/runstate/dynamodb_state_store_test.py index d2cbd3475..e727eb98d 100644 --- a/tests/serialize/runstate/dynamodb_state_store_test.py +++ b/tests/serialize/runstate/dynamodb_state_store_test.py @@ -8,7 +8,9 @@ from moto.dynamodb2.responses import dynamo_json_dump from testifycompat import assert_equal +from testifycompat.assertions import assert_in from tron.serialize.runstate.dynamodb_state_store import DynamoDBStateStore +from tron.serialize.runstate.dynamodb_state_store import MAX_UNPROCESSED_KEYS_RETRIES def mock_transact_write_items(self): @@ -294,7 +296,8 @@ def test_delete_item_with_json_partitions(self, store, small_object, large_objec vals = store.restore([key]) assert key not in vals - def test_retry_saving(self, store, small_object, large_object): + @mock.patch("time.sleep", return_value=None) + def test_retry_saving(self, mock_sleep, store, small_object, large_object): with mock.patch( "moto.dynamodb2.responses.DynamoHandler.transact_write_items", side_effect=KeyError("foo"), @@ -307,45 +310,56 @@ def test_retry_saving(self, store, small_object, large_object): except Exception: assert_equal(mock_failed_write.call_count, 3) - def test_retry_reading(self, store, small_object, large_object): + @mock.patch("time.sleep") + @mock.patch("random.uniform") + def test_retry_reading(self, mock_random_uniform, mock_sleep, store, small_object, large_object): unprocessed_value = { - "Responses": { - store.name: [ - { - "index": {"N": "0"}, - "key": {"S": "job_state 0"}, - }, - ], - }, + "Responses": {}, "UnprocessedKeys": { store.name: { - "ConsistentRead": True, "Keys": [ { - "index": {"N": "0"}, "key": {"S": "job_state 0"}, + "index": {"N": "0"}, } ], - }, + "ConsistentRead": True, + } }, - "ResponseMetadata": {}, } keys = [store.build_key("job_state", i) for i in range(1)] value = small_object - pairs = zip(keys, (value for i in range(len(keys)))) + pairs = zip(keys, [value] * len(keys)) store.save(pairs) + store._consume_save_queue() + + # Mock random.uniform to return the upper limit of the range so that we are simulating max jitter + def side_effect_random_uniform(a, b): + return b + + mock_random_uniform.side_effect = side_effect_random_uniform + with mock.patch.object( store.client, "batch_get_item", return_value=unprocessed_value, ) as mock_failed_read: - try: - with mock.patch("tron.config.static_config.load_yaml_file", autospec=True), mock.patch( - "tron.config.static_config.build_configuration_watcher", autospec=True - ): - store.restore(keys) - except Exception: - assert_equal(mock_failed_read.call_count, 10) + with pytest.raises(Exception) as exec_info, mock.patch( + "tron.config.static_config.load_yaml_file", autospec=True + ), mock.patch("tron.config.static_config.build_configuration_watcher", autospec=True): + store.restore(keys) + assert_in("failed to retrieve items with keys", str(exec_info.value)) + assert_equal(mock_failed_read.call_count, MAX_UNPROCESSED_KEYS_RETRIES) + + # We also need to verify that sleep was called with expected delays + expected_delays = [] + base_delay_seconds = 0.5 + max_delay_seconds = 10 + for attempt in range(1, MAX_UNPROCESSED_KEYS_RETRIES + 1): + expected_delay = min(base_delay_seconds * (2 ** (attempt - 1)), max_delay_seconds) + expected_delays.append(expected_delay) + actual_delays = [call.args[0] for call in mock_sleep.call_args_list] + assert_equal(actual_delays, expected_delays) def test_restore_exception_propagation(self, store, small_object): # This test is to ensure that restore propagates exceptions upwards: see DAR-2328 diff --git a/tron/serialize/runstate/dynamodb_state_store.py b/tron/serialize/runstate/dynamodb_state_store.py index fd0234346..2eec708b2 100644 --- a/tron/serialize/runstate/dynamodb_state_store.py +++ b/tron/serialize/runstate/dynamodb_state_store.py @@ -107,8 +107,8 @@ def _get_items(self, table_keys: list) -> object: # let's avoid potentially mutating our input :) cand_keys_list = copy.copy(table_keys) attempts = 0 - base_delay = 0.5 - max_delay = 10 + base_delay_seconds = 0.5 + max_delay_seconds = 10 while len(cand_keys_list) != 0 and attempts < MAX_UNPROCESSED_KEYS_RETRIES: with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: @@ -141,7 +141,7 @@ def _get_items(self, table_keys: list) -> object: if cand_keys_list: attempts += 1 # Exponential backoff for retrying unprocessed keys - exponential_delay = min(base_delay * (2 ** (attempts - 1)), max_delay) + exponential_delay = min(base_delay_seconds * (2 ** (attempts - 1)), max_delay_seconds) # Full jitter (i.e. from 0 to exponential_delay) will help minimize the number and length of calls jitter = random.uniform(0, exponential_delay) delay = jitter