Skip to content

Commit 8236157

Browse files
committed
Try at_shifts
1 parent e225f30 commit 8236157

File tree

2 files changed

+106
-38
lines changed

2 files changed

+106
-38
lines changed

src/dartsort/peel/matching.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ def peel_chunk(
426426
def templates_at_time(self, t_s):
427427
"""Handle drift -- grab the right spatial neighborhoods."""
428428
pconvdb = self.pairwise_conv_db
429-
pitch_shifts_a=pitch_shifts_b=None
429+
pitch_shifts_a = pitch_shifts_b = None
430+
pconvdb.to(self.objective_spatial_components.device, pin=True)
430431
if self.is_drifting:
431432
pitch_shifts_b, cur_spatial = template_util.templates_at_time(
432433
t_s,
@@ -464,17 +465,22 @@ def templates_at_time(self, t_s):
464465
fill_value=0.0,
465466
)
466467
max_channels = cur_ampvecs[:, 0, :].argmax(1)
467-
# pconvdb = pconvdb.at_shifts(pitch_shifts_a, pitch_shifts_b)
468+
# pitch_shifts_a = torch.as_tensor(pitch_shifts_a)
469+
# pitch_shifts_b = torch.as_tensor(pitch_shifts_b)
468470
pitch_shifts_a = torch.as_tensor(pitch_shifts_a, device=cur_obj_spatial.device)
469471
pitch_shifts_b = torch.as_tensor(pitch_shifts_b, device=cur_obj_spatial.device)
472+
pconvdb = pconvdb.at_shifts(pitch_shifts_a, pitch_shifts_b)
473+
# pitch_shifts_a = torch.as_tensor(pitch_shifts_a, device=cur_obj_spatial.device)
474+
# pitch_shifts_b = torch.as_tensor(pitch_shifts_b, device=cur_obj_spatial.device)
470475
else:
471476
cur_spatial = self.spatial_components
472477
cur_obj_spatial = self.objective_spatial_components
473478
max_channels = self.registered_template_ampvecs.argmax(1)
474479

475480
# if not pconvdb._is_torch:
476-
# # pconvdb.to("cpu")
477-
pconvdb.to(cur_obj_spatial.device)
481+
# pconvdb.to("cpu")
482+
# if cur_obj_spatial.device.type == "cuda" and not pconvdb.device.type == "cuda":
483+
# pconvdb.to(cur_obj_spatial.device, pin=True)
478484

479485
return MatchingTemplateData(
480486
objective_spatial_components=cur_obj_spatial,
@@ -492,8 +498,10 @@ def templates_at_time(self, t_s):
492498
compressed_upsampled_temporal=self.compressed_upsampled_temporal,
493499
max_channels=torch.as_tensor(max_channels, device=cur_obj_spatial.device),
494500
pairwise_conv_db=pconvdb,
495-
shifts_a=pitch_shifts_a,
496-
shifts_b=pitch_shifts_b,
501+
shifts_a=None,
502+
shifts_b=None,
503+
# shifts_a=pitch_shifts_a,
504+
# shifts_b=pitch_shifts_b,
497505
)
498506

499507
def match_chunk(
@@ -560,20 +568,20 @@ def match_chunk(
560568

561569
# subtract them
562570
# old_norm = torch.linalg.norm(residual) ** 2
563-
compressed_template_data.subtract_conv(
564-
padded_conv,
571+
compressed_template_data.subtract(
572+
residual_padded,
565573
new_peaks.times,
566574
new_peaks.template_indices,
567575
new_peaks.upsampling_indices,
568576
new_peaks.scalings,
569-
conv_pad_len=self.obj_pad_len,
570577
)
571-
compressed_template_data.subtract(
572-
residual_padded,
578+
compressed_template_data.subtract_conv(
579+
padded_conv,
573580
new_peaks.times,
574581
new_peaks.template_indices,
575582
new_peaks.upsampling_indices,
576583
new_peaks.scalings,
584+
conv_pad_len=self.obj_pad_len,
577585
)
578586

579587
# new_norm = torch.linalg.norm(residual) ** 2
@@ -627,7 +635,7 @@ def find_peaks(
627635
alpha=2.0,
628636
out=padded_objective[:-1],
629637
)
630-
638+
631639
# first step: coarse peaks. not temporally upsampled or amplitude-scaled.
632640
objective = (padded_objective + refrac_mask)[
633641
:-1, self.obj_pad_len : -self.obj_pad_len
@@ -668,7 +676,7 @@ def find_peaks(
668676
)
669677
if time_shifts is not None:
670678
times += time_shifts
671-
679+
672680
return MatchingPeaks(
673681
n_spikes=times.numel(),
674682
times=times,
@@ -884,12 +892,17 @@ def fine_match(
884892
superres_ix = superres_index[objective_template_indices]
885893
dup_ix, column_ix = (superres_ix < self.n_templates).nonzero(as_tuple=True)
886894
template_indices = superres_ix[dup_ix, column_ix]
887-
convs = torch.einsum(
888-
"jtc,jrc,jtr->j",
889-
snips[dup_ix],
890-
self.spatial_singular[template_indices],
895+
convs = torch.baddbmm(
891896
self.temporal_components[template_indices],
892-
)
897+
snips[dup_ix],
898+
self.spatial_singular[template_indices].mT,
899+
).sum((1, 2))
900+
# convs = torch.einsum(
901+
# "jtc,jrc,jtr->j",
902+
# snips[dup_ix],
903+
# self.spatial_singular[template_indices],
904+
# self.temporal_components[template_indices],
905+
# )
893906
norms = self.template_norms_squared[template_indices]
894907
objs = torch.full(superres_ix.shape, -torch.inf, device=convs.device)
895908
objs[dup_ix, column_ix] = 2 * convs - norms

src/dartsort/templates/pairwise.py

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,35 @@ class CompressedPairwiseConv:
5252
# the 0 index is special: pconv[0] === 0.
5353
pconv: np.ndarray
5454
in_memory: bool = False
55+
device: torch.device = torch.device("cpu")
5556

5657
def __post_init__(self):
5758
assert self.shifts_a.ndim == self.shifts_b.ndim == 1
5859
assert self.shifts_a.shape == (self.shifted_template_index_a.shape[1],)
59-
assert self.shifts_b.shape == (self.upsampled_shifted_template_index_b.shape[1],)
60+
assert self.shifts_b.shape == (
61+
self.upsampled_shifted_template_index_b.shape[1],
62+
)
63+
64+
self.a_shift_offset, self.offset_shift_a_to_ix = _get_shift_indexer(
65+
self.shifts_a
66+
)
67+
self.b_shift_offset, self.offset_shift_b_to_ix = _get_shift_indexer(
68+
self.shifts_b
69+
)
70+
71+
def get_shift_ix_a(self, shifts_a):
72+
return self.offset_shift_a_to_ix[shifts_a.to(int) + self.a_shift_offset]
73+
74+
def get_shift_ix_b(self, shifts_b):
75+
return self.offset_shift_b_to_ix[shifts_b.to(int) + self.b_shift_offset]
6076

6177
@classmethod
6278
def from_h5(cls, hdf5_filename, in_memory=True):
63-
ff = [f for f in fields(cls) if not f.name == "in_memory"]
79+
ff = [f for f in fields(cls) if f.name not in ("in_memory", "device")]
6480
if in_memory:
6581
with h5py.File(hdf5_filename, "r") as h5:
6682
data = {f.name: torch.from_numpy(h5[f.name][:]) for f in ff}
6783
return cls(**data, in_memory=in_memory)
68-
6984
_h5 = h5py.File(hdf5_filename, "r")
7085
data = {}
7186
for f in ff:
@@ -117,7 +132,7 @@ def from_template_data(
117132
)
118133
return cls.from_h5(hdf5_filename)
119134

120-
def at_shifts(self, shifts_a=None, shifts_b=None):
135+
def at_shifts(self, shifts_a=None, shifts_b=None, device=None):
121136
"""Subset this database to one set of shifts.
122137
123138
The database becomes shiftless (not in the pejorative sense).
@@ -133,8 +148,8 @@ def at_shifts(self, shifts_a=None, shifts_b=None):
133148
n_shifted_temps_a, n_up_shifted_temps_b = self.pconv_index.shape
134149

135150
# active shifted and upsampled indices
136-
shift_ix_a = torch.searchsorted(self.shifts_a, shifts_a)
137-
shift_ix_b = torch.searchsorted(self.shifts_b, shifts_b)
151+
shift_ix_a = self.get_shift_ix_a(shifts_a)
152+
shift_ix_b = self.get_shift_ix_b(shifts_b)
138153
sub_shifted_temp_index_a = self.shifted_template_index_a[
139154
torch.arange(len(self.shifted_template_index_a))[:, None],
140155
shift_ix_a[:, None],
@@ -166,6 +181,8 @@ def at_shifts(self, shifts_a=None, shifts_b=None):
166181
sub_pconv = self.pconv[sub_pconv_indices.to(self.pconv.device)]
167182
else:
168183
sub_pconv = torch.from_numpy(batched_h5_read(self.pconv, sub_pconv_indices))
184+
if device is not None:
185+
sub_pconv = sub_pconv.to(device)
169186

170187
# reindexing
171188
n_sub_shifted_temps_a = len(shifted_temp_ixs_a)
@@ -184,17 +201,30 @@ def at_shifts(self, shifts_a=None, shifts_b=None):
184201
pconv_index=sub_pconv_index,
185202
pconv=sub_pconv,
186203
in_memory=True,
204+
device=self.device,
187205
)
188206

189-
def to(self, device=None, incl_pconv=False):
207+
def to(self, device=None, incl_pconv=False, pin=False):
190208
"""Become torch tensors on device."""
191-
for f in fields(self):
192-
if f.name == "pconv":
209+
print(f"to {device=}")
210+
for name in ["offset_shift_a_to_ix", "offset_shift_b_to_ix"] + [
211+
f.name for f in fields(self)
212+
]:
213+
if name == "pconv" and not incl_pconv:
193214
continue
194-
v = getattr(self, f.name)
215+
v = getattr(self, name)
195216
if isinstance(v, np.ndarray) or torch.is_tensor(v):
196-
setattr(self, f.name, torch.as_tensor(v, device=device))
217+
setattr(self, name, torch.as_tensor(v, device=device))
197218
self.device = device
219+
if pin and self.device.type == "cuda" and torch.cuda.is_available() and not self.pconv.is_pinned():
220+
# self.pconv.share_memory_()
221+
print("pin")
222+
torch.cuda.cudart().cudaHostRegister(
223+
self.pconv.data_ptr(), self.pconv.numel() * self.pconv.element_size(), 0
224+
)
225+
# assert x.is_shared()
226+
assert self.pconv.is_pinned()
227+
# self.pconv = self.pconv.pin_memory()
198228
return self
199229

200230
def query(
@@ -211,9 +241,9 @@ def query(
211241
device=None,
212242
):
213243
if template_indices_a is None:
214-
template_indices_a = torch.arange(
215-
len(self.shifted_template_index_a), device=self.device
216-
)
244+
template_indices_a = torch.arange(
245+
len(self.shifted_template_index_a), device=self.device
246+
)
217247
template_indices_a = torch.atleast_1d(template_indices_a)
218248
template_indices_b = torch.atleast_1d(template_indices_b)
219249

@@ -230,8 +260,8 @@ def query(
230260
shifted_template_index = shifted_template_index[:, 0]
231261
upsampled_shifted_template_index = upsampled_shifted_template_index[:, 0]
232262
else:
233-
shift_indices_a = torch.searchsorted(self.shifts_a, shifts_a)
234-
shift_indices_b = torch.searchsorted(self.shifts_b, shifts_b)
263+
shift_indices_a = self.get_shift_ix_a(shifts_a)
264+
shift_indices_b = self.get_shift_ix_a(shifts_b)
235265
a_ix = (template_indices_a, shift_indices_a)
236266
b_ix = (template_indices_b, shift_indices_b)
237267

@@ -250,6 +280,9 @@ def query(
250280
up_shifted_temp_ix_b = upsampled_shifted_template_index[b_ix]
251281

252282
# return convolutions between all ai,bj or just ai,bi?
283+
print(f"{shifted_temp_ix_a.device=} {up_shifted_temp_ix_b.device=}")
284+
print(f"{self.device=} {self.shifts_a.device=}")
285+
print(f"{template_indices_a.device=} {template_indices_b.device=}")
253286
if grid:
254287
pconv_indices = self.pconv_index[
255288
shifted_temp_ix_a[:, None], up_shifted_temp_ix_b[None, :]
@@ -258,9 +291,13 @@ def query(
258291
template_indices_a, template_indices_b
259292
).T
260293
if scalings_b is not None:
261-
scalings_b = torch.broadcast_to(scalings_b[None], pconv_indices.shape).reshape(-1)
294+
scalings_b = torch.broadcast_to(
295+
scalings_b[None], pconv_indices.shape
296+
).reshape(-1)
262297
if times_b is not None:
263-
times_b = torch.broadcast_to(times_b[None], pconv_indices.shape).reshape(-1)
298+
times_b = torch.broadcast_to(
299+
times_b[None], pconv_indices.shape
300+
).reshape(-1)
264301
pconv_indices = pconv_indices.view(-1)
265302
else:
266303
pconv_indices = self.pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b]
@@ -279,7 +316,9 @@ def query(
279316
if self.in_memory:
280317
pconvs = self.pconv[pconv_indices.to(self.pconv.device)]
281318
else:
282-
pconvs = torch.from_numpy(batched_h5_read(self.pconv, pconv_indices.numpy(force=True)))
319+
pconvs = torch.from_numpy(
320+
batched_h5_read(self.pconv, pconv_indices.numpy(force=True))
321+
)
283322
if device is not None:
284323
pconvs = pconvs.to(device)
285324

@@ -291,6 +330,7 @@ def query(
291330

292331
return template_indices_a, template_indices_b, pconvs
293332

333+
294334
def batched_h5_read(dataset, indices, batch_size=1000):
295335
if indices.size < batch_size:
296336
return dataset[indices]
@@ -299,4 +339,19 @@ def batched_h5_read(dataset, indices, batch_size=1000):
299339
for bs in range(0, indices.size, batch_size):
300340
be = min(indices.size, bs + batch_size)
301341
out[bs:be] = dataset[indices[bs:be]]
302-
return out
342+
return out
343+
344+
345+
def _get_shift_indexer(shifts):
346+
assert torch.equal(shifts, torch.sort(shifts).values)
347+
shift_offset = -int(shifts[0])
348+
offset_shift_to_ix = []
349+
for j, shift in enumerate(shifts):
350+
ix = shift + shift_offset
351+
assert len(offset_shift_to_ix) <= ix
352+
assert 0 <= ix < len(shifts)
353+
while len(offset_shift_to_ix) < ix:
354+
offset_shift_to_ix.append(len(shifts))
355+
offset_shift_to_ix.append(j)
356+
offset_shift_to_ix = torch.tensor(offset_shift_to_ix, device=shifts.device)
357+
return shift_offset, offset_shift_to_ix

0 commit comments

Comments
 (0)