Skip to content

Commit

Permalink
Merge pull request #280 from pyt-team/frantzen/toponetx-imports
Browse files Browse the repository at this point in the history
Adapt to new TopoNetX import convention
  • Loading branch information
ffl096 authored Oct 21, 2024
2 parents 8881c71 + 27b0fc9 commit 106db16
Show file tree
Hide file tree
Showing 29 changed files with 83 additions and 80 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ from topomodelx.nn.simplicial.san import SAN
from topomodelx.utils.sparse import from_sparse

# Step 1: Load the Karate Club dataset
dataset = tnx.karate_club(complex_type="simplicial")
dataset = tnx.datasets.karate_club(complex_type="simplicial")

# Step 2: Prepare Laplacians and node/edge features
laplacian_down = from_sparse(dataset.down_laplacian_matrix(rank=1))
Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_dist2cycle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Unit tests for Dist2Cycke Model."""

import numpy as np
import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.dist2cycle import Dist2Cycle

Expand All @@ -15,7 +16,7 @@ def test_forward(self):
face_set = [[2, 3, 4], [2, 4, 5]]

torch.manual_seed(42)
simplicial_complex = SimplicialComplex(edge_set + face_set)
simplicial_complex = tnx.SimplicialComplex(edge_set + face_set)
laplacian_down_1 = simplicial_complex.down_laplacian_matrix(rank=1).todense()
adjacency_1 = simplicial_complex.adjacency_matrix(rank=1).todense()
laplacian_down_1_inv = np.linalg.pinv(
Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_hsn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Unit tests for HSN Model."""

import numpy as np
import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.hsn import HSN

Expand All @@ -15,7 +16,7 @@ def test_forward(self):
face_set = [[2, 3, 4], [2, 4, 5]]

torch.manual_seed(42)
simplicial_complex = SimplicialComplex(edge_set + face_set)
simplicial_complex = tnx.SimplicialComplex(edge_set + face_set)
laplacian_down_1 = simplicial_complex.down_laplacian_matrix(rank=1).todense()
adjacency_1 = simplicial_complex.adjacency_matrix(rank=1).todense()
laplacian_down_1_inv = np.linalg.pinv(
Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_san.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Unit tests for SAN Model."""

import itertools
import random

import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.san import SAN
from topomodelx.utils.sparse import from_sparse
Expand All @@ -28,7 +29,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)
x_1 = torch.randn(35, 2)
Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_sca_cmps.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Unit tests for SCA Model."""

import itertools
import random

import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.sca_cmps import SCACMPS
from topomodelx.utils.sparse import from_sparse
Expand All @@ -28,7 +29,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)

Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_sccn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Unit tests for SCCN Model."""

import itertools
import random

import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.sccn import SCCN
from topomodelx.utils.sparse import from_sparse
Expand All @@ -28,7 +29,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)

Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_sccnn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Unit tests for SCCNN Model."""

import itertools
import random

import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.sccnn import SCCNN
from topomodelx.utils.sparse import from_sparse
Expand All @@ -28,7 +29,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)
# Some nodes might not be selected at all in the combinations above
Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_scconv.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Unit tests for SCCNN Model."""

import itertools
import random

import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.scconv import SCConv
from topomodelx.utils.sparse import from_sparse
Expand All @@ -28,7 +29,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)
# Some nodes might not be selected at all in the combinations above
Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_scn2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Unit tests for SCN2 Model."""

import itertools
import random

import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.scn2 import SCN2
from topomodelx.utils.sparse import from_sparse
Expand All @@ -28,7 +29,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)

Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_scnn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Unit tests for SCNN Model."""

import itertools
import random

import numpy as np
import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.scnn import SCNN
from topomodelx.utils.sparse import from_sparse
Expand All @@ -29,7 +30,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)
x_1 = torch.randn(simplicial_complex.shape[1], 2)
Expand Down
13 changes: 8 additions & 5 deletions topomodelx/nn/simplicial/scone.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Neural network implementation of classification using SCoNe."""

import random
from itertools import product

import networkx as nx
import numpy as np
import toponetx as tnx
import torch
from scipy.spatial import Delaunay, distance
from toponetx.classes.simplicial_complex import SimplicialComplex
from torch import nn
from torch.utils.data.dataset import Dataset

Expand All @@ -15,7 +16,7 @@

def generate_complex(
N: int = 100, *, rng: np.random.Generator | None = None
) -> tuple[SimplicialComplex, np.ndarray]:
) -> tuple[tnx.SimplicialComplex, np.ndarray]:
"""Generate a simplicial complex as described.
Generate a simplicial complex of dimension 2 as follows:
Expand Down Expand Up @@ -58,13 +59,13 @@ def generate_complex(
for j in range(3):
simplices[i][j] = idx_dict[simplices[i][j]]

sc = SimplicialComplex(simplices)
sc = tnx.SimplicialComplex(simplices)
coords = points[list(indices_included)]
return sc, coords


def generate_trajectories(
sc: SimplicialComplex, coords: np.ndarray, n_max: int = 1000
sc: tnx.SimplicialComplex, coords: np.ndarray, n_max: int = 1000
) -> list[list[int]]:
"""Generate trajectories from nodes in the lower left corner to the upper right corner connected through a node in the middle."""
# Get indices for start points in the lower left corner, mid points in the center region and end points in the upper right corner.
Expand Down Expand Up @@ -98,7 +99,9 @@ def generate_trajectories(
class TrajectoriesDataset(Dataset):
"""Create a dataset of trajectories."""

def __init__(self, sc: SimplicialComplex, trajectories: list[list[int]]) -> None:
def __init__(
self, sc: tnx.SimplicialComplex, trajectories: list[list[int]]
) -> None:
self.trajectories = trajectories
self.sc = sc
self.adjacency = torch.Tensor(sc.adjacency_matrix(0).toarray())
Expand Down
4 changes: 2 additions & 2 deletions tutorials/cell/can_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@
],
"source": [
"import numpy as np\n",
"import toponetx as tnx\n",
"import torch\n",
"from sklearn.model_selection import train_test_split\n",
"from toponetx.classes.cell_complex import CellComplex\n",
"from torch_geometric.datasets import TUDataset\n",
"from torch_geometric.utils.convert import to_networkx\n",
"\n",
Expand Down Expand Up @@ -218,7 +218,7 @@
"x_1_list = []\n",
"y_list = []\n",
"for graph in dataset:\n",
" cell_complex = CellComplex(to_networkx(graph))\n",
" cell_complex = tnx.CellComplex(to_networkx(graph))\n",
" cc_list.append(cell_complex)\n",
" x_0_list.append(graph.x)\n",
" x_1_list.append(graph.edge_attr)\n",
Expand Down
4 changes: 2 additions & 2 deletions tutorials/cell/ccxn_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import toponetx.datasets as datasets\n",
"import toponetx as tnx\n",
"import torch\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
Expand Down Expand Up @@ -132,7 +132,7 @@
}
],
"source": [
"shrec, _ = datasets.mesh.shrec_16(size=\"small\")\n",
"shrec, _ = tnx.datasets.shrec_16(size=\"small\")\n",
"\n",
"shrec = {key: np.array(value) for key, value in shrec.items()}\n",
"x_0s = shrec[\"node_feat\"]\n",
Expand Down
4 changes: 2 additions & 2 deletions tutorials/cell/cwn_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
],
"source": [
"import numpy as np\n",
"import toponetx.datasets as datasets\n",
"import toponetx as tnx\n",
"import torch\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
Expand Down Expand Up @@ -152,7 +152,7 @@
}
],
"source": [
"shrec, _ = datasets.mesh.shrec_16(size=\"small\")\n",
"shrec, _ = tnx.datasets.shrec_16(size=\"small\")\n",
"\n",
"shrec = {key: np.array(value) for key, value in shrec.items()}\n",
"x_0s = shrec[\"node_feat\"]\n",
Expand Down
10 changes: 5 additions & 5 deletions tutorials/combinatorial/hmc_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import toponetx as tnx\n",
"import torch\n",
"from toponetx.datasets.mesh import shrec_16\n",
"from torch.utils.data import DataLoader, Dataset\n",
"\n",
"from topomodelx.nn.combinatorial.hmc import HMC"
Expand Down Expand Up @@ -338,7 +338,7 @@
}
],
"source": [
"shrec_training, shrec_testing = shrec_16()"
"shrec_training, shrec_testing = tnx.datasets.shrec_16()"
]
},
{
Expand Down Expand Up @@ -839,9 +839,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv_modelx",
"display_name": "venv",
"language": "python",
"name": "venv_modelx"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -853,7 +853,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions tutorials/hypergraph/dhgcn_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import toponetx.datasets as datasets\n",
"import toponetx as tnx\n",
"import torch\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
Expand Down Expand Up @@ -98,7 +98,7 @@
}
],
"source": [
"shrec, _ = datasets.mesh.shrec_16(size=\"small\")\n",
"shrec, _ = tnx.datasets.mesh.shrec_16(size=\"small\")\n",
"\n",
"shrec = {key: np.array(value) for key, value in shrec.items()}\n",
"x_0s = shrec[\"node_feat\"]\n",
Expand Down
4 changes: 2 additions & 2 deletions tutorials/hypergraph/hypergat_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import toponetx.datasets as datasets\n",
"import toponetx as tnx\n",
"import torch\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
Expand Down Expand Up @@ -112,7 +112,7 @@
}
],
"source": [
"shrec, _ = datasets.mesh.shrec_16(size=\"small\")\n",
"shrec, _ = tnx.datasets.shrec_16(size=\"small\")\n",
"\n",
"shrec = {key: np.array(value) for key, value in shrec.items()}\n",
"x_0s = shrec[\"node_feat\"]\n",
Expand Down
4 changes: 2 additions & 2 deletions tutorials/hypergraph/unigcn_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"metadata": {},
"outputs": [],
"source": [
"import toponetx as tnx\n",
"import torch\n",
"from sklearn.model_selection import train_test_split\n",
"from toponetx.classes.simplicial_complex import SimplicialComplex\n",
"from torch_geometric.datasets import TUDataset\n",
"from torch_geometric.utils.convert import to_networkx\n",
"\n",
Expand Down Expand Up @@ -83,7 +83,7 @@
"x_1_list = []\n",
"y_list = []\n",
"for graph in dataset:\n",
" hg = SimplicialComplex(to_networkx(graph)).to_hypergraph()\n",
" hg = tnx.SimplicialComplex(to_networkx(graph)).to_hypergraph()\n",
" hg_list.append(hg)\n",
" x_1_list.append(graph.x.to(device))\n",
" y_list.append(graph.y.to(device))\n",
Expand Down
Loading

0 comments on commit 106db16

Please sign in to comment.