Skip to content

Commit 835f472

Browse files
committed
Fix shifts with gaps assert thing
1 parent dda7911 commit 835f472

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

src/dartsort/templates/pairwise.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def __post_init__(self):
6060
assert self.shifts_b.shape == (
6161
self.upsampled_shifted_template_index_b.shape[1],
6262
)
63-
6463
self.a_shift_offset, self.offset_shift_a_to_ix = _get_shift_indexer(
6564
self.shifts_a
6665
)
@@ -69,6 +68,15 @@ def __post_init__(self):
6968
)
7069

7170
def get_shift_ix_a(self, shifts_a):
71+
"""Map shift (an integer, signed) to a shift index
72+
73+
A shift index can be used to index into axis=1 of shifted_template_index_a,
74+
or self.shifts_a for that matter.
75+
It's an int in [0, n_shifts_a).
76+
It's equal to np.searchsorted(self.shifts_a, shifts_a).
77+
The thing is, searchsorted is slow, and we can pre-bake a lookup table.
78+
_get_shift_indexer does the baking for us above.
79+
"""
7280
shifts_a = torch.atleast_1d(torch.as_tensor(shifts_a))
7381
return self.offset_shift_a_to_ix[shifts_a.to(int) + self.a_shift_offset]
7482

@@ -328,6 +336,7 @@ def query(
328336
# device=self.device,
329337
# )
330338

339+
331340
def batched_h5_read(dataset, indices, batch_size=1000):
332341
if indices.size < batch_size:
333342
return dataset[indices]
@@ -341,14 +350,16 @@ def batched_h5_read(dataset, indices, batch_size=1000):
341350

342351
def _get_shift_indexer(shifts):
343352
assert torch.equal(shifts, torch.sort(shifts).values)
353+
# smallest shift (say, -5) becomes 5
344354
shift_offset = -int(shifts[0])
345355
offset_shift_to_ix = []
356+
346357
for j, shift in enumerate(shifts):
347358
ix = shift + shift_offset
348359
assert len(offset_shift_to_ix) <= ix
349-
assert 0 <= ix < len(shifts)
350360
while len(offset_shift_to_ix) < ix:
351361
offset_shift_to_ix.append(len(shifts))
352362
offset_shift_to_ix.append(j)
363+
353364
offset_shift_to_ix = torch.tensor(offset_shift_to_ix, device=shifts.device)
354365
return shift_offset, offset_shift_to_ix

0 commit comments

Comments
 (0)