Skip to content

Commit 60d2219

Browse files
committed
Merge remote-tracking branch 'origin/develop' into feature/regrid
2 parents f031721 + 9e44cab commit 60d2219

File tree

6 files changed

+334
-27
lines changed

6 files changed

+334
-27
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Keep it human-readable, your future self will thank you!
1313
### Added
1414

1515
- Add regrid filter
16+
- Added repeat-member #18
1617

1718
## [0.1.0](https://github.com/ecmwf/anemoi-utils/transform/0.0.5...HEAD/compare/0.0.8...0.1.0) - 2024-11-18
1819

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
import importlib
11+
import typing as tp
12+
13+
from earthkit.data.core.fieldlist import Field
14+
from earthkit.data.core.fieldlist import FieldList
15+
16+
from anemoi.transform.filters import filter_registry
17+
from anemoi.transform.filters.base import SimpleFilter
18+
19+
20+
@filter_registry.register("earthkitfieldlambda")
21+
class EarthkitFieldLambdaFilter(SimpleFilter):
22+
"""A filter to apply an arbitrary function to individual fields."""
23+
24+
def __init__(
25+
self,
26+
fn: str | tp.Callable[[Field, tp.Any], Field],
27+
param: str | list[str],
28+
args: list = [],
29+
kwargs: dict[str, tp.Any] = {},
30+
backward_fn: str | tp.Callable[[Field, tp.Any], Field] | None = None,
31+
):
32+
"""Initialise the EarthkitFieldLambdaFilter.
33+
34+
Parameters
35+
----------
36+
fn: callable or str
37+
The lambda function as a callable with the general signature
38+
`fn(*earthkit.data.Field, *args, **kwargs) -> earthkit.data.Field` or
39+
a string path to the function, such as "package.module.function".
40+
param: list or str
41+
The parameter name or list of parameter names to apply the function to.
42+
args: list
43+
The list of arguments to pass to the lambda function.
44+
kwargs: dict
45+
The dictionary of keyword arguments to pass to the lambda function.
46+
backward_fn (optional): callable, str or None
47+
The backward lambda function as a callable with the general signature
48+
`backward_fn(*earthkit.data.Field, *args, **kwargs) -> earthkit.data.Field` or
49+
a string path to the function, such as "package.module.function".
50+
51+
Examples
52+
--------
53+
>>> from anemoi.transform.filters.lambda_filters import EarthkitFieldLambdaFilter
54+
>>> import earthkit.data as ekd
55+
>>> fields = ekd.from_source(
56+
... "mars",{"param": ["2t"],
57+
... "levtype": "sfc",
58+
... "dates": ["2023-11-17 00:00:00"]})
59+
>>> kelvin_to_celsius = EarthkitFieldLambdaFilter(
60+
... fn=lambda x, s: x.clone(values=x.values - s),
61+
... param="2t",
62+
... args=[273.15],
63+
... )
64+
>>> fields = kelvin_to_celsius.forward(fields)
65+
"""
66+
if not isinstance(args, list):
67+
raise ValueError("Expected 'args' to be a list. " f"Got {args} instead.")
68+
if not isinstance(kwargs, dict):
69+
raise ValueError("Expected 'kwargs' to be a dictionary. " f"Got {kwargs} instead.")
70+
71+
self.fn = self._import_fn(fn) if isinstance(fn, str) else fn
72+
73+
if isinstance(backward_fn, str):
74+
self.backward_fn = self._import_fn(backward_fn)
75+
else:
76+
self.backward_fn = backward_fn
77+
78+
self.param = param if isinstance(param, list) else [param]
79+
self.args = args
80+
self.kwargs = kwargs
81+
82+
def forward(self, data: FieldList) -> FieldList:
83+
return self._transform(data, self.forward_transform, *self.param)
84+
85+
def backward(self, data: FieldList) -> FieldList:
86+
if self.backward_fn:
87+
return self._transform(data, self.backward_transform, *self.param)
88+
raise NotImplementedError(f"{self} is not reversible.")
89+
90+
def forward_transform(self, *fields: Field) -> tp.Iterator[Field]:
91+
"""Apply the lambda function to the field."""
92+
yield self.fn(*fields, *self.args, **self.kwargs)
93+
94+
def backward_transform(self, *fields: Field) -> tp.Iterator[Field]:
95+
"""Apply the backward lambda function to the field."""
96+
yield self.backward_fn(*fields, *self.args, **self.kwargs)
97+
98+
def _import_fn(self, fn: str) -> tp.Callable[..., Field]:
99+
try:
100+
module_name, fn_name = fn.rsplit(".", 1)
101+
module = importlib.import_module(module_name)
102+
return getattr(module, fn_name)
103+
except Exception as e:
104+
raise ValueError(f"Could not import function {fn}") from e
105+
106+
def __repr__(self):
107+
out = f"{self.__class__.__name__}(fn={self.fn},"
108+
if self.backward_fn:
109+
out += f"backward_fn={self.backward_fn},"
110+
out += f"param={self.param},"
111+
out += f"args={self.args},"
112+
out += f"kwargs={self.kwargs},"
113+
out += ")"
114+
return out
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
11+
import logging
12+
13+
from ..fields import new_field_from_numpy
14+
from ..fields import new_fieldlist_from_list
15+
from . import filter_registry
16+
from .base import Filter
17+
18+
LOG = logging.getLogger(__name__)
19+
20+
21+
def make_list_int(value):
22+
if isinstance(value, str):
23+
if "/" not in value:
24+
return [value]
25+
bits = value.split("/")
26+
if len(bits) == 3 and bits[1].lower() == "to":
27+
value = list(range(int(bits[0]), int(bits[2]) + 1, 1))
28+
29+
elif len(bits) == 5 and bits[1].lower() == "to" and bits[3].lower() == "by":
30+
value = list(range(int(bits[0]), int(bits[2]) + int(bits[4]), int(bits[4])))
31+
32+
if isinstance(value, list):
33+
return value
34+
if isinstance(value, tuple):
35+
return value
36+
if isinstance(value, int):
37+
return [value]
38+
39+
raise ValueError(f"Cannot make list from {value}")
40+
41+
42+
@filter_registry.register("repeat_members")
43+
class RepeatMembers(Filter):
44+
"""The filter can be used to replicate non-ensembles fields into ensemble fields.
45+
46+
Args: (only one of the following)
47+
numbers: A list of numbers (1-based) of the fields to replicate.
48+
members: A list of 0-based indices of the fields to replicate.
49+
count: The number of times to replicate the fields.
50+
"""
51+
52+
def __init__(
53+
self,
54+
numbers=None, # 1-based
55+
members=None, # 0-based
56+
count=None,
57+
):
58+
if sum(x is not None for x in (members, count, numbers)) != 1:
59+
raise ValueError("Exactly one of members, count or numbers must be given")
60+
61+
if numbers is not None:
62+
numbers = make_list_int(numbers)
63+
members = [n - 1 for n in numbers]
64+
65+
if count is not None:
66+
members = list(range(count))
67+
68+
members = make_list_int(members)
69+
self.members = members
70+
assert isinstance(members, (tuple, list)), f"members must be a list or tuple, got {type(members)}"
71+
72+
def forward(self, data):
73+
result = []
74+
for f in data:
75+
array = f.to_numpy()
76+
for member in self.members:
77+
number = member + 1
78+
new_field = new_field_from_numpy(array, template=f, number=number)
79+
result.append(new_field)
80+
81+
return new_fieldlist_from_list(result)
82+
83+
def backward(self, data):
84+
# this could be implemented
85+
raise NotImplementedError("RepeatMembers is not reversible")

tests/__init__.py

Whitespace-only changes.

tests/test_filters.py

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,51 +7,87 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10+
import sys
11+
from pathlib import Path
1012

1113
import earthkit.data as ekd
14+
import numpy.testing as npt
1215
from pytest import approx
1316

17+
from anemoi.transform.filters.lambda_filters import EarthkitFieldLambdaFilter
1418
from anemoi.transform.filters.rescale import Convert
1519
from anemoi.transform.filters.rescale import Rescale
1620

21+
sys.path.append(Path(__file__).parents[1].as_posix())
1722

18-
def test_rescale():
23+
24+
def test_rescale(fieldlist):
25+
fieldlist = fieldlist.sel(param="2t")
1926
# rescale from K to °C
20-
temp = ekd.from_source("mars", {"param": "2t", "levtype": "sfc", "dates": ["2023-11-17 00:00:00"]})
21-
fieldlist = temp.to_fieldlist()
2227
k_to_deg = Rescale(scale=1.0, offset=-273.15, param="2t")
2328
rescaled = k_to_deg.forward(fieldlist)
24-
assert rescaled[0].values.min() == temp.values.min() - 273.15
25-
assert rescaled[0].values.std() == approx(temp.values.std())
29+
30+
npt.assert_allclose(rescaled[0].to_numpy(), fieldlist[0].to_numpy() - 273.15)
2631
# and back
2732
rescaled_back = k_to_deg.backward(rescaled)
28-
assert rescaled_back[0].values.min() == temp.values.min()
29-
assert rescaled_back[0].values.std() == approx(temp.values.std())
30-
# rescale from °C to F
31-
deg_to_far = Rescale(scale=9 / 5, offset=32, param="2t")
32-
rescaled_farheneit = deg_to_far.forward(rescaled)
33-
assert rescaled_farheneit[0].values.min() == 9 / 5 * rescaled[0].values.min() + 32
34-
assert rescaled_farheneit[0].values.std() == approx((9 / 5) * rescaled[0].values.std())
35-
# rescale from F to K
36-
rescaled_back = k_to_deg.backward(deg_to_far.backward(rescaled_farheneit))
37-
assert rescaled_back[0].values.min() == temp.values.min()
38-
assert rescaled_back[0].values.std() == approx(temp.values.std())
39-
40-
41-
def test_convert():
33+
npt.assert_allclose(rescaled_back[0].to_numpy(), fieldlist[0].to_numpy())
34+
35+
36+
def test_convert(fieldlist):
4237
# rescale from K to °C
43-
temp = ekd.from_source("mars", {"param": "2t", "levtype": "sfc", "dates": ["2023-11-17 00:00:00"]})
44-
fieldlist = temp.to_fieldlist()
38+
fieldlist = fieldlist.sel(param="2t")
4539
k_to_deg = Convert(unit_in="K", unit_out="degC", param="2t")
4640
rescaled = k_to_deg.forward(fieldlist)
47-
assert rescaled[0].values.min() == temp.values.min() - 273.15
48-
assert rescaled[0].values.std() == approx(temp.values.std())
41+
assert rescaled[0].values.min() == fieldlist.values.min() - 273.15
42+
assert rescaled[0].values.std() == approx(fieldlist.values.std())
4943
# and back
5044
rescaled_back = k_to_deg.backward(rescaled)
51-
assert rescaled_back[0].values.min() == temp.values.min()
52-
assert rescaled_back[0].values.std() == approx(temp.values.std())
45+
assert rescaled_back[0].values.min() == fieldlist.values.min()
46+
assert rescaled_back[0].values.std() == approx(fieldlist.values.std())
47+
48+
49+
# used in the test below
50+
def _do_something(field, a):
51+
return field.clone(values=field.values * a)
52+
53+
54+
def test_singlefieldlambda(fieldlist):
55+
56+
fieldlist = fieldlist.sel(param="sp")
57+
58+
def undo_something(field, a):
59+
return field.clone(values=field.values / a)
60+
61+
something = EarthkitFieldLambdaFilter(
62+
fn="tests.test_filters._do_something",
63+
param="sp",
64+
args=[10],
65+
backward_fn=undo_something,
66+
)
67+
68+
transformed = something.forward(fieldlist)
69+
npt.assert_allclose(transformed[0].to_numpy(), fieldlist[0].to_numpy() * 10)
70+
71+
untransformed = something.backward(transformed)
72+
npt.assert_allclose(untransformed[0].to_numpy(), fieldlist[0].to_numpy())
5373

5474

5575
if __name__ == "__main__":
56-
test_rescale()
57-
test_convert()
76+
77+
fieldlist = ekd.from_source(
78+
"mars",
79+
{
80+
"param": ["2t", "sp"],
81+
"levtype": "sfc",
82+
"dates": ["2023-11-17 00:00:00"],
83+
},
84+
)
85+
86+
test_rescale(fieldlist)
87+
try:
88+
test_convert(fieldlist)
89+
except FileNotFoundError:
90+
print("Skipping test_convert because of missing UNIDATA UDUNITS2 library, " "required by cfunits.")
91+
test_singlefieldlambda(fieldlist)
92+
93+
print("All tests passed.")

tests/test_repeat_members.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
import earthkit.data as ekd
11+
import numpy as np
12+
13+
from anemoi.transform.filters.repeat_members import RepeatMembers
14+
15+
16+
def _get_template():
17+
temp = ekd.from_source("mars", {"param": "2t", "levtype": "sfc", "dates": ["2023-11-17 00:00:00"]})
18+
fieldlist = temp.to_fieldlist()
19+
return fieldlist, fieldlist[0].values, fieldlist[0].metadata
20+
21+
22+
def test_repeat_members_using_numbers_1():
23+
fieldlist, values, metadata = _get_template()
24+
25+
repeat = RepeatMembers(numbers=[1, 2, 3])
26+
repeated = repeat.forward(fieldlist)
27+
assert len(repeated) == 3
28+
for i, f in enumerate(repeated):
29+
assert f.values.shape == values.shape
30+
assert np.all(f.values == values)
31+
assert f.metadata("number") == i + 1
32+
assert f.metadata("name") == metadata("name")
33+
34+
35+
def test_repeat_members_using_numbers_2():
36+
fieldlist, values, metadata = _get_template()
37+
38+
repeat = RepeatMembers(numbers="1/to/3")
39+
repeated = repeat.forward(fieldlist)
40+
assert len(repeated) == 3
41+
for i, f in enumerate(repeated):
42+
assert f.values.shape == values.shape
43+
assert np.all(f.values == values)
44+
assert f.metadata("number") == i + 1
45+
assert f.metadata("name") == metadata("name")
46+
47+
48+
def test_repeat_members_using_members():
49+
fieldlist, values, metadata = _get_template()
50+
51+
repeat = RepeatMembers(members=[0, 1, 2])
52+
repeated = repeat.forward(fieldlist)
53+
assert len(repeated) == 3
54+
for i, f in enumerate(repeated):
55+
assert f.values.shape == values.shape
56+
assert np.all(f.values == values)
57+
assert f.metadata("number") == i + 1
58+
assert f.metadata("name") == metadata("name")
59+
60+
61+
def test_repeat_members_using_count():
62+
fieldlist, values, metadata = _get_template()
63+
64+
repeat = RepeatMembers(count=3)
65+
repeated = repeat.forward(fieldlist)
66+
assert len(repeated) == 3
67+
for i, f in enumerate(repeated):
68+
assert f.values.shape == values.shape
69+
assert np.all(f.values == values)
70+
assert f.metadata("number") == i + 1
71+
assert f.metadata("name") == metadata("name")

0 commit comments

Comments
 (0)