Skip to content

Commit 2ee1858

Browse files
committed
Vis debugging / perf for gmm
1 parent 88e349a commit 2ee1858

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

src/dartsort/vis/gmm.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,16 @@ def draw(self, panel, gmm, unit_id, axes=None):
162162
(in_unit,) = torch.nonzero(gmm.labels == unit_id, as_tuple=True)
163163
if not in_unit.numel():
164164
return
165-
inds_, liks = gmm.unit_log_likelihoods(unit_id, spike_indices=in_unit)
165+
if hasattr(gmm, "log_liks"):
166+
liks_ = gmm.log_liks[:, in_unit][[unit_id]].tocoo()
167+
inds_ = None
168+
if liks_.nnz:
169+
inds_ = in_unit
170+
liks = np.full(in_unit.shape, -np.inf, dtype=np.float32)
171+
liks[liks_.coords[1]] = liks_.data
172+
liks = torch.from_numpy(liks)
173+
else:
174+
inds_, liks = gmm.unit_log_likelihoods(unit_id, spike_indices=in_unit)
166175
if inds_ is None:
167176
return
168177
assert torch.equal(inds_, in_unit)
@@ -215,6 +224,12 @@ def __init__(self, layout="vert"):
215224
def draw(self, panel, gmm, unit_id, split_info=None):
216225
if split_info is None:
217226
split_info = gmm.kmeans_split_unit(unit_id, debug=True)
227+
if not split_info:
228+
ax = panel.subplots()
229+
ax.text(.5, .5, "no channels!", ha="center", transform=ax.transAxes)
230+
ax.axis("off")
231+
return
232+
218233
split_labels = split_info["reas_labels"]
219234
split_ids = np.unique(split_labels)
220235

@@ -370,8 +385,13 @@ def __init__(self, n_neighbors=5):
370385
def draw(self, panel, gmm, unit_id):
371386
neighbors = gmm_helpers.get_neighbors(gmm, unit_id)
372387
assert neighbors[0] == unit_id
373-
log_liks = gmm.log_likelihoods(unit_ids=neighbors)
374-
labels, spikells = gaussian_mixture.loglik_reassign(log_liks, has_noise_unit=True)
388+
if hasattr(gmm, "log_liks"):
389+
neighbors_plus_noiseunit = np.concatenate((neighbors, [gmm.log_liks.shape[0] - 1]))
390+
log_liks = gmm.log_liks[neighbors_plus_noiseunit]
391+
else:
392+
log_liks = gmm.log_likelihoods(unit_ids=neighbors)
393+
labels, spikells, log_liks = gaussian_mixture.loglik_reassign(log_liks, has_noise_unit=True)
394+
log_liks = log_liks.tocoo()
375395
log_liks = gaussian_mixture.coo_to_torch(log_liks, torch.float)
376396
kept = labels >= 0
377397
labels_ = np.full_like(labels, -1)
@@ -409,12 +429,12 @@ def draw(self, panel, gmm, unit_id):
409429
bimod_ax.text(0, 0, f"too-small kept prop {bimod_info['keep_prop']:.2f}")
410430
bimod_ax.axis("off")
411431
continue
412-
bimod_ax.hist(bimod_info["samples"], color="gray", label="unweighted hist", **histkw)
432+
bimod_ax.hist(bimod_info["samples"], color="gray", label="hist", **histkw)
413433
bimod_ax.hist(
414434
bimod_info["samples"],
415435
weights=bimod_info["sample_weights"],
416436
color="k",
417-
label="weighted hist",
437+
label="whist",
418438
**histkw,
419439
)
420440
bimod_ax.axvline(bimod_info["cut"], color="k", lw=0.8, ls=":")
@@ -450,7 +470,7 @@ def make_unit_gmm_summary(
450470
unit_id,
451471
plots=default_gmm_plots,
452472
max_height=9,
453-
figsize=(13, 9),
473+
figsize=(14, 11),
454474
hspace=0.1,
455475
figure=None,
456476
**other_global_params,
@@ -479,7 +499,7 @@ def make_all_gmm_summaries(
479499
save_folder,
480500
plots=default_gmm_plots,
481501
max_height=9,
482-
figsize=(13, 9),
502+
figsize=(14, 11),
483503
hspace=0.1,
484504
dpi=200,
485505
image_ext="png",

src/dartsort/vis/gmm_helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def get_neighbors(gmm, unit_id, n_neighbors=5):
10-
means, covs, logdets = gmm.stack_units()
10+
means, covs, logdets = gmm.stack_units(use_cache=True)
1111
dists = gmm.units[unit_id].divergence(means, covs, logdets, kind=gmm.distance_metric)
1212
dists = dists.view(-1)
1313
order = torch.argsort(dists)
@@ -70,6 +70,7 @@ def plot_means(panel, prgeom, tpca, chans, units, labels, title="nearest neighbo
7070
show_zero=False,
7171
ax=ax,
7272
zlim=None,
73+
subar=True,
7374
)
7475
panel.legend(
7576
handles=[Line2D([0, 1], [0, 0], color=c) for c in colors],

src/dartsort/vis/gt.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,6 @@ def draw(self, panel, comparison):
569569
# box and whisker between sorters
570570

571571

572-
573572
def make_gt_overview_summary(
574573
comparison,
575574
plots=gt_overview_plots,

0 commit comments

Comments
 (0)