Skip to content

Commit f50c0a0

Browse files
author
julienboussard
committed
fix split + additional options
2 parents 2fa7d6a + 6b4d2c4 commit f50c0a0

File tree

1 file changed

+40
-13
lines changed

1 file changed

+40
-13
lines changed

src/dartsort/cluster/split.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from dartsort.util import drift_util, waveform_util
88
from dartsort.util.multiprocessing_util import get_pool
99
from hdbscan import HDBSCAN
10-
from scipy.spatial import KDTree
1110
from scipy.spatial.distance import pdist
1211
from sklearn.decomposition import PCA
1312
from tqdm.auto import tqdm
@@ -86,17 +85,25 @@ def split_clusters(
8685
if recursive:
8786
new_units = np.unique(new_untriaged_labels)
8887
for i in new_units:
89-
jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i), max_size_wfs))
88+
jobs.append(
89+
pool.submit(
90+
_split_job, np.flatnonzero(labels == i), max_size_wfs
91+
)
92+
)
9093
if show_progress:
9194
iterator.total += len(new_units)
9295
elif split_big:
9396
new_units = np.unique(new_untriaged_labels)
9497
for i in new_units:
9598
idx = np.flatnonzero(new_untriaged_labels == i)
96-
tall = split_result.x[idx].ptp() > split_big_kw['dx']
97-
wide = split_result.z_reg[idx].ptp() > split_big_kw['dx']
98-
if (tall or wide) and len(idx) > split_big_kw['min_size_split']:
99-
jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i), max_size_wfs))
99+
tall = split_result.x[idx].ptp() > split_big_kw["dx"]
100+
wide = split_result.z_reg[idx].ptp() > split_big_kw["dx"]
101+
if (tall or wide) and len(idx) > split_big_kw["min_size_split"]:
102+
jobs.append(
103+
pool.submit(
104+
_split_job, np.flatnonzero(labels == i), max_size_wfs
105+
)
106+
)
100107
if show_progress:
101108
iterator.total += 1
102109

@@ -234,17 +241,17 @@ def __init__(
234241
)
235242

236243
def split_cluster(self, in_unit_all, max_size_wfs):
237-
238244
n_spikes = in_unit_all.size
239245
if max_size_wfs is not None and n_spikes > max_size_wfs:
240-
#TODO: max_size_wfs could be chosen automatically based on available memory and number of spikes
246+
# TODO: max_size_wfs could be chosen automatically based on available memory and number of spikes
241247
idx_subsample = np.random.choice(n_spikes, max_size_wfs, replace=False)
242248
idx_subsample.sort()
243249
in_unit = in_unit_all[idx_subsample]
244250
subsampling = True
245-
else:
251+
else:
246252
in_unit = in_unit_all
247253
subsampling = False
254+
248255
if n_spikes < self.min_cluster_size:
249256
return SplitResult()
250257

@@ -313,9 +320,17 @@ def split_cluster(self, in_unit_all, max_size_wfs):
313320
new_labels[idx_subsample[kept]] = hdb_labels
314321

315322
if self.use_localization_features:
316-
return SplitResult(is_split=is_split, in_unit=in_unit_all, new_labels=new_labels, x=loc_features[:, 0], z_reg=loc_features[:, 1])
323+
return SplitResult(
324+
is_split=is_split,
325+
in_unit=in_unit_all,
326+
new_labels=new_labels,
327+
x=loc_features[:, 0],
328+
z_reg=loc_features[:, 1],
329+
)
317330
else:
318-
return SplitResult(is_split=is_split, in_unit=in_unit_all, new_labels=new_labels)
331+
return SplitResult(
332+
is_split=is_split, in_unit=in_unit_all, new_labels=new_labels
333+
)
319334

320335
def get_registered_channels(self, in_unit):
321336
n_pitches_shift = drift_util.get_spike_pitch_shifts(
@@ -349,13 +364,23 @@ def get_registered_channels(self, in_unit):
349364
)
350365
kept = np.flatnonzero(~np.isnan(reloc_amp_vecs).any(axis=1))
351366
reloc_amplitudes = np.nanmax(reloc_amp_vecs[kept], axis=1)
367+
reloc_amplitudes = np.log(self.log_c + reloc_amplitudes)
352368
else:
353369
reloc_amplitudes = None
354370
kept = np.arange(in_unit.size)
355371

356372
return max_registered_channel, n_pitches_shift, reloc_amplitudes, kept
357373

358-
def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_size=1_000, max_samples_pca=50_000, amplitude_normalized=False, a=None):
374+
def pca_features(
375+
self,
376+
in_unit,
377+
max_registered_channel,
378+
n_pitches_shift,
379+
batch_size=1_000,
380+
max_samples_pca=50_000,
381+
amplitude_normalized=False,
382+
a=None,
383+
):
359384
"""Compute relocated PCA features on a drift-invariant channel set"""
360385
# figure out which set of channels to use
361386
# we use the stored amplitudes to do this rather than computing a
@@ -405,7 +430,9 @@ def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_s
405430
)
406431

407432
if waveforms is None:
408-
waveforms = np.empty((in_unit.size, t * pca_channels.size), dtype=batch.dtype) #POTENTIALLY WAY TOO BIG
433+
waveforms = np.empty(
434+
(in_unit.size, t * pca_channels.size), dtype=batch.dtype
435+
) # POTENTIALLY WAY TOO BIG
409436
waveforms[bs:be] = batch.reshape(n_batch, -1)
410437

411438
# figure out which waveforms actually overlap with the requested channels

0 commit comments

Comments
 (0)