Skip to content

Commit

Permalink
Make universal scale to actual case
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Dec 20, 2024
1 parent 0cbdd00 commit 76e3e10
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 224 deletions.
67 changes: 20 additions & 47 deletions src/dartsort/peel/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
channel_index,
featurization_pipeline,
motion_est=None,
pairwise_conv_db=None,
svd_compression_rank=10,
coarse_objective=True,
temporal_upsampling_factor=8,
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(

# main properties
self.template_data = template_data
self.pairwise_conv_db = pairwise_conv_db
self.coarse_objective = coarse_objective
self.temporal_upsampling_factor = temporal_upsampling_factor
self.upsampling_peak_window_radius = upsampling_peak_window_radius
Expand Down Expand Up @@ -320,21 +322,22 @@ def build_template_data(
chunk_centers_s = self.recording._recording_segments[0].sample_index_to_time(
chunk_centers_samples
)
self.pairwise_conv_db = CompressedPairwiseConv.from_template_data(
save_folder / "pconv.h5",
template_data=objective_temp_data,
low_rank_templates=objective_low_rank_temps,
template_data_b=template_data,
low_rank_templates_b=low_rank_templates,
compressed_upsampled_temporal=compressed_upsampled_temporal,
chunk_time_centers_s=chunk_centers_s,
motion_est=self.motion_est,
geom=self.geom,
overwrite=overwrite,
conv_ignore_threshold=self.conv_ignore_threshold,
coarse_approx_error_threshold=self.coarse_approx_error_threshold,
computation_config=computation_config,
)
if self.pairwise_conv_db is None:
self.pairwise_conv_db = CompressedPairwiseConv.from_template_data(
save_folder / "pconv.h5",
template_data=objective_temp_data,
low_rank_templates=objective_low_rank_temps,
template_data_b=template_data,
low_rank_templates_b=low_rank_templates,
compressed_upsampled_temporal=compressed_upsampled_temporal,
chunk_time_centers_s=chunk_centers_s,
motion_est=self.motion_est,
geom=self.geom,
overwrite=overwrite,
conv_ignore_threshold=self.conv_ignore_threshold,
coarse_approx_error_threshold=self.coarse_approx_error_threshold,
computation_config=computation_config,
)

self.fixed_output_data += [
("temporal_components", temporal_components),
Expand Down Expand Up @@ -868,12 +871,7 @@ def subtract_conv(
)
ix_template = template_indices_a[:, None]
ix_time = times_sub[:, None] + (conv_pad_len + self.conv_lags)[None, :]
spiketorch.add_at_(
conv,
(ix_template, ix_time),
pconvs,
sign=-1,
)
spiketorch.add_at_(conv, (ix_template, ix_time), pconvs, sign=-1)

def fine_match(
self,
Expand Down Expand Up @@ -931,6 +929,7 @@ def fine_match(
# )

if self.coarse_objective:
assert superres_index is not None
# TODO best I came up with, but it still syncs
superres_ix = superres_index[objective_template_indices]
dup_ix, column_ix = (superres_ix < self.n_templates).nonzero(as_tuple=True)
Expand All @@ -940,12 +939,6 @@ def fine_match(
snips[dup_ix],
self.spatial_singular[template_indices].mT,
).sum((1, 2))
# convs = torch.einsum(
# "jtc,jrc,jtr->j",
# snips[dup_ix],
# self.spatial_singular[template_indices],
# self.temporal_components[template_indices],
# )
norms = self.template_norms_squared[template_indices]
objs = torch.full(superres_ix.shape, -torch.inf, device=convs.device)
objs[dup_ix, column_ix] = 2 * convs - norms
Expand Down Expand Up @@ -980,33 +973,13 @@ def fine_match(
comp_up_ix < self.n_compressed_upsampled_templates
).nonzero(as_tuple=True)
comp_up_indices = comp_up_ix[dup_ix, column_ix]
# convs = torch.einsum(
# "jtcd,jrc,jtr->jd",
# snips_dt[dup_ix],
# self.spatial_singular[template_indices[dup_ix]],
# self.compressed_upsampled_temporal[comp_up_indices],
# )
temps = torch.bmm(
self.compressed_upsampled_temporal[comp_up_indices],
self.spatial_singular[template_indices[dup_ix]],
).view(len(comp_up_indices), -1)
convs = torch.linalg.vecdot(snips[dup_ix].view(len(temps), -1), temps)
convs_prev = torch.linalg.vecdot(snips_prev[dup_ix].view(len(temps), -1), temps)

# convs_r = torch.round(convs).to(int).numpy()
# convs_prev_r = torch.round(convs_prev).to(int).numpy()
# convs = torch.einsum(
# "jtc,jrc,jtr->j",
# snips[dup_ix],
# self.spatial_singular[template_indices[dup_ix]],
# self.compressed_upsampled_temporal[comp_up_indices],
# )
# convs_prev = torch.einsum(
# "jtc,jrc,jtr->j",
# snips_prev[dup_ix],
# self.spatial_singular[template_indices[dup_ix]],
# self.compressed_upsampled_temporal[comp_up_indices],
# )
better = convs >= convs_prev
convs = torch.maximum(convs, convs_prev)

Expand Down
2 changes: 1 addition & 1 deletion src/dartsort/peel/peel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def peel(
results,
total=n_chunks_orig,
initial=n_chunks_orig - len(chunks_to_do),
smoothing=0.01,
smoothing=0,
desc=f"{task_name} {n_sec_chunk:.1f}s/it [spk/it=%%%]",
)

Expand Down
45 changes: 25 additions & 20 deletions src/dartsort/peel/universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..util import universal_util, waveform_util
from ..transform import WaveformPipeline
from .matching import ObjectiveUpdateTemplateMatchingPeeler
from ..templates.pairwise import SeparablePairwiseConv


class UniversalTemplatesMatchingPeeler(ObjectiveUpdateTemplateMatchingPeeler):
Expand Down Expand Up @@ -49,32 +50,36 @@ def __init__(
fit_sampling="random",
dtype=torch.float,
):
template_data = universal_util.universal_templates_from_data(
rec=recording,
detection_threshold=detection_threshold,
trough_offset_samples=trough_offset_samples,
spike_length_samples=spike_length_samples,
alignment_padding=alignment_padding,
n_centroids=n_centroids,
pca_rank=pca_rank,
n_waveforms_fit=n_waveforms_fit,
taper=taper,
taper_start=alignment_padding // 2,
taper_end=alignment_padding // 2,
random_seed=fit_subsampling_random_state,
n_sigmas=n_sigmas,
min_template_size=min_template_size,
max_distance=max_distance,
dx=dx,
# let's not worry about exposing these
deduplication_radius=150.0,
kmeanspp_initial="random",
shapes, footprints, template_data = (
universal_util.universal_templates_from_data(
rec=recording,
detection_threshold=detection_threshold,
trough_offset_samples=trough_offset_samples,
spike_length_samples=spike_length_samples,
alignment_padding=alignment_padding,
n_centroids=n_centroids,
pca_rank=pca_rank,
n_waveforms_fit=n_waveforms_fit,
taper=taper,
taper_start=alignment_padding // 2,
taper_end=alignment_padding // 2,
random_seed=fit_subsampling_random_state,
n_sigmas=n_sigmas,
min_template_size=min_template_size,
max_distance=max_distance,
dx=dx,
# let's not worry about exposing these
deduplication_radius=150.0,
kmeanspp_initial="random",
)
)
pairwise_conv_db = SeparablePairwiseConv(footprints, shapes)
super().__init__(
recording,
template_data,
channel_index,
featurization_pipeline,
pairwise_conv_db=pairwise_conv_db,
threshold=threshold,
amplitude_scaling_variance=amplitude_scaling_variance,
amplitude_scaling_boundary=amplitude_scaling_boundary,
Expand Down
Loading

0 comments on commit 76e3e10

Please sign in to comment.