From b0629c2b419ce91d50d8851705645e414055559f Mon Sep 17 00:00:00 2001 From: Andreas Pedersen Date: Fri, 12 Apr 2024 09:19:49 +0200 Subject: [PATCH] pr response --- EasyReflectometry/parameter_utils.py | 14 +++++++++++++- tests/experiment/test_model.py | 8 +++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/EasyReflectometry/parameter_utils.py b/EasyReflectometry/parameter_utils.py index 0805ef59..4dd0b085 100644 --- a/EasyReflectometry/parameter_utils.py +++ b/EasyReflectometry/parameter_utils.py @@ -5,12 +5,24 @@ from easyCore.Objects.ObjectClasses import Parameter -def get_as_parameter(value: Union[Parameter, Number, None], name, default_dict: dict[str, str]) -> Parameter: +def get_as_parameter(value: Union[Parameter, Number, None], name: str, default_dict: dict[str, str]) -> Parameter: + """ + This function creates a parameter for the variable `name`. A parameter has a value and metadata. + If the value already is a parameter, it is returned. + If the value is a number, a parameter is created with this value and metadata from the dictionary. + If the value is None, a parameter is created with the default value and metadata from the dictionary. + + param value: The value to use for the parameter. If None, the default value in the dictionary is used. + param name: The name of the parameter + param default_dict: Dictionary with entry for `name` containing the default value and metadata for the parameter + """ # Should leave the passed dictionary unchanged local_dict = deepcopy(default_dict) if value is None: + # Create a default parameter using both value and metadata from dictionary return Parameter(name, **local_dict[name]) elif isinstance(value, Number): + # Create a parameter using provided value and metadata from dictionary del local_dict[name]['value'] return Parameter(name, value, **local_dict[name]) elif not isinstance(value, Parameter): diff --git a/tests/experiment/test_model.py b/tests/experiment/test_model.py index c317f42c..ef966542 100644 --- a/tests/experiment/test_model.py +++ b/tests/experiment/test_model.py @@ -58,7 +58,13 @@ def test_from_pars(self): o2 = RepeatingMultilayer.from_pars(ls2, 1.0, 'oneLayerItem2') d = Sample.from_pars(o1, o2, name='myModel') resolution_function = constant_resolution_function(2.0) - mod = Model(d, 2, 1e-5, resolution_function, 'newModel') + mod = Model( + sample=d, + scale=2, + background=1e-5, + resolution_function=resolution_function, + name='newModel', + ) assert_equal(mod.name, 'newModel') assert_equal(mod.interface, None) assert_equal(mod.sample.name, 'myModel')