Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CellEncoding Lifting (Cell to Graph) #12

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions configs/datasets/manual_cell_complex.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: cell_complex
data_type: toy_dataset
data_name: manual_cell_complex
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
num_features: 1
num_classes: 2
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: node
3 changes: 3 additions & 0 deletions configs/models/graph/gcn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
in_channels: null # This will be set by the dataset
hidden_channels: 32
out_channels: null # This will be set by the dataset
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
transform_type: 'lifting'
transform_name: "CellEncodingLifting"
feature_lifting: ProjectionSum
14 changes: 13 additions & 1 deletion modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from modules.data.utils.utils import (
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
load_manual_cell_complex,
load_manual_graph,
load_simplicial_dataset,
)
Expand Down Expand Up @@ -143,7 +144,18 @@ def load(
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
return load_cell_complex_dataset(self.parameters)
# Define the path to the data directory
root_folder = rootutils.find_root()
root_data_dir = os.path.join(root_folder, self.parameters["data_dir"])

self.data_dir = os.path.join(root_data_dir, self.parameters["data_name"])

if self.parameters.data_name in ["manual_cell_complex"]:
data = load_manual_cell_complex()
dataset = CustomDataset([data], self.data_dir)
else:
dataset = load_cell_complex_dataset(self.parameters)
return dataset


class SimplicialLoader(AbstractLoader):
Expand Down
29 changes: 21 additions & 8 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity
Expand Down Expand Up @@ -334,6 +334,19 @@ def load_manual_graph():
)


def load_manual_cell_complex():
"""Create a manual cell complex for testing purposes."""
return torch_geometric.data.Data(
x=torch.tensor([0, 1, 2, 3]),
x_0=torch.zeros(4, 1, dtype=torch.float32),
x_1=torch.zeros(4, 1, dtype=torch.float32),
x_2=torch.zeros(1, 1, dtype=torch.float32),
num_nodes=4,
edge_index=torch.tensor([[0, 0, 1, 2], [1, 2, 2, 3]]),
incidence_2=torch.tensor([[1], [-1], [1], [0]], dtype=torch.float32),
)


def get_Planetoid_pyg(cfg):
r"""Loads Planetoid graph datasets from torch_geometric.

Expand Down
23 changes: 23 additions & 0 deletions modules/models/graph/gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from torch_geometric.nn import GCNConv


class GCNModel(torch.nn.Module):
def __init__(self, model_config, dataset_config):
in_channels = (
dataset_config["num_features"]
if isinstance(dataset_config["num_features"], int)
else dataset_config["num_features"][0]
)
hidden_channels = model_config["hidden_channels"]
out_channels = dataset_config["num_classes"]
super().__init__()
self.conv = GCNConv(in_channels, hidden_channels)
self.linear = torch.nn.Linear(hidden_channels, out_channels)

def forward(self, data):
x, edge_index = data.x, data.edge_index

x = self.conv(x, edge_index)
x = torch.relu(x)
return self.linear(x)
5 changes: 5 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
OneHotDegreeFeatures,
)
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
from modules.transforms.liftings.cell2graph.cell_encoding_lifting import (
CellEncodingLifting,
)
from modules.transforms.liftings.graph2cell.cycle_lifting import CellCycleLifting
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
Expand All @@ -23,6 +26,8 @@
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Cell Complex -> Graph
"CellEncodingLifting": CellEncodingLifting,
# Feature Liftings
"ProjectionSum": ProjectionSum,
# Data Manipulations
Expand Down
15 changes: 15 additions & 0 deletions modules/transforms/liftings/cell2graph/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from modules.transforms.liftings.lifting import CellComplexLifting


class Cell2GraphLifting(CellComplexLifting):
r"""Abstract class for lifting cell complexes to graphs.

Parameters
----------
**kwargs : optional
Additional arguments for the class.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.type = "cell2graph"
109 changes: 109 additions & 0 deletions modules/transforms/liftings/cell2graph/cell_encoding_lifting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import networkx as nx
import torch
from toponetx.classes import CellComplex
from torch_geometric.data import Data
from torch_geometric.utils.convert import from_networkx

from modules.transforms.liftings.cell2graph.base import Cell2GraphLifting


class CellEncodingLifting(Cell2GraphLifting):
r"""Lifts cell complex data to graph by using CellEncoding
Parameters
----------
**kwargs : optional
Additional arguments for the class
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.encodings = {
0: torch.tensor([1, 0, 0], dtype=torch.float),
1: torch.tensor([0, 1, 0], dtype=torch.float),
2: torch.tensor([0, 0, 1], dtype=torch.float),
}

def data2cell_complex(self, data: Data) -> CellComplex:
r"""Helper function to transform a torch_geometric.data.Data
object to a toponetx.classes.CellComplex object. E.g. previous
liftings might return a Data instead of a CellComplex object.
Parameters
----------
data : torch_geometric.data.Data
The data to be transformed to a CellComplex
Returns
-------
CellComplex
The transformed object

"""
cc = CellComplex()

# Add 0-cells (vertices)
for i in range(data.num_nodes):
cc.add_node(i)

# Add 1-cells (edges)
edge_index = data.edge_index.t().tolist()
for u, v in edge_index:
cc.add_edge(u, v)

# Add 2-cells (faces)
incidence_2 = data.incidence_2.t()
for i in range(incidence_2.shape[0]):
boundary = incidence_2[i].to_dense().nonzero().flatten().tolist()
cc.add_cell(boundary, rank=2)

return cc

def lift_topology(self, cell_complex: CellComplex | Data) -> dict:
r"""Lifts a cell complex dataset to a graph by using CellEncoding
as described in 'Reducing learning on cell complexes to graphs' by
Fabian Jogl, Maximilian Thiessen and Thomas Gaertner.
Parameters
----------
cell_complex : toponetx.classes.CellComplex or torch_geometric.data.Data
The input data to be lifted
Returns
-------
dict
The lifted topology
"""
# Transform input to CellComplex if necessary
if type(cell_complex) == Data:
cell_complex = self.data2cell_complex(cell_complex)

G = nx.Graph()

# Add 0-cells as nodes
G.add_nodes_from(cell_complex.nodes, cell_dim=self.encodings[0])
G.add_edges_from(cell_complex.edges)

# Add 1-cells
for u, v in cell_complex.edges:
min_e = min(u, v)
max_e = max(u, v)
G.add_node((min_e, max_e), cell_dim=self.encodings[1])
G.add_edge(u, (min_e, max_e))
G.add_edge(v, (min_e, max_e))

# Add 2-cells
for c in cell_complex.cells:
G.add_node(c.elements, cell_dim=self.encodings[2])

previous_boundaries = []
for b0, b1 in c.boundary:
min_b, max_b = min(b0, b1), max(b0, b1)
G.add_edge((min_b, max_b), c.elements)
for pre_b in previous_boundaries:
G.add_edge(pre_b, (min_b, max_b))
previous_boundaries.append((min_b, max_b))

data = from_networkx(G, group_node_attrs="cell_dim")

return {
"x": data.x,
"shape": [data.x.shape[0], data.edge_index.shape[1]],
"edge_index": data.edge_index,
"num_nodes": data.x.shape[0],
}
19 changes: 19 additions & 0 deletions modules/transforms/liftings/lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,25 @@ class CellComplexLifting(AbstractLifting):
def __init__(self, feature_lifting="ProjectionSum", **kwargs):
super().__init__(feature_lifting=feature_lifting, **kwargs)

def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data:
r"""Applies the full lifting (topology + features) to the input data.

Parameters
----------
data : torch_geometric.data.Data
The input data to be lifted.

Returns
-------
torch_geometric.data.Data
The lifted data.
"""
lifted_topology = self.lift_topology(data)
lifted_topology = self.feature_lifting(lifted_topology)

# use only lifted topology, since we want to drop cells and their features from initial_data
return torch_geometric.data.Data(**lifted_topology)


class SimplicialLifting(AbstractLifting):
r"""Abstract class for lifting simplicial complexes to other (topological) domains.
Expand Down
70 changes: 70 additions & 0 deletions test/transforms/liftings/cell2graph/test_cell_encoding_lifting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Test the cell encoding module."""

import networkx as nx
from torch_geometric.utils import to_networkx

from modules.data.utils.utils import load_manual_cell_complex
from modules.transforms.liftings.cell2graph.cell_encoding_lifting import (
CellEncodingLifting,
)


class TestCellEncodingLifting:
"""Test the CellEncoding class."""

def setup_method(self):
# Load the cell complex
self.data = load_manual_cell_complex()

# Initialize the CellEncodingLifting class
self.lifting = CellEncodingLifting()

def test_lift_topology(self):
# test the lift topology method
lifted_data = self.lifting.forward(self.data.clone())
lifted_graph = to_networkx(lifted_data, node_attrs=["x"], to_undirected=True)

expected_graph = nx.Graph()
expected_graph.add_nodes_from(range(8))
expected_graph.add_edges_from(
[
(0, 1),
(0, 2),
(0, 4),
(0, 5),
(1, 2),
(1, 4),
(1, 6),
(2, 3),
(2, 5),
(2, 6),
(2, 7),
(3, 7),
(4, 5),
(4, 6),
(4, 8),
(5, 6),
(5, 8),
(6, 8),
]
)

assert nx.is_isomorphic(
lifted_graph, expected_graph
), "Something in the lifted graph structure is wrong."

expected_node_features = {
0: [1.0, 0.0, 0.0],
1: [1.0, 0.0, 0.0],
2: [1.0, 0.0, 0.0],
3: [1.0, 0.0, 0.0],
4: [0.0, 1.0, 0.0],
5: [0.0, 1.0, 0.0],
6: [0.0, 1.0, 0.0],
7: [0.0, 1.0, 0.0],
8: [0.0, 0.0, 1.0],
}

assert (
nx.get_node_attributes(lifted_graph, "x") == expected_node_features
), "Something in the lifted graph structure is wrong."
Loading
Loading