Skip to content

Commit

Permalink
fix config conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Julien Boussard committed Dec 6, 2023
2 parents 4345448 + 96e1246 commit 66ed74e
Show file tree
Hide file tree
Showing 37 changed files with 4,120 additions and 545 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ dependencies:
- h5py
- tqdm
- scikit-learn
- colorcet
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pytest
ibl-neuropixel
spikeinterface
cloudpickle
cloudpickle
2 changes: 1 addition & 1 deletion scripts/uhd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@
# Don't trust spikeinterface preprocessing :( ...
if preprocessing:
print("Preprocessing...")
preprocessing_dir = Path(output_all) / "preprocessing_test"
preprocessing_dir = Path(output_all) / "preprocessing"
Path(preprocessing_dir).mkdir(exist_ok=True)
if t_end_preproc is None:
t_end_preproc=rec_len_sec
Expand Down
236 changes: 236 additions & 0 deletions src/dartsort/cluster/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from dataclasses import replace
from typing import Optional

import numpy as np
from dartsort.config import TemplateConfig
from dartsort.templates import TemplateData, template_util
from dartsort.templates.pairwise_util import (
construct_shift_indices, iterate_compressed_pairwise_convolutions)
from dartsort.util.data_util import DARTsortSorting
from scipy.cluster.hierarchy import complete, fcluster


def merge_templates(
sorting: DARTsortSorting,
recording,
template_data: Optional[TemplateData] = None,
template_config: Optional[TemplateConfig] = None,
motion_est=None,
max_shift_samples=20,
superres_linkage=np.max,
merge_distance_threshold=0.25,
temporal_upsampling_factor=8,
amplitude_scaling_variance=0.0,
amplitude_scaling_boundary=0.5,
svd_compression_rank=10,
min_channel_amplitude=0.0,
conv_batch_size=1024,
units_batch_size=8,
device=None,
n_jobs=0,
n_jobs_templates=0,
template_save_folder=None,
overwrite_templates=False,
show_progress=True,
template_npz_filename="template_data.npz",
) -> DARTsortSorting:
"""Template distance based merge
Pass in a sorting, recording and template config to make templates,
and this will merge them (with superres). Or, if you have templates
already, pass them into template_data and we can skip the template
construction.
Arguments
---------
max_shift_samples
Max offset during matching
superres_linkage
How to combine distances between two units' superres templates
By default, it's the max.
amplitude_scaling_*
Optionally allow scaling during matching
Returns
-------
A new DARTsortSorting
"""
if template_data is None:
template_data = TemplateData.from_config(
recording,
sorting,
template_config,
motion_est=motion_est,
n_jobs=n_jobs_templates,
save_folder=template_save_folder,
overwrite=overwrite_templates,
device=device,
save_npz_name=template_npz_filename,
)

# allocate distance + shift matrices. shifts[i,j] is trough[j]-trough[i].
n_templates = template_data.templates.shape[0]
sup_dists = np.full((n_templates, n_templates), np.inf)
sup_shifts = np.zeros((n_templates, n_templates), dtype=int)

# build distance matrix
dec_res_iter = get_deconv_resid_norm_iter(
template_data,
max_shift_samples=max_shift_samples,
temporal_upsampling_factor=temporal_upsampling_factor,
amplitude_scaling_variance=amplitude_scaling_variance,
amplitude_scaling_boundary=amplitude_scaling_boundary,
svd_compression_rank=svd_compression_rank,
min_channel_amplitude=min_channel_amplitude,
conv_batch_size=conv_batch_size,
units_batch_size=units_batch_size,
device=device,
n_jobs=n_jobs,
show_progress=show_progress,
)
for res in dec_res_iter:
tixa = res.template_indices_a
tixb = res.template_indices_b
rms_ratio = res.deconv_resid_norms / res.template_a_norms
sup_dists[tixa, tixb] = rms_ratio
sup_shifts[tixa, tixb] = res.shifts

# apply linkage to reduce across superres templates
units = np.unique(template_data.unit_ids)
if units.size < n_templates:
dists = np.full((units.size, units.size), np.inf)
shifts = np.zeros((units.size, units.size), dtype=int)
for ia, ua in enumerate(units):
in_ua = np.flatnonzero(template_data.unit_ids == ua)
for ib, ub in enumerate(units):
in_ub = np.flatnonzero(template_data.unit_ids == ub)
in_pair = (in_ua[:, None], in_ub[None, :])
dists[ia, ib] = superres_linkage(sup_dists[in_pair])
shifts[ia, ib] = np.median(sup_shifts[in_pair])
coarse_td = template_data.coarsen(with_locs=False)
template_snrs = coarse_td.templates.ptp(1).max(1) / coarse_td.spike_counts
else:
dists = sup_dists
shifts = sup_shifts
template_snrs = (
template_data.templates.ptp(1).max(1) / template_data.spike_counts
)

# now run hierarchical clustering
return recluster(
sorting,
units,
dists,
shifts,
template_snrs,
merge_distance_threshold=merge_distance_threshold,
)


def recluster(sorting, units, dists, shifts, template_snrs, merge_distance_threshold=0.25):
# upper triangle not including diagonal, aka condensed distance matrix in scipy
pdist = dists[np.triu_indices(dists.shape[0], k=1)]
# scipy hierarchical clustering only supports finite values, so let's just
# drop in a huge value here
pdist[~np.isfinite(pdist)] = 1_000_000 + pdist[np.isfinite(pdist)].max()
# complete linkage: max dist between all pairs across clusters.
Z = complete(pdist)
# extract flat clustering using our max dist threshold
new_labels = fcluster(Z, merge_distance_threshold, criterion="distance")

# update labels
labels_updated = sorting.labels.copy()
kept = np.flatnonzero(np.isin(sorting.labels, units))
_, flat_labels = np.unique(labels_updated[kept], return_inverse=True)
labels_updated[kept] = new_labels[flat_labels]

# update times according to shifts
times_updated = sorting.times_samples.copy()

# find original labels in each cluster
clust_inverse = {i: [] for i in new_labels}
for orig_label, new_label in enumerate(new_labels):
clust_inverse[new_label].append(orig_label)

# align to best snr unit
for new_label, orig_labels in clust_inverse.items():
# we don't need to realign clusters which didn't change
if len(orig_labels) <= 1:
continue

orig_snrs = template_snrs[orig_labels]
best_orig = orig_labels[orig_snrs.argmax()]
for ogl in np.setdiff1d(orig_labels, [best_orig]):
in_orig_unit = np.flatnonzero(sorting.labels == ogl)
# this is like trough[best] - trough[ogl]
shift_og_best = shifts[ogl, best_orig]
# if >0, trough of og is behind trough of best.
# subtracting will move trough of og to the right.
times_updated[in_orig_unit] -= shift_og_best

return replace(sorting, times_samples=times_updated, labels=labels_updated)


def get_deconv_resid_norm_iter(
template_data,
max_shift_samples=20,
temporal_upsampling_factor=8,
amplitude_scaling_variance=0.0,
amplitude_scaling_boundary=0.5,
svd_compression_rank=10,
min_channel_amplitude=0.0,
conv_batch_size=1024,
units_batch_size=8,
device=None,
n_jobs=0,
show_progress=True,
):
# get template aux data
low_rank_templates = template_util.svd_compress_templates(
template_data.templates,
min_channel_amplitude=min_channel_amplitude,
rank=svd_compression_rank,
)
compressed_upsampled_temporal = template_util.compressed_upsampled_templates(
low_rank_templates.temporal_components,
ptps=template_data.templates.ptp(1).max(1),
max_upsample=temporal_upsampling_factor,
)

# construct helper data and run pairwise convolutions
(
template_shift_index_a,
template_shift_index_b,
upsampled_shifted_template_index,
cooccurrence,
) = construct_shift_indices(
None,
None,
template_data,
compressed_upsampled_temporal,
motion_est=None,
)
yield from iterate_compressed_pairwise_convolutions(
template_data,
low_rank_templates,
template_data,
low_rank_templates,
compressed_upsampled_temporal,
template_shift_index_a,
template_shift_index_b,
cooccurrence,
upsampled_shifted_template_index,
do_shifting=False,
reduce_deconv_resid_norm=True,
geom=template_data.registered_geom,
conv_ignore_threshold=0.0,
coarse_approx_error_threshold=0.0,
amplitude_scaling_variance=amplitude_scaling_variance,
amplitude_scaling_boundary=amplitude_scaling_boundary,
max_shift=max_shift_samples,
conv_batch_size=conv_batch_size,
units_batch_size=units_batch_size,
device=device,
n_jobs=n_jobs,
show_progress=show_progress,
)
36 changes: 10 additions & 26 deletions src/dartsort/cluster/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,15 @@ def split_clusters(
new_labels = split_result.new_labels
triaged = split_result.new_labels < 0
labels[in_unit[triaged]] = new_labels[triaged]
labels[in_unit[new_labels > 0]] = (
cur_max_label + new_labels[new_labels > 0]
)
labels[in_unit[new_labels > 0]] = cur_max_label + new_labels[new_labels > 0]
new_untriaged_labels = labels[in_unit[new_labels >= 0]]
cur_max_label = new_untriaged_labels.max()

# submit recursive jobs to the pool, if any
if recursive:
new_units = np.unique(new_untriaged_labels)
for i in new_units:
jobs.append(
pool.submit(_split_job, np.flatnonzero(labels == i))
)
jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i)))
if show_progress:
iterator.total += len(new_units)

Expand Down Expand Up @@ -151,7 +147,7 @@ def __init__(
min_cluster_size=25,
min_samples=25,
cluster_selection_epsilon=25,
reassign_outliers=True,
reassign_outliers=False,
random_state=0,
**dataset_name_kwargs,
):
Expand Down Expand Up @@ -241,18 +237,14 @@ def split_cluster(self, in_unit):
is_split = np.setdiff1d(np.unique(hdb_labels), [-1]).size > 1

if is_split and self.reassign_outliers:
hdb_labels = cluster_util.knn_reassign_outliers(
hdb_labels, features
)
hdb_labels = cluster_util.knn_reassign_outliers(hdb_labels, features)

new_labels = None
if is_split:
new_labels = np.full(n_spikes, -1)
new_labels[kept] = hdb_labels

return SplitResult(
is_split=is_split, in_unit=in_unit, new_labels=new_labels
)
return SplitResult(is_split=is_split, in_unit=in_unit, new_labels=new_labels)

def pca_features(self, in_unit):
"""Compute relocated PCA features on a drift-invariant channel set"""
Expand Down Expand Up @@ -316,12 +308,8 @@ def pca_features(self, in_unit):
return False, no_nan, None

# fit pca and embed
pca = PCA(
self.n_pca_features, random_state=self.random_state, whiten=True
)
pca_projs = np.full(
(n, self.n_pca_features), np.nan, dtype=waveforms.dtype
)
pca = PCA(self.n_pca_features, random_state=self.random_state, whiten=True)
pca_projs = np.full((n, self.n_pca_features), np.nan, dtype=waveforms.dtype)
pca_projs[no_nan] = pca.fit_transform(waveforms[no_nan])

return True, no_nan, pca_projs
Expand All @@ -335,7 +323,7 @@ def initialize_from_h5(
amplitudes_dataset_name="denoised_amplitudes",
amplitude_vectors_dataset_name="denoised_amplitude_vectors",
):
h5 = h5py.File(peeling_hdf5_filename, "r")
h5 = h5py.File(peeling_hdf5_filename, "r", locking=False)
self.geom = h5["geom"][:]
self.channel_index = h5["channel_index"][:]

Expand Down Expand Up @@ -386,9 +374,7 @@ def initialize_from_h5(

# this is to help split_clusters take a string argument
all_split_strategies = [FeatureSplit]
split_strategies_by_class_name = {
cls.__name__: cls for cls in all_split_strategies
}
split_strategies_by_class_name = {cls.__name__: cls for cls in all_split_strategies}

# -- parallelism widgets

Expand All @@ -404,9 +390,7 @@ def __init__(self, split_strategy):
def _split_job_init(split_strategy_class_name, split_strategy_kwargs):
global _split_job_context
split_strategy = split_strategies_by_class_name[split_strategy_class_name]
_split_job_context = SplitJobContext(
split_strategy(**split_strategy_kwargs)
)
_split_job_context = SplitJobContext(split_strategy(**split_strategy_kwargs))


def _split_job(in_unit):
Expand Down
9 changes: 6 additions & 3 deletions src/dartsort/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class FeaturizationConfig:
localization_radius: float = 100.0
# these are saved always if do_localization
save_amplitude_vectors: bool = True
localization_model = "dipole"

# -- further info about denoising
# in the future we may add multi-channel or other nns
Expand All @@ -96,7 +97,7 @@ class FeaturizationConfig:
class SubtractionConfig:
trough_offset_samples: int = 42
spike_length_samples: int = 121
detection_thresholds: List[int] = (12, 10, 8, 6, 5, 4)
detection_thresholds: List[int] = (12, 9, 6)
chunk_length_samples: int = 30_000
peak_sign: str = "neg"
spatial_dedup_radius: float = 150.0
Expand Down Expand Up @@ -153,14 +154,16 @@ class MatchingConfig:
fit_subsampling_random_state: int = 0

# template matching parameters
threshold: float = 30.0
threshold: float = 50.0
template_svd_compression_rank: int = 10
template_temporal_upsampling_factor: int = 8
template_min_channel_amplitude: float = 1.0
refractory_radius_frames: int = 10
amplitude_scaling_variance: float = 0.0
amplitude_scaling_boundary: float = 0.5
max_iter: int = 1000
conv_ignore_threshold: float = 5.0
coarse_approx_error_threshold: float = 5.0

@dataclass(frozen=True)
class ClusteringConfig:
Expand All @@ -174,4 +177,4 @@ class ClusteringConfig:
# -- ensembling
ensemble_strategy: Optional[str] = "forward_backward"
chunk_size_s: int = 300
# forward-backward
# forward-backward
Loading

0 comments on commit 66ed74e

Please sign in to comment.