Skip to content

Commit

Permalink
Addressing more reviews and adding exception for exceeding max_attempts
Browse files Browse the repository at this point in the history
  • Loading branch information
EmanElsaban committed Apr 29, 2024
1 parent d057c31 commit ef981d3
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 25 deletions.
6 changes: 1 addition & 5 deletions tests/mcp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,14 @@ def teardown_mcp(self):
shutil.rmtree(self.working_dir)
shutil.rmtree(self.config_path)

@mock.patch("tron.mcp.MesosClusterRepository", autospec=True)
def test_restore_state(self, mock_cluster_repo):
def test_restore_state(self):
job_state_data = {"1": "things", "2": "things"}
mesos_state_data = {"3": "things", "4": "things"}
state_data = {
"mesos_state": mesos_state_data,
"job_state": job_state_data,
}
self.mcp.state_watcher.restore.return_value = state_data
action_runner = mock.Mock()
self.mcp.restore_state(action_runner)
mock_cluster_repo.restore_state.assert_called_with(mesos_state_data)
self.mcp.jobs.restore_state.assert_called_with(job_state_data, action_runner)


Expand Down
4 changes: 0 additions & 4 deletions tests/serialize/runstate/statemanager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,19 @@ def test_restore(self):
"one": {"key": "val1"},
"two": {"key": "val2"},
},
# _restore_dicts for MESOS_STATE
{"frameworks": "clusters"},
]

restored_state = self.manager.restore(job_names)
mock_restore_metadata.assert_called_once_with()
assert mock_restore_dicts.call_args_list == [
mock.call(runstate.JOB_STATE, job_names),
mock.call(runstate.MESOS_STATE, ["frameworks"]),
]
assert len(mock_restore_runs.call_args_list) == 2
assert restored_state == {
runstate.JOB_STATE: {
"one": {"key": "val1", "runs": mock_restore_runs.return_value},
"two": {"key": "val2", "runs": mock_restore_runs.return_value},
},
runstate.MESOS_STATE: {"frameworks": "clusters"},
}

def test_restore_runs_for_job(self):
Expand Down
1 change: 0 additions & 1 deletion tron/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def restore_state(self, action_runner):
states = self.state_watcher.restore(self.jobs.get_names())
duration_restore = time.time() - start_time_restore_dynamodb
log.info(f"Time takes to state_watcher restore state directly from dynamodb: {duration_restore}")
MesosClusterRepository.restore_state(states.get("mesos_state", {}))
log.info(
f"Tron will start restoring state for the jobs and will start scheduling them! Time elapsed since Tron started {time.time() - self.boot_time}"
)
Expand Down
15 changes: 11 additions & 4 deletions tron/serialize/runstate/dynamodb_state_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from collections import defaultdict
from collections import OrderedDict
from typing import DefaultDict
from typing import List

import boto3 # type: ignore

from tron.metrics import timer

OBJECT_SIZE = 400000
MAX_SAVE_QUEUE = 500
MAX_ATTEMPTS = 11
MAX_ATTEMPTS = 10
log = logging.getLogger(__name__)


Expand Down Expand Up @@ -50,7 +51,7 @@ def restore(self, keys) -> dict:
vals = self._merge_items(first_items, remaining_items)
return vals

def chunk_keys(self, keys: list) -> list:
def chunk_keys(self, keys: List[dict]) -> List[List[dict]]:
"""Generates a list of chunks of keys to be used to read from DynamoDB"""
# have a for loop here for all the key chunks we want to go over
cand_keys_chunks = []
Expand All @@ -64,7 +65,7 @@ def _get_items(self, table_keys: list) -> object:
# let's avoid potentially mutating our input :)
cand_keys_list = copy.copy(table_keys)
attempts_to_retrieve_keys = 0
while attempts_to_retrieve_keys < MAX_ATTEMPTS and len(cand_keys_list) != 0:
while len(cand_keys_list) != 0:
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
responses = [
executor.submit(
Expand All @@ -83,8 +84,14 @@ def _get_items(self, table_keys: list) -> object:
for resp in concurrent.futures.as_completed(responses):
items.extend(resp.result()["Responses"][self.name])
# add any potential unprocessed keys to the thread pool
if resp.result()["UnprocessedKeys"].get(self.name):
if resp.result()["UnprocessedKeys"].get(self.name) and attempts_to_retrieve_keys < MAX_ATTEMPTS:
cand_keys_list.append(resp.result()["UnprocessedKeys"][self.name]["Keys"])
elif attempts_to_retrieve_keys >= MAX_ATTEMPTS:
failed_keys = resp.result()["UnprocessedKeys"][self.name]["Keys"]
error = Exception(
f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{failed_keys}\n from dynamodb\n{resp.result()}"
)
raise error
attempts_to_retrieve_keys += 1
return items

Expand Down
19 changes: 8 additions & 11 deletions tron/serialize/runstate/statemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,9 @@ def restore(self, job_names, skip_validation=False):
}
for result in concurrent.futures.as_completed(results):
jobs[results[result]]["runs"] = result.result()
# TODO: clean the Mesos code below
frameworks = self._restore_dicts(runstate.MESOS_STATE, ["frameworks"])

state = {
runstate.JOB_STATE: jobs,
runstate.MESOS_STATE: frameworks,
}
return state

Expand All @@ -177,16 +174,16 @@ def _restore_runs_for_job(self, job_name, job_state):
run_nums = job_state["run_nums"]
keys = [jobrun.get_job_run_id(job_name, run_num) for run_num in run_nums]
job_runs_restored_states = self._restore_dicts(runstate.JOB_RUN_STATE, keys)
run_state = copy.copy(job_runs_restored_states)
for key, value in run_state.items():
if value == {}:
log.error(f"Failed to restore {key}, no state found for it!")
job_runs_restored_states.pop(key)
runs = copy.copy(job_runs_restored_states)
for run_id, state in runs.items():
if state == {}:
log.error(f"Failed to restore {run_id}, no state found for it!")
job_runs_restored_states.pop(run_id)

run_state = list(job_runs_restored_states.values())
runs = list(job_runs_restored_states.values())
# We need to sort below otherwise the runs will not be in order
run_state.sort(key=lambda x: x["run_num"], reverse=True)
return run_state
runs.sort(key=lambda x: x["run_num"], reverse=True)
return runs

def _restore_metadata(self):
metadata = self._impl.restore([self.metadata_key])
Expand Down

0 comments on commit ef981d3

Please sign in to comment.