Skip to content

Commit

Permalink
Debug matching
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Dec 8, 2023
1 parent 96f76c5 commit 3493143
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 89 deletions.
28 changes: 15 additions & 13 deletions src/dartsort/peel/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ def build_template_data(
block_size, *_ = spiketorch._calc_oa_lens(convlen, self.spike_length_samples)
self.register_buffer("objective_temporalf", torch.fft.rfft(self.objective_temporal_components, dim=1, n=block_size))

half_chunk = self.chunk_length_samples // 2
chunk_starts = np.arange(
0, self.recording.get_num_samples(), self.chunk_length_samples
)
Expand Down Expand Up @@ -470,9 +469,7 @@ def templates_at_time(self, t_s):
# pitch_shifts_b = torch.as_tensor(pitch_shifts_b)
pitch_shifts_a = torch.as_tensor(pitch_shifts_a, device=cur_obj_spatial.device)
pitch_shifts_b = torch.as_tensor(pitch_shifts_b, device=cur_obj_spatial.device)
pconvdb = pconvdb.at_shifts(pitch_shifts_a, pitch_shifts_b)
# pitch_shifts_a = torch.as_tensor(pitch_shifts_a, device=cur_obj_spatial.device)
# pitch_shifts_b = torch.as_tensor(pitch_shifts_b, device=cur_obj_spatial.device)
# pconvdb = pconvdb.at_shifts(pitch_shifts_a, pitch_shifts_b)
else:
cur_spatial = self.spatial_components
cur_obj_spatial = self.objective_spatial_components
Expand All @@ -499,10 +496,10 @@ def templates_at_time(self, t_s):
compressed_upsampled_temporal=self.compressed_upsampled_temporal,
max_channels=torch.as_tensor(max_channels, device=cur_obj_spatial.device),
pairwise_conv_db=pconvdb,
shifts_a=None,
shifts_b=None,
# shifts_a=pitch_shifts_a,
# shifts_b=pitch_shifts_b,
# shifts_a=None,
# shifts_b=None,
shifts_a=pitch_shifts_a,
shifts_b=pitch_shifts_b,
)

def match_chunk(
Expand Down Expand Up @@ -653,7 +650,7 @@ def find_peaks(
if self.coarse_objective or self.temporal_upsampling_factor > 1:
residual_snips = spiketorch.grab_spikes_full(
residual,
times - 1,
times,
trough_offset=0,
spike_length_samples=self.spike_length_samples + 1,
)
Expand All @@ -670,6 +667,7 @@ def find_peaks(
objective_max[times],
residual_snips,
obj_template_indices,
times,
amp_scale_variance=self.amplitude_scaling_variance,
amp_scale_min=self.amp_scale_min,
amp_scale_max=self.amp_scale_max,
Expand Down Expand Up @@ -840,6 +838,7 @@ def fine_match(
objs,
residual_snips,
objective_template_indices,
times,
amp_scale_variance=0.0,
amp_scale_min=None,
amp_scale_max=None,
Expand Down Expand Up @@ -950,6 +949,9 @@ def fine_match(
).view(len(comp_up_indices), -1)
convs = torch.linalg.vecdot(snips[dup_ix].view(len(temps), -1), temps)
convs_prev = torch.linalg.vecdot(snips_prev[dup_ix].view(len(temps), -1), temps)

convs_r = torch.round(convs).to(int).numpy()
convs_prev_r = torch.round(convs_prev).to(int).numpy()
# convs = torch.einsum(
# "jtc,jrc,jtr->j",
# snips[dup_ix],
Expand Down Expand Up @@ -985,11 +987,11 @@ def fine_match(
upsampling_indices = self.compressed_index_to_upsampling_index[comp_up_indices]

# prev convs were one step earlier
time_shifts = torch.full(comp_up_ix.shape, -1, device=convs.device)
time_shifts[dup_ix, column_ix] += better
# time_shifts = torch.full(comp_up_ix.shape, -1, device=convs.device)
# time_shifts[dup_ix, column_ix] += better
time_shifts = torch.full(comp_up_ix.shape, 0, device=convs.device)
time_shifts[dup_ix, column_ix] += better.to(int)
time_shifts = time_shifts[row_ix, best_column_ix]
print(f"{better=}")
print(f"{time_shifts=}")

return time_shifts, upsampling_indices, scalings, template_indices, objs

Expand Down
142 changes: 71 additions & 71 deletions src/dartsort/templates/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,77 +134,77 @@ def from_template_data(
)
return cls.from_h5(hdf5_filename)

def at_shifts(self, shifts_a=None, shifts_b=None, device=None):
"""Subset this database to one set of shifts.
The database becomes shiftless (not in the pejorative sense).
"""
if shifts_a is None or shifts_b is None:
assert shifts_a is shifts_b
assert self.shifts_a.shape == (1,)
assert self.shifts_b.shape == (1,)
return self

assert shifts_a.shape == (len(self.shifted_template_index_a),)
assert shifts_b.shape == (len(self.upsampled_shifted_template_index_b),)
n_shifted_temps_a, n_up_shifted_temps_b = self.pconv_index.shape

# active shifted and upsampled indices
shift_ix_a = self.get_shift_ix_a(shifts_a)
shift_ix_b = self.get_shift_ix_b(shifts_b)
sub_shifted_temp_index_a = self.shifted_template_index_a[
torch.arange(len(self.shifted_template_index_a))[:, None],
shift_ix_a[:, None],
]
sub_up_shifted_temp_index_b = self.upsampled_shifted_template_index_b[
torch.arange(len(self.upsampled_shifted_template_index_b))[:, None],
shift_ix_b[:, None],
]

# in flat form for indexing into pconv_index. also, reindex.
valid_a = sub_shifted_temp_index_a < n_shifted_temps_a
shifted_temp_ixs_a, new_shifted_temp_ixs_a = torch.unique(
sub_shifted_temp_index_a[valid_a], return_inverse=True
)
valid_b = sub_up_shifted_temp_index_b < n_up_shifted_temps_b
up_shifted_temp_ixs_b, new_up_shifted_temp_ixs_b = torch.unique(
sub_up_shifted_temp_index_b[valid_b], return_inverse=True
)

# get relevant pconv subset and reindex
sub_pconv_indices, new_pconv_indices = torch.unique(
self.pconv_index[
shifted_temp_ixs_a[:, None],
up_shifted_temp_ixs_b.ravel()[None, :],
],
return_inverse=True,
)
if self.in_memory:
sub_pconv = self.pconv[sub_pconv_indices.to(self.pconv.device)]
else:
sub_pconv = torch.from_numpy(batched_h5_read(self.pconv, sub_pconv_indices))
if device is not None:
sub_pconv = sub_pconv.to(device)

# reindexing
n_sub_shifted_temps_a = len(shifted_temp_ixs_a)
n_sub_up_shifted_temps_b = len(up_shifted_temp_ixs_b)
sub_pconv_index = new_pconv_indices.view(
n_sub_shifted_temps_a, n_sub_up_shifted_temps_b
)
sub_shifted_temp_index_a[valid_a] = new_shifted_temp_ixs_a
sub_up_shifted_temp_index_b[valid_b] = new_up_shifted_temp_ixs_b

return self.__class__(
shifts_a=torch.zeros(1),
shifts_b=torch.zeros(1),
shifted_template_index_a=sub_shifted_temp_index_a,
upsampled_shifted_template_index_b=sub_up_shifted_temp_index_b,
pconv_index=sub_pconv_index,
pconv=sub_pconv,
in_memory=True,
device=self.device,
)
# def at_shifts(self, shifts_a=None, shifts_b=None, device=None):
# """Subset this database to one set of shifts.

# The database becomes shiftless (not in the pejorative sense).
# """
# if shifts_a is None or shifts_b is None:
# assert shifts_a is shifts_b
# assert self.shifts_a.shape == (1,)
# assert self.shifts_b.shape == (1,)
# return self

# assert shifts_a.shape == (len(self.shifted_template_index_a),)
# assert shifts_b.shape == (len(self.upsampled_shifted_template_index_b),)
# n_shifted_temps_a, n_up_shifted_temps_b = self.pconv_index.shape

# # active shifted and upsampled indices
# shift_ix_a = self.get_shift_ix_a(shifts_a)
# shift_ix_b = self.get_shift_ix_b(shifts_b)
# sub_shifted_temp_index_a = self.shifted_template_index_a[
# torch.arange(len(self.shifted_template_index_a))[:, None],
# shift_ix_a[:, None],
# ]
# sub_up_shifted_temp_index_b = self.upsampled_shifted_template_index_b[
# torch.arange(len(self.upsampled_shifted_template_index_b))[:, None],
# shift_ix_b[:, None],
# ]

# # in flat form for indexing into pconv_index. also, reindex.
# valid_a = sub_shifted_temp_index_a < n_shifted_temps_a
# shifted_temp_ixs_a, new_shifted_temp_ixs_a = torch.unique(
# sub_shifted_temp_index_a[valid_a], return_inverse=True
# )
# valid_b = sub_up_shifted_temp_index_b < n_up_shifted_temps_b
# up_shifted_temp_ixs_b, new_up_shifted_temp_ixs_b = torch.unique(
# sub_up_shifted_temp_index_b[valid_b], return_inverse=True
# )

# # get relevant pconv subset and reindex
# sub_pconv_indices, new_pconv_indices = torch.unique(
# self.pconv_index[
# shifted_temp_ixs_a[:, None],
# up_shifted_temp_ixs_b.ravel()[None, :],
# ],
# return_inverse=True,
# )
# if self.in_memory:
# sub_pconv = self.pconv[sub_pconv_indices.to(self.pconv.device)]
# else:
# sub_pconv = torch.from_numpy(batched_h5_read(self.pconv, sub_pconv_indices))
# if device is not None:
# sub_pconv = sub_pconv.to(device)

# # reindexing
# n_sub_shifted_temps_a = len(shifted_temp_ixs_a)
# n_sub_up_shifted_temps_b = len(up_shifted_temp_ixs_b)
# sub_pconv_index = new_pconv_indices.view(
# n_sub_shifted_temps_a, n_sub_up_shifted_temps_b
# )
# sub_shifted_temp_index_a[valid_a] = new_shifted_temp_ixs_a
# sub_up_shifted_temp_index_b[valid_b] = new_up_shifted_temp_ixs_b

# return self.__class__(
# shifts_a=torch.zeros(1),
# shifts_b=torch.zeros(1),
# shifted_template_index_a=sub_shifted_temp_index_a,
# upsampled_shifted_template_index_b=sub_up_shifted_temp_index_b,
# pconv_index=sub_pconv_index,
# pconv=sub_pconv,
# in_memory=True,
# device=self.device,
# )

def to(self, device=None, incl_pconv=False, pin=False):
"""Become torch tensors on device."""
Expand Down
13 changes: 8 additions & 5 deletions tests/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,11 @@ def test_tiny_up(tmp_path, up_factor=8):

# spike train
# fmt: off
# start = 50
start = 50
# tclu = []
# for i in range(up_factor):
# tclu.extend((start + 200 * i, 0, 0, i))
# tclu.extend((start + 1 + 200 * i, 0, 1, i))
tclu = [50, 0, 0, 0]
tclu = [50, 0, 0, 7]
# fmt: on
times, channels, labels, upsampling_indices = np.array(tclu).reshape(-1, 4).T
rec = np.zeros((recording_length_samples, n_channels), dtype="float32")
Expand Down Expand Up @@ -256,8 +255,8 @@ def test_tiny_up(tmp_path, up_factor=8):
print(f'{res["n_spikes"]=} {len(times)=}')
print(f"{cupts.compressed_upsampled_templates.ptp(1).max(1)=}")
print(f'{res["collisioncleaned_waveforms"].numpy(force=True).ptp(1).max(1)=}')
print(f'{np.c_[res["times_samples"], res["labels"]]=}')
print(f"{np.c_[times, labels]=}")
print(f'{np.c_[res["times_samples"], res["labels"], res["upsampling_indices"]]=}')
print(f"{np.c_[times, labels, upsampling_indices]=}")
print(f'{torch.square(res["residual"]).mean()=}')
print(f'{torch.square(res["conv"]).mean()=}')
assert res["n_spikes"] == len(times)
Expand Down Expand Up @@ -601,20 +600,24 @@ def drifting_tester(tmp_path, up_factor=1):
import tempfile
from pathlib import Path

print("\n"*5)
print("test tiny")
with tempfile.TemporaryDirectory() as tdir:
test_tiny(Path(tdir))

print("\n"*5)
print("test tiny_up")
with tempfile.TemporaryDirectory() as tdir:
test_tiny_up(Path(tdir))

print()
print("\n"*5)
print("test test_static_noup")
with tempfile.TemporaryDirectory() as tdir:
test_static_noup(Path(tdir))

print()
print("\n"*5)
print("test test_static_up")
with tempfile.TemporaryDirectory() as tdir:
test_static_up(Path(tdir))

0 comments on commit 3493143

Please sign in to comment.