Skip to content

Commit

Permalink
Merge branch 'main' of github.com:cwindolf/spike-psvae into main
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Oct 27, 2023
2 parents 28da396 + 49746f5 commit b637b19
Show file tree
Hide file tree
Showing 7 changed files with 780 additions and 38 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pytest
ibl-neuropixel
spikeinterface
cloudpickle
cloudpickle
hdbscan
34 changes: 9 additions & 25 deletions src/dartsort/cluster/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,15 @@ def split_clusters(
new_labels = split_result.new_labels
triaged = split_result.new_labels < 0
labels[in_unit[triaged]] = new_labels[triaged]
labels[in_unit[new_labels > 0]] = (
cur_max_label + new_labels[new_labels > 0]
)
labels[in_unit[new_labels > 0]] = cur_max_label + new_labels[new_labels > 0]
new_untriaged_labels = labels[in_unit[new_labels >= 0]]
cur_max_label = new_untriaged_labels.max()

# submit recursive jobs to the pool, if any
if recursive:
new_units = np.unique(new_untriaged_labels)
for i in new_units:
jobs.append(
pool.submit(_split_job, np.flatnonzero(labels == i))
)
jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i)))
if show_progress:
iterator.total += len(new_units)

Expand Down Expand Up @@ -151,7 +147,7 @@ def __init__(
min_cluster_size=25,
min_samples=25,
cluster_selection_epsilon=25,
reassign_outliers=True,
reassign_outliers=False,
random_state=0,
**dataset_name_kwargs,
):
Expand Down Expand Up @@ -241,18 +237,14 @@ def split_cluster(self, in_unit):
is_split = np.setdiff1d(np.unique(hdb_labels), [-1]).size > 1

if is_split and self.reassign_outliers:
hdb_labels = cluster_util.knn_reassign_outliers(
hdb_labels, features
)
hdb_labels = cluster_util.knn_reassign_outliers(hdb_labels, features)

new_labels = None
if is_split:
new_labels = np.full(n_spikes, -1)
new_labels[kept] = hdb_labels

return SplitResult(
is_split=is_split, in_unit=in_unit, new_labels=new_labels
)
return SplitResult(is_split=is_split, in_unit=in_unit, new_labels=new_labels)

def pca_features(self, in_unit):
"""Compute relocated PCA features on a drift-invariant channel set"""
Expand Down Expand Up @@ -316,12 +308,8 @@ def pca_features(self, in_unit):
return False, no_nan, None

# fit pca and embed
pca = PCA(
self.n_pca_features, random_state=self.random_state, whiten=True
)
pca_projs = np.full(
(n, self.n_pca_features), np.nan, dtype=waveforms.dtype
)
pca = PCA(self.n_pca_features, random_state=self.random_state, whiten=True)
pca_projs = np.full((n, self.n_pca_features), np.nan, dtype=waveforms.dtype)
pca_projs[no_nan] = pca.fit_transform(waveforms[no_nan])

return True, no_nan, pca_projs
Expand Down Expand Up @@ -386,9 +374,7 @@ def initialize_from_h5(

# this is to help split_clusters take a string argument
all_split_strategies = [FeatureSplit]
split_strategies_by_class_name = {
cls.__name__: cls for cls in all_split_strategies
}
split_strategies_by_class_name = {cls.__name__: cls for cls in all_split_strategies}

# -- parallelism widgets

Expand All @@ -404,9 +390,7 @@ def __init__(self, split_strategy):
def _split_job_init(split_strategy_class_name, split_strategy_kwargs):
global _split_job_context
split_strategy = split_strategies_by_class_name[split_strategy_class_name]
_split_job_context = SplitJobContext(
split_strategy(**split_strategy_kwargs)
)
_split_job_context = SplitJobContext(split_strategy(**split_strategy_kwargs))


def _split_job(in_unit):
Expand Down
Loading

0 comments on commit b637b19

Please sign in to comment.