Skip to content

Commit 4b64155

Browse files
committed
Device bugs
1 parent 1c1dc58 commit 4b64155

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/dartsort/cluster/stable_features.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -723,13 +723,14 @@ def spike_neighborhoods(
723723
if spike_ids is None:
724724
spike_ids = self.neighborhood_ids[spike_indices]
725725
assert spike_ids is not None
726-
covered_ids = torch.unique(spike_ids).to(self.indicators.device)
726+
covered_ids = torch.unique(spike_ids)
727727
if min_coverage:
728+
covered_ids = covered_ids.to(self.indicators.device)
728729
inds = self.indicators[channels][:, covered_ids]
729730
coverage = inds.sum(0) / self.channel_counts[covered_ids]
730731
covered = coverage >= min_coverage
731732
covered_ids = covered_ids[covered].cpu()
732-
spike_ids = spike_ids.cpu()
733+
spike_ids = spike_ids.cpu()
733734
neighborhood_info = [
734735
(j, self.neighborhoods[j], *(spike_ids == j).nonzero(as_tuple=True), None)
735736
for j in covered_ids

0 commit comments

Comments
 (0)