diff --git a/src/dartsort/cluster/split.py b/src/dartsort/cluster/split.py index df271d64..1236eaf6 100644 --- a/src/dartsort/cluster/split.py +++ b/src/dartsort/cluster/split.py @@ -7,7 +7,6 @@ from dartsort.util import drift_util, waveform_util from dartsort.util.multiprocessing_util import get_pool from hdbscan import HDBSCAN -from scipy.spatial import KDTree from scipy.spatial.distance import pdist from sklearn.decomposition import PCA from tqdm.auto import tqdm @@ -86,17 +85,25 @@ def split_clusters( if recursive: new_units = np.unique(new_untriaged_labels) for i in new_units: - jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i), max_size_wfs)) + jobs.append( + pool.submit( + _split_job, np.flatnonzero(labels == i), max_size_wfs + ) + ) if show_progress: iterator.total += len(new_units) elif split_big: new_units = np.unique(new_untriaged_labels) for i in new_units: idx = np.flatnonzero(new_untriaged_labels == i) - tall = split_result.x[idx].ptp() > split_big_kw['dx'] - wide = split_result.z_reg[idx].ptp() > split_big_kw['dx'] - if (tall or wide) and len(idx) > split_big_kw['min_size_split']: - jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i), max_size_wfs)) + tall = split_result.x[idx].ptp() > split_big_kw["dx"] + wide = split_result.z_reg[idx].ptp() > split_big_kw["dx"] + if (tall or wide) and len(idx) > split_big_kw["min_size_split"]: + jobs.append( + pool.submit( + _split_job, np.flatnonzero(labels == i), max_size_wfs + ) + ) if show_progress: iterator.total += 1 @@ -227,18 +234,17 @@ def __init__( ) def split_cluster(self, in_unit_all, max_size_wfs): - - n_spikes = in_unit_all.size if max_size_wfs is not None and n_spikes > max_size_wfs: - #TODO: max_size_wfs could be chosen automatically based on available memory and number of spikes + # TODO: max_size_wfs could be chosen automatically based on available memory and number of spikes idx_subsample = np.random.choice(n_spikes, max_size_wfs, replace=False) idx_subsample.sort() in_unit = in_unit_all[idx_subsample] subsampling = True - else: + else: in_unit = in_unit_all subsampling = False + if n_spikes < self.min_cluster_size: return SplitResult() @@ -293,9 +299,17 @@ def split_cluster(self, in_unit_all, max_size_wfs): new_labels[idx_subsample[kept]] = hdb_labels if self.use_localization_features: - return SplitResult(is_split=is_split, in_unit=in_unit_all, new_labels=new_labels, x=loc_features[:, 0], z_reg=loc_features[:, 1]) + return SplitResult( + is_split=is_split, + in_unit=in_unit_all, + new_labels=new_labels, + x=loc_features[:, 0], + z_reg=loc_features[:, 1], + ) else: - return SplitResult(is_split=is_split, in_unit=in_unit_all, new_labels=new_labels) + return SplitResult( + is_split=is_split, in_unit=in_unit_all, new_labels=new_labels + ) def get_registered_channels(self, in_unit): n_pitches_shift = drift_util.get_spike_pitch_shifts( @@ -329,13 +343,21 @@ def get_registered_channels(self, in_unit): ) kept = np.flatnonzero(~np.isnan(reloc_amp_vecs).any(axis=1)) reloc_amplitudes = np.nanmax(reloc_amp_vecs[kept], axis=1) + reloc_amplitudes = np.log(self.log_c + reloc_amplitudes) else: reloc_amplitudes = None kept = np.arange(in_unit.size) return max_registered_channel, n_pitches_shift, reloc_amplitudes, kept - def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_size=1_000, max_samples_pca=50_000): + def pca_features( + self, + in_unit, + max_registered_channel, + n_pitches_shift, + batch_size=1_000, + max_samples_pca=50_000, + ): """Compute relocated PCA features on a drift-invariant channel set""" # figure out which set of channels to use # we use the stored amplitudes to do this rather than computing a @@ -385,7 +407,9 @@ def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_s ) if waveforms is None: - waveforms = np.empty((in_unit.size, t * pca_channels.size), dtype=batch.dtype) #POTENTIALLY WAY TOO BIG + waveforms = np.empty( + (in_unit.size, t * pca_channels.size), dtype=batch.dtype + ) # POTENTIALLY WAY TOO BIG waveforms[bs:be] = batch.reshape(n_batch, -1) # figure out which waveforms actually overlap with the requested channels @@ -401,11 +425,15 @@ def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_s ) fit_indices = no_nan if fit_indices.size > max_samples_pca: - fit_indices = self.rg.choice(fit_indices, size=max_samples_pca, replace=False) + fit_indices = self.rg.choice( + fit_indices, size=max_samples_pca, replace=False + ) pca.fit(waveforms[fit_indices]) # embed into the cluster's PCA space - pca_projs = np.full((waveforms.shape[0], self.n_pca_features), np.nan, dtype=waveforms.dtype) + pca_projs = np.full( + (waveforms.shape[0], self.n_pca_features), np.nan, dtype=waveforms.dtype + ) for bs in range(0, no_nan.size, batch_size): be = min(no_nan.size, bs + batch_size) pca_projs[no_nan[bs:be]] = pca.transform(waveforms[no_nan[bs:be]])