Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Path lifting (Graph to Hypergraph) #52

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
transform_type: 'lifting'
transform_name: "PathLifting"
feature_lifting: ProjectionSum
2 changes: 2 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
)
from modules.transforms.liftings.graph2hypergraph.path_lifting import PathLifting
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
Expand All @@ -31,6 +32,7 @@
"OneHotDegreeFeatures": OneHotDegreeFeatures,
"NodeFeaturesToFloat": NodeFeaturesToFloat,
"KeepOnlyConnectedComponent": KeepOnlyConnectedComponent,
"PathLifting": PathLifting,
}


Expand Down
136 changes: 136 additions & 0 deletions modules/transforms/liftings/graph2hypergraph/path_lifting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""A module for the PathLifting class."""
import networkx as nx
import numpy as np
import torch
import torch_geometric

from modules.transforms.liftings.graph2hypergraph.base import Graph2HypergraphLifting


class PathLifting(Graph2HypergraphLifting):
"""Lifts graphs to hypergraph domain by considering paths between nodes."""

def __init__(
self,
source_nodes: list[int],
target_nodes: list[int],
lengths: list[int],
include_smaller_paths=False,
**kwargs,
):
# guard clauses
if len(source_nodes) != len(lengths):
raise ValueError("source_nodes and lengths must have the same length")
if target_nodes is not None and len(target_nodes) != len(source_nodes):
raise ValueError(
"When target_nodes is not None, it must have the same length"
"as source_nodes"
)
if len(source_nodes) == 0:
raise ValueError(
"source_nodes,target_nodes and lengths must have at least one element"
)

super().__init__(**kwargs)
self.source_nodes = source_nodes
self.target_nodes = target_nodes
self.lengths = lengths
self.include_smaller_paths = include_smaller_paths

def find_hyperedges(self, data: torch_geometric.data.Data):
"""Finds hyperedges from paths between nodes in a graph."""
G = torch_geometric.utils.convert.to_networkx(data, to_undirected=True)
s_hyperedges = set()

if self.target_nodes is None: # all paths stemming from source nodes only
for source, length in zip(self.source_nodes, self.lengths, strict=True):
D, d_id2label, l_leafs = self.build_stemmingTree(G, source, length)
s = self.extract_hyperedgesFromStemmingTree(D, d_id2label, l_leafs)
s_hyperedges = s_hyperedges.union(s)

else: # paths from source_nodes to target_nodes or from source nodes only
for source, target, length in zip(
self.source_nodes, self.target_nodes, self.lengths, strict=True
):
if target is None:
D, d_id2label, l_leafs = self.build_stemmingTree(G, source, length)
s = self.extract_hyperedgesFromStemmingTree(D, d_id2label, l_leafs)
s_hyperedges = s_hyperedges.union(s)
else:
paths = list(
nx.all_simple_paths(
G, source=source, target=target, cutoff=length
)
)
if not self.include_smaller_paths:
paths = [path for path in paths if len(path) - 1 == length]
s_hyperedges = s_hyperedges.union({frozenset(x) for x in paths})
return s_hyperedges

def lift_topology(self, data: torch_geometric.data.Data):
s_hyperedges = self.find_hyperedges(data)
indices = [[], []]
for edge_id, x in enumerate(s_hyperedges):
indices[1].extend([edge_id] * len(x))
indices[0].extend(list(x))
incidence = torch.sparse_coo_tensor(
indices, torch.ones(len(indices[0])), (len(data.x), len(s_hyperedges))
)
return {
"incidence_hyperedges": incidence,
"num_hyperedges": len(s_hyperedges),
"x_0": data.x,
}

def build_stemmingTree(self, G, source_root, length, verbose=False):
"""Creates a directed tree from a source node with paths of a given length."""
d_id2label = {}
stack = []
D = nx.DiGraph()
n_id = 0
D.add_node(n_id)
d_id2label[n_id] = source_root
stack.append(n_id)
n_id += 1
l_leafs = []
while len(stack) > 0:
node = stack.pop()
neighbors = list(G.neighbors(d_id2label[node]))
visited_id = nx.shortest_path(D, source=0, target=node)
visited_labels = [d_id2label[i] for i in visited_id]
for neighbor in neighbors:
if neighbor not in visited_labels:
D.add_node(n_id)
d_id2label[n_id] = neighbor
if len(visited_labels) < length:
stack.append(n_id)
elif len(visited_labels) == length:
l_leafs.append(n_id)
else:
raise ValueError("Visited labels length is greater than length")
D.add_edge(node, n_id)
n_id += 1
if verbose:
print("\nLoop Variables Summary:")
print("nodes:", node)
print("neighbors:", neighbors)
print("visited_id:", visited_id)
print("visited_labels:", visited_labels)
print("stack:", stack)
print("id2label:", d_id2label)
return D, d_id2label, l_leafs

def extract_hyperedgesFromStemmingTree(self, D, d_id2label, l_leafs):
"""From the root of the directed tree D,
extract hyperedges from the paths to the leafs."""
a_paths = np.array(
[list(map(d_id2label.get, nx.shortest_path(D, 0, x))) for x in l_leafs]
)
s_hyperedges = {
(frozenset(x)) for x in a_paths
} # set bc != paths can be same hpedge
if self.include_smaller_paths:
for i in range(a_paths.shape[1] - 1, 1, -1):
a_paths = np.unique(a_paths[:, :i], axis=0)
s_hyperedges = s_hyperedges.union({(frozenset(x)) for x in a_paths})
return s_hyperedges
2 changes: 1 addition & 1 deletion modules/utils/utils.py
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not directly part of the submission but it fixes what I believe to be a bug in a plotting function. This issue was created by another participant of the challenge.

Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def sort_vertices_ccw(vertices):
n_hyperedges = incidence.shape[1]
vertices += [i + n_vertices for i in range(n_hyperedges)]
indices = incidence.indices()
edges = np.array([indices[1].numpy(), indices[0].numpy() + n_vertices]).T
edges = np.array([indices[0].numpy(), indices[1].numpy() + n_vertices]).T
pos_n = [[i, 0] for i in range(n_vertices)]
pos_he = [[i, 1] for i in range(n_hyperedges)]
pos = pos_n + pos_he
Expand Down
156 changes: 156 additions & 0 deletions test/transforms/liftings/graph2hypergraph/test_path_lifting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Test the path lifting module."""

import numpy as np

from modules.data.load.loaders import GraphLoader
from modules.transforms.liftings.graph2hypergraph.path_lifting import PathLifting
from modules.utils.utils import load_dataset_config


class TestHypergraphPathLifting:
"""Test the PathLifting class."""

def setup_method(self):
"""Initialise the PathLifting class."""
dataset_config = load_dataset_config("manual_dataset")
loader = GraphLoader(dataset_config)
self.dataset = loader.load()
self.data = self.dataset._data

def test_true(self):
"""Naive test to check if the test is running."""
assert True

# def test_false(self):
# """Naive test to check if the test is running."""
# assert False

def test_1(self):
"""Verifies setup_method is working."""
assert self.dataset is not None

def test_2(self):
"""test: no target node for one source node returns something"""
source_nodes = [0, 2]
target_nodes = [1, None]
lengths = [2, 2]
include_smaller_paths = True
path_lifting = PathLifting(
source_nodes,
target_nodes,
lengths,
include_smaller_paths=include_smaller_paths,
)
res = path_lifting.find_hyperedges(self.data)
res_expected = [
[0, 1],
[0, 1, 2],
[0, 4, 1],
[2, 4],
[2, 1],
[2, 0],
[2, 7],
[2, 5],
[2, 3],
[2, 1, 4],
[2, 4, 0],
[2, 1, 0],
[2, 0, 7],
[2, 5, 7],
[2, 3, 6],
[2, 5, 6],
# [],
]
assert {frozenset(x) for x in res_expected} == res

def test_3(self):
"""test: include_smaller_paths=False"""
source_nodes = [0]
target_nodes = [1]
lengths = [2]
include_smaller_paths = False
res = PathLifting(
source_nodes,
target_nodes,
lengths,
include_smaller_paths=include_smaller_paths,
).find_hyperedges(self.data)
assert frozenset({0, 1}) not in res

def test_4(self):
"""test: include_smaller_paths=True"""
source_nodes = [0]
target_nodes = [1]
lengths = [2]
include_smaller_paths = True
res = PathLifting(
source_nodes,
target_nodes,
lengths,
include_smaller_paths=include_smaller_paths,
).find_hyperedges(self.data)
assert frozenset({0, 1}) in res

def test_5(self):
"""test: when include_smaller_paths=False all paths have the length specified"""
source_nodes = [0]
target_nodes = [1]
include_smaller_paths = False
for k in range(1, 5):
lengths = [k]
res = PathLifting(
source_nodes,
target_nodes,
lengths,
include_smaller_paths=include_smaller_paths,
).find_hyperedges(self.data)
assert np.array([len(x) - 1 == k for x in res]).all()

def test_6(self):
"""test: no target node global returns something"""
source_nodes = [0, 1]
target_nodes = None
lengths = [2, 2]
include_smaller_paths = False
res = PathLifting(
source_nodes,
target_nodes,
lengths,
include_smaller_paths=include_smaller_paths,
).find_hyperedges(self.data)
assert len(res) > 0

def test_7(self):
"""test: every hyperedge contains the source and target nodes when specified"""
a = np.random.default_rng().choice(
np.arange(len(self.data.x)), 2, replace=False
)
source_nodes = [a[0]]
target_nodes = [a[1]]
lengths = [np.random.default_rng().integers(1, 5)]
include_smaller_paths = False
res = PathLifting(
source_nodes,
target_nodes,
lengths,
include_smaller_paths=include_smaller_paths,
).find_hyperedges(self.data)
if len(res) > 0:
assert (
np.array([source_nodes[0] in x for x in res]).all()
and np.array([target_nodes[0] in x for x in res]).all()
)

def test_8(self):
"""test: no target node for one source node returns something"""
source_nodes = [0, 2]
target_nodes = [1, None]
lengths = [2, 2]
include_smaller_paths = False
res = PathLifting(
source_nodes,
target_nodes,
lengths,
include_smaller_paths=include_smaller_paths,
).find_hyperedges(self.data)
assert len(res) > 0
Loading
Loading