Skip to content

Commit

Permalink
Add simple caching for dispersions
Browse files Browse the repository at this point in the history
  • Loading branch information
MarJMue committed Jul 18, 2024
1 parent e8a03ab commit d306edd
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 6 deletions.
50 changes: 50 additions & 0 deletions src/elli/dispersions/base_dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
from lmfit import Parameter
from numpy.lib.scimath import sqrt

from .. import dispersions
Expand Down Expand Up @@ -44,6 +45,14 @@ def _guard_invalid_params(params1, params2):
missing_param_strings = ", ".join(f"{p}" for p in missing_params)
raise InvalidParameters(f"Invalid parameter(s): {missing_param_strings}")

@staticmethod
def _hash_params(params: dict | list[dict]) -> int:
"""Creates an single_params_dict or the repeating_params_list."""
if isinstance(params, list):
return hash(tuple([self._hash_params(dictionary) for dictionary in params]))
else:
return hash(tuple([item for _, item in params.items()]))

@staticmethod
def _fill_params_dict(template: dict, *args, **kwargs) -> dict:
BaseDispersion._guard_invalid_params(list(kwargs.keys()), list(template.keys()))
Expand All @@ -56,6 +65,8 @@ def _fill_params_dict(template: dict, *args, **kwargs) -> dict:

for i, val in enumerate(args):
key = list(template.keys())[i]
if isinstance(val, Parameter):
val = val.value
params[key] = val
pos_arguments.add(key)

Expand All @@ -64,6 +75,8 @@ def _fill_params_dict(template: dict, *args, **kwargs) -> dict:
raise InvalidParameters(
f"Parameter {key} already set by positional argument"
)
if isinstance(value, Parameter):
value = value.value
params[key] = value

return params
Expand All @@ -80,6 +93,10 @@ def __init__(self, *args, **kwargs):
if self.single_params[param] is None:
raise InvalidParameters(f"Please specify parameter {param}")

self.last_lbda = None
self.hash_single_params = None
self.hash_rep_params = None

@abstractmethod
def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray:
"""Calculates the dielectric function in a given wavelength window.
Expand Down Expand Up @@ -114,6 +131,39 @@ def get_dielectric(self, lbda: Optional[npt.ArrayLike] = None) -> npt.NDArray:
"""Returns the dielectric constant for wavelength 'lbda' default unit (nm)
in the convention ε1 + iε2."""
lbda = self.default_lbda_range if lbda is None else lbda

from .table_epsilon import TableEpsilon
from .table_index import Table

if not isinstance(self, (DispersionSum, IndexDispersionSum)):
if isinstance(self, (TableEpsilon, Table)):
if self.last_lbda is lbda:
return self.cached_diel
else:
self.last_lbda = lbda
self.cached_diel = np.asarray(
self.dielectric_function(lbda), dtype=np.complex128
)
return self.cached_diel
else:
new_single_hash = self._hash_params(self.single_params)
new_rep_hash = self._hash_params(self.rep_params)

if (
self.last_lbda is lbda
and self.hash_single_params == new_single_hash
and self.hash_rep_params == new_rep_hash
):
return self.cached_diel
else:
self.last_lbda = lbda
self.hash_single_params = new_single_hash
self.hash_rep_params = new_rep_hash
self.cached_diel = np.asarray(
self.dielectric_function(lbda), dtype=np.complex128
)
return self.cached_diel

return np.asarray(self.dielectric_function(lbda), dtype=np.complex128)

def get_refractive_index(self, lbda: Optional[npt.ArrayLike] = None) -> npt.NDArray:
Expand Down
12 changes: 6 additions & 6 deletions tests/benchmark_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ def test_fitting_structure_updates(benchmark, datadir):

@fit(psi_delta, params)
def model(lbda, params):
SiO2.single_params["n0"] = params["SiO2_n0"]
SiO2.single_params["n1"] = params["SiO2_n1"]
SiO2.single_params["n2"] = params["SiO2_n2"]
SiO2.single_params["k0"] = params["SiO2_k0"]
SiO2.single_params["k1"] = params["SiO2_k1"]
SiO2.single_params["k2"] = params["SiO2_k2"]
SiO2.single_params["n0"] = params["SiO2_n0"].value
SiO2.single_params["n1"] = params["SiO2_n1"].value
SiO2.single_params["n2"] = params["SiO2_n2"].value
SiO2.single_params["k0"] = params["SiO2_k0"].value
SiO2.single_params["k1"] = params["SiO2_k1"].value
SiO2.single_params["k2"] = params["SiO2_k2"].value

layer.set_thickness(params["SiO2_d"])

Expand Down

0 comments on commit d306edd

Please sign in to comment.