From 76e3e106079131f7ef0087246d321a9f6e5bb548 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 20 Dec 2024 11:56:38 -0500 Subject: [PATCH] Make universal scale to actual case --- src/dartsort/peel/matching.py | 67 ++--- src/dartsort/peel/peel_base.py | 2 +- src/dartsort/peel/universal.py | 45 ++-- src/dartsort/templates/pairwise.py | 303 ++++++++++++----------- src/dartsort/transform/temporal_pca.py | 18 +- src/dartsort/transform/transform_base.py | 4 +- src/dartsort/util/universal_util.py | 2 + 7 files changed, 217 insertions(+), 224 deletions(-) diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index a5480d8a..6bcc4249 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -36,6 +36,7 @@ def __init__( channel_index, featurization_pipeline, motion_est=None, + pairwise_conv_db=None, svd_compression_rank=10, coarse_objective=True, temporal_upsampling_factor=8, @@ -76,6 +77,7 @@ def __init__( # main properties self.template_data = template_data + self.pairwise_conv_db = pairwise_conv_db self.coarse_objective = coarse_objective self.temporal_upsampling_factor = temporal_upsampling_factor self.upsampling_peak_window_radius = upsampling_peak_window_radius @@ -320,21 +322,22 @@ def build_template_data( chunk_centers_s = self.recording._recording_segments[0].sample_index_to_time( chunk_centers_samples ) - self.pairwise_conv_db = CompressedPairwiseConv.from_template_data( - save_folder / "pconv.h5", - template_data=objective_temp_data, - low_rank_templates=objective_low_rank_temps, - template_data_b=template_data, - low_rank_templates_b=low_rank_templates, - compressed_upsampled_temporal=compressed_upsampled_temporal, - chunk_time_centers_s=chunk_centers_s, - motion_est=self.motion_est, - geom=self.geom, - overwrite=overwrite, - conv_ignore_threshold=self.conv_ignore_threshold, - coarse_approx_error_threshold=self.coarse_approx_error_threshold, - computation_config=computation_config, - ) + if self.pairwise_conv_db is None: + self.pairwise_conv_db = CompressedPairwiseConv.from_template_data( + save_folder / "pconv.h5", + template_data=objective_temp_data, + low_rank_templates=objective_low_rank_temps, + template_data_b=template_data, + low_rank_templates_b=low_rank_templates, + compressed_upsampled_temporal=compressed_upsampled_temporal, + chunk_time_centers_s=chunk_centers_s, + motion_est=self.motion_est, + geom=self.geom, + overwrite=overwrite, + conv_ignore_threshold=self.conv_ignore_threshold, + coarse_approx_error_threshold=self.coarse_approx_error_threshold, + computation_config=computation_config, + ) self.fixed_output_data += [ ("temporal_components", temporal_components), @@ -868,12 +871,7 @@ def subtract_conv( ) ix_template = template_indices_a[:, None] ix_time = times_sub[:, None] + (conv_pad_len + self.conv_lags)[None, :] - spiketorch.add_at_( - conv, - (ix_template, ix_time), - pconvs, - sign=-1, - ) + spiketorch.add_at_(conv, (ix_template, ix_time), pconvs, sign=-1) def fine_match( self, @@ -931,6 +929,7 @@ def fine_match( # ) if self.coarse_objective: + assert superres_index is not None # TODO best I came up with, but it still syncs superres_ix = superres_index[objective_template_indices] dup_ix, column_ix = (superres_ix < self.n_templates).nonzero(as_tuple=True) @@ -940,12 +939,6 @@ def fine_match( snips[dup_ix], self.spatial_singular[template_indices].mT, ).sum((1, 2)) - # convs = torch.einsum( - # "jtc,jrc,jtr->j", - # snips[dup_ix], - # self.spatial_singular[template_indices], - # self.temporal_components[template_indices], - # ) norms = self.template_norms_squared[template_indices] objs = torch.full(superres_ix.shape, -torch.inf, device=convs.device) objs[dup_ix, column_ix] = 2 * convs - norms @@ -980,12 +973,6 @@ def fine_match( comp_up_ix < self.n_compressed_upsampled_templates ).nonzero(as_tuple=True) comp_up_indices = comp_up_ix[dup_ix, column_ix] - # convs = torch.einsum( - # "jtcd,jrc,jtr->jd", - # snips_dt[dup_ix], - # self.spatial_singular[template_indices[dup_ix]], - # self.compressed_upsampled_temporal[comp_up_indices], - # ) temps = torch.bmm( self.compressed_upsampled_temporal[comp_up_indices], self.spatial_singular[template_indices[dup_ix]], @@ -993,20 +980,6 @@ def fine_match( convs = torch.linalg.vecdot(snips[dup_ix].view(len(temps), -1), temps) convs_prev = torch.linalg.vecdot(snips_prev[dup_ix].view(len(temps), -1), temps) - # convs_r = torch.round(convs).to(int).numpy() - # convs_prev_r = torch.round(convs_prev).to(int).numpy() - # convs = torch.einsum( - # "jtc,jrc,jtr->j", - # snips[dup_ix], - # self.spatial_singular[template_indices[dup_ix]], - # self.compressed_upsampled_temporal[comp_up_indices], - # ) - # convs_prev = torch.einsum( - # "jtc,jrc,jtr->j", - # snips_prev[dup_ix], - # self.spatial_singular[template_indices[dup_ix]], - # self.compressed_upsampled_temporal[comp_up_indices], - # ) better = convs >= convs_prev convs = torch.maximum(convs, convs_prev) diff --git a/src/dartsort/peel/peel_base.py b/src/dartsort/peel/peel_base.py index 8bb7dcd5..237e5b15 100644 --- a/src/dartsort/peel/peel_base.py +++ b/src/dartsort/peel/peel_base.py @@ -204,7 +204,7 @@ def peel( results, total=n_chunks_orig, initial=n_chunks_orig - len(chunks_to_do), - smoothing=0.01, + smoothing=0, desc=f"{task_name} {n_sec_chunk:.1f}s/it [spk/it=%%%]", ) diff --git a/src/dartsort/peel/universal.py b/src/dartsort/peel/universal.py index 803d58e7..9fad8bf8 100644 --- a/src/dartsort/peel/universal.py +++ b/src/dartsort/peel/universal.py @@ -3,6 +3,7 @@ from ..util import universal_util, waveform_util from ..transform import WaveformPipeline from .matching import ObjectiveUpdateTemplateMatchingPeeler +from ..templates.pairwise import SeparablePairwiseConv class UniversalTemplatesMatchingPeeler(ObjectiveUpdateTemplateMatchingPeeler): @@ -49,32 +50,36 @@ def __init__( fit_sampling="random", dtype=torch.float, ): - template_data = universal_util.universal_templates_from_data( - rec=recording, - detection_threshold=detection_threshold, - trough_offset_samples=trough_offset_samples, - spike_length_samples=spike_length_samples, - alignment_padding=alignment_padding, - n_centroids=n_centroids, - pca_rank=pca_rank, - n_waveforms_fit=n_waveforms_fit, - taper=taper, - taper_start=alignment_padding // 2, - taper_end=alignment_padding // 2, - random_seed=fit_subsampling_random_state, - n_sigmas=n_sigmas, - min_template_size=min_template_size, - max_distance=max_distance, - dx=dx, - # let's not worry about exposing these - deduplication_radius=150.0, - kmeanspp_initial="random", + shapes, footprints, template_data = ( + universal_util.universal_templates_from_data( + rec=recording, + detection_threshold=detection_threshold, + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, + alignment_padding=alignment_padding, + n_centroids=n_centroids, + pca_rank=pca_rank, + n_waveforms_fit=n_waveforms_fit, + taper=taper, + taper_start=alignment_padding // 2, + taper_end=alignment_padding // 2, + random_seed=fit_subsampling_random_state, + n_sigmas=n_sigmas, + min_template_size=min_template_size, + max_distance=max_distance, + dx=dx, + # let's not worry about exposing these + deduplication_radius=150.0, + kmeanspp_initial="random", + ) ) + pairwise_conv_db = SeparablePairwiseConv(footprints, shapes) super().__init__( recording, template_data, channel_index, featurization_pipeline, + pairwise_conv_db=pairwise_conv_db, threshold=threshold, amplitude_scaling_variance=amplitude_scaling_variance, amplitude_scaling_boundary=amplitude_scaling_boundary, diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index 478c28a2..abf62a7d 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -4,6 +4,7 @@ import h5py import numpy as np import torch +import torch.nn.functional as F from .pairwise_util import compressed_convolve_to_h5 from .template_util import CompressedUpsampledTemplates, LowRankTemplates @@ -56,6 +57,106 @@ class CompressedPairwiseConv: in_memory: bool = False device: torch.device = torch.device("cpu") + def query( + self, + template_indices_a, + template_indices_b, + upsampling_indices_b=None, + shifts_a=None, + shifts_b=None, + scalings_b=None, + times_b=None, + return_zero_convs=False, + grid=False, + device=None, + ): + if template_indices_a is None: + template_indices_a = torch.arange( + len(self.shifted_template_index_a), device=self.device + ) + template_indices_a = torch.atleast_1d(torch.as_tensor(template_indices_a)) + template_indices_b = torch.atleast_1d(torch.as_tensor(template_indices_b)) + + # handle no shifting + no_shifting = shifts_a is None or shifts_b is None + shifted_template_index = self.shifted_template_index_a + upsampled_shifted_template_index = self.upsampled_shifted_template_index_b + if no_shifting: + assert shifts_a is None and shifts_b is None + assert self.shifts_a.shape == (1,) + assert self.shifts_b.shape == (1,) + a_ix = (template_indices_a,) + b_ix = (template_indices_b,) + shifted_template_index = shifted_template_index[:, 0] + upsampled_shifted_template_index = upsampled_shifted_template_index[:, 0] + else: + shift_indices_a = self._get_shift_ix_a(shifts_a) + shift_indices_b = self._get_shift_ix_a(shifts_b) + a_ix = (template_indices_a, shift_indices_a) + b_ix = (template_indices_b, shift_indices_b) + + # handle no upsampling + no_upsampling = upsampling_indices_b is None + if no_upsampling: + assert self.upsampled_shifted_template_index_b.shape[2] == 1 + upsampled_shifted_template_index = upsampled_shifted_template_index[..., 0] + else: + b_ix = b_ix + (torch.atleast_1d(torch.as_tensor(upsampling_indices_b)),) + + # get shifted template indices for A + shifted_temp_ix_a = shifted_template_index[a_ix] + + # upsampled shifted template indices for B + up_shifted_temp_ix_b = upsampled_shifted_template_index[b_ix] + + # return convolutions between all ai,bj or just ai,bi? + if grid: + pconv_indices = self.pconv_index[ + shifted_temp_ix_a[:, None], up_shifted_temp_ix_b[None, :] + ] + template_indices_a, template_indices_b = torch.cartesian_prod( + template_indices_a, template_indices_b + ).T + if scalings_b is not None: + scalings_b = torch.broadcast_to( + scalings_b[None], pconv_indices.shape + ).reshape(-1) + if times_b is not None: + times_b = torch.broadcast_to( + times_b[None], pconv_indices.shape + ).reshape(-1) + pconv_indices = pconv_indices.view(-1) + else: + pconv_indices = self.pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b] + + # most users will be happy not to get a bunch of zeros for pairs that don't overlap + if not return_zero_convs: + which = pconv_indices > 0 + pconv_indices = pconv_indices[which] + template_indices_a = template_indices_a[which] + template_indices_b = template_indices_b[which] + if scalings_b is not None: + scalings_b = scalings_b[which] + if times_b is not None: + times_b = times_b[which] + + if self.in_memory: + pconvs = self.pconv[pconv_indices.to(self.pconv.device)] + else: + pconvs = torch.from_numpy( + batched_h5_read(self.pconv, pconv_indices.numpy(force=True)) + ) + if device is not None: + pconvs = pconvs.to(device) + + if scalings_b is not None: + pconvs.mul_(scalings_b[:, None]) + + if times_b is not None: + return template_indices_a, template_indices_b, times_b, pconvs + + return template_indices_a, template_indices_b, pconvs + def __post_init__(self): assert self.shifts_a.ndim == self.shifts_b.ndim == 1 assert self.shifts_a.shape == (self.shifted_template_index_a.shape[1],) @@ -69,7 +170,7 @@ def __post_init__(self): self.shifts_b ) - def get_shift_ix_a(self, shifts_a): + def _get_shift_ix_a(self, shifts_a): """Map shift (an integer, signed) to a shift index A shift index can be used to index into axis=1 of shifted_template_index_a, @@ -173,6 +274,44 @@ def to(self, device=None, incl_pconv=False, pin=False): # self.pconv = self.pconv.pin_memory() return self + +class SeparablePairwiseConv(torch.nn.Module): + def __init__(self, spatial_footprints, temporal_shapes): + """Footprint-major rank 1 template convolution database + + Let Nf = len(spatial_footprints), Ns = len(temporal_shapes). Then + indexing is footprint-major, so that + + template[i] = spatial_footprints[i // Nf] * temporal_shapes[i - Nf * (i // Nf)] + + Let f(i) = i // Nf and s(i) = i - Nf * (i // Nf). Then the channel-summed + convolution of templates i and j is given by + + conv(t; i, j) = ( + + * conv1d(temporal_shapes[s(i)], temporal_shapes[s(j)])[t] + ) + + Note: need to be consistent with interpretation of the sign of the time lag + between here and CompressedPairwiseConv. + """ + self.register_buffer("spatial_footprints", torch.asarray(spatial_footprints)) + self.register_buffer("temporal_shapes", torch.asarray(temporal_shapes)) + self.Nf = len(spatial_footprints) + + # convolve all pairs of temporal shapes + # i is data, j is filter + nt = temporal_shapes.shape[1] + inp = self.temporal_shapes[:, None, :] + fil = self.temporal_shapes[:, None, :] + # Ns, Ns, 2*nt - 1 + self.register_buffer("tconv", F.conv1d(inp, fil, padding=nt - 1)) + + # spatial component + sdot = self.spatial_footprints @ self.spatial_footprints.T + self.register_buffer("sdot", sdot) + self.tia = torch.arange(len(temporal_shapes)) + def query( self, template_indices_a, @@ -186,84 +325,24 @@ def query( grid=False, device=None, ): + if device is not None and device != self.spatial_footprints.device: + self.to(device) + assert shifts_a is shifts_b is None + assert upsampling_indices_b is None + del return_zero_convs # choose not to implement this + assert grid # only this case here. can probably do the same above. if template_indices_a is None: - template_indices_a = torch.arange( - len(self.shifted_template_index_a), device=self.device - ) - template_indices_a = torch.atleast_1d(torch.as_tensor(template_indices_a)) - template_indices_b = torch.atleast_1d(torch.as_tensor(template_indices_b)) - - # handle no shifting - no_shifting = shifts_a is None or shifts_b is None - shifted_template_index = self.shifted_template_index_a - upsampled_shifted_template_index = self.upsampled_shifted_template_index_b - if no_shifting: - assert shifts_a is None and shifts_b is None - assert self.shifts_a.shape == (1,) - assert self.shifts_b.shape == (1,) - a_ix = (template_indices_a,) - b_ix = (template_indices_b,) - shifted_template_index = shifted_template_index[:, 0] - upsampled_shifted_template_index = upsampled_shifted_template_index[:, 0] - else: - shift_indices_a = self.get_shift_ix_a(shifts_a) - shift_indices_b = self.get_shift_ix_a(shifts_b) - a_ix = (template_indices_a, shift_indices_a) - b_ix = (template_indices_b, shift_indices_b) + template_indices_a = self.tia.to(template_indices_b) - # handle no upsampling - no_upsampling = upsampling_indices_b is None - if no_upsampling: - assert self.upsampled_shifted_template_index_b.shape[2] == 1 - upsampled_shifted_template_index = upsampled_shifted_template_index[..., 0] - else: - b_ix = b_ix + (torch.atleast_1d(torch.as_tensor(upsampling_indices_b)),) + f_i = template_indices_b // self.Nf + f_j = template_indices_a // self.Nf + s_i = template_indices_b - self.Nf * f_i + s_j = template_indices_a - self.Nf * f_j - # get shifted template indices for A - shifted_temp_ix_a = shifted_template_index[a_ix] - - # upsampled shifted template indices for B - up_shifted_temp_ix_b = upsampled_shifted_template_index[b_ix] - - # return convolutions between all ai,bj or just ai,bi? - if grid: - pconv_indices = self.pconv_index[ - shifted_temp_ix_a[:, None], up_shifted_temp_ix_b[None, :] - ] - template_indices_a, template_indices_b = torch.cartesian_prod( - template_indices_a, template_indices_b - ).T - if scalings_b is not None: - scalings_b = torch.broadcast_to( - scalings_b[None], pconv_indices.shape - ).reshape(-1) - if times_b is not None: - times_b = torch.broadcast_to( - times_b[None], pconv_indices.shape - ).reshape(-1) - pconv_indices = pconv_indices.view(-1) - else: - pconv_indices = self.pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b] - - # most users will be happy not to get a bunch of zeros for pairs that don't overlap - if not return_zero_convs: - which = pconv_indices > 0 - pconv_indices = pconv_indices[which] - template_indices_a = template_indices_a[which] - template_indices_b = template_indices_b[which] - if scalings_b is not None: - scalings_b = scalings_b[which] - if times_b is not None: - times_b = times_b[which] - - if self.in_memory: - pconvs = self.pconv[pconv_indices.to(self.pconv.device)] - else: - pconvs = torch.from_numpy( - batched_h5_read(self.pconv, pconv_indices.numpy(force=True)) - ) - if device is not None: - pconvs = pconvs.to(device) + pconvs = ( + self.sdot[f_i[:, None], f_j[None, :]] + * self.tconv[s_i[:, None], s_j[None, :]] + ) if scalings_b is not None: pconvs.mul_(scalings_b[:, None]) @@ -273,78 +352,6 @@ def query( return template_indices_a, template_indices_b, pconvs - # def at_shifts(self, shifts_a=None, shifts_b=None, device=None): - # """Subset this database to one set of shifts. - - # The database becomes shiftless (not in the pejorative sense). - # """ - # if shifts_a is None or shifts_b is None: - # assert shifts_a is shifts_b - # assert self.shifts_a.shape == (1,) - # assert self.shifts_b.shape == (1,) - # return self - - # assert shifts_a.shape == (len(self.shifted_template_index_a),) - # assert shifts_b.shape == (len(self.upsampled_shifted_template_index_b),) - # n_shifted_temps_a, n_up_shifted_temps_b = self.pconv_index.shape - - # # active shifted and upsampled indices - # shift_ix_a = self.get_shift_ix_a(shifts_a) - # shift_ix_b = self.get_shift_ix_b(shifts_b) - # sub_shifted_temp_index_a = self.shifted_template_index_a[ - # torch.arange(len(self.shifted_template_index_a))[:, None], - # shift_ix_a[:, None], - # ] - # sub_up_shifted_temp_index_b = self.upsampled_shifted_template_index_b[ - # torch.arange(len(self.upsampled_shifted_template_index_b))[:, None], - # shift_ix_b[:, None], - # ] - - # # in flat form for indexing into pconv_index. also, reindex. - # valid_a = sub_shifted_temp_index_a < n_shifted_temps_a - # shifted_temp_ixs_a, new_shifted_temp_ixs_a = torch.unique( - # sub_shifted_temp_index_a[valid_a], return_inverse=True - # ) - # valid_b = sub_up_shifted_temp_index_b < n_up_shifted_temps_b - # up_shifted_temp_ixs_b, new_up_shifted_temp_ixs_b = torch.unique( - # sub_up_shifted_temp_index_b[valid_b], return_inverse=True - # ) - - # # get relevant pconv subset and reindex - # sub_pconv_indices, new_pconv_indices = torch.unique( - # self.pconv_index[ - # shifted_temp_ixs_a[:, None], - # up_shifted_temp_ixs_b.ravel()[None, :], - # ], - # return_inverse=True, - # ) - # if self.in_memory: - # sub_pconv = self.pconv[sub_pconv_indices.to(self.pconv.device)] - # else: - # sub_pconv = torch.from_numpy(batched_h5_read(self.pconv, sub_pconv_indices)) - # if device is not None: - # sub_pconv = sub_pconv.to(device) - - # # reindexing - # n_sub_shifted_temps_a = len(shifted_temp_ixs_a) - # n_sub_up_shifted_temps_b = len(up_shifted_temp_ixs_b) - # sub_pconv_index = new_pconv_indices.view( - # n_sub_shifted_temps_a, n_sub_up_shifted_temps_b - # ) - # sub_shifted_temp_index_a[valid_a] = new_shifted_temp_ixs_a - # sub_up_shifted_temp_index_b[valid_b] = new_up_shifted_temp_ixs_b - - # return self.__class__( - # shifts_a=torch.zeros(1), - # shifts_b=torch.zeros(1), - # shifted_template_index_a=sub_shifted_temp_index_a, - # upsampled_shifted_template_index_b=sub_up_shifted_temp_index_b, - # pconv_index=sub_pconv_index, - # pconv=sub_pconv, - # in_memory=True, - # device=self.device, - # ) - def _get_shift_indexer(shifts): assert torch.equal(shifts, torch.sort(shifts).values) diff --git a/src/dartsort/transform/temporal_pca.py b/src/dartsort/transform/temporal_pca.py index 6f2043bb..10be33f5 100644 --- a/src/dartsort/transform/temporal_pca.py +++ b/src/dartsort/transform/temporal_pca.py @@ -2,12 +2,18 @@ import torch from sklearn.decomposition import PCA -from dartsort.util.waveform_util import (channel_subset_by_radius, - get_channels_in_probe, - set_channels_in_probe) - -from .transform_base import (BaseWaveformAutoencoder, BaseWaveformDenoiser, - BaseWaveformFeaturizer, BaseWaveformModule) +from dartsort.util.waveform_util import ( + channel_subset_by_radius, + get_channels_in_probe, + set_channels_in_probe, +) + +from .transform_base import ( + BaseWaveformAutoencoder, + BaseWaveformDenoiser, + BaseWaveformFeaturizer, + BaseWaveformModule, +) class BaseTemporalPCA(BaseWaveformModule): diff --git a/src/dartsort/transform/transform_base.py b/src/dartsort/transform/transform_base.py index d34c32b9..772d6791 100644 --- a/src/dartsort/transform/transform_base.py +++ b/src/dartsort/transform/transform_base.py @@ -43,7 +43,7 @@ def extra_repr(self): class BaseWaveformDenoiser(BaseWaveformModule): is_denoiser = True - def forward(self, waveforms, max_channels=None): + def forward(self, waveforms, max_channels): raise NotImplementedError @@ -55,7 +55,7 @@ class BaseWaveformFeaturizer(BaseWaveformModule): # output dtye dtype = torch.float - def transform(self, waveforms, max_channels=None): + def transform(self, waveforms, max_channels): # returns dict {key=feat name, value=feature} raise NotImplementedError diff --git a/src/dartsort/util/universal_util.py b/src/dartsort/util/universal_util.py index 0df269f6..57773434 100644 --- a/src/dartsort/util/universal_util.py +++ b/src/dartsort/util/universal_util.py @@ -193,6 +193,8 @@ def singlechan_to_library( nsct, nt = singlechan_templates.shape templates = footprints[:, None, None, :] * singlechan_templates[None, :, :, None] assert templates.shape == (nf, nsct, nt, nc) + + # note: this is footprint-major templates = templates.reshape(nf * nsct, nt, nc) templates /= np.linalg.norm(templates, axis=(1, 2))