Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Node centrality Lifting (Graph to Hypergraph) #46

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
transform_type: 'lifting'
transform_name: "HypergraphNodeCentralityLifting"
network_type: 'weighted'
alpha: 0.85
th_percentile: 0.05
n_most_influential: 2
do_weight_hyperedge_influence: False
do_hyperedge_node_assignment_feature_lifting_passthrough: False
max_iter: 100
tol: 1e-06
feature_lifting: ProjectionSum
4 changes: 4 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
)
from modules.transforms.liftings.graph2hypergraph.node_centrality_lifting import (
HypergraphNodeCentralityLifting,
)
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)

TRANSFORMS = {
# Graph -> Hypergraph
"HypergraphKNNLifting": HypergraphKNNLifting,
"HypergraphNodeCentralityLifting": HypergraphNodeCentralityLifting,
# Graph -> Simplicial Complex
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import networkx as nx
import numpy as np
import torch
import torch_geometric

from modules.transforms.liftings.graph2hypergraph.base import Graph2HypergraphLifting


class HypergraphNodeCentralityLifting(Graph2HypergraphLifting):
r"""Lifts graphs to hypergraph domain using node centrality.

This lifting creates hyperedges based on central, i.e. highly influential, nodes in the network. Mapping a connection between individual nodes to specific nodes in the network architecture that have a specific and potentially competing influence on them is a very convenient scenario to be modelled via hyperedges. Using shortest path distance to identify the most influential nodes on any given node even allows for placing weights on the hyperedge connection to individual, connected nodes (i.e. the inverse shortest path distance to the corresponding most influential node that the hyperedge represents). To define and identify influential nodes in the network, we refer to the variant of the Eigenvector Centrality with an additional jump probability (i.e. PageRank)

Parameters
----------
network_type : str
Network type may be weighted or unweighted. Default is "weighted".
alpha: float
jump probability, called dampening factor, which decides whether to continue following the transition matrix or teleport to random positions, default=0.85.
th_percentile: float
Fraction of most influential nodes in the network to consider, default=0.05.
n_most_influential: integer
Number of most influential nodes to assign a node to. default=2.
do_weight_hyperedge_influence: bool
add a weight to the hyperedge connections per node based on the inverse spath distance to influential node. default=False.
do_hyperedge_node_assignment_feature_lifting_passthrough: bool
assign features of most influential nodes to corresponding hyperedges and pass through feature lifting. default=False.
max_iter: integer
Maximum number of iterations in power method eigenvalue solver.
tol: float
Error tolerance used to check convergence in power method solver. The iteration will stop after a tolerance of len(G) * tol is reached.

**kwargs : optional
Additional arguments for the class.
"""

def __init__(
self,
network_type: str = "weighted",
alpha: float = 0.85,
th_percentile: float = 0.05,
n_most_influential: float = 2,
do_weight_hyperedge_influence: bool = False,
do_hyperedge_node_assignment_feature_lifting_passthrough: bool = False,
max_iter: int = 100,
tol: float = 1e-06,
**kwargs,
):
super().__init__(**kwargs)
self.network_type = network_type
self.alpha = alpha
self.max_iter = max_iter
self.tol = tol
self.th_percentile = th_percentile
self.n_most_influential = n_most_influential
self.do_weight_hyperedge_influence = do_weight_hyperedge_influence
self.do_hyperedge_node_assignment_feature_lifting_passthrough = (
do_hyperedge_node_assignment_feature_lifting_passthrough
)

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
r"""Lifts the topology of a graph to hypergraph domain using node centrality.

Parameters
----------
data : torch_geometric.data.Data
The input data to be lifted.

Returns
-------
dict
The lifted topology.
"""

edge_list = data.edge_index.t().numpy()

# for unweighted graphs or higher-dimensional edge or node features revert to unweighted network structure
if (
data.edge_attr is None
or self.network_type == "unweighted"
or data.edge_attr.shape[1] > 1
):
edge_attr = np.ones(shape=(len(edge_list), 1))
elif isinstance(data.edge_attr, torch.Tensor):
edge_attr = data.edge_attr.numpy()
else:
edge_attr = data.edge_attr

if data.x is None or self.network_type == "unweighted" or data.x.shape[1] > 1:
node_attr = np.ones(shape=(data.num_nodes, 1))
elif isinstance(data.x, torch.Tensor):
node_attr = data.x.numpy()
else:
node_attr = data.x

# create directed networkx graph from pyg data
G = nx.Graph()
for v in range(len(node_attr)):
G.add_node(v)
G.nodes[v]["w"] = node_attr[v][0]

for e in range(len(edge_list)):
v1 = edge_list[e][0]
v2 = edge_list[e][1]
G.add_edge(v1, v2, w=edge_attr[e][0])

assert self.n_most_influential >= 1

# estimate distance between all nodes
if self.network_type == "unweighted":
sp = dict(nx.all_pairs_shortest_path_length(G))
elif self.network_type == "weighted":
sp = dict(nx.all_pairs_dijkstra_path_length(G))
else:
raise NotImplementedError(
f"network type {self.network_type} not implemented"
)

# estimate node centrality for all nodes
pr = nx.pagerank(
G, alpha=self.alpha, max_iter=self.max_iter, tol=self.tol, weight="w"
)

# estimate fraction of most influential nodes in the network to consider, i.e. the hyperedges
th_cutoff = np.quantile(list(pr.values()), (1 - self.th_percentile))
nodes_most_influential = [n for n, v in pr.items() if v >= th_cutoff]
num_hyperedges = len(nodes_most_influential)
hyperedge_map = {v: e for e, v in enumerate(nodes_most_influential)}

incidence_hyperedges = torch.zeros(data.num_nodes, num_hyperedges)

# assign each node to the hyeredges corresponding to the top "n_most_influential" most influential nodes
for v in list(G.nodes()):
if v in nodes_most_influential:
incidence_hyperedges[v, hyperedge_map[v]] = 1
else:
sp_v_influencial = {
k: v for k, v in sp[v].items() if k in nodes_most_influential
}
v_influencial = [
(k, v)
for i, (k, v) in enumerate(sp_v_influencial.items())
if i < self.n_most_influential
]
for k_infl, v_infl in v_influencial:
w = 1
if self.do_weight_hyperedge_influence:
w = max(1 / v_infl, 0.0001)
incidence_hyperedges[v, hyperedge_map[k_infl]] = w

incidence_hyperedges = incidence_hyperedges.to_sparse_coo()
lifted_data = {
"incidence_hyperedges": incidence_hyperedges,
"num_hyperedges": num_hyperedges,
"x_0": data.x,
}

if self.do_hyperedge_node_assignment_feature_lifting_passthrough:
# assign features of most influential nodes to corresponding hyperedges and pass through feature lifting.
lifted_data["x_hyperedges"] = data.x[nodes_most_influential]

return lifted_data
2 changes: 1 addition & 1 deletion modules/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def sort_vertices_ccw(vertices):
n_hyperedges = incidence.shape[1]
vertices += [i + n_vertices for i in range(n_hyperedges)]
indices = incidence.indices()
edges = np.array([indices[1].numpy(), indices[0].numpy() + n_vertices]).T
edges = np.array([indices[0].numpy(), indices[1].numpy() + n_vertices]).T
pos_n = [[i, 0] for i in range(n_vertices)]
pos_he = [[i, 1] for i in range(n_hyperedges)]
pos = pos_n + pos_he
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Test Page Rank Lifting."""

import pytest
import torch

from modules.data.utils.utils import load_manual_graph
from modules.transforms.liftings.graph2hypergraph.node_centrality_lifting import (
HypergraphNodeCentralityLifting,
)


class TestHypergraphNodeCentralityLifting:
"""Test the HypergraphNodeCentralityLifting class."""

def setup_method(self):
self.data = load_manual_graph()

self.lifting = HypergraphNodeCentralityLifting(
network_type="weighted",
th_percentile=0.2,
n_most_influential=1,
)

def test_lift_topology(self):
# Test the lift_topology method
lifted_data = self.lifting.forward(self.data.clone())

expected_n_hyperedges = 2

expected_incidence_1 = torch.tensor(
[
[1.0, 0.0],
[1.0, 0.0],
[0.0, 1.0],
[0.0, 1.0],
[1.0, 0.0],
[0.0, 1.0],
[0.0, 1.0],
[1.0, 0.0],
]
)

assert (
expected_incidence_1 == lifted_data.incidence_hyperedges.to_dense()
).all(), "Something is wrong with incidence_hyperedges."
assert (
expected_n_hyperedges == lifted_data.num_hyperedges
), "Something is wrong with the number of hyperedges."

self.lifting.network_type = "unweighted"
lifted_data = self.lifting.forward(self.data.clone())

assert (
expected_incidence_1 == lifted_data.incidence_hyperedges.to_dense()
).all(), "Something is wrong with incidence_hyperedges."
assert (
expected_n_hyperedges == lifted_data.num_hyperedges
), "Something is wrong with the number of hyperedges."

expected_incidence_1 = torch.tensor(
[
[1.0, 0.0],
[1.0, 0.0],
[0.0, 1.0],
[0.0, 1.0],
[1.0, 0.0],
[0.0, 1.0],
[0.0, 0.5],
[1.0, 0.0],
]
)

self.lifting.network_type = "unweighted"
self.lifting.do_weight_hyperedge_influence = True
lifted_data = self.lifting.forward(self.data.clone())

assert (
expected_incidence_1 == lifted_data.incidence_hyperedges.to_dense()
).all(), "Something is wrong with incidence_hyperedges."
assert (
expected_n_hyperedges == lifted_data.num_hyperedges
), "Something is wrong with the number of hyperedges."

assert (
lifted_data.x_hyperedges.to_dense() == torch.tensor([[5106.0], [1060.0]])
).all(), "Something is wrong with x_hyperedges."

self.lifting.do_hyperedge_node_assignment_feature_lifting_passthrough = True
lifted_data = self.lifting.forward(self.data.clone())

assert (
lifted_data.x_hyperedges.to_dense() == torch.tensor([[1.0], [10.0]])
).all(), "Something is wrong with x_hyperedges."

def test_validations(self):
with pytest.raises(NotImplementedError):
self.lifting.network_type = "mixed"
self.lifting.forward(self.data.clone())
2 changes: 1 addition & 1 deletion tutorials/graph2hypergraph/knn_lifting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
Loading
Loading