This repository has been archived by the owner on Dec 20, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: edge_length moved to edges/attributes.py
- Loading branch information
Showing
8 changed files
with
159 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,4 @@ | |
|
||
from ._version import __version__ | ||
|
||
earth_radius = 6371.0 # km | ||
EARTH_RADIUS = 6371.0 # km |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,74 +1,81 @@ | ||
import logging | ||
from abc import ABC | ||
from abc import abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import torch | ||
from anemoi.utils.config import DotDict | ||
import numpy as np | ||
import torch | ||
from scipy.sparse import coo_matrix | ||
from sklearn.preprocessing import normalize | ||
from torch_geometric.data import HeteroData | ||
from hydra.utils import instantiate | ||
|
||
from anemoi.graphs.edges.directional import directional_edge_features | ||
from anemoi.graphs.normalizer import NormalizerMixin | ||
from anemoi.graphs.utils import haversine_distance | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class NodeAttributeBuilder(): | ||
|
||
def transform(self, graph: HeteroData, graph_config: DotDict): | ||
|
||
for name, nodes_cfg in graph_config.nodes.items(): | ||
graph = self.register_node_attributes(graph, name, nodes_cfg.get("attributes", {})) | ||
|
||
def register_node_attributes(self, graph: HeteroData, node_name: str, node_config: DotDict): | ||
assert node_name in graph.keys(), f"Node {node_name} does not exist in the graph." | ||
for attr_name, attr_cfg in node_config.items(): | ||
graph[node_name][attr_name] = instantiate(attr_cfg).compute(graph, node_name) | ||
return graph | ||
|
||
class EdgeAttributeBuilder(): | ||
|
||
def transform(self, graph: HeteroData, graph_config: DotDict): | ||
for edges_cfg in graph_config.edges: | ||
graph = self.register_edge_attributes(graph, edges_cfg.nodes.src_name, edges_cfg.nodes.dst_name, edges_cfg.get("attributes", {})) | ||
return graph | ||
|
||
def register_edge_attributes(self, graph: HeteroData, src_name: str, dst_name: str, edge_config: DotDict): | ||
|
||
for attr_name, attr_cfg in edge_config.items(): | ||
attr_values = instantiate(attr_cfg).compute(graph, src_name, dst_name) | ||
graph = self.register_edge_attribute(graph, src_name, dst_name, attr_name, attr_values) | ||
return graph | ||
|
||
def register_edge_attribute(self, graph: HeteroData, src_name: str, dst_name: str, attr_name: str, attr_values: torch.Tensor): | ||
num_edges = graph[(src_name, "to", dst_name)].num_edges | ||
assert ( attr_values.shape[0] == num_edges), f"Number of edge features ({attr_values.shape[0]}) must match number of edges ({num_edges})." | ||
|
||
graph[(src_name, "to", dst_name)][attr_name] = attr_values | ||
return graph | ||
|
||
|
||
@dataclass | ||
class BaseEdgeAttribute(ABC, NormalizerMixin): | ||
norm: Optional[str] = None | ||
|
||
@abstractmethod | ||
def compute(self, graph: HeteroData, *args, **kwargs): ... | ||
def compute(self, graph: HeteroData, *args, **kwargs) -> np.ndarray: ... | ||
|
||
def post_process(self, values: np.ndarray) -> torch.Tensor: | ||
return torch.tensor(values) | ||
|
||
def __call__(self, *args, **kwargs): | ||
def __call__(self, *args, **kwargs) -> torch.Tensor: | ||
values = self.compute(*args, **kwargs) | ||
if values.ndim == 1: | ||
values = values[:, np.newaxis] | ||
return self.normalize(values) | ||
normed_values = self.normalize(values) | ||
if normed_values.ndim == 1: | ||
normed_values = normed_values[:, np.newaxis] | ||
return self.post_process(normed_values) | ||
|
||
|
||
@dataclass | ||
class DirectionalFeatures(BaseEdgeAttribute): | ||
"""Compute directional features for edges.""" | ||
|
||
norm: Optional[str] = None | ||
luse_rotated_features: bool = False | ||
|
||
def compute(self, graph: HeteroData, src_name: str, dst_name: str): | ||
def compute(self, graph: HeteroData, src_name: str, dst_name: str) -> torch.Tensor: | ||
edge_index = graph[(src_name, "to", dst_name)].edge_index | ||
src_coords = graph[src_name].x.numpy()[edge_index[0]].T | ||
dst_coords = graph[dst_name].x.numpy()[edge_index[1]].T | ||
edge_dirs = directional_edge_features(src_coords, dst_coords, self.luse_rotated_features).T | ||
return edge_dirs | ||
|
||
|
||
@dataclass | ||
class HaversineDistance(BaseEdgeAttribute): | ||
"""Edge length feature.""" | ||
|
||
norm: str = "l1" | ||
invert: bool = True | ||
|
||
def compute(self, graph: HeteroData, src_name: str, dst_name: str): | ||
"""Compute haversine distance (in kilometers) between nodes connected by edges.""" | ||
assert src_name in graph.node_types, f"Node {src_name} not found in graph." | ||
assert dst_name in graph.node_types, f"Node {dst_name} not found in graph." | ||
edge_index = graph[(src_name, "to", dst_name)].edge_index | ||
src_coords = graph[src_name].x.numpy()[edge_index[0]] | ||
dst_coords = graph[dst_name].x.numpy()[edge_index[1]] | ||
edge_lengths = haversine_distance(src_coords, dst_coords) | ||
return coo_matrix((edge_lengths, (edge_index[1], edge_index[0]))) | ||
|
||
def normalize(self, values) -> np.ndarray: | ||
"""Normalize the edge length. | ||
This method scales the edge lengths to a unit norm, computing the norms | ||
for each source node (axis=1). | ||
""" | ||
return normalize(values, norm="l1", axis=1).data | ||
|
||
def post_process(self, values: np.ndarray) -> torch.Tensor: | ||
if self.invert: | ||
values = 1 - values | ||
return super().post_process(values) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from anemoi.graphs.normalizer import NormalizerMixin | ||
|
||
|
||
@pytest.mark.parametrize("norm", ["l1", "l2", "unit-max", "unit-sum", "unit-std"]) | ||
def test_normalizer(norm: str): | ||
"""Test NormalizerMixin normalize method.""" | ||
class Normalizer(NormalizerMixin): | ||
def __init__(self, norm): | ||
self.norm = norm | ||
|
||
def __call__(self, data): | ||
return self.normalize(data) | ||
|
||
normalizer = Normalizer(norm=norm) | ||
data = np.random.rand(10, 5) | ||
normalized_data = normalizer(data) | ||
assert isinstance(normalized_data, np.ndarray) | ||
assert normalized_data.shape == data.shape | ||
|
||
|
||
@pytest.mark.parametrize("norm", ["l3", "invalid"]) | ||
def test_normalizer_wrong_norm(norm: str): | ||
"""Test NormalizerMixin normalize method.""" | ||
class Normalizer(NormalizerMixin): | ||
def __init__(self, norm: str): | ||
self.norm = norm | ||
|
||
def __call__(self, data): | ||
return self.normalize(data) | ||
|
||
with pytest.raises(ValueError): | ||
normalizer = Normalizer(norm=norm) | ||
data = np.random.rand(10, 5) | ||
normalizer(data) | ||
|
||
|
||
def test_normalizer_wrong_inheritance(): | ||
"""Test NormalizerMixin normalize method.""" | ||
class Normalizer(NormalizerMixin): | ||
def __init__(self, attr): | ||
self.attr = attr | ||
|
||
def __call__(self, data): | ||
return self.normalize(data) | ||
|
||
with pytest.raises(AttributeError): | ||
normalizer = Normalizer(attr="attr_name") | ||
data = np.random.rand(10, 5) | ||
normalizer(data) |