Skip to content

Commit 68dca10

Browse files
committed
np<2 compat
1 parent 457576c commit 68dca10

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/dartsort/cluster/gaussian_mixture.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,6 +1879,8 @@ def get_log_likelihoods(
18791879
# torch's index_select is painfully slow
18801880
# weights = torch.index_select(likelihoods, 1, features.indices)
18811881
# here we have weights as a csc_array
1882+
if torch.is_tensor(indices):
1883+
indices = indices.numpy(force=True)
18821884
liks = likelihoods[:, indices]
18831885
if unit_ids is not None:
18841886
liks = liks[unit_ids]

0 commit comments

Comments
 (0)