Skip to content

Commit

Permalink
added default behavior of path lifting
Browse files Browse the repository at this point in the history
  • Loading branch information
PierrickLeroy committed Jul 12, 2024
1 parent cb007ef commit b26701f
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions modules/transforms/liftings/graph2hypergraph/path_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,38 @@ class PathLifting(Graph2HypergraphLifting):

def __init__(
self,
source_nodes: list[int],
target_nodes: list[int],
lengths: list[int],
source_nodes: list[int] = None,
target_nodes: list[int] = None,
lengths: list[int] = None,
include_smaller_paths=False,
**kwargs,
):
# guard clauses
if len(source_nodes) != len(lengths):
if (
lengths is not None
and source_nodes is not None
and len(source_nodes) != len(lengths)
):
raise ValueError("source_nodes and lengths must have the same length")
if target_nodes is not None and len(target_nodes) != len(source_nodes):
raise ValueError(
"When target_nodes is not None, it must have the same length"
"as source_nodes"
)
if len(source_nodes) == 0:
raise ValueError(
"source_nodes,target_nodes and lengths must have at least one element"
)

super().__init__(**kwargs)
self.source_nodes = source_nodes
self.target_nodes = target_nodes
self.lengths = lengths
self.include_smaller_paths = include_smaller_paths

def value_defaults(self, data: torch_geometric.data.Data):
"""Sets default values for source_nodes and lengths if not provided."""
if self.source_nodes is None:
self.source_nodes = np.arange(data.num_nodes)
if self.lengths is None:
self.lengths = [2] * len(self.source_nodes)

def find_hyperedges(self, data: torch_geometric.data.Data):
"""Finds hyperedges from paths between nodes in a graph."""
G = torch_geometric.utils.convert.to_networkx(data, to_undirected=True)
Expand Down Expand Up @@ -68,6 +75,8 @@ def find_hyperedges(self, data: torch_geometric.data.Data):
return s_hyperedges

def lift_topology(self, data: torch_geometric.data.Data):
if self.source_nodes is None or self.lengths is None:
self.value_defaults(data)
s_hyperedges = self.find_hyperedges(data)
indices = [[], []]
for edge_id, x in enumerate(s_hyperedges):
Expand Down

0 comments on commit b26701f

Please sign in to comment.