Skip to content

Commit

Permalink
Fix ruff errors
Browse files Browse the repository at this point in the history
  • Loading branch information
pzajec committed Jul 13, 2024
1 parent 6265cf4 commit a028c81
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 32 deletions.
8 changes: 3 additions & 5 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,12 +401,8 @@ def load_pointcloud_dataset(cfg):
file_path=osp.join(data_dir, "stanford_bunny.npy"),
accept_license=False,
)

num_points = cfg["num_points"] if "num_points" in cfg else len(pos)
pos = torch.tensor(pos)

pos = pos[np.random.choice(pos.shape[0], num_points, replace=False)]

return CustomDataset(
[
torch_geometric.data.Data(
Expand All @@ -420,8 +416,10 @@ def load_pointcloud_dataset(cfg):
def annulus_2d(D, N, R1=0.8, R2=1, A=0):
n = 0
P = np.array([[0.0] * D] * N)

rng = np.random.default_rng()
while n < N:
p = np.random.uniform(-R2, R2, D)
p = rng.uniform(-R2, R2, D)
if np.linalg.norm(p) > R2 or np.linalg.norm(p) < R1:
continue
if (p[0] > 0) and (np.abs(p[1]) < A / 2):
Expand Down
28 changes: 13 additions & 15 deletions modules/transforms/liftings/pointcloud2graph/cover_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,35 @@
import statsmodels.stats.multitest as mt
import torch
import torch_geometric
from gudhi import cover_complex
from statsmodels.distributions.empirical_distribution import ECDF
from torch_geometric.utils.convert import from_networkx

from modules.transforms.liftings.pointcloud2graph.base import PointCloud2GraphLifting

rng = np.random.default_rng()

def persistent_homology(points: torch.Tensor, subcomplex_inds: list[int] = None):

def persistent_homology(points: torch.Tensor, subcomplex_inds=None):
st = gudhi.AlphaComplex(points=points).create_simplex_tree()

if subcomplex_inds is not None:
subcomplex = []
for simplex in st.get_simplices():
if all(x in subcomplex_inds for x in simplex[0]):
subcomplex.append(simplex[0])
subcomplex = [
simplex
for simplex, _ in st.get_simplices()
if all(x in subcomplex_inds for x in simplex)
]

new_vertex = st.num_vertices()
st.insert([new_vertex], 0)
for simplex in subcomplex:
st.insert(simplex + [new_vertex], st.filtration(simplex))
st.insert([*simplex, new_vertex], st.filtration(simplex))

persistence = st.persistence()
diagram = np.array(

return np.array(
[(birth, death) for (dim, (birth, death)) in persistence if dim == 1]
)

return diagram


def transform(diagram):
b, d = diagram[:, 0], diagram[:, 1]
Expand All @@ -43,7 +44,7 @@ def transform(diagram):

def get_empirical_distribution(dim: int):
"""Generates empirical distribution of pi values for random pointcloud in R^{dim}"""
random_pc = np.random.uniform(size=(10000, dim))
random_pc = rng.uniform(size=(10000, dim))
dgm_rand = persistent_homology(random_pc)
return ECDF(transform(dgm_rand))

Expand All @@ -55,9 +56,7 @@ def test_weak_universality(emp_cdf: ECDF, diagram, alpha: float = 0.05):


def sample_points(points: torch.Tensor, n=300):
return points[
np.random.choice(points.shape[0], min(n, points.shape[0]), replace=False)
]
return points[rng.choice(points.shape[0], min(n, points.shape[0]), replace=False)]


class CoverLifting(PointCloud2GraphLifting):
Expand All @@ -72,7 +71,6 @@ class CoverLifting(PointCloud2GraphLifting):
def __init__(
self,
ambient_dim: int = 2,
cover_complex: gudhi.cover_complex.CoverComplex = None,
**kwargs,
):
super().__init__(**kwargs)
Expand Down
12 changes: 3 additions & 9 deletions test/transforms/liftings/pointcloud2graph/test_cover_lifting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Test the message passing module."""

import networkx as nx
import torch
from torch_geometric.utils.convert import to_networkx

from modules.data.utils.utils import load_annulus
Expand All @@ -21,12 +20,7 @@ def setup_method(self):
def test_lift_topology(self):
"""Test the lift_topology method."""
# Test the lift_topology method
lifted_dataset = self.lifting(self.data)
lifted_data = self.lifting(self.data)

# g = nx.Graph()
# us, vs = lifted_dataset["edge_index"]

# for u, v in zip(us, vs):
# g.add_edge(u, v)

# nx.cycles.find_cycle(g)
g = to_networkx(lifted_data, to_undirected=True)
nx.find_cycle(g)
3 changes: 0 additions & 3 deletions tutorials/pointcloud2graph/cover_lifting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
"from modules.utils.utils import (\n",
" describe_data,\n",
" load_dataset_config,\n",
" load_model_config,\n",
" load_transform_config,\n",
")"
]
Expand Down Expand Up @@ -257,8 +256,6 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch_geometric.nn import global_mean_pool\n",
"from torch_geometric.nn.models import GraphSAGE"
]
},
Expand Down

0 comments on commit a028c81

Please sign in to comment.