|
1 | 1 | import warnings
|
2 | 2 |
|
3 | 3 | import linear_operator
|
| 4 | +from linear_operator.operators import CholLinearOperator |
4 | 5 | import torch
|
5 | 6 | import torch.nn.functional as F
|
6 | 7 | from tqdm.auto import trange
|
@@ -251,8 +252,8 @@ def ppca_e_step(
|
251 | 252 | C_oo = noise.marginal_covariance(
|
252 | 253 | channels=neighb_chans, cache_prefix=cache_prefix, cache_key=nid, device=y.device
|
253 | 254 | )
|
254 |
| - # TODO: genuinely confused about the need for this. |
255 |
| - C_oochol = C_oo.cholesky() |
| 255 | + # TODO: genuinely confused about the need for this. why doesn't solve() use this cached object? |
| 256 | + C_oochol = CholLinearOperator(C_oo.cholesky()) |
256 | 257 | nu = active_mean[:, active_subset].reshape(D_neighb)
|
257 | 258 | if have_missing:
|
258 | 259 | C_mo = noise.offdiag_covariance(
|
@@ -283,9 +284,9 @@ def ppca_e_step(
|
283 | 284 | # pca-centered data
|
284 | 285 | if yes_pca and have_missing:
|
285 | 286 | CooinvWo = C_oochol.solve(W_o)
|
286 |
| - # xcc = torch.addmm(xc, ubar, W_o.T, alpha=-1) |
287 |
| - # Cooinvxcc = C_oochol.solve(xcc.T).T |
288 |
| - Cooinvxcc = Cooinvxc.addmm(ubar, CooinvWo.T, alpha=-1) |
| 287 | + xcc = torch.addmm(xc, ubar, W_o.T, alpha=-1) |
| 288 | + Cooinvxcc = C_oochol.solve(xcc.T).T |
| 289 | + # Cooinvxcc = Cooinvxc.addmm(ubar, CooinvWo.T, alpha=-1) |
289 | 290 | else:
|
290 | 291 | Cooinvxcc = Cooinvxc
|
291 | 292 |
|
|
0 commit comments