Skip to content

Commit

Permalink
package: 3.10 compat
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Feb 21, 2025
1 parent c028ab3 commit c0ff3ea
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def tem(
)
del tmm
import gc

print(f"done setting units")

gc.collect()
Expand Down Expand Up @@ -1894,7 +1895,7 @@ def merge_criteria(

data = full_logliks_sp.values()
full_logliks = data.new_full(full_logliks_sp.shape, -torch.inf)
full_logliks[*full_logliks_sp.indices()] = data
full_logliks[full_logliks_sp.indices()] = data
full_logliks = full_logliks[unit_ids].sub_(prop_correction)

lik_weights = torch.sparse.softmax(full_logliks_sp, dim=dim_units)
Expand Down Expand Up @@ -2079,7 +2080,7 @@ def get_log_likelihoods(
inds = liks.indices()
data = liks.values()
liks = data.new_full(liks.shape, -torch.inf)
liks[*inds] = data
liks[inds] = data

return liks

Expand Down Expand Up @@ -2469,7 +2470,7 @@ def from_parameters(
if channels is not None:
channels = torch.asarray(channels)
else:
channels = (mean.square().sum(dim=0).sqrt() > channels_amp)
channels = mean.square().sum(dim=0).sqrt() > channels_amp
(channels,) = channels.nonzero(as_tuple=True)
self.register_buffer("channels", channels.to(mean.device))
# TODO: neeed for vis. figure it out.
Expand Down

0 comments on commit c0ff3ea

Please sign in to comment.