Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

activation functions added for bounded outputs #14

Merged
merged 35 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
da2ffcc
activation functions added for bounded outputs
gabrieloks Aug 1, 2024
2418ca4
generalised fraction normalisation
gabrieloks Aug 1, 2024
eea2c26
precommit changes
gabrieloks Aug 2, 2024
8eaf33f
precommit changes
gabrieloks Aug 2, 2024
1095042
chore: mv bounding to post-processors
JesperDramsch Aug 27, 2024
ce18cd0
feat: make bounding strategies torch Module
JesperDramsch Aug 27, 2024
2541529
refactor: pre-build indices during init
JesperDramsch Aug 28, 2024
f4fd930
refactor: nn.module in place modification
JesperDramsch Aug 28, 2024
c163dc0
refactor: mv bounding to layers
JesperDramsch Aug 28, 2024
a5c76b2
refactor: implement bounding ModuleList
JesperDramsch Aug 28, 2024
85d45ae
Merge branch 'develop' into feature/activation_function
JesperDramsch Aug 28, 2024
3fdc66d
docs: add changelog
JesperDramsch Aug 28, 2024
0cde121
refactor: mv bounding config to models
JesperDramsch Aug 28, 2024
19c8ad6
feat: enable 1-1 variable remapping in preprocessors
JesperDramsch Aug 29, 2024
703c8fc
test: create tests and fix code
JesperDramsch Aug 29, 2024
6c8ff12
test: add a test for hydra instantiating of bounding
JesperDramsch Aug 29, 2024
04eb1d0
fix: naming
JesperDramsch Aug 29, 2024
9bbb22a
refactor: reduce verboseness
JesperDramsch Aug 29, 2024
04419b2
docs: add comments
JesperDramsch Aug 29, 2024
929732d
fix: inject name_to_index on initiation
JesperDramsch Aug 29, 2024
16ade80
fixed reading of statistics remapping
gabrieloks Sep 6, 2024
24867f9
Merge branch 'develop' into feature/activation_function
gabrieloks Sep 6, 2024
6e6fa85
Merge branch 'develop' into feature/activation_function
gabrieloks Sep 6, 2024
65b142d
revert
gabrieloks Sep 6, 2024
1d68538
Merge branch 'develop' into feature/activation_function
gabrieloks Sep 10, 2024
68d7394
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2024
58f1b7e
Update test_preprocessor_normalizer.py
gabrieloks Sep 10, 2024
5823c66
Update encoder_processor_decoder.py
gabrieloks Sep 11, 2024
1555b7b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 11, 2024
3d562cb
docs: added a comment on special keys
JesperDramsch Sep 23, 2024
d5616c9
Merge remote-tracking branch 'origin/develop' into feature/activation…
JesperDramsch Sep 23, 2024
7da5bbe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2024
9c28777
docs: add docstrings
JesperDramsch Sep 23, 2024
1c7a99c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2024
5180c8a
fix: changelog
JesperDramsch Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,31 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-models/compare/0.3.0...HEAD)

## [0.3.0](https://github.com/ecmwf/anemoi-models/compare/0.2.1...0.3.0) - Remapping of (meteorological) Variables

### Added

- CI workflow to update the changelog on release
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- CI workflow to update the changelog on release
- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords.
- Codeowners file
- Pygrep precommit hooks
- Docsig precommit hooks
- Changelog merge strategy
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13)

### Changed
- Bugfixes for CI

### Removed

## [0.3.0](https://github.com/ecmwf/anemoi-models/compare/0.2.1...0.3.0) - Remapping of (meteorological) Variables

### Added

- CI workflow to update the changelog on release
- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords.

### Changed

- Update CI to inherit from common infrastructue reusable workflows
- run downstream-ci only when src and tests folders have changed
- New error messages for wrongs graphs.
- Bugfixes for CI

### Removed

Expand Down
115 changes: 115 additions & 0 deletions src/anemoi/models/layers/bounding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

from abc import ABC
from abc import abstractmethod

import torch
from torch import nn

from anemoi.models.data_indices.tensor import InputTensorIndex


class BaseBounding(nn.Module, ABC):
"""Abstract base class for bounding strategies.

This class defines an interface for bounding strategies which are used to apply a specific
restriction to the predictions of a model.
"""

def __init__(
self,
*,
variables: list[str],
name_to_index: dict,
) -> None:
super().__init__()

self.name_to_index = name_to_index
self.variables = variables
self.data_index = self._create_index(variables=self.variables)

def _create_index(self, variables: list[str]) -> InputTensorIndex:
return InputTensorIndex(includes=variables, excludes=[], name_to_index=self.name_to_index)._only

@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the bounding to the predictions.

Parameters
----------
x : torch.Tensor
The tensor containing the predictions that will be bounded.

Returns
-------
torch.Tensor
A tensor with the bounding applied.
"""
pass


class ReluBounding(BaseBounding):
"""Initializes the bounding with a ReLU activation / zero clamping."""

def forward(self, x: torch.Tensor) -> torch.Tensor:
x[..., self.data_index] = torch.nn.functional.relu(x[..., self.data_index])
return x


class HardtanhBounding(BaseBounding):
"""Initializes the bounding with specified minimum and maximum values for bounding.

Parameters
----------
variables : list[str]
A list of strings representing the variables that will be bounded.
name_to_index : dict
A dictionary mapping the variable names to their corresponding indices.
min_val : float
The minimum value for the HardTanh activation.
max_val : float
The maximum value for the HardTanh activation.
"""

def __init__(self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float) -> None:
super().__init__(variables=variables, name_to_index=name_to_index)
self.min_val = min_val
self.max_val = max_val

def forward(self, x: torch.Tensor) -> torch.Tensor:
x[..., self.data_index] = torch.nn.functional.hardtanh(
x[..., self.data_index], min_val=self.min_val, max_val=self.max_val
)
return x


class FractionBounding(HardtanhBounding):
"""Initializes the FractionBounding with specified parameters.

Parameters
----------
variables : list[str]
A list of strings representing the variables that will be bounded.
name_to_index : dict
A dictionary mapping the variable names to their corresponding indices.
min_val : float
The minimum value for the HardTanh activation.
max_val : float
The maximum value for the HardTanh activation.
total_var : str
A string representing a variable from which a secondary variable is derived. For
example, in the case of convective precipitation (Cp), total_var = Tp (total precipitation).
"""

def __init__(
self, *, variables: list[str], name_to_index: dict, min_val: float, max_val: float, total_var: str
) -> None:
super().__init__(variables=variables, name_to_index=name_to_index, min_val=min_val, max_val=max_val)
self.total_variable = self._create_index(variables=[total_var])

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Apply the HardTanh bounding to the data_index variables
x = super().forward(x)
# Calculate the fraction of the total variable
x[..., self.data_index] *= x[..., self.total_variable]
return x
15 changes: 15 additions & 0 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def __init__(
self._register_latlon("data", self._graph_name_data)
self._register_latlon("hidden", self._graph_name_hidden)

self.data_indices = data_indices

self.num_channels = config.model.num_channels

input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size
Expand Down Expand Up @@ -103,6 +105,14 @@ def __init__(
dst_grid_size=self._data_grid_size,
)

# Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite)
self.boundings = nn.ModuleList(
mchantry marked this conversation as resolved.
Show resolved Hide resolved
[
instantiate(cfg, name_to_index=self.data_indices.model.output.name_to_index)
for cfg in getattr(config.model, "bounding", [])
]
)

def _calculate_shapes_and_indices(self, data_indices: dict) -> None:
self.num_input_channels = len(data_indices.internal_model.input)
self.num_output_channels = len(data_indices.internal_model.output)
Expand Down Expand Up @@ -251,4 +261,9 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->

# residual connection (just for the prognostic variables)
x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx]

for bounding in self.boundings:
# bounding performed in the order specified in the config file
x_out = bounding(x_out)

return x_out
20 changes: 19 additions & 1 deletion src/anemoi/models/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,23 @@ def __init__(
Data indices for input and output variables
statistics : dict
Data statistics dictionary
data_indices : dict
Data indices for input and output variables

Attributes
----------
default : str
Default method for variables not specified in the config
method_config : dict
Dictionary of the methods with lists of variables
methods : dict
Dictionary of the variables with methods
data_indices : IndexCollection
Data indices for input and output variables
remap : dict
Dictionary of the variables with remapped names in the config
"""

super().__init__()

self.default, self.method_config = self._process_config(config)
Expand All @@ -47,8 +63,10 @@ def __init__(
self.data_indices = data_indices

def _process_config(self, config):
JesperDramsch marked this conversation as resolved.
Show resolved Hide resolved
_special_keys = ["default", "remap"] # Keys that do not contain a list of variables in a preprocessing method.
default = config.get("default", "none")
method_config = {k: v for k, v in config.items() if k != "default" and v is not None and v != "none"}
self.remap = config.get("remap", {})
method_config = {k: v for k, v in config.items() if k not in _special_keys and v is not None and v != "none"}

if not method_config:
LOGGER.warning(
Expand Down
22 changes: 22 additions & 0 deletions src/anemoi/models/preprocessing/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,38 @@ def __init__(
mean = statistics["mean"]
stdev = statistics["stdev"]

# Optionally reuse statistic of one variable for another variable
statistics_remap = {}
for remap, source in self.remap.items():
idx_src, idx_remap = name_to_index_training_input[source], name_to_index_training_input[remap]
statistics_remap[idx_remap] = (minimum[idx_src], maximum[idx_src], mean[idx_src], stdev[idx_src])

# Two-step to avoid overwriting the original statistics in the loop (this reduces dependence on order)
for idx, new_stats in statistics_remap.items():
minimum[idx], maximum[idx], mean[idx], stdev[idx] = new_stats

self._validate_normalization_inputs(name_to_index_training_input, minimum, maximum, mean, stdev)

_norm_add = np.zeros((minimum.size,), dtype=np.float32)
_norm_mul = np.ones((minimum.size,), dtype=np.float32)

for name, i in name_to_index_training_input.items():
method = self.methods.get(name, self.default)

if method == "mean-std":
LOGGER.debug(f"Normalizing: {name} is mean-std-normalised.")
if stdev[i] < (mean[i] * 1e-6):
warnings.warn(f"Normalizing: the field seems to have only one value {mean[i]}")
_norm_mul[i] = 1 / stdev[i]
_norm_add[i] = -mean[i] / stdev[i]

elif method == "std":
LOGGER.debug(f"Normalizing: {name} is std-normalised.")
if stdev[i] < (mean[i] * 1e-6):
warnings.warn(f"Normalizing: the field seems to have only one value {mean[i]}")
_norm_mul[i] = 1 / stdev[i]
_norm_add[i] = 0

elif method == "min-max":
LOGGER.debug(f"Normalizing: {name} is min-max-normalised to [0, 1].")
x = maximum[i] - minimum[i]
Expand Down Expand Up @@ -92,16 +110,20 @@ def _validate_normalization_inputs(self, name_to_index_training_input: dict, min
f"Error parsing methods in InputNormalizer methods ({len(self.methods)}) "
f"and entries in config ({sum(len(v) for v in self.method_config)}) do not match."
)

# Check that all sizes align
n = minimum.size
assert maximum.size == n, (maximum.size, n)
assert mean.size == n, (mean.size, n)
assert stdev.size == n, (stdev.size, n)

# Check for typos in method config
assert isinstance(self.methods, dict)
for name, method in self.methods.items():
assert name in name_to_index_training_input, f"{name} is not a valid variable name"
assert method in [
"mean-std",
"std",
# "robust",
"min-max",
"max",
Expand Down
92 changes: 92 additions & 0 deletions tests/layers/test_bounding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate

from anemoi.models.layers.bounding import FractionBounding
from anemoi.models.layers.bounding import HardtanhBounding
from anemoi.models.layers.bounding import ReluBounding


@pytest.fixture
def config():
return DotDict({"variables": ["var1", "var2"], "total_var": "total_var"})


@pytest.fixture
def name_to_index():
return {"var1": 0, "var2": 1, "total_var": 2}


@pytest.fixture
def input_tensor():
return torch.tensor([[-1.0, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]])


def test_relu_bounding(config, name_to_index, input_tensor):
bounding = ReluBounding(variables=config.variables, name_to_index=name_to_index)
output = bounding(input_tensor.clone())
expected_output = torch.tensor([[0.0, 2.0, 3.0], [4.0, 0.0, 6.0], [0.5, 0.5, 0.5]])
assert torch.equal(output, expected_output)


def test_hardtanh_bounding(config, name_to_index, input_tensor):
minimum, maximum = -1.0, 1.0
bounding = HardtanhBounding(
variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum
)
output = bounding(input_tensor.clone())
expected_output = torch.tensor([[minimum, maximum, 3.0], [maximum, minimum, 6.0], [0.5, 0.5, 0.5]])
assert torch.equal(output, expected_output)


def test_fraction_bounding(config, name_to_index, input_tensor):
bounding = FractionBounding(
variables=config.variables, name_to_index=name_to_index, min_val=0.0, max_val=1.0, total_var=config.total_var
)
output = bounding(input_tensor.clone())
expected_output = torch.tensor([[0.0, 3.0, 3.0], [6.0, 0.0, 6.0], [0.25, 0.25, 0.5]])

assert torch.equal(output, expected_output)


def test_multi_chained_bounding(config, name_to_index, input_tensor):
# Apply Relu first on the first variable only
bounding1 = ReluBounding(variables=config.variables[:-1], name_to_index=name_to_index)
expected_output = torch.tensor([[0.0, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]])
# Check intemediate result
assert torch.equal(bounding1(input_tensor.clone()), expected_output)
minimum, maximum = 0.5, 1.75
bounding2 = HardtanhBounding(
variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum
)
# Use full chaining on the input tensor
output = bounding2(bounding1(input_tensor.clone()))
# Data with Relu applied first and then Hardtanh
expected_output = torch.tensor([[minimum, maximum, 3.0], [maximum, minimum, 6.0], [0.5, 0.5, 0.5]])
assert torch.equal(output, expected_output)


def test_hydra_instantiate_bounding(config, name_to_index, input_tensor):
layer_definitions = [
{
"_target_": "anemoi.models.layers.bounding.ReluBounding",
"variables": config.variables,
},
{
"_target_": "anemoi.models.layers.bounding.HardtanhBounding",
"variables": config.variables,
"min_val": 0.0,
"max_val": 1.0,
},
{
"_target_": "anemoi.models.layers.bounding.FractionBounding",
"variables": config.variables,
"min_val": 0.0,
"max_val": 1.0,
"total_var": config.total_var,
},
]
for layer_definition in layer_definitions:
bounding = instantiate(layer_definition, name_to_index=name_to_index)
bounding(input_tensor.clone())
Loading
Loading