Skip to content

Commit 67b37b5

Browse files
committed
Merge branch 'main' of github.com:cwindolf/spike-psvae
2 parents dd8f50b + b4cc282 commit 67b37b5

File tree

11 files changed

+100
-53
lines changed

11 files changed

+100
-53
lines changed

src/dartsort/peel/matching.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ def build_template_data(
218218
temporal_components = low_rank_templates.temporal_components.astype(dtype)
219219
singular_values = low_rank_templates.singular_values.astype(dtype)
220220
spatial_components = low_rank_templates.spatial_components.astype(dtype)
221+
print(f"{template_data.templates.dtype=}")
222+
print(f"{temporal_components.dtype=}")
223+
print(f"{singular_values.dtype=}")
224+
print(f"{spatial_components.dtype=}")
221225
self.register_buffer("temporal_components", torch.tensor(temporal_components))
222226
self.register_buffer("singular_values", torch.tensor(singular_values))
223227
self.register_buffer("spatial_components", torch.tensor(spatial_components))
@@ -236,16 +240,20 @@ def build_template_data(
236240
chunk_centers_s = self.recording._recording_segments[0].sample_index_to_time(
237241
chunk_centers_samples
238242
)
243+
print(f"build_template_data {device=}")
244+
print(f"{chunk_centers_s.shape=} {chunk_centers_s[:10]=}")
239245
self.pairwise_conv_db = CompressedPairwiseConv.from_template_data(
240246
save_folder / "pconv.h5",
241247
template_data=template_data,
242248
low_rank_templates=low_rank_templates,
243249
compressed_upsampled_temporal=compressed_upsampled_temporal,
244250
chunk_time_centers_s=chunk_centers_s,
245-
motion_est=motion_est,
251+
motion_est=self.motion_est,
246252
geom=self.geom,
247253
conv_ignore_threshold=self.conv_ignore_threshold,
248254
coarse_approx_error_threshold=self.coarse_approx_error_threshold,
255+
device=device,
256+
n_jobs=n_jobs,
249257
)
250258

251259
self.fixed_output_data += [
@@ -258,7 +266,7 @@ def build_template_data(
258266
),
259267
(
260268
"compressed_upsampled_temporal",
261-
compressed_upsampled_temporal.compressed_upsampled_temporal,
269+
compressed_upsampled_temporal.compressed_upsampled_templates,
262270
),
263271
]
264272

@@ -274,13 +282,14 @@ def handle_upsampling(
274282
ptps=ptps,
275283
max_upsample=temporal_upsampling_factor,
276284
)
285+
print(f"{compressed_upsampled_temporal.compressed_upsampled_templates.dtype=}")
277286
self.register_buffer(
278287
"compressed_upsampling_map",
279-
compressed_upsampled_temporal.compressed_upsampling_map,
288+
torch.tensor(compressed_upsampled_temporal.compressed_upsampling_map),
280289
)
281290
self.register_buffer(
282291
"compressed_upsampled_temporal",
283-
compressed_upsampled_temporal.compressed_upsampled_temporal,
292+
torch.tensor(compressed_upsampled_temporal.compressed_upsampled_templates),
284293
)
285294
if temporal_upsampling_factor == 1:
286295
return compressed_upsampled_temporal

src/dartsort/peel/peel_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def peeling_needs_fit(self):
213213
def precompute_peeling_data(self, save_folder, n_jobs=0, device=None):
214214
# subclasses should override if they need to cache data for peeling
215215
# runs before fit_peeler_models()
216-
assert not self.peeling_needs_fit()
216+
pass
217217

218218
def fit_peeler_models(self, save_folder):
219219
# subclasses should override if they need to fit models for peeling
@@ -324,7 +324,9 @@ def needs_fit(self):
324324
def fit_models(self, save_folder, n_jobs=0, device=None):
325325
with torch.no_grad():
326326
if self.peeling_needs_fit():
327-
self.precompute_peeling_data()
327+
self.precompute_peeling_data(
328+
save_folder=save_folder, n_jobs=n_jobs, device=device
329+
)
328330
self.fit_peeler_models(
329331
save_folder=save_folder, n_jobs=n_jobs, device=device
330332
)

src/dartsort/templates/get_templates.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def get_templates(
181181
snr_threshold=denoising_snr_threshold,
182182
)
183183
templates = weights * raw_templates + (1 - weights) * low_rank_templates
184+
templates = templates.astype(recording.dtype)
184185

185186
return dict(
186187
sorting=sorting,
@@ -379,13 +380,16 @@ def get_all_shifted_raw_and_low_rank_templates(
379380
registered_kdtree = KDTree(registered_geom)
380381

381382
n_units = sorting.labels.max() + 1
382-
raw_templates = np.zeros((n_units, spike_length_samples, n_template_channels))
383+
raw_templates = np.zeros(
384+
(n_units, spike_length_samples, n_template_channels), dtype=recording.dtype
385+
)
383386
low_rank_templates = None
384387
if not raw:
385388
low_rank_templates = np.zeros(
386-
(n_units, spike_length_samples, n_template_channels)
389+
(n_units, spike_length_samples, n_template_channels),
390+
dtype=recording.dtype,
387391
)
388-
snrs_by_channel = np.zeros((n_units, n_template_channels))
392+
snrs_by_channel = np.zeros((n_units, n_template_channels), dtype=recording.dtype)
389393

390394
unit_id_chunks = [
391395
unit_ids[i : i + units_per_job] for i in range(0, n_units, units_per_job)
@@ -421,6 +425,8 @@ def get_all_shifted_raw_and_low_rank_templates(
421425
unit="template",
422426
)
423427
for res in results:
428+
if res is None:
429+
continue
424430
units_chunk, raw_temps_chunk, low_rank_temps_chunk, snrs_chunk = res
425431
raw_templates[units_chunk] = raw_temps_chunk
426432
if not raw:
@@ -477,12 +483,14 @@ def __init__(
477483
dtype=torch.from_numpy(np.zeros(1, dtype=recording.dtype)).dtype,
478484
)
479485

486+
self.n_template_channels = self.n_channels
480487
if self.registered:
481488
self.geom = recording.get_channel_locations()
482489
self.match_distance = pdist(self.geom).min() / 2
483490
self.registered_geom = registered_kdtree.data
484491
self.registered_kdtree = registered_kdtree
485492
self.pitch_shifts = pitch_shifts
493+
self.n_template_channels = len(self.registered_geom)
486494

487495

488496
_template_process_context = None
@@ -535,6 +543,8 @@ def _template_job(unit_ids):
535543
p = _template_process_context
536544

537545
in_units_full = np.flatnonzero(np.isin(p.sorting.labels, unit_ids))
546+
if not in_units_full.size:
547+
return
538548
labels_full = p.sorting.labels[in_units_full]
539549

540550
# only so many spikes per unit
@@ -564,7 +574,7 @@ def _template_job(unit_ids):
564574
(times >= p.trough_offset_samples) & (times < p.max_spike_time)
565575
)
566576
if not valid.size:
567-
return uids, 0, 0, 0
577+
return
568578
in_units = in_units[valid]
569579
labels = labels[valid]
570580
times = times[valid]
@@ -581,12 +591,12 @@ def _template_job(unit_ids):
581591
# compute raw templates and spike counts per channel
582592
raw_templates = []
583593
counts = []
594+
units_chunk = []
584595
for u in uids:
585596
in_unit = np.flatnonzero(labels == u)
586597
if not in_unit.size:
587-
raw_templates.append(np.zeros(1))
588-
counts.append(0)
589598
continue
599+
units_chunk.append(u)
590600
in_unit_orig = in_units[labels == u]
591601
if p.registered:
592602
raw_templates.append(
@@ -617,9 +627,10 @@ def _template_job(unit_ids):
617627
)
618628
counts.append(in_unit.size)
619629
snrs_by_chan = [ptp(rt, 0) * c for rt, c in zip(raw_templates, counts)]
630+
raw_templates = np.array(raw_templates)
620631

621632
if p.denoising_tsvd is None:
622-
return uids, raw_templates, None, snrs_by_chan
633+
return units_chunk, raw_templates, None, snrs_by_chan
623634

624635
# apply denoising
625636
waveforms = waveforms.permute(0, 2, 1).reshape(n * c, t)
@@ -628,11 +639,8 @@ def _template_job(unit_ids):
628639

629640
# get low rank templates
630641
low_rank_templates = []
631-
for u in uids:
642+
for u in units_chunk:
632643
in_unit = np.flatnonzero(labels == u)
633-
if not in_unit.size:
634-
low_rank_templates.append(0)
635-
continue
636644
in_unit_orig = in_units[labels == u]
637645
if p.registered:
638646
low_rank_templates.append(
@@ -650,8 +658,9 @@ def _template_job(unit_ids):
650658
low_rank_templates.append(
651659
p.reducer(waveforms[in_unit], axis=0).numpy(force=True)
652660
)
661+
low_rank_templates = np.array(low_rank_templates)
653662

654-
return uids, raw_templates, low_rank_templates, snrs_by_chan
663+
return units_chunk, raw_templates, low_rank_templates, snrs_by_chan
655664

656665

657666
class TorchSVDProjector(torch.nn.Module):

src/dartsort/templates/pairwise.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,14 @@ def from_template_data(
7979
geom: Optional[np.ndarray] = None,
8080
conv_ignore_threshold=0.0,
8181
coarse_approx_error_threshold=0.0,
82-
conv_batch_size=128,
82+
conv_batch_size=1024,
8383
units_batch_size=8,
8484
overwrite=False,
8585
device=None,
8686
n_jobs=0,
8787
show_progress=True,
8888
):
89+
print(f"pairwise from_template_data {device=}")
8990
compressed_convolve_to_h5(
9091
hdf5_filename,
9192
template_data=template_data,

src/dartsort/templates/pairwise_util.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def compressed_convolve_to_h5(
2828
geom: Optional[np.ndarray] = None,
2929
conv_ignore_threshold=0.0,
3030
coarse_approx_error_threshold=0.0,
31-
conv_batch_size=128,
31+
conv_batch_size=1024,
3232
units_batch_size=8,
3333
overwrite=False,
3434
device=None,
@@ -57,6 +57,7 @@ def compressed_convolve_to_h5(
5757
upsampled_shifted_template_index = get_upsampled_shifted_template_index(
5858
template_shift_index, compressed_upsampled_temporal
5959
)
60+
print(f"compressed_convolve_to_h5 {conv_batch_size=} {units_batch_size=} {device=}")
6061

6162
chunk_res_iterator = iterate_compressed_pairwise_convolutions(
6263
template_data=template_data,
@@ -148,7 +149,7 @@ def iterate_compressed_pairwise_convolutions(
148149
conv_ignore_threshold=0.0,
149150
coarse_approx_error_threshold=0.0,
150151
max_shift="full",
151-
conv_batch_size=128,
152+
conv_batch_size=1024,
152153
units_batch_size=8,
153154
device=None,
154155
n_jobs=0,
@@ -165,6 +166,7 @@ def iterate_compressed_pairwise_convolutions(
165166
process the results differently.
166167
"""
167168
# construct drift-related helper data if needed
169+
print(f"iterate_compressed_pairwise_convolutions {conv_batch_size=} {units_batch_size=} {device=}")
168170
n_shifts = template_shift_index.all_pitch_shifts.size
169171
do_shifting = n_shifts > 1
170172
geom_kdtree = reg_geom_kdtree = match_distance = None
@@ -267,7 +269,7 @@ def compressed_convolve_pairs(
267269
conv_ignore_threshold=0.0,
268270
coarse_approx_error_threshold=0.0,
269271
max_shift="full",
270-
batch_size=128,
272+
batch_size=1024,
271273
device=None,
272274
) -> Optional[CompressedConvResult]:
273275
"""Compute compressed pairwise convolutions between template pairs
@@ -280,9 +282,11 @@ def compressed_convolve_pairs(
280282
shifts, superres templates, and upsamples. Some of these may be zero or may
281283
be duplicates, so the return value is a sparse representation. See below.
282284
"""
285+
# print(f"compressed_convolve_pairs {device=}")
283286
# print(f"{units_a.shape=}")
284287
# print(f"{units_b.shape=}")
285288
# print(f"{(units_a.size * units_b.size)=}")
289+
# print(f"compressed_convolve_pairs {batch_size=} {units_a.size=} {device=}")
286290

287291
# what pairs, shifts, etc are we convolving?
288292
shifted_temp_ix_a, temp_ix_a, shift_a, unit_a = handle_shift_indices(
@@ -317,6 +321,9 @@ def compressed_convolve_pairs(
317321
match_distance=match_distance,
318322
device=device,
319323
)
324+
# print(f"{low_rank_templates.spatial_components.dtype=} {low_rank_templates.singular_values.dtype=}")
325+
# print(f"{compressed_upsampled_temporal.compressed_upsampled_templates.dtype=}")
326+
# print(f"{spatial_singular_a.dtype=} {spatial_singular_b.dtype=}")
320327

321328
# figure out pairs of shifted templates to convolve in a deduplicated way
322329
pairs_ret = shift_deduplicated_pairs(
@@ -392,27 +399,27 @@ def compressed_convolve_pairs(
392399
# print(f"{temporal_a[ix_a[conv_ix]].shape=}")
393400
# print(f"{conv_temporal_components_up_b.shape=}")
394401
pconv, kept = correlate_pairs_lowrank(
395-
torch.as_tensor(spatial_singular_a[ix_a[conv_ix]]).to(device),
396-
torch.as_tensor(spatial_singular_b[ix_b[conv_ix]]).to(device),
397-
torch.as_tensor(temporal_a[ix_a[conv_ix]]).to(device),
398-
torch.as_tensor(conv_temporal_components_up_b).to(device),
402+
torch.as_tensor(spatial_singular_a[ix_a[conv_ix]], device=device),
403+
torch.as_tensor(spatial_singular_b[ix_b[conv_ix]], device=device),
404+
torch.as_tensor(temporal_a[ix_a[conv_ix]], device=device),
405+
torch.as_tensor(conv_temporal_components_up_b, device=device),
399406
max_shift=max_shift,
400407
conv_ignore_threshold=conv_ignore_threshold,
401408
batch_size=batch_size,
402409
)
403-
print(f"-----------")
404-
print(f"after corr {pconv.shape=} {conv_ix[kept].shape=}")
410+
# print(f"-----------")
411+
# print(f"after corr {pconv.shape=} {conv_ix[kept].shape=}")
405412
conv_ix = conv_ix[kept]
406413
if not conv_ix.size:
407414
return None
408415
kept_pairs = np.flatnonzero(np.isin(compression_index, kept))
409-
print(f"-----------")
410-
print(f"kept {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}")
411-
print(f"{compression_index.min()=} {compression_index.max()=}")
412-
print(f"{compression_index[kept_pairs].min()=} {compression_index[kept_pairs].max()=}")
413-
print(f"{ix_a.shape=} {ix_b.shape=}")
414-
print(f"{kept.shape=} {kept.dtype=} {kept.min()=} {kept.max()=}")
415-
print(f"{kept_pairs.shape=} {kept_pairs.dtype=} {kept_pairs.min()=} {kept_pairs.max()=}")
416+
# print(f"-----------")
417+
# print(f"kept {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}")
418+
# print(f"{compression_index.min()=} {compression_index.max()=}")
419+
# print(f"{compression_index[kept_pairs].min()=} {compression_index[kept_pairs].max()=}")
420+
# print(f"{ix_a.shape=} {ix_b.shape=}")
421+
# print(f"{kept.shape=} {kept.dtype=} {kept.min()=} {kept.max()=}")
422+
# print(f"{kept_pairs.shape=} {kept_pairs.dtype=} {kept_pairs.min()=} {kept_pairs.max()=}")
416423
compression_index = np.searchsorted(kept, compression_index[kept_pairs])
417424
conv_ix = np.searchsorted(kept_pairs, conv_ix)
418425
ix_a = ix_a[kept_pairs]
@@ -472,7 +479,7 @@ def correlate_pairs_lowrank(
472479
temporal_b,
473480
max_shift="full",
474481
conv_ignore_threshold=0.0,
475-
batch_size=128,
482+
batch_size=1024,
476483
):
477484
"""Convolve pairs of low rank templates
478485
@@ -504,6 +511,8 @@ def correlate_pairs_lowrank(
504511
assert n_pairs == n_pairs_
505512
assert t == t_
506513
assert rank == rank_
514+
# print(f"{spatial_a.device=} {spatial_b.device=} {temporal_a.device=} {temporal_b.device=}")
515+
# print(f"compressed_convolve_pairs {batch_size=} {n_pairs=} {spatial_a.device=}")
507516

508517
if max_shift == "full":
509518
max_shift = t - 1

0 commit comments

Comments
 (0)