Skip to content

Commit cba9086

Browse files
committed
Batch a matching internal
1 parent e82b42f commit cba9086

File tree

1 file changed

+30
-27
lines changed

1 file changed

+30
-27
lines changed

src/dartsort/peel/matching.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -834,34 +834,37 @@ def subtract_conv(
834834
upsampling_indices,
835835
scalings,
836836
conv_pad_len=0,
837+
batch_size=256,
837838
):
838-
# TODO: may need to batch this.
839-
(
840-
template_indices_a,
841-
template_indices_b,
842-
times,
843-
pconvs,
844-
) = self.pairwise_conv_db.query(
845-
template_indices_a=None,
846-
template_indices_b=template_indices,
847-
upsampling_indices_b=upsampling_indices,
848-
scalings_b=scalings,
849-
times_b=times,
850-
grid=True,
851-
device=conv.device,
852-
shifts_a=self.shifts_a,
853-
shifts_b=self.shifts_b[template_indices]
854-
if self.shifts_b is not None
855-
else None,
856-
)
857-
ix_template = template_indices_a[:, None]
858-
ix_time = times[:, None] + (conv_pad_len + self.conv_lags)[None, :]
859-
spiketorch.add_at_(
860-
conv,
861-
(ix_template, ix_time),
862-
pconvs,
863-
sign=-1,
864-
)
839+
n_spikes = times.shape[0]
840+
for batch_start in range(0, n_spikes, batch_size):
841+
batch_end = min(batch_start + batch_size, n_spikes)
842+
(
843+
template_indices_a,
844+
template_indices_b,
845+
times_sub,
846+
pconvs,
847+
) = self.pairwise_conv_db.query(
848+
template_indices_a=None,
849+
template_indices_b=template_indices[batch_start:batch_end],
850+
upsampling_indices_b=upsampling_indices[batch_start:batch_end],
851+
scalings_b=scalings[batch_start:batch_end],
852+
times_b=times[batch_start:batch_end],
853+
grid=True,
854+
device=conv.device,
855+
shifts_a=self.shifts_a,
856+
shifts_b=self.shifts_b[template_indices[batch_start:batch_end]]
857+
if self.shifts_b is not None
858+
else None,
859+
)
860+
ix_template = template_indices_a[:, None]
861+
ix_time = times_sub[:, None] + (conv_pad_len + self.conv_lags)[None, :]
862+
spiketorch.add_at_(
863+
conv,
864+
(ix_template, ix_time),
865+
pconvs,
866+
sign=-1,
867+
)
865868

866869
def fine_match(
867870
self,

0 commit comments

Comments
 (0)