diff --git a/tron/serialize/runstate/dynamodb_state_store.py b/tron/serialize/runstate/dynamodb_state_store.py index 2eec708b2..bfc9f8592 100644 --- a/tron/serialize/runstate/dynamodb_state_store.py +++ b/tron/serialize/runstate/dynamodb_state_store.py @@ -4,7 +4,6 @@ import math import os import pickle -import random import threading import time from collections import defaultdict @@ -102,13 +101,17 @@ def chunk_keys(self, keys: Sequence[T]) -> List[Sequence[T]]: cand_keys_chunks.append(keys[i : min(len(keys), i + 100)]) return cand_keys_chunks + def _calculate_backoff_delay(self, attempt: int) -> float: + base_delay_seconds = 0.5 + max_delay_seconds = 10 + delay = min(base_delay_seconds * (2 ** (attempt - 1)), max_delay_seconds) + return delay + def _get_items(self, table_keys: list) -> object: items = [] # let's avoid potentially mutating our input :) cand_keys_list = copy.copy(table_keys) attempts = 0 - base_delay_seconds = 0.5 - max_delay_seconds = 10 while len(cand_keys_list) != 0 and attempts < MAX_UNPROCESSED_KEYS_RETRIES: with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: @@ -140,13 +143,10 @@ def _get_items(self, table_keys: list) -> object: raise e if cand_keys_list: attempts += 1 - # Exponential backoff for retrying unprocessed keys - exponential_delay = min(base_delay_seconds * (2 ** (attempts - 1)), max_delay_seconds) - # Full jitter (i.e. from 0 to exponential_delay) will help minimize the number and length of calls - jitter = random.uniform(0, exponential_delay) - delay = jitter + delay = self._calculate_backoff_delay(attempts) log.warning( - f"Attempt {attempts}/{MAX_UNPROCESSED_KEYS_RETRIES} - Retrying {len(cand_keys_list)} unprocessed keys after {delay:.2f}s delay." + f"Attempt {attempts}/{MAX_UNPROCESSED_KEYS_RETRIES} - " + f"Retrying {len(cand_keys_list)} unprocessed keys after {delay:.2f}s delay." ) time.sleep(delay) if cand_keys_list: