From e7eb19f880ef6bde4056d4f46eb93bc4173d2439 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 28 Nov 2023 17:38:05 -0500 Subject: [PATCH] Sorta functional --- src/dartsort/localize/localize_torch.py | 2 + src/dartsort/main.py | 2 +- src/dartsort/peel/matching.py | 103 +++++++++-------- src/dartsort/templates/pairwise.py | 19 +--- src/dartsort/templates/pairwise_util.py | 145 +++--------------------- src/dartsort/templates/template_util.py | 5 - src/dartsort/templates/templates.py | 8 +- src/dartsort/util/drift_util.py | 21 +--- src/dartsort/util/spiketorch.py | 6 +- 9 files changed, 77 insertions(+), 234 deletions(-) diff --git a/src/dartsort/localize/localize_torch.py b/src/dartsort/localize/localize_torch.py index 56a35024..42741034 100644 --- a/src/dartsort/localize/localize_torch.py +++ b/src/dartsort/localize/localize_torch.py @@ -103,6 +103,8 @@ def localize_amplitude_vectors( geom_pad = F.pad(geom, (0, 0, 0, 1)) local_geoms = geom_pad[channel_index[main_channels]] local_geoms[:, :, 1] -= geom[main_channels, 1][:, None] + print(f"{amplitude_vectors.shape=}") + print(f"{local_geoms.shape=}") # center of mass initialization com = torch.divide( diff --git a/src/dartsort/main.py b/src/dartsort/main.py index b441cc52..f05a328c 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -169,7 +169,7 @@ def _run_peeler( ) # do localization - if featurization_config.do_localization: + if not featurization_config.denoise_only and featurization_config.do_localization: wf_name = featurization_config.output_waveforms_name localize_hdf5( output_hdf5_filename, diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index 48abc064..24e56859 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -285,11 +285,16 @@ def build_template_data( self.handle_template_groups( coarse_template_data.unit_ids, self.template_data.unit_ids ) + convlen = self.chunk_length_samples + self.chunk_margin_samples + block_size, *_ = spiketorch._calc_oa_lens(convlen, self.spike_length_samples) + self.register_buffer("objective_temporalf", torch.fft.rfft(self.objective_temporal_components, dim=1, n=block_size)) half_chunk = self.chunk_length_samples // 2 - chunk_centers_samples = np.arange( - half_chunk, self.recording.get_num_samples(), self.chunk_length_samples + chunk_starts = np.arange( + 0, self.recording.get_num_samples(), self.chunk_length_samples ) + chunk_ends = np.minimum(chunk_starts + self.chunk_length_samples, self.recording.get_num_samples()) + chunk_centers_samples = (chunk_starts + chunk_ends) / 2 chunk_centers_s = self.recording._recording_segments[0].sample_index_to_time( chunk_centers_samples ) @@ -392,6 +397,7 @@ def peel_chunk( left_margin=0, right_margin=0, return_residual=False, + return_conv=False, ): # get current template set chunk_center_samples = chunk_start_samples + self.chunk_length_samples // 2 @@ -404,11 +410,12 @@ def peel_chunk( match_results = self.match_chunk( traces, compressed_template_data, - trough_offset_samples=42, - left_margin=0, - right_margin=0, - threshold=30, + trough_offset_samples=self.trough_offset_samples, + left_margin=left_margin, + right_margin=right_margin, + threshold=self.threshold, return_residual=return_residual, + return_conv=return_conv, ) # process spike times and create return result @@ -459,15 +466,17 @@ def templates_at_time(self, t_s): pconvdb = pconvdb.at_shifts(pitch_shifts_a, pitch_shifts_b) else: cur_spatial = self.spatial_components + cur_obj_spatial = self.objective_spatial_components max_channels = self.registered_template_ampvecs.argmax(1) if not pconvdb._is_torch: - pconvdb.to(cur_spatial.device) + pconvdb.to(cur_obj_spatial.device) return MatchingTemplateData( objective_spatial_components=cur_obj_spatial, objective_singular_values=self.objective_singular_values, objective_temporal_components=self.objective_temporal_components, + objective_temporalf=self.objective_temporalf, fine_to_coarse=self.fine_to_coarse, coarse_objective=self.coarse_objective, spatial_components=cur_spatial, @@ -477,7 +486,7 @@ def templates_at_time(self, t_s): compressed_upsampling_index=self.compressed_upsampling_index, compressed_index_to_upsampling_index=self.compressed_index_to_upsampling_index, compressed_upsampled_temporal=self.compressed_upsampled_temporal, - max_channels=max_channels, + max_channels=torch.as_tensor(max_channels, device=cur_obj_spatial.device), pairwise_conv_db=pconvdb, ) @@ -490,12 +499,12 @@ def match_chunk( right_margin=0, threshold=30, return_residual=False, + return_conv=False, ): """Core peeling routine for subtraction""" # initialize residual, it needs to be padded to support our channel # indexing convention (used later to extract small channel # neighborhoods). this copies the input. - print(f"match_chunk {traces.shape=}") residual_padded = F.pad(traces, (0, 1), value=torch.nan) residual = residual_padded[:, :-1] @@ -517,7 +526,6 @@ def match_chunk( refrac_mask = torch.zeros_like(padded_objective) # padded objective has an extra unit (for group_index) and refractory # padding (for easier implementation of enforce_refractory) - neg_temp_normsq = -compressed_template_data.objective_template_norms_squared[:, None] # manages buffers for spike train data (peak times, labels, etc) peaks = MatchingPeaks(device=traces.device) @@ -528,28 +536,13 @@ def match_chunk( ) # main loop - print("start") for it in range(self.max_iter): - # update the coarse objective - torch.add( - neg_temp_normsq, padded_conv, alpha=2.0, out=padded_objective[:-1] - ) - # find high-res peaks - print("before find") new_peaks = self.find_peaks( residual, padded_conv, padded_objective, refrac_mask, compressed_template_data ) if new_peaks is None: break - # print("----------") - # if not it % 1: - # if new_peaks.n_spikes > 1: - # print( - # f"{it=} {new_peaks.n_spikes=} {new_peaks.scores.mean().numpy(force=True)=} {torch.diff(new_peaks.times).min()=}" - # ) - # tq = new_peaks.times.numpy(force=True) - # print(f"{np.diff(tq).min()=} {tq=}") # enforce refractoriness self.enforce_refractory( @@ -579,7 +572,8 @@ def match_chunk( # new_norm = torch.linalg.norm(residual) ** 2 # print(f"{it=} {new_norm=}") # print(f"{(new_norm-old_norm)=}") - # print(f"{new_peaks.scores.sum().numpy(force=True)=}") + # print(f"{new_peaks.n_spikes=}") + # print(f"{new_peaks.scores.mean().numpy(force=True)=}") # print("----------") # update spike train @@ -607,6 +601,8 @@ def match_chunk( ) if return_residual: res["residual"] = residual + if return_conv: + res["conv"] = padded_conv return res def find_peaks( @@ -617,6 +613,14 @@ def find_peaks( refrac_mask, compressed_template_data, ): + # update the coarse objective + torch.add( + compressed_template_data.objective_template_norms_squared.neg()[:, None], + padded_conv, + alpha=2.0, + out=padded_objective[:-1], + ) + # first step: coarse peaks. not temporally upsampled or amplitude-scaled. objective = (padded_objective + refrac_mask)[ :-1, self.obj_pad_len : -self.obj_pad_len @@ -624,8 +628,6 @@ def find_peaks( # formerly used detect_and_deduplicate, but that was slow. objective_max, max_obj_template = objective.max(dim=0) times = argrelmax(objective_max, self.spike_length_samples, self.threshold) - # tt = times.numpy(force=True) - # print(f"{np.diff(tt).min()=} {tt=}") obj_template_indices = max_obj_template[times] # remove peaks inside the padding if not times.numel(): @@ -648,7 +650,7 @@ def find_peaks( template_indices, scores, ) = compressed_template_data.fine_match( - padded_conv[obj_template_indices, times], + padded_conv[obj_template_indices, times + self.obj_pad_len], objective_max[times], residual_snips, obj_template_indices, @@ -659,7 +661,7 @@ def find_peaks( ) if time_shifts is not None: times += time_shifts - + return MatchingPeaks( n_spikes=times.numel(), times=times, @@ -697,6 +699,7 @@ class MatchingTemplateData: objective_spatial_components: torch.Tensor objective_singular_values: torch.Tensor objective_temporal_components: torch.Tensor + objective_temporalf: torch.Tensor fine_to_coarse: torch.LongTensor coarse_objective: bool spatial_components: torch.Tensor @@ -758,24 +761,23 @@ def convolve(self, traces, padding=0, out=None): ) else: assert out.shape == (self.objective_n_templates, out_len) - print(f"convolve {traces.shape=} {out.shape=} {out_len=} {padding=}") for q in range(self.rank): # units x time rec_spatial = self.objective_spatial_singular[:, q, :] @ traces.T # convolve with temporal components -- units x time temporal = self.objective_temporal_components[:, :, q] + temporalf = self.objective_temporalf[:, :, q] # conv1d with groups! only convolve each unit with its own temporal filter. - # conv = F.conv1d( - # rec_spatial[None], - # temporal[:, None, :], - # groups=self.n_templates, - # padding=padding, - # )[0] - conv = spiketorch.depthwise_oaconv1d( - rec_spatial, temporal[:, :], padding=padding - ) - print(f"{conv.shape=}") + conv = F.conv1d( + rec_spatial[None], + temporal[:, None, :], + groups=self.objective_n_templates, + padding=padding, + )[0] + # conv = spiketorch.depthwise_oaconv1d( + # rec_spatial, temporal, padding=padding, f2=temporalf + # ) if q: out += conv else: @@ -803,7 +805,6 @@ def subtract_conv( ) ix_template = template_indices_a[:, None] ix_time = times[:, None] + (conv_pad_len + self.conv_lags)[None, :] - print(f"{ix_template.shape=} {ix_time.shape=} {scalings.shape=} {pconvs.shape=}") spiketorch.add_at_( conv, (ix_template, ix_time), @@ -858,13 +859,12 @@ def fine_match( spike_length_samples = window_length_samples - 1 # grab the current traces snips = residual_snips[:, 1:] - # unpack the current traces and the traces one step back - snips_dt = F.unfold( - residual_snips[:, None, :, :], (spike_length_samples, snips.shape[2]) - ) - snips_dt = snips_dt.reshape( - len(snips), spike_length_samples, snips.shape[2], 2 - ) + # snips_dt = F.unfold( + # residual_snips[:, None, :, :], (spike_length_samples, snips.shape[2]) + # ) + # snips_dt = snips_dt.reshape( + # len(snips), spike_length_samples, snips.shape[2], 2 + # ) if self.coarse_objective: # TODO best I came up with, but it still syncs @@ -900,6 +900,10 @@ def fine_match( objs = 2 * scalings * b - torch.square(scalings) * a - inv_lambda return None, None, scalings, template_indices, objs + # unpack the current traces and the traces one step back + snips_prev = residual_snips[:, :-1] + snips_dt = torch.stack((snips_prev, snips), dim=3) + # now, upsampling # repeat the superres logic, the comp up index acts the same comp_up_ix = self.compressed_upsampling_index[template_indices] @@ -924,7 +928,6 @@ def fine_match( 2 * scalings * b - torch.square(scalings) * a - inv_lambda ) else: - print(f"{objs.shape=} {objs[dup_ix, column_ix].shape=} {convs.shape=} {norms.shape=}") objs[dup_ix, column_ix] = 2 * convs - norms[:, None] scalings = None objs, best_column_dt_ix = objs.reshape(len(objs), comp_up_ix.shape[1] * 2).max(dim=1) diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index 67f689da..eba3dd86 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -86,7 +86,6 @@ def from_template_data( n_jobs=0, show_progress=True, ): - print(f"pairwise from_template_data {device=}") compressed_convolve_to_h5( hdf5_filename, template_data=template_data, @@ -124,22 +123,9 @@ def at_shifts(self, shifts_a=None, shifts_b=None): n_shifted_temps_a, n_up_shifted_temps_b = self.pconv_index.shape # active shifted and upsampled indices + print(f"{self.shifts_a=} {shifts_a.min()=} {shifts_a.max()=}") shift_ix_a = np.searchsorted(self.shifts_a, shifts_a) shift_ix_b = np.searchsorted(self.shifts_b, shifts_b) - print( - f"at_shifts {self.shifts_a.shape=} {self.shifts_a.min()=} {self.shifts_a.max()=}" - ) - print(f"at_shifts {shifts_a.shape=} {shifts_a.min()=} {shifts_a.max()=}") - print(f"{shift_ix_a.shape=} {shift_ix_a.min()=} {shift_ix_a.max()=}") - - print( - f"at_shifts {self.shifts_b.shape=} {self.shifts_b.min()=} {self.shifts_b.max()=}" - ) - print(f"at_shifts {shifts_b.shape=} {shifts_b.min()=} {shifts_b.max()=}") - print(f"at_shifts {shift_ix_b.shape=} {shift_ix_b.min()=} {shift_ix_b.max()=}") - - print(f"at_shifts {self.shifted_template_index_a.shape=}") - print(f"at_shifts {self.upsampled_shifted_template_index_b.shape=}") sub_shifted_temp_index_a = self.shifted_template_index_a[ np.arange(len(self.shifted_template_index_a))[:, None], shift_ix_a[:, None], @@ -148,8 +134,6 @@ def at_shifts(self, shifts_a=None, shifts_b=None): np.arange(len(self.upsampled_shifted_template_index_b))[:, None], shift_ix_b[:, None], ] - print(f"at_shifts {sub_shifted_temp_index_a.shape=}") - print(f"at_shifts {sub_up_shifted_temp_index_b.shape=}") # in flat form for indexing into pconv_index. also, reindex. valid_a = sub_shifted_temp_index_a < n_shifted_temps_a @@ -262,7 +246,6 @@ def query( template_indices_a, template_indices_b ).T if scalings_b is not None: - print(f"{scalings_b.shape=} {pconv_indices.shape=}") scalings_b = torch.broadcast_to(scalings_b[None], pconv_indices.shape).reshape(-1) if times_b is not None: times_b = torch.broadcast_to(times_b[None], pconv_indices.shape).reshape(-1) diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index d20e7a4c..9f0f45eb 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -67,7 +67,6 @@ def compressed_convolve_to_h5( template_data_b=template_data_b, motion_est=motion_est, ) - print(f"compressed_convolve_to_h5 {conv_batch_size=} {units_batch_size=} {device=}") chunk_res_iterator = iterate_compressed_pairwise_convolutions( template_data_a=template_data, @@ -193,9 +192,6 @@ def iterate_compressed_pairwise_convolutions( assert np.array_equal(reg_geom, template_data_b.registered_geom) # construct drift-related helper data if needed - print( - f"iterate_compressed_pairwise_convolutions {conv_batch_size=} {units_batch_size=} {device=}" - ) geom_kdtree = reg_geom_kdtree = match_distance = None if do_shifting: geom_kdtree = KDTree(geom) @@ -315,12 +311,6 @@ def compressed_convolve_pairs( shifts, superres templates, and upsamples. Some of these may be zero or may be duplicates, so the return value is a sparse representation. See below. """ - # print(f"compressed_convolve_pairs {device=}") - # print(f"{units_a.shape=}") - # print(f"{units_b.shape=}") - # print(f"{(units_a.size * units_b.size)=}") - # print(f"compressed_convolve_pairs {batch_size=} {units_a.size=} {device=}") - # what pairs, shifts, etc are we convolving? shifted_temp_ix_a, temp_ix_a, shift_a, unit_a = handle_shift_indices( units_a, template_data_a.unit_ids, template_shift_index_a @@ -328,8 +318,6 @@ def compressed_convolve_pairs( shifted_temp_ix_b, temp_ix_b, shift_b, unit_b = handle_shift_indices( units_b, template_data_b.unit_ids, template_shift_index_b ) - # print(f"0 {shifted_temp_ix_a.shape=} {(shifted_temp_ix_a.size / np.unique(unit_a).size)=}") - # print(f"0 {shifted_temp_ix_b.shape=} {(shifted_temp_ix_b.size / np.unique(unit_b).size)=}") # get (shifted) spatial components * singular values spatial_singular_a = get_shifted_spatial_singular( @@ -354,9 +342,6 @@ def compressed_convolve_pairs( match_distance=match_distance, device=device, ) - # print(f"{low_rank_templates.spatial_components.dtype=} {low_rank_templates.singular_values.dtype=}") - # print(f"{compressed_upsampled_temporal.compressed_upsampled_templates.dtype=}") - # print(f"{spatial_singular_a.dtype=} {spatial_singular_b.dtype=}") # figure out pairs of shifted templates to convolve in a deduplicated way pairs_ret = shift_deduplicated_pairs( @@ -378,17 +363,7 @@ def compressed_convolve_pairs( if pairs_ret is None: return None ix_a, ix_b, compression_index, conv_ix, spatial_shift_ids = pairs_ret - # print(f"A {ix_a.shape=}") - # print(f"A {ix_b.shape=}") - # print(f"A {compression_index.shape=}") - # print(f"A {conv_ix.shape=}") - # print(f"A {spatial_shift_ids.shape=}") - - # print(f"-----------") - # print(f"after pairs {conv_ix.shape=} {compression_index.shape=}") - # print(f"{compression_index.min()=} {compression_index.max()=}") - # print(f"{ix_a.shape=} {ix_b.shape=}") - + # handle upsampling # each pair will be duplicated by the b unit's number of upsampled copies ( @@ -409,30 +384,9 @@ def compressed_convolve_pairs( ) ix_a = ix_a[compression_dup_ix] spatial_shift_ids = spatial_shift_ids[compression_dup_ix] - # print(f"B {ix_a.shape=}") - # print(f"B {ix_b.shape=}") - # print(f"B {compression_index.shape=}") - # print(f"B {conv_ix.shape=}") - - # print(f"-----------") - # print(f"after up {conv_ix.shape=} {compression_index.shape=}") - # print(f"{compression_index.min()=} {compression_index.max()=}") - # print(f"{ix_a.shape=} {ix_b.shape=}") - - # # now, these arrays all have length n_pairs - # shifted_temp_ix_a = shifted_temp_ix_a[ix_a] - # temp_ix_a = temp_ix_a[ix_a] - # shift_a = shift_a[ix_a] - # shifted_temp_ix_b = shifted_temp_ix_b[ix_b] - # temp_ix_b = temp_ix_b[ix_b] - # shift_b = shift_b[ix_b] # run convolutions temporal_a = low_rank_templates_a.temporal_components[temp_ix_a] - # print(f"{spatial_singular_a[ix_a[conv_ix]].shape=}") - # print(f"{spatial_singular_b[ix_b[conv_ix]].shape=}") - # print(f"{temporal_a[ix_a[conv_ix]].shape=}") - # print(f"{conv_temporal_components_up_b.shape=}") pconv, kept = correlate_pairs_lowrank( torch.as_tensor(spatial_singular_a[ix_a[conv_ix]], device=device), torch.as_tensor(spatial_singular_b[ix_b[conv_ix]], device=device), @@ -442,8 +396,6 @@ def compressed_convolve_pairs( conv_ignore_threshold=conv_ignore_threshold, batch_size=batch_size, ) - # print(f"-----------") - # print(f"after corr {pconv.shape=} {conv_ix[kept].shape=}") if kept is not None: conv_ix = conv_ix[kept] if not conv_ix.shape[0]: @@ -455,29 +407,16 @@ def compressed_convolve_pairs( ix_b = ix_b[kept_pairs] spatial_shift_ids = spatial_shift_ids[kept_pairs] assert pconv.numel() > 0 - # print(f"-----------") - # print(f"after searchsorted {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}") - # print(f"{compression_index.min()=} {compression_index.max()=}") - # print(f"{ix_a.shape=} {ix_b.shape=}") # coarse approx - # print(f"-----------") - # print(f"before approx {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}") pconv, old_ix_to_new_ix = coarse_approximate( pconv, unit_a[ix_a[conv_ix]], unit_b[ix_b[conv_ix]], temp_ix_a[ix_a[conv_ix]], - # shift_a[ix_a[conv_ix]], - # shift_b[ix_b[conv_ix]], spatial_shift_ids[conv_ix], coarse_approx_error_threshold=coarse_approx_error_threshold, ) - # print(f"-----------") - # print(f"after approx") - # print(f"{pconv.shape=} {conv_ix.shape=} {old_ix_to_new_ix.shape=} {compression_index.shape=}") - # print(f"{compression_index.min()=} {compression_index.max()=}") - # print(f"{old_ix_to_new_ix.min()=} {old_ix_to_new_ix.max()=}") compression_index = old_ix_to_new_ix[compression_index] # above function invalidates the whole idea of conv_ix del conv_ix @@ -541,8 +480,7 @@ def correlate_pairs_lowrank( assert n_pairs == n_pairs_ assert t == t_ assert rank == rank_ - # print(f"compressed_convolve_pairs {batch_size=} {n_pairs=} {spatial_a.device=}") - + if max_shift == "full": max_shift = t - 1 elif max_shift == "valid": @@ -554,7 +492,6 @@ def correlate_pairs_lowrank( pconv = torch.zeros( (n_pairs, 2 * max_shift + 1), dtype=spatial_a.dtype, device=spatial_a.device ) - # print(f"compressed_convolve_pairs {pconv.shape=}") for istart in range(0, n_pairs, batch_size): iend = min(istart + batch_size, n_pairs) ix = slice(istart, iend) @@ -582,10 +519,8 @@ def correlate_pairs_lowrank( kept = max_val > conv_ignore_threshold pconv = pconv[kept] kept = np.flatnonzero(kept.numpy(force=True)) - # print(f"compressed_convolve_pairs {pconv.shape=} {kept.shape=}") else: kept = None - # print(f"compressed_convolve_pairs {pconv.shape=} {kept=}") return pconv, kept @@ -721,7 +656,6 @@ def shift_deduplicated_pairs( pair = chan_amp_a @ chan_amp_b.T pair = pair > conv_ignore_threshold pair = pair.cpu() - # print(f"___ after overlaps {pair.sum()=}") # co-occurrence cooccurrence = cooccurrence[ @@ -729,13 +663,11 @@ def shift_deduplicated_pairs( shifted_temp_ix_b[None, :], ] pair *= torch.as_tensor(cooccurrence, device=pair.device) - # print(f"___ after cooccur {pair.sum()=}") pair_ix_a, pair_ix_b = torch.nonzero(pair, as_tuple=True) nco = pair_ix_a.numel() if not nco: return None - # print(f"___ {nco=}") # if no shifting, deduplication is the identity do_shifting = reg_geom_kdtree is not None @@ -778,10 +710,6 @@ def shift_deduplicated_pairs( chanset_b, active_chan_ids_b = np.unique( active_chans_b, axis=0, return_inverse=True ) - # print(f"___ {chanset_a.sum(1)=}") - # print(f"___ {chanset_b.sum(1)=}") - # print(f"___ {active_chan_ids_a.shape=} {np.unique(active_chan_ids_a).shape=}") - # print(f"___ {active_chan_ids_b.shape=} {np.unique(active_chan_ids_b).shape=}") # 3 temp_ix_a = temp_ix_a[pair_ix_a] @@ -789,13 +717,6 @@ def shift_deduplicated_pairs( # get the relative shifts shift_a = shift_a[pair_ix_a] shift_b = shift_b[pair_ix_b] - # print(f"{temp_ix_a=}") - # print(f"{shift_a=}") - # print(f"{active_chan_ids_a[pair_ix_a]=}") - # print(f"{temp_ix_b=}") - # print(f"{shift_b=}") - # print(f"{active_chan_ids_b[pair_ix_b]=}") - # print(f"{shift_diff=}") # figure out combinations _, spatial_shift_ids = np.unique( @@ -812,7 +733,6 @@ def shift_deduplicated_pairs( temp_ix_b, spatial_shift_ids, ] - # print(f"{conv_determiners=}") # conv_ix: indices of unique determiners # compression_index: which representative does each pair belong to _, conv_ix, compression_index = np.unique( @@ -910,7 +830,8 @@ def compressed_upsampled_pairs( convolutions between templates ix_a[i], ix_b[i] equal that between templates ix_a[conv_ix[compression_index[i]]], ix_b[conv_ix[compression_index[i]]]. - We will upsample the templates in the RHS (b) in a compressed way. + We will upsample the templates in the RHS (b) in a compressed way, so that + each b index gets its own number of duplicates. """ up_factor = compressed_upsampled_temporal.compressed_upsampling_map.shape[1] compression_dup_ix = slice(None) @@ -930,61 +851,25 @@ def compressed_upsampled_pairs( ) conv_up_i, up_shift_up_i = np.nonzero(upsampling_mask) conv_compressed_upsampled_ix = ( - upsampled_shifted_template_index.up_shift_temp_ix_to_shift_temp_ix[ + upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[ up_shift_up_i ] ) - conv_ix_up = conv_ix[conv_up_i] + conv_dup = conv_ix[conv_up_i] # And, all ix_{a,b}[i] such that compression_ix[i] lands in # that conv_ix need to be duplicated as well. - dup_mask = conv_ix[compression_index][:, None] == conv_ix_up[None, :] + dup_mask = conv_ix[compression_index][:, None] == conv_dup[None, :] if torch.is_tensor(dup_mask): dup_mask = dup_mask.numpy(force=True) compression_dup_ix, compression_index_up = np.nonzero(dup_mask) ix_b_up = ix_b[compression_dup_ix] - # ix_a_up = np.repeat(ix_a, ndups) - # ix_b_up = np.repeat(ix_b, ndups) - - # ix_a_up = np.zeros(len(ix_a) * up_factor, dtype=int) - # ix_b_up = np.zeros(len(ix_a) * up_factor, dtype=int) - # compression_index_up = np.zeros(len(ix_a) * up_factor, dtype=int) - # conv_ix_up = np.zeros(len(conv_ix) * up_factor, dtype=int) - # conv_compressed_upsampled_ix = np.zeros(len(conv_ix) * up_factor, dtype=int) - # cur_dedup_ix = 0 - # cur_pair_ix = 0 - # cur_conv_up_ix = 0 - # for i, convi in enumerate(conv_ix): - # # get b's shifted template ix - # conv_shifted_temp_ix_b = shifted_temp_ix_b[ix_b[convi]] - - # # which compressed upsampled indices match this? - # which_up = np.flatnonzero( - # upsampled_shifted_template_index.up_shift_temp_ix_to_shift_temp_ix - # == conv_shifted_temp_ix_b - # ) - # conv_comp_up_ix = ( - # upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[which_up] - # ) - - # # which deduplication indices map ix_a,b to this convi? - # which_dedup = np.flatnonzero(compression_index == i) - - # # extend arrays with new indices - # nupi = conv_comp_up_ix.size - # n_new_pair = which_dedup.size * nupi - # ix_a_up[cur_pair_ix:cur_pair_ix+n_new_pair] = np.repeat(ix_a[which_dedup], nupi) - # ix_b_up[cur_pair_ix:cur_pair_ix+n_new_pair] = np.repeat(ix_b[which_dedup], nupi) - # conv_ix_up[cur_dedup_ix:cur_dedup_ix+nupi]=convi - # conv_compressed_upsampled_ix[cur_dedup_ix:cur_dedup_ix+nupi]=conv_comp_up_ix - # compression_index_up[cur_pair_ix:cur_pair_ix+n_new_pair]=np.tile(np.arange(cur_dedup_ix, cur_dedup_ix + nupi), which_dedup.size) - # cur_pair_ix += n_new_pair - # cur_dedup_ix += nupi - - # ix_a_up = ix_a_up[:cur_pair_ix] #np.array(ix_a_up) - # ix_b_up = ix_b_up[:cur_pair_ix] #np.array(ix_b_up) - # conv_compressed_upsampled_ix = conv_compressed_upsampled_ix[:cur_pair_ix] #np.array(conv_compressed_upsampled_ix) - # compression_index_up = compression_index_up[:cur_dedup_ix] #np.array(compression_index_up) - # conv_ix_up = conv_ix_up[:cur_dedup_ix] #np.array(conv_ix_up) + + # the conv ix need to be offset to keep the relation with the pairs + # ix_a[old i] + # offsets = np.cumsum((conv_ix[:, None] == conv_dup[None, :]).sum(0)) + # offsets -= offsets[0] + _, offsets = np.unique(compression_dup_ix, return_index=True) + conv_ix_up = offsets[conv_dup] # which upsamples and which templates? conv_upsampling_indices_b = ( @@ -997,7 +882,7 @@ def compressed_upsampled_pairs( conv_compressed_upsampled_ix ] ) - + return ( ix_b_up, compression_index_up, diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index af28a958..57f303bd 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -168,17 +168,12 @@ def templates_at_time( unregistered_depths_um = drift_util.invert_motion_estimate( motion_est, t_s, registered_template_depths_um ) - print(f"templates_at_time {registered_template_depths_um.min()=} {registered_template_depths_um.max()=}") - print(f"templates_at_time {unregistered_depths_um.min()=} {unregistered_depths_um.max()=}") - diff = unregistered_depths_um - registered_template_depths_um - print(f"templates_at_time {diff.min()=} {diff.max()=}") # reverse arguments to pitch shifts since we are going the other direction pitch_shifts = drift_util.get_spike_pitch_shifts( depths_um=registered_template_depths_um, geom=geom, registered_depths_um=unregistered_depths_um, ) - print(f"templates_at_time {pitch_shifts.min()=} {pitch_shifts.max()=}") # extract relevant channel neighborhoods, also by reversing args to a drift helper unregistered_templates = drift_util.get_waveforms_on_static_channels( registered_templates, diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 5decccb8..1abeee9d 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -65,9 +65,6 @@ def coarsen(self): self.registered_geom, localization_radius_um=self.localization_radius_um, ) - print(f"coarsen {self.registered_geom[:,1].min()=} {self.registered_geom[:,1].max()=}") - print(f"coarsen {self.registered_template_depths_um.min()=} {self.registered_template_depths_um.max()=}") - print(f"coarsen {registered_template_depths_um.min()=} {registered_template_depths_um.max()=}") return replace( self, @@ -179,11 +176,10 @@ def from_config( # main! results = get_templates(recording, sorting, **kwargs) - print( - f"{[(k,v.dtype) for k,v in results.items() if (isinstance(v, np.ndarray))]=}" - ) # handle registered templates + print(f"{results['templates'].shape=}") + print(f"{kwargs['registered_geom'].shape=}") if template_config.registered_templates: registered_template_depths_um = get_template_depths( results["templates"], diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index 4013619f..0bf4464a 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -220,13 +220,9 @@ def get_spike_pitch_shifts( else: probe_displacement = registered_depths_um - depths_um - print(f"get_spike_pitch_shifts {pitch=}") - print(f"get_spike_pitch_shifts {probe_displacement.min()=} {probe_displacement.max()=}") - # if probe_displacement > 0, then the registered position is below the original # and, to be conservative, round towards 0 rather than using // n_pitches_shift = (probe_displacement / pitch).astype(int) - print(f"get_spike_pitch_shifts {n_pitches_shift.min()=} {n_pitches_shift.max()=}") return n_pitches_shift @@ -547,15 +543,6 @@ def get_shift_and_unit_pairs( reg_depths_um = np.concatenate((reg_depths_um_a, reg_depths_um_b)) # figure out all shifts for all units at all times - print(f"get_shift_and_unit_pairs {np.min(chunk_time_centers_s)=} {np.max(chunk_time_centers_s)=}") - print(f"get_shift_and_unit_pairs {reg_depths_um.min()=} {reg_depths_um.max()=}") - print(f"get_shift_and_unit_pairs {reg_depths_um_a.shape=} {reg_depths_um_b.shape=}") - print(f"get_shift_and_unit_pairs {reg_depths_um_a.min()=} {reg_depths_um_a.max()=}") - print(f"get_shift_and_unit_pairs {reg_depths_um_b.min()=} {reg_depths_um_b.max()=}") - print(f"get_shift_and_unit_pairs {motion_est.time_bin_centers_s.min()=}") - print(f"get_shift_and_unit_pairs {motion_est.time_bin_centers_s.max()=}") - print(f"get_shift_and_unit_pairs {motion_est.spatial_bin_centers_um.min()=}") - print(f"get_shift_and_unit_pairs {motion_est.spatial_bin_centers_um.max()=}") unreg_depths_um = np.stack( [ invert_motion_estimate( @@ -565,11 +552,8 @@ def get_shift_and_unit_pairs( ], axis=0, ) - print(f"get_shift_and_unit_pairs {reg_depths_um.shape=} {unreg_depths_um.shape=}") assert unreg_depths_um.shape == (len(chunk_time_centers_s), len(reg_depths_um)) diff = reg_depths_um - unreg_depths_um - print(f"get_shift_and_unit_pairs {unreg_depths_um.min()=} {unreg_depths_um.max()=}") - print(f"get_shift_and_unit_pairs {diff.min()=} {diff.max()=}") pitch_shifts = get_spike_pitch_shifts( depths_um=reg_depths_um, pitch=get_pitch(geom), @@ -580,9 +564,8 @@ def get_shift_and_unit_pairs( else: shifts_a = pitch_shifts[:, :na] shifts_b = pitch_shifts[:, na:] - print(f"get_shift_and_unit_pairs {shifts_a.min()=} {shifts_a.max()=}") - print(f"get_shift_and_unit_pairs {shifts_b.min()=} {shifts_b.max()=}") - print(f"get_shift_and_unit_pairs {shifts_a.shape=} {shifts_b.shape=}") + print(f"{shifts_a.min()=} {shifts_a.max()=}") + print(f"{shifts_b.min()=} {shifts_b.max()=}") # assign ids to pitch/shift pairs template_shift_index_a = TemplateShiftIndex.from_shift_matrix(shifts_a) diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index 533b4a07..aeb305b9 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -289,7 +289,7 @@ def depthwise_oaconv1d(input, weight, f2=None, padding=0): s2 = weight.shape[1] assert s1 >= s2 - shape_final = s1 + s2 - 1 + # shape_full = s1 + s2 - 1 block_size, overlap, in1_step, in2_step = _calc_oa_lens(s1, s2) nstep1, pad1, nstep2, pad2 = steps_and_pad( s1, in1_step, s2, in2_step, block_size, overlap @@ -320,15 +320,11 @@ def depthwise_oaconv1d(input, weight, f2=None, padding=0): oa = fold_res.reshape(n1, fold_out_len) # this is the full convolution - print(f"oaconv orig {oa.shape=}") # oa = oa[:, : shape_final] - print(f"oaconv full {oa.shape=}") # extract correct padding valid_len = s1 - s2 + 1 valid_start = s2 - 1 assert valid_start >= padding oa = oa[:, valid_start - padding:valid_start + valid_len + padding] - print(f"oaconv {oa.shape=} {valid_len=} {valid_start=} {padding=}") - print(f"oaconv {(valid_start - padding)=} {(valid_start + valid_len + padding)=}") return oa