Skip to content

Commit

Permalink
PPCA perf
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Jan 15, 2025
1 parent cec7743 commit 0caf553
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 26 deletions.
58 changes: 32 additions & 26 deletions src/dartsort/cluster/ppcalib.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,43 +261,47 @@ def ppca_e_step(
uubar = full_uubar[nd.neighb_members]

# actual data in neighborhood
xc = nd.x - nu
xcc = xc = nd.x - nu
if yes_pca:
Woubar = ubar @ W_o.T
if nd.have_missing:
Wmubar = ubar @ W_m.T
xcc = xcc - Woubar

# we need these ones everywhere
# Cooinvxc = nd.C_oo_chol.solve(xc.T).T
Cooinvxc = xc @ nd.C_oo_inv

# pca-centered data
if yes_pca and nd.have_missing:
CooinvWo = nd.C_oo_chol.solve(W_o)
CooinvWo = nd.C_oo_inv @ W_o
# xcc = torch.addmm(xc, ubar, W_o.T, alpha=-1)
# Cooinvxcc = C_oochol.solve(xcc.T).T
Cooinvxcc = Cooinvxc.addmm(ubar, CooinvWo.T, alpha=-1)
else:
Cooinvxcc = Cooinvxc
# Cooinvxc = xc @ nd.C_oo_inv

# # pca-centered data
# if yes_pca and nd.have_missing:
# CooinvWo = nd.C_oo_chol.solve(W_o)
# CooinvWo = nd.C_oo_inv @ W_o
# # xcc = torch.addmm(xc, ubar, W_o.T, alpha=-1)
# # Cooinvxcc = C_oochol.solve(xcc.T).T
# Cooinvxcc = Cooinvxc.addmm(ubar, CooinvWo.T, alpha=-1)
# else:
# Cooinvxcc = Cooinvxc

# first data moment
if nd.have_missing:
xbar_m = torch.addmm(tnu, Cooinvxcc, nd.C_mo.T)
CooinvCom = nd.C_oo_inv @ nd.C_mo.T
xbar_m = torch.addmm(tnu, xcc, CooinvCom)
if yes_pca:
xbar_m.addmm_(ubar, W_m.T)
xbar_m.add_(Wmubar)

# cross moment
if yes_pca:
e_xcu = xc[:, :, None] * ubar[:, None, :]
if yes_pca and nd.have_missing:
# e_mxcu = (Cooinvxc @ nd.C_mo.T)[:, :, None] * ubar[:, None, :]
# print(f"{e_mxcu.shape=}")
e_mxcu = torch.einsum("ij,kj,il->ikl", Cooinvxc, nd.C_mo, ubar)
# CmoCooinvWo = C_mo @ CooinvWo
# R tilde.
CooinvWo = nd.C_oo_inv @ W_o
Wm_less_CmoCooinvWo = W_m.addmm(nd.C_mo, CooinvWo, beta=-1)
shp = Wm_less_CmoCooinvWo.shape
# e_mxcu += (uubar @ (W_m - CmoCooinvWo).T).mT
Wm_less_CmoCooinvWo = Wm_less_CmoCooinvWo.unsqueeze(0)
Wm_less_CmoCooinvWo = Wm_less_CmoCooinvWo.broadcast_to((len(uubar), *shp))
# e_mxcu += uubar.mT @ Wm_less_CmoCooinvWo
e_mxcu.baddbmm_(Wm_less_CmoCooinvWo, uubar)
coefts = Wm_less_CmoCooinvWo[None].broadcast_to(
(len(uubar), *Wm_less_CmoCooinvWo.shape)
)
e_mxcu = torch.bmm(coefts, uubar)
coefts = CooinvCom.T[None].broadcast_to((len(e_mxcu), *CooinvCom.T.shape))
e_mxcu.baddbmm_(coefts, e_xcu)

# take weighted averages
if yes_pca:
Expand Down Expand Up @@ -379,13 +383,15 @@ def embed(

# we need these ones everywhere
# Cooinvxc = nd.C_oo_chol.solve(xc.T).T
Cooinvxc = xc @ nd.C_oo_inv
# Cooinvxc = xc @ nd.C_oo_inv

# moments of embeddings
# T_inv = eye_M + W_o.T @ nd.C_oo_chol.solve(W_o)
T_inv = eye_M + W_o.T @ nd.C_oo_inv @ W_o
T = torch.linalg.inv(T_inv)
ubar = Cooinvxc @ (W_o @ T)
u_proj = nd.C_oo_inv @ W_o @ T
# ubar = Cooinvxc @ (W_o @ T)
ubar = xc @ u_proj
uubar = torch.baddbmm(T, ubar[:, :, None], ubar[:, None, :])

_ubar[nd.neighb_members] = ubar
Expand Down
5 changes: 5 additions & 0 deletions tests/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,8 @@ def test_mixture():
# mixture_testing_util.test_moppcas(**kw)

# assert False


if __name__ == "__main__":
test_ppca()
test_mixture()

0 comments on commit 0caf553

Please sign in to comment.