9
9
"""
10
10
from dataclasses import dataclass , replace
11
11
from pathlib import Path
12
- from typing import Optional
12
+ from typing import Callable , Optional
13
13
14
14
import h5py
15
15
import numpy as np
19
19
from sklearn .decomposition import PCA
20
20
from spikeinterface .comparison import GroundTruthComparison
21
21
22
- from ..cluster import relocate
22
+ from ..cluster import merge , relocate
23
23
from ..templates import TemplateData
24
24
from ..transform import WaveformPipeline
25
25
from .data_util import DARTsortSorting
@@ -56,6 +56,11 @@ class DARTsortAnalysis:
56
56
tpca_features_dataset = "collisioncleaned_tpca_features"
57
57
template_indices_dataset = "collisioncleaned_tpca_features"
58
58
59
+ # configuration for analysis computations not included in above objects
60
+ device : Optional [str , torch .device ] = None
61
+ merge_distance_templates_kind : str = "coarse"
62
+ merge_superres_linkage : Callable [[np .ndarray ], float ] = np .max
63
+
59
64
# helper constructors
60
65
61
66
@classmethod
@@ -110,7 +115,9 @@ def from_peeling_paths(
110
115
def __post_init__ (self ):
111
116
if self .featurization_pipeline is not None :
112
117
assert not self .featurization_pipeline .needs_fit ()
113
- assert np .isin (self .template_data .unit_ids , np .unique (self .sorting .labels )).all ()
118
+ assert np .isin (
119
+ self .template_data .unit_ids , np .unique (self .sorting .labels )
120
+ ).all ()
114
121
115
122
assert self .hdf5_path .exists ()
116
123
self .coarse_template_data = self .template_data .coarsen ()
@@ -127,6 +134,7 @@ def __post_init__(self):
127
134
self .motion_est is not None
128
135
and self .template_data .registered_geom is not None
129
136
)
137
+ assert self .coarse_template_data .unit_ids == self .unit_ids
130
138
131
139
# cached hdf5 pointer
132
140
self ._h5 = None
@@ -145,6 +153,7 @@ def clear_cache(self):
145
153
self ._sklearn_tpca = None
146
154
self ._unit_ids = None
147
155
self ._spike_counts = None
156
+ self ._merge_dist = None
148
157
self ._feats = {}
149
158
150
159
def __getstate__ (self ):
@@ -209,6 +218,12 @@ def sklearn_tpca(self):
209
218
self ._sklearn_tpca = tpca_feature [0 ].to_sklearn ()
210
219
return self ._sklearn_tpca
211
220
221
+ @property
222
+ def merge_dist (self ):
223
+ if self ._merge_dist is None :
224
+ self ._merge_dist = self ._calc_merge_dist ()
225
+ return self ._merge_dist
226
+
212
227
# spike train helpers
213
228
214
229
@property
@@ -236,6 +251,16 @@ def in_template(self, template_index):
236
251
def unit_template_indices (self , unit_id ):
237
252
return np .flatnonzero (self .template_data .unit_ids == self .unit_id )
238
253
254
+ @property
255
+ def show_geom (self ):
256
+ show_geom = self .template_data .registered_geom
257
+ if show_geom is None :
258
+ show_geom = self .recording .get_channel_locations ()
259
+ return show_geom
260
+
261
+ def show_channel_index (self , radius_um = 50 ):
262
+ return make_channel_index (self .show_geom , radius_um )
263
+
239
264
# spike feature loading methods
240
265
241
266
def named_feature (self , name , which = slice (None )):
@@ -330,7 +355,12 @@ def unit_raw_waveforms(
330
355
if not self .shifting :
331
356
return which , waveforms
332
357
333
- waveforms , max_chan , show_geom , show_channel_index = self .unit_shift_or_relocate_channels (
358
+ (
359
+ waveforms ,
360
+ max_chan ,
361
+ show_geom ,
362
+ show_channel_index ,
363
+ ) = self .unit_shift_or_relocate_channels (
334
364
unit_id ,
335
365
which ,
336
366
waveforms ,
@@ -367,7 +397,12 @@ def unit_tpca_waveforms(
367
397
t = waveforms .shape [1 ]
368
398
waveforms = waveforms .reshape (n , c , t ).transpose (0 , 2 , 1 )
369
399
370
- waveforms , max_chan , show_geom , show_channel_index = self .unit_shift_or_relocate_channels (
400
+ (
401
+ waveforms ,
402
+ max_chan ,
403
+ show_geom ,
404
+ show_channel_index ,
405
+ ) = self .unit_shift_or_relocate_channels (
371
406
unit_id ,
372
407
which ,
373
408
waveforms ,
@@ -378,9 +413,21 @@ def unit_tpca_waveforms(
378
413
return which , waveforms , max_chan , show_geom , show_channel_index
379
414
380
415
def unit_pca_features (
381
- self , unit_id , relocated = True , rank = 2 , pca_radius_um = 75 , random_seed = 0 , max_count = 500
416
+ self ,
417
+ unit_id ,
418
+ relocated = True ,
419
+ rank = 2 ,
420
+ pca_radius_um = 75 ,
421
+ random_seed = 0 ,
422
+ max_count = 500 ,
382
423
):
383
- which , waveforms , max_chan , show_geom , show_channel_index = self .unit_tpca_waveforms (
424
+ (
425
+ which ,
426
+ waveforms ,
427
+ max_chan ,
428
+ show_geom ,
429
+ show_channel_index ,
430
+ ) = self .unit_tpca_waveforms (
384
431
unit_id ,
385
432
relocated = relocated ,
386
433
show_radius_um = pca_radius_um ,
@@ -439,7 +486,9 @@ def unit_shift_or_relocate_channels(
439
486
show_channel_index = make_channel_index (show_geom , show_radius_um )
440
487
show_chans = show_channel_index [max_chan ]
441
488
show_chans = show_chans [show_chans < len (show_geom )]
442
- show_channel_index = np .broadcast_to (show_chans [None ], (len (show_geom ), show_chans .size ))
489
+ show_channel_index = np .broadcast_to (
490
+ show_chans [None ], (len (show_geom ), show_chans .size )
491
+ )
443
492
444
493
if not self .shifting :
445
494
return waveforms , max_chan , show_geom , show_channel_index
@@ -478,6 +527,34 @@ def unit_shift_or_relocate_channels(
478
527
479
528
return waveforms , max_chan , show_geom , show_channel_index
480
529
530
+ def nearby_coarse_templates (self , unit_id , n_neighbors = 5 ):
531
+ unit_ix = np .searchsorted (self .unit_ids , unit_id )
532
+ unit_dists = self .merge_dist [unit_ix ]
533
+ distance_order = np .argsort (unit_dists )
534
+ assert distance_order [0 ] == unit_ix
535
+ neighbor_ixs = distance_order [:n_neighbors ]
536
+ neighbor_ids = self .unit_ids [:n_neighbors ]
537
+ neighbor_dists = self .merge_dist [neighbor_ixs [:, None ], neighbor_ixs [None , :]]
538
+ neighbor_coarse_templates = self .coarse_template_data .templates [neighbor_ixs ]
539
+ return neighbor_ids , neighbor_dists , neighbor_coarse_templates
540
+
541
+ # computation
542
+
543
+ def _calc_merge_dist (self ):
544
+ """Compute the merge distance matrix"""
545
+ merge_td = self .template_data
546
+ if self .merge_distance_templates_kind == "coarse" :
547
+ merge_td = self .coarse_template_data
548
+
549
+ units , dists , shifts , template_snrs = merge .calculate_merge_distances (
550
+ merge_td ,
551
+ superres_linkage = self .merge_superres_linkage ,
552
+ device = self .device ,
553
+ n_jobs = 1 ,
554
+ )
555
+ assert np .array_equal (units , self .unit_ids )
556
+ self ._merge_dist = dists
557
+
481
558
482
559
@dataclass
483
560
class DARTsortGroundTruthComparison :
0 commit comments