Skip to content

Commit

Permalink
Log relocated amplitudes as well
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Jan 29, 2024
1 parent 82e9306 commit 6b4d2c4
Showing 1 changed file with 44 additions and 16 deletions.
60 changes: 44 additions & 16 deletions src/dartsort/cluster/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from dartsort.util import drift_util, waveform_util
from dartsort.util.multiprocessing_util import get_pool
from hdbscan import HDBSCAN
from scipy.spatial import KDTree
from scipy.spatial.distance import pdist
from sklearn.decomposition import PCA
from tqdm.auto import tqdm
Expand Down Expand Up @@ -86,17 +85,25 @@ def split_clusters(
if recursive:
new_units = np.unique(new_untriaged_labels)
for i in new_units:
jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i), max_size_wfs))
jobs.append(
pool.submit(
_split_job, np.flatnonzero(labels == i), max_size_wfs
)
)
if show_progress:
iterator.total += len(new_units)
elif split_big:
new_units = np.unique(new_untriaged_labels)
for i in new_units:
idx = np.flatnonzero(new_untriaged_labels == i)
tall = split_result.x[idx].ptp() > split_big_kw['dx']
wide = split_result.z_reg[idx].ptp() > split_big_kw['dx']
if (tall or wide) and len(idx) > split_big_kw['min_size_split']:
jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i), max_size_wfs))
tall = split_result.x[idx].ptp() > split_big_kw["dx"]
wide = split_result.z_reg[idx].ptp() > split_big_kw["dx"]
if (tall or wide) and len(idx) > split_big_kw["min_size_split"]:
jobs.append(
pool.submit(
_split_job, np.flatnonzero(labels == i), max_size_wfs
)
)
if show_progress:
iterator.total += 1

Expand Down Expand Up @@ -227,18 +234,17 @@ def __init__(
)

def split_cluster(self, in_unit_all, max_size_wfs):


n_spikes = in_unit_all.size
if max_size_wfs is not None and n_spikes > max_size_wfs:
#TODO: max_size_wfs could be chosen automatically based on available memory and number of spikes
# TODO: max_size_wfs could be chosen automatically based on available memory and number of spikes
idx_subsample = np.random.choice(n_spikes, max_size_wfs, replace=False)
idx_subsample.sort()
in_unit = in_unit_all[idx_subsample]
subsampling = True
else:
else:
in_unit = in_unit_all
subsampling = False

if n_spikes < self.min_cluster_size:
return SplitResult()

Expand Down Expand Up @@ -293,9 +299,17 @@ def split_cluster(self, in_unit_all, max_size_wfs):
new_labels[idx_subsample[kept]] = hdb_labels

if self.use_localization_features:
return SplitResult(is_split=is_split, in_unit=in_unit_all, new_labels=new_labels, x=loc_features[:, 0], z_reg=loc_features[:, 1])
return SplitResult(
is_split=is_split,
in_unit=in_unit_all,
new_labels=new_labels,
x=loc_features[:, 0],
z_reg=loc_features[:, 1],
)
else:
return SplitResult(is_split=is_split, in_unit=in_unit_all, new_labels=new_labels)
return SplitResult(
is_split=is_split, in_unit=in_unit_all, new_labels=new_labels
)

def get_registered_channels(self, in_unit):
n_pitches_shift = drift_util.get_spike_pitch_shifts(
Expand Down Expand Up @@ -329,13 +343,21 @@ def get_registered_channels(self, in_unit):
)
kept = np.flatnonzero(~np.isnan(reloc_amp_vecs).any(axis=1))
reloc_amplitudes = np.nanmax(reloc_amp_vecs[kept], axis=1)
reloc_amplitudes = np.log(self.log_c + reloc_amplitudes)
else:
reloc_amplitudes = None
kept = np.arange(in_unit.size)

return max_registered_channel, n_pitches_shift, reloc_amplitudes, kept

def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_size=1_000, max_samples_pca=50_000):
def pca_features(
self,
in_unit,
max_registered_channel,
n_pitches_shift,
batch_size=1_000,
max_samples_pca=50_000,
):
"""Compute relocated PCA features on a drift-invariant channel set"""
# figure out which set of channels to use
# we use the stored amplitudes to do this rather than computing a
Expand Down Expand Up @@ -385,7 +407,9 @@ def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_s
)

if waveforms is None:
waveforms = np.empty((in_unit.size, t * pca_channels.size), dtype=batch.dtype) #POTENTIALLY WAY TOO BIG
waveforms = np.empty(
(in_unit.size, t * pca_channels.size), dtype=batch.dtype
) # POTENTIALLY WAY TOO BIG
waveforms[bs:be] = batch.reshape(n_batch, -1)

# figure out which waveforms actually overlap with the requested channels
Expand All @@ -401,11 +425,15 @@ def pca_features(self, in_unit, max_registered_channel, n_pitches_shift, batch_s
)
fit_indices = no_nan
if fit_indices.size > max_samples_pca:
fit_indices = self.rg.choice(fit_indices, size=max_samples_pca, replace=False)
fit_indices = self.rg.choice(
fit_indices, size=max_samples_pca, replace=False
)
pca.fit(waveforms[fit_indices])

# embed into the cluster's PCA space
pca_projs = np.full((waveforms.shape[0], self.n_pca_features), np.nan, dtype=waveforms.dtype)
pca_projs = np.full(
(waveforms.shape[0], self.n_pca_features), np.nan, dtype=waveforms.dtype
)
for bs in range(0, no_nan.size, batch_size):
be = min(no_nan.size, bs + batch_size)
pca_projs[no_nan[bs:be]] = pca.transform(waveforms[no_nan[bs:be]])
Expand Down

0 comments on commit 6b4d2c4

Please sign in to comment.