Skip to content

Commit

Permalink
Sorta functional
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 28, 2023
1 parent 4b1d1cb commit e7eb19f
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 234 deletions.
2 changes: 2 additions & 0 deletions src/dartsort/localize/localize_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def localize_amplitude_vectors(
geom_pad = F.pad(geom, (0, 0, 0, 1))
local_geoms = geom_pad[channel_index[main_channels]]
local_geoms[:, :, 1] -= geom[main_channels, 1][:, None]
print(f"{amplitude_vectors.shape=}")
print(f"{local_geoms.shape=}")

# center of mass initialization
com = torch.divide(
Expand Down
2 changes: 1 addition & 1 deletion src/dartsort/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _run_peeler(
)

# do localization
if featurization_config.do_localization:
if not featurization_config.denoise_only and featurization_config.do_localization:
wf_name = featurization_config.output_waveforms_name
localize_hdf5(
output_hdf5_filename,
Expand Down
103 changes: 53 additions & 50 deletions src/dartsort/peel/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,16 @@ def build_template_data(
self.handle_template_groups(
coarse_template_data.unit_ids, self.template_data.unit_ids
)
convlen = self.chunk_length_samples + self.chunk_margin_samples
block_size, *_ = spiketorch._calc_oa_lens(convlen, self.spike_length_samples)
self.register_buffer("objective_temporalf", torch.fft.rfft(self.objective_temporal_components, dim=1, n=block_size))

half_chunk = self.chunk_length_samples // 2
chunk_centers_samples = np.arange(
half_chunk, self.recording.get_num_samples(), self.chunk_length_samples
chunk_starts = np.arange(
0, self.recording.get_num_samples(), self.chunk_length_samples
)
chunk_ends = np.minimum(chunk_starts + self.chunk_length_samples, self.recording.get_num_samples())
chunk_centers_samples = (chunk_starts + chunk_ends) / 2
chunk_centers_s = self.recording._recording_segments[0].sample_index_to_time(
chunk_centers_samples
)
Expand Down Expand Up @@ -392,6 +397,7 @@ def peel_chunk(
left_margin=0,
right_margin=0,
return_residual=False,
return_conv=False,
):
# get current template set
chunk_center_samples = chunk_start_samples + self.chunk_length_samples // 2
Expand All @@ -404,11 +410,12 @@ def peel_chunk(
match_results = self.match_chunk(
traces,
compressed_template_data,
trough_offset_samples=42,
left_margin=0,
right_margin=0,
threshold=30,
trough_offset_samples=self.trough_offset_samples,
left_margin=left_margin,
right_margin=right_margin,
threshold=self.threshold,
return_residual=return_residual,
return_conv=return_conv,
)

# process spike times and create return result
Expand Down Expand Up @@ -459,15 +466,17 @@ def templates_at_time(self, t_s):
pconvdb = pconvdb.at_shifts(pitch_shifts_a, pitch_shifts_b)
else:
cur_spatial = self.spatial_components
cur_obj_spatial = self.objective_spatial_components
max_channels = self.registered_template_ampvecs.argmax(1)

if not pconvdb._is_torch:
pconvdb.to(cur_spatial.device)
pconvdb.to(cur_obj_spatial.device)

return MatchingTemplateData(
objective_spatial_components=cur_obj_spatial,
objective_singular_values=self.objective_singular_values,
objective_temporal_components=self.objective_temporal_components,
objective_temporalf=self.objective_temporalf,
fine_to_coarse=self.fine_to_coarse,
coarse_objective=self.coarse_objective,
spatial_components=cur_spatial,
Expand All @@ -477,7 +486,7 @@ def templates_at_time(self, t_s):
compressed_upsampling_index=self.compressed_upsampling_index,
compressed_index_to_upsampling_index=self.compressed_index_to_upsampling_index,
compressed_upsampled_temporal=self.compressed_upsampled_temporal,
max_channels=max_channels,
max_channels=torch.as_tensor(max_channels, device=cur_obj_spatial.device),
pairwise_conv_db=pconvdb,
)

Expand All @@ -490,12 +499,12 @@ def match_chunk(
right_margin=0,
threshold=30,
return_residual=False,
return_conv=False,
):
"""Core peeling routine for subtraction"""
# initialize residual, it needs to be padded to support our channel
# indexing convention (used later to extract small channel
# neighborhoods). this copies the input.
print(f"match_chunk {traces.shape=}")
residual_padded = F.pad(traces, (0, 1), value=torch.nan)
residual = residual_padded[:, :-1]

Expand All @@ -517,7 +526,6 @@ def match_chunk(
refrac_mask = torch.zeros_like(padded_objective)
# padded objective has an extra unit (for group_index) and refractory
# padding (for easier implementation of enforce_refractory)
neg_temp_normsq = -compressed_template_data.objective_template_norms_squared[:, None]

# manages buffers for spike train data (peak times, labels, etc)
peaks = MatchingPeaks(device=traces.device)
Expand All @@ -528,28 +536,13 @@ def match_chunk(
)

# main loop
print("start")
for it in range(self.max_iter):
# update the coarse objective
torch.add(
neg_temp_normsq, padded_conv, alpha=2.0, out=padded_objective[:-1]
)

# find high-res peaks
print("before find")
new_peaks = self.find_peaks(
residual, padded_conv, padded_objective, refrac_mask, compressed_template_data
)
if new_peaks is None:
break
# print("----------")
# if not it % 1:
# if new_peaks.n_spikes > 1:
# print(
# f"{it=} {new_peaks.n_spikes=} {new_peaks.scores.mean().numpy(force=True)=} {torch.diff(new_peaks.times).min()=}"
# )
# tq = new_peaks.times.numpy(force=True)
# print(f"{np.diff(tq).min()=} {tq=}")

# enforce refractoriness
self.enforce_refractory(
Expand Down Expand Up @@ -579,7 +572,8 @@ def match_chunk(
# new_norm = torch.linalg.norm(residual) ** 2
# print(f"{it=} {new_norm=}")
# print(f"{(new_norm-old_norm)=}")
# print(f"{new_peaks.scores.sum().numpy(force=True)=}")
# print(f"{new_peaks.n_spikes=}")
# print(f"{new_peaks.scores.mean().numpy(force=True)=}")
# print("----------")

# update spike train
Expand Down Expand Up @@ -607,6 +601,8 @@ def match_chunk(
)
if return_residual:
res["residual"] = residual
if return_conv:
res["conv"] = padded_conv
return res

def find_peaks(
Expand All @@ -617,15 +613,21 @@ def find_peaks(
refrac_mask,
compressed_template_data,
):
# update the coarse objective
torch.add(
compressed_template_data.objective_template_norms_squared.neg()[:, None],
padded_conv,
alpha=2.0,
out=padded_objective[:-1],
)

# first step: coarse peaks. not temporally upsampled or amplitude-scaled.
objective = (padded_objective + refrac_mask)[
:-1, self.obj_pad_len : -self.obj_pad_len
]
# formerly used detect_and_deduplicate, but that was slow.
objective_max, max_obj_template = objective.max(dim=0)
times = argrelmax(objective_max, self.spike_length_samples, self.threshold)
# tt = times.numpy(force=True)
# print(f"{np.diff(tt).min()=} {tt=}")
obj_template_indices = max_obj_template[times]
# remove peaks inside the padding
if not times.numel():
Expand All @@ -648,7 +650,7 @@ def find_peaks(
template_indices,
scores,
) = compressed_template_data.fine_match(
padded_conv[obj_template_indices, times],
padded_conv[obj_template_indices, times + self.obj_pad_len],
objective_max[times],
residual_snips,
obj_template_indices,
Expand All @@ -659,7 +661,7 @@ def find_peaks(
)
if time_shifts is not None:
times += time_shifts

return MatchingPeaks(
n_spikes=times.numel(),
times=times,
Expand Down Expand Up @@ -697,6 +699,7 @@ class MatchingTemplateData:
objective_spatial_components: torch.Tensor
objective_singular_values: torch.Tensor
objective_temporal_components: torch.Tensor
objective_temporalf: torch.Tensor
fine_to_coarse: torch.LongTensor
coarse_objective: bool
spatial_components: torch.Tensor
Expand Down Expand Up @@ -758,24 +761,23 @@ def convolve(self, traces, padding=0, out=None):
)
else:
assert out.shape == (self.objective_n_templates, out_len)
print(f"convolve {traces.shape=} {out.shape=} {out_len=} {padding=}")

for q in range(self.rank):
# units x time
rec_spatial = self.objective_spatial_singular[:, q, :] @ traces.T
# convolve with temporal components -- units x time
temporal = self.objective_temporal_components[:, :, q]
temporalf = self.objective_temporalf[:, :, q]
# conv1d with groups! only convolve each unit with its own temporal filter.
# conv = F.conv1d(
# rec_spatial[None],
# temporal[:, None, :],
# groups=self.n_templates,
# padding=padding,
# )[0]
conv = spiketorch.depthwise_oaconv1d(
rec_spatial, temporal[:, :], padding=padding
)
print(f"{conv.shape=}")
conv = F.conv1d(
rec_spatial[None],
temporal[:, None, :],
groups=self.objective_n_templates,
padding=padding,
)[0]
# conv = spiketorch.depthwise_oaconv1d(
# rec_spatial, temporal, padding=padding, f2=temporalf
# )
if q:
out += conv
else:
Expand Down Expand Up @@ -803,7 +805,6 @@ def subtract_conv(
)
ix_template = template_indices_a[:, None]
ix_time = times[:, None] + (conv_pad_len + self.conv_lags)[None, :]
print(f"{ix_template.shape=} {ix_time.shape=} {scalings.shape=} {pconvs.shape=}")
spiketorch.add_at_(
conv,
(ix_template, ix_time),
Expand Down Expand Up @@ -858,13 +859,12 @@ def fine_match(
spike_length_samples = window_length_samples - 1
# grab the current traces
snips = residual_snips[:, 1:]
# unpack the current traces and the traces one step back
snips_dt = F.unfold(
residual_snips[:, None, :, :], (spike_length_samples, snips.shape[2])
)
snips_dt = snips_dt.reshape(
len(snips), spike_length_samples, snips.shape[2], 2
)
# snips_dt = F.unfold(
# residual_snips[:, None, :, :], (spike_length_samples, snips.shape[2])
# )
# snips_dt = snips_dt.reshape(
# len(snips), spike_length_samples, snips.shape[2], 2
# )

if self.coarse_objective:
# TODO best I came up with, but it still syncs
Expand Down Expand Up @@ -900,6 +900,10 @@ def fine_match(
objs = 2 * scalings * b - torch.square(scalings) * a - inv_lambda
return None, None, scalings, template_indices, objs

# unpack the current traces and the traces one step back
snips_prev = residual_snips[:, :-1]
snips_dt = torch.stack((snips_prev, snips), dim=3)

# now, upsampling
# repeat the superres logic, the comp up index acts the same
comp_up_ix = self.compressed_upsampling_index[template_indices]
Expand All @@ -924,7 +928,6 @@ def fine_match(
2 * scalings * b - torch.square(scalings) * a - inv_lambda
)
else:
print(f"{objs.shape=} {objs[dup_ix, column_ix].shape=} {convs.shape=} {norms.shape=}")
objs[dup_ix, column_ix] = 2 * convs - norms[:, None]
scalings = None
objs, best_column_dt_ix = objs.reshape(len(objs), comp_up_ix.shape[1] * 2).max(dim=1)
Expand Down
19 changes: 1 addition & 18 deletions src/dartsort/templates/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def from_template_data(
n_jobs=0,
show_progress=True,
):
print(f"pairwise from_template_data {device=}")
compressed_convolve_to_h5(
hdf5_filename,
template_data=template_data,
Expand Down Expand Up @@ -124,22 +123,9 @@ def at_shifts(self, shifts_a=None, shifts_b=None):
n_shifted_temps_a, n_up_shifted_temps_b = self.pconv_index.shape

# active shifted and upsampled indices
print(f"{self.shifts_a=} {shifts_a.min()=} {shifts_a.max()=}")
shift_ix_a = np.searchsorted(self.shifts_a, shifts_a)
shift_ix_b = np.searchsorted(self.shifts_b, shifts_b)
print(
f"at_shifts {self.shifts_a.shape=} {self.shifts_a.min()=} {self.shifts_a.max()=}"
)
print(f"at_shifts {shifts_a.shape=} {shifts_a.min()=} {shifts_a.max()=}")
print(f"{shift_ix_a.shape=} {shift_ix_a.min()=} {shift_ix_a.max()=}")

print(
f"at_shifts {self.shifts_b.shape=} {self.shifts_b.min()=} {self.shifts_b.max()=}"
)
print(f"at_shifts {shifts_b.shape=} {shifts_b.min()=} {shifts_b.max()=}")
print(f"at_shifts {shift_ix_b.shape=} {shift_ix_b.min()=} {shift_ix_b.max()=}")

print(f"at_shifts {self.shifted_template_index_a.shape=}")
print(f"at_shifts {self.upsampled_shifted_template_index_b.shape=}")
sub_shifted_temp_index_a = self.shifted_template_index_a[
np.arange(len(self.shifted_template_index_a))[:, None],
shift_ix_a[:, None],
Expand All @@ -148,8 +134,6 @@ def at_shifts(self, shifts_a=None, shifts_b=None):
np.arange(len(self.upsampled_shifted_template_index_b))[:, None],
shift_ix_b[:, None],
]
print(f"at_shifts {sub_shifted_temp_index_a.shape=}")
print(f"at_shifts {sub_up_shifted_temp_index_b.shape=}")

# in flat form for indexing into pconv_index. also, reindex.
valid_a = sub_shifted_temp_index_a < n_shifted_temps_a
Expand Down Expand Up @@ -262,7 +246,6 @@ def query(
template_indices_a, template_indices_b
).T
if scalings_b is not None:
print(f"{scalings_b.shape=} {pconv_indices.shape=}")
scalings_b = torch.broadcast_to(scalings_b[None], pconv_indices.shape).reshape(-1)
if times_b is not None:
times_b = torch.broadcast_to(times_b[None], pconv_indices.shape).reshape(-1)
Expand Down
Loading

0 comments on commit e7eb19f

Please sign in to comment.