Skip to content

Commit

Permalink
6/n: Adjust decoding functions to set db_ids on user-facing classes u…
Browse files Browse the repository at this point in the history
…pon re-creation

Summary:
NOTE: All diffs in this stack will need to land together. Logic is interdependent, and I'm splitting the diffs for ease of review. It's not possible to split these perfectly such that all tests pass on each diff, so bear with me please : )

Specifically what is done in this diff I plan on replacing with a decorator once all these initial changes are in. I don't want to overcomplicate the work now by introducing that decorator –– I believe it will be easier to figure one out once `db_id` is correctly used and all tests are passing.

Reviewed By: ldworkin

Differential Revision: D24601249

fbshipit-source-id: 57f7ad421b8fe878e1755e3f14829edbaefa1ed2
  • Loading branch information
lena-kashtelyan authored and facebook-github-bot committed Nov 12, 2020
1 parent 0421acd commit e5d4795
Showing 1 changed file with 30 additions and 14 deletions.
44 changes: 30 additions & 14 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def experiment_from_sqa(
value=experiment_sqa.experiment_type, enum=self.config.experiment_type_enum
)
experiment._data_by_trial = dict(data_by_trial)

experiment.db_id = experiment_sqa.id
return experiment

def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter:
Expand All @@ -273,7 +273,7 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter:
raise SQADecodeError( # pragma: no cover
"`lower` and `upper` must be set for RangeParameter."
)
return RangeParameter(
parameter = RangeParameter(
name=parameter_sqa.name,
parameter_type=parameter_sqa.parameter_type,
# pyre-fixme[6]: Expected `float` for 3rd param but got
Expand All @@ -290,7 +290,7 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter:
raise SQADecodeError( # pragma: no cover
"`values` must be set for ChoiceParameter."
)
return ChoiceParameter(
parameter = ChoiceParameter(
name=parameter_sqa.name,
parameter_type=parameter_sqa.parameter_type,
# pyre-fixme[6]: Expected `List[Optional[Union[bool, float, int,
Expand All @@ -303,7 +303,7 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter:
elif parameter_sqa.domain_type == DomainType.FIXED:
# Don't throw an error if parameter_sqa.fixed_value is None;
# that might be the actual value!
return FixedParameter(
parameter = FixedParameter(
name=parameter_sqa.name,
parameter_type=parameter_sqa.parameter_type,
value=parameter_sqa.fixed_value,
Expand All @@ -316,6 +316,9 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter:
"is an invalid domain type."
)

parameter.db_id = parameter_sqa.id
return parameter

def parameter_constraint_from_sqa(
self,
parameter_constraint_sqa: SQAParameterConstraint,
Expand All @@ -340,7 +343,7 @@ def parameter_constraint_from_sqa(
lower_parameter = parameter_map[lower_name]
# pyre-fixme[6]: Expected `str` for 1st param but got `None`.
upper_parameter = parameter_map[upper_name]
return OrderConstraint(
constraint = OrderConstraint(
lower_parameter=lower_parameter, upper_parameter=upper_parameter
)
elif parameter_constraint_sqa.type == ParameterConstraintType.SUM:
Expand All @@ -364,17 +367,20 @@ def parameter_constraint_from_sqa(
a = a_values[0]
is_upper_bound = a == 1
bound = parameter_constraint_sqa.bound * a
return SumConstraint(
constraint = SumConstraint(
parameters=constraint_parameters,
is_upper_bound=is_upper_bound,
bound=bound,
)
else:
return ParameterConstraint(
constraint = ParameterConstraint(
constraint_dict=dict(parameter_constraint_sqa.constraint_dict),
bound=parameter_constraint_sqa.bound,
)

constraint.db_id = parameter_constraint_sqa.id
return constraint

def search_space_from_sqa(
self,
parameters_sqa: List[SQAParameter],
Expand Down Expand Up @@ -414,6 +420,7 @@ def metric_from_sqa_util(self, metric_sqa: SQAMetric) -> Metric:
args["lower_is_better"] = metric_sqa.lower_is_better
args = metric_class.deserialize_init_args(args=args)
metric = metric_class(**args)
metric.db_id = metric_sqa.id
return metric

def metric_from_sqa(
Expand Down Expand Up @@ -570,7 +577,9 @@ def opt_config_and_tracking_metrics_from_sqa(

def arm_from_sqa(self, arm_sqa: SQAArm) -> Arm:
"""Convert SQLAlchemy Arm to Ax Arm."""
return Arm(parameters=arm_sqa.parameters, name=arm_sqa.name)
arm = Arm(parameters=arm_sqa.parameters, name=arm_sqa.name)
arm.db_id = arm_sqa.id
return arm

def abandoned_arm_from_sqa(
self, abandoned_arm_sqa: SQAAbandonedArm
Expand Down Expand Up @@ -671,6 +680,7 @@ def generator_run_from_sqa(
enum=self.config.generator_run_type_enum,
)
generator_run._index = generator_run_sqa.index
generator_run.db_id = generator_run_sqa.id
return generator_run

def generation_strategy_from_sqa(
Expand Down Expand Up @@ -717,7 +727,7 @@ def generation_strategy_from_sqa(
gs._restore_model_from_generator_run(models_enum=models_enum)
else:
gs._restore_model_from_generator_run()
gs._db_id = gs_sqa.id
gs.db_id = gs_sqa.id
return gs

def runner_from_sqa(self, runner_sqa: SQARunner) -> Runner:
Expand All @@ -729,8 +739,10 @@ def runner_from_sqa(self, runner_sqa: SQARunner) -> Runner:
f"is an invalid type."
)
args = runner_class.deserialize_init_args(args=runner_sqa.properties or {})
# pyre-fixme[45]: Cannot instantiate abstract class `Runner`.
return runner_class(**args)
# pyre-ignore[45]: Cannot instantiate abstract class `Runner`.
runner = runner_class(**args)
runner.db_id = runner_sqa.id
return runner

def trial_from_sqa(
self, trial_sqa: SQATrial, experiment: Experiment, reduced_state: bool = False
Expand Down Expand Up @@ -825,13 +837,17 @@ def trial_from_sqa(
)
trial._generation_step_index = trial_sqa.generation_step_index
trial._properties = trial_sqa.properties or {}
trial.db_id = trial_sqa.id
return trial

def data_from_sqa(self, data_sqa: SQAData) -> Data:
"""Convert SQLAlchemy Data to AE Data."""

# Need dtype=False, otherwise infers arm_names like "4_1" should be int 41
return Data(
dat = Data(
description=data_sqa.description,
# NOTE: Need dtype=False, otherwise infers arm_names like
# "4_1" should be int 41.
# pyre-fixme[16]: Module `pd` has no attribute `read_json`.
df=pd.read_json(data_sqa.data_json, dtype=False),
)
dat.db_id = data_sqa.id
return dat

0 comments on commit e5d4795

Please sign in to comment.