From 2084cfd5e44003f0ed574723ae90f636a7a5e9b5 Mon Sep 17 00:00:00 2001
From: pb1623
Date: Fri, 12 Jul 2024 12:16:11 +0100
Subject: [PATCH 1/2] Add SpinLifting to Graph2Pointcloud category
---
.../graph2pointcloud/spin_lifting.yaml | 3 +
modules/transforms/data_transform.py | 3 +
.../liftings/graph2pointcloud/base.py | 25 +
.../liftings/graph2pointcloud/spin_lifting.py | 137 ++++++
.../graph2pointcloud/test_spin_lifting.py | 112 +++++
tutorials/graph2pointcloud/spin_lifting.ipynb | 438 ++++++++++++++++++
6 files changed, 718 insertions(+)
create mode 100644 configs/transforms/liftings/graph2pointcloud/spin_lifting.yaml
create mode 100644 modules/transforms/liftings/graph2pointcloud/spin_lifting.py
create mode 100644 test/transforms/liftings/graph2pointcloud/test_spin_lifting.py
create mode 100644 tutorials/graph2pointcloud/spin_lifting.ipynb
diff --git a/configs/transforms/liftings/graph2pointcloud/spin_lifting.yaml b/configs/transforms/liftings/graph2pointcloud/spin_lifting.yaml
new file mode 100644
index 00000000..ff7b5073
--- /dev/null
+++ b/configs/transforms/liftings/graph2pointcloud/spin_lifting.yaml
@@ -0,0 +1,3 @@
+transform_type: 'lifting'
+transform_name: "SpinLifting"
+start_node: 0
diff --git a/modules/transforms/data_transform.py b/modules/transforms/data_transform.py
index 59253ecf..14890315 100755
--- a/modules/transforms/data_transform.py
+++ b/modules/transforms/data_transform.py
@@ -12,6 +12,7 @@
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
)
+from modules.transforms.liftings.graph2pointcloud.spin_lifting import SpinLifting
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
@@ -23,6 +24,8 @@
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
+ # Graph -> Point Cloud
+ "SpinLifting": SpinLifting,
# Feature Liftings
"ProjectionSum": ProjectionSum,
# Data Manipulations
diff --git a/modules/transforms/liftings/graph2pointcloud/base.py b/modules/transforms/liftings/graph2pointcloud/base.py
index dbe5e2cb..38dd9999 100755
--- a/modules/transforms/liftings/graph2pointcloud/base.py
+++ b/modules/transforms/liftings/graph2pointcloud/base.py
@@ -1,3 +1,6 @@
+import torch
+import torch_geometric
+
from modules.transforms.liftings.lifting import GraphLifting
@@ -13,3 +16,25 @@ class Graph2PointcloudLifting(GraphLifting):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.type = "graph2pointcloud"
+ self.start_node = kwargs.get("start_node", None)
+
+ @staticmethod
+ def _get_lifted_topology(coords: dict) -> torch_geometric.data.Data:
+ r"""Returns the lifted topology.
+
+ Parameters
+ ----------
+ coords : dict
+ The coordinates of the nodes.
+
+ Returns
+ -------
+ torch_geometric.data.Data
+ The lifted topology.
+ """
+
+ # Sort the items by key to ensure correspondences between features/labels and points
+ items = sorted(coords.items(), key=lambda x: x[0])
+ # Convert the coordinates to tensors in order to create a torch_geometric.data.Data object
+ tensor_coords = {key: torch.tensor(value) for key, value in items}
+ return torch_geometric.data.Data(pos=torch.stack(list(tensor_coords.values())))
diff --git a/modules/transforms/liftings/graph2pointcloud/spin_lifting.py b/modules/transforms/liftings/graph2pointcloud/spin_lifting.py
new file mode 100644
index 00000000..99a8064a
--- /dev/null
+++ b/modules/transforms/liftings/graph2pointcloud/spin_lifting.py
@@ -0,0 +1,137 @@
+import math
+
+import torch_geometric
+
+from modules.transforms.liftings.graph2pointcloud.base import Graph2PointcloudLifting
+
+
+class SpinLifting(Graph2PointcloudLifting):
+ r"""Lifts graphs to point clouds domain by placing the nodes in a rotational manner
+
+ Parameters
+ ----------
+ **kwargs : optional
+ Additional arguments for the class.
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ @staticmethod
+ def find_neighbors(graph, node):
+ return list(graph.neighbors(node))
+
+ @staticmethod
+ def calculate_coords_delta(angle):
+ radians = math.radians(angle)
+ x_delta = math.cos(radians)
+ y_delta = math.sin(radians)
+ return x_delta, y_delta
+
+ def assign_coordinates(self, center_coords, neighbors):
+ coords_dict = {}
+ angle_to_rotate = 30
+ current_angle = 0
+ for neighbor in neighbors:
+ if current_angle >= 360:
+ angle_to_rotate /= 2
+ current_angle = angle_to_rotate
+ delta = self.calculate_coords_delta(current_angle)
+ new_coords = (center_coords[0] + delta[0], center_coords[1] + delta[1])
+ while new_coords in coords_dict.values():
+ current_angle += angle_to_rotate
+ if current_angle >= 360:
+ angle_to_rotate /= 2
+ current_angle = angle_to_rotate
+ delta = self.calculate_coords_delta(current_angle)
+ new_coords = (center_coords[0] + delta[0], center_coords[1] + delta[1])
+ coords_dict[neighbor] = new_coords
+ current_angle += angle_to_rotate
+
+ return coords_dict
+
+ def lift(self, coords, graph, start_node):
+ old_coords = coords.copy()
+ neighbors = self.find_neighbors(graph, start_node)
+ coords.update(self.assign_coordinates(coords[start_node], neighbors))
+ # Do a breadth-first traversal of the remaining nodes
+ queue = neighbors
+ visited = set(neighbors)
+ while queue:
+ current_center = queue.pop(0)
+ neighbors = self.find_neighbors(graph, current_center)
+ # Remove neighbors that have coordinates already assigned
+ neighbors = [neighbor for neighbor in neighbors if neighbor not in coords]
+ coords.update(self.assign_coordinates(coords[current_center], neighbors))
+ for neighbor in neighbors:
+ if neighbor not in visited:
+ queue.append(neighbor)
+ visited.add(neighbor)
+
+ # Get the new coordinates generated
+ new_coords = {start_node: coords[start_node]}
+ new_coords.update(
+ {key: value for key, value in coords.items() if key not in old_coords}
+ )
+ # Find the max distance between the nodes in new_coords,
+ # which will be used as the separation distance between the disconnected parts of the graph
+ max_distance = 0
+ for node1 in new_coords:
+ for node2 in new_coords:
+ distance = math.sqrt(
+ (new_coords[node1][0] - new_coords[node2][0]) ** 2
+ + (new_coords[node1][1] - new_coords[node2][1]) ** 2
+ )
+ if distance > max_distance:
+ max_distance = distance
+
+ return coords, max_distance
+
+ def lift_topology(
+ self, data: torch_geometric.data.Data
+ ) -> torch_geometric.data.Data:
+ r"""Lifts the topology of a graph to a point cloud by placing the nodes in a rotational manner
+
+ Parameters
+ ----------
+ data : torch_geometric.data.Data
+ The input data to be lifted.
+
+ Returns
+ -------
+ torch_geometric.data.Data
+ The lifted point cloud, with node names as keys and coordinates as values.
+ """
+ graph = self._generate_graph_from_data(data)
+ coords = {}
+ node_list = list(graph.nodes)
+ # Assign the first node to (0, 0)
+ start_node = node_list[0] if self.start_node is None else self.start_node
+ coords[start_node] = (0.0, 0.0)
+ # Then spin around to assign coords to its neighbors
+ coords, max_distance = self.lift(coords, graph, start_node)
+
+ # If it's a graph with multiple disconnected parts, do the above for each part
+ remaining_nodes = set(node_list) - coords.keys()
+ max_separation_distance = max_distance
+ while remaining_nodes:
+ start_node = remaining_nodes.pop()
+ last_assigned_node = list(coords.keys())[-1]
+ new_start_coords = (
+ coords[last_assigned_node][0] + max_separation_distance,
+ coords[last_assigned_node][1],
+ )
+ while new_start_coords in coords.values():
+ new_start_coords = (
+ new_start_coords[0] + max_separation_distance,
+ new_start_coords[1],
+ )
+ coords[start_node] = new_start_coords
+ coords, max_distance = self.lift(coords, graph, start_node)
+ if max_distance > max_separation_distance:
+ max_separation_distance = max_distance
+ remaining_nodes = set(node_list) - coords.keys()
+
+ topology = self._get_lifted_topology(coords)
+
+ return topology
diff --git a/test/transforms/liftings/graph2pointcloud/test_spin_lifting.py b/test/transforms/liftings/graph2pointcloud/test_spin_lifting.py
new file mode 100644
index 00000000..cd7382ca
--- /dev/null
+++ b/test/transforms/liftings/graph2pointcloud/test_spin_lifting.py
@@ -0,0 +1,112 @@
+"""Test the message passing module."""
+import math
+
+import networkx as nx
+import torch
+
+from modules.data.utils.utils import load_manual_graph
+from modules.transforms.liftings.graph2pointcloud.spin_lifting import SpinLifting
+
+
+class TestSpinLifting:
+ """Test the SimplicialCliqueLifting class."""
+
+ def setup_method(self):
+ # Load the graph
+ self.data = load_manual_graph()
+
+ # Initialise the SimplicialCliqueLifting class
+ self.spin_lifting = SpinLifting()
+
+ def test_find_neighbors(self):
+ """Test the find_neighbors method."""
+
+ # Test the find_neighbors method
+ graph = nx.Graph()
+ graph.add_edge(0, 1)
+ graph.add_edge(1, 2)
+ graph.add_edge(1, 3)
+ graph.add_edge(2, 3)
+ neighbors = self.spin_lifting.find_neighbors(graph, 0)
+ assert neighbors == [1]
+ neighbors = self.spin_lifting.find_neighbors(graph, 1)
+ assert neighbors == [0, 2, 3]
+ neighbors = self.spin_lifting.find_neighbors(graph, 2)
+ assert neighbors == [1, 3]
+
+ def test_calculate_coords_delta(self):
+ """Test the calculate_coords_delta method."""
+
+ # Test the calculate_coords_delta method
+ allowable_error = 1e-10
+ x_delta, y_delta = self.spin_lifting.calculate_coords_delta(30)
+ assert x_delta - math.sqrt(3) / 2 < allowable_error
+ assert y_delta - 0.5 < allowable_error
+ x_delta, y_delta = self.spin_lifting.calculate_coords_delta(45)
+ assert x_delta - math.sqrt(2) / 2 < allowable_error
+ assert x_delta - math.sqrt(2) / 2 < allowable_error
+
+ def test_assign_coordinates(self):
+ """Test the assign_coordinates method."""
+
+ # Test the assign_coordinates method
+ allowable_error = 1e-10
+ center_coords = (0, 0)
+ neighbors = list(range(1, 14))
+ coords_dict = self.spin_lifting.assign_coordinates(center_coords, neighbors)
+
+ assert (
+ coords_dict[1][0] - 1 < allowable_error
+ and coords_dict[1][1] - 0.0 < allowable_error
+ )
+ assert (
+ coords_dict[4][0] - 0.0 < allowable_error
+ and coords_dict[4][1] - 1.0 < allowable_error
+ )
+ assert (
+ coords_dict[13][0] - math.cos(math.radians(15)) < allowable_error
+ and coords_dict[13][1] - math.sin(math.radians(15)) < allowable_error
+ )
+
+ def test_lift(self):
+ """Test the lift method."""
+
+ # Test the lift method
+ allowable_error = 1e-10
+ coords = {0: (0, 0)}
+ graph = nx.Graph()
+ graph.add_edge(0, 1)
+ graph.add_edge(1, 2)
+ graph.add_edge(1, 3)
+ graph.add_edge(2, 3)
+ coords, max_distance = self.spin_lifting.lift(coords, graph, 0)
+ print(coords)
+ assert coords[1][0] - 1 < allowable_error and coords[1][1] - 0 < allowable_error
+ assert coords[2][0] - 2 < allowable_error and coords[2][1] - 0 < allowable_error
+ assert (
+ coords[3][0] - (1 + math.sqrt(3) / 2) < allowable_error
+ and coords[3][1] - (0 + 0.5) < allowable_error
+ )
+ assert max_distance - 2.0 < allowable_error
+
+ def test_lift_topology(self):
+ """Test the lift_topology method."""
+
+ # Test the lift_topology method
+ lifted_data = self.spin_lifting.forward(self.data.clone())
+ assert lifted_data.x.shape[0] == 8
+ assert lifted_data.y.shape[0] == 8
+ assert lifted_data.pos.shape == (8, 2)
+ expected_pos = torch.tensor(
+ [
+ [0.0, 0.0],
+ [1.0, 0.0],
+ [math.sqrt(3) / 2, 0.5],
+ [1 + math.sqrt(3) / 2, 0.5],
+ [0.5, math.sqrt(3) / 2],
+ [math.sqrt(3), 1.0],
+ [2 + math.sqrt(3) / 2, 0.5],
+ [0.0, 1.0],
+ ]
+ )
+ assert torch.allclose(lifted_data.pos, expected_pos)
diff --git a/tutorials/graph2pointcloud/spin_lifting.ipynb b/tutorials/graph2pointcloud/spin_lifting.ipynb
new file mode 100644
index 00000000..b3d4f9d5
--- /dev/null
+++ b/tutorials/graph2pointcloud/spin_lifting.ipynb
@@ -0,0 +1,438 @@
+{
+ "cells": [
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Graph-to-Point Cloud Spin Lifting Tutorial",
+ "id": "4af0d0e96a2e43fa"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "***\n",
+ "The notebook is divided into sections:\n",
+ "\n",
+ "- [Loading the dataset](#loading-the-dataset) loads the config files for the data and the desired transformation, create a dataset object and visualizes it.\n",
+ "- [Loading and applying the lifting](#loading-and-applying-the-lifting) tests that the lifting creates the expected point cloud.\n",
+ "- [Create and run a nn model over the point cloud](#create-and-run-a-nn-model-over-the-point-cloud) simply runs a forward pass of the model to check that everything is working as expected.\n",
+ "\n",
+ "***"
+ ],
+ "id": "5242b0a7d221e4fe"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Imports and utilities",
+ "id": "76ac0d85cdceb12"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-12T11:10:23.960275Z",
+ "start_time": "2024-07-12T11:10:22.069764Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# With this cell any imported module is reloaded before each cell execution\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "from modules.data.load.loaders import GraphLoader\n",
+ "from modules.data.preprocess.preprocessor import PreProcessor\n",
+ "from modules.utils.utils import (\n",
+ " describe_data,\n",
+ " load_dataset_config,\n",
+ " load_transform_config,\n",
+ ")"
+ ],
+ "id": "e85258bdcef88c5a",
+ "outputs": [],
+ "execution_count": 1
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "",
+ "id": "4bd060736c12ec5b"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "## Loading the Dataset",
+ "id": "eb2fbd8292458975"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-12T11:10:26.285392Z",
+ "start_time": "2024-07-12T11:10:26.249461Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "dataset_name = \"manual_dataset\"\n",
+ "dataset_config = load_dataset_config(dataset_name)\n",
+ "loader = GraphLoader(dataset_config)"
+ ],
+ "id": "b991906e42568edb",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Dataset configuration for manual_dataset:\n",
+ "\n",
+ "{'data_domain': 'graph',\n",
+ " 'data_type': 'toy_dataset',\n",
+ " 'data_name': 'manual',\n",
+ " 'data_dir': 'datasets/graph/toy_dataset',\n",
+ " 'num_features': 1,\n",
+ " 'num_classes': 2,\n",
+ " 'task': 'classification',\n",
+ " 'loss_type': 'cross_entropy',\n",
+ " 'monitor_metric': 'accuracy',\n",
+ " 'task_level': 'node'}\n"
+ ]
+ }
+ ],
+ "execution_count": 2
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "We can then access to the data through the `load()`method:",
+ "id": "77f459b894319f3c"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-12T11:10:28.923577Z",
+ "start_time": "2024-07-12T11:10:28.714105Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "dataset = loader.load()\n",
+ "describe_data(dataset)"
+ ],
+ "id": "d970bfd090fd1280",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Dataset only contains 1 sample:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "