Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
feat: edge_length moved to edges/attributes.py
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Jun 27, 2024
1 parent 9ba0391 commit 9a47184
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 81 deletions.
2 changes: 1 addition & 1 deletion src/anemoi/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

from ._version import __version__

earth_radius = 6371.0 # km
EARTH_RADIUS = 6371.0 # km
21 changes: 19 additions & 2 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,25 @@


def generate_graph(graph_config: DotDict) -> HeteroData:
"""Generate a graph from a configuration.
Parameters
----------
graph_config : DotDict
Configuration for the nodes and edges (and its attributes).
Returns
-------
HeteroData
Graph.
"""
graph = HeteroData()

for name, nodes_cfg in graph_config.nodes.items():
graph = instantiate(nodes_cfg.node_type).transform(graph, name, nodes_cfg.get("attributes", {}))
graph = instantiate(nodes_cfg.node_builder).transform(graph, name, nodes_cfg.get("attributes", {}))

for edges_cfg in graph_config.edges:
graph = instantiate(edges_cfg.edge_type, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {}))
graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {}))

return graph

Expand Down Expand Up @@ -66,3 +78,8 @@ def _path_readable(self) -> bool:
return True
except FileNotFoundError:
return False


if __name__ == "__main__":
creator = GraphCreator(config="/home/ecm1924/GitRepos/anemoi-graphs/recipe.yaml", path="graph.pt")
creator.create()
93 changes: 50 additions & 43 deletions src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,81 @@
import logging
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional

import torch
from anemoi.utils.config import DotDict
import numpy as np
import torch
from scipy.sparse import coo_matrix
from sklearn.preprocessing import normalize
from torch_geometric.data import HeteroData
from hydra.utils import instantiate

from anemoi.graphs.edges.directional import directional_edge_features
from anemoi.graphs.normalizer import NormalizerMixin
from anemoi.graphs.utils import haversine_distance

logger = logging.getLogger(__name__)

class NodeAttributeBuilder():

def transform(self, graph: HeteroData, graph_config: DotDict):

for name, nodes_cfg in graph_config.nodes.items():
graph = self.register_node_attributes(graph, name, nodes_cfg.get("attributes", {}))

def register_node_attributes(self, graph: HeteroData, node_name: str, node_config: DotDict):
assert node_name in graph.keys(), f"Node {node_name} does not exist in the graph."
for attr_name, attr_cfg in node_config.items():
graph[node_name][attr_name] = instantiate(attr_cfg).compute(graph, node_name)
return graph

class EdgeAttributeBuilder():

def transform(self, graph: HeteroData, graph_config: DotDict):
for edges_cfg in graph_config.edges:
graph = self.register_edge_attributes(graph, edges_cfg.nodes.src_name, edges_cfg.nodes.dst_name, edges_cfg.get("attributes", {}))
return graph

def register_edge_attributes(self, graph: HeteroData, src_name: str, dst_name: str, edge_config: DotDict):

for attr_name, attr_cfg in edge_config.items():
attr_values = instantiate(attr_cfg).compute(graph, src_name, dst_name)
graph = self.register_edge_attribute(graph, src_name, dst_name, attr_name, attr_values)
return graph

def register_edge_attribute(self, graph: HeteroData, src_name: str, dst_name: str, attr_name: str, attr_values: torch.Tensor):
num_edges = graph[(src_name, "to", dst_name)].num_edges
assert ( attr_values.shape[0] == num_edges), f"Number of edge features ({attr_values.shape[0]}) must match number of edges ({num_edges})."

graph[(src_name, "to", dst_name)][attr_name] = attr_values
return graph


@dataclass
class BaseEdgeAttribute(ABC, NormalizerMixin):
norm: Optional[str] = None

@abstractmethod
def compute(self, graph: HeteroData, *args, **kwargs): ...
def compute(self, graph: HeteroData, *args, **kwargs) -> np.ndarray: ...

def post_process(self, values: np.ndarray) -> torch.Tensor:
return torch.tensor(values)

def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs) -> torch.Tensor:
values = self.compute(*args, **kwargs)
if values.ndim == 1:
values = values[:, np.newaxis]
return self.normalize(values)
normed_values = self.normalize(values)
if normed_values.ndim == 1:
normed_values = normed_values[:, np.newaxis]
return self.post_process(normed_values)


@dataclass
class DirectionalFeatures(BaseEdgeAttribute):
"""Compute directional features for edges."""

norm: Optional[str] = None
luse_rotated_features: bool = False

def compute(self, graph: HeteroData, src_name: str, dst_name: str):
def compute(self, graph: HeteroData, src_name: str, dst_name: str) -> torch.Tensor:
edge_index = graph[(src_name, "to", dst_name)].edge_index
src_coords = graph[src_name].x.numpy()[edge_index[0]].T
dst_coords = graph[dst_name].x.numpy()[edge_index[1]].T
edge_dirs = directional_edge_features(src_coords, dst_coords, self.luse_rotated_features).T
return edge_dirs


@dataclass
class HaversineDistance(BaseEdgeAttribute):
"""Edge length feature."""

norm: str = "l1"
invert: bool = True

def compute(self, graph: HeteroData, src_name: str, dst_name: str):
"""Compute haversine distance (in kilometers) between nodes connected by edges."""
assert src_name in graph.node_types, f"Node {src_name} not found in graph."
assert dst_name in graph.node_types, f"Node {dst_name} not found in graph."
edge_index = graph[(src_name, "to", dst_name)].edge_index
src_coords = graph[src_name].x.numpy()[edge_index[0]]
dst_coords = graph[dst_name].x.numpy()[edge_index[1]]
edge_lengths = haversine_distance(src_coords, dst_coords)
return coo_matrix((edge_lengths, (edge_index[1], edge_index[0])))

def normalize(self, values) -> np.ndarray:
"""Normalize the edge length.
This method scales the edge lengths to a unit norm, computing the norms
for each source node (axis=1).
"""
return normalize(values, norm="l1", axis=1).data

def post_process(self, values: np.ndarray) -> torch.Tensor:
if self.invert:
values = 1 - values
return super().post_process(values)
15 changes: 2 additions & 13 deletions src/anemoi/graphs/edges/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize
from torch_geometric.data import HeteroData
from torch_geometric.data.storage import NodeStorage

from anemoi.graphs import earth_radius
from anemoi.graphs import EARTH_RADIUS
from anemoi.graphs.utils import get_grid_reference_distance

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -54,26 +53,16 @@ def transform(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -
# Compute adjacency matrix.
adjmat = self.get_adj_matrix(src_nodes, dst_nodes)

# Normalize adjacency matrix.
adjmat_norm = self.normalize_adjmat(adjmat)

# Add edges to the graph and register normed distance.
graph = self.register_edges(graph, adjmat.col, adjmat.row)

self.register_edge_attribute(graph, "normed_dist", adjmat_norm.data)
if attrs_config is not None:
for attr_name, attr_cfg in attrs_config.items():
attr_values = instantiate(attr_cfg)(graph, self.src_name, self.dst_name)
graph = self.register_edge_attribute(graph, attr_name, attr_values)

return graph

def normalize_adjmat(self, adjmat):
"""Normalize a sparse adjacency matrix."""
adjmat_norm = normalize(adjmat, norm="l1", axis=1)
adjmat_norm.data = 1.0 - adjmat_norm.data
return adjmat_norm


class KNNEdgeBuilder(BaseEdgeBuilder):
"""Computes KNN based edges and adds them to the graph."""
Expand Down Expand Up @@ -124,7 +113,7 @@ def prepare_node_data(self, graph: HeteroData):
return super().prepare_node_data(graph)

def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage):
logger.debug("Using cut-off radius of %.1f km.", self.radius * earth_radius)
logger.debug("Using cut-off radius of %.1f km.", self.radius * EARTH_RADIUS)

nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4)
nearest_neighbour.fit(src_nodes.x)
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/graphs/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@


class NormalizerMixin:
"""Mixin class for normalizing attributes."""

def normalize(self, values: np.ndarray) -> np.ndarray:
if self.norm is None:
logger.debug("Node weights are not normalized.")
Expand Down
22 changes: 22 additions & 0 deletions src/anemoi/graphs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,25 @@ def get_index_in_outer_join(vector: torch.Tensor, tensor: torch.Tensor) -> int:
if mask.any():
return int(torch.where(mask)[0])
return -1


def haversine_distance(src_coords: np.ndarray, dst_coords: np.ndarray) -> np.ndarray:
"""Haversine distance.
Parameters
----------
src_coords : np.ndarray of shape (N, 2)
Source coordinates in radians.
dst_coords : np.ndarray of shape (N, 2)
Destination coordinates in radians.
Returns
-------
np.ndarray of shape (N,)
Haversine distance between source and destination coordinates.
"""
dlat = dst_coords[:, 0] - src_coords[:, 0]
dlon = dst_coords[:, 1] - src_coords[:, 1]
a = np.sin(dlat / 2) ** 2 + np.cos(src_coords[:, 0]) * np.cos(dst_coords[:, 0]) * np.sin(dlon / 2) ** 2
c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
return c
33 changes: 11 additions & 22 deletions tests/nodes/test_weights.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import numpy as np
import pytest
import torch
from hydra.utils import instantiate
from torch_geometric.data import HeteroData

from anemoi.graphs.nodes.weights import AreaWeights
from anemoi.graphs.nodes.weights import UniformWeights


@pytest.mark.parametrize("norm", [None, "l1", "l2", "unit-max", "unit-sum", "unit-std"])
def test_uniform_weights(graph_with_nodes: HeteroData, norm: str):
"""Test NPZNodes register correctly the weights."""
config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm}

weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"])
node_attr_builder = UniformWeights(norm=norm)
weights = node_attr_builder.get_weights(graph_with_nodes["test_nodes"])

assert weights is not None
assert isinstance(weights, torch.Tensor)
Expand All @@ -20,21 +20,15 @@ def test_uniform_weights(graph_with_nodes: HeteroData, norm: str):
@pytest.mark.parametrize("norm", ["l3", "invalide"])
def test_uniform_weights_fail(graph_with_nodes: HeteroData, norm: str):
"""Test NPZNodes register correctly the weights."""
config = {"_target_": "anemoi.graphs.nodes.weights.UniformWeights", "norm": norm}

with pytest.raises(ValueError):
instantiate(config).get_weights(graph_with_nodes["test_nodes"])
node_attr_builder = UniformWeights(norm=norm)
node_attr_builder.get_weights(graph_with_nodes["test_nodes"])


def test_area_weights(graph_with_nodes: HeteroData):
"""Test NPZNodes register correctly the weights."""
config = {
"_target_": "anemoi.graphs.nodes.weights.AreaWeights",
"radius": 1.0,
"centre": np.array([0, 0, 0]),
}

weights = instantiate(config).get_weights(graph_with_nodes["test_nodes"])
node_attr_builder = AreaWeights()
weights = node_attr_builder.get_weights(graph_with_nodes["test_nodes"])

assert weights is not None
assert isinstance(weights, torch.Tensor)
Expand All @@ -43,11 +37,6 @@ def test_area_weights(graph_with_nodes: HeteroData):

@pytest.mark.parametrize("radius", [-1.0, "hello", None])
def test_area_weights_fail(graph_with_nodes: HeteroData, radius: float):
config = {
"_target_": "anemoi.graphs.nodes.weights.AreaWeights",
"radius": radius,
"centre": np.array([0, 0, 0]),
}

with pytest.raises(ValueError):
instantiate(config).get_weights(graph_with_nodes["test_nodes"])
node_attr_builder = AreaWeights(radius=radius)
node_attr_builder.get_weights(graph_with_nodes["test_nodes"])
52 changes: 52 additions & 0 deletions tests/test_normalizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np
import pytest

from anemoi.graphs.normalizer import NormalizerMixin


@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-sum", "unit-std"])
def test_normalizer(norm: str):
"""Test NormalizerMixin normalize method."""
class Normalizer(NormalizerMixin):
def __init__(self, norm):
self.norm = norm

def __call__(self, data):
return self.normalize(data)

normalizer = Normalizer(norm=norm)
data = np.random.rand(10, 5)
normalized_data = normalizer(data)
assert isinstance(normalized_data, np.ndarray)
assert normalized_data.shape == data.shape


@pytest.mark.parametrize("norm", ["l3", "invalid"])
def test_normalizer_wrong_norm(norm: str):
"""Test NormalizerMixin normalize method."""
class Normalizer(NormalizerMixin):
def __init__(self, norm: str):
self.norm = norm

def __call__(self, data):
return self.normalize(data)

with pytest.raises(ValueError):
normalizer = Normalizer(norm=norm)
data = np.random.rand(10, 5)
normalizer(data)


def test_normalizer_wrong_inheritance():
"""Test NormalizerMixin normalize method."""
class Normalizer(NormalizerMixin):
def __init__(self, attr):
self.attr = attr

def __call__(self, data):
return self.normalize(data)

with pytest.raises(AttributeError):
normalizer = Normalizer(attr="attr_name")
data = np.random.rand(10, 5)
normalizer(data)

0 comments on commit 9a47184

Please sign in to comment.