Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Node Attribute Lifting (Graph to Hypergraph) #40

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
transform_type: 'lifting'
transform_name: "HypergraphNodeAttributeLifting"
attribute_idx: 1
feature_lifting: ProjectionSum
16 changes: 8 additions & 8 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity
Expand Down
4 changes: 4 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
)
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
from modules.transforms.liftings.graph2cell.cycle_lifting import CellCycleLifting
from modules.transforms.liftings.graph2hypergraph.attribute_lifting import (
NodeAttributeLifting,
)
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
)
Expand All @@ -19,6 +22,7 @@
TRANSFORMS = {
# Graph -> Hypergraph
"HypergraphKNNLifting": HypergraphKNNLifting,
"HypergraphNodeAttributeLifting": NodeAttributeLifting,
# Graph -> Simplicial Complex
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
Expand Down
47 changes: 47 additions & 0 deletions modules/transforms/liftings/graph2hypergraph/attribute_lifting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch_geometric

from modules.transforms.liftings.graph2hypergraph.base import Graph2HypergraphLifting


class NodeAttributeLifting(Graph2HypergraphLifting):
r"""Lifts graphs to hypergraph domain by grouping nodes with the same attribute.

Parameters
----------
attribute_idx : int
The index of the node attribute to use for hyperedge construction.
"""

def __init__(self, attribute_idx: int, **kwargs):
super().__init__(**kwargs)
self.attribute_idx = attribute_idx

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
r"""Lifts the topology of a graph to hypergraph domain by grouping nodes with the same attribute.

Parameters
----------
data : torch_geometric.data.Data
The input data to be lifted.

Returns
-------
dict
The lifted topology.
"""
attribute = data.x[:, self.attribute_idx]
unique_attributes = torch.unique(attribute)
num_hyperedges = unique_attributes.size(0)
# incidence matrix of the hypergraph
incidence_1 = torch.zeros(data.num_nodes, num_hyperedges)
for i, attr in enumerate(unique_attributes):
nodes_with_attr = torch.where(attribute == attr)[0]
incidence_1[nodes_with_attr, i] = 1

incidence_1 = incidence_1.to_sparse_coo()
return {
"incidence_hyperedges": incidence_1,
"num_hyperedges": num_hyperedges,
"x_0": data.x,
}
2 changes: 1 addition & 1 deletion modules/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def sort_vertices_ccw(vertices):
n_hyperedges = incidence.shape[1]
vertices += [i + n_vertices for i in range(n_hyperedges)]
indices = incidence.indices()
edges = np.array([indices[1].numpy(), indices[0].numpy() + n_vertices]).T
edges = np.array([indices[0].numpy(), indices[1].numpy() + n_vertices]).T
pos_n = [[i, 0] for i in range(n_vertices)]
pos_he = [[i, 1] for i in range(n_hyperedges)]
pos = pos_n + pos_he
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
import torch_geometric

from modules.transforms.liftings.graph2hypergraph.attribute_lifting import (
NodeAttributeLifting,
)


class TestNodeAttributeLifting:
"""Test the NodeAttributeLifting class."""

def setup_method(self):
# Set up a simple manual graph for testing
self.data = torch_geometric.data.Data(
x=torch.tensor(
[
[1, 0],
[1, 0],
[0, 1],
[0, 1],
],
dtype=torch.float,
),
edge_index=torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long),
)
self.lifting = NodeAttributeLifting(attribute_idx=1)

def test_lift_topology(self):
# Test the lift_topology method
lifted_topology = self.lifting.lift_topology(self.data.clone())

expected_n_hyperedges = 2
expected_incidence_1 = torch.tensor(
[
[1.0, 0.0],
[1.0, 0.0],
[0.0, 1.0],
[0.0, 1.0],
]
).to_sparse_coo()

# Print for debugging
print("Expected Incidence Matrix (Dense):")
print(expected_incidence_1.to_dense())

print("Actual Incidence Matrix (Dense):")
print(lifted_topology["incidence_hyperedges"].to_dense())

assert torch.equal(
expected_incidence_1.to_dense(),
lifted_topology["incidence_hyperedges"].to_dense(),
), "Something is wrong with incidence_hyperedges."
assert (
expected_n_hyperedges == lifted_topology["num_hyperedges"]
), "Something is wrong with the number of hyperedges."


# Running the test manually for debugging purposes
if __name__ == "__main__":
test = TestNodeAttributeLifting()
test.setup_method()
test.test_lift_topology()
204 changes: 204 additions & 0 deletions tutorials/graph2hypergraph/attribute_lifting.ipynb

Large diffs are not rendered by default.

Loading