Skip to content

Commit

Permalink
refactoring for ruff check
Browse files Browse the repository at this point in the history
  • Loading branch information
PierrickLeroy committed Jul 11, 2024
1 parent 5676e4e commit cb007ef
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions modules/transforms/liftings/graph2hypergraph/path_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def find_hyperedges(self, data: torch_geometric.data.Data):
s_hyperedges = set()

if self.target_nodes is None: # all paths stemming from source nodes only
for source, length in zip(self.source_nodes, self.lengths):
for source, length in zip(self.source_nodes, self.lengths, strict=True):
D, d_id2label, l_leafs = self.build_stemmingTree(G, source, length)
s = self.extract_hyperedgesFromStemmingTree(D, d_id2label, l_leafs)
s_hyperedges = s_hyperedges.union(s)

else: # paths from source_nodes to target_nodes or from source nodes only
for source, target, length in zip(
self.source_nodes, self.target_nodes, self.lengths
self.source_nodes, self.target_nodes, self.lengths, strict=True
):
if target is None:
D, d_id2label, l_leafs = self.build_stemmingTree(G, source, length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_3(self):
lengths,
include_smaller_paths=include_smaller_paths,
).find_hyperedges(self.data)
assert not frozenset({0, 1}) in res
assert frozenset({0, 1}) not in res

def test_4(self):
"""test: include_smaller_paths=True"""
Expand Down Expand Up @@ -122,10 +122,12 @@ def test_6(self):

def test_7(self):
"""test: every hyperedge contains the source and target nodes when specified"""
a = np.random.choice(np.arange(len(self.data.x)), 2, replace=False)
a = np.random.default_rng().choice(
np.arange(len(self.data.x)), 2, replace=False
)
source_nodes = [a[0]]
target_nodes = [a[1]]
lengths = [np.random.randint(1, 5)]
lengths = [np.random.default_rng().integers(1, 5)]
include_smaller_paths = False
res = PathLifting(
source_nodes,
Expand Down

0 comments on commit cb007ef

Please sign in to comment.