@@ -162,7 +162,16 @@ def draw(self, panel, gmm, unit_id, axes=None):
162
162
(in_unit ,) = torch .nonzero (gmm .labels == unit_id , as_tuple = True )
163
163
if not in_unit .numel ():
164
164
return
165
- inds_ , liks = gmm .unit_log_likelihoods (unit_id , spike_indices = in_unit )
165
+ if hasattr (gmm , "log_liks" ):
166
+ liks_ = gmm .log_liks [:, in_unit ][[unit_id ]].tocoo ()
167
+ inds_ = None
168
+ if liks_ .nnz :
169
+ inds_ = in_unit
170
+ liks = np .full (in_unit .shape , - np .inf , dtype = np .float32 )
171
+ liks [liks_ .coords [1 ]] = liks_ .data
172
+ liks = torch .from_numpy (liks )
173
+ else :
174
+ inds_ , liks = gmm .unit_log_likelihoods (unit_id , spike_indices = in_unit )
166
175
if inds_ is None :
167
176
return
168
177
assert torch .equal (inds_ , in_unit )
@@ -215,6 +224,12 @@ def __init__(self, layout="vert"):
215
224
def draw (self , panel , gmm , unit_id , split_info = None ):
216
225
if split_info is None :
217
226
split_info = gmm .kmeans_split_unit (unit_id , debug = True )
227
+ if not split_info :
228
+ ax = panel .subplots ()
229
+ ax .text (.5 , .5 , "no channels!" , ha = "center" , transform = ax .transAxes )
230
+ ax .axis ("off" )
231
+ return
232
+
218
233
split_labels = split_info ["reas_labels" ]
219
234
split_ids = np .unique (split_labels )
220
235
@@ -370,8 +385,13 @@ def __init__(self, n_neighbors=5):
370
385
def draw (self , panel , gmm , unit_id ):
371
386
neighbors = gmm_helpers .get_neighbors (gmm , unit_id )
372
387
assert neighbors [0 ] == unit_id
373
- log_liks = gmm .log_likelihoods (unit_ids = neighbors )
374
- labels , spikells = gaussian_mixture .loglik_reassign (log_liks , has_noise_unit = True )
388
+ if hasattr (gmm , "log_liks" ):
389
+ neighbors_plus_noiseunit = np .concatenate ((neighbors , [gmm .log_liks .shape [0 ] - 1 ]))
390
+ log_liks = gmm .log_liks [neighbors_plus_noiseunit ]
391
+ else :
392
+ log_liks = gmm .log_likelihoods (unit_ids = neighbors )
393
+ labels , spikells , log_liks = gaussian_mixture .loglik_reassign (log_liks , has_noise_unit = True )
394
+ log_liks = log_liks .tocoo ()
375
395
log_liks = gaussian_mixture .coo_to_torch (log_liks , torch .float )
376
396
kept = labels >= 0
377
397
labels_ = np .full_like (labels , - 1 )
@@ -409,12 +429,12 @@ def draw(self, panel, gmm, unit_id):
409
429
bimod_ax .text (0 , 0 , f"too-small kept prop { bimod_info ['keep_prop' ]:.2f} " )
410
430
bimod_ax .axis ("off" )
411
431
continue
412
- bimod_ax .hist (bimod_info ["samples" ], color = "gray" , label = "unweighted hist" , ** histkw )
432
+ bimod_ax .hist (bimod_info ["samples" ], color = "gray" , label = "hist" , ** histkw )
413
433
bimod_ax .hist (
414
434
bimod_info ["samples" ],
415
435
weights = bimod_info ["sample_weights" ],
416
436
color = "k" ,
417
- label = "weighted hist " ,
437
+ label = "whist " ,
418
438
** histkw ,
419
439
)
420
440
bimod_ax .axvline (bimod_info ["cut" ], color = "k" , lw = 0.8 , ls = ":" )
@@ -450,7 +470,7 @@ def make_unit_gmm_summary(
450
470
unit_id ,
451
471
plots = default_gmm_plots ,
452
472
max_height = 9 ,
453
- figsize = (13 , 9 ),
473
+ figsize = (14 , 11 ),
454
474
hspace = 0.1 ,
455
475
figure = None ,
456
476
** other_global_params ,
@@ -479,7 +499,7 @@ def make_all_gmm_summaries(
479
499
save_folder ,
480
500
plots = default_gmm_plots ,
481
501
max_height = 9 ,
482
- figsize = (13 , 9 ),
502
+ figsize = (14 , 11 ),
483
503
hspace = 0.1 ,
484
504
dpi = 200 ,
485
505
image_ext = "png" ,
0 commit comments