Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
feat: initial version of AttributeBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed Jun 26, 2024
1 parent b12272d commit cce5ea6
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,47 @@
from abc import abstractmethod
from typing import Optional

import torch
from anemoi.utils.config import DotDict
import numpy as np
from torch_geometric.data import HeteroData
from hydra.utils import instantiate

from anemoi.graphs.edges.directional import directional_edge_features
from anemoi.graphs.normalizer import NormalizerMixin

logger = logging.getLogger(__name__)

class AttributeBuilder():

def transform(self, graph: HeteroData, graph_config: DotDict):

for name, nodes_cfg in graph_config.nodes.items():
graph = self.register_node_attributes(graph, name, nodes_cfg.get("attributes", {}))
for edges_cfg in graph_config.edges:
graph = self.register_edge_attributes(graph, edges_cfg.nodes.src_name, edges_cfg.nodes.dst_name, edges_cfg.get("attributes", {}))
return graph

def register_node_attributes(self, graph: HeteroData, node_name: str, node_config: DotDict):
assert node_name in graph.keys(), f"Node {node_name} does not exist in the graph."
for attr_name, attr_cfg in node_config.items():
graph[node_name][attr_name] = instantiate(attr_cfg).compute(graph, node_name)
return graph

def register_edge_attributes(self, graph: HeteroData, src_name: str, dst_name: str, edge_config: DotDict):

for attr_name, attr_cfg in edge_config.items():
attr_values = instantiate(attr_cfg).compute(graph, src_name, dst_name)
graph = self.register_edge_attribute(graph, src_name, dst_name, attr_name, attr_values)
return graph

def register_edge_attribute(self, graph: HeteroData, src_name: str, dst_name: str, attr_name: str, attr_values: torch.Tensor):
num_edges = graph[(src_name, "to", dst_name)].num_edges
assert ( attr_values.shape[0] == num_edges), f"Number of edge features ({attr_values.shape[0]}) must match number of edges ({num_edges})."

graph[(src_name, "to", dst_name)][attr_name] = attr_values
return graph


class BaseEdgeAttribute(ABC, NormalizerMixin):
norm: Optional[str] = None
Expand Down

0 comments on commit cce5ea6

Please sign in to comment.