Skip to content

Commit e2de16d

Browse files
committed
Un-logged amplitudes in amplitude-normalized PCA
1 parent 614f7e6 commit e2de16d

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/dartsort/cluster/split.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,6 @@ def split_cluster(self, in_unit_all, max_size_wfs):
291291
max_registered_channel,
292292
n_pitches_shift,
293293
amplitude_normalized=self.amplitude_normalized,
294-
a=self.localization_features[in_unit, 2],
295294
)
296295

297296
if not enough_good_spikes:
@@ -448,7 +447,6 @@ def pca_features(
448447
batch_size=1_000,
449448
max_samples_pca=50_000,
450449
amplitude_normalized=False,
451-
a=None,
452450
):
453451
"""Compute relocated PCA features on a drift-invariant channel set"""
454452
# figure out which set of channels to use
@@ -527,7 +525,8 @@ def pca_features(
527525
fit_indices, size=max_samples_pca, replace=False
528526
)
529527
if amplitude_normalized:
530-
pca.fit(waveforms[fit_indices] / a[fit_indices, None])
528+
amplitudes = self.amplitudes[in_unit][:, None]
529+
pca.fit(waveforms[fit_indices] / amplitudes[fit_indices])
531530
else:
532531
pca.fit(waveforms[fit_indices])
533532

@@ -539,7 +538,7 @@ def pca_features(
539538
for bs in range(0, no_nan.size, batch_size):
540539
be = min(no_nan.size, bs + batch_size)
541540
pca_projs[no_nan[bs:be]] = pca.transform(
542-
waveforms[no_nan[bs:be]] / a[no_nan[bs:be], None]
541+
waveforms[no_nan[bs:be]] / amplitudes[no_nan[bs:be], None]
543542
)
544543
else:
545544
for bs in range(0, no_nan.size, batch_size):

0 commit comments

Comments
 (0)