diff --git a/modules/data/utils/utils.py b/modules/data/utils/utils.py index ce6e922..5880c35 100755 --- a/modules/data/utils/utils.py +++ b/modules/data/utils/utils.py @@ -401,12 +401,8 @@ def load_pointcloud_dataset(cfg): file_path=osp.join(data_dir, "stanford_bunny.npy"), accept_license=False, ) - - num_points = cfg["num_points"] if "num_points" in cfg else len(pos) pos = torch.tensor(pos) - pos = pos[np.random.choice(pos.shape[0], num_points, replace=False)] - return CustomDataset( [ torch_geometric.data.Data( @@ -420,8 +416,10 @@ def load_pointcloud_dataset(cfg): def annulus_2d(D, N, R1=0.8, R2=1, A=0): n = 0 P = np.array([[0.0] * D] * N) + + rng = np.random.default_rng() while n < N: - p = np.random.uniform(-R2, R2, D) + p = rng.uniform(-R2, R2, D) if np.linalg.norm(p) > R2 or np.linalg.norm(p) < R1: continue if (p[0] > 0) and (np.abs(p[1]) < A / 2): diff --git a/modules/transforms/liftings/pointcloud2graph/cover_lifting.py b/modules/transforms/liftings/pointcloud2graph/cover_lifting.py index 03074fa..5bbaf39 100644 --- a/modules/transforms/liftings/pointcloud2graph/cover_lifting.py +++ b/modules/transforms/liftings/pointcloud2graph/cover_lifting.py @@ -6,34 +6,35 @@ import statsmodels.stats.multitest as mt import torch import torch_geometric -from gudhi import cover_complex from statsmodels.distributions.empirical_distribution import ECDF from torch_geometric.utils.convert import from_networkx from modules.transforms.liftings.pointcloud2graph.base import PointCloud2GraphLifting +rng = np.random.default_rng() -def persistent_homology(points: torch.Tensor, subcomplex_inds: list[int] = None): + +def persistent_homology(points: torch.Tensor, subcomplex_inds=None): st = gudhi.AlphaComplex(points=points).create_simplex_tree() if subcomplex_inds is not None: - subcomplex = [] - for simplex in st.get_simplices(): - if all(x in subcomplex_inds for x in simplex[0]): - subcomplex.append(simplex[0]) + subcomplex = [ + simplex + for simplex, _ in st.get_simplices() + if all(x in subcomplex_inds for x in simplex) + ] new_vertex = st.num_vertices() st.insert([new_vertex], 0) for simplex in subcomplex: - st.insert(simplex + [new_vertex], st.filtration(simplex)) + st.insert([*simplex, new_vertex], st.filtration(simplex)) persistence = st.persistence() - diagram = np.array( + + return np.array( [(birth, death) for (dim, (birth, death)) in persistence if dim == 1] ) - return diagram - def transform(diagram): b, d = diagram[:, 0], diagram[:, 1] @@ -43,7 +44,7 @@ def transform(diagram): def get_empirical_distribution(dim: int): """Generates empirical distribution of pi values for random pointcloud in R^{dim}""" - random_pc = np.random.uniform(size=(10000, dim)) + random_pc = rng.uniform(size=(10000, dim)) dgm_rand = persistent_homology(random_pc) return ECDF(transform(dgm_rand)) @@ -55,9 +56,7 @@ def test_weak_universality(emp_cdf: ECDF, diagram, alpha: float = 0.05): def sample_points(points: torch.Tensor, n=300): - return points[ - np.random.choice(points.shape[0], min(n, points.shape[0]), replace=False) - ] + return points[rng.choice(points.shape[0], min(n, points.shape[0]), replace=False)] class CoverLifting(PointCloud2GraphLifting): @@ -72,7 +71,6 @@ class CoverLifting(PointCloud2GraphLifting): def __init__( self, ambient_dim: int = 2, - cover_complex: gudhi.cover_complex.CoverComplex = None, **kwargs, ): super().__init__(**kwargs) diff --git a/test/transforms/liftings/pointcloud2graph/test_cover_lifting.py b/test/transforms/liftings/pointcloud2graph/test_cover_lifting.py index d8f28af..d1b2d66 100644 --- a/test/transforms/liftings/pointcloud2graph/test_cover_lifting.py +++ b/test/transforms/liftings/pointcloud2graph/test_cover_lifting.py @@ -1,7 +1,6 @@ """Test the message passing module.""" import networkx as nx -import torch from torch_geometric.utils.convert import to_networkx from modules.data.utils.utils import load_annulus @@ -21,12 +20,7 @@ def setup_method(self): def test_lift_topology(self): """Test the lift_topology method.""" # Test the lift_topology method - lifted_dataset = self.lifting(self.data) + lifted_data = self.lifting(self.data) - # g = nx.Graph() - # us, vs = lifted_dataset["edge_index"] - - # for u, v in zip(us, vs): - # g.add_edge(u, v) - - # nx.cycles.find_cycle(g) + g = to_networkx(lifted_data, to_undirected=True) + nx.find_cycle(g) diff --git a/tutorials/pointcloud2graph/cover_lifting.ipynb b/tutorials/pointcloud2graph/cover_lifting.ipynb index 06eb164..86ba9c6 100644 --- a/tutorials/pointcloud2graph/cover_lifting.ipynb +++ b/tutorials/pointcloud2graph/cover_lifting.ipynb @@ -45,7 +45,6 @@ "from modules.utils.utils import (\n", " describe_data,\n", " load_dataset_config,\n", - " load_model_config,\n", " load_transform_config,\n", ")" ] @@ -257,8 +256,6 @@ "metadata": {}, "outputs": [], "source": [ - "import torch\n", - "from torch_geometric.nn import global_mean_pool\n", "from torch_geometric.nn.models import GraphSAGE" ] },