diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index ee0f93b684f..c2807b29ea3 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -468,7 +468,12 @@ def _gen_multiple( ) model = not_none(self.model) - generator_runs = [] + # TODO[T79183560]: Cloning generator runs here is a temporary measure + # to ensure a 1-to-1 correspondence between user-facing generator runs + # and their stored SQL counterparts. This will be no longer needed soon + # as we move to use foreign keys to avoid storing generotor runs on both + # experiment and generation strategy like we do now. + generator_run_clones = [] for _ in range(num_generator_runs): try: generator_run = model.gen( @@ -479,17 +484,17 @@ def _gen_multiple( ), ) generator_run._generation_step_index = self._curr.index - generator_runs.append(generator_run) + self._generator_runs.append(generator_run) + generator_run_clones.append(generator_run.clone()) except DataRequiredError as err: # Model needs more data, so we log the error and return # as many generator runs as we were able to produce, unless # no trials were produced at all (in which case its safe to raise). - if len(generator_runs) == 0: + if len(generator_run_clones) == 0: raise logger.debug(f"Model required more data: {err}.") - self._generator_runs.extend(generator_runs) - return generator_runs + return generator_run_clones # ------------------------- Model selection logic helpers. ------------------------- diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 7f999b877e3..34dc4b9d4a3 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -313,9 +313,11 @@ def get_next_trial( trial=trial, suppress_all_errors=self._suppress_storage_errors, ) + # TODO[T79183560]: Ensure correct handling of generator run when using + # foreign keys. self._update_generation_strategy_in_db_if_possible( generation_strategy=self.generation_strategy, - new_generator_runs=trial.generator_runs, + new_generator_runs=[self.generation_strategy._generator_runs[-1]], suppress_all_errors=self._suppress_storage_errors, ) return not_none(trial.arm).parameters, trial.index diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 57fe20cbfb9..6fd67bb1ff3 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -846,7 +846,6 @@ def data_from_sqa(self, data_sqa: SQAData) -> Data: description=data_sqa.description, # NOTE: Need dtype=False, otherwise infers arm_names like # "4_1" should be int 41. - # pyre-fixme[16]: Module `pd` has no attribute `read_json`. df=pd.read_json(data_sqa.data_json, dtype=False), ) dat.db_id = data_sqa.id diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 72da12700d7..e6376599ab6 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -297,8 +297,7 @@ def testExperimentSaveAndLoadReducedState( exp.trials.get(1).generator_run._model_state_after_gen = None exp.trials.get(1).generator_run._search_space = None exp.trials.get(1).generator_run._optimization_config = None - # TODO[D24786849]: bring back the check below - # self.assertEqual(loaded_experiment, exp) + self.assertEqual(loaded_experiment, exp) def testMTExperimentSaveAndLoad(self): experiment = get_multi_type_experiment(add_trials=True) @@ -1181,20 +1180,22 @@ def testEncodeDecodeGenerationStrategyReducedStateLoadExperiment(self): # Experiment should not be equal, since it would be loaded with reduced # state along with the generation strategy. self.assertNotEqual(new_generation_strategy.experiment, experiment) - # Adjust experiment to reduced state. + # Adjust experiment and GS to reduced state. experiment.trials.get(0).generator_run._model_kwargs = None experiment.trials.get(0).generator_run._bridge_kwargs = None experiment.trials.get(0).generator_run._gen_metadata = None experiment.trials.get(0).generator_run._model_state_after_gen = None experiment.trials.get(0).generator_run._search_space = None experiment.trials.get(0).generator_run._optimization_config = None + generation_strategy._generator_runs[0]._model_kwargs = None + generation_strategy._generator_runs[0]._bridge_kwargs = None + generation_strategy._generator_runs[0]._gen_metadata = None + generation_strategy._generator_runs[0]._model_state_after_gen = None + generation_strategy._generator_runs[0]._search_space = None + generation_strategy._generator_runs[0]._optimization_config = None # Now experiment on generation strategy should be equal to the original # experiment with reduced state. - # TODO[D24786849]: bring back the check below - # self.assertEqual(new_generation_strategy.experiment, experiment) - # `generation_strategy` shares its generator runs with `experiment`, - # so adjusting the generator run on experiment above also adjusted it - # for the GS; now the reloaded and the original GS-s should be equal. + self.assertEqual(new_generation_strategy.experiment, experiment) self.assertEqual(new_generation_strategy, generation_strategy) # Model should be successfully restored in generation strategy even with # the reduced state. @@ -1235,8 +1236,7 @@ def testUpdateGenerationStrategy(self): # some recently added trials, so we update the mappings to match and check # that the generation strategies are equal otherwise. generation_strategy._seen_trial_indices_by_status[TrialStatus.CANDIDATE].add(1) - # TODO[D24786849]: bring back the check below - # self.assertEqual(generation_strategy, loaded_generation_strategy) + self.assertEqual(generation_strategy, loaded_generation_strategy) # make sure that we can update the experiment too experiment.description = "foobar" @@ -1244,8 +1244,7 @@ def testUpdateGenerationStrategy(self): loaded_generation_strategy = load_generation_strategy_by_experiment_name( experiment_name=experiment.name ) - # TODO[D24786849]: bring back the check below - # self.assertEqual(generation_strategy, loaded_generation_strategy) + self.assertEqual(generation_strategy, loaded_generation_strategy) self.assertEqual( generation_strategy._experiment.description, experiment.description )