Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
add cli command
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Jun 25, 2024
1 parent 9dc2cec commit f1fe18f
Show file tree
Hide file tree
Showing 15 changed files with 145 additions and 104 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,4 @@ _build/
_version.py
*.code-workspace

/config*
/config*
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ Install via `pip` with:
$ pip install anemoi-graphs
```

## Usage

Create you graph

```
$ anemoi-graphs create recipe.yaml my_graph.pt
```

## License

```
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ dynamic = [
]
dependencies = [
"anemoi-datasets[data]>=0.3.3",
"anemoi-utils>=0.3.6",
"torch>=2.2",
"torch-geometric>=2.3.1,<2.5",
"anemoi-utils>=0.1.3",
]

optional-dependencies.all = [
Expand Down
10 changes: 9 additions & 1 deletion src/anemoi/graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@
earth_radius = 6371.0 # km
# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from ._version import __version__

earth_radius = 6371.0 # km
28 changes: 28 additions & 0 deletions src/anemoi/graphs/commands/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from anemoi.graphs.create import GraphCreator

from . import Command


class Create(Command):
"""Create a graph."""

internal = True
timestamp = True

def add_arguments(self, command_parser):
command_parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite existing files. This will delete the target graph if it already exists.",
)
command_parser.add_argument("config", help="Configuration yaml file defining the recipe to create the graph.")
command_parser.add_argument("path", help="Path to store the created graph.")

def run(self, args):
kwargs = vars(args)

c = GraphCreator(**kwargs)
c.create()


command = Create
32 changes: 0 additions & 32 deletions src/anemoi/graphs/commands/hello.py

This file was deleted.

68 changes: 68 additions & 0 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from torch_geometric.data import HeteroData

import logging

logger = logging.getLogger(__name__)


def generate_graph(graph_config: DotDict) -> HeteroData:
graph = HeteroData()

for name, nodes_cfg in graph_config.nodes.items():
graph = instantiate(nodes_cfg.node_type).transform(graph, name, nodes_cfg.get("attributes", {}))

for edges_cfg in graph_config.edges:
graph = instantiate(edges_cfg.edge_type, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {}))

return graph


class GraphCreator:
def __init__(
self,
path,
config=None,
cache=None,
print=print,
overwrite=False,
**kwargs,
):
self.path = path # Output path
self.config = config
self.cache = cache
self.print = print
self.overwrite = overwrite

def init(self):
assert os.path.exists(self.config), f"Path {self.config} does not exist."

if self._path_readable() and not self.overwrite:
raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.")

def load(self) -> HeteroData:
config = DotDict.from_file(self.config)
graph = generate_graph(config)
return graph

def save(self, graph: HeteroData) -> None:
if not os.path.exists(self.path) or self.overwrite:
torch.save(graph, self.path)
self.print(f"Graph saved at {self.path}.")

def create(self):
self.init()
graph = self.load()
self.save(graph)

def _path_readable(self) -> bool:
import torch

try:
torch.load(self.path, "r")
return True
except FileNotFoundError:
return False
4 changes: 2 additions & 2 deletions src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from abc import abstractmethod
from typing import Optional

import logging
import numpy as np
from torch_geometric.data import HeteroData

from anemoi.graphs.edges.directional import directional_edge_features
from anemoi.graphs.normalizer import NormalizerMixin
from anemoi.utils.logger import get_code_logger

logger = get_code_logger(__name__)
logger = logging.getLogger(__name__)


class BaseEdgeAttribute(ABC, NormalizerMixin):
Expand Down
17 changes: 9 additions & 8 deletions src/anemoi/graphs/edges/connections.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import logging
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional

import networkx as nx
import numpy as np
import torch
from anemoi.utils.config import DotDict
Expand All @@ -15,8 +14,6 @@
from anemoi.graphs import earth_radius
from anemoi.graphs.utils import get_grid_reference_distance

import logging

logger = logging.getLogger(__name__)


Expand All @@ -32,15 +29,19 @@ def __init__(self, src_name: str, dst_name: str):
def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage): ...

def register_edges(self, graph, head_indices, tail_indices):
graph[(self.src_name, "to", self.dst_name)].edge_index = np.stack([head_indices, tail_indices], axis=0).astype(np.int32)
graph[(self.src_name, "to", self.dst_name)].edge_index = np.stack([head_indices, tail_indices], axis=0).astype(
np.int32
)
return graph

def register_edge_attribute(self, graph: HeteroData, name: str, values: np.ndarray):
num_edges = graph[(self.src_name, "to", self.dst_name)].num_edges
assert (
values.shape[0] == num_edges
), f"Number of edge features ({values.shape[0]}) must match number of edges ({num_edges})."
graph[self.src_name, "to", self.dst_name][name] = values.reshape(num_edges, -1) # TODO: Check the [name] part works
graph[self.src_name, "to", self.dst_name][name] = values.reshape(
num_edges, -1
) # TODO: Check the [name] part works
return graph

def prepare_node_data(self, graph: HeteroData):
Expand Down Expand Up @@ -111,7 +112,8 @@ def __init__(self, src_name: str, dst_name: str, cutoff_factor: float):
assert cutoff_factor > 0, "Cutoff factor must be positive"
self.cutoff_factor = cutoff_factor

def get_cutoff_radius(self, dst_nodes: NodeStorage, mask_attr: Optional[torch.Tensor] = None):
def get_cutoff_radius(self, graph: HeteroData, mask_attr: Optional[torch.Tensor] = None):
dst_nodes = graph[self.dst_name]
mask = dst_nodes[mask_attr] if mask_attr is not None else None
dst_grid_reference_distance = get_grid_reference_distance(dst_nodes.x, mask)
radius = dst_grid_reference_distance * self.cutoff_factor
Expand All @@ -128,4 +130,3 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage):
nearest_neighbour.fit(src_nodes.x)
adj_matrix = nearest_neighbour.radius_neighbors_graph(dst_nodes.x, radius=self.radius).tocoo()
return adj_matrix

4 changes: 3 additions & 1 deletion src/anemoi/graphs/edges/directional.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def compute_directions(loc1: np.ndarray, loc2: np.ndarray, pole_vec: Optional[np
return direction / np.sqrt(np.power(direction, 2).sum(axis=0))


def directional_edge_features(loc1: np.ndarray, loc2: np.ndarray, relative_to_rotated_target: bool = True) -> np.ndarray:
def directional_edge_features(
loc1: np.ndarray, loc2: np.ndarray, relative_to_rotated_target: bool = True
) -> np.ndarray:
"""Compute features of the edge joining the nodes considered.
It computes the direction of the edge after rotating the north pole.
Expand Down
36 changes: 0 additions & 36 deletions src/anemoi/graphs/generate.py

This file was deleted.

12 changes: 2 additions & 10 deletions src/anemoi/graphs/nodes/nodes.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
import logging
from abc import ABC
from abc import abstractmethod
from pathlib import Path
from typing import Optional
from typing import Union

import h3
import numpy as np
import torch
from abc import ABC
from anemoi.datasets import open_dataset
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import HeteroData

from aifs.graphs import GraphBuilder
from aifs.graphs.generate.hexagonal import create_hexagonal_nodes
from aifs.graphs.generate.icosahedral import create_icosahedral_nodes
import logging

logger = logging.getLogger(__name__)
earth_radius = 6371.0 # km

Expand Down
18 changes: 8 additions & 10 deletions src/anemoi/graphs/nodes/weights.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import logging
from abc import ABC
from abc import abstractmethod
from typing import Optional

import numpy as np
import torch
from scipy.spatial import SphericalVoronoi
from torch_geometric.data.storage import NodeStorage

from anemoi.graphs.generate.transforms import to_sphere_xyz
from scipy.spatial import SphericalVoronoi
from anemoi.graphs.normalizer import NormalizerMixin
import logging

logger = logging.getLogger(__name__)


class BaseWeights(ABC, NormalizerMixin):
"""Base class for the weights of the nodes."""

Expand All @@ -32,24 +33,21 @@ def get_weights(self, *args, **kwargs):
class UniformWeights(BaseWeights):
"""Implements a uniform weight for the nodes."""

def __init__(self, norm: str = "unit-max"):
self.norm = norm

def compute(self, nodes: NodeStorage) -> np.ndarray:
return torch.ones(nodes.num_nodes)


class AreaWeights(BaseWeights):
"""Implements the area of the nodes as the weights."""

def __init__(self, norm: str = "unit-max", radius: float = 1.0, centre: np.ndarray = np.array[0, 0, 0]):
def __init__(self, norm: str = "unit-max", radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0])):
super().__init__(norm=norm)

# Weighting of the nodes
self.norm: str = norm
self.radius: float = radius
self.centre: np.ndarray = centre
self.radius = radius
self.centre = centre

def compute(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray:
# TODO: Check if works
latitudes, longitudes = nodes.x[:, 0], nodes.x[:, 1]
points = to_sphere_xyz((latitudes, longitudes))
sv = SphericalVoronoi(points, self.radius, self.centre)
Expand Down
3 changes: 2 additions & 1 deletion src/anemoi/graphs/normalizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import logging

import numpy as np

logger = logging.getLogger(__name__)


Expand Down
5 changes: 4 additions & 1 deletion src/anemoi/graphs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def get_nearest_neighbour(coords_rad: torch.Tensor, mask: Optional[torch.Tensor]
NearestNeighbors
fitted NearestNeighbour object
"""
assert mask is None or mask.shape == (coords_rad.shape[0], 1), "Mask must have the same shape as the number of nodes."
assert mask is None or mask.shape == (
coords_rad.shape[0],
1,
), "Mask must have the same shape as the number of nodes."

nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4)

Expand Down

0 comments on commit f1fe18f

Please sign in to comment.