Skip to content

Commit

Permalink
updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
levtelyatnikov committed May 1, 2024
1 parent 2c588b1 commit 09ffb92
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 75 deletions.
8 changes: 4 additions & 4 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_complex_connectivity(complex, max_rank, signed=False):
rank=rank_idx, signed=signed
)
)
except ValueError: # noqa: PERF203
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
Expand Down Expand Up @@ -122,7 +122,7 @@ def load_simplicial_dataset(cfg):
)
)
)
except ValueError: # noqa: PERF203
except ValueError: # noqa: PERF203
features[f"x_{rank_idx}"] = torch.tensor(
np.zeros((data.shape[rank_idx], 0))
)
Expand Down Expand Up @@ -310,7 +310,7 @@ def load_manual_graph():
for tetrahedron in tetrahedrons:
for i in range(len(tetrahedron)):
for j in range(i + 1, len(tetrahedron)):
edges.append([tetrahedron[i], tetrahedron[j]]) # noqa: PERF401
edges.append([tetrahedron[i], tetrahedron[j]]) # noqa: PERF401

# Create a graph
G = nx.Graph()
Expand Down Expand Up @@ -389,7 +389,7 @@ def ensure_serializable(obj):
for key, value in obj.items():
obj[key] = ensure_serializable(value)
return obj
elif isinstance(obj, list | tuple): # noqa: RET505
elif isinstance(obj, list | tuple): # noqa: RET505
return [ensure_serializable(item) for item in obj]
elif isinstance(obj, set):
return {ensure_serializable(item) for item in obj}
Expand Down
4 changes: 2 additions & 2 deletions modules/transforms/data_manipulations/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def forward(self, data: torch_geometric.data.Data):
for key in data:
for field_substring in self.parameters["selected_fields"]:
if field_substring in key and key != "incidence_0":
field_to_process.append(key) # noqa : PERF401
field_to_process.append(key) # noqa : PERF401

for field in field_to_process:
data = self.calculate_node_degrees(data, field)
Expand Down Expand Up @@ -299,7 +299,7 @@ def forward(self, data: torch_geometric.data.Data):
if len(self.parameters["keep_fields"]) == 1:
return data

for key, _ in data.items(): # noqa : PERF102
for key, _ in data.items(): # noqa : PERF102
if key not in self.parameters["keep_fields"]:
del data[key]

Expand Down
4 changes: 0 additions & 4 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
)
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
from modules.transforms.liftings.graph2cell.cycle_lifting import CellCycleLifting
from modules.transforms.liftings.graph2hypergraph.khop_lifting import (
HypergraphKHopLifting,
)
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
)
Expand All @@ -21,7 +18,6 @@

TRANSFORMS = {
# Graph -> Hypergraph
"HypergraphKHopLifting": HypergraphKHopLifting,
"HypergraphKNNLifting": HypergraphKNNLifting,
# Graph -> Simplicial Complex
"SimplicialCliqueLifting": SimplicialCliqueLifting,
Expand Down
2 changes: 1 addition & 1 deletion modules/transforms/feature_liftings/feature_liftings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def lift_features(
-------
torch_geometric.data.Data | dict
The lifted data."""
keys = sorted([key.split("_")[1] for key in data.keys() if "incidence" in key]) # noqa : SIM118
keys = sorted([key.split("_")[1] for key in data.keys() if "incidence" in key]) # noqa : SIM118
for elem in keys:
if f"x_{elem}" not in data:
idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1
Expand Down
2 changes: 1 addition & 1 deletion modules/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def describe_data(dataset: torch_geometric.data.Dataset, idx_sample: int = 0):
isolated_nodes = []
for i in range(data.x.shape[0]):
if i not in connected_nodes:
isolated_nodes.append(i) # noqa : PERF401
isolated_nodes.append(i) # noqa : PERF401
print(f" - There are {len(isolated_nodes)} isolated nodes.")
else:
for i, c_d in enumerate(complex_dim):
Expand Down
19 changes: 5 additions & 14 deletions test/transforms/feature_liftings/test_projection_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from modules.data.utils.utils import load_manual_graph
from modules.transforms.feature_liftings.feature_liftings import ProjectionSum
from modules.transforms.liftings.graph2hypergraph.khop_lifting import (
HypergraphKHopLifting,
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
)
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
Expand All @@ -23,8 +23,8 @@ def setup_method(self):

# Initialize a simplicial/cell lifting class
self.lifting = SimplicialCliqueLifting(complex_dim=3)
# Initialize a hypergraph lifting class
self.lifting_h = HypergraphKHopLifting(k_value=1)

self.lifting_h = HypergraphKNNLifting(k_value=3)

def test_lift_features(self):
# Test the lift_features method for simplicial/cell lifting
Expand Down Expand Up @@ -78,16 +78,7 @@ def test_lift_features(self):
)

expected_x_hyperedges = torch.tensor(
[
[5116.0],
[116.0],
[5666.0],
[1060.0],
[116.0],
[6510.0],
[1550.0],
[5511.0],
]
[[16.0], [66.0], [166.0], [650.0], [1600.0], [6500.0], [6000.0], [5000.0]]
)

assert (
Expand Down
54 changes: 15 additions & 39 deletions test/transforms/liftings/graph2hypergraph/test_khop_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch

from modules.data.utils.utils import load_manual_graph
from modules.transforms.liftings.graph2hypergraph.khop_lifting import (
HypergraphKHopLifting,
from modules.transforms.liftings.graph2hypergraph.knn_lifting import (
HypergraphKNNLifting,
)


Expand All @@ -16,55 +16,31 @@ def setup_method(self):
self.data = load_manual_graph()

# Initialise the HypergraphKHopLifting class
self.lifting_k1 = HypergraphKHopLifting(k_value=1)
self.lifting_k2 = HypergraphKHopLifting(k_value=2)

self.lifting_k = HypergraphKNNLifting(k_value=3)

def test_lift_topology(self):
# Test the lift_topology method
lifted_data_k1 = self.lifting_k1.forward(self.data.clone())
lifted_data_k = self.lifting_k.forward(self.data.clone())

expected_n_hyperedges = 8

expected_incidence_1 = torch.tensor(
[
[1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0],
[1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0],
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0],
[1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0],
[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
]
)

assert (
expected_incidence_1 == lifted_data_k1.incidence_hyperedges.to_dense()
expected_incidence_1 == lifted_data_k.incidence_hyperedges.to_dense()
).all(), "Something is wrong with incidence_hyperedges (k=1)."
assert (
expected_n_hyperedges == lifted_data_k1.num_hyperedges
expected_n_hyperedges == lifted_data_k.num_hyperedges
), "Something is wrong with the number of hyperedges (k=1)."

lifted_data_k2 = self.lifting_k2.forward(self.data.clone())

expected_n_hyperedges = 8

expected_incidence_1 = torch.tensor(
[
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
]
)

assert (
expected_incidence_1 == lifted_data_k2.incidence_hyperedges.to_dense()
).all(), "Something is wrong with incidence_hyperedges (k=2)."
assert (
expected_n_hyperedges == lifted_data_k2.num_hyperedges
), "Something is wrong with the number of hyperedges (k=2)."
13 changes: 3 additions & 10 deletions tutorials/graph2hypergraph/knn_lifting.ipynb

Large diffs are not rendered by default.

0 comments on commit 09ffb92

Please sign in to comment.