|
7 | 7 | from dartsort.util import drift_util, waveform_util
|
8 | 8 | from dartsort.util.multiprocessing_util import get_pool
|
9 | 9 | from hdbscan import HDBSCAN
|
10 |
| -from scipy.spatial import KDTree |
11 | 10 | from scipy.spatial.distance import pdist
|
12 | 11 | from sklearn.decomposition import PCA
|
13 | 12 | from tqdm.auto import tqdm
|
@@ -86,17 +85,25 @@ def split_clusters(
|
86 | 85 | if recursive:
|
87 | 86 | new_units = np.unique(new_untriaged_labels)
|
88 | 87 | for i in new_units:
|
89 |
| - jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i), max_size_wfs)) |
| 88 | + jobs.append( |
| 89 | + pool.submit( |
| 90 | + _split_job, np.flatnonzero(labels == i), max_size_wfs |
| 91 | + ) |
| 92 | + ) |
90 | 93 | if show_progress:
|
91 | 94 | iterator.total += len(new_units)
|
92 | 95 | elif split_big:
|
93 | 96 | new_units = np.unique(new_untriaged_labels)
|
94 | 97 | for i in new_units:
|
95 | 98 | idx = np.flatnonzero(new_untriaged_labels == i)
|
96 |
| - tall = split_result.x[idx].ptp() > split_big_kw['dx'] |
97 |
| - wide = split_result.z_reg[idx].ptp() > split_big_kw['dx'] |
98 |
| - if (tall or wide) and len(idx) > split_big_kw['min_size_split']: |
99 |
| - jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i), max_size_wfs)) |
| 99 | + tall = split_result.x[idx].ptp() > split_big_kw["dx"] |
| 100 | + wide = split_result.z_reg[idx].ptp() > split_big_kw["dx"] |
| 101 | + if (tall or wide) and len(idx) > split_big_kw["min_size_split"]: |
| 102 | + jobs.append( |
| 103 | + pool.submit( |
| 104 | + _split_job, np.flatnonzero(labels == i), max_size_wfs |
| 105 | + ) |
| 106 | + ) |
100 | 107 | if show_progress:
|
101 | 108 | iterator.total += 1
|
102 | 109 |
|
@@ -234,17 +241,17 @@ def __init__(
|
234 | 241 | )
|
235 | 242 |
|
236 | 243 | def split_cluster(self, in_unit_all, max_size_wfs):
|
237 |
| - |
238 | 244 | n_spikes = in_unit_all.size
|
239 | 245 | if max_size_wfs is not None and n_spikes > max_size_wfs:
|
240 |
| - #TODO: max_size_wfs could be chosen automatically based on available memory and number of spikes |
| 246 | + # TODO: max_size_wfs could be chosen automatically based on available memory and number of spikes |
241 | 247 | idx_subsample = np.random.choice(n_spikes, max_size_wfs, replace=False)
|
242 | 248 | idx_subsample.sort()
|
243 | 249 | in_unit = in_unit_all[idx_subsample]
|
244 | 250 | subsampling = True
|
245 |
| - else: |
| 251 | + else: |
246 | 252 | in_unit = in_unit_all
|
247 | 253 | subsampling = False
|
| 254 | + |
248 | 255 | if n_spikes < self.min_cluster_size:
|
249 | 256 | return SplitResult()
|
250 | 257 |
|
@@ -313,9 +320,17 @@ def split_cluster(self, in_unit_all, max_size_wfs):
|
313 | 320 | new_labels[idx_subsample[kept]] = hdb_labels
|
314 | 321 |
|
315 | 322 | if self.use_localization_features:
|
316 |
| - return SplitResult(is_split=is_split, in_unit=in_unit_all, new_labels=new_labels, x=loc_features[:, 0], z_reg=loc_features[:, 1]) |
| 323 | + return SplitResult( |
| 324 | + is_split=is_split, |
| 325 | + in_unit=in_unit_all, |
| 326 | + new_labels=new_labels, |
| 327 | + x=loc_features[:, 0], |
| 328 | + z_reg=loc_features[:, 1], |
| 329 | + ) |
317 | 330 | else:
|
318 |
| - return SplitResult(is_split=is_split, in_unit=in_unit_all, new_labels=new_labels) |
| 331 | + return SplitResult( |
| 332 | + is_split=is_split, in_unit=in_unit_all, new_labels=new_labels |
| 333 | + ) |
319 | 334 |
|
320 | 335 | def get_registered_channels(self, in_unit):
|
321 | 336 | n_pitches_shift = drift_util.get_spike_pitch_shifts(
|
@@ -349,13 +364,23 @@ def get_registered_channels(self, in_unit):
|
349 | 364 | )
|
350 | 365 | kept = np.flatnonzero(~np.isnan(reloc_amp_vecs).any(axis=1))
|
351 | 366 | reloc_amplitudes = np.nanmax(reloc_amp_vecs[kept], axis=1)
|
| 367 | + reloc_amplitudes = np.log(self.log_c + reloc_amplitudes) |
352 | 368 | else:
|
353 | 369 | reloc_amplitudes = None
|
354 | 370 | kept = np.arange(in_unit.size)
|
355 | 371 |
|
356 | 372 | return max_registered_channel, n_pitches_shift, reloc_amplitudes, kept
|
357 | 373 |
|
358 |
| - def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_size=1_000, max_samples_pca=50_000, amplitude_normalized=False, a=None): |
| 374 | + def pca_features( |
| 375 | + self, |
| 376 | + in_unit, |
| 377 | + max_registered_channel, |
| 378 | + n_pitches_shift, |
| 379 | + batch_size=1_000, |
| 380 | + max_samples_pca=50_000, |
| 381 | + amplitude_normalized=False, |
| 382 | + a=None, |
| 383 | + ): |
359 | 384 | """Compute relocated PCA features on a drift-invariant channel set"""
|
360 | 385 | # figure out which set of channels to use
|
361 | 386 | # we use the stored amplitudes to do this rather than computing a
|
@@ -405,7 +430,9 @@ def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_s
|
405 | 430 | )
|
406 | 431 |
|
407 | 432 | if waveforms is None:
|
408 |
| - waveforms = np.empty((in_unit.size, t * pca_channels.size), dtype=batch.dtype) #POTENTIALLY WAY TOO BIG |
| 433 | + waveforms = np.empty( |
| 434 | + (in_unit.size, t * pca_channels.size), dtype=batch.dtype |
| 435 | + ) # POTENTIALLY WAY TOO BIG |
409 | 436 | waveforms[bs:be] = batch.reshape(n_batch, -1)
|
410 | 437 |
|
411 | 438 | # figure out which waveforms actually overlap with the requested channels
|
|
0 commit comments