From 86ef67fbc25318353699bc0eaa9d4995efdb5b4e Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 15 Jan 2024 13:24:59 -0500 Subject: [PATCH] Add grid_snap parcellation --- src/dartsort/cluster/cluster_util.py | 28 ++++++++++++++++++++++++-- src/dartsort/cluster/initial.py | 30 +++++++++++++++++++++++----- src/dartsort/config.py | 8 +++++++- 3 files changed, 58 insertions(+), 8 deletions(-) diff --git a/src/dartsort/cluster/cluster_util.py b/src/dartsort/cluster/cluster_util.py index e8cd76fd..12114f1e 100644 --- a/src/dartsort/cluster/cluster_util.py +++ b/src/dartsort/cluster/cluster_util.py @@ -23,6 +23,30 @@ def closest_registered_channels(times_seconds, x, z_abs, geom, motion_est=None): return reg_channels +def grid_snap(times_seconds, x, z_abs, geom, grid_dx=15, grid_dz=15, motion_est=None): + if motion_est is None: + motion_est == IdentityMotionEstimate() + z_reg = motion_est.correct_s(times_seconds, z_abs) + reg_pos = np.c_[x, z_reg] + + # make a grid inside the registered geom bounding box + registered_geom = drift_util.registered_geometry(geom, motion_est) + min_x, max_x = registered_geom[:, 0].min(), registered_geom[:, 0].max() + min_z, max_z = registered_geom[:, 1].min(), registered_geom[:, 1].max() + grid_x = np.arange(min_x, max_x, grid_dx) + grid_x += (min_x + max_x) / 2 - grid_x.mean() + grid_z = np.arange(min_z, max_z, grid_dz) + grid_z += (min_z + max_z) / 2 - grid_z.mean() + grid_xx, grid_zz = np.meshgrid(grid_x, grid_z, indexing="ij") + grid = np.c_[grid_xx.ravel(), grid_zz.ravel()] + + # snap to closest grid point + registered_kdt = KDTree(grid) + distances, reg_channels = registered_kdt.query(reg_pos) + + return reg_channels + + def hdbscan_clustering( recording, times_seconds, @@ -38,7 +62,7 @@ def hdbscan_clustering( scales=(1, 1, 50), log_c=5, recursive=True, - do_remove_dups=True, + remove_duplicates=True, frames_dedup=12, frame_dedup_cluster=20, ): @@ -60,7 +84,7 @@ def hdbscan_clustering( ) clusterer.fit(features) - if do_remove_dups: + if remove_duplicates: ( clusterer, duplicate_indices, diff --git a/src/dartsort/cluster/initial.py b/src/dartsort/cluster/initial.py index 3352d19d..897c99d8 100644 --- a/src/dartsort/cluster/initial.py +++ b/src/dartsort/cluster/initial.py @@ -57,6 +57,16 @@ def cluster_chunk( labels[in_chunk] = cluster_util.closest_registered_channels( times_s[in_chunk], xyza[in_chunk, 0], xyza[in_chunk, 2], geom, motion_est ) + elif clustering_config.cluster_strategy == "grid_snap": + labels[in_chunk] = cluster_util.grid_snap( + times_s[in_chunk], + xyza[in_chunk, 0], + xyza[in_chunk, 2], + geom, + grid_dx=clustering_config.grid_dx, + grid_dz=clustering_config.grid_dz, + motion_est=motion_est, + ) elif clustering_config.cluster_strategy == "hdbscan": labels[in_chunk] = cluster_util.hdbscan_clustering( times_s[in_chunk], @@ -69,6 +79,8 @@ def cluster_chunk( min_samples=clustering_config.min_samples, cluster_selection_epsilon=clustering_config.cluster_selection_epsilon, scales=clustering_config.feature_scales, + recursive=clustering_config.recursive, + remove_duplicates=clustering_config.remove_duplicates, ) else: assert False @@ -100,17 +112,25 @@ def cluster_chunks( """ chunk_samples = recording.sampling_frequency * clustering_config.chunk_size_s - # determine number of chunks, and we'll count the extra if it's at least 2/3 - n_chunks = recording.get_num_samples() / chunk_samples - n_chunks = np.floor(n_chunks) + (n_chunks - np.floor(n_chunks) > 0.66) - n_chunks = int(max(1, n_chunks)) + # determine number of chunks + # if we're not ensembling, that's 1 chunk. + if ( + not clustering_config.ensemble_strategy + or clustering_config.ensemble_strategy.lower() == "none" + ): + n_chunks = 1 + else: + n_chunks = recording.get_num_samples() / chunk_samples + # we'll count the remainder as a chunk if it's at least 2/3 of one + n_chunks = np.floor(n_chunks) + (n_chunks - np.floor(n_chunks) > 0.66) + n_chunks = int(max(1, n_chunks)) # evenly divide the recording into chunks assert recording.get_num_segments() == 1 start_time_s, end_time_s = recording._recording_segments[0].sample_index_to_time( [0, recording.get_num_samples() - 1] ) - chunk_times_s = np.linspace(start_time_s, end_time_s, num=n_chunks) + chunk_times_s = np.linspace(start_time_s, end_time_s, num=n_chunks + 1) chunk_time_ranges_s = list(zip(chunk_times_s[:-1], chunk_times_s[1:])) # cluster each chunk. can be parallelized in the future. diff --git a/src/dartsort/config.py b/src/dartsort/config.py index cf478a6d..7a672cd5 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -181,10 +181,16 @@ class ClusteringConfig: cluster_selection_epsilon: int = 1 feature_scales: Tuple[float] = (1.0, 1.0, 50.0) log_c: float = 5.0 + recursive: bool = True + remove_duplicates: bool = True + + # grid snap parameters + grid_dx: float = 15.0 + grid_dz: float = 15.0 # -- ensembling ensemble_strategy: Optional[str] = "forward_backward" - chunk_size_s: int = 150 + chunk_size_s: float = 150.0 default_featurization_config = FeaturizationConfig()