Skip to content

Commit 321e24a

Browse files
committed
Nearby unit vis and high level functions
1 parent 0a1ee31 commit 321e24a

File tree

7 files changed

+299
-25
lines changed

7 files changed

+299
-25
lines changed

src/dartsort/cluster/merge.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def merge_templates(
1818
motion_est=None,
1919
max_shift_samples=20,
2020
superres_linkage=np.max,
21+
sym_function=np.minimum,
2122
merge_distance_threshold=0.25,
2223
temporal_upsampling_factor=8,
2324
amplitude_scaling_variance=0.0,
@@ -68,6 +69,50 @@ def merge_templates(
6869
save_npz_name=template_npz_filename,
6970
)
7071

72+
units, dists, shifts, template_snrs = calculate_merge_distances(
73+
template_data,
74+
superres_linkage=superres_linkage,
75+
sym_function=sym_function,
76+
max_shift_samples=max_shift_samples,
77+
temporal_upsampling_factor=temporal_upsampling_factor,
78+
amplitude_scaling_variance=amplitude_scaling_variance,
79+
amplitude_scaling_boundary=amplitude_scaling_boundary,
80+
svd_compression_rank=svd_compression_rank,
81+
min_channel_amplitude=min_channel_amplitude,
82+
conv_batch_size=conv_batch_size,
83+
units_batch_size=units_batch_size,
84+
device=device,
85+
n_jobs=n_jobs,
86+
show_progress=show_progress,
87+
)
88+
89+
# now run hierarchical clustering
90+
return recluster(
91+
sorting,
92+
units,
93+
dists,
94+
shifts,
95+
template_snrs,
96+
merge_distance_threshold=merge_distance_threshold,
97+
)
98+
99+
100+
def calculate_merge_distances(
101+
template_data,
102+
superres_linkage=np.max,
103+
sym_function=np.minimum,
104+
max_shift_samples=20,
105+
temporal_upsampling_factor=8,
106+
amplitude_scaling_variance=0.0,
107+
amplitude_scaling_boundary=0.5,
108+
svd_compression_rank=10,
109+
min_channel_amplitude=0.0,
110+
conv_batch_size=128,
111+
units_batch_size=8,
112+
device=None,
113+
n_jobs=0,
114+
show_progress=True,
115+
):
71116
# allocate distance + shift matrices. shifts[i,j] is trough[j]-trough[i].
72117
n_templates = template_data.templates.shape[0]
73118
sup_dists = np.full((n_templates, n_templates), np.inf)
@@ -116,15 +161,9 @@ def merge_templates(
116161
template_data.templates.ptp(1).max(1) / template_data.spike_counts
117162
)
118163

119-
# now run hierarchical clustering
120-
return recluster(
121-
sorting,
122-
units,
123-
dists,
124-
shifts,
125-
template_snrs,
126-
merge_distance_threshold=merge_distance_threshold,
127-
)
164+
dists = sym_function(dists, dists.T)
165+
166+
return units, dists, shifts, template_snrs
128167

129168

130169
def recluster(
@@ -134,9 +173,7 @@ def recluster(
134173
shifts,
135174
template_snrs,
136175
merge_distance_threshold=0.25,
137-
sym_function=np.minimum,
138176
):
139-
dists = sym_function(dists, dists.T)
140177

141178
# upper triangle not including diagonal, aka condensed distance matrix in scipy
142179
pdist = dists[np.triu_indices(dists.shape[0], k=1)]

src/dartsort/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,4 @@ class SplitMergeConfig:
242242
default_split_merge_config = SplitMergeConfig()
243243
coarse_template_config = TemplateConfig(superres_templates=False)
244244
default_matching_config = MatchingConfig()
245+
default_motion_estimation_config = MotionEstimationConfig()

src/dartsort/main.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from dataclasses import asdict
12
from pathlib import Path
23

34
from dartsort.cluster.initial import ensemble_chunks
@@ -6,6 +7,7 @@
67
from dartsort.config import (default_clustering_config,
78
default_featurization_config,
89
default_matching_config,
10+
default_motion_estimation_config,
911
default_split_merge_config,
1012
default_subtraction_config,
1113
default_template_config)
@@ -14,6 +16,7 @@
1416
from dartsort.templates import TemplateData
1517
from dartsort.util.data_util import check_recording
1618
from dartsort.util.peel_util import run_peeler
19+
from dartsort.util.registration_util import estimate_motion
1720

1821

1922
def dartsort_from_config(
@@ -27,6 +30,7 @@ def dartsort(
2730
recording,
2831
output_directory,
2932
featurization_config=default_featurization_config,
33+
motion_estimation_config=default_motion_estimation_config,
3034
subtraction_config=default_subtraction_config,
3135
matching_config=default_subtraction_config,
3236
template_config=default_template_config,
@@ -50,8 +54,9 @@ def dartsort(
5054
device=device,
5155
)
5256
if motion_est is None:
53-
# TODO
54-
motion_est = estimate_motion()
57+
motion_est = estimate_motion(
58+
recording, sorting, **asdict(motion_estimation_config)
59+
)
5560
sorting = cluster(
5661
sub_h5,
5762
recording,

src/dartsort/util/analysis.py

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010
from dataclasses import dataclass, replace
1111
from pathlib import Path
12-
from typing import Optional
12+
from typing import Callable, Optional
1313

1414
import h5py
1515
import numpy as np
@@ -19,7 +19,7 @@
1919
from sklearn.decomposition import PCA
2020
from spikeinterface.comparison import GroundTruthComparison
2121

22-
from ..cluster import relocate
22+
from ..cluster import merge, relocate
2323
from ..templates import TemplateData
2424
from ..transform import WaveformPipeline
2525
from .data_util import DARTsortSorting
@@ -56,6 +56,11 @@ class DARTsortAnalysis:
5656
tpca_features_dataset = "collisioncleaned_tpca_features"
5757
template_indices_dataset = "collisioncleaned_tpca_features"
5858

59+
# configuration for analysis computations not included in above objects
60+
device: Optional[str, torch.device] = None
61+
merge_distance_templates_kind: str = "coarse"
62+
merge_superres_linkage: Callable[[np.ndarray], float] = np.max
63+
5964
# helper constructors
6065

6166
@classmethod
@@ -110,7 +115,9 @@ def from_peeling_paths(
110115
def __post_init__(self):
111116
if self.featurization_pipeline is not None:
112117
assert not self.featurization_pipeline.needs_fit()
113-
assert np.isin(self.template_data.unit_ids, np.unique(self.sorting.labels)).all()
118+
assert np.isin(
119+
self.template_data.unit_ids, np.unique(self.sorting.labels)
120+
).all()
114121

115122
assert self.hdf5_path.exists()
116123
self.coarse_template_data = self.template_data.coarsen()
@@ -127,6 +134,7 @@ def __post_init__(self):
127134
self.motion_est is not None
128135
and self.template_data.registered_geom is not None
129136
)
137+
assert self.coarse_template_data.unit_ids == self.unit_ids
130138

131139
# cached hdf5 pointer
132140
self._h5 = None
@@ -145,6 +153,7 @@ def clear_cache(self):
145153
self._sklearn_tpca = None
146154
self._unit_ids = None
147155
self._spike_counts = None
156+
self._merge_dist = None
148157
self._feats = {}
149158

150159
def __getstate__(self):
@@ -209,6 +218,12 @@ def sklearn_tpca(self):
209218
self._sklearn_tpca = tpca_feature[0].to_sklearn()
210219
return self._sklearn_tpca
211220

221+
@property
222+
def merge_dist(self):
223+
if self._merge_dist is None:
224+
self._merge_dist = self._calc_merge_dist()
225+
return self._merge_dist
226+
212227
# spike train helpers
213228

214229
@property
@@ -236,6 +251,16 @@ def in_template(self, template_index):
236251
def unit_template_indices(self, unit_id):
237252
return np.flatnonzero(self.template_data.unit_ids == self.unit_id)
238253

254+
@property
255+
def show_geom(self):
256+
show_geom = self.template_data.registered_geom
257+
if show_geom is None:
258+
show_geom = self.recording.get_channel_locations()
259+
return show_geom
260+
261+
def show_channel_index(self, radius_um=50):
262+
return make_channel_index(self.show_geom, radius_um)
263+
239264
# spike feature loading methods
240265

241266
def named_feature(self, name, which=slice(None)):
@@ -330,7 +355,12 @@ def unit_raw_waveforms(
330355
if not self.shifting:
331356
return which, waveforms
332357

333-
waveforms, max_chan, show_geom, show_channel_index = self.unit_shift_or_relocate_channels(
358+
(
359+
waveforms,
360+
max_chan,
361+
show_geom,
362+
show_channel_index,
363+
) = self.unit_shift_or_relocate_channels(
334364
unit_id,
335365
which,
336366
waveforms,
@@ -367,7 +397,12 @@ def unit_tpca_waveforms(
367397
t = waveforms.shape[1]
368398
waveforms = waveforms.reshape(n, c, t).transpose(0, 2, 1)
369399

370-
waveforms, max_chan, show_geom, show_channel_index = self.unit_shift_or_relocate_channels(
400+
(
401+
waveforms,
402+
max_chan,
403+
show_geom,
404+
show_channel_index,
405+
) = self.unit_shift_or_relocate_channels(
371406
unit_id,
372407
which,
373408
waveforms,
@@ -378,9 +413,21 @@ def unit_tpca_waveforms(
378413
return which, waveforms, max_chan, show_geom, show_channel_index
379414

380415
def unit_pca_features(
381-
self, unit_id, relocated=True, rank=2, pca_radius_um=75, random_seed=0, max_count=500
416+
self,
417+
unit_id,
418+
relocated=True,
419+
rank=2,
420+
pca_radius_um=75,
421+
random_seed=0,
422+
max_count=500,
382423
):
383-
which, waveforms, max_chan, show_geom, show_channel_index = self.unit_tpca_waveforms(
424+
(
425+
which,
426+
waveforms,
427+
max_chan,
428+
show_geom,
429+
show_channel_index,
430+
) = self.unit_tpca_waveforms(
384431
unit_id,
385432
relocated=relocated,
386433
show_radius_um=pca_radius_um,
@@ -439,7 +486,9 @@ def unit_shift_or_relocate_channels(
439486
show_channel_index = make_channel_index(show_geom, show_radius_um)
440487
show_chans = show_channel_index[max_chan]
441488
show_chans = show_chans[show_chans < len(show_geom)]
442-
show_channel_index = np.broadcast_to(show_chans[None], (len(show_geom), show_chans.size))
489+
show_channel_index = np.broadcast_to(
490+
show_chans[None], (len(show_geom), show_chans.size)
491+
)
443492

444493
if not self.shifting:
445494
return waveforms, max_chan, show_geom, show_channel_index
@@ -478,6 +527,34 @@ def unit_shift_or_relocate_channels(
478527

479528
return waveforms, max_chan, show_geom, show_channel_index
480529

530+
def nearby_coarse_templates(self, unit_id, n_neighbors=5):
531+
unit_ix = np.searchsorted(self.unit_ids, unit_id)
532+
unit_dists = self.merge_dist[unit_ix]
533+
distance_order = np.argsort(unit_dists)
534+
assert distance_order[0] == unit_ix
535+
neighbor_ixs = distance_order[:n_neighbors]
536+
neighbor_ids = self.unit_ids[:n_neighbors]
537+
neighbor_dists = self.merge_dist[neighbor_ixs[:, None], neighbor_ixs[None, :]]
538+
neighbor_coarse_templates = self.coarse_template_data.templates[neighbor_ixs]
539+
return neighbor_ids, neighbor_dists, neighbor_coarse_templates
540+
541+
# computation
542+
543+
def _calc_merge_dist(self):
544+
"""Compute the merge distance matrix"""
545+
merge_td = self.template_data
546+
if self.merge_distance_templates_kind == "coarse":
547+
merge_td = self.coarse_template_data
548+
549+
units, dists, shifts, template_snrs = merge.calculate_merge_distances(
550+
merge_td,
551+
superres_linkage=self.merge_superres_linkage,
552+
device=self.device,
553+
n_jobs=1,
554+
)
555+
assert np.array_equal(units, self.unit_ids)
556+
self._merge_dist = dists
557+
481558

482559
@dataclass
483560
class DARTsortGroundTruthComparison:
Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,65 @@
1+
from typing import Optional
2+
3+
import numpy as np
4+
15
try:
26
from dredge import dredge_ap
7+
38
have_dredge = True
49
except ImportError:
510
have_dredge = False
611
pass
712

813

9-
def estimate_motion(recording, sorting, motion_estimation_config=None, localizations_dataset_name="point_source_localizations"):
10-
if not motion_estimation_config.do_motion_estimation:
14+
def estimate_motion(
15+
recording,
16+
sorting,
17+
do_motion_estimation=True,
18+
probe_boundary_padding_um=100.0,
19+
spatial_bin_length_um: float = 1.0,
20+
temporal_bin_length_s: float = 1.0,
21+
window_step_um: float = 400.0,
22+
window_scale_um: float = 450.0,
23+
window_margin_um: Optional[float] = None,
24+
max_dt_s: float = 0.1,
25+
max_disp_um: Optional[float] = None,
26+
localizations_dataset_name="point_source_localizations",
27+
amplitudes_dataset_name="denoised_ptp_amplitudes",
28+
):
29+
if not do_motion_estimation:
1130
return None
1231

1332
if not have_dredge:
1433
raise ValueError("Please install DREDge to use motion estimation.")
34+
35+
x = getattr(sorting, localizations_dataset_name)[:, 0]
36+
z = getattr(sorting, localizations_dataset_name)[:, 1]
37+
geom = recording.get_channel_locations()
38+
xmin = geom[:, 0].min() - probe_boundary_padding_um
39+
xmax = geom[:, 0].max() + probe_boundary_padding_um
40+
zmin = geom[:, 1].min() - probe_boundary_padding_um
41+
zmax = geom[:, 1].max() + probe_boundary_padding_um
42+
xvalid = x == np.clip(x, xmin, xmax)
43+
zvalid = z == np.clip(z, zmin, zmax)
44+
valid = np.flatnonzero(xvalid & zvalid)
45+
46+
# features for registration
47+
z = z[valid]
48+
t_s = sorting.times_seconds[valid]
49+
a = getattr(sorting, amplitudes_dataset_name)[valid]
50+
51+
# run registration
52+
motion_est, info = dredge_ap.register(
53+
a,
54+
z,
55+
t_s,
56+
window_step_um=window_step_um,
57+
bin_um=spatial_bin_length_um,
58+
bin_s=temporal_bin_length_s,
59+
window_scale_um=window_scale_um,
60+
window_margin_um=window_margin_um,
61+
max_disp_um=max_disp_um,
62+
max_dt_s=max_dt_s,
63+
)
64+
65+
return motion_est

src/dartsort/vis/analysis_plots.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def distance_matrix_dendro():
2+
pass

0 commit comments

Comments
 (0)