Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
feat: Initial implementation of global graphs
Browse files Browse the repository at this point in the history
Co-authored by: Mario Santa Cruz <mario.santacruz@ecmwf.int>
  • Loading branch information
theissenhelen committed Jun 24, 2024
1 parent a654d21 commit 38f8d15
Show file tree
Hide file tree
Showing 14 changed files with 670 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,5 @@ _build/
*.sync
_version.py
*.code-workspace

/config*
9 changes: 1 addition & 8 deletions src/anemoi/graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,2 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
earth_radius = 6371.0 # km


from ._version import __version__ as __version__
Empty file.
37 changes: 37 additions & 0 deletions src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from abc import ABC
from abc import abstractmethod
from typing import Optional

import numpy as np
from torch_geometric.data import HeteroData

from anemoi.graphs.edges.directional import directional_edge_features
from anemoi.graphs.normalizer import NormalizerMixin
from anemoi.utils.logger import get_code_logger

logger = get_code_logger(__name__)


class BaseEdgeAttribute(ABC, NormalizerMixin):
norm: Optional[str] = None

@abstractmethod
def compute(self, graph: HeteroData, *args, **kwargs): ...

def __call__(self, *args, **kwargs):
values = self.compute(*args, **kwargs)
if values.ndim == 1:
values = values[:, np.newaxis]
return self.normalize(values)


class DirectionalFeatures(BaseEdgeAttribute):
norm: Optional[str] = None
luse_rotated_features: bool = False

def compute(self, graph: HeteroData, src_name: str, dst_name: str):
edge_index = graph[(src_name, "to", dst_name)].edge_index
src_coords = graph[src_name].x.numpy()[edge_index[0]].T
dst_coords = graph[dst_name].x.numpy()[edge_index[1]].T
edge_dirs = directional_edge_features(src_coords, dst_coords, self.luse_rotated_features).T
return edge_dirs
131 changes: 131 additions & 0 deletions src/anemoi/graphs/edges/connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional

import networkx as nx
import numpy as np
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize
from torch_geometric.data import HeteroData
from torch_geometric.data.storage import NodeStorage

from anemoi.graphs import earth_radius
from anemoi.graphs.utils import get_grid_reference_distance

import logging

logger = logging.getLogger(__name__)


class BaseEdgeBuilder:
"""Base class for edge builders."""

def __init__(self, src_name: str, dst_name: str):
super().__init__()
self.src_name = src_name
self.dst_name = dst_name

@abstractmethod
def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): ...

def register_edges(self, graph, head_indices, tail_indices):
graph[(self.src_name, "to", self.dst_name)].edge_index = np.stack([head_indices, tail_indices], axis=0).astype(np.int32)
return graph

def register_edge_attribute(self, graph: HeteroData, name: str, values: np.ndarray):
num_edges = graph[(self.src_name, "to", self.dst_name)].num_edges
assert (
values.shape[0] == num_edges
), f"Number of edge features ({values.shape[0]}) must match number of edges ({num_edges})."
graph[self.src_name, "to", self.dst_name][name] = values.reshape(num_edges, -1) # TODO: Check the [name] part works
return graph

def prepare_node_data(self, graph: HeteroData):
return graph[self.src_name], graph[self.dst_name]

def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData:
# Get source and destination nodes.
src_nodes, dst_nodes = self.prepare_node_data(graph)

# Compute adjacency matrix.
adjmat = self.get_adj_matrix(src_nodes, dst_nodes)

# Normalize adjacency matrix.
adjmat_norm = self.normalize_adjmat(adjmat)

# Add edges to the graph and register normed distance.
graph = self.register_edges(graph, adjmat.col, adjmat.row)

self.register_edge_attribute(graph, "normed_dist", adjmat_norm.data)
if attrs_config is not None:
for attr_name, attr_cfg in attrs_config.items():
attr_values = instantiate(attr_cfg)(graph, self.src_name, self.dst_name)
graph = self.register_edge_attribute(graph, attr_name, attr_values)

return graph

def normalize_adjmat(self, adjmat):
"""Normalize a sparse adjacency matrix."""
adjmat_norm = normalize(adjmat, norm="l1", axis=1)
adjmat_norm.data = 1.0 - adjmat_norm.data
return adjmat_norm


class KNNEdgeBuilder(BaseEdgeBuilder):
"""Computes KNN based edges and adds them to the graph."""

def __init__(self, src_name: str, dst_name: str, num_nearest_neighbours: int):
super().__init__(src_name, dst_name)
assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer"
assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive"
self.num_nearest_neighbours = num_nearest_neighbours

def get_adj_matrix(self, src_nodes: np.ndarray, dst_nodes: np.ndarray):
assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder"
logger.debug(
"Using %d nearest neighbours for KNN-Edges between %s and %s.",
self.num_nearest_neighbours,
self.src_name,
self.dst_name,
)

nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4)
nearest_neighbour.fit(src_nodes.x.numpy())
adj_matrix = nearest_neighbour.kneighbors_graph(
dst_nodes.x.numpy(),
n_neighbors=self.num_nearest_neighbours,
mode="distance",
).tocoo()
return adj_matrix


class CutOffEdgeBuilder(BaseEdgeBuilder):
"""Computes cut-off based edges and adds them to the graph."""

def __init__(self, src_name: str, dst_name: str, cutoff_factor: float):
super().__init__(src_name, dst_name)
assert isinstance(cutoff_factor, float), "Cutoff factor must be a float"
assert cutoff_factor > 0, "Cutoff factor must be positive"
self.cutoff_factor = cutoff_factor

def get_cutoff_radius(self, dst_nodes: NodeStorage, mask_attr: Optional[torch.Tensor] = None):
mask = dst_nodes[mask_attr] if mask_attr is not None else None
dst_grid_reference_distance = get_grid_reference_distance(dst_nodes.x, mask)
radius = dst_grid_reference_distance * self.cutoff_factor
return radius

def prepare_node_data(self, graph: HeteroData):
self.radius = self.get_cutoff_radius(graph)
return super().prepare_node_data(graph)

def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage):
logger.debug("Using cut-off radius of %.1f km.", self.radius * earth_radius)

nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4)
nearest_neighbour.fit(src_nodes.x)
adj_matrix = nearest_neighbour.radius_neighbors_graph(dst_nodes.x, radius=self.radius).tocoo()
return adj_matrix

83 changes: 83 additions & 0 deletions src/anemoi/graphs/edges/directional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Optional

import numpy as np
from scipy.spatial.transform import Rotation

from anemoi.graphs.generate.transforms import direction_vec
from anemoi.graphs.generate.transforms import to_sphere_xyz


def get_rotation_from_unit_vecs(points: np.ndarray, reference: np.ndarray) -> Rotation:
"""Compute rotation matrix of a set of points with respect to a reference vector.
Parameters
----------
points : np.ndarray of shape (num_points, 3)
The points to compute the direction vector.
reference : np.ndarray of shape (3, )
The reference vector.
Returns
-------
Rotation
The rotation matrix that aligns the points with the reference vector.
"""
assert points.shape[1] == 3, "Points must be in 3D"
v_unit = direction_vec(points, reference)
theta = np.arccos(np.dot(points, reference))
return Rotation.from_rotvec(np.transpose(v_unit * theta))


def compute_directions(loc1: np.ndarray, loc2: np.ndarray, pole_vec: Optional[np.ndarray] = None) -> np.ndarray:
"""Compute the direction of the edge joining the nodes considered.
Parameters
----------
loc1 : np.ndarray of shape (2, num_points)
Location of the head nodes.
loc2 : np.ndarray
Location of the tail nodes.
pole_vec : np.ndarray, optional
The pole vector to rotate the points to. Defaults to the north pole.
Returns
-------
np.ndarray of shape (3, num_points)
The direction of the edge after rotating the north pole.
"""
if pole_vec is None:
pole_vec = np.array([0, 0, 1])

# all will be rotated relative to destination node
loc1_xyz = to_sphere_xyz(loc1, 1.0)
loc2_xyz = to_sphere_xyz(loc2, 1.0)
r = get_rotation_from_unit_vecs(loc2_xyz, pole_vec)
direction = direction_vec(r.apply(loc1_xyz), pole_vec)
return direction / np.sqrt(np.power(direction, 2).sum(axis=0))


def directional_edge_features(loc1: np.ndarray, loc2: np.ndarray, relative_to_rotated_target: bool = True) -> np.ndarray:
"""Compute features of the edge joining the nodes considered.
It computes the direction of the edge after rotating the north pole.
Parameters
----------
loc1 : np.ndarray of shpae (2, num_points)
Location of the head node.
loc2 : np.ndarray of shape (2, num_points)
Location of the tail node.
relative_to_rotated_target : bool, optional
Whether to rotate the north pole to the target node. Defaults to True.
Returns
-------
np.ndarray of shape of (2, num_points)
Direction of the edge after rotation the north pole.
"""
if relative_to_rotated_target:
rotation = compute_directions(loc1, loc2)
assert np.allclose(rotation[2], 0), "Rotation should be aligned with the north pole"
return rotation[:2]

return loc2 - loc1
36 changes: 36 additions & 0 deletions src/anemoi/graphs/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from abc import ABC
from abc import abstractmethod

import hydra
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch_geometric.data import HeteroData

import logging

logger = logging.getLogger(__name__)


def generate_graph(graph_config):
graph = HeteroData()

for name, nodes_cfg in graph_config.nodes.items():
graph = instantiate(nodes_cfg.node_type).transform(graph, name, nodes_cfg.get("attributes", {}))

for edges_cfg in graph_config.edges:
graph = instantiate(edges_cfg.edge_type, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {}))

return graph


@hydra.main(version_base=None, config_path="../config", config_name="config")
def main(config: DictConfig):

graph = generate_graph(config)

return graph


if __name__ == "__main__":
main()
Empty file.
Loading

0 comments on commit 38f8d15

Please sign in to comment.