From e5d47950e2d44613c9f6fe514ac742bc2e819b32 Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Thu, 12 Nov 2020 08:09:01 -0800 Subject: [PATCH] 6/n: Adjust decoding functions to set db_ids on user-facing classes upon re-creation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: NOTE: All diffs in this stack will need to land together. Logic is interdependent, and I'm splitting the diffs for ease of review. It's not possible to split these perfectly such that all tests pass on each diff, so bear with me please : ) Specifically what is done in this diff I plan on replacing with a decorator once all these initial changes are in. I don't want to overcomplicate the work now by introducing that decorator –– I believe it will be easier to figure one out once `db_id` is correctly used and all tests are passing. Reviewed By: ldworkin Differential Revision: D24601249 fbshipit-source-id: 57f7ad421b8fe878e1755e3f14829edbaefa1ed2 --- ax/storage/sqa_store/decoder.py | 44 ++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 3f05231da39..57fe20cbfb9 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -263,7 +263,7 @@ def experiment_from_sqa( value=experiment_sqa.experiment_type, enum=self.config.experiment_type_enum ) experiment._data_by_trial = dict(data_by_trial) - + experiment.db_id = experiment_sqa.id return experiment def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: @@ -273,7 +273,7 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: raise SQADecodeError( # pragma: no cover "`lower` and `upper` must be set for RangeParameter." ) - return RangeParameter( + parameter = RangeParameter( name=parameter_sqa.name, parameter_type=parameter_sqa.parameter_type, # pyre-fixme[6]: Expected `float` for 3rd param but got @@ -290,7 +290,7 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: raise SQADecodeError( # pragma: no cover "`values` must be set for ChoiceParameter." ) - return ChoiceParameter( + parameter = ChoiceParameter( name=parameter_sqa.name, parameter_type=parameter_sqa.parameter_type, # pyre-fixme[6]: Expected `List[Optional[Union[bool, float, int, @@ -303,7 +303,7 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: elif parameter_sqa.domain_type == DomainType.FIXED: # Don't throw an error if parameter_sqa.fixed_value is None; # that might be the actual value! - return FixedParameter( + parameter = FixedParameter( name=parameter_sqa.name, parameter_type=parameter_sqa.parameter_type, value=parameter_sqa.fixed_value, @@ -316,6 +316,9 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: "is an invalid domain type." ) + parameter.db_id = parameter_sqa.id + return parameter + def parameter_constraint_from_sqa( self, parameter_constraint_sqa: SQAParameterConstraint, @@ -340,7 +343,7 @@ def parameter_constraint_from_sqa( lower_parameter = parameter_map[lower_name] # pyre-fixme[6]: Expected `str` for 1st param but got `None`. upper_parameter = parameter_map[upper_name] - return OrderConstraint( + constraint = OrderConstraint( lower_parameter=lower_parameter, upper_parameter=upper_parameter ) elif parameter_constraint_sqa.type == ParameterConstraintType.SUM: @@ -364,17 +367,20 @@ def parameter_constraint_from_sqa( a = a_values[0] is_upper_bound = a == 1 bound = parameter_constraint_sqa.bound * a - return SumConstraint( + constraint = SumConstraint( parameters=constraint_parameters, is_upper_bound=is_upper_bound, bound=bound, ) else: - return ParameterConstraint( + constraint = ParameterConstraint( constraint_dict=dict(parameter_constraint_sqa.constraint_dict), bound=parameter_constraint_sqa.bound, ) + constraint.db_id = parameter_constraint_sqa.id + return constraint + def search_space_from_sqa( self, parameters_sqa: List[SQAParameter], @@ -414,6 +420,7 @@ def metric_from_sqa_util(self, metric_sqa: SQAMetric) -> Metric: args["lower_is_better"] = metric_sqa.lower_is_better args = metric_class.deserialize_init_args(args=args) metric = metric_class(**args) + metric.db_id = metric_sqa.id return metric def metric_from_sqa( @@ -570,7 +577,9 @@ def opt_config_and_tracking_metrics_from_sqa( def arm_from_sqa(self, arm_sqa: SQAArm) -> Arm: """Convert SQLAlchemy Arm to Ax Arm.""" - return Arm(parameters=arm_sqa.parameters, name=arm_sqa.name) + arm = Arm(parameters=arm_sqa.parameters, name=arm_sqa.name) + arm.db_id = arm_sqa.id + return arm def abandoned_arm_from_sqa( self, abandoned_arm_sqa: SQAAbandonedArm @@ -671,6 +680,7 @@ def generator_run_from_sqa( enum=self.config.generator_run_type_enum, ) generator_run._index = generator_run_sqa.index + generator_run.db_id = generator_run_sqa.id return generator_run def generation_strategy_from_sqa( @@ -717,7 +727,7 @@ def generation_strategy_from_sqa( gs._restore_model_from_generator_run(models_enum=models_enum) else: gs._restore_model_from_generator_run() - gs._db_id = gs_sqa.id + gs.db_id = gs_sqa.id return gs def runner_from_sqa(self, runner_sqa: SQARunner) -> Runner: @@ -729,8 +739,10 @@ def runner_from_sqa(self, runner_sqa: SQARunner) -> Runner: f"is an invalid type." ) args = runner_class.deserialize_init_args(args=runner_sqa.properties or {}) - # pyre-fixme[45]: Cannot instantiate abstract class `Runner`. - return runner_class(**args) + # pyre-ignore[45]: Cannot instantiate abstract class `Runner`. + runner = runner_class(**args) + runner.db_id = runner_sqa.id + return runner def trial_from_sqa( self, trial_sqa: SQATrial, experiment: Experiment, reduced_state: bool = False @@ -825,13 +837,17 @@ def trial_from_sqa( ) trial._generation_step_index = trial_sqa.generation_step_index trial._properties = trial_sqa.properties or {} + trial.db_id = trial_sqa.id return trial def data_from_sqa(self, data_sqa: SQAData) -> Data: """Convert SQLAlchemy Data to AE Data.""" - - # Need dtype=False, otherwise infers arm_names like "4_1" should be int 41 - return Data( + dat = Data( description=data_sqa.description, + # NOTE: Need dtype=False, otherwise infers arm_names like + # "4_1" should be int 41. + # pyre-fixme[16]: Module `pd` has no attribute `read_json`. df=pd.read_json(data_sqa.data_json, dtype=False), ) + dat.db_id = data_sqa.id + return dat