Skip to content

Commit

Permalink
Addressing reviews and fixing unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EmanElsaban committed Apr 18, 2024
1 parent a5c94e1 commit f4cd57e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 25 deletions.
12 changes: 4 additions & 8 deletions tests/serialize/runstate/statemanager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,10 @@ def test_restore_runs_for_job(self):
"_restore_dicts",
autospec=True,
) as mock_restore_dicts:
mock_restore_dicts.side_effect = [{"job_a.2": "two"}, {"job_a.3": "three"}]
mock_restore_dicts.side_effect = [{"job_a.2": "two", "job_a.3": "three"}]
runs = self.manager._restore_runs_for_job("job_a", job_state)

assert mock_restore_dicts.call_args_list == [
mock.call(runstate.JOB_RUN_STATE, ["job_a.2"]),
mock.call(runstate.JOB_RUN_STATE, ["job_a.3"]),
]
assert mock_restore_dicts.call_args_list == [mock.call(runstate.JOB_RUN_STATE, ["job_a.2", "job_a.3"])]
assert runs == ["two", "three"]

def test_restore_runs_for_job_one_missing(self):
Expand All @@ -164,12 +161,11 @@ def test_restore_runs_for_job_one_missing(self):
"_restore_dicts",
autospec=True,
) as mock_restore_dicts:
mock_restore_dicts.side_effect = [{}, {"job_a.3": "three"}]
mock_restore_dicts.side_effect = [{"job_a.3": "three", "job_b": {}}]
runs = self.manager._restore_runs_for_job("job_a", job_state)

assert mock_restore_dicts.call_args_list == [
mock.call(runstate.JOB_RUN_STATE, ["job_a.2"]),
mock.call(runstate.JOB_RUN_STATE, ["job_a.3"]),
mock.call(runstate.JOB_RUN_STATE, ["job_a.2", "job_a.3"]),
]
assert runs == ["three"]

Expand Down
28 changes: 19 additions & 9 deletions tron/serialize/runstate/dynamodb_state_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import concurrent.futures
import copy
import logging
import math
import os
Expand All @@ -15,6 +16,7 @@

OBJECT_SIZE = 400000
MAX_SAVE_QUEUE = 500
MAX_ATTEMPTS = 11
log = logging.getLogger(__name__)


Expand Down Expand Up @@ -49,33 +51,41 @@ def restore(self, keys) -> dict:
return vals

def chunk_keys(self, keys: list) -> list:
"""Generates the cand keys list to be used to read from DynamoDB"""
"""Generates a list of chunks of keys to be used to read from DynamoDB"""
# have a for loop here for all the key chunks we want to go over
cand_keys_chunks = []
for i in range(0, len(keys), 100):
# chunks of 100 keys will be in this list
# chunks of at most 100 keys will be in this list as there could be smaller chunks
cand_keys_chunks.append(keys[i : min(len(keys), i + 100)])
return cand_keys_chunks

def _get_items(self, table_keys: list) -> object:
items = []
# precompute the cand_keys and then all we gotta do is submit stuff to the thread pool using the precomputed keys
cand_keys_list = self.chunk_keys(table_keys)
count = 0
while count < len(cand_keys_list):
# let's avoid potentially mutating our input :)
cand_keys_list = copy.copy(table_keys)
attempts_to_retrieve_keys = 0
while attempts_to_retrieve_keys < MAX_ATTEMPTS and len(cand_keys_list) != 0:
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
responses = [
executor.submit(
self.client.batch_get_item, RequestItems={self.name: {"Keys": key, "ConsistentRead": True}}
self.client.batch_get_item,
RequestItems={self.name: {"Keys": chunked_keys, "ConsistentRead": True}},
)
for key in cand_keys_list
for chunked_keys in self.chunk_keys(cand_keys_list)
]
# let's wipe the state so that we can loop back around
# if there are any un-processed keys
# NOTE: we'll re-chunk when submitting to the threadpool
# since it's possible that we've had several chunks fail
# enough keys that we'd otherwise send > 100 keys in a
# request otherwise
cand_keys_list = []
for resp in concurrent.futures.as_completed(responses):
items.extend(resp.result()["Responses"][self.name])
# add any potential unprocessed keys to the thread pool
if resp.result()["UnprocessedKeys"].get(self.name):
cand_keys_list.append(resp.result()["UnprocessedKeys"][self.name]["Keys"])
count += 1
attempts_to_retrieve_keys += 1
return items

def _get_first_partitions(self, keys: list):
Expand Down
20 changes: 12 additions & 8 deletions tron/serialize/runstate/statemanager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import concurrent.futures
import copy
import itertools
import logging
import time
Expand Down Expand Up @@ -165,6 +166,7 @@ def restore(self, job_names, skip_validation=False):
}
for result in concurrent.futures.as_completed(results):
jobs[results[result]]["runs"] = result.result()
# TODO: clean the Mesos code below
frameworks = self._restore_dicts(runstate.MESOS_STATE, ["frameworks"])

state = {
Expand All @@ -176,15 +178,17 @@ def restore(self, job_names, skip_validation=False):
def _restore_runs_for_job(self, job_name, job_state):
"""Restore the state for the runs of each job"""
run_nums = job_state["run_nums"]
runs = []
keys_ids_list = []
# with self._lock:
for run_num in run_nums:
key = jobrun.get_job_run_id(job_name, run_num)
keys_ids_list.append(key)
run_state = list(self._restore_dicts(runstate.JOB_RUN_STATE, keys_ids_list).values())
runs.extend(run_state)
return runs
keys = [jobrun.get_job_run_id(job_name, run_num) for run_num in run_nums]

job_runs_restored_states = self._restore_dicts(runstate.JOB_RUN_STATE, keys)
run_state = copy.copy(job_runs_restored_states)
for key, value in run_state.items():
if value == {}:
log.error(f"Failed to restore {key}, no state found for it!")
job_runs_restored_states.pop(key)
run_state = list(job_runs_restored_states.values())
return run_state

def _restore_metadata(self):
metadata = self._impl.restore([self.metadata_key])
Expand Down

0 comments on commit f4cd57e

Please sign in to comment.