-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
activation functions added for bounded outputs (#14)
* activation functions added for bounded outputs * generalised fraction normalisation * precommit changes * precommit changes * chore: mv bounding to post-processors * feat: make bounding strategies torch Module * refactor: pre-build indices during init * refactor: nn.module in place modification * refactor: mv bounding to layers * refactor: implement bounding ModuleList * docs: add changelog * refactor: mv bounding config to models * feat: enable 1-1 variable remapping in preprocessors * test: create tests and fix code * test: add a test for hydra instantiating of bounding * fix: naming * refactor: reduce verboseness * docs: add comments * fix: inject name_to_index on initiation * fixed reading of statistics remapping * revert * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_preprocessor_normalizer.py * Update encoder_processor_decoder.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * docs: added a comment on special keys * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * docs: add docstrings * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: changelog --------- Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int>
- Loading branch information
1 parent
20e5df2
commit 2b1d2ce
Showing
7 changed files
with
319 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC | ||
from abc import abstractmethod | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from anemoi.models.data_indices.tensor import InputTensorIndex | ||
|
||
|
||
class BaseBounding(nn.Module, ABC): | ||
"""Abstract base class for bounding strategies. | ||
This class defines an interface for bounding strategies which are used to apply a specific | ||
restriction to the predictions of a model. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
variables: list[str], | ||
name_to_index: dict, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.name_to_index = name_to_index | ||
self.variables = variables | ||
self.data_index = self._create_index(variables=self.variables) | ||
|
||
def _create_index(self, variables: list[str]) -> InputTensorIndex: | ||
return InputTensorIndex(includes=variables, excludes=[], name_to_index=self.name_to_index)._only | ||
|
||
@abstractmethod | ||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
"""Applies the bounding to the predictions. | ||
Parameters | ||
---------- | ||
x : torch.Tensor | ||
The tensor containing the predictions that will be bounded. | ||
Returns | ||
------- | ||
torch.Tensor | ||
A tensor with the bounding applied. | ||
""" | ||
pass | ||
|
||
|
||
class ReluBounding(BaseBounding): | ||
"""Initializes the bounding with a ReLU activation / zero clamping.""" | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x[..., self.data_index] = torch.nn.functional.relu(x[..., self.data_index]) | ||
return x | ||
|
||
|
||
class HardtanhBounding(BaseBounding): | ||
"""Initializes the bounding with specified minimum and maximum values for bounding. | ||
Parameters | ||
---------- | ||
variables : list[str] | ||
A list of strings representing the variables that will be bounded. | ||
name_to_index : dict | ||
A dictionary mapping the variable names to their corresponding indices. | ||
min_val : float | ||
The minimum value for the HardTanh activation. | ||
max_val : float | ||
The maximum value for the HardTanh activation. | ||
""" | ||
|
||
def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float) -> None: | ||
super().__init__(variables=variables, name_to_index=name_to_index) | ||
self.min_val = min_val | ||
self.max_val = max_val | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x[..., self.data_index] = torch.nn.functional.hardtanh( | ||
x[..., self.data_index], min_val=self.min_val, max_val=self.max_val | ||
) | ||
return x | ||
|
||
|
||
class FractionBounding(HardtanhBounding): | ||
"""Initializes the FractionBounding with specified parameters. | ||
Parameters | ||
---------- | ||
variables : list[str] | ||
A list of strings representing the variables that will be bounded. | ||
name_to_index : dict | ||
A dictionary mapping the variable names to their corresponding indices. | ||
min_val : float | ||
The minimum value for the HardTanh activation. | ||
max_val : float | ||
The maximum value for the HardTanh activation. | ||
total_var : str | ||
A string representing a variable from which a secondary variable is derived. For | ||
example, in the case of convective precipitation (Cp), total_var = Tp (total precipitation). | ||
""" | ||
|
||
def __init__( | ||
self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, total_var: str | ||
) -> None: | ||
super().__init__(variables=variables, name_to_index=name_to_index, min_val=min_val, max_val=max_val) | ||
self.total_variable = self._create_index(variables=[total_var]) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
# Apply the HardTanh bounding to the data_index variables | ||
x = super().forward(x) | ||
# Calculate the fraction of the total variable | ||
x[..., self.data_index] *= x[..., self.total_variable] | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import pytest | ||
import torch | ||
from anemoi.utils.config import DotDict | ||
from hydra.utils import instantiate | ||
|
||
from anemoi.models.layers.bounding import FractionBounding | ||
from anemoi.models.layers.bounding import HardtanhBounding | ||
from anemoi.models.layers.bounding import ReluBounding | ||
|
||
|
||
@pytest.fixture | ||
def config(): | ||
return DotDict({"variables": ["var1", "var2"], "total_var": "total_var"}) | ||
|
||
|
||
@pytest.fixture | ||
def name_to_index(): | ||
return {"var1": 0, "var2": 1, "total_var": 2} | ||
|
||
|
||
@pytest.fixture | ||
def input_tensor(): | ||
return torch.tensor([[-1.0, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]]) | ||
|
||
|
||
def test_relu_bounding(config, name_to_index, input_tensor): | ||
bounding = ReluBounding(variables=config.variables, name_to_index=name_to_index) | ||
output = bounding(input_tensor.clone()) | ||
expected_output = torch.tensor([[0.0, 2.0, 3.0], [4.0, 0.0, 6.0], [0.5, 0.5, 0.5]]) | ||
assert torch.equal(output, expected_output) | ||
|
||
|
||
def test_hardtanh_bounding(config, name_to_index, input_tensor): | ||
minimum, maximum = -1.0, 1.0 | ||
bounding = HardtanhBounding( | ||
variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum | ||
) | ||
output = bounding(input_tensor.clone()) | ||
expected_output = torch.tensor([[minimum, maximum, 3.0], [maximum, minimum, 6.0], [0.5, 0.5, 0.5]]) | ||
assert torch.equal(output, expected_output) | ||
|
||
|
||
def test_fraction_bounding(config, name_to_index, input_tensor): | ||
bounding = FractionBounding( | ||
variables=config.variables, name_to_index=name_to_index, min_val=0.0, max_val=1.0, total_var=config.total_var | ||
) | ||
output = bounding(input_tensor.clone()) | ||
expected_output = torch.tensor([[0.0, 3.0, 3.0], [6.0, 0.0, 6.0], [0.25, 0.25, 0.5]]) | ||
|
||
assert torch.equal(output, expected_output) | ||
|
||
|
||
def test_multi_chained_bounding(config, name_to_index, input_tensor): | ||
# Apply Relu first on the first variable only | ||
bounding1 = ReluBounding(variables=config.variables[:-1], name_to_index=name_to_index) | ||
expected_output = torch.tensor([[0.0, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]]) | ||
# Check intemediate result | ||
assert torch.equal(bounding1(input_tensor.clone()), expected_output) | ||
minimum, maximum = 0.5, 1.75 | ||
bounding2 = HardtanhBounding( | ||
variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum | ||
) | ||
# Use full chaining on the input tensor | ||
output = bounding2(bounding1(input_tensor.clone())) | ||
# Data with Relu applied first and then Hardtanh | ||
expected_output = torch.tensor([[minimum, maximum, 3.0], [maximum, minimum, 6.0], [0.5, 0.5, 0.5]]) | ||
assert torch.equal(output, expected_output) | ||
|
||
|
||
def test_hydra_instantiate_bounding(config, name_to_index, input_tensor): | ||
layer_definitions = [ | ||
{ | ||
"_target_": "anemoi.models.layers.bounding.ReluBounding", | ||
"variables": config.variables, | ||
}, | ||
{ | ||
"_target_": "anemoi.models.layers.bounding.HardtanhBounding", | ||
"variables": config.variables, | ||
"min_val": 0.0, | ||
"max_val": 1.0, | ||
}, | ||
{ | ||
"_target_": "anemoi.models.layers.bounding.FractionBounding", | ||
"variables": config.variables, | ||
"min_val": 0.0, | ||
"max_val": 1.0, | ||
"total_var": config.total_var, | ||
}, | ||
] | ||
for layer_definition in layer_definitions: | ||
bounding = instantiate(layer_definition, name_to_index=name_to_index) | ||
bounding(input_tensor.clone()) |
Oops, something went wrong.