From d306edd51c2f515ea4cc67f0af4bd98b3a8d57db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marius=20M=C3=BCller?= <49639740+MarJMue@users.noreply.github.com> Date: Thu, 18 Jul 2024 14:19:17 +0200 Subject: [PATCH] Add simple caching for dispersions --- src/elli/dispersions/base_dispersion.py | 50 +++++++++++++++++++++++++ tests/benchmark_fitting.py | 12 +++--- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/src/elli/dispersions/base_dispersion.py b/src/elli/dispersions/base_dispersion.py index baa2a36f..22ea9846 100644 --- a/src/elli/dispersions/base_dispersion.py +++ b/src/elli/dispersions/base_dispersion.py @@ -8,6 +8,7 @@ import numpy as np import numpy.typing as npt import pandas as pd +from lmfit import Parameter from numpy.lib.scimath import sqrt from .. import dispersions @@ -44,6 +45,14 @@ def _guard_invalid_params(params1, params2): missing_param_strings = ", ".join(f"{p}" for p in missing_params) raise InvalidParameters(f"Invalid parameter(s): {missing_param_strings}") + @staticmethod + def _hash_params(params: dict | list[dict]) -> int: + """Creates an single_params_dict or the repeating_params_list.""" + if isinstance(params, list): + return hash(tuple([self._hash_params(dictionary) for dictionary in params])) + else: + return hash(tuple([item for _, item in params.items()])) + @staticmethod def _fill_params_dict(template: dict, *args, **kwargs) -> dict: BaseDispersion._guard_invalid_params(list(kwargs.keys()), list(template.keys())) @@ -56,6 +65,8 @@ def _fill_params_dict(template: dict, *args, **kwargs) -> dict: for i, val in enumerate(args): key = list(template.keys())[i] + if isinstance(val, Parameter): + val = val.value params[key] = val pos_arguments.add(key) @@ -64,6 +75,8 @@ def _fill_params_dict(template: dict, *args, **kwargs) -> dict: raise InvalidParameters( f"Parameter {key} already set by positional argument" ) + if isinstance(value, Parameter): + value = value.value params[key] = value return params @@ -80,6 +93,10 @@ def __init__(self, *args, **kwargs): if self.single_params[param] is None: raise InvalidParameters(f"Please specify parameter {param}") + self.last_lbda = None + self.hash_single_params = None + self.hash_rep_params = None + @abstractmethod def dielectric_function(self, lbda: npt.ArrayLike) -> npt.NDArray: """Calculates the dielectric function in a given wavelength window. @@ -114,6 +131,39 @@ def get_dielectric(self, lbda: Optional[npt.ArrayLike] = None) -> npt.NDArray: """Returns the dielectric constant for wavelength 'lbda' default unit (nm) in the convention ε1 + iε2.""" lbda = self.default_lbda_range if lbda is None else lbda + + from .table_epsilon import TableEpsilon + from .table_index import Table + + if not isinstance(self, (DispersionSum, IndexDispersionSum)): + if isinstance(self, (TableEpsilon, Table)): + if self.last_lbda is lbda: + return self.cached_diel + else: + self.last_lbda = lbda + self.cached_diel = np.asarray( + self.dielectric_function(lbda), dtype=np.complex128 + ) + return self.cached_diel + else: + new_single_hash = self._hash_params(self.single_params) + new_rep_hash = self._hash_params(self.rep_params) + + if ( + self.last_lbda is lbda + and self.hash_single_params == new_single_hash + and self.hash_rep_params == new_rep_hash + ): + return self.cached_diel + else: + self.last_lbda = lbda + self.hash_single_params = new_single_hash + self.hash_rep_params = new_rep_hash + self.cached_diel = np.asarray( + self.dielectric_function(lbda), dtype=np.complex128 + ) + return self.cached_diel + return np.asarray(self.dielectric_function(lbda), dtype=np.complex128) def get_refractive_index(self, lbda: Optional[npt.ArrayLike] = None) -> npt.NDArray: diff --git a/tests/benchmark_fitting.py b/tests/benchmark_fitting.py index fd24c0a3..8d924131 100644 --- a/tests/benchmark_fitting.py +++ b/tests/benchmark_fitting.py @@ -94,12 +94,12 @@ def test_fitting_structure_updates(benchmark, datadir): @fit(psi_delta, params) def model(lbda, params): - SiO2.single_params["n0"] = params["SiO2_n0"] - SiO2.single_params["n1"] = params["SiO2_n1"] - SiO2.single_params["n2"] = params["SiO2_n2"] - SiO2.single_params["k0"] = params["SiO2_k0"] - SiO2.single_params["k1"] = params["SiO2_k1"] - SiO2.single_params["k2"] = params["SiO2_k2"] + SiO2.single_params["n0"] = params["SiO2_n0"].value + SiO2.single_params["n1"] = params["SiO2_n1"].value + SiO2.single_params["n2"] = params["SiO2_n2"].value + SiO2.single_params["k0"] = params["SiO2_k0"].value + SiO2.single_params["k1"] = params["SiO2_k1"].value + SiO2.single_params["k2"] = params["SiO2_k2"].value layer.set_thickness(params["SiO2_d"])