From 9e44cab6e331c125cc6e8cac37b49e4bdd6d49c5 Mon Sep 17 00:00:00 2001 From: Francesco Zanetta <62377868+frazane@users.noreply.github.com> Date: Wed, 18 Dec 2024 10:57:58 +0100 Subject: [PATCH] Add lambda filter for earthkit.data.Field objects (#16) * add earthkit.data.Field lambda filter and tests --- .../transform/filters/lambda_filters.py | 114 ++++++++++++++++++ tests/__init__.py | 0 tests/test_filters.py | 110 +++++++++++++---- 3 files changed, 198 insertions(+), 26 deletions(-) create mode 100644 src/anemoi/transform/filters/lambda_filters.py create mode 100644 tests/__init__.py diff --git a/src/anemoi/transform/filters/lambda_filters.py b/src/anemoi/transform/filters/lambda_filters.py new file mode 100644 index 0000000..edb33cc --- /dev/null +++ b/src/anemoi/transform/filters/lambda_filters.py @@ -0,0 +1,114 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import typing as tp +import importlib + +from anemoi.transform.filters.base import SimpleFilter +from anemoi.transform.filters import filter_registry +from earthkit.data.core.fieldlist import Field, FieldList + + +@filter_registry.register("earthkitfieldlambda") +class EarthkitFieldLambdaFilter(SimpleFilter): + """A filter to apply an arbitrary function to individual fields.""" + + def __init__( + self, + fn: str | tp.Callable[[Field, tp.Any], Field], + param: str | list[str], + args: list = [], + kwargs: dict[str, tp.Any] = {}, + backward_fn: str | tp.Callable[[Field, tp.Any], Field] | None = None, + ): + """Initialise the EarthkitFieldLambdaFilter. + + Parameters + ---------- + fn: callable or str + The lambda function as a callable with the general signature + `fn(*earthkit.data.Field, *args, **kwargs) -> earthkit.data.Field` or + a string path to the function, such as "package.module.function". + param: list or str + The parameter name or list of parameter names to apply the function to. + args: list + The list of arguments to pass to the lambda function. + kwargs: dict + The dictionary of keyword arguments to pass to the lambda function. + backward_fn (optional): callable, str or None + The backward lambda function as a callable with the general signature + `backward_fn(*earthkit.data.Field, *args, **kwargs) -> earthkit.data.Field` or + a string path to the function, such as "package.module.function". + + Examples + -------- + >>> from anemoi.transform.filters.lambda_filters import EarthkitFieldLambdaFilter + >>> import earthkit.data as ekd + >>> fields = ekd.from_source( + ... "mars",{"param": ["2t"], + ... "levtype": "sfc", + ... "dates": ["2023-11-17 00:00:00"]}) + >>> kelvin_to_celsius = EarthkitFieldLambdaFilter( + ... fn=lambda x, s: x.clone(values=x.values - s), + ... param="2t", + ... args=[273.15], + ... ) + >>> fields = kelvin_to_celsius.forward(fields) + """ + if not isinstance(args, list): + raise ValueError("Expected 'args' to be a list. " + f"Got {args} instead.") + if not isinstance(kwargs, dict): + raise ValueError("Expected 'kwargs' to be a dictionary. " + f"Got {kwargs} instead.") + + self.fn = self._import_fn(fn) if isinstance(fn, str) else fn + + if isinstance(backward_fn, str): + self.backward_fn = self._import_fn(backward_fn) + else: + self.backward_fn = backward_fn + + self.param = param if isinstance(param, list) else [param] + self.args = args + self.kwargs = kwargs + + def forward(self, data: FieldList) -> FieldList: + return self._transform(data, self.forward_transform, *self.param) + + def backward(self, data: FieldList) -> FieldList: + if self.backward_fn: + return self._transform(data, self.backward_transform, *self.param) + raise NotImplementedError(f"{self} is not reversible.") + + def forward_transform(self, *fields: Field) -> tp.Iterator[Field]: + """Apply the lambda function to the field.""" + yield self.fn(*fields, *self.args, **self.kwargs) + + def backward_transform(self, *fields: Field) -> tp.Iterator[Field]: + """Apply the backward lambda function to the field.""" + yield self.backward_fn(*fields, *self.args, **self.kwargs) + + def _import_fn(self, fn: str) -> tp.Callable[..., Field]: + try: + module_name, fn_name = fn.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, fn_name) + except Exception as e: + raise ValueError(f"Could not import function {fn}") from e + + def __repr__(self): + out = f"{self.__class__.__name__}(fn={self.fn}," + if self.backward_fn: + out += f"backward_fn={self.backward_fn}," + out += f"param={self.param}," + out += f"args={self.args}," + out += f"kwargs={self.kwargs}," + out += ")" + return out \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_filters.py b/tests/test_filters.py index 0351907..ddcbb99 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -7,55 +7,113 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import sys +from pathlib import Path + +import numpy.testing as npt from anemoi.transform.filters.rescale import Rescale, Convert +from anemoi.transform.filters.lambda_filters import EarthkitFieldLambdaFilter import earthkit.data as ekd from pytest import approx +sys.path.append(Path(__file__).parents[1].as_posix()) -def test_rescale(): +def test_rescale(fieldlist): + fieldlist = fieldlist.sel(param="2t") # rescale from K to °C - temp = ekd.from_source( - "mars", {"param": "2t", "levtype": "sfc", "dates": ["2023-11-17 00:00:00"]} - ) - fieldlist = temp.to_fieldlist() k_to_deg = Rescale(scale=1.0, offset=-273.15, param="2t") rescaled = k_to_deg.forward(fieldlist) - assert rescaled[0].values.min() == temp.values.min() - 273.15 - assert rescaled[0].values.std() == approx(temp.values.std()) + + npt.assert_allclose( + rescaled[0].to_numpy(), + fieldlist[0].to_numpy() - 273.15 + ) # and back rescaled_back = k_to_deg.backward(rescaled) - assert rescaled_back[0].values.min() == temp.values.min() - assert rescaled_back[0].values.std() == approx(temp.values.std()) + npt.assert_allclose( + rescaled_back[0].to_numpy(), + fieldlist[0].to_numpy() + ) # rescale from °C to F deg_to_far = Rescale(scale=9 / 5, offset=32, param="2t") rescaled_farheneit = deg_to_far.forward(rescaled) - assert rescaled_farheneit[0].values.min() == 9 / 5 * rescaled[0].values.min() + 32 - assert rescaled_farheneit[0].values.std() == approx( - (9 / 5) * rescaled[0].values.std() + npt.assert_allclose( + rescaled_farheneit[0].to_numpy(), + 9 / 5 * rescaled[0].to_numpy() + 32 ) # rescale from F to K rescaled_back = k_to_deg.backward(deg_to_far.backward(rescaled_farheneit)) - assert rescaled_back[0].values.min() == temp.values.min() - assert rescaled_back[0].values.std() == approx(temp.values.std()) - + npt.assert_allclose( + rescaled_back[0].to_numpy(), + fieldlist[0].to_numpy() + ) -def test_convert(): +def test_convert(fieldlist): # rescale from K to °C - temp = ekd.from_source( - "mars", {"param": "2t", "levtype": "sfc", "dates": ["2023-11-17 00:00:00"]} - ) - fieldlist = temp.to_fieldlist() + fieldlist = fieldlist.sel(param="2t") k_to_deg = Convert(unit_in="K", unit_out="degC", param="2t") rescaled = k_to_deg.forward(fieldlist) - assert rescaled[0].values.min() == temp.values.min() - 273.15 - assert rescaled[0].values.std() == approx(temp.values.std()) + assert rescaled[0].values.min() == fieldlist.values.min() - 273.15 + assert rescaled[0].values.std() == approx(fieldlist.values.std()) # and back rescaled_back = k_to_deg.backward(rescaled) - assert rescaled_back[0].values.min() == temp.values.min() - assert rescaled_back[0].values.std() == approx(temp.values.std()) + assert rescaled_back[0].values.min() == fieldlist.values.min() + assert rescaled_back[0].values.std() == approx(fieldlist.values.std()) + + + +# used in the test below +def _do_something(field, a): + return field.clone(values=field.values * a) + +def test_singlefieldlambda(fieldlist): + + fieldlist = fieldlist.sel(param="sp") + + def undo_something(field, a): + return field.clone(values=field.values / a) + + something = EarthkitFieldLambdaFilter( + fn="tests.test_filters._do_something", + param="sp", + args=[10], + backward_fn=undo_something, + ) + + transformed = something.forward(fieldlist) + npt.assert_allclose( + transformed[0].to_numpy(), + fieldlist[0].to_numpy() * 10 + ) + + untransformed = something.backward(transformed) + npt.assert_allclose( + untransformed[0].to_numpy(), + fieldlist[0].to_numpy() + ) + if __name__ == "__main__": - test_rescale() - test_convert() + + fieldlist = ekd.from_source( + "mars", + { + "param": ["2t", "sp"], + "levtype": "sfc", + "dates": ["2023-11-17 00:00:00"], + }, + ) + + test_rescale(fieldlist) + try: + test_convert(fieldlist) + except FileNotFoundError: + print( + "Skipping test_convert because of missing UNIDATA UDUNITS2 library, " + "required by cfunits." + ) + test_singlefieldlambda(fieldlist) + + print("All tests passed.")