Skip to content

Commit

Permalink
132 add direct initialization as replacement for from pars (#133)
Browse files Browse the repository at this point in the history
* changing constructor

* constructor

* code cleaning

* tests

* pr response, test dict remains unchanged

* pr response
  • Loading branch information
andped10 authored Apr 12, 2024
1 parent a255a00 commit 87ddabd
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 81 deletions.
75 changes: 16 additions & 59 deletions EasyReflectometry/experiment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

__author__ = 'github.com/arm61'

from copy import deepcopy
from numbers import Number
from typing import Callable
from typing import Union

import yaml
from easyCore import np
from easyCore.Objects.ObjectClasses import BaseObj
from easyCore.Objects.ObjectClasses import Parameter

from EasyReflectometry.parameter_utils import get_as_parameter
from EasyReflectometry.sample import BaseAssembly
from EasyReflectometry.sample import Layer
from EasyReflectometry.sample import LayerCollection
Expand Down Expand Up @@ -53,10 +55,10 @@ class Model(BaseObj):

def __init__(
self,
sample: Sample,
scale: Parameter,
background: Parameter,
resolution_function: Callable[[np.array], float],
sample: Union[Sample, None] = None,
scale: Union[Parameter, Number, None] = None,
background: Union[Parameter, Number, None] = None,
resolution_function: Union[Callable[[np.array], float], None] = None,
name: str = 'EasyModel',
interface=None,
):
Expand All @@ -69,6 +71,15 @@ def __init__(
:param interface: Calculator interface, defaults to `None`.
"""

if sample is None:
sample = Sample.default()
if resolution_function is None:
resolution_function = constant_resolution_function(MODEL_DETAILS['resolution']['value'])

scale = get_as_parameter(scale, 'scale', MODEL_DETAILS)
background = get_as_parameter(background, 'background', MODEL_DETAILS)

super().__init__(
name=name,
sample=sample,
Expand All @@ -78,60 +89,6 @@ def __init__(
self.interface = interface
self._resolution_function = resolution_function

# Class methods for instance creation
@classmethod
def default(cls, interface=None) -> Model:
"""Default instance of the reflectometry experiment model.
:param interface: Calculator interface, defaults to `None`.
"""
sample = Sample.default()
scale = Parameter('scale', **MODEL_DETAILS['scale'])
background = Parameter('background', **MODEL_DETAILS['background'])
resolution_function = constant_resolution_function(MODEL_DETAILS['resolution']['value'])

return cls(
sample=sample,
scale=scale,
background=background,
resolution_function=resolution_function,
interface=interface,
)

@classmethod
def from_pars(
cls,
sample: Sample,
scale: float,
background: float,
resolution_function: Callable[[np.array], float],
name: str = 'EasyModel',
interface=None,
) -> Model:
"""Instance of a reflectometry experiment model where the parameters are known.
:param sample: The sample being modelled.
:param scale: Scaling factor of profile.
:param background: Linear background magnitude.
:param name: Name of the layer, defaults to 'EasyModel'.
:param interface: Calculator interface, defaults to `None`.
"""
default_options = deepcopy(MODEL_DETAILS)
del default_options['scale']['value']
del default_options['background']['value']

scale = Parameter('scale', scale, **default_options['scale'])
background = Parameter('background', background, **default_options['background'])

return cls(
sample=sample,
scale=scale,
background=background,
resolution_function=resolution_function,
name=name,
interface=interface,
)

def add_item(self, *assemblies: list[BaseAssembly]) -> None:
"""Add a layer or item to the model sample.
Expand Down
30 changes: 30 additions & 0 deletions EasyReflectometry/parameter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from copy import deepcopy
from numbers import Number
from typing import Union

from easyCore.Objects.ObjectClasses import Parameter


def get_as_parameter(value: Union[Parameter, Number, None], name: str, default_dict: dict[str, str]) -> Parameter:
"""
This function creates a parameter for the variable `name`. A parameter has a value and metadata.
If the value already is a parameter, it is returned.
If the value is a number, a parameter is created with this value and metadata from the dictionary.
If the value is None, a parameter is created with the default value and metadata from the dictionary.
param value: The value to use for the parameter. If None, the default value in the dictionary is used.
param name: The name of the parameter
param default_dict: Dictionary with entry for `name` containing the default value and metadata for the parameter
"""
# Should leave the passed dictionary unchanged
local_dict = deepcopy(default_dict)
if value is None:
# Create a default parameter using both value and metadata from dictionary
return Parameter(name, **local_dict[name])
elif isinstance(value, Number):
# Create a parameter using provided value and metadata from dictionary
del local_dict[name]['value']
return Parameter(name, value, **local_dict[name])
elif not isinstance(value, Parameter):
raise ValueError(f'{name} must be a Parameter, a number, or None.')
return value
48 changes: 27 additions & 21 deletions tests/experiment/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class TestModel(unittest.TestCase):
def test_default(self):
p = Model.default()
p = Model()
assert_equal(p.name, 'EasyModel')
assert_equal(p.interface, None)
assert_equal(p.sample.name, 'EasySample')
Expand Down Expand Up @@ -58,7 +58,13 @@ def test_from_pars(self):
o2 = RepeatingMultilayer.from_pars(ls2, 1.0, 'oneLayerItem2')
d = Sample.from_pars(o1, o2, name='myModel')
resolution_function = constant_resolution_function(2.0)
mod = Model.from_pars(d, 2, 1e-5, resolution_function, 'newModel')
mod = Model(
sample=d,
scale=2,
background=1e-5,
resolution_function=resolution_function,
name='newModel',
)
assert_equal(mod.name, 'newModel')
assert_equal(mod.interface, None)
assert_equal(mod.sample.name, 'myModel')
Expand Down Expand Up @@ -90,7 +96,7 @@ def test_add_item(self):
multilayer = Multilayer.default()
d = Sample.from_pars(o1, name='myModel')
resolution_function = constant_resolution_function(2.0)
mod = Model.from_pars(d, 2, 1e-5, resolution_function, 'newModel')
mod = Model(d, 2, 1e-5, resolution_function, 'newModel')
assert_equal(len(mod.sample), 1)
mod.add_item(o2)
assert_equal(len(mod.sample), 2)
Expand All @@ -103,7 +109,7 @@ def test_add_item(self):

def test_add_item_exception(self):
# When
mod = Model.default()
mod = Model()

# Then Expect
with pytest.raises(ValueError):
Expand All @@ -121,7 +127,7 @@ def test_add_item_with_interface_refnx(self):
o2 = RepeatingMultilayer.from_pars(ls2, 1.0, 'oneLayerItem2')
d = Sample.from_pars(o1, name='myModel')
resolution_function = constant_resolution_function(2.0)
mod = Model.from_pars(d, 2, 1e-5, resolution_function, 'newModel', interface=interface)
mod = Model(d, 2, 1e-5, resolution_function, 'newModel', interface=interface)
assert_equal(len(mod.interface()._wrapper.storage['item']), 1)
assert_equal(len(mod.interface()._wrapper.storage['layer']), 2)
mod.add_item(o2)
Expand All @@ -141,7 +147,7 @@ def test_add_item_with_interface_refl1d(self):
o2 = RepeatingMultilayer.from_pars(ls2, 1.0, 'oneLayerItem2')
d = Sample.from_pars(o1, name='myModel')
resolution_function = constant_resolution_function(2.0)
mod = Model.from_pars(d, 2, 1e-5, resolution_function, 'newModel', interface=interface)
mod = Model(d, 2, 1e-5, resolution_function, 'newModel', interface=interface)
assert_equal(len(mod.interface()._wrapper.storage['item']), 1)
assert_equal(len(mod.interface()._wrapper.storage['layer']), 2)
mod.add_item(o2)
Expand All @@ -160,7 +166,7 @@ def test_add_item_with_interface_refl1d(self):
# o1 = RepeatingMultilayer.from_pars(ls1, 2.0, 'twoLayerItem1')
# o2 = RepeatingMultilayer.from_pars(ls2, 1.0, 'oneLayerItem2')
# d = Sample.from_pars(o1, name='myModel')
# mod = Model.from_pars(d, 2, 1e-5, 2.0, 'newModel', interface=interface)
# mod = Model(d, 2, 1e-5, 2.0, 'newModel', interface=interface)
# assert_equal(len(mod.interface()._wrapper.storage['item']), 1)
# assert_equal(len(mod.interface()._wrapper.storage['layer']), 2)
# mod.add_item(o2)
Expand All @@ -178,7 +184,7 @@ def test_duplicate_item(self):
o2 = RepeatingMultilayer.from_pars(ls2, 1.0, 'oneLayerItem2')
d = Sample.from_pars(o1, name='myModel')
resolution_function = constant_resolution_function(2.0)
mod = Model.from_pars(d, 2, 1e-5, resolution_function, 'newModel')
mod = Model(d, 2, 1e-5, resolution_function, 'newModel')
assert_equal(len(mod.sample), 1)
mod.add_item(o2)
assert_equal(len(mod.sample), 2)
Expand All @@ -199,7 +205,7 @@ def test_duplicate_item_with_interface_refnx(self):
o2 = RepeatingMultilayer.from_pars(ls2, 1.0, 'oneLayerItem2')
d = Sample.from_pars(o1, name='myModel')
resolution_function = constant_resolution_function(2.0)
mod = Model.from_pars(d, 2, 1e-5, resolution_function, 'newModel', interface=interface)
mod = Model(d, 2, 1e-5, resolution_function, 'newModel', interface=interface)
assert_equal(len(mod.interface()._wrapper.storage['item']), 1)
mod.add_item(o2)
assert_equal(len(mod.interface()._wrapper.storage['item']), 2)
Expand All @@ -219,7 +225,7 @@ def test_duplicate_item_with_interface_refl1d(self):
o2 = RepeatingMultilayer.from_pars(ls2, 1.0, 'oneLayerItem2')
d = Sample.from_pars(o1, name='myModel')
resolution_function = constant_resolution_function(2.0)
mod = Model.from_pars(d, 2, 1e-5, resolution_function, 'newModel', interface=interface)
mod = Model(d, 2, 1e-5, resolution_function, 'newModel', interface=interface)
assert_equal(len(mod.interface()._wrapper.storage['item']), 1)
mod.add_item(o2)
assert_equal(len(mod.interface()._wrapper.storage['item']), 2)
Expand All @@ -238,7 +244,7 @@ def test_duplicate_item_with_interface_refl1d(self):
# o1 = RepeatingMultilayer.from_pars(ls1, 2.0, 'twoLayerItem1')
# o2 = RepeatingMultilayer.from_pars(ls2, 1.0, 'oneLayerItem2')
# d = Sample.from_pars(o1, name='myModel')
# mod = Model.from_pars(d, 2, 1e-5, 2.0, 'newModel', interface=interface)
# mod = Model(d, 2, 1e-5, 2.0, 'newModel', interface=interface)
# assert_equal(len(mod.interface()._wrapper.storage['item']), 1)
# mod.add_item(o2)
# assert_equal(len(mod.interface()._wrapper.storage['item']), 2)
Expand All @@ -256,7 +262,7 @@ def test_remove_item(self):
o2 = RepeatingMultilayer.from_pars(ls2, 1.0, 'oneLayerItem2')
d = Sample.from_pars(o1, name='myModel')
resolution_function = constant_resolution_function(2.0)
mod = Model.from_pars(d, 2, 1e-5, resolution_function, 'newModel')
mod = Model(d, 2, 1e-5, resolution_function, 'newModel')
assert_equal(len(mod.sample), 1)
mod.add_item(o2)
assert_equal(len(mod.sample), 2)
Expand All @@ -275,7 +281,7 @@ def test_remove_item_with_interface_refnx(self):
o2 = RepeatingMultilayer.from_pars(ls2, 1.0, 'oneLayerItem2')
d = Sample.from_pars(o1, name='myModel')
resolution_function = constant_resolution_function(2.0)
mod = Model.from_pars(d, 2, 1e-5, resolution_function, 'newModel', interface=interface)
mod = Model(d, 2, 1e-5, resolution_function, 'newModel', interface=interface)
assert_equal(len(mod.interface()._wrapper.storage['item']), 1)
assert_equal(len(mod.interface()._wrapper.storage['layer']), 2)
mod.add_item(o2)
Expand All @@ -296,9 +302,9 @@ def test_remove_item_with_interface_refl1d(self):
ls2 = LayerCollection.from_pars(l2, l1, name='twoLayer2')
o1 = RepeatingMultilayer.from_pars(ls1, 2.0, 'twoLayerItem1')
o2 = RepeatingMultilayer.from_pars(ls2, 1.0, 'oneLayerItem2')
d = Sample.from_pars(o1, name='myModel')
d = Sample(o1, name='myModel')
resolution_function = constant_resolution_function(2.0)
mod = Model.from_pars(d, 2, 1e-5, resolution_function, 'newModel', interface=interface)
mod = Model(d, 2, 1e-5, resolution_function, 'newModel', interface=interface)
assert_equal(len(mod.interface()._wrapper.storage['item']), 1)
assert_equal(len(mod.interface()._wrapper.storage['layer']), 2)
mod.add_item(o2)
Expand All @@ -320,7 +326,7 @@ def test_remove_item_with_interface_refl1d(self):
# o1 = RepeatingMultilayer.from_pars(ls1, 2.0, 'twoLayerItem1')
# o2 = RepeatingMultilayer.from_pars(ls2, 1.0, 'oneLayerItem2')
# d = Sample.from_pars(o1, name='myModel')
# mod = Model.from_pars(d, 2, 1e-5, 2.0, 'newModel', interface=interface)
# mod = Model(d, 2, 1e-5, 2.0, 'newModel', interface=interface)
# assert_equal(len(mod.interface()._wrapper.storage['item']), 1)
# assert_equal(len(mod.interface()._wrapper.storage['layer']), 2)
# mod.add_item(o2)
Expand All @@ -331,17 +337,17 @@ def test_remove_item_with_interface_refl1d(self):
# assert_equal(len(mod.interface()._wrapper.storage['layer']), 2)

def test_uid(self):
p = Model.default()
p = Model()
assert_equal(p.uid, p._borg.map.convert_id_to_key(p))

def test_set_resolution_function(self):
mock_resolution_function = MagicMock()
model = Model.default()
model = Model()
model.set_resolution_function(mock_resolution_function)
assert model._resolution_function == mock_resolution_function

def test_repr(self):
model = Model.default()
model = Model()

assert (
model.__repr__()
Expand All @@ -350,7 +356,7 @@ def test_repr(self):

def test_repr_resolution_function(self):
resolution_function = linear_spline_resolution_function([0, 10], [0, 10])
model = Model.default()
model = Model()
model.set_resolution_function(resolution_function)
assert (
model.__repr__()
Expand All @@ -360,7 +366,7 @@ def test_repr_resolution_function(self):
def test_dict_round_trip(self):
resolution_function = linear_spline_resolution_function([0, 10], [0, 10])
interface = CalculatorFactory()
model = Model.default(interface)
model = Model(interface=interface)
model.set_resolution_function(resolution_function)
surfactant = SurfactantLayer.default()
model.add_item(surfactant)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_fitting(self):
superphase,
name='Film Structure',
)
model = Model.from_pars(sample, 1, 1e-6, 0.02, 'Film Model')
model = Model(sample, 1, 1e-6, 0.02, 'Film Model')
# Thicknesses
sio2_layer.thickness.bounds = (15, 50)
film_layer.thickness.bounds = (200, 300)
Expand Down
Loading

0 comments on commit 87ddabd

Please sign in to comment.