Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature-Based Rips Complex (Point Cloud to Simplicial Complex) #35

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions configs/datasets/gudhi_bunny.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
data_domain: pointcloud
data_type: gudhi
data_name: gudhi_bunny
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
task: regression
loss_type: mse
monitor_metric: mae
9 changes: 9 additions & 0 deletions configs/datasets/gudhi_daily_activities.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
data_domain: pointcloud
data_type: gudhi
data_name: gudhi_daily_activities
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
task: regression
loss_type: mse
monitor_metric: mae
12 changes: 12 additions & 0 deletions configs/datasets/gudhi_sphere.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: gudhi_sphere
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
ambient_dim: 3
sample: random # can also be 'grid'
n_samples: 1000
task: regression
loss_type: mse
monitor_metric: mae
9 changes: 9 additions & 0 deletions configs/datasets/gudhi_spiral_2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: gudhi_spiral_2d
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
task: regression
loss_type: mse
monitor_metric: mae
12 changes: 12 additions & 0 deletions configs/datasets/gudhi_torus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: gudhi_torus
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
dim: 3 # The dimension of the *torus* - the ambient space has dimension 2 * dim
sample: random # can also be 'grid'
n_samples: 1000
task: regression
loss_type: mse
monitor_metric: mae
12 changes: 12 additions & 0 deletions configs/datasets/manual_points.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: manual_points
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
dim: 2
num_classes: 2
num_samples: 7
num_features: 2
task: classification
loss_type: cross_entropy
12 changes: 12 additions & 0 deletions configs/datasets/random_points.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: random_points
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
dim: 3
num_classes: 2
num_samples: 1000
num_features: 1
task: classification
loss_type: cross_entropy
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
transform_type: 'lifting'
transform_name: "FeatureRipsComplexLifting"
complex_dim: 3
feature_percent: 0.2
max_edge_length: 10.0
sparse: null
feature_lifting: ProjectionSum
68 changes: 68 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from collections.abc import Callable

import numpy as np
import rootutils
import torch
import torch_geometric
from omegaconf import DictConfig

Expand All @@ -10,8 +12,11 @@
from modules.data.utils.custom_dataset import CustomDataset
from modules.data.utils.utils import (
load_cell_complex_dataset,
load_gudhi_dataset,
load_hypergraph_pickle_dataset,
load_manual_graph,
load_manual_points,
load_random_points,
load_simplicial_dataset,
)

Expand Down Expand Up @@ -204,3 +209,66 @@ def load(
torch_geometric.data.Dataset object containing the loaded data.
"""
return load_hypergraph_pickle_dataset(self.parameters)


class PointCloudLoader(AbstractLoader):
r"""Loader for point cloud datasets.

Parameters
----------
parameters : DictConfig
Configuration parameters.
feature_generator: Optional[Callable[[torch.Tensor], torch.Tensor]]
Function to generate the dataset features. If None, no features added.
target_generator: Optional[Callable[[torch.Tensor], torch.Tensor]]
Function to generate the target variable. If None, no target added.
"""

def __init__(
self,
parameters: DictConfig,
feature_generator: Callable[[torch.Tensor], torch.Tensor] | None = None,
target_generator: Callable[[torch.Tensor], torch.Tensor] | None = None,
):
self.feature_generator = feature_generator
self.target_generator = target_generator
super().__init__(parameters)
self.parameters = parameters

def load(self) -> torch_geometric.data.Dataset:
r"""Load point cloud dataset.

Parameters
----------
None

Returns
-------
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
# Define the path to the data directory
root_folder = rootutils.find_root()
root_data_dir = os.path.join(root_folder, self.parameters["data_dir"])
self.data_dir = os.path.join(root_data_dir, self.parameters["data_name"])

if self.parameters.data_name.startswith("gudhi_"):
data = load_gudhi_dataset(
self.parameters,
feature_generator=self.feature_generator,
target_generator=self.target_generator,
)
elif self.parameters.data_name == "random_points":
data = load_random_points(
dim=self.parameters["dim"],
num_classes=self.parameters["num_classes"],
num_samples=self.parameters["num_samples"],
)
elif self.parameters.data_name == "manual_points":
data = load_manual_points()
else:
raise NotImplementedError(
f"Dataset {self.parameters.data_name} not implemented"
)

return CustomDataset([data], self.data_dir)
125 changes: 117 additions & 8 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import hashlib
import os.path as osp
import pickle
from collections.abc import Callable

import networkx as nx
import numpy as np
import omegaconf
import rootutils
import toponetx.datasets.graph as graph
import torch
import torch_geometric
from gudhi.datasets.generators import points
from gudhi.datasets.remote import fetch_bunny, fetch_daily_activities, fetch_spiral_2d
from topomodelx.utils.sparse import from_sparse
from torch_geometric.data import Data
from torch_sparse import coalesce

rootutils.setup_root("./", indicator=".project-root", pythonpath=True)


def get_complex_connectivity(complex, max_rank, signed=False):
r"""Gets the connectivity matrices for the complex.
Expand Down Expand Up @@ -50,16 +56,16 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity
Expand Down Expand Up @@ -421,3 +427,106 @@ def make_hash(o):
hash_as_hex = sha1.hexdigest()
# Convert the hex back to int and restrict it to the relevant int range
return int(hash_as_hex, 16) % 4294967295


def load_gudhi_dataset(
cfg: omegaconf.DictConfig,
feature_generator: Callable[[torch.Tensor], torch.Tensor] | None = None,
target_generator: Callable[[torch.Tensor], torch.Tensor] | None = None,
) -> torch_geometric.data.Data:
"""Load a dataset from the gudhi.datasets module."""
if not cfg.data_name.startswith("gudhi_"):
raise ValueError("This function should only be used with gudhi datasets")

gudhi_dataset_name = cfg.data_name.removeprefix("gudhi_")

if gudhi_dataset_name == "sphere":
points_data = points.sphere(
n_samples=cfg["n_samples"],
ambient_dim=cfg["ambient_dim"],
sample=cfg["sample"],
)
elif gudhi_dataset_name == "torus":
points_data = points.torus(
n_samples=cfg["n_samples"], dim=cfg["dim"], sample=cfg["sample"]
)
elif gudhi_dataset_name == "bunny":
file_path = osp.join(
rootutils.find_root(), cfg["data_dir"], "bunny", "bunny.npy"
)
points_data = fetch_bunny(
file_path=file_path, accept_license=cfg.get("accept_license", False)
)
elif gudhi_dataset_name == "spiral_2d":
file_path = osp.join(
rootutils.find_root(), cfg["data_dir"], "spiral_2d", "spiral_2d.npy"
)
points_data = fetch_spiral_2d(file_path=file_path)
elif gudhi_dataset_name == "daily_activities":
file_path = osp.join(
rootutils.find_root(), cfg["data_dir"], "activities", "activities.npy"
)
data = fetch_daily_activities(file_path=file_path)
points_data = data[:, :3]
else:
raise ValueError(f"Gudhi dataset {gudhi_dataset_name} not recognized.")

pos = torch.tensor(points_data, dtype=torch.float)
if feature_generator:
x = feature_generator(pos)
if x.shape[0] != pos.shape[0]:
raise ValueError(
"feature_generator must not change first dimension of points data."
)
else:
x = None

if target_generator:
y = target_generator(pos)
if y.shape[0] != pos.shape[0]:
raise ValueError(
"target_generator must not change first dimension of points data."
)
elif gudhi_dataset_name == "daily_activities":
# Target is the activity type
# 14. for 'cross_training', 18. for 'jumping', 13. for 'stepper', or 9. for 'walking'
y = torch.tensor(data[:, 3:], dtype=torch.float)
else:
y = None

return torch_geometric.data.Data(x=x, y=y, pos=pos, complex_dim=0)


def load_random_points(
dim: int, num_classes: int, num_samples: int, seed: int = 42
) -> torch_geometric.data.Data:
"""Create a random point cloud dataset."""
rng = np.random.default_rng(seed)

points = torch.tensor(rng.random((num_samples, dim)), dtype=torch.float)
classes = torch.tensor(
rng.integers(num_classes, size=num_samples), dtype=torch.long
)
features = torch.tensor(rng.integers(2, size=(num_samples, 1)), dtype=torch.float)

return torch_geometric.data.Data(x=features, y=classes, pos=points, complex_dim=0)


def load_manual_points():
pos = torch.tensor(
[
[1.0, 1.0],
[7.0, 0.0],
[4.0, 6.0],
[9.0, 6.0],
[0.0, 14.0],
[2.0, 19.0],
[9.0, 17.0],
],
dtype=torch.float,
)
x = torch.ones_like(pos, dtype=torch.float)
x[:4] = 0.0
x[3, 1] = 1.0
y = torch.randint(0, 2, (pos.shape[0],), dtype=torch.float)
return torch_geometric.data.Data(x=x, y=y, pos=pos, complex_dim=0)
5 changes: 5 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.pointcloud2simplicial.feature_rips_complex_lifting import (
FeatureRipsComplexLifting,
)

TRANSFORMS = {
# Graph -> Hypergraph
Expand All @@ -23,6 +26,8 @@
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Point Cloud -> Simplicial Complex,
"FeatureRipsComplexLifting": FeatureRipsComplexLifting,
# Feature Liftings
"ProjectionSum": ProjectionSum,
# Data Manipulations
Expand Down
17 changes: 17 additions & 0 deletions modules/transforms/liftings/pointcloud2simplicial/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from toponetx.classes import SimplicialComplex

from modules.data.utils.utils import get_complex_connectivity
from modules.transforms.liftings.lifting import PointCloudLifting


Expand All @@ -12,6 +15,20 @@ class PointCloud2SimplicialLifting(PointCloudLifting):
Additional arguments for the class.
"""

def _get_lifted_topology(self, simplicial_complex: SimplicialComplex) -> dict:
r"""Returns the lifted topology.

Parameters
----------
simplicial_complex : SimplicialComplex
The simplicial complex.
Returns
---------
dict
The lifted topology.
"""
return get_complex_connectivity(simplicial_complex, self.complex_dim)

def __init__(self, complex_dim=2, **kwargs):
super().__init__(**kwargs)
self.complex_dim = complex_dim
Expand Down
Loading
Loading