Skip to content

Commit

Permalink
7.5/n: Set db_id on user-facing classes in update functions
Browse files Browse the repository at this point in the history
Summary: Set db ids in update, not just in save as done in D24601318

Reviewed By: ldworkin

Differential Revision: D24759873

fbshipit-source-id: f70b432abdc00eb50135ed09e8ad712c54ada73f
  • Loading branch information
lena-kashtelyan authored and facebook-github-bot committed Nov 12, 2020
1 parent 4faf98b commit 9bc0d13
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions ax/storage/sqa_store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def _update_trials(
"""Update trials and attach data."""
trial_sqa_class = encoder.config.class_to_sqa_class[Trial]
trial_indices = [trial.index for trial in trials]
obj_to_sqa = []
with session_scope() as session:
experiment_id = _get_experiment_id(
experiment=experiment, encoder=encoder, session=session
Expand All @@ -256,7 +257,8 @@ def _update_trials(
if existing_trial is None:
raise ValueError(f"Trial {trial.index} is not attached to the experiment.")

new_sqa_trial = encoder.trial_to_sqa(trial)
new_sqa_trial, _obj_to_sqa = encoder.trial_to_sqa(trial)
obj_to_sqa.extend(_obj_to_sqa)
existing_trial.update(new_sqa_trial)
updated_sqa_trials.append(existing_trial)

Expand All @@ -265,12 +267,16 @@ def _update_trials(
sqa_data = encoder.data_to_sqa(
data=data, trial_index=trial.index, timestamp=ts
)
obj_to_sqa.append((data, sqa_data))
sqa_data.experiment_id = experiment_id
new_sqa_data.append(sqa_data)

with session_scope() as session:
session.add_all(updated_sqa_trials)
session.add_all(new_sqa_data)
session.flush()

_set_db_ids(obj_to_sqa=obj_to_sqa)


def update_generation_strategy(
Expand All @@ -297,10 +303,11 @@ def _update_generation_strategy(
"""Update generation strategy's current step and attach generator runs."""
gs_sqa_class = encoder.config.class_to_sqa_class[GenerationStrategy]

gs_id = generation_strategy._db_id
gs_id = generation_strategy.db_id
if gs_id is None:
raise ValueError("GenerationStrategy must be saved before being updated.")

obj_to_sqa = []
with session_scope() as session:
experiment_id = _get_experiment_id(
experiment=generation_strategy.experiment, encoder=encoder, session=session
Expand All @@ -314,11 +321,12 @@ def _update_generation_strategy(

generator_runs_sqa = []
for generator_run in generator_runs:
gr_sqa = encoder.generator_run_to_sqa(generator_run=generator_run)
gr_sqa, _obj_to_sqa = encoder.generator_run_to_sqa(generator_run=generator_run)
obj_to_sqa.extend(_obj_to_sqa)
gr_sqa.generation_strategy_id = gs_id
generator_runs_sqa.append(gr_sqa)

with session_scope() as session:
session.add_all(generator_runs_sqa)

# TODO: db ids
_set_db_ids(obj_to_sqa=obj_to_sqa)

0 comments on commit 9bc0d13

Please sign in to comment.