diff --git a/MARBLE/preprocessing.py b/MARBLE/preprocessing.py index 5ab11957..d05844eb 100644 --- a/MARBLE/preprocessing.py +++ b/MARBLE/preprocessing.py @@ -24,6 +24,7 @@ def construct_dataset( local_gauges=False, seed=None, metric="euclidean", + eigendecomposition=True ): """Construct PyG dataset from node positions and features. @@ -44,6 +45,8 @@ def construct_dataset( embedding dimension is > 2 or dim embedding is not dim of manifold) seed: Specify for reproducibility in the furthest point sampling. The default is None, which means a random starting vertex. + metric: metric used to fit proximity graph + eigendecomposition: perform eigendecomposition (needed for diffusion). """ anchor = [torch.tensor(a).float() for a in utils.to_list(anchor)] @@ -112,6 +115,7 @@ def construct_dataset( local_gauges=local_gauges, n_geodesic_nb=k * frac_geodesic_nb, var_explained=var_explained, + eigendecomposition=eigendecomposition ) @@ -120,6 +124,7 @@ def _compute_geometric_objects( n_geodesic_nb=10, var_explained=0.9, local_gauges=False, + eigendecomposition=True ): """ Compute geometric objects used later: local gauges, Levi-Civita connections @@ -189,9 +194,10 @@ def _compute_geometric_objects( kernels = g.gradient_op(data.pos, data.edge_index, gauges) Lc = None - print("\n---- Computing eigendecomposition ... ", end="") - L = g.compute_eigendecomposition(L) - Lc = g.compute_eigendecomposition(Lc) + if eigendecomposition: + print("\n---- Computing eigendecomposition ... ", end="") + L = g.compute_eigendecomposition(L) + Lc = g.compute_eigendecomposition(Lc) data.kernels = [ utils.to_SparseTensor(K.coalesce().indices(), value=K.coalesce().values()) for K in kernels