From 6b4d2c45fda238c1586e291b95901040b80b2894 Mon Sep 17 00:00:00 2001
From: Charlie Windolf <cwindolf95@gmail.com>
Date: Mon, 29 Jan 2024 14:37:25 -0500
Subject: [PATCH] Log relocated amplitudes as well

---
 src/dartsort/cluster/split.py | 60 +++++++++++++++++++++++++----------
 1 file changed, 44 insertions(+), 16 deletions(-)

diff --git a/src/dartsort/cluster/split.py b/src/dartsort/cluster/split.py
index df271d64..1236eaf6 100644
--- a/src/dartsort/cluster/split.py
+++ b/src/dartsort/cluster/split.py
@@ -7,7 +7,6 @@
 from dartsort.util import drift_util, waveform_util
 from dartsort.util.multiprocessing_util import get_pool
 from hdbscan import HDBSCAN
-from scipy.spatial import KDTree
 from scipy.spatial.distance import pdist
 from sklearn.decomposition import PCA
 from tqdm.auto import tqdm
@@ -86,17 +85,25 @@ def split_clusters(
             if recursive:
                 new_units = np.unique(new_untriaged_labels)
                 for i in new_units:
-                    jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i), max_size_wfs))
+                    jobs.append(
+                        pool.submit(
+                            _split_job, np.flatnonzero(labels == i), max_size_wfs
+                        )
+                    )
                 if show_progress:
                     iterator.total += len(new_units)
             elif split_big:
                 new_units = np.unique(new_untriaged_labels)
                 for i in new_units:
                     idx = np.flatnonzero(new_untriaged_labels == i)
-                    tall = split_result.x[idx].ptp() > split_big_kw['dx']
-                    wide = split_result.z_reg[idx].ptp() > split_big_kw['dx']
-                    if (tall or wide) and len(idx) > split_big_kw['min_size_split']:
-                        jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i), max_size_wfs))
+                    tall = split_result.x[idx].ptp() > split_big_kw["dx"]
+                    wide = split_result.z_reg[idx].ptp() > split_big_kw["dx"]
+                    if (tall or wide) and len(idx) > split_big_kw["min_size_split"]:
+                        jobs.append(
+                            pool.submit(
+                                _split_job, np.flatnonzero(labels == i), max_size_wfs
+                            )
+                        )
                         if show_progress:
                             iterator.total += 1
 
@@ -227,18 +234,17 @@ def __init__(
         )
 
     def split_cluster(self, in_unit_all, max_size_wfs):
-        
-        
         n_spikes = in_unit_all.size
         if max_size_wfs is not None and n_spikes > max_size_wfs:
-            #TODO: max_size_wfs could be chosen automatically based on available memory and number of spikes
+            # TODO: max_size_wfs could be chosen automatically based on available memory and number of spikes
             idx_subsample = np.random.choice(n_spikes, max_size_wfs, replace=False)
             idx_subsample.sort()
             in_unit = in_unit_all[idx_subsample]
             subsampling = True
-        else: 
+        else:
             in_unit = in_unit_all
             subsampling = False
+
         if n_spikes < self.min_cluster_size:
             return SplitResult()
 
@@ -293,9 +299,17 @@ def split_cluster(self, in_unit_all, max_size_wfs):
                 new_labels[idx_subsample[kept]] = hdb_labels
 
         if self.use_localization_features:
-            return SplitResult(is_split=is_split, in_unit=in_unit_all, new_labels=new_labels, x=loc_features[:, 0], z_reg=loc_features[:, 1])
+            return SplitResult(
+                is_split=is_split,
+                in_unit=in_unit_all,
+                new_labels=new_labels,
+                x=loc_features[:, 0],
+                z_reg=loc_features[:, 1],
+            )
         else:
-            return SplitResult(is_split=is_split, in_unit=in_unit_all, new_labels=new_labels)
+            return SplitResult(
+                is_split=is_split, in_unit=in_unit_all, new_labels=new_labels
+            )
 
     def get_registered_channels(self, in_unit):
         n_pitches_shift = drift_util.get_spike_pitch_shifts(
@@ -329,13 +343,21 @@ def get_registered_channels(self, in_unit):
             )
             kept = np.flatnonzero(~np.isnan(reloc_amp_vecs).any(axis=1))
             reloc_amplitudes = np.nanmax(reloc_amp_vecs[kept], axis=1)
+            reloc_amplitudes = np.log(self.log_c + reloc_amplitudes)
         else:
             reloc_amplitudes = None
             kept = np.arange(in_unit.size)
 
         return max_registered_channel, n_pitches_shift, reloc_amplitudes, kept
 
-    def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_size=1_000, max_samples_pca=50_000):
+    def pca_features(
+        self,
+        in_unit,
+        max_registered_channel,
+        n_pitches_shift,
+        batch_size=1_000,
+        max_samples_pca=50_000,
+    ):
         """Compute relocated PCA features on a drift-invariant channel set"""
         # figure out which set of channels to use
         # we use the stored amplitudes to do this rather than computing a
@@ -385,7 +407,9 @@ def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_s
                 )
 
             if waveforms is None:
-                waveforms = np.empty((in_unit.size, t * pca_channels.size), dtype=batch.dtype) #POTENTIALLY WAY TOO BIG
+                waveforms = np.empty(
+                    (in_unit.size, t * pca_channels.size), dtype=batch.dtype
+                )  # POTENTIALLY WAY TOO BIG
             waveforms[bs:be] = batch.reshape(n_batch, -1)
 
         # figure out which waveforms actually overlap with the requested channels
@@ -401,11 +425,15 @@ def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_s
         )
         fit_indices = no_nan
         if fit_indices.size > max_samples_pca:
-            fit_indices = self.rg.choice(fit_indices, size=max_samples_pca, replace=False)
+            fit_indices = self.rg.choice(
+                fit_indices, size=max_samples_pca, replace=False
+            )
         pca.fit(waveforms[fit_indices])
 
         # embed into the cluster's PCA space
-        pca_projs = np.full((waveforms.shape[0], self.n_pca_features), np.nan, dtype=waveforms.dtype)
+        pca_projs = np.full(
+            (waveforms.shape[0], self.n_pca_features), np.nan, dtype=waveforms.dtype
+        )
         for bs in range(0, no_nan.size, batch_size):
             be = min(no_nan.size, bs + batch_size)
             pca_projs[no_nan[bs:be]] = pca.transform(waveforms[no_nan[bs:be]])