Skip to content

Commit

Permalink
implement general purpose mask filter (#20)
Browse files Browse the repository at this point in the history
* add apply_mask filter
  • Loading branch information
NRaoult authored Dec 20, 2024
1 parent 36d5ca3 commit 7cbf5f3
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions src/anemoi/transform/filters/apply_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import earthkit.data as ekd
import numpy as np

from ..fields import new_field_from_numpy
from ..fields import new_fieldlist_from_list
from ..filter import Filter
from . import filter_registry


@filter_registry.register("apply_mask")
class MaskVariable(Filter):
"""A filter to mask variables using external file."""

def __init__(
self,
*,
path,
mask_value=1,
threshold=None,
rename=None,
):

mask = ekd.from_source("file", path)[0].to_numpy().astype(bool)

if threshold is not None:
self._mask = mask > threshold
else:
self._mask = mask == mask_value

self._rename = rename

def forward(self, data):

result = []
extra = {}
for field in data:

values = field.to_numpy(flatten=True)
values[self._mask] = np.nan

if self._rename is not None:
param = field.metadata("param")
name = f"{param}_{self._rename}"
extra["param"] = name

result.append(new_field_from_numpy(values, template=field, **extra))

return new_fieldlist_from_list(result)

def backward(self, data):
raise NotImplementedError("`apply_mask` is not reversible")

0 comments on commit 7cbf5f3

Please sign in to comment.