This repository has been archived by the owner on Dec 20, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Initial implementation of global graphs
Co-authored by: Mario Santa Cruz <mario.santacruz@ecmwf.int>
- Loading branch information
1 parent
a654d21
commit 38f8d15
Showing
14 changed files
with
670 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -186,3 +186,5 @@ _build/ | |
*.sync | ||
_version.py | ||
*.code-workspace | ||
|
||
/config* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,2 @@ | ||
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
earth_radius = 6371.0 # km | ||
|
||
|
||
from ._version import __version__ as __version__ |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from abc import ABC | ||
from abc import abstractmethod | ||
from typing import Optional | ||
|
||
import numpy as np | ||
from torch_geometric.data import HeteroData | ||
|
||
from anemoi.graphs.edges.directional import directional_edge_features | ||
from anemoi.graphs.normalizer import NormalizerMixin | ||
from anemoi.utils.logger import get_code_logger | ||
|
||
logger = get_code_logger(__name__) | ||
|
||
|
||
class BaseEdgeAttribute(ABC, NormalizerMixin): | ||
norm: Optional[str] = None | ||
|
||
@abstractmethod | ||
def compute(self, graph: HeteroData, *args, **kwargs): ... | ||
|
||
def __call__(self, *args, **kwargs): | ||
values = self.compute(*args, **kwargs) | ||
if values.ndim == 1: | ||
values = values[:, np.newaxis] | ||
return self.normalize(values) | ||
|
||
|
||
class DirectionalFeatures(BaseEdgeAttribute): | ||
norm: Optional[str] = None | ||
luse_rotated_features: bool = False | ||
|
||
def compute(self, graph: HeteroData, src_name: str, dst_name: str): | ||
edge_index = graph[(src_name, "to", dst_name)].edge_index | ||
src_coords = graph[src_name].x.numpy()[edge_index[0]].T | ||
dst_coords = graph[dst_name].x.numpy()[edge_index[1]].T | ||
edge_dirs = directional_edge_features(src_coords, dst_coords, self.luse_rotated_features).T | ||
return edge_dirs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from abc import abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import networkx as nx | ||
import numpy as np | ||
import torch | ||
from anemoi.utils.config import DotDict | ||
from hydra.utils import instantiate | ||
from sklearn.neighbors import NearestNeighbors | ||
from sklearn.preprocessing import normalize | ||
from torch_geometric.data import HeteroData | ||
from torch_geometric.data.storage import NodeStorage | ||
|
||
from anemoi.graphs import earth_radius | ||
from anemoi.graphs.utils import get_grid_reference_distance | ||
|
||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class BaseEdgeBuilder: | ||
"""Base class for edge builders.""" | ||
|
||
def __init__(self, src_name: str, dst_name: str): | ||
super().__init__() | ||
self.src_name = src_name | ||
self.dst_name = dst_name | ||
|
||
@abstractmethod | ||
def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): ... | ||
|
||
def register_edges(self, graph, head_indices, tail_indices): | ||
graph[(self.src_name, "to", self.dst_name)].edge_index = np.stack([head_indices, tail_indices], axis=0).astype(np.int32) | ||
return graph | ||
|
||
def register_edge_attribute(self, graph: HeteroData, name: str, values: np.ndarray): | ||
num_edges = graph[(self.src_name, "to", self.dst_name)].num_edges | ||
assert ( | ||
values.shape[0] == num_edges | ||
), f"Number of edge features ({values.shape[0]}) must match number of edges ({num_edges})." | ||
graph[self.src_name, "to", self.dst_name][name] = values.reshape(num_edges, -1) # TODO: Check the [name] part works | ||
return graph | ||
|
||
def prepare_node_data(self, graph: HeteroData): | ||
return graph[self.src_name], graph[self.dst_name] | ||
|
||
def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData: | ||
# Get source and destination nodes. | ||
src_nodes, dst_nodes = self.prepare_node_data(graph) | ||
|
||
# Compute adjacency matrix. | ||
adjmat = self.get_adj_matrix(src_nodes, dst_nodes) | ||
|
||
# Normalize adjacency matrix. | ||
adjmat_norm = self.normalize_adjmat(adjmat) | ||
|
||
# Add edges to the graph and register normed distance. | ||
graph = self.register_edges(graph, adjmat.col, adjmat.row) | ||
|
||
self.register_edge_attribute(graph, "normed_dist", adjmat_norm.data) | ||
if attrs_config is not None: | ||
for attr_name, attr_cfg in attrs_config.items(): | ||
attr_values = instantiate(attr_cfg)(graph, self.src_name, self.dst_name) | ||
graph = self.register_edge_attribute(graph, attr_name, attr_values) | ||
|
||
return graph | ||
|
||
def normalize_adjmat(self, adjmat): | ||
"""Normalize a sparse adjacency matrix.""" | ||
adjmat_norm = normalize(adjmat, norm="l1", axis=1) | ||
adjmat_norm.data = 1.0 - adjmat_norm.data | ||
return adjmat_norm | ||
|
||
|
||
class KNNEdgeBuilder(BaseEdgeBuilder): | ||
"""Computes KNN based edges and adds them to the graph.""" | ||
|
||
def __init__(self, src_name: str, dst_name: str, num_nearest_neighbours: int): | ||
super().__init__(src_name, dst_name) | ||
assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer" | ||
assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive" | ||
self.num_nearest_neighbours = num_nearest_neighbours | ||
|
||
def get_adj_matrix(self, src_nodes: np.ndarray, dst_nodes: np.ndarray): | ||
assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder" | ||
logger.debug( | ||
"Using %d nearest neighbours for KNN-Edges between %s and %s.", | ||
self.num_nearest_neighbours, | ||
self.src_name, | ||
self.dst_name, | ||
) | ||
|
||
nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) | ||
nearest_neighbour.fit(src_nodes.x.numpy()) | ||
adj_matrix = nearest_neighbour.kneighbors_graph( | ||
dst_nodes.x.numpy(), | ||
n_neighbors=self.num_nearest_neighbours, | ||
mode="distance", | ||
).tocoo() | ||
return adj_matrix | ||
|
||
|
||
class CutOffEdgeBuilder(BaseEdgeBuilder): | ||
"""Computes cut-off based edges and adds them to the graph.""" | ||
|
||
def __init__(self, src_name: str, dst_name: str, cutoff_factor: float): | ||
super().__init__(src_name, dst_name) | ||
assert isinstance(cutoff_factor, float), "Cutoff factor must be a float" | ||
assert cutoff_factor > 0, "Cutoff factor must be positive" | ||
self.cutoff_factor = cutoff_factor | ||
|
||
def get_cutoff_radius(self, dst_nodes: NodeStorage, mask_attr: Optional[torch.Tensor] = None): | ||
mask = dst_nodes[mask_attr] if mask_attr is not None else None | ||
dst_grid_reference_distance = get_grid_reference_distance(dst_nodes.x, mask) | ||
radius = dst_grid_reference_distance * self.cutoff_factor | ||
return radius | ||
|
||
def prepare_node_data(self, graph: HeteroData): | ||
self.radius = self.get_cutoff_radius(graph) | ||
return super().prepare_node_data(graph) | ||
|
||
def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): | ||
logger.debug("Using cut-off radius of %.1f km.", self.radius * earth_radius) | ||
|
||
nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) | ||
nearest_neighbour.fit(src_nodes.x) | ||
adj_matrix = nearest_neighbour.radius_neighbors_graph(dst_nodes.x, radius=self.radius).tocoo() | ||
return adj_matrix | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from typing import Optional | ||
|
||
import numpy as np | ||
from scipy.spatial.transform import Rotation | ||
|
||
from anemoi.graphs.generate.transforms import direction_vec | ||
from anemoi.graphs.generate.transforms import to_sphere_xyz | ||
|
||
|
||
def get_rotation_from_unit_vecs(points: np.ndarray, reference: np.ndarray) -> Rotation: | ||
"""Compute rotation matrix of a set of points with respect to a reference vector. | ||
Parameters | ||
---------- | ||
points : np.ndarray of shape (num_points, 3) | ||
The points to compute the direction vector. | ||
reference : np.ndarray of shape (3, ) | ||
The reference vector. | ||
Returns | ||
------- | ||
Rotation | ||
The rotation matrix that aligns the points with the reference vector. | ||
""" | ||
assert points.shape[1] == 3, "Points must be in 3D" | ||
v_unit = direction_vec(points, reference) | ||
theta = np.arccos(np.dot(points, reference)) | ||
return Rotation.from_rotvec(np.transpose(v_unit * theta)) | ||
|
||
|
||
def compute_directions(loc1: np.ndarray, loc2: np.ndarray, pole_vec: Optional[np.ndarray] = None) -> np.ndarray: | ||
"""Compute the direction of the edge joining the nodes considered. | ||
Parameters | ||
---------- | ||
loc1 : np.ndarray of shape (2, num_points) | ||
Location of the head nodes. | ||
loc2 : np.ndarray | ||
Location of the tail nodes. | ||
pole_vec : np.ndarray, optional | ||
The pole vector to rotate the points to. Defaults to the north pole. | ||
Returns | ||
------- | ||
np.ndarray of shape (3, num_points) | ||
The direction of the edge after rotating the north pole. | ||
""" | ||
if pole_vec is None: | ||
pole_vec = np.array([0, 0, 1]) | ||
|
||
# all will be rotated relative to destination node | ||
loc1_xyz = to_sphere_xyz(loc1, 1.0) | ||
loc2_xyz = to_sphere_xyz(loc2, 1.0) | ||
r = get_rotation_from_unit_vecs(loc2_xyz, pole_vec) | ||
direction = direction_vec(r.apply(loc1_xyz), pole_vec) | ||
return direction / np.sqrt(np.power(direction, 2).sum(axis=0)) | ||
|
||
|
||
def directional_edge_features(loc1: np.ndarray, loc2: np.ndarray, relative_to_rotated_target: bool = True) -> np.ndarray: | ||
"""Compute features of the edge joining the nodes considered. | ||
It computes the direction of the edge after rotating the north pole. | ||
Parameters | ||
---------- | ||
loc1 : np.ndarray of shpae (2, num_points) | ||
Location of the head node. | ||
loc2 : np.ndarray of shape (2, num_points) | ||
Location of the tail node. | ||
relative_to_rotated_target : bool, optional | ||
Whether to rotate the north pole to the target node. Defaults to True. | ||
Returns | ||
------- | ||
np.ndarray of shape of (2, num_points) | ||
Direction of the edge after rotation the north pole. | ||
""" | ||
if relative_to_rotated_target: | ||
rotation = compute_directions(loc1, loc2) | ||
assert np.allclose(rotation[2], 0), "Rotation should be aligned with the north pole" | ||
return rotation[:2] | ||
|
||
return loc2 - loc1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from abc import ABC | ||
from abc import abstractmethod | ||
|
||
import hydra | ||
from anemoi.utils.config import DotDict | ||
from hydra.utils import instantiate | ||
from omegaconf import DictConfig | ||
from torch_geometric.data import HeteroData | ||
|
||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def generate_graph(graph_config): | ||
graph = HeteroData() | ||
|
||
for name, nodes_cfg in graph_config.nodes.items(): | ||
graph = instantiate(nodes_cfg.node_type).transform(graph, name, nodes_cfg.get("attributes", {})) | ||
|
||
for edges_cfg in graph_config.edges: | ||
graph = instantiate(edges_cfg.edge_type, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {})) | ||
|
||
return graph | ||
|
||
|
||
@hydra.main(version_base=None, config_path="../config", config_name="config") | ||
def main(config: DictConfig): | ||
|
||
graph = generate_graph(config) | ||
|
||
return graph | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Empty file.
Oops, something went wrong.