@@ -30,7 +30,7 @@ def compressed_convolve_to_h5(
30
30
geom : Optional [np .ndarray ] = None ,
31
31
conv_ignore_threshold = 0.0 ,
32
32
coarse_approx_error_threshold = 0.0 ,
33
- conv_batch_size = 1024 ,
33
+ conv_batch_size = 128 ,
34
34
units_batch_size = 8 ,
35
35
overwrite = False ,
36
36
device = None ,
@@ -174,7 +174,7 @@ def iterate_compressed_pairwise_convolutions(
174
174
amplitude_scaling_variance = 0.0 ,
175
175
amplitude_scaling_boundary = 0.5 ,
176
176
reduce_deconv_resid_norm = False ,
177
- conv_batch_size = 1024 ,
177
+ conv_batch_size = 128 ,
178
178
units_batch_size = 8 ,
179
179
device = None ,
180
180
n_jobs = 0 ,
@@ -401,7 +401,7 @@ def compressed_convolve_pairs(
401
401
amplitude_scaling_boundary = 0.5 ,
402
402
reduce_deconv_resid_norm = False ,
403
403
max_shift = "full" ,
404
- batch_size = 1024 ,
404
+ batch_size = 128 ,
405
405
device = None ,
406
406
) -> Optional [CompressedConvResult ]:
407
407
"""Compute compressed pairwise convolutions between template pairs
@@ -469,12 +469,14 @@ def compressed_convolve_pairs(
469
469
470
470
# handle upsampling
471
471
# each pair will be duplicated by the b unit's number of upsampled copies
472
+
472
473
(
473
474
ix_b ,
474
475
compression_index ,
475
476
conv_ix ,
476
477
conv_upsampling_indices_b ,
477
- conv_temporal_components_up_b ,
478
+ conv_temporal_components_up_b , #Need to change this conv_temporal_components_up_b[conv_compressed_upsampled_ix_b]
479
+ conv_compressed_upsampled_ix_b ,
478
480
compression_dup_ix ,
479
481
) = compressed_upsampled_pairs (
480
482
ix_b ,
@@ -491,10 +493,14 @@ def compressed_convolve_pairs(
491
493
# run convolutions
492
494
temporal_a = low_rank_templates_a .temporal_components [temp_ix_a ]
493
495
pconv , kept = correlate_pairs_lowrank (
494
- torch .as_tensor (spatial_singular_a [ ix_a [ conv_ix ]] , device = device ),
495
- torch .as_tensor (spatial_singular_b [ ix_b [ conv_ix ]] , device = device ),
496
- torch .as_tensor (temporal_a [ ix_a [ conv_ix ]] , device = device ),
496
+ torch .as_tensor (spatial_singular_a , device = device ),
497
+ torch .as_tensor (spatial_singular_b , device = device ),
498
+ torch .as_tensor (temporal_a , device = device ),
497
499
torch .as_tensor (conv_temporal_components_up_b , device = device ),
500
+ ix_a = ix_a ,
501
+ ix_b = ix_b ,
502
+ conv_ix = conv_ix ,
503
+ conv_compressed_upsampled_ix_b = conv_compressed_upsampled_ix_b ,
498
504
max_shift = max_shift ,
499
505
conv_ignore_threshold = conv_ignore_threshold ,
500
506
batch_size = batch_size ,
@@ -558,9 +564,13 @@ def correlate_pairs_lowrank(
558
564
spatial_b ,
559
565
temporal_a ,
560
566
temporal_b ,
567
+ ix_a ,
568
+ ix_b ,
569
+ conv_ix ,
570
+ conv_compressed_upsampled_ix_b ,
561
571
max_shift = "full" ,
562
572
conv_ignore_threshold = 0.0 ,
563
- batch_size = 1024 ,
573
+ batch_size = 128 ,
564
574
):
565
575
"""Convolve pairs of low rank templates
566
576
@@ -580,15 +590,19 @@ def correlate_pairs_lowrank(
580
590
-------
581
591
pconv, kept
582
592
"""
583
- n_pairs , rank , nchan = spatial_a .shape
584
- n_pairs_ , rank_ , nchan_ = spatial_b .shape
593
+
594
+ # Now need to take ix_a/b[conv_ix] of spatial_a, spatial_b, temporal_a
595
+ _ , rank , nchan = spatial_a .shape
596
+ _ , rank_ , nchan_ = spatial_b .shape
597
+ n_pairs = conv_ix .shape [0 ]
585
598
assert rank == rank_
586
599
assert nchan == nchan_
587
- assert n_pairs == n_pairs_
588
- n_pairs_ , t , rank_ = temporal_a .shape
589
- assert n_pairs == n_pairs_
600
+ # assert n_pairs == n_pairs_
601
+ _ , t , rank_ = temporal_a .shape
602
+ # assert n_pairs == n_pairs_
590
603
assert rank_ == rank
591
- n_pairs_ , t_ , rank_ = temporal_b .shape
604
+ _ , t_ , rank_ = temporal_b .shape
605
+ n_pairs_ = conv_compressed_upsampled_ix_b .shape [0 ]
592
606
assert n_pairs == n_pairs_
593
607
assert t == t_
594
608
assert rank == rank_
@@ -609,12 +623,12 @@ def correlate_pairs_lowrank(
609
623
ix = slice (istart , iend )
610
624
611
625
# want conv filter: nco, 1, rank, t
612
- template_a = torch .bmm (temporal_a [ix ], spatial_a [ix ])
613
- conv_filt = torch .bmm (spatial_b [ix ], template_a .mT )
626
+ template_a = torch .bmm (temporal_a [ix_a [ conv_ix ][ ix ]] , spatial_a [ix_a [ conv_ix ][ ix ] ])
627
+ conv_filt = torch .bmm (spatial_b [ix_b [ conv_ix ][ ix ] ], template_a .mT )
614
628
conv_filt = conv_filt [:, None ] # (nco, 1, rank, t)
615
629
616
630
# 1, nco, rank, t
617
- conv_in = temporal_b [ix ].mT [None ]
631
+ conv_in = temporal_b [conv_compressed_upsampled_ix_b [ ix ] ].mT [None ]
618
632
619
633
# conv2d:
620
634
# depthwise, chans=nco. batch=1. h=rank. w=t. out: nup=1, nco, 1, 2p+1.
@@ -951,10 +965,10 @@ def compressed_upsampled_pairs(
951
965
compression_dup_ix = slice (None )
952
966
if up_factor == 1 :
953
967
upinds = np .zeros (len (conv_ix ), dtype = int )
954
- temp_comps = compressed_upsampled_temporal .compressed_upsampled_templates [
955
- np .atleast_1d (temp_ix_b [ix_b [conv_ix ]])
956
- ]
957
- return ix_b , compression_index , conv_ix , upinds , temp_comps , compression_dup_ix
968
+ # temp_comps = compressed_upsampled_temporal.compressed_upsampled_templates[
969
+ # np.atleast_1d(temp_ix_b[ix_b[conv_ix]])
970
+ # ]
971
+ return ix_b , compression_index , conv_ix , upinds , compressed_upsampled_temporal . compressed_upsampled_templates , np . atleast_1d ( temp_ix_b [ ix_b [ conv_ix ]]) , compression_dup_ix
958
972
959
973
# each conv_ix needs to be duplicated as many times as its b template has
960
974
# upsampled copies
@@ -991,18 +1005,16 @@ def compressed_upsampled_pairs(
991
1005
conv_compressed_upsampled_ix
992
1006
]
993
1007
)
994
- conv_temporal_components_up_b = (
995
- compressed_upsampled_temporal .compressed_upsampled_templates [
996
- conv_compressed_upsampled_ix
997
- ]
998
- )
1008
+
1009
+ # conv_temporal_components_up_b = compressed_upsampled_temporal.compressed_upsampled_templates
999
1010
1000
1011
return (
1001
1012
ix_b_up ,
1002
1013
compression_index_up ,
1003
1014
conv_ix_up ,
1004
1015
conv_upsampling_indices_b ,
1005
- conv_temporal_components_up_b ,
1016
+ compressed_upsampled_temporal .compressed_upsampled_templates ,
1017
+ conv_compressed_upsampled_ix ,
1006
1018
compression_dup_ix ,
1007
1019
)
1008
1020
0 commit comments