Skip to content

Commit 1b11ab1

Browse files
committed
Needs chol wrapper
1 parent 7bcf40e commit 1b11ab1

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/dartsort/cluster/ppcalib.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22

33
import linear_operator
4+
from linear_operator.operators import CholLinearOperator
45
import torch
56
import torch.nn.functional as F
67
from tqdm.auto import trange
@@ -251,8 +252,8 @@ def ppca_e_step(
251252
C_oo = noise.marginal_covariance(
252253
channels=neighb_chans, cache_prefix=cache_prefix, cache_key=nid, device=y.device
253254
)
254-
# TODO: genuinely confused about the need for this.
255-
C_oochol = C_oo.cholesky()
255+
# TODO: genuinely confused about the need for this. why doesn't solve() use this cached object?
256+
C_oochol = CholLinearOperator(C_oo.cholesky())
256257
nu = active_mean[:, active_subset].reshape(D_neighb)
257258
if have_missing:
258259
C_mo = noise.offdiag_covariance(
@@ -283,9 +284,9 @@ def ppca_e_step(
283284
# pca-centered data
284285
if yes_pca and have_missing:
285286
CooinvWo = C_oochol.solve(W_o)
286-
# xcc = torch.addmm(xc, ubar, W_o.T, alpha=-1)
287-
# Cooinvxcc = C_oochol.solve(xcc.T).T
288-
Cooinvxcc = Cooinvxc.addmm(ubar, CooinvWo.T, alpha=-1)
287+
xcc = torch.addmm(xc, ubar, W_o.T, alpha=-1)
288+
Cooinvxcc = C_oochol.solve(xcc.T).T
289+
# Cooinvxcc = Cooinvxc.addmm(ubar, CooinvWo.T, alpha=-1)
289290
else:
290291
Cooinvxcc = Cooinvxc
291292

0 commit comments

Comments
 (0)