diff --git a/src/dartsort/cluster/ensemble_utils.py b/src/dartsort/cluster/ensemble_utils.py index 3e36fa74..e34f7b13 100644 --- a/src/dartsort/cluster/ensemble_utils.py +++ b/src/dartsort/cluster/ensemble_utils.py @@ -10,6 +10,7 @@ def forward_backward( feature_scales=(1, 1, 50), adaptive_feature_scales=False, motion_est=None, + verbose=True, ): """ Ensemble over HDBSCAN clustering @@ -19,21 +20,23 @@ def forward_backward( return chunk_sortings[0] times_seconds = chunk_sortings[0].times_seconds - times_samples = chunk_sortings[0].times_samples + min_time_s = chunk_time_ranges_s[0][0] idx_all_chunks = [get_indices_in_chunk(times_seconds, chunk_range) for chunk_range in chunk_time_ranges_s] # put all labels into one array # TODO: this does not allow for overlapping chunks. - labels_all = np.full_like(times_samples, -1) + labels_all = np.full_like(times_seconds, -1) for ix, sorting in zip(idx_all_chunks, chunk_sortings): - assert labels_all[ix].max() < 0 # assert non-overlapping - labels_all[ix] = sorting.labels[ix] + if len(ix): + assert labels_all[ix].max() < 0 # assert non-overlapping + labels_all[ix] = sorting.labels[ix] # load features that we will need # needs to be all features here amps = chunk_sortings[0].denoised_ptp_amplitudes xyza = chunk_sortings[0].point_source_localizations + x = xyza[:, 0] z_reg = xyza[:, 2] @@ -44,7 +47,11 @@ def forward_backward( if motion_est is not None: z_reg = motion_est.correct_s(times_seconds, z_reg) - for k in trange(len(chunk_sortings) - 1, desc="Ensembling chunks"): + if verbose is True: + tbar = trange(len(chunk_sortings) - 1, desc="Ensembling chunks") + else: + tbar = range(len(chunk_sortings) - 1) + for k in tbar: # CHANGE THE 1 --- # idx_1 = np.flatnonzero(np.logical_and(times_seconds>=min_time_s, times_seconds -1] += unit_label_shift units_1 = np.unique(labels_1) units_1 = units_1[units_1 > -1] units_2 = np.unique(labels_2) units_2 = units_2[units_2 > -1] - # FORWARD PASS - - dist_matrix = np.zeros((units_1.shape[0], units_2.shape[0])) - - # Speed up this code - this matrix can be sparse (only compute distance for "neighboring" units) - OK for now, still pretty fast - for i in range(units_1.shape[0]): - unit_1 = units_1[i] - for j in range(units_2.shape[0]): - unit_2 = units_2[j] - feat_1 = np.c_[ - np.median(x_1[labels_1 == unit_1]), - np.median(z_1[labels_1 == unit_1]), - np.median(amps_1[labels_1 == unit_1]), - ] - feat_2 = np.c_[ - np.median(x_2[labels_2 == unit_2]), - np.median(z_2[labels_2 == unit_2]), - np.median(amps_2[labels_2 == unit_2]), - ] - dist_matrix[i, j] = ((feat_1 - feat_2) ** 2).sum() - - # find for chunk 2 units the closest units in chunk 1 and split chunk 1 units - dist_forward = dist_matrix.argmin(0) - units_, counts_ = np.unique(dist_forward, return_counts=True) - - for unit_to_split in units_[counts_ > 1]: - units_to_match_to = ( - np.flatnonzero(dist_forward == unit_to_split) + unit_label_shift - ) - features_to_match_to = np.c_[ - np.median(x_2[labels_2 == units_to_match_to[0]]), - np.median(z_2[labels_2 == units_to_match_to[0]]), - np.median(amps_2[labels_2 == units_to_match_to[0]]), - ] - for u in units_to_match_to[1:]: - features_to_match_to = np.concatenate( - ( - features_to_match_to, - np.c_[ - np.median(x_2[labels_2 == u]), - np.median(z_2[labels_2 == u]), - np.median(amps_2[labels_2 == u]), - ], - ) - ) - spikes_to_update = np.flatnonzero(labels_1 == unit_to_split) - x_s_to_update = x_1[spikes_to_update] - z_s_to_update = z_1[spikes_to_update] - amps_s_to_update = amps_1[spikes_to_update] - for j, s in enumerate(spikes_to_update): - # Don't update if new distance is too high? - feat_s = np.c_[ - x_s_to_update[j], z_s_to_update[j], amps_s_to_update[j] - ] - labels_1[s] = units_to_match_to[ - ((feat_s - features_to_match_to) ** 2).sum(1).argmin() - ] - - # Relabel labels_1 and labels_2 - for unit_to_relabel in units_: - if counts_[np.flatnonzero(units_ == unit_to_relabel)][0] == 1: - idx_to_relabel = np.flatnonzero(labels_1 == unit_to_relabel) - labels_1[idx_to_relabel] = units_2[dist_forward == unit_to_relabel] - - # BACKWARD PASS - - units_not_matched = np.unique(labels_1) - units_not_matched = units_not_matched[units_not_matched > -1] - units_not_matched = units_not_matched[units_not_matched < unit_label_shift] - - if len(units_not_matched): - all_units_to_match_to = ( - dist_matrix[units_not_matched].argmin(1) + unit_label_shift - ) - for unit_to_split in np.unique(all_units_to_match_to): - units_to_match_to = np.concatenate( - ( - units_not_matched[all_units_to_match_to == unit_to_split], - [unit_to_split], - ) + if len(units_2) and len(units_1): + unit_label_shift = int(labels_1.max() + 1) + labels_2[labels_2 > -1] += unit_label_shift + units_2 += unit_label_shift + + # FORWARD PASS + dist_matrix = np.zeros((units_1.shape[0], units_2.shape[0])) + + # Speed up this code - this matrix can be sparse (only compute distance for "neighboring" units) - OK for now, still pretty fast + for i in range(units_1.shape[0]): + unit_1 = units_1[i] + for j in range(units_2.shape[0]): + unit_2 = units_2[j] + feat_1 = np.c_[ + np.median(x_1[labels_1 == unit_1]), + np.median(z_1[labels_1 == unit_1]), + np.median(amps_1[labels_1 == unit_1]), + ] + feat_2 = np.c_[ + np.median(x_2[labels_2 == unit_2]), + np.median(z_2[labels_2 == unit_2]), + np.median(amps_2[labels_2 == unit_2]), + ] + dist_matrix[i, j] = ((feat_1 - feat_2) ** 2).sum() + + # find for chunk 2 units the closest units in chunk 1 and split chunk 1 units + dist_forward = dist_matrix.argmin(0) + units_, counts_ = np.unique(dist_forward, return_counts=True) + + for unit_to_split in units_[counts_ > 1]: + units_to_match_to = ( + np.flatnonzero(dist_forward == unit_to_split) + unit_label_shift ) - features_to_match_to = np.c_[ - np.median(x_1[labels_1 == units_to_match_to[0]]), - np.median(z_1[labels_1 == units_to_match_to[0]]), - np.median(amps_1[labels_1 == units_to_match_to[0]]), + np.median(x_2[labels_2 == units_to_match_to[0]]), + np.median(z_2[labels_2 == units_to_match_to[0]]), + np.median(amps_2[labels_2 == units_to_match_to[0]]), ] for u in units_to_match_to[1:]: features_to_match_to = np.concatenate( ( features_to_match_to, np.c_[ - np.median(x_1[labels_1 == u]), - np.median(z_1[labels_1 == u]), - np.median(amps_1[labels_1 == u]), + np.median(x_2[labels_2 == u]), + np.median(z_2[labels_2 == u]), + np.median(amps_2[labels_2 == u]), ], ) ) - spikes_to_update = np.flatnonzero(labels_2 == unit_to_split) - x_s_to_update = x_2[spikes_to_update] - z_s_to_update = z_2[spikes_to_update] - amps_s_to_update = amps_2[spikes_to_update] + spikes_to_update = np.flatnonzero(labels_1 == unit_to_split) + x_s_to_update = x_1[spikes_to_update] + z_s_to_update = z_1[spikes_to_update] + amps_s_to_update = amps_1[spikes_to_update] for j, s in enumerate(spikes_to_update): + # Don't update if new distance is too high? feat_s = np.c_[ x_s_to_update[j], z_s_to_update[j], amps_s_to_update[j] ] - labels_2[s] = units_to_match_to[ + labels_1[s] = units_to_match_to[ ((feat_s - features_to_match_to) ** 2).sum(1).argmin() ] - - # Do we need to "regularize" and make sure the distance intra units after merging is smaller than the distance inter units before merging - all_labels_1 = np.unique(labels_1) - all_labels_1 = all_labels_1[all_labels_1 > -1] - - features_all_1 = np.c_[ - np.median(x_1[labels_1 == all_labels_1[0]]), - np.median(z_1[labels_1 == all_labels_1[0]]), - np.median(amps_1[labels_1 == all_labels_1[0]]), - ] - for u in all_labels_1[1:]: - features_all_1 = np.concatenate( - ( - features_all_1, - np.c_[ - np.median(x_1[labels_1 == u]), - np.median(z_1[labels_1 == u]), - np.median(amps_1[labels_1 == u]), - ], + + # Relabel labels_1 and labels_2 + for unit_to_relabel in units_: + if counts_[np.flatnonzero(units_ == unit_to_relabel)][0] == 1: + idx_to_relabel = np.flatnonzero(labels_1 == unit_to_relabel) + labels_1[idx_to_relabel] = units_2[dist_forward == unit_to_relabel] + + # BACKWARD PASS + + units_not_matched = np.unique(labels_1) + units_not_matched = units_not_matched[units_not_matched > -1] + units_not_matched = units_not_matched[units_not_matched < unit_label_shift] + + if len(units_not_matched): + all_units_to_match_to = ( + dist_matrix[units_not_matched].argmin(1) + unit_label_shift ) + for unit_to_split in np.unique(all_units_to_match_to): + units_to_match_to = np.concatenate( + ( + units_not_matched[all_units_to_match_to == unit_to_split], + [unit_to_split], + ) + ) + + features_to_match_to = np.c_[ + np.median(x_1[labels_1 == units_to_match_to[0]]), + np.median(z_1[labels_1 == units_to_match_to[0]]), + np.median(amps_1[labels_1 == units_to_match_to[0]]), + ] + for u in units_to_match_to[1:]: + features_to_match_to = np.concatenate( + ( + features_to_match_to, + np.c_[ + np.median(x_1[labels_1 == u]), + np.median(z_1[labels_1 == u]), + np.median(amps_1[labels_1 == u]), + ], + ) + ) + spikes_to_update = np.flatnonzero(labels_2 == unit_to_split) + x_s_to_update = x_2[spikes_to_update] + z_s_to_update = z_2[spikes_to_update] + amps_s_to_update = amps_2[spikes_to_update] + for j, s in enumerate(spikes_to_update): + feat_s = np.c_[ + x_s_to_update[j], z_s_to_update[j], amps_s_to_update[j] + ] + labels_2[s] = units_to_match_to[ + ((feat_s - features_to_match_to) ** 2).sum(1).argmin() + ] + + # Do we need to "regularize" and make sure the distance intra units after merging is smaller than the distance inter units before merging + # all_labels_1 = np.unique(labels_1) + # all_labels_1 = all_labels_1[all_labels_1 > -1] + + # features_all_1 = np.c_[ + # np.median(x_1[labels_1 == all_labels_1[0]]), + # np.median(z_1[labels_1 == all_labels_1[0]]), + # np.median(amps_1[labels_1 == all_labels_1[0]]), + # ] + # for u in all_labels_1[1:]: + # features_all_1 = np.concatenate( + # ( + # features_all_1, + # np.c_[ + # np.median(x_1[labels_1 == u]), + # np.median(z_1[labels_1 == u]), + # np.median(amps_1[labels_1 == u]), + # ], + # ) + # ) + + # distance_inter = ( + # (features_all_1[:, :, None] - features_all_1.T[None]) ** 2 + # ).sum(1) + + labels_12 = np.concatenate((labels_1, labels_2)) + _, labels_12[labels_12 > -1] = np.unique( + labels_12[labels_12 > -1], return_inverse=True + ) # Make contiguous + idx_all = np.flatnonzero( + times_seconds < min_time_s + chunk_time_ranges_s[k + 1][1] ) - - distance_inter = ( - (features_all_1[:, :, None] - features_all_1.T[None]) ** 2 - ).sum(1) - - labels_12 = np.concatenate((labels_1, labels_2)) - _, labels_12[labels_12 > -1] = np.unique( - labels_12[labels_12 > -1], return_inverse=True - ) # Make contiguous - idx_all = np.flatnonzero( - times_seconds < min_time_s + chunk_time_ranges_s[k + 1][1] - ) - labels_all = -1 * np.ones( - times_seconds.shape[0] - ) # discard all spikes at the end for now - labels_all[idx_all] = labels_12.astype("int") + labels_all = -1 * np.ones( + times_seconds.shape[0] + ) # discard all spikes at the end for now + labels_all[idx_all] = labels_12.astype("int") return labels_all diff --git a/src/dartsort/cluster/split.py b/src/dartsort/cluster/split.py index ccd0b571..47956580 100644 --- a/src/dartsort/cluster/split.py +++ b/src/dartsort/cluster/split.py @@ -6,12 +6,15 @@ import torch from dartsort.util import drift_util, waveform_util from dartsort.util.multiprocessing_util import get_pool +from dartsort.util.data_util import DARTsortSorting + from hdbscan import HDBSCAN from scipy.spatial.distance import pdist from sklearn.decomposition import PCA from tqdm.auto import tqdm from . import cluster_util, relocate +from .ensemble_utils import forward_backward def split_clusters( @@ -23,7 +26,6 @@ def split_clusters( split_big_kw=dict(dz=40, dx=48, min_size_split=50), show_progress=True, n_jobs=0, - max_size_wfs=None, ): """Parallel split step runner function @@ -56,7 +58,7 @@ def split_clusters( initargs=(split_strategy, split_strategy_kwargs), ) as pool: iterator = jobs = [ - pool.submit(_split_job, np.flatnonzero(labels == i), max_size_wfs) + pool.submit(_split_job, np.flatnonzero(labels == i)) for i in labels_to_process ] if show_progress: @@ -87,7 +89,7 @@ def split_clusters( for i in new_units: jobs.append( pool.submit( - _split_job, np.flatnonzero(labels == i), max_size_wfs + _split_job, np.flatnonzero(labels == i) ) ) if show_progress: @@ -101,7 +103,7 @@ def split_clusters( if (tall or wide) and len(idx) > split_big_kw["min_size_split"]: jobs.append( pool.submit( - _split_job, np.flatnonzero(labels == i), max_size_wfs + _split_job, np.flatnonzero(labels == i) ) ) if show_progress: @@ -160,6 +162,9 @@ class FeatureSplit(SplitStrategy): def __init__( self, peeling_hdf5_filename, + recording=None, + ensemble_over_chunks=False, + chunk_size_s=300, peeling_featurization_pt=None, motion_est=None, use_localization_features=True, @@ -177,6 +182,7 @@ def __init__( use_ptp=True, amplitude_normalized=False, use_spread=False, + max_size_wfs=None, **dataset_name_kwargs, ): """Split clusters based on per-cluster PCA and localization features @@ -221,6 +227,7 @@ def __init__( self.log_c = log_c self.rg = np.random.default_rng(random_state) self.reassign_outliers = reassign_outliers + self.max_size_wfs = max_size_wfs # hdbscan parameters self.min_cluster_size = min_cluster_size @@ -232,6 +239,15 @@ def __init__( self.amplitude_normalized = amplitude_normalized self.use_spread = use_spread + #Check for ensembling + self.ensemble_over_chunks = ensemble_over_chunks + self.recording = recording + self.chunk_size_s = chunk_size_s + if self.ensemble_over_chunks: + assert self.use_localization_features, "Need to use loc features for ensembling over chunks" + assert self.recording is not None, "Need to input recording for ensembling over chunks" + assert self.chunk_size_s is not None, "Need to input chunk size for ensembling over chunks" + # load up the required h5 datasets self.initialize_from_h5( peeling_hdf5_filename, @@ -239,11 +255,101 @@ def __init__( **dataset_name_kwargs, ) - def split_cluster(self, in_unit_all, max_size_wfs): + def split_cluster_with_strategy(self, in_unit): + if self.ensemble_over_chunks: + return self.split_cluster_chunks(in_unit) + else: + return self.split_cluster(in_unit) + + def split_cluster_chunks(self, + in_unit): + + chunk_samples = self.recording.sampling_frequency * self.chunk_size_s + + n_chunks = self.recording.get_num_samples() / chunk_samples + # we'll count the remainder as a chunk if it's at least 2/3 of one + n_chunks = np.floor(n_chunks) + (n_chunks - np.floor(n_chunks) > 0.66) + n_chunks = int(max(1, n_chunks)) + + # evenly divide the recording into chunks + assert self.recording.get_num_segments() == 1 + start_time_s, end_time_s = self.recording._recording_segments[0].sample_index_to_time( + np.array([0, self.recording.get_num_samples() - 1]) + ) + chunk_times_s = np.linspace(start_time_s, end_time_s, num=n_chunks + 1) + chunk_time_ranges_s = list(zip(chunk_times_s[:-1], chunk_times_s[1:])) + + chunks_indices_unit = [ + np.flatnonzero(np.logical_and( + self.t_s[in_unit]>=chunk_range[0], self.t_s[in_unit] 1 + + return SplitResult( + is_split=is_split, + in_unit=in_unit, + new_labels=labels_ensembled, + x=self.localization_features[in_unit, 0], + z_reg=self.localization_features[in_unit, 2], + ) + + + def split_cluster(self, in_unit_all): n_spikes = in_unit_all.size - if max_size_wfs and n_spikes > max_size_wfs: + if self.max_size_wfs and n_spikes > self.max_size_wfs: # TODO: max_size_wfs could be chosen automatically based on available memory and number of spikes - idx_subsample = np.random.choice(n_spikes, max_size_wfs, replace=False) + idx_subsample = np.random.choice(n_spikes, self.max_size_wfs, replace=False) idx_subsample.sort() in_unit = in_unit_all[idx_subsample] subsampling = True @@ -270,7 +376,6 @@ def split_cluster(self, in_unit_all, max_size_wfs): loc_features = self.localization_features[in_unit] if self.relocated: loc_features[kept, 2] = np.log(self.log_c + reloc_amplitudes) - features.append(loc_features) if self.rescale_all_features: loc_features[kept, 1] *= np.median( np.abs(loc_features[kept, 0] - np.median(loc_features[kept, 0])) @@ -282,8 +387,9 @@ def split_cluster(self, in_unit_all, max_size_wfs): ) / np.median( np.abs(loc_features[kept, 2] - np.median(loc_features[kept, 2])) ) - if not self.use_ptp: - loc_features = loc_features[:, :2] + if not self.use_ptp: + loc_features = loc_features[:, :2] + features.append(loc_features) if self.n_pca_features > 0: enough_good_spikes, kept, pca_embeds = self.pca_features( @@ -350,6 +456,8 @@ def split_cluster(self, in_unit_all, max_size_wfs): new_labels[kept] = hdb_labels else: new_labels[idx_subsample[kept]] = hdb_labels + else: + new_labels = np.full(n_spikes, 0) #Needed for split ensembling over chunks, cannot be none if self.use_localization_features: return SplitResult( @@ -627,8 +735,8 @@ def _split_job_init(split_strategy_class_name, split_strategy_kwargs): _split_job_context = SplitJobContext(split_strategy(**split_strategy_kwargs)) -def _split_job(in_unit, max_size_wfs): - return _split_job_context.split_strategy.split_cluster(in_unit, max_size_wfs) +def _split_job(in_unit): + return _split_job_context.split_strategy.split_cluster_with_strategy(in_unit) # -- h5 helper... slow reading...