@@ -28,7 +28,7 @@ class UnitPlot:
28
28
width = 1
29
29
height = 1
30
30
31
- def draw (self , axis , sorting_analysis , unit_id ):
31
+ def draw (self , panel , sorting_analysis , unit_id ):
32
32
raise NotImplementedError
33
33
34
34
def notify_global_params (self , ** params ):
@@ -47,7 +47,8 @@ class TextInfo(UnitPlot):
47
47
kind = "text"
48
48
height = 0.5
49
49
50
- def draw (self , axis , sorting_analysis , unit_id ):
50
+ def draw (self , panel , sorting_analysis , unit_id ):
51
+ axis = panel .subplots ()
51
52
axis .axis ("off" )
52
53
msg = f"unit { unit_id } \n "
53
54
@@ -77,12 +78,13 @@ class ACG(UnitPlot):
77
78
def __init__ (self , max_lag = 50 ):
78
79
self .max_lag = max_lag
79
80
80
- def draw (self , axis , sorting_analysis , unit_id ):
81
+ def draw (self , panel , sorting_analysis , unit_id ):
82
+ axis = panel .subplots ()
81
83
times_samples = sorting_analysis .times_samples (
82
84
which = sorting_analysis .in_unit (unit_id )
83
85
)
84
86
lags , acg = correlogram (times_samples , max_lag = self .max_lag )
85
- axis . bar (lags , acg )
87
+ bar (axis , lags , acg , fill = True , color = "k" )
86
88
axis .set_xlabel ("lag (samples)" )
87
89
axis .set_ylabel ("acg" )
88
90
@@ -95,7 +97,8 @@ def __init__(self, bin_ms=0.1, max_ms=5):
95
97
self .bin_ms = bin_ms
96
98
self .max_ms = max_ms
97
99
98
- def draw (self , axis , sorting_analysis , unit_id ):
100
+ def draw (self , panel , sorting_analysis , unit_id ):
101
+ axis = panel .subplots ()
99
102
times_s = sorting_analysis .times_seconds (
100
103
which = sorting_analysis .in_unit (unit_id )
101
104
)
@@ -108,7 +111,7 @@ def draw(self, axis, sorting_analysis, unit_id):
108
111
# counts, _ = np.histogram(dt_ms, bin_edges)
109
112
# bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1])
110
113
# axis.bar(bin_centers, counts)
111
- plt .hist (dt_ms , bin_edges )
114
+ plt .hist (dt_ms , bin_edges , color = "k" )
112
115
axis .set_xlabel ("isi (ms)" )
113
116
axis .set_ylabel (f"count (out of { dt_ms .size } total isis)" )
114
117
@@ -130,7 +133,8 @@ def __init__(
130
133
self .probe_margin_um = probe_margin_um
131
134
self .colorbar = colorbar
132
135
133
- def draw (self , axis , sorting_analysis , unit_id ):
136
+ def draw (self , panel , sorting_analysis , unit_id ):
137
+ axis = panel .subplots ()
134
138
in_unit = sorting_analysis .in_unit (unit_id )
135
139
x = sorting_analysis .x (which = in_unit )
136
140
z = sorting_analysis .z (which = in_unit , registered = self .registered )
@@ -160,15 +164,19 @@ class PCAScatter(UnitPlot):
160
164
kind = "scatter"
161
165
162
166
def __init__ (
163
- self , relocate_amplitudes = False , relocated = True , amplitude_color_cutoff = 15 ,
164
- colorbar = False
167
+ self ,
168
+ relocate_amplitudes = False ,
169
+ relocated = True ,
170
+ amplitude_color_cutoff = 15 ,
171
+ colorbar = False ,
165
172
):
166
173
self .relocated = relocated
167
174
self .relocate_amplitudes = relocate_amplitudes
168
175
self .amplitude_color_cutoff = amplitude_color_cutoff
169
176
self .colorbar = colorbar
170
177
171
- def draw (self , axis , sorting_analysis , unit_id ):
178
+ def draw (self , panel , sorting_analysis , unit_id ):
179
+ axis = panel .subplots ()
172
180
which , loadings = sorting_analysis .unit_pca_features (
173
181
unit_id = unit_id , relocated = self .relocated
174
182
)
@@ -205,7 +213,8 @@ def __init__(
205
213
self .amplitude_color_cutoff = amplitude_color_cutoff
206
214
self .probe_margin_um = probe_margin_um
207
215
208
- def draw (self , axis , sorting_analysis , unit_id ):
216
+ def draw (self , panel , sorting_analysis , unit_id ):
217
+ axis = panel .subplots ()
209
218
in_unit = sorting_analysis .in_unit (unit_id )
210
219
t = sorting_analysis .times_seconds (which = in_unit )
211
220
z = sorting_analysis .z (which = in_unit , registered = self .registered )
@@ -245,7 +254,8 @@ def __init__(
245
254
self .amplitude_color_cutoff = amplitude_color_cutoff
246
255
self .color_by_amplitude = color_by_amplitude
247
256
248
- def draw (self , axis , sorting_analysis , unit_id ):
257
+ def draw (self , panel , sorting_analysis , unit_id ):
258
+ axis = panel .subplots ()
249
259
in_unit = sorting_analysis .in_unit (unit_id )
250
260
t = sorting_analysis .times_seconds (which = in_unit )
251
261
feat = sorting_analysis .named_feature (self .feat_name , which = in_unit )
@@ -271,7 +281,8 @@ def __init__(self, relocate_amplitudes=False, amplitude_color_cutoff=15):
271
281
self .relocate_amplitudes = relocate_amplitudes
272
282
self .amplitude_color_cutoff = amplitude_color_cutoff
273
283
274
- def draw (self , axis , sorting_analysis , unit_id ):
284
+ def draw (self , panel , sorting_analysis , unit_id ):
285
+ axis = panel .subplots ()
275
286
in_unit = sorting_analysis .in_unit (unit_id )
276
287
t = sorting_analysis .times_seconds (which = in_unit )
277
288
amps = sorting_analysis .amplitudes (
@@ -327,7 +338,8 @@ def __init__(
327
338
def get_waveforms (self , sorting_analysis , unit_id ):
328
339
raise NotImplementedError
329
340
330
- def draw (self , axis , sorting_analysis , unit_id ):
341
+ def draw (self , panel , sorting_analysis , unit_id ):
342
+ axis = panel .subplots ()
331
343
which , waveforms , max_chan , geom , ci = self .get_waveforms (
332
344
sorting_analysis , unit_id
333
345
)
@@ -353,7 +365,7 @@ def draw(self, axis, sorting_analysis, unit_id):
353
365
new_offset = self .trough_offset_samples ,
354
366
new_length = self .spike_length_samples ,
355
367
)
356
- max_abs_amp = self .max_abs_template_scale * np .abs (templates ). max ( )
368
+ max_abs_amp = self .max_abs_template_scale * np .nanmax ( np . abs (templates ))
357
369
show_superres_templates = (
358
370
self .show_superres_templates and self .template_index is None
359
371
)
@@ -368,7 +380,7 @@ def draw(self, axis, sorting_analysis, unit_id):
368
380
new_length = self .spike_length_samples ,
369
381
)
370
382
show_superres_templates = suptemplates .shape [0 ] > 1
371
- max_abs_amp = self .max_abs_template_scale * np .abs (suptemplates ). max ( )
383
+ max_abs_amp = self .max_abs_template_scale * np .nanmax ( np . abs (suptemplates ))
372
384
373
385
ls = geomplot (
374
386
waveforms ,
@@ -434,6 +446,7 @@ def draw(self, axis, sorting_analysis, unit_id):
434
446
reg_str = "registered " * sorting_analysis .shifting
435
447
axis .set_ylabel (reg_str + "depth (um)" )
436
448
axis .set_xticks ([])
449
+ axis .set_yticks ([])
437
450
438
451
if self .legend :
439
452
axis .legend (
@@ -487,7 +500,8 @@ def __init__(self, channel_show_radius_um=50, n_neighbors=5, legend=True):
487
500
self .n_neighbors = n_neighbors
488
501
self .legend = legend
489
502
490
- def draw (self , axis , sorting_analysis , unit_id ):
503
+ def draw (self , panel , sorting_analysis , unit_id ):
504
+ axis = panel .subplots ()
491
505
(
492
506
neighbor_ids ,
493
507
neighbor_dists ,
@@ -529,7 +543,7 @@ def draw(self, axis, sorting_analysis, unit_id):
529
543
)
530
544
labels .append (str (uid ))
531
545
handles .append (lines [0 ])
532
- axis .legend (handles = handles , labels = labels , fancybox = False )
546
+ axis .legend (handles = handles , labels = labels , fancybox = False , loc = "upper left" )
533
547
axis .set_xticks ([])
534
548
axis .set_yticks ([])
535
549
axis .set_title (self .title )
@@ -539,15 +553,18 @@ class CoarseTemplateDistancePlot(UnitPlot):
539
553
title = "coarse template distance"
540
554
kind = "neighbors"
541
555
width = 2
542
- height = 2
556
+ height = 1.25
543
557
544
- def __init__ (self , channel_show_radius_um = 50 , n_neighbors = 5 , dist_vmax = 1.0 , show_values = True ):
558
+ def __init__ (
559
+ self , channel_show_radius_um = 50 , n_neighbors = 5 , dist_vmax = 1.0 , show_values = True
560
+ ):
545
561
self .channel_show_radius_um = channel_show_radius_um
546
562
self .n_neighbors = n_neighbors
547
563
self .dist_vmax = dist_vmax
548
564
self .show_values = show_values
549
565
550
- def draw (self , axis , sorting_analysis , unit_id ):
566
+ def draw (self , panel , sorting_analysis , unit_id ):
567
+ axis = panel .subplots ()
551
568
(
552
569
neighbor_ids ,
553
570
neighbor_dists ,
@@ -580,6 +597,52 @@ def draw(self, axis, sorting_analysis, unit_id):
580
597
axis .set_title (self .title )
581
598
582
599
600
+ class NeighborCCGPlot (UnitPlot ):
601
+ kind = "neighbors"
602
+ width = 2
603
+ height = 0.75
604
+
605
+ def __init__ (self , n_neighbors = 3 , max_lag = 50 ):
606
+ self .n_neighbors = n_neighbors
607
+ self .max_lag = max_lag
608
+
609
+ def draw (self , panel , sorting_analysis , unit_id ):
610
+ (
611
+ neighbor_ids ,
612
+ neighbor_dists ,
613
+ neighbor_coarse_templates ,
614
+ ) = sorting_analysis .nearby_coarse_templates (
615
+ unit_id , n_neighbors = self .n_neighbors + 1
616
+ )
617
+ colors = np .array (cc .glasbey_light )[neighbor_ids % len (cc .glasbey_light )]
618
+ # assert neighbor_ids[0] == unit_id
619
+ neighbor_ids = neighbor_ids [1 :]
620
+
621
+ my_st = sorting_analysis .times_samples (which = sorting_analysis .in_unit (unit_id ))
622
+ neighb_sts = [
623
+ sorting_analysis .times_samples (which = sorting_analysis .in_unit (nid ))
624
+ for nid in neighbor_ids
625
+ ]
626
+ ccgs = [correlogram (my_st , nst , max_lag = self .max_lag ) for nst in neighb_sts ]
627
+ acgs = [correlogram (my_st , nst , max_lag = self .max_lag ) for nst in neighb_sts ]
628
+
629
+ axes = panel .subplots (
630
+ nrows = 2 , sharey = "row" , sharex = True , ncols = self .n_neighbors
631
+ )
632
+ for j in range (self .n_neighbors ):
633
+ clags , ccg = correlogram (my_st , neighb_sts [j ], max_lag = self .max_lag )
634
+ merged_st = np .concatenate ((my_st , neighb_sts [j ]))
635
+ merged_st .sort ()
636
+ alags , acg = correlogram (merged_st , max_lag = self .max_lag )
637
+
638
+ bar (axes [0 , j ], clags , ccg , fill = True , fc = colors [j ]) # , ec="k", lw=1)
639
+ bar (axes [1 , j ], alags , acg , fill = True , fc = colors [j ]) # , ec="k", lw=1)
640
+ axes [1 , j ].set_xlabel ("lag (samples)" )
641
+ axes [0 , j ].set_title (f"unit { neighbor_ids [j ]} " )
642
+ axes [0 , 0 ].set_ylabel ("ccg" )
643
+ axes [1 , 0 ].set_ylabel ("merged acg" )
644
+
645
+
583
646
# -- multi plots
584
647
# these have multiple plots per unit, and we don't know in advance how many
585
648
# for instance, making separate plots of spikes belonging to each superres template
@@ -662,6 +725,7 @@ def unit_plots(self, sorting_analysis, unit_id):
662
725
TPCAWaveformPlot (relocated = True ),
663
726
NearbyCoarseTemplatesPlot (),
664
727
CoarseTemplateDistancePlot (),
728
+ NeighborCCGPlot (),
665
729
)
666
730
667
731
@@ -716,10 +780,11 @@ def make_unit_summary(
716
780
all_panels .extend (cardfigs )
717
781
718
782
for cardfig , card in zip (cardfigs , column ):
719
- axes = cardfig .subplots (nrows = len (card .plots ), ncols = 1 )
720
- axes = np .atleast_1d (axes )
721
- for plot , axis in zip (card .plots , axes ):
722
- plot .draw (axis , sorting_analysis , unit_id )
783
+ panels = cardfig .subfigures (nrows = len (card .plots ), ncols = 1 )
784
+ panels = np .atleast_1d (panels )
785
+ for plot , panel in zip (card .plots , panels ):
786
+ plot .draw (panel , sorting_analysis , unit_id )
787
+ all_panels .extend (panels )
723
788
724
789
# clean up the panels, or else things get clipped
725
790
for panel in all_panels :
@@ -953,3 +1018,9 @@ def _summary_job(unit_id):
953
1018
fig .savefig (tmp_out , dpi = _summary_job_context .dpi )
954
1019
tmp_out .rename (final_out )
955
1020
plt .close (fig )
1021
+
1022
+
1023
+ def bar (ax , x , y , ** kwargs ):
1024
+ dx = np .diff (x ).min ()
1025
+ x0 = np .concatenate ((x - dx , x [- 1 :] + dx ))
1026
+ ax .stairs (y , x0 , ** kwargs )
0 commit comments