@@ -410,7 +410,7 @@ def em(
410
410
unit_churn , reas_count , log_liks , spike_logliks = self .e_step (
411
411
show_progress = step_progress , split = final_split
412
412
)
413
- log_liks , _ = self .cleanup (log_liks )
413
+ log_liks , _ = self .cleanup (log_liks , relabel_split = final_split )
414
414
return log_liks
415
415
416
416
def e_step (
@@ -454,7 +454,9 @@ def m_step(
454
454
if self .use_proportions and likelihoods is not None :
455
455
self .update_proportions (likelihoods )
456
456
if self .log_proportions is not None :
457
- assert len (self .log_proportions ) == unit_ids .max () + 1 + self .with_noise_unit
457
+ assert (
458
+ len (self .log_proportions ) == unit_ids .max () + 1 + self .with_noise_unit
459
+ )
458
460
459
461
fit_full_indices , fit_split_indices = quick_indices (
460
462
self .rg ,
@@ -541,11 +543,11 @@ def log_likelihoods(
541
543
unit_neighb_info .append ((j , neighbs , ns_unit ))
542
544
else :
543
545
assert previous_logliks is not None
544
- assert hasattr (previous_logliks , ' row_nnz' )
545
- assert ' covered_neighbs' in unit .annotations
546
+ assert hasattr (previous_logliks , " row_nnz" )
547
+ assert " covered_neighbs" in unit .annotations
546
548
ns_unit = previous_logliks .row_nnz [j ]
547
549
unit_neighb_info .append ((j , ns_unit ))
548
- covered_neighbs = unit .annotations [' covered_neighbs' ]
550
+ covered_neighbs = unit .annotations [" covered_neighbs" ]
549
551
core_overlaps [covered_neighbs ] += 1
550
552
nnz += ns_unit
551
553
@@ -695,12 +697,18 @@ def reassign(self, log_liks):
695
697
unit_churn = 1.0 - iou
696
698
697
699
# update labels
700
+ self .labels .fill_ (- 1 )
698
701
self .labels [spike_ix ] = assignments
699
702
700
703
return unit_churn , reassign_count , spike_logliks , log_liks_csc
701
704
702
705
def cleanup (
703
- self , log_liks = None , min_count = None , clean_props = None , split = "train"
706
+ self ,
707
+ log_liks = None ,
708
+ min_count = None ,
709
+ clean_props = None ,
710
+ split = "train" ,
711
+ relabel_split = "train" ,
704
712
) -> tuple [Optional [csc_array ], Optional [dict ]]:
705
713
"""Remove too-small units, make label space contiguous, tidy all properties"""
706
714
if min_count is None :
@@ -731,7 +739,7 @@ def cleanup(
731
739
kept_ids = label_ids [big_enough ]
732
740
new_ids = torch .arange (kept_ids .numel ())
733
741
old2new = dict (zip (kept_ids , new_ids ))
734
- self ._relabel (kept_ids , split = split )
742
+ self ._relabel (kept_ids , split = relabel_split )
735
743
736
744
if self .log_proportions is not None :
737
745
lps = self .log_proportions .numpy (force = True )
@@ -757,8 +765,6 @@ def cleanup(
757
765
758
766
keep_ll = keep_noise .numpy (force = True )
759
767
assert keep_ll .size == log_liks .shape [0 ]
760
- if keep_ll .all ():
761
- return log_liks , clean_props
762
768
763
769
if isinstance (log_liks , coo_array ):
764
770
log_liks = coo_sparse_mask_rows (log_liks , keep_ll )
@@ -1171,6 +1177,8 @@ def unit_log_likelihoods(
1171
1177
)
1172
1178
unit .annotations ["covered_neighbs" ] = cn
1173
1179
if not ns :
1180
+ if inds_already :
1181
+ return None
1174
1182
return None , None
1175
1183
1176
1184
if inds_already :
@@ -1330,7 +1338,7 @@ def kmeans_split_unit(self, unit_id, debug=False):
1330
1338
1331
1339
split_labels = torch .asarray (split_labels , device = self .labels .device )
1332
1340
n_new_units = split_ids .size - 1
1333
- if n_new_units < 1 :
1341
+ if n_new_units <= 1 :
1334
1342
# quick case
1335
1343
with self .labels_lock :
1336
1344
self .labels [indices_full ] = - 1
@@ -2033,7 +2041,7 @@ def _relabel(self, old_labels, new_labels=None, flat=False, split=None):
2033
2041
label_indices = label_indices [kept ]
2034
2042
2035
2043
unkept = torch .logical_not (kept )
2036
- if split is not None :
2044
+ if split_indices != slice ( None ) :
2037
2045
unkept = split_indices [unkept ]
2038
2046
kept = split_indices [kept ]
2039
2047
@@ -2158,7 +2166,6 @@ def merge_units(
2158
2166
sym_function = self .merge_sym_function ,
2159
2167
show_progress = show_progress ,
2160
2168
)
2161
- print (f"{ group_ids .shape = } { distances .shape = } " )
2162
2169
if debug_info is not None :
2163
2170
debug_info ["Z" ] = Z
2164
2171
debug_info ["improvements" ] = improvements
0 commit comments