Skip to content

Commit 2d548c9

Browse files
authored
Merge branch 'ecmwf:develop' into feature/lambda-filter
2 parents dfd8030 + 7b91806 commit 2d548c9

File tree

3 files changed

+161
-1
lines changed

3 files changed

+161
-1
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
Please add your functional changes to the appropriate section in the PR.
99
Keep it human-readable, your future self will thank you!
1010

11-
## [Unreleased](https://github.com/ecmwf/anemoi-utils/transform/0.0.5...HEAD/compare/0.0.8...HEAD)
11+
## [Unreleased](https://github.com/ecmwf/anemoi-utils/transform/0.0.5...HEAD/compare/0.1.0...HEAD)
12+
13+
- Added repeat-member #18
14+
15+
## [0.1.0](https://github.com/ecmwf/anemoi-utils/transform/0.0.5...HEAD/compare/0.0.8...0.1.0) - 2024-11-18
1216

1317
## [0.0.8](https://github.com/ecmwf/anemoi-utils/transform/0.0.5...HEAD/compare/0.0.5...0.0.8) - 2024-10-26
1418

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
11+
import logging
12+
13+
from ..fields import new_field_from_numpy
14+
from ..fields import new_fieldlist_from_list
15+
from . import filter_registry
16+
from .base import Filter
17+
18+
LOG = logging.getLogger(__name__)
19+
20+
21+
def make_list_int(value):
22+
if isinstance(value, str):
23+
if "/" not in value:
24+
return [value]
25+
bits = value.split("/")
26+
if len(bits) == 3 and bits[1].lower() == "to":
27+
value = list(range(int(bits[0]), int(bits[2]) + 1, 1))
28+
29+
elif len(bits) == 5 and bits[1].lower() == "to" and bits[3].lower() == "by":
30+
value = list(range(int(bits[0]), int(bits[2]) + int(bits[4]), int(bits[4])))
31+
32+
if isinstance(value, list):
33+
return value
34+
if isinstance(value, tuple):
35+
return value
36+
if isinstance(value, int):
37+
return [value]
38+
39+
raise ValueError(f"Cannot make list from {value}")
40+
41+
42+
@filter_registry.register("repeat_members")
43+
class RepeatMembers(Filter):
44+
"""The filter can be used to replicate non-ensembles fields into ensemble fields.
45+
46+
Args: (only one of the following)
47+
numbers: A list of numbers (1-based) of the fields to replicate.
48+
members: A list of 0-based indices of the fields to replicate.
49+
count: The number of times to replicate the fields.
50+
"""
51+
52+
def __init__(
53+
self,
54+
numbers=None, # 1-based
55+
members=None, # 0-based
56+
count=None,
57+
):
58+
if sum(x is not None for x in (members, count, numbers)) != 1:
59+
raise ValueError("Exactly one of members, count or numbers must be given")
60+
61+
if numbers is not None:
62+
numbers = make_list_int(numbers)
63+
members = [n - 1 for n in numbers]
64+
65+
if count is not None:
66+
members = list(range(count))
67+
68+
members = make_list_int(members)
69+
self.members = members
70+
assert isinstance(members, (tuple, list)), f"members must be a list or tuple, got {type(members)}"
71+
72+
def forward(self, data):
73+
result = []
74+
for f in data:
75+
array = f.to_numpy()
76+
for member in self.members:
77+
number = member + 1
78+
new_field = new_field_from_numpy(array, template=f, number=number)
79+
result.append(new_field)
80+
81+
return new_fieldlist_from_list(result)
82+
83+
def backward(self, data):
84+
# this could be implemented
85+
raise NotImplementedError("RepeatMembers is not reversible")

tests/test_repeat_members.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
import earthkit.data as ekd
11+
import numpy as np
12+
13+
from anemoi.transform.filters.repeat_members import RepeatMembers
14+
15+
16+
def _get_template():
17+
temp = ekd.from_source("mars", {"param": "2t", "levtype": "sfc", "dates": ["2023-11-17 00:00:00"]})
18+
fieldlist = temp.to_fieldlist()
19+
return fieldlist, fieldlist[0].values, fieldlist[0].metadata
20+
21+
22+
def test_repeat_members_using_numbers_1():
23+
fieldlist, values, metadata = _get_template()
24+
25+
repeat = RepeatMembers(numbers=[1, 2, 3])
26+
repeated = repeat.forward(fieldlist)
27+
assert len(repeated) == 3
28+
for i, f in enumerate(repeated):
29+
assert f.values.shape == values.shape
30+
assert np.all(f.values == values)
31+
assert f.metadata("number") == i + 1
32+
assert f.metadata("name") == metadata("name")
33+
34+
35+
def test_repeat_members_using_numbers_2():
36+
fieldlist, values, metadata = _get_template()
37+
38+
repeat = RepeatMembers(numbers="1/to/3")
39+
repeated = repeat.forward(fieldlist)
40+
assert len(repeated) == 3
41+
for i, f in enumerate(repeated):
42+
assert f.values.shape == values.shape
43+
assert np.all(f.values == values)
44+
assert f.metadata("number") == i + 1
45+
assert f.metadata("name") == metadata("name")
46+
47+
48+
def test_repeat_members_using_members():
49+
fieldlist, values, metadata = _get_template()
50+
51+
repeat = RepeatMembers(members=[0, 1, 2])
52+
repeated = repeat.forward(fieldlist)
53+
assert len(repeated) == 3
54+
for i, f in enumerate(repeated):
55+
assert f.values.shape == values.shape
56+
assert np.all(f.values == values)
57+
assert f.metadata("number") == i + 1
58+
assert f.metadata("name") == metadata("name")
59+
60+
61+
def test_repeat_members_using_count():
62+
fieldlist, values, metadata = _get_template()
63+
64+
repeat = RepeatMembers(count=3)
65+
repeated = repeat.forward(fieldlist)
66+
assert len(repeated) == 3
67+
for i, f in enumerate(repeated):
68+
assert f.values.shape == values.shape
69+
assert np.all(f.values == values)
70+
assert f.metadata("number") == i + 1
71+
assert f.metadata("name") == metadata("name")

0 commit comments

Comments
 (0)