Skip to content

Commit

Permalink
Add lambda filter for earthkit.data.Field objects (#16)
Browse files Browse the repository at this point in the history
* add earthkit.data.Field lambda filter and tests
  • Loading branch information
frazane authored Dec 18, 2024
1 parent 7b91806 commit 9e44cab
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 26 deletions.
114 changes: 114 additions & 0 deletions src/anemoi/transform/filters/lambda_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import typing as tp
import importlib

from anemoi.transform.filters.base import SimpleFilter
from anemoi.transform.filters import filter_registry
from earthkit.data.core.fieldlist import Field, FieldList


@filter_registry.register("earthkitfieldlambda")
class EarthkitFieldLambdaFilter(SimpleFilter):
"""A filter to apply an arbitrary function to individual fields."""

def __init__(
self,
fn: str | tp.Callable[[Field, tp.Any], Field],
param: str | list[str],
args: list = [],
kwargs: dict[str, tp.Any] = {},
backward_fn: str | tp.Callable[[Field, tp.Any], Field] | None = None,
):
"""Initialise the EarthkitFieldLambdaFilter.
Parameters
----------
fn: callable or str
The lambda function as a callable with the general signature
`fn(*earthkit.data.Field, *args, **kwargs) -> earthkit.data.Field` or
a string path to the function, such as "package.module.function".
param: list or str
The parameter name or list of parameter names to apply the function to.
args: list
The list of arguments to pass to the lambda function.
kwargs: dict
The dictionary of keyword arguments to pass to the lambda function.
backward_fn (optional): callable, str or None
The backward lambda function as a callable with the general signature
`backward_fn(*earthkit.data.Field, *args, **kwargs) -> earthkit.data.Field` or
a string path to the function, such as "package.module.function".
Examples
--------
>>> from anemoi.transform.filters.lambda_filters import EarthkitFieldLambdaFilter
>>> import earthkit.data as ekd
>>> fields = ekd.from_source(
... "mars",{"param": ["2t"],
... "levtype": "sfc",
... "dates": ["2023-11-17 00:00:00"]})
>>> kelvin_to_celsius = EarthkitFieldLambdaFilter(
... fn=lambda x, s: x.clone(values=x.values - s),
... param="2t",
... args=[273.15],
... )
>>> fields = kelvin_to_celsius.forward(fields)
"""
if not isinstance(args, list):
raise ValueError("Expected 'args' to be a list. "
f"Got {args} instead.")
if not isinstance(kwargs, dict):
raise ValueError("Expected 'kwargs' to be a dictionary. "
f"Got {kwargs} instead.")

self.fn = self._import_fn(fn) if isinstance(fn, str) else fn

if isinstance(backward_fn, str):
self.backward_fn = self._import_fn(backward_fn)
else:
self.backward_fn = backward_fn

self.param = param if isinstance(param, list) else [param]
self.args = args
self.kwargs = kwargs

def forward(self, data: FieldList) -> FieldList:
return self._transform(data, self.forward_transform, *self.param)

def backward(self, data: FieldList) -> FieldList:
if self.backward_fn:
return self._transform(data, self.backward_transform, *self.param)
raise NotImplementedError(f"{self} is not reversible.")

def forward_transform(self, *fields: Field) -> tp.Iterator[Field]:
"""Apply the lambda function to the field."""
yield self.fn(*fields, *self.args, **self.kwargs)

def backward_transform(self, *fields: Field) -> tp.Iterator[Field]:
"""Apply the backward lambda function to the field."""
yield self.backward_fn(*fields, *self.args, **self.kwargs)

def _import_fn(self, fn: str) -> tp.Callable[..., Field]:
try:
module_name, fn_name = fn.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, fn_name)
except Exception as e:
raise ValueError(f"Could not import function {fn}") from e

def __repr__(self):
out = f"{self.__class__.__name__}(fn={self.fn},"
if self.backward_fn:
out += f"backward_fn={self.backward_fn},"
out += f"param={self.param},"
out += f"args={self.args},"
out += f"kwargs={self.kwargs},"
out += ")"
return out
Empty file added tests/__init__.py
Empty file.
110 changes: 84 additions & 26 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,55 +7,113 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import sys
from pathlib import Path

import numpy.testing as npt

from anemoi.transform.filters.rescale import Rescale, Convert
from anemoi.transform.filters.lambda_filters import EarthkitFieldLambdaFilter
import earthkit.data as ekd
from pytest import approx

sys.path.append(Path(__file__).parents[1].as_posix())

def test_rescale():
def test_rescale(fieldlist):
fieldlist = fieldlist.sel(param="2t")
# rescale from K to °C
temp = ekd.from_source(
"mars", {"param": "2t", "levtype": "sfc", "dates": ["2023-11-17 00:00:00"]}
)
fieldlist = temp.to_fieldlist()
k_to_deg = Rescale(scale=1.0, offset=-273.15, param="2t")
rescaled = k_to_deg.forward(fieldlist)
assert rescaled[0].values.min() == temp.values.min() - 273.15
assert rescaled[0].values.std() == approx(temp.values.std())

npt.assert_allclose(
rescaled[0].to_numpy(),
fieldlist[0].to_numpy() - 273.15
)
# and back
rescaled_back = k_to_deg.backward(rescaled)
assert rescaled_back[0].values.min() == temp.values.min()
assert rescaled_back[0].values.std() == approx(temp.values.std())
npt.assert_allclose(
rescaled_back[0].to_numpy(),
fieldlist[0].to_numpy()
)
# rescale from °C to F
deg_to_far = Rescale(scale=9 / 5, offset=32, param="2t")
rescaled_farheneit = deg_to_far.forward(rescaled)
assert rescaled_farheneit[0].values.min() == 9 / 5 * rescaled[0].values.min() + 32
assert rescaled_farheneit[0].values.std() == approx(
(9 / 5) * rescaled[0].values.std()
npt.assert_allclose(
rescaled_farheneit[0].to_numpy(),
9 / 5 * rescaled[0].to_numpy() + 32
)
# rescale from F to K
rescaled_back = k_to_deg.backward(deg_to_far.backward(rescaled_farheneit))
assert rescaled_back[0].values.min() == temp.values.min()
assert rescaled_back[0].values.std() == approx(temp.values.std())

npt.assert_allclose(
rescaled_back[0].to_numpy(),
fieldlist[0].to_numpy()
)

def test_convert():
def test_convert(fieldlist):
# rescale from K to °C
temp = ekd.from_source(
"mars", {"param": "2t", "levtype": "sfc", "dates": ["2023-11-17 00:00:00"]}
)
fieldlist = temp.to_fieldlist()
fieldlist = fieldlist.sel(param="2t")
k_to_deg = Convert(unit_in="K", unit_out="degC", param="2t")
rescaled = k_to_deg.forward(fieldlist)
assert rescaled[0].values.min() == temp.values.min() - 273.15
assert rescaled[0].values.std() == approx(temp.values.std())
assert rescaled[0].values.min() == fieldlist.values.min() - 273.15
assert rescaled[0].values.std() == approx(fieldlist.values.std())
# and back
rescaled_back = k_to_deg.backward(rescaled)
assert rescaled_back[0].values.min() == temp.values.min()
assert rescaled_back[0].values.std() == approx(temp.values.std())
assert rescaled_back[0].values.min() == fieldlist.values.min()
assert rescaled_back[0].values.std() == approx(fieldlist.values.std())



# used in the test below
def _do_something(field, a):
return field.clone(values=field.values * a)

def test_singlefieldlambda(fieldlist):

fieldlist = fieldlist.sel(param="sp")

def undo_something(field, a):
return field.clone(values=field.values / a)

something = EarthkitFieldLambdaFilter(
fn="tests.test_filters._do_something",
param="sp",
args=[10],
backward_fn=undo_something,
)

transformed = something.forward(fieldlist)
npt.assert_allclose(
transformed[0].to_numpy(),
fieldlist[0].to_numpy() * 10
)

untransformed = something.backward(transformed)
npt.assert_allclose(
untransformed[0].to_numpy(),
fieldlist[0].to_numpy()
)



if __name__ == "__main__":
test_rescale()
test_convert()

fieldlist = ekd.from_source(
"mars",
{
"param": ["2t", "sp"],
"levtype": "sfc",
"dates": ["2023-11-17 00:00:00"],
},
)

test_rescale(fieldlist)
try:
test_convert(fieldlist)
except FileNotFoundError:
print(
"Skipping test_convert because of missing UNIDATA UDUNITS2 library, "
"required by cfunits."
)
test_singlefieldlambda(fieldlist)

print("All tests passed.")

0 comments on commit 9e44cab

Please sign in to comment.