@@ -60,7 +60,6 @@ def __post_init__(self):
60
60
assert self .shifts_b .shape == (
61
61
self .upsampled_shifted_template_index_b .shape [1 ],
62
62
)
63
-
64
63
self .a_shift_offset , self .offset_shift_a_to_ix = _get_shift_indexer (
65
64
self .shifts_a
66
65
)
@@ -69,6 +68,15 @@ def __post_init__(self):
69
68
)
70
69
71
70
def get_shift_ix_a (self , shifts_a ):
71
+ """Map shift (an integer, signed) to a shift index
72
+
73
+ A shift index can be used to index into axis=1 of shifted_template_index_a,
74
+ or self.shifts_a for that matter.
75
+ It's an int in [0, n_shifts_a).
76
+ It's equal to np.searchsorted(self.shifts_a, shifts_a).
77
+ The thing is, searchsorted is slow, and we can pre-bake a lookup table.
78
+ _get_shift_indexer does the baking for us above.
79
+ """
72
80
shifts_a = torch .atleast_1d (torch .as_tensor (shifts_a ))
73
81
return self .offset_shift_a_to_ix [shifts_a .to (int ) + self .a_shift_offset ]
74
82
@@ -328,6 +336,7 @@ def query(
328
336
# device=self.device,
329
337
# )
330
338
339
+
331
340
def batched_h5_read (dataset , indices , batch_size = 1000 ):
332
341
if indices .size < batch_size :
333
342
return dataset [indices ]
@@ -341,14 +350,16 @@ def batched_h5_read(dataset, indices, batch_size=1000):
341
350
342
351
def _get_shift_indexer (shifts ):
343
352
assert torch .equal (shifts , torch .sort (shifts ).values )
353
+ # smallest shift (say, -5) becomes 5
344
354
shift_offset = - int (shifts [0 ])
345
355
offset_shift_to_ix = []
356
+
346
357
for j , shift in enumerate (shifts ):
347
358
ix = shift + shift_offset
348
359
assert len (offset_shift_to_ix ) <= ix
349
- assert 0 <= ix < len (shifts )
350
360
while len (offset_shift_to_ix ) < ix :
351
361
offset_shift_to_ix .append (len (shifts ))
352
362
offset_shift_to_ix .append (j )
363
+
353
364
offset_shift_to_ix = torch .tensor (offset_shift_to_ix , device = shifts .device )
354
365
return shift_offset , offset_shift_to_ix
0 commit comments