diff --git a/src/anemoi/graphs/__init__.py b/src/anemoi/graphs/__init__.py index 80d19fb..715b8a4 100644 --- a/src/anemoi/graphs/__init__.py +++ b/src/anemoi/graphs/__init__.py @@ -7,4 +7,4 @@ from ._version import __version__ -earth_radius = 6371.0 # km +EARTH_RADIUS = 6371.0 # km diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index d2b803d..da1692e 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -10,13 +10,25 @@ def generate_graph(graph_config: DotDict) -> HeteroData: + """Generate a graph from a configuration. + + Parameters + ---------- + graph_config : DotDict + Configuration for the nodes and edges (and its attributes). + + Returns + ------- + HeteroData + Graph. + """ graph = HeteroData() for name, nodes_cfg in graph_config.nodes.items(): - graph = instantiate(nodes_cfg.node_type).transform(graph, name, nodes_cfg.get("attributes", {})) + graph = instantiate(nodes_cfg.node_builder).transform(graph, name, nodes_cfg.get("attributes", {})) for edges_cfg in graph_config.edges: - graph = instantiate(edges_cfg.edge_type, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {})) + graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {})) return graph @@ -66,3 +78,8 @@ def _path_readable(self) -> bool: return True except FileNotFoundError: return False + + +if __name__ == "__main__": + creator = GraphCreator(config="/home/ecm1924/GitRepos/anemoi-graphs/recipe.yaml", path="graph.pt") + creator.create() diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 6b57e70..9e7509f 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -1,74 +1,81 @@ import logging from abc import ABC from abc import abstractmethod +from dataclasses import dataclass from typing import Optional -import torch -from anemoi.utils.config import DotDict import numpy as np +import torch +from scipy.sparse import coo_matrix +from sklearn.preprocessing import normalize from torch_geometric.data import HeteroData -from hydra.utils import instantiate from anemoi.graphs.edges.directional import directional_edge_features from anemoi.graphs.normalizer import NormalizerMixin +from anemoi.graphs.utils import haversine_distance logger = logging.getLogger(__name__) -class NodeAttributeBuilder(): - - def transform(self, graph: HeteroData, graph_config: DotDict): - - for name, nodes_cfg in graph_config.nodes.items(): - graph = self.register_node_attributes(graph, name, nodes_cfg.get("attributes", {})) - - def register_node_attributes(self, graph: HeteroData, node_name: str, node_config: DotDict): - assert node_name in graph.keys(), f"Node {node_name} does not exist in the graph." - for attr_name, attr_cfg in node_config.items(): - graph[node_name][attr_name] = instantiate(attr_cfg).compute(graph, node_name) - return graph - -class EdgeAttributeBuilder(): - - def transform(self, graph: HeteroData, graph_config: DotDict): - for edges_cfg in graph_config.edges: - graph = self.register_edge_attributes(graph, edges_cfg.nodes.src_name, edges_cfg.nodes.dst_name, edges_cfg.get("attributes", {})) - return graph - - def register_edge_attributes(self, graph: HeteroData, src_name: str, dst_name: str, edge_config: DotDict): - - for attr_name, attr_cfg in edge_config.items(): - attr_values = instantiate(attr_cfg).compute(graph, src_name, dst_name) - graph = self.register_edge_attribute(graph, src_name, dst_name, attr_name, attr_values) - return graph - - def register_edge_attribute(self, graph: HeteroData, src_name: str, dst_name: str, attr_name: str, attr_values: torch.Tensor): - num_edges = graph[(src_name, "to", dst_name)].num_edges - assert ( attr_values.shape[0] == num_edges), f"Number of edge features ({attr_values.shape[0]}) must match number of edges ({num_edges})." - - graph[(src_name, "to", dst_name)][attr_name] = attr_values - return graph - +@dataclass class BaseEdgeAttribute(ABC, NormalizerMixin): norm: Optional[str] = None @abstractmethod - def compute(self, graph: HeteroData, *args, **kwargs): ... + def compute(self, graph: HeteroData, *args, **kwargs) -> np.ndarray: ... + + def post_process(self, values: np.ndarray) -> torch.Tensor: + return torch.tensor(values) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> torch.Tensor: values = self.compute(*args, **kwargs) - if values.ndim == 1: - values = values[:, np.newaxis] - return self.normalize(values) + normed_values = self.normalize(values) + if normed_values.ndim == 1: + normed_values = normed_values[:, np.newaxis] + return self.post_process(normed_values) +@dataclass class DirectionalFeatures(BaseEdgeAttribute): + """Compute directional features for edges.""" + norm: Optional[str] = None luse_rotated_features: bool = False - def compute(self, graph: HeteroData, src_name: str, dst_name: str): + def compute(self, graph: HeteroData, src_name: str, dst_name: str) -> torch.Tensor: edge_index = graph[(src_name, "to", dst_name)].edge_index src_coords = graph[src_name].x.numpy()[edge_index[0]].T dst_coords = graph[dst_name].x.numpy()[edge_index[1]].T edge_dirs = directional_edge_features(src_coords, dst_coords, self.luse_rotated_features).T return edge_dirs + + +@dataclass +class HaversineDistance(BaseEdgeAttribute): + """Edge length feature.""" + + norm: str = "l1" + invert: bool = True + + def compute(self, graph: HeteroData, src_name: str, dst_name: str): + """Compute haversine distance (in kilometers) between nodes connected by edges.""" + assert src_name in graph.node_types, f"Node {src_name} not found in graph." + assert dst_name in graph.node_types, f"Node {dst_name} not found in graph." + edge_index = graph[(src_name, "to", dst_name)].edge_index + src_coords = graph[src_name].x.numpy()[edge_index[0]] + dst_coords = graph[dst_name].x.numpy()[edge_index[1]] + edge_lengths = haversine_distance(src_coords, dst_coords) + return coo_matrix((edge_lengths, (edge_index[1], edge_index[0]))) + + def normalize(self, values) -> np.ndarray: + """Normalize the edge length. + + This method scales the edge lengths to a unit norm, computing the norms + for each source node (axis=1). + """ + return normalize(values, norm="l1", axis=1).data + + def post_process(self, values: np.ndarray) -> torch.Tensor: + if self.invert: + values = 1 - values + return super().post_process(values) diff --git a/src/anemoi/graphs/edges/connections.py b/src/anemoi/graphs/edges/connections.py index 6bf057e..49080ce 100644 --- a/src/anemoi/graphs/edges/connections.py +++ b/src/anemoi/graphs/edges/connections.py @@ -7,11 +7,10 @@ from anemoi.utils.config import DotDict from hydra.utils import instantiate from sklearn.neighbors import NearestNeighbors -from sklearn.preprocessing import normalize from torch_geometric.data import HeteroData from torch_geometric.data.storage import NodeStorage -from anemoi.graphs import earth_radius +from anemoi.graphs import EARTH_RADIUS from anemoi.graphs.utils import get_grid_reference_distance logger = logging.getLogger(__name__) @@ -54,13 +53,9 @@ def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) - # Compute adjacency matrix. adjmat = self.get_adj_matrix(src_nodes, dst_nodes) - # Normalize adjacency matrix. - adjmat_norm = self.normalize_adjmat(adjmat) - # Add edges to the graph and register normed distance. graph = self.register_edges(graph, adjmat.col, adjmat.row) - self.register_edge_attribute(graph, "normed_dist", adjmat_norm.data) if attrs_config is not None: for attr_name, attr_cfg in attrs_config.items(): attr_values = instantiate(attr_cfg)(graph, self.src_name, self.dst_name) @@ -68,12 +63,6 @@ def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) - return graph - def normalize_adjmat(self, adjmat): - """Normalize a sparse adjacency matrix.""" - adjmat_norm = normalize(adjmat, norm="l1", axis=1) - adjmat_norm.data = 1.0 - adjmat_norm.data - return adjmat_norm - class KNNEdgeBuilder(BaseEdgeBuilder): """Computes KNN based edges and adds them to the graph.""" @@ -124,7 +113,7 @@ def prepare_node_data(self, graph: HeteroData): return super().prepare_node_data(graph) def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): - logger.debug("Using cut-off radius of %.1f km.", self.radius * earth_radius) + logger.debug("Using cut-off radius of %.1f km.", self.radius * EARTH_RADIUS) nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) nearest_neighbour.fit(src_nodes.x) diff --git a/src/anemoi/graphs/normalizer.py b/src/anemoi/graphs/normalizer.py index 98820c0..5b3edcd 100644 --- a/src/anemoi/graphs/normalizer.py +++ b/src/anemoi/graphs/normalizer.py @@ -6,6 +6,8 @@ class NormalizerMixin: + """Mixin class for normalizing attributes.""" + def normalize(self, values: np.ndarray) -> np.ndarray: if self.norm is None: logger.debug("Node weights are not normalized.") diff --git a/src/anemoi/graphs/utils.py b/src/anemoi/graphs/utils.py index 1a25134..f655e8d 100644 --- a/src/anemoi/graphs/utils.py +++ b/src/anemoi/graphs/utils.py @@ -109,3 +109,25 @@ def get_index_in_outer_join(vector: torch.Tensor, tensor: torch.Tensor) -> int: if mask.any(): return int(torch.where(mask)[0]) return -1 + + +def haversine_distance(src_coords: np.ndarray, dst_coords: np.ndarray) -> np.ndarray: + """Haversine distance. + + Parameters + ---------- + src_coords : np.ndarray of shape (N, 2) + Source coordinates in radians. + dst_coords : np.ndarray of shape (N, 2) + Destination coordinates in radians. + + Returns + ------- + np.ndarray of shape (N,) + Haversine distance between source and destination coordinates. + """ + dlat = dst_coords[:, 0] - src_coords[:, 0] + dlon = dst_coords[:, 1] - src_coords[:, 1] + a = np.sin(dlat / 2) ** 2 + np.cos(src_coords[:, 0]) * np.cos(dst_coords[:, 0]) * np.sin(dlon / 2) ** 2 + c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a)) + return c diff --git a/tests/nodes/test_weights.py b/tests/nodes/test_weights.py index db80dce..71e54fa 100644 --- a/tests/nodes/test_weights.py +++ b/tests/nodes/test_weights.py @@ -1,16 +1,16 @@ -import numpy as np import pytest import torch -from hydra.utils import instantiate from torch_geometric.data import HeteroData +from anemoi.graphs.nodes.weights import AreaWeights +from anemoi.graphs.nodes.weights import UniformWeights + @pytest.mark.parametrize("norm", [None, "l1", "l2", "unit-max", "unit-sum", "unit-std"]) def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): """Test NPZNodes register correctly the weights.""" - config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm} - - weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + node_attr_builder = UniformWeights(norm=norm) + weights = node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) assert weights is not None assert isinstance(weights, torch.Tensor) @@ -20,21 +20,15 @@ def test_uniform_weights(graph_with_nodes: HeteroData, norm: str): @pytest.mark.parametrize("norm", ["l3", "invalide"]) def test_uniform_weights_fail(graph_with_nodes: HeteroData, norm: str): """Test NPZNodes register correctly the weights.""" - config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm} - with pytest.raises(ValueError): - instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + node_attr_builder = UniformWeights(norm=norm) + node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) def test_area_weights(graph_with_nodes: HeteroData): """Test NPZNodes register correctly the weights.""" - config = { - "_target_": "anemoi.graphs.nodes.weights.AreaWeights", - "radius": 1.0, - "centre": np.array([0, 0, 0]), - } - - weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + node_attr_builder = AreaWeights() + weights = node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) assert weights is not None assert isinstance(weights, torch.Tensor) @@ -43,11 +37,6 @@ def test_area_weights(graph_with_nodes: HeteroData): @pytest.mark.parametrize("radius", [-1.0, "hello", None]) def test_area_weights_fail(graph_with_nodes: HeteroData, radius: float): - config = { - "_target_": "anemoi.graphs.nodes.weights.AreaWeights", - "radius": radius, - "centre": np.array([0, 0, 0]), - } - with pytest.raises(ValueError): - instantiate(config).get_weights(graph_with_nodes["test_nodes"]) + node_attr_builder = AreaWeights(radius=radius) + node_attr_builder.get_weights(graph_with_nodes["test_nodes"]) diff --git a/tests/test_normalizer.py b/tests/test_normalizer.py new file mode 100644 index 0000000..2654c0c --- /dev/null +++ b/tests/test_normalizer.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest + +from anemoi.graphs.normalizer import NormalizerMixin + + +@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-sum", "unit-std"]) +def test_normalizer(norm: str): + """Test NormalizerMixin normalize method.""" + class Normalizer(NormalizerMixin): + def __init__(self, norm): + self.norm = norm + + def __call__(self, data): + return self.normalize(data) + + normalizer = Normalizer(norm=norm) + data = np.random.rand(10, 5) + normalized_data = normalizer(data) + assert isinstance(normalized_data, np.ndarray) + assert normalized_data.shape == data.shape + + +@pytest.mark.parametrize("norm", ["l3", "invalid"]) +def test_normalizer_wrong_norm(norm: str): + """Test NormalizerMixin normalize method.""" + class Normalizer(NormalizerMixin): + def __init__(self, norm: str): + self.norm = norm + + def __call__(self, data): + return self.normalize(data) + + with pytest.raises(ValueError): + normalizer = Normalizer(norm=norm) + data = np.random.rand(10, 5) + normalizer(data) + + +def test_normalizer_wrong_inheritance(): + """Test NormalizerMixin normalize method.""" + class Normalizer(NormalizerMixin): + def __init__(self, attr): + self.attr = attr + + def __call__(self, data): + return self.normalize(data) + + with pytest.raises(AttributeError): + normalizer = Normalizer(attr="attr_name") + data = np.random.rand(10, 5) + normalizer(data)