Skip to content

Commit cec7743

Browse files
committed
Debugging
1 parent 5105e3c commit cec7743

File tree

5 files changed

+57
-28
lines changed

5 files changed

+57
-28
lines changed

src/dartsort/cluster/gaussian_mixture.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ def reassign(self, log_liks):
683683
# intersection
684684
n_units = max(log_liks.shape[0] - self.with_noise_unit, original.max() + 1)
685685
intersection = torch.zeros(n_units, dtype=int)
686-
spiketorch.add_at_(intersection, assignments[kept], original[kept])
686+
spiketorch.add_at_(intersection, assignments[kept], same[kept])
687687

688688
# union by include/exclude
689689
union = torch.zeros_like(intersection)
@@ -783,6 +783,8 @@ def merge(self, log_liks=None, show_progress=True):
783783
new_labels, new_ids = self.merge_units(
784784
likelihoods=log_liks, show_progress=show_progress
785785
)
786+
if new_labels is None:
787+
return
786788
self.labels.copy_(torch.asarray(new_labels))
787789

788790
unique_new_ids = np.unique(new_ids)
@@ -1324,6 +1326,8 @@ def kmeans_split_unit(self, unit_id, debug=False):
13241326
debug=debug,
13251327
debug_info=result,
13261328
)
1329+
if split_labels is None:
1330+
return result
13271331
split_ids, split_counts = np.unique(split_labels, return_counts=True)
13281332
valid = split_ids >= 0
13291333
if not valid.any():
@@ -1598,6 +1602,12 @@ def tree_merge(
15981602
# heuristic unit groupings to investigate
15991603
distances = sym_function(distances, distances.T)
16001604
distances = distances[np.triu_indices(len(distances), k=1)]
1605+
finite = np.isfinite(distances)
1606+
if not finite.any():
1607+
return None, None, None
1608+
if not finite.all():
1609+
inf = max(0, distances[finite].max()) + max_distance + 1
1610+
distances[np.logical_not(finite)] = inf
16011611
Z = linkage(distances)
16021612
n_units = len(Z) + 1
16031613

@@ -1869,16 +1879,17 @@ def merge_criteria(
18691879
units,
18701880
merged_unit,
18711881
spikes_core[keep],
1872-
self.data.neighborhoods(),
1882+
self.data.neighborhoods()[1],
18731883
use_proportions=self.use_proportions,
18741884
reduce=False,
18751885
)
18761886
class_w = class_w.cpu()
18771887
if lik_weights is not None:
18781888
lik_weights = lik_weights.cpu()
18791889
labixs = labixs.cpu()
1880-
k_full = class_sum(labids, labixs, k_full, lik_weights) / class_w
1881-
k_merged = class_sum(labids, labixs, k_merged, lik_weights) / class_w
1890+
labids = labids.cpu()
1891+
k_full = class_sum(labids, labixs, k_full.cpu(), lik_weights) / class_w
1892+
k_merged = class_sum(labids, labixs, k_merged.cpu(), lik_weights) / class_w
18821893
if class_balancing == "worst":
18831894
k_full = k_full[worst_ix]
18841895
k_merged = k_merged[worst_ix]
@@ -2144,6 +2155,9 @@ def merge_units(
21442155
debug_info["distances"] = distances
21452156
if distances.shape[0] == 1:
21462157
return None, None
2158+
pdist = distances[np.triu_indices(len(distances), k=1)]
2159+
if not (pdist <= self.merge_distance_threshold).any():
2160+
return None, None
21472161

21482162
if merge_kind == "hierarchical":
21492163
return self.hierarchical_bimodality_merge(
@@ -2397,10 +2411,10 @@ def fit(
23972411
weights = weights[kept]
23982412

23992413
if self.channels_strategy.endswith("fuzzcore"):
2400-
achans_full = occupied_chans(
2414+
achans_full, _ = occupied_chans(
24012415
features, self.n_channels, neighborhoods=neighborhoods
24022416
)
2403-
achans = occupied_chans(
2417+
achans, _ = occupied_chans(
24042418
features,
24052419
neighborhood_ids=core_neighborhood_ids,
24062420
n_channels=self.n_channels,
@@ -2410,15 +2424,15 @@ def fit(
24102424
achans = achans[spiketorch.isin_sorted(achans, achans_full)]
24112425
needs_direct = True
24122426
elif self.channels_strategy.endswith("core"):
2413-
achans = occupied_chans(
2427+
achans, _ = occupied_chans(
24142428
features,
24152429
neighborhood_ids=core_neighborhood_ids,
24162430
n_channels=self.n_channels,
24172431
neighborhoods=core_neighborhoods,
24182432
)
24192433
needs_direct = True
24202434
else:
2421-
achans = occupied_chans(
2435+
achans, _ = occupied_chans(
24222436
features, self.n_channels, neighborhoods=neighborhoods
24232437
)
24242438
needs_direct = False
@@ -2428,7 +2442,7 @@ def fit(
24282442
do_pca = self.cov_kind == "ppca" and self.ppca_rank
24292443

24302444
active_mean = active_W = None
2431-
if hasattr(self, "mean"):
2445+
if hasattr(self, "mean") and self.ppca_warm_start:
24322446
active_mean = self.mean[:, achans]
24332447
if hasattr(self, "W") and self.ppca_warm_start:
24342448
active_W = self.W[:, achans]
@@ -2538,6 +2552,8 @@ def logdet(self, channels=None):
25382552

25392553
def log_likelihood(self, features, channels, neighborhood_id=None) -> torch.Tensor:
25402554
"""Log likelihood for spike features living on the same channels."""
2555+
if not len(features):
2556+
return features.new_zeros((0,))
25412557
mean = self.noise.mean_full[:, channels]
25422558
if self.mean_kind == "full":
25432559
mean = mean + self.mean[:, channels]
@@ -2736,7 +2752,7 @@ def get_average_parameter_counts(
27362752
def class_sum(classes, inverse_inds, x, weights=None):
27372753
wsum = x.new_zeros(len(classes))
27382754
x = x * weights if weights is not None else x
2739-
spiketorch.add_at_(wsum, inverse_inds, x)
2755+
spiketorch.add_at_(wsum, inverse_inds.to(x.device), x)
27402756
return wsum
27412757

27422758

src/dartsort/cluster/ppcalib.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,13 @@ def ppca_e_step(
264264
xc = nd.x - nu
265265

266266
# we need these ones everywhere
267-
Cooinvxc = nd.C_oo_chol.solve(xc.T).T
267+
# Cooinvxc = nd.C_oo_chol.solve(xc.T).T
268+
Cooinvxc = xc @ nd.C_oo_inv
268269

269270
# pca-centered data
270271
if yes_pca and nd.have_missing:
271272
CooinvWo = nd.C_oo_chol.solve(W_o)
273+
CooinvWo = nd.C_oo_inv @ W_o
272274
# xcc = torch.addmm(xc, ubar, W_o.T, alpha=-1)
273275
# Cooinvxcc = C_oochol.solve(xcc.T).T
274276
Cooinvxcc = Cooinvxc.addmm(ubar, CooinvWo.T, alpha=-1)
@@ -285,7 +287,9 @@ def ppca_e_step(
285287
if yes_pca:
286288
e_xcu = xc[:, :, None] * ubar[:, None, :]
287289
if yes_pca and nd.have_missing:
288-
e_mxcu = (Cooinvxc @ nd.C_mo.T)[:, :, None] * ubar[:, None, :]
290+
# e_mxcu = (Cooinvxc @ nd.C_mo.T)[:, :, None] * ubar[:, None, :]
291+
# print(f"{e_mxcu.shape=}")
292+
e_mxcu = torch.einsum("ij,kj,il->ikl", Cooinvxc, nd.C_mo, ubar)
289293
# CmoCooinvWo = C_mo @ CooinvWo
290294
Wm_less_CmoCooinvWo = W_m.addmm(nd.C_mo, CooinvWo, beta=-1)
291295
shp = Wm_less_CmoCooinvWo.shape
@@ -314,7 +318,7 @@ def ppca_e_step(
314318
wxcu = nd.w_norm @ e_xcu.view(nd.neighb_n_spikes, -1)
315319
wxcu = wxcu.view(e_xcu.shape[1:])
316320
if nd.have_missing and yes_pca:
317-
wmxcu = nd.w_norm @ e_mxcu.view(nd.neighb_n_spikes, -1)
321+
wmxcu = nd.w_norm @ e_mxcu.reshape(nd.neighb_n_spikes, -1)
318322
wmxcu = wmxcu.view(e_mxcu.shape[1:])
319323
ycubar = y.new_zeros((rank, nc, M))
320324
ycubar[:, nd.active_subset] = wxcu.view(rank, nd.neighb_nc, M)
@@ -374,10 +378,12 @@ def embed(
374378
xc = nd.x - nu
375379

376380
# we need these ones everywhere
377-
Cooinvxc = nd.C_oo_chol.solve(xc.T).T
381+
# Cooinvxc = nd.C_oo_chol.solve(xc.T).T
382+
Cooinvxc = xc @ nd.C_oo_inv
378383

379384
# moments of embeddings
380-
T_inv = eye_M + W_o.T @ nd.C_oo_chol.solve(W_o)
385+
# T_inv = eye_M + W_o.T @ nd.C_oo_chol.solve(W_o)
386+
T_inv = eye_M + W_o.T @ nd.C_oo_inv @ W_o
381387
T = torch.linalg.inv(T_inv)
382388
ubar = Cooinvxc @ (W_o @ T)
383389
uubar = torch.baddbmm(T, ubar[:, :, None], ubar[:, None, :])
@@ -437,6 +443,7 @@ class NeighborhoodPPCAData:
437443

438444
C_oo: linear_operator.LinearOperator
439445
C_oo_chol: CholLinearOperator
446+
C_oo_inv: CholLinearOperator
440447
w: torch.Tensor
441448
w_norm: torch.Tensor
442449
x: torch.Tensor
@@ -482,13 +489,11 @@ def get_neighborhood_data(
482489
# subset of active chans which are in the neighborhood
483490
active_subset = spiketorch.isin_sorted(active_channels, neighb_chans)
484491

485-
w = weights[neighb_members]
486492
x = sp.features[neighb_members][:, :, neighb_subset]
487493

488494
chans_tuple = tuple(active_channels[active_subset].tolist())
489495
if chans_tuple in dedup_data:
490-
*info, ws, xs, mems = dedup_data[chans_tuple]
491-
ws.append(w)
496+
*info, xs, mems = dedup_data[chans_tuple]
492497
xs.append(x)
493498
mems.append(neighb_members)
494499
else:
@@ -499,23 +504,22 @@ def get_neighborhood_data(
499504
active_subset,
500505
can_cache_by_neighborhood,
501506
have_missing,
502-
[w],
503507
[x],
504508
[neighb_members],
505509
)
506510

507511
neighborhood_data = []
508512
ess = weights.sum()
509513
for chans_tuple, chans_data in dedup_data.items():
510-
*info, ws, xs, mems = chans_data
514+
*info, xs, mems = chans_data
511515
nid, neighb_chans, active_subset, can_cache_by_neighborhood, have_missing = info
512-
if len(ws) > 1:
513-
w = torch.concatenate(ws)
516+
if len(mems) > 1:
514517
x = torch.concatenate(xs)
515518
neighb_members = torch.concatenate(mems)
519+
neighb_members, order = neighb_members.sort()
520+
x = x[order]
516521
nid = None
517522
else:
518-
w = ws[0]
519523
x = xs[0]
520524
neighb_members = mems[0]
521525

@@ -547,7 +551,10 @@ def get_neighborhood_data(
547551
channels=neighb_chans, device=device, **cache_kw
548552
)
549553
assert C_oo.shape == (D_neighb, D_neighb)
550-
C_oo_chol = CholLinearOperator(C_oo.cholesky())
554+
chol = C_oo.cholesky(upper=False)
555+
C_oo_chol = CholLinearOperator(chol)
556+
Linv = chol.inverse().to_dense()
557+
C_oo_inv = Linv.T @ Linv
551558
w = weights[neighb_members]
552559
C_mo = None
553560
if have_missing:
@@ -565,6 +572,7 @@ def get_neighborhood_data(
565572
have_missing=have_missing,
566573
C_oo=C_oo,
567574
C_oo_chol=C_oo_chol,
575+
C_oo_inv=C_oo_inv,
568576
w=w,
569577
w_norm=w / ess,
570578
x=x,

src/dartsort/cluster/refine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,11 @@ def refine_clustering(
7171
gmm.cleanup()
7272
for it in range(refinement_config.n_total_iters):
7373
log_liks = gmm.em()
74-
gmm.split()
75-
log_liks = gmm.em()
74+
if log_liks.shape[0] > refinement_config.max_avg_units * recording.get_num_channels():
75+
print(f"{log_liks.shape=}, skipping split.")
76+
else:
77+
gmm.split()
78+
log_liks = gmm.em()
7679
gmm.merge(log_liks)
7780
gmm.em(final_split="full")
7881
gmm.cpu()

src/dartsort/cluster/stable_features.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -764,13 +764,14 @@ def occupied_chans(
764764
neighborhood_ids = spike_data.neighborhood_ids
765765
ids = torch.unique(neighborhood_ids)
766766
chans = neighborhoods.neighborhoods[ids]
767-
chans = torch.unique(chans)
767+
chans, counts = torch.unique(chans, return_counts=True)
768+
counts = counts[chans < n_channels]
768769
chans = chans[chans < n_channels]
769770
for _ in range(fuzz):
770771
chans = neighborhoods.channel_index[chans]
771772
chans = torch.unique(chans)
772773
chans = chans[chans < n_channels]
773-
return chans
774+
return chans, counts
774775

775776

776777
def interp_to_chans(

src/dartsort/util/internal_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ class RefinementConfig:
325325
interpolation_sigma: float = 20.0
326326
val_proportion: float = 0.25
327327
max_n_spikes: float | int = argfield(default=4_000_000, arg_type=int_or_inf)
328+
max_avg_units: int = 8
328329

329330
# model params
330331
channels_strategy: str = "count"

0 commit comments

Comments
 (0)