Skip to content

Commit

Permalink
activation functions added for bounded outputs (#14)
Browse files Browse the repository at this point in the history
* activation functions added for bounded outputs

* generalised fraction normalisation

* precommit changes

* precommit changes

* chore: mv bounding to post-processors

* feat: make bounding strategies torch Module

* refactor: pre-build indices during init

* refactor: nn.module in place modification

* refactor: mv bounding to layers

* refactor: implement bounding ModuleList

* docs: add changelog

* refactor: mv bounding config to models

* feat: enable 1-1 variable remapping in preprocessors

* test: create tests and fix code

* test: add a test for hydra instantiating of bounding

* fix: naming

* refactor: reduce verboseness

* docs: add comments

* fix: inject name_to_index on initiation

* fixed reading of statistics remapping

* revert

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_preprocessor_normalizer.py

* Update encoder_processor_decoder.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* docs: added a comment on special keys

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* docs: add docstrings

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: changelog

---------

Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int>
  • Loading branch information
2 people authored and theissenhelen committed Sep 27, 2024
1 parent 20e5df2 commit 2b1d2ce
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 9 deletions.
21 changes: 13 additions & 8 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,31 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-models/compare/0.3.0...HEAD)

## [0.3.0](https://github.com/ecmwf/anemoi-models/compare/0.2.1...0.3.0) - Remapping of (meteorological) Variables

### Added

- CI workflow to update the changelog on release
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- CI workflow to update the changelog on release
- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords.
- Codeowners file
- Pygrep precommit hooks
- Docsig precommit hooks
- Changelog merge strategy
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13)

### Changed
- Bugfixes for CI

### Removed

## [0.3.0](https://github.com/ecmwf/anemoi-models/compare/0.2.1...0.3.0) - Remapping of (meteorological) Variables

### Added

- CI workflow to update the changelog on release
- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords.

### Changed

- Update CI to inherit from common infrastructue reusable workflows
- run downstream-ci only when src and tests folders have changed
- New error messages for wrongs graphs.
- Bugfixes for CI

### Removed

Expand Down
115 changes: 115 additions & 0 deletions src/anemoi/models/layers/bounding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

from abc import ABC
from abc import abstractmethod

import torch
from torch import nn

from anemoi.models.data_indices.tensor import InputTensorIndex


class BaseBounding(nn.Module, ABC):
"""Abstract base class for bounding strategies.
This class defines an interface for bounding strategies which are used to apply a specific
restriction to the predictions of a model.
"""

def __init__(
self,
*,
variables: list[str],
name_to_index: dict,
) -> None:
super().__init__()

self.name_to_index = name_to_index
self.variables = variables
self.data_index = self._create_index(variables=self.variables)

def _create_index(self, variables: list[str]) -> InputTensorIndex:
return InputTensorIndex(includes=variables, excludes=[], name_to_index=self.name_to_index)._only

@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the bounding to the predictions.
Parameters
----------
x : torch.Tensor
The tensor containing the predictions that will be bounded.
Returns
-------
torch.Tensor
A tensor with the bounding applied.
"""
pass


class ReluBounding(BaseBounding):
"""Initializes the bounding with a ReLU activation / zero clamping."""

def forward(self, x: torch.Tensor) -> torch.Tensor:
x[..., self.data_index] = torch.nn.functional.relu(x[..., self.data_index])
return x


class HardtanhBounding(BaseBounding):
"""Initializes the bounding with specified minimum and maximum values for bounding.
Parameters
----------
variables : list[str]
A list of strings representing the variables that will be bounded.
name_to_index : dict
A dictionary mapping the variable names to their corresponding indices.
min_val : float
The minimum value for the HardTanh activation.
max_val : float
The maximum value for the HardTanh activation.
"""

def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float) -> None:
super().__init__(variables=variables, name_to_index=name_to_index)
self.min_val = min_val
self.max_val = max_val

def forward(self, x: torch.Tensor) -> torch.Tensor:
x[..., self.data_index] = torch.nn.functional.hardtanh(
x[..., self.data_index], min_val=self.min_val, max_val=self.max_val
)
return x


class FractionBounding(HardtanhBounding):
"""Initializes the FractionBounding with specified parameters.
Parameters
----------
variables : list[str]
A list of strings representing the variables that will be bounded.
name_to_index : dict
A dictionary mapping the variable names to their corresponding indices.
min_val : float
The minimum value for the HardTanh activation.
max_val : float
The maximum value for the HardTanh activation.
total_var : str
A string representing a variable from which a secondary variable is derived. For
example, in the case of convective precipitation (Cp), total_var = Tp (total precipitation).
"""

def __init__(
self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, total_var: str
) -> None:
super().__init__(variables=variables, name_to_index=name_to_index, min_val=min_val, max_val=max_val)
self.total_variable = self._create_index(variables=[total_var])

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Apply the HardTanh bounding to the data_index variables
x = super().forward(x)
# Calculate the fraction of the total variable
x[..., self.data_index] *= x[..., self.total_variable]
return x
15 changes: 15 additions & 0 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def __init__(
self._register_latlon("data", self._graph_name_data)
self._register_latlon("hidden", self._graph_name_hidden)

self.data_indices = data_indices

self.num_channels = config.model.num_channels

input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size
Expand Down Expand Up @@ -103,6 +105,14 @@ def __init__(
dst_grid_size=self._data_grid_size,
)

# Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite)
self.boundings = nn.ModuleList(
[
instantiate(cfg, name_to_index=self.data_indices.model.output.name_to_index)
for cfg in getattr(config.model, "bounding", [])
]
)

def _calculate_shapes_and_indices(self, data_indices: dict) -> None:
self.num_input_channels = len(data_indices.internal_model.input)
self.num_output_channels = len(data_indices.internal_model.output)
Expand Down Expand Up @@ -251,4 +261,9 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->

# residual connection (just for the prognostic variables)
x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx]

for bounding in self.boundings:
# bounding performed in the order specified in the config file
x_out = bounding(x_out)

return x_out
20 changes: 19 additions & 1 deletion src/anemoi/models/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,23 @@ def __init__(
Data indices for input and output variables
statistics : dict
Data statistics dictionary
data_indices : dict
Data indices for input and output variables
Attributes
----------
default : str
Default method for variables not specified in the config
method_config : dict
Dictionary of the methods with lists of variables
methods : dict
Dictionary of the variables with methods
data_indices : IndexCollection
Data indices for input and output variables
remap : dict
Dictionary of the variables with remapped names in the config
"""

super().__init__()

self.default, self.method_config = self._process_config(config)
Expand All @@ -47,8 +63,10 @@ def __init__(
self.data_indices = data_indices

def _process_config(self, config):
_special_keys = ["default", "remap"] # Keys that do not contain a list of variables in a preprocessing method.
default = config.get("default", "none")
method_config = {k: v for k, v in config.items() if k != "default" and v is not None and v != "none"}
self.remap = config.get("remap", {})
method_config = {k: v for k, v in config.items() if k not in _special_keys and v is not None and v != "none"}

if not method_config:
LOGGER.warning(
Expand Down
22 changes: 22 additions & 0 deletions src/anemoi/models/preprocessing/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,38 @@ def __init__(
mean = statistics["mean"]
stdev = statistics["stdev"]

# Optionally reuse statistic of one variable for another variable
statistics_remap = {}
for remap, source in self.remap.items():
idx_src, idx_remap = name_to_index_training_input[source], name_to_index_training_input[remap]
statistics_remap[idx_remap] = (minimum[idx_src], maximum[idx_src], mean[idx_src], stdev[idx_src])

# Two-step to avoid overwriting the original statistics in the loop (this reduces dependence on order)
for idx, new_stats in statistics_remap.items():
minimum[idx], maximum[idx], mean[idx], stdev[idx] = new_stats

self._validate_normalization_inputs(name_to_index_training_input, minimum, maximum, mean, stdev)

_norm_add = np.zeros((minimum.size,), dtype=np.float32)
_norm_mul = np.ones((minimum.size,), dtype=np.float32)

for name, i in name_to_index_training_input.items():
method = self.methods.get(name, self.default)

if method == "mean-std":
LOGGER.debug(f"Normalizing: {name} is mean-std-normalised.")
if stdev[i] < (mean[i] * 1e-6):
warnings.warn(f"Normalizing: the field seems to have only one value {mean[i]}")
_norm_mul[i] = 1 / stdev[i]
_norm_add[i] = -mean[i] / stdev[i]

elif method == "std":
LOGGER.debug(f"Normalizing: {name} is std-normalised.")
if stdev[i] < (mean[i] * 1e-6):
warnings.warn(f"Normalizing: the field seems to have only one value {mean[i]}")
_norm_mul[i] = 1 / stdev[i]
_norm_add[i] = 0

elif method == "min-max":
LOGGER.debug(f"Normalizing: {name} is min-max-normalised to [0, 1].")
x = maximum[i] - minimum[i]
Expand Down Expand Up @@ -92,16 +110,20 @@ def _validate_normalization_inputs(self, name_to_index_training_input: dict, min
f"Error parsing methods in InputNormalizer methods ({len(self.methods)}) "
f"and entries in config ({sum(len(v) for v in self.method_config)}) do not match."
)

# Check that all sizes align
n = minimum.size
assert maximum.size == n, (maximum.size, n)
assert mean.size == n, (mean.size, n)
assert stdev.size == n, (stdev.size, n)

# Check for typos in method config
assert isinstance(self.methods, dict)
for name, method in self.methods.items():
assert name in name_to_index_training_input, f"{name} is not a valid variable name"
assert method in [
"mean-std",
"std",
# "robust",
"min-max",
"max",
Expand Down
92 changes: 92 additions & 0 deletions tests/layers/test_bounding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate

from anemoi.models.layers.bounding import FractionBounding
from anemoi.models.layers.bounding import HardtanhBounding
from anemoi.models.layers.bounding import ReluBounding


@pytest.fixture
def config():
return DotDict({"variables": ["var1", "var2"], "total_var": "total_var"})


@pytest.fixture
def name_to_index():
return {"var1": 0, "var2": 1, "total_var": 2}


@pytest.fixture
def input_tensor():
return torch.tensor([[-1.0, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]])


def test_relu_bounding(config, name_to_index, input_tensor):
bounding = ReluBounding(variables=config.variables, name_to_index=name_to_index)
output = bounding(input_tensor.clone())
expected_output = torch.tensor([[0.0, 2.0, 3.0], [4.0, 0.0, 6.0], [0.5, 0.5, 0.5]])
assert torch.equal(output, expected_output)


def test_hardtanh_bounding(config, name_to_index, input_tensor):
minimum, maximum = -1.0, 1.0
bounding = HardtanhBounding(
variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum
)
output = bounding(input_tensor.clone())
expected_output = torch.tensor([[minimum, maximum, 3.0], [maximum, minimum, 6.0], [0.5, 0.5, 0.5]])
assert torch.equal(output, expected_output)


def test_fraction_bounding(config, name_to_index, input_tensor):
bounding = FractionBounding(
variables=config.variables, name_to_index=name_to_index, min_val=0.0, max_val=1.0, total_var=config.total_var
)
output = bounding(input_tensor.clone())
expected_output = torch.tensor([[0.0, 3.0, 3.0], [6.0, 0.0, 6.0], [0.25, 0.25, 0.5]])

assert torch.equal(output, expected_output)


def test_multi_chained_bounding(config, name_to_index, input_tensor):
# Apply Relu first on the first variable only
bounding1 = ReluBounding(variables=config.variables[:-1], name_to_index=name_to_index)
expected_output = torch.tensor([[0.0, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]])
# Check intemediate result
assert torch.equal(bounding1(input_tensor.clone()), expected_output)
minimum, maximum = 0.5, 1.75
bounding2 = HardtanhBounding(
variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum
)
# Use full chaining on the input tensor
output = bounding2(bounding1(input_tensor.clone()))
# Data with Relu applied first and then Hardtanh
expected_output = torch.tensor([[minimum, maximum, 3.0], [maximum, minimum, 6.0], [0.5, 0.5, 0.5]])
assert torch.equal(output, expected_output)


def test_hydra_instantiate_bounding(config, name_to_index, input_tensor):
layer_definitions = [
{
"_target_": "anemoi.models.layers.bounding.ReluBounding",
"variables": config.variables,
},
{
"_target_": "anemoi.models.layers.bounding.HardtanhBounding",
"variables": config.variables,
"min_val": 0.0,
"max_val": 1.0,
},
{
"_target_": "anemoi.models.layers.bounding.FractionBounding",
"variables": config.variables,
"min_val": 0.0,
"max_val": 1.0,
"total_var": config.total_var,
},
]
for layer_definition in layer_definitions:
bounding = instantiate(layer_definition, name_to_index=name_to_index)
bounding(input_tensor.clone())
Loading

0 comments on commit 2b1d2ce

Please sign in to comment.