Skip to content

Commit

Permalink
9/n: Temporarily clone generator runs in generation strategy to ensur…
Browse files Browse the repository at this point in the history
…e 1:1 relationship to SQL (#422)

Summary:
Pull Request resolved: #422

To get in the stack of diffs that takes care of setting db_ids on all user-facing Ax objects that have a corresponding SQA object in DB, we temporarily clone generator runs before returning them from generation strategy, to avoid pointing one user-facing instance to two separate SQA objects. This will no longer be necessary after the next phase of the storage & perf project, taskified in T79183560

Reviewed By: ldworkin

Differential Revision: D24786849

fbshipit-source-id: 4fd7cbf331fcbf1e871ed9f0f94b46f1385d551b
  • Loading branch information
lena-kashtelyan authored and facebook-github-bot committed Nov 12, 2020
1 parent 7943317 commit d5b7de1
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
15 changes: 10 additions & 5 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,12 @@ def _gen_multiple(
)

model = not_none(self.model)
generator_runs = []
# TODO[T79183560]: Cloning generator runs here is a temporary measure
# to ensure a 1-to-1 correspondence between user-facing generator runs
# and their stored SQL counterparts. This will be no longer needed soon
# as we move to use foreign keys to avoid storing generotor runs on both
# experiment and generation strategy like we do now.
generator_run_clones = []
for _ in range(num_generator_runs):
try:
generator_run = model.gen(
Expand All @@ -479,17 +484,17 @@ def _gen_multiple(
),
)
generator_run._generation_step_index = self._curr.index
generator_runs.append(generator_run)
self._generator_runs.append(generator_run)
generator_run_clones.append(generator_run.clone())
except DataRequiredError as err:
# Model needs more data, so we log the error and return
# as many generator runs as we were able to produce, unless
# no trials were produced at all (in which case its safe to raise).
if len(generator_runs) == 0:
if len(generator_run_clones) == 0:
raise
logger.debug(f"Model required more data: {err}.")

self._generator_runs.extend(generator_runs)
return generator_runs
return generator_run_clones

# ------------------------- Model selection logic helpers. -------------------------

Expand Down
4 changes: 3 additions & 1 deletion ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,11 @@ def get_next_trial(
trial=trial,
suppress_all_errors=self._suppress_storage_errors,
)
# TODO[T79183560]: Ensure correct handling of generator run when using
# foreign keys.
self._update_generation_strategy_in_db_if_possible(
generation_strategy=self.generation_strategy,
new_generator_runs=trial.generator_runs,
new_generator_runs=[self.generation_strategy._generator_runs[-1]],
suppress_all_errors=self._suppress_storage_errors,
)
return not_none(trial.arm).parameters, trial.index
Expand Down
1 change: 0 additions & 1 deletion ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,6 @@ def data_from_sqa(self, data_sqa: SQAData) -> Data:
description=data_sqa.description,
# NOTE: Need dtype=False, otherwise infers arm_names like
# "4_1" should be int 41.
# pyre-fixme[16]: Module `pd` has no attribute `read_json`.
df=pd.read_json(data_sqa.data_json, dtype=False),
)
dat.db_id = data_sqa.id
Expand Down
23 changes: 11 additions & 12 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,7 @@ def testExperimentSaveAndLoadReducedState(
exp.trials.get(1).generator_run._model_state_after_gen = None
exp.trials.get(1).generator_run._search_space = None
exp.trials.get(1).generator_run._optimization_config = None
# TODO[D24786849]: bring back the check below
# self.assertEqual(loaded_experiment, exp)
self.assertEqual(loaded_experiment, exp)

def testMTExperimentSaveAndLoad(self):
experiment = get_multi_type_experiment(add_trials=True)
Expand Down Expand Up @@ -1181,20 +1180,22 @@ def testEncodeDecodeGenerationStrategyReducedStateLoadExperiment(self):
# Experiment should not be equal, since it would be loaded with reduced
# state along with the generation strategy.
self.assertNotEqual(new_generation_strategy.experiment, experiment)
# Adjust experiment to reduced state.
# Adjust experiment and GS to reduced state.
experiment.trials.get(0).generator_run._model_kwargs = None
experiment.trials.get(0).generator_run._bridge_kwargs = None
experiment.trials.get(0).generator_run._gen_metadata = None
experiment.trials.get(0).generator_run._model_state_after_gen = None
experiment.trials.get(0).generator_run._search_space = None
experiment.trials.get(0).generator_run._optimization_config = None
generation_strategy._generator_runs[0]._model_kwargs = None
generation_strategy._generator_runs[0]._bridge_kwargs = None
generation_strategy._generator_runs[0]._gen_metadata = None
generation_strategy._generator_runs[0]._model_state_after_gen = None
generation_strategy._generator_runs[0]._search_space = None
generation_strategy._generator_runs[0]._optimization_config = None
# Now experiment on generation strategy should be equal to the original
# experiment with reduced state.
# TODO[D24786849]: bring back the check below
# self.assertEqual(new_generation_strategy.experiment, experiment)
# `generation_strategy` shares its generator runs with `experiment`,
# so adjusting the generator run on experiment above also adjusted it
# for the GS; now the reloaded and the original GS-s should be equal.
self.assertEqual(new_generation_strategy.experiment, experiment)
self.assertEqual(new_generation_strategy, generation_strategy)
# Model should be successfully restored in generation strategy even with
# the reduced state.
Expand Down Expand Up @@ -1235,17 +1236,15 @@ def testUpdateGenerationStrategy(self):
# some recently added trials, so we update the mappings to match and check
# that the generation strategies are equal otherwise.
generation_strategy._seen_trial_indices_by_status[TrialStatus.CANDIDATE].add(1)
# TODO[D24786849]: bring back the check below
# self.assertEqual(generation_strategy, loaded_generation_strategy)
self.assertEqual(generation_strategy, loaded_generation_strategy)

# make sure that we can update the experiment too
experiment.description = "foobar"
save_experiment(experiment)
loaded_generation_strategy = load_generation_strategy_by_experiment_name(
experiment_name=experiment.name
)
# TODO[D24786849]: bring back the check below
# self.assertEqual(generation_strategy, loaded_generation_strategy)
self.assertEqual(generation_strategy, loaded_generation_strategy)
self.assertEqual(
generation_strategy._experiment.description, experiment.description
)
Expand Down

0 comments on commit d5b7de1

Please sign in to comment.