Skip to content

Commit 5105e3c

Browse files
committed
Labeling bug
1 parent 84d268e commit 5105e3c

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

src/dartsort/cluster/gaussian_mixture.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def em(
410410
unit_churn, reas_count, log_liks, spike_logliks = self.e_step(
411411
show_progress=step_progress, split=final_split
412412
)
413-
log_liks, _ = self.cleanup(log_liks)
413+
log_liks, _ = self.cleanup(log_liks, relabel_split=final_split)
414414
return log_liks
415415

416416
def e_step(
@@ -454,7 +454,9 @@ def m_step(
454454
if self.use_proportions and likelihoods is not None:
455455
self.update_proportions(likelihoods)
456456
if self.log_proportions is not None:
457-
assert len(self.log_proportions) == unit_ids.max() + 1 + self.with_noise_unit
457+
assert (
458+
len(self.log_proportions) == unit_ids.max() + 1 + self.with_noise_unit
459+
)
458460

459461
fit_full_indices, fit_split_indices = quick_indices(
460462
self.rg,
@@ -541,11 +543,11 @@ def log_likelihoods(
541543
unit_neighb_info.append((j, neighbs, ns_unit))
542544
else:
543545
assert previous_logliks is not None
544-
assert hasattr(previous_logliks, 'row_nnz')
545-
assert 'covered_neighbs' in unit.annotations
546+
assert hasattr(previous_logliks, "row_nnz")
547+
assert "covered_neighbs" in unit.annotations
546548
ns_unit = previous_logliks.row_nnz[j]
547549
unit_neighb_info.append((j, ns_unit))
548-
covered_neighbs = unit.annotations['covered_neighbs']
550+
covered_neighbs = unit.annotations["covered_neighbs"]
549551
core_overlaps[covered_neighbs] += 1
550552
nnz += ns_unit
551553

@@ -695,12 +697,18 @@ def reassign(self, log_liks):
695697
unit_churn = 1.0 - iou
696698

697699
# update labels
700+
self.labels.fill_(-1)
698701
self.labels[spike_ix] = assignments
699702

700703
return unit_churn, reassign_count, spike_logliks, log_liks_csc
701704

702705
def cleanup(
703-
self, log_liks=None, min_count=None, clean_props=None, split="train"
706+
self,
707+
log_liks=None,
708+
min_count=None,
709+
clean_props=None,
710+
split="train",
711+
relabel_split="train",
704712
) -> tuple[Optional[csc_array], Optional[dict]]:
705713
"""Remove too-small units, make label space contiguous, tidy all properties"""
706714
if min_count is None:
@@ -731,7 +739,7 @@ def cleanup(
731739
kept_ids = label_ids[big_enough]
732740
new_ids = torch.arange(kept_ids.numel())
733741
old2new = dict(zip(kept_ids, new_ids))
734-
self._relabel(kept_ids, split=split)
742+
self._relabel(kept_ids, split=relabel_split)
735743

736744
if self.log_proportions is not None:
737745
lps = self.log_proportions.numpy(force=True)
@@ -757,8 +765,6 @@ def cleanup(
757765

758766
keep_ll = keep_noise.numpy(force=True)
759767
assert keep_ll.size == log_liks.shape[0]
760-
if keep_ll.all():
761-
return log_liks, clean_props
762768

763769
if isinstance(log_liks, coo_array):
764770
log_liks = coo_sparse_mask_rows(log_liks, keep_ll)
@@ -1171,6 +1177,8 @@ def unit_log_likelihoods(
11711177
)
11721178
unit.annotations["covered_neighbs"] = cn
11731179
if not ns:
1180+
if inds_already:
1181+
return None
11741182
return None, None
11751183

11761184
if inds_already:
@@ -1330,7 +1338,7 @@ def kmeans_split_unit(self, unit_id, debug=False):
13301338

13311339
split_labels = torch.asarray(split_labels, device=self.labels.device)
13321340
n_new_units = split_ids.size - 1
1333-
if n_new_units < 1:
1341+
if n_new_units <= 1:
13341342
# quick case
13351343
with self.labels_lock:
13361344
self.labels[indices_full] = -1
@@ -2033,7 +2041,7 @@ def _relabel(self, old_labels, new_labels=None, flat=False, split=None):
20332041
label_indices = label_indices[kept]
20342042

20352043
unkept = torch.logical_not(kept)
2036-
if split is not None:
2044+
if split_indices != slice(None):
20372045
unkept = split_indices[unkept]
20382046
kept = split_indices[kept]
20392047

@@ -2158,7 +2166,6 @@ def merge_units(
21582166
sym_function=self.merge_sym_function,
21592167
show_progress=show_progress,
21602168
)
2161-
print(f"{group_ids.shape=} {distances.shape=}")
21622169
if debug_info is not None:
21632170
debug_info["Z"] = Z
21642171
debug_info["improvements"] = improvements

0 commit comments

Comments
 (0)