From 4faf98bc8447328706c1d7d1c4eb241e5b937df4 Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Thu, 12 Nov 2020 08:09:01 -0800 Subject: [PATCH] 7/n: Set `db_id` on user-facing classes in saving functions Summary: NOTE: All diffs in this stack will need to land together. Logic is interdependent, and I'm splitting the diffs for ease of review. It's not possible to split these perfectly such that all tests pass on each diff, so bear with me please : ) In this diff, we actually set `db_id` on user-facing classes during `save_experiment` Reviewed By: ldworkin Differential Revision: D24601318 fbshipit-source-id: 8ccfdbe3f874ad72ecc3a957adcfa60c910755bc --- ax/storage/sqa_store/save.py | 37 +++++++++++++------- ax/storage/sqa_store/tests/test_sqa_store.py | 6 ++-- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index e84c50c6131..dd1472f07da 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -4,19 +4,29 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional +from typing import List, Optional, Tuple from ax.core.base_trial import BaseTrial from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.trial import Trial from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.storage.sqa_store.db import optional_session_scope, session_scope +from ax.storage.sqa_store.db import SQABase, optional_session_scope, session_scope from ax.storage.sqa_store.encoder import Encoder from ax.storage.sqa_store.sqa_config import SQAConfig +from ax.utils.common.base import Base +from ax.utils.common.typeutils import not_none from sqlalchemy.orm import Session +def _set_db_ids(obj_to_sqa: List[Tuple[Base, SQABase]]) -> None: + for obj, sqa_obj in obj_to_sqa: + # pyre-ignore[16]: `ax.storage.sqa_store.db.SQABase` has no attribute `id` + # (all classes in values of this mapping should, and if they don't, error + # should be raised) + obj.db_id = sqa_obj.id + + def save_experiment(experiment: Experiment, config: Optional[SQAConfig] = None) -> None: """Save experiment (using default SQAConfig).""" if not isinstance(experiment, Experiment): @@ -53,7 +63,7 @@ def _save_experiment(experiment: Experiment, encoder: Encoder) -> None: # got `Optional[ax.storage.sqa_store.db.SQABase]`. existing_sqa_experiment=existing_sqa_experiment, ) - new_sqa_experiment = encoder.experiment_to_sqa(experiment)[0] + new_sqa_experiment, obj_to_sqa = encoder.experiment_to_sqa(experiment) if existing_sqa_experiment is not None: # Update the SQA object outside of session scope to avoid timeouts. @@ -66,6 +76,9 @@ def _save_experiment(experiment: Experiment, encoder: Encoder) -> None: with session_scope() as session: session.add(new_sqa_experiment) + session.flush() + + _set_db_ids(obj_to_sqa=obj_to_sqa) def save_generation_strategy( @@ -104,7 +117,7 @@ def _save_generation_strategy( encoder=encoder, ) - gs_sqa, _ = encoder.generation_strategy_to_sqa( + gs_sqa, obj_to_sqa = encoder.generation_strategy_to_sqa( generation_strategy=generation_strategy, experiment_id=experiment_id ) @@ -129,10 +142,9 @@ def _save_generation_strategy( session.add(gs_sqa) session.flush() # Ensures generation strategy id is set. - # pyre-fixme[16]: `None` has no attribute `id`. - generation_strategy._db_id = gs_sqa.id - # pyre-fixme[7]: Expected `int` but got `Optional[int]`. - return generation_strategy._db_id + _set_db_ids(obj_to_sqa=obj_to_sqa) + + return not_none(generation_strategy.db_id) def _get_experiment_id( @@ -170,6 +182,7 @@ def _save_new_trials( ) -> None: """Add new trials to the experiment.""" trial_sqa_class = encoder.config.class_to_sqa_class[Trial] + obj_to_sqa = [] with session_scope() as session: experiment_id = _get_experiment_id( experiment=experiment, encoder=encoder, session=session @@ -191,15 +204,17 @@ def _save_new_trials( if trial.index in new_trial_idcs: raise ValueError(f"Trial {trial.index} appears in `trials` more than once.") - new_sqa_trial = encoder.trial_to_sqa(trial) + new_sqa_trial, _obj_to_sqa = encoder.trial_to_sqa(trial) + obj_to_sqa.extend(_obj_to_sqa) new_sqa_trial.experiment_id = experiment_id trials_sqa.append(new_sqa_trial) new_trial_idcs.add(trial.index) with session_scope() as session: session.add_all(trials_sqa) + session.flush() - # TODO: db ids + _set_db_ids(obj_to_sqa=obj_to_sqa) def update_trial( @@ -257,8 +272,6 @@ def _update_trials( session.add_all(updated_sqa_trials) session.add_all(new_sqa_data) - # TODO: db ids - def update_generation_strategy( generation_strategy: GenerationStrategy, diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 04b9b29d071..e9bcab6986f 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -211,7 +211,9 @@ def testExperimentSaveAndLoad(self): get_experiment_with_multi_objective(), get_experiment_with_scalarized_objective(), ]: + self.assertIsNone(exp.db_id) save_experiment(exp) + self.assertIsNotNone(exp.db_id) loaded_experiment = load_experiment(exp.name) self.assertEqual(loaded_experiment, exp) @@ -1097,8 +1099,8 @@ def testEncodeDecodeGenerationStrategy(self): generation_strategy = get_generation_strategy(with_callable_model_kwarg=False) experiment.new_trial(generation_strategy.gen(experiment=experiment)) generation_strategy.gen(experiment, data=get_branin_data()) - save_generation_strategy(generation_strategy=generation_strategy) save_experiment(experiment) + save_generation_strategy(generation_strategy=generation_strategy) # Try restoring the generation strategy using the experiment its # attached to. new_generation_strategy = load_generation_strategy_by_experiment_name( @@ -1191,7 +1193,7 @@ def testEncodeDecodeGenerationStrategyReducedStateLoadExperiment(self): # `generation_strategy` shares its generator runs with `experiment`, # so adjusting the generator run on experiment above also adjusted it # for the GS; now the reloaded and the original GS-s should be equal. - self.assertEqual(generation_strategy, new_generation_strategy) + self.assertEqual(new_generation_strategy, generation_strategy) # Model should be successfully restored in generation strategy even with # the reduced state. self.assertIsInstance(new_generation_strategy._steps[0].model, Models)