Skip to content

Commit 932a8b6

Browse files
committed
PPCA parens
1 parent 0caf553 commit 932a8b6

File tree

2 files changed

+52
-34
lines changed

2 files changed

+52
-34
lines changed

src/dartsort/cluster/gaussian_mixture.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -453,10 +453,6 @@ def m_step(
453453

454454
if self.use_proportions and likelihoods is not None:
455455
self.update_proportions(likelihoods)
456-
if self.log_proportions is not None:
457-
assert (
458-
len(self.log_proportions) == unit_ids.max() + 1 + self.with_noise_unit
459-
)
460456

461457
fit_full_indices, fit_split_indices = quick_indices(
462458
self.rg,

src/dartsort/cluster/ppcalib.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional
33

44
import linear_operator
5+
from linear_operator import operators
56
from linear_operator.operators import CholLinearOperator
67
import torch
78
import torch.nn.functional as F
@@ -10,7 +11,7 @@
1011

1112
from ..util.noise_util import EmbeddedNoise
1213
from .stable_features import SpikeFeatures, SpikeNeighborhoods
13-
from ..util import spiketorch
14+
from ..util import spiketorch, more_operators
1415

1516
vecdot = torch.linalg.vecdot
1617

@@ -29,7 +30,7 @@ def ppca_em(
2930
mean_prior_pseudocount=0.0,
3031
show_progress=False,
3132
W_initialization="svd",
32-
normalize=False,
33+
normalize=True,
3334
em_converged_atol=0.1,
3435
prior_var=1.0,
3536
cache_global_direct=True,
@@ -145,6 +146,7 @@ def ppca_em(
145146
break
146147
if show_progress:
147148
iters.set_description(f"PPCA[{dmu=:.2g}, {dW=:.2g}]")
149+
print(i, dmu, dW)
148150

149151
if normalize and any_missing and state["W"] is not None:
150152
_, _, state["W"], state["mu"] = embed(
@@ -156,6 +158,7 @@ def ppca_em(
156158
state["W"],
157159
state["mu"],
158160
active_channels=active_channels,
161+
ess=ess,
159162
active_cov_chol_factor=active_cov_chol_factor,
160163
prior_var=prior_var,
161164
normalize=normalize,
@@ -216,18 +219,18 @@ def ppca_e_step(
216219
D = rank * nc
217220

218221
# get normalized weights
219-
y = sp.features
222+
new_zeros = sp.features.new_zeros
220223

221224
# we will build our outputs by iterating over the unique
222225
# neighborhoods and adding weighted sums of moments in each
223-
e_y = y.new_zeros((rank, nc))
226+
e_y = new_zeros((rank, nc))
224227
yc = e_u = e_ycu = e_uu = None
225228
if return_yc:
226-
yc = y.new_zeros((n, rank, nc))
229+
yc = new_zeros((n, rank, nc))
227230
if yes_pca:
228-
e_u = y.new_zeros((M,))
229-
e_ycu = y.new_zeros((rank, nc, M))
230-
e_uu = y.new_zeros((M, M))
231+
e_u = new_zeros((M,))
232+
e_ycu = new_zeros((rank, nc, M))
233+
e_uu = new_zeros((M, M))
231234

232235
# helpful tensors to keep around
233236
if yes_pca:
@@ -257,8 +260,8 @@ def ppca_e_step(
257260
W_m = active_W[:, nd.missing_subset].reshape(D - nd.D_neighb, M)
258261

259262
if yes_pca:
260-
ubar = full_ubar[nd.neighb_members]
261-
uubar = full_uubar[nd.neighb_members]
263+
ubar = full_ubar[nd.u_slice]
264+
uubar = full_uubar[nd.u_slice]
262265

263266
# actual data in neighborhood
264267
xcc = xc = nd.x - nu
@@ -312,7 +315,7 @@ def ppca_e_step(
312315
wx = nd.w_norm @ nd.x
313316
if nd.have_missing:
314317
wxbar_m = nd.w_norm @ xbar_m
315-
ybar = y.new_zeros((rank, nc))
318+
ybar = new_zeros((rank, nc))
316319
ybar[:, nd.active_subset] = wx.view(rank, nd.neighb_nc)
317320
ybar[:, nd.missing_subset] = wxbar_m.view(rank, nc - nd.neighb_nc)
318321
else:
@@ -324,7 +327,7 @@ def ppca_e_step(
324327
if nd.have_missing and yes_pca:
325328
wmxcu = nd.w_norm @ e_mxcu.reshape(nd.neighb_n_spikes, -1)
326329
wmxcu = wmxcu.view(e_mxcu.shape[1:])
327-
ycubar = y.new_zeros((rank, nc, M))
330+
ycubar = new_zeros((rank, nc, M))
328331
ycubar[:, nd.active_subset] = wxcu.view(rank, nd.neighb_nc, M)
329332
ycubar[:, nd.missing_subset] = wmxcu.view(rank, nc - nd.neighb_nc, M)
330333
elif yes_pca:
@@ -334,12 +337,12 @@ def ppca_e_step(
334337
if return_yc:
335338
if nd.have_missing:
336339
xc = xc.view(nd.neighb_n_spikes, rank, nd.neighb_nc).mT
337-
yc[nd.neighb_members[:, None], :, nd.active_subset[None, :]] = xc
340+
yc[nd.u_slice][:, :, nd.active_subset[None, :]] = xc
338341
xbar_m -= tnu
339342
txc = xbar_m.view(nd.neighb_n_spikes, rank, nc - nd.neighb_nc).mT
340-
yc[nd.neighb_members[:, None], :, nd.missing_subset[None, :]] = txc
343+
yc[nd.u_slice][:, :, nd.missing_subset[None, :]] = txc
341344
else:
342-
yc[nd.neighb_members] = xc.view(nd.neighb_n_spikes, rank, nd.neighb_nc)
345+
yc[nd.u_slice] = xc.view(nd.neighb_n_spikes, rank, nd.neighb_nc)
343346

344347
# accumulate results
345348
e_y += ybar
@@ -367,14 +370,18 @@ def embed(
367370
scratch=None,
368371
):
369372
N = len(sp)
373+
new_zeros = sp.features.new_zeros
374+
device = sp.features.device
375+
dtype = sp.features.dtype
376+
370377
if scratch is not None:
371378
_ubar, _uubar = scratch
372379
else:
373-
_ubar = sp.features.new_zeros((N, M))
380+
_ubar = features.new_zeros((N, M))
374381
# if not normalize:
375-
_uubar = sp.features.new_zeros(N, M, M)
376-
# _T = sp.features.new_zeros((N, M, M))
377-
eye_M = prior_var * torch.eye(M, device=sp.features.device, dtype=sp.features.dtype)
382+
_uubar = features.new_zeros(N, M, M)
383+
eye_M_ = torch.eye(M, device=device, dtype=dtype)
384+
eye_M = prior_var * eye_M_
378385

379386
for nd in neighb_data:
380387
nu = active_mean[:, nd.active_subset].reshape(nd.D_neighb)
@@ -388,14 +395,22 @@ def embed(
388395
# moments of embeddings
389396
# T_inv = eye_M + W_o.T @ nd.C_oo_chol.solve(W_o)
390397
T_inv = eye_M + W_o.T @ nd.C_oo_inv @ W_o
391-
T = torch.linalg.inv(T_inv)
392-
u_proj = nd.C_oo_inv @ W_o @ T
398+
# root = operators.LowRankRootLinearOperator(W_o.T @ nd.C_oo_cholinv)
399+
# print(f"{root.shape=} {I_M.shape=}")
400+
# helper = root + I_M
401+
# helper = operators.LowRankRootSumLinearOperator(I_M
402+
# print(f"{T_inv.shape=}")
403+
# T = helper.solve(eye_M_)
404+
T, info = torch.linalg.inv_ex(T_inv)
405+
u_proj = nd.C_oo_inv @ (W_o @ T)
393406
# ubar = Cooinvxc @ (W_o @ T)
394-
ubar = xc @ u_proj
395-
uubar = torch.baddbmm(T, ubar[:, :, None], ubar[:, None, :])
407+
# ubar = xc @ u_proj
408+
# uubar = torch.baddbmm(T, ubar[:, :, None], ubar[:, None, :])
396409

397-
_ubar[nd.neighb_members] = ubar
398-
_uubar[nd.neighb_members] = uubar
410+
# _ubar[nd.u_slice] = ubar
411+
# _uubar[nd.u_slice] = uubar
412+
torch.mm(xc, u_proj, out=_ubar[nd.u_slice])
413+
torch.baddbmm(T, _ubar[nd.u_slice].unsqueeze(2), _ubar[nd.u_slice].unsqueeze(1), out=_uubar[nd.u_slice])
399414

400415
if normalize:
401416
if active_cov_chol_factor is None:
@@ -412,14 +427,14 @@ def embed(
412427
# active_mean = active_mean + W @ um
413428

414429
# whitening. need to do a GEVP to start...
415-
S = (weights @ _uubar.view(N, M * M)).view(N, M, M)
430+
S = (weights @ _uubar.view(N, M * M)).view(M, M)
416431
Dx, U = torch.linalg.eigh(S)
417432
Dx = Dx.flip(dims=(0,))
418433
U = U.flip(dims=(1,))
419434
U.mul_(sgn(U[0]))
420435
UDxrt = U * Dx.sqrt()
421436
rhs = Wflat @ UDxrt.T
422-
gevp_W_right = torch.linalg.solve_triangular(active_cov_chol_factor, rhs)
437+
gevp_W_right = torch.linalg.solve_triangular(active_cov_chol_factor, rhs, upper=False)
423438
gevp_W = gevp_W_right.T @ gevp_W_right
424439
# gevp_W = linear_operator.solve(lhs=rhs.T, input=active_cov, rhs=rhs)
425440
Dw, V = torch.linalg.eigh(gevp_W)
@@ -435,7 +450,8 @@ def embed(
435450
W @= W_tf
436451
_ubar @= u_tf
437452
_uubar = torch.einsum("nij,ip,jq->npq", _uubar, u_tf, u_tf)
438-
active_mean.addmm_(W, um)
453+
active_mean += W @ um
454+
# .addmm_(W.view(-1, M), um.unsqueeze(1))
439455

440456
return _ubar, _uubar, W, active_mean
441457

@@ -449,11 +465,13 @@ class NeighborhoodPPCAData:
449465

450466
C_oo: linear_operator.LinearOperator
451467
C_oo_chol: CholLinearOperator
468+
C_oo_cholinv: torch.Tensor
452469
C_oo_inv: CholLinearOperator
453470
w: torch.Tensor
454471
w_norm: torch.Tensor
455472
x: torch.Tensor
456473
neighb_members: torch.Tensor
474+
u_slice: torch.Tensor
457475

458476
C_mo: Optional[torch.Tensor]
459477
active_subset: Optional[torch.Tensor]
@@ -516,14 +534,15 @@ def get_neighborhood_data(
516534

517535
neighborhood_data = []
518536
ess = weights.sum()
537+
n_start = 0
519538
for chans_tuple, chans_data in dedup_data.items():
520539
*info, xs, mems = chans_data
521540
nid, neighb_chans, active_subset, can_cache_by_neighborhood, have_missing = info
522541
if len(mems) > 1:
523542
x = torch.concatenate(xs)
524543
neighb_members = torch.concatenate(mems)
525-
neighb_members, order = neighb_members.sort()
526-
x = x[order]
544+
# neighb_members, order = neighb_members.sort()
545+
# x = x[order]
527546
nid = None
528547
else:
529548
x = xs[0]
@@ -578,16 +597,19 @@ def get_neighborhood_data(
578597
have_missing=have_missing,
579598
C_oo=C_oo,
580599
C_oo_chol=C_oo_chol,
600+
C_oo_cholinv=Linv,
581601
C_oo_inv=C_oo_inv,
582602
w=w,
583603
w_norm=w / ess,
584604
x=x,
585605
neighb_members=neighb_members,
606+
u_slice=slice(n_start, n_start + n_neighb),
586607
C_mo=C_mo,
587608
active_subset=active_subset,
588609
missing_subset=missing_subset,
589610
)
590611
neighborhood_data.append(nd)
612+
n_start += n_neighb
591613

592614
return neighborhood_data
593615

0 commit comments

Comments
 (0)