Skip to content

Commit

Permalink
improved code readability
Browse files Browse the repository at this point in the history
  • Loading branch information
PierrickLeroy committed Jul 12, 2024
1 parent b26701f commit 5410453
Show file tree
Hide file tree
Showing 3 changed files with 338 additions and 32 deletions.
70 changes: 55 additions & 15 deletions modules/transforms/liftings/graph2hypergraph/path_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,22 @@ def __init__(
include_smaller_paths=False,
**kwargs,
):
"""Init function
Args:
source_nodes (list[int], optional): a list of nodes from which to start the paths.
Defaults to None in __init__ but is later valued in value_defaults().
target_nodes (list[int], optional): a list of nodes where the paths must end.
Defaults to None.
lengths (list[int], optional): a list of paths lenghts.
Defaults to None in __init__ but is later valued in value_defaults().
include_smaller_paths (bool, optional): whether or not to include paths from source
to target smaller than the length specified. Defaults to False.
Raises:
ValueError: when provided source_nodes and lengths must have the same length
ValueError: when provided target_nodes and source_nodes must have the same length
"""
# guard clauses
if (
lengths is not None
Expand All @@ -37,31 +53,31 @@ def __init__(
self.lengths = lengths
self.include_smaller_paths = include_smaller_paths

def value_defaults(self, data: torch_geometric.data.Data):
def _value_defaults(self, data: torch_geometric.data.Data):
"""Sets default values for source_nodes and lengths if not provided."""
if self.source_nodes is None:
self.source_nodes = np.arange(data.num_nodes)
if self.lengths is None:
self.lengths = [2] * len(self.source_nodes)

def find_hyperedges(self, data: torch_geometric.data.Data):
def _find_hyperedges(self, data: torch_geometric.data.Data):
"""Finds hyperedges from paths between nodes in a graph."""
G = torch_geometric.utils.convert.to_networkx(data, to_undirected=True)
s_hyperedges = set()

if self.target_nodes is None: # all paths stemming from source nodes only
for source, length in zip(self.source_nodes, self.lengths, strict=True):
D, d_id2label, l_leafs = self.build_stemmingTree(G, source, length)
s = self.extract_hyperedgesFromStemmingTree(D, d_id2label, l_leafs)
D, d_id2label, l_leafs = self._build_stemmingTree(G, source, length)
s = self._extract_hyperedgesFromStemmingTree(D, d_id2label, l_leafs)
s_hyperedges = s_hyperedges.union(s)

else: # paths from source_nodes to target_nodes or from source nodes only
for source, target, length in zip(
self.source_nodes, self.target_nodes, self.lengths, strict=True
):
if target is None:
D, d_id2label, l_leafs = self.build_stemmingTree(G, source, length)
s = self.extract_hyperedgesFromStemmingTree(D, d_id2label, l_leafs)
D, d_id2label, l_leafs = self._build_stemmingTree(G, source, length)
s = self._extract_hyperedgesFromStemmingTree(D, d_id2label, l_leafs)
s_hyperedges = s_hyperedges.union(s)
else:
paths = list(
Expand All @@ -75,9 +91,10 @@ def find_hyperedges(self, data: torch_geometric.data.Data):
return s_hyperedges

def lift_topology(self, data: torch_geometric.data.Data):
"""Lifts the graph data to a hypergraph by considering paths between nodes."""
if self.source_nodes is None or self.lengths is None:
self.value_defaults(data)
s_hyperedges = self.find_hyperedges(data)
self._value_defaults(data)
s_hyperedges = self._find_hyperedges(data)
indices = [[], []]
for edge_id, x in enumerate(s_hyperedges):
indices[1].extend([edge_id] * len(x))
Expand All @@ -91,8 +108,22 @@ def lift_topology(self, data: torch_geometric.data.Data):
"x_0": data.x,
}

def build_stemmingTree(self, G, source_root, length, verbose=False):
"""Creates a directed tree from a source node with paths of a given length."""
def _build_stemmingTree(self, G, source_root, length, verbose=False):
"""Creates a directed tree from a source node with paths of a given length.
This directed tree has as root the source node and paths stemming from it.
This tree is used to extract hyperedges from paths to leafs.
Args:
G (networkx.classes.graph.Graph): the original graph
source_root (int): the source node from which to start the paths
length (int): the length of the paths
verbose (bool, optional): Defaults to False.
Returns:
D (networkx.classes.graph.DiGraph): a directed tree stemming from source_root
d_id2label (dict): a dictionary mapping node ids to node labels
l_leafs (list): a list of leaf nodes ids
"""
d_id2label = {}
stack = []
D = nx.DiGraph()
Expand All @@ -115,11 +146,11 @@ def build_stemmingTree(self, G, source_root, length, verbose=False):
stack.append(n_id)
elif len(visited_labels) == length:
l_leafs.append(n_id)
else:
else: # security check
raise ValueError("Visited labels length is greater than length")
D.add_edge(node, n_id)
n_id += 1
if verbose:
if verbose: # output information during the process
print("\nLoop Variables Summary:")
print("nodes:", node)
print("neighbors:", neighbors)
Expand All @@ -129,15 +160,24 @@ def build_stemmingTree(self, G, source_root, length, verbose=False):
print("id2label:", d_id2label)
return D, d_id2label, l_leafs

def extract_hyperedgesFromStemmingTree(self, D, d_id2label, l_leafs):
def _extract_hyperedgesFromStemmingTree(self, D, d_id2label, l_leafs):
"""From the root of the directed tree D,
extract hyperedges from the paths to the leafs."""
extract hyperedges from the paths to the leafs.
Args:
D (networkx.classes.graph.DiGraph): a directed tree stemming from source_root
d_id2label (dict): a dictionary mapping node ids to node labels
l_leafs (list): a list of leaf nodes ids
Returns:
_type_: _description_
"""
a_paths = np.array(
[list(map(d_id2label.get, nx.shortest_path(D, 0, x))) for x in l_leafs]
)
s_hyperedges = {
(frozenset(x)) for x in a_paths
} # set bc != paths can be same hpedge
} # set because different paths can be same hyperedge
if self.include_smaller_paths:
for i in range(a_paths.shape[1] - 1, 1, -1):
a_paths = np.unique(a_paths[:, :i], axis=0)
Expand Down
14 changes: 7 additions & 7 deletions test/transforms/liftings/graph2hypergraph/test_path_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_2(self):
lengths,
include_smaller_paths=include_smaller_paths,
)
res = path_lifting.find_hyperedges(self.data)
res = path_lifting._find_hyperedges(self.data)
res_expected = [
[0, 1],
[0, 1, 2],
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_3(self):
target_nodes,
lengths,
include_smaller_paths=include_smaller_paths,
).find_hyperedges(self.data)
)._find_hyperedges(self.data)
assert frozenset({0, 1}) not in res

def test_4(self):
Expand All @@ -88,7 +88,7 @@ def test_4(self):
target_nodes,
lengths,
include_smaller_paths=include_smaller_paths,
).find_hyperedges(self.data)
)._find_hyperedges(self.data)
assert frozenset({0, 1}) in res

def test_5(self):
Expand All @@ -103,7 +103,7 @@ def test_5(self):
target_nodes,
lengths,
include_smaller_paths=include_smaller_paths,
).find_hyperedges(self.data)
)._find_hyperedges(self.data)
assert np.array([len(x) - 1 == k for x in res]).all()

def test_6(self):
Expand All @@ -117,7 +117,7 @@ def test_6(self):
target_nodes,
lengths,
include_smaller_paths=include_smaller_paths,
).find_hyperedges(self.data)
)._find_hyperedges(self.data)
assert len(res) > 0

def test_7(self):
Expand All @@ -134,7 +134,7 @@ def test_7(self):
target_nodes,
lengths,
include_smaller_paths=include_smaller_paths,
).find_hyperedges(self.data)
)._find_hyperedges(self.data)
if len(res) > 0:
assert (
np.array([source_nodes[0] in x for x in res]).all()
Expand All @@ -152,5 +152,5 @@ def test_8(self):
target_nodes,
lengths,
include_smaller_paths=include_smaller_paths,
).find_hyperedges(self.data)
)._find_hyperedges(self.data)
assert len(res) > 0
286 changes: 276 additions & 10 deletions tutorials/graph2hypergraph/path_lifting.ipynb

Large diffs are not rendered by default.

0 comments on commit 5410453

Please sign in to comment.