Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpinLifting (Graph to Pointcloud) #56

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
transform_type: 'lifting'
transform_name: "SpinLifting"
start_node: 0
3 changes: 3 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
)
from modules.transforms.liftings.graph2pointcloud.spin_lifting import SpinLifting
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
Expand All @@ -23,6 +24,8 @@
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Graph -> Point Cloud
"SpinLifting": SpinLifting,
# Feature Liftings
"ProjectionSum": ProjectionSum,
# Data Manipulations
Expand Down
25 changes: 25 additions & 0 deletions modules/transforms/liftings/graph2pointcloud/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import torch
import torch_geometric

from modules.transforms.liftings.lifting import GraphLifting


Expand All @@ -13,3 +16,25 @@ class Graph2PointcloudLifting(GraphLifting):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.type = "graph2pointcloud"
self.start_node = kwargs.get("start_node", None)

@staticmethod
def _get_lifted_topology(coords: dict) -> torch_geometric.data.Data:
r"""Returns the lifted topology.

Parameters
----------
coords : dict
The coordinates of the nodes.

Returns
-------
torch_geometric.data.Data
The lifted topology.
"""

# Sort the items by key to ensure correspondences between features/labels and points
items = sorted(coords.items(), key=lambda x: x[0])
# Convert the coordinates to tensors in order to create a torch_geometric.data.Data object
tensor_coords = {key: torch.tensor(value) for key, value in items}
return torch_geometric.data.Data(pos=torch.stack(list(tensor_coords.values())))
135 changes: 135 additions & 0 deletions modules/transforms/liftings/graph2pointcloud/spin_lifting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import math

import torch_geometric

from modules.transforms.liftings.graph2pointcloud.base import Graph2PointcloudLifting


class SpinLifting(Graph2PointcloudLifting):
r"""Lifts graphs to point clouds domain by placing the nodes in a rotational manner

Parameters
----------
**kwargs : optional
Additional arguments for the class.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

@staticmethod
def find_neighbors(graph, node):
return list(graph.neighbors(node))

@staticmethod
def calculate_coords_delta(angle):
radians = math.radians(angle)
x_delta = math.cos(radians)
y_delta = math.sin(radians)
return x_delta, y_delta

def assign_coordinates(self, center_coords, neighbors):
coords_dict = {}
angle_to_rotate = 30
current_angle = 0
for neighbor in neighbors:
if current_angle >= 360:
angle_to_rotate /= 2
current_angle = angle_to_rotate
delta = self.calculate_coords_delta(current_angle)
new_coords = (center_coords[0] + delta[0], center_coords[1] + delta[1])
while new_coords in coords_dict.values():
current_angle += angle_to_rotate
if current_angle >= 360:
angle_to_rotate /= 2
current_angle = angle_to_rotate
delta = self.calculate_coords_delta(current_angle)
new_coords = (center_coords[0] + delta[0], center_coords[1] + delta[1])
coords_dict[neighbor] = new_coords
current_angle += angle_to_rotate

return coords_dict

def lift(self, coords, graph, start_node):
old_coords = coords.copy()
neighbors = self.find_neighbors(graph, start_node)
coords.update(self.assign_coordinates(coords[start_node], neighbors))
# Do a breadth-first traversal of the remaining nodes
queue = neighbors
visited = set(neighbors)
while queue:
current_center = queue.pop(0)
neighbors = self.find_neighbors(graph, current_center)
# Remove neighbors that have coordinates already assigned
neighbors = [neighbor for neighbor in neighbors if neighbor not in coords]
coords.update(self.assign_coordinates(coords[current_center], neighbors))
for neighbor in neighbors:
if neighbor not in visited:
queue.append(neighbor)
visited.add(neighbor)

# Get the new coordinates generated
new_coords = {start_node: coords[start_node]}
new_coords.update(
{key: value for key, value in coords.items() if key not in old_coords}
)
# Find the max distance between the nodes in new_coords,
# which will be used as the separation distance between the disconnected parts of the graph
max_distance = 0
for node1 in new_coords:
for node2 in new_coords:
distance = math.sqrt(
(new_coords[node1][0] - new_coords[node2][0]) ** 2
+ (new_coords[node1][1] - new_coords[node2][1]) ** 2
)
if distance > max_distance:
max_distance = distance

return coords, max_distance

def lift_topology(
self, data: torch_geometric.data.Data
) -> torch_geometric.data.Data:
r"""Lifts the topology of a graph to a point cloud by placing the nodes in a rotational manner

Parameters
----------
data : torch_geometric.data.Data
The input data to be lifted.

Returns
-------
torch_geometric.data.Data
The lifted point cloud, with node names as keys and coordinates as values.
"""
graph = self._generate_graph_from_data(data)
coords = {}
node_list = list(graph.nodes)
# Assign the first node to (0, 0)
start_node = node_list[0] if self.start_node is None else self.start_node
coords[start_node] = (0.0, 0.0)
# Then spin around to assign coords to its neighbors
coords, max_distance = self.lift(coords, graph, start_node)

# If it's a graph with multiple disconnected parts, do the above for each part
remaining_nodes = set(node_list) - coords.keys()
max_separation_distance = max_distance
while remaining_nodes:
start_node = remaining_nodes.pop()
last_assigned_node = list(coords.keys())[-1]
new_start_coords = (
coords[last_assigned_node][0] + max_separation_distance,
coords[last_assigned_node][1],
)
while new_start_coords in coords.values():
new_start_coords = (
new_start_coords[0] + max_separation_distance,
new_start_coords[1],
)
coords[start_node] = new_start_coords
coords, max_distance = self.lift(coords, graph, start_node)
if max_distance > max_separation_distance:
max_separation_distance = max_distance
remaining_nodes = set(node_list) - coords.keys()

return self._get_lifted_topology(coords)
112 changes: 112 additions & 0 deletions test/transforms/liftings/graph2pointcloud/test_spin_lifting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Test the message passing module."""
import math

import networkx as nx
import torch

from modules.data.utils.utils import load_manual_graph
from modules.transforms.liftings.graph2pointcloud.spin_lifting import SpinLifting


class TestSpinLifting:
"""Test the SimplicialCliqueLifting class."""

def setup_method(self):
# Load the graph
self.data = load_manual_graph()

# Initialise the SimplicialCliqueLifting class
self.spin_lifting = SpinLifting()

def test_find_neighbors(self):
"""Test the find_neighbors method."""

# Test the find_neighbors method
graph = nx.Graph()
graph.add_edge(0, 1)
graph.add_edge(1, 2)
graph.add_edge(1, 3)
graph.add_edge(2, 3)
neighbors = self.spin_lifting.find_neighbors(graph, 0)
assert neighbors == [1]
neighbors = self.spin_lifting.find_neighbors(graph, 1)
assert neighbors == [0, 2, 3]
neighbors = self.spin_lifting.find_neighbors(graph, 2)
assert neighbors == [1, 3]

def test_calculate_coords_delta(self):
"""Test the calculate_coords_delta method."""

# Test the calculate_coords_delta method
allowable_error = 1e-10
x_delta, y_delta = self.spin_lifting.calculate_coords_delta(30)
assert x_delta - math.sqrt(3) / 2 < allowable_error
assert y_delta - 0.5 < allowable_error
x_delta, y_delta = self.spin_lifting.calculate_coords_delta(45)
assert x_delta - math.sqrt(2) / 2 < allowable_error
assert x_delta - math.sqrt(2) / 2 < allowable_error

def test_assign_coordinates(self):
"""Test the assign_coordinates method."""

# Test the assign_coordinates method
allowable_error = 1e-10
center_coords = (0, 0)
neighbors = list(range(1, 14))
coords_dict = self.spin_lifting.assign_coordinates(center_coords, neighbors)

assert (
coords_dict[1][0] - 1 < allowable_error
and coords_dict[1][1] - 0.0 < allowable_error
)
assert (
coords_dict[4][0] - 0.0 < allowable_error
and coords_dict[4][1] - 1.0 < allowable_error
)
assert (
coords_dict[13][0] - math.cos(math.radians(15)) < allowable_error
and coords_dict[13][1] - math.sin(math.radians(15)) < allowable_error
)

def test_lift(self):
"""Test the lift method."""

# Test the lift method
allowable_error = 1e-10
coords = {0: (0, 0)}
graph = nx.Graph()
graph.add_edge(0, 1)
graph.add_edge(1, 2)
graph.add_edge(1, 3)
graph.add_edge(2, 3)
coords, max_distance = self.spin_lifting.lift(coords, graph, 0)
print(coords)
assert coords[1][0] - 1 < allowable_error and coords[1][1] - 0 < allowable_error
assert coords[2][0] - 2 < allowable_error and coords[2][1] - 0 < allowable_error
assert (
coords[3][0] - (1 + math.sqrt(3) / 2) < allowable_error
and coords[3][1] - (0 + 0.5) < allowable_error
)
assert max_distance - 2.0 < allowable_error

def test_lift_topology(self):
"""Test the lift_topology method."""

# Test the lift_topology method
lifted_data = self.spin_lifting.forward(self.data.clone())
assert lifted_data.x.shape[0] == 8
assert lifted_data.y.shape[0] == 8
assert lifted_data.pos.shape == (8, 2)
expected_pos = torch.tensor(
[
[0.0, 0.0],
[1.0, 0.0],
[math.sqrt(3) / 2, 0.5],
[1 + math.sqrt(3) / 2, 0.5],
[0.5, math.sqrt(3) / 2],
[math.sqrt(3), 1.0],
[2 + math.sqrt(3) / 2, 0.5],
[0.0, 1.0],
]
)
assert torch.allclose(lifted_data.pos, expected_pos)
Loading
Loading