@@ -834,34 +834,37 @@ def subtract_conv(
834
834
upsampling_indices ,
835
835
scalings ,
836
836
conv_pad_len = 0 ,
837
+ batch_size = 256 ,
837
838
):
838
- # TODO: may need to batch this.
839
- (
840
- template_indices_a ,
841
- template_indices_b ,
842
- times ,
843
- pconvs ,
844
- ) = self .pairwise_conv_db .query (
845
- template_indices_a = None ,
846
- template_indices_b = template_indices ,
847
- upsampling_indices_b = upsampling_indices ,
848
- scalings_b = scalings ,
849
- times_b = times ,
850
- grid = True ,
851
- device = conv .device ,
852
- shifts_a = self .shifts_a ,
853
- shifts_b = self .shifts_b [template_indices ]
854
- if self .shifts_b is not None
855
- else None ,
856
- )
857
- ix_template = template_indices_a [:, None ]
858
- ix_time = times [:, None ] + (conv_pad_len + self .conv_lags )[None , :]
859
- spiketorch .add_at_ (
860
- conv ,
861
- (ix_template , ix_time ),
862
- pconvs ,
863
- sign = - 1 ,
864
- )
839
+ n_spikes = times .shape [0 ]
840
+ for batch_start in range (0 , n_spikes , batch_size ):
841
+ batch_end = min (batch_start + batch_size , n_spikes )
842
+ (
843
+ template_indices_a ,
844
+ template_indices_b ,
845
+ times_sub ,
846
+ pconvs ,
847
+ ) = self .pairwise_conv_db .query (
848
+ template_indices_a = None ,
849
+ template_indices_b = template_indices [batch_start :batch_end ],
850
+ upsampling_indices_b = upsampling_indices [batch_start :batch_end ],
851
+ scalings_b = scalings [batch_start :batch_end ],
852
+ times_b = times [batch_start :batch_end ],
853
+ grid = True ,
854
+ device = conv .device ,
855
+ shifts_a = self .shifts_a ,
856
+ shifts_b = self .shifts_b [template_indices [batch_start :batch_end ]]
857
+ if self .shifts_b is not None
858
+ else None ,
859
+ )
860
+ ix_template = template_indices_a [:, None ]
861
+ ix_time = times_sub [:, None ] + (conv_pad_len + self .conv_lags )[None , :]
862
+ spiketorch .add_at_ (
863
+ conv ,
864
+ (ix_template , ix_time ),
865
+ pconvs ,
866
+ sign = - 1 ,
867
+ )
865
868
866
869
def fine_match (
867
870
self ,
0 commit comments