From de2e2f93faf753db7088ccc909d1f5d6604516e1 Mon Sep 17 00:00:00 2001 From: Kevin Kaspari Date: Wed, 22 Jan 2025 09:06:14 -0800 Subject: [PATCH] Add test for backoff, Update test_retry_saving to actually test assertions and requeue of failed items. Remove backoff from test_retry_reading --- .../runstate/dynamodb_state_store_test.py | 111 ++++++++++-------- 1 file changed, 60 insertions(+), 51 deletions(-) diff --git a/tests/serialize/runstate/dynamodb_state_store_test.py b/tests/serialize/runstate/dynamodb_state_store_test.py index e727eb98d..9ef083659 100644 --- a/tests/serialize/runstate/dynamodb_state_store_test.py +++ b/tests/serialize/runstate/dynamodb_state_store_test.py @@ -8,7 +8,6 @@ from moto.dynamodb2.responses import dynamo_json_dump from testifycompat import assert_equal -from testifycompat.assertions import assert_in from tron.serialize.runstate.dynamodb_state_store import DynamoDBStateStore from tron.serialize.runstate.dynamodb_state_store import MAX_UNPROCESSED_KEYS_RETRIES @@ -296,70 +295,80 @@ def test_delete_item_with_json_partitions(self, store, small_object, large_objec vals = store.restore([key]) assert key not in vals - @mock.patch("time.sleep", return_value=None) - def test_retry_saving(self, mock_sleep, store, small_object, large_object): - with mock.patch( - "moto.dynamodb2.responses.DynamoHandler.transact_write_items", - side_effect=KeyError("foo"), - ) as mock_failed_write: - keys = [store.build_key("job_state", i) for i in range(1)] - value = small_object - pairs = zip(keys, (value for i in range(len(keys)))) - try: - store.save(pairs) - except Exception: - assert_equal(mock_failed_write.call_count, 3) - - @mock.patch("time.sleep") - @mock.patch("random.uniform") - def test_retry_reading(self, mock_random_uniform, mock_sleep, store, small_object, large_object): + @pytest.mark.parametrize( + "test_object, side_effects, expected_save_errors, expected_queue_length", + [ + # All attempts fail + ("small_object", [KeyError("foo")] * 3, 3, 1), + ("large_object", [KeyError("foo")] * 3, 3, 1), + # Failure followed by success + ("small_object", [KeyError("foo"), {}], 0, 0), + ("large_object", [KeyError("foo"), {}], 0, 0), + ], + ) + def test_retry_saving( + self, test_object, side_effects, expected_save_errors, expected_queue_length, store, small_object, large_object + ): + object_mapping = { + "small_object": small_object, + "large_object": large_object, + } + value = object_mapping[test_object] + + with mock.patch.object( + store.client, + "transact_write_items", + side_effect=side_effects, + ) as mock_transact_write: + keys = [store.build_key("job_state", 0)] + pairs = zip(keys, [value]) + store.save(pairs) + + for _ in side_effects: + store._consume_save_queue() + + assert mock_transact_write.call_count == len(side_effects) + assert store.save_errors == expected_save_errors + assert len(store.save_queue) == expected_queue_length + + @pytest.mark.parametrize( + "attempt, expected_delay", + [ + (1, 0.5), + (2, 1.0), + (3, 2.0), + (4, 4.0), + (5, 8.0), + (6, 10.0), + (7, 10.0), + ], + ) + def test_calculate_backoff_delay(self, store, attempt, expected_delay): + delay = store._calculate_backoff_delay(attempt) + assert_equal(delay, expected_delay) + + def test_retry_reading(self, store): unprocessed_value = { "Responses": {}, "UnprocessedKeys": { store.name: { - "Keys": [ - { - "key": {"S": "job_state 0"}, - "index": {"N": "0"}, - } - ], + "Keys": [{"key": {"S": store.build_key("job_state", 0)}, "index": {"N": "0"}}], "ConsistentRead": True, } }, } - keys = [store.build_key("job_state", i) for i in range(1)] - value = small_object - pairs = zip(keys, [value] * len(keys)) - store.save(pairs) - store._consume_save_queue() - - # Mock random.uniform to return the upper limit of the range so that we are simulating max jitter - def side_effect_random_uniform(a, b): - return b - mock_random_uniform.side_effect = side_effect_random_uniform + keys = [store.build_key("job_state", 0)] with mock.patch.object( store.client, "batch_get_item", return_value=unprocessed_value, - ) as mock_failed_read: - with pytest.raises(Exception) as exec_info, mock.patch( - "tron.config.static_config.load_yaml_file", autospec=True - ), mock.patch("tron.config.static_config.build_configuration_watcher", autospec=True): - store.restore(keys) - assert_in("failed to retrieve items with keys", str(exec_info.value)) - assert_equal(mock_failed_read.call_count, MAX_UNPROCESSED_KEYS_RETRIES) - - # We also need to verify that sleep was called with expected delays - expected_delays = [] - base_delay_seconds = 0.5 - max_delay_seconds = 10 - for attempt in range(1, MAX_UNPROCESSED_KEYS_RETRIES + 1): - expected_delay = min(base_delay_seconds * (2 ** (attempt - 1)), max_delay_seconds) - expected_delays.append(expected_delay) - actual_delays = [call.args[0] for call in mock_sleep.call_args_list] - assert_equal(actual_delays, expected_delays) + ) as mock_batch_get_item, mock.patch("time.sleep") as mock_sleep, pytest.raises(Exception) as exec_info: + store.restore(keys) + assert "failed to retrieve items with keys" in str(exec_info.value) + assert mock_batch_get_item.call_count == MAX_UNPROCESSED_KEYS_RETRIES + assert mock_sleep.call_count == MAX_UNPROCESSED_KEYS_RETRIES def test_restore_exception_propagation(self, store, small_object): # This test is to ensure that restore propagates exceptions upwards: see DAR-2328