Skip to content

Commit

Permalink
tem: Bit by linop again.
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Feb 18, 2025
1 parent 487264d commit 4e67665
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/dartsort/cluster/_truncated_em_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,15 +536,15 @@ def woodbury_kl_divergence(C, mu, W=None, out=None, batch_size=32, affine_ok=Tru
out = mu.new_empty((n, n))
out.fill_(0.0)

Cchol = C.cholesky()#.to_dense()
Cchol = C.cholesky() # .to_dense()
# Some weird issues with solve_triangular using the batched input...
# it seems to allocate something big?
# Ccholinv = torch.linalg.solve_triangular(
# Cchol, torch.eye(len(Cchol), out=torch.empty_like(Cchol)), upper=False
# )
Ccholinv = Cchol.inverse()
Ccholinv = Cchol.inverse().to_dense()[None].broadcast_to((n, *Cchol.shape))
# mu_ = torch.linalg.solve_triangular(Cchol, mu.unsqueeze(2), upper=False)
mu_ = Ccholinv @ mu.unsqueeze(2)
mu_ = Ccholinv.bmm(mu.unsqueeze(2))

if W is None:
# else, better to do this later
Expand All @@ -556,7 +556,7 @@ def woodbury_kl_divergence(C, mu, W=None, out=None, batch_size=32, affine_ok=Tru
M = W.shape[2]
assert W.shape == (n, d, M)
# U = torch.linalg.solve_triangular(Cchol, W, upper=False)
U = Ccholinv @ W
U = Ccholinv.bmm(W)

# first part of trace
UTU = U.mT.bmm(U)
Expand All @@ -570,7 +570,7 @@ def woodbury_kl_divergence(C, mu, W=None, out=None, batch_size=32, affine_ok=Tru
cap.diagonal(dim1=-2, dim2=-1).add_(1.0)
capchol = psd_safe_cholesky(cap)
# V is (n, d, M)
V = torch.linalg.solve_triangular(capchol, U, upper=False, left=False)
V = torch.linalg.solve_triangular(capchol.mT, U, upper=True, left=False)

# log dets via Cholesky
logdet = capchol.diagonal(dim1=-2, dim2=-1).log().sum(dim=1).mul_(2.0)
Expand Down

0 comments on commit 4e67665

Please sign in to comment.