Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Directed flag complex lifting (graph to simplicial complex) #44

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
transform_type: 'lifting'
transform_name: "DirectedSimplicialCliqueLifting"
complex_dim: 3
preserve_edge_attr: False
signed: True
feature_lifting: ProjectionSum
4 changes: 4 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.graph2simplicial.directed_clique_lifting import (
DirectedSimplicialCliqueLifting,
)

TRANSFORMS = {
# Graph -> Hypergraph
"HypergraphKNNLifting": HypergraphKNNLifting,
# Graph -> Simplicial Complex
"SimplicialCliqueLifting": SimplicialCliqueLifting,
"DirectedSimplicialCliqueLifting": DirectedSimplicialCliqueLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Feature Liftings
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from itertools import combinations

import networkx as nx
import torch_geometric
from toponetx.classes import SimplicialComplex
from torch_geometric.utils.undirected import is_undirected

from modules.transforms.liftings.graph2simplicial.base import Graph2SimplicialLifting


class DirectedSimplicialCliqueLifting(Graph2SimplicialLifting):
r"""Lifts graphs to simplicial complex domain by identifying
the (k-1)-cliques as k-simplices if the clique has a single
source and sink.

See [Computing persistent homology of directed flag complexes](https://arxiv.org/abs/1906.10458)
for more details.

Parameters
----------
**kwargs : optional
Additional arguments for the class.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

def _generate_graph_from_data(self, data: torch_geometric.data.Data) -> nx.Graph:
r"""Generates a NetworkX graph from the input data object.
Falls back to superclass method if data is not directed.

Parameters
----------
data : torch_geometric.data.Data
The input data.

Returns
-------
nx.DiGraph
The generated NetworkX graph.
"""
# Check if undirected and fall back to superclass method if so
if is_undirected(data.edge_index, data.edge_attr):
return super()._generate_graph_from_data(data)

# Check if data object have edge_attr, return list of tuples as [(node_id, {'features':data}, 'dim':1)] or ??
nodes = [(n, dict(features=data.x[n], dim=0)) for n in range(data.x.shape[0])]

if self.preserve_edge_attr and self._data_has_edge_attr(data):
# In case edge features are given, assign features to every edge
edge_index, edge_attr = (data.edge_index, data.edge_attr)
edges = [
(i.item(), j.item(), dict(features=edge_attr[edge_idx], dim=1))
for edge_idx, (i, j) in enumerate(
zip(edge_index[0], edge_index[1], strict=False)
)
]
self.contains_edge_attr = True
else:
# If edge_attr is not present, return list list of edges
edges = [
(i.item(), j.item())
for i, j in zip(data.edge_index[0], data.edge_index[1], strict=False)
]
self.contains_edge_attr = False
graph = nx.DiGraph()
graph.add_nodes_from(nodes)
graph.add_edges_from(edges)
return graph

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
r"""Lifts the topology of a graph to a simplicial complex by identifying
the (k-1)-cliques of a graph as the k-simplices if the cliques have a single
source and sink.

Parameters
----------
data : torch_geometric.data.Data
The input data to be lifted.

Returns
-------
dict
The lifted topology.
"""
graph = self._generate_graph_from_data(data)
simplicial_complex = SimplicialComplex(graph)
# find cliques in the undirected graph
cliques = nx.find_cliques(graph.to_undirected())
simplices = [set() for _ in range(2, self.complex_dim + 1)]

for clique in cliques:
# locate the clique in the original directed graph
gs = graph.subgraph(clique)
# check if the clique has a single source and sink
# (i.e. is a DAG) and add as a simplex if so
if nx.is_directed_acyclic_graph(gs):
for i in range(2, self.complex_dim + 1):
for c in combinations(gs, i + 1):
simplices[i - 2].add(tuple(c))

for set_k_simplices in simplices:
simplicial_complex.add_simplices_from(list(set_k_simplices))

return self._get_lifted_topology(simplicial_complex, graph)
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Test the message passing module."""

import networkx as nx
import torch
import torch_geometric

from modules.transforms.liftings.graph2simplicial.directed_clique_lifting import (
DirectedSimplicialCliqueLifting,
)


class TestDirectedSimplicialCliqueLifting:
"""Test the SimplicialCliqueLifting class."""

def setup_triangle_graph(self):
"""Sets up a test graph with 0 as a source node which
generates a 2-simplex amongst (0,1,2), equivalent to
the standard/undirected clique complex.
"""
edges = [
[0, 1],
[0, 2],
[2, 1],
]
g = nx.DiGraph()
g.add_edges_from(edges)
edge_list = torch.Tensor(list(g.edges())).T.long()
# Generate feature from 0 to 3
x = torch.tensor([1, 5, 10]).unsqueeze(1).float()
self.triangle_data = torch_geometric.data.Data(
x=x, edge_index=edge_list, num_nodes=len(g.nodes)
)

def setup_three_two_graph(self):
"""Sets up a test graph with a single source node (0)
with three edges emanating from it, and two sinks (1,2).

The directed clique complex should result in two 2-simplices
(0,2,3) and (0,1,3).
"""
edges = [
[0, 3],
[0, 2],
[0, 1],
[3, 2],
[3, 1],
]
g = nx.DiGraph()
g.add_edges_from(edges)
edge_list = torch.Tensor(list(g.edges())).T.long()
# Generate feature from 0 to 3
x = torch.tensor([1, 5, 10, 50]).unsqueeze(1).float()
self.three_two_data = torch_geometric.data.Data(
x=x, edge_index=edge_list, num_nodes=len(g.nodes)
)

def setup_missing_triangle_graph(self):
"""Sets up a test graph with one clique with a single source
and sink (0,1,3) and one without either (1,2,3).

The directed clique complex should result in only one 2-simplex
(0,1,3), the other clique is empty, illustrating the difference
between the directed clique complex and the undirected clique
complex.
"""
edges = [
[0, 3],
[0, 1],
[3, 2],
[2, 1],
[1, 3],
]
g = nx.DiGraph()
g.add_edges_from(edges)
edge_list = torch.Tensor(list(g.edges())).T.long()
# Generate feature from 0 to 3
x = torch.tensor([1, 5, 10, 50]).unsqueeze(1).float()
self.missing_triangle_data = torch_geometric.data.Data(
x=x, edge_index=edge_list, num_nodes=len(g.nodes)
)

def setup_method(self):
self.setup_triangle_graph()
self.triangle_lifting = DirectedSimplicialCliqueLifting(complex_dim=2)

self.setup_three_two_graph()
self.three_two_lifting = DirectedSimplicialCliqueLifting(complex_dim=2)

self.setup_missing_triangle_graph()
self.missing_triangle_lifting = DirectedSimplicialCliqueLifting(complex_dim=2)

def test_lift_topology(self):
"""Test the lift_topology method."""

# Test the triangle_graph lifting
triangle_lifted_data = self.triangle_lifting.forward(self.triangle_data.clone())

expected_triangle_incidence_1 = torch.tensor(
[[1.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0]]
)

assert (
expected_triangle_incidence_1 == triangle_lifted_data.incidence_1.to_dense()
).all(), "Something is wrong with triangle incidence_1 (nodes to edges)."

# single triangle with all edges connected
expected_triangle_incidence_2 = torch.tensor([[1.0], [1.0], [1.0]])

assert (
expected_triangle_incidence_2 == triangle_lifted_data.incidence_2.to_dense()
).all(), "Something is wrong with triangle incidence_2 (edges to triangles)."

# Test the three_two_graph lifting
three_two_lifted_data = self.three_two_lifting.forward(
self.three_two_data.clone()
)

expected_three_two_incidence_1 = torch.tensor(
[
[1.0, 1.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 1.0, 1.0, 1.0],
]
)

assert (
expected_three_two_incidence_1
== three_two_lifted_data.incidence_1.to_dense()
).all(), "Something is wrong with three_two incidence_1 (nodes to edges)."

# five edges incident to two triangles, and the edge
# connecting (0,3) is shared by both triangles.
expected_three_two_incidence_2 = torch.tensor(
[[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 1.0]]
)

assert (
expected_three_two_incidence_2
== three_two_lifted_data.incidence_2.to_dense()
).all(), "Something is wrong with three_two incidence_2 (edges to triangles)."

# Test missing_triangle lifting
missing_triangle_lifted_data = self.missing_triangle_lifting.forward(
self.missing_triangle_data.clone()
)

expected_missing_triangle_incidence_1 = torch.tensor(
[
[1.0, 1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 1.0, 1.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 1.0],
[0.0, 1.0, 0.0, 1.0, 1.0],
]
)

assert (
expected_missing_triangle_incidence_1
== missing_triangle_lifted_data.incidence_1.to_dense()
).all(), (
"Something is wrong with missing_triangle incidence_1 (nodes to edges)."
)

# only one triangle with the edges (3,2) and (2,1) ignored.
expected_missing_triangle_incidence_2 = torch.tensor(
[[1.0], [1.0], [0.0], [1.0], [0.0]]
)

assert (
expected_missing_triangle_incidence_2
== missing_triangle_lifted_data.incidence_2.to_dense()
).all(), (
"Something is wrong with missing_triangle incidence_2 (edges to triangles)."
)
Loading
Loading