Skip to content

Commit 26141fa

Browse files
committed
Fix motion awareness in split; debug split/merge ensemble and plots
1 parent 0ec5612 commit 26141fa

File tree

9 files changed

+55
-15
lines changed

9 files changed

+55
-15
lines changed

src/dartsort/cluster/density.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def nearest_higher_density_neighbor(
7272
distances, indices = distances[:, 1:].copy(), indices[:, 1:].copy()
7373

7474
# find lowest distance higher density neighbor
75+
print(f"nhdn {density.shape=}")
7576
density_padded = np.pad(density, (0, 1), constant_values=np.inf)
7677
is_lower_density = density_padded[indices] <= density[:, None]
7778
distances[is_lower_density] = np.inf
@@ -151,6 +152,8 @@ def density_peaks_clustering(
151152
density = get_smoothed_densities(X, inliers=inliers, sigmas=sigmas)
152153
if sigma_regional is not None:
153154
density = density[0] / density[1]
155+
else:
156+
density = density[0]
154157

155158
nhdn, distances, indices = nearest_higher_density_neighbor(
156159
kdtree,
@@ -182,6 +185,8 @@ def density_peaks_clustering(
182185
if remove_clusters_smaller_than:
183186
labels = decrumb(labels, min_size=remove_clusters_smaller_than)
184187

188+
print("dpc", np.unique(labels).size)
189+
185190
if not return_extra:
186191
return labels
187192

src/dartsort/cluster/ensemble_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def split_merge_ensemble(
3232
split_strategy=split_merge_config.split_strategy,
3333
recursive=split_merge_config.recursive_split,
3434
n_jobs=n_jobs_split,
35+
motion_est=motion_est,
3536
show_progress=False
3637
)
3738
for sorting in tqdm(chunk_sortings, desc="Split within chunks")

src/dartsort/cluster/initial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,10 @@ def ensemble_chunks(
226226
recording,
227227
chunk_sortings,
228228
motion_est=motion_est,
229-
split_merge_config=clustering_config.split_merge_config,
229+
split_merge_config=clustering_config.split_merge_ensemble_config,
230230
n_jobs_split=computation_config.n_jobs_cpu,
231231
n_jobs_merge=computation_config.actual_n_jobs_gpu,
232-
device=None,
232+
device=computation_config.actual_device,
233233
show_progress=True,
234234
)
235235

src/dartsort/cluster/merge.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ def merge_templates(
6060
-------
6161
A new DARTsortSorting
6262
"""
63+
print("merge input", np.unique(sorting.labels).size - 1)
6364
if template_data is None:
64-
template_data = TemplateData.from_config(
65+
template_data, sorting = TemplateData.from_config(
6566
recording,
6667
sorting,
6768
template_config,
@@ -71,6 +72,7 @@ def merge_templates(
7172
overwrite=overwrite_templates,
7273
device=device,
7374
save_npz_name=template_npz_filename,
75+
return_realigned_sorting=True,
7476
)
7577

7678
units, dists, shifts, template_snrs = calculate_merge_distances(
@@ -246,6 +248,10 @@ def calculate_merge_distances(
246248
show_progress=show_progress,
247249
)
248250
for res in dec_res_iter:
251+
if res is None:
252+
# all pairs in chunk were ignored for one reason or another
253+
continue
254+
249255
tixa = res.template_indices_a
250256
tixb = res.template_indices_b
251257
rms_ratio = res.deconv_resid_norms / res.template_a_norms
@@ -359,7 +365,12 @@ def recluster(
359365
pdist = dists[np.triu_indices(dists.shape[0], k=1)]
360366
# scipy hierarchical clustering only supports finite values, so let's just
361367
# drop in a huge value here
362-
pdist[~np.isfinite(pdist)] = 1_000_000 + pdist[np.isfinite(pdist)].max()
368+
finite = np.isfinite(pdist)
369+
if not finite.any():
370+
print("no merges")
371+
return sorting
372+
373+
pdist[~finite] = 1_000_000 + pdist[finite].max()
363374
# complete linkage: max dist between all pairs across clusters.
364375
Z = complete(pdist)
365376
# extract flat clustering using our max dist threshold
@@ -378,6 +389,7 @@ def recluster(
378389
clust_inverse = {i: [] for i in new_labels}
379390
for orig_label, new_label in enumerate(new_labels):
380391
clust_inverse[new_label].append(orig_label)
392+
print(sum(len(v) - 1 for v in clust_inverse.values()), "merges")
381393

382394
# align to best snr unit
383395
for new_label, orig_labels in clust_inverse.items():
@@ -409,8 +421,8 @@ def cross_match(
409421
units_b,
410422
merge_distance_threshold=0.5,
411423
):
412-
assert np.array_equal(units_a, sorting_a.units)
413-
assert np.array_equal(units_b, sorting_b.units)
424+
assert np.array_equal(units_a, sorting_a.unit_ids)
425+
assert np.array_equal(units_b, sorting_b.unit_ids)
414426

415427
ia, ib = np.nonzero(dists <= merge_distance_threshold)
416428
weights = coo_array((-dists[ia, ib], (ia.astype(np.intc), ib.astype(np.intc))))

src/dartsort/cluster/split.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import dataclass, replace
22
from typing import Optional
3+
from pathlib import Path
34

45
import h5py
56
import numpy as np
@@ -24,6 +25,7 @@ def split_clusters(
2425
recursive=False,
2526
split_big=False,
2627
split_big_kw=dict(dz=40, dx=48, min_size_split=50),
28+
motion_est=None,
2729
show_progress=True,
2830
n_jobs=0,
2931
):
@@ -50,6 +52,11 @@ def split_clusters(
5052
labels_to_process = list(labels_to_process[labels_to_process > 0])
5153
cur_max_label = max(labels_to_process)
5254

55+
if split_strategy_kwargs is None:
56+
split_strategy_kwargs = {}
57+
if motion_est is not None:
58+
split_strategy_kwargs["motion_est"] = motion_est
59+
5360
n_jobs, Executor, context = get_pool(n_jobs)
5461
with Executor(
5562
max_workers=n_jobs,
@@ -667,6 +674,7 @@ def initialize_from_h5(
667674
amplitudes_dataset_name="denoised_ptp_amplitudes",
668675
amplitude_vectors_dataset_name="denoised_ptp_amplitude_vectors",
669676
):
677+
peeling_hdf5_filename = Path(peeling_hdf5_filename)
670678
h5 = h5py.File(peeling_hdf5_filename, "r")
671679
self.geom = h5["geom"][:]
672680
self.channel_index = h5["channel_index"][:]
@@ -704,6 +712,11 @@ def initialize_from_h5(
704712
self.tpca_features = h5[tpca_features_dataset_name]
705713
self.match_distance = pdist(self.geom).min() / 2
706714

715+
if peeling_featurization_pt is None:
716+
mdir = peeling_hdf5_filename.parent / f"{peeling_hdf5_filename.stem}_models"
717+
peeling_featurization_pt = mdir / "featurization_pipeline.pt"
718+
assert peeling_featurization_pt.exists()
719+
707720
if self.n_pca_features and self.relocated:
708721
# load up featurization pipeline for tpca inversion
709722
assert peeling_featurization_pt is not None

src/dartsort/config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ class SplitMergeConfig:
206206
merge_template_config: TemplateConfig = TemplateConfig(superres_templates=False)
207207
merge_distance_threshold: float = 0.25
208208
cross_merge_distance_threshold: float = 0.5
209-
min_spatial_cosine: float = 0.5
209+
min_spatial_cosine: float = 0.0
210210

211211

212212
@dataclass(frozen=True)
@@ -235,15 +235,17 @@ class ClusteringConfig:
235235
# density peaks parameters
236236
sigma_local: float = 5.0
237237
sigma_regional: Optional[float] = None
238-
n_neighbors_search: int = 10
238+
n_neighbors_search: int = 20
239239
radius_search: float = 5.0
240240
remove_clusters_smaller_than: int = 10
241241
noise_density: float = 0.0
242242

243243
# -- ensembling
244244
ensemble_strategy: Optional[str] = "forward_backward"
245245
chunk_size_s: float = 300.0
246-
split_merge_ensemble_config: SplitMergeConfig = SplitMergeConfig()
246+
split_merge_ensemble_config: SplitMergeConfig = SplitMergeConfig(
247+
merge_template_config=TemplateConfig(superres_templates=False, realign_peaks=False,)
248+
)
247249

248250

249251
@dataclass(frozen=True)
@@ -280,7 +282,7 @@ def __post_init__(self):
280282

281283
if self.actual_device.type == "cuda":
282284
self.actual_n_jobs_gpu = self.n_jobs_gpu
283-
if self.actual_device.type == "cuda":
285+
else:
284286
self.actual_n_jobs_gpu = self.n_jobs_cpu
285287

286288

src/dartsort/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def split_merge(
196196
split_strategy=split_merge_config.split_strategy,
197197
recursive=split_merge_config.recursive_split,
198198
n_jobs=n_jobs_split,
199+
motion_est=motion_est,
199200
)
200201
if output_directory is not None:
201202
split_sorting.save(split_npz)
@@ -206,6 +207,7 @@ def split_merge(
206207
merge_sorting = merge_templates(
207208
split_sorting,
208209
recording,
210+
motion_est=motion_est,
209211
template_config=split_merge_config.merge_template_config,
210212
merge_distance_threshold=split_merge_config.merge_distance_threshold,
211213
min_spatial_cosine=split_merge_config.min_spatial_cosine,

src/dartsort/templates/templates.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def from_config(
108108
device=None,
109109
trough_offset_samples=42,
110110
spike_length_samples=121,
111+
return_realigned_sorting=False,
111112
):
112113
if save_folder is not None:
113114
save_folder = Path(save_folder)
@@ -181,7 +182,7 @@ def from_config(
181182

182183
# handle superresolved templates
183184
if template_config.superres_templates:
184-
unit_ids, sorting = superres_sorting(
185+
unit_ids, superres_sorting = superres_sorting(
185186
sorting,
186187
sorting.times_seconds,
187188
spike_depths_um,
@@ -194,6 +195,7 @@ def from_config(
194195
adaptive_bin_size=template_config.adaptive_bin_size,
195196
)
196197
else:
198+
superres_sorting = sorting
197199
# we don't skip empty units
198200
unit_ids = np.arange(sorting.labels.max() + 1)
199201

@@ -203,7 +205,7 @@ def from_config(
203205
spike_counts[ix[ix >= 0]] = counts[ix >= 0]
204206

205207
# main!
206-
results = get_templates(recording, sorting, **kwargs)
208+
results = get_templates(recording, superres_sorting, **kwargs)
207209

208210
# handle registered templates
209211
if template_config.registered_templates and motion_est is not None:
@@ -233,4 +235,7 @@ def from_config(
233235
if save_folder is not None:
234236
obj.to_npz(npz_path)
235237

238+
if return_realigned_sorting:
239+
return obj, sorting
240+
236241
return obj

src/dartsort/vis/scatterplots.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,12 @@ def scatter_time_vs_depth(
239239
if depths_um is None:
240240
depths_um = getattr(sorting, "point_source_localizations", None)
241241
if depths_um is not None:
242-
depths_um = x[:, 2]
242+
depths_um = depths_um[:, 2]
243243
if amplitudes is None:
244244
amplitudes = getattr(sorting, amplitudes_dataset_name, None)
245245
if hdf5_filename is None:
246246
hdf5_filename = sorting.parent_h5_path
247-
247+
248248
needs_load = any(v is None for v in (times_s, depths_um, amplitudes, geom))
249249
if needs_load and hdf5_filename is not None:
250250
with h5py.File(hdf5_filename, "r") as h5:
@@ -473,6 +473,7 @@ def scatter_feature_vs_depth(
473473

474474
if sorting is not None:
475475
labels = sorting.labels
476+
476477
if labels is None:
477478
c = np.clip(amplitudes, 0, amplitude_color_cutoff)
478479
cmap = amplitude_cmap
@@ -490,7 +491,6 @@ def scatter_feature_vs_depth(
490491
rasterized=rasterized,
491492
**scatter_kw,
492493
)
493-
to_show = to_show[kept]
494494
else:
495495
c = labels
496496
cmap = colorcet.m_glasbey_light

0 commit comments

Comments
 (0)