Skip to content

Commit

Permalink
tem: Debug chols
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Feb 18, 2025
1 parent ecc4afc commit 487264d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
14 changes: 11 additions & 3 deletions src/dartsort/cluster/_truncated_em_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,15 @@ def woodbury_kl_divergence(C, mu, W=None, out=None, batch_size=32, affine_ok=Tru
out = mu.new_empty((n, n))
out.fill_(0.0)

Cchol = C.cholesky().to_dense()
mu_ = torch.linalg.solve_triangular(Cchol, mu.unsqueeze(2), upper=False)
Cchol = C.cholesky()#.to_dense()
# Some weird issues with solve_triangular using the batched input...
# it seems to allocate something big?
# Ccholinv = torch.linalg.solve_triangular(
# Cchol, torch.eye(len(Cchol), out=torch.empty_like(Cchol)), upper=False
# )
Ccholinv = Cchol.inverse()
# mu_ = torch.linalg.solve_triangular(Cchol, mu.unsqueeze(2), upper=False)
mu_ = Ccholinv @ mu.unsqueeze(2)

if W is None:
# else, better to do this later
Expand All @@ -548,7 +555,8 @@ def woodbury_kl_divergence(C, mu, W=None, out=None, batch_size=32, affine_ok=Tru

M = W.shape[2]
assert W.shape == (n, d, M)
U = torch.linalg.solve_triangular(Cchol, W, upper=False)
# U = torch.linalg.solve_triangular(Cchol, W, upper=False)
U = Ccholinv @ W

# first part of trace
UTU = U.mT.bmm(U)
Expand Down
4 changes: 2 additions & 2 deletions src/dartsort/cluster/truncated_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def __init__(
noise: EmbeddedNoise,
M: int = 0,
n_candidates: int = 3,
n_search: int = None,
n_search: int = 5,
n_explore: int = None,
covariance_radius: Optional[float] = 500.0,
random_seed=0,
n_threads: int = 0,
batch_size=2048,
exact_kl=False,
exact_kl=True,
):
super().__init__()
if n_search is None:
Expand Down

0 comments on commit 487264d

Please sign in to comment.