From b039924aacf873a5dce1d01e69bd3ff633dff7be Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 29 Jun 2024 15:46:58 -0400 Subject: [PATCH 01/15] add point cloud datasets from gudhi --- configs/datasets/gudhi_bunny.yaml | 9 + configs/datasets/gudhi_daily_activities.yaml | 9 + configs/datasets/gudhi_sphere.yaml | 12 + configs/datasets/gudhi_spiral_2d.yaml | 9 + configs/datasets/gudhi_torus.yaml | 12 + configs/datasets/random_points.yaml | 11 + modules/data/load/loaders.py | 65 ++++ modules/data/utils/utils.py | 107 +++++- modules/utils/utils.py | 60 ++- .../point_cloud_complex_lifting.ipynb | 352 ++++++++++++++++++ 10 files changed, 635 insertions(+), 11 deletions(-) create mode 100644 configs/datasets/gudhi_bunny.yaml create mode 100644 configs/datasets/gudhi_daily_activities.yaml create mode 100644 configs/datasets/gudhi_sphere.yaml create mode 100644 configs/datasets/gudhi_spiral_2d.yaml create mode 100644 configs/datasets/gudhi_torus.yaml create mode 100644 configs/datasets/random_points.yaml create mode 100644 tutorials/pointcloud2simplicial/point_cloud_complex_lifting.ipynb diff --git a/configs/datasets/gudhi_bunny.yaml b/configs/datasets/gudhi_bunny.yaml new file mode 100644 index 00000000..3c7b41b4 --- /dev/null +++ b/configs/datasets/gudhi_bunny.yaml @@ -0,0 +1,9 @@ +data_domain: pointcloud +data_type: gudhi +data_name: gudhi_bunny +data_dir: datasets/${data_domain}/${data_type} + +# Dataset parameters +task: regression +loss_type: mse +monitor_metric: mae \ No newline at end of file diff --git a/configs/datasets/gudhi_daily_activities.yaml b/configs/datasets/gudhi_daily_activities.yaml new file mode 100644 index 00000000..7b6bdd97 --- /dev/null +++ b/configs/datasets/gudhi_daily_activities.yaml @@ -0,0 +1,9 @@ +data_domain: pointcloud +data_type: gudhi +data_name: gudhi_daily_activities +data_dir: datasets/${data_domain}/${data_type} + +# Dataset parameters +task: regression +loss_type: mse +monitor_metric: mae \ No newline at end of file diff --git a/configs/datasets/gudhi_sphere.yaml b/configs/datasets/gudhi_sphere.yaml new file mode 100644 index 00000000..6bff480b --- /dev/null +++ b/configs/datasets/gudhi_sphere.yaml @@ -0,0 +1,12 @@ +data_domain: pointcloud +data_type: toy_dataset +data_name: gudhi_sphere +data_dir: datasets/${data_domain}/${data_type} + +# Dataset parameters +ambient_dim: 3 +sample: random # can also be 'grid' +n_samples: 1000 +task: regression +loss_type: mse +monitor_metric: mae \ No newline at end of file diff --git a/configs/datasets/gudhi_spiral_2d.yaml b/configs/datasets/gudhi_spiral_2d.yaml new file mode 100644 index 00000000..a101b4a2 --- /dev/null +++ b/configs/datasets/gudhi_spiral_2d.yaml @@ -0,0 +1,9 @@ +data_domain: pointcloud +data_type: toy_dataset +data_name: gudhi_spiral_2d +data_dir: datasets/${data_domain}/${data_type} + +# Dataset parameters +task: regression +loss_type: mse +monitor_metric: mae \ No newline at end of file diff --git a/configs/datasets/gudhi_torus.yaml b/configs/datasets/gudhi_torus.yaml new file mode 100644 index 00000000..5020ce65 --- /dev/null +++ b/configs/datasets/gudhi_torus.yaml @@ -0,0 +1,12 @@ +data_domain: pointcloud +data_type: toy_dataset +data_name: gudhi_torus +data_dir: datasets/${data_domain}/${data_type} + +# Dataset parameters +dim: 3 # The dimension of the *torus* - the ambient space has dimension 2 * dim +sample: random # can also be 'grid' +n_samples: 1000 +task: regression +loss_type: mse +monitor_metric: mae \ No newline at end of file diff --git a/configs/datasets/random_points.yaml b/configs/datasets/random_points.yaml new file mode 100644 index 00000000..ff1dbb93 --- /dev/null +++ b/configs/datasets/random_points.yaml @@ -0,0 +1,11 @@ +data_domain: pointcloud +data_type: toy_dataset +data_name: random_points +data_dir: datasets/${data_domain}/${data_type} + +# Dataset parameters +dim: 3 +num_classes: 2 +num_samples: 1000 +task: classification +loss_type: cross_entropy \ No newline at end of file diff --git a/modules/data/load/loaders.py b/modules/data/load/loaders.py index 8ccafb11..ef3d2705 100755 --- a/modules/data/load/loaders.py +++ b/modules/data/load/loaders.py @@ -1,7 +1,9 @@ import os +from typing import Callable, Optional import numpy as np import rootutils +import torch import torch_geometric from omegaconf import DictConfig @@ -10,8 +12,10 @@ from modules.data.utils.custom_dataset import CustomDataset from modules.data.utils.utils import ( load_cell_complex_dataset, + load_gudhi_dataset, load_hypergraph_pickle_dataset, load_manual_graph, + load_random_points, load_simplicial_dataset, ) @@ -204,3 +208,64 @@ def load( torch_geometric.data.Dataset object containing the loaded data. """ return load_hypergraph_pickle_dataset(self.parameters) + + +class PointCloudLoader(AbstractLoader): + r"""Loader for point cloud datasets. + + Parameters + ---------- + parameters : DictConfig + Configuration parameters. + feature_generator: Optional[Callable[[torch.Tensor], torch.Tensor]] + Function to generate the dataset features. If None, no features added. + target_generator: Optional[Callable[[torch.Tensor], torch.Tensor]] + Function to generate the target variable. If None, no target added. + """ + + def __init__( + self, + parameters: DictConfig, + feature_generator: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + target_generator: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + ): + self.feature_generator = feature_generator + self.target_generator = target_generator + super().__init__(parameters) + self.parameters = parameters + + def load(self) -> torch_geometric.data.Dataset: + r"""Load point cloud dataset. + + Parameters + ---------- + None + + Returns + ------- + torch_geometric.data.Dataset + torch_geometric.data.Dataset object containing the loaded data. + """ + # Define the path to the data directory + root_folder = rootutils.find_root() + root_data_dir = os.path.join(root_folder, self.parameters["data_dir"]) + self.data_dir = os.path.join(root_data_dir, self.parameters["data_name"]) + + if self.parameters.data_name.startswith("gudhi_"): + data = load_gudhi_dataset( + self.parameters, + feature_generator=self.feature_generator, + target_generator=self.target_generator, + ) + elif self.parameters.data_name == "random_points": + data = load_random_points( + dim=self.parameters["dim"], + num_classes=self.parameters["num_classes"], + num_samples=self.parameters["num_samples"], + ) + else: + raise NotImplementedError( + f"Dataset {self.parameters.data_name} not implemented" + ) + + return CustomDataset([data], self.data_dir) diff --git a/modules/data/utils/utils.py b/modules/data/utils/utils.py index 93ab5021..c61342c3 100755 --- a/modules/data/utils/utils.py +++ b/modules/data/utils/utils.py @@ -1,17 +1,23 @@ import hashlib import os.path as osp import pickle +from typing import Callable, Optional import networkx as nx import numpy as np import omegaconf +import rootutils import toponetx.datasets.graph as graph import torch import torch_geometric +from gudhi.datasets.generators import points +from gudhi.datasets.remote import fetch_bunny, fetch_daily_activities, fetch_spiral_2d from topomodelx.utils.sparse import from_sparse from torch_geometric.data import Data from torch_sparse import coalesce +rootutils.setup_root("./", indicator=".project-root", pythonpath=True) + def get_complex_connectivity(complex, max_rank, signed=False): r"""Gets the connectivity matrices for the complex. @@ -50,16 +56,16 @@ def get_complex_connectivity(complex, max_rank, signed=False): ) except ValueError: # noqa: PERF203 if connectivity_info == "incidence": - connectivity[f"{connectivity_info}_{rank_idx}"] = ( - generate_zero_sparse_connectivity( - m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx] - ) + connectivity[ + f"{connectivity_info}_{rank_idx}" + ] = generate_zero_sparse_connectivity( + m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx] ) else: - connectivity[f"{connectivity_info}_{rank_idx}"] = ( - generate_zero_sparse_connectivity( - m=practical_shape[rank_idx], n=practical_shape[rank_idx] - ) + connectivity[ + f"{connectivity_info}_{rank_idx}" + ] = generate_zero_sparse_connectivity( + m=practical_shape[rank_idx], n=practical_shape[rank_idx] ) connectivity["shape"] = practical_shape return connectivity @@ -421,3 +427,88 @@ def make_hash(o): hash_as_hex = sha1.hexdigest() # Convert the hex back to int and restrict it to the relevant int range return int(hash_as_hex, 16) % 4294967295 + + +def load_gudhi_dataset( + cfg: omegaconf.DictConfig, + feature_generator: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + target_generator: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, +) -> torch_geometric.data.Data: + """Load a dataset from the gudhi.datasets module.""" + if not cfg.data_name.startswith("gudhi_"): + raise ValueError("This function should only be used with gudhi datasets") + + gudhi_dataset_name = cfg.data_name.removeprefix("gudhi_") + + if gudhi_dataset_name == "sphere": + points_data = points.sphere( + n_samples=cfg["n_samples"], + ambient_dim=cfg["ambient_dim"], + sample=cfg["sample"], + ) + elif gudhi_dataset_name == "torus": + points_data = points.torus( + n_samples=cfg["n_samples"], dim=cfg["dim"], sample=cfg["sample"] + ) + elif gudhi_dataset_name == "bunny": + file_path = osp.join( + rootutils.find_root(), cfg["data_dir"], "bunny", "bunny.npy" + ) + points_data = fetch_bunny( + file_path=file_path, accept_license=cfg.get("accept_license", False) + ) + elif gudhi_dataset_name == "spiral_2d": + file_path = osp.join( + rootutils.find_root(), cfg["data_dir"], "spiral_2d", "spiral_2d.npy" + ) + points_data = fetch_spiral_2d(file_path=file_path) + elif gudhi_dataset_name == "daily_activities": + file_path = osp.join( + rootutils.find_root(), cfg["data_dir"], "activities", "activities.npy" + ) + data = fetch_daily_activities(file_path=file_path) + points_data = data[:, :3] + else: + raise ValueError(f"Gudhi dataset {gudhi_dataset_name} not recognized.") + + pos = torch.tensor(points_data, dtype=torch.float) + if feature_generator: + x = feature_generator(points) + if x.shape[0] != points.shape[0]: + raise ValueError( + "feature_generator must not change first dimension of points data." + ) + else: + x = None + + if target_generator: + y = target_generator(points) + if y.shape[0] != points.shape[0]: + raise ValueError( + "target_generator must not change first dimension of points data." + ) + elif gudhi_dataset_name == "daily_activities": + # Target is the activity type + # 14. for ‘cross_training’, 18. for ‘jumping’, 13. for ‘stepper’, or 9. for ‘walking’ + y = torch.tensor(data[:, 3:], dtype=torch.float) + else: + y = None + + data = torch_geometric.data.Data(x=x, y=y, pos=pos, complex_dim=0) + return data + + +def load_random_points( + dim: int, num_classes: int, num_samples: int, seed: int = 42 +) -> torch_geometric.data.Data: + """Create a random point cloud dataset.""" + rng = np.random.default_rng(seed) + + points = torch.tensor(rng.random((num_samples, dim)), dtype=torch.float) + classes = torch.tensor( + rng.integers(num_classes, size=num_samples), dtype=torch.long + ) + features = torch.tensor(rng.integers(2, size=(num_samples, 1)), dtype=torch.float) + + data = torch_geometric.data.Data(x=features, y=classes, pos=points, complex_dim=0) + return data diff --git a/modules/utils/utils.py b/modules/utils/utils.py index 1dfcdc2e..ca590307 100644 --- a/modules/utils/utils.py +++ b/modules/utils/utils.py @@ -133,10 +133,10 @@ def describe_data(dataset: torch_geometric.data.Dataset, idx_sample: int = 0): features_dim.append(data.x.shape[1]) else: raise ValueError("Data object does not contain any vertices/points.") - if hasattr(data, "num_edges"): + if hasattr(data, "num_edges") and data.num_edges: complex_dim.append(data.num_edges) features_dim.append(data.num_edge_features) - elif hasattr(data, "edge_index"): + elif hasattr(data, "edge_index") and data.edge_index: complex_dim.append(data.edge_index.shape[1]) features_dim.append(data.edge_attr.shape[1]) # Check if the data object contains hyperedges @@ -149,6 +149,15 @@ def describe_data(dataset: torch_geometric.data.Dataset, idx_sample: int = 0): if complex_dim[0] < 50: plot_manual_graph(data) + # Plot point cloud if it is not too large + if ( + complex_dim[0] < 10_000 + and len(complex_dim) == 1 + and not hyperedges + and data.pos.shape[1] in [2, 3] + ): + plot_point_cloud(data) + if hyperedges: print( f" - Hypergraph with {complex_dim[0]} vertices and {hyperedges} hyperedges." @@ -166,7 +175,7 @@ def describe_data(dataset: torch_geometric.data.Dataset, idx_sample: int = 0): ) print(f" - Features dimensions: {features_dim}") # Check if there are isolated nodes - if hasattr(data, "edge_index") and hasattr(data, "x"): + if hasattr(data, "edge_index") and hasattr(data, "x") and data.edge_index: connected_nodes = torch.unique(data.edge_index) isolated_nodes = [] for i in range(data.x.shape[0]): @@ -180,6 +189,51 @@ def describe_data(dataset: torch_geometric.data.Dataset, idx_sample: int = 0): print("") +def plot_point_cloud(data, title=None): + """Plot point cloud data. + + Parameters + ---------- + data : torch_geometric.data.Data + Data object containing the point cloud. + title: str + Title for the plot. + """ + + if not hasattr(data, "pos"): + raise ValueError("Must have a pos attribute to plot point cloud data.") + + if len(data.pos.shape) != 2: + raise ValueError( + f"pos tensor should have 2 dimensions, found {len(data.pos.shape)}" + ) + + if data.pos.shape[1] == 3: + dim = 2 + x = data.pos[:, 0] + y = data.pos[:, 1] + z = data.pos[:, 2] + fig = plt.figure(figsize=(8, 8)) + ax = fig.add_subplot(111, projection="3d") + ax.scatter(x, y, z) + plt.show() + elif data.pos.shape[1] == 2: + dim = 3 + x = data.pos[:, 0] + y = data.pos[:, 1] + fig = plt.figure(figsize=(8, 8)) + ax = fig.add_subplot(111) + ax.scatter(x, y) + else: + raise ValueError("Only 2 and 3 dimensional point cloud data can be plotted") + + if title is not None: + ax.set_title(title) + else: + ax.set_title(f"{dim}D Point Cloud") + plt.show() + + def plot_manual_graph(data, title=None): r"""Plot a manual graph. If lifted, the plot shows the inferred higher-order structures (bipartite graph for hyperedges, diff --git a/tutorials/pointcloud2simplicial/point_cloud_complex_lifting.ipynb b/tutorials/pointcloud2simplicial/point_cloud_complex_lifting.ipynb new file mode 100644 index 00000000..42a983b6 --- /dev/null +++ b/tutorials/pointcloud2simplicial/point_cloud_complex_lifting.ipynb @@ -0,0 +1,352 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Point Cloud-to-Simplicial Alpha Complex Lifting Tutorial" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "***\n", + "This notebook shows how to import a dataset, with the desired lifting, and how to run a neural network using the loaded data.\n", + "\n", + "The notebook is divided into sections:\n", + "\n", + "- [Loading the dataset](#loading-the-dataset) loads the config files for the data and the desired tranformation, createsa a dataset object and visualizes it.\n", + "- [Loading and applying the lifting](#loading-and-applying-the-lifting) defines a simple neural network to test that the lifting creates the expected incidence matrices.\n", + "- [Create and run a simplicial nn model](#create-and-run-a-simplicial-nn-model) simply runs a forward pass of the model to check that everything is working as expected.\n", + "\n", + "***\n", + "***\n", + "\n", + "For simplicity the notebook is setup to use a random point cloud. However, there is a set of available datasets that you can play with from the gudhi python library.\n", + "\n", + "To switch to one of the available datasets, simply change the *dataset_name* variable in [Dataset config](#dataset-config) to one of the following names:\n", + "\n", + "* gudhi_sphere\n", + "* gudhi_torus\n", + "* gudhi_bunny\n", + "* gudhi_spiral_2d\n", + "* gudhi_daily_activities\n", + "\n", + "Please see the gudhi documentation [1] for a description of these datasets and the relevant config options\n", + "\n", + "[[1]](https://gudhi.inria.fr/python/latest/index.html) GUDHI Python documentation\n", + "***" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imports and utilities" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# With this cell any imported module is reloaded before each cell execution\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "from modules.data.load.loaders import PointCloudLoader\n", + "from modules.data.preprocess.preprocessor import PreProcessor\n", + "from modules.utils.utils import (\n", + " describe_data,\n", + " load_dataset_config,\n", + " load_model_config,\n", + " load_transform_config,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading the Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we just need to spicify the name of the available dataset that we want to load. First, the dataset config is read from the corresponding yaml file (located at `/configs/datasets/` directory), and then the data is loaded via the implemented `Loaders`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset configuration for gudhi_daily_activities:\n", + "\n", + "{'data_domain': 'pointcloud',\n", + " 'data_type': 'gudhi',\n", + " 'data_name': 'gudhi_daily_activities',\n", + " 'data_dir': 'datasets/pointcloud/gudhi',\n", + " 'task': 'regression',\n", + " 'loss_type': 'mse',\n", + " 'monitor_metric': 'mae'}\n" + ] + } + ], + "source": [ + "dataset_name = \"gudhi_daily_activities\"\n", + "dataset_config = load_dataset_config(dataset_name)\n", + "dataset_config[\"n_samples\"] = 900\n", + "loader = PointCloudLoader(dataset_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can then access to the data through the `load()`method:" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The 'daily and sports activities' dataset comes from the UC Irvine Machine Learning Repository https://archive.ics.uci.edu/ml/datasets/daily+and+sports+activities\n", + "This dataset is licensed under a Creative Commons Attribution 4.0 International (CC BY 4.0) license.\n", + "\n", + "\n", + "Dataset only contains 1 sample:\n", + " - Set with 30000 points.\n", + " - Features dimension: 0\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing...\n", + "Done!\n" + ] + } + ], + "source": [ + "dataset = loader.load()\n", + "describe_data(dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading and Applying the Lifting" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this section we will instantiate the lifting we want to apply to the data. For this example the clique lifting was chosen. For a clique of n nodes the algorithm for $m=3,...,max(n, complex\\_dim)$ will create simplicials for every possible combinations containing m nodes of the clique. $complex\\_dim$ is a parameter of the lifting. This is a deterministic lifting, based on connectivity, that does not modify the initial connectivity of the graph. The problem of extracting all the cliques in a graph is NP-hard, on in some formulaitons NP-complete (clique decision problem). The computational complexity of this algorithm is $O(n^k k^2)$[[1]](https://www.sciencedirect.com/science/article/pii/S0019995885800413), where $n$ is the number of nodes in the graph and $k$ is the highest clique dimension considered.\n", + "\n", + "***\n", + "[[1]](https://www.sciencedirect.com/science/article/pii/S0019995885800413) Cook, S. A. (1985). A taxonomy of problems with fast parallel algorithms. Information and control, 64(1-3), 2-22.\n", + "***\n", + "\n", + "For simplicial complexes creating a lifting involves creating a `SimplicialComplex` object from topomodelx and adding simplices to it using the method `add_simplices_from`. The `SimplicialComplex` class then takes care of creating all the needed matrices.\n", + "\n", + "Similarly to before, we can specify the transformation we want to apply through its type and id --the correxponding config files located at `/configs/transforms`. \n", + "\n", + "Note that the *tranform_config* dictionary generated below can contain a sequence of tranforms if it is needed.\n", + "\n", + "This can also be used to explore liftings from one topological domain to another, for example using two liftings it is possible to achieve a sequence such as: graph -> simplicial complex -> hypergraph. " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Transform configuration for graph2simplicial/clique_lifting:\n", + "\n", + "{'transform_type': 'lifting',\n", + " 'transform_name': 'SimplicialCliqueLifting',\n", + " 'complex_dim': 3,\n", + " 'preserve_edge_attr': False,\n", + " 'signed': True,\n", + " 'feature_lifting': 'ProjectionSum'}\n" + ] + } + ], + "source": [ + "# Define transformation type and id\n", + "transform_type = \"liftings\"\n", + "# If the transform is a topological lifting, it should include both the type of the lifting and the identifier\n", + "transform_id = \"graph2simplicial/clique_lifting\"\n", + "\n", + "# Read yaml file\n", + "transform_config = {\n", + " \"lifting\": load_transform_config(transform_type, transform_id)\n", + " # other transforms (e.g. data manipulations, feature liftings) can be added here\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We than apply the transform via our `PreProcesor`:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transform parameters are the same, using existing data_dir: /Users/leone/Desktop/PhD-S/projects/challenge-icml-2024/datasets/graph/toy_dataset/manual/lifting/2744620725\n", + "\n", + "Dataset only contains 1 sample:\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " - The complex has 8 0-cells.\n", + " - The 0-cells have features dimension 1\n", + " - The complex has 13 1-cells.\n", + " - The 1-cells have features dimension 1\n", + " - The complex has 6 2-cells.\n", + " - The 2-cells have features dimension 1\n", + " - The complex has 1 3-cells.\n", + " - The 3-cells have features dimension 1\n", + "\n" + ] + } + ], + "source": [ + "lifted_dataset = PreProcessor(dataset, transform_config, loader.data_dir)\n", + "describe_data(lifted_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create and Run a Simplicial NN Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this section a simple model is created to test that the used lifting works as intended. In this case the model uses the `up_laplacian_1` and the `down_laplacian_1` so the lifting should make sure to add them to the data." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Model configuration for simplicial SAN:\n", + "\n", + "{'in_channels': None,\n", + " 'hidden_channels': 32,\n", + " 'out_channels': None,\n", + " 'n_layers': 2,\n", + " 'n_filters': 2,\n", + " 'order_harmonic': 5,\n", + " 'epsilon_harmonic': 0.1}\n" + ] + } + ], + "source": [ + "from modules.models.simplicial.san import SANModel\n", + "\n", + "model_type = \"simplicial\"\n", + "model_id = \"san\"\n", + "model_config = load_model_config(model_type, model_id)\n", + "\n", + "model = SANModel(model_config, dataset_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "y_hat = model(lifted_dataset.get(0))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If everything is correct the cell above should execute without errors. " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.11.3 ('topox')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + }, + "vscode": { + "interpreter": { + "hash": "5209ee787340d6caf238f8c0093dc78889cb331b3f459734c35c70f07b690b2a" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 3acd1a2e391b2436d36a2d5562e29de7abe6de78 Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 29 Jun 2024 16:37:55 -0400 Subject: [PATCH 02/15] fix --- modules/data/utils/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/data/utils/utils.py b/modules/data/utils/utils.py index c61342c3..bdea1334 100755 --- a/modules/data/utils/utils.py +++ b/modules/data/utils/utils.py @@ -473,8 +473,8 @@ def load_gudhi_dataset( pos = torch.tensor(points_data, dtype=torch.float) if feature_generator: - x = feature_generator(points) - if x.shape[0] != points.shape[0]: + x = feature_generator(pos) + if x.shape[0] != pos.shape[0]: raise ValueError( "feature_generator must not change first dimension of points data." ) @@ -482,8 +482,8 @@ def load_gudhi_dataset( x = None if target_generator: - y = target_generator(points) - if y.shape[0] != points.shape[0]: + y = target_generator(pos) + if y.shape[0] != pos.shape[0]: raise ValueError( "target_generator must not change first dimension of points data." ) From 447074895f4dd8ed47692b7a6e99ccc937f69e96 Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 29 Jun 2024 17:35:12 -0400 Subject: [PATCH 03/15] add manual dataset and fix configs --- configs/datasets/manual_points.yaml | 12 ++++++++++++ configs/datasets/random_points.yaml | 1 + .../alpha_complex_lifting.yaml | 7 +++++++ modules/data/load/loaders.py | 3 +++ modules/data/utils/utils.py | 18 ++++++++++++++++++ modules/utils/utils.py | 6 +++--- 6 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 configs/datasets/manual_points.yaml create mode 100644 configs/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.yaml diff --git a/configs/datasets/manual_points.yaml b/configs/datasets/manual_points.yaml new file mode 100644 index 00000000..be137fc8 --- /dev/null +++ b/configs/datasets/manual_points.yaml @@ -0,0 +1,12 @@ +data_domain: pointcloud +data_type: toy_dataset +data_name: manual_points +data_dir: datasets/${data_domain}/${data_type} + +# Dataset parameters +dim: 2 +num_classes: 2 +num_samples: 7 +num_features: 2 +task: classification +loss_type: cross_entropy \ No newline at end of file diff --git a/configs/datasets/random_points.yaml b/configs/datasets/random_points.yaml index ff1dbb93..f6f22524 100644 --- a/configs/datasets/random_points.yaml +++ b/configs/datasets/random_points.yaml @@ -7,5 +7,6 @@ data_dir: datasets/${data_domain}/${data_type} dim: 3 num_classes: 2 num_samples: 1000 +num_features: 1 task: classification loss_type: cross_entropy \ No newline at end of file diff --git a/configs/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.yaml b/configs/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.yaml new file mode 100644 index 00000000..6b8771a5 --- /dev/null +++ b/configs/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.yaml @@ -0,0 +1,7 @@ +transform_type: 'lifting' +transform_name: "AlphaComplexLifting" +complex_dim: 3 +alpha: 25.0 +preserve_edge_attr: False +signed: True +feature_lifting: ProjectionSum diff --git a/modules/data/load/loaders.py b/modules/data/load/loaders.py index ef3d2705..6e8a7b7e 100755 --- a/modules/data/load/loaders.py +++ b/modules/data/load/loaders.py @@ -15,6 +15,7 @@ load_gudhi_dataset, load_hypergraph_pickle_dataset, load_manual_graph, + load_manual_points, load_random_points, load_simplicial_dataset, ) @@ -263,6 +264,8 @@ def load(self) -> torch_geometric.data.Dataset: num_classes=self.parameters["num_classes"], num_samples=self.parameters["num_samples"], ) + elif self.parameters.data_name == "manual_points": + data = load_manual_points() else: raise NotImplementedError( f"Dataset {self.parameters.data_name} not implemented" diff --git a/modules/data/utils/utils.py b/modules/data/utils/utils.py index bdea1334..d8c2aa9e 100755 --- a/modules/data/utils/utils.py +++ b/modules/data/utils/utils.py @@ -512,3 +512,21 @@ def load_random_points( data = torch_geometric.data.Data(x=features, y=classes, pos=points, complex_dim=0) return data + + +def load_manual_points(): + pos = torch.tensor( + [ + [1.0, 1.0], + [7.0, 0.0], + [4.0, 6.0], + [9.0, 6.0], + [0.0, 14.0], + [2.0, 19.0], + [9.0, 17.0], + ], + dtype=torch.float, + ) + x = torch.ones_like(pos, dtype=torch.float) + y = torch.randint(0, 2, (pos.shape[0],), dtype=torch.float) + return torch_geometric.data.Data(x=x, y=y, pos=pos, complex_dim=0) diff --git a/modules/utils/utils.py b/modules/utils/utils.py index ca590307..7a1e9275 100644 --- a/modules/utils/utils.py +++ b/modules/utils/utils.py @@ -146,7 +146,7 @@ def describe_data(dataset: torch_geometric.data.Dataset, idx_sample: int = 0): hyperedges_features_dim = data.x_hyperedges.shape[1] # Plot the graph if it is not too large - if complex_dim[0] < 50: + if complex_dim[0] < 50 and len(complex_dim) != 1: plot_manual_graph(data) # Plot point cloud if it is not too large @@ -209,7 +209,7 @@ def plot_point_cloud(data, title=None): ) if data.pos.shape[1] == 3: - dim = 2 + dim = 3 x = data.pos[:, 0] y = data.pos[:, 1] z = data.pos[:, 2] @@ -218,7 +218,7 @@ def plot_point_cloud(data, title=None): ax.scatter(x, y, z) plt.show() elif data.pos.shape[1] == 2: - dim = 3 + dim = 2 x = data.pos[:, 0] y = data.pos[:, 1] fig = plt.figure(figsize=(8, 8)) From 59eaca5e0031dffbade654850ecffd4e230825e5 Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 29 Jun 2024 17:35:37 -0400 Subject: [PATCH 04/15] add alpha complex --- modules/transforms/data_transform.py | 5 + .../alpha_complex_lifting.py | 43 +++ .../liftings/pointcloud2simplicial/base.py | 18 + .../alpha_complex_lifting.ipynb | 364 ++++++++++++++++++ .../point_cloud_complex_lifting.ipynb | 352 ----------------- 5 files changed, 430 insertions(+), 352 deletions(-) create mode 100644 modules/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.py create mode 100644 tutorials/pointcloud2simplicial/alpha_complex_lifting.ipynb delete mode 100644 tutorials/pointcloud2simplicial/point_cloud_complex_lifting.ipynb diff --git a/modules/transforms/data_transform.py b/modules/transforms/data_transform.py index 59253ecf..4376966c 100755 --- a/modules/transforms/data_transform.py +++ b/modules/transforms/data_transform.py @@ -15,6 +15,9 @@ from modules.transforms.liftings.graph2simplicial.clique_lifting import ( SimplicialCliqueLifting, ) +from modules.transforms.liftings.pointcloud2simplicial.alpha_complex_lifting import ( + AlphaComplexLifting, +) TRANSFORMS = { # Graph -> Hypergraph @@ -23,6 +26,8 @@ "SimplicialCliqueLifting": SimplicialCliqueLifting, # Graph -> Cell Complex "CellCycleLifting": CellCycleLifting, + # Point Cloud -> Simplicial Complex, + "AlphaComplexLifting": AlphaComplexLifting, # Feature Liftings "ProjectionSum": ProjectionSum, # Data Manipulations diff --git a/modules/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.py b/modules/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.py new file mode 100644 index 00000000..5dc06516 --- /dev/null +++ b/modules/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.py @@ -0,0 +1,43 @@ +import gudhi +import torch_geometric +from toponetx.classes import SimplicialComplex + +from modules.transforms.liftings.pointcloud2simplicial.base import ( + PointCloud2SimplicialLifting, +) + + +class AlphaComplexLifting(PointCloud2SimplicialLifting): + r"""Lifts point clouds to simplicial complex domain by generating the alpha complex using the Gudhi library. The alpha complex is a simplicial complex constructed from the finite cells of a Delaunay Triangulation. It has the same persistent homology as the Čech complex and is significantly smaller. + + Parameters + ---------- + **kwargs : optional + Additional arguments for the class. + """ + + def __init__(self, alpha: float, **kwargs): + self.alpha = alpha + super().__init__(**kwargs) + + def lift_topology(self, data: torch_geometric.data.Data) -> dict: + r"""Lifts the topology of a point cloud to the alpha complex. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data to be lifted. + + Returns + ------- + dict + The lifted topology. + """ + ac = gudhi.AlphaComplex(data.pos) + stree = ac.create_simplex_tree() + stree.prune_above_filtration(self.alpha) + stree.prune_above_dimension(self.complex_dim) + sc = SimplicialComplex(s for s, filtration_value in stree.get_simplices()) + lifted_topolgy = self._get_lifted_topology(sc) + lifted_topolgy["x_0"] = data.x + return lifted_topolgy diff --git a/modules/transforms/liftings/pointcloud2simplicial/base.py b/modules/transforms/liftings/pointcloud2simplicial/base.py index fddfaf17..88c3137f 100755 --- a/modules/transforms/liftings/pointcloud2simplicial/base.py +++ b/modules/transforms/liftings/pointcloud2simplicial/base.py @@ -1,3 +1,6 @@ +from toponetx.classes import SimplicialComplex + +from modules.data.utils.utils import get_complex_connectivity from modules.transforms.liftings.lifting import PointCloudLifting @@ -12,6 +15,21 @@ class PointCloud2SimplicialLifting(PointCloudLifting): Additional arguments for the class. """ + def _get_lifted_topology(self, simplicial_complex: SimplicialComplex) -> dict: + r"""Returns the lifted topology. + + Parameters + ---------- + simplicial_complex : SimplicialComplex + The simplicial complex. + Returns + --------- + dict + The lifted topology. + """ + lifted_topology = get_complex_connectivity(simplicial_complex, self.complex_dim) + return lifted_topology + def __init__(self, complex_dim=2, **kwargs): super().__init__(**kwargs) self.complex_dim = complex_dim diff --git a/tutorials/pointcloud2simplicial/alpha_complex_lifting.ipynb b/tutorials/pointcloud2simplicial/alpha_complex_lifting.ipynb new file mode 100644 index 00000000..603674b6 --- /dev/null +++ b/tutorials/pointcloud2simplicial/alpha_complex_lifting.ipynb @@ -0,0 +1,364 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Point Cloud-to-Simplicial Alpha Complex Lifting Tutorial" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "***\n", + "This notebook shows how to import a dataset, with the desired lifting, and how to run a neural network using the loaded data.\n", + "\n", + "The notebook is divided into sections:\n", + "\n", + "- [Loading the dataset](#loading-the-dataset) loads the config files for the data and the desired tranformation, createsa a dataset object and visualizes it.\n", + "- [Loading and applying the lifting](#loading-and-applying-the-lifting) defines a simple neural network to test that the lifting creates the expected incidence matrices.\n", + "- [Create and run a simplicial nn model](#create-and-run-a-simplicial-nn-model) simply runs a forward pass of the model to check that everything is working as expected.\n", + "\n", + "***\n", + "***\n", + "\n", + "There is a set of available datasets that you can play with from the gudhi python library.\n", + "\n", + "To switch to one of the available datasets, simply change the *dataset_name* variable in [Dataset config](#dataset-config) to one of the following names:\n", + "\n", + "* gudhi_sphere\n", + "* gudhi_torus\n", + "* gudhi_bunny\n", + "* gudhi_spiral_2d\n", + "* gudhi_daily_activities\n", + "\n", + "Please see the gudhi documentation [1] for a description of these datasets and the relevant config options. Note that *all datasets except gudhi_daily_activities lack features and targets*. You must instead provide a feature_generator target_generator in the PointCloudLoader\n", + "\n", + "[[1]](https://gudhi.inria.fr/python/latest/index.html) GUDHI Python documentation\n", + "***" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imports and utilities" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# With this cell any imported module is reloaded before each cell execution\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "from modules.data.load.loaders import PointCloudLoader\n", + "from modules.data.preprocess.preprocessor import PreProcessor\n", + "from modules.utils.utils import (\n", + " describe_data,\n", + " load_dataset_config,\n", + " load_model_config,\n", + " load_transform_config,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading the Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we just need to spicify the name of the available dataset that we want to load. First, the dataset config is read from the corresponding yaml file (located at `/configs/datasets/` directory), and then the data is loaded via the implemented `Loaders`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset configuration for manual_points:\n", + "\n", + "{'data_domain': 'pointcloud',\n", + " 'data_type': 'toy_dataset',\n", + " 'data_name': 'manual_points',\n", + " 'data_dir': 'datasets/pointcloud/toy_dataset',\n", + " 'dim': 2,\n", + " 'num_classes': 2,\n", + " 'num_samples': 7,\n", + " 'num_features': 2,\n", + " 'task': 'classification',\n", + " 'loss_type': 'cross_entropy'}\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "dataset_name = \"manual_points\"\n", + "dataset_config = load_dataset_config(dataset_name)\n", + "\n", + "# Note that some point cloud datasets may not have features or targets\n", + "# In this case you must provide the feature_generator and target_generator methods yourself!\n", + "loader = PointCloudLoader(dataset_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can then access to the data through the `load()`method:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset only contains 1 sample:\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " - Set with 7 points.\n", + " - Features dimension: 2\n", + "\n" + ] + } + ], + "source": [ + "dataset = loader.load()\n", + "describe_data(dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading and Applying the Lifting" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this section we will instantiate the lifting we want to apply to the data. For this example the alpha complex lifting was chosen. The alpha complex is a subcomplex of the Delaunay triangulation [1]. It is generated by filtering the simplices of the Delaunay triangulation such that only those with diameter less than $\\alpha^2$ are retained. The GUDHI library is used to compute the complex [2].\n", + "\n", + "---\n", + "[[1]](https://en.wikipedia.org/wiki/Delaunay_triangulation) Delauny Triangulation Wikipedia\n", + "[[2]](https://gudhi.inria.fr/python/latest/alpha_complex_user.html#) Gudhi Alpha Complex User Manual\n", + "\n", + "---\n", + "\n", + "For simplicial complexes creating a lifting involves creating a `SimplicialComplex` object from topomodelx and adding simplices to it using the method `add_simplices_from`. The `SimplicialComplex` class then takes care of creating all the needed matrices.\n", + "\n", + "Similarly to before, we can specify the transformation we want to apply through its type and id --the correxponding config files located at `/configs/transforms`. \n", + "\n", + "Note that the *tranform_config* dictionary generated below can contain a sequence of tranforms if it is needed.\n", + "\n", + "This can also be used to explore liftings from one topological domain to another, for example using two liftings it is possible to achieve a sequence such as: graph -> simplicial complex -> hypergraph. " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Transform configuration for pointcloud2simplicial/alpha_complex_lifting:\n", + "\n", + "{'transform_type': 'lifting',\n", + " 'transform_name': 'AlphaComplexLifting',\n", + " 'complex_dim': 3,\n", + " 'alpha': 25.0,\n", + " 'preserve_edge_attr': False,\n", + " 'signed': True,\n", + " 'feature_lifting': 'ProjectionSum'}\n" + ] + } + ], + "source": [ + "# Define transformation type and id\n", + "transform_type = \"liftings\"\n", + "# If the transform is a topological lifting, it should include both the type of the lifting and the identifier\n", + "transform_id = \"pointcloud2simplicial/alpha_complex_lifting\"\n", + "\n", + "# Read yaml file\n", + "transform_config = {\n", + " \"lifting\": load_transform_config(transform_type, transform_id)\n", + " # other transforms (e.g. data manipulations, feature liftings) can be added here\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We than apply the transform via our `PreProcesor`:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transform parameters are the same, using existing data_dir: /Users/tlong/Documents/code/challenge-icml-2024/datasets/pointcloud/toy_dataset/manual_points/lifting/3217688758\n", + "\n", + "Dataset only contains 1 sample:\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " - The complex has 7 0-cells.\n", + " - The 0-cells have features dimension 2\n", + " - The complex has 9 1-cells.\n", + " - The 1-cells have features dimension 2\n", + " - The complex has 3 2-cells.\n", + " - The 2-cells have features dimension 2\n", + "\n" + ] + } + ], + "source": [ + "lifted_dataset = PreProcessor(dataset, transform_config, loader.data_dir)\n", + "describe_data(lifted_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create and Run a Simplicial NN Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this section a simple model is created to test that the used lifting works as intended. In this case the model uses the `up_laplacian_1` and the `down_laplacian_1` so the lifting should make sure to add them to the data." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Model configuration for simplicial SAN:\n", + "\n", + "{'in_channels': None,\n", + " 'hidden_channels': 32,\n", + " 'out_channels': None,\n", + " 'n_layers': 2,\n", + " 'n_filters': 2,\n", + " 'order_harmonic': 5,\n", + " 'epsilon_harmonic': 0.1}\n" + ] + } + ], + "source": [ + "from modules.models.simplicial.san import SANModel\n", + "\n", + "model_type = \"simplicial\"\n", + "model_id = \"san\"\n", + "model_config = load_model_config(model_type, model_id)\n", + "model = SANModel(model_config, dataset_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "y_hat = model(lifted_dataset.get(0))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If everything is correct the cell above should execute without errors. " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.11.3 ('topox')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + }, + "vscode": { + "interpreter": { + "hash": "5209ee787340d6caf238f8c0093dc78889cb331b3f459734c35c70f07b690b2a" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorials/pointcloud2simplicial/point_cloud_complex_lifting.ipynb b/tutorials/pointcloud2simplicial/point_cloud_complex_lifting.ipynb deleted file mode 100644 index 42a983b6..00000000 --- a/tutorials/pointcloud2simplicial/point_cloud_complex_lifting.ipynb +++ /dev/null @@ -1,352 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Point Cloud-to-Simplicial Alpha Complex Lifting Tutorial" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "***\n", - "This notebook shows how to import a dataset, with the desired lifting, and how to run a neural network using the loaded data.\n", - "\n", - "The notebook is divided into sections:\n", - "\n", - "- [Loading the dataset](#loading-the-dataset) loads the config files for the data and the desired tranformation, createsa a dataset object and visualizes it.\n", - "- [Loading and applying the lifting](#loading-and-applying-the-lifting) defines a simple neural network to test that the lifting creates the expected incidence matrices.\n", - "- [Create and run a simplicial nn model](#create-and-run-a-simplicial-nn-model) simply runs a forward pass of the model to check that everything is working as expected.\n", - "\n", - "***\n", - "***\n", - "\n", - "For simplicity the notebook is setup to use a random point cloud. However, there is a set of available datasets that you can play with from the gudhi python library.\n", - "\n", - "To switch to one of the available datasets, simply change the *dataset_name* variable in [Dataset config](#dataset-config) to one of the following names:\n", - "\n", - "* gudhi_sphere\n", - "* gudhi_torus\n", - "* gudhi_bunny\n", - "* gudhi_spiral_2d\n", - "* gudhi_daily_activities\n", - "\n", - "Please see the gudhi documentation [1] for a description of these datasets and the relevant config options\n", - "\n", - "[[1]](https://gudhi.inria.fr/python/latest/index.html) GUDHI Python documentation\n", - "***" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Imports and utilities" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# With this cell any imported module is reloaded before each cell execution\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "from modules.data.load.loaders import PointCloudLoader\n", - "from modules.data.preprocess.preprocessor import PreProcessor\n", - "from modules.utils.utils import (\n", - " describe_data,\n", - " load_dataset_config,\n", - " load_model_config,\n", - " load_transform_config,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Loading the Dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Here we just need to spicify the name of the available dataset that we want to load. First, the dataset config is read from the corresponding yaml file (located at `/configs/datasets/` directory), and then the data is loaded via the implemented `Loaders`.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Dataset configuration for gudhi_daily_activities:\n", - "\n", - "{'data_domain': 'pointcloud',\n", - " 'data_type': 'gudhi',\n", - " 'data_name': 'gudhi_daily_activities',\n", - " 'data_dir': 'datasets/pointcloud/gudhi',\n", - " 'task': 'regression',\n", - " 'loss_type': 'mse',\n", - " 'monitor_metric': 'mae'}\n" - ] - } - ], - "source": [ - "dataset_name = \"gudhi_daily_activities\"\n", - "dataset_config = load_dataset_config(dataset_name)\n", - "dataset_config[\"n_samples\"] = 900\n", - "loader = PointCloudLoader(dataset_config)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can then access to the data through the `load()`method:" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The 'daily and sports activities' dataset comes from the UC Irvine Machine Learning Repository https://archive.ics.uci.edu/ml/datasets/daily+and+sports+activities\n", - "This dataset is licensed under a Creative Commons Attribution 4.0 International (CC BY 4.0) license.\n", - "\n", - "\n", - "Dataset only contains 1 sample:\n", - " - Set with 30000 points.\n", - " - Features dimension: 0\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Processing...\n", - "Done!\n" - ] - } - ], - "source": [ - "dataset = loader.load()\n", - "describe_data(dataset)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Loading and Applying the Lifting" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this section we will instantiate the lifting we want to apply to the data. For this example the clique lifting was chosen. For a clique of n nodes the algorithm for $m=3,...,max(n, complex\\_dim)$ will create simplicials for every possible combinations containing m nodes of the clique. $complex\\_dim$ is a parameter of the lifting. This is a deterministic lifting, based on connectivity, that does not modify the initial connectivity of the graph. The problem of extracting all the cliques in a graph is NP-hard, on in some formulaitons NP-complete (clique decision problem). The computational complexity of this algorithm is $O(n^k k^2)$[[1]](https://www.sciencedirect.com/science/article/pii/S0019995885800413), where $n$ is the number of nodes in the graph and $k$ is the highest clique dimension considered.\n", - "\n", - "***\n", - "[[1]](https://www.sciencedirect.com/science/article/pii/S0019995885800413) Cook, S. A. (1985). A taxonomy of problems with fast parallel algorithms. Information and control, 64(1-3), 2-22.\n", - "***\n", - "\n", - "For simplicial complexes creating a lifting involves creating a `SimplicialComplex` object from topomodelx and adding simplices to it using the method `add_simplices_from`. The `SimplicialComplex` class then takes care of creating all the needed matrices.\n", - "\n", - "Similarly to before, we can specify the transformation we want to apply through its type and id --the correxponding config files located at `/configs/transforms`. \n", - "\n", - "Note that the *tranform_config* dictionary generated below can contain a sequence of tranforms if it is needed.\n", - "\n", - "This can also be used to explore liftings from one topological domain to another, for example using two liftings it is possible to achieve a sequence such as: graph -> simplicial complex -> hypergraph. " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Transform configuration for graph2simplicial/clique_lifting:\n", - "\n", - "{'transform_type': 'lifting',\n", - " 'transform_name': 'SimplicialCliqueLifting',\n", - " 'complex_dim': 3,\n", - " 'preserve_edge_attr': False,\n", - " 'signed': True,\n", - " 'feature_lifting': 'ProjectionSum'}\n" - ] - } - ], - "source": [ - "# Define transformation type and id\n", - "transform_type = \"liftings\"\n", - "# If the transform is a topological lifting, it should include both the type of the lifting and the identifier\n", - "transform_id = \"graph2simplicial/clique_lifting\"\n", - "\n", - "# Read yaml file\n", - "transform_config = {\n", - " \"lifting\": load_transform_config(transform_type, transform_id)\n", - " # other transforms (e.g. data manipulations, feature liftings) can be added here\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We than apply the transform via our `PreProcesor`:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Transform parameters are the same, using existing data_dir: /Users/leone/Desktop/PhD-S/projects/challenge-icml-2024/datasets/graph/toy_dataset/manual/lifting/2744620725\n", - "\n", - "Dataset only contains 1 sample:\n" - ] - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " - The complex has 8 0-cells.\n", - " - The 0-cells have features dimension 1\n", - " - The complex has 13 1-cells.\n", - " - The 1-cells have features dimension 1\n", - " - The complex has 6 2-cells.\n", - " - The 2-cells have features dimension 1\n", - " - The complex has 1 3-cells.\n", - " - The 3-cells have features dimension 1\n", - "\n" - ] - } - ], - "source": [ - "lifted_dataset = PreProcessor(dataset, transform_config, loader.data_dir)\n", - "describe_data(lifted_dataset)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create and Run a Simplicial NN Model" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this section a simple model is created to test that the used lifting works as intended. In this case the model uses the `up_laplacian_1` and the `down_laplacian_1` so the lifting should make sure to add them to the data." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Model configuration for simplicial SAN:\n", - "\n", - "{'in_channels': None,\n", - " 'hidden_channels': 32,\n", - " 'out_channels': None,\n", - " 'n_layers': 2,\n", - " 'n_filters': 2,\n", - " 'order_harmonic': 5,\n", - " 'epsilon_harmonic': 0.1}\n" - ] - } - ], - "source": [ - "from modules.models.simplicial.san import SANModel\n", - "\n", - "model_type = \"simplicial\"\n", - "model_id = \"san\"\n", - "model_config = load_model_config(model_type, model_id)\n", - "\n", - "model = SANModel(model_config, dataset_config)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "y_hat = model(lifted_dataset.get(0))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If everything is correct the cell above should execute without errors. " - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.11.3 ('topox')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.3" - }, - "vscode": { - "interpreter": { - "hash": "5209ee787340d6caf238f8c0093dc78889cb331b3f459734c35c70f07b690b2a" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From cfa47e4d4aaf7aa79868e9e3efcf274b9a4d7ea7 Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 29 Jun 2024 17:35:43 -0400 Subject: [PATCH 05/15] add tests --- .../test_alpha_complex_lifting.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 test/transforms/liftings/pointcloud2simplicial/test_alpha_complex_lifting.py diff --git a/test/transforms/liftings/pointcloud2simplicial/test_alpha_complex_lifting.py b/test/transforms/liftings/pointcloud2simplicial/test_alpha_complex_lifting.py new file mode 100644 index 00000000..0fe8af35 --- /dev/null +++ b/test/transforms/liftings/pointcloud2simplicial/test_alpha_complex_lifting.py @@ -0,0 +1,65 @@ +"""Test the alpha complex lifting.""" + +import torch + +from modules.data.utils.utils import load_manual_points +from modules.transforms.liftings.pointcloud2simplicial.alpha_complex_lifting import ( + AlphaComplexLifting, +) + + +class TestSimplicialCliqueLifting: + """Test the SimplicialCliqueLifting class.""" + + def setup_method(self): + # Load the graph + self.data = load_manual_points() + + # Initialise the SimplicialCliqueLifting class + self.lifting = AlphaComplexLifting(complex_dim=3, alpha=25.0) + + def test_lift_topology(self): + """Test the lift_topology method.""" + + # Test the lift_topology method + lifted_data = self.lifting.forward(self.data.clone()) + + expected_incidence_1 = torch.tensor( + [ + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + ] + ) + + assert ( + expected_incidence_1 == lifted_data.incidence_1.to_dense() + ).all(), "Something is wrong with incidence_1 (nodes to edges)." + + expected_incidence_2 = torch.tensor( + [ + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [1.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + ] + ) + + assert ( + abs(expected_incidence_2) == lifted_data.incidence_2.to_dense() + ).all(), "Something is wrong with incidence_2 (edges to triangles)." + + expected_incidence_3 = torch.tensor([]) + + assert ( + abs(expected_incidence_3) == lifted_data.incidence_3.to_dense() + ).all(), "Something is wrong with incidence_3 (triangles to tetrahedrons)." From 2b377677d92229e3c71ba1a7f16172df10845a35 Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 29 Jun 2024 17:45:46 -0400 Subject: [PATCH 06/15] ruff fixes --- modules/data/load/loaders.py | 6 +++--- modules/data/utils/utils.py | 14 ++++++-------- .../liftings/pointcloud2simplicial/base.py | 3 +-- .../alpha_complex_lifting.ipynb | 2 -- 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/modules/data/load/loaders.py b/modules/data/load/loaders.py index 6e8a7b7e..291fe023 100755 --- a/modules/data/load/loaders.py +++ b/modules/data/load/loaders.py @@ -1,5 +1,5 @@ import os -from typing import Callable, Optional +from collections.abc import Callable import numpy as np import rootutils @@ -227,8 +227,8 @@ class PointCloudLoader(AbstractLoader): def __init__( self, parameters: DictConfig, - feature_generator: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, - target_generator: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + feature_generator: Callable[[torch.Tensor], torch.Tensor] | None = None, + target_generator: Callable[[torch.Tensor], torch.Tensor] | None = None, ): self.feature_generator = feature_generator self.target_generator = target_generator diff --git a/modules/data/utils/utils.py b/modules/data/utils/utils.py index d8c2aa9e..fd2aefec 100755 --- a/modules/data/utils/utils.py +++ b/modules/data/utils/utils.py @@ -1,7 +1,7 @@ import hashlib import os.path as osp import pickle -from typing import Callable, Optional +from collections.abc import Callable import networkx as nx import numpy as np @@ -431,8 +431,8 @@ def make_hash(o): def load_gudhi_dataset( cfg: omegaconf.DictConfig, - feature_generator: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, - target_generator: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + feature_generator: Callable[[torch.Tensor], torch.Tensor] | None = None, + target_generator: Callable[[torch.Tensor], torch.Tensor] | None = None, ) -> torch_geometric.data.Data: """Load a dataset from the gudhi.datasets module.""" if not cfg.data_name.startswith("gudhi_"): @@ -489,13 +489,12 @@ def load_gudhi_dataset( ) elif gudhi_dataset_name == "daily_activities": # Target is the activity type - # 14. for ‘cross_training’, 18. for ‘jumping’, 13. for ‘stepper’, or 9. for ‘walking’ + # 14. for 'cross_training', 18. for 'jumping', 13. for 'stepper', or 9. for 'walking' y = torch.tensor(data[:, 3:], dtype=torch.float) else: y = None - data = torch_geometric.data.Data(x=x, y=y, pos=pos, complex_dim=0) - return data + return torch_geometric.data.Data(x=x, y=y, pos=pos, complex_dim=0) def load_random_points( @@ -510,8 +509,7 @@ def load_random_points( ) features = torch.tensor(rng.integers(2, size=(num_samples, 1)), dtype=torch.float) - data = torch_geometric.data.Data(x=features, y=classes, pos=points, complex_dim=0) - return data + return torch_geometric.data.Data(x=features, y=classes, pos=points, complex_dim=0) def load_manual_points(): diff --git a/modules/transforms/liftings/pointcloud2simplicial/base.py b/modules/transforms/liftings/pointcloud2simplicial/base.py index 88c3137f..0d1e5031 100755 --- a/modules/transforms/liftings/pointcloud2simplicial/base.py +++ b/modules/transforms/liftings/pointcloud2simplicial/base.py @@ -27,8 +27,7 @@ def _get_lifted_topology(self, simplicial_complex: SimplicialComplex) -> dict: dict The lifted topology. """ - lifted_topology = get_complex_connectivity(simplicial_complex, self.complex_dim) - return lifted_topology + return get_complex_connectivity(simplicial_complex, self.complex_dim) def __init__(self, complex_dim=2, **kwargs): super().__init__(**kwargs) diff --git a/tutorials/pointcloud2simplicial/alpha_complex_lifting.ipynb b/tutorials/pointcloud2simplicial/alpha_complex_lifting.ipynb index 603674b6..1c4f9452 100644 --- a/tutorials/pointcloud2simplicial/alpha_complex_lifting.ipynb +++ b/tutorials/pointcloud2simplicial/alpha_complex_lifting.ipynb @@ -105,8 +105,6 @@ } ], "source": [ - "import torch\n", - "\n", "dataset_name = \"manual_points\"\n", "dataset_config = load_dataset_config(dataset_name)\n", "\n", From 97c37604446ca31a81337694afcb78068d59edd7 Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 29 Jun 2024 17:55:14 -0400 Subject: [PATCH 07/15] handle tensor values --- modules/utils/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/utils/utils.py b/modules/utils/utils.py index 7a1e9275..3bbdb385 100644 --- a/modules/utils/utils.py +++ b/modules/utils/utils.py @@ -136,7 +136,7 @@ def describe_data(dataset: torch_geometric.data.Dataset, idx_sample: int = 0): if hasattr(data, "num_edges") and data.num_edges: complex_dim.append(data.num_edges) features_dim.append(data.num_edge_features) - elif hasattr(data, "edge_index") and data.edge_index: + elif hasattr(data, "edge_index") and (data.edge_index is not None): complex_dim.append(data.edge_index.shape[1]) features_dim.append(data.edge_attr.shape[1]) # Check if the data object contains hyperedges @@ -175,7 +175,11 @@ def describe_data(dataset: torch_geometric.data.Dataset, idx_sample: int = 0): ) print(f" - Features dimensions: {features_dim}") # Check if there are isolated nodes - if hasattr(data, "edge_index") and hasattr(data, "x") and data.edge_index: + if ( + hasattr(data, "edge_index") + and hasattr(data, "x") + and (data.edge_index is not None) + ): connected_nodes = torch.unique(data.edge_index) isolated_nodes = [] for i in range(data.x.shape[0]): From 2777da1f6498c7bdb3c886ee6928b3b6f5c0e2b2 Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 29 Jun 2024 18:53:43 -0400 Subject: [PATCH 08/15] wip feature-based rips lifting --- .../alpha_complex_lifting.yaml | 7 +- modules/transforms/data_transform.py | 6 +- .../alpha_complex_lifting.py | 43 ---------- .../feature_rips_complex_lifting.py | 80 +++++++++++++++++++ ...y => test_feature_rips_complex_lifting.py} | 12 +-- ...ynb => feature_rips_complex_lifting.ipynb} | 15 ++-- 6 files changed, 103 insertions(+), 60 deletions(-) delete mode 100644 modules/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.py create mode 100644 modules/transforms/liftings/pointcloud2simplicial/feature_rips_complex_lifting.py rename test/transforms/liftings/pointcloud2simplicial/{test_alpha_complex_lifting.py => test_feature_rips_complex_lifting.py} (84%) rename tutorials/pointcloud2simplicial/{alpha_complex_lifting.ipynb => feature_rips_complex_lifting.ipynb} (97%) diff --git a/configs/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.yaml b/configs/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.yaml index 6b8771a5..9aa266c3 100644 --- a/configs/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.yaml +++ b/configs/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.yaml @@ -1,7 +1,8 @@ transform_type: 'lifting' -transform_name: "AlphaComplexLifting" +transform_name: "FeatureRipsComplexLifting" complex_dim: 3 -alpha: 25.0 +feature_percent: 0.2 +max_edge_length: 10.0 +sparse: null preserve_edge_attr: False -signed: True feature_lifting: ProjectionSum diff --git a/modules/transforms/data_transform.py b/modules/transforms/data_transform.py index 4376966c..71c0b552 100755 --- a/modules/transforms/data_transform.py +++ b/modules/transforms/data_transform.py @@ -15,8 +15,8 @@ from modules.transforms.liftings.graph2simplicial.clique_lifting import ( SimplicialCliqueLifting, ) -from modules.transforms.liftings.pointcloud2simplicial.alpha_complex_lifting import ( - AlphaComplexLifting, +from modules.transforms.liftings.pointcloud2simplicial.feature_rips_complex_lifting import ( + FeatureRipsComplexLifting, ) TRANSFORMS = { @@ -27,7 +27,7 @@ # Graph -> Cell Complex "CellCycleLifting": CellCycleLifting, # Point Cloud -> Simplicial Complex, - "AlphaComplexLifting": AlphaComplexLifting, + "FeatureRipsComplexLifting": FeatureRipsComplexLifting, # Feature Liftings "ProjectionSum": ProjectionSum, # Data Manipulations diff --git a/modules/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.py b/modules/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.py deleted file mode 100644 index 5dc06516..00000000 --- a/modules/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.py +++ /dev/null @@ -1,43 +0,0 @@ -import gudhi -import torch_geometric -from toponetx.classes import SimplicialComplex - -from modules.transforms.liftings.pointcloud2simplicial.base import ( - PointCloud2SimplicialLifting, -) - - -class AlphaComplexLifting(PointCloud2SimplicialLifting): - r"""Lifts point clouds to simplicial complex domain by generating the alpha complex using the Gudhi library. The alpha complex is a simplicial complex constructed from the finite cells of a Delaunay Triangulation. It has the same persistent homology as the Čech complex and is significantly smaller. - - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, alpha: float, **kwargs): - self.alpha = alpha - super().__init__(**kwargs) - - def lift_topology(self, data: torch_geometric.data.Data) -> dict: - r"""Lifts the topology of a point cloud to the alpha complex. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. - """ - ac = gudhi.AlphaComplex(data.pos) - stree = ac.create_simplex_tree() - stree.prune_above_filtration(self.alpha) - stree.prune_above_dimension(self.complex_dim) - sc = SimplicialComplex(s for s, filtration_value in stree.get_simplices()) - lifted_topolgy = self._get_lifted_topology(sc) - lifted_topolgy["x_0"] = data.x - return lifted_topolgy diff --git a/modules/transforms/liftings/pointcloud2simplicial/feature_rips_complex_lifting.py b/modules/transforms/liftings/pointcloud2simplicial/feature_rips_complex_lifting.py new file mode 100644 index 00000000..36ff4152 --- /dev/null +++ b/modules/transforms/liftings/pointcloud2simplicial/feature_rips_complex_lifting.py @@ -0,0 +1,80 @@ +import gudhi +import torch +import torch_geometric +from toponetx.classes import SimplicialComplex + +from modules.transforms.liftings.pointcloud2simplicial.base import ( + PointCloud2SimplicialLifting, +) + + +class FeatureRipsComplexLifting(PointCloud2SimplicialLifting): + r"""Lifts point clouds to simplicial complex domain by generating the Vietoris-Rips complex using the Gudhi library. This complex is constructed in two steps - first add edges for all pairs of vertices a distance < d away from each other. Then generate the clique complex of the graph. Note that this implementation allows for *feature-based* distances as well - the distance of two nodes is a function of both the position and the features. + + If using feature-based distances, it is recommended that the features be normalized or standardized beforehand. This is because the Euclidean distance of the features will be used directly, so if the scales are wildy different the distances may be dominated by a small number of features with large magnitudes. + + Parameters + ---------- + max_edge_length: float + The maximum pairwise distance to add an edge to the graph + + feature_percent: float + The percentage weight to give the feature-based distance (should be between 0 and 1) + + sparse: float or None. + If float, uses a sparse approximation to the Rips complex to speed up computation. + + **kwargs : optional + Additional arguments for the class. + """ + + def __init__(self, max_edge_length: float, feature_percent: float, **kwargs): + self.feature_percent = feature_percent + self.max_edge_length = max_edge_length + self.sparse = self.sparse + super().__init__(**kwargs) + + def generate_distance_matrix(self, data): + """Generate the pairwise distance matrix of point cloud data, based on both point distances and feature-based distance. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data to be lifted. + + Returns + ------- + torch.tensor + The pairwise distances. + """ + x_expanded_1 = data.x.unsqueeze(1) + x_expanded_2 = data.x.unsqueeze(0) + + # Calculate pairwise differences + feature_differences = x_expanded_1 - x_expanded_2 + + pass + + def lift_topology(self, data: torch_geometric.data.Data) -> dict: + r"""Lifts the topology of a point cloud to the Rips complex based on point-wise and feature-based distances. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data to be lifted. + + Returns + ------- + dict + The lifted topology. + """ + dm = self.generate_distance_matrix(data.pos, data.x) + sc = gudhi.RipsComplex( + distance_matrix=dm, sparse=self.sparse, max_edge_length=self.max_edge_length + ) + stree = sc.create_simplex_tree() + stree.prune_above_dimension(self.complex_dim) + sc = SimplicialComplex(s for s, filtration_value in stree.get_simplices()) + lifted_topolgy = self._get_lifted_topology(sc) + lifted_topolgy["x_0"] = data.x + return lifted_topolgy diff --git a/test/transforms/liftings/pointcloud2simplicial/test_alpha_complex_lifting.py b/test/transforms/liftings/pointcloud2simplicial/test_feature_rips_complex_lifting.py similarity index 84% rename from test/transforms/liftings/pointcloud2simplicial/test_alpha_complex_lifting.py rename to test/transforms/liftings/pointcloud2simplicial/test_feature_rips_complex_lifting.py index 0fe8af35..5ef8ec4d 100644 --- a/test/transforms/liftings/pointcloud2simplicial/test_alpha_complex_lifting.py +++ b/test/transforms/liftings/pointcloud2simplicial/test_feature_rips_complex_lifting.py @@ -1,22 +1,22 @@ -"""Test the alpha complex lifting.""" +"""Test the feature-based Rips complex lifting.""" import torch from modules.data.utils.utils import load_manual_points -from modules.transforms.liftings.pointcloud2simplicial.alpha_complex_lifting import ( - AlphaComplexLifting, +from modules.transforms.liftings.pointcloud2simplicial.feature_rips_complex_lifting import ( + FeatureRipsComplexLifting, ) -class TestSimplicialCliqueLifting: - """Test the SimplicialCliqueLifting class.""" +class TestFeatureRipsComplexLifting: + """Test the FeatureRipsComplexLifting class.""" def setup_method(self): # Load the graph self.data = load_manual_points() # Initialise the SimplicialCliqueLifting class - self.lifting = AlphaComplexLifting(complex_dim=3, alpha=25.0) + self.lifting = FeatureRipsComplexLifting(complex_dim=3, alpha=25.0) def test_lift_topology(self): """Test the lift_topology method.""" diff --git a/tutorials/pointcloud2simplicial/alpha_complex_lifting.ipynb b/tutorials/pointcloud2simplicial/feature_rips_complex_lifting.ipynb similarity index 97% rename from tutorials/pointcloud2simplicial/alpha_complex_lifting.ipynb rename to tutorials/pointcloud2simplicial/feature_rips_complex_lifting.ipynb index 1c4f9452..948da7fa 100644 --- a/tutorials/pointcloud2simplicial/alpha_complex_lifting.ipynb +++ b/tutorials/pointcloud2simplicial/feature_rips_complex_lifting.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Point Cloud-to-Simplicial Alpha Complex Lifting Tutorial" + "# Point Cloud-to-Simplicial Feature-Based Vietoris Rips Complex Lifting Tutorial" ] }, { @@ -169,11 +169,16 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In this section we will instantiate the lifting we want to apply to the data. For this example the alpha complex lifting was chosen. The alpha complex is a subcomplex of the Delaunay triangulation [1]. It is generated by filtering the simplices of the Delaunay triangulation such that only those with diameter less than $\\alpha^2$ are retained. The GUDHI library is used to compute the complex [2].\n", + "In this section we will instantiate the lifting we want to apply to the data. For this example the Vietoris-Rips (or Rips) complex lifting was chosen. It is generated in two steps - first pairwise distances between points are computed, and for any pair with distance $\\le d$, an edge is added. The simplicial complex is then the clique complex of this graph. \n", + "\n", + "Note that this implementation *includes features* in the pairwise distance. More precisely, the parameter `feature_percent` allows the user to adjust how much weight to give to feature differences in the overall distance. Currently the feature distance is simply the Euclidean distance between feature vectors, which is combined with the usual Euclidean distance to generate the final pairwise distances.\n", + "\n", + "This can be useful in cases where the features correspond to a *signal* on a surface. For example, temperature measurments on the surface of a human body. If your arms are by your side, a naïve point cloud complex might interpret the points in your hand as being very close to those of your torso. However, your hand is in fact quite far since the 'true' distance goes up your arm. The feature-based lifting may be able to distinguish this since the feature values are likely quite different, hence the points will not be connected in the lifting.\n", + "\n", + "The GUDHI library is used to compute the complex [1].\n", "\n", "---\n", - "[[1]](https://en.wikipedia.org/wiki/Delaunay_triangulation) Delauny Triangulation Wikipedia\n", - "[[2]](https://gudhi.inria.fr/python/latest/alpha_complex_user.html#) Gudhi Alpha Complex User Manual\n", + "[[1]](https://gudhi.inria.fr/python/latest/rips_complex_user.htm) Gudhi Rips Complex User Manual\n", "\n", "---\n", "\n", @@ -212,7 +217,7 @@ "# Define transformation type and id\n", "transform_type = \"liftings\"\n", "# If the transform is a topological lifting, it should include both the type of the lifting and the identifier\n", - "transform_id = \"pointcloud2simplicial/alpha_complex_lifting\"\n", + "transform_id = \"pointcloud2simplicial/feature_rips_complex_lifting\"\n", "\n", "# Read yaml file\n", "transform_config = {\n", From 53818eb28616d4b06caa99a0925adcce15a3aa9a Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 6 Jul 2024 19:40:22 -0400 Subject: [PATCH 09/15] remove old config --- .../pointcloud2simplicial/alpha_complex_lifting.yaml | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 configs/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.yaml diff --git a/configs/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.yaml b/configs/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.yaml deleted file mode 100644 index 9aa266c3..00000000 --- a/configs/transforms/liftings/pointcloud2simplicial/alpha_complex_lifting.yaml +++ /dev/null @@ -1,8 +0,0 @@ -transform_type: 'lifting' -transform_name: "FeatureRipsComplexLifting" -complex_dim: 3 -feature_percent: 0.2 -max_edge_length: 10.0 -sparse: null -preserve_edge_attr: False -feature_lifting: ProjectionSum From 118a3dc69d77fd301102c7a05cf6585ce322af3f Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 6 Jul 2024 19:40:28 -0400 Subject: [PATCH 10/15] add new config --- .../pointcloud2simplicial/feature_rips_lifting.yaml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 configs/transforms/liftings/pointcloud2simplicial/feature_rips_lifting.yaml diff --git a/configs/transforms/liftings/pointcloud2simplicial/feature_rips_lifting.yaml b/configs/transforms/liftings/pointcloud2simplicial/feature_rips_lifting.yaml new file mode 100644 index 00000000..14aa2d21 --- /dev/null +++ b/configs/transforms/liftings/pointcloud2simplicial/feature_rips_lifting.yaml @@ -0,0 +1,7 @@ +transform_type: 'lifting' +transform_name: "FeatureRipsComplexLifting" +complex_dim: 3 +feature_percent: 0.2 +max_edge_length: 10.0 +sparse: null +feature_lifting: ProjectionSum From afbc8e16149d56199851e410fa32909ce146ea07 Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 6 Jul 2024 19:41:02 -0400 Subject: [PATCH 11/15] rips lifting --- .../feature_rips_complex_lifting.py | 52 ++++++++++++++----- modules/utils/utils.py | 36 +++++++++++++ 2 files changed, 76 insertions(+), 12 deletions(-) diff --git a/modules/transforms/liftings/pointcloud2simplicial/feature_rips_complex_lifting.py b/modules/transforms/liftings/pointcloud2simplicial/feature_rips_complex_lifting.py index 36ff4152..817b4a2f 100644 --- a/modules/transforms/liftings/pointcloud2simplicial/feature_rips_complex_lifting.py +++ b/modules/transforms/liftings/pointcloud2simplicial/feature_rips_complex_lifting.py @@ -6,6 +6,7 @@ from modules.transforms.liftings.pointcloud2simplicial.base import ( PointCloud2SimplicialLifting, ) +from modules.utils.utils import add_epsilon_to_zeros, calculate_pairwise_differences class FeatureRipsComplexLifting(PointCloud2SimplicialLifting): @@ -19,19 +20,34 @@ class FeatureRipsComplexLifting(PointCloud2SimplicialLifting): The maximum pairwise distance to add an edge to the graph feature_percent: float - The percentage weight to give the feature-based distance (should be between 0 and 1) + The percentage weight to give the feature-based distance (should be between 0 and 1). 0 corresponds to the usual Rips Complex, while 1 corresponds to a complex generated using only distance-based features. sparse: float or None. If float, uses a sparse approximation to the Rips complex to speed up computation. + epsilon: float + A small value that gets added to 0 values in the pairwise distance matrix. This is only used to handle an edge case in gudhi where it treats points with distance 0 as the same, and should rarely need to be modified. + **kwargs : optional Additional arguments for the class. """ - def __init__(self, max_edge_length: float, feature_percent: float, **kwargs): + def __init__( + self, + max_edge_length: float, + feature_percent: float, + sparse: bool = False, + epsilon: float = 1e-8, + **kwargs, + ): + if feature_percent < 0 or feature_percent > 1: + raise ValueError( + "feature_percent must be a value between 0 and 1 inclusive." + ) self.feature_percent = feature_percent self.max_edge_length = max_edge_length - self.sparse = self.sparse + self.sparse = sparse + self.epsilon = epsilon super().__init__(**kwargs) def generate_distance_matrix(self, data): @@ -47,13 +63,26 @@ def generate_distance_matrix(self, data): torch.tensor The pairwise distances. """ - x_expanded_1 = data.x.unsqueeze(1) - x_expanded_2 = data.x.unsqueeze(0) + pairwise_distances = torch.zeros( + ( + data.pos.shape[0], + data.pos.shape[0], + ) + ) + + if self.feature_percent > 0: + feature_differences = calculate_pairwise_differences(data.x) + pairwise_distances += self.feature_percent * torch.linalg.norm( + feature_differences, dim=-1 + ) - # Calculate pairwise differences - feature_differences = x_expanded_1 - x_expanded_2 + if self.feature_percent < 1.0: + position_differences = calculate_pairwise_differences(data.pos) + pairwise_distances += (1 - self.feature_percent) * torch.linalg.norm( + position_differences, dim=-1 + ) - pass + return add_epsilon_to_zeros(pairwise_distances, self.epsilon) def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lifts the topology of a point cloud to the Rips complex based on point-wise and feature-based distances. @@ -68,12 +97,11 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: dict The lifted topology. """ - dm = self.generate_distance_matrix(data.pos, data.x) - sc = gudhi.RipsComplex( + dm = self.generate_distance_matrix(data) + gudhi_sc = gudhi.RipsComplex( distance_matrix=dm, sparse=self.sparse, max_edge_length=self.max_edge_length ) - stree = sc.create_simplex_tree() - stree.prune_above_dimension(self.complex_dim) + stree = gudhi_sc.create_simplex_tree(max_dimension=self.complex_dim) sc = SimplicialComplex(s for s, filtration_value in stree.get_simplices()) lifted_topolgy = self._get_lifted_topology(sc) lifted_topolgy["x_0"] = data.x diff --git a/modules/utils/utils.py b/modules/utils/utils.py index 3bbdb385..1d17e17b 100644 --- a/modules/utils/utils.py +++ b/modules/utils/utils.py @@ -553,3 +553,39 @@ def describe_hypergraph(data: torch_geometric.data.Data): if he_idx >= 10: print("...") break + + +def calculate_pairwise_differences(x: torch.Tensor): + r"""Generate tensor of pairwise differences between each row of x. + + Parameters + ---------- + x : torch.Tensor + 2-dimensional tensor. + + Returns + ------- + torch.Tensor + Tensor of pairwise differences of rows of x. + """ + x_expanded_1 = x.unsqueeze(1) + x_expanded_2 = x.unsqueeze(0) + return x_expanded_1 - x_expanded_2 + + +def add_epsilon_to_zeros(tensor, epsilon=1e-8): + """Add a small epsilon value to off-diagonal elements which are zero. This is useful for using the gudhi library since it treats pairwise distances which are 0 as the same point.""" + + # Create a mask for non-diagonal elements + non_diagonal_mask = ~torch.eye( + tensor.shape[0], dtype=torch.bool, device=tensor.device + ) + + # Create a tensor with epsilon values where the original tensor is 0 + epsilon_tensor = torch.zeros_like(tensor) + epsilon_tensor[non_diagonal_mask & (tensor == 0)] = epsilon + + # Add the epsilon tensor to the original tensor + result = tensor + epsilon_tensor + + return result From 00f744dd2a45a5aaa0fa964bbbf869c0eaedc1be Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 6 Jul 2024 19:41:27 -0400 Subject: [PATCH 12/15] update notebook and dataset --- modules/data/utils/utils.py | 2 + .../feature_rips_complex_lifting.ipynb | 292 +++++++++++++++++- 2 files changed, 281 insertions(+), 13 deletions(-) diff --git a/modules/data/utils/utils.py b/modules/data/utils/utils.py index fd2aefec..9982cb1b 100755 --- a/modules/data/utils/utils.py +++ b/modules/data/utils/utils.py @@ -526,5 +526,7 @@ def load_manual_points(): dtype=torch.float, ) x = torch.ones_like(pos, dtype=torch.float) + x[:4] = 0.0 + x[3, 1] = 1.0 y = torch.randint(0, 2, (pos.shape[0],), dtype=torch.float) return torch_geometric.data.Data(x=x, y=y, pos=pos, complex_dim=0) diff --git a/tutorials/pointcloud2simplicial/feature_rips_complex_lifting.ipynb b/tutorials/pointcloud2simplicial/feature_rips_complex_lifting.ipynb index 948da7fa..d8021b57 100644 --- a/tutorials/pointcloud2simplicial/feature_rips_complex_lifting.ipynb +++ b/tutorials/pointcloud2simplicial/feature_rips_complex_lifting.ipynb @@ -33,7 +33,7 @@ "* gudhi_spiral_2d\n", "* gudhi_daily_activities\n", "\n", - "Please see the gudhi documentation [1] for a description of these datasets and the relevant config options. Note that *all datasets except gudhi_daily_activities lack features and targets*. You must instead provide a feature_generator target_generator in the PointCloudLoader\n", + "Please see the gudhi documentation [1] for a description of these datasets and the relevant config options. Note that *all datasets except gudhi_daily_activities lack features and targets*. You must instead provide a feature_generator and target_generator in the PointCloudLoader\n", "\n", "[[1]](https://gudhi.inria.fr/python/latest/index.html) GUDHI Python documentation\n", "***" @@ -133,6 +133,14 @@ "Dataset only contains 1 sample:\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing...\n", + "Done!\n" + ] + }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAqIAAAKlCAYAAAAQBI5LAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAnZklEQVR4nO3dv2+T977A8Y9ppbAQGx+pKkcnAw8b6uQklc5MInVhS3qXriTqeodETIgpSv6BKrB2gXhjqRSffwDKM1VseRiQDkeVbuKEhQzFd+DGl4ADCRg++fF6SdaJH3/tfKk57rvf54drvV6vFwAA8IWdy54AAABnkxAFACCFEAUAIIUQBQAghRAFACCFEAUAIIUQBQAghRAFACCFEAUAIIUQBU6dlZWVqNVqcfHixbh48WLUarW4cuVKLC4ufvJrrKysHGku4+PjMT8/f9Q/wpF0u91YXFyM8fHx/pxnZ2ejqqqIiJienj7Sn/2oyrKMWq322V4fOL2EKHDidLvdmJ2djYsXLx4YmI1GI7a2tmJrayt6vV6sra1FWZZx5cqVfqB9yKDXuHfvXszOzh56rjdv3jzS+KOqqirGx8ejLMtYXl6Ora2t+Ne//hXNZjO63e5n+70AwyBEgRPnxo0bMT09HU+fPo3V1dVot9sfjL1WqxXr6+tRFMVHr1C2Wq1YW1uLdrsdZVke6jkzMzMxNTV16N/R6XTiypUrhx4/PT3d/7NNTU1Fo9GIVqsVq6ur0Wq1Dv06ABmEKHCiVFUV3W435ubmotFoxNTUVD9GD7MCuLq6Gp1OJzqdzkf9/qIootFoxO+///5Rzx+mO3fuRFVVcffu3eypAHwUIQqcKEVRxOrq6r5tExMTERGHisOiKKIoilhfX/+o378Xwnu/MyJicXExrly5EhcvXnxntfXt4zNnZ2djZWUl5ufn+4cW7EXx7OxsTE9PR1VVUavVolarvTeu31wFPaqjzPntY0C73W5MT09HrVaL8fHxj456ACEKnDhFUey7vxegb8bhh55/2ONE39TpdGJ6ejpmZmb6u71nZ2ejLMtYX1+Pp0+fxubmZkxPTx/4GnsnFs3OzsbTp0+j1Wr1Q3BtbS3W1taiKIro9XrR6/XeG5llWb7zz+IwjjrnQc/f3NyMjY2N+Ne//hWPHj068hwAIiK+zp4AwKdaXl6OhYWFI60MHmY3frfb3bcSuBeNCwsLEfE6BNvtdmxtbfV/99raWly8eDE6nc6Bx4a2Wq3+Y/Pz80eKwE/1sXPeU1VVdDqd2NjY6EfwzZs3o91uf+6pA6eQFVHgRJudnY1WqxXLy8uHfk5VVYc6kafRaPRXJnu9Xjx+/LgfoRGvV2L3jhl908TExHt3/b+5cttsNg8977e1Wq0jr+x+7Jz3lGUZjUbjo1ZiAd4mRIETa3Z2NoqiOHKEVlU1lFXIj7080scc0znI9PR0dDqdI83DJZ2A40SIAifS3ok9R4nQiNcn6by5a/xTTE1N9U9eetPvv/8ek5OTn/z6HzI3NxdFUcSNGzcO/ZyPmfPm5mb/56IootvtftQxtgBvE6LAiTM7OxuTk5Px448/Rrfb7d/eZ28VtCzLWFtbG8o89oL22rVr/bjbW6WdmZn5qNfcO5Gq2+1Gp9P5YPCtr69Hp9Ppn4AU8frPuri4OPBboA4z56Io3nmtN5/farVidna2H6RHCWGANwlR4ESpqira7XYsLi72v35z7/bmZYS63e6+r+ecnp6Ooiji8ePHQz2+ce8SSuPj43H58uVoNpvx+PHjj369vdC7fPnyoVZ7i6KIp0+fRrPZjNnZ2f4llbrd7oEx/KE5z8/Px++//96/tNP8/Py+f2Z739x00OMAh1Xr9Xq97EkAAHD2WBEFACCFEAUAIIUQBQAghRAFACCFEAUAIMWJ+q75V69exb///e+4cOHCvu9/BgDgeOj1evHixYv4+9//HufOvX/N80SF6L///e8YGxvLngYAAB/w7Nmz+Mc//vHeMScqRC9cuBARr/9go6OjybMBAOBtOzs7MTY21u+29zlRIbq3O350dFSIAgAcY4c5jNLJSgAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACk+Dp7AnCc/PWqFw+fbsafL17GNxfOx/eXm/HVuVr2tADgVBKi8H9+++N53H7wJJ5vv+xvu1Q/H7euX40fvruUODMAOJ3smod4HaE//1rui9CIiP9sv4yffy3jtz+eJ80MAE4vIcqZ99erXtx+8CR6Ax7b23b7wZP469WgEQDAxxKinHkPn26+sxL6pl5EPN9+GQ+fbn65SQHAGSBEOfP+fHFwhH7MOADgcIQoZ943F84PdRwAcDhClDPv+8vNuFQ/HwddpKkWr8+e//5y80tOCwBOPSHKmffVuVrcun41IuKdGN27f+v6VdcTBYAhE6IQET98dyl++akV39b3737/tn4+fvmp5TqiAPAZuKA9/J8fvrsU01e/9c1KAPCFCFF4w1fnavHPK3/LngYAnAl2zQMAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJDiSCFalmWMj4+/s73dbke3241ut3uo1yjLMiIiqqrq/wwAwNly6BBtt9sREQPDcXZ2Ni5evBgXL16MWq0WtVotVlZWBr7O6upqjI+PR61Wi/n5+SiK4iOnDgDASfb1YQfOzMwM3N7tdmNtbW3f4ysrK7GwsDBw/Pj4eGxtbUVERKPROMJUAQA4TQ4dou/zZoS22+0Do3XPYQN0d3c3dnd3+/d3dnY+an4AABw/n3yy0ptR2e12Y3Nz872727vdbrTb7Wi327G4uBhVVR04dmlpKer1ev82Njb2qdMFAOCYqPV6vd6RnlCrxUFPmZ+fj+Xl5feueHa73f7jZVnG7OxsbGxsDBw7aEV0bGwstre3Y3R09CjTBgDgC9jZ2Yl6vX6oXhva5Zu63W50Op0P7nZ/cwW0KIqoqurAVdGRkZEYHR3ddwMA4HQYWoj+/vvvH4zQsizj2rVr72xvNpvDmgYAACfER4XooOuFlmU5MCjLsuyveBZFEcvLy/3HOp1OzMzMOHseAOAMOvRZ851OJ9bX1yPi9UlEk5OT75wdP+gkpb2xCwsL0Wg0YmJiIlZWVqLRaMTGxkasra194h8BAICT6MgnK2U6ysGvAAB8eSknKwEAwFEIUQAAUghRAABSDOUrPgEAOH7+etWLh083488XL+ObC+fj+8vN+OpcLXtafUIUAOAU+u2P53H7wZN4vv2yv+1S/Xzcun41fvjuUuLM/p9d8wAAp8xvfzyPn38t90VoRMR/tl/Gz7+W8dsfz5Nmtp8QBQA4Rf561YvbD57EoOtz7m27/eBJ/PUq/wqeQhQA4BR5+HTznZXQN/Ui4vn2y3j4dPPLTeoAQhQA4BT588XBEfox4z4nIQoAcIp8c+H8UMd9TkIUAOAU+f5yMy7Vz8dBF2mqxeuz57+/3PyS0xpIiAIAnCJfnavFretXIyLeidG9+7euXz0W1xMVogAAp8wP312KX35qxbf1/bvfv62fj19+ah2b64i6oD0AwCn0w3eXYvrqt75ZCQCAL++rc7X455W/ZU/jQHbNAwCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkOJIIVqWZYyPjw/cXpZlRERUVdX/eZCqqmJlZSXa7XasrKxEt9s92owBADgVvj7swHa7HUVRDIzM1dXVuHPnTkRETE1Nxdra2oGvMzs7G48fP46I11F648aN944HAOB0OnSIzszMHPjY+Ph4bG1tRUREo9E4cFxVVfvuF0URnU7nsFMAAOAUOXSIfsj7AnRPp9OJZrO5b1uz2YyyLKPVar0zfnd3N3Z3d/v3d3Z2PnmeAAAcD0M5Wanb7Ua73Y52ux2Li4vvrHy+OW6Qzc3NgduXlpaiXq/3b2NjY8OYLgAAx8BQVkTn5ub6K6JFUcT09HRsbGwc+vkHBerNmzfjv//7v/v3d3Z2xCgAwCkxlBXRN1dAi6KIqqoGroo2Go13Vj83NzcP3K0/MjISo6Oj+24AAJwOnxyiZVnGtWvX3tn+9rGgEa/PqB9kYmLiU6cBAMAJ81Eh+uau9KIoYnl5uX+/0+nEzMxMf5WzLMv+6mhRFPtep6qqmJiYONSJTgAAnC6HPka00+nE+vp6RLw+iWhycrIfnBMTE7GyshKNRiM2Njb2XRd0b+zCwkJERKytrcXi4mJMTk7Go0ePXEMUAOCMqvV6vV72JA5rZ2cn6vV6bG9vO14UAOAYOkqv+a55AABSCFEAAFIIUQAAUghRAABSCFEAAFIIUQAAUghRAABSCFEAAFIIUQAAUghRAABSCFEAAFIIUQAAUghRAABSCFEAAFIIUQAAUghRAABSCFEAAFJ8nT2B4+yvV714+HQz/nzxMr65cD6+v9yMr87VsqcFAHAqCNED/PbH87j94Ek8337Z33apfj5uXb8aP3x3KXFmAACng13zA/z2x/P4+ddyX4RGRPxn+2X8/GsZv/3xPGlmAACnhxB9y1+venH7wZPoDXhsb9vtB0/ir1eDRgAAcFhC9C0Pn26+sxL6pl5EPN9+GQ+fbn65SQEAnEJC9C1/vjg4Qj9mHAAAgwnRt3xz4fxQxwEAMJgQfcv3l5txqX4+DrpIUy1enz3//eXml5wWAMCpI0Tf8tW5Wty6fjUi4p0Y3bt/6/pV1xMFAPhEQnSAH767FL/81Ipv6/t3v39bPx+//NRyHVEAgCFwQfsD/PDdpZi++q1vVgIA+EyE6Ht8da4W/7zyt+xpAACcSnbNAwCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkOJIIVqWZYyPjw/cvrKyEisrKzE7Oxvdbve9r1GWZUREVFXV/xkAgLPl0CHabrcjIgaGY6fTiYWFhVhYWIjJycm4du3aga+zuroa4+PjUavVYn5+Poqi+IhpAwBw0tV6vV7vSE+o1eLNp5RlGdeuXYutra2IeL3KeeXKldjY2BgYmXfu3Ikff/wxIiIajcaRJruzsxP1ej22t7djdHT0SM8FAODzO0qvff2pv6zVasXdu3f79/d2yzebzQOfc9QABQDg9PnkEI2ImJmZ6f987969mJqaOjA2u91ufzf/o0eP3rt7fnd3N3Z3d/v3d3Z2hjFdAACOgaGE6J69yHz8+PGBY+bm5vqRWhRFTE9Px8bGxsCxS0tLcfv27WFOEQCAY2Kol29aXFyM9fX19+56r6qq/3NRFFFV1b5tb7p582Zsb2/3b8+ePRvmdAEASDS0FdGVlZVYXFyMoij6x4m+HaRvn9i056DjSUdGRmJkZGRYUwQA4Bj5qBXRt68T2m63o9Vq9SP0/v37/Qgty7K/4lkURSwvL/ef1+l0YmZmxslLAABn0KFXRDudTqyvr0fE62M3JycnY2ZmJqqqitnZ2X1jG41GzM3N7Ru7sLAQjUYjJiYmYmVlJRqNRmxsbMTa2toQ/zgAAJwUR76OaCbXEQUAON6O0mu+ax4AgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUQhQAgBRCFACAFEIUAIAUXx9lcFmWcePGjXj8+PG+7VVVRbvdjqIooqqqmJubi0ajMfA1jjIWAIDT69AhuhePZVm+89js7Gw/Tquqihs3bsTa2trA1znKWAAATq9Dh+jMzMzA7VVV7btfFEV0Op1PHgsAwOn2yceIdjqdaDab+7Y1m82BK6dHGRsRsbu7Gzs7O/tuAACcDp8cot1ud+D2zc3NTxobEbG0tBT1er1/Gxsb+9hpAgBwzHy2s+YPis6jjL1582Zsb2/3b8+ePRvO5AAASHeks+YHaTQa76xobm5uDjwT/ihjIyJGRkZiZGTkU6cIAMAx9MkrolNTUwO3T0xMfNJYAABOt48K0Td3pRdFse+xqqpiYmKiv8pZlmX/bPkPjQUA4Ow49K75TqcT6+vrEfH6JKLJycn+JZ3W1tZicXExJicn49GjR/uuC7o3dmFh4YNjAQA4O2q9Xq+XPYnD2tnZiXq9Htvb2zE6Opo9HQAA3nKUXvNd8wAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQQogAApBCiAACkEKIAAKQYWoi22+3odrvR7XY/OLYsyyjLMiIiqqrq/wwAwNkxtBCdnZ2NixcvxsWLF6NWq0WtVouVlZWBY1dXV2N8fDxqtVrMz89HURTDmgYAACfE18N4kW63G2trazEzM9PftrKyEgsLCwPHj4+Px9bWVkRENBqNYUwBAIATZighGhH7IrTdbu+7P8hhAnR3dzd2d3f793d2dj56fgAAHC9D2TX/ZlR2u93Y3Nx87+72brcb7XY72u12LC4uRlVVA8ctLS1FvV7v38bGxoYxXQAAjoFar9frDfMF5+fnY3l5+b0rnt1ut/94WZYxOzsbGxsb74wbtCI6NjYW29vbMTo6OsxpAwAwBDs7O1Gv1w/Va0O9fFO3241Op/PB3e5vroAWRRFVVQ1cFR0ZGYnR0dF9NwAAToehhujvv//+wQgtyzKuXbv2zvZmsznMqQAAcMwNNUTLshwYlGVZ9lc8i6KI5eXl/mOdTidmZmacPQ8AcMYM7az5PYNOUlpaWorJyclYWFiIRqMRExMTsbKyEo1GIzY2NmJtbW3Y0wAA4Jgb+slKn9NRDn4FAODLSztZCQAADkuIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQ4uvsCQBk+utVLx4+3Yw/X7yMby6cj+8vN+Orc7XsaQEMxXH/jBOiwJn12x/P4/aDJ/F8+2V/26X6+bh1/Wr88N2lxJkBfLqT8Bln1zxwJv32x/P4+ddy3wd0RMR/tl/Gz7+W8dsfz5NmBvDpTspnnBAFzpy/XvXi9oMn0Rvw2N622w+exF+vBo0AON5O0mecEAXOnIdPN99ZJXhTLyKeb7+Mh083v9ykAIbkJH3GCVHgzPnzxcEf0B8zDuA4OUmfcUIUOHO+uXB+qOMAjpOT9BknRIEz5/vLzbhUPx8HXcCkFq/PLP3+cvNLTgtgKE7SZ5wQBc6cr87V4tb1qxER73xQ792/df3qsbrWHsBhnaTPOCEKnEk/fHcpfvmpFd/W9++a+rZ+Pn75qXVsrrEH8DFOymdcrdfr5Z+7f0g7OztRr9dje3s7RkdHs6cDnALH/VtHAD5FxmfcUXrNNysBZ9pX52rxzyt/y54GwGdx3D/j7JoHACCFEAUAIIUQBQAghRAFACCFEAUAIIUQBQAghRAFACCFEAUAIIUQBQAghRAFACCFEAUAIIUQBQAghRAFACCFEAUAIIUQBQAghRAFACCFEAUAIIUQBQAgxdfDeqGyLCMiotVqRVVV0e12o9VqDRxbVVW02+0oiiKqqoq5ubloNBrDmgoAACfA0EJ0dXU17ty5ExERU1NTsba2duDY2dnZePz4cUS8jtIbN268dzwAAKfP0EJ0fHw8tra2IiLeu7pZVdW++0VRRKfTGdY0AAA4IYZ6jGij0fjgLvZOpxPNZnPftmaz2d+1DwDA2TC0FdFutxvtdjsiIh49ehTz8/NRFMXAcYNsbm6+s213dzd2d3f793d2doYzWQAA0g0tRN884agoipieno6NjY1DP39QoC4tLcXt27eHNEMAAI6Toe2af/PYz72z4d8+HjTi9e77t1c/Nzc3B+7Sv3nzZmxvb/dvz549G9Z0AQBINpQQLcsyrl279s72t48FjXh9Rv0gExMT72wbGRmJ0dHRfTcAAE6HoYRoURSxvLzcv9/pdGJmZqa/ylmWZX919O3jRquqiomJCdcRBQA4Y4ZyjGij0YiJiYlYWVmJRqMRGxsb+64LurS0FJOTk7GwsBAREWtra7G4uBiTk5Px6NEj1xAFADiDar1er5c9icPa2dmJer0e29vbdtMDABxDR+k13zUPAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBACiEKAEAKIQoAQAohCgBAiq+H9UJlWUan04mIiEePHsXdu3ej0WgcODYiotVqRVVV0e12o9VqDWsqAACcAENbEe10OrGwsBALCwsxOTkZ165dO3Ds6upqjI+PR61Wi/n5+SiKYljTAADghBhKiJZlGUtLS/37MzMzUZZlVFU1cPz4+HhsbW3F1tZWrK+vH7hyCgDA6TWUXfOtVivu3r3bv9/tdiMiotlsHvicw8Tn7u5u7O7u9u/v7Ox89BwBADhehrZrfmZmpv/zvXv3Ympq6sDY7Ha70W63o91ux+Li4oErp0tLS1Gv1/u3sbGxYU0XAIBktV6v1xvmC3a73RgfH4/Hjx+/N0T3HivLMmZnZ2NjY+OdcYNWRMfGxmJ7eztGR0eHOW0AAIZgZ2cn6vX6oXpt6JdvWlxc/OBxn2+ugBZFEVVVDVwVHRkZidHR0X03AABOh6GG6MrKSiwuLkZRFNHtdvvHir6pLMuBZ9S/73hSAABOn6GFaLvdjlar1Y/Q+/fv79v9vrfiWRRFLC8v95/X6XRiZmbGmfMAAGfMUI4Rraoqrly5sm9bo9GIra2tiIiYnZ2NycnJWFhYiIj/v/h9o9GIjY2NfWH6Pkc55gAAgC/vKL029JOVPichCgBwvKWerAQAAIchRAEASCFEAQBIIUQBAEghRAEASCFEAQBIIUQBAEghRAEASCFEAQBIIUQBAEghRAEASCFEAQBIIUQBAEghRAEASCFEAQBIIUQBAEghRAEASCFEAQBIIUQBAEghRAEASCFEAQBIIUQBAEghRAEASCFEAQBIIUQBAEghRAEASCFEAQBIIUQBAEghRAEASCFEAQBIIUQBAEghRAEASCFEAQBI8XX2BDh+/nrVi4dPN+PPFy/jmwvn4/vLzfjqXC17WgDAKSNE2ee3P57H7QdP4vn2y/62S/Xzcev61fjhu0uJMwMAThu75un77Y/n8fOv5b4IjYj4z/bL+PnXMn7743nSzACA00iIEhGvd8fffvAkegMe29t2+8GT+OvVoBEAAEcnRImIiIdPN99ZCX1TLyKeb7+Mh083v9ykAIBTTYgSERF/vjg4Qj9mHADAhwhRIiLimwvnhzoOAOBDhCgREfH95WZcqp+Pgy7SVIvXZ89/f7n5JacFAJxiQpSIiPjqXC1uXb8aEfFOjO7dv3X9quuJAgBDI0Tp++G7S/HLT634tr5/9/u39fPxy08t1xEFAIbKBe3Z54fvLsX01W99sxIA8NkJUd7x1bla/PPK37KnAQCccnbNAwCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJDi6+wJAABfxl+vevHw6Wb8+eJlfHPhfHx/uRlfnatlT4szbGghWlVVtNvtKIoiqqqKubm5aDQanzwWAPh0v/3xPG4/eBLPt1/2t12qn49b16/GD99dSpwZZ1mt1+v1hvFC4+Pj8fjx44h4HZqLi4uxtrb2yWPftLOzE/V6Pba3t2N0dHQY0waAU++3P57Hz7+W8fa/8PfWQn/5qSVGGZqj9NpQjhGtqmrf/aIootPpfPJYAODT/PWqF7cfPHknQiOiv+32gyfx16uhrEvBkQwlRDudTjSbzX3bms1mlGX5SWN3d3djZ2dn3w0AOLyHTzf37Y5/Wy8inm+/jIdPN7/cpOD/DCVEu93uwO2bm+/+pT7K2KWlpajX6/3b2NjYp0wTAM6cP18cHKEfMw6G6bNevumg6Dzs2Js3b8b29nb/9uzZs+FNDgDOgG8unB/qOBimoZw132g03lnR3NzcHHgm/FHGjoyMxMjIyDCmCABn0veXm3Gpfj7+s/1y4HGitYj4tv76Uk7wpQ1lRXRqamrg9omJiU8aCwB8mq/O1eLW9asR8f9nye/Zu3/r+lXXEyXFUEK0KIp996uqiomJif4qZ1mW/bPlPzQWABiuH767FL/81Ipv6/t3v39bP+/STaQa2nVEq6qK1dXVmJycjEePHsXNmzf7cTk7OxuTk5OxsLDwwbHv4zqiAPDxfLMSX8JRem1oIfolCFEAgOPti1/QHgAAjkqIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQQogCAJBCiAIAkEKIAgCQ4uvsCRxFr9eLiIidnZ3kmQAAMMhep+112/ucqBB98eJFRESMjY0lzwQAgPd58eJF1Ov1946p9Q6Tq8fEq1ev4t///ndcuHAharXaF/mdOzs7MTY2Fs+ePYvR0dEv8jvJ5T0/e7znZ4/3/Gzyvn8ZvV4vXrx4EX//+9/j3Ln3HwV6olZEz507F//4xz9Sfvfo6Ki/tGeM9/zs8Z6fPd7zs8n7/vl9aCV0j5OVAABIIUQBAEghRD9gZGQkbt26FSMjI9lT4Qvxnp893vOzx3t+Nnnfj58TdbISAACnhxVRAABSCFEAAFIIUQAAUpyo64h+SVVVRbvdjqIooqqqmJubi0ajkT0tPqOyLKPT6URExKNHj+Lu3bve8zNmcXExbt686X0/AzqdTlRVFUVRRETE1NRU8oz4nKqqik6nE81mM6qqipmZmf57Ty4nKx1gfHw8Hj9+HBGv/wIvLi7G2tpa8qz4nFZWVmJhYaH/87179/p/Bzj9yrKM8fHx2NraEqKnXKfTibW1tVhdXY2qqmJ6ejo2Njayp8Vn9Obne0TE/Px8rK6uJs6IPXbND1BV1b77RVH0V8o4ncqyjKWlpf79mZmZKMvynb8LnF5vro5xus3Pz8fy8nJEvP58X19fT54Rn9u9e/eyp8ABhOgAe8v3b2o2m1GWZdKM+NxarVbcvXu3f7/b7UZEvPP3gNOp3W7HzMxM9jT4Aqqqis3NzWg0GlGWZXS7Xf8BcgY0m80YHx/v76Kfnp7OnhL/R4gOsBchb9vc3PyyE+GLejNE7t27F1NTU3bRngHdbtf7fIaUZRnNZrN/DsCdO3ei3W5nT4vPbO/QuitXrsTa2pr/8DxGnKx0BAcFKqdLt9uNdrvt+NAz4v79+zE3N5c9Db6Qzc3NqKqq/x+ac3NzcfHixXC6xOnW6XRieXk5qqqK+fn5iAjHiB4TVkQHaDQa76x+7u3K4fRbXFyM9fV17/cZ0Ol04scff8yeBl9QURTRaDT6///e+1+HXp1eVVXFo0ePYmpqKubm5mJjYyPu37/vHIBjworoAFNTUwP/S2liYiJhNnxJKysrsbi4GEVR9FfABenpdv/+/f7PVVXF0tJS/Nd//Ve0Wq3EWfG5OB707CnLMiYnJ/v3i6KImzdv2st5TFgRHeDtD6qqqmJiYkKQnHLtdjtarVY/Qu/fv+89P+X2Vkj2bhGvz6gWoadXURQxMTHRj5C9qyV4z0+vVqsVjx492rftf/7nf7znx4TriB6gqqpYXV2NycnJePTokYtcn3JVVcWVK1f2bWs0GrG1tZU0I76kbrcbd+7cicXFxZibmxOjp1y3243FxcX+9aL39oJwenU6nSjLsv/v8ampKe/5MSFEAQBIYdc8AAAphCgAACmEKAAAKYQoAAAphCgAACmEKAAAKYQoAAAphCgAACmEKAAAKYQoAAAphCgAACn+F8Y/4jGG/xFaAAAAAElFTkSuQmCC", @@ -201,14 +209,14 @@ "output_type": "stream", "text": [ "\n", - "Transform configuration for pointcloud2simplicial/alpha_complex_lifting:\n", + "Transform configuration for pointcloud2simplicial/feature_rips_lifting:\n", "\n", "{'transform_type': 'lifting',\n", - " 'transform_name': 'AlphaComplexLifting',\n", + " 'transform_name': 'FeatureRipsComplexLifting',\n", " 'complex_dim': 3,\n", - " 'alpha': 25.0,\n", - " 'preserve_edge_attr': False,\n", - " 'signed': True,\n", + " 'feature_percent': 0.2,\n", + " 'max_edge_length': 10.0,\n", + " 'sparse': None,\n", " 'feature_lifting': 'ProjectionSum'}\n" ] } @@ -217,7 +225,7 @@ "# Define transformation type and id\n", "transform_type = \"liftings\"\n", "# If the transform is a topological lifting, it should include both the type of the lifting and the identifier\n", - "transform_id = \"pointcloud2simplicial/feature_rips_complex_lifting\"\n", + "transform_id = \"pointcloud2simplicial/feature_rips_lifting\"\n", "\n", "# Read yaml file\n", "transform_config = {\n", @@ -242,14 +250,23 @@ "name": "stdout", "output_type": "stream", "text": [ - "Transform parameters are the same, using existing data_dir: /Users/tlong/Documents/code/challenge-icml-2024/datasets/pointcloud/toy_dataset/manual_points/lifting/3217688758\n", "\n", "Dataset only contains 1 sample:\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing...\n", + "/Users/tlong/anaconda3/envs/topox/lib/python3.11/site-packages/scipy/sparse/_index.py:143: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.\n", + " self._set_arrayXarray(i, j, x)\n", + "Done!\n" + ] + }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -263,10 +280,12 @@ "text": [ " - The complex has 7 0-cells.\n", " - The 0-cells have features dimension 2\n", - " - The complex has 9 1-cells.\n", + " - The complex has 13 1-cells.\n", " - The 1-cells have features dimension 2\n", - " - The complex has 3 2-cells.\n", + " - The complex has 9 2-cells.\n", " - The 2-cells have features dimension 2\n", + " - The complex has 2 3-cells.\n", + " - The 3-cells have features dimension 2\n", "\n" ] } @@ -276,6 +295,253 @@ "describe_data(lifted_dataset)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Illustrating how the Feature-based Rips Complex differs from the standard Rips complex\n", + "\n", + "Let's see what the effect of the `feature_percent` parameter is. The Rips complex is generated based on **pairwise distances** between points. Usually these are calculated using the usual Euclidean distance of point positions, however this implementation also lets us linearly combine this with the distance between **feature vectors**." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Transform configuration for pointcloud2simplicial/feature_rips_lifting:\n", + "\n", + "{'transform_type': 'lifting',\n", + " 'transform_name': 'FeatureRipsComplexLifting',\n", + " 'complex_dim': 3,\n", + " 'feature_percent': 0.2,\n", + " 'max_edge_length': 10.0,\n", + " 'sparse': None,\n", + " 'feature_lifting': 'ProjectionSum'}\n" + ] + } + ], + "source": [ + "transform_config = {\n", + " \"lifting\": load_transform_config(transform_type, transform_id)\n", + " # other transforms (e.g. data manipulations, feature liftings) can be added here\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With `feature_percent=0.0`, we obtain the 'standard' Rips Complex, which is based only on the distances between points." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset only contains 1 sample:\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing...\n", + "Done!\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " - The complex has 7 0-cells.\n", + " - The 0-cells have features dimension 2\n", + " - The complex has 10 1-cells.\n", + " - The 1-cells have features dimension 2\n", + " - The complex has 5 2-cells.\n", + " - The 2-cells have features dimension 2\n", + " - The complex has 1 3-cells.\n", + " - The 3-cells have features dimension 2\n", + "\n" + ] + } + ], + "source": [ + "transform_config[\"lifting\"][\"feature_percent\"] = 0.0\n", + "lifted_dataset = PreProcessor(dataset, transform_config, loader.data_dir)\n", + "describe_data(lifted_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With `feature_percent=1.0`, the distances between points are **entirely** based on the features. There are only 3 unique feature vectors in this dataset - 3 vertices with feature vector `[0.0, 0.0]`, 3 with `[1.0, 1.0]`, and 1 with `[1.0, 0.0]`. This is reflected in the geometry of the resulting complex - there are two simplexes of 4 vertices with a 'bridge' corresponding to the single vertex with an intermediate feature vector.\n", + "\n", + "Note that we also set `max_edge_length=1.0` since the scale of the feature vectors is much smaller than the position vectors. When using this lifting, one should be careful to normalize or standardize both features and positions to ensure that the scales match appropriately. " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transform parameters are the same, using existing data_dir: /Users/tlong/Documents/code/challenge-icml-2024/datasets/pointcloud/toy_dataset/manual_points/lifting/1869000105\n", + "\n", + "Dataset only contains 1 sample:\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " - The complex has 7 0-cells.\n", + " - The 0-cells have features dimension 2\n", + " - The complex has 12 1-cells.\n", + " - The 1-cells have features dimension 2\n", + " - The complex has 8 2-cells.\n", + " - The 2-cells have features dimension 2\n", + " - The complex has 2 3-cells.\n", + " - The 3-cells have features dimension 2\n", + "\n" + ] + } + ], + "source": [ + "transform_config[\"lifting\"][\"feature_percent\"] = 1.0\n", + "transform_config[\"lifting\"][\"max_edge_length\"] = 1.0\n", + "lifted_dataset = PreProcessor(dataset, transform_config, loader.data_dir)\n", + "describe_data(lifted_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(indices=tensor([[0, 1, 2, 3, 4, 5, 6, 7],\n", + " [0, 0, 0, 0, 1, 1, 1, 1]]),\n", + " values=tensor([1., 1., 1., 1., 1., 1., 1., 1.]),\n", + " size=(8, 2), nnz=8, layout=torch.sparse_coo)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lifted_dataset.get(0).incidence_3" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally we test `feature_percent=0.2`, an intermediate value where the distances are mostly determined by the positions, with a small influence from the features. We see that the geometry is now different from both of the other two cases." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Dataset only contains 1 sample:\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing...\n", + "Done!\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " - The complex has 7 0-cells.\n", + " - The 0-cells have features dimension 2\n", + " - The complex has 13 1-cells.\n", + " - The 1-cells have features dimension 2\n", + " - The complex has 9 2-cells.\n", + " - The 2-cells have features dimension 2\n", + " - The complex has 2 3-cells.\n", + " - The 3-cells have features dimension 2\n", + "\n" + ] + } + ], + "source": [ + "transform_config[\"lifting\"][\"feature_percent\"] = 0.2\n", + "transform_config[\"lifting\"][\"max_edge_length\"] = 10\n", + "lifted_dataset = PreProcessor(dataset, transform_config, loader.data_dir)\n", + "describe_data(lifted_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When might feature-based Rips Complexes be useful? Suppose we have a point cloud dataset that actually consists of samples from multiple **independent objects**. For example, disjoint spheres. If two of the spheres are very close to one another, the standard Rips complex may generate a topology which incorrectly connects these two spheres. If a feature contribution is added, these spheres may have quite different feature vectors allowing the 'correct' geometry to be inferred.\n", + "\n", + "Another use case is to generate more expressive message-passing toplogies. One might want to have two message-passing networks, one which operates over the usual Euclidean metric topology, while the other use the feature-based topology. This can allow for more complex interactions to be embedded in the topology of the lifted point cloud." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -292,7 +558,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -323,7 +589,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ From 46d79cee11f686a8d0c71a4821532972646aa71e Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 6 Jul 2024 19:41:33 -0400 Subject: [PATCH 13/15] add testing --- .../test_feature_rips_complex_lifting.py | 192 ++++++++++++++---- 1 file changed, 157 insertions(+), 35 deletions(-) diff --git a/test/transforms/liftings/pointcloud2simplicial/test_feature_rips_complex_lifting.py b/test/transforms/liftings/pointcloud2simplicial/test_feature_rips_complex_lifting.py index 5ef8ec4d..85ecb134 100644 --- a/test/transforms/liftings/pointcloud2simplicial/test_feature_rips_complex_lifting.py +++ b/test/transforms/liftings/pointcloud2simplicial/test_feature_rips_complex_lifting.py @@ -12,54 +12,176 @@ class TestFeatureRipsComplexLifting: """Test the FeatureRipsComplexLifting class.""" def setup_method(self): - # Load the graph + # Load the point cloud self.data = load_manual_points() - # Initialise the SimplicialCliqueLifting class - self.lifting = FeatureRipsComplexLifting(complex_dim=3, alpha=25.0) + # Initialise the FeatureRipsLifting class + self.position_lifting = FeatureRipsComplexLifting( + complex_dim=3, feature_percent=0.0, max_edge_length=10.0 + ) + self.feature_lifting = FeatureRipsComplexLifting( + complex_dim=3, feature_percent=1.0, max_edge_length=1.0 + ) + self.mixed_lifting = FeatureRipsComplexLifting( + complex_dim=3, feature_percent=0.2, max_edge_length=10.0 + ) def test_lift_topology(self): """Test the lift_topology method.""" # Test the lift_topology method - lifted_data = self.lifting.forward(self.data.clone()) - - expected_incidence_1 = torch.tensor( - [ - [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], - ] + lifted_data = self.position_lifting.forward(self.data.clone()) + + expected_incidences = ( + torch.tensor( + [ + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + ] + ), + torch.tensor( + [ + [1.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ] + ), + torch.tensor([[1.0], [1.0], [1.0], [1.0], [0.0]]), ) assert ( - expected_incidence_1 == lifted_data.incidence_1.to_dense() - ).all(), "Something is wrong with incidence_1 (nodes to edges)." - - expected_incidence_2 = torch.tensor( - [ - [1.0, 0.0, 0.0], - [1.0, 0.0, 0.0], - [1.0, 1.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 1.0], - [0.0, 0.0, 1.0], - [0.0, 0.0, 1.0], - ] + expected_incidences[0] == lifted_data.incidence_1.to_dense() + ).all(), "Something is wrong with incidence_1 (nodes to edges) for feature_percent=0.0." + + assert ( + abs(expected_incidences[1]) == lifted_data.incidence_2.to_dense() + ).all(), "Something is wrong with incidence_2 (edges to triangles) for feature_percent=0.0." + + assert ( + abs(expected_incidences[2]) == lifted_data.incidence_3.to_dense() + ).all(), "Something is wrong with incidence_3 (triangles to tetrahedrons) for feature_percent=0.0." + + expected_incidences = ( + torch.tensor( + [ + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0], + ] + ), + torch.tensor( + [ + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + ] + ), + torch.tensor( + [ + [1.0, 0.0], + [1.0, 0.0], + [1.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + [0.0, 1.0], + [0.0, 1.0], + [0.0, 1.0], + ] + ), ) + lifted_data = self.feature_lifting.forward(self.data.clone()) assert ( - abs(expected_incidence_2) == lifted_data.incidence_2.to_dense() - ).all(), "Something is wrong with incidence_2 (edges to triangles)." + expected_incidences[0] == lifted_data.incidence_1.to_dense() + ).all(), "Something is wrong with incidence_1 (nodes to edges) for feature_percent=1.0." - expected_incidence_3 = torch.tensor([]) + assert ( + abs(expected_incidences[1]) == lifted_data.incidence_2.to_dense() + ).all(), "Something is wrong with incidence_2 (edges to triangles) for feature_percent=1.0." + + assert ( + abs(expected_incidences[2]) == lifted_data.incidence_3.to_dense() + ).all(), "Something is wrong with incidence_3 (triangles to tetrahedrons) for feature_percent=1.0." + + expected_incidences = ( + torch.tensor( + [ + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0], + ] + ), + torch.tensor( + [ + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + ] + ), + torch.tensor( + [ + [1.0, 0.0], + [1.0, 0.0], + [1.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + [0.0, 1.0], + [0.0, 1.0], + [0.0, 1.0], + [0.0, 0.0], + ] + ), + ) + + lifted_data = self.mixed_lifting.forward(self.data.clone()) + + assert ( + expected_incidences[0] == lifted_data.incidence_1.to_dense() + ).all(), "Something is wrong with incidence_1 (nodes to edges) for feature_percent=0.2 ." + + assert ( + abs(expected_incidences[1]) == lifted_data.incidence_2.to_dense() + ).all(), "Something is wrong with incidence_2 (edges to triangles) for feature_percent=0.2 ." assert ( - abs(expected_incidence_3) == lifted_data.incidence_3.to_dense() - ).all(), "Something is wrong with incidence_3 (triangles to tetrahedrons)." + abs(expected_incidences[2]) == lifted_data.incidence_3.to_dense() + ).all(), "Something is wrong with incidence_3 (triangles to tetrahedrons) for feature_percent=0.2 ." From 83121ceabc4a9ed441fc2ab25ea0b801181aa150 Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 6 Jul 2024 20:54:03 -0400 Subject: [PATCH 14/15] update function (linting + docstring) --- modules/utils/utils.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/modules/utils/utils.py b/modules/utils/utils.py index 1d17e17b..be012115 100644 --- a/modules/utils/utils.py +++ b/modules/utils/utils.py @@ -573,19 +573,26 @@ def calculate_pairwise_differences(x: torch.Tensor): return x_expanded_1 - x_expanded_2 -def add_epsilon_to_zeros(tensor, epsilon=1e-8): - """Add a small epsilon value to off-diagonal elements which are zero. This is useful for using the gudhi library since it treats pairwise distances which are 0 as the same point.""" +def add_epsilon_to_zeros(t: torch.Tensor, epsilon=1e-8): + """Add a small epsilon value to off-diagonal elements which are zero. This is useful for using the gudhi library since it treats pairwise distances which are 0 as the same point. + + Parameters + ---------- + t : torch.Tensor + 2-dimensional tensor. + + Returns + ------- + torch.Tensor + Tensor with non-diagonal elements that are zero shifted slightly by a small epsilon. + """ # Create a mask for non-diagonal elements - non_diagonal_mask = ~torch.eye( - tensor.shape[0], dtype=torch.bool, device=tensor.device - ) + non_diagonal_mask = ~torch.eye(t.shape[0], dtype=torch.bool, device=t.device) # Create a tensor with epsilon values where the original tensor is 0 - epsilon_tensor = torch.zeros_like(tensor) - epsilon_tensor[non_diagonal_mask & (tensor == 0)] = epsilon + epsilon_tensor = torch.zeros_like(t) + epsilon_tensor[non_diagonal_mask & (t == 0)] = epsilon # Add the epsilon tensor to the original tensor - result = tensor + epsilon_tensor - - return result + return t + epsilon_tensor From 617993c8956f23363d078e01f6a7fd0e35e0feaf Mon Sep 17 00:00:00 2001 From: Theo Long Date: Sat, 13 Jul 2024 01:13:01 -0400 Subject: [PATCH 15/15] add test --- .../test_feature_rips_complex_lifting.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/transforms/liftings/pointcloud2simplicial/test_feature_rips_complex_lifting.py b/test/transforms/liftings/pointcloud2simplicial/test_feature_rips_complex_lifting.py index 85ecb134..3c037c17 100644 --- a/test/transforms/liftings/pointcloud2simplicial/test_feature_rips_complex_lifting.py +++ b/test/transforms/liftings/pointcloud2simplicial/test_feature_rips_complex_lifting.py @@ -1,6 +1,7 @@ """Test the feature-based Rips complex lifting.""" import torch +import torch_geometric from modules.data.utils.utils import load_manual_points from modules.transforms.liftings.pointcloud2simplicial.feature_rips_complex_lifting import ( @@ -25,6 +26,27 @@ def setup_method(self): self.mixed_lifting = FeatureRipsComplexLifting( complex_dim=3, feature_percent=0.2, max_edge_length=10.0 ) + self.no_epsilon_lifting = FeatureRipsComplexLifting( + complex_dim=3, + feature_percent=0.5, + max_edge_length=10.0, + epsilon=0.0, + ) + + def test_generate_distance_matrix(self): + # Test the generate_distance_matrix method + data = torch_geometric.data.Data( + pos=torch.tensor([[-1], [1], [2]]).float(), + x=torch.tensor([[2], [4], [6]]).float(), + ) + expected_pairwise_distances = torch.tensor( + [[0, 2, 3.5], [2, 0, 1.5], [3.5, 1.5, 0]] + ) + + assert ( + self.no_epsilon_lifting.generate_distance_matrix(data) + == expected_pairwise_distances + ).all(), "generate_distance_matrix not working as expected" def test_lift_topology(self): """Test the lift_topology method."""