Skip to content

Commit

Permalink
7/n: Set db_id on user-facing classes in saving functions
Browse files Browse the repository at this point in the history
Summary:
NOTE: All diffs in this stack will need to land together. Logic is interdependent, and I'm splitting the diffs for ease of review. It's not possible to split these perfectly such that all tests pass on each diff, so bear with me please : )

In this diff, we actually set `db_id` on user-facing classes during `save_experiment`

Reviewed By: ldworkin

Differential Revision: D24601318

fbshipit-source-id: 8ccfdbe3f874ad72ecc3a957adcfa60c910755bc
  • Loading branch information
lena-kashtelyan authored and facebook-github-bot committed Nov 12, 2020
1 parent e5d4795 commit 4faf98b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
37 changes: 25 additions & 12 deletions ax/storage/sqa_store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,29 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional
from typing import List, Optional, Tuple

from ax.core.base_trial import BaseTrial
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.trial import Trial
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.storage.sqa_store.db import optional_session_scope, session_scope
from ax.storage.sqa_store.db import SQABase, optional_session_scope, session_scope
from ax.storage.sqa_store.encoder import Encoder
from ax.storage.sqa_store.sqa_config import SQAConfig
from ax.utils.common.base import Base
from ax.utils.common.typeutils import not_none
from sqlalchemy.orm import Session


def _set_db_ids(obj_to_sqa: List[Tuple[Base, SQABase]]) -> None:
for obj, sqa_obj in obj_to_sqa:
# pyre-ignore[16]: `ax.storage.sqa_store.db.SQABase` has no attribute `id`
# (all classes in values of this mapping should, and if they don't, error
# should be raised)
obj.db_id = sqa_obj.id


def save_experiment(experiment: Experiment, config: Optional[SQAConfig] = None) -> None:
"""Save experiment (using default SQAConfig)."""
if not isinstance(experiment, Experiment):
Expand Down Expand Up @@ -53,7 +63,7 @@ def _save_experiment(experiment: Experiment, encoder: Encoder) -> None:
# got `Optional[ax.storage.sqa_store.db.SQABase]`.
existing_sqa_experiment=existing_sqa_experiment,
)
new_sqa_experiment = encoder.experiment_to_sqa(experiment)[0]
new_sqa_experiment, obj_to_sqa = encoder.experiment_to_sqa(experiment)

if existing_sqa_experiment is not None:
# Update the SQA object outside of session scope to avoid timeouts.
Expand All @@ -66,6 +76,9 @@ def _save_experiment(experiment: Experiment, encoder: Encoder) -> None:

with session_scope() as session:
session.add(new_sqa_experiment)
session.flush()

_set_db_ids(obj_to_sqa=obj_to_sqa)


def save_generation_strategy(
Expand Down Expand Up @@ -104,7 +117,7 @@ def _save_generation_strategy(
encoder=encoder,
)

gs_sqa, _ = encoder.generation_strategy_to_sqa(
gs_sqa, obj_to_sqa = encoder.generation_strategy_to_sqa(
generation_strategy=generation_strategy, experiment_id=experiment_id
)

Expand All @@ -129,10 +142,9 @@ def _save_generation_strategy(
session.add(gs_sqa)
session.flush() # Ensures generation strategy id is set.

# pyre-fixme[16]: `None` has no attribute `id`.
generation_strategy._db_id = gs_sqa.id
# pyre-fixme[7]: Expected `int` but got `Optional[int]`.
return generation_strategy._db_id
_set_db_ids(obj_to_sqa=obj_to_sqa)

return not_none(generation_strategy.db_id)


def _get_experiment_id(
Expand Down Expand Up @@ -170,6 +182,7 @@ def _save_new_trials(
) -> None:
"""Add new trials to the experiment."""
trial_sqa_class = encoder.config.class_to_sqa_class[Trial]
obj_to_sqa = []
with session_scope() as session:
experiment_id = _get_experiment_id(
experiment=experiment, encoder=encoder, session=session
Expand All @@ -191,15 +204,17 @@ def _save_new_trials(
if trial.index in new_trial_idcs:
raise ValueError(f"Trial {trial.index} appears in `trials` more than once.")

new_sqa_trial = encoder.trial_to_sqa(trial)
new_sqa_trial, _obj_to_sqa = encoder.trial_to_sqa(trial)
obj_to_sqa.extend(_obj_to_sqa)
new_sqa_trial.experiment_id = experiment_id
trials_sqa.append(new_sqa_trial)
new_trial_idcs.add(trial.index)

with session_scope() as session:
session.add_all(trials_sqa)
session.flush()

# TODO: db ids
_set_db_ids(obj_to_sqa=obj_to_sqa)


def update_trial(
Expand Down Expand Up @@ -257,8 +272,6 @@ def _update_trials(
session.add_all(updated_sqa_trials)
session.add_all(new_sqa_data)

# TODO: db ids


def update_generation_strategy(
generation_strategy: GenerationStrategy,
Expand Down
6 changes: 4 additions & 2 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def testExperimentSaveAndLoad(self):
get_experiment_with_multi_objective(),
get_experiment_with_scalarized_objective(),
]:
self.assertIsNone(exp.db_id)
save_experiment(exp)
self.assertIsNotNone(exp.db_id)
loaded_experiment = load_experiment(exp.name)
self.assertEqual(loaded_experiment, exp)

Expand Down Expand Up @@ -1097,8 +1099,8 @@ def testEncodeDecodeGenerationStrategy(self):
generation_strategy = get_generation_strategy(with_callable_model_kwarg=False)
experiment.new_trial(generation_strategy.gen(experiment=experiment))
generation_strategy.gen(experiment, data=get_branin_data())
save_generation_strategy(generation_strategy=generation_strategy)
save_experiment(experiment)
save_generation_strategy(generation_strategy=generation_strategy)
# Try restoring the generation strategy using the experiment its
# attached to.
new_generation_strategy = load_generation_strategy_by_experiment_name(
Expand Down Expand Up @@ -1191,7 +1193,7 @@ def testEncodeDecodeGenerationStrategyReducedStateLoadExperiment(self):
# `generation_strategy` shares its generator runs with `experiment`,
# so adjusting the generator run on experiment above also adjusted it
# for the GS; now the reloaded and the original GS-s should be equal.
self.assertEqual(generation_strategy, new_generation_strategy)
self.assertEqual(new_generation_strategy, generation_strategy)
# Model should be successfully restored in generation strategy even with
# the reduced state.
self.assertIsInstance(new_generation_strategy._steps[0].model, Models)
Expand Down

0 comments on commit 4faf98b

Please sign in to comment.