Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Probabilistic Clique Lifting (Graph to Combinatorial) #62

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions configs/datasets/manual_cliques_dataset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: graph
data_type: toy_dataset
data_name: manual_cliques
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
num_features: 1
num_classes: 2
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: node
5 changes: 5 additions & 0 deletions configs/models/combinatorial/hmc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
in_channels: null # This will be set by the dataset
hidden_channels: 32
out_channels: null # This will be set by the dataset
n_layers: 2
negative_slope: 0.2
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
transform_type: "lifting"
transform_name: "ProbabilisticCliqueLifting"
max_cell_length: null
preserve_edge_attr: False
feature_lifting: ProjectionSum
probability: 0.3
5 changes: 5 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from modules.data.utils.concat2geometric_dataset import ConcatToGeometricDataset
from modules.data.utils.custom_dataset import CustomDataset
from modules.data.utils.utils import (
load_almost_cliques_graph,
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
load_manual_graph,
Expand Down Expand Up @@ -108,6 +109,10 @@ def load(self) -> torch_geometric.data.Dataset:
data = load_manual_graph()
dataset = CustomDataset([data], self.data_dir)

elif self.parameters.data_name in ["manual_cliques"]:
data = load_almost_cliques_graph()
dataset = CustomDataset([data], self.data_dir)

else:
raise NotImplementedError(
f"Dataset {self.parameters.data_name} not implemented"
Expand Down
103 changes: 95 additions & 8 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,71 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity


def get_combinatorial_complex_connectivity(complex, max_rank=None):
r"""Gets the connectivity matrices for the combinatorial complex.

Parameters
----------
complex : topnetx.CombinatorialComplex
Combinatorial complex.
max_rank : int
Maximum rank of the complex.

Returns
-------
dict
Dictionary containing the connectivity matrices.
"""
if max_rank is None:
max_rank = complex.dim
practical_shape = list(
np.pad(list(complex.shape), (0, max_rank + 1 - len(complex.shape)))
)

connectivity = {}

for rank_idx in range(max_rank + 1):
if rank_idx > 0:
try:
connectivity[f"incidence_{rank_idx}"] = from_sparse(
complex.incidence_matrix(rank=rank_idx - 1, to_rank=rank_idx)
)
except ValueError:
connectivity[
f"incidence_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)

try:
connectivity[f"adjacency_{rank_idx}"] = from_sparse(
complex.adjacency_matrix(rank=rank_idx, via_rank=rank_idx + 1)
)
except ValueError:
connectivity[f"adjacency_{rank_idx}"] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)

connectivity["shape"] = practical_shape

return connectivity


def generate_zero_sparse_connectivity(m, n):
r"""Generates a zero sparse connectivity matrix.

Expand Down Expand Up @@ -334,6 +384,43 @@ def load_manual_graph():
)


def load_almost_cliques_graph():
"""Create a manual graph featuring almost-cliques for testing purposes."""
# Define the vertices (just 9 vertices)
vertices = [i for i in range(9)]
y = [0, 1, 1, 1, 0, 0, 0, 0, 1]
# Define the edges

almost_5_clique = [[i, j] for i in range(5) for j in range(i + 1, 5)]
almost_4_clique = [[i + 5, j + 5] for i in range(4) for j in range(i + 1, 4)]

almost_5_clique.remove([0, 1])
almost_4_clique.remove([7, 8])

edges = [[4, 5], [0, 8], *almost_5_clique, *almost_4_clique]

# Create a graph
G = nx.Graph()

# Add vertices
G.add_nodes_from(vertices)

# Add edges
G.add_edges_from(edges)
G.to_undirected()
edge_list = torch.Tensor(list(G.edges())).T.long()

# Generate feature from 0 to 9
x = torch.tensor([1, 5, 10, 50, 100, 500, 1000, 5000, 10000]).unsqueeze(1).float()

return torch_geometric.data.Data(
x=x,
edge_index=edge_list,
num_nodes=len(vertices),
y=torch.tensor(y),
)


def get_Planetoid_pyg(cfg):
r"""Loads Planetoid graph datasets from torch_geometric.

Expand Down
78 changes: 78 additions & 0 deletions modules/models/combinatorial/hmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
from topomodelx.nn.combinatorial.hmc import HMC


class HMCModel(torch.nn.Module):
r"""A simple HMC model that runs over combinatorial complex data.
Note that some parameters are defined by the considered dataset.

Parameters
----------
model_config : Dict | DictConfig
Model configuration.
dataset_config : Dict | DictConfig
Dataset configuration.
"""

def __init__(self, model_config, dataset_config):
in_channels = (
dataset_config["num_features"]
if isinstance(dataset_config["num_features"], int)
else dataset_config["num_features"][0]
)
hidden_channels = model_config["hidden_channels"]
out_channels = dataset_config["num_classes"]
n_layers = model_config["n_layers"]
negative_slope = model_config["negative_slope"]

super().__init__()

in_channels_layer = [in_channels, in_channels, in_channels]
int_channels_layer = [hidden_channels, hidden_channels, hidden_channels]
out_channels_layer = [hidden_channels, hidden_channels, hidden_channels]

channels_per_layer = [
[in_channels_layer, int_channels_layer, out_channels_layer]
]

for _ in range(1, n_layers):
in_channels_layer = [hidden_channels, hidden_channels, hidden_channels]
int_channels_layer = [hidden_channels, hidden_channels, hidden_channels]
out_channels_layer = [hidden_channels, hidden_channels, hidden_channels]

channels_per_layer.append(
[in_channels_layer, int_channels_layer, out_channels_layer]
)

self.base_model = HMC(
channels_per_layer=channels_per_layer, negative_slope=negative_slope
)
self.linear = torch.nn.Linear(hidden_channels, out_channels)

def forward(self, data):
r"""Forward pass of the model.

Parameters
----------
data : torch_geometric.data.Data
Input data.

Returns
-------
torch.Tensor
Output tensor.
"""
x = self.base_model(
data.x_0,
data.x_1,
data.x_2,
data.adjacency_0,
data.adjacency_1,
data.adjacency_2,
data.incidence_1,
data.incidence_2,
)[1]

x = self.linear(x)

return torch.sigmoid(x)
5 changes: 5 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
)
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
from modules.transforms.liftings.graph2cell.cycle_lifting import CellCycleLifting
from modules.transforms.liftings.graph2combinatorial.probabilistic_clique_lifting import (
ProbabilisticCliqueLifting,
)
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
)
Expand All @@ -23,6 +26,8 @@
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Graph -> Combinatorial Complex
"ProbabilisticCliqueLifting": ProbabilisticCliqueLifting,
# Feature Liftings
"ProjectionSum": ProjectionSum,
# Data Manipulations
Expand Down
37 changes: 37 additions & 0 deletions modules/transforms/liftings/graph2combinatorial/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import torch
from toponetx import CombinatorialComplex

from modules.data.utils.utils import get_combinatorial_complex_connectivity
from modules.transforms.liftings.lifting import GraphLifting


Expand All @@ -13,3 +17,36 @@ class Graph2CombinatorialLifting(GraphLifting):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.type = "graph2combinatorial"

def _get_lifted_topology(self, combinatorial_complex: CombinatorialComplex) -> dict:
r"""Returns the lifted topology.

Parameters
----------
combinatorial_complex : CombinatorialComplex
The combinatorial complex.

Returns
-------
dict
The lifted topology.
"""
lifted_topology = get_combinatorial_complex_connectivity(combinatorial_complex)

# Feature liftings

features = combinatorial_complex.get_cell_attributes("features")

for i in range(combinatorial_complex.dim + 1):
x = [
feat
for cell, feat in features
if combinatorial_complex.cells.get_rank(cell) == i
]
if x:
lifted_topology[f"x_{i}"] = torch.stack(x)
else:
num_cells = len(combinatorial_complex.skeleton(i))
lifted_topology[f"x_{i}"] = torch.zeros(num_cells, 1)

return lifted_topology
Loading
Loading