diff --git a/modules/transforms/liftings/graph2hypergraph/path_lifting.py b/modules/transforms/liftings/graph2hypergraph/path_lifting.py index 634a182e..0f1811c2 100644 --- a/modules/transforms/liftings/graph2hypergraph/path_lifting.py +++ b/modules/transforms/liftings/graph2hypergraph/path_lifting.py @@ -12,24 +12,24 @@ class PathLifting(Graph2HypergraphLifting): def __init__( self, - source_nodes: list[int], - target_nodes: list[int], - lengths: list[int], + source_nodes: list[int] = None, + target_nodes: list[int] = None, + lengths: list[int] = None, include_smaller_paths=False, **kwargs, ): # guard clauses - if len(source_nodes) != len(lengths): + if ( + lengths is not None + and source_nodes is not None + and len(source_nodes) != len(lengths) + ): raise ValueError("source_nodes and lengths must have the same length") if target_nodes is not None and len(target_nodes) != len(source_nodes): raise ValueError( "When target_nodes is not None, it must have the same length" "as source_nodes" ) - if len(source_nodes) == 0: - raise ValueError( - "source_nodes,target_nodes and lengths must have at least one element" - ) super().__init__(**kwargs) self.source_nodes = source_nodes @@ -37,6 +37,13 @@ def __init__( self.lengths = lengths self.include_smaller_paths = include_smaller_paths + def value_defaults(self, data: torch_geometric.data.Data): + """Sets default values for source_nodes and lengths if not provided.""" + if self.source_nodes is None: + self.source_nodes = np.arange(data.num_nodes) + if self.lengths is None: + self.lengths = [2] * len(self.source_nodes) + def find_hyperedges(self, data: torch_geometric.data.Data): """Finds hyperedges from paths between nodes in a graph.""" G = torch_geometric.utils.convert.to_networkx(data, to_undirected=True) @@ -68,6 +75,8 @@ def find_hyperedges(self, data: torch_geometric.data.Data): return s_hyperedges def lift_topology(self, data: torch_geometric.data.Data): + if self.source_nodes is None or self.lengths is None: + self.value_defaults(data) s_hyperedges = self.find_hyperedges(data) indices = [[], []] for edge_id, x in enumerate(s_hyperedges):