From 54a031e866256d7556d562f04128069d5ba7a9f2 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 4 Aug 2023 10:12:56 -0700 Subject: [PATCH 01/49] SI ... --- spike_psvae/subtract.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spike_psvae/subtract.py b/spike_psvae/subtract.py index b3bbb11a..b216a333 100644 --- a/spike_psvae/subtract.py +++ b/spike_psvae/subtract.py @@ -668,8 +668,8 @@ def subtraction_binary( recording = sc.read_binary( standardized_bin, sampling_rate, - n_channels, - binary_dtype, + num_channels=n_channels, + dtype=binary_dtype, time_axis=time_axis, is_filtered=True, ) From 132e1de4d98798263a6357f23b642925b517d34b Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 4 Aug 2023 10:13:15 -0700 Subject: [PATCH 02/49] new snr --- spike_psvae/hybrid_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spike_psvae/hybrid_analysis.py b/spike_psvae/hybrid_analysis.py index 8305b3c6..57a0d48c 100644 --- a/spike_psvae/hybrid_analysis.py +++ b/spike_psvae/hybrid_analysis.py @@ -1836,7 +1836,7 @@ def calc_template_snrs( spike_length_samples=spike_length_samples, buffer=wf_buffer, ) - denominator = np.abs(np.einsum("ij,nij->n", t, noise) / C).mean() + denominator = np.abs(np.einsum("ij,nij->n", t, noise) / C).std()#.mean() snrs.append(numerator / denominator) return np.array(snrs) From 267f35cab940e197fe1549c385d024f5430d9d8f Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 4 Aug 2023 10:13:31 -0700 Subject: [PATCH 03/49] super init --- dartsort/transform/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dartsort/transform/base.py b/dartsort/transform/base.py index 25864ab6..a6739b95 100644 --- a/dartsort/transform/base.py +++ b/dartsort/transform/base.py @@ -10,6 +10,7 @@ class BaseWaveformModule(torch.nn.Module): default_name = "" def __init__(self, name): + super().__init__() self.name = name if name is None: name = self.default_name From 5e016c501f1439090c1c72edb84200f0ae570a76 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 17 Oct 2023 16:40:05 -0400 Subject: [PATCH 04/49] Checking in working ResidualUpdate algorithm before converting to ObjectiveUpdate --- src/dartsort/config.py | 2 +- src/dartsort/detect/detect.py | 15 +- src/dartsort/localize/localize_torch.py | 9 + src/dartsort/main.py | 5 + src/dartsort/peel/matching.py | 341 ++++++++++++++---------- src/dartsort/templates/pairwise_conv.py | 65 +++++ src/dartsort/templates/superres_util.py | 2 + src/dartsort/templates/template_util.py | 29 +- src/dartsort/templates/templates.py | 48 +++- src/dartsort/util/drift_util.py | 39 ++- src/dartsort/util/spiketorch.py | 51 +++- 11 files changed, 436 insertions(+), 170 deletions(-) create mode 100644 src/dartsort/templates/pairwise_conv.py diff --git a/src/dartsort/config.py b/src/dartsort/config.py index 07b54864..0aa103e5 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -153,7 +153,7 @@ class MatchingConfig: fit_subsampling_random_state: int = 0 # template matching parameters - threshold: float = 30.0 + threshold: float = 50.0 template_svd_compression_rank: int = 10 template_temporal_upsampling_factor: int = 8 template_min_channel_amplitude: float = 1.0 diff --git a/src/dartsort/detect/detect.py b/src/dartsort/detect/detect.py index 10e46074..59fc5de5 100644 --- a/src/dartsort/detect/detect.py +++ b/src/dartsort/detect/detect.py @@ -41,7 +41,9 @@ def detect_and_deduplicate( with corresponding channels """ nsamples, nchans = traces.shape - if dedup_channel_index is not None: + if dedup_channel_index == "all": + pass + elif dedup_channel_index is not None: assert dedup_channel_index.shape[0] == nchans # -- handle peak sign. we use max pool below, so make peaks positive @@ -79,14 +81,17 @@ def detect_and_deduplicate( F.threshold_(energies, threshold, 0.0) # -- temporal deduplication - max_energies = energies if dedup_temporal_radius > 0: - max_energies, indices = F.max_pool2d_with_indices( + max_energies = F.max_pool2d( energies, kernel_size=[2 * dedup_temporal_radius + 1, 1], stride=1, padding=[dedup_temporal_radius, 0], ) + elif dedup_channel_index is not None: + max_energies = energies.clone() + else: + max_energies = energies # back to TC energies = energies[0, 0] max_energies = max_energies[0, 0] @@ -94,7 +99,9 @@ def detect_and_deduplicate( # -- spatial deduplication # we would like to max pool again on the other axis, # but that doesn't support any old radial neighborhood - if dedup_channel_index is not None: + if dedup_channel_index == "all": + max_energies[:] = max_energies.max(dim=1, keepdim=True).values + elif dedup_channel_index is not None: # pad channel axis with extra chan of 0s max_energies = F.pad(max_energies, (0, 1)) for batch_start in range(0, nsamples, spatial_dedup_batch_size): diff --git a/src/dartsort/localize/localize_torch.py b/src/dartsort/localize/localize_torch.py index d7e4d20b..d13ee78a 100644 --- a/src/dartsort/localize/localize_torch.py +++ b/src/dartsort/localize/localize_torch.py @@ -67,6 +67,8 @@ def localize_amplitude_vectors( channel_index = full_channel_index(n_channels_tot) assert channel_index.shape == (n_channels_tot, c) assert main_channels.shape == (n_spikes,) + # we'll return numpy if user sent numpy + is_numpy = not torch.is_tensor(amplitude_vectors) # handle channel subsetting if radius is not None or n_channels_subset is not None: @@ -142,6 +144,13 @@ def localize_amplitude_vectors( ) z_abs = z_rel + geom[main_channels, 1] + if is_numpy: + x = x.numpy(force=True) + y = y.numpy(force=True) + z_rel = z_rel.numpy(force=True) + z_abs = z_abs.numpy(force=True) + alpha = alpha.numpy(force=True) + return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha) diff --git a/src/dartsort/main.py b/src/dartsort/main.py index 5d84c488..c0bf73da 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -86,6 +86,7 @@ def match( residual_filename=None, show_progress=True, device=None, + template_npz_filename="matching0_templates.npz", hdf5_filename="matching0.h5", model_subdir="matching0_models", ): @@ -98,7 +99,10 @@ def match( n_jobs=n_jobs_templates, save_folder=output_directory, overwrite=overwrite, + device=device, + save_npz_name=template_npz_filename, ) + # instantiate peeler matching_peeler = ResidualUpdateTemplateMatchingPeeler.from_config( recording, @@ -157,6 +161,7 @@ def _run_peeler( overwrite=overwrite, residual_filename=residual_filename, show_progress=show_progress, + device=device, ) # do localization diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index b5cfb32f..e1f513fa 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -40,7 +40,7 @@ def __init__( amplitude_scaling_variance=0.0, amplitude_scaling_boundary=0.5, trough_offset_samples=42, - threshold=30.0, + threshold=50.0, chunk_length_samples=30_000, n_chunks_fit=40, fit_subsampling_random_state=0, @@ -67,6 +67,9 @@ def __init__( min_channel_amplitude=min_channel_amplitude, rank=svd_compression_rank, ) + temporal_components = temporal_components.astype(recording.dtype) + singular_values = singular_values.astype(recording.dtype) + spatial_components = spatial_components.astype(recording.dtype) self.handle_upsampling( temporal_components, temporal_upsampling_factor=temporal_upsampling_factor, @@ -78,6 +81,7 @@ def __init__( self.refractory_radius_frames = refractory_radius_frames self.max_iter = max_iter self.n_templates = n_templates + self.trough_offset_samples = trough_offset_samples self.spike_length_samples = spike_length_samples self.geom = recording.get_channel_locations() self.svd_compression_rank = svd_compression_rank @@ -101,6 +105,7 @@ def __init__( "_refrac_ix", torch.arange(-refractory_radius_frames, refractory_radius_frames + 1), ) + self.register_buffer("_rank_ix", torch.arange(svd_compression_rank)) # amplitude scaling properties self.is_scaling = bool(amplitude_scaling_variance) @@ -166,7 +171,7 @@ def check_shapes(self): assert self.unit_ids.shape == (self.n_templates,) def handle_template_groups(self, unit_ids): - self.unit_ids = unit_ids + self.register_buffer("unit_ids", torch.from_numpy(unit_ids)) self.grouped_temps = True unique_units = np.unique(unit_ids) if unique_units.size == unit_ids.size: @@ -188,7 +193,7 @@ def handle_template_groups(self, unit_ids): group_index = np.full((self.n_templates, max_group_size), -1) for j, row in enumerate(group_index): group_index[j, : len(row)] = row - self.group_index = torch.tensor(group_index) + self.register_buffer("group_index", torch.from_numpy(group_index)) def handle_upsampling( self, @@ -197,41 +202,46 @@ def handle_upsampling( upsampling_peak_window_radius=8, ): self.temporal_upsampling_factor = temporal_upsampling_factor - upsampled_temporal_components = temporal_components - if temporal_upsampling_factor > 1: - upsampled_temporal_components = template_util.temporally_upsample_templates( - temporal_components, - temporal_upsampling_factor=temporal_upsampling_factor, - ) + if temporal_upsampling_factor == 1: + upsampled_temporal_components = temporal_components[:, :, None, :] self.register_buffer( "upsampled_temporal_components", torch.tensor(upsampled_temporal_components), ) - self.register_buffer( - "upsampling_window", - torch.arange( - -upsampling_peak_window_radius, upsampling_peak_window_radius + 1 - ), - ) - self.upsampling_window_len = 2 * upsampling_peak_window_radius - center = upsampling_peak_window_radius * temporal_upsampling_factor - radius = temporal_upsampling_factor // 2 + temporal_upsampling_factor % 2 - self.register_buffer( - "upsampled_peak_search_window", - torch.arange(center - radius, center + radius + 1), - ) - self.register_buffer( - "peak_to_upsampling_index", - torch.concatenate( - [ - torch.arange(radius, -1, -1), - (temporal_upsampling_factor - 1) - torch.arange(radius), - ] - ), - ) - self.register_buffer( - "peak_to_time_shift", torch.tensor([0] * (radius + 1) + [1] * radius) - ) + return + + upsampled_temporal_components = template_util.temporally_upsample_templates( + temporal_components, + temporal_upsampling_factor=temporal_upsampling_factor, + ) + self.register_buffer( + "upsampled_temporal_components", torch.tensor(upsampled_temporal_components) + ) + self.register_buffer( + "upsampling_window", + torch.arange( + -upsampling_peak_window_radius, upsampling_peak_window_radius + 1 + ), + ) + self.upsampling_window_len = 2 * upsampling_peak_window_radius + center = upsampling_peak_window_radius * temporal_upsampling_factor + radius = temporal_upsampling_factor // 2 + temporal_upsampling_factor % 2 + self.register_buffer( + "upsampled_peak_search_window", + torch.arange(center - radius, center + radius + 1), + ) + self.register_buffer( + "peak_to_upsampling_index", + torch.concatenate( + [ + torch.arange(radius, -1, -1), + (temporal_upsampling_factor - 1) - torch.arange(radius), + ] + ), + ) + self.register_buffer( + "peak_to_time_shift", torch.tensor([0] * (radius + 1) + [1] * radius) + ) @classmethod def from_config( @@ -279,7 +289,7 @@ def peel_chunk( ): # get current template set chunk_center_samples = chunk_start_samples + self.chunk_length_samples // 2 - + segment = self.recording._recording_segments[0] chunk_center_seconds = segment.sample_index_to_time(chunk_center_samples) compressed_template_data = self.templates_at_time(chunk_center_seconds) @@ -302,21 +312,21 @@ def peel_chunk( def templates_at_time(self, t_s): """Extract the right spatial components for each unit.""" if self.is_drifting: - cur_spatial = template_util.templates_at_time( + pitch_shifts, cur_spatial = template_util.templates_at_time( t_s, self.spatial_components, self.geom, registered_template_depths_um=self.registered_template_depths_um, registered_geom=self.registered_geom, motion_est=self.motion_est, + return_pitch_shifts=True, ) - cur_ampvecs = template_util.templates_at_time( - t_s, + cur_ampvecs = drift_util.get_waveforms_on_static_channels( self.registered_template_ampvecs[:, None, :], - self.geom, - registered_template_depths_um=self.registered_template_depths_um, - registered_geom=self.registered_geom, - motion_est=self.motion_est, + self.registered_geom, + n_pitches_shift=pitch_shifts, + registered_geom=self.geom, + fill_value=0.0, ) max_channels = cur_ampvecs[:, 0, :].argmax(1) else: @@ -328,7 +338,7 @@ def templates_at_time(self, t_s): self.singular_values, self.temporal_components, self.upsampled_temporal_components, - max_channels, + torch.tensor(max_channels, device=cur_spatial.device), ) def match_chunk( @@ -348,57 +358,89 @@ def match_chunk( residual = residual_padded[:, :-1] # name objective variables so that we can update them in-place later - conv = None conv_len = traces.shape[0] - self.spike_length_samples + 1 padded_obj_len = conv_len + 2 * self.obj_pad_len + padded_conv = torch.zeros( + self.n_templates, + padded_obj_len, + dtype=traces.dtype, + device=traces.device, + ) padded_objective = torch.zeros( self.n_templates + 1, padded_obj_len, dtype=traces.dtype, device=traces.device, ) + refrac_mask = torch.zeros_like(padded_objective) # padded objective has an extra unit (for group_index) and refractory # padding (for easier implementation of enforce_refractory) - objective = padded_objective[ - :-1, self.refractory_radius_frames : -self.refractory_radius_frames - ] neg_temp_normsq = -compressed_template_data.template_norms_squared[:, None] # manages buffers for spike train data (peak times, labels, etc) - peaks = MatchingPeaks() - + peaks = MatchingPeaks(device=traces.device) # main loop - for _ in range(self.max_iter): + print("start") + for it in range(self.max_iter): # update objective - conv = compressed_template_data.convolve(residual, out=conv) + compressed_template_data.convolve( + residual, padding=self.obj_pad_len, out=padded_conv + ) # unscaled objective for coarse peaks, scaled when finding high res peak - torch.add(neg_temp_normsq, conv, alpha=2.0, out=objective) + torch.add( + neg_temp_normsq, padded_conv, alpha=2.0, out=padded_objective[:-1] + ) # find high-res peaks - new_peaks = self.find_peaks(conv, padded_objective, peaks, neg_temp_normsq) + print('before find') + new_peaks = self.find_peaks( + padded_conv, padded_objective, refrac_mask, neg_temp_normsq + ) if new_peaks is None: break + # print("----------") + # if not it % 1: + # if new_peaks.n_spikes > 1: + # print( + # f"{it=} {new_peaks.n_spikes=} {new_peaks.scores.mean().numpy(force=True)=} {torch.diff(new_peaks.times).min()=}" + # ) + # tq = new_peaks.times.numpy(force=True) + # print(f"{np.diff(tq).min()=} {tq=}") + + # enforce refractoriness + self.enforce_refractory( + refrac_mask, + new_peaks.times + self.obj_pad_len, + new_peaks.template_indices, + ) # subtract them - # offset times: conv result peaks with valid padding are offset - # by spike len - 1 samples from the corresponding trace peaks - sample_times = new_peaks.times + self.spike_length_samples - 1 + # old_norm = torch.linalg.norm(residual) ** 2 compressed_template_data.subtract( - residual, - sample_times, + residual_padded, + new_peaks.times, new_peaks.template_indices, new_peaks.upsampling_indices, new_peaks.scalings, ) + # new_norm = torch.linalg.norm(residual) ** 2 + # print(f"{it=} {new_norm=}") + # print(f"{(new_norm-old_norm)=}") + # print(f"{new_peaks.scores.sum().numpy(force=True)=}") + # print("----------") + # update spike train peaks.extend(new_peaks) peaks.sort() # extract collision-cleaned waveforms on small neighborhoods - channels, waveforms = self.get_collisioncleaned_waveforms() + channels, waveforms = self.get_collisioncleaned_waveforms( + residual_padded, peaks, compressed_template_data + ) return dict( + n_spikes=peaks.n_spikes, times_samples=peaks.times + self.trough_offset_samples, channels=channels, labels=self.unit_ids[peaks.template_indices], @@ -409,38 +451,35 @@ def match_chunk( collisioncleaned_waveforms=waveforms, ) - def find_peaks(self, conv, padded_objective, peaks, neg_temp_normsq): - # zeroth step: enforce refractoriness. - self.enforce_refractory( - padded_objective, - peaks.times + self.obj_pad_len, - peaks.template_indices, - ) - + def find_peaks(self, padded_conv, padded_objective, refrac_mask, neg_temp_normsq): # first step: coarse peaks. not temporally upsampled or amplitude-scaled. - objective = padded_objective[:-1, self.obj_pad_len : -self.obj_pad_len] - times, template_indices = detect_and_deduplicate( - objective.T, - self.threshold, - dedup_channel_index=None, - peak_sign="pos", - # add 1 here to account for possible time_shifts later - relative_peak_radius=self.spike_length_samples + 1, - dedup_temporal_radius=0, - # dedup_temporal_radius=self.spike_length_samples + 1, - ) + padded_obj_len = padded_objective.shape[1] + objective = (padded_objective + refrac_mask)[ + :-1, self.obj_pad_len : -self.obj_pad_len + ] + # formerly used detect_and_deduplicate, but that was slow. + objective_max, max_template = objective.max(dim=0) + times = argrelmax(objective_max, self.spike_length_samples, self.threshold) + # tt = times.numpy(force=True) + # print(f"{np.diff(tt).min()=} {tt=}") + template_indices = max_template[times] + # remove peaks inside the padding if not times.numel(): return None # second step: high-res peaks (upsampled and/or amp-scaled) time_shifts, upsampling_indices, scalings, scores = self.find_fancy_peaks( - conv, objective, times, template_indices, neg_temp_normsq + padded_conv, + padded_objective, + times + self.obj_pad_len, + template_indices, + neg_temp_normsq, ) if time_shifts is not None: times += time_shifts return MatchingPeaks( - n_spikes=times.size, + n_spikes=times.numel(), times=times, template_indices=template_indices, upsampling_indices=upsampling_indices, @@ -449,13 +488,15 @@ def find_peaks(self, conv, padded_objective, peaks, neg_temp_normsq): ) def enforce_refractory(self, objective, times, template_indices): + if not times.numel(): + return # overwrite objective with -inf to enforce refractoriness - time_ix = times[None, :] + self._refrac_ix[:, None] + time_ix = times[:, None] + self._refrac_ix[None, :] if self.grouped_temps: unit_ix = self.group_index[template_indices] else: - unit_ix = template_indices[:, None] - objective[unit_ix, time_ix] = -torch.inf + unit_ix = template_indices[:, None, None] + objective[unit_ix[:, :, None], time_ix[:, None, :]] = -torch.inf def find_fancy_peaks( self, conv, objective, times, template_indices, neg_temp_normsq @@ -466,10 +507,10 @@ def find_fancy_peaks( # use one of the upsampled templates, no problem. when the peak # comes to the left, it's different: it came from one of the upsampled # templates shifted one sample (spike time += 1). - if self.up_factor == 1 and not self.is_scaling: + if self.temporal_upsampling_factor == 1 and not self.is_scaling: return None, None, None, objective[template_indices, times] - if self.is_scaling and self.up_factor == 1: + if self.is_scaling and self.temporal_upsampling_factor == 1: inv_lambda = 1 / self.amplitude_scaling_variance b = conv[times, template_indices] + inv_lambda a = neg_temp_normsq[template_indices] + inv_lambda @@ -509,7 +550,7 @@ def find_fancy_peaks( high_res_obj[:, self.upsampled_peak_search_window], dim=1 ) upsampling_indices = self.peak_to_upsampling_index[zoom_peak] - time_shifts = self.peak_to_time_shifts[zoom_peak] + time_shifts = self.peak_to_time_shift[zoom_peak] return time_shifts, upsampling_indices, scalings, scores @@ -527,14 +568,14 @@ def get_collisioncleaned_waveforms( buffer=0, already_padded=True, ) - spatial = compressed_template_data.spatial_singular[ + padded_spatial = F.pad(compressed_template_data.spatial_singular, (0, 1)) + spatial = padded_spatial[ peaks.template_indices[:, None, None], - :, + self._rank_ix[None, :, None], self.channel_index[channels][:, None, :], ] temporal = compressed_template_data.upsampled_temporal_components[ - peaks.template_indices, - peaks.upsampling_indices, + peaks.template_indices, :, peaks.upsampling_indices ] torch.baddbmm(waveforms, temporal, spatial, out=waveforms) return channels, waveforms @@ -556,6 +597,14 @@ def __post_init__(self): self.spike_length_samples, self.rank, ) = self.temporal_components.shape + assert self.spatial_components.shape[:2] == (self.n_templates, self.rank) + assert self.upsampled_temporal_components.shape == ( + self.n_templates, + self.spike_length_samples, + self.upsampled_temporal_components.shape[2], + self.rank, + ) + assert self.singular_values.shape == (self.n_templates, self.rank) # squared l2 norms are the sums of squared singular values self.template_norms_squared = torch.square(self.singular_values).sum(1) self.spatial_singular = ( @@ -568,21 +617,15 @@ def __post_init__(self): self.spike_length_samples, device=self.spatial_components.device ) - def convolve(self, traces, out=None): + def convolve(self, traces, padding=0, out=None): """This is not the fastest strategy on GPU, but it's low-memory and fast on CPU.""" + out_len = traces.shape[0] + 2 * padding - self.spike_length_samples + 1 if out is None: out = torch.zeros( - 1, - self.n_templates, - traces.shape[0] - self.spike_length_samples + 1, - dtype=traces.dtype, - device=traces.device, + 1, self.n_templates, out_len, dtype=traces.dtype, device=traces.device ) else: - assert out.shape == ( - self.n_templates, - traces.shape[0] - self.spike_length_samples + 1, - ) + assert out.shape == (self.n_templates, out_len) out = out[None] for q in range(self.rank): @@ -591,9 +634,16 @@ def convolve(self, traces, out=None): # convolve with temporal components -- units x time temporal = self.temporal_components[:, :, q] # conv1d with groups! only convolve each unit with its own temporal filter. - out += F.conv1d( - rec_spatial[None], temporal[:, None, :], groups=self.n_templates + conv = F.conv1d( + rec_spatial[None], + temporal[:, None, :], + groups=self.n_templates, + padding=padding, ) + if q: + out += conv + else: + out.copy_(conv) # back to units x time (remove extra dim used for conv1d) return out[0] @@ -607,10 +657,10 @@ def subtract( scalings, ): batch_templates = torch.einsum( - "n,nrc,ntr", + "n,nrc,ntr->ntc", scalings, self.spatial_singular[template_indices], - self.upsampled_temporal_components[template_indices, upsampling_indices], + self.upsampled_temporal_components[template_indices, :, upsampling_indices], ) time_ix = times[:, None, None] + self.time_ix[None, :, None] spiketorch.add_at_( @@ -619,7 +669,7 @@ def subtract( class MatchingPeaks: - BUFFER_INIT: int = 1000 + BUFFER_INIT: int = 1500 BUFFER_GROWTH: float = 1.5 def __init__( @@ -630,22 +680,35 @@ def __init__( upsampling_indices: Optional[torch.LongTensor] = None, scalings: Optional[torch.Tensor] = None, scores: Optional[torch.Tensor] = None, + device=None, ): self.n_spikes = n_spikes + self._times = times + self._template_indices = template_indices + self._upsampling_indices = upsampling_indices + self._scalings = scalings + self._scores = scores + + if device is None and times is not None: + device = times.device if times is None: - cur_buf_size = self.BUFFER_INIT - self._times = torch.zeros(cur_buf_size, dtype=int) + self.cur_buf_size = self.BUFFER_INIT + self._times = torch.zeros(self.cur_buf_size, dtype=int, device=device) else: - cur_buf_size = times.size - assert cur_buf_size == n_spikes + self.cur_buf_size = times.numel() + assert self.cur_buf_size == n_spikes if template_indices is None: - self._template_indices = torch.zeros(cur_buf_size, dtype=int) + self._template_indices = torch.zeros( + self.cur_buf_size, dtype=int, device=device + ) if scalings is None: - self._scalings = torch.zeros(cur_buf_size) + self._scalings = torch.ones(self.cur_buf_size, device=device) if upsampling_indices is None: - self._upsampling_indices = torch.zeros(cur_buf_size, dtype=int) + self._upsampling_indices = torch.zeros( + self.cur_buf_size, dtype=int, device=device + ) if scores is None: - self._scores = torch.zeros(cur_buf_size) + self._scores = torch.zeros(self.cur_buf_size, device=device) @property def times(self): @@ -668,29 +731,14 @@ def scores(self): return self._scores[: self.n_spikes] def grow_buffers(self, min_size=0): - new_buf_size = max(min_size, int(self.cur_buf_size * self.BUFFER_GROWTH)) - new_times = torch.zeros(new_buf_size, dtype=self._times.dtype) - new_template_indices = torch.zeros( - new_buf_size, dtype=self._template_indices.dtype - ) - new_upsampling_indices = torch.zeros( - new_buf_size, dtype=self._upsampling_indices.dtype - ) - new_scalings = torch.zeros(new_buf_size, dtype=self._scalings.dtype) - new_scores = torch.zeros(new_buf_size, dtype=self._scores.dtype) - - new_times[: self.n_spikes] = self.times - new_template_indices[: self.n_spikes] = self.template_indices - new_upsampling_indices[: self.n_spikes] = self.upsampling_indices - new_scalings[: self.n_spikes] = self.scalings - new_scores[: self.n_spikes] = self.scores - - self.cur_buf_size = new_buf_size - self._times = new_times - self._template_indices = new_template_indices - self._upsampling_indices = new_upsampling_indices - self._scalings = new_scalings - self._scores = new_scores + sz = max(min_size, int(self.cur_buf_size * self.BUFFER_GROWTH)) + k = self.n_spikes + self._times = _grow_buffer(self._times, k, sz) + self._template_indices = _grow_buffer(self._template_indices, k, sz) + self._upsampling_indices = _grow_buffer(self._upsampling_indices, k, sz) + self._scalings = _grow_buffer(self._scalings, k, sz) + self._scores = _grow_buffer(self._scores, k, sz) + self.cur_buf_size = sz def sort(self): order = torch.argsort(self.times[: self.n_spikes]) @@ -712,3 +760,24 @@ def extend(self, other): self._scalings[self.n_spikes : new_n_spikes] = other.scalings self._scores[self.n_spikes : new_n_spikes] = other.scores self.n_spikes = new_n_spikes + + +def _grow_buffer(x, old_length, new_size): + new = torch.empty(new_size, dtype=x.dtype, device=x.device) + new[:old_length] = x[:old_length] + return new + + +def argrelmax(x, radius, threshold, exclude_edge=True): + x1 = F.max_pool1d( + x[None, None], + kernel_size=2 * radius + 1, + padding=radius, + stride=1, + )[0, 0] + x1[x < x1] = 0 + F.threshold_(x1, threshold, 0.0) + ix = torch.nonzero(x1)[:, 0] + if exclude_edge: + return ix[(ix > 0) & (ix < x.numel() - 1)] + return ix diff --git a/src/dartsort/templates/pairwise_conv.py b/src/dartsort/templates/pairwise_conv.py new file mode 100644 index 00000000..be23ba67 --- /dev/null +++ b/src/dartsort/templates/pairwise_conv.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass + + +def sparse_pairwise_conv( + sorting, + template_temporal_components, + template_upsampled_temporal_components, + template_singular_values, + template_spatial_components, + conv_ignore_threshold: float = 0.0, + coarse_approx_error_threshold: float = 0.0, +): + """ + + Arguments + --------- + sorting : DARTsortSorting + original (non-superres) sorting. its labels should appear in + template_data.unit_ids + template_* : tensors or arrays + template SVD approximations + conv_ignore_threshold: float = 0.0 + pairs will be ignored (i.e., pconv set to 0) if their pconv + does not exceed this value + coarse_approx_error_threshold: float = 0.0 + superres will not be used if coarse pconv and superres pconv + are uniformly closer than this threshold value + + Returns + ------- + pitch_shifts : array + array of all the pitch shifts + use searchsorted to find the pitch shift ix for a pitch shift + index_table: torch sparse tensor + index_table[(pitch shift ix a, superres label a, pitch shift ix b, superres label b)] = ( + -1 + if superres pconv a,b at these shifts was below the conv_ignore_threshold + else pconv_index) + pconvs: np.ndarray + pconv[pconv_index] is a cross-correlation of two templates, summed over chans + """ + + + + +def _pairwise_conv_job( + units_a, + units_b, +): + """units_a,b are chunks of original (non-superres) unit labels""" + # determine co-visibility + # get all coarse templates + # get all superres templates + # compute all coarse and superres pconvs + + # returns + # list of tuples containing: + # - pitch shift ix a + # - pitch shift ix b + # - superres label a + # - superres label b + # list of the same length containing: + # - -1 or an index into the next list + # list of pconvs, indexed by previous list + \ No newline at end of file diff --git a/src/dartsort/templates/superres_util.py b/src/dartsort/templates/superres_util.py index 80d4c6a1..8476cf37 100644 --- a/src/dartsort/templates/superres_util.py +++ b/src/dartsort/templates/superres_util.py @@ -108,6 +108,7 @@ def motion_estimate_strategy( displacements = motion_est.disp_at_s(spike_times_s, spike_depths_um) mod_positions = displacements % pitch bin_ids = mod_positions // superres_bin_size_um + bin_ids = bin_ids.astype(int) orig_label_and_bin, superres_labels = np.unique( np.c_[original_labels, bin_ids], axis=0, return_inverse=True ) @@ -128,6 +129,7 @@ def drift_pitch_loc_bin_strategy( ) coarse_reg_depths = spike_depths_um + n_pitches_shift * pitch bin_ids = coarse_reg_depths // superres_bin_size_um + bin_ids = bin_ids.astype(int) orig_label_and_bin, superres_labels = np.unique( np.c_[original_labels, bin_ids], axis=0, return_inverse=True ) diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 7ee3d868..b4fea070 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -118,6 +118,22 @@ def get_realigned_sorting( return results["sorting"] +def weighted_average(unit_ids, templates, weights): + n_out = unit_ids.max() + 1 + n_in, t, c = templates.shape + out = np.zeros((n_out, t, c), dtype=templates.dtype) + for i in range(n_out): + which_in = np.flatnonzero(unit_ids == i) + if not which_in.size: + continue + + w = weights[which_in][:, None, None] + w /= w.sum() + out[i] = (w * templates[which_in]).sum(0) + + return out + + # -- template drift handling @@ -126,6 +142,7 @@ def get_template_depths(templates, geom, localization_radius_um=100): templates, geom=geom, radius=localization_radius_um ) template_depths_um = template_locs["z_abs"] + return template_depths_um @@ -136,6 +153,7 @@ def templates_at_time( registered_template_depths_um=None, registered_geom=None, motion_est=None, + return_pitch_shifts=False, ): if registered_geom is None: return registered_templates @@ -161,8 +179,8 @@ def templates_at_time( registered_geom=geom, fill_value=np.nan, ) - assert not np.isnan(unregistered_templates).any() - + if return_pitch_shifts: + return pitch_shifts, unregistered_templates return unregistered_templates @@ -180,9 +198,9 @@ def svd_compress_templates(templates, min_channel_amplitude=1.0, rank=5): vis_templates = templates * vis_mask U, s, Vh = np.linalg.svd(vis_templates, full_matrices=False) # s is descending. - temporal_components = U[..., :, :rank] - singular_values = s[..., :rank] - spatial_components = Vh[..., :rank, :] + temporal_components = U[:, :, :rank] + singular_values = s[:, :rank] + spatial_components = Vh[:, :rank, :] return temporal_components, singular_values, spatial_components @@ -197,4 +215,5 @@ def temporally_upsample_templates( tup.clip(0, t - 1, out=tup) upsampled_templates = erp(tup) upsampled_templates = upsampled_templates.reshape(n, t, temporal_upsampling_factor, c) + upsampled_templates = upsampled_templates.astype(templates.dtype) return upsampled_templates diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index d36ff123..4598e8da 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -22,31 +22,49 @@ @dataclass class TemplateData: + # (n_templates, spike_length_samples, n_registered_channels or n_channels) templates: np.ndarray + # (n_templates,) maps template index to unit index (multiple templates can share a unit index) unit_ids: np.ndarray + # (n_templates,) spike count for each template + spike_counts: np.ndarray + registered_geom: Optional[np.ndarray] = None registered_template_depths_um: Optional[np.ndarray] = None - + @classmethod def from_npz(cls, npz_path): with np.load(npz_path) as npz: templates = npz["templates"] unit_ids = npz["unit_ids"] + spike_counts = npz["spike_counts"] registered_geom = registered_template_depths_um = None if "registered_geom" in npz: registered_geom = npz["registered_geom"] if "registered_template_depths_um" in npz: registered_template_depths_um = npz["registered_template_depths_um"] - return cls(templates, unit_ids, registered_geom, registered_template_depths_um) - + return cls( + templates, + unit_ids, + spike_counts, + registered_geom, + registered_template_depths_um, + ) + def to_npz(self, npz_path): - to_save = dict(templates=self.templates, unit_ids=self.unit_ids) + to_save = dict( + templates=self.templates, + unit_ids=self.unit_ids, + spike_counts=self.spike_counts, + ) if self.registered_geom is not None: to_save["registered_geom"] = self.registered_geom if self.registered_template_depths_um is not None: - to_save["registered_template_depths_um"] = self.registered_template_depths_um + to_save[ + "registered_template_depths_um" + ] = self.registered_template_depths_um np.savez(npz_path, **to_save) - + @classmethod def from_config( cls, @@ -59,7 +77,7 @@ def from_config( save_npz_name="template_data.npz", localizations_dataset_name="point_source_localizations", n_jobs=0, - device=None, + device=None, ): if save_folder is not None: save_folder = Path(save_folder) @@ -67,8 +85,8 @@ def from_config( save_folder.mkdir() npz_path = save_folder / save_npz_name if npz_path.exists() and not overwrite: - return cls.from_npz(npz_path) - + return cls.from_npz(npz_path) + motion_aware = ( template_config.registered_templates or template_config.superres_templates ) @@ -141,6 +159,12 @@ def from_config( else: unit_ids = np.arange(sorting.labels.max() + 1) + # count spikes in each template + spike_counts = np.zeros_like(unit_ids) + ix, counts = np.unique(sorting.labels, return_counts=True) + spike_counts[ix[ix >= 0]] = counts[ix >= 0] + + # main! results = get_templates(recording, sorting, **kwargs) # handle registered templates @@ -153,15 +177,17 @@ def from_config( obj = cls( results["templates"], unit_ids, + spike_counts, kwargs["registered_geom"], registered_template_depths_um, ) else: - obj = cls( + obj = cls( results["templates"], unit_ids, + spike_counts, ) - + if save_folder is not None: obj.to_npz(npz_path) diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index 76748c7a..c2e53a96 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -231,6 +231,7 @@ def invert_motion_estimate(motion_est, t_s, registered_depths_um): and motion_est.spatial_bin_centers_um is not None ): bin_centers = motion_est.spatial_bin_centers_um + t_s = np.full(bin_centers.shape, t_s) bin_center_disps = motion_est.disp_at_s(t_s, depth_um=bin_centers) # registered_bin_centers = motion_est.correct_s(t_s, depths_um=bin_centers) registered_bin_centers = bin_centers - bin_center_disps @@ -374,11 +375,19 @@ def get_waveforms_on_static_channels( # scatter the waveforms into their static channel neighborhoods if out is None: - static_waveforms = np.full( - (n_spikes, t, n_static_channels + 1), - fill_value=fill_value, - dtype=waveforms.dtype, - ) + if torch.is_tensor(waveforms): + static_waveforms = torch.full( + (n_spikes, t, n_static_channels + 1), + fill_value=fill_value, + dtype=waveforms.dtype, + device=waveforms.device, + ) + else: + static_waveforms = np.full( + (n_spikes, t, n_static_channels + 1), + fill_value=fill_value, + dtype=waveforms.dtype, + ) else: assert out.shape == (n_spikes, t, n_static_channels + 1) out.fill(fill_value) @@ -404,12 +413,22 @@ def _full_probe_shifting_fast( fill_value, out=None, ): + is_tensor = torch.is_tensor(waveforms) + if out is None: - static_waveforms = np.full( - (*waveforms.shape[:2], target_kdtree.n + 1), - fill_value=fill_value, - dtype=waveforms.dtype, - ) + if is_tensor: + static_waveforms = torch.full( + (*waveforms.shape[:2], target_kdtree.n + 1), + fill_value=fill_value, + dtype=waveforms.dtype, + device=waveforms.device, + ) + else: + static_waveforms = np.full( + (*waveforms.shape[:2], target_kdtree.n + 1), + fill_value=fill_value, + dtype=waveforms.dtype, + ) else: assert out.shape == (*waveforms.shape[:2], target_kdtree.n + 1) out.fill(fill_value) diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index b892958f..eb97b64d 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -24,8 +24,6 @@ def ravel_multi_index(multi_index, dims): Indices into the flattened tensor of shape `dims` """ assert len(multi_index) == len(dims) - if any(torch.any((ix < 0) | (ix >= d)) for ix, d in zip(multi_index, dims)): - raise ValueError("Out of bounds indices in ravel_multi_index") # collect multi indices multi_index = torch.broadcast_tensors(*multi_index) @@ -53,9 +51,10 @@ def add_at_(dest, ix, src, sign=1): src = src.neg() elif sign != 1: src = sign * src + flat_ix = ravel_multi_index(ix, dest.shape) dest.view(-1).scatter_add_( 0, - ravel_multi_index(ix, dest.shape), + flat_ix, src.reshape(-1), ) @@ -218,3 +217,49 @@ def real_resample(x, num, dim=0): y *= (float(num) / float(Nx)) return y + + +def depthwise_oaconv1d(input, weight, f2=None, padding=0): + """Depthwise correlation (F.conv1d with groups=in_chans) with overlap-add + """ + # conv on last axis + # assert input.ndim == weight.ndim == 2 + n1 = input.shape[0] + n2 = weight.shape[0] + # assert n1 == n2 + s1 = input.shape[1] + s2 = weight.shape[1] + # assert s1 >= s2 + + shape_final = s1 + s2 - 1 + block_size, overlap, in1_step, in2_step = _calc_oa_lens(s1, s2) + nstep1, pad1, nstep2, pad2 = steps_and_pad(s1, in1_step, s2, in2_step, block_size, overlap) + + if pad1 > 0: + input = F.pad(input, (0, pad1)) + input = input.reshape(n1, nstep1, in1_step) + + # freq domain correlation + f1 = torch.fft.rfft(input, n=block_size) + if f2 is None: + f2 = torch.fft.rfft(weight, n=block_size) + # .conj() here to do cross-correlation instead of convolution (time reversal property of rfft) + f1.mul_(f2.conj()[:, None, :]) + res = torch.fft.irfft(f1, n=block_size) + + # overlap add part with torch + fold_input = res.reshape(n1, nstep1, block_size).permute(0, 2, 1) + fold_out_len = nstep1 * in1_step + overlap + fold_res = F.fold( + fold_input, + output_size=(1, fold_out_len), + kernel_size=(1, block_size), + stride=(1, in1_step), + ) + assert fold_res.shape == (n1, 1, 1, fold_out_len) + + oa = fold_res.reshape(n1, fold_out_len) + # this is the full convolution + oa = oa[:, :shape_final - pad1] + + return oa \ No newline at end of file From 1f8066a2f6ad1eae87a141b20921c3bc7f0b408b Mon Sep 17 00:00:00 2001 From: julienboussard Date: Thu, 19 Oct 2023 15:27:01 -0400 Subject: [PATCH 05/49] Update requirements.txt --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e18ff27a..b2d0197f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ pytest ibl-neuropixel spikeinterface -cloudpickle \ No newline at end of file +cloudpickle +hdbscan From 49746f5d8001b3caf1f1dfc1da36bcbf4396b00b Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 27 Oct 2023 11:55:28 -0400 Subject: [PATCH 06/49] Pairwise convolution checkin --- src/dartsort/cluster/split.py | 34 +- src/dartsort/templates/pairwise.py | 760 ++++++++++++++++++++++++ src/dartsort/templates/template_util.py | 2 +- src/dartsort/templates/templates.py | 10 +- src/dartsort/util/data_util.py | 5 +- src/dartsort/vis/scatterplots.py | 4 +- 6 files changed, 778 insertions(+), 37 deletions(-) create mode 100644 src/dartsort/templates/pairwise.py diff --git a/src/dartsort/cluster/split.py b/src/dartsort/cluster/split.py index 4b4c7f66..bc3fdc90 100644 --- a/src/dartsort/cluster/split.py +++ b/src/dartsort/cluster/split.py @@ -74,9 +74,7 @@ def split_clusters( new_labels = split_result.new_labels triaged = split_result.new_labels < 0 labels[in_unit[triaged]] = new_labels[triaged] - labels[in_unit[new_labels > 0]] = ( - cur_max_label + new_labels[new_labels > 0] - ) + labels[in_unit[new_labels > 0]] = cur_max_label + new_labels[new_labels > 0] new_untriaged_labels = labels[in_unit[new_labels >= 0]] cur_max_label = new_untriaged_labels.max() @@ -84,9 +82,7 @@ def split_clusters( if recursive: new_units = np.unique(new_untriaged_labels) for i in new_units: - jobs.append( - pool.submit(_split_job, np.flatnonzero(labels == i)) - ) + jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i))) if show_progress: iterator.total += len(new_units) @@ -151,7 +147,7 @@ def __init__( min_cluster_size=25, min_samples=25, cluster_selection_epsilon=25, - reassign_outliers=True, + reassign_outliers=False, random_state=0, **dataset_name_kwargs, ): @@ -241,18 +237,14 @@ def split_cluster(self, in_unit): is_split = np.setdiff1d(np.unique(hdb_labels), [-1]).size > 1 if is_split and self.reassign_outliers: - hdb_labels = cluster_util.knn_reassign_outliers( - hdb_labels, features - ) + hdb_labels = cluster_util.knn_reassign_outliers(hdb_labels, features) new_labels = None if is_split: new_labels = np.full(n_spikes, -1) new_labels[kept] = hdb_labels - return SplitResult( - is_split=is_split, in_unit=in_unit, new_labels=new_labels - ) + return SplitResult(is_split=is_split, in_unit=in_unit, new_labels=new_labels) def pca_features(self, in_unit): """Compute relocated PCA features on a drift-invariant channel set""" @@ -316,12 +308,8 @@ def pca_features(self, in_unit): return False, no_nan, None # fit pca and embed - pca = PCA( - self.n_pca_features, random_state=self.random_state, whiten=True - ) - pca_projs = np.full( - (n, self.n_pca_features), np.nan, dtype=waveforms.dtype - ) + pca = PCA(self.n_pca_features, random_state=self.random_state, whiten=True) + pca_projs = np.full((n, self.n_pca_features), np.nan, dtype=waveforms.dtype) pca_projs[no_nan] = pca.fit_transform(waveforms[no_nan]) return True, no_nan, pca_projs @@ -386,9 +374,7 @@ def initialize_from_h5( # this is to help split_clusters take a string argument all_split_strategies = [FeatureSplit] -split_strategies_by_class_name = { - cls.__name__: cls for cls in all_split_strategies -} +split_strategies_by_class_name = {cls.__name__: cls for cls in all_split_strategies} # -- parallelism widgets @@ -404,9 +390,7 @@ def __init__(self, split_strategy): def _split_job_init(split_strategy_class_name, split_strategy_kwargs): global _split_job_context split_strategy = split_strategies_by_class_name[split_strategy_class_name] - _split_job_context = SplitJobContext( - split_strategy(**split_strategy_kwargs) - ) + _split_job_context = SplitJobContext(split_strategy(**split_strategy_kwargs)) def _split_job(in_unit): diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py new file mode 100644 index 00000000..cbfab55f --- /dev/null +++ b/src/dartsort/templates/pairwise.py @@ -0,0 +1,760 @@ +from dataclasses import dataclass, fields +from typing import Optional + +import h5py +import numpy as np +import torch +import torch.nn.functional as F +from dartsort.templates import template_util +from dartsort.util import drift_util +from dartsort.util.multiprocessing_util import get_pool +from scipy.spatial import KDTree +from scipy.spatial.distance import pdist +from tqdm.auto import tqdm + +# todo: extend this code to also handle computing pairwise +# stuff necessary for the merge! ie shifts, scaling, +# residnorm(a,b) (or min of rn(a,b),rn(b,a)???) + + +def sparse_pairwise_conv( + output_hdf5_filename, + geom, + template_data, + template_temporal_components, + template_upsampled_temporal_components, + template_singular_values, + template_spatial_components, + chunk_time_centers_s=None, + motion_est=None, + conv_ignore_threshold: float = 0.0, + coarse_approx_error_threshold: float = 0.0, + min_channel_amplitude: float = 1.0, + units_per_chunk=8, + overwrite=False, + show_progress=True, + device=None, + n_jobs=0, +): + """ + + Arguments + --------- + template_* : tensors or arrays + template SVD approximations + conv_ignore_threshold: float = 0.0 + pairs will be ignored (i.e., pconv set to 0) if their pconv + does not exceed this value + coarse_approx_error_threshold: float = 0.0 + superres will not be used if coarse pconv and superres pconv + are uniformly closer than this threshold value + + Returns + ------- + pitch_shifts : array + array of all the pitch shifts + use searchsorted to find the pitch shift ix for a pitch shift + index_table: torch sparse tensor + index_table[(pitch shift ix a, superres label a, pitch shift ix b, superres label b)] = ( + 0 + if superres pconv a,b at these shifts was below the conv_ignore_threshold + else pconv_index) + pconvs: np.ndarray + pconv[pconv_index] is a cross-correlation of two templates, summed over chans + """ + if overwrite: + pass + + ( + n_templates, + spike_length_samples, + upsampling_factor, + ) = template_upsampled_temporal_components.shape[:3] + + # find all of the co-occurring pitch shift and template pairs + temp_shift_index = get_shift_and_unit_pairs( + chunk_time_centers_s, + geom, + template_data, + motion_est=motion_est, + ) + + # check if the convolutions need to be drift-aware + # they do if we need to do any channel selection + is_drifting = np.array_equal(temp_shift_index.all_pitch_shifts, [0]) + if template_data.registered_geom is not None: + is_drifting &= np.array_equal(geom, template_data.registered_geom) + + # initialize pairwise conv data structures + # index_table[shifted_temp_ix(i), shifted_temp_ix(j)] = pconvix(i,j) + pconv_index_table = np.zeros( + (temp_shift_index.n_shifted_templates, temp_shift_index.n_shifted_templates), + dtype=int, + ) + # pconvs[pconvix(i,j)] = (2*spikelen-1, upsampling_factor) arr of pconv(shifted_temp(i), shifted_temp(j)) + + cur_pconv_ix = 1 + with h5py.File(output_hdf5_filename, "w") as h5: + # resizeable pconv dataset + pconv = h5.create_dataset( + "pconv", + dtype=np.float32, + maxshape=(None, upsampling_factor, 2 * spike_length_samples - 1), + chunks=(128, upsampling_factor, 2 * spike_length_samples - 1), + ) + + # pconv[0] is special -- it is 0. + pconv[0] = 0.0 + + # res is a ConvBatchResult + for res in compute_pairwise_convs( + template_data, + template_spatial_components, + template_singular_values, + template_temporal_components, + template_upsampled_temporal_components, + temp_shift_index.shifted_temp_ix_to_temp_ix, + temp_shift_index.shifted_temp_ix_to_shift, + geom, + conv_ignore_threshold=conv_ignore_threshold, + coarse_approx_error_threshold=coarse_approx_error_threshold, + min_channel_amplitude=min_channel_amplitude, + is_drifting=is_drifting, + units_per_chunk=units_per_chunk, + n_jobs=n_jobs, + device=device, + show_progress=show_progress, + max_shift="full", + store_conv=True, + compute_max=False, + ): + new_conv_ix = res.cconv_ix + ixnz = new_conv_ix > 0 + new_conv_ix[ixnz] += cur_pconv_ix + pconv_index_table[ + res.shifted_temp_ix_a, res.shifted_temp_ix_b + ] = new_conv_ix + pconv.resize(cur_pconv_ix + ixnz.size, axis=0) + pconv[new_conv_ix[ixnz]] = res.pconv[ixnz] + cur_pconv_ix += ixnz.size + + # smaller datasets all at once + h5.create_dataset( + "template_shift_index", data=temp_shift_index.template_shift_index + ) + h5.create_dataset("pconv_index_table", data=pconv_index_table) + h5.create_dataset("shifts", data=temp_shift_index.all_pitch_shifts) + h5.create_dataset( + "shifted_temp_ix_to_temp_ix", + data=temp_shift_index.shifted_temp_ix_to_temp_ix, + ) + h5.create_dataset( + "shifted_temp_ix_to_shift", data=temp_shift_index.shifted_temp_ix_to_shift + ) + h5.create_dataset( + "shifted_temp_ix_to_unit", + data=template_data.unit_ids[temp_shift_index.shifted_temp_ix_to_temp_ix], + ) + + return SparsePairwiseConv.from_h5(output_hdf5_filename) + + +@dataclass +class SparsePairwiseConv: + # shift_ix -> shift + shifts: np.ndarray + # (temp_ix, shift_ix) -> shifted_temp_ix + template_shift_index: torch.LongTensor + # (shifted_temp_ix a, shifted_temp_ix b) -> pconv index + pconv_index_table: torch.LongTensor + # pconv index -> pconv (upsampling, 2 * spike len - 1) + # the zero index lands you at an all 0 pconv + pconv: torch.Tensor + + # metadata: map shifted template index to original template ix and shift + shifted_temp_ix_to_temp_ix: np.ndarray + shifted_temp_ix_to_shift: np.ndarray + shifted_temp_ix_to_unit: np.ndarray + + @classmethod + def from_h5(cls, hdf5_filename): + ff = fields(cls) + with h5py.File(hdf5_filename, "r") as h5: + data = {f.name: h5[f.name][:] for f in ff} + return cls(**data) + + def query( + self, + template_indices_a, + template_indices_b, + upsampling_indices_b=None, + shifts_a=None, + shifts_b=None, + return_zero_convs=False, + ): + """Get cross-correlations of pairs of units A and B + + This passes through the series of lookup tables to recover (upsampled) + cross-correlations from this sparse database. + + Returns + ------- + template_indices_a, template_indices_b, pair_convs + """ + # get shifted template indices + pconv = self.pconv + if upsampling_indices_b is None: + assert self.pconv.shape[1] == 1 + pconv = pconv[:, 0, :] + if shifts_a is None or shifts_b is None: + assert np.array_equal(self.shifts, [0.0]) + shifted_temp_ix_a = template_indices_a + shifted_temp_ix_b = template_indices_b + else: + shift_ix_a = np.searchsorted(self.shifts, shifts_a) + assert np.array_equal(self.shifts[shift_ix_a], shifts_a) + shift_ix_b = np.searchsorted(self.shifts, shifts_b) + assert np.array_equal(self.shifts[shift_ix_b], shifts_b) + shifted_temp_ix_a = self.template_shift_index[ + template_indices_a, shift_ix_a + ] + shifted_temp_ix_b = self.template_shift_index[ + template_indices_b, shift_ix_b + ] + + pconv_indices = self.pconv_index_table[shifted_temp_ix_a, shifted_temp_ix_b] + + # most users will be happy not to get a bunch of zeros for pairs that don't overlap + if not return_zero_convs: + which = np.flatnonzero(pconv_indices > 0) + pconv_indices = pconv_indices[which] + template_indices_a = template_indices_a[which] + template_indices_b = template_indices_b[which] + if upsampling_indices_b is not None: + upsampling_indices_b = upsampling_indices_b[which] + + if upsampling_indices_b is None: + pair_convs = pconv[pconv_indices] + else: + pair_convs = pconv[pconv_indices, upsampling_indices_b] + + return template_indices_a, template_indices_b, pair_convs + + +def compute_pairwise_convs( + template_data, + spatial, + singular, + temporal, + temporal_up, + shifted_temp_ix_to_temp_ix, + shifted_temp_ix_to_shift, + geom, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + min_channel_amplitude=1.0, + max_shift="full", + is_drifting=True, + store_conv=True, + compute_max=False, + units_per_chunk=8, + n_jobs=0, + device=None, + show_progress=True, +): + # chunk up coarse unit ids, go by pairs of chunks + units = np.unique(template_data.unit_ids) + jobs = [] + for start_a in range(0, units.size, units_per_chunk): + end_a = min(start_a + units_per_chunk, units.size) + for start_b in range(start_a + 1, units.size, units_per_chunk): + end_b = min(start_b + units_per_chunk, units.size) + jobs.append((units[start_a:end_a], units[start_b:end_b])) + if show_progress: + jobs = tqdm(jobs) + + # compute the coarse templates if needed + if units.size == template_data.unit_ids.size: + # coarse templates are original templates + coarse_approx_error_threshold = 0 + if coarse_approx_error_threshold > 0: + coarse_templates = template_util.weighted_average( + template_data.unit_ids, template_data.templates, template_data.spike_counts + ) + ( + coarse_spatial, + coarse_singular, + coarse_temporal, + ) = template_util.svd_compress_templates( + coarse_templates, + rank=spatial.shape[2], + min_channel_amplitude=min_channel_amplitude, + ) + + # template data to torch + spatial_singular = torch.as_tensor(spatial * singular[:, None, :]) + temporal = torch.as_tensor(temporal) + temporal_up = torch.as_tensor(temporal_up) + if coarse_approx_error_threshold > 0: + coarse_spatial_singular = torch.as_tensor( + coarse_spatial * coarse_singular[:, None, :] + ) + coarse_temporal = torch.as_tensor(coarse_temporal) + else: + coarse_spatial_singular = None + coarse_temporal = None + + n_jobs, Executor, context, rank_queue = get_pool(n_jobs, with_rank_queue=True) + + pconv_params = dict( + store_conv=store_conv, + compute_max=compute_max, + is_drifting=is_drifting, + max_shift=max_shift, + conv_ignore_threshold=conv_ignore_threshold, + coarse_approx_error_threshold=coarse_approx_error_threshold, + spatial_singular=spatial_singular, + temporal=temporal, + temporal_up=temporal_up, + coarse_spatial_singular=coarse_spatial_singular, + coarse_temporal=coarse_temporal, + unit_ids=template_data.unit_ids, + shifted_temp_ix_to_shift=shifted_temp_ix_to_shift, + shifted_temp_ix_to_temp_ix=shifted_temp_ix_to_temp_ix, + geom=geom, + registered_geom=template_data.registered_geom, + ) + + with Executor( + n_jobs, + mp_context=context, + initializer=_pairwise_conv_init, + initargs=(device, rank_queue, pconv_params), + ) as pool: + yield from pool.map(_pairwise_conv_job, jobs) + + +# -- parallel job code + + +# helper class which stores parameters for _pairwise_conv_job +@dataclass +class PairwiseConvContext: + device: torch.device + + # parameters + store_conv: bool + compute_max: bool + is_drifting: bool + max_shift: int + conv_ignore_threshold: float + coarse_approx_error_threshold: float + + # superres registered templates + spatial_singular: torch.Tensor + temporal: torch.Tensor + temporal_up: torch.Tensor + coarse_spatial_singular: Optional[torch.Tensor] + coarse_temporal: Optional[torch.Tensor] + + # template indexing helper arrays + unit_ids: np.ndarray + shifted_temp_ix_to_temp_ix: np.ndarray + shifted_temp_ix_to_shift: np.ndarray + + # only needed if is_drifting + geom: np.ndarray + registered_geom: np.ndarray + target_kdtree: Optional[KDTree] + match_distance: Optional[float] + + +_pairwise_conv_context = None + + +def _pairwise_conv_init( + device, + rank_queue, + kwargs, +): + global _pairwise_conv_context + + # figure out what device to work on + my_rank = rank_queue.get() + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + if device.type == "cuda" and device.index is None: + if torch.cuda.device_count() > 1: + device = torch.device("cuda", index=my_rank % torch.cuda.device_count()) + + # handle string max_shift + max_shift = kwargs.pop("max_shift", "full") + t = kwargs["temporal"].shape[1] + if max_shift == "full": + max_shift = t - 1 + elif max_shift == "valid": + max_shift = 0 + elif max_shift == "same": + max_shift = t // 2 + kwargs["max_shift"] = max_shift + + kwargs["target_kdtree"] = kwargs["match_distance"] = None + if kwargs["is_drifting"]: + kwargs["target_kdtree"] = KDTree(kwargs["geom"]) + kwargs["match_distance"] = pdist(kwargs["geom"]).min() / 2 + + _pairwise_conv_context = PairwiseConvContext(device=device, **kwargs) + + +@dataclass +class ConvBatchResult: + # arrays of length + shifted_temp_ix_a: np.ndarray + shifted_temp_ix_b: np.ndarray + # array of length such that the ith + # pair's array of upsampled convs is cconv_up[cconv_ix[i]] + cconv_ix: np.ndarray + cconv_up: Optional[np.ndarray] + max_conv: Optional[float] + best_shift: Optional[int] + + +def _pairwise_conv_job( + units_a, + units_b, +): + """units_a,b are chunks of original (non-superres) unit labels""" + global _pairwise_conv_context + p = _pairwise_conv_context + + # this job consists of pairs of coarse units + # lets get all shifted superres template indices corresponding to those pairs, + # and the template indices, pitch shifts, and coarse units while we're at it + shifted_temp_ix_a = np.flatnonzero(np.isin(p.shifted_temp_ix_to_unit, units_a)) + shifted_temp_ix_b = np.flatnonzero(np.isin(p.shifted_temp_ix_to_unit, units_b)) + temp_ix_a = p.shifted_temp_ix_to_temp_ix[shifted_temp_ix_a] + temp_ix_b = p.shifted_temp_ix_to_temp_ix[shifted_temp_ix_b] + shift_a = p.shifted_temp_ix_to_shift[shifted_temp_ix_a] + shift_b = p.shifted_temp_ix_to_shift[shifted_temp_ix_b] + unit_a = p.unit_ids[temp_ix_a] + unit_b = p.unit_ids[temp_ix_b] + + # get shifted spatial components + spatial_a = p.spatial_singular[temp_ix_a] + spatial_b = p.spatial_singular[temp_ix_b] + if p.is_drifting: + spatial_a = drift_util.get_waveforms_on_static_channels( + spatial_a, + p.registered_geom, + n_pitches_shift=shift_a, + registered_geom=p.geom, + target_kdtree=p.target_kdtree, + match_distance=p.match_distance, + fill_value=0.0, + ) + spatial_b = drift_util.get_waveforms_on_static_channels( + spatial_b, + p.registered_geom, + n_pitches_shift=shift_b, + registered_geom=p.geom, + target_kdtree=p.target_kdtree, + match_distance=p.match_distance, + fill_value=0.0, + ) + + # to device + spatial_a = spatial_a.to(p.device) + spatial_b = spatial_b.to(p.device) + temporal_a = p.temporal[temp_ix_a].to(p.device) + temporal_up_b = p.temporal_up[temp_ix_b].to(p.device) + + # convolve valid pairs + pair_mask = p.cooccurence[shifted_temp_ix_a[:, None], shifted_temp_ix_b[None, :]] + pair_mask = pair_mask * (shifted_temp_ix_a[:, None] <= shifted_temp_ix_b[None, :]) + pair_mask = torch.as_tensor(pair_mask, device=p.device) + conv_ix_a, conv_ix_b, cconv = ccorrelate_up( + spatial_a, + temporal_a, + spatial_b, + temporal_up_b, + conv_ignore_threshold=p.conv_ignore_threshold, + max_shift=p.max_shift, + covisible_mask=pair_mask, + ) + nco = conv_ix_a.numel() + cconv_ix = np.arange(nco) + + # summarize units by coarse pconv when possible + if p.coarse_approx_error_threshold > 0: + # figure out coarse templates to correlate + conv_unit_a = unit_a[conv_ix_a] + conv_unit_b = unit_b[conv_ix_b] + coarse_units_a, coarse_units_b = np.unique( + np.c_[conv_unit_a, conv_unit_b], + axis=0, + ).T + + # correlate them + coarse_ix_a, coarse_ix_b, coarse_cconv = ccorrelate_up( + p.coarse_spatial_singular[coarse_units_a], + p.coarse_temporal[temp_ix_a], + p.coarse_spatial_singular[coarse_units_b], + p.coarse_temporal[temp_ix_b].unsqueeze(2), + conv_ignore_threshold=p.conv_ignore_threshold, + max_shift=p.max_shift, + ) + # i feel like this should hold so assert for now + assert coarse_ix_a.size == coarse_units_a.size + + # find coarse units which well summarize the fine cconvs + for coarse_unit_a, coarse_unit_b, conv in zip( + coarse_units_a, coarse_units_b, coarse_cconv + ): + # check good approx. if not, continue + in_pair = np.flatnonzero( + (conv_unit_a == coarse_unit_a) & (conv_unit_b == coarse_unit_b) + ) + assert in_pair.size + fine_cconvs = cconv[in_pair] + approx_err = (fine_cconvs - conv[None]).abs().max() + if not approx_err < p.coarse_approx_error_threshold: + continue + + # replace first fine cconv with the coarse cconv + fine_cconvs[in_pair[0]] = conv + # set all fine cconv ix to the index of that first one + cconv_ix[in_pair] = cconv_ix[in_pair[0]] + + # re-index and subset cconvs + cconv_ix = np.unique(cconv_ix) + conv_ix_a = conv_ix_a[cconv_ix] + conv_ix_b = conv_ix_b[cconv_ix] + cconv = cconv[cconv_ix] + + # for use in deconv residual distance merge + # TODO: actually probably need to do the real objective here with + # scaling. only need to do that bc of scaling right? + # makes it kind of a pain, because then we need to go pairwise + # (deconv objective is not symmetric) + max_conv = best_shift = None + if p.compute_max: + cconv_ = cconv.reshape(nco, cconv.shape[1] * cconv.shape[2]) + max_conv, max_index = cconv_.max(dim=1) + max_up, max_sample = np.unravel_index( + max_index.numpy(force=True), shape=cconv.shape[1:] + ) + best_shift = max_sample - (p.max_shift + 1) + # if upsample>half nup, round max shift up + best_shift += np.rint(max_up / cconv.shape[1]).astype(int) + + return ConvBatchResult( + shifted_temp_ix_a[conv_ix_a].numpy(force=True), + shifted_temp_ix_b[conv_ix_b].numpy(force=True), + cconv_ix, + cconv.numpy(force=True), + max_conv.numpy(force=True), + best_shift, + ) + + +# -- library code +# template index and shift pairs +# pairwise low-rank cross-correlation + + +# this dtype lets us use np.union1d to find unique +# template index + pitch shift pairs below +template_shift_pair = np.dtype([("template_ix", int), ("shift", int)]) + + +@dataclass +class TemplateShiftIndex: + """Return value for get_shift_and_unit_pairs""" + + n_shifted_templates: int + # shift index -> shift + all_pitch_shifts: np.ndarray + # (template ix, shift index) -> shifted template index + template_shift_index: np.ndarray + # (shifted temp ix, shifted temp ix) -> did these appear at the same time + cooccurence: np.ndarray + shifted_temp_ix_to_temp_ix: np.ndarray + shifted_temp_ix_to_shift: np.ndarray + + +def static_template_shift_index(n_templates): + temp_ixs = np.arange(n_templates) + return TemplateShiftIndex( + n_templates, + np.zeros(1), + temp_ixs[:, None], + np.ones((n_templates, n_templates), dtype=bool), + temp_ixs, + np.zeros_like(temp_ixs), + ) + + +def get_shift_and_unit_pairs( + chunk_time_centers_s, + geom, + template_data, + motion_est=None, +): + n_templates = len(template_data.templates) + if motion_est is None: + # no motion case + return static_template_shift_index(n_templates) + + # all observed pitch shift values + all_pitch_shifts = np.empty(shape=(), dtype=int) + temp_ixs = np.arange(n_templates) + # set of (template idx, shift) + template_shift_pairs = np.empty(shape=(), dtype=template_shift_pair) + + for t_s in chunk_time_centers_s: + # see the fn `templates_at_time` + unregistered_depths_um = drift_util.invert_motion_estimate( + motion_est, t_s, template_data.registered_template_depths_um + ) + pitch_shifts = drift_util.get_spike_pitch_shifts( + depths_um=template_data.registered_template_depths_um, + geom=geom, + registered_depths_um=unregistered_depths_um, + ) + pitch_shifts = pitch_shifts.astype(int) + + # get unique pitch/unit shift pairs in chunk + template_shift = np.c_[temp_ixs, pitch_shifts] + template_shift = template_shift.view(template_shift_pair)[:, 0] + assert template_shift.shape == (n_templates,) + + # update full set + all_pitch_shifts = np.union1d(all_pitch_shifts, pitch_shifts) + template_shift_pairs = np.union1d(template_shift_pairs, template_shift) + + n_shifts = len(all_pitch_shifts) + n_template_shift_pairs = len(template_shift_pairs) + + # index template/shift pairs: template_shift_index[template_ix, shift_ix] = shifted template index + # fill with an invalid index + template_shift_index = np.full((n_templates, n_shifts), n_template_shift_pairs + 1) + template_shift_index[ + template_shift_pairs["template_ix"], template_shift_pairs["shift"] + ] = np.arange(n_template_shift_pairs) + shifted_temp_ix_to_temp_ix = template_shift_pairs["template_ix"] + shifted_temp_ix_to_shift = template_shift_pairs["shift"] + + # co-occurrence matrix: do these shifted templates appear together? + cooccurence = np.zeros((n_template_shift_pairs, n_template_shift_pairs), dtype=bool) + for t_s in chunk_time_centers_s: + # see the fn `templates_at_time` + unregistered_depths_um = drift_util.invert_motion_estimate( + motion_est, t_s, template_data.registered_template_depths_um + ) + pitch_shifts = drift_util.get_spike_pitch_shifts( + depths_um=template_data.registered_template_depths_um, + geom=geom, + registered_depths_um=unregistered_depths_um, + ) + pitch_shifts = pitch_shifts.astype(int) + + shifted_temp_ixs = template_shift_index[temp_ixs, pitch_shifts] + cooccurence[shifted_temp_ixs[:, None], shifted_temp_ixs[None, :]] = 1 + + return TemplateShiftIndex( + n_template_shift_pairs, + all_pitch_shifts, + template_shift_index, + cooccurence, + shifted_temp_ix_to_temp_ix, + shifted_temp_ix_to_shift, + ) + + +def ccorrelate_up( + spatial_a, + temporal_a, + spatial_b, + temporal_b, + conv_ignore_threshold=0.0, + max_shift="full", + covisible_mask=None, +): + """Convolve all pairs of low-rank templates + + This uses too much memory to run on all pairs at once. + + Templates Ka = Sa Ta, Kb = Sb Tb. The channel-summed convolution is + (Ka (*) Kb) = sum_c Ka(c) * Kb(c) + = (Sb.T @ Ka) (*) Tb + = (Sb.T @ Sa @ Ta) (*) Tb + where * is cross-correlation, and (*) is channel (or rank) summed. + + We use full-height conv2d to do rank-summed convs. + + Returns + ------- + covisible_a, covisible_b : tensors of indices + Both have shape (nco,), where nco is the number of templates + whose pairwise conv exceeds conv_ignore_threshold. + So, zip(covisible_a, covisible_b) is the set of co-visible pairs. + cconv : torch.Tensor + Shape is (nco, nup, 2 * max_shift + 1) + All cross-correlations for pairs of templates (templates in b + can be upsampled.) + If max_shift is full, then 2*max_shift+1=2t-1. + """ + na, rank, nchan = spatial_a.shape + nb, rank_, nchan_ = spatial_b.shape + assert rank == rank_ + assert nchan == nchan_ + na_, t, rank_ = temporal_a.shape + assert na == na_ + assert rank_ == rank + nb_, t_, nup, rank_ = temporal_b.shape + assert nb == nb_ + assert t == t_ + assert rank == rank_ + if covisible_mask is not None: + assert covisible_mask.shape == (na, nb) + + # this is covisible with ignore threshold 0 + # no need to convolve templates which do not overlap + covisible = spatial_a.max(1).values @ spatial_b.max(1).values.T + if covisible_mask is not None: + covisible *= covisible_mask + covisible_a, covisible_b = torch.nonzero(covisible, as_tuple=True) + nco = covisible_a.numel() + # TODO: can batch over nco dims below if memory issues arise + + Sa = spatial_a[covisible_a].reshape(nco * rank, nchan) + Sb = spatial_b[covisible_b].reshape(nco * rank, nchan) + spatial_outer = torch.vecdot(Sa, Sb) + spatial_outer = spatial_outer.reshape(nco, rank) + assert spatial_outer.shape == (nco, rank) + + # want conv filter: nco, rank, t + spatial_outer_co = spatial_outer[covisible_a, covisible_b] + conv_filt = spatial_outer_co[:, None, :] * temporal_a.permute(0, 2, 1)[None] + assert conv_filt.shape == (nco, rank, t) + + # nup, nco, rank, t + conv_in = temporal_b[covisible_b].permute(2, 0, 3, 1) + + # conv2d: + # depthwise, chans=nco. batch=1. h=rank. w=t. out: nup, nco, 1, 2p+1. + # input (conv_left): nup, nco, rank, t. + # filters (conv_right): nco, 1, rank, t. (groups=nco). + cconv = F.conv2d(conv_in, conv_filt, padding=max_shift, groups=nco) + assert cconv.shape == (nup, nco, 1, 2 * max_shift + 1) + cconv = cconv[:, :, 0, :].permute(1, 0, 2) + + # more stringent covisibility + if conv_ignore_threshold > 0: + vis = cconv.abs().max(dim=(0, 2)).values > conv_ignore_threshold + cconv = cconv[vis] + covisible_a = covisible_a[vis] + covisible_b = covisible_b[vis] + + return covisible_a, covisible_b, cconv diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index b4fea070..08d0b470 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -130,7 +130,7 @@ def weighted_average(unit_ids, templates, weights): w = weights[which_in][:, None, None] w /= w.sum() out[i] = (w * templates[which_in]).sum(0) - + return out diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 4598e8da..955c9cee 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -1,17 +1,13 @@ from dataclasses import dataclass +from pathlib import Path from typing import Optional import numpy as np -from pathlib import Path +from dartsort.util import drift_util from .get_templates import get_templates from .superres_util import superres_sorting -from .template_util import ( - get_registered_templates, - get_realigned_sorting, - get_template_depths, -) -from dartsort.util import drift_util +from .template_util import get_realigned_sorting, get_template_depths _motion_error_prefix = ( "If template_config has registered_templates==True " diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index 16c8969c..66367999 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -72,9 +72,8 @@ def to_numpy_sorting(self): def __str__(self): name = self.__class__.__name__ nspikes = self.times_samples.size - units = np.unique(self.labels) - units = units[units >= 0] - unit_str = f"{units.size} unit" + ("s" if units.size > 1 else "") + nunits = (np.unique(self.labels) >= 0).sum() + unit_str = f"{nunits} unit" + "s" * (nunits > 1) feat_str = "" if self.extra_features: feat_str = ", ".join(self.extra_features.keys()) diff --git a/src/dartsort/vis/scatterplots.py b/src/dartsort/vis/scatterplots.py index f639394b..8d1258c3 100644 --- a/src/dartsort/vis/scatterplots.py +++ b/src/dartsort/vis/scatterplots.py @@ -128,7 +128,7 @@ def scatter_spike_features( to_show=to_show, **scatter_kw, ) - + if label_axes: axes[0].set_ylabel("depth (um)") axes[0].set_xlabel("x (um)") @@ -166,6 +166,8 @@ def scatter_time_vs_depth( the times_s, depths_um, and (one of) amplitudes or labels as arrays, or alternatively, these can be left unset and they will be loaded from hdf5_filename when it is supplied. + + Returns: axis, scatter """ if hdf5_filename is not None: with h5py.File(hdf5_filename, "r") as h5: From 28da3961d30392df5da6df0b2a482e0b8ba9def5 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 27 Oct 2023 11:57:01 -0400 Subject: [PATCH 07/49] todo --- src/dartsort/templates/template_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index b4fea070..0eb50c11 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -154,6 +154,7 @@ def templates_at_time( registered_geom=None, motion_est=None, return_pitch_shifts=False, + # TODO: geom kdtree ): if registered_geom is None: return registered_templates From 1b3be9e5396dd95235a47c7dfdea07b86c782a29 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 31 Oct 2023 10:58:29 -0400 Subject: [PATCH 08/49] Update environment.yml --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index cf7f34c6..d8123cce 100644 --- a/environment.yml +++ b/environment.yml @@ -7,3 +7,4 @@ dependencies: - h5py - tqdm - scikit-learn + - colorcet From e3dc31b88ef5a90667fb7b7c29e8d1833c39540a Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 31 Oct 2023 11:27:31 -0400 Subject: [PATCH 09/49] Pairwise progress --- src/dartsort/cluster/split.py | 2 +- src/dartsort/templates/pairwise.py | 355 ++++++++++++++++-------- src/dartsort/templates/pairwise_conv.py | 123 ++++++-- src/dartsort/templates/superres_util.py | 3 + src/dartsort/templates/template_util.py | 37 ++- src/dartsort/util/drift_util.py | 4 +- 6 files changed, 379 insertions(+), 145 deletions(-) diff --git a/src/dartsort/cluster/split.py b/src/dartsort/cluster/split.py index bc3fdc90..479b4232 100644 --- a/src/dartsort/cluster/split.py +++ b/src/dartsort/cluster/split.py @@ -323,7 +323,7 @@ def initialize_from_h5( amplitudes_dataset_name="denoised_amplitudes", amplitude_vectors_dataset_name="denoised_amplitude_vectors", ): - h5 = h5py.File(peeling_hdf5_filename, "r") + h5 = h5py.File(peeling_hdf5_filename, "r", locking=False) self.geom = h5["geom"][:] self.channel_index = h5["channel_index"][:] diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index cbfab55f..e1952d98 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -81,9 +81,13 @@ def sparse_pairwise_conv( # check if the convolutions need to be drift-aware # they do if we need to do any channel selection - is_drifting = np.array_equal(temp_shift_index.all_pitch_shifts, [0]) + print(f"{temp_shift_index.all_pitch_shifts=}") + is_drifting = not np.array_equal(temp_shift_index.all_pitch_shifts, [0]) + print(f"A {is_drifting=}") if template_data.registered_geom is not None: - is_drifting &= np.array_equal(geom, template_data.registered_geom) + print(f"{np.array_equal(geom, template_data.registered_geom)=}") + is_drifting |= not np.array_equal(geom, template_data.registered_geom) + print(f"B {is_drifting=}") # initialize pairwise conv data structures # index_table[shifted_temp_ix(i), shifted_temp_ix(j)] = pconvix(i,j) @@ -99,6 +103,7 @@ def sparse_pairwise_conv( pconv = h5.create_dataset( "pconv", dtype=np.float32, + shape=(1, upsampling_factor, 2 * spike_length_samples - 1), maxshape=(None, upsampling_factor, 2 * spike_length_samples - 1), chunks=(128, upsampling_factor, 2 * spike_length_samples - 1), ) @@ -116,6 +121,7 @@ def sparse_pairwise_conv( temp_shift_index.shifted_temp_ix_to_temp_ix, temp_shift_index.shifted_temp_ix_to_shift, geom, + cooccurrence=temp_shift_index.cooccurrence, conv_ignore_threshold=conv_ignore_threshold, coarse_approx_error_threshold=coarse_approx_error_threshold, min_channel_amplitude=min_channel_amplitude, @@ -128,15 +134,16 @@ def sparse_pairwise_conv( store_conv=True, compute_max=False, ): + if res is None: + continue new_conv_ix = res.cconv_ix - ixnz = new_conv_ix > 0 - new_conv_ix[ixnz] += cur_pconv_ix + new_conv_ix += cur_pconv_ix pconv_index_table[ res.shifted_temp_ix_a, res.shifted_temp_ix_b ] = new_conv_ix - pconv.resize(cur_pconv_ix + ixnz.size, axis=0) - pconv[new_conv_ix[ixnz]] = res.pconv[ixnz] - cur_pconv_ix += ixnz.size + pconv.resize(cur_pconv_ix + new_conv_ix.size, axis=0) + pconv[new_conv_ix] = res.cconv_up + cur_pconv_ix += new_conv_ix.size # smaller datasets all at once h5.create_dataset( @@ -156,7 +163,7 @@ def sparse_pairwise_conv( data=template_data.unit_ids[temp_shift_index.shifted_temp_ix_to_temp_ix], ) - return SparsePairwiseConv.from_h5(output_hdf5_filename) + return output_hdf5_filename # SparsePairwiseConv.from_h5(output_hdf5_filename) @dataclass @@ -250,6 +257,7 @@ def compute_pairwise_convs( shifted_temp_ix_to_temp_ix, shifted_temp_ix_to_shift, geom, + cooccurrence, conv_ignore_threshold=0.0, coarse_approx_error_threshold=0.0, min_channel_amplitude=1.0, @@ -271,7 +279,9 @@ def compute_pairwise_convs( end_b = min(start_b + units_per_chunk, units.size) jobs.append((units[start_a:end_a], units[start_b:end_b])) if show_progress: - jobs = tqdm(jobs) + jobs = tqdm( + jobs, smoothing=0.01, desc="Pairwise convolution", unit="pair block" + ) # compute the coarse templates if needed if units.size == template_data.unit_ids.size: @@ -281,23 +291,24 @@ def compute_pairwise_convs( coarse_templates = template_util.weighted_average( template_data.unit_ids, template_data.templates, template_data.spike_counts ) + print(f"{coarse_templates.shape=}") ( - coarse_spatial, - coarse_singular, coarse_temporal, + coarse_singular, + coarse_spatial, ) = template_util.svd_compress_templates( coarse_templates, - rank=spatial.shape[2], + rank=singular.shape[1], min_channel_amplitude=min_channel_amplitude, ) # template data to torch - spatial_singular = torch.as_tensor(spatial * singular[:, None, :]) + spatial_singular = torch.as_tensor(spatial * singular[:, :, None]) temporal = torch.as_tensor(temporal) temporal_up = torch.as_tensor(temporal_up) if coarse_approx_error_threshold > 0: coarse_spatial_singular = torch.as_tensor( - coarse_spatial * coarse_singular[:, None, :] + coarse_spatial * coarse_singular[:, :, None] ) coarse_temporal = torch.as_tensor(coarse_temporal) else: @@ -321,6 +332,8 @@ def compute_pairwise_convs( unit_ids=template_data.unit_ids, shifted_temp_ix_to_shift=shifted_temp_ix_to_shift, shifted_temp_ix_to_temp_ix=shifted_temp_ix_to_temp_ix, + shifted_temp_ix_to_unit=template_data.unit_ids[shifted_temp_ix_to_temp_ix], + cooccurrence=cooccurrence, geom=geom, registered_geom=template_data.registered_geom, ) @@ -356,11 +369,13 @@ class PairwiseConvContext: temporal_up: torch.Tensor coarse_spatial_singular: Optional[torch.Tensor] coarse_temporal: Optional[torch.Tensor] + cooccurrence: torch.Tensor # template indexing helper arrays unit_ids: np.ndarray shifted_temp_ix_to_temp_ix: np.ndarray shifted_temp_ix_to_shift: np.ndarray + shifted_temp_ix_to_unit: np.ndarray # only needed if is_drifting geom: np.ndarray @@ -420,14 +435,12 @@ class ConvBatchResult: best_shift: Optional[int] -def _pairwise_conv_job( - units_a, - units_b, -): - """units_a,b are chunks of original (non-superres) unit labels""" +def _pairwise_conv_job(unit_chunk): global _pairwise_conv_context p = _pairwise_conv_context + units_a, units_b = unit_chunk + # this job consists of pairs of coarse units # lets get all shifted superres template indices corresponding to those pairs, # and the template indices, pitch shifts, and coarse units while we're at it @@ -443,6 +456,7 @@ def _pairwise_conv_job( # get shifted spatial components spatial_a = p.spatial_singular[temp_ix_a] spatial_b = p.spatial_singular[temp_ix_b] + # print(f"{p.is_drifting=} old {spatial_a.shape=}") if p.is_drifting: spatial_a = drift_util.get_waveforms_on_static_channels( spatial_a, @@ -453,6 +467,7 @@ def _pairwise_conv_job( match_distance=p.match_distance, fill_value=0.0, ) + # print(f"new {spatial_a.shape=} {p.target_kdtree=}") spatial_b = drift_util.get_waveforms_on_static_channels( spatial_b, p.registered_geom, @@ -470,7 +485,7 @@ def _pairwise_conv_job( temporal_up_b = p.temporal_up[temp_ix_b].to(p.device) # convolve valid pairs - pair_mask = p.cooccurence[shifted_temp_ix_a[:, None], shifted_temp_ix_b[None, :]] + pair_mask = p.cooccurrence[shifted_temp_ix_a[:, None], shifted_temp_ix_b[None, :]] pair_mask = pair_mask * (shifted_temp_ix_a[:, None] <= shifted_temp_ix_b[None, :]) pair_mask = torch.as_tensor(pair_mask, device=p.device) conv_ix_a, conv_ix_b, cconv = ccorrelate_up( @@ -482,55 +497,35 @@ def _pairwise_conv_job( max_shift=p.max_shift, covisible_mask=pair_mask, ) + if conv_ix_a is None: + return None nco = conv_ix_a.numel() + if not nco: + return None cconv_ix = np.arange(nco) + + # shifts may not matter + if p.is_drifting: + cconv, cconv_ix_subset = _shift_normalize( + cconv, + cconv_ix, + temp_ix_a[conv_ix_a.cpu()], + shift_a[conv_ix_a.cpu()], + temp_ix_b[conv_ix_b.cpu()], + shift_b[conv_ix_b.cpu()], + ) + conv_ix_a = conv_ix_a[cconv_ix_subset] + conv_ix_b = conv_ix_b[cconv_ix_subset] + cconv_ix = np.arange(len(cconv_ix_subset)) # summarize units by coarse pconv when possible if p.coarse_approx_error_threshold > 0: - # figure out coarse templates to correlate - conv_unit_a = unit_a[conv_ix_a] - conv_unit_b = unit_b[conv_ix_b] - coarse_units_a, coarse_units_b = np.unique( - np.c_[conv_unit_a, conv_unit_b], - axis=0, - ).T - - # correlate them - coarse_ix_a, coarse_ix_b, coarse_cconv = ccorrelate_up( - p.coarse_spatial_singular[coarse_units_a], - p.coarse_temporal[temp_ix_a], - p.coarse_spatial_singular[coarse_units_b], - p.coarse_temporal[temp_ix_b].unsqueeze(2), - conv_ignore_threshold=p.conv_ignore_threshold, - max_shift=p.max_shift, + cconv, cconv_ix_subset = _coarse_approx( + cconv, cconv_ix, conv_ix_a, conv_ix_b, unit_a, unit_b, p ) - # i feel like this should hold so assert for now - assert coarse_ix_a.size == coarse_units_a.size - - # find coarse units which well summarize the fine cconvs - for coarse_unit_a, coarse_unit_b, conv in zip( - coarse_units_a, coarse_units_b, coarse_cconv - ): - # check good approx. if not, continue - in_pair = np.flatnonzero( - (conv_unit_a == coarse_unit_a) & (conv_unit_b == coarse_unit_b) - ) - assert in_pair.size - fine_cconvs = cconv[in_pair] - approx_err = (fine_cconvs - conv[None]).abs().max() - if not approx_err < p.coarse_approx_error_threshold: - continue - - # replace first fine cconv with the coarse cconv - fine_cconvs[in_pair[0]] = conv - # set all fine cconv ix to the index of that first one - cconv_ix[in_pair] = cconv_ix[in_pair[0]] - - # re-index and subset cconvs - cconv_ix = np.unique(cconv_ix) - conv_ix_a = conv_ix_a[cconv_ix] - conv_ix_b = conv_ix_b[cconv_ix] - cconv = cconv[cconv_ix] + conv_ix_a = conv_ix_a[cconv_ix_subset] + conv_ix_b = conv_ix_b[cconv_ix_subset] + cconv_ix = np.arange(len(cconv_ix_subset)) # for use in deconv residual distance merge # TODO: actually probably need to do the real objective here with @@ -549,11 +544,11 @@ def _pairwise_conv_job( best_shift += np.rint(max_up / cconv.shape[1]).astype(int) return ConvBatchResult( - shifted_temp_ix_a[conv_ix_a].numpy(force=True), - shifted_temp_ix_b[conv_ix_b].numpy(force=True), + shifted_temp_ix_a[conv_ix_a.numpy(force=True)], + shifted_temp_ix_b[conv_ix_b.numpy(force=True)], cconv_ix, - cconv.numpy(force=True), - max_conv.numpy(force=True), + cconv.numpy(force=True) if cconv is not None else None, + max_conv.numpy(force=True) if max_conv is not None else None, best_shift, ) @@ -563,11 +558,6 @@ def _pairwise_conv_job( # pairwise low-rank cross-correlation -# this dtype lets us use np.union1d to find unique -# template index + pitch shift pairs below -template_shift_pair = np.dtype([("template_ix", int), ("shift", int)]) - - @dataclass class TemplateShiftIndex: """Return value for get_shift_and_unit_pairs""" @@ -578,7 +568,7 @@ class TemplateShiftIndex: # (template ix, shift index) -> shifted template index template_shift_index: np.ndarray # (shifted temp ix, shifted temp ix) -> did these appear at the same time - cooccurence: np.ndarray + cooccurrence: np.ndarray shifted_temp_ix_to_temp_ix: np.ndarray shifted_temp_ix_to_shift: np.ndarray @@ -602,6 +592,7 @@ def get_shift_and_unit_pairs( motion_est=None, ): n_templates = len(template_data.templates) + print(f"get_shift_and_unit_pairs {motion_est=}") if motion_est is None: # no motion case return static_template_shift_index(n_templates) @@ -610,63 +601,73 @@ def get_shift_and_unit_pairs( all_pitch_shifts = np.empty(shape=(), dtype=int) temp_ixs = np.arange(n_templates) # set of (template idx, shift) - template_shift_pairs = np.empty(shape=(), dtype=template_shift_pair) + template_shift_pairs = np.empty(shape=(0, 2), dtype=int) + pitch = drift_util.get_pitch(geom) for t_s in chunk_time_centers_s: # see the fn `templates_at_time` unregistered_depths_um = drift_util.invert_motion_estimate( motion_est, t_s, template_data.registered_template_depths_um ) + diff = np.abs( + unregistered_depths_um - template_data.registered_template_depths_um + ) pitch_shifts = drift_util.get_spike_pitch_shifts( depths_um=template_data.registered_template_depths_um, - geom=geom, + pitch=pitch, registered_depths_um=unregistered_depths_um, ) pitch_shifts = pitch_shifts.astype(int) # get unique pitch/unit shift pairs in chunk template_shift = np.c_[temp_ixs, pitch_shifts] - template_shift = template_shift.view(template_shift_pair)[:, 0] - assert template_shift.shape == (n_templates,) # update full set all_pitch_shifts = np.union1d(all_pitch_shifts, pitch_shifts) - template_shift_pairs = np.union1d(template_shift_pairs, template_shift) + template_shift_pairs = np.unique( + np.concatenate((template_shift_pairs, template_shift), axis=0), axis=0 + ) + print(f"get_shift_and_unit_pairs {all_pitch_shifts=}") n_shifts = len(all_pitch_shifts) n_template_shift_pairs = len(template_shift_pairs) # index template/shift pairs: template_shift_index[template_ix, shift_ix] = shifted template index # fill with an invalid index - template_shift_index = np.full((n_templates, n_shifts), n_template_shift_pairs + 1) - template_shift_index[ - template_shift_pairs["template_ix"], template_shift_pairs["shift"] - ] = np.arange(n_template_shift_pairs) - shifted_temp_ix_to_temp_ix = template_shift_pairs["template_ix"] - shifted_temp_ix_to_shift = template_shift_pairs["shift"] + template_shift_index = np.full((n_templates, n_shifts), n_template_shift_pairs) + shift_ix = np.searchsorted(all_pitch_shifts, template_shift_pairs[:, 1]) + assert np.array_equal(all_pitch_shifts[shift_ix], template_shift_pairs[:, 1]) + print(f"{template_shift_pairs[:, 0]=}, {shift_ix=} {np.unique(shift_ix)=}") + template_shift_index[template_shift_pairs[:, 0], shift_ix] = np.arange( + n_template_shift_pairs + ) + shifted_temp_ix_to_temp_ix = template_shift_pairs[:, 0] + shifted_temp_ix_to_shift = template_shift_pairs[:, 1] # co-occurrence matrix: do these shifted templates appear together? - cooccurence = np.zeros((n_template_shift_pairs, n_template_shift_pairs), dtype=bool) + cooccurrence = np.zeros( + (n_template_shift_pairs, n_template_shift_pairs), dtype=bool + ) for t_s in chunk_time_centers_s: - # see the fn `templates_at_time` unregistered_depths_um = drift_util.invert_motion_estimate( motion_est, t_s, template_data.registered_template_depths_um ) pitch_shifts = drift_util.get_spike_pitch_shifts( depths_um=template_data.registered_template_depths_um, - geom=geom, + pitch=pitch, registered_depths_um=unregistered_depths_um, ) pitch_shifts = pitch_shifts.astype(int) + pitch_shift_ix = np.searchsorted(all_pitch_shifts, pitch_shifts) - shifted_temp_ixs = template_shift_index[temp_ixs, pitch_shifts] - cooccurence[shifted_temp_ixs[:, None], shifted_temp_ixs[None, :]] = 1 + shifted_temp_ixs = template_shift_index[temp_ixs, pitch_shift_ix] + cooccurrence[shifted_temp_ixs[:, None], shifted_temp_ixs[None, :]] = 1 return TemplateShiftIndex( n_template_shift_pairs, all_pitch_shifts, template_shift_index, - cooccurence, + cooccurrence, shifted_temp_ix_to_temp_ix, shifted_temp_ix_to_shift, ) @@ -680,6 +681,7 @@ def ccorrelate_up( conv_ignore_threshold=0.0, max_shift="full", covisible_mask=None, + batch_size=128, ): """Convolve all pairs of low-rank templates @@ -719,42 +721,161 @@ def ccorrelate_up( if covisible_mask is not None: assert covisible_mask.shape == (na, nb) - # this is covisible with ignore threshold 0 - # no need to convolve templates which do not overlap - covisible = spatial_a.max(1).values @ spatial_b.max(1).values.T + # no need to convolve templates which do not overlap enough + covisible = ( + torch.sqrt(torch.square(spatial_a).sum(1)) + @ torch.sqrt(torch.square(spatial_b).sum(1)).T + ) + covisible = covisible > conv_ignore_threshold + # print(f"{covisible.shape=}") if covisible_mask is not None: covisible *= covisible_mask covisible_a, covisible_b = torch.nonzero(covisible, as_tuple=True) nco = covisible_a.numel() - # TODO: can batch over nco dims below if memory issues arise - - Sa = spatial_a[covisible_a].reshape(nco * rank, nchan) - Sb = spatial_b[covisible_b].reshape(nco * rank, nchan) - spatial_outer = torch.vecdot(Sa, Sb) - spatial_outer = spatial_outer.reshape(nco, rank) - assert spatial_outer.shape == (nco, rank) - - # want conv filter: nco, rank, t - spatial_outer_co = spatial_outer[covisible_a, covisible_b] - conv_filt = spatial_outer_co[:, None, :] * temporal_a.permute(0, 2, 1)[None] - assert conv_filt.shape == (nco, rank, t) - - # nup, nco, rank, t - conv_in = temporal_b[covisible_b].permute(2, 0, 3, 1) - - # conv2d: - # depthwise, chans=nco. batch=1. h=rank. w=t. out: nup, nco, 1, 2p+1. - # input (conv_left): nup, nco, rank, t. - # filters (conv_right): nco, 1, rank, t. (groups=nco). - cconv = F.conv2d(conv_in, conv_filt, padding=max_shift, groups=nco) - assert cconv.shape == (nup, nco, 1, 2 * max_shift + 1) - cconv = cconv[:, :, 0, :].permute(1, 0, 2) + if not nco: + return None, None, None + # print(f"{(nco/covisible.numel())=}") + + # batch over nco for memory reasons + cconv = torch.zeros( + (nco, nup, 2 * max_shift + 1), dtype=spatial_a.dtype, device=spatial_a.device + ) + for istart in range(0, nco, batch_size): + iend = min(istart + batch_size, nco) + co_a = covisible_a[istart:iend] + co_b = covisible_b[istart:iend] + nco_ = iend - istart + + # want conv filter: nco, 1, rank, t + template_a = torch.bmm(temporal_a, spatial_a) + conv_filt = torch.bmm(spatial_b[co_b], template_a[co_a].mT) + conv_filt = conv_filt[:, None] # (nco, 1, rank, t) + + # nup, nco, rank, t + conv_in = temporal_b[co_b].permute(2, 0, 3, 1) + + # conv2d: + # depthwise, chans=nco. batch=1. h=rank. w=t. out: nup, nco, 1, 2p+1. + # input (conv_in): nup, nco, rank, t. + # filters (conv_filt): nco, 1, rank, t. (groups=nco). + cconv_ = F.conv2d(conv_in, conv_filt, padding=(0, max_shift), groups=nco_) + cconv[istart:iend] = cconv_[:, :, 0, :].permute(1, 0, 2) # nco, nup, time # more stringent covisibility if conv_ignore_threshold > 0: - vis = cconv.abs().max(dim=(0, 2)).values > conv_ignore_threshold + max_val = cconv.reshape(nco, -1).abs().max(dim=1).values + vis = max_val > conv_ignore_threshold cconv = cconv[vis] covisible_a = covisible_a[vis] covisible_b = covisible_b[vis] + # print(f"{(covisible_b.numel()/covisible.numel())=}") return covisible_a, covisible_b, cconv + + +# -- helpers + + +def _coarse_approx(cconv, cconv_ix, conv_ix_a, conv_ix_b, unit_a, unit_b, p): + # figure out coarse templates to correlate + conv_ix_a = conv_ix_a.cpu() + conv_ix_b = conv_ix_b.cpu() + conv_unit_a = unit_a[conv_ix_a] + conv_unit_b = unit_b[conv_ix_b] + coarse_units_a = np.unique(conv_unit_a) + coarse_units_b = np.unique(conv_unit_b) + coarsecovis = np.zeros((coarse_units_a.size, coarse_units_b.size), dtype=bool) + coarsecovis[ + np.searchsorted(coarse_units_a, conv_unit_a), + np.searchsorted(coarse_units_b, conv_unit_b), + ] = True + + # correlate them + coarse_ix_a, coarse_ix_b, coarse_cconv = ccorrelate_up( + p.coarse_spatial_singular[coarse_units_a].to(p.device), + p.coarse_temporal[coarse_units_a].to(p.device), + p.coarse_spatial_singular[coarse_units_b].to(p.device), + p.coarse_temporal[coarse_units_b].unsqueeze(2).to(p.device), + conv_ignore_threshold=p.conv_ignore_threshold, + max_shift=p.max_shift, + covisible_mask=torch.as_tensor(coarsecovis, device=p.device), + ) + if coarse_ix_a is None: + return cconv, cconv_ix + + coarse_units_a = np.atleast_1d(coarse_units_a[coarse_ix_a.cpu()]) + coarse_units_b = np.atleast_1d(coarse_units_b[coarse_ix_b.cpu()]) + + # find coarse units which well summarize the fine cconvs + for coarse_unit_a, coarse_unit_b, conv in zip( + coarse_units_a, coarse_units_b, coarse_cconv + ): + # check good approx. if not, continue + in_pair = np.flatnonzero( + (conv_unit_a == coarse_unit_a) & (conv_unit_b == coarse_unit_b) + ) + assert in_pair.size + fine_cconvs = cconv[cconv_ix[in_pair]] + approx_err = (fine_cconvs - conv[None]).abs().max() + if not approx_err < p.coarse_approx_error_threshold: + continue + + # replace first fine cconv with the coarse cconv + cconv[cconv_ix[in_pair[0]]] = conv + # set all fine cconv ix to the index of that first one + cconv_ix[in_pair] = cconv_ix[in_pair[0]] + + # re-index and subset cconvs + cconv_ix_subset = np.unique(cconv_ix) + conv_ix_a = conv_ix_a[cconv_ix_subset] + conv_ix_b = conv_ix_b[cconv_ix_subset] + cconv = cconv[cconv_ix_subset] + return cconv, cconv_ix_subset + +def _shift_normalize(cconv, cconv_ix, temp_ix_a, shift_a, temp_ix_b, shift_b, atol=1e-1): + pairs_done = set() + for ua, ub in zip(temp_ix_a, temp_ix_b): + if (ua, ub) in pairs_done: + continue + pairs_done.add((ua, ub)) + + in_pair = np.flatnonzero( + (temp_ix_a == ua) & (temp_ix_b == ub) + ) + diffs = shift_a[in_pair] - shift_b[in_pair] + changed = False + for diff in np.unique(diffs): + in_diff = in_pair[diffs == diff] + + cconvs = cconv[cconv_ix[in_diff]] + meanconv = cconvs.mean(0, keepdims=True) + err = (cconvs - meanconv).abs().max() + if err > atol: + continue + changed = True + cconv[cconv_ix[in_diff[0]]] = meanconv + cconv_ix[in_diff] = cconv_ix[in_diff[0]] + if changed: + pairs_done.remove((ua, ub)) + + for ua, ub in zip(temp_ix_a, temp_ix_b): + if (ua, ub) in pairs_done: + continue + pairs_done.add((ua, ub)) + + in_pair = np.flatnonzero( + (temp_ix_a == ua) & (temp_ix_b == ub) + ) + cconvs = cconv[cconv_ix[in_pair]] + meanconv = cconvs.mean(0, keepdims=True) + err = (cconvs - meanconv).abs().max() + if err > atol: + continue + + cconv[cconv_ix[in_pair[0]]] = meanconv + cconv_ix[in_pair] = cconv_ix[in_pair[0]] + + # re-index and subset cconvs + cconv_ix_subset = np.unique(cconv_ix) + cconv = cconv[cconv_ix_subset] + return cconv, cconv_ix_subset \ No newline at end of file diff --git a/src/dartsort/templates/pairwise_conv.py b/src/dartsort/templates/pairwise_conv.py index be23ba67..26b5b76a 100644 --- a/src/dartsort/templates/pairwise_conv.py +++ b/src/dartsort/templates/pairwise_conv.py @@ -2,21 +2,20 @@ def sparse_pairwise_conv( - sorting, + template_data, template_temporal_components, template_upsampled_temporal_components, template_singular_values, template_spatial_components, + chunk_time_centers_s=None, + motion_est=None, conv_ignore_threshold: float = 0.0, coarse_approx_error_threshold: float = 0.0, ): """ - + Arguments --------- - sorting : DARTsortSorting - original (non-superres) sorting. its labels should appear in - template_data.unit_ids template_* : tensors or arrays template SVD approximations conv_ignore_threshold: float = 0.0 @@ -25,7 +24,7 @@ def sparse_pairwise_conv( coarse_approx_error_threshold: float = 0.0 superres will not be used if coarse pconv and superres pconv are uniformly closer than this threshold value - + Returns ------- pitch_shifts : array @@ -39,27 +38,111 @@ def sparse_pairwise_conv( pconvs: np.ndarray pconv[pconv_index] is a cross-correlation of two templates, summed over chans """ - - - + # find all of the co-occurring pitch shift and unit id pairs + all_pitch_shifts, shift_unit_pairs = get_shift_and_unit_pairs( + chunk_time_centers_s, + geom, + template_data, + motion_est=motion_est, + ) + + +# defining this dtype, which represents a pair of units and shifts, +# allows us to use numpy's 1d set functions on these pairs +shift_unit_pair_dtype = np.dtype( + [("unita", int), ("shifta", int), ("unitb", int), ("shiftb", int)] +) + + +class PairwiseConvContext: + def __init__( + self, + coarse_spatial, + coarse_singular, + coarse_f_temporal, + spatial, + singular, + f_temporal, + f_temporal_up, + geom, + registered_geom, + ): + def _pairwise_conv_job( units_a, units_b, ): """units_a,b are chunks of original (non-superres) unit labels""" - # determine co-visibility - # get all coarse templates - # get all superres templates - # compute all coarse and superres pconvs # returns - # list of tuples containing: - # - pitch shift ix a - # - pitch shift ix b - # - superres label a - # - superres label b - # list of the same length containing: + # array of type shift_unit_pair_dtype + # array of the same length containing # - -1 or an index into the next list # list of pconvs, indexed by previous list - \ No newline at end of file + + # determine co-visible shift/unit pairs + + # extract template data for left and right entries of each + # pair into npairs-len structures + # "depthwise" convolve these two structures + + # same for the coarse templates + + # when max pconv is < co-correlation threshold: + # - key list entry gets -1 + + # now, the coarse part + # for each pair of coarse units, check if the max difference + # of coarse pconv and all superres pconvs is small enough, + # and use an id for the (temporally upsampled) coarse pconv if so + + + pass + + + +def get_shift_and_unit_pairs( + chunk_time_centers_s, + geom, + template_data, + motion_est=None, +): + if motion_est is None: + return None, None + + # all observed pitch shift values + all_pitch_shifts = [] + # set of (unit a, shift a, unit b, shift b) + # units are unit ids, not (superres) template indices + shift_unit_pairs = [] + + for t_s in chunk_time_centers_s: + # see the fn `templates_at_time` + unregistered_depths_um = drift_util.invert_motion_estimate( + motion_est, t_s, template_data.registered_template_depths_um + ) + pitch_shifts = drift_util.get_spike_pitch_shifts( + depths_um=template_data.registered_template_depths_um, + geom=geom, + registered_depths_um=unregistered_depths_um, + ) + + # get unique pitch/unit shift pairs in chunk + pitch_and_unit = np.c_[td.unit_ids, pitch_shifts.astype(int)] + pairs = np.concatenate( + np.broadcast_arrays( + pitch_and_unit[:, None, :], + pitch_and_unit[None, :, :], + ), + axis=2, + ) + pairs = pairs.reshape(len(td.unit_ids) ** 2, 4) + pairs = np.ascontiguousarray(pairs).view(shift_unit_pair_dtype) + unique_pairs_in_chunk = np.unique(pairs) + + # update full set + all_pitch_shifts = np.union1d(all_pitch_shifts, pitch_shifts) + shift_unit_pairs = np.union1d(shift_unit_pairs, unique_pairs_in_chunk) + + return all_pitch_shifts, shift_unit_pairs diff --git a/src/dartsort/templates/superres_util.py b/src/dartsort/templates/superres_util.py index 8476cf37..579f8cbd 100644 --- a/src/dartsort/templates/superres_util.py +++ b/src/dartsort/templates/superres_util.py @@ -129,6 +129,9 @@ def drift_pitch_loc_bin_strategy( ) coarse_reg_depths = spike_depths_um + n_pitches_shift * pitch bin_ids = coarse_reg_depths // superres_bin_size_um + print(f"{np.isnan(n_pitches_shift).any()=} {np.isfinite(bin_ids).all()=} {superres_bin_size_um=}") + print(f"{bin_ids.min()=} {bin_ids.max()=} {bin_ids.shape=}") + print(f"{original_labels.min()=} {original_labels.max()=} {original_labels.shape=}") bin_ids = bin_ids.astype(int) orig_label_and_bin, superres_labels = np.unique( np.c_[original_labels, bin_ids], axis=0, return_inverse=True diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index f3f217b9..77c49ffa 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -122,6 +122,7 @@ def weighted_average(unit_ids, templates, weights): n_out = unit_ids.max() + 1 n_in, t, c = templates.shape out = np.zeros((n_out, t, c), dtype=templates.dtype) + weights = weights.astype(float) for i in range(n_out): which_in = np.flatnonzero(unit_ids == i) if not which_in.size: @@ -188,7 +189,7 @@ def templates_at_time( # -- template numerical processing -def svd_compress_templates(templates, min_channel_amplitude=1.0, rank=5): +def svd_compress_templates(templates, min_channel_amplitude=1.0, rank=5, channel_sparse=True): """ Returns: temporal_components: n_units, spike_length_samples, rank @@ -197,11 +198,35 @@ def svd_compress_templates(templates, min_channel_amplitude=1.0, rank=5): """ vis_mask = templates.ptp(axis=1, keepdims=True) > min_channel_amplitude vis_templates = templates * vis_mask - U, s, Vh = np.linalg.svd(vis_templates, full_matrices=False) - # s is descending. - temporal_components = U[:, :, :rank] - singular_values = s[:, :rank] - spatial_components = Vh[:, :rank, :] + dtype = templates.dtype + + if not channel_sparse: + U, s, Vh = np.linalg.svd(vis_templates, full_matrices=False) + # s is descending. + temporal_components = U[:, :, :rank].astype(dtype) + singular_values = s[:, :rank].astype(dtype) + spatial_components = Vh[:, :rank, :].astype(dtype) + return temporal_components, singular_values, spatial_components + + # channel sparse: only SVD the nonzero channels + # this encodes the same exact subspace as above, and the reconstruction + # error is the same as above as a function of rank. it's just that + # we can zero out some spatial components, which is a useful property + # (used in pairwise convolutions for instance) + n, t, c = templates.shape + temporal_components = np.zeros((n, t, rank), dtype=dtype) + singular_values = np.zeros((n, rank), dtype=dtype) + spatial_components = np.zeros((n, rank, c), dtype=dtype) + for i in range(len(templates)): + template = templates[i] + mask = np.flatnonzero(vis_mask[i, 0]) + k = min(rank, mask.size) + if not k: + continue + U, s, Vh = np.linalg.svd(template[:, mask], full_matrices=False) + temporal_components[i, :, :k] = U[:, :rank] + singular_values[i, :k] = s[:rank] + spatial_components[i, :k, mask] = Vh[:rank].T return temporal_components, singular_values, spatial_components diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index c2e53a96..d527432a 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -230,6 +230,7 @@ def invert_motion_estimate(motion_est, t_s, registered_depths_um): hasattr(motion_est, "spatial_bin_centers_um") and motion_est.spatial_bin_centers_um is not None ): + # nonrigid motion bin_centers = motion_est.spatial_bin_centers_um t_s = np.full(bin_centers.shape, t_s) bin_center_disps = motion_est.disp_at_s(t_s, depth_um=bin_centers) @@ -240,6 +241,7 @@ def invert_motion_estimate(motion_est, t_s, registered_depths_um): registered_depths_um, registered_bin_centers, bin_center_disps ) else: + # rigid motion disps = motion_est.disp_at_s(t_s) return registered_depths_um + disps @@ -414,7 +416,7 @@ def _full_probe_shifting_fast( out=None, ): is_tensor = torch.is_tensor(waveforms) - + if out is None: if is_tensor: static_waveforms = torch.full( From 8d88b1fe00d89a1658a00d7d0021b37b359c7754 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 31 Oct 2023 18:43:13 -0400 Subject: [PATCH 10/49] Fix up tests --- src/dartsort/detect/detect.py | 2 +- src/dartsort/main.py | 4 ++-- src/dartsort/peel/__init__.py | 2 +- src/dartsort/peel/matching.py | 5 ++--- src/dartsort/templates/get_templates.py | 19 ++++++++++++------- src/dartsort/templates/template_util.py | 4 +--- src/dartsort/util/data_util.py | 1 + src/dartsort/util/drift_util.py | 3 ++- src/dartsort/util/spiketorch.py | 14 +++++++++++++- src/dartsort/util/waveform_util.py | 11 ----------- tests/test_templates.py | 16 ++++++++++++---- 11 files changed, 47 insertions(+), 34 deletions(-) diff --git a/src/dartsort/detect/detect.py b/src/dartsort/detect/detect.py index 59fc5de5..cd705191 100644 --- a/src/dartsort/detect/detect.py +++ b/src/dartsort/detect/detect.py @@ -41,7 +41,7 @@ def detect_and_deduplicate( with corresponding channels """ nsamples, nchans = traces.shape - if dedup_channel_index == "all": + if isinstance(dedup_channel_index, str) and dedup_channel_index == "all": pass elif dedup_channel_index is not None: assert dedup_channel_index.shape[0] == nchans diff --git a/src/dartsort/main.py b/src/dartsort/main.py index c0bf73da..4cb5ecf9 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -3,7 +3,7 @@ from dartsort.config import (FeaturizationConfig, MatchingConfig, SubtractionConfig, TemplateConfig) from dartsort.localize.localize_util import localize_hdf5 -from dartsort.peel import (ResidualUpdateTemplateMatchingPeeler, +from dartsort.peel import (ObjectiveUpdateTemplateMatchingPeeler, SubtractionPeeler) from dartsort.templates import TemplateData from dartsort.util.data_util import DARTsortSorting, check_recording @@ -104,7 +104,7 @@ def match( ) # instantiate peeler - matching_peeler = ResidualUpdateTemplateMatchingPeeler.from_config( + matching_peeler = ObjectiveUpdateTemplateMatchingPeeler.from_config( recording, matching_config, featurization_config, diff --git a/src/dartsort/peel/__init__.py b/src/dartsort/peel/__init__.py index aa03dd24..811f3f96 100644 --- a/src/dartsort/peel/__init__.py +++ b/src/dartsort/peel/__init__.py @@ -1,2 +1,2 @@ -from .matching import ResidualUpdateTemplateMatchingPeeler +from .matching import ObjectiveUpdateTemplateMatchingPeeler from .subtract import SubtractionPeeler, subtract_chunk diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index e1f513fa..c04c694e 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -12,7 +12,6 @@ import numpy as np import torch import torch.nn.functional as F -from dartsort.detect import detect_and_deduplicate from dartsort.templates import template_util from dartsort.transform import WaveformPipeline from dartsort.util import spiketorch @@ -22,7 +21,7 @@ from .peel_base import BasePeeler -class ResidualUpdateTemplateMatchingPeeler(BasePeeler): +class ObjectiveUpdateTemplateMatchingPeeler(BasePeeler): peel_kind = "TemplateMatching" def __init__( @@ -583,7 +582,7 @@ def get_collisioncleaned_waveforms( @dataclass class CompressedTemplateData: - """Objects of this class are returned by ResidualUpdateTemplateMatchingPeeler.templates_at_time()""" + """Objects of this class are returned by ObjectiveUpdateTemplateMatchingPeeler.templates_at_time()""" spatial_components: torch.Tensor singular_values: torch.Tensor diff --git a/src/dartsort/templates/get_templates.py b/src/dartsort/templates/get_templates.py index d8b1f993..fdb49cd7 100644 --- a/src/dartsort/templates/get_templates.py +++ b/src/dartsort/templates/get_templates.py @@ -10,7 +10,8 @@ from dartsort.util import spikeio from dartsort.util.drift_util import registered_template from dartsort.util.multiprocessing_util import get_pool -from dartsort.util.waveform_util import fast_nanmedian, make_channel_index +from dartsort.util.spiketorch import fast_nanmedian, ptp +from dartsort.util.waveform_util import make_channel_index from scipy.spatial import KDTree from scipy.spatial.distance import pdist from sklearn.decomposition import TruncatedSVD @@ -469,7 +470,7 @@ def __init__( self.max_spike_time = recording.get_num_samples() - ( spike_length_samples - trough_offset_samples ) - + self.spike_buffer = torch.zeros( (spikes_per_unit * units_per_job, spike_length_samples, self.n_channels), device=device, @@ -573,8 +574,8 @@ def _template_job(unit_ids): trough_offset_samples=p.trough_offset_samples, spike_length_samples=p.spike_length_samples, ) - p.spike_buffer[:times.size] = torch.from_numpy(waveforms) - waveforms = p.spike_buffer[:times.size] + p.spike_buffer[: times.size] = torch.from_numpy(waveforms) + waveforms = p.spike_buffer[: times.size] n, t, c = waveforms.shape # compute raw templates and spike counts per channel @@ -611,9 +612,11 @@ def _template_job(unit_ids): ) ) else: - raw_templates.append(p.reducer(waveforms[in_unit], axis=0)) + raw_templates.append( + p.reducer(waveforms[in_unit], axis=0).numpy(force=True) + ) counts.append(in_unit.size) - snrs_by_chan = [rt.ptp(0) * c for rt, c in zip(raw_templates, counts)] + snrs_by_chan = [ptp(rt, 0) * c for rt, c in zip(raw_templates, counts)] if p.denoising_tsvd is None: return uids, raw_templates, None, snrs_by_chan @@ -644,7 +647,9 @@ def _template_job(unit_ids): ) ) else: - low_rank_templates.append(p.reducer(waveforms[in_unit], axis=0)) + low_rank_templates.append( + p.reducer(waveforms[in_unit], axis=0).numpy(force=True) + ) return uids, raw_templates, low_rank_templates, snrs_by_chan diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 77c49ffa..dbdc4d9e 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -2,7 +2,7 @@ from dartsort.localize.localize_util import localize_waveforms from dartsort.util import drift_util from dartsort.util.data_util import DARTsortSorting -from dartsort.util.waveform_util import fast_nanmedian +from dartsort.util.spiketorch import fast_nanmedian from scipy.interpolate import interp1d from .get_templates import get_raw_templates, get_templates @@ -55,7 +55,6 @@ def get_registered_templates( denoising_fit_radius=75, denoising_spikes_fit=50_000, denoising_snr_threshold=50.0, - zero_radius_um=None, reducer=fast_nanmedian, random_seed=0, n_jobs=0, @@ -84,7 +83,6 @@ def get_registered_templates( denoising_fit_radius=denoising_fit_radius, denoising_spikes_fit=denoising_spikes_fit, denoising_snr_threshold=denoising_snr_threshold, - zero_radius_um=zero_radius_um, reducer=reducer, random_seed=random_seed, n_jobs=n_jobs, diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index 66367999..c7b29b9f 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -73,6 +73,7 @@ def __str__(self): name = self.__class__.__name__ nspikes = self.times_samples.size nunits = (np.unique(self.labels) >= 0).sum() + print(f"{nunits=}") unit_str = f"{nunits} unit" + "s" * (nunits > 1) feat_str = "" if self.extra_features: diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index d527432a..8c698490 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -15,7 +15,8 @@ from scipy.spatial import KDTree from scipy.spatial.distance import pdist -from .waveform_util import fast_nanmedian, get_pitch +from .spiketorch import fast_nanmedian +from .waveform_util import get_pitch # -- registered geometry and templates helpers diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index eb97b64d..439413fe 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -3,7 +3,19 @@ from torch.fft import irfft, rfft +def fast_nanmedian(x, axis=-1): + is_tensor = torch.is_tensor(x) + x = torch.nanmedian(torch.as_tensor(x), dim=axis).values + if is_tensor: + return x + else: + return x.numpy() + + def ptp(waveforms, dim=1): + is_tensor = torch.is_tensor(waveforms) + if not is_tensor: + return waveforms.ptp(axis=dim) return waveforms.max(dim=dim).values - waveforms.min(dim=dim).values @@ -262,4 +274,4 @@ def depthwise_oaconv1d(input, weight, f2=None, padding=0): # this is the full convolution oa = oa[:, :shape_final - pad1] - return oa \ No newline at end of file + return oa diff --git a/src/dartsort/util/waveform_util.py b/src/dartsort/util/waveform_util.py index 4acee974..a1394734 100644 --- a/src/dartsort/util/waveform_util.py +++ b/src/dartsort/util/waveform_util.py @@ -335,14 +335,3 @@ def get_channel_subset( npx.arange(N)[:, None], rel_sub_channel_index[max_channels][:, :], ] - - -# -- general util - -def fast_nanmedian(x, axis=-1): - is_tensor = torch.is_tensor(x) - x = torch.nanmedian(torch.as_tensor(x), dim=axis).values - if is_tensor: - return x - else: - return x.numpy() diff --git a/tests/test_templates.py b/tests/test_templates.py index f6faad40..63cbd357 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -130,17 +130,25 @@ def test_main_object(): labels=[0, 0, 1, 1], channels=[0, 0, 0, 0], sampling_frequency=1, - extra_features=dict(point_source_localizations=np.zeros((4, 4)), times_seconds=[0, 2, 6, 8]), + extra_features=dict( + point_source_localizations=np.zeros((4, 4)), times_seconds=[0, 2, 6, 8] + ), ) tdata = templates.TemplateData.from_config( rec, sorting, - config.TemplateConfig(trough_offset_samples=0, spike_length_samples=2, realign_peaks=False), + config.TemplateConfig( + trough_offset_samples=0, + spike_length_samples=2, + realign_peaks=False, + superres_templates=False, + denoising_rank=2, + ), motion_est=me, ) if __name__ == "__main__": - test_static_templates() - test_drifting_templates() + # test_static_templates() + # test_drifting_templates() test_main_object() From eb2a16984959fa81581fe7db35dc15b47a594ce6 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 1 Nov 2023 15:54:13 -0400 Subject: [PATCH 11/49] Debug --- src/dartsort/detect/detect.py | 7 +- src/dartsort/templates/pairwise.py | 104 ++++++-------- src/dartsort/templates/template_util.py | 16 ++- src/dartsort/templates/templates.py | 1 + src/dartsort/util/data_util.py | 1 - tests/test_templates.py | 182 +++++++++++++++++++++++- 6 files changed, 239 insertions(+), 72 deletions(-) diff --git a/src/dartsort/detect/detect.py b/src/dartsort/detect/detect.py index cd705191..7700e2f0 100644 --- a/src/dartsort/detect/detect.py +++ b/src/dartsort/detect/detect.py @@ -41,9 +41,8 @@ def detect_and_deduplicate( with corresponding channels """ nsamples, nchans = traces.shape - if isinstance(dedup_channel_index, str) and dedup_channel_index == "all": - pass - elif dedup_channel_index is not None: + all_dedup = isinstance(dedup_channel_index, str) and dedup_channel_index == "all" + if not all_dedup and dedup_channel_index is not None: assert dedup_channel_index.shape[0] == nchans # -- handle peak sign. we use max pool below, so make peaks positive @@ -99,7 +98,7 @@ def detect_and_deduplicate( # -- spatial deduplication # we would like to max pool again on the other axis, # but that doesn't support any old radial neighborhood - if dedup_channel_index == "all": + if all_dedup: max_energies[:] = max_energies.max(dim=1, keepdim=True).values elif dedup_channel_index is not None: # pad channel axis with extra chan of 0s diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index e1952d98..bb3e0574 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -81,13 +81,9 @@ def sparse_pairwise_conv( # check if the convolutions need to be drift-aware # they do if we need to do any channel selection - print(f"{temp_shift_index.all_pitch_shifts=}") is_drifting = not np.array_equal(temp_shift_index.all_pitch_shifts, [0]) - print(f"A {is_drifting=}") if template_data.registered_geom is not None: - print(f"{np.array_equal(geom, template_data.registered_geom)=}") is_drifting |= not np.array_equal(geom, template_data.registered_geom) - print(f"B {is_drifting=}") # initialize pairwise conv data structures # index_table[shifted_temp_ix(i), shifted_temp_ix(j)] = pconvix(i,j) @@ -141,9 +137,9 @@ def sparse_pairwise_conv( pconv_index_table[ res.shifted_temp_ix_a, res.shifted_temp_ix_b ] = new_conv_ix - pconv.resize(cur_pconv_ix + new_conv_ix.size, axis=0) - pconv[new_conv_ix] = res.cconv_up - cur_pconv_ix += new_conv_ix.size + pconv.resize(cur_pconv_ix + res.cconv_up.shape[0], axis=0) + pconv[cur_pconv_ix:] = res.cconv_up + cur_pconv_ix += res.cconv_up.shape[0] # smaller datasets all at once h5.create_dataset( @@ -208,16 +204,25 @@ def query( ------- template_indices_a, template_indices_b, pair_convs """ - # get shifted template indices + template_indices_a = np.atleast_1d(template_indices_a) + template_indices_b = np.atleast_1d(template_indices_b) + shifted = shifts_a is not None + if shifted: + assert shifts_b is not None + shifts_a = np.atleast_1d(shifts_a) + shifts_b = np.atleast_1d(shifts_b) + else: + assert np.array_equal(self.shifts, [0.0]) + + # handle upsampling pconv = self.pconv - if upsampling_indices_b is None: + upsampled = upsampling_indices_b is not None + if not upsampled: assert self.pconv.shape[1] == 1 pconv = pconv[:, 0, :] - if shifts_a is None or shifts_b is None: - assert np.array_equal(self.shifts, [0.0]) - shifted_temp_ix_a = template_indices_a - shifted_temp_ix_b = template_indices_b - else: + + # get shifted template indices + if shifted: shift_ix_a = np.searchsorted(self.shifts, shifts_a) assert np.array_equal(self.shifts[shift_ix_a], shifts_a) shift_ix_b = np.searchsorted(self.shifts, shifts_b) @@ -228,8 +233,14 @@ def query( shifted_temp_ix_b = self.template_shift_index[ template_indices_b, shift_ix_b ] + else: + shifted_temp_ix_a = template_indices_a + shifted_temp_ix_b = template_indices_b - pconv_indices = self.pconv_index_table[shifted_temp_ix_a, shifted_temp_ix_b] + # we only store the upper triangle of this symmetric object + min_ = np.minimum(shifted_temp_ix_a, shifted_temp_ix_b) + max_ = np.maximum(shifted_temp_ix_a, shifted_temp_ix_b) + pconv_indices = self.pconv_index_table[min_, max_] # most users will be happy not to get a bunch of zeros for pairs that don't overlap if not return_zero_convs: @@ -240,10 +251,10 @@ def query( if upsampling_indices_b is not None: upsampling_indices_b = upsampling_indices_b[which] - if upsampling_indices_b is None: - pair_convs = pconv[pconv_indices] - else: + if upsampled: pair_convs = pconv[pconv_indices, upsampling_indices_b] + else: + pair_convs = pconv[pconv_indices] return template_indices_a, template_indices_b, pair_convs @@ -275,7 +286,7 @@ def compute_pairwise_convs( jobs = [] for start_a in range(0, units.size, units_per_chunk): end_a = min(start_a + units_per_chunk, units.size) - for start_b in range(start_a + 1, units.size, units_per_chunk): + for start_b in range(start_a, units.size, units_per_chunk): end_b = min(start_b + units_per_chunk, units.size) jobs.append((units[start_a:end_a], units[start_b:end_b])) if show_progress: @@ -291,7 +302,6 @@ def compute_pairwise_convs( coarse_templates = template_util.weighted_average( template_data.unit_ids, template_data.templates, template_data.spike_counts ) - print(f"{coarse_templates.shape=}") ( coarse_temporal, coarse_singular, @@ -456,7 +466,6 @@ def _pairwise_conv_job(unit_chunk): # get shifted spatial components spatial_a = p.spatial_singular[temp_ix_a] spatial_b = p.spatial_singular[temp_ix_b] - # print(f"{p.is_drifting=} old {spatial_a.shape=}") if p.is_drifting: spatial_a = drift_util.get_waveforms_on_static_channels( spatial_a, @@ -467,7 +476,6 @@ def _pairwise_conv_job(unit_chunk): match_distance=p.match_distance, fill_value=0.0, ) - # print(f"new {spatial_a.shape=} {p.target_kdtree=}") spatial_b = drift_util.get_waveforms_on_static_channels( spatial_b, p.registered_geom, @@ -503,10 +511,10 @@ def _pairwise_conv_job(unit_chunk): if not nco: return None cconv_ix = np.arange(nco) - + # shifts may not matter if p.is_drifting: - cconv, cconv_ix_subset = _shift_normalize( + cconv, cconv_ix = _shift_normalize( cconv, cconv_ix, temp_ix_a[conv_ix_a.cpu()], @@ -514,18 +522,12 @@ def _pairwise_conv_job(unit_chunk): temp_ix_b[conv_ix_b.cpu()], shift_b[conv_ix_b.cpu()], ) - conv_ix_a = conv_ix_a[cconv_ix_subset] - conv_ix_b = conv_ix_b[cconv_ix_subset] - cconv_ix = np.arange(len(cconv_ix_subset)) # summarize units by coarse pconv when possible if p.coarse_approx_error_threshold > 0: - cconv, cconv_ix_subset = _coarse_approx( + cconv, cconv_ix = _coarse_approx( cconv, cconv_ix, conv_ix_a, conv_ix_b, unit_a, unit_b, p ) - conv_ix_a = conv_ix_a[cconv_ix_subset] - conv_ix_b = conv_ix_b[cconv_ix_subset] - cconv_ix = np.arange(len(cconv_ix_subset)) # for use in deconv residual distance merge # TODO: actually probably need to do the real objective here with @@ -592,13 +594,12 @@ def get_shift_and_unit_pairs( motion_est=None, ): n_templates = len(template_data.templates) - print(f"get_shift_and_unit_pairs {motion_est=}") if motion_est is None: # no motion case return static_template_shift_index(n_templates) # all observed pitch shift values - all_pitch_shifts = np.empty(shape=(), dtype=int) + all_pitch_shifts = np.empty(shape=(0,), dtype=int) temp_ixs = np.arange(n_templates) # set of (template idx, shift) template_shift_pairs = np.empty(shape=(0, 2), dtype=int) @@ -609,9 +610,6 @@ def get_shift_and_unit_pairs( unregistered_depths_um = drift_util.invert_motion_estimate( motion_est, t_s, template_data.registered_template_depths_um ) - diff = np.abs( - unregistered_depths_um - template_data.registered_template_depths_um - ) pitch_shifts = drift_util.get_spike_pitch_shifts( depths_um=template_data.registered_template_depths_um, pitch=pitch, @@ -627,7 +625,6 @@ def get_shift_and_unit_pairs( template_shift_pairs = np.unique( np.concatenate((template_shift_pairs, template_shift), axis=0), axis=0 ) - print(f"get_shift_and_unit_pairs {all_pitch_shifts=}") n_shifts = len(all_pitch_shifts) n_template_shift_pairs = len(template_shift_pairs) @@ -637,7 +634,6 @@ def get_shift_and_unit_pairs( template_shift_index = np.full((n_templates, n_shifts), n_template_shift_pairs) shift_ix = np.searchsorted(all_pitch_shifts, template_shift_pairs[:, 1]) assert np.array_equal(all_pitch_shifts[shift_ix], template_shift_pairs[:, 1]) - print(f"{template_shift_pairs[:, 0]=}, {shift_ix=} {np.unique(shift_ix)=}") template_shift_index[template_shift_pairs[:, 0], shift_ix] = np.arange( n_template_shift_pairs ) @@ -645,9 +641,7 @@ def get_shift_and_unit_pairs( shifted_temp_ix_to_shift = template_shift_pairs[:, 1] # co-occurrence matrix: do these shifted templates appear together? - cooccurrence = np.zeros( - (n_template_shift_pairs, n_template_shift_pairs), dtype=bool - ) + cooccurrence = np.eye(n_template_shift_pairs, dtype=bool) for t_s in chunk_time_centers_s: unregistered_depths_um = drift_util.invert_motion_estimate( motion_est, t_s, template_data.registered_template_depths_um @@ -727,14 +721,12 @@ def ccorrelate_up( @ torch.sqrt(torch.square(spatial_b).sum(1)).T ) covisible = covisible > conv_ignore_threshold - # print(f"{covisible.shape=}") if covisible_mask is not None: covisible *= covisible_mask covisible_a, covisible_b = torch.nonzero(covisible, as_tuple=True) nco = covisible_a.numel() if not nco: return None, None, None - # print(f"{(nco/covisible.numel())=}") # batch over nco for memory reasons cconv = torch.zeros( @@ -768,7 +760,6 @@ def ccorrelate_up( cconv = cconv[vis] covisible_a = covisible_a[vis] covisible_b = covisible_b[vis] - # print(f"{(covisible_b.numel()/covisible.numel())=}") return covisible_a, covisible_b, cconv @@ -826,27 +817,26 @@ def _coarse_approx(cconv, cconv_ix, conv_ix_a, conv_ix_b, unit_a, unit_b, p): cconv_ix[in_pair] = cconv_ix[in_pair[0]] # re-index and subset cconvs - cconv_ix_subset = np.unique(cconv_ix) - conv_ix_a = conv_ix_a[cconv_ix_subset] - conv_ix_b = conv_ix_b[cconv_ix_subset] + cconv_ix_subset, new_cconv_ix = np.unique(cconv_ix, return_inverse=True) cconv = cconv[cconv_ix_subset] - return cconv, cconv_ix_subset + return cconv, new_cconv_ix + -def _shift_normalize(cconv, cconv_ix, temp_ix_a, shift_a, temp_ix_b, shift_b, atol=1e-1): +def _shift_normalize( + cconv, cconv_ix, temp_ix_a, shift_a, temp_ix_b, shift_b, atol=1e-1 +): pairs_done = set() for ua, ub in zip(temp_ix_a, temp_ix_b): if (ua, ub) in pairs_done: continue pairs_done.add((ua, ub)) - in_pair = np.flatnonzero( - (temp_ix_a == ua) & (temp_ix_b == ub) - ) + in_pair = np.flatnonzero((temp_ix_a == ua) & (temp_ix_b == ub)) diffs = shift_a[in_pair] - shift_b[in_pair] changed = False for diff in np.unique(diffs): in_diff = in_pair[diffs == diff] - + cconvs = cconv[cconv_ix[in_diff]] meanconv = cconvs.mean(0, keepdims=True) err = (cconvs - meanconv).abs().max() @@ -863,9 +853,7 @@ def _shift_normalize(cconv, cconv_ix, temp_ix_a, shift_a, temp_ix_b, shift_b, at continue pairs_done.add((ua, ub)) - in_pair = np.flatnonzero( - (temp_ix_a == ua) & (temp_ix_b == ub) - ) + in_pair = np.flatnonzero((temp_ix_a == ua) & (temp_ix_b == ub)) cconvs = cconv[cconv_ix[in_pair]] meanconv = cconvs.mean(0, keepdims=True) err = (cconvs - meanconv).abs().max() @@ -876,6 +864,6 @@ def _shift_normalize(cconv, cconv_ix, temp_ix_a, shift_a, temp_ix_b, shift_b, at cconv_ix[in_pair] = cconv_ix[in_pair[0]] # re-index and subset cconvs - cconv_ix_subset = np.unique(cconv_ix) + cconv_ix_subset, new_cconv_ix = np.unique(cconv_ix, return_inverse=True) cconv = cconv[cconv_ix_subset] - return cconv, cconv_ix_subset \ No newline at end of file + return cconv, new_cconv_ix diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index dbdc4d9e..780fc7d7 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -187,7 +187,9 @@ def templates_at_time( # -- template numerical processing -def svd_compress_templates(templates, min_channel_amplitude=1.0, rank=5, channel_sparse=True): +def svd_compress_templates( + templates, min_channel_amplitude=1.0, rank=5, channel_sparse=True +): """ Returns: temporal_components: n_units, spike_length_samples, rank @@ -197,7 +199,7 @@ def svd_compress_templates(templates, min_channel_amplitude=1.0, rank=5, channel vis_mask = templates.ptp(axis=1, keepdims=True) > min_channel_amplitude vis_templates = templates * vis_mask dtype = templates.dtype - + if not channel_sparse: U, s, Vh = np.linalg.svd(vis_templates, full_matrices=False) # s is descending. @@ -205,7 +207,7 @@ def svd_compress_templates(templates, min_channel_amplitude=1.0, rank=5, channel singular_values = s[:, :rank].astype(dtype) spatial_components = Vh[:, :rank, :].astype(dtype) return temporal_components, singular_values, spatial_components - + # channel sparse: only SVD the nonzero channels # this encodes the same exact subspace as above, and the reconstruction # error is the same as above as a function of rank. it's just that @@ -234,10 +236,12 @@ def temporally_upsample_templates( """Note, also works on temporal components thanks to compatible shape.""" n, t, c = templates.shape tp = np.arange(t).astype(float) - erp = interp1d(tp, templates, axis=1, bounds_error=True) - tup = np.arange(t, step=1. / temporal_upsampling_factor) + erp = interp1d(tp, templates, axis=1, bounds_error=True, kind=kind) + tup = np.arange(t, step=1.0 / temporal_upsampling_factor) tup.clip(0, t - 1, out=tup) upsampled_templates = erp(tup) - upsampled_templates = upsampled_templates.reshape(n, t, temporal_upsampling_factor, c) + upsampled_templates = upsampled_templates.reshape( + n, t, temporal_upsampling_factor, c + ) upsampled_templates = upsampled_templates.astype(templates.dtype) return upsampled_templates diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 955c9cee..28e7710b 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -109,6 +109,7 @@ def from_config( trough_offset_samples=template_config.trough_offset_samples, spike_length_samples=template_config.spike_length_samples, spikes_per_unit=template_config.spikes_per_unit, + # realign handled in advance below, not needed in kwargs # realign_peaks=template_config.realign_peaks, realign_max_sample_shift=template_config.realign_max_sample_shift, denoising_rank=template_config.denoising_rank, diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index c7b29b9f..66367999 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -73,7 +73,6 @@ def __str__(self): name = self.__class__.__name__ nspikes = self.times_samples.size nunits = (np.unique(self.labels) >= 0).sum() - print(f"{nunits=}") unit_str = f"{nunits} unit" + "s" * (nunits > 1) feat_str = "" if self.extra_features: diff --git a/tests/test_templates.py b/tests/test_templates.py index 63cbd357..71fc44e9 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,7 +1,12 @@ +import tempfile +from pathlib import Path + import numpy as np import spikeinterface.core as sc from dartsort import config -from dartsort.templates import get_templates, template_util, templates +from dartsort.templates import (get_templates, pairwise, template_util, + templates) +from dartsort.util import drift_util from dartsort.util.data_util import DARTsortSorting from dredge.motion_util import get_motion_estimate @@ -148,7 +153,178 @@ def test_main_object(): ) +def test_pconv(): + # want to make sure drift handling is as expected + # design an experiment + + # 4 chans, no drift + # 3 units (superres): 0 (0,1), 1 (2,3), 3 (4) + # temps overlap like: + # 0 chan=0 z=0 + # 12 1 1 + # 23 2 2 + # 4 3 3 + t = 2 + c = 4 + temps = np.zeros((5, t, c), dtype=np.float32) + temps[0, 0, 0] = 2 + temps[1, 0, 1] = 2 + temps[2, 0, [1, 2]] = 2 + temps[3, 0, 2] = 2 + temps[4, 0, 3] = 2 + geom = np.c_[np.zeros(c), np.arange(c).astype(float)] + overlaps = {(i, i): np.square(temps[i]).sum() for i in range(5)} + overlaps[(1, 2)] = overlaps[(2, 1)] = (temps[1] * temps[2]).sum() + overlaps[(2, 3)] = overlaps[(3, 2)] = (temps[3] * temps[2]).sum() + + tdata = templates.TemplateData( + templates=temps, + unit_ids=np.array([0, 0, 1, 1, 2]), + spike_counts=np.ones(5), + registered_geom=None, + registered_template_depths_um=None, + ) + temp, sv, spat = template_util.svd_compress_templates(temps, rank=1) + print(f"{temp=} {sv=} {spat=}") + tempup = temp.reshape(5, t, 1, 1) + + with tempfile.TemporaryDirectory() as tdir: + pconvdb_path = pairwise.sparse_pairwise_conv( + Path(tdir) / "test.h5", + geom, + tdata, + temp, + tempup, + sv, + spat, + ) + pconvdb = pairwise.SparsePairwiseConv.from_h5(pconvdb_path) + assert np.all(pconvdb.pconv[0] == 0) + + for tixa in range(5): + for tixb in range(5): + ixa, ixb, pconv = pconvdb.query(tixa, tixb) + if (tixa, tixb) not in overlaps: + assert not ixa.size + assert not ixb.size + assert not pconv.size + continue + + olap = overlaps[tixa, tixb] + assert (ixa, ixb) == (tixa, tixb) + assert np.isclose(pconv.max(), olap) + + # drifting version + # rigid drift from -1 to 0 to 1, note pitch=1 + # same templates but padded + tempspad = np.pad(temps, [(0, 0), (0, 0), (1, 1)]) + print(f"{tempspad.shape=}") + temp, sv, spat = template_util.svd_compress_templates(tempspad, rank=1) + print(f"{temp.shape=} {sv.shape=} {spat.shape=}") + reg_geom = np.c_[np.zeros(c + 2), np.arange(c + 2).astype(float)] + tdata = templates.TemplateData( + templates=tempspad, + unit_ids=np.array([0, 0, 1, 1, 2]), + spike_counts=np.ones(5), + registered_geom=reg_geom, + registered_template_depths_um=np.zeros(5), + ) + geom = np.c_[np.zeros(c), np.arange(1, c + 1).astype(float)] + motion_est = get_motion_estimate(time_bin_centers_s=np.array([0., 1, 2]), displacement=[-1., 0, 1]) + + # visualize shifted temps + for tix in range(5): + print("------------------") + print(f"{tix=}") + for shift in (-1, 0, 1): + spatial_shifted = drift_util.get_waveforms_on_static_channels( + spat[tix][None], + reg_geom, + n_pitches_shift=np.array([shift]), + registered_geom=geom, + fill_value=0.0, + ) + print(f"{shift=}") + print(f"{spatial_shifted=}") + + print() + print() + print('-=' * 30) + print('=-' * 30) + print('-=' * 30) + print('=-' * 30) + print() + print() + + with tempfile.TemporaryDirectory() as tdir: + pconvdb_path = pairwise.sparse_pairwise_conv( + Path(tdir) / "test.h5", + geom, + tdata, + temp, + tempup, + sv, + spat, + motion_est=motion_est, + chunk_time_centers_s=[0, 1, 2], + ) + pconvdb = pairwise.SparsePairwiseConv.from_h5(pconvdb_path) + assert np.all(pconvdb.pconv[0] == 0) + + print(f"{pconvdb.template_shift_index=}") + + for tixa in range(5): + for tixb in range(5): + ixa, ixb, pconv = pconvdb.query(tixa, tixb, shifts_a=0, shifts_b=0) + + if (tixa, tixb) not in overlaps: + assert not ixa.size + assert not ixb.size + assert not pconv.size + continue + + olap = overlaps[tixa, tixb] + assert (ixa, ixb) == (tixa, tixb) + assert np.isclose(pconv.max(), olap) + + for tixb in range(5): + for shiftb in (-1, 0, 1): + ixa, ixb, pconv = pconvdb.query(0, tixb, shifts_a=-1, shifts_b=shiftb) + assert not ixa.size + assert not ixb.size + assert not pconv.size + + for tixb in range(5): + for shift in (-1, 0, 1): + ixa, ixb, pconv = pconvdb.query(4, tixb, shifts_a=shift, shifts_b=shift) + if tixb != 4 or shift == 1: + assert not ixa.size + assert not ixb.size + assert not pconv.size + else: + assert np.isclose(pconv.max(), 4 if shift < 1 else 0) + ixa, ixb, pconv = pconvdb.query(tixb, 4, shifts_a=shift, shifts_b=shift) + if tixb != 4 or shift == 1: + assert not ixa.size + assert not ixb.size + assert not pconv.size + else: + assert np.isclose(pconv.max(), 4) + + for shifta in (-1, 0, 1): + for shiftb in (-1, 0, 1): + for tixa in range(5): + for tixb in range(5): + ixa, ixb, pconv = pconvdb.query(tixa, tixb, shifts_a=shifta, shifts_b=shiftb) + if shifta != shiftb: + # this is because we are rigid here + assert not ixa.size + assert not ixb.size + assert not pconv.size + + if __name__ == "__main__": - # test_static_templates() - # test_drifting_templates() + test_static_templates() + test_drifting_templates() test_main_object() + test_pconv() From d0f24600b81287d45ec6e2f93c01691a24256af2 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 3 Nov 2023 12:22:17 -0400 Subject: [PATCH 12/49] Checking in before rewriting pairwise a bit --- src/dartsort/peel/matching.py | 4 - src/dartsort/templates/pairwise.py | 806 +-------------------- src/dartsort/templates/pairwise_conv.py | 148 ---- src/dartsort/templates/pairwise_util.py | 902 ++++++++++++++++++++++++ src/dartsort/templates/template_util.py | 61 ++ 5 files changed, 979 insertions(+), 942 deletions(-) delete mode 100644 src/dartsort/templates/pairwise_conv.py create mode 100644 src/dartsort/templates/pairwise_util.py diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index c04c694e..bc22eacb 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -125,10 +125,6 @@ def __init__( ("temporal_components", temporal_components), ("singular_values", singular_values), ("spatial_components", spatial_components), - ( - "upsampled_temporal_components", - self.upsampled_temporal_components.numpy(force=True).copy(), - ), ] if self.is_drifting: self.fixed_output_data.append( diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index bb3e0574..330b824a 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -1,166 +1,3 @@ -from dataclasses import dataclass, fields -from typing import Optional - -import h5py -import numpy as np -import torch -import torch.nn.functional as F -from dartsort.templates import template_util -from dartsort.util import drift_util -from dartsort.util.multiprocessing_util import get_pool -from scipy.spatial import KDTree -from scipy.spatial.distance import pdist -from tqdm.auto import tqdm - -# todo: extend this code to also handle computing pairwise -# stuff necessary for the merge! ie shifts, scaling, -# residnorm(a,b) (or min of rn(a,b),rn(b,a)???) - - -def sparse_pairwise_conv( - output_hdf5_filename, - geom, - template_data, - template_temporal_components, - template_upsampled_temporal_components, - template_singular_values, - template_spatial_components, - chunk_time_centers_s=None, - motion_est=None, - conv_ignore_threshold: float = 0.0, - coarse_approx_error_threshold: float = 0.0, - min_channel_amplitude: float = 1.0, - units_per_chunk=8, - overwrite=False, - show_progress=True, - device=None, - n_jobs=0, -): - """ - - Arguments - --------- - template_* : tensors or arrays - template SVD approximations - conv_ignore_threshold: float = 0.0 - pairs will be ignored (i.e., pconv set to 0) if their pconv - does not exceed this value - coarse_approx_error_threshold: float = 0.0 - superres will not be used if coarse pconv and superres pconv - are uniformly closer than this threshold value - - Returns - ------- - pitch_shifts : array - array of all the pitch shifts - use searchsorted to find the pitch shift ix for a pitch shift - index_table: torch sparse tensor - index_table[(pitch shift ix a, superres label a, pitch shift ix b, superres label b)] = ( - 0 - if superres pconv a,b at these shifts was below the conv_ignore_threshold - else pconv_index) - pconvs: np.ndarray - pconv[pconv_index] is a cross-correlation of two templates, summed over chans - """ - if overwrite: - pass - - ( - n_templates, - spike_length_samples, - upsampling_factor, - ) = template_upsampled_temporal_components.shape[:3] - - # find all of the co-occurring pitch shift and template pairs - temp_shift_index = get_shift_and_unit_pairs( - chunk_time_centers_s, - geom, - template_data, - motion_est=motion_est, - ) - - # check if the convolutions need to be drift-aware - # they do if we need to do any channel selection - is_drifting = not np.array_equal(temp_shift_index.all_pitch_shifts, [0]) - if template_data.registered_geom is not None: - is_drifting |= not np.array_equal(geom, template_data.registered_geom) - - # initialize pairwise conv data structures - # index_table[shifted_temp_ix(i), shifted_temp_ix(j)] = pconvix(i,j) - pconv_index_table = np.zeros( - (temp_shift_index.n_shifted_templates, temp_shift_index.n_shifted_templates), - dtype=int, - ) - # pconvs[pconvix(i,j)] = (2*spikelen-1, upsampling_factor) arr of pconv(shifted_temp(i), shifted_temp(j)) - - cur_pconv_ix = 1 - with h5py.File(output_hdf5_filename, "w") as h5: - # resizeable pconv dataset - pconv = h5.create_dataset( - "pconv", - dtype=np.float32, - shape=(1, upsampling_factor, 2 * spike_length_samples - 1), - maxshape=(None, upsampling_factor, 2 * spike_length_samples - 1), - chunks=(128, upsampling_factor, 2 * spike_length_samples - 1), - ) - - # pconv[0] is special -- it is 0. - pconv[0] = 0.0 - - # res is a ConvBatchResult - for res in compute_pairwise_convs( - template_data, - template_spatial_components, - template_singular_values, - template_temporal_components, - template_upsampled_temporal_components, - temp_shift_index.shifted_temp_ix_to_temp_ix, - temp_shift_index.shifted_temp_ix_to_shift, - geom, - cooccurrence=temp_shift_index.cooccurrence, - conv_ignore_threshold=conv_ignore_threshold, - coarse_approx_error_threshold=coarse_approx_error_threshold, - min_channel_amplitude=min_channel_amplitude, - is_drifting=is_drifting, - units_per_chunk=units_per_chunk, - n_jobs=n_jobs, - device=device, - show_progress=show_progress, - max_shift="full", - store_conv=True, - compute_max=False, - ): - if res is None: - continue - new_conv_ix = res.cconv_ix - new_conv_ix += cur_pconv_ix - pconv_index_table[ - res.shifted_temp_ix_a, res.shifted_temp_ix_b - ] = new_conv_ix - pconv.resize(cur_pconv_ix + res.cconv_up.shape[0], axis=0) - pconv[cur_pconv_ix:] = res.cconv_up - cur_pconv_ix += res.cconv_up.shape[0] - - # smaller datasets all at once - h5.create_dataset( - "template_shift_index", data=temp_shift_index.template_shift_index - ) - h5.create_dataset("pconv_index_table", data=pconv_index_table) - h5.create_dataset("shifts", data=temp_shift_index.all_pitch_shifts) - h5.create_dataset( - "shifted_temp_ix_to_temp_ix", - data=temp_shift_index.shifted_temp_ix_to_temp_ix, - ) - h5.create_dataset( - "shifted_temp_ix_to_shift", data=temp_shift_index.shifted_temp_ix_to_shift - ) - h5.create_dataset( - "shifted_temp_ix_to_unit", - data=template_data.unit_ids[temp_shift_index.shifted_temp_ix_to_temp_ix], - ) - - return output_hdf5_filename # SparsePairwiseConv.from_h5(output_hdf5_filename) - @dataclass class SparsePairwiseConv: @@ -168,9 +5,11 @@ class SparsePairwiseConv: shifts: np.ndarray # (temp_ix, shift_ix) -> shifted_temp_ix template_shift_index: torch.LongTensor - # (shifted_temp_ix a, shifted_temp_ix b) -> pconv index - pconv_index_table: torch.LongTensor - # pconv index -> pconv (upsampling, 2 * spike len - 1) + # (shifted_temp_ix a, shifted_temp_ix b) -> pair index + pair_index_table: torch.LongTensor + # (pair index, upsampling index) -> pconv index + upsampling_index_table: torch.LongTensor + # pconv index -> pconv (2 * spike len - 1,) # the zero index lands you at an all 0 pconv pconv: torch.Tensor @@ -214,13 +53,6 @@ def query( else: assert np.array_equal(self.shifts, [0.0]) - # handle upsampling - pconv = self.pconv - upsampled = upsampling_indices_b is not None - if not upsampled: - assert self.pconv.shape[1] == 1 - pconv = pconv[:, 0, :] - # get shifted template indices if shifted: shift_ix_a = np.searchsorted(self.shifts, shifts_a) @@ -240,7 +72,16 @@ def query( # we only store the upper triangle of this symmetric object min_ = np.minimum(shifted_temp_ix_a, shifted_temp_ix_b) max_ = np.maximum(shifted_temp_ix_a, shifted_temp_ix_b) - pconv_indices = self.pconv_index_table[min_, max_] + pair_indices = self.pair_index_table[min_, max_] + + # handle upsampling + if upsampling_indices_b is None: + assert self.upsampling_index_table.shape[1] == 1 + pconv_indices = self.upsampling_index_table[pair_indices, 0] + else: + pconv_indices = self.upsampling_index_table[ + pair_indices, upsampling_indices_b + ] # most users will be happy not to get a bunch of zeros for pairs that don't overlap if not return_zero_convs: @@ -248,622 +89,7 @@ def query( pconv_indices = pconv_indices[which] template_indices_a = template_indices_a[which] template_indices_b = template_indices_b[which] - if upsampling_indices_b is not None: - upsampling_indices_b = upsampling_indices_b[which] - if upsampled: - pair_convs = pconv[pconv_indices, upsampling_indices_b] - else: - pair_convs = pconv[pconv_indices] + pair_convs = self.pconv[pconv_indices] return template_indices_a, template_indices_b, pair_convs - - -def compute_pairwise_convs( - template_data, - spatial, - singular, - temporal, - temporal_up, - shifted_temp_ix_to_temp_ix, - shifted_temp_ix_to_shift, - geom, - cooccurrence, - conv_ignore_threshold=0.0, - coarse_approx_error_threshold=0.0, - min_channel_amplitude=1.0, - max_shift="full", - is_drifting=True, - store_conv=True, - compute_max=False, - units_per_chunk=8, - n_jobs=0, - device=None, - show_progress=True, -): - # chunk up coarse unit ids, go by pairs of chunks - units = np.unique(template_data.unit_ids) - jobs = [] - for start_a in range(0, units.size, units_per_chunk): - end_a = min(start_a + units_per_chunk, units.size) - for start_b in range(start_a, units.size, units_per_chunk): - end_b = min(start_b + units_per_chunk, units.size) - jobs.append((units[start_a:end_a], units[start_b:end_b])) - if show_progress: - jobs = tqdm( - jobs, smoothing=0.01, desc="Pairwise convolution", unit="pair block" - ) - - # compute the coarse templates if needed - if units.size == template_data.unit_ids.size: - # coarse templates are original templates - coarse_approx_error_threshold = 0 - if coarse_approx_error_threshold > 0: - coarse_templates = template_util.weighted_average( - template_data.unit_ids, template_data.templates, template_data.spike_counts - ) - ( - coarse_temporal, - coarse_singular, - coarse_spatial, - ) = template_util.svd_compress_templates( - coarse_templates, - rank=singular.shape[1], - min_channel_amplitude=min_channel_amplitude, - ) - - # template data to torch - spatial_singular = torch.as_tensor(spatial * singular[:, :, None]) - temporal = torch.as_tensor(temporal) - temporal_up = torch.as_tensor(temporal_up) - if coarse_approx_error_threshold > 0: - coarse_spatial_singular = torch.as_tensor( - coarse_spatial * coarse_singular[:, :, None] - ) - coarse_temporal = torch.as_tensor(coarse_temporal) - else: - coarse_spatial_singular = None - coarse_temporal = None - - n_jobs, Executor, context, rank_queue = get_pool(n_jobs, with_rank_queue=True) - - pconv_params = dict( - store_conv=store_conv, - compute_max=compute_max, - is_drifting=is_drifting, - max_shift=max_shift, - conv_ignore_threshold=conv_ignore_threshold, - coarse_approx_error_threshold=coarse_approx_error_threshold, - spatial_singular=spatial_singular, - temporal=temporal, - temporal_up=temporal_up, - coarse_spatial_singular=coarse_spatial_singular, - coarse_temporal=coarse_temporal, - unit_ids=template_data.unit_ids, - shifted_temp_ix_to_shift=shifted_temp_ix_to_shift, - shifted_temp_ix_to_temp_ix=shifted_temp_ix_to_temp_ix, - shifted_temp_ix_to_unit=template_data.unit_ids[shifted_temp_ix_to_temp_ix], - cooccurrence=cooccurrence, - geom=geom, - registered_geom=template_data.registered_geom, - ) - - with Executor( - n_jobs, - mp_context=context, - initializer=_pairwise_conv_init, - initargs=(device, rank_queue, pconv_params), - ) as pool: - yield from pool.map(_pairwise_conv_job, jobs) - - -# -- parallel job code - - -# helper class which stores parameters for _pairwise_conv_job -@dataclass -class PairwiseConvContext: - device: torch.device - - # parameters - store_conv: bool - compute_max: bool - is_drifting: bool - max_shift: int - conv_ignore_threshold: float - coarse_approx_error_threshold: float - - # superres registered templates - spatial_singular: torch.Tensor - temporal: torch.Tensor - temporal_up: torch.Tensor - coarse_spatial_singular: Optional[torch.Tensor] - coarse_temporal: Optional[torch.Tensor] - cooccurrence: torch.Tensor - - # template indexing helper arrays - unit_ids: np.ndarray - shifted_temp_ix_to_temp_ix: np.ndarray - shifted_temp_ix_to_shift: np.ndarray - shifted_temp_ix_to_unit: np.ndarray - - # only needed if is_drifting - geom: np.ndarray - registered_geom: np.ndarray - target_kdtree: Optional[KDTree] - match_distance: Optional[float] - - -_pairwise_conv_context = None - - -def _pairwise_conv_init( - device, - rank_queue, - kwargs, -): - global _pairwise_conv_context - - # figure out what device to work on - my_rank = rank_queue.get() - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - device = torch.device(device) - if device.type == "cuda" and device.index is None: - if torch.cuda.device_count() > 1: - device = torch.device("cuda", index=my_rank % torch.cuda.device_count()) - - # handle string max_shift - max_shift = kwargs.pop("max_shift", "full") - t = kwargs["temporal"].shape[1] - if max_shift == "full": - max_shift = t - 1 - elif max_shift == "valid": - max_shift = 0 - elif max_shift == "same": - max_shift = t // 2 - kwargs["max_shift"] = max_shift - - kwargs["target_kdtree"] = kwargs["match_distance"] = None - if kwargs["is_drifting"]: - kwargs["target_kdtree"] = KDTree(kwargs["geom"]) - kwargs["match_distance"] = pdist(kwargs["geom"]).min() / 2 - - _pairwise_conv_context = PairwiseConvContext(device=device, **kwargs) - - -@dataclass -class ConvBatchResult: - # arrays of length - shifted_temp_ix_a: np.ndarray - shifted_temp_ix_b: np.ndarray - # array of length such that the ith - # pair's array of upsampled convs is cconv_up[cconv_ix[i]] - cconv_ix: np.ndarray - cconv_up: Optional[np.ndarray] - max_conv: Optional[float] - best_shift: Optional[int] - - -def _pairwise_conv_job(unit_chunk): - global _pairwise_conv_context - p = _pairwise_conv_context - - units_a, units_b = unit_chunk - - # this job consists of pairs of coarse units - # lets get all shifted superres template indices corresponding to those pairs, - # and the template indices, pitch shifts, and coarse units while we're at it - shifted_temp_ix_a = np.flatnonzero(np.isin(p.shifted_temp_ix_to_unit, units_a)) - shifted_temp_ix_b = np.flatnonzero(np.isin(p.shifted_temp_ix_to_unit, units_b)) - temp_ix_a = p.shifted_temp_ix_to_temp_ix[shifted_temp_ix_a] - temp_ix_b = p.shifted_temp_ix_to_temp_ix[shifted_temp_ix_b] - shift_a = p.shifted_temp_ix_to_shift[shifted_temp_ix_a] - shift_b = p.shifted_temp_ix_to_shift[shifted_temp_ix_b] - unit_a = p.unit_ids[temp_ix_a] - unit_b = p.unit_ids[temp_ix_b] - - # get shifted spatial components - spatial_a = p.spatial_singular[temp_ix_a] - spatial_b = p.spatial_singular[temp_ix_b] - if p.is_drifting: - spatial_a = drift_util.get_waveforms_on_static_channels( - spatial_a, - p.registered_geom, - n_pitches_shift=shift_a, - registered_geom=p.geom, - target_kdtree=p.target_kdtree, - match_distance=p.match_distance, - fill_value=0.0, - ) - spatial_b = drift_util.get_waveforms_on_static_channels( - spatial_b, - p.registered_geom, - n_pitches_shift=shift_b, - registered_geom=p.geom, - target_kdtree=p.target_kdtree, - match_distance=p.match_distance, - fill_value=0.0, - ) - - # to device - spatial_a = spatial_a.to(p.device) - spatial_b = spatial_b.to(p.device) - temporal_a = p.temporal[temp_ix_a].to(p.device) - temporal_up_b = p.temporal_up[temp_ix_b].to(p.device) - - # convolve valid pairs - pair_mask = p.cooccurrence[shifted_temp_ix_a[:, None], shifted_temp_ix_b[None, :]] - pair_mask = pair_mask * (shifted_temp_ix_a[:, None] <= shifted_temp_ix_b[None, :]) - pair_mask = torch.as_tensor(pair_mask, device=p.device) - conv_ix_a, conv_ix_b, cconv = ccorrelate_up( - spatial_a, - temporal_a, - spatial_b, - temporal_up_b, - conv_ignore_threshold=p.conv_ignore_threshold, - max_shift=p.max_shift, - covisible_mask=pair_mask, - ) - if conv_ix_a is None: - return None - nco = conv_ix_a.numel() - if not nco: - return None - cconv_ix = np.arange(nco) - - # shifts may not matter - if p.is_drifting: - cconv, cconv_ix = _shift_normalize( - cconv, - cconv_ix, - temp_ix_a[conv_ix_a.cpu()], - shift_a[conv_ix_a.cpu()], - temp_ix_b[conv_ix_b.cpu()], - shift_b[conv_ix_b.cpu()], - ) - - # summarize units by coarse pconv when possible - if p.coarse_approx_error_threshold > 0: - cconv, cconv_ix = _coarse_approx( - cconv, cconv_ix, conv_ix_a, conv_ix_b, unit_a, unit_b, p - ) - - # for use in deconv residual distance merge - # TODO: actually probably need to do the real objective here with - # scaling. only need to do that bc of scaling right? - # makes it kind of a pain, because then we need to go pairwise - # (deconv objective is not symmetric) - max_conv = best_shift = None - if p.compute_max: - cconv_ = cconv.reshape(nco, cconv.shape[1] * cconv.shape[2]) - max_conv, max_index = cconv_.max(dim=1) - max_up, max_sample = np.unravel_index( - max_index.numpy(force=True), shape=cconv.shape[1:] - ) - best_shift = max_sample - (p.max_shift + 1) - # if upsample>half nup, round max shift up - best_shift += np.rint(max_up / cconv.shape[1]).astype(int) - - return ConvBatchResult( - shifted_temp_ix_a[conv_ix_a.numpy(force=True)], - shifted_temp_ix_b[conv_ix_b.numpy(force=True)], - cconv_ix, - cconv.numpy(force=True) if cconv is not None else None, - max_conv.numpy(force=True) if max_conv is not None else None, - best_shift, - ) - - -# -- library code -# template index and shift pairs -# pairwise low-rank cross-correlation - - -@dataclass -class TemplateShiftIndex: - """Return value for get_shift_and_unit_pairs""" - - n_shifted_templates: int - # shift index -> shift - all_pitch_shifts: np.ndarray - # (template ix, shift index) -> shifted template index - template_shift_index: np.ndarray - # (shifted temp ix, shifted temp ix) -> did these appear at the same time - cooccurrence: np.ndarray - shifted_temp_ix_to_temp_ix: np.ndarray - shifted_temp_ix_to_shift: np.ndarray - - -def static_template_shift_index(n_templates): - temp_ixs = np.arange(n_templates) - return TemplateShiftIndex( - n_templates, - np.zeros(1), - temp_ixs[:, None], - np.ones((n_templates, n_templates), dtype=bool), - temp_ixs, - np.zeros_like(temp_ixs), - ) - - -def get_shift_and_unit_pairs( - chunk_time_centers_s, - geom, - template_data, - motion_est=None, -): - n_templates = len(template_data.templates) - if motion_est is None: - # no motion case - return static_template_shift_index(n_templates) - - # all observed pitch shift values - all_pitch_shifts = np.empty(shape=(0,), dtype=int) - temp_ixs = np.arange(n_templates) - # set of (template idx, shift) - template_shift_pairs = np.empty(shape=(0, 2), dtype=int) - pitch = drift_util.get_pitch(geom) - - for t_s in chunk_time_centers_s: - # see the fn `templates_at_time` - unregistered_depths_um = drift_util.invert_motion_estimate( - motion_est, t_s, template_data.registered_template_depths_um - ) - pitch_shifts = drift_util.get_spike_pitch_shifts( - depths_um=template_data.registered_template_depths_um, - pitch=pitch, - registered_depths_um=unregistered_depths_um, - ) - pitch_shifts = pitch_shifts.astype(int) - - # get unique pitch/unit shift pairs in chunk - template_shift = np.c_[temp_ixs, pitch_shifts] - - # update full set - all_pitch_shifts = np.union1d(all_pitch_shifts, pitch_shifts) - template_shift_pairs = np.unique( - np.concatenate((template_shift_pairs, template_shift), axis=0), axis=0 - ) - - n_shifts = len(all_pitch_shifts) - n_template_shift_pairs = len(template_shift_pairs) - - # index template/shift pairs: template_shift_index[template_ix, shift_ix] = shifted template index - # fill with an invalid index - template_shift_index = np.full((n_templates, n_shifts), n_template_shift_pairs) - shift_ix = np.searchsorted(all_pitch_shifts, template_shift_pairs[:, 1]) - assert np.array_equal(all_pitch_shifts[shift_ix], template_shift_pairs[:, 1]) - template_shift_index[template_shift_pairs[:, 0], shift_ix] = np.arange( - n_template_shift_pairs - ) - shifted_temp_ix_to_temp_ix = template_shift_pairs[:, 0] - shifted_temp_ix_to_shift = template_shift_pairs[:, 1] - - # co-occurrence matrix: do these shifted templates appear together? - cooccurrence = np.eye(n_template_shift_pairs, dtype=bool) - for t_s in chunk_time_centers_s: - unregistered_depths_um = drift_util.invert_motion_estimate( - motion_est, t_s, template_data.registered_template_depths_um - ) - pitch_shifts = drift_util.get_spike_pitch_shifts( - depths_um=template_data.registered_template_depths_um, - pitch=pitch, - registered_depths_um=unregistered_depths_um, - ) - pitch_shifts = pitch_shifts.astype(int) - pitch_shift_ix = np.searchsorted(all_pitch_shifts, pitch_shifts) - - shifted_temp_ixs = template_shift_index[temp_ixs, pitch_shift_ix] - cooccurrence[shifted_temp_ixs[:, None], shifted_temp_ixs[None, :]] = 1 - - return TemplateShiftIndex( - n_template_shift_pairs, - all_pitch_shifts, - template_shift_index, - cooccurrence, - shifted_temp_ix_to_temp_ix, - shifted_temp_ix_to_shift, - ) - - -def ccorrelate_up( - spatial_a, - temporal_a, - spatial_b, - temporal_b, - conv_ignore_threshold=0.0, - max_shift="full", - covisible_mask=None, - batch_size=128, -): - """Convolve all pairs of low-rank templates - - This uses too much memory to run on all pairs at once. - - Templates Ka = Sa Ta, Kb = Sb Tb. The channel-summed convolution is - (Ka (*) Kb) = sum_c Ka(c) * Kb(c) - = (Sb.T @ Ka) (*) Tb - = (Sb.T @ Sa @ Ta) (*) Tb - where * is cross-correlation, and (*) is channel (or rank) summed. - - We use full-height conv2d to do rank-summed convs. - - Returns - ------- - covisible_a, covisible_b : tensors of indices - Both have shape (nco,), where nco is the number of templates - whose pairwise conv exceeds conv_ignore_threshold. - So, zip(covisible_a, covisible_b) is the set of co-visible pairs. - cconv : torch.Tensor - Shape is (nco, nup, 2 * max_shift + 1) - All cross-correlations for pairs of templates (templates in b - can be upsampled.) - If max_shift is full, then 2*max_shift+1=2t-1. - """ - na, rank, nchan = spatial_a.shape - nb, rank_, nchan_ = spatial_b.shape - assert rank == rank_ - assert nchan == nchan_ - na_, t, rank_ = temporal_a.shape - assert na == na_ - assert rank_ == rank - nb_, t_, nup, rank_ = temporal_b.shape - assert nb == nb_ - assert t == t_ - assert rank == rank_ - if covisible_mask is not None: - assert covisible_mask.shape == (na, nb) - - # no need to convolve templates which do not overlap enough - covisible = ( - torch.sqrt(torch.square(spatial_a).sum(1)) - @ torch.sqrt(torch.square(spatial_b).sum(1)).T - ) - covisible = covisible > conv_ignore_threshold - if covisible_mask is not None: - covisible *= covisible_mask - covisible_a, covisible_b = torch.nonzero(covisible, as_tuple=True) - nco = covisible_a.numel() - if not nco: - return None, None, None - - # batch over nco for memory reasons - cconv = torch.zeros( - (nco, nup, 2 * max_shift + 1), dtype=spatial_a.dtype, device=spatial_a.device - ) - for istart in range(0, nco, batch_size): - iend = min(istart + batch_size, nco) - co_a = covisible_a[istart:iend] - co_b = covisible_b[istart:iend] - nco_ = iend - istart - - # want conv filter: nco, 1, rank, t - template_a = torch.bmm(temporal_a, spatial_a) - conv_filt = torch.bmm(spatial_b[co_b], template_a[co_a].mT) - conv_filt = conv_filt[:, None] # (nco, 1, rank, t) - - # nup, nco, rank, t - conv_in = temporal_b[co_b].permute(2, 0, 3, 1) - - # conv2d: - # depthwise, chans=nco. batch=1. h=rank. w=t. out: nup, nco, 1, 2p+1. - # input (conv_in): nup, nco, rank, t. - # filters (conv_filt): nco, 1, rank, t. (groups=nco). - cconv_ = F.conv2d(conv_in, conv_filt, padding=(0, max_shift), groups=nco_) - cconv[istart:iend] = cconv_[:, :, 0, :].permute(1, 0, 2) # nco, nup, time - - # more stringent covisibility - if conv_ignore_threshold > 0: - max_val = cconv.reshape(nco, -1).abs().max(dim=1).values - vis = max_val > conv_ignore_threshold - cconv = cconv[vis] - covisible_a = covisible_a[vis] - covisible_b = covisible_b[vis] - - return covisible_a, covisible_b, cconv - - -# -- helpers - - -def _coarse_approx(cconv, cconv_ix, conv_ix_a, conv_ix_b, unit_a, unit_b, p): - # figure out coarse templates to correlate - conv_ix_a = conv_ix_a.cpu() - conv_ix_b = conv_ix_b.cpu() - conv_unit_a = unit_a[conv_ix_a] - conv_unit_b = unit_b[conv_ix_b] - coarse_units_a = np.unique(conv_unit_a) - coarse_units_b = np.unique(conv_unit_b) - coarsecovis = np.zeros((coarse_units_a.size, coarse_units_b.size), dtype=bool) - coarsecovis[ - np.searchsorted(coarse_units_a, conv_unit_a), - np.searchsorted(coarse_units_b, conv_unit_b), - ] = True - - # correlate them - coarse_ix_a, coarse_ix_b, coarse_cconv = ccorrelate_up( - p.coarse_spatial_singular[coarse_units_a].to(p.device), - p.coarse_temporal[coarse_units_a].to(p.device), - p.coarse_spatial_singular[coarse_units_b].to(p.device), - p.coarse_temporal[coarse_units_b].unsqueeze(2).to(p.device), - conv_ignore_threshold=p.conv_ignore_threshold, - max_shift=p.max_shift, - covisible_mask=torch.as_tensor(coarsecovis, device=p.device), - ) - if coarse_ix_a is None: - return cconv, cconv_ix - - coarse_units_a = np.atleast_1d(coarse_units_a[coarse_ix_a.cpu()]) - coarse_units_b = np.atleast_1d(coarse_units_b[coarse_ix_b.cpu()]) - - # find coarse units which well summarize the fine cconvs - for coarse_unit_a, coarse_unit_b, conv in zip( - coarse_units_a, coarse_units_b, coarse_cconv - ): - # check good approx. if not, continue - in_pair = np.flatnonzero( - (conv_unit_a == coarse_unit_a) & (conv_unit_b == coarse_unit_b) - ) - assert in_pair.size - fine_cconvs = cconv[cconv_ix[in_pair]] - approx_err = (fine_cconvs - conv[None]).abs().max() - if not approx_err < p.coarse_approx_error_threshold: - continue - - # replace first fine cconv with the coarse cconv - cconv[cconv_ix[in_pair[0]]] = conv - # set all fine cconv ix to the index of that first one - cconv_ix[in_pair] = cconv_ix[in_pair[0]] - - # re-index and subset cconvs - cconv_ix_subset, new_cconv_ix = np.unique(cconv_ix, return_inverse=True) - cconv = cconv[cconv_ix_subset] - return cconv, new_cconv_ix - - -def _shift_normalize( - cconv, cconv_ix, temp_ix_a, shift_a, temp_ix_b, shift_b, atol=1e-1 -): - pairs_done = set() - for ua, ub in zip(temp_ix_a, temp_ix_b): - if (ua, ub) in pairs_done: - continue - pairs_done.add((ua, ub)) - - in_pair = np.flatnonzero((temp_ix_a == ua) & (temp_ix_b == ub)) - diffs = shift_a[in_pair] - shift_b[in_pair] - changed = False - for diff in np.unique(diffs): - in_diff = in_pair[diffs == diff] - - cconvs = cconv[cconv_ix[in_diff]] - meanconv = cconvs.mean(0, keepdims=True) - err = (cconvs - meanconv).abs().max() - if err > atol: - continue - changed = True - cconv[cconv_ix[in_diff[0]]] = meanconv - cconv_ix[in_diff] = cconv_ix[in_diff[0]] - if changed: - pairs_done.remove((ua, ub)) - - for ua, ub in zip(temp_ix_a, temp_ix_b): - if (ua, ub) in pairs_done: - continue - pairs_done.add((ua, ub)) - - in_pair = np.flatnonzero((temp_ix_a == ua) & (temp_ix_b == ub)) - cconvs = cconv[cconv_ix[in_pair]] - meanconv = cconvs.mean(0, keepdims=True) - err = (cconvs - meanconv).abs().max() - if err > atol: - continue - - cconv[cconv_ix[in_pair[0]]] = meanconv - cconv_ix[in_pair] = cconv_ix[in_pair[0]] - - # re-index and subset cconvs - cconv_ix_subset, new_cconv_ix = np.unique(cconv_ix, return_inverse=True) - cconv = cconv[cconv_ix_subset] - return cconv, new_cconv_ix diff --git a/src/dartsort/templates/pairwise_conv.py b/src/dartsort/templates/pairwise_conv.py deleted file mode 100644 index 26b5b76a..00000000 --- a/src/dartsort/templates/pairwise_conv.py +++ /dev/null @@ -1,148 +0,0 @@ -from dataclasses import dataclass - - -def sparse_pairwise_conv( - template_data, - template_temporal_components, - template_upsampled_temporal_components, - template_singular_values, - template_spatial_components, - chunk_time_centers_s=None, - motion_est=None, - conv_ignore_threshold: float = 0.0, - coarse_approx_error_threshold: float = 0.0, -): - """ - - Arguments - --------- - template_* : tensors or arrays - template SVD approximations - conv_ignore_threshold: float = 0.0 - pairs will be ignored (i.e., pconv set to 0) if their pconv - does not exceed this value - coarse_approx_error_threshold: float = 0.0 - superres will not be used if coarse pconv and superres pconv - are uniformly closer than this threshold value - - Returns - ------- - pitch_shifts : array - array of all the pitch shifts - use searchsorted to find the pitch shift ix for a pitch shift - index_table: torch sparse tensor - index_table[(pitch shift ix a, superres label a, pitch shift ix b, superres label b)] = ( - -1 - if superres pconv a,b at these shifts was below the conv_ignore_threshold - else pconv_index) - pconvs: np.ndarray - pconv[pconv_index] is a cross-correlation of two templates, summed over chans - """ - # find all of the co-occurring pitch shift and unit id pairs - all_pitch_shifts, shift_unit_pairs = get_shift_and_unit_pairs( - chunk_time_centers_s, - geom, - template_data, - motion_est=motion_est, - ) - - -# defining this dtype, which represents a pair of units and shifts, -# allows us to use numpy's 1d set functions on these pairs -shift_unit_pair_dtype = np.dtype( - [("unita", int), ("shifta", int), ("unitb", int), ("shiftb", int)] -) - - -class PairwiseConvContext: - def __init__( - self, - coarse_spatial, - coarse_singular, - coarse_f_temporal, - spatial, - singular, - f_temporal, - f_temporal_up, - geom, - registered_geom, - ): - - -def _pairwise_conv_job( - units_a, - units_b, -): - """units_a,b are chunks of original (non-superres) unit labels""" - - # returns - # array of type shift_unit_pair_dtype - # array of the same length containing - # - -1 or an index into the next list - # list of pconvs, indexed by previous list - - # determine co-visible shift/unit pairs - - # extract template data for left and right entries of each - # pair into npairs-len structures - # "depthwise" convolve these two structures - - # same for the coarse templates - - # when max pconv is < co-correlation threshold: - # - key list entry gets -1 - - # now, the coarse part - # for each pair of coarse units, check if the max difference - # of coarse pconv and all superres pconvs is small enough, - # and use an id for the (temporally upsampled) coarse pconv if so - - - pass - - - -def get_shift_and_unit_pairs( - chunk_time_centers_s, - geom, - template_data, - motion_est=None, -): - if motion_est is None: - return None, None - - # all observed pitch shift values - all_pitch_shifts = [] - # set of (unit a, shift a, unit b, shift b) - # units are unit ids, not (superres) template indices - shift_unit_pairs = [] - - for t_s in chunk_time_centers_s: - # see the fn `templates_at_time` - unregistered_depths_um = drift_util.invert_motion_estimate( - motion_est, t_s, template_data.registered_template_depths_um - ) - pitch_shifts = drift_util.get_spike_pitch_shifts( - depths_um=template_data.registered_template_depths_um, - geom=geom, - registered_depths_um=unregistered_depths_um, - ) - - # get unique pitch/unit shift pairs in chunk - pitch_and_unit = np.c_[td.unit_ids, pitch_shifts.astype(int)] - pairs = np.concatenate( - np.broadcast_arrays( - pitch_and_unit[:, None, :], - pitch_and_unit[None, :, :], - ), - axis=2, - ) - pairs = pairs.reshape(len(td.unit_ids) ** 2, 4) - pairs = np.ascontiguousarray(pairs).view(shift_unit_pair_dtype) - unique_pairs_in_chunk = np.unique(pairs) - - # update full set - all_pitch_shifts = np.union1d(all_pitch_shifts, pitch_shifts) - shift_unit_pairs = np.union1d(shift_unit_pairs, unique_pairs_in_chunk) - - return all_pitch_shifts, shift_unit_pairs diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py new file mode 100644 index 00000000..0bf992ea --- /dev/null +++ b/src/dartsort/templates/pairwise_util.py @@ -0,0 +1,902 @@ +from dataclasses import dataclass, fields +from typing import Optional + +import h5py +import numpy as np +import torch +import torch.nn.functional as F +from dartsort.templates import template_util +from dartsort.util import drift_util +from dartsort.util.multiprocessing_util import get_pool +from scipy.spatial import KDTree +from scipy.spatial.distance import pdist +from tqdm.auto import tqdm + +# todo: extend this code to also handle computing pairwise +# stuff necessary for the merge! ie shifts, scaling, +# residnorm(a,b) (or min of rn(a,b),rn(b,a)???) + + +def sparse_pairwise_conv( + output_hdf5_filename, + geom, + template_data, + template_temporal_components, + template_upsampled_temporal_components, + template_singular_values, + template_spatial_components, + chunk_time_centers_s=None, + motion_est=None, + conv_ignore_threshold: float = 0.0, + coarse_approx_error_threshold: float = 0.0, + min_channel_amplitude: float = 1.0, + units_per_chunk=8, + overwrite=False, + show_progress=True, + device=None, + n_jobs=0, +): + """ + + Arguments + --------- + template_* : tensors or arrays + template SVD approximations + conv_ignore_threshold: float = 0.0 + pairs will be ignored (i.e., pconv set to 0) if their pconv + does not exceed this value + coarse_approx_error_threshold: float = 0.0 + superres will not be used if coarse pconv and superres pconv + are uniformly closer than this threshold value + + Returns + ------- + pitch_shifts : array + array of all the pitch shifts + use searchsorted to find the pitch shift ix for a pitch shift + index_table: torch sparse tensor + index_table[(pitch shift ix a, superres label a, pitch shift ix b, superres label b)] = ( + 0 + if superres pconv a,b at these shifts was below the conv_ignore_threshold + else pconv_index) + pconvs: np.ndarray + pconv[pconv_index] is a cross-correlation of two templates, summed over chans + """ + if overwrite: + pass + + ( + n_templates, + spike_length_samples, + upsampling_factor, + ) = template_upsampled_temporal_components.shape[:3] + + # find all of the co-occurring pitch shift and template pairs + temp_shift_index = get_shift_and_unit_pairs( + chunk_time_centers_s, + geom, + template_data, + motion_est=motion_est, + ) + + # check if the convolutions need to be drift-aware + # they do if we need to do any channel selection + is_drifting = not np.array_equal(temp_shift_index.all_pitch_shifts, [0]) + if template_data.registered_geom is not None: + is_drifting |= not np.array_equal(geom, template_data.registered_geom) + + # initialize pairwise conv data structures + # index_table[shifted_temp_ix(i), shifted_temp_ix(j)] = pconvix(i,j) + pair_index_table = np.zeros( + (temp_shift_index.n_shifted_templates, temp_shift_index.n_shifted_templates), + dtype=int, + ) + upsampling_index_table = np.zeros( + (temp_shift_index.n_shifted_templates, temp_shift_index.n_shifted_templates), + dtype=int, + ) + # pconvs[pconvix(i,j)] = (2*spikelen-1, upsampling_factor) arr of pconv(shifted_temp(i), shifted_temp(j)) + + cur_pair_ix = 1 + cur_pconv_ix = 1 + with h5py.File(output_hdf5_filename, "w") as h5: + # resizeable pconv dataset + pconv = h5.create_dataset( + "pconv", + dtype=np.float32, + shape=(1, upsampling_factor, 2 * spike_length_samples - 1), + maxshape=(None, upsampling_factor, 2 * spike_length_samples - 1), + chunks=(128, upsampling_factor, 2 * spike_length_samples - 1), + ) + + # pconv[0] is special -- it is 0. + pconv[0] = 0.0 + + # res is a ConvBatchResult + for res in compute_pairwise_convs( + template_data, + template_spatial_components, + template_singular_values, + template_temporal_components, + template_upsampled_temporal_components, + temp_shift_index.shifted_temp_ix_to_temp_ix, + temp_shift_index.shifted_temp_ix_to_shift, + geom, + cooccurrence=temp_shift_index.cooccurrence, + conv_ignore_threshold=conv_ignore_threshold, + coarse_approx_error_threshold=coarse_approx_error_threshold, + min_channel_amplitude=min_channel_amplitude, + is_drifting=is_drifting, + units_per_chunk=units_per_chunk, + n_jobs=n_jobs, + device=device, + show_progress=show_progress, + max_shift="full", + store_conv=True, + compute_max=False, + ): + if res is None: + continue + new_pair_ix = res.pair_ix + cur_pair_ix + pair_index_table[res.shifted_temp_ix_a, res.shifted_temp_ix_b] = new_pair_ix + + new_pconv_ix = res.pconv_ix + cur_pconv_ix + upsampling_index_table[new_pair_ix, res.upsampling_ix] = new_pconv_ix + + pconv.resize(cur_pconv_ix + res.cconv_up.shape[0], axis=0) + pconv[cur_pconv_ix:] = res.cconv_up + cur_pconv_ix += res.cconv_up.shape[0] + + # smaller datasets all at once + h5.create_dataset( + "template_shift_index", data=temp_shift_index.template_shift_index + ) + h5.create_dataset("pconv_index_table", data=pconv_index_table) + h5.create_dataset("shifts", data=temp_shift_index.all_pitch_shifts) + h5.create_dataset( + "shifted_temp_ix_to_temp_ix", + data=temp_shift_index.shifted_temp_ix_to_temp_ix, + ) + h5.create_dataset( + "shifted_temp_ix_to_shift", data=temp_shift_index.shifted_temp_ix_to_shift + ) + h5.create_dataset( + "shifted_temp_ix_to_unit", + data=template_data.unit_ids[temp_shift_index.shifted_temp_ix_to_temp_ix], + ) + + return output_hdf5_filename + + +# -- main general worker function + + +def compute_pairwise_convs( + template_data, + spatial, + singular, + temporal, + temporal_up, + shifted_temp_ix_to_temp_ix, + shifted_temp_ix_to_shift, + geom, + cooccurrence, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + min_channel_amplitude=1.0, + max_shift="full", + is_drifting=True, + store_conv=True, + compute_max=False, + units_per_chunk=8, + n_jobs=0, + device=None, + show_progress=True, +): + # chunk up coarse unit ids, go by pairs of chunks + units = np.unique(template_data.unit_ids) + jobs = [] + for start_a in range(0, units.size, units_per_chunk): + end_a = min(start_a + units_per_chunk, units.size) + for start_b in range(start_a, units.size, units_per_chunk): + end_b = min(start_b + units_per_chunk, units.size) + jobs.append((units[start_a:end_a], units[start_b:end_b])) + if show_progress: + jobs = tqdm( + jobs, smoothing=0.01, desc="Pairwise convolution", unit="pair block" + ) + + # compute the coarse templates if needed + if units.size == template_data.unit_ids.size: + # coarse templates are original templates + coarse_approx_error_threshold = 0 + if coarse_approx_error_threshold > 0: + coarse_templates = template_util.weighted_average( + template_data.unit_ids, template_data.templates, template_data.spike_counts + ) + ( + coarse_temporal, + coarse_singular, + coarse_spatial, + ) = template_util.svd_compress_templates( + coarse_templates, + rank=singular.shape[1], + min_channel_amplitude=min_channel_amplitude, + ) + + # template data to torch + spatial_singular = torch.as_tensor(spatial * singular[:, :, None]) + temporal = torch.as_tensor(temporal) + temporal_up = torch.as_tensor(temporal_up) + if coarse_approx_error_threshold > 0: + coarse_spatial_singular = torch.as_tensor( + coarse_spatial * coarse_singular[:, :, None] + ) + coarse_temporal = torch.as_tensor(coarse_temporal) + else: + coarse_spatial_singular = None + coarse_temporal = None + + n_jobs, Executor, context, rank_queue = get_pool(n_jobs, with_rank_queue=True) + + pconv_params = dict( + store_conv=store_conv, + compute_max=compute_max, + is_drifting=is_drifting, + max_shift=max_shift, + conv_ignore_threshold=conv_ignore_threshold, + coarse_approx_error_threshold=coarse_approx_error_threshold, + spatial_singular=spatial_singular, + temporal=temporal, + temporal_up=temporal_up, + coarse_spatial_singular=coarse_spatial_singular, + coarse_temporal=coarse_temporal, + unit_ids=template_data.unit_ids, + shifted_temp_ix_to_shift=shifted_temp_ix_to_shift, + shifted_temp_ix_to_temp_ix=shifted_temp_ix_to_temp_ix, + shifted_temp_ix_to_unit=template_data.unit_ids[shifted_temp_ix_to_temp_ix], + cooccurrence=cooccurrence, + geom=geom, + registered_geom=template_data.registered_geom, + ) + + with Executor( + n_jobs, + mp_context=context, + initializer=_pairwise_conv_init, + initargs=(device, rank_queue, pconv_params), + ) as pool: + yield from pool.map(_pairwise_conv_job, jobs) + + +# -- parallel job code + + +# helper class which stores parameters for _pairwise_conv_job +@dataclass +class PairwiseConvContext: + device: torch.device + + # parameters + store_conv: bool + compute_max: bool + is_drifting: bool + max_shift: int + conv_ignore_threshold: float + coarse_approx_error_threshold: float + + # superres registered templates + spatial_singular: torch.Tensor + temporal: torch.Tensor + temporal_up: torch.Tensor + coarse_spatial_singular: Optional[torch.Tensor] + coarse_temporal: Optional[torch.Tensor] + cooccurrence: torch.Tensor + + # template indexing helper arrays + unit_ids: np.ndarray + shifted_temp_ix_to_temp_ix: np.ndarray + shifted_temp_ix_to_shift: np.ndarray + shifted_temp_ix_to_unit: np.ndarray + + # only needed if is_drifting + geom: np.ndarray + registered_geom: np.ndarray + geom_kdtree: Optional[KDTree] + reg_geom_kdtree: Optional[KDTree] + match_distance: Optional[float] + + +_pairwise_conv_context = None + + +def _pairwise_conv_init( + device, + rank_queue, + kwargs, +): + global _pairwise_conv_context + + # figure out what device to work on + my_rank = rank_queue.get() + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + if device.type == "cuda" and device.index is None: + if torch.cuda.device_count() > 1: + device = torch.device("cuda", index=my_rank % torch.cuda.device_count()) + + # handle string max_shift + max_shift = kwargs.pop("max_shift", "full") + t = kwargs["temporal"].shape[1] + if max_shift == "full": + max_shift = t - 1 + elif max_shift == "valid": + max_shift = 0 + elif max_shift == "same": + max_shift = t // 2 + kwargs["max_shift"] = max_shift + + kwargs["geom_kdtree"] = kwargs["reg_geom_kdtree"] = kwargs["match_distance"] = None + if kwargs["is_drifting"]: + kwargs["geom_kdtree"] = KDTree(kwargs["geom"]) + kwargs["reg_geom_kdtree"] = KDTree(kwargs["registered_geom"]) + kwargs["match_distance"] = pdist(kwargs["geom"]).min() / 2 + + _pairwise_conv_context = PairwiseConvContext(device=device, **kwargs) + + +@dataclass +class ConvBatchResult: + # arrays of length + shifted_temp_ix_a: np.ndarray + shifted_temp_ix_b: np.ndarray + # array of length such that the ith + # pair's array of upsampled convs is cconv_up[cconv_ix[i]] + cconv_ix: np.ndarray + cconv_up: Optional[np.ndarray] + max_conv: Optional[float] + best_shift: Optional[int] + + +def _pairwise_conv_job(unit_chunk): + global _pairwise_conv_context + p = _pairwise_conv_context + + units_a, units_b = unit_chunk + + # this job consists of pairs of coarse units + # lets get all shifted superres template indices corresponding to those pairs, + # and the template indices, pitch shifts, and coarse units while we're at it + shifted_temp_ix_a = np.flatnonzero(np.isin(p.shifted_temp_ix_to_unit, units_a)) + shifted_temp_ix_b = np.flatnonzero(np.isin(p.shifted_temp_ix_to_unit, units_b)) + temp_ix_a = p.shifted_temp_ix_to_temp_ix[shifted_temp_ix_a] + temp_ix_b = p.shifted_temp_ix_to_temp_ix[shifted_temp_ix_b] + shift_a = p.shifted_temp_ix_to_shift[shifted_temp_ix_a] + shift_b = p.shifted_temp_ix_to_shift[shifted_temp_ix_b] + unit_a = p.unit_ids[temp_ix_a] + unit_b = p.unit_ids[temp_ix_b] + spatial_a = p.spatial_singular[temp_ix_a] + spatial_b = p.spatial_singular[temp_ix_b] + + # get shifted spatial components + if p.is_drifting: + spatial_a = drift_util.get_waveforms_on_static_channels( + spatial_a, + p.registered_geom, + n_pitches_shift=shift_a, + registered_geom=p.geom, + target_kdtree=p.geom_kdtree, + match_distance=p.match_distance, + fill_value=0.0, + ) + spatial_b = drift_util.get_waveforms_on_static_channels( + spatial_b, + p.registered_geom, + n_pitches_shift=shift_b, + registered_geom=p.geom, + target_kdtree=p.geom_kdtree, + match_distance=p.match_distance, + fill_value=0.0, + ) + + # Explanation of all of the indexing going on below. + # - pair_ix_a,b index pairs of templates with nonzero cross-correlation + # so, there are i=1,...,N pairs pair_ix_a[i], pair_ix_b[i] + # - pair_ix is a N-array of indices such that... + # - up_pconv[pair_ix[i]] = set of upsampled pconv(pair_ix_a[i], pair_ix_b[i]) + # + # these are normalized/deduplicated: pair_ix contains duplicate entries. + # conv_ix contains the unique entries. in particular, the pconv between + # pair_ix_a[i], pair_ix_b[i] is being computed as that between + # pair_ix_a[conv_ix[pair_ix[i]]] and pair_ix_b[conv_ix[pair_ix[i]]]. + # + # then, we also sparsify the temporal upsampling. + + # figure out pairs of templates to convolve + pair_ix_a, pair_ix_b, pair_ix, conv_ix, shift_diff = deduplicated_pairs( + shifted_temp_ix_a, + shifted_temp_ix_b, + spatial_a, + spatial_b, + temp_ix_a, + temp_ix_b, + shift_a=shift_a, + shift_b=shift_b, + cooccurrence=p.cooccurrence, + conv_ignore_threshold=p.conv_ignore_threshold, + geom=p.geom, + registered_geom=p.registered_geom, + reg_geom_kdtree=p.reg_geom_kdtree, + match_distance=p.match_distance, + ) + + # to device + spatial_a = spatial_a.to(p.device) + spatial_b = spatial_b.to(p.device) + temporal_a = p.temporal[temp_ix_a].to(p.device) + temporal_up_b = p.temporal_up[temp_ix_b].to(p.device) + + # convolve valid pairs + conv_ix_a, conv_ix_b, up_pconv, pair_survived = ccorrelate_up( + spatial_a, + temporal_a, + spatial_b, + temporal_up_b, + conv_ignore_threshold=p.conv_ignore_threshold, + max_shift=p.max_shift, + pair_ix_a=pair_ix_a[conv_ix], + pair_ix_b=pair_ix_b[conv_ix], + ) + if conv_ix_a is None: + return None + nco = conv_ix_a.numel() + if not nco: + return None + + pair_ix = pair_ix[pair_survived] + pair_ix_a = pair_ix_a[pair_survived] + pair_ix_b = pair_ix_b[pair_survived] + + # # summarize units by coarse pconv when possible + # if p.coarse_approx_error_threshold > 0: + # pconv, cconv_ix = _coarse_approx( + # pconv, cconv_ix, conv_ix_a, conv_ix_b, unit_a, unit_b, p + # ) + + # for use in deconv residual distance merge + # TODO: actually probably need to do the real objective here with + # scaling. only need to do that bc of scaling right? + # makes it kind of a pain, because then we need to go pairwise + # (deconv objective is not symmetric) + max_conv = best_shift = None + if p.compute_max: + cconv_ = pconv.reshape(nco, pconv.shape[1] * pconv.shape[2]) + max_conv, max_index = cconv_.max(dim=1) + max_up, max_sample = np.unravel_index( + max_index.numpy(force=True), shape=pconv.shape[1:] + ) + best_shift = max_sample - (p.max_shift + 1) + # if upsample>half nup, round max shift up + best_shift += np.rint(max_up / pconv.shape[1]).astype(int) + + print(f"end {conv_ix_a.shape=}") + print(f"end {conv_ix_b.shape=}") + print(f"end {cconv_ix.shape=}") + print(f"end {pconv.shape=}") + + return ConvBatchResult( + shifted_temp_ix_a[pair_ix_a.numpy(force=True)], + shifted_temp_ix_b[pair_ix_b.numpy(force=True)], + cconv_ix, + pconv.numpy(force=True) if pconv is not None else None, + max_conv.numpy(force=True) if max_conv is not None else None, + best_shift, + ) + + +# -- library code +# template index and shift pairs +# pairwise low-rank cross-correlation + + +@dataclass +class TemplateShiftIndex: + """Return value for get_shift_and_unit_pairs""" + + n_shifted_templates: int + # shift index -> shift + all_pitch_shifts: np.ndarray + # (template ix, shift index) -> shifted template index + template_shift_index: np.ndarray + # (shifted temp ix, shifted temp ix) -> did these appear at the same time + cooccurrence: np.ndarray + shifted_temp_ix_to_temp_ix: np.ndarray + shifted_temp_ix_to_shift: np.ndarray + + +def static_template_shift_index(n_templates): + temp_ixs = np.arange(n_templates) + return TemplateShiftIndex( + n_templates, + np.zeros(1), + temp_ixs[:, None], + np.ones((n_templates, n_templates), dtype=bool), + temp_ixs, + np.zeros_like(temp_ixs), + ) + + +def get_shift_and_unit_pairs( + chunk_time_centers_s, + geom, + template_data, + motion_est=None, +): + n_templates = len(template_data.templates) + if motion_est is None: + # no motion case + return static_template_shift_index(n_templates) + + # all observed pitch shift values + all_pitch_shifts = np.empty(shape=(0,), dtype=int) + temp_ixs = np.arange(n_templates) + # set of (template idx, shift) + template_shift_pairs = np.empty(shape=(0, 2), dtype=int) + pitch = drift_util.get_pitch(geom) + + for t_s in chunk_time_centers_s: + # see the fn `templates_at_time` + unregistered_depths_um = drift_util.invert_motion_estimate( + motion_est, t_s, template_data.registered_template_depths_um + ) + pitch_shifts = drift_util.get_spike_pitch_shifts( + depths_um=template_data.registered_template_depths_um, + pitch=pitch, + registered_depths_um=unregistered_depths_um, + ) + pitch_shifts = pitch_shifts.astype(int) + + # get unique pitch/unit shift pairs in chunk + template_shift = np.c_[temp_ixs, pitch_shifts] + + # update full set + all_pitch_shifts = np.union1d(all_pitch_shifts, pitch_shifts) + template_shift_pairs = np.unique( + np.concatenate((template_shift_pairs, template_shift), axis=0), axis=0 + ) + + n_shifts = len(all_pitch_shifts) + n_template_shift_pairs = len(template_shift_pairs) + + # index template/shift pairs: template_shift_index[template_ix, shift_ix] = shifted template index + # fill with an invalid index + template_shift_index = np.full((n_templates, n_shifts), n_template_shift_pairs) + shift_ix = np.searchsorted(all_pitch_shifts, template_shift_pairs[:, 1]) + assert np.array_equal(all_pitch_shifts[shift_ix], template_shift_pairs[:, 1]) + template_shift_index[template_shift_pairs[:, 0], shift_ix] = np.arange( + n_template_shift_pairs + ) + shifted_temp_ix_to_temp_ix = template_shift_pairs[:, 0] + shifted_temp_ix_to_shift = template_shift_pairs[:, 1] + + # co-occurrence matrix: do these shifted templates appear together? + cooccurrence = np.eye(n_template_shift_pairs, dtype=bool) + for t_s in chunk_time_centers_s: + unregistered_depths_um = drift_util.invert_motion_estimate( + motion_est, t_s, template_data.registered_template_depths_um + ) + pitch_shifts = drift_util.get_spike_pitch_shifts( + depths_um=template_data.registered_template_depths_um, + pitch=pitch, + registered_depths_um=unregistered_depths_um, + ) + pitch_shifts = pitch_shifts.astype(int) + pitch_shift_ix = np.searchsorted(all_pitch_shifts, pitch_shifts) + + shifted_temp_ixs = template_shift_index[temp_ixs, pitch_shift_ix] + cooccurrence[shifted_temp_ixs[:, None], shifted_temp_ixs[None, :]] = 1 + + return TemplateShiftIndex( + n_template_shift_pairs, + all_pitch_shifts, + template_shift_index, + cooccurrence, + shifted_temp_ix_to_temp_ix, + shifted_temp_ix_to_shift, + ) + + +def ccorrelate_up( + spatial_a, + temporal_a, + spatial_b, + temporal_b, + upsampling_compression_map=None, + conv_ignore_threshold=0.0, + max_shift="full", + covisible_mask=None, + pair_ix_a=None, + pair_ix_b=None, + batch_size=128, +): + """Convolve all pairs of low-rank templates + + This uses too much memory to run on all pairs at once. + + Templates Ka = Sa Ta, Kb = Sb Tb. The channel-summed convolution is + (Ka (*) Kb) = sum_c Ka(c) * Kb(c) + = (Sb.T @ Ka) (*) Tb + = (Sb.T @ Sa @ Ta) (*) Tb + where * is cross-correlation, and (*) is channel (or rank) summed. + + We use full-height conv2d to do rank-summed convs. + + upsampling_compression_map + (n_templates, upsampling_factor) + + + Returns + ------- + covisible_a, covisible_b : tensors of indices + Both have shape (nco,), where nco is the number of templates + whose pairwise conv exceeds conv_ignore_threshold. + So, zip(covisible_a, covisible_b) is the set of co-visible pairs. + cconv : torch.Tensor + Shape is (nco, nup, 2 * max_shift + 1) + All cross-correlations for pairs of templates (templates in b + can be upsampled.) + If max_shift is full, then 2*max_shift+1=2t-1. + """ + na, rank, nchan = spatial_a.shape + nb, rank_, nchan_ = spatial_b.shape + assert rank == rank_ + assert nchan == nchan_ + na_, t, rank_ = temporal_a.shape + assert na == na_ + assert rank_ == rank + nb_, t_, nup, rank_ = temporal_b.shape + assert nb == nb_ + assert t == t_ + assert rank == rank_ + if covisible_mask is not None: + assert covisible_mask.shape == (na, nb) + + # no need to convolve templates which do not overlap enough + if pair_ix_a is None: + covisible = ( + torch.sqrt(torch.square(spatial_a).sum(1)) + @ torch.sqrt(torch.square(spatial_b).sum(1)).T + ) + covisible = covisible > conv_ignore_threshold + if covisible_mask is not None: + covisible *= covisible_mask + covisible_a, covisible_b = torch.nonzero(covisible, as_tuple=True) + else: + covisible_a, covisible_b = pair_ix_a, pair_ix_b + nco = covisible_a.numel() + if not nco: + return None, None, None + + # batch over nco for memory reasons + cconv = torch.zeros( + (nco, nup, 2 * max_shift + 1), dtype=spatial_a.dtype, device=spatial_a.device + ) + for istart in range(0, nco, batch_size): + iend = min(istart + batch_size, nco) + co_a = covisible_a[istart:iend] + co_b = covisible_b[istart:iend] + nco_ = iend - istart + + # want conv filter: nco, 1, rank, t + template_a = torch.bmm(temporal_a, spatial_a) + conv_filt = torch.bmm(spatial_b[co_b], template_a[co_a].mT) + conv_filt = conv_filt[:, None] # (nco, 1, rank, t) + + # nup, nco, rank, t + conv_in = temporal_b[co_b].permute(2, 0, 3, 1) + + # conv2d: + # depthwise, chans=nco. batch=1. h=rank. w=t. out: nup, nco, 1, 2p+1. + # input (conv_in): nup, nco, rank, t. + # filters (conv_filt): nco, 1, rank, t. (groups=nco). + cconv_ = F.conv2d(conv_in, conv_filt, padding=(0, max_shift), groups=nco_) + cconv[istart:iend] = cconv_[:, :, 0, :].permute(1, 0, 2) # nco, nup, time + + # more stringent covisibility + pair_survived = slice(None) + if conv_ignore_threshold > 0: + max_val = cconv.reshape(nco, -1).abs().max(dim=1).values + pair_survived = max_val > conv_ignore_threshold + cconv = cconv[pair_survived] + covisible_a = covisible_a[pair_survived] + covisible_b = covisible_b[pair_survived] + + return covisible_a, covisible_b, cconv, pair_survived + + +# -- helpers + + +def deduplicated_pairs( + shifted_temp_ix_a, + shifted_temp_ix_b, + spatialsing_a, + spatialsing_b, + temp_ix_a, + temp_ix_b, + shift_a=None, + shift_b=None, + cooccurrence=None, + conv_ignore_threshold=0.0, + geom=None, + registered_geom=None, + reg_geom_kdtree=None, + match_distance=None, +): + """Choose a set of pairs of indices from group A and B to convolve + + Some pairs of shifted templates don't overlap, so we don't need to convolve them. + Some pairs of shifted templates never show up in the recording at the same time + (what this code calls "cooccurrence"), so we don't need to convolve them. + We don't need to convolve the same pair of templates twice, just where the indices + are ordered (shifted_temp_ix_a <= shifted_temp_ix_b). + + More complicated: for each shift, a certain set of registered template channels + survives. Given that the some set of visible channels has survived for a pair of + templates at shifts shift_a and shift_b, their cross-correlation at these shifts + is the same as the one at shift_a_prime and shift_b_prime if the same exact channels + survived at shift_a_prime and shift_b_prime and if + shift_a-shift_b == shift_a_prime-shift_b_prime. + + Returns + ------- + pair_ix_a, pair_ix_b + Size < original number of shifted templates a,b + The indices of shifted templates which overlap enough to be + co-visible. So, these are subsets of shifted_temp_ix_a,b + dedup_ix + Size == pair_ix_a,b size + Subsets of conv_ix_a,b, so that the xcorr of templates + shifted_temp_ix_a[pair_ix_a[i]], shifted_temp_ix_b[pair_ix_b[i]] + is the same as that of + shifted_temp_ix_a[conv_ix[dedup_ix[i]], conv_ix[dedup_ix[i]]] + conv_ix + Size < original number of shifted templates a,b + Pairs of templates which should actually be convolved + shift_diff + Optional. if not None, same size as pair_ix_a + shift_a - shift_b for this pair + """ + # check spatially overlapping + chan_amp_a = torch.sqrt(torch.square(spatialsing_a).sum(1)) + chan_amp_b = torch.sqrt(torch.square(spatialsing_b).sum(1)) + pair = chan_amp_a @ chan_amp_b.T + pair = pair > conv_ignore_threshold + + # co-occurrence + if cooccurrence is not None: + pair *= cooccurrence + + # mask out lower triangle + pair *= shifted_temp_ix_a[:, None] <= shifted_temp_ix_b[None, :] + pair_ix_a, pair_ix_b = torch.nonzero(pair, as_tuple=True) + nco = pair_ix_a.numel() + + # if no shifting, deduplication is the identity + if shift_a is None: + assert shift_b is None + nco_range = torch.arange(nco, device=pair_ix_a.device) + return pair_ix_a, pair_ix_b, nco_range, nco_range, None + + # shift deduplication. algorithm: + # 1 for each shifted template, determine the set of registered channels + # which it occupies + # 2 assign each such set an ID (an "active channel ID") + # - // then a pair of shifted templates' xcorr is a function of the pair + # // of active channel IDs and the difference of shifts + # 3 figure out the set of unique (active chan id a, active chan id b, shift diff a,b) + # combinations in each pair of units + + # 1: get active channel neighborhoods as many-hot len(reg_geom)-vectors + active_chans_a = drift_util.get_waveforms_on_static_channels( + (chan_amp_a > 0).numpy(force=True), + geom, + n_pitches_shift=-shift_a, + registered_geom=registered_geom, + target_kdtree=reg_geom_kdtree, + match_distance=match_distance, + fill_value=0, + ) + active_chans_b = drift_util.get_waveforms_on_static_channels( + (chan_amp_b > 0).numpy(force=True), + geom, + n_pitches_shift=-shift_b, + registered_geom=registered_geom, + target_kdtree=reg_geom_kdtree, + match_distance=match_distance, + fill_value=0, + ) + # 2: assign IDs to each such vector + _, active_chan_ids_a = np.unique(active_chans_a, axis=0, return_inverse=True) + _, active_chan_ids_b = np.unique(active_chans_b, axis=0, return_inverse=True) + + # 3 + temp_ix_a = temp_ix_a[pair_ix_a] + temp_ix_b = temp_ix_b[pair_ix_b] + # get the relative shifts + shift_a = shift_a[pair_ix_a] + shift_b = shift_b[pair_ix_b] + shift_diff = shift_a - shift_b + + # figure out combinations + conv_determiners = np.c_[ + temp_ix_a, + active_chan_ids_a[pair_ix_a], + temp_ix_b, + active_chan_ids_b[pair_ix_b], + shift_diff, + ] + # conv_ix: indices of unique determiners + # dedup_ix: which representative does each pair belong to + _, conv_ix, dedup_ix = np.unique( + conv_determiners, axis=0, return_index=True, return_inverse=True + ) + + return pair_ix_a, pair_ix_b, dedup_ix, conv_ix, shift_diff + + + +def _coarse_approx(cconv, cconv_ix, conv_ix_a, conv_ix_b, unit_a, unit_b, p): + # figure out coarse templates to correlate + conv_ix_a = conv_ix_a.cpu() + conv_ix_b = conv_ix_b.cpu() + conv_unit_a = unit_a[conv_ix_a] + conv_unit_b = unit_b[conv_ix_b] + coarse_units_a = np.unique(conv_unit_a) + coarse_units_b = np.unique(conv_unit_b) + coarsecovis = np.zeros((coarse_units_a.size, coarse_units_b.size), dtype=bool) + coarsecovis[ + np.searchsorted(coarse_units_a, conv_unit_a), + np.searchsorted(coarse_units_b, conv_unit_b), + ] = True + + # correlate them + coarse_ix_a, coarse_ix_b, coarse_cconv = ccorrelate_up( + p.coarse_spatial_singular[coarse_units_a].to(p.device), + p.coarse_temporal[coarse_units_a].to(p.device), + p.coarse_spatial_singular[coarse_units_b].to(p.device), + p.coarse_temporal[coarse_units_b].unsqueeze(2).to(p.device), + conv_ignore_threshold=p.conv_ignore_threshold, + max_shift=p.max_shift, + covisible_mask=torch.as_tensor(coarsecovis, device=p.device), + ) + if coarse_ix_a is None: + return cconv, cconv_ix + + coarse_units_a = np.atleast_1d(coarse_units_a[coarse_ix_a.cpu()]) + coarse_units_b = np.atleast_1d(coarse_units_b[coarse_ix_b.cpu()]) + + # find coarse units which well summarize the fine cconvs + for coarse_unit_a, coarse_unit_b, conv in zip( + coarse_units_a, coarse_units_b, coarse_cconv + ): + # check good approx. if not, continue + in_pair = np.flatnonzero( + (conv_unit_a == coarse_unit_a) & (conv_unit_b == coarse_unit_b) + ) + assert in_pair.size + fine_cconvs = cconv[cconv_ix[in_pair]] + approx_err = (fine_cconvs - conv[None]).abs().max() + if not approx_err < p.coarse_approx_error_threshold: + continue + + # replace first fine cconv with the coarse cconv + cconv[cconv_ix[in_pair[0]]] = conv + # set all fine cconv ix to the index of that first one + cconv_ix[in_pair] = cconv_ix[in_pair[0]] + + # re-index and subset cconvs + cconv_ix_subset, new_cconv_ix = np.unique(cconv_ix, return_inverse=True) + cconv = cconv[cconv_ix_subset] + return cconv, new_cconv_ix diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 780fc7d7..55320054 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -245,3 +245,64 @@ def temporally_upsample_templates( ) upsampled_templates = upsampled_templates.astype(templates.dtype) return upsampled_templates + + +def default_n_upsamples_map(ptps): + return 4 ** (ptps // 2) + + +def sparse_upsampled_templates( + templates, + ptps=None, + max_upsample=8, + n_upsamples_map=default_n_upsamples_map, + kind="cubic", +): + """Sparsely store fewer temporally upsampled copies of lower amplitude templates + + Returns + ------- + sparse_upsampled_templates : array (n_sparse_upsampled_templates, spike_length_samples) + sparse_upsampling_map : array (n_templates, max_upsample) + sparse_upsampled_templates[sparse_upsampling_map[unit, j]] is an approximation + of the jth upsampled template for this unit. for low-amplitude units, + sparse_upsampling_map[unit] will have fewer unique entries, corresponding + to fewer saved upsampled copies for that unit. + """ + n_templates = templates.shape[0] + + # how many copies should each unit get? + # sometimes users may pass temporal SVD components in instead of templates, + # so we allow them to pass in the amplitudes of the actual templates + if ptps is None: + ptps = templates.ptp(1).max(1) + assert ptps.shape == (n_templates,) + if n_upsamples_map is None: + n_upsamples = np.full(n_templates, max_upsample) + else: + n_upsamples = np.clip(n_upsamples_map(ptps), 1, max_upsample).astype(int) + + # build the sparse upsampling map + sparse_upsampling_map = np.zeros((n_templates, max_upsample), dtype=int) + upsampling_indices = [] + template_indices = [] + current_sparse_index = 0 + for i, nup in enumerate(n_upsamples): + compression = max_upsample // nup + nup = max_upsample // compression # handle divisibility failure + + # new sparse indices + sparse_upsampling_map[i] = current_sparse_index + np.arange(nup).repeat(compression) + current_sparse_index += nup + + # indices of the templates to keep in the full array of upsampled templates + upsampling_indices.extend(compression * np.arange(nup)) + template_indices.extend([i] * nup) + + # get the upsampled templates + all_upsampled_templates = temporally_upsample_templates( + templates, temporal_upsampling_factor=max_upsample, kind=kind + ) + sparse_upsampled_templates = all_upsampled_templates[template_indices, upsampling_indices] + + return sparse_upsampled_templates, sparse_upsampling_map From 21ea5925521670e01493fa494de3ef28281d9e7a Mon Sep 17 00:00:00 2001 From: julien Date: Mon, 6 Nov 2023 11:52:49 -0500 Subject: [PATCH 13/49] dipole localization --- src/dartsort/localize/localize_torch.py | 126 ++++++++++++++++++++---- src/dartsort/transform/localize.py | 62 ++++++++++++ src/spike_psvae/chunk_features.py | 3 +- src/spike_psvae/denoise.py | 16 ++- src/spike_psvae/localize_index.py | 84 +++++++++++----- src/spike_psvae/subtract.py | 12 +-- 6 files changed, 244 insertions(+), 59 deletions(-) diff --git a/src/dartsort/localize/localize_torch.py b/src/dartsort/localize/localize_torch.py index d7e4d20b..c08a615f 100644 --- a/src/dartsort/localize/localize_torch.py +++ b/src/dartsort/localize/localize_torch.py @@ -19,6 +19,7 @@ def localize_amplitude_vectors( dtype=torch.double, y0=1.0, levenberg_marquardt_kwargs=None, + th_dipole_proj_dist=250.0, ): """Localize a bunch of amplitude vectors with torch @@ -59,7 +60,7 @@ def localize_amplitude_vectors( # maybe this will become a wrapper function if we want more models. # and, this is why we return a dict, different models will have different # parameters - assert model in ("com", "pointsource") + assert model in ("com", "pointsource", "dipole") n_spikes, c = amplitude_vectors.shape n_channels_tot = len(geom) if channel_index is None: @@ -120,35 +121,76 @@ def localize_amplitude_vectors( # fixed constants in regularizers like the log barrier max_amplitudes = torch.max(amplitude_vectors, dim=1).values normalized_amp_vecs = amplitude_vectors / max_amplitudes[:, None] - + # -- torch optimize # initialize with center of mass locs = torch.column_stack((xcom, torch.full_like(xcom, y0), zcom)) - if levenberg_marquardt_kwargs is None: - levenberg_marquardt_kwargs = {} - locs, i = batched_levenberg_marquardt( - locs, - vmap_point_source_grad_and_mse, - vmap_point_source_hessian, - extra_args=(normalized_amp_vecs, in_probe_mask, local_geoms), - **levenberg_marquardt_kwargs, - ) - # finish: get alpha closed form - x, y0, z_rel = locs.T - y = F.softplus(y0) - alpha = vmap_point_source_find_alpha( - amplitude_vectors, in_probe_mask, x, y, z_rel, local_geoms - ) - z_abs = z_rel + geom[main_channels, 1] + + if model == "pointsource": + + if levenberg_marquardt_kwargs is None: + levenberg_marquardt_kwargs = {} + locs, i = batched_levenberg_marquardt( + locs, + vmap_point_source_grad_and_mse, + vmap_point_source_hessian, + extra_args=(normalized_amp_vecs, in_probe_mask, local_geoms), + **levenberg_marquardt_kwargs, + ) - return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha) + # finish: get alpha closed form + x, y0, z_rel = locs.T + y = F.softplus(y0) + alpha = vmap_point_source_find_alpha( + amplitude_vectors, in_probe_mask, x, y, z_rel, local_geoms + ) + z_abs = z_rel + geom[main_channels, 1] + return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha) + + if model == "dipole": + if levenberg_marquardt_kwargs is None: + levenberg_marquardt_kwargs = {} + locs, i = batched_levenberg_marquardt( + locs, + vmap_dipole_grad_and_mse, + vmap_dipole_hessian, + extra_args=(normalized_amp_vecs, local_geoms), + **levenberg_marquardt_kwargs, + ) + + x, y0, z_rel = locs.T + y = F.softplus(y0) + projected_dist = vmap_dipole_find_projection_distance( + normalized_amp_vecs, x, y, z_rel, local_geoms + ) + + # if projected_dist>th_dipole_proj_dist: return the loc values from pointsource -# -- point source model library functions + pointsource_spikes = torch.nonzero(projected_dist>th_dipole_proj_dist, as_tuple=True) + + locs_pointsource_spikes, i = batched_levenberg_marquardt( + locs[pointsource_spikes], + vmap_point_source_grad_and_mse, + vmap_point_source_hessian, + extra_args=(normalized_amp_vecs[pointsource_spikes], in_probe_mask, local_geoms[pointsource_spikes]), + **levenberg_marquardt_kwargs, + ) + x_pointsource_spikes, y0_pointsource_spikes, z_rel_pointsource_spikes = locs.T + y_pointsource_spikes = F.softplus(y0_pointsource_spikes) + + x[pointsource_spikes] = x_pointsource_spikes + y[pointsource_spikes] = y_pointsource_spikes + z_rel[pointsource_spikes] = z_rel_pointsource_spikes + + z_abs = z_rel + geom[main_channels, 1] + + return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=projected_dist) -def point_source_amplitude_at(x, y, z, alpha, local_geom): +# -- point source / dipole model library functions +def point_source_amplitude_at(x, y, z, local_geom): """Point source model predicted amplitude at local_geom given location""" dxs = torch.square(x - local_geom[:, 0]) dzs = torch.square(z - local_geom[:, 1]) @@ -165,6 +207,21 @@ def point_source_find_alpha(amp_vec, channel_mask, x, y, z, local_geoms): ) return alpha +def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom): + """We can solve for the brightness (alpha) of the source in closed form given x,y,z""" + + dxs = x - local_geom[:, 0] + dzs = z - local_geom[:, 1] + dys = y + duv = torch.tensor([dxs, dys, dzs]) + X = duv / torch.pow(torch.sum(torch.square(duv)), 3/2) + beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, normalized_amp_vec)) + beta /= torch.sqrt(torch.square(beta).sum()) + dipole_planar_direction = torch.sqrt(np.torch(beta[[0, 2]]).sum()) + closest_chan = torch.square(duv).sum(1).argmin() + min_duv = duv[closest_chan] + val_th = torch.sqrt(torch.square(min_duv).sum())/dipole_planar_direction + return val_th def point_source_mse( loc, amplitude_vector, channel_mask, local_geom, logbarrier=True @@ -204,8 +261,35 @@ def point_source_mse( # obj -= torch.log(1000.0 - torch.sqrt(torch.square(x) + torch.square(z))).sum() / 10000.0 return obj +def dipole_mse(loc, amplitude_vector, local_geom, logbarrier=True): + """Dipole model predicted amplitude at local_geom given location""" + + x, y0, z = loc + y = F.softplus(y0) + + dxs = x - local_geom[:, 0] + dzs = z - local_geom[:, 1] + dys = y + + duv = torch.tensor([dxs, dys, dzs]) + + X = duv / torch.pow(torch.sum(torch.square(duv)), 3/2) + + beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, (ptp / maxptp))) + qtq = torch.matmul(X, beta) + + obj = torch.square(ptp / maxptp - qtq).mean() + if logbarrier: + obj -= torch.log(10.0 * y) / 10000.0 + + return obj + # vmapped functions for use in the optimizer, and might be handy for users too vmap_point_source_grad_and_mse = vmap(grad_and_value(point_source_mse)) vmap_point_source_hessian = vmap(hessian(point_source_mse)) vmap_point_source_find_alpha = vmap(point_source_find_alpha) + +vmap_dipole_grad_and_mse = vmap(grad_and_value(dipole_mse)) +vmap_dipole_hessian = vmap(hessian(dipole_mse)) +vmap_dipole_find_projection_distance = vmap(dipole_find_projection_distance) diff --git a/src/dartsort/transform/localize.py b/src/dartsort/transform/localize.py index 29786441..083318fe 100644 --- a/src/dartsort/transform/localize.py +++ b/src/dartsort/transform/localize.py @@ -63,3 +63,65 @@ def transform(self, waveforms, max_channels=None): ] ) return localizations + +class DipoleLocalization(BaseWaveformFeaturizer): + """Order of output columns: x, y, z_abs, alpha""" + + default_name = "dipole_localizations" + shape = (4,) + dtype = torch.double + + def __init__( + self, + channel_index, + geom, + radius=None, + n_channels_subset=None, + logbarrier=True, + amplitude_kind="peak", + model="dipole", + name=None, + name_prefix="", + ): + assert amplitude_kind in ("peak", "ptp") + super().__init__( + geom=geom, + channel_index=channel_index, + name=name, + name_prefix=name_prefix, + ) + self.amplitude_kind = amplitude_kind + self.radius = radius + self.n_channels_subset = n_channels_subset + self.logbarrier = logbarrier + self.model = model + + def transform(self, waveforms, max_channels=None): + # get amplitude vectors + if self.amplitude_kind == "peak": + ampvecs = waveforms.abs().max(dim=1).values + elif self.amplitude_kind == "ptp": + ampvecs = ptp(waveforms, dim=1) + + with torch.enable_grad(): + loc_result = localize_amplitude_vectors( + ampvecs, + self.geom, + max_channels, + channel_index=self.channel_index, + radius=self.radius, + n_channels_subset=self.n_channels_subset, + logbarrier=self.logbarrier, + model=self.model, + dtype=self.dtype, + ) + + localizations = torch.column_stack( + [ + loc_result["x"], + loc_result["y"], + loc_result["z_abs"], + loc_result["alpha"], + ] + ) + return localizations diff --git a/src/spike_psvae/chunk_features.py b/src/spike_psvae/chunk_features.py index e7940148..a0defb89 100644 --- a/src/spike_psvae/chunk_features.py +++ b/src/spike_psvae/chunk_features.py @@ -392,6 +392,7 @@ def transform( else: if self.ptp_precision_decimals is not None: ptps = np.round(ptps, decimals=self.ptp_precision_decimals) + ( xs, ys, @@ -487,7 +488,7 @@ def raw_fit(self, wfs, max_channels): self.needs_fit = False self.dtype = self.tpca.components_.dtype - self.n_components = self.tpca.n_components + self.n_components = self.n_components self.components_ = self.tpca.components_ self.mean_ = self.tpca.mean_ if self.centered: # otherwise SVD diff --git a/src/spike_psvae/denoise.py b/src/spike_psvae/denoise.py index a42e992f..4ecf62cf 100644 --- a/src/spike_psvae/denoise.py +++ b/src/spike_psvae/denoise.py @@ -95,7 +95,9 @@ def phase_shift_and_hallucination_idx_preshift(waveforms_roll_denoise, waveforms which = slice(offset-10, offset+10) - d_s_corr = wfs_corr(waveforms_roll_denoise[:, which], waveforms_roll[:, which])#torch.sum(wfs_denoised[which]*chan_wfs[which], 1)/torch.sqrt(torch.sum(chan_wfs[which]*chan_wfs[which],1) * torch.sum(wfs_denoised[which]*wfs_denoised[which],1)) ## didn't use which at the beginning! check whether this changes the results + d_s_corr = wfs_corr(waveforms_roll_denoise[:, which], waveforms_roll[:, which]) + # torch.sum(wfs_denoised[which]*chan_wfs[which], 1)/torch.sqrt(torch.sum(chan_wfs[which]*chan_wfs[which],1) * torch.sum(wfs_denoised[which]*wfs_denoised[which],1)) + # didn't use which at the beginning! check whether this changes the results halu_idx = (ptp(waveforms_roll_denoise, 1) decr_ptp[parents_rel].max(): decr_ptp[c] *= decr_ptp[parents_rel].max() / decr_ptp[c] + # decreasing_ptps[i] = decr_ptp # apply decreasing ptps to the original waveforms rescale = (decreasing_ptps / orig_ptps)[:, None, :] + if is_torch: rescale = torch.as_tensor(rescale, device=waveforms.device) if in_place: diff --git a/src/spike_psvae/localize_index.py b/src/spike_psvae/localize_index.py index ea8708a9..b90f8d73 100644 --- a/src/spike_psvae/localize_index.py +++ b/src/spike_psvae/localize_index.py @@ -65,9 +65,9 @@ def ptp_at_dipole(x1, y1, z1, alpha, x2, y2, z2): ) - 1 / np.sqrt( - np.square(x2 - local_geom[:, 0]) - + np.square(z2 - local_geom[:, 1]) - + np.square(y2) + np.square(x2 + x1 - local_geom[:, 0]) + + np.square(z2 + z1 - local_geom[:, 1]) + + np.square(y2 + y1) ) ) return ptp_dipole_out @@ -107,18 +107,24 @@ def mse(loc): # - (np.log1p(10.0 * y) / 10000.0 if logbarrier else 0) # ) - def mse_dipole(x_in): - x1 = x_in[0] - y1 = x_in[1] - z1 = x_in[2] - x2 = x_in[3] - y2 = x_in[4] - z2 = x_in[5] - q = ptp_at_dipole(x1, y1, z1, 1.0, x2, y2, z2) - alpha = (q * ptp).sum() / (q * q).sum() - return np.square( - ptp - ptp_at_dipole(x1, y1, z1, alpha, x2, y2, z2) - ).mean() - (np.log1p(10.0 * y1) / 10000.0 if logbarrier else 0) + def mse_dipole(loc): + x, y, z = loc + # q = ptp_at(x, y, z, 1.0) + # alpha = (q * (ptp / maxptp - delta)).sum() / (q * q).sum() + duv = np.c_[ + x - local_geom[:, 0], + np.broadcast_to(y, ptp.shape), + z - local_geom[:, 1], + ] + X = duv / np.power(np.square(duv).sum(axis=1, keepdims=True), 3/2) + beta = np.linalg.solve(X.T @ X, X.T @ (ptp / maxptp)) + qtq = X @ beta + return ( + np.square(ptp / maxptp - qtq).mean() + # np.square(ptp / maxptp - delta - ptp_at(x, y, z, alpha)).mean() + # np.square(np.maximum(0, ptp / maxptp - ptp_at(x, y, z, alpha))).mean() + - np.log1p(10.0 * y) / 10000.0 + ) if model == "pointsource": result = minimize( @@ -146,24 +152,51 @@ def mse_dipole(x_in): result = minimize( mse_dipole, - x0=[xcom, Y0, zcom, xcom + 1, Y0 + 1, zcom + 1], + x0=[xcom, Y0, zcom], bounds=[ (local_geom[:, 0].min() - DX, local_geom[:, 0].max() + DX), (1e-4, 250), (-DZ, DZ), - (-100, 100), - (-100, 100), - (-100, 100), ], ) # print(result) - bx, by, bz_rel, bpx, bpy, bpz = result.x - - q = ptp_at_dipole(bx, by, bz_rel, 1.0, bpx, bpy, bpz) - - balpha = (q * ptp).sum() / (q * q).sum() - return bx, by, bz_rel, balpha + bx, by, bz_rel = result.x + + duv = np.c_[ + bx - local_geom[:, 0], + np.broadcast_to(by, ptp.shape), + bz_rel - local_geom[:, 1], + ] + X = duv / np.power(np.square(duv).sum(axis=1, keepdims=True), 3/2) + beta = np.linalg.solve(X.T @ X, X.T @ (ptp / maxptp)) + beta /= np.sqrt(np.square(beta).sum()) + dipole_planar_direction = np.sqrt(np.square(beta[[0, 2]]).sum()) + closest_chan = np.square(duv).sum(1).argmin() + min_duv = duv[closest_chan] + + val_th = np.sqrt(np.square(min_duv).sum())/dipole_planar_direction + + # reparameterized_dist = np.sqrt(np.square(min_duv[0]/beta[2]) + np.square(min_duv[2]/beta[0]) + # + np.square(min_duv[1]/beta[1])) + + if val_th<250: + return bx, by, bz_rel, val_th + else: + result = minimize( + mse, + x0=[xcom, Y0, zcom], + bounds=[ + (local_geom[:, 0].min() - DX, local_geom[:, 0].max() + DX), + (1e-4, 250), + (-DZ, DZ), + ], + ) + # print(result) + bx, by, bz_rel = result.x + q = ptp_at(bx, by, bz_rel, 1.0) + balpha = (ptp * q).sum() / np.square(q).sum() + return bx, by, bz_rel, val_th else: raise NameError("Wrong localization model") @@ -230,6 +263,5 @@ def localize_ptps_index( ys[n] = y z_rels[n] = z_rel alphas[n] = alpha - z_abss = z_rels + geom[maxchans, 1] return xs, ys, z_rels, z_abss, alphas diff --git a/src/spike_psvae/subtract.py b/src/spike_psvae/subtract.py index d1e3d62a..0cc936e9 100644 --- a/src/spike_psvae/subtract.py +++ b/src/spike_psvae/subtract.py @@ -668,11 +668,11 @@ def subtraction_binary( n_channels = geom.shape[0] recording = sc.read_binary( - standardized_bin, - sampling_rate, - n_channels, - binary_dtype, - time_axis=time_axis, + file_paths=standardized_bin, + sampling_frequency=sampling_rate, + num_channels=n_channels, + dtype=binary_dtype, + time_axis=0, is_filtered=True, ) @@ -1077,7 +1077,7 @@ def subtraction_batch( batch_data_folder / f"{prefix}{f.name}.npy", feat, ) - + denoised_wfs = full_denoising( cleaned_wfs, spike_index[:, 1], From 6d9e0c0c9c49e29c688d4d804ce24feb994cf8cc Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 6 Nov 2023 14:19:04 -0500 Subject: [PATCH 14/49] Check in new code before removing old --- src/dartsort/templates/new_pairwise_util.py | 958 ++++++++++++++++++++ src/dartsort/templates/pairwise.py | 105 +++ src/dartsort/templates/template_util.py | 80 +- src/dartsort/util/drift_util.py | 113 +++ 4 files changed, 1236 insertions(+), 20 deletions(-) create mode 100644 src/dartsort/templates/new_pairwise_util.py diff --git a/src/dartsort/templates/new_pairwise_util.py b/src/dartsort/templates/new_pairwise_util.py new file mode 100644 index 00000000..a3cf3516 --- /dev/null +++ b/src/dartsort/templates/new_pairwise_util.py @@ -0,0 +1,958 @@ +from __future__ import annotations # allow forward type references + +from collections import namedtuple +from dataclasses import dataclass, fields +from pathlib import Path +from typing import Iterator, Optional + +import h5py +import numpy as np +import torch +import torch.nn.functional as F +from dartsort.util import drift_util +from dartsort.util.multiprocessing_util import get_pool +from scipy.spatial import KDTree +from scipy.spatial.distance import pdist +from tqdm.auto import tqdm + +from . import template_util, templates + + +def compressed_convolve_to_h5( + output_hdf5_filename, + template_data: templates.TemplateData, + low_rank_templates: template_util.LowRankTemplates, + compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, + chunk_time_centers_s: Optional[np.ndarray] = None, + motion_est=None, + geom: Optional[np.ndarray] = None, + reg_geom: Optional[np.ndarray] = None, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + conv_batch_size=128, + units_batch_size=8, + overwrite=False, + device=None, + n_jobs=0, + show_progress=True, +): + """Convolve all pairs of templates and store result in a .h5 + + See pairwise.CompressedPairwiseConvDB for how to read the + resulting convolutions back. + + This runs compressed_convolve_pairs in a loop over chunks + of unit pairs, so that it's not all done in memory at one time, + and so that it can be done in parallel. + """ + if overwrite: + pass # TODO + + # construct indexing helpers + template_shift_index = drift_util.get_shift_and_unit_pairs( + chunk_time_centers_s, + geom, + template_data, + motion_est=motion_est, + ) + upsampled_shifted_template_index = get_upsampled_shifted_template_index( + template_shift_index, compressed_upsampled_temporal + ) + + chunk_res_iterator = iterate_compressed_pairwise_convolutions( + template_data=template_data, + low_rank_templates=low_rank_templates, + compressed_upsampled_temporal=compressed_upsampled_temporal, + geom=geom, + reg_geom=reg_geom, + conv_ignore_threshold=conv_ignore_threshold, + coarse_approx_error_threshold=coarse_approx_error_threshold, + max_shift="full", + conv_batch_size=conv_batch_size, + units_batch_size=units_batch_size, + device=device, + n_jobs=n_jobs, + show_progress=show_progress, + ) + + pconv_index = np.zeros( + ( + template_shift_index.n_shifted_templates, + upsampled_shifted_template_index.n_upsampled_shifted_templates, + ), + dtype=int, + ) + n_pconvs = 1 + with h5py.File(output_hdf5_filename, "w") as h5: + # resizeable pconv dataset + spike_length_samples = template_data.templates.shape[1] + pconv = h5.create_dataset( + "pconv", + dtype=np.float32, + shape=(1, 2 * spike_length_samples - 1), + maxshape=(None, 2 * spike_length_samples - 1), + chunks=(128, 2 * spike_length_samples - 1), + ) + + for chunk_res in chunk_res_iterator: + if chunk_res is None: + continue + + # get shifted template indices for A + shifted_temp_ix_a = template_shift_index.template_shift_index[ + chunk_res.template_indices_a, + chunk_res.shift_indices_a, + ] + + # upsampled shifted template indices for B + up_shifted_temp_ix_b = upsampled_shifted_template_index.upsampled_shifted_template_index[ + chunk_res.template_indices_b, + chunk_res.shift_indices_b, + chunk_res.upsampling_indices_b, + ] + + # store new set of indices + new_pconv_indices = chunk_res.compression_index + n_pconvs + pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b] = new_pconv_indices + + # store new pconvs + n_new_pconvs = chunk_res.compressed_conv.shape[0] + pconv.resize(n_pconvs + n_new_pconvs, axis=0) + pconv[n_pconvs:] = chunk_res.pconv + + n_pconvs += n_new_pconvs + + # write fixed size outputs + h5.create_dataset("shifts", data=template_shift_index.all_pitch_shifts) + h5.create_dataset("shifted_template_index", data=template_shift_index.template_shift_index) + h5.create_dataset("upsampled_shifted_template_index", data=upsampled_shifted_template_index.upsampled_shifted_template_index) + h5.create_dataset("pconv_index", data=pconv_index) + + return output_hdf5_filename + + +def iterate_compressed_pairwise_convolutions( + template_data: templates.TemplateData, + low_rank_templates: template_util.LowRankTemplates, + compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, + template_shift_index: drift_util.TemplateShiftIndex, + upsampled_shifted_template_index: UpsampledShiftedTemplateIndex, + geom: Optional[np.ndarray] = None, + reg_geom: Optional[np.ndarray] = None, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + max_shift="full", + conv_batch_size=128, + units_batch_size=8, + device=None, + n_jobs=0, + show_progress=True, +) -> Iterator[Optional[CompressedConvResult]]: + """A generator of CompressedConvResults capturing all pairs of templates + + + Runs the function compressed_convolve_pairs on chunks of units. + + This is a helper function for parallelizing computation of cross correlations + between pairs of templates. There are too many to store all the results in + memory, so this is a generator yielding a chunk at a time. Callers may + process the results differently. + """ + # construct drift-related helper data if needed + n_shifts = template_shift_index.all_pitch_shifts.size + do_shifting = n_shifts > 1 + geom_kdtree = reg_geom_kdtree = match_distance = None + if do_shifting: + geom_kdtree = KDTree(geom) + reg_geom_kdtree = KDTree(reg_geom) + match_distance = pdist(geom).min() / 2 + + # make chunks + units = np.unique(template_data.unit_ids) + jobs = [] + for start_a in range(0, units.size, units_batch_size): + end_a = min(start_a + units_batch_size, units.size) + for start_b in range(start_a, units.size, units_batch_size): + end_b = min(start_b + units_batch_size, units.size) + jobs.append((units[start_a:end_a], units[start_b:end_b])) + if show_progress: + jobs = tqdm( + jobs, smoothing=0.01, desc="Pairwise convolution", unit="pair block" + ) + + # worker kwargs + kwargs = dict( + template_data=template_data, + low_rank_templates=low_rank_templates, + compressed_upsampled_temporal=compressed_upsampled_temporal, + template_shift_index=template_shift_index, + upsampled_shifted_template_index=upsampled_shifted_template_index, + geom=geom, + reg_geom=reg_geom, + geom_kdtree=geom_kdtree, + reg_geom_kdtree=reg_geom_kdtree, + match_distance=match_distance, + conv_ignore_threshold=conv_ignore_threshold, + coarse_approx_error_threshold=coarse_approx_error_threshold, + max_shift=max_shift, + batch_size=conv_batch_size, + device=device, + ) + + n_jobs, Executor, context, rank_queue = get_pool(n_jobs, with_rank_queue=True) + with Executor( + n_jobs, + mp_context=context, + initializer=_conv_worker_init, + initargs=(rank_queue, device, kwargs), + ) as pool: + yield from pool.map(_conv_job, jobs) + + +@dataclass +class CompressedConvResult: + """Return type of compressed_convolve_pairs + + After convolving a bunch of template pairs, some convolutions + may be zero. Let n_pairs be the number of nonzero convolutions. + We don't store the zero ones. + """ + + # arrays of shape n_pairs, + # For each convolved pair, these document which templates were + # in the pair, what their relative shifts were, and what the + # upsampling was (we only upsample the RHS) + template_indices_a: np.ndarray + template_indices_b: np.ndarray + shift_indices_a: np.ndarray + shift_indices_b: np.ndarray + upsampling_indices_b: np.ndarray + + # another one of shape n_pairs + # maps a pair index to the corresponding convolution index + # some convolutions are duplicates, so this array contains + # many duplicate entries in the range 0, ..., n_convs-1 + compression_index: np.ndarray + + # this one has shape (n_convs, 2 * spike_length_samples - 1) + compressed_conv: np.ndarray + + +def compressed_convolve_pairs( + template_data: templates.TemplateData, + low_rank_templates: template_util.LowRankTemplates, + compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, + template_shift_index: drift_util.TemplateShiftIndex, + upsampled_shifted_template_index: UpsampledShiftedTemplateIndex, + geom: Optional[np.ndarray] = None, + reg_geom: Optional[np.ndarray] = None, + geom_kdtree: Optional[KDTree] = None, + reg_geom_kdtree: Optional[KDTree] = None, + match_distance: Optional[float] = None, + units_a: Optional[np.ndarray] = None, + units_b: Optional[np.ndarray] = None, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + max_shift="full", + batch_size=128, + device=None, +) -> Optional[CompressedConvResult]: + """Compute compressed pairwise convolutions between template pairs + + Takes as input all the template data and groups of pairs of units to convolve + (units_a,b). units_a,b are unit indices, not template indices (i.e., coarse + units, not superresolved bin indices). + + Returns compressed convolutions between all units_a[i], units_b[j], for all + shifts, superres templates, and upsamples. Some of these may be zero or may + be duplicates, so the return value is a sparse representation. See below. + """ + # what pairs, shifts, etc are we convolving? + shifted_temp_ix_a, temp_ix_a, shift_a, unit_a = handle_shift_indices( + units_a, template_data.unit_ids, template_shift_index + ) + shifted_temp_ix_b, temp_ix_b, shift_b, unit_b = handle_shift_indices( + units_b, template_data.unit_ids, template_shift_index + ) + + # get (shifted) spatial components * singular values + spatial_singular_a = get_shifted_spatial_singular( + temp_ix_a, + shift_a, + template_shift_index, + low_rank_templates, + geom=geom, + registered_geom=reg_geom, + geom_kdtree=geom_kdtree, + match_distance=match_distance, + device=device, + ) + spatial_singular_b = get_shifted_spatial_singular( + temp_ix_b, + shift_b, + template_shift_index, + low_rank_templates, + geom=geom, + registered_geom=reg_geom, + geom_kdtree=geom_kdtree, + match_distance=match_distance, + device=device, + ) + + # figure out pairs of shifted templates to convolve in a deduplicated way + pairs_ret = shift_deduplicated_pairs( + shifted_temp_ix_a, + shifted_temp_ix_b, + spatial_singular_a, + spatial_singular_b, + temp_ix_a, + temp_ix_b, + shift_a=shift_a, + shift_b=shift_b, + template_shift_index=template_shift_index, + conv_ignore_threshold=conv_ignore_threshold, + geom=geom, + registered_geom=reg_geom, + reg_geom_kdtree=reg_geom_kdtree, + match_distance=match_distance, + ) + if pairs_ret is None: + return None + ix_a, ix_b, compression_index, conv_ix = pairs_ret + + # handle upsampling + # each pair will be duplicated by the b unit's number of upsampled copies + ( + ix_a, + ix_b, + compression_index, + conv_ix, + conv_upsampling_indices_b, + conv_temporal_components_up_b, + ) = compressed_upsampled_pairs( + ix_a, + ix_b, + compression_index, + conv_ix, + temp_ix_b, + shifted_temp_ix_b, + upsampled_shifted_template_index, + compressed_upsampled_temporal, + ) + + # # now, these arrays all have length n_pairs + # shifted_temp_ix_a = shifted_temp_ix_a[ix_a] + # temp_ix_a = temp_ix_a[ix_a] + # shift_a = shift_a[ix_a] + # shifted_temp_ix_b = shifted_temp_ix_b[ix_b] + # temp_ix_b = temp_ix_b[ix_b] + # shift_b = shift_b[ix_b] + + # run convolutions + temporal_a = low_rank_templates.temporal_components[temp_ix_a] + pconv, kept = correlate_pairs_lowrank( + spatial_singular_a[ix_a[conv_ix]].to(device), + spatial_singular_b[ix_b[conv_ix]].to(device), + temporal_a[ix_a[conv_ix]].to(device), + conv_temporal_components_up_b.to(device), + max_shift=max_shift, + conv_ignore_threshold=conv_ignore_threshold, + batch_size=batch_size, + ) + if not kept.size: + return None + kept_pairs = np.isin(conv_ix[compression_index], conv_ix[kept]) + conv_ix = conv_ix[kept] + compression_index = compression_index[kept_pairs] + ix_a = ix_a[kept_pairs] + ix_b = ix_b[kept_pairs] + # compression_index = compression_index[kept] + pconv = pconv.cpu() + + # coarse approx + pconv, old_ix_to_new_ix = coarse_approximate( + pconv, + unit_a[ix_a[conv_ix]], + unit_b[ix_b[conv_ix]], + temp_ix_a[ix_a[conv_ix]], + shift_a[ix_a[conv_ix]], + shift_b[ix_b[conv_ix]], + coarse_approx_error_threshold=coarse_approx_error_threshold, + ) + # above function invalidates the whole idea of conv_ix + del conv_ix + compression_index = old_ix_to_new_ix[compression_index] + + # recover metadata + temp_ix_a = temp_ix_a[ix_a] + shift_ix_a = np.searchsorted(template_shift_index.all_pitch_shifts, shift_a[ix_a]) + temp_ix_b = temp_ix_b[ix_b] + shift_ix_b = np.searchsorted(template_shift_index.all_pitch_shifts, shift_b[ix_b]) + + return CompressedConvResult( + template_indices_a=temp_ix_a, + template_indices_b=temp_ix_b, + shift_indices_a=shift_ix_a, + shift_indices_b=shift_ix_b, + upsampling_indices_b=conv_upsampling_indices_b[compression_index], + compression_index=compression_index, + compressed_conv=pconv.numpy(), + ) + + +# -- helpers + + +def correlate_pairs_lowrank( + spatial_a, + spatial_b, + temporal_a, + temporal_b, + max_shift="full", + conv_ignore_threshold=0.0, + batch_size=128, +): + """Convolve pairs of low rank templates + + For each i, we want to convolve (temporal_a[i] @ spatial_a[i]) with + (temporal_b[i] @ spatial_b[i]). So, spatial_{a,b} and temporal_{a,b} + should contain lots of duplicates, since they are already representing + pairs. + + Templates Ka = Sa Ta, Kb = Sb Tb. The channel-summed convolution is + (Ka (*) Kb) = sum_c Ka(c) * Kb(c) + = (Sb.T @ Ka) (*) Tb + = (Sb.T @ Sa @ Ta) (*) Tb + where * is cross-correlation, and (*) is channel (or rank) summed. + We use full-height conv2d to do rank-summed convs. + + Returns + ------- + pconv, kept + """ + n_pairs, rank, nchan = spatial_a.shape + n_pairs_, rank_, nchan_ = spatial_b.shape + assert rank == rank_ + assert nchan == nchan_ + assert n_pairs == n_pairs_ + n_pairs_, t, rank_ = temporal_a.shape + assert n_pairs == n_pairs_ + assert rank_ == rank + n_pairs_, t_, rank_ = temporal_b.shape + assert n_pairs == n_pairs_ + assert t == t_ + assert rank == rank_ + + if max_shift == "full": + max_shift = t - 1 + elif max_shift == "valid": + max_shift = 0 + elif max_shift == "same": + max_shift = t // 2 + + # batch over n_pairs for memory reasons + pconv = torch.zeros( + (n_pairs, 2 * max_shift + 1), dtype=spatial_a.dtype, device=spatial_a.device + ) + for istart in range(0, n_pairs, batch_size): + iend = min(istart + batch_size, n_pairs) + ix = slice(istart, iend) + + # want conv filter: nco, 1, rank, t + template_a = torch.bmm(temporal_a[ix], spatial_a[ix]) + conv_filt = torch.bmm(spatial_b[ix], template_a.mT) + conv_filt = conv_filt[:, None] # (nco, 1, rank, t) + + # nup, nco, rank, t + conv_in = temporal_b[ix].permute(2, 0, 3, 1) + + # conv2d: + # depthwise, chans=nco. batch=1. h=rank. w=t. out: nup, nco, 1, 2p+1. + # input (conv_in): nup, nco, rank, t. + # filters (conv_filt): nco, 1, rank, t. (groups=nco). + pconv_ = F.conv2d( + conv_in, conv_filt, padding=(0, max_shift), groups=iend - istart + ) + pconv[istart:iend] = pconv_[:, :, 0, :].permute(1, 0, 2) # nco, nup, time + + # more stringent covisibility + kept = slice(None) + if conv_ignore_threshold > 0: + max_val = pconv.reshape(n_pairs, -1).abs().max(dim=1).values + kept = max_val > conv_ignore_threshold + pconv = pconv[kept] + kept = np.flatnonzero(kept.numpy(force=True)) + + return pconv, kept + + +def handle_shift_indices(units, unit_ids, template_shift_index): + shifted_temp_ix_to_unit = unit_ids[template_shift_index.shifted_temp_ix_to_temp_ix] + if units is None: + shifted_temp_ix = np.arange(template_shift_index.n_shifted_templates) + else: + shifted_temp_ix = np.flatnonzero(np.isin(shifted_temp_ix_to_unit, units)) + + shift = template_shift_index.shifted_temp_ix_to_shift[shifted_temp_ix] + temp_ix = template_shift_index.shifted_temp_ix_to_temp_ix[shifted_temp_ix] + unit = unit_ids[temp_ix] + + return shifted_temp_ix, temp_ix, shift, unit + + +def get_shifted_spatial_singular( + temp_ix, + shift, + template_shift_index, + low_rank_templates, + geom=None, + registered_geom=None, + geom_kdtree=None, + match_distance=None, + device=None, +): + # do we need to shift the templates? + n_shifts = template_shift_index.all_pitch_shifts.size + do_shifting = n_shifts > 1 + + spatial_singular = ( + low_rank_templates.spatial_components[temp_ix] + * low_rank_templates.singular_values[temp_ix][..., None] + ) + if do_shifting: + spatial_singular = drift_util.get_waveforms_on_static_channels( + spatial_singular, + registered_geom, + n_pitches_shift=shift, + registered_geom=geom, + target_kdtree=geom_kdtree, + match_distance=match_distance, + fill_value=0.0, + ) + spatial_singular = torch.as_tensor(spatial_singular, device=device) + + return spatial_singular + + +def shift_deduplicated_pairs( + shifted_temp_ix_a, + shifted_temp_ix_b, + spatialsing_a, + spatialsing_b, + temp_ix_a, + temp_ix_b, + shift_a=None, + shift_b=None, + template_shift_index=None, + conv_ignore_threshold=0.0, + geom=None, + registered_geom=None, + reg_geom_kdtree=None, + match_distance=None, +): + """Choose a set of pairs of indices from group A and B to convolve + + Some pairs of shifted templates don't overlap, so we don't need to convolve them. + Some pairs of shifted templates never show up in the recording at the same time + (what this code calls "cooccurrence"), so we don't need to convolve them. + We don't need to convolve the same pair of templates twice, just where the indices + are ordered (shifted_temp_ix_a <= shifted_temp_ix_b). + + More complicated: for each shift, a certain set of registered template channels + survives. Given that the some set of visible channels has survived for a pair of + templates at shifts shift_a and shift_b, their cross-correlation at these shifts + is the same as the one at shift_a_prime and shift_b_prime if the same exact channels + survived at shift_a_prime and shift_b_prime and if + shift_a-shift_b == shift_a_prime-shift_b_prime. + + Returns + ------- + pair_ix_a, pair_ix_b + Size < original number of shifted templates a,b + The indices of shifted templates which overlap enough to be + co-visible. So, these are subsets of shifted_temp_ix_a,b + compression_index + Size == pair_ix_a,b size + Subsets of conv_ix_a,b, so that the xcorr of templates + shifted_temp_ix_a[pair_ix_a[i]], shifted_temp_ix_b[pair_ix_b[i]] + is the same as that of + shifted_temp_ix_a[pair_ix_a[conv_ix[compression_index[i]]], + pair_ix_b[conv_ix[compression_index[i]]] + conv_ix + Size < original number of shifted templates a,b + Pairs of templates which should actually be convolved + """ + # check spatially overlapping + chan_amp_a = torch.sqrt(torch.square(spatialsing_a).sum(1)) + chan_amp_b = torch.sqrt(torch.square(spatialsing_b).sum(1)) + pair = chan_amp_a @ chan_amp_b.T + pair = pair > conv_ignore_threshold + + # co-occurrence + pair *= template_shift_index.cooccurrence + + # mask out lower triangle + pair *= shifted_temp_ix_a[:, None] <= shifted_temp_ix_b[None, :] + pair_ix_a, pair_ix_b = torch.nonzero(pair, as_tuple=True) + nco = pair_ix_a.numel() + if not nco: + return None + + # if no shifting, deduplication is the identity + do_shifting = template_shift_index.all_pitch_shifts.size > 1 + if not do_shifting: + assert shift_b is None + nco_range = torch.arange(nco, device=pair_ix_a.device) + return pair_ix_a, pair_ix_b, nco_range, nco_range + + # shift deduplication. algorithm: + # 1 for each shifted template, determine the set of registered channels + # which it occupies + # 2 assign each such set an ID (an "active channel ID") + # - // then a pair of shifted templates' xcorr is a function of the pair + # // of active channel IDs and the difference of shifts + # 3 figure out the set of unique (active chan id a, active chan id b, shift diff a,b) + # combinations in each pair of units + + # 1: get active channel neighborhoods as many-hot len(reg_geom)-vectors + active_chans_a = drift_util.get_waveforms_on_static_channels( + (chan_amp_a > 0).numpy(force=True), + geom, + n_pitches_shift=-shift_a, + registered_geom=registered_geom, + target_kdtree=reg_geom_kdtree, + match_distance=match_distance, + fill_value=0, + ) + active_chans_b = drift_util.get_waveforms_on_static_channels( + (chan_amp_b > 0).numpy(force=True), + geom, + n_pitches_shift=-shift_b, + registered_geom=registered_geom, + target_kdtree=reg_geom_kdtree, + match_distance=match_distance, + fill_value=0, + ) + # 2: assign IDs to each such vector + _, active_chan_ids_a = np.unique(active_chans_a, axis=0, return_inverse=True) + _, active_chan_ids_b = np.unique(active_chans_b, axis=0, return_inverse=True) + + # 3 + temp_ix_a = temp_ix_a[pair_ix_a] + temp_ix_b = temp_ix_b[pair_ix_b] + # get the relative shifts + shift_a = shift_a[pair_ix_a] + shift_b = shift_b[pair_ix_b] + shift_diff = shift_a - shift_b + + # figure out combinations + conv_determiners = np.c_[ + temp_ix_a, + active_chan_ids_a[pair_ix_a], + temp_ix_b, + active_chan_ids_b[pair_ix_b], + shift_diff, + ] + # conv_ix: indices of unique determiners + # compression_index: which representative does each pair belong to + _, conv_ix, compression_index = np.unique( + conv_determiners, axis=0, return_index=True, return_inverse=True + ) + + return pair_ix_a, pair_ix_b, compression_index, conv_ix + + +UpsampledShiftedTemplateIndex = namedtuple( + "UpsampledShiftedTemplateIndex", + [ + "n_upsampled_shifted_templates", + "upsampled_shifted_template_index", + "up_shift_temp_ix_to_shift_temp_ix", + "up_shift_temp_ix_to_temp_ix", + "up_shift_temp_ix_to_comp_up_ix", + ], +) + + +def get_upsampled_shifted_template_index( + template_shift_index, compressed_upsampled_temporal +): + """Make a compressed index space for upsampled shifted templates + + See also: template_util.{compressed_upsampled_templates,ComptessedUpsampledTemplates}. + + The comp_up_ix / compressed upsampled template indices here are indices into that + structure. + + Returns + ------- + UpsampledShiftedTemplateIndex + named tuple with fields: + upsampled_shifted_template_index : (n_templates, n_shifts, up_factor) + Maps template_ix, shift_ix, up_ix -> compressed upsampled template index + up_shift_temp_ix_to_shift_temp_ix + up_shift_temp_ix_to_temp_ix + up_shift_temp_ix_to_comp_up_ix + """ + n_shifted_templates = template_shift_index.n_shifted_templates + n_templates, n_shifts = template_shift_index.template_shift_index.shape + max_upsample = compressed_upsampled_temporal.compressed_usampling_map.shape[1] + + cur_up_shift_temp_ix = 0 + # fill with an invalid index + upsampled_shifted_template_index = np.full( + (n_templates, n_shifts, max_upsample), n_shifted_templates * max_upsample + ) + usti2sti = [] + usti2ti = [] + usti2cui = [] + for i in range(n_templates): + shifted_temps = template_shift_index.template_shift_index[i] + valid_shifts = np.flatnonzero(shifted_temps < n_shifted_templates) + + upsampled_temps = compressed_upsampled_temporal.compressed_usampling_map[i] + unique_comp_up_inds, inverse = np.unique(upsampled_temps, return_inverse=True) + + for j in valid_shifts: + up_shift_inds = unique_comp_up_inds + cur_up_shift_temp_ix + upsampled_shifted_template_index[i, j] = up_shift_inds[inverse] + cur_up_shift_temp_ix += up_shift_inds.size + + usti2sti.extend([shifted_temps[j]] * up_shift_inds.size) + usti2ti.extend([i] * up_shift_inds.size) + usti2cui.extend(unique_comp_up_inds) + + up_shift_temp_ix_to_shift_temp_ix = np.array(usti2sti) + up_shift_temp_ix_to_temp_ix = np.array(usti2ti) + up_shift_temp_ix_to_comp_up_ix = np.array(usti2cui) + + return UpsampledShiftedTemplateIndex( + up_shift_temp_ix_to_shift_temp_ix.size, + upsampled_shifted_template_index, + up_shift_temp_ix_to_shift_temp_ix, + up_shift_temp_ix_to_temp_ix, + up_shift_temp_ix_to_comp_up_ix, + ) + + +def compressed_upsampled_pairs( + ix_a, + ix_b, + compression_index, + conv_ix, + temp_ix_b, + shifted_temp_ix_b, + upsampled_shifted_template_index, + compressed_upsampled_temporal, +): + """Add in upsampling to the set of pairs that need to be convolved + + So far, ix_a,b, compression_index, and conv_ix are such that non-upsampled + convolutions between templates ix_a[i], ix_b[i] equal that between templates + ix_a[conv_ix[compression_index[i]]], ix_b[conv_ix[compression_index[i]]]. + + We will upsample the templates in the RHS (b) in a compressed way. + """ + up_factor = compressed_upsampled_temporal.compressed_usampling_map.shape[1] + if up_factor == 1: + upinds = np.zeros(conv_ix.size, dtype=int) + temp_comps = compressed_upsampled_temporal.compressed_upsampled_templates[ + temp_ix_b[ix_b[conv_ix]] + ] + return ix_a, ix_b, compression_index, conv_ix, upinds, temp_comps + + # each conv_ix needs to be duplicated as many times as its b template has + # upsampled copies. And, all ix_{a,b}[i] such that compression_ix[i] lands in + # that conv_ix need to be duplicated as well. + ix_a_up = [] + ix_b_up = [] + compression_index_up = [] + conv_ix_up = [] + conv_compressed_upsampled_ix = [] + cur_dedup_ix = 0 + for i, convi in enumerate(conv_ix): + # get b's shifted template ix + conv_shifted_temp_ix_b = shifted_temp_ix_b[ix_b[convi]] + + # which compressed upsampled indices match this? + which_up = np.flatnonzero( + upsampled_shifted_template_index.up_shift_temp_ix_to_shift_temp_ix + == conv_shifted_temp_ix_b + ) + conv_comp_up_ix = ( + upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[which_up] + ) + + # which deduplication indices map ix_a,b to this convi? + which_dedup = np.flatnonzero(compression_index == i) + + # extend arrays with new indices + nupi = conv_comp_up_ix.size + ix_a_up.extend(np.repeat(ix_a[which_dedup], nupi)) + ix_b_up.extend(np.repeat(ix_b[which_dedup], nupi)) + conv_ix_up.extend([convi] * nupi) + compression_index_up.extend( + np.tile(np.arange(cur_dedup_ix, cur_dedup_ix + nupi), which_dedup.size) + ) + cur_dedup_ix += nupi + conv_compressed_upsampled_ix.extend(conv_comp_up_ix) + + ix_a_up = np.array(ix_a_up) + ix_b_up = np.array(ix_b_up) + compression_index_up = np.array(compression_index_up) + conv_ix_up = np.array(conv_ix_up) + conv_compressed_upsampled_ix = np.array(conv_compressed_upsampled_ix) + + # which upsamples and which templates? + conv_upsampling_indices_b = ( + compressed_upsampled_temporal.compressed_index_to_upsampling_index[ + conv_compressed_upsampled_ix + ] + ) + conv_temporal_components_up_b = ( + compressed_upsampled_temporal.compressed_index_to_upsampling_index[ + conv_compressed_upsampled_ix + ] + ) + + return ( + ix_a_up, + ix_b_up, + compression_index_up, + conv_ix_up, + conv_upsampling_indices_b, + conv_temporal_components_up_b, + ) + + +def coarse_approximate( + pconv, + units_a, + units_b, + temp_ix_a, + shift_a, + shift_b, + coarse_approx_error_threshold=0.0, +): + """Try to replace fine (superres+temporally upsampled) convs with coarse ones + + For each pair of convolved units, we first try to replace all of the pairwise + convolutions between these units with their mean, respecting the shifts. + + If that fails, we try to do this in a factorized way: for each superres unit a, + try to replace all of its convolutions with unit b with their mean, respecting + the shifts. + + Above, "respecting the shifts" means we only do this within each shift-deduplication + class, since changes in the sets of channels being convolved cause large changes + in the cross correlation. pconv has already been deduplicated with respect to + equivalent channel neighborhoods, so all that matters for that purpose is the + shift difference. + + This needs to tell the caller how to update its bookkeeping. + """ + new_pconv = [] + old_ix_to_new_ix = np.full(len(pconv), -1) + cur_new_ix = 0 + shift_diff = shift_a - shift_b + for ua in np.unique(units_a): + ina = np.flatnonzero(units_a == ua) + partners_b = np.unique(units_b[ina]) + for ub in partners_b: + inab = ina[units_b[ina] == ub] + dshift = shift_diff[inab] + for shift in np.unique(dshift): + inshift = inab[dshift == shift] + + convs = pconv[inshift] + meanconv = convs.mean(dim=0, keepdims=True) + if (convs - meanconv).abs().max() < coarse_approx_error_threshold: + # do something + new_pconv.append(meanconv) + old_ix_to_new_ix[inshift] = cur_new_ix + cur_new_ix += 1 + continue + # else: + # new_pconv.append(convs) + # old_ix_to_new_ix[inshift] = np.arange(cur_new_ix, cur_new_ix + inshift.size) + # cur_new_ix += inshift.size + + active_temp_a = temp_ix_a[inshift] + unique_active_temp_a = np.unique(active_temp_a) + if unique_active_temp_a.size == 1: + new_pconv.append(convs) + old_ix_to_new_ix[inshift] = np.arange( + cur_new_ix, cur_new_ix + inshift.size + ) + cur_new_ix += inshift.size + continue + + for tixa in unique_active_temp_a: + insup = active_temp_a == tixa + supconvs = convs[insup] + + meanconv = supconvs.mean(dim=0, keepdims=True) + if (convs - meanconv).abs().max() < coarse_approx_error_threshold: + new_pconv.append(meanconv) + old_ix_to_new_ix[insup] = cur_new_ix + cur_new_ix += 1 + else: + new_pconv.append(supconvs) + old_ix_to_new_ix[insup] = np.arange( + cur_new_ix, cur_new_ix + insup.size + ) + cur_new_ix += insup.size + + new_pconv = torch.cat(new_pconv) + return new_pconv, old_ix_to_new_ix + + +# -- parallelism helpers + + +@dataclass +class ConvWorkerContext: + template_data: templates.TemplateData + low_rank_templates: template_util.LowRankTemplates + compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates + template_shift_index: drift_util.TemplateShiftIndex + upsampled_shifted_template_index: UpsampledShiftedTemplateIndex + geom: Optional[np.ndarray] = None + reg_geom: Optional[np.ndarray] = None + geom_kdtree: Optional[KDTree] = None + reg_geom_kdtree: Optional[KDTree] = None + match_distance: Optional[float] = None + conv_ignore_threshold = 0.0 + coarse_approx_error_threshold = 0.0 + max_shift = "full" + batch_size = 128 + device = None + + +_conv_worker_context = None + + +def _conv_worker_init(rank_queue, device, kwargs): + global _conv_worker_context + + my_rank = rank_queue.get() + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + if device.type == "cuda" and device.index is None: + if torch.cuda.device_count() > 1: + device = torch.device("cuda", index=my_rank % torch.cuda.device_count()) + + _conv_worker_context = ConvWorkerContext(device=device, **kwargs) + + +def _conv_job(unit_chunk): + global _pairwise_conv_context + units_a, units_b = unit_chunk + return compressed_convolve_pairs( + units_a=units_a, units_b=units_b, **asdict_shallow(_pairwise_conv_context) + ) + + +def asdict_shallow(obj): + return {field.name: getattr(obj, field.name) for field in fields(obj)} diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index 330b824a..32c554a0 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -1,3 +1,108 @@ +from dataclasses import dataclass, fields + +import h5py +import numpy as np + + +@dataclass +class CompressedPairwiseConv: + """A database of channel-summed cross-correlations between template pairs + + There are too many templates to store all of these, especially after + superres binning, temporal upsampling, and pitch shifting. We compress + this as much as possible, first by deduplication (many convolutions of + templates at different shifts are identical), next by not wasting space + (no need to compute as many upsampled copies of small templates), and + finally by approximation (for pairs of far-away units, correlations of + superres templates are very close to correlations of the non-superres + template). + + This database holds some indexing structures that help us store these + correlations sparsely. .query() grabs the actual correlations for the + user. + """ + # shape: (n_shifts,) + # shift_ix -> shift (pitch shift, an integer) + shifts: np.ndarray + + # shape: (n_templates, n_shifts) + # (template_ix, shift_ix) -> shifted_template_ix + # shifted_template_ix can be either invalid (this template does not occur + # at this shift), or it can range from 0, ..., n_shifted_templates-1 + shifted_template_index: np.ndarray + + # shape: (n_templates, n_shifts, upsampling_factor) + # (template_ix, shift_ix, upsampling_ix) -> upsampled_shifted_template_ix + upsampled_shifted_template_index: np.ndarray + + # shape: (n_shifted_templates, n_upsampled_shifted_templates) + # (shifted_template_ix, upsampled_shifted_template_ix) -> pconv_ix + pconv_index: np.ndarray + + # shape: (n_pconvs, 2 * spike_length_samples - 1) + # pconv_ix -> a cross-correlation array + # the 0 index is special: pconv[0] === 0. + pconv: np.ndarray + + @classmethod + def from_h5(cls, hdf5_filename): + ff = fields(cls) + with h5py.File(hdf5_filename, "r") as h5: + data = {f.name: h5[f.name][:] for f in ff} + return cls(**data) + + def query( + self, + template_indices_a, + template_indices_b, + upsampling_indices_b=None, + shifts_a=None, + shifts_b=None, + return_zero_convs=False, + ): + # handle no shifting + no_shifting = shifts_a is None or shifts_b is None + shifted_template_index = self.shifted_template_index + upsampled_shifted_template_index = self.upsampled_shifted_template_index + if no_shifting: + assert shifts_a is None and shifts_b is None + assert self.shifts.shape == (1,) + a_ix = (template_indices_a,) + b_ix = (template_indices_b,) + shifted_template_index = shifted_template_index[:, 0] + upsampled_shifted_template_index = upsampled_shifted_template_index[:, 0] + else: + shift_indices_a = np.searchsorted(self.shifts, shifts_a) + shift_indices_b = np.searchsorted(self.shifts, shifts_b) + a_ix = (template_indices_a, shift_indices_a) + b_ix = (template_indices_b, shift_indices_b) + + # handle no upsampling + no_upsampling = upsampling_indices_b is None + if no_upsampling: + assert self.upsampled_shifted_template_index.shape[2] == 1 + upsampled_shifted_template_index = self.upsampled_shifted_template_index[..., 0] + else: + b_ix = b_ix + (upsampling_indices_b,) + + # get shifted template indices for A + shifted_temp_ix_a = shifted_template_index[a_ix] + + # upsampled shifted template indices for B + up_shifted_temp_ix_b = upsampled_shifted_template_index[b_ix] + + pconv_indices = self.pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b] + + # most users will be happy not to get a bunch of zeros for pairs that don't overlap + if not return_zero_convs: + which = np.flatnonzero(pconv_indices > 0) + pconv_indices = pconv_indices[which] + template_indices_a = template_indices_a[which] + template_indices_b = template_indices_b[which] + + return template_indices_a, template_indices_b, self.pconv[pconv_indices] + + @dataclass class SparsePairwiseConv: diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 55320054..8114f222 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -1,3 +1,5 @@ +from collections import namedtuple + import numpy as np from dartsort.localize.localize_util import localize_waveforms from dartsort.util import drift_util @@ -187,6 +189,11 @@ def templates_at_time( # -- template numerical processing +LowRankTemplates = namedtuple( + "LowRankTemplates", ["temporal_components", "singular_values", "spatial_components"] +) + + def svd_compress_templates( templates, min_channel_amplitude=1.0, rank=5, channel_sparse=True ): @@ -227,7 +234,8 @@ def svd_compress_templates( temporal_components[i, :, :k] = U[:, :rank] singular_values[i, :k] = s[:rank] spatial_components[i, :k, mask] = Vh[:rank].T - return temporal_components, singular_values, spatial_components + + return LowRankTemplates(temporal_components, singular_values, spatial_components) def temporally_upsample_templates( @@ -247,29 +255,50 @@ def temporally_upsample_templates( return upsampled_templates +CompressedUpsampledTemplates = namedtuple( + "CompressedUpsampledTemplates", + [ + "compressed_upsampled_templates", + "compressed_upsampling_map", + "compressed_index_to_template_index", + "compressed_index_to_upsampling_index", + ], +) + + def default_n_upsamples_map(ptps): return 4 ** (ptps // 2) -def sparse_upsampled_templates( +def compressed_upsampled_templates( templates, ptps=None, max_upsample=8, n_upsamples_map=default_n_upsamples_map, kind="cubic", ): - """Sparsely store fewer temporally upsampled copies of lower amplitude templates + """compressedly store fewer temporally upsampled copies of lower amplitude templates Returns ------- - sparse_upsampled_templates : array (n_sparse_upsampled_templates, spike_length_samples) - sparse_upsampling_map : array (n_templates, max_upsample) - sparse_upsampled_templates[sparse_upsampling_map[unit, j]] is an approximation - of the jth upsampled template for this unit. for low-amplitude units, - sparse_upsampling_map[unit] will have fewer unique entries, corresponding - to fewer saved upsampled copies for that unit. + A CompressedUpsampledTemplates object with fields: + compressed_upsampled_templates : array (n_compressed_upsampled_templates, spike_length_samples) + compressed_upsampling_map : array (n_templates, max_upsample) + compressed_upsampled_templates[compressed_upsampling_map[unit, j]] is an approximation + of the jth upsampled template for this unit. for low-amplitude units, + compressed_upsampling_map[unit] will have fewer unique entries, corresponding + to fewer saved upsampled copies for that unit. + compressed_index_to_template_index + compressed_index_to_upsampling_index """ n_templates = templates.shape[0] + if max_upsample == 1: + return CompressedUpsampledTemplates( + templates, + np.arange(n_templates)[:, None], + np.arange(n_templates), + np.zeros(n_templates, dtype=int) + ) # how many copies should each unit get? # sometimes users may pass temporal SVD components in instead of templates, @@ -282,27 +311,38 @@ def sparse_upsampled_templates( else: n_upsamples = np.clip(n_upsamples_map(ptps), 1, max_upsample).astype(int) - # build the sparse upsampling map - sparse_upsampling_map = np.zeros((n_templates, max_upsample), dtype=int) - upsampling_indices = [] + # build the compressed upsampling map + compressed_upsampling_map = np.zeros((n_templates, max_upsample), dtype=int) template_indices = [] - current_sparse_index = 0 + upsampling_indices = [] + current_compressed_index = 0 for i, nup in enumerate(n_upsamples): compression = max_upsample // nup nup = max_upsample // compression # handle divisibility failure - # new sparse indices - sparse_upsampling_map[i] = current_sparse_index + np.arange(nup).repeat(compression) - current_sparse_index += nup + # new compressed indices + compressed_upsampling_map[i] = current_compressed_index + np.arange(nup).repeat( + compression + ) + current_compressed_index += nup # indices of the templates to keep in the full array of upsampled templates - upsampling_indices.extend(compression * np.arange(nup)) template_indices.extend([i] * nup) + upsampling_indices.extend(compression * np.arange(nup)) + template_indices = np.array(template_indices) + upsampling_indices = np.array(upsampling_indices) # get the upsampled templates all_upsampled_templates = temporally_upsample_templates( templates, temporal_upsampling_factor=max_upsample, kind=kind ) - sparse_upsampled_templates = all_upsampled_templates[template_indices, upsampling_indices] - - return sparse_upsampled_templates, sparse_upsampling_map + compressed_upsampled_templates = all_upsampled_templates[ + template_indices, upsampling_indices + ] + + return CompressedUpsampledTemplates( + compressed_upsampled_templates, + compressed_upsampling_map, + template_indices, + upsampling_indices, + ) diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index 8c698490..e3761924 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -10,6 +10,8 @@ by integer numbers of pitches. As many shifted copies are created as needed to capture all the drift. """ +from dataclasses import dataclass + import numpy as np import torch from scipy.spatial import KDTree @@ -456,3 +458,114 @@ def _full_probe_shifting_fast( shifted_channels[shift_inverse][:, None, :], ] = waveforms return static_waveforms[:, :, : target_kdtree.n] + + +# -- which templates appear at which shifts in a recording? +# and, which pairs of shifted templates appear together? + + +@dataclass +class TemplateShiftIndex: + """Return value for get_shift_and_unit_pairs""" + + n_shifted_templates: int + # shift index -> shift + all_pitch_shifts: np.ndarray + # (template ix, shift index) -> shifted template index + template_shift_index: np.ndarray + # (shifted temp ix, shifted temp ix) -> did these appear at the same time + cooccurrence: np.ndarray + shifted_temp_ix_to_temp_ix: np.ndarray + shifted_temp_ix_to_shift: np.ndarray + + +def static_template_shift_index(n_templates): + temp_ixs = np.arange(n_templates) + return TemplateShiftIndex( + n_templates, + np.zeros(1), + temp_ixs[:, None], + np.ones((n_templates, n_templates), dtype=bool), + temp_ixs, + np.zeros_like(temp_ixs), + ) + + +def get_shift_and_unit_pairs( + chunk_time_centers_s, + geom, + template_data, + motion_est=None, +): + n_templates = len(template_data.templates) + if motion_est is None: + # no motion case + return static_template_shift_index(n_templates) + + # all observed pitch shift values + all_pitch_shifts = np.empty(shape=(0,), dtype=int) + temp_ixs = np.arange(n_templates) + # set of (template idx, shift) + template_shift_pairs = np.empty(shape=(0, 2), dtype=int) + pitch = get_pitch(geom) + + for t_s in chunk_time_centers_s: + # see the fn `templates_at_time` + unregistered_depths_um = invert_motion_estimate( + motion_est, t_s, template_data.registered_template_depths_um + ) + pitch_shifts = get_spike_pitch_shifts( + depths_um=template_data.registered_template_depths_um, + pitch=pitch, + registered_depths_um=unregistered_depths_um, + ) + pitch_shifts = pitch_shifts.astype(int) + + # get unique pitch/unit shift pairs in chunk + template_shift = np.c_[temp_ixs, pitch_shifts] + + # update full set + all_pitch_shifts = np.union1d(all_pitch_shifts, pitch_shifts) + template_shift_pairs = np.unique( + np.concatenate((template_shift_pairs, template_shift), axis=0), axis=0 + ) + + n_shifts = len(all_pitch_shifts) + n_template_shift_pairs = len(template_shift_pairs) + + # index template/shift pairs: template_shift_index[template_ix, shift_ix] = shifted template index + # fill with an invalid index + template_shift_index = np.full((n_templates, n_shifts), n_template_shift_pairs) + shift_ix = np.searchsorted(all_pitch_shifts, template_shift_pairs[:, 1]) + assert np.array_equal(all_pitch_shifts[shift_ix], template_shift_pairs[:, 1]) + template_shift_index[template_shift_pairs[:, 0], shift_ix] = np.arange( + n_template_shift_pairs + ) + shifted_temp_ix_to_temp_ix = template_shift_pairs[:, 0] + shifted_temp_ix_to_shift = template_shift_pairs[:, 1] + + # co-occurrence matrix: do these shifted templates appear together? + cooccurrence = np.eye(n_template_shift_pairs, dtype=bool) + for t_s in chunk_time_centers_s: + unregistered_depths_um = invert_motion_estimate( + motion_est, t_s, template_data.registered_template_depths_um + ) + pitch_shifts = get_spike_pitch_shifts( + depths_um=template_data.registered_template_depths_um, + pitch=pitch, + registered_depths_um=unregistered_depths_um, + ) + pitch_shifts = pitch_shifts.astype(int) + pitch_shift_ix = np.searchsorted(all_pitch_shifts, pitch_shifts) + + shifted_temp_ixs = template_shift_index[temp_ixs, pitch_shift_ix] + cooccurrence[shifted_temp_ixs[:, None], shifted_temp_ixs[None, :]] = 1 + + return TemplateShiftIndex( + n_template_shift_pairs, + all_pitch_shifts, + template_shift_index, + cooccurrence, + shifted_temp_ix_to_temp_ix, + shifted_temp_ix_to_shift, + ) From 0892e1d627f154a4bd7b98a71dac1003e095f12f Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 6 Nov 2023 14:19:44 -0500 Subject: [PATCH 15/49] Overwrite --- src/dartsort/templates/new_pairwise_util.py | 958 ------------- src/dartsort/templates/pairwise.py | 97 -- src/dartsort/templates/pairwise_util.py | 1352 ++++++++++--------- 3 files changed, 704 insertions(+), 1703 deletions(-) delete mode 100644 src/dartsort/templates/new_pairwise_util.py diff --git a/src/dartsort/templates/new_pairwise_util.py b/src/dartsort/templates/new_pairwise_util.py deleted file mode 100644 index a3cf3516..00000000 --- a/src/dartsort/templates/new_pairwise_util.py +++ /dev/null @@ -1,958 +0,0 @@ -from __future__ import annotations # allow forward type references - -from collections import namedtuple -from dataclasses import dataclass, fields -from pathlib import Path -from typing import Iterator, Optional - -import h5py -import numpy as np -import torch -import torch.nn.functional as F -from dartsort.util import drift_util -from dartsort.util.multiprocessing_util import get_pool -from scipy.spatial import KDTree -from scipy.spatial.distance import pdist -from tqdm.auto import tqdm - -from . import template_util, templates - - -def compressed_convolve_to_h5( - output_hdf5_filename, - template_data: templates.TemplateData, - low_rank_templates: template_util.LowRankTemplates, - compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, - chunk_time_centers_s: Optional[np.ndarray] = None, - motion_est=None, - geom: Optional[np.ndarray] = None, - reg_geom: Optional[np.ndarray] = None, - conv_ignore_threshold=0.0, - coarse_approx_error_threshold=0.0, - conv_batch_size=128, - units_batch_size=8, - overwrite=False, - device=None, - n_jobs=0, - show_progress=True, -): - """Convolve all pairs of templates and store result in a .h5 - - See pairwise.CompressedPairwiseConvDB for how to read the - resulting convolutions back. - - This runs compressed_convolve_pairs in a loop over chunks - of unit pairs, so that it's not all done in memory at one time, - and so that it can be done in parallel. - """ - if overwrite: - pass # TODO - - # construct indexing helpers - template_shift_index = drift_util.get_shift_and_unit_pairs( - chunk_time_centers_s, - geom, - template_data, - motion_est=motion_est, - ) - upsampled_shifted_template_index = get_upsampled_shifted_template_index( - template_shift_index, compressed_upsampled_temporal - ) - - chunk_res_iterator = iterate_compressed_pairwise_convolutions( - template_data=template_data, - low_rank_templates=low_rank_templates, - compressed_upsampled_temporal=compressed_upsampled_temporal, - geom=geom, - reg_geom=reg_geom, - conv_ignore_threshold=conv_ignore_threshold, - coarse_approx_error_threshold=coarse_approx_error_threshold, - max_shift="full", - conv_batch_size=conv_batch_size, - units_batch_size=units_batch_size, - device=device, - n_jobs=n_jobs, - show_progress=show_progress, - ) - - pconv_index = np.zeros( - ( - template_shift_index.n_shifted_templates, - upsampled_shifted_template_index.n_upsampled_shifted_templates, - ), - dtype=int, - ) - n_pconvs = 1 - with h5py.File(output_hdf5_filename, "w") as h5: - # resizeable pconv dataset - spike_length_samples = template_data.templates.shape[1] - pconv = h5.create_dataset( - "pconv", - dtype=np.float32, - shape=(1, 2 * spike_length_samples - 1), - maxshape=(None, 2 * spike_length_samples - 1), - chunks=(128, 2 * spike_length_samples - 1), - ) - - for chunk_res in chunk_res_iterator: - if chunk_res is None: - continue - - # get shifted template indices for A - shifted_temp_ix_a = template_shift_index.template_shift_index[ - chunk_res.template_indices_a, - chunk_res.shift_indices_a, - ] - - # upsampled shifted template indices for B - up_shifted_temp_ix_b = upsampled_shifted_template_index.upsampled_shifted_template_index[ - chunk_res.template_indices_b, - chunk_res.shift_indices_b, - chunk_res.upsampling_indices_b, - ] - - # store new set of indices - new_pconv_indices = chunk_res.compression_index + n_pconvs - pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b] = new_pconv_indices - - # store new pconvs - n_new_pconvs = chunk_res.compressed_conv.shape[0] - pconv.resize(n_pconvs + n_new_pconvs, axis=0) - pconv[n_pconvs:] = chunk_res.pconv - - n_pconvs += n_new_pconvs - - # write fixed size outputs - h5.create_dataset("shifts", data=template_shift_index.all_pitch_shifts) - h5.create_dataset("shifted_template_index", data=template_shift_index.template_shift_index) - h5.create_dataset("upsampled_shifted_template_index", data=upsampled_shifted_template_index.upsampled_shifted_template_index) - h5.create_dataset("pconv_index", data=pconv_index) - - return output_hdf5_filename - - -def iterate_compressed_pairwise_convolutions( - template_data: templates.TemplateData, - low_rank_templates: template_util.LowRankTemplates, - compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, - template_shift_index: drift_util.TemplateShiftIndex, - upsampled_shifted_template_index: UpsampledShiftedTemplateIndex, - geom: Optional[np.ndarray] = None, - reg_geom: Optional[np.ndarray] = None, - conv_ignore_threshold=0.0, - coarse_approx_error_threshold=0.0, - max_shift="full", - conv_batch_size=128, - units_batch_size=8, - device=None, - n_jobs=0, - show_progress=True, -) -> Iterator[Optional[CompressedConvResult]]: - """A generator of CompressedConvResults capturing all pairs of templates - - - Runs the function compressed_convolve_pairs on chunks of units. - - This is a helper function for parallelizing computation of cross correlations - between pairs of templates. There are too many to store all the results in - memory, so this is a generator yielding a chunk at a time. Callers may - process the results differently. - """ - # construct drift-related helper data if needed - n_shifts = template_shift_index.all_pitch_shifts.size - do_shifting = n_shifts > 1 - geom_kdtree = reg_geom_kdtree = match_distance = None - if do_shifting: - geom_kdtree = KDTree(geom) - reg_geom_kdtree = KDTree(reg_geom) - match_distance = pdist(geom).min() / 2 - - # make chunks - units = np.unique(template_data.unit_ids) - jobs = [] - for start_a in range(0, units.size, units_batch_size): - end_a = min(start_a + units_batch_size, units.size) - for start_b in range(start_a, units.size, units_batch_size): - end_b = min(start_b + units_batch_size, units.size) - jobs.append((units[start_a:end_a], units[start_b:end_b])) - if show_progress: - jobs = tqdm( - jobs, smoothing=0.01, desc="Pairwise convolution", unit="pair block" - ) - - # worker kwargs - kwargs = dict( - template_data=template_data, - low_rank_templates=low_rank_templates, - compressed_upsampled_temporal=compressed_upsampled_temporal, - template_shift_index=template_shift_index, - upsampled_shifted_template_index=upsampled_shifted_template_index, - geom=geom, - reg_geom=reg_geom, - geom_kdtree=geom_kdtree, - reg_geom_kdtree=reg_geom_kdtree, - match_distance=match_distance, - conv_ignore_threshold=conv_ignore_threshold, - coarse_approx_error_threshold=coarse_approx_error_threshold, - max_shift=max_shift, - batch_size=conv_batch_size, - device=device, - ) - - n_jobs, Executor, context, rank_queue = get_pool(n_jobs, with_rank_queue=True) - with Executor( - n_jobs, - mp_context=context, - initializer=_conv_worker_init, - initargs=(rank_queue, device, kwargs), - ) as pool: - yield from pool.map(_conv_job, jobs) - - -@dataclass -class CompressedConvResult: - """Return type of compressed_convolve_pairs - - After convolving a bunch of template pairs, some convolutions - may be zero. Let n_pairs be the number of nonzero convolutions. - We don't store the zero ones. - """ - - # arrays of shape n_pairs, - # For each convolved pair, these document which templates were - # in the pair, what their relative shifts were, and what the - # upsampling was (we only upsample the RHS) - template_indices_a: np.ndarray - template_indices_b: np.ndarray - shift_indices_a: np.ndarray - shift_indices_b: np.ndarray - upsampling_indices_b: np.ndarray - - # another one of shape n_pairs - # maps a pair index to the corresponding convolution index - # some convolutions are duplicates, so this array contains - # many duplicate entries in the range 0, ..., n_convs-1 - compression_index: np.ndarray - - # this one has shape (n_convs, 2 * spike_length_samples - 1) - compressed_conv: np.ndarray - - -def compressed_convolve_pairs( - template_data: templates.TemplateData, - low_rank_templates: template_util.LowRankTemplates, - compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, - template_shift_index: drift_util.TemplateShiftIndex, - upsampled_shifted_template_index: UpsampledShiftedTemplateIndex, - geom: Optional[np.ndarray] = None, - reg_geom: Optional[np.ndarray] = None, - geom_kdtree: Optional[KDTree] = None, - reg_geom_kdtree: Optional[KDTree] = None, - match_distance: Optional[float] = None, - units_a: Optional[np.ndarray] = None, - units_b: Optional[np.ndarray] = None, - conv_ignore_threshold=0.0, - coarse_approx_error_threshold=0.0, - max_shift="full", - batch_size=128, - device=None, -) -> Optional[CompressedConvResult]: - """Compute compressed pairwise convolutions between template pairs - - Takes as input all the template data and groups of pairs of units to convolve - (units_a,b). units_a,b are unit indices, not template indices (i.e., coarse - units, not superresolved bin indices). - - Returns compressed convolutions between all units_a[i], units_b[j], for all - shifts, superres templates, and upsamples. Some of these may be zero or may - be duplicates, so the return value is a sparse representation. See below. - """ - # what pairs, shifts, etc are we convolving? - shifted_temp_ix_a, temp_ix_a, shift_a, unit_a = handle_shift_indices( - units_a, template_data.unit_ids, template_shift_index - ) - shifted_temp_ix_b, temp_ix_b, shift_b, unit_b = handle_shift_indices( - units_b, template_data.unit_ids, template_shift_index - ) - - # get (shifted) spatial components * singular values - spatial_singular_a = get_shifted_spatial_singular( - temp_ix_a, - shift_a, - template_shift_index, - low_rank_templates, - geom=geom, - registered_geom=reg_geom, - geom_kdtree=geom_kdtree, - match_distance=match_distance, - device=device, - ) - spatial_singular_b = get_shifted_spatial_singular( - temp_ix_b, - shift_b, - template_shift_index, - low_rank_templates, - geom=geom, - registered_geom=reg_geom, - geom_kdtree=geom_kdtree, - match_distance=match_distance, - device=device, - ) - - # figure out pairs of shifted templates to convolve in a deduplicated way - pairs_ret = shift_deduplicated_pairs( - shifted_temp_ix_a, - shifted_temp_ix_b, - spatial_singular_a, - spatial_singular_b, - temp_ix_a, - temp_ix_b, - shift_a=shift_a, - shift_b=shift_b, - template_shift_index=template_shift_index, - conv_ignore_threshold=conv_ignore_threshold, - geom=geom, - registered_geom=reg_geom, - reg_geom_kdtree=reg_geom_kdtree, - match_distance=match_distance, - ) - if pairs_ret is None: - return None - ix_a, ix_b, compression_index, conv_ix = pairs_ret - - # handle upsampling - # each pair will be duplicated by the b unit's number of upsampled copies - ( - ix_a, - ix_b, - compression_index, - conv_ix, - conv_upsampling_indices_b, - conv_temporal_components_up_b, - ) = compressed_upsampled_pairs( - ix_a, - ix_b, - compression_index, - conv_ix, - temp_ix_b, - shifted_temp_ix_b, - upsampled_shifted_template_index, - compressed_upsampled_temporal, - ) - - # # now, these arrays all have length n_pairs - # shifted_temp_ix_a = shifted_temp_ix_a[ix_a] - # temp_ix_a = temp_ix_a[ix_a] - # shift_a = shift_a[ix_a] - # shifted_temp_ix_b = shifted_temp_ix_b[ix_b] - # temp_ix_b = temp_ix_b[ix_b] - # shift_b = shift_b[ix_b] - - # run convolutions - temporal_a = low_rank_templates.temporal_components[temp_ix_a] - pconv, kept = correlate_pairs_lowrank( - spatial_singular_a[ix_a[conv_ix]].to(device), - spatial_singular_b[ix_b[conv_ix]].to(device), - temporal_a[ix_a[conv_ix]].to(device), - conv_temporal_components_up_b.to(device), - max_shift=max_shift, - conv_ignore_threshold=conv_ignore_threshold, - batch_size=batch_size, - ) - if not kept.size: - return None - kept_pairs = np.isin(conv_ix[compression_index], conv_ix[kept]) - conv_ix = conv_ix[kept] - compression_index = compression_index[kept_pairs] - ix_a = ix_a[kept_pairs] - ix_b = ix_b[kept_pairs] - # compression_index = compression_index[kept] - pconv = pconv.cpu() - - # coarse approx - pconv, old_ix_to_new_ix = coarse_approximate( - pconv, - unit_a[ix_a[conv_ix]], - unit_b[ix_b[conv_ix]], - temp_ix_a[ix_a[conv_ix]], - shift_a[ix_a[conv_ix]], - shift_b[ix_b[conv_ix]], - coarse_approx_error_threshold=coarse_approx_error_threshold, - ) - # above function invalidates the whole idea of conv_ix - del conv_ix - compression_index = old_ix_to_new_ix[compression_index] - - # recover metadata - temp_ix_a = temp_ix_a[ix_a] - shift_ix_a = np.searchsorted(template_shift_index.all_pitch_shifts, shift_a[ix_a]) - temp_ix_b = temp_ix_b[ix_b] - shift_ix_b = np.searchsorted(template_shift_index.all_pitch_shifts, shift_b[ix_b]) - - return CompressedConvResult( - template_indices_a=temp_ix_a, - template_indices_b=temp_ix_b, - shift_indices_a=shift_ix_a, - shift_indices_b=shift_ix_b, - upsampling_indices_b=conv_upsampling_indices_b[compression_index], - compression_index=compression_index, - compressed_conv=pconv.numpy(), - ) - - -# -- helpers - - -def correlate_pairs_lowrank( - spatial_a, - spatial_b, - temporal_a, - temporal_b, - max_shift="full", - conv_ignore_threshold=0.0, - batch_size=128, -): - """Convolve pairs of low rank templates - - For each i, we want to convolve (temporal_a[i] @ spatial_a[i]) with - (temporal_b[i] @ spatial_b[i]). So, spatial_{a,b} and temporal_{a,b} - should contain lots of duplicates, since they are already representing - pairs. - - Templates Ka = Sa Ta, Kb = Sb Tb. The channel-summed convolution is - (Ka (*) Kb) = sum_c Ka(c) * Kb(c) - = (Sb.T @ Ka) (*) Tb - = (Sb.T @ Sa @ Ta) (*) Tb - where * is cross-correlation, and (*) is channel (or rank) summed. - We use full-height conv2d to do rank-summed convs. - - Returns - ------- - pconv, kept - """ - n_pairs, rank, nchan = spatial_a.shape - n_pairs_, rank_, nchan_ = spatial_b.shape - assert rank == rank_ - assert nchan == nchan_ - assert n_pairs == n_pairs_ - n_pairs_, t, rank_ = temporal_a.shape - assert n_pairs == n_pairs_ - assert rank_ == rank - n_pairs_, t_, rank_ = temporal_b.shape - assert n_pairs == n_pairs_ - assert t == t_ - assert rank == rank_ - - if max_shift == "full": - max_shift = t - 1 - elif max_shift == "valid": - max_shift = 0 - elif max_shift == "same": - max_shift = t // 2 - - # batch over n_pairs for memory reasons - pconv = torch.zeros( - (n_pairs, 2 * max_shift + 1), dtype=spatial_a.dtype, device=spatial_a.device - ) - for istart in range(0, n_pairs, batch_size): - iend = min(istart + batch_size, n_pairs) - ix = slice(istart, iend) - - # want conv filter: nco, 1, rank, t - template_a = torch.bmm(temporal_a[ix], spatial_a[ix]) - conv_filt = torch.bmm(spatial_b[ix], template_a.mT) - conv_filt = conv_filt[:, None] # (nco, 1, rank, t) - - # nup, nco, rank, t - conv_in = temporal_b[ix].permute(2, 0, 3, 1) - - # conv2d: - # depthwise, chans=nco. batch=1. h=rank. w=t. out: nup, nco, 1, 2p+1. - # input (conv_in): nup, nco, rank, t. - # filters (conv_filt): nco, 1, rank, t. (groups=nco). - pconv_ = F.conv2d( - conv_in, conv_filt, padding=(0, max_shift), groups=iend - istart - ) - pconv[istart:iend] = pconv_[:, :, 0, :].permute(1, 0, 2) # nco, nup, time - - # more stringent covisibility - kept = slice(None) - if conv_ignore_threshold > 0: - max_val = pconv.reshape(n_pairs, -1).abs().max(dim=1).values - kept = max_val > conv_ignore_threshold - pconv = pconv[kept] - kept = np.flatnonzero(kept.numpy(force=True)) - - return pconv, kept - - -def handle_shift_indices(units, unit_ids, template_shift_index): - shifted_temp_ix_to_unit = unit_ids[template_shift_index.shifted_temp_ix_to_temp_ix] - if units is None: - shifted_temp_ix = np.arange(template_shift_index.n_shifted_templates) - else: - shifted_temp_ix = np.flatnonzero(np.isin(shifted_temp_ix_to_unit, units)) - - shift = template_shift_index.shifted_temp_ix_to_shift[shifted_temp_ix] - temp_ix = template_shift_index.shifted_temp_ix_to_temp_ix[shifted_temp_ix] - unit = unit_ids[temp_ix] - - return shifted_temp_ix, temp_ix, shift, unit - - -def get_shifted_spatial_singular( - temp_ix, - shift, - template_shift_index, - low_rank_templates, - geom=None, - registered_geom=None, - geom_kdtree=None, - match_distance=None, - device=None, -): - # do we need to shift the templates? - n_shifts = template_shift_index.all_pitch_shifts.size - do_shifting = n_shifts > 1 - - spatial_singular = ( - low_rank_templates.spatial_components[temp_ix] - * low_rank_templates.singular_values[temp_ix][..., None] - ) - if do_shifting: - spatial_singular = drift_util.get_waveforms_on_static_channels( - spatial_singular, - registered_geom, - n_pitches_shift=shift, - registered_geom=geom, - target_kdtree=geom_kdtree, - match_distance=match_distance, - fill_value=0.0, - ) - spatial_singular = torch.as_tensor(spatial_singular, device=device) - - return spatial_singular - - -def shift_deduplicated_pairs( - shifted_temp_ix_a, - shifted_temp_ix_b, - spatialsing_a, - spatialsing_b, - temp_ix_a, - temp_ix_b, - shift_a=None, - shift_b=None, - template_shift_index=None, - conv_ignore_threshold=0.0, - geom=None, - registered_geom=None, - reg_geom_kdtree=None, - match_distance=None, -): - """Choose a set of pairs of indices from group A and B to convolve - - Some pairs of shifted templates don't overlap, so we don't need to convolve them. - Some pairs of shifted templates never show up in the recording at the same time - (what this code calls "cooccurrence"), so we don't need to convolve them. - We don't need to convolve the same pair of templates twice, just where the indices - are ordered (shifted_temp_ix_a <= shifted_temp_ix_b). - - More complicated: for each shift, a certain set of registered template channels - survives. Given that the some set of visible channels has survived for a pair of - templates at shifts shift_a and shift_b, their cross-correlation at these shifts - is the same as the one at shift_a_prime and shift_b_prime if the same exact channels - survived at shift_a_prime and shift_b_prime and if - shift_a-shift_b == shift_a_prime-shift_b_prime. - - Returns - ------- - pair_ix_a, pair_ix_b - Size < original number of shifted templates a,b - The indices of shifted templates which overlap enough to be - co-visible. So, these are subsets of shifted_temp_ix_a,b - compression_index - Size == pair_ix_a,b size - Subsets of conv_ix_a,b, so that the xcorr of templates - shifted_temp_ix_a[pair_ix_a[i]], shifted_temp_ix_b[pair_ix_b[i]] - is the same as that of - shifted_temp_ix_a[pair_ix_a[conv_ix[compression_index[i]]], - pair_ix_b[conv_ix[compression_index[i]]] - conv_ix - Size < original number of shifted templates a,b - Pairs of templates which should actually be convolved - """ - # check spatially overlapping - chan_amp_a = torch.sqrt(torch.square(spatialsing_a).sum(1)) - chan_amp_b = torch.sqrt(torch.square(spatialsing_b).sum(1)) - pair = chan_amp_a @ chan_amp_b.T - pair = pair > conv_ignore_threshold - - # co-occurrence - pair *= template_shift_index.cooccurrence - - # mask out lower triangle - pair *= shifted_temp_ix_a[:, None] <= shifted_temp_ix_b[None, :] - pair_ix_a, pair_ix_b = torch.nonzero(pair, as_tuple=True) - nco = pair_ix_a.numel() - if not nco: - return None - - # if no shifting, deduplication is the identity - do_shifting = template_shift_index.all_pitch_shifts.size > 1 - if not do_shifting: - assert shift_b is None - nco_range = torch.arange(nco, device=pair_ix_a.device) - return pair_ix_a, pair_ix_b, nco_range, nco_range - - # shift deduplication. algorithm: - # 1 for each shifted template, determine the set of registered channels - # which it occupies - # 2 assign each such set an ID (an "active channel ID") - # - // then a pair of shifted templates' xcorr is a function of the pair - # // of active channel IDs and the difference of shifts - # 3 figure out the set of unique (active chan id a, active chan id b, shift diff a,b) - # combinations in each pair of units - - # 1: get active channel neighborhoods as many-hot len(reg_geom)-vectors - active_chans_a = drift_util.get_waveforms_on_static_channels( - (chan_amp_a > 0).numpy(force=True), - geom, - n_pitches_shift=-shift_a, - registered_geom=registered_geom, - target_kdtree=reg_geom_kdtree, - match_distance=match_distance, - fill_value=0, - ) - active_chans_b = drift_util.get_waveforms_on_static_channels( - (chan_amp_b > 0).numpy(force=True), - geom, - n_pitches_shift=-shift_b, - registered_geom=registered_geom, - target_kdtree=reg_geom_kdtree, - match_distance=match_distance, - fill_value=0, - ) - # 2: assign IDs to each such vector - _, active_chan_ids_a = np.unique(active_chans_a, axis=0, return_inverse=True) - _, active_chan_ids_b = np.unique(active_chans_b, axis=0, return_inverse=True) - - # 3 - temp_ix_a = temp_ix_a[pair_ix_a] - temp_ix_b = temp_ix_b[pair_ix_b] - # get the relative shifts - shift_a = shift_a[pair_ix_a] - shift_b = shift_b[pair_ix_b] - shift_diff = shift_a - shift_b - - # figure out combinations - conv_determiners = np.c_[ - temp_ix_a, - active_chan_ids_a[pair_ix_a], - temp_ix_b, - active_chan_ids_b[pair_ix_b], - shift_diff, - ] - # conv_ix: indices of unique determiners - # compression_index: which representative does each pair belong to - _, conv_ix, compression_index = np.unique( - conv_determiners, axis=0, return_index=True, return_inverse=True - ) - - return pair_ix_a, pair_ix_b, compression_index, conv_ix - - -UpsampledShiftedTemplateIndex = namedtuple( - "UpsampledShiftedTemplateIndex", - [ - "n_upsampled_shifted_templates", - "upsampled_shifted_template_index", - "up_shift_temp_ix_to_shift_temp_ix", - "up_shift_temp_ix_to_temp_ix", - "up_shift_temp_ix_to_comp_up_ix", - ], -) - - -def get_upsampled_shifted_template_index( - template_shift_index, compressed_upsampled_temporal -): - """Make a compressed index space for upsampled shifted templates - - See also: template_util.{compressed_upsampled_templates,ComptessedUpsampledTemplates}. - - The comp_up_ix / compressed upsampled template indices here are indices into that - structure. - - Returns - ------- - UpsampledShiftedTemplateIndex - named tuple with fields: - upsampled_shifted_template_index : (n_templates, n_shifts, up_factor) - Maps template_ix, shift_ix, up_ix -> compressed upsampled template index - up_shift_temp_ix_to_shift_temp_ix - up_shift_temp_ix_to_temp_ix - up_shift_temp_ix_to_comp_up_ix - """ - n_shifted_templates = template_shift_index.n_shifted_templates - n_templates, n_shifts = template_shift_index.template_shift_index.shape - max_upsample = compressed_upsampled_temporal.compressed_usampling_map.shape[1] - - cur_up_shift_temp_ix = 0 - # fill with an invalid index - upsampled_shifted_template_index = np.full( - (n_templates, n_shifts, max_upsample), n_shifted_templates * max_upsample - ) - usti2sti = [] - usti2ti = [] - usti2cui = [] - for i in range(n_templates): - shifted_temps = template_shift_index.template_shift_index[i] - valid_shifts = np.flatnonzero(shifted_temps < n_shifted_templates) - - upsampled_temps = compressed_upsampled_temporal.compressed_usampling_map[i] - unique_comp_up_inds, inverse = np.unique(upsampled_temps, return_inverse=True) - - for j in valid_shifts: - up_shift_inds = unique_comp_up_inds + cur_up_shift_temp_ix - upsampled_shifted_template_index[i, j] = up_shift_inds[inverse] - cur_up_shift_temp_ix += up_shift_inds.size - - usti2sti.extend([shifted_temps[j]] * up_shift_inds.size) - usti2ti.extend([i] * up_shift_inds.size) - usti2cui.extend(unique_comp_up_inds) - - up_shift_temp_ix_to_shift_temp_ix = np.array(usti2sti) - up_shift_temp_ix_to_temp_ix = np.array(usti2ti) - up_shift_temp_ix_to_comp_up_ix = np.array(usti2cui) - - return UpsampledShiftedTemplateIndex( - up_shift_temp_ix_to_shift_temp_ix.size, - upsampled_shifted_template_index, - up_shift_temp_ix_to_shift_temp_ix, - up_shift_temp_ix_to_temp_ix, - up_shift_temp_ix_to_comp_up_ix, - ) - - -def compressed_upsampled_pairs( - ix_a, - ix_b, - compression_index, - conv_ix, - temp_ix_b, - shifted_temp_ix_b, - upsampled_shifted_template_index, - compressed_upsampled_temporal, -): - """Add in upsampling to the set of pairs that need to be convolved - - So far, ix_a,b, compression_index, and conv_ix are such that non-upsampled - convolutions between templates ix_a[i], ix_b[i] equal that between templates - ix_a[conv_ix[compression_index[i]]], ix_b[conv_ix[compression_index[i]]]. - - We will upsample the templates in the RHS (b) in a compressed way. - """ - up_factor = compressed_upsampled_temporal.compressed_usampling_map.shape[1] - if up_factor == 1: - upinds = np.zeros(conv_ix.size, dtype=int) - temp_comps = compressed_upsampled_temporal.compressed_upsampled_templates[ - temp_ix_b[ix_b[conv_ix]] - ] - return ix_a, ix_b, compression_index, conv_ix, upinds, temp_comps - - # each conv_ix needs to be duplicated as many times as its b template has - # upsampled copies. And, all ix_{a,b}[i] such that compression_ix[i] lands in - # that conv_ix need to be duplicated as well. - ix_a_up = [] - ix_b_up = [] - compression_index_up = [] - conv_ix_up = [] - conv_compressed_upsampled_ix = [] - cur_dedup_ix = 0 - for i, convi in enumerate(conv_ix): - # get b's shifted template ix - conv_shifted_temp_ix_b = shifted_temp_ix_b[ix_b[convi]] - - # which compressed upsampled indices match this? - which_up = np.flatnonzero( - upsampled_shifted_template_index.up_shift_temp_ix_to_shift_temp_ix - == conv_shifted_temp_ix_b - ) - conv_comp_up_ix = ( - upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[which_up] - ) - - # which deduplication indices map ix_a,b to this convi? - which_dedup = np.flatnonzero(compression_index == i) - - # extend arrays with new indices - nupi = conv_comp_up_ix.size - ix_a_up.extend(np.repeat(ix_a[which_dedup], nupi)) - ix_b_up.extend(np.repeat(ix_b[which_dedup], nupi)) - conv_ix_up.extend([convi] * nupi) - compression_index_up.extend( - np.tile(np.arange(cur_dedup_ix, cur_dedup_ix + nupi), which_dedup.size) - ) - cur_dedup_ix += nupi - conv_compressed_upsampled_ix.extend(conv_comp_up_ix) - - ix_a_up = np.array(ix_a_up) - ix_b_up = np.array(ix_b_up) - compression_index_up = np.array(compression_index_up) - conv_ix_up = np.array(conv_ix_up) - conv_compressed_upsampled_ix = np.array(conv_compressed_upsampled_ix) - - # which upsamples and which templates? - conv_upsampling_indices_b = ( - compressed_upsampled_temporal.compressed_index_to_upsampling_index[ - conv_compressed_upsampled_ix - ] - ) - conv_temporal_components_up_b = ( - compressed_upsampled_temporal.compressed_index_to_upsampling_index[ - conv_compressed_upsampled_ix - ] - ) - - return ( - ix_a_up, - ix_b_up, - compression_index_up, - conv_ix_up, - conv_upsampling_indices_b, - conv_temporal_components_up_b, - ) - - -def coarse_approximate( - pconv, - units_a, - units_b, - temp_ix_a, - shift_a, - shift_b, - coarse_approx_error_threshold=0.0, -): - """Try to replace fine (superres+temporally upsampled) convs with coarse ones - - For each pair of convolved units, we first try to replace all of the pairwise - convolutions between these units with their mean, respecting the shifts. - - If that fails, we try to do this in a factorized way: for each superres unit a, - try to replace all of its convolutions with unit b with their mean, respecting - the shifts. - - Above, "respecting the shifts" means we only do this within each shift-deduplication - class, since changes in the sets of channels being convolved cause large changes - in the cross correlation. pconv has already been deduplicated with respect to - equivalent channel neighborhoods, so all that matters for that purpose is the - shift difference. - - This needs to tell the caller how to update its bookkeeping. - """ - new_pconv = [] - old_ix_to_new_ix = np.full(len(pconv), -1) - cur_new_ix = 0 - shift_diff = shift_a - shift_b - for ua in np.unique(units_a): - ina = np.flatnonzero(units_a == ua) - partners_b = np.unique(units_b[ina]) - for ub in partners_b: - inab = ina[units_b[ina] == ub] - dshift = shift_diff[inab] - for shift in np.unique(dshift): - inshift = inab[dshift == shift] - - convs = pconv[inshift] - meanconv = convs.mean(dim=0, keepdims=True) - if (convs - meanconv).abs().max() < coarse_approx_error_threshold: - # do something - new_pconv.append(meanconv) - old_ix_to_new_ix[inshift] = cur_new_ix - cur_new_ix += 1 - continue - # else: - # new_pconv.append(convs) - # old_ix_to_new_ix[inshift] = np.arange(cur_new_ix, cur_new_ix + inshift.size) - # cur_new_ix += inshift.size - - active_temp_a = temp_ix_a[inshift] - unique_active_temp_a = np.unique(active_temp_a) - if unique_active_temp_a.size == 1: - new_pconv.append(convs) - old_ix_to_new_ix[inshift] = np.arange( - cur_new_ix, cur_new_ix + inshift.size - ) - cur_new_ix += inshift.size - continue - - for tixa in unique_active_temp_a: - insup = active_temp_a == tixa - supconvs = convs[insup] - - meanconv = supconvs.mean(dim=0, keepdims=True) - if (convs - meanconv).abs().max() < coarse_approx_error_threshold: - new_pconv.append(meanconv) - old_ix_to_new_ix[insup] = cur_new_ix - cur_new_ix += 1 - else: - new_pconv.append(supconvs) - old_ix_to_new_ix[insup] = np.arange( - cur_new_ix, cur_new_ix + insup.size - ) - cur_new_ix += insup.size - - new_pconv = torch.cat(new_pconv) - return new_pconv, old_ix_to_new_ix - - -# -- parallelism helpers - - -@dataclass -class ConvWorkerContext: - template_data: templates.TemplateData - low_rank_templates: template_util.LowRankTemplates - compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates - template_shift_index: drift_util.TemplateShiftIndex - upsampled_shifted_template_index: UpsampledShiftedTemplateIndex - geom: Optional[np.ndarray] = None - reg_geom: Optional[np.ndarray] = None - geom_kdtree: Optional[KDTree] = None - reg_geom_kdtree: Optional[KDTree] = None - match_distance: Optional[float] = None - conv_ignore_threshold = 0.0 - coarse_approx_error_threshold = 0.0 - max_shift = "full" - batch_size = 128 - device = None - - -_conv_worker_context = None - - -def _conv_worker_init(rank_queue, device, kwargs): - global _conv_worker_context - - my_rank = rank_queue.get() - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - device = torch.device(device) - if device.type == "cuda" and device.index is None: - if torch.cuda.device_count() > 1: - device = torch.device("cuda", index=my_rank % torch.cuda.device_count()) - - _conv_worker_context = ConvWorkerContext(device=device, **kwargs) - - -def _conv_job(unit_chunk): - global _pairwise_conv_context - units_a, units_b = unit_chunk - return compressed_convolve_pairs( - units_a=units_a, units_b=units_b, **asdict_shallow(_pairwise_conv_context) - ) - - -def asdict_shallow(obj): - return {field.name: getattr(obj, field.name) for field in fields(obj)} diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index 32c554a0..8ff9e534 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -101,100 +101,3 @@ def query( template_indices_b = template_indices_b[which] return template_indices_a, template_indices_b, self.pconv[pconv_indices] - - - -@dataclass -class SparsePairwiseConv: - # shift_ix -> shift - shifts: np.ndarray - # (temp_ix, shift_ix) -> shifted_temp_ix - template_shift_index: torch.LongTensor - # (shifted_temp_ix a, shifted_temp_ix b) -> pair index - pair_index_table: torch.LongTensor - # (pair index, upsampling index) -> pconv index - upsampling_index_table: torch.LongTensor - # pconv index -> pconv (2 * spike len - 1,) - # the zero index lands you at an all 0 pconv - pconv: torch.Tensor - - # metadata: map shifted template index to original template ix and shift - shifted_temp_ix_to_temp_ix: np.ndarray - shifted_temp_ix_to_shift: np.ndarray - shifted_temp_ix_to_unit: np.ndarray - - @classmethod - def from_h5(cls, hdf5_filename): - ff = fields(cls) - with h5py.File(hdf5_filename, "r") as h5: - data = {f.name: h5[f.name][:] for f in ff} - return cls(**data) - - def query( - self, - template_indices_a, - template_indices_b, - upsampling_indices_b=None, - shifts_a=None, - shifts_b=None, - return_zero_convs=False, - ): - """Get cross-correlations of pairs of units A and B - - This passes through the series of lookup tables to recover (upsampled) - cross-correlations from this sparse database. - - Returns - ------- - template_indices_a, template_indices_b, pair_convs - """ - template_indices_a = np.atleast_1d(template_indices_a) - template_indices_b = np.atleast_1d(template_indices_b) - shifted = shifts_a is not None - if shifted: - assert shifts_b is not None - shifts_a = np.atleast_1d(shifts_a) - shifts_b = np.atleast_1d(shifts_b) - else: - assert np.array_equal(self.shifts, [0.0]) - - # get shifted template indices - if shifted: - shift_ix_a = np.searchsorted(self.shifts, shifts_a) - assert np.array_equal(self.shifts[shift_ix_a], shifts_a) - shift_ix_b = np.searchsorted(self.shifts, shifts_b) - assert np.array_equal(self.shifts[shift_ix_b], shifts_b) - shifted_temp_ix_a = self.template_shift_index[ - template_indices_a, shift_ix_a - ] - shifted_temp_ix_b = self.template_shift_index[ - template_indices_b, shift_ix_b - ] - else: - shifted_temp_ix_a = template_indices_a - shifted_temp_ix_b = template_indices_b - - # we only store the upper triangle of this symmetric object - min_ = np.minimum(shifted_temp_ix_a, shifted_temp_ix_b) - max_ = np.maximum(shifted_temp_ix_a, shifted_temp_ix_b) - pair_indices = self.pair_index_table[min_, max_] - - # handle upsampling - if upsampling_indices_b is None: - assert self.upsampling_index_table.shape[1] == 1 - pconv_indices = self.upsampling_index_table[pair_indices, 0] - else: - pconv_indices = self.upsampling_index_table[ - pair_indices, upsampling_indices_b - ] - - # most users will be happy not to get a bunch of zeros for pairs that don't overlap - if not return_zero_convs: - which = np.flatnonzero(pconv_indices > 0) - pconv_indices = pconv_indices[which] - template_indices_a = template_indices_a[which] - template_indices_b = template_indices_b[which] - - pair_convs = self.pconv[pconv_indices] - - return template_indices_a, template_indices_b, pair_convs diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index 0bf992ea..a3cf3516 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -1,724 +1,540 @@ +from __future__ import annotations # allow forward type references + +from collections import namedtuple from dataclasses import dataclass, fields -from typing import Optional +from pathlib import Path +from typing import Iterator, Optional import h5py import numpy as np import torch import torch.nn.functional as F -from dartsort.templates import template_util from dartsort.util import drift_util from dartsort.util.multiprocessing_util import get_pool from scipy.spatial import KDTree from scipy.spatial.distance import pdist from tqdm.auto import tqdm -# todo: extend this code to also handle computing pairwise -# stuff necessary for the merge! ie shifts, scaling, -# residnorm(a,b) (or min of rn(a,b),rn(b,a)???) +from . import template_util, templates -def sparse_pairwise_conv( +def compressed_convolve_to_h5( output_hdf5_filename, - geom, - template_data, - template_temporal_components, - template_upsampled_temporal_components, - template_singular_values, - template_spatial_components, - chunk_time_centers_s=None, + template_data: templates.TemplateData, + low_rank_templates: template_util.LowRankTemplates, + compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, + chunk_time_centers_s: Optional[np.ndarray] = None, motion_est=None, - conv_ignore_threshold: float = 0.0, - coarse_approx_error_threshold: float = 0.0, - min_channel_amplitude: float = 1.0, - units_per_chunk=8, + geom: Optional[np.ndarray] = None, + reg_geom: Optional[np.ndarray] = None, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + conv_batch_size=128, + units_batch_size=8, overwrite=False, - show_progress=True, device=None, n_jobs=0, + show_progress=True, ): - """ + """Convolve all pairs of templates and store result in a .h5 - Arguments - --------- - template_* : tensors or arrays - template SVD approximations - conv_ignore_threshold: float = 0.0 - pairs will be ignored (i.e., pconv set to 0) if their pconv - does not exceed this value - coarse_approx_error_threshold: float = 0.0 - superres will not be used if coarse pconv and superres pconv - are uniformly closer than this threshold value + See pairwise.CompressedPairwiseConvDB for how to read the + resulting convolutions back. - Returns - ------- - pitch_shifts : array - array of all the pitch shifts - use searchsorted to find the pitch shift ix for a pitch shift - index_table: torch sparse tensor - index_table[(pitch shift ix a, superres label a, pitch shift ix b, superres label b)] = ( - 0 - if superres pconv a,b at these shifts was below the conv_ignore_threshold - else pconv_index) - pconvs: np.ndarray - pconv[pconv_index] is a cross-correlation of two templates, summed over chans + This runs compressed_convolve_pairs in a loop over chunks + of unit pairs, so that it's not all done in memory at one time, + and so that it can be done in parallel. """ if overwrite: - pass + pass # TODO - ( - n_templates, - spike_length_samples, - upsampling_factor, - ) = template_upsampled_temporal_components.shape[:3] - - # find all of the co-occurring pitch shift and template pairs - temp_shift_index = get_shift_and_unit_pairs( + # construct indexing helpers + template_shift_index = drift_util.get_shift_and_unit_pairs( chunk_time_centers_s, geom, template_data, motion_est=motion_est, ) + upsampled_shifted_template_index = get_upsampled_shifted_template_index( + template_shift_index, compressed_upsampled_temporal + ) - # check if the convolutions need to be drift-aware - # they do if we need to do any channel selection - is_drifting = not np.array_equal(temp_shift_index.all_pitch_shifts, [0]) - if template_data.registered_geom is not None: - is_drifting |= not np.array_equal(geom, template_data.registered_geom) - - # initialize pairwise conv data structures - # index_table[shifted_temp_ix(i), shifted_temp_ix(j)] = pconvix(i,j) - pair_index_table = np.zeros( - (temp_shift_index.n_shifted_templates, temp_shift_index.n_shifted_templates), - dtype=int, + chunk_res_iterator = iterate_compressed_pairwise_convolutions( + template_data=template_data, + low_rank_templates=low_rank_templates, + compressed_upsampled_temporal=compressed_upsampled_temporal, + geom=geom, + reg_geom=reg_geom, + conv_ignore_threshold=conv_ignore_threshold, + coarse_approx_error_threshold=coarse_approx_error_threshold, + max_shift="full", + conv_batch_size=conv_batch_size, + units_batch_size=units_batch_size, + device=device, + n_jobs=n_jobs, + show_progress=show_progress, ) - upsampling_index_table = np.zeros( - (temp_shift_index.n_shifted_templates, temp_shift_index.n_shifted_templates), + + pconv_index = np.zeros( + ( + template_shift_index.n_shifted_templates, + upsampled_shifted_template_index.n_upsampled_shifted_templates, + ), dtype=int, ) - # pconvs[pconvix(i,j)] = (2*spikelen-1, upsampling_factor) arr of pconv(shifted_temp(i), shifted_temp(j)) - - cur_pair_ix = 1 - cur_pconv_ix = 1 + n_pconvs = 1 with h5py.File(output_hdf5_filename, "w") as h5: # resizeable pconv dataset + spike_length_samples = template_data.templates.shape[1] pconv = h5.create_dataset( "pconv", dtype=np.float32, - shape=(1, upsampling_factor, 2 * spike_length_samples - 1), - maxshape=(None, upsampling_factor, 2 * spike_length_samples - 1), - chunks=(128, upsampling_factor, 2 * spike_length_samples - 1), + shape=(1, 2 * spike_length_samples - 1), + maxshape=(None, 2 * spike_length_samples - 1), + chunks=(128, 2 * spike_length_samples - 1), ) - # pconv[0] is special -- it is 0. - pconv[0] = 0.0 - - # res is a ConvBatchResult - for res in compute_pairwise_convs( - template_data, - template_spatial_components, - template_singular_values, - template_temporal_components, - template_upsampled_temporal_components, - temp_shift_index.shifted_temp_ix_to_temp_ix, - temp_shift_index.shifted_temp_ix_to_shift, - geom, - cooccurrence=temp_shift_index.cooccurrence, - conv_ignore_threshold=conv_ignore_threshold, - coarse_approx_error_threshold=coarse_approx_error_threshold, - min_channel_amplitude=min_channel_amplitude, - is_drifting=is_drifting, - units_per_chunk=units_per_chunk, - n_jobs=n_jobs, - device=device, - show_progress=show_progress, - max_shift="full", - store_conv=True, - compute_max=False, - ): - if res is None: + for chunk_res in chunk_res_iterator: + if chunk_res is None: continue - new_pair_ix = res.pair_ix + cur_pair_ix - pair_index_table[res.shifted_temp_ix_a, res.shifted_temp_ix_b] = new_pair_ix - new_pconv_ix = res.pconv_ix + cur_pconv_ix - upsampling_index_table[new_pair_ix, res.upsampling_ix] = new_pconv_ix + # get shifted template indices for A + shifted_temp_ix_a = template_shift_index.template_shift_index[ + chunk_res.template_indices_a, + chunk_res.shift_indices_a, + ] - pconv.resize(cur_pconv_ix + res.cconv_up.shape[0], axis=0) - pconv[cur_pconv_ix:] = res.cconv_up - cur_pconv_ix += res.cconv_up.shape[0] + # upsampled shifted template indices for B + up_shifted_temp_ix_b = upsampled_shifted_template_index.upsampled_shifted_template_index[ + chunk_res.template_indices_b, + chunk_res.shift_indices_b, + chunk_res.upsampling_indices_b, + ] - # smaller datasets all at once - h5.create_dataset( - "template_shift_index", data=temp_shift_index.template_shift_index - ) - h5.create_dataset("pconv_index_table", data=pconv_index_table) - h5.create_dataset("shifts", data=temp_shift_index.all_pitch_shifts) - h5.create_dataset( - "shifted_temp_ix_to_temp_ix", - data=temp_shift_index.shifted_temp_ix_to_temp_ix, - ) - h5.create_dataset( - "shifted_temp_ix_to_shift", data=temp_shift_index.shifted_temp_ix_to_shift - ) - h5.create_dataset( - "shifted_temp_ix_to_unit", - data=template_data.unit_ids[temp_shift_index.shifted_temp_ix_to_temp_ix], - ) + # store new set of indices + new_pconv_indices = chunk_res.compression_index + n_pconvs + pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b] = new_pconv_indices - return output_hdf5_filename + # store new pconvs + n_new_pconvs = chunk_res.compressed_conv.shape[0] + pconv.resize(n_pconvs + n_new_pconvs, axis=0) + pconv[n_pconvs:] = chunk_res.pconv + n_pconvs += n_new_pconvs -# -- main general worker function + # write fixed size outputs + h5.create_dataset("shifts", data=template_shift_index.all_pitch_shifts) + h5.create_dataset("shifted_template_index", data=template_shift_index.template_shift_index) + h5.create_dataset("upsampled_shifted_template_index", data=upsampled_shifted_template_index.upsampled_shifted_template_index) + h5.create_dataset("pconv_index", data=pconv_index) + + return output_hdf5_filename -def compute_pairwise_convs( - template_data, - spatial, - singular, - temporal, - temporal_up, - shifted_temp_ix_to_temp_ix, - shifted_temp_ix_to_shift, - geom, - cooccurrence, +def iterate_compressed_pairwise_convolutions( + template_data: templates.TemplateData, + low_rank_templates: template_util.LowRankTemplates, + compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, + template_shift_index: drift_util.TemplateShiftIndex, + upsampled_shifted_template_index: UpsampledShiftedTemplateIndex, + geom: Optional[np.ndarray] = None, + reg_geom: Optional[np.ndarray] = None, conv_ignore_threshold=0.0, coarse_approx_error_threshold=0.0, - min_channel_amplitude=1.0, max_shift="full", - is_drifting=True, - store_conv=True, - compute_max=False, - units_per_chunk=8, - n_jobs=0, + conv_batch_size=128, + units_batch_size=8, device=None, + n_jobs=0, show_progress=True, -): - # chunk up coarse unit ids, go by pairs of chunks +) -> Iterator[Optional[CompressedConvResult]]: + """A generator of CompressedConvResults capturing all pairs of templates + + + Runs the function compressed_convolve_pairs on chunks of units. + + This is a helper function for parallelizing computation of cross correlations + between pairs of templates. There are too many to store all the results in + memory, so this is a generator yielding a chunk at a time. Callers may + process the results differently. + """ + # construct drift-related helper data if needed + n_shifts = template_shift_index.all_pitch_shifts.size + do_shifting = n_shifts > 1 + geom_kdtree = reg_geom_kdtree = match_distance = None + if do_shifting: + geom_kdtree = KDTree(geom) + reg_geom_kdtree = KDTree(reg_geom) + match_distance = pdist(geom).min() / 2 + + # make chunks units = np.unique(template_data.unit_ids) jobs = [] - for start_a in range(0, units.size, units_per_chunk): - end_a = min(start_a + units_per_chunk, units.size) - for start_b in range(start_a, units.size, units_per_chunk): - end_b = min(start_b + units_per_chunk, units.size) + for start_a in range(0, units.size, units_batch_size): + end_a = min(start_a + units_batch_size, units.size) + for start_b in range(start_a, units.size, units_batch_size): + end_b = min(start_b + units_batch_size, units.size) jobs.append((units[start_a:end_a], units[start_b:end_b])) if show_progress: jobs = tqdm( jobs, smoothing=0.01, desc="Pairwise convolution", unit="pair block" ) - # compute the coarse templates if needed - if units.size == template_data.unit_ids.size: - # coarse templates are original templates - coarse_approx_error_threshold = 0 - if coarse_approx_error_threshold > 0: - coarse_templates = template_util.weighted_average( - template_data.unit_ids, template_data.templates, template_data.spike_counts - ) - ( - coarse_temporal, - coarse_singular, - coarse_spatial, - ) = template_util.svd_compress_templates( - coarse_templates, - rank=singular.shape[1], - min_channel_amplitude=min_channel_amplitude, - ) - - # template data to torch - spatial_singular = torch.as_tensor(spatial * singular[:, :, None]) - temporal = torch.as_tensor(temporal) - temporal_up = torch.as_tensor(temporal_up) - if coarse_approx_error_threshold > 0: - coarse_spatial_singular = torch.as_tensor( - coarse_spatial * coarse_singular[:, :, None] - ) - coarse_temporal = torch.as_tensor(coarse_temporal) - else: - coarse_spatial_singular = None - coarse_temporal = None - - n_jobs, Executor, context, rank_queue = get_pool(n_jobs, with_rank_queue=True) - - pconv_params = dict( - store_conv=store_conv, - compute_max=compute_max, - is_drifting=is_drifting, - max_shift=max_shift, + # worker kwargs + kwargs = dict( + template_data=template_data, + low_rank_templates=low_rank_templates, + compressed_upsampled_temporal=compressed_upsampled_temporal, + template_shift_index=template_shift_index, + upsampled_shifted_template_index=upsampled_shifted_template_index, + geom=geom, + reg_geom=reg_geom, + geom_kdtree=geom_kdtree, + reg_geom_kdtree=reg_geom_kdtree, + match_distance=match_distance, conv_ignore_threshold=conv_ignore_threshold, coarse_approx_error_threshold=coarse_approx_error_threshold, - spatial_singular=spatial_singular, - temporal=temporal, - temporal_up=temporal_up, - coarse_spatial_singular=coarse_spatial_singular, - coarse_temporal=coarse_temporal, - unit_ids=template_data.unit_ids, - shifted_temp_ix_to_shift=shifted_temp_ix_to_shift, - shifted_temp_ix_to_temp_ix=shifted_temp_ix_to_temp_ix, - shifted_temp_ix_to_unit=template_data.unit_ids[shifted_temp_ix_to_temp_ix], - cooccurrence=cooccurrence, - geom=geom, - registered_geom=template_data.registered_geom, + max_shift=max_shift, + batch_size=conv_batch_size, + device=device, ) + n_jobs, Executor, context, rank_queue = get_pool(n_jobs, with_rank_queue=True) with Executor( n_jobs, mp_context=context, - initializer=_pairwise_conv_init, - initargs=(device, rank_queue, pconv_params), + initializer=_conv_worker_init, + initargs=(rank_queue, device, kwargs), ) as pool: - yield from pool.map(_pairwise_conv_job, jobs) - - -# -- parallel job code + yield from pool.map(_conv_job, jobs) -# helper class which stores parameters for _pairwise_conv_job @dataclass -class PairwiseConvContext: - device: torch.device - - # parameters - store_conv: bool - compute_max: bool - is_drifting: bool - max_shift: int - conv_ignore_threshold: float - coarse_approx_error_threshold: float - - # superres registered templates - spatial_singular: torch.Tensor - temporal: torch.Tensor - temporal_up: torch.Tensor - coarse_spatial_singular: Optional[torch.Tensor] - coarse_temporal: Optional[torch.Tensor] - cooccurrence: torch.Tensor - - # template indexing helper arrays - unit_ids: np.ndarray - shifted_temp_ix_to_temp_ix: np.ndarray - shifted_temp_ix_to_shift: np.ndarray - shifted_temp_ix_to_unit: np.ndarray - - # only needed if is_drifting - geom: np.ndarray - registered_geom: np.ndarray - geom_kdtree: Optional[KDTree] - reg_geom_kdtree: Optional[KDTree] - match_distance: Optional[float] - - -_pairwise_conv_context = None - - -def _pairwise_conv_init( - device, - rank_queue, - kwargs, -): - global _pairwise_conv_context - - # figure out what device to work on - my_rank = rank_queue.get() - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - device = torch.device(device) - if device.type == "cuda" and device.index is None: - if torch.cuda.device_count() > 1: - device = torch.device("cuda", index=my_rank % torch.cuda.device_count()) +class CompressedConvResult: + """Return type of compressed_convolve_pairs - # handle string max_shift - max_shift = kwargs.pop("max_shift", "full") - t = kwargs["temporal"].shape[1] - if max_shift == "full": - max_shift = t - 1 - elif max_shift == "valid": - max_shift = 0 - elif max_shift == "same": - max_shift = t // 2 - kwargs["max_shift"] = max_shift - - kwargs["geom_kdtree"] = kwargs["reg_geom_kdtree"] = kwargs["match_distance"] = None - if kwargs["is_drifting"]: - kwargs["geom_kdtree"] = KDTree(kwargs["geom"]) - kwargs["reg_geom_kdtree"] = KDTree(kwargs["registered_geom"]) - kwargs["match_distance"] = pdist(kwargs["geom"]).min() / 2 - - _pairwise_conv_context = PairwiseConvContext(device=device, **kwargs) + After convolving a bunch of template pairs, some convolutions + may be zero. Let n_pairs be the number of nonzero convolutions. + We don't store the zero ones. + """ + # arrays of shape n_pairs, + # For each convolved pair, these document which templates were + # in the pair, what their relative shifts were, and what the + # upsampling was (we only upsample the RHS) + template_indices_a: np.ndarray + template_indices_b: np.ndarray + shift_indices_a: np.ndarray + shift_indices_b: np.ndarray + upsampling_indices_b: np.ndarray + + # another one of shape n_pairs + # maps a pair index to the corresponding convolution index + # some convolutions are duplicates, so this array contains + # many duplicate entries in the range 0, ..., n_convs-1 + compression_index: np.ndarray + + # this one has shape (n_convs, 2 * spike_length_samples - 1) + compressed_conv: np.ndarray + + +def compressed_convolve_pairs( + template_data: templates.TemplateData, + low_rank_templates: template_util.LowRankTemplates, + compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, + template_shift_index: drift_util.TemplateShiftIndex, + upsampled_shifted_template_index: UpsampledShiftedTemplateIndex, + geom: Optional[np.ndarray] = None, + reg_geom: Optional[np.ndarray] = None, + geom_kdtree: Optional[KDTree] = None, + reg_geom_kdtree: Optional[KDTree] = None, + match_distance: Optional[float] = None, + units_a: Optional[np.ndarray] = None, + units_b: Optional[np.ndarray] = None, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + max_shift="full", + batch_size=128, + device=None, +) -> Optional[CompressedConvResult]: + """Compute compressed pairwise convolutions between template pairs -@dataclass -class ConvBatchResult: - # arrays of length - shifted_temp_ix_a: np.ndarray - shifted_temp_ix_b: np.ndarray - # array of length such that the ith - # pair's array of upsampled convs is cconv_up[cconv_ix[i]] - cconv_ix: np.ndarray - cconv_up: Optional[np.ndarray] - max_conv: Optional[float] - best_shift: Optional[int] - - -def _pairwise_conv_job(unit_chunk): - global _pairwise_conv_context - p = _pairwise_conv_context + Takes as input all the template data and groups of pairs of units to convolve + (units_a,b). units_a,b are unit indices, not template indices (i.e., coarse + units, not superresolved bin indices). - units_a, units_b = unit_chunk + Returns compressed convolutions between all units_a[i], units_b[j], for all + shifts, superres templates, and upsamples. Some of these may be zero or may + be duplicates, so the return value is a sparse representation. See below. + """ + # what pairs, shifts, etc are we convolving? + shifted_temp_ix_a, temp_ix_a, shift_a, unit_a = handle_shift_indices( + units_a, template_data.unit_ids, template_shift_index + ) + shifted_temp_ix_b, temp_ix_b, shift_b, unit_b = handle_shift_indices( + units_b, template_data.unit_ids, template_shift_index + ) - # this job consists of pairs of coarse units - # lets get all shifted superres template indices corresponding to those pairs, - # and the template indices, pitch shifts, and coarse units while we're at it - shifted_temp_ix_a = np.flatnonzero(np.isin(p.shifted_temp_ix_to_unit, units_a)) - shifted_temp_ix_b = np.flatnonzero(np.isin(p.shifted_temp_ix_to_unit, units_b)) - temp_ix_a = p.shifted_temp_ix_to_temp_ix[shifted_temp_ix_a] - temp_ix_b = p.shifted_temp_ix_to_temp_ix[shifted_temp_ix_b] - shift_a = p.shifted_temp_ix_to_shift[shifted_temp_ix_a] - shift_b = p.shifted_temp_ix_to_shift[shifted_temp_ix_b] - unit_a = p.unit_ids[temp_ix_a] - unit_b = p.unit_ids[temp_ix_b] - spatial_a = p.spatial_singular[temp_ix_a] - spatial_b = p.spatial_singular[temp_ix_b] - - # get shifted spatial components - if p.is_drifting: - spatial_a = drift_util.get_waveforms_on_static_channels( - spatial_a, - p.registered_geom, - n_pitches_shift=shift_a, - registered_geom=p.geom, - target_kdtree=p.geom_kdtree, - match_distance=p.match_distance, - fill_value=0.0, - ) - spatial_b = drift_util.get_waveforms_on_static_channels( - spatial_b, - p.registered_geom, - n_pitches_shift=shift_b, - registered_geom=p.geom, - target_kdtree=p.geom_kdtree, - match_distance=p.match_distance, - fill_value=0.0, - ) + # get (shifted) spatial components * singular values + spatial_singular_a = get_shifted_spatial_singular( + temp_ix_a, + shift_a, + template_shift_index, + low_rank_templates, + geom=geom, + registered_geom=reg_geom, + geom_kdtree=geom_kdtree, + match_distance=match_distance, + device=device, + ) + spatial_singular_b = get_shifted_spatial_singular( + temp_ix_b, + shift_b, + template_shift_index, + low_rank_templates, + geom=geom, + registered_geom=reg_geom, + geom_kdtree=geom_kdtree, + match_distance=match_distance, + device=device, + ) - # Explanation of all of the indexing going on below. - # - pair_ix_a,b index pairs of templates with nonzero cross-correlation - # so, there are i=1,...,N pairs pair_ix_a[i], pair_ix_b[i] - # - pair_ix is a N-array of indices such that... - # - up_pconv[pair_ix[i]] = set of upsampled pconv(pair_ix_a[i], pair_ix_b[i]) - # - # these are normalized/deduplicated: pair_ix contains duplicate entries. - # conv_ix contains the unique entries. in particular, the pconv between - # pair_ix_a[i], pair_ix_b[i] is being computed as that between - # pair_ix_a[conv_ix[pair_ix[i]]] and pair_ix_b[conv_ix[pair_ix[i]]]. - # - # then, we also sparsify the temporal upsampling. - - # figure out pairs of templates to convolve - pair_ix_a, pair_ix_b, pair_ix, conv_ix, shift_diff = deduplicated_pairs( + # figure out pairs of shifted templates to convolve in a deduplicated way + pairs_ret = shift_deduplicated_pairs( shifted_temp_ix_a, shifted_temp_ix_b, - spatial_a, - spatial_b, + spatial_singular_a, + spatial_singular_b, temp_ix_a, temp_ix_b, shift_a=shift_a, shift_b=shift_b, - cooccurrence=p.cooccurrence, - conv_ignore_threshold=p.conv_ignore_threshold, - geom=p.geom, - registered_geom=p.registered_geom, - reg_geom_kdtree=p.reg_geom_kdtree, - match_distance=p.match_distance, - ) - - # to device - spatial_a = spatial_a.to(p.device) - spatial_b = spatial_b.to(p.device) - temporal_a = p.temporal[temp_ix_a].to(p.device) - temporal_up_b = p.temporal_up[temp_ix_b].to(p.device) - - # convolve valid pairs - conv_ix_a, conv_ix_b, up_pconv, pair_survived = ccorrelate_up( - spatial_a, - temporal_a, - spatial_b, - temporal_up_b, - conv_ignore_threshold=p.conv_ignore_threshold, - max_shift=p.max_shift, - pair_ix_a=pair_ix_a[conv_ix], - pair_ix_b=pair_ix_b[conv_ix], + template_shift_index=template_shift_index, + conv_ignore_threshold=conv_ignore_threshold, + geom=geom, + registered_geom=reg_geom, + reg_geom_kdtree=reg_geom_kdtree, + match_distance=match_distance, ) - if conv_ix_a is None: - return None - nco = conv_ix_a.numel() - if not nco: + if pairs_ret is None: return None + ix_a, ix_b, compression_index, conv_ix = pairs_ret - pair_ix = pair_ix[pair_survived] - pair_ix_a = pair_ix_a[pair_survived] - pair_ix_b = pair_ix_b[pair_survived] - - # # summarize units by coarse pconv when possible - # if p.coarse_approx_error_threshold > 0: - # pconv, cconv_ix = _coarse_approx( - # pconv, cconv_ix, conv_ix_a, conv_ix_b, unit_a, unit_b, p - # ) - - # for use in deconv residual distance merge - # TODO: actually probably need to do the real objective here with - # scaling. only need to do that bc of scaling right? - # makes it kind of a pain, because then we need to go pairwise - # (deconv objective is not symmetric) - max_conv = best_shift = None - if p.compute_max: - cconv_ = pconv.reshape(nco, pconv.shape[1] * pconv.shape[2]) - max_conv, max_index = cconv_.max(dim=1) - max_up, max_sample = np.unravel_index( - max_index.numpy(force=True), shape=pconv.shape[1:] - ) - best_shift = max_sample - (p.max_shift + 1) - # if upsample>half nup, round max shift up - best_shift += np.rint(max_up / pconv.shape[1]).astype(int) - - print(f"end {conv_ix_a.shape=}") - print(f"end {conv_ix_b.shape=}") - print(f"end {cconv_ix.shape=}") - print(f"end {pconv.shape=}") - - return ConvBatchResult( - shifted_temp_ix_a[pair_ix_a.numpy(force=True)], - shifted_temp_ix_b[pair_ix_b.numpy(force=True)], - cconv_ix, - pconv.numpy(force=True) if pconv is not None else None, - max_conv.numpy(force=True) if max_conv is not None else None, - best_shift, + # handle upsampling + # each pair will be duplicated by the b unit's number of upsampled copies + ( + ix_a, + ix_b, + compression_index, + conv_ix, + conv_upsampling_indices_b, + conv_temporal_components_up_b, + ) = compressed_upsampled_pairs( + ix_a, + ix_b, + compression_index, + conv_ix, + temp_ix_b, + shifted_temp_ix_b, + upsampled_shifted_template_index, + compressed_upsampled_temporal, ) - -# -- library code -# template index and shift pairs -# pairwise low-rank cross-correlation - - -@dataclass -class TemplateShiftIndex: - """Return value for get_shift_and_unit_pairs""" - - n_shifted_templates: int - # shift index -> shift - all_pitch_shifts: np.ndarray - # (template ix, shift index) -> shifted template index - template_shift_index: np.ndarray - # (shifted temp ix, shifted temp ix) -> did these appear at the same time - cooccurrence: np.ndarray - shifted_temp_ix_to_temp_ix: np.ndarray - shifted_temp_ix_to_shift: np.ndarray - - -def static_template_shift_index(n_templates): - temp_ixs = np.arange(n_templates) - return TemplateShiftIndex( - n_templates, - np.zeros(1), - temp_ixs[:, None], - np.ones((n_templates, n_templates), dtype=bool), - temp_ixs, - np.zeros_like(temp_ixs), + # # now, these arrays all have length n_pairs + # shifted_temp_ix_a = shifted_temp_ix_a[ix_a] + # temp_ix_a = temp_ix_a[ix_a] + # shift_a = shift_a[ix_a] + # shifted_temp_ix_b = shifted_temp_ix_b[ix_b] + # temp_ix_b = temp_ix_b[ix_b] + # shift_b = shift_b[ix_b] + + # run convolutions + temporal_a = low_rank_templates.temporal_components[temp_ix_a] + pconv, kept = correlate_pairs_lowrank( + spatial_singular_a[ix_a[conv_ix]].to(device), + spatial_singular_b[ix_b[conv_ix]].to(device), + temporal_a[ix_a[conv_ix]].to(device), + conv_temporal_components_up_b.to(device), + max_shift=max_shift, + conv_ignore_threshold=conv_ignore_threshold, + batch_size=batch_size, ) - - -def get_shift_and_unit_pairs( - chunk_time_centers_s, - geom, - template_data, - motion_est=None, -): - n_templates = len(template_data.templates) - if motion_est is None: - # no motion case - return static_template_shift_index(n_templates) - - # all observed pitch shift values - all_pitch_shifts = np.empty(shape=(0,), dtype=int) - temp_ixs = np.arange(n_templates) - # set of (template idx, shift) - template_shift_pairs = np.empty(shape=(0, 2), dtype=int) - pitch = drift_util.get_pitch(geom) - - for t_s in chunk_time_centers_s: - # see the fn `templates_at_time` - unregistered_depths_um = drift_util.invert_motion_estimate( - motion_est, t_s, template_data.registered_template_depths_um - ) - pitch_shifts = drift_util.get_spike_pitch_shifts( - depths_um=template_data.registered_template_depths_um, - pitch=pitch, - registered_depths_um=unregistered_depths_um, - ) - pitch_shifts = pitch_shifts.astype(int) - - # get unique pitch/unit shift pairs in chunk - template_shift = np.c_[temp_ixs, pitch_shifts] - - # update full set - all_pitch_shifts = np.union1d(all_pitch_shifts, pitch_shifts) - template_shift_pairs = np.unique( - np.concatenate((template_shift_pairs, template_shift), axis=0), axis=0 - ) - - n_shifts = len(all_pitch_shifts) - n_template_shift_pairs = len(template_shift_pairs) - - # index template/shift pairs: template_shift_index[template_ix, shift_ix] = shifted template index - # fill with an invalid index - template_shift_index = np.full((n_templates, n_shifts), n_template_shift_pairs) - shift_ix = np.searchsorted(all_pitch_shifts, template_shift_pairs[:, 1]) - assert np.array_equal(all_pitch_shifts[shift_ix], template_shift_pairs[:, 1]) - template_shift_index[template_shift_pairs[:, 0], shift_ix] = np.arange( - n_template_shift_pairs + if not kept.size: + return None + kept_pairs = np.isin(conv_ix[compression_index], conv_ix[kept]) + conv_ix = conv_ix[kept] + compression_index = compression_index[kept_pairs] + ix_a = ix_a[kept_pairs] + ix_b = ix_b[kept_pairs] + # compression_index = compression_index[kept] + pconv = pconv.cpu() + + # coarse approx + pconv, old_ix_to_new_ix = coarse_approximate( + pconv, + unit_a[ix_a[conv_ix]], + unit_b[ix_b[conv_ix]], + temp_ix_a[ix_a[conv_ix]], + shift_a[ix_a[conv_ix]], + shift_b[ix_b[conv_ix]], + coarse_approx_error_threshold=coarse_approx_error_threshold, + ) + # above function invalidates the whole idea of conv_ix + del conv_ix + compression_index = old_ix_to_new_ix[compression_index] + + # recover metadata + temp_ix_a = temp_ix_a[ix_a] + shift_ix_a = np.searchsorted(template_shift_index.all_pitch_shifts, shift_a[ix_a]) + temp_ix_b = temp_ix_b[ix_b] + shift_ix_b = np.searchsorted(template_shift_index.all_pitch_shifts, shift_b[ix_b]) + + return CompressedConvResult( + template_indices_a=temp_ix_a, + template_indices_b=temp_ix_b, + shift_indices_a=shift_ix_a, + shift_indices_b=shift_ix_b, + upsampling_indices_b=conv_upsampling_indices_b[compression_index], + compression_index=compression_index, + compressed_conv=pconv.numpy(), ) - shifted_temp_ix_to_temp_ix = template_shift_pairs[:, 0] - shifted_temp_ix_to_shift = template_shift_pairs[:, 1] - - # co-occurrence matrix: do these shifted templates appear together? - cooccurrence = np.eye(n_template_shift_pairs, dtype=bool) - for t_s in chunk_time_centers_s: - unregistered_depths_um = drift_util.invert_motion_estimate( - motion_est, t_s, template_data.registered_template_depths_um - ) - pitch_shifts = drift_util.get_spike_pitch_shifts( - depths_um=template_data.registered_template_depths_um, - pitch=pitch, - registered_depths_um=unregistered_depths_um, - ) - pitch_shifts = pitch_shifts.astype(int) - pitch_shift_ix = np.searchsorted(all_pitch_shifts, pitch_shifts) - shifted_temp_ixs = template_shift_index[temp_ixs, pitch_shift_ix] - cooccurrence[shifted_temp_ixs[:, None], shifted_temp_ixs[None, :]] = 1 - return TemplateShiftIndex( - n_template_shift_pairs, - all_pitch_shifts, - template_shift_index, - cooccurrence, - shifted_temp_ix_to_temp_ix, - shifted_temp_ix_to_shift, - ) +# -- helpers -def ccorrelate_up( +def correlate_pairs_lowrank( spatial_a, - temporal_a, spatial_b, + temporal_a, temporal_b, - upsampling_compression_map=None, - conv_ignore_threshold=0.0, max_shift="full", - covisible_mask=None, - pair_ix_a=None, - pair_ix_b=None, + conv_ignore_threshold=0.0, batch_size=128, ): - """Convolve all pairs of low-rank templates + """Convolve pairs of low rank templates - This uses too much memory to run on all pairs at once. + For each i, we want to convolve (temporal_a[i] @ spatial_a[i]) with + (temporal_b[i] @ spatial_b[i]). So, spatial_{a,b} and temporal_{a,b} + should contain lots of duplicates, since they are already representing + pairs. Templates Ka = Sa Ta, Kb = Sb Tb. The channel-summed convolution is (Ka (*) Kb) = sum_c Ka(c) * Kb(c) = (Sb.T @ Ka) (*) Tb = (Sb.T @ Sa @ Ta) (*) Tb where * is cross-correlation, and (*) is channel (or rank) summed. - We use full-height conv2d to do rank-summed convs. - upsampling_compression_map - (n_templates, upsampling_factor) - - Returns ------- - covisible_a, covisible_b : tensors of indices - Both have shape (nco,), where nco is the number of templates - whose pairwise conv exceeds conv_ignore_threshold. - So, zip(covisible_a, covisible_b) is the set of co-visible pairs. - cconv : torch.Tensor - Shape is (nco, nup, 2 * max_shift + 1) - All cross-correlations for pairs of templates (templates in b - can be upsampled.) - If max_shift is full, then 2*max_shift+1=2t-1. + pconv, kept """ - na, rank, nchan = spatial_a.shape - nb, rank_, nchan_ = spatial_b.shape + n_pairs, rank, nchan = spatial_a.shape + n_pairs_, rank_, nchan_ = spatial_b.shape assert rank == rank_ assert nchan == nchan_ - na_, t, rank_ = temporal_a.shape - assert na == na_ + assert n_pairs == n_pairs_ + n_pairs_, t, rank_ = temporal_a.shape + assert n_pairs == n_pairs_ assert rank_ == rank - nb_, t_, nup, rank_ = temporal_b.shape - assert nb == nb_ + n_pairs_, t_, rank_ = temporal_b.shape + assert n_pairs == n_pairs_ assert t == t_ assert rank == rank_ - if covisible_mask is not None: - assert covisible_mask.shape == (na, nb) - - # no need to convolve templates which do not overlap enough - if pair_ix_a is None: - covisible = ( - torch.sqrt(torch.square(spatial_a).sum(1)) - @ torch.sqrt(torch.square(spatial_b).sum(1)).T - ) - covisible = covisible > conv_ignore_threshold - if covisible_mask is not None: - covisible *= covisible_mask - covisible_a, covisible_b = torch.nonzero(covisible, as_tuple=True) - else: - covisible_a, covisible_b = pair_ix_a, pair_ix_b - nco = covisible_a.numel() - if not nco: - return None, None, None - # batch over nco for memory reasons - cconv = torch.zeros( - (nco, nup, 2 * max_shift + 1), dtype=spatial_a.dtype, device=spatial_a.device + if max_shift == "full": + max_shift = t - 1 + elif max_shift == "valid": + max_shift = 0 + elif max_shift == "same": + max_shift = t // 2 + + # batch over n_pairs for memory reasons + pconv = torch.zeros( + (n_pairs, 2 * max_shift + 1), dtype=spatial_a.dtype, device=spatial_a.device ) - for istart in range(0, nco, batch_size): - iend = min(istart + batch_size, nco) - co_a = covisible_a[istart:iend] - co_b = covisible_b[istart:iend] - nco_ = iend - istart + for istart in range(0, n_pairs, batch_size): + iend = min(istart + batch_size, n_pairs) + ix = slice(istart, iend) # want conv filter: nco, 1, rank, t - template_a = torch.bmm(temporal_a, spatial_a) - conv_filt = torch.bmm(spatial_b[co_b], template_a[co_a].mT) + template_a = torch.bmm(temporal_a[ix], spatial_a[ix]) + conv_filt = torch.bmm(spatial_b[ix], template_a.mT) conv_filt = conv_filt[:, None] # (nco, 1, rank, t) # nup, nco, rank, t - conv_in = temporal_b[co_b].permute(2, 0, 3, 1) + conv_in = temporal_b[ix].permute(2, 0, 3, 1) # conv2d: # depthwise, chans=nco. batch=1. h=rank. w=t. out: nup, nco, 1, 2p+1. # input (conv_in): nup, nco, rank, t. # filters (conv_filt): nco, 1, rank, t. (groups=nco). - cconv_ = F.conv2d(conv_in, conv_filt, padding=(0, max_shift), groups=nco_) - cconv[istart:iend] = cconv_[:, :, 0, :].permute(1, 0, 2) # nco, nup, time + pconv_ = F.conv2d( + conv_in, conv_filt, padding=(0, max_shift), groups=iend - istart + ) + pconv[istart:iend] = pconv_[:, :, 0, :].permute(1, 0, 2) # nco, nup, time # more stringent covisibility - pair_survived = slice(None) + kept = slice(None) if conv_ignore_threshold > 0: - max_val = cconv.reshape(nco, -1).abs().max(dim=1).values - pair_survived = max_val > conv_ignore_threshold - cconv = cconv[pair_survived] - covisible_a = covisible_a[pair_survived] - covisible_b = covisible_b[pair_survived] + max_val = pconv.reshape(n_pairs, -1).abs().max(dim=1).values + kept = max_val > conv_ignore_threshold + pconv = pconv[kept] + kept = np.flatnonzero(kept.numpy(force=True)) - return covisible_a, covisible_b, cconv, pair_survived + return pconv, kept -# -- helpers +def handle_shift_indices(units, unit_ids, template_shift_index): + shifted_temp_ix_to_unit = unit_ids[template_shift_index.shifted_temp_ix_to_temp_ix] + if units is None: + shifted_temp_ix = np.arange(template_shift_index.n_shifted_templates) + else: + shifted_temp_ix = np.flatnonzero(np.isin(shifted_temp_ix_to_unit, units)) + + shift = template_shift_index.shifted_temp_ix_to_shift[shifted_temp_ix] + temp_ix = template_shift_index.shifted_temp_ix_to_temp_ix[shifted_temp_ix] + unit = unit_ids[temp_ix] + + return shifted_temp_ix, temp_ix, shift, unit + + +def get_shifted_spatial_singular( + temp_ix, + shift, + template_shift_index, + low_rank_templates, + geom=None, + registered_geom=None, + geom_kdtree=None, + match_distance=None, + device=None, +): + # do we need to shift the templates? + n_shifts = template_shift_index.all_pitch_shifts.size + do_shifting = n_shifts > 1 + + spatial_singular = ( + low_rank_templates.spatial_components[temp_ix] + * low_rank_templates.singular_values[temp_ix][..., None] + ) + if do_shifting: + spatial_singular = drift_util.get_waveforms_on_static_channels( + spatial_singular, + registered_geom, + n_pitches_shift=shift, + registered_geom=geom, + target_kdtree=geom_kdtree, + match_distance=match_distance, + fill_value=0.0, + ) + spatial_singular = torch.as_tensor(spatial_singular, device=device) + return spatial_singular -def deduplicated_pairs( + +def shift_deduplicated_pairs( shifted_temp_ix_a, shifted_temp_ix_b, spatialsing_a, @@ -727,7 +543,7 @@ def deduplicated_pairs( temp_ix_b, shift_a=None, shift_b=None, - cooccurrence=None, + template_shift_index=None, conv_ignore_threshold=0.0, geom=None, registered_geom=None, @@ -755,18 +571,16 @@ def deduplicated_pairs( Size < original number of shifted templates a,b The indices of shifted templates which overlap enough to be co-visible. So, these are subsets of shifted_temp_ix_a,b - dedup_ix + compression_index Size == pair_ix_a,b size Subsets of conv_ix_a,b, so that the xcorr of templates shifted_temp_ix_a[pair_ix_a[i]], shifted_temp_ix_b[pair_ix_b[i]] is the same as that of - shifted_temp_ix_a[conv_ix[dedup_ix[i]], conv_ix[dedup_ix[i]]] + shifted_temp_ix_a[pair_ix_a[conv_ix[compression_index[i]]], + pair_ix_b[conv_ix[compression_index[i]]] conv_ix Size < original number of shifted templates a,b Pairs of templates which should actually be convolved - shift_diff - Optional. if not None, same size as pair_ix_a - shift_a - shift_b for this pair """ # check spatially overlapping chan_amp_a = torch.sqrt(torch.square(spatialsing_a).sum(1)) @@ -775,19 +589,21 @@ def deduplicated_pairs( pair = pair > conv_ignore_threshold # co-occurrence - if cooccurrence is not None: - pair *= cooccurrence + pair *= template_shift_index.cooccurrence # mask out lower triangle pair *= shifted_temp_ix_a[:, None] <= shifted_temp_ix_b[None, :] pair_ix_a, pair_ix_b = torch.nonzero(pair, as_tuple=True) nco = pair_ix_a.numel() + if not nco: + return None # if no shifting, deduplication is the identity - if shift_a is None: + do_shifting = template_shift_index.all_pitch_shifts.size > 1 + if not do_shifting: assert shift_b is None nco_range = torch.arange(nco, device=pair_ix_a.device) - return pair_ix_a, pair_ix_b, nco_range, nco_range, None + return pair_ix_a, pair_ix_b, nco_range, nco_range # shift deduplication. algorithm: # 1 for each shifted template, determine the set of registered channels @@ -838,65 +654,305 @@ def deduplicated_pairs( shift_diff, ] # conv_ix: indices of unique determiners - # dedup_ix: which representative does each pair belong to - _, conv_ix, dedup_ix = np.unique( + # compression_index: which representative does each pair belong to + _, conv_ix, compression_index = np.unique( conv_determiners, axis=0, return_index=True, return_inverse=True ) - return pair_ix_a, pair_ix_b, dedup_ix, conv_ix, shift_diff - - - -def _coarse_approx(cconv, cconv_ix, conv_ix_a, conv_ix_b, unit_a, unit_b, p): - # figure out coarse templates to correlate - conv_ix_a = conv_ix_a.cpu() - conv_ix_b = conv_ix_b.cpu() - conv_unit_a = unit_a[conv_ix_a] - conv_unit_b = unit_b[conv_ix_b] - coarse_units_a = np.unique(conv_unit_a) - coarse_units_b = np.unique(conv_unit_b) - coarsecovis = np.zeros((coarse_units_a.size, coarse_units_b.size), dtype=bool) - coarsecovis[ - np.searchsorted(coarse_units_a, conv_unit_a), - np.searchsorted(coarse_units_b, conv_unit_b), - ] = True - - # correlate them - coarse_ix_a, coarse_ix_b, coarse_cconv = ccorrelate_up( - p.coarse_spatial_singular[coarse_units_a].to(p.device), - p.coarse_temporal[coarse_units_a].to(p.device), - p.coarse_spatial_singular[coarse_units_b].to(p.device), - p.coarse_temporal[coarse_units_b].unsqueeze(2).to(p.device), - conv_ignore_threshold=p.conv_ignore_threshold, - max_shift=p.max_shift, - covisible_mask=torch.as_tensor(coarsecovis, device=p.device), + return pair_ix_a, pair_ix_b, compression_index, conv_ix + + +UpsampledShiftedTemplateIndex = namedtuple( + "UpsampledShiftedTemplateIndex", + [ + "n_upsampled_shifted_templates", + "upsampled_shifted_template_index", + "up_shift_temp_ix_to_shift_temp_ix", + "up_shift_temp_ix_to_temp_ix", + "up_shift_temp_ix_to_comp_up_ix", + ], +) + + +def get_upsampled_shifted_template_index( + template_shift_index, compressed_upsampled_temporal +): + """Make a compressed index space for upsampled shifted templates + + See also: template_util.{compressed_upsampled_templates,ComptessedUpsampledTemplates}. + + The comp_up_ix / compressed upsampled template indices here are indices into that + structure. + + Returns + ------- + UpsampledShiftedTemplateIndex + named tuple with fields: + upsampled_shifted_template_index : (n_templates, n_shifts, up_factor) + Maps template_ix, shift_ix, up_ix -> compressed upsampled template index + up_shift_temp_ix_to_shift_temp_ix + up_shift_temp_ix_to_temp_ix + up_shift_temp_ix_to_comp_up_ix + """ + n_shifted_templates = template_shift_index.n_shifted_templates + n_templates, n_shifts = template_shift_index.template_shift_index.shape + max_upsample = compressed_upsampled_temporal.compressed_usampling_map.shape[1] + + cur_up_shift_temp_ix = 0 + # fill with an invalid index + upsampled_shifted_template_index = np.full( + (n_templates, n_shifts, max_upsample), n_shifted_templates * max_upsample ) - if coarse_ix_a is None: - return cconv, cconv_ix - - coarse_units_a = np.atleast_1d(coarse_units_a[coarse_ix_a.cpu()]) - coarse_units_b = np.atleast_1d(coarse_units_b[coarse_ix_b.cpu()]) - - # find coarse units which well summarize the fine cconvs - for coarse_unit_a, coarse_unit_b, conv in zip( - coarse_units_a, coarse_units_b, coarse_cconv - ): - # check good approx. if not, continue - in_pair = np.flatnonzero( - (conv_unit_a == coarse_unit_a) & (conv_unit_b == coarse_unit_b) + usti2sti = [] + usti2ti = [] + usti2cui = [] + for i in range(n_templates): + shifted_temps = template_shift_index.template_shift_index[i] + valid_shifts = np.flatnonzero(shifted_temps < n_shifted_templates) + + upsampled_temps = compressed_upsampled_temporal.compressed_usampling_map[i] + unique_comp_up_inds, inverse = np.unique(upsampled_temps, return_inverse=True) + + for j in valid_shifts: + up_shift_inds = unique_comp_up_inds + cur_up_shift_temp_ix + upsampled_shifted_template_index[i, j] = up_shift_inds[inverse] + cur_up_shift_temp_ix += up_shift_inds.size + + usti2sti.extend([shifted_temps[j]] * up_shift_inds.size) + usti2ti.extend([i] * up_shift_inds.size) + usti2cui.extend(unique_comp_up_inds) + + up_shift_temp_ix_to_shift_temp_ix = np.array(usti2sti) + up_shift_temp_ix_to_temp_ix = np.array(usti2ti) + up_shift_temp_ix_to_comp_up_ix = np.array(usti2cui) + + return UpsampledShiftedTemplateIndex( + up_shift_temp_ix_to_shift_temp_ix.size, + upsampled_shifted_template_index, + up_shift_temp_ix_to_shift_temp_ix, + up_shift_temp_ix_to_temp_ix, + up_shift_temp_ix_to_comp_up_ix, + ) + + +def compressed_upsampled_pairs( + ix_a, + ix_b, + compression_index, + conv_ix, + temp_ix_b, + shifted_temp_ix_b, + upsampled_shifted_template_index, + compressed_upsampled_temporal, +): + """Add in upsampling to the set of pairs that need to be convolved + + So far, ix_a,b, compression_index, and conv_ix are such that non-upsampled + convolutions between templates ix_a[i], ix_b[i] equal that between templates + ix_a[conv_ix[compression_index[i]]], ix_b[conv_ix[compression_index[i]]]. + + We will upsample the templates in the RHS (b) in a compressed way. + """ + up_factor = compressed_upsampled_temporal.compressed_usampling_map.shape[1] + if up_factor == 1: + upinds = np.zeros(conv_ix.size, dtype=int) + temp_comps = compressed_upsampled_temporal.compressed_upsampled_templates[ + temp_ix_b[ix_b[conv_ix]] + ] + return ix_a, ix_b, compression_index, conv_ix, upinds, temp_comps + + # each conv_ix needs to be duplicated as many times as its b template has + # upsampled copies. And, all ix_{a,b}[i] such that compression_ix[i] lands in + # that conv_ix need to be duplicated as well. + ix_a_up = [] + ix_b_up = [] + compression_index_up = [] + conv_ix_up = [] + conv_compressed_upsampled_ix = [] + cur_dedup_ix = 0 + for i, convi in enumerate(conv_ix): + # get b's shifted template ix + conv_shifted_temp_ix_b = shifted_temp_ix_b[ix_b[convi]] + + # which compressed upsampled indices match this? + which_up = np.flatnonzero( + upsampled_shifted_template_index.up_shift_temp_ix_to_shift_temp_ix + == conv_shifted_temp_ix_b + ) + conv_comp_up_ix = ( + upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[which_up] + ) + + # which deduplication indices map ix_a,b to this convi? + which_dedup = np.flatnonzero(compression_index == i) + + # extend arrays with new indices + nupi = conv_comp_up_ix.size + ix_a_up.extend(np.repeat(ix_a[which_dedup], nupi)) + ix_b_up.extend(np.repeat(ix_b[which_dedup], nupi)) + conv_ix_up.extend([convi] * nupi) + compression_index_up.extend( + np.tile(np.arange(cur_dedup_ix, cur_dedup_ix + nupi), which_dedup.size) ) - assert in_pair.size - fine_cconvs = cconv[cconv_ix[in_pair]] - approx_err = (fine_cconvs - conv[None]).abs().max() - if not approx_err < p.coarse_approx_error_threshold: - continue - - # replace first fine cconv with the coarse cconv - cconv[cconv_ix[in_pair[0]]] = conv - # set all fine cconv ix to the index of that first one - cconv_ix[in_pair] = cconv_ix[in_pair[0]] - - # re-index and subset cconvs - cconv_ix_subset, new_cconv_ix = np.unique(cconv_ix, return_inverse=True) - cconv = cconv[cconv_ix_subset] - return cconv, new_cconv_ix + cur_dedup_ix += nupi + conv_compressed_upsampled_ix.extend(conv_comp_up_ix) + + ix_a_up = np.array(ix_a_up) + ix_b_up = np.array(ix_b_up) + compression_index_up = np.array(compression_index_up) + conv_ix_up = np.array(conv_ix_up) + conv_compressed_upsampled_ix = np.array(conv_compressed_upsampled_ix) + + # which upsamples and which templates? + conv_upsampling_indices_b = ( + compressed_upsampled_temporal.compressed_index_to_upsampling_index[ + conv_compressed_upsampled_ix + ] + ) + conv_temporal_components_up_b = ( + compressed_upsampled_temporal.compressed_index_to_upsampling_index[ + conv_compressed_upsampled_ix + ] + ) + + return ( + ix_a_up, + ix_b_up, + compression_index_up, + conv_ix_up, + conv_upsampling_indices_b, + conv_temporal_components_up_b, + ) + + +def coarse_approximate( + pconv, + units_a, + units_b, + temp_ix_a, + shift_a, + shift_b, + coarse_approx_error_threshold=0.0, +): + """Try to replace fine (superres+temporally upsampled) convs with coarse ones + + For each pair of convolved units, we first try to replace all of the pairwise + convolutions between these units with their mean, respecting the shifts. + + If that fails, we try to do this in a factorized way: for each superres unit a, + try to replace all of its convolutions with unit b with their mean, respecting + the shifts. + + Above, "respecting the shifts" means we only do this within each shift-deduplication + class, since changes in the sets of channels being convolved cause large changes + in the cross correlation. pconv has already been deduplicated with respect to + equivalent channel neighborhoods, so all that matters for that purpose is the + shift difference. + + This needs to tell the caller how to update its bookkeeping. + """ + new_pconv = [] + old_ix_to_new_ix = np.full(len(pconv), -1) + cur_new_ix = 0 + shift_diff = shift_a - shift_b + for ua in np.unique(units_a): + ina = np.flatnonzero(units_a == ua) + partners_b = np.unique(units_b[ina]) + for ub in partners_b: + inab = ina[units_b[ina] == ub] + dshift = shift_diff[inab] + for shift in np.unique(dshift): + inshift = inab[dshift == shift] + + convs = pconv[inshift] + meanconv = convs.mean(dim=0, keepdims=True) + if (convs - meanconv).abs().max() < coarse_approx_error_threshold: + # do something + new_pconv.append(meanconv) + old_ix_to_new_ix[inshift] = cur_new_ix + cur_new_ix += 1 + continue + # else: + # new_pconv.append(convs) + # old_ix_to_new_ix[inshift] = np.arange(cur_new_ix, cur_new_ix + inshift.size) + # cur_new_ix += inshift.size + + active_temp_a = temp_ix_a[inshift] + unique_active_temp_a = np.unique(active_temp_a) + if unique_active_temp_a.size == 1: + new_pconv.append(convs) + old_ix_to_new_ix[inshift] = np.arange( + cur_new_ix, cur_new_ix + inshift.size + ) + cur_new_ix += inshift.size + continue + + for tixa in unique_active_temp_a: + insup = active_temp_a == tixa + supconvs = convs[insup] + + meanconv = supconvs.mean(dim=0, keepdims=True) + if (convs - meanconv).abs().max() < coarse_approx_error_threshold: + new_pconv.append(meanconv) + old_ix_to_new_ix[insup] = cur_new_ix + cur_new_ix += 1 + else: + new_pconv.append(supconvs) + old_ix_to_new_ix[insup] = np.arange( + cur_new_ix, cur_new_ix + insup.size + ) + cur_new_ix += insup.size + + new_pconv = torch.cat(new_pconv) + return new_pconv, old_ix_to_new_ix + + +# -- parallelism helpers + + +@dataclass +class ConvWorkerContext: + template_data: templates.TemplateData + low_rank_templates: template_util.LowRankTemplates + compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates + template_shift_index: drift_util.TemplateShiftIndex + upsampled_shifted_template_index: UpsampledShiftedTemplateIndex + geom: Optional[np.ndarray] = None + reg_geom: Optional[np.ndarray] = None + geom_kdtree: Optional[KDTree] = None + reg_geom_kdtree: Optional[KDTree] = None + match_distance: Optional[float] = None + conv_ignore_threshold = 0.0 + coarse_approx_error_threshold = 0.0 + max_shift = "full" + batch_size = 128 + device = None + + +_conv_worker_context = None + + +def _conv_worker_init(rank_queue, device, kwargs): + global _conv_worker_context + + my_rank = rank_queue.get() + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + if device.type == "cuda" and device.index is None: + if torch.cuda.device_count() > 1: + device = torch.device("cuda", index=my_rank % torch.cuda.device_count()) + + _conv_worker_context = ConvWorkerContext(device=device, **kwargs) + + +def _conv_job(unit_chunk): + global _pairwise_conv_context + units_a, units_b = unit_chunk + return compressed_convolve_pairs( + units_a=units_a, units_b=units_b, **asdict_shallow(_pairwise_conv_context) + ) + + +def asdict_shallow(obj): + return {field.name: getattr(obj, field.name) for field in fields(obj)} From c944ef7fbcec89e703068aa17f10e94f8bbc58ee Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 6 Nov 2023 14:19:55 -0500 Subject: [PATCH 16/49] Update template tests --- tests/test_templates.py | 41 +++++++++++++++-------------------------- 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/tests/test_templates.py b/tests/test_templates.py index 71fc44e9..3a5896f3 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -177,6 +177,7 @@ def test_pconv(): overlaps[(1, 2)] = overlaps[(2, 1)] = (temps[1] * temps[2]).sum() overlaps[(2, 3)] = overlaps[(3, 2)] = (temps[3] * temps[2]).sum() + print(f"--------- no drift") tdata = templates.TemplateData( templates=temps, unit_ids=np.array([0, 0, 1, 1, 2]), @@ -185,7 +186,6 @@ def test_pconv(): registered_template_depths_um=None, ) temp, sv, spat = template_util.svd_compress_templates(temps, rank=1) - print(f"{temp=} {sv=} {spat=}") tempup = temp.reshape(5, t, 1, 1) with tempfile.TemporaryDirectory() as tdir: @@ -200,6 +200,7 @@ def test_pconv(): ) pconvdb = pairwise.SparsePairwiseConv.from_h5(pconvdb_path) assert np.all(pconvdb.pconv[0] == 0) + print(f"{pconvdb.pconv.shape=}") for tixa in range(5): for tixb in range(5): @@ -217,10 +218,9 @@ def test_pconv(): # drifting version # rigid drift from -1 to 0 to 1, note pitch=1 # same templates but padded + print(f"--------- rigid drift") tempspad = np.pad(temps, [(0, 0), (0, 0), (1, 1)]) - print(f"{tempspad.shape=}") temp, sv, spat = template_util.svd_compress_templates(tempspad, rank=1) - print(f"{temp.shape=} {sv.shape=} {spat.shape=}") reg_geom = np.c_[np.zeros(c + 2), np.arange(c + 2).astype(float)] tdata = templates.TemplateData( templates=tempspad, @@ -233,28 +233,17 @@ def test_pconv(): motion_est = get_motion_estimate(time_bin_centers_s=np.array([0., 1, 2]), displacement=[-1., 0, 1]) # visualize shifted temps - for tix in range(5): - print("------------------") - print(f"{tix=}") - for shift in (-1, 0, 1): - spatial_shifted = drift_util.get_waveforms_on_static_channels( - spat[tix][None], - reg_geom, - n_pitches_shift=np.array([shift]), - registered_geom=geom, - fill_value=0.0, - ) - print(f"{shift=}") - print(f"{spatial_shifted=}") - - print() - print() - print('-=' * 30) - print('=-' * 30) - print('-=' * 30) - print('=-' * 30) - print() - print() + # for tix in range(5): + # for shift in (-1, 0, 1): + # spatial_shifted = drift_util.get_waveforms_on_static_channels( + # spat[tix][None], + # reg_geom, + # n_pitches_shift=np.array([shift]), + # registered_geom=geom, + # fill_value=0.0, + # ) + # print(f"{shift=}") + # print(f"{spatial_shifted=}") with tempfile.TemporaryDirectory() as tdir: pconvdb_path = pairwise.sparse_pairwise_conv( @@ -270,7 +259,7 @@ def test_pconv(): ) pconvdb = pairwise.SparsePairwiseConv.from_h5(pconvdb_path) assert np.all(pconvdb.pconv[0] == 0) - + print(f"{pconvdb.pconv.shape=}") print(f"{pconvdb.template_shift_index=}") for tixa in range(5): From 443b61a12ef6c69fc8ded107bc342832ee16dbd4 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 6 Nov 2023 15:23:08 -0500 Subject: [PATCH 17/49] Fix up tests --- src/dartsort/templates/pairwise.py | 3 ++ src/dartsort/templates/pairwise_util.py | 71 +++++++++++++------------ tests/test_templates.py | 45 ++++++++-------- 3 files changed, 63 insertions(+), 56 deletions(-) diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index 8ff9e534..6c958c72 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -60,6 +60,9 @@ def query( shifts_b=None, return_zero_convs=False, ): + template_indices_a = np.atleast_1d(template_indices_a) + template_indices_b = np.atleast_1d(template_indices_b) + # handle no shifting no_shifting = shifts_a is None or shifts_b is None shifted_template_index = self.shifted_template_index diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index a3cf3516..930c9b15 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -3,7 +3,7 @@ from collections import namedtuple from dataclasses import dataclass, fields from pathlib import Path -from typing import Iterator, Optional +from typing import Iterator, Optional, Union import h5py import numpy as np @@ -59,10 +59,15 @@ def compressed_convolve_to_h5( template_shift_index, compressed_upsampled_temporal ) + print(f"{template_shift_index=}") + print(f"{upsampled_shifted_template_index=}") + chunk_res_iterator = iterate_compressed_pairwise_convolutions( template_data=template_data, low_rank_templates=low_rank_templates, compressed_upsampled_temporal=compressed_upsampled_temporal, + template_shift_index=template_shift_index, + upsampled_shifted_template_index=upsampled_shifted_template_index, geom=geom, reg_geom=reg_geom, conv_ignore_threshold=conv_ignore_threshold, @@ -118,7 +123,7 @@ def compressed_convolve_to_h5( # store new pconvs n_new_pconvs = chunk_res.compressed_conv.shape[0] pconv.resize(n_pconvs + n_new_pconvs, axis=0) - pconv[n_pconvs:] = chunk_res.pconv + pconv[n_pconvs:] = chunk_res.compressed_conv n_pconvs += n_new_pconvs @@ -163,6 +168,8 @@ def iterate_compressed_pairwise_convolutions( do_shifting = n_shifts > 1 geom_kdtree = reg_geom_kdtree = match_distance = None if do_shifting: + assert geom is not None + assert reg_geom is not None geom_kdtree = KDTree(geom) reg_geom_kdtree = KDTree(reg_geom) match_distance = pdist(geom).min() / 2 @@ -196,7 +203,6 @@ def iterate_compressed_pairwise_convolutions( coarse_approx_error_threshold=coarse_approx_error_threshold, max_shift=max_shift, batch_size=conv_batch_size, - device=device, ) n_jobs, Executor, context, rank_queue = get_pool(n_jobs, with_rank_queue=True) @@ -351,18 +357,18 @@ def compressed_convolve_pairs( # run convolutions temporal_a = low_rank_templates.temporal_components[temp_ix_a] pconv, kept = correlate_pairs_lowrank( - spatial_singular_a[ix_a[conv_ix]].to(device), - spatial_singular_b[ix_b[conv_ix]].to(device), - temporal_a[ix_a[conv_ix]].to(device), - conv_temporal_components_up_b.to(device), + torch.as_tensor(spatial_singular_a[ix_a[conv_ix]]).to(device), + torch.as_tensor(spatial_singular_b[ix_b[conv_ix]]).to(device), + torch.as_tensor(temporal_a[ix_a[conv_ix]]).to(device), + torch.as_tensor(conv_temporal_components_up_b).to(device), max_shift=max_shift, conv_ignore_threshold=conv_ignore_threshold, batch_size=batch_size, ) - if not kept.size: - return None kept_pairs = np.isin(conv_ix[compression_index], conv_ix[kept]) conv_ix = conv_ix[kept] + if not conv_ix.size: + return None compression_index = compression_index[kept_pairs] ix_a = ix_a[kept_pairs] ix_b = ix_b[kept_pairs] @@ -463,17 +469,17 @@ def correlate_pairs_lowrank( conv_filt = torch.bmm(spatial_b[ix], template_a.mT) conv_filt = conv_filt[:, None] # (nco, 1, rank, t) - # nup, nco, rank, t - conv_in = temporal_b[ix].permute(2, 0, 3, 1) + # 1, nco, rank, t + conv_in = temporal_b[ix].mT[None] # conv2d: - # depthwise, chans=nco. batch=1. h=rank. w=t. out: nup, nco, 1, 2p+1. + # depthwise, chans=nco. batch=1. h=rank. w=t. out: nup=1, nco, 1, 2p+1. # input (conv_in): nup, nco, rank, t. # filters (conv_filt): nco, 1, rank, t. (groups=nco). pconv_ = F.conv2d( conv_in, conv_filt, padding=(0, max_shift), groups=iend - istart ) - pconv[istart:iend] = pconv_[:, :, 0, :].permute(1, 0, 2) # nco, nup, time + pconv[istart:iend] = pconv_[0, :, 0, :] # nco, nup, time # more stringent covisibility kept = slice(None) @@ -555,8 +561,6 @@ def shift_deduplicated_pairs( Some pairs of shifted templates don't overlap, so we don't need to convolve them. Some pairs of shifted templates never show up in the recording at the same time (what this code calls "cooccurrence"), so we don't need to convolve them. - We don't need to convolve the same pair of templates twice, just where the indices - are ordered (shifted_temp_ix_a <= shifted_temp_ix_b). More complicated: for each shift, a certain set of registered template channels survives. Given that the some set of visible channels has survived for a pair of @@ -591,8 +595,6 @@ def shift_deduplicated_pairs( # co-occurrence pair *= template_shift_index.cooccurrence - # mask out lower triangle - pair *= shifted_temp_ix_a[:, None] <= shifted_temp_ix_b[None, :] pair_ix_a, pair_ix_b = torch.nonzero(pair, as_tuple=True) nco = pair_ix_a.numel() if not nco: @@ -601,7 +603,6 @@ def shift_deduplicated_pairs( # if no shifting, deduplication is the identity do_shifting = template_shift_index.all_pitch_shifts.size > 1 if not do_shifting: - assert shift_b is None nco_range = torch.arange(nco, device=pair_ix_a.device) return pair_ix_a, pair_ix_b, nco_range, nco_range @@ -696,7 +697,7 @@ def get_upsampled_shifted_template_index( """ n_shifted_templates = template_shift_index.n_shifted_templates n_templates, n_shifts = template_shift_index.template_shift_index.shape - max_upsample = compressed_upsampled_temporal.compressed_usampling_map.shape[1] + max_upsample = compressed_upsampled_temporal.compressed_upsampling_map.shape[1] cur_up_shift_temp_ix = 0 # fill with an invalid index @@ -710,11 +711,11 @@ def get_upsampled_shifted_template_index( shifted_temps = template_shift_index.template_shift_index[i] valid_shifts = np.flatnonzero(shifted_temps < n_shifted_templates) - upsampled_temps = compressed_upsampled_temporal.compressed_usampling_map[i] + upsampled_temps = compressed_upsampled_temporal.compressed_upsampling_map[i] unique_comp_up_inds, inverse = np.unique(upsampled_temps, return_inverse=True) for j in valid_shifts: - up_shift_inds = unique_comp_up_inds + cur_up_shift_temp_ix + up_shift_inds = cur_up_shift_temp_ix + np.arange(unique_comp_up_inds.size) upsampled_shifted_template_index[i, j] = up_shift_inds[inverse] cur_up_shift_temp_ix += up_shift_inds.size @@ -753,9 +754,9 @@ def compressed_upsampled_pairs( We will upsample the templates in the RHS (b) in a compressed way. """ - up_factor = compressed_upsampled_temporal.compressed_usampling_map.shape[1] + up_factor = compressed_upsampled_temporal.compressed_upsampling_map.shape[1] if up_factor == 1: - upinds = np.zeros(conv_ix.size, dtype=int) + upinds = np.zeros(len(conv_ix), dtype=int) temp_comps = compressed_upsampled_temporal.compressed_upsampled_templates[ temp_ix_b[ix_b[conv_ix]] ] @@ -873,9 +874,11 @@ def coarse_approximate( cur_new_ix += 1 continue # else: + # # if we don't want the factorized thing... # new_pconv.append(convs) # old_ix_to_new_ix[inshift] = np.arange(cur_new_ix, cur_new_ix + inshift.size) # cur_new_ix += inshift.size + # continue active_temp_a = temp_ix_a[inshift] unique_active_temp_a = np.unique(active_temp_a) @@ -894,14 +897,14 @@ def coarse_approximate( meanconv = supconvs.mean(dim=0, keepdims=True) if (convs - meanconv).abs().max() < coarse_approx_error_threshold: new_pconv.append(meanconv) - old_ix_to_new_ix[insup] = cur_new_ix + old_ix_to_new_ix[inshift[insup]] = cur_new_ix cur_new_ix += 1 else: new_pconv.append(supconvs) - old_ix_to_new_ix[insup] = np.arange( - cur_new_ix, cur_new_ix + insup.size + old_ix_to_new_ix[inshift[insup]] = np.arange( + cur_new_ix, cur_new_ix + insup.sum() ) - cur_new_ix += insup.size + cur_new_ix += insup.sum() new_pconv = torch.cat(new_pconv) return new_pconv, old_ix_to_new_ix @@ -922,11 +925,11 @@ class ConvWorkerContext: geom_kdtree: Optional[KDTree] = None reg_geom_kdtree: Optional[KDTree] = None match_distance: Optional[float] = None - conv_ignore_threshold = 0.0 - coarse_approx_error_threshold = 0.0 - max_shift = "full" - batch_size = 128 - device = None + conv_ignore_threshold: float = 0.0 + coarse_approx_error_threshold: float = 0.0 + max_shift: Union[int, str] = "full" + batch_size: int = 128 + device: Optional[torch.device] = None _conv_worker_context = None @@ -947,10 +950,10 @@ def _conv_worker_init(rank_queue, device, kwargs): def _conv_job(unit_chunk): - global _pairwise_conv_context + global _conv_worker_context units_a, units_b = unit_chunk return compressed_convolve_pairs( - units_a=units_a, units_b=units_b, **asdict_shallow(_pairwise_conv_context) + units_a=units_a, units_b=units_b, **asdict_shallow(_conv_worker_context) ) diff --git a/tests/test_templates.py b/tests/test_templates.py index 3a5896f3..90e840ee 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -4,8 +4,8 @@ import numpy as np import spikeinterface.core as sc from dartsort import config -from dartsort.templates import (get_templates, pairwise, template_util, - templates) +from dartsort.templates import (get_templates, pairwise, pairwise_util, + template_util, templates) from dartsort.util import drift_util from dartsort.util.data_util import DARTsortSorting from dredge.motion_util import get_motion_estimate @@ -185,20 +185,23 @@ def test_pconv(): registered_geom=None, registered_template_depths_um=None, ) - temp, sv, spat = template_util.svd_compress_templates(temps, rank=1) - tempup = temp.reshape(5, t, 1, 1) + svd_compressed = template_util.svd_compress_templates(temps, rank=1) + ctempup = template_util.compressed_upsampled_templates( + svd_compressed.temporal_components, + ptps=temps.ptp(1).max(1), + max_upsample=1, + kind="cubic", + ) with tempfile.TemporaryDirectory() as tdir: - pconvdb_path = pairwise.sparse_pairwise_conv( + pconvdb_path = pairwise_util.compressed_convolve_to_h5( Path(tdir) / "test.h5", - geom, - tdata, - temp, - tempup, - sv, - spat, + geom=geom, + template_data=tdata, + low_rank_templates=svd_compressed, + compressed_upsampled_temporal=ctempup, ) - pconvdb = pairwise.SparsePairwiseConv.from_h5(pconvdb_path) + pconvdb = pairwise.CompressedPairwiseConv.from_h5(pconvdb_path) assert np.all(pconvdb.pconv[0] == 0) print(f"{pconvdb.pconv.shape=}") @@ -220,7 +223,7 @@ def test_pconv(): # same templates but padded print(f"--------- rigid drift") tempspad = np.pad(temps, [(0, 0), (0, 0), (1, 1)]) - temp, sv, spat = template_util.svd_compress_templates(tempspad, rank=1) + svd_compressed = template_util.svd_compress_templates(tempspad, rank=1) reg_geom = np.c_[np.zeros(c + 2), np.arange(c + 2).astype(float)] tdata = templates.TemplateData( templates=tempspad, @@ -246,21 +249,19 @@ def test_pconv(): # print(f"{spatial_shifted=}") with tempfile.TemporaryDirectory() as tdir: - pconvdb_path = pairwise.sparse_pairwise_conv( + pconvdb_path = pairwise_util.compressed_convolve_to_h5( Path(tdir) / "test.h5", - geom, - tdata, - temp, - tempup, - sv, - spat, + geom=geom, + reg_geom=reg_geom, + template_data=tdata, + low_rank_templates=svd_compressed, + compressed_upsampled_temporal=ctempup, motion_est=motion_est, chunk_time_centers_s=[0, 1, 2], ) - pconvdb = pairwise.SparsePairwiseConv.from_h5(pconvdb_path) + pconvdb = pairwise.CompressedPairwiseConv.from_h5(pconvdb_path) assert np.all(pconvdb.pconv[0] == 0) print(f"{pconvdb.pconv.shape=}") - print(f"{pconvdb.template_shift_index=}") for tixa in range(5): for tixb in range(5): From 2ecc274095e32d1e4547ea881b8d35c221287742 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 6 Nov 2023 15:41:14 -0500 Subject: [PATCH 18/49] Fix tests after merge --- src/dartsort/localize/localize_torch.py | 121 ++++++++++++------------ 1 file changed, 61 insertions(+), 60 deletions(-) diff --git a/src/dartsort/localize/localize_torch.py b/src/dartsort/localize/localize_torch.py index 667032ae..46f3dc44 100644 --- a/src/dartsort/localize/localize_torch.py +++ b/src/dartsort/localize/localize_torch.py @@ -69,8 +69,7 @@ def localize_amplitude_vectors( assert channel_index.shape == (n_channels_tot, c) assert main_channels.shape == (n_spikes,) # we'll return numpy if user sent numpy - is_numpy = not torch.is_tensor(amplitude_vectors) - + is_numpy = not torch.is_tensor(amplitude_vectors) # handle channel subsetting if radius is not None or n_channels_subset is not None: @@ -114,26 +113,22 @@ def localize_amplitude_vectors( if model == "com": z_abs_com = zcom + geom[main_channels, 1] - nancom = torch.full_like(xcom, torch.nan) - return dict( - x=xcom, y=nancom, z_rel=zcom, z_abs=z_abs_com, alpha=nancom - ) + return dict(x=xcom, z_rel=zcom, z_abs=z_abs_com) # normalized PTP vectors # this helps to keep the objective in a similar range, so we can use # fixed constants in regularizers like the log barrier max_amplitudes = torch.max(amplitude_vectors, dim=1).values normalized_amp_vecs = amplitude_vectors / max_amplitudes[:, None] - + # -- torch optimize + if levenberg_marquardt_kwargs is None: + levenberg_marquardt_kwargs = {} + # initialize with center of mass locs = torch.column_stack((xcom, torch.full_like(xcom, y0), zcom)) - if model == "pointsource": - - if levenberg_marquardt_kwargs is None: - levenberg_marquardt_kwargs = {} locs, i = batched_levenberg_marquardt( locs, vmap_point_source_grad_and_mse, @@ -149,18 +144,13 @@ def localize_amplitude_vectors( amplitude_vectors, in_probe_mask, x, y, z_rel, local_geoms ) z_abs = z_rel + geom[main_channels, 1] - + + results = dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha) if is_numpy: - x = x.numpy(force=True) - y = y.numpy(force=True) - z_rel = z_rel.numpy(force=True) - z_abs = z_abs.numpy(force=True) - alpha = alpha.numpy(force=True) - return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha) - - if model == "dipole": - if levenberg_marquardt_kwargs is None: - levenberg_marquardt_kwargs = {} + results = {k: v.numpy(force=True) for k, v in results.items()} + return results + + elif model == "dipole": locs, i = batched_levenberg_marquardt( locs, vmap_dipole_grad_and_mse, @@ -168,33 +158,39 @@ def localize_amplitude_vectors( extra_args=(normalized_amp_vecs, local_geoms), **levenberg_marquardt_kwargs, ) - + x, y0, z_rel = locs.T y = F.softplus(y0) projected_dist = vmap_dipole_find_projection_distance( normalized_amp_vecs, x, y, z_rel, local_geoms - ) - + ) + # if projected_dist>th_dipole_proj_dist: return the loc values from pointsource - pointsource_spikes = torch.nonzero(projected_dist>th_dipole_proj_dist, as_tuple=True) - + pointsource_spikes = torch.nonzero( + projected_dist > th_dipole_proj_dist, as_tuple=True + ) + locs_pointsource_spikes, i = batched_levenberg_marquardt( locs[pointsource_spikes], vmap_point_source_grad_and_mse, vmap_point_source_hessian, - extra_args=(normalized_amp_vecs[pointsource_spikes], in_probe_mask, local_geoms[pointsource_spikes]), + extra_args=( + normalized_amp_vecs[pointsource_spikes], + in_probe_mask, + local_geoms[pointsource_spikes], + ), **levenberg_marquardt_kwargs, ) x_pointsource_spikes, y0_pointsource_spikes, z_rel_pointsource_spikes = locs.T y_pointsource_spikes = F.softplus(y0_pointsource_spikes) - + x[pointsource_spikes] = x_pointsource_spikes y[pointsource_spikes] = y_pointsource_spikes z_rel[pointsource_spikes] = z_rel_pointsource_spikes - + z_abs = z_rel + geom[main_channels, 1] - + if is_numpy: x = x.numpy(force=True) y = y.numpy(force=True) @@ -204,9 +200,14 @@ def localize_amplitude_vectors( return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=projected_dist) + else: + assert False + # -- point source / dipole model library functions -def point_source_amplitude_at(x, y, z, local_geom): + + +def point_source_amplitude_at(x, y, z, alpha, local_geom): """Point source model predicted amplitude at local_geom given location""" dxs = torch.square(x - local_geom[:, 0]) dzs = torch.square(z - local_geom[:, 1]) @@ -223,25 +224,8 @@ def point_source_find_alpha(amp_vec, channel_mask, x, y, z, local_geoms): ) return alpha -def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom): - """We can solve for the brightness (alpha) of the source in closed form given x,y,z""" - - dxs = x - local_geom[:, 0] - dzs = z - local_geom[:, 1] - dys = y - duv = torch.tensor([dxs, dys, dzs]) - X = duv / torch.pow(torch.sum(torch.square(duv)), 3/2) - beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, normalized_amp_vec)) - beta /= torch.sqrt(torch.square(beta).sum()) - dipole_planar_direction = torch.sqrt(np.torch(beta[[0, 2]]).sum()) - closest_chan = torch.square(duv).sum(1).argmin() - min_duv = duv[closest_chan] - val_th = torch.sqrt(torch.square(min_duv).sum())/dipole_planar_direction - return val_th -def point_source_mse( - loc, amplitude_vector, channel_mask, local_geom, logbarrier=True -): +def point_source_mse(loc, amplitude_vector, channel_mask, local_geom, logbarrier=True): """Objective in point source model Arguments @@ -264,12 +248,9 @@ def point_source_mse( x, y0, z = loc y = F.softplus(y0) - alpha = point_source_find_alpha( - amplitude_vector, channel_mask, x, y, z, local_geom - ) + alpha = point_source_find_alpha(amplitude_vector, channel_mask, x, y, z, local_geom) obj = torch.square( - amplitude_vector - - point_source_amplitude_at(x, y, z, alpha, local_geom) + amplitude_vector - point_source_amplitude_at(x, y, z, alpha, local_geom) ).mean() if logbarrier: obj -= torch.log(10.0 * y) / 10000.0 @@ -277,23 +258,43 @@ def point_source_mse( # obj -= torch.log(1000.0 - torch.sqrt(torch.square(x) + torch.square(z))).sum() / 10000.0 return obj + +def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom): + """We can solve for the brightness (alpha) of the source in closed form given x,y,z""" + + dxs = x - local_geom[:, 0] + dzs = z - local_geom[:, 1] + dys = y + duv = torch.tensor([dxs, dys, dzs]) + X = duv / torch.pow(torch.sum(torch.square(duv)), 3 / 2) + beta = torch.linalg.solve( + torch.matmul(X.T, X), torch.matmul(X.T, normalized_amp_vec) + ) + beta /= torch.sqrt(torch.square(beta).sum()) + dipole_planar_direction = torch.sqrt(np.torch(beta[[0, 2]]).sum()) + closest_chan = torch.square(duv).sum(1).argmin() + min_duv = duv[closest_chan] + val_th = torch.sqrt(torch.square(min_duv).sum()) / dipole_planar_direction + return val_th + + def dipole_mse(loc, amplitude_vector, local_geom, logbarrier=True): """Dipole model predicted amplitude at local_geom given location""" - + x, y0, z = loc y = F.softplus(y0) dxs = x - local_geom[:, 0] dzs = z - local_geom[:, 1] dys = y - + duv = torch.tensor([dxs, dys, dzs]) - X = duv / torch.pow(torch.sum(torch.square(duv)), 3/2) - + X = duv / torch.pow(torch.sum(torch.square(duv)), 3 / 2) + beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, (ptp / maxptp))) qtq = torch.matmul(X, beta) - + obj = torch.square(ptp / maxptp - qtq).mean() if logbarrier: obj -= torch.log(10.0 * y) / 10000.0 From eb23cd34f1e85474a45b2ba79d215f957534d174 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 6 Nov 2023 15:47:05 -0500 Subject: [PATCH 19/49] reg_geom unneeded --- src/dartsort/templates/pairwise_util.py | 4 +--- tests/test_templates.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index 930c9b15..9bdfef7d 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -26,7 +26,6 @@ def compressed_convolve_to_h5( chunk_time_centers_s: Optional[np.ndarray] = None, motion_est=None, geom: Optional[np.ndarray] = None, - reg_geom: Optional[np.ndarray] = None, conv_ignore_threshold=0.0, coarse_approx_error_threshold=0.0, conv_batch_size=128, @@ -69,7 +68,6 @@ def compressed_convolve_to_h5( template_shift_index=template_shift_index, upsampled_shifted_template_index=upsampled_shifted_template_index, geom=geom, - reg_geom=reg_geom, conv_ignore_threshold=conv_ignore_threshold, coarse_approx_error_threshold=coarse_approx_error_threshold, max_shift="full", @@ -143,7 +141,6 @@ def iterate_compressed_pairwise_convolutions( template_shift_index: drift_util.TemplateShiftIndex, upsampled_shifted_template_index: UpsampledShiftedTemplateIndex, geom: Optional[np.ndarray] = None, - reg_geom: Optional[np.ndarray] = None, conv_ignore_threshold=0.0, coarse_approx_error_threshold=0.0, max_shift="full", @@ -167,6 +164,7 @@ def iterate_compressed_pairwise_convolutions( n_shifts = template_shift_index.all_pitch_shifts.size do_shifting = n_shifts > 1 geom_kdtree = reg_geom_kdtree = match_distance = None + reg_geom = template_data.registered_geom if do_shifting: assert geom is not None assert reg_geom is not None diff --git a/tests/test_templates.py b/tests/test_templates.py index 90e840ee..e574dde3 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -252,7 +252,6 @@ def test_pconv(): pconvdb_path = pairwise_util.compressed_convolve_to_h5( Path(tdir) / "test.h5", geom=geom, - reg_geom=reg_geom, template_data=tdata, low_rank_templates=svd_compressed, compressed_upsampled_temporal=ctempup, From c8cbe8f85ad8352fee13cd7a31cdcc77b31741d3 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 7 Nov 2023 12:46:31 -0500 Subject: [PATCH 20/49] Pairwise convs working --- src/dartsort/templates/pairwise_util.py | 128 +++++++++++++++++++----- src/dartsort/templates/template_util.py | 10 +- src/dartsort/util/drift_util.py | 4 +- 3 files changed, 114 insertions(+), 28 deletions(-) diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index 9bdfef7d..05e0deed 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -58,9 +58,6 @@ def compressed_convolve_to_h5( template_shift_index, compressed_upsampled_temporal ) - print(f"{template_shift_index=}") - print(f"{upsampled_shifted_template_index=}") - chunk_res_iterator = iterate_compressed_pairwise_convolutions( template_data=template_data, low_rank_templates=low_rank_templates, @@ -108,11 +105,13 @@ def compressed_convolve_to_h5( ] # upsampled shifted template indices for B - up_shifted_temp_ix_b = upsampled_shifted_template_index.upsampled_shifted_template_index[ - chunk_res.template_indices_b, - chunk_res.shift_indices_b, - chunk_res.upsampling_indices_b, - ] + up_shifted_temp_ix_b = ( + upsampled_shifted_template_index.upsampled_shifted_template_index[ + chunk_res.template_indices_b, + chunk_res.shift_indices_b, + chunk_res.upsampling_indices_b, + ] + ) # store new set of indices new_pconv_indices = chunk_res.compression_index + n_pconvs @@ -127,8 +126,13 @@ def compressed_convolve_to_h5( # write fixed size outputs h5.create_dataset("shifts", data=template_shift_index.all_pitch_shifts) - h5.create_dataset("shifted_template_index", data=template_shift_index.template_shift_index) - h5.create_dataset("upsampled_shifted_template_index", data=upsampled_shifted_template_index.upsampled_shifted_template_index) + h5.create_dataset( + "shifted_template_index", data=template_shift_index.template_shift_index + ) + h5.create_dataset( + "upsampled_shifted_template_index", + data=upsampled_shifted_template_index.upsampled_shifted_template_index, + ) h5.create_dataset("pconv_index", data=pconv_index) return output_hdf5_filename @@ -180,10 +184,6 @@ def iterate_compressed_pairwise_convolutions( for start_b in range(start_a, units.size, units_batch_size): end_b = min(start_b + units_batch_size, units.size) jobs.append((units[start_a:end_a], units[start_b:end_b])) - if show_progress: - jobs = tqdm( - jobs, smoothing=0.01, desc="Pairwise convolution", unit="pair block" - ) # worker kwargs kwargs = dict( @@ -210,7 +210,16 @@ def iterate_compressed_pairwise_convolutions( initializer=_conv_worker_init, initargs=(rank_queue, device, kwargs), ) as pool: - yield from pool.map(_conv_job, jobs) + it = pool.map(_conv_job, jobs) + if show_progress: + it = tqdm( + it, + smoothing=0.01, + desc="Pairwise convolution", + unit="pair block", + total=len(jobs), + ) + yield from it @dataclass @@ -271,6 +280,10 @@ def compressed_convolve_pairs( shifts, superres templates, and upsamples. Some of these may be zero or may be duplicates, so the return value is a sparse representation. See below. """ + # print(f"{units_a.shape=}") + # print(f"{units_b.shape=}") + # print(f"{(units_a.size * units_b.size)=}") + # what pairs, shifts, etc are we convolving? shifted_temp_ix_a, temp_ix_a, shift_a, unit_a = handle_shift_indices( units_a, template_data.unit_ids, template_shift_index @@ -278,6 +291,8 @@ def compressed_convolve_pairs( shifted_temp_ix_b, temp_ix_b, shift_b, unit_b = handle_shift_indices( units_b, template_data.unit_ids, template_shift_index ) + # print(f"{shifted_temp_ix_a.shape=}") + # print(f"{shifted_temp_ix_b.shape=}") # get (shifted) spatial components * singular values spatial_singular_a = get_shifted_spatial_singular( @@ -323,6 +338,15 @@ def compressed_convolve_pairs( if pairs_ret is None: return None ix_a, ix_b, compression_index, conv_ix = pairs_ret + # print(f"A {ix_a.shape=}") + # print(f"A {ix_b.shape=}") + # print(f"A {compression_index.shape=}") + # print(f"A {conv_ix.shape=}") + + # print(f"-----------") + # print(f"after pairs {conv_ix.shape=} {compression_index.shape=}") + # print(f"{compression_index.min()=} {compression_index.max()=}") + # print(f"{ix_a.shape=} {ix_b.shape=}") # handle upsampling # each pair will be duplicated by the b unit's number of upsampled copies @@ -343,6 +367,15 @@ def compressed_convolve_pairs( upsampled_shifted_template_index, compressed_upsampled_temporal, ) + # print(f"B {ix_a.shape=}") + # print(f"B {ix_b.shape=}") + # print(f"B {compression_index.shape=}") + # print(f"B {conv_ix.shape=}") + + # print(f"-----------") + # print(f"after up {conv_ix.shape=} {compression_index.shape=}") + # print(f"{compression_index.min()=} {compression_index.max()=}") + # print(f"{ix_a.shape=} {ix_b.shape=}") # # now, these arrays all have length n_pairs # shifted_temp_ix_a = shifted_temp_ix_a[ix_a] @@ -354,6 +387,10 @@ def compressed_convolve_pairs( # run convolutions temporal_a = low_rank_templates.temporal_components[temp_ix_a] + # print(f"{spatial_singular_a[ix_a[conv_ix]].shape=}") + # print(f"{spatial_singular_b[ix_b[conv_ix]].shape=}") + # print(f"{temporal_a[ix_a[conv_ix]].shape=}") + # print(f"{conv_temporal_components_up_b.shape=}") pconv, kept = correlate_pairs_lowrank( torch.as_tensor(spatial_singular_a[ix_a[conv_ix]]).to(device), torch.as_tensor(spatial_singular_b[ix_b[conv_ix]]).to(device), @@ -363,17 +400,33 @@ def compressed_convolve_pairs( conv_ignore_threshold=conv_ignore_threshold, batch_size=batch_size, ) - kept_pairs = np.isin(conv_ix[compression_index], conv_ix[kept]) + # print(f"-----------") + # print(f"after corr {pconv.shape=} {kept.shape=}") conv_ix = conv_ix[kept] if not conv_ix.size: return None - compression_index = compression_index[kept_pairs] + kept_pairs = np.flatnonzero(np.isin(compression_index, kept)) + # print(f"-----------") + # print(f"kept {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}") + # print(f"{compression_index.min()=} {compression_index.max()=}") + # print(f"{compression_index[kept_pairs].min()=} {compression_index[kept_pairs].max()=}") + # print(f"{ix_a.shape=} {ix_b.shape=}") + # print(f"{kept.shape=} {kept.dtype=} {kept.min()=} {kept.max()=}") + # print(f"{kept_pairs.shape=} {kept_pairs.dtype=} {kept_pairs.min()=} {kept_pairs.max()=}") + compression_index = np.searchsorted(kept, compression_index[kept_pairs]) + conv_ix = np.searchsorted(kept_pairs, conv_ix) ix_a = ix_a[kept_pairs] ix_b = ix_b[kept_pairs] # compression_index = compression_index[kept] pconv = pconv.cpu() + # print(f"-----------") + # print(f"after searchsorted {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}") + # print(f"{compression_index.min()=} {compression_index.max()=}") + # print(f"{ix_a.shape=} {ix_b.shape=}") # coarse approx + # print(f"-----------") + # print(f"before approx {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}") pconv, old_ix_to_new_ix = coarse_approximate( pconv, unit_a[ix_a[conv_ix]], @@ -383,9 +436,14 @@ def compressed_convolve_pairs( shift_b[ix_b[conv_ix]], coarse_approx_error_threshold=coarse_approx_error_threshold, ) + # print(f"-----------") + # print(f"after approx") + # print(f"{pconv.shape=} {conv_ix.shape=} {old_ix_to_new_ix.shape=} {compression_index.shape=}") + # print(f"{compression_index.min()=} {compression_index.max()=}") + # print(f"{old_ix_to_new_ix.min()=} {old_ix_to_new_ix.max()=}") + compression_index = old_ix_to_new_ix[compression_index] # above function invalidates the whole idea of conv_ix del conv_ix - compression_index = old_ix_to_new_ix[compression_index] # recover metadata temp_ix_a = temp_ix_a[ix_a] @@ -575,7 +633,7 @@ def shift_deduplicated_pairs( co-visible. So, these are subsets of shifted_temp_ix_a,b compression_index Size == pair_ix_a,b size - Subsets of conv_ix_a,b, so that the xcorr of templates + Arrays with shape matching pair_ix_a,b, so that the xcorr of templates shifted_temp_ix_a[pair_ix_a[i]], shifted_temp_ix_b[pair_ix_b[i]] is the same as that of shifted_temp_ix_a[pair_ix_a[conv_ix[compression_index[i]]], @@ -589,14 +647,22 @@ def shift_deduplicated_pairs( chan_amp_b = torch.sqrt(torch.square(spatialsing_b).sum(1)) pair = chan_amp_a @ chan_amp_b.T pair = pair > conv_ignore_threshold + pair = pair.cpu() + # print(f"___ after overlaps {pair.sum()=}") # co-occurrence - pair *= template_shift_index.cooccurrence + cooccurrence = template_shift_index.cooccurrence[ + shifted_temp_ix_a[:, None], + shifted_temp_ix_b[None, :], + ] + pair *= torch.as_tensor(cooccurrence, device=pair.device) + # print(f"___ after cooccur {pair.sum()=}") pair_ix_a, pair_ix_b = torch.nonzero(pair, as_tuple=True) nco = pair_ix_a.numel() if not nco: return None + # print(f"___ {nco=}") # if no shifting, deduplication is the identity do_shifting = template_shift_index.all_pitch_shifts.size > 1 @@ -633,8 +699,16 @@ def shift_deduplicated_pairs( fill_value=0, ) # 2: assign IDs to each such vector - _, active_chan_ids_a = np.unique(active_chans_a, axis=0, return_inverse=True) - _, active_chan_ids_b = np.unique(active_chans_b, axis=0, return_inverse=True) + chanset_a, active_chan_ids_a = np.unique( + active_chans_a, axis=0, return_inverse=True + ) + chanset_b, active_chan_ids_b = np.unique( + active_chans_b, axis=0, return_inverse=True + ) + # print(f"___ {chanset_a.sum(1)=}") + # print(f"___ {chanset_b.sum(1)=}") + # print(f"___ {active_chan_ids_a.shape=} {np.unique(active_chan_ids_a).shape=}") + # print(f"___ {active_chan_ids_b.shape=} {np.unique(active_chan_ids_b).shape=}") # 3 temp_ix_a = temp_ix_a[pair_ix_a] @@ -643,6 +717,13 @@ def shift_deduplicated_pairs( shift_a = shift_a[pair_ix_a] shift_b = shift_b[pair_ix_b] shift_diff = shift_a - shift_b + # print(f"{temp_ix_a=}") + # print(f"{shift_a=}") + # print(f"{active_chan_ids_a[pair_ix_a]=}") + # print(f"{temp_ix_b=}") + # print(f"{shift_b=}") + # print(f"{active_chan_ids_b[pair_ix_b]=}") + # print(f"{shift_diff=}") # figure out combinations conv_determiners = np.c_[ @@ -652,6 +733,7 @@ def shift_deduplicated_pairs( active_chan_ids_b[pair_ix_b], shift_diff, ] + # print(f"{conv_determiners=}") # conv_ix: indices of unique determiners # compression_index: which representative does each pair belong to _, conv_ix, compression_index = np.unique( @@ -809,7 +891,7 @@ def compressed_upsampled_pairs( ] ) conv_temporal_components_up_b = ( - compressed_upsampled_temporal.compressed_index_to_upsampling_index[ + compressed_upsampled_temporal.compressed_upsampled_templates[ conv_compressed_upsampled_ix ] ) diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 8114f222..3bfc7d15 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -336,9 +336,13 @@ def compressed_upsampled_templates( all_upsampled_templates = temporally_upsample_templates( templates, temporal_upsampling_factor=max_upsample, kind=kind ) - compressed_upsampled_templates = all_upsampled_templates[ - template_indices, upsampling_indices - ] + # n, up, t, c + all_upsampled_templates = all_upsampled_templates.transpose(0, 2, 1, 3) + rix = np.ravel_multi_index((template_indices, upsampling_indices), all_upsampled_templates.shape[:2]) + all_upsampled_templates = all_upsampled_templates.reshape( + n_templates * max_upsample, templates.shape[1], templates.shape[2] + ) + compressed_upsampled_templates = all_upsampled_templates[rix] return CompressedUpsampledTemplates( compressed_upsampled_templates, diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index e3761924..3054da74 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -557,10 +557,10 @@ def get_shift_and_unit_pairs( ) pitch_shifts = pitch_shifts.astype(int) pitch_shift_ix = np.searchsorted(all_pitch_shifts, pitch_shifts) - + shifted_temp_ixs = template_shift_index[temp_ixs, pitch_shift_ix] cooccurrence[shifted_temp_ixs[:, None], shifted_temp_ixs[None, :]] = 1 - + return TemplateShiftIndex( n_template_shift_pairs, all_pitch_shifts, From 949a0b0acfe04f8dc81d009f9e34c56f32d3f552 Mon Sep 17 00:00:00 2001 From: julien Date: Tue, 7 Nov 2023 12:54:08 -0500 Subject: [PATCH 21/49] dipole localization --- scripts/uhd_pipeline.py | 2 +- src/dartsort/config.py | 1 + src/dartsort/localize/localize_torch.py | 56 ++++++------------ src/dartsort/localize/localize_util.py | 2 + src/dartsort/main.py | 1 + src/dartsort/transform/all_transformers.py | 4 +- src/dartsort/transform/localize.py | 67 ++-------------------- 7 files changed, 30 insertions(+), 103 deletions(-) diff --git a/scripts/uhd_pipeline.py b/scripts/uhd_pipeline.py index 215fbf20..331266a3 100644 --- a/scripts/uhd_pipeline.py +++ b/scripts/uhd_pipeline.py @@ -164,7 +164,7 @@ # Don't trust spikeinterface preprocessing :( ... if preprocessing: print("Preprocessing...") - preprocessing_dir = Path(output_all) / "preprocessing_test" + preprocessing_dir = Path(output_all) / "preprocessing" Path(preprocessing_dir).mkdir(exist_ok=True) if t_end_preproc is None: t_end_preproc=rec_len_sec diff --git a/src/dartsort/config.py b/src/dartsort/config.py index 0aa103e5..d78b0f37 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -76,6 +76,7 @@ class FeaturizationConfig: localization_radius: float = 100.0 # these are saved always if do_localization save_amplitude_vectors: bool = True + localization_model = "dipole" # -- further info about denoising # in the future we may add multi-channel or other nns diff --git a/src/dartsort/localize/localize_torch.py b/src/dartsort/localize/localize_torch.py index 667032ae..1893953c 100644 --- a/src/dartsort/localize/localize_torch.py +++ b/src/dartsort/localize/localize_torch.py @@ -69,8 +69,7 @@ def localize_amplitude_vectors( assert channel_index.shape == (n_channels_tot, c) assert main_channels.shape == (n_spikes,) # we'll return numpy if user sent numpy - is_numpy = not torch.is_tensor(amplitude_vectors) - + is_numpy = not torch.is_tensor(amplitude_vectors) # handle channel subsetting if radius is not None or n_channels_subset is not None: @@ -173,26 +172,8 @@ def localize_amplitude_vectors( y = F.softplus(y0) projected_dist = vmap_dipole_find_projection_distance( normalized_amp_vecs, x, y, z_rel, local_geoms - ) - - # if projected_dist>th_dipole_proj_dist: return the loc values from pointsource + ) - pointsource_spikes = torch.nonzero(projected_dist>th_dipole_proj_dist, as_tuple=True) - - locs_pointsource_spikes, i = batched_levenberg_marquardt( - locs[pointsource_spikes], - vmap_point_source_grad_and_mse, - vmap_point_source_hessian, - extra_args=(normalized_amp_vecs[pointsource_spikes], in_probe_mask, local_geoms[pointsource_spikes]), - **levenberg_marquardt_kwargs, - ) - x_pointsource_spikes, y0_pointsource_spikes, z_rel_pointsource_spikes = locs.T - y_pointsource_spikes = F.softplus(y0_pointsource_spikes) - - x[pointsource_spikes] = x_pointsource_spikes - y[pointsource_spikes] = y_pointsource_spikes - z_rel[pointsource_spikes] = z_rel_pointsource_spikes - z_abs = z_rel + geom[main_channels, 1] if is_numpy: @@ -206,7 +187,7 @@ def localize_amplitude_vectors( # -- point source / dipole model library functions -def point_source_amplitude_at(x, y, z, local_geom): +def point_source_amplitude_at(x, y, z, alpha, local_geom): """Point source model predicted amplitude at local_geom given location""" dxs = torch.square(x - local_geom[:, 0]) dzs = torch.square(z - local_geom[:, 1]) @@ -224,18 +205,19 @@ def point_source_find_alpha(amp_vec, channel_mask, x, y, z, local_geoms): return alpha def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom): - """We can solve for the brightness (alpha) of the source in closed form given x,y,z""" + """COmpute a value dist/dipole in x,z that tells us if dipole or monopole is better""" dxs = x - local_geom[:, 0] dzs = z - local_geom[:, 1] - dys = y - duv = torch.tensor([dxs, dys, dzs]) - X = duv / torch.pow(torch.sum(torch.square(duv)), 3/2) - beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, normalized_amp_vec)) + dys = y.expand(dzs.size()) + duv = torch.stack([dxs, dys, dzs], dim=1) + X = duv / torch.pow(torch.sum(torch.square(duv), dim=1), 3/2)[:, None] + # beta = torch.linalg.lstsq(X, amplitude_vector[:, None])[0] + beta = torch.matmul(torch.linalg.pinv(torch.matmul(X.T, X)), torch.matmul(X.T, normalized_amp_vec)) beta /= torch.sqrt(torch.square(beta).sum()) - dipole_planar_direction = torch.sqrt(np.torch(beta[[0, 2]]).sum()) - closest_chan = torch.square(duv).sum(1).argmin() - min_duv = duv[closest_chan] + dipole_planar_direction = torch.sqrt(torch.square(beta[[0, 2]]).sum()) + closest_chan = torch.argmin(torch.sum(torch.square(duv), dim=1)) + min_duv = duv[closest_chan[None]][0] #workaround around vmap doesn't work for one dim tensor .item() val_th = torch.sqrt(torch.square(min_duv).sum())/dipole_planar_direction return val_th @@ -285,16 +267,16 @@ def dipole_mse(loc, amplitude_vector, local_geom, logbarrier=True): dxs = x - local_geom[:, 0] dzs = z - local_geom[:, 1] - dys = y - - duv = torch.tensor([dxs, dys, dzs]) - - X = duv / torch.pow(torch.sum(torch.square(duv)), 3/2) + dys = y.expand(dzs.size()) - beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, (ptp / maxptp))) + duv = torch.stack([dxs, dys, dzs], dim=1) + X = duv / torch.pow(torch.sum(torch.square(duv), dim=1), 3/2)[:, None] + # beta = torch.linalg.lstsq(X, amplitude_vector[:, None])[0] + # beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, amplitude_vector)) + beta = torch.matmul(torch.linalg.pinv(torch.matmul(X.T, X)), torch.matmul(X.T, amplitude_vector)) qtq = torch.matmul(X, beta) - obj = torch.square(ptp / maxptp - qtq).mean() + obj = torch.square(amplitude_vector - qtq).mean() if logbarrier: obj -= torch.log(10.0 * y) / 10000.0 diff --git a/src/dartsort/localize/localize_util.py b/src/dartsort/localize/localize_util.py index b0cc1b59..f85e0ffb 100644 --- a/src/dartsort/localize/localize_util.py +++ b/src/dartsort/localize/localize_util.py @@ -43,6 +43,7 @@ def localize_hdf5( spikes_per_batch=100_000, show_progress=True, device=None, + localization_model="pointsource", ): """Run localization on a HDF5 file with stored amplitude vectors @@ -100,6 +101,7 @@ def localize_hdf5( channel_index=channel_index, radius=radius, n_channels_subset=n_channels_subset, + model=localization_model, ) xyza_batch = np.c_[ locs["x"].cpu().numpy(), diff --git a/src/dartsort/main.py b/src/dartsort/main.py index 4cb5ecf9..d2565186 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -173,6 +173,7 @@ def _run_peeler( amplitude_vectors_dataset_name=f"{wf_name}_amplitude_vectors", show_progress=show_progress, device=device, + localization_model=featurization_config.localization_model ) return ( diff --git a/src/dartsort/transform/all_transformers.py b/src/dartsort/transform/all_transformers.py index fd7901ee..98e81291 100644 --- a/src/dartsort/transform/all_transformers.py +++ b/src/dartsort/transform/all_transformers.py @@ -1,6 +1,6 @@ from .amplitudes import AmplitudeVector, MaxAmplitude from .enforce_decrease import EnforceDecrease -from .localize import PointSourceLocalization +from .localize import Localization from .single_channel_denoiser import SingleChannelWaveformDenoiser from .temporal_pca import TemporalPCADenoiser, TemporalPCAFeaturizer from .transform_base import Waveform @@ -13,7 +13,7 @@ SingleChannelWaveformDenoiser, TemporalPCADenoiser, TemporalPCAFeaturizer, - PointSourceLocalization, + Localization, ] transformers_by_class_name = {cls.__name__: cls for cls in all_transformers} diff --git a/src/dartsort/transform/localize.py b/src/dartsort/transform/localize.py index 083318fe..7bbe69c6 100644 --- a/src/dartsort/transform/localize.py +++ b/src/dartsort/transform/localize.py @@ -5,7 +5,7 @@ from .transform_base import BaseWaveformFeaturizer -class PointSourceLocalization(BaseWaveformFeaturizer): +class Localization(BaseWaveformFeaturizer): """Order of output columns: x, y, z_abs, alpha""" default_name = "point_source_localizations" @@ -22,6 +22,7 @@ def __init__( amplitude_kind="peak", name=None, name_prefix="", + localization_model="pointsource", ): assert amplitude_kind in ("peak", "ptp") super().__init__( @@ -34,6 +35,7 @@ def __init__( self.radius = radius self.n_channels_subset = n_channels_subset self.logbarrier = logbarrier + self.localization_model = localization_model def transform(self, waveforms, max_channels=None): # get amplitude vectors @@ -52,68 +54,7 @@ def transform(self, waveforms, max_channels=None): n_channels_subset=self.n_channels_subset, logbarrier=self.logbarrier, dtype=self.dtype, - ) - - localizations = torch.column_stack( - [ - loc_result["x"], - loc_result["y"], - loc_result["z_abs"], - loc_result["alpha"], - ] - ) - return localizations - -class DipoleLocalization(BaseWaveformFeaturizer): - """Order of output columns: x, y, z_abs, alpha""" - - default_name = "dipole_localizations" - shape = (4,) - dtype = torch.double - - def __init__( - self, - channel_index, - geom, - radius=None, - n_channels_subset=None, - logbarrier=True, - amplitude_kind="peak", - model="dipole", - name=None, - name_prefix="", - ): - assert amplitude_kind in ("peak", "ptp") - super().__init__( - geom=geom, - channel_index=channel_index, - name=name, - name_prefix=name_prefix, - ) - self.amplitude_kind = amplitude_kind - self.radius = radius - self.n_channels_subset = n_channels_subset - self.logbarrier = logbarrier - self.model = model - - def transform(self, waveforms, max_channels=None): - # get amplitude vectors - if self.amplitude_kind == "peak": - ampvecs = waveforms.abs().max(dim=1).values - elif self.amplitude_kind == "ptp": - ampvecs = ptp(waveforms, dim=1) - - with torch.enable_grad(): - loc_result = localize_amplitude_vectors( - ampvecs, - self.geom, - max_channels, - channel_index=self.channel_index, - radius=self.radius, - n_channels_subset=self.n_channels_subset, - logbarrier=self.logbarrier, - model=self.model, - dtype=self.dtype, + model=self.localization_model, ) localizations = torch.column_stack( From 60b085b2078ccbca7d2e272057347f282802d8e3 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 7 Nov 2023 13:39:02 -0500 Subject: [PATCH 22/49] slice(None)s are not always my friend --- src/dartsort/templates/pairwise_util.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index 05e0deed..061596ee 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -400,19 +400,19 @@ def compressed_convolve_pairs( conv_ignore_threshold=conv_ignore_threshold, batch_size=batch_size, ) - # print(f"-----------") - # print(f"after corr {pconv.shape=} {kept.shape=}") + print(f"-----------") + print(f"after corr {pconv.shape=} {conv_ix[kept].shape=}") conv_ix = conv_ix[kept] if not conv_ix.size: return None kept_pairs = np.flatnonzero(np.isin(compression_index, kept)) - # print(f"-----------") - # print(f"kept {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}") - # print(f"{compression_index.min()=} {compression_index.max()=}") - # print(f"{compression_index[kept_pairs].min()=} {compression_index[kept_pairs].max()=}") - # print(f"{ix_a.shape=} {ix_b.shape=}") - # print(f"{kept.shape=} {kept.dtype=} {kept.min()=} {kept.max()=}") - # print(f"{kept_pairs.shape=} {kept_pairs.dtype=} {kept_pairs.min()=} {kept_pairs.max()=}") + print(f"-----------") + print(f"kept {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}") + print(f"{compression_index.min()=} {compression_index.max()=}") + print(f"{compression_index[kept_pairs].min()=} {compression_index[kept_pairs].max()=}") + print(f"{ix_a.shape=} {ix_b.shape=}") + print(f"{kept.shape=} {kept.dtype=} {kept.min()=} {kept.max()=}") + print(f"{kept_pairs.shape=} {kept_pairs.dtype=} {kept_pairs.min()=} {kept_pairs.max()=}") compression_index = np.searchsorted(kept, compression_index[kept_pairs]) conv_ix = np.searchsorted(kept_pairs, conv_ix) ix_a = ix_a[kept_pairs] @@ -538,12 +538,13 @@ def correlate_pairs_lowrank( pconv[istart:iend] = pconv_[0, :, 0, :] # nco, nup, time # more stringent covisibility - kept = slice(None) if conv_ignore_threshold > 0: max_val = pconv.reshape(n_pairs, -1).abs().max(dim=1).values kept = max_val > conv_ignore_threshold pconv = pconv[kept] kept = np.flatnonzero(kept.numpy(force=True)) + else: + kept = np.arange(len(pconv)) return pconv, kept From 2c4d78a710baf9690d4527cceed38dce042d98f3 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 7 Nov 2023 13:53:23 -0500 Subject: [PATCH 23/49] All tests passing --- tests/test_grab_and_featurize.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_grab_and_featurize.py b/tests/test_grab_and_featurize.py index 67abb6ae..67a2fccf 100644 --- a/tests/test_grab_and_featurize.py +++ b/tests/test_grab_and_featurize.py @@ -249,8 +249,14 @@ def test_grab_and_featurize(): assert np.array_equal(h5["channel_index"][()], channel_index) assert h5["last_chunk_start"][()] == 90_000 - # this is kind of a good test of reproducibility/random seeds - assert np.array_equal(locs0, locs1) + # this is kind of a good test of reproducibility + # totally reproducible on CPU, suprprisingly large diffs on GPU + if not torch.cuda.is_available(): + assert np.array_equal(locs0, locs1) + else: + valid = np.clip(locs1[:, 2], geom[:,1].min(), geom[:,1].max()) + valid = locs1[:, 2] == valid + assert np.isclose(locs0[valid], locs1[valid], atol=1e-6).all() if __name__ == "__main__": From f05bf1df7f3d684f9a0d3a0040ffd64639ff29b8 Mon Sep 17 00:00:00 2001 From: julien Date: Tue, 7 Nov 2023 13:56:55 -0500 Subject: [PATCH 24/49] dipole --- src/dartsort/localize/localize_torch.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/src/dartsort/localize/localize_torch.py b/src/dartsort/localize/localize_torch.py index 72e99206..56a35024 100644 --- a/src/dartsort/localize/localize_torch.py +++ b/src/dartsort/localize/localize_torch.py @@ -206,7 +206,6 @@ def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom): dys = y.expand(dzs.size()) duv = torch.stack([dxs, dys, dzs], dim=1) X = duv / torch.pow(torch.sum(torch.square(duv), dim=1), 3/2)[:, None] - # beta = torch.linalg.lstsq(X, amplitude_vector[:, None])[0] beta = torch.matmul(torch.linalg.pinv(torch.matmul(X.T, X)), torch.matmul(X.T, normalized_amp_vec)) beta /= torch.sqrt(torch.square(beta).sum()) dipole_planar_direction = torch.sqrt(torch.square(beta[[0, 2]]).sum()) @@ -249,25 +248,6 @@ def point_source_mse(loc, amplitude_vector, channel_mask, local_geom, logbarrier return obj -def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom): - """We can solve for the brightness (alpha) of the source in closed form given x,y,z""" - - dxs = x - local_geom[:, 0] - dzs = z - local_geom[:, 1] - dys = y - duv = torch.tensor([dxs, dys, dzs]) - X = duv / torch.pow(torch.sum(torch.square(duv)), 3 / 2) - beta = torch.linalg.solve( - torch.matmul(X.T, X), torch.matmul(X.T, normalized_amp_vec) - ) - beta /= torch.sqrt(torch.square(beta).sum()) - dipole_planar_direction = torch.sqrt(np.torch(beta[[0, 2]]).sum()) - closest_chan = torch.square(duv).sum(1).argmin() - min_duv = duv[closest_chan] - val_th = torch.sqrt(torch.square(min_duv).sum()) / dipole_planar_direction - return val_th - - def dipole_mse(loc, amplitude_vector, local_geom, logbarrier=True): """Dipole model predicted amplitude at local_geom given location""" From 4122cfbcfb2d30d0bd9985a57e04c9b36dd7e53c Mon Sep 17 00:00:00 2001 From: julien Date: Tue, 7 Nov 2023 14:10:53 -0500 Subject: [PATCH 25/49] PointSourceLoc --- src/dartsort/transform/all_transformers.py | 1 + src/dartsort/transform/localize.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/dartsort/transform/all_transformers.py b/src/dartsort/transform/all_transformers.py index 98e81291..736e7d00 100644 --- a/src/dartsort/transform/all_transformers.py +++ b/src/dartsort/transform/all_transformers.py @@ -14,6 +14,7 @@ TemporalPCADenoiser, TemporalPCAFeaturizer, Localization, + PointSourceLocalization, ] transformers_by_class_name = {cls.__name__: cls for cls in all_transformers} diff --git a/src/dartsort/transform/localize.py b/src/dartsort/transform/localize.py index 7bbe69c6..fa7acf37 100644 --- a/src/dartsort/transform/localize.py +++ b/src/dartsort/transform/localize.py @@ -66,3 +66,5 @@ def transform(self, waveforms, max_channels=None): ] ) return localizations + +PointSourceLocalization = Localization \ No newline at end of file From 65e8eea71a3148e370fe29a975f538b459a3804c Mon Sep 17 00:00:00 2001 From: julien Date: Tue, 7 Nov 2023 14:14:03 -0500 Subject: [PATCH 26/49] src/dartsort/transform/all_transformers.py --- src/dartsort/transform/all_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dartsort/transform/all_transformers.py b/src/dartsort/transform/all_transformers.py index 736e7d00..bf20d743 100644 --- a/src/dartsort/transform/all_transformers.py +++ b/src/dartsort/transform/all_transformers.py @@ -1,6 +1,6 @@ from .amplitudes import AmplitudeVector, MaxAmplitude from .enforce_decrease import EnforceDecrease -from .localize import Localization +from .localize import Localization, PointSourceLocalization from .single_channel_denoiser import SingleChannelWaveformDenoiser from .temporal_pca import TemporalPCADenoiser, TemporalPCAFeaturizer from .transform_base import Waveform From 2ef7a3a5eec793b1c8752a23f4f62595365b893f Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 7 Nov 2023 14:14:49 -0500 Subject: [PATCH 27/49] Merge and fix tests --- tests/test_grab_and_featurize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_grab_and_featurize.py b/tests/test_grab_and_featurize.py index 67a2fccf..eb036136 100644 --- a/tests/test_grab_and_featurize.py +++ b/tests/test_grab_and_featurize.py @@ -143,7 +143,7 @@ def test_grab_and_featurize(): fit_radius=10, ), transform.Waveform(channel_index, name="tpca_waveforms"), - transform.PointSourceLocalization( + transform.Localization( channel_index=channel_index, geom=geom, radius=50.0 ), ] From 42442c90afd367dc07e8972eacb290efc9cfc4a0 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 8 Nov 2023 16:20:29 -0500 Subject: [PATCH 28/49] Initial, un-debugged objective updating matcher --- src/dartsort/config.py | 2 + src/dartsort/main.py | 6 +- src/dartsort/peel/matching.py | 292 ++++++++++++++++++------ src/dartsort/peel/peel_base.py | 14 +- src/dartsort/templates/pairwise.py | 153 ++++++++++++- src/dartsort/templates/template_util.py | 7 +- 6 files changed, 385 insertions(+), 89 deletions(-) diff --git a/src/dartsort/config.py b/src/dartsort/config.py index d78b0f37..3cbcf422 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -162,3 +162,5 @@ class MatchingConfig: amplitude_scaling_variance: float = 0.0 amplitude_scaling_boundary: float = 0.5 max_iter: int = 1000 + conv_ignore_threshold: float = 5.0 + coarse_approx_error_threshold: float = 5.0 diff --git a/src/dartsort/main.py b/src/dartsort/main.py index d2565186..f9dc062d 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -86,10 +86,12 @@ def match( residual_filename=None, show_progress=True, device=None, - template_npz_filename="matching0_templates.npz", hdf5_filename="matching0.h5", model_subdir="matching0_models", + template_npz_filename="template_data.npz", ): + model_dir = Path(output_directory) / model_subdir + # compute templates template_data = TemplateData.from_config( recording, @@ -97,7 +99,7 @@ def match( template_config, motion_est=motion_est, n_jobs=n_jobs_templates, - save_folder=output_directory, + save_folder=model_dir, overwrite=overwrite, device=device, save_npz_name=template_npz_filename, diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index bc22eacb..a06710f1 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -13,10 +13,13 @@ import torch import torch.nn.functional as F from dartsort.templates import template_util +from dartsort.templates.pairwise import CompressedPairwiseConv from dartsort.transform import WaveformPipeline -from dartsort.util import spiketorch +from dartsort.util import drift_util, spiketorch from dartsort.util.data_util import SpikeDataset from dartsort.util.waveform_util import make_channel_index +from scipy.spatial import KDTree +from scipy.spatial.distance import pdist from .peel_base import BasePeeler @@ -38,6 +41,8 @@ def __init__( refractory_radius_frames=10, amplitude_scaling_variance=0.0, amplitude_scaling_boundary=0.5, + conv_ignore_threshold=5.0, + coarse_approx_error_threshold=5.0, trough_offset_samples=42, threshold=50.0, chunk_length_samples=30_000, @@ -45,47 +50,36 @@ def __init__( fit_subsampling_random_state=0, max_iter=1000, ): - n_templates, spike_length_samples = template_data.templates.shape[:2] super().__init__( recording=recording, channel_index=channel_index, featurization_pipeline=featurization_pipeline, chunk_length_samples=chunk_length_samples, - chunk_margin_samples=2 * spike_length_samples, + chunk_margin_samples=2 * template_data.templates.shape[1], n_chunks_fit=n_chunks_fit, fit_subsampling_random_state=fit_subsampling_random_state, ) - # process templates - ( - temporal_components, - singular_values, - spatial_components, - ) = template_util.svd_compress_templates( - template_data.templates, - min_channel_amplitude=min_channel_amplitude, - rank=svd_compression_rank, - ) - temporal_components = temporal_components.astype(recording.dtype) - singular_values = singular_values.astype(recording.dtype) - spatial_components = spatial_components.astype(recording.dtype) - self.handle_upsampling( - temporal_components, - temporal_upsampling_factor=temporal_upsampling_factor, - upsampling_peak_window_radius=upsampling_peak_window_radius, - ) - # main properties + self.template_data = template_data + self.temporal_upsampling_factor = temporal_upsampling_factor + self.upsampling_peak_window_radius = upsampling_peak_window_radius + self.svd_compression_rank = svd_compression_rank + self.min_channel_amplitude = min_channel_amplitude self.threshold = threshold + self.conv_ignore_threshold = conv_ignore_threshold + self.coarse_approx_error_threshold = coarse_approx_error_threshold self.refractory_radius_frames = refractory_radius_frames self.max_iter = max_iter - self.n_templates = n_templates + self.n_templates, self.spike_length_samples = template_data.templates.shape[:2] self.trough_offset_samples = trough_offset_samples - self.spike_length_samples = spike_length_samples self.geom = recording.get_channel_locations() - self.svd_compression_rank = svd_compression_rank self.n_channels = len(self.geom) - self.obj_pad_len = max(refractory_radius_frames, upsampling_peak_window_radius) + self.obj_pad_len = max( + refractory_radius_frames, + upsampling_peak_window_radius, + self.spike_length_samples - 1, + ) self.n_registered_channels = ( len(template_data.registered_geom) if template_data.registered_geom is not None @@ -96,16 +90,6 @@ def __init__( self.channel_index = channel_index self.registered_template_ampvecs = template_data.templates.ptp(1) - # torch buffers - self.register_buffer("temporal_components", torch.tensor(temporal_components)) - self.register_buffer("singular_values", torch.tensor(singular_values)) - self.register_buffer("spatial_components", torch.tensor(spatial_components)) - self.register_buffer( - "_refrac_ix", - torch.arange(-refractory_radius_frames, refractory_radius_frames + 1), - ) - self.register_buffer("_rank_ix", torch.arange(svd_compression_rank)) - # amplitude scaling properties self.is_scaling = bool(amplitude_scaling_variance) self.amplitude_scaling_variance = amplitude_scaling_variance @@ -117,19 +101,43 @@ def __init__( self.motion_est = motion_est self.registered_geom = template_data.registered_geom self.registered_template_depths_um = template_data.registered_template_depths_um - - self.handle_template_groups(template_data.unit_ids) - self.check_shapes() - - self.fixed_output_data += [ - ("temporal_components", temporal_components), - ("singular_values", singular_values), - ("spatial_components", spatial_components), - ] if self.is_drifting: self.fixed_output_data.append( ("registered_geom", template_data.registered_geom) ) + self.registered_geom_kdtree = KDTree(self.registered_geom) + self.geom_kdtree = KDTree(self.geom) + self.match_distance = pdist(self.geom).min() / 2.0 + + # some parts of this constructor are deferred to precompute_peeling_data + self._peeling_needs_fit = True + + def peeling_needs_fit(self): + return self._peeling_needs_fit + + def precompute_peeling_data(self, save_folder, n_jobs=0, device=None): + self.build_template_data( + save_folder, + self.template_data, + temporal_upsampling_factor=self.temporal_upsampling_factor, + upsampling_peak_window_radius=self.upsampling_peak_window_radius, + svd_compression_rank=self.svd_compression_rank, + min_channel_amplitude=self.min_channel_amplitude, + dtype=self.recording.dtype, + n_jobs=n_jobs, + device=device, + ) + self.handle_template_groups(self.template_data.unit_ids) + # couple more torch buffers + self.register_buffer( + "_refrac_ix", + torch.arange( + -self.refractory_radius_frames, self.refractory_radius_frames + 1 + ), + ) + self.register_buffer("_rank_ix", torch.arange(self.svd_compression_rank)) + self.check_shapes() + self._peeling_needs_fit = False def out_datasets(self): datasets = super().out_datasets() @@ -190,28 +198,93 @@ def handle_template_groups(self, unit_ids): group_index[j, : len(row)] = row self.register_buffer("group_index", torch.from_numpy(group_index)) - def handle_upsampling( + def build_template_data( self, - temporal_components, + save_folder, + template_data, temporal_upsampling_factor=8, upsampling_peak_window_radius=8, + svd_compression_rank=10, + min_channel_amplitude=1.0, + dtype=np.float32, + n_jobs=0, + device=None, ): - self.temporal_upsampling_factor = temporal_upsampling_factor - if temporal_upsampling_factor == 1: - upsampled_temporal_components = temporal_components[:, :, None, :] - self.register_buffer( - "upsampled_temporal_components", - torch.tensor(upsampled_temporal_components), - ) - return + low_rank_templates = template_util.svd_compress_templates( + template_data.templates, + min_channel_amplitude=min_channel_amplitude, + rank=svd_compression_rank, + ) + temporal_components = low_rank_templates.temporal_components.astype(dtype) + singular_values = low_rank_templates.singular_values.astype(dtype) + spatial_components = low_rank_templates.spatial_components.astype(dtype) + self.register_buffer("temporal_components", torch.tensor(temporal_components)) + self.register_buffer("singular_values", torch.tensor(singular_values)) + self.register_buffer("spatial_components", torch.tensor(spatial_components)) - upsampled_temporal_components = template_util.temporally_upsample_templates( + compressed_upsampled_temporal = self.handle_upsampling( temporal_components, + ptps=template_data.templates.ptp(1).max(1), temporal_upsampling_factor=temporal_upsampling_factor, + upsampling_peak_window_radius=upsampling_peak_window_radius, + ) + + half_chunk = self.chunk_length_samples // 2 + chunk_centers_samples = np.arange( + half_chunk, self.recording.get_num_samples(), self.chunk_length_samples + ) + chunk_centers_s = self.recording._recording_segments[0].sample_index_to_time( + chunk_centers_samples + ) + self.pairwise_conv_db = CompressedPairwiseConv.from_template_data( + save_folder / "pconv.h5", + template_data=template_data, + low_rank_templates=low_rank_templates, + compressed_upsampled_temporal=compressed_upsampled_temporal, + chunk_time_centers_s=chunk_centers_s, + motion_est=motion_est, + geom=self.geom, + conv_ignore_threshold=self.conv_ignore_threshold, + coarse_approx_error_threshold=self.coarse_approx_error_threshold, + ) + + self.fixed_output_data += [ + ("temporal_components", temporal_components), + ("singular_values", singular_values), + ("spatial_components", spatial_components), + ( + "compressed_upsampling_map", + compressed_upsampled_temporal.compressed_upsampling_map, + ), + ( + "compressed_upsampled_temporal", + compressed_upsampled_temporal.compressed_upsampled_temporal, + ), + ] + + def handle_upsampling( + self, + temporal_components, + ptps, + temporal_upsampling_factor=8, + upsampling_peak_window_radius=8, + ): + compressed_upsampled_temporal = template_util.compressed_upsampled_templates( + temporal_components, + ptps=ptps, + max_upsample=temporal_upsampling_factor, ) self.register_buffer( - "upsampled_temporal_components", torch.tensor(upsampled_temporal_components) + "compressed_upsampling_map", + compressed_upsampled_temporal.compressed_upsampling_map, ) + self.register_buffer( + "compressed_upsampled_temporal", + compressed_upsampled_temporal.compressed_upsampled_temporal, + ) + if temporal_upsampling_factor == 1: + return compressed_upsampled_temporal + self.register_buffer( "upsampling_window", torch.arange( @@ -237,6 +310,7 @@ def handle_upsampling( self.register_buffer( "peak_to_time_shift", torch.tensor([0] * (radius + 1) + [1] * radius) ) + return compressed_upsampled_temporal @classmethod def from_config( @@ -266,6 +340,8 @@ def from_config( refractory_radius_frames=matching_config.refractory_radius_frames, amplitude_scaling_variance=matching_config.amplitude_scaling_variance, amplitude_scaling_boundary=matching_config.amplitude_scaling_boundary, + conv_ignore_threshold=matching_config.conv_ignore_threshold, + coarse_approx_error_threshold=matching_config.coarse_approx_error_threshold, trough_offset_samples=matching_config.trough_offset_samples, threshold=matching_config.threshold, chunk_length_samples=matching_config.chunk_length_samples, @@ -297,6 +373,7 @@ def peel_chunk( left_margin=0, right_margin=0, threshold=30, + return_residual=return_residual, ) # process spike times and create return result @@ -306,6 +383,7 @@ def peel_chunk( def templates_at_time(self, t_s): """Extract the right spatial components for each unit.""" + pconvdb = self.pairwise_conv_db if self.is_drifting: pitch_shifts, cur_spatial = template_util.templates_at_time( t_s, @@ -315,25 +393,35 @@ def templates_at_time(self, t_s): registered_geom=self.registered_geom, motion_est=self.motion_est, return_pitch_shifts=True, + geom_kdtree=self.geom_kdtree, + match_distance=self.match_distance, ) cur_ampvecs = drift_util.get_waveforms_on_static_channels( self.registered_template_ampvecs[:, None, :], self.registered_geom, n_pitches_shift=pitch_shifts, registered_geom=self.geom, + target_kdtree=self.geom_kdtree, + match_distance=self.match_distance, fill_value=0.0, ) max_channels = cur_ampvecs[:, 0, :].argmax(1) + pconvdb = pconvdb.at_shifts(pitch_shifts) else: cur_spatial = self.spatial_components max_channels = self.registered_template_ampvecs.argmax(1) + if not pconvdb._is_torch: + pconvdb = pconvdb.to(cur_spatial.device) + return CompressedTemplateData( cur_spatial, self.singular_values, self.temporal_components, - self.upsampled_temporal_components, + self.compressed_upsampling_map, + self.compressed_upsampled_temporal, torch.tensor(max_channels, device=cur_spatial.device), + pconvdb, ) def match_chunk( @@ -344,6 +432,7 @@ def match_chunk( left_margin=0, right_margin=0, threshold=30, + return_residual=False, ): """Core peeling routine for subtraction""" # initialize residual, it needs to be padded to support our channel @@ -374,20 +463,22 @@ def match_chunk( # manages buffers for spike train data (peak times, labels, etc) peaks = MatchingPeaks(device=traces.device) + + # initialize convolution + compressed_template_data.convolve( + residual, padding=self.obj_pad_len, out=padded_conv + ) + # main loop print("start") for it in range(self.max_iter): - # update objective - compressed_template_data.convolve( - residual, padding=self.obj_pad_len, out=padded_conv - ) - # unscaled objective for coarse peaks, scaled when finding high res peak + # update the coarse objective torch.add( neg_temp_normsq, padded_conv, alpha=2.0, out=padded_objective[:-1] ) # find high-res peaks - print('before find') + print("before find") new_peaks = self.find_peaks( padded_conv, padded_objective, refrac_mask, neg_temp_normsq ) @@ -411,13 +502,22 @@ def match_chunk( # subtract them # old_norm = torch.linalg.norm(residual) ** 2 - compressed_template_data.subtract( - residual_padded, + compressed_template_data.subtract_conv( + padded_conv, new_peaks.times, new_peaks.template_indices, new_peaks.upsampling_indices, new_peaks.scalings, + conv_pad_len=self.obj_pad_len, ) + if return_residual: + compressed_template_data.subtract( + residual_padded, + new_peaks.times, + new_peaks.template_indices, + new_peaks.upsampling_indices, + new_peaks.scalings, + ) # new_norm = torch.linalg.norm(residual) ** 2 # print(f"{it=} {new_norm=}") @@ -434,7 +534,7 @@ def match_chunk( residual_padded, peaks, compressed_template_data ) - return dict( + res = dict( n_spikes=peaks.n_spikes, times_samples=peaks.times + self.trough_offset_samples, channels=channels, @@ -445,10 +545,12 @@ def match_chunk( scores=peaks.scores, collisioncleaned_waveforms=waveforms, ) + if return_residual: + res["residual"] = residual + return res def find_peaks(self, padded_conv, padded_objective, refrac_mask, neg_temp_normsq): # first step: coarse peaks. not temporally upsampled or amplitude-scaled. - padded_obj_len = padded_objective.shape[1] objective = (padded_objective + refrac_mask)[ :-1, self.obj_pad_len : -self.obj_pad_len ] @@ -583,8 +685,10 @@ class CompressedTemplateData: spatial_components: torch.Tensor singular_values: torch.Tensor temporal_components: torch.Tensor - upsampled_temporal_components: torch.Tensor + compressed_upsampling_map: torch.LongTensor + compressed_upsampled_temporal: torch.Tensor max_channels: torch.LongTensor + pairwise_conv_db: CompressedPairwiseConv def __post_init__(self): ( @@ -593,27 +697,33 @@ def __post_init__(self): self.rank, ) = self.temporal_components.shape assert self.spatial_components.shape[:2] == (self.n_templates, self.rank) - assert self.upsampled_temporal_components.shape == ( - self.n_templates, + assert self.compressed_upsampled_temporal.shape[1:] == ( self.spike_length_samples, - self.upsampled_temporal_components.shape[2], self.rank, ) assert self.singular_values.shape == (self.n_templates, self.rank) - # squared l2 norms are the sums of squared singular values - self.template_norms_squared = torch.square(self.singular_values).sum(1) + + # squared l2 norms are usually the sums of squared singular values: + # self.template_norms_squared = torch.square(self.singular_values).sum(1) + # in this case, we have subset the spatial components, so use a diff formula self.spatial_singular = ( self.spatial_components * self.singular_values[:, :, None] ) + self.template_norms_squared = torch.square(self.spatial_singular).sum((1, 2)) self.chan_ix = torch.arange( self.spatial_components.shape[2], device=self.spatial_components.device ) self.time_ix = torch.arange( self.spike_length_samples, device=self.spatial_components.device ) + self.conv_lags = torch.arange( + -self.spike_length_samples + 1, + self.spike_length_samples, + device=self.spatial_components.device, + ) def convolve(self, traces, padding=0, out=None): - """This is not the fastest strategy on GPU, but it's low-memory and fast on CPU.""" + """Convolve all templates with traces.""" out_len = traces.shape[0] + 2 * padding - self.spike_length_samples + 1 if out is None: out = torch.zeros( @@ -643,6 +753,32 @@ def convolve(self, traces, padding=0, out=None): # back to units x time (remove extra dim used for conv1d) return out[0] + def subtract_conv( + self, + conv, + times, + template_indices, + upsampling_indices, + scalings, + conv_pad_len=0, + ): + template_indices_a, template_indices_b, pconvs = scalings[ + :, None + ] * self.pairwise_conv_db.query( + template_indices_a=None, + template_indices_b=template_indices, + upsampling_indices_b=upsampling_indices, + grid=True, + ) + ix_template = template_indices_a[:, None] + ix_time = times[None, :] + (conv_pad_len + self.conv_lags) + spiketorch.add_at_( + conv, + (ix_template, ix_time), + pconvs, + sign=-1, + ) + def subtract( self, traces, @@ -651,11 +787,15 @@ def subtract( upsampling_indices, scalings, ): + """Subtract templates from traces.""" + compressed_up_inds = self.compressed_upsampling_map[ + template_indices, upsampling_indices + ] batch_templates = torch.einsum( "n,nrc,ntr->ntc", scalings, self.spatial_singular[template_indices], - self.upsampled_temporal_components[template_indices, :, upsampling_indices], + self.compressed_upsampled_temporal[compressed_up_inds], ) time_ix = times[:, None, None] + self.time_ix[None, :, None] spiketorch.add_at_( diff --git a/src/dartsort/peel/peel_base.py b/src/dartsort/peel/peel_base.py index a0ae5a4e..e7e3edc0 100644 --- a/src/dartsort/peel/peel_base.py +++ b/src/dartsort/peel/peel_base.py @@ -207,6 +207,14 @@ def peel_chunk( raise NotImplementedError + def peeling_needs_fit(self): + return False + + def precompute_peeling_data(self, save_folder, n_jobs=0, device=None): + # subclasses should override if they need to cache data for peeling + # runs before fit_peeler_models() + assert not self.peeling_needs_fit() + def fit_peeler_models(self, save_folder): # subclasses should override if they need to fit models for peeling assert not self.peeling_needs_fit() @@ -270,7 +278,7 @@ def process_chunk(self, chunk_start_samples, return_residual=False): assert not any(k in features for k in peel_result) chunk_result = {**peel_result, **features} chunk_result = { - k: v.cpu().numpy() if torch.is_tensor(v) else v + k: v.numpy(force=True) if torch.is_tensor(v) else v for k, v in chunk_result.items() } @@ -310,15 +318,13 @@ def gather_chunk_result( return n_new_spikes - def peeling_needs_fit(self): - return False - def needs_fit(self): return self.peeling_needs_fit() or self.featurization_pipeline.needs_fit() def fit_models(self, save_folder, n_jobs=0, device=None): with torch.no_grad(): if self.peeling_needs_fit(): + self.precompute_peeling_data() self.fit_peeler_models( save_folder=save_folder, n_jobs=n_jobs, device=device ) diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index 6c958c72..ace0a170 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -1,7 +1,13 @@ from dataclasses import dataclass, fields +from typing import Optional import h5py import numpy as np +import torch + +from .pairwise_util import compressed_convolve_to_h5 +from .template_util import CompressedUpsampledTemplates, LowRankTemplates +from .templates import TemplateData @dataclass @@ -21,6 +27,7 @@ class CompressedPairwiseConv: correlations sparsely. .query() grabs the actual correlations for the user. """ + # shape: (n_shifts,) # shift_ix -> shift (pitch shift, an integer) shifts: np.ndarray @@ -44,6 +51,15 @@ class CompressedPairwiseConv: # the 0 index is special: pconv[0] === 0. pconv: np.ndarray + def __post_init__(self): + assert self.shifts.ndim == 1 + assert self.shifts.size == self.shifted_template_index.shape[1] + assert ( + self.shifted_template_index.shape + == self.upsampled_shifted_template_index.shape[:2] + ) + self._is_torch = False + @classmethod def from_h5(cls, hdf5_filename): ff = fields(cls) @@ -51,6 +67,110 @@ def from_h5(cls, hdf5_filename): data = {f.name: h5[f.name][:] for f in ff} return cls(**data) + @classmethod + def from_template_data( + cls, + hdf5_filename, + template_data: TemplateData, + low_rank_templates: LowRankTemplates, + compressed_upsampled_temporal: CompressedUpsampledTemplates, + chunk_time_centers_s: Optional[np.ndarray] = None, + motion_est=None, + geom: Optional[np.ndarray] = None, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + conv_batch_size=128, + units_batch_size=8, + overwrite=False, + device=None, + n_jobs=0, + show_progress=True, + ): + compressed_convolve_to_h5( + hdf5_filename, + template_data=template_data, + low_rank_templates=low_rank_templates, + compressed_upsampled_temporal=compressed_upsampled_temporal, + chunk_time_centers_s=chunk_time_centers_s, + motion_est=motion_est, + geom=geom, + conv_ignore_threshold=conv_ignore_threshold, + coarse_approx_error_threshold=coarse_approx_error_threshold, + conv_batch_size=conv_batch_size, + units_batch_size=units_batch_size, + overwrite=overwrite, + device=device, + n_jobs=n_jobs, + show_progress=show_progress, + ) + return cls.from_h5(hdf5_filename) + + def at_shifts(self, shifts=None): + """Subset this database to one set of shifts. + + The database becomes shiftless (not in the pejorative sense). + """ + if shifts is None: + assert self.shifts.shape == (1,) + return self + + assert shifts.shape == len(self.shifted_template_index) + n_shifted_temps, n_up_shifted_temps = self.pconv_index.shape + + # active shifted and upsampled indices + shift_ix = np.searchsorted(self.shifts, shifts) + sub_shifted_temp_index = self.shifted_template_index[ + np.arange(len(self.shifted_template_index)), + shift_ix, + ] + sub_up_shifted_temp_index = self.upsampled_shifted_template_index[ + np.arange(len(self.shifted_template_index)), + shift_ix, + ] + + # in flat form for indexing into pconv_index. also, reindex. + valid_shifted = sub_shifted_temp_index < n_shifted_temps + shifted_temp_ixs, new_shifted_temp_ixs = np.unique( + sub_shifted_temp_index[valid_shifted] + ) + valid_up_shifted = sub_up_shifted_temp_index < n_up_shifted_temps + up_shifted_temp_ixs, new_up_shifted_temp_ixs = np.unique( + sub_up_shifted_temp_index[valid_up_shifted], return_inverse=True + ) + + # get relevant pconv subset and reindex + sub_pconv_indices, new_pconv_indices = np.unique( + self.pconv_index[ + shifted_temp_ixs[:, None], + up_shifted_temp_ixs.ravel()[None, :], + ], + return_inverse=True, + ) + sub_pconv = self.pconv[sub_pconv_indices] + + # reindexing + n_sub_shifted_temps = len(shifted_temp_ixs) + n_sub_up_shifted_temps = len(up_shifted_temp_ixs) + sub_pconv_index = new_pconv_indices.reshape( + n_sub_shifted_temps, n_sub_up_shifted_temps + ) + sub_shifted_temp_index[valid_shifted] = new_shifted_temp_ixs + sub_up_shifted_temp_index[valid_shifted] = new_up_shifted_temp_ixs + + return self.__class__( + shifts=np.zeros(1), + shifted_template_index=sub_shifted_temp_index, + upsampled_shifted_template_index=sub_up_shifted_temp_index, + pconv_index=sub_pconv_index, + pconv=sub_pconv, + ) + + def to(self, device=None): + """Become torch tensors on device.""" + for f in fields(self): + self.setattr(f.name, torch.as_tensor(getattr(self, f.name), device=device)) + self.device = device + def query( self, template_indices_a, @@ -59,9 +179,18 @@ def query( shifts_a=None, shifts_b=None, return_zero_convs=False, + grid=False, ): - template_indices_a = np.atleast_1d(template_indices_a) - template_indices_b = np.atleast_1d(template_indices_b) + if template_indices_a is None: + if self._is_torch: + template_indices_a = torch.arange( + len(self.shifted_template_index), device=self.device + ) + else: + template_indices_a = np.arange(len(self.shifted_template_index)) + if not self._is_torch: + template_indices_a = np.atleast_1d(template_indices_a) + template_indices_b = np.atleast_1d(template_indices_b) # handle no shifting no_shifting = shifts_a is None or shifts_b is None @@ -84,7 +213,7 @@ def query( no_upsampling = upsampling_indices_b is None if no_upsampling: assert self.upsampled_shifted_template_index.shape[2] == 1 - upsampled_shifted_template_index = self.upsampled_shifted_template_index[..., 0] + upsampled_shifted_template_index = upsampled_shifted_template_index[..., 0] else: b_ix = b_ix + (upsampling_indices_b,) @@ -94,11 +223,25 @@ def query( # upsampled shifted template indices for B up_shifted_temp_ix_b = upsampled_shifted_template_index[b_ix] - pconv_indices = self.pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b] + # return convolutions between all ai,bj or just ai,bi? + if grid: + pconv_indices = self.pconv_index[shifted_temp_ix_a[:, None], up_shifted_temp_ix_b[None, :]] + if self._is_torch: + template_indices_a, template_indices_b = torch.cartesian_prod( + template_indices_a, template_indices_b + ).T + pconv_indices = pconv_indices.view(-1) + else: + template_indices_a, template_indices_b = np.meshgrid(template_indices_a, template_indices_b, indexing="ij") + template_indices_a = template_indices_a.ravel() + template_indices_b = template_indices_b.ravel() + pconv_indices = pconv_indices.ravel() + else: + pconv_indices = self.pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b] # most users will be happy not to get a bunch of zeros for pairs that don't overlap if not return_zero_convs: - which = np.flatnonzero(pconv_indices > 0) + which = pconv_indices > 0 pconv_indices = pconv_indices[which] template_indices_a = template_indices_a[which] template_indices_b = template_indices_b[which] diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 3bfc7d15..30255af1 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -155,7 +155,8 @@ def templates_at_time( registered_geom=None, motion_est=None, return_pitch_shifts=False, - # TODO: geom kdtree + geom_kdtree=None, + match_distance=None, ): if registered_geom is None: return registered_templates @@ -180,6 +181,8 @@ def templates_at_time( n_pitches_shift=pitch_shifts, registered_geom=geom, fill_value=np.nan, + target_kdtree=geom_kdtree, + match_distance=match_distance, ) if return_pitch_shifts: return pitch_shifts, unregistered_templates @@ -282,7 +285,7 @@ def compressed_upsampled_templates( Returns ------- A CompressedUpsampledTemplates object with fields: - compressed_upsampled_templates : array (n_compressed_upsampled_templates, spike_length_samples) + compressed_upsampled_templates : array (n_compressed_upsampled_templates, spike_length_samples, n_channels) compressed_upsampling_map : array (n_templates, max_upsample) compressed_upsampled_templates[compressed_upsampling_map[unit, j]] is an approximation of the jth upsampled template for this unit. for low-amplitude units, From 12dc69ec18c5ef28f3d60187eae4f39853e58648 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 9 Nov 2023 15:57:15 -0500 Subject: [PATCH 29/49] Debugging --- src/dartsort/peel/matching.py | 17 +++++++--- src/dartsort/peel/peel_base.py | 6 ++-- src/dartsort/templates/get_templates.py | 33 ++++++++++++------- src/dartsort/templates/pairwise.py | 3 +- src/dartsort/templates/pairwise_util.py | 43 +++++++++++++++---------- src/dartsort/templates/superres_util.py | 39 ++++++++++++++-------- src/dartsort/templates/templates.py | 1 + src/dartsort/util/drift_util.py | 2 ++ 8 files changed, 94 insertions(+), 50 deletions(-) diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index a06710f1..8b0af61b 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -218,6 +218,10 @@ def build_template_data( temporal_components = low_rank_templates.temporal_components.astype(dtype) singular_values = low_rank_templates.singular_values.astype(dtype) spatial_components = low_rank_templates.spatial_components.astype(dtype) + print(f"{template_data.templates.dtype=}") + print(f"{temporal_components.dtype=}") + print(f"{singular_values.dtype=}") + print(f"{spatial_components.dtype=}") self.register_buffer("temporal_components", torch.tensor(temporal_components)) self.register_buffer("singular_values", torch.tensor(singular_values)) self.register_buffer("spatial_components", torch.tensor(spatial_components)) @@ -236,16 +240,20 @@ def build_template_data( chunk_centers_s = self.recording._recording_segments[0].sample_index_to_time( chunk_centers_samples ) + print(f"build_template_data {device=}") + print(f"{chunk_centers_s.shape=} {chunk_centers_s[:10]=}") self.pairwise_conv_db = CompressedPairwiseConv.from_template_data( save_folder / "pconv.h5", template_data=template_data, low_rank_templates=low_rank_templates, compressed_upsampled_temporal=compressed_upsampled_temporal, chunk_time_centers_s=chunk_centers_s, - motion_est=motion_est, + motion_est=self.motion_est, geom=self.geom, conv_ignore_threshold=self.conv_ignore_threshold, coarse_approx_error_threshold=self.coarse_approx_error_threshold, + device=device, + n_jobs=n_jobs, ) self.fixed_output_data += [ @@ -258,7 +266,7 @@ def build_template_data( ), ( "compressed_upsampled_temporal", - compressed_upsampled_temporal.compressed_upsampled_temporal, + compressed_upsampled_temporal.compressed_upsampled_templates, ), ] @@ -274,13 +282,14 @@ def handle_upsampling( ptps=ptps, max_upsample=temporal_upsampling_factor, ) + print(f"{compressed_upsampled_temporal.compressed_upsampled_templates.dtype=}") self.register_buffer( "compressed_upsampling_map", - compressed_upsampled_temporal.compressed_upsampling_map, + torch.tensor(compressed_upsampled_temporal.compressed_upsampling_map), ) self.register_buffer( "compressed_upsampled_temporal", - compressed_upsampled_temporal.compressed_upsampled_temporal, + torch.tensor(compressed_upsampled_temporal.compressed_upsampled_templates), ) if temporal_upsampling_factor == 1: return compressed_upsampled_temporal diff --git a/src/dartsort/peel/peel_base.py b/src/dartsort/peel/peel_base.py index e7e3edc0..f8110d12 100644 --- a/src/dartsort/peel/peel_base.py +++ b/src/dartsort/peel/peel_base.py @@ -213,7 +213,7 @@ def peeling_needs_fit(self): def precompute_peeling_data(self, save_folder, n_jobs=0, device=None): # subclasses should override if they need to cache data for peeling # runs before fit_peeler_models() - assert not self.peeling_needs_fit() + pass def fit_peeler_models(self, save_folder): # subclasses should override if they need to fit models for peeling @@ -324,7 +324,9 @@ def needs_fit(self): def fit_models(self, save_folder, n_jobs=0, device=None): with torch.no_grad(): if self.peeling_needs_fit(): - self.precompute_peeling_data() + self.precompute_peeling_data( + save_folder=save_folder, n_jobs=n_jobs, device=device + ) self.fit_peeler_models( save_folder=save_folder, n_jobs=n_jobs, device=device ) diff --git a/src/dartsort/templates/get_templates.py b/src/dartsort/templates/get_templates.py index fdb49cd7..a354495d 100644 --- a/src/dartsort/templates/get_templates.py +++ b/src/dartsort/templates/get_templates.py @@ -181,6 +181,7 @@ def get_templates( snr_threshold=denoising_snr_threshold, ) templates = weights * raw_templates + (1 - weights) * low_rank_templates + templates = templates.astype(recording.dtype) return dict( sorting=sorting, @@ -379,13 +380,16 @@ def get_all_shifted_raw_and_low_rank_templates( registered_kdtree = KDTree(registered_geom) n_units = sorting.labels.max() + 1 - raw_templates = np.zeros((n_units, spike_length_samples, n_template_channels)) + raw_templates = np.zeros( + (n_units, spike_length_samples, n_template_channels), dtype=recording.dtype + ) low_rank_templates = None if not raw: low_rank_templates = np.zeros( - (n_units, spike_length_samples, n_template_channels) + (n_units, spike_length_samples, n_template_channels), + dtype=recording.dtype, ) - snrs_by_channel = np.zeros((n_units, n_template_channels)) + snrs_by_channel = np.zeros((n_units, n_template_channels), dtype=recording.dtype) unit_id_chunks = [ unit_ids[i : i + units_per_job] for i in range(0, n_units, units_per_job) @@ -421,6 +425,8 @@ def get_all_shifted_raw_and_low_rank_templates( unit="template", ) for res in results: + if res is None: + continue units_chunk, raw_temps_chunk, low_rank_temps_chunk, snrs_chunk = res raw_templates[units_chunk] = raw_temps_chunk if not raw: @@ -477,12 +483,14 @@ def __init__( dtype=torch.from_numpy(np.zeros(1, dtype=recording.dtype)).dtype, ) + self.n_template_channels = self.n_channels if self.registered: self.geom = recording.get_channel_locations() self.match_distance = pdist(self.geom).min() / 2 self.registered_geom = registered_kdtree.data self.registered_kdtree = registered_kdtree self.pitch_shifts = pitch_shifts + self.n_template_channels = len(self.registered_geom) _template_process_context = None @@ -535,6 +543,8 @@ def _template_job(unit_ids): p = _template_process_context in_units_full = np.flatnonzero(np.isin(p.sorting.labels, unit_ids)) + if not in_units_full.size: + return labels_full = p.sorting.labels[in_units_full] # only so many spikes per unit @@ -564,7 +574,7 @@ def _template_job(unit_ids): (times >= p.trough_offset_samples) & (times < p.max_spike_time) ) if not valid.size: - return uids, 0, 0, 0 + return in_units = in_units[valid] labels = labels[valid] times = times[valid] @@ -581,12 +591,12 @@ def _template_job(unit_ids): # compute raw templates and spike counts per channel raw_templates = [] counts = [] + units_chunk = [] for u in uids: in_unit = np.flatnonzero(labels == u) if not in_unit.size: - raw_templates.append(np.zeros(1)) - counts.append(0) continue + units_chunk.append(u) in_unit_orig = in_units[labels == u] if p.registered: raw_templates.append( @@ -617,9 +627,10 @@ def _template_job(unit_ids): ) counts.append(in_unit.size) snrs_by_chan = [ptp(rt, 0) * c for rt, c in zip(raw_templates, counts)] + raw_templates = np.array(raw_templates) if p.denoising_tsvd is None: - return uids, raw_templates, None, snrs_by_chan + return units_chunk, raw_templates, None, snrs_by_chan # apply denoising waveforms = waveforms.permute(0, 2, 1).reshape(n * c, t) @@ -628,11 +639,8 @@ def _template_job(unit_ids): # get low rank templates low_rank_templates = [] - for u in uids: + for u in units_chunk: in_unit = np.flatnonzero(labels == u) - if not in_unit.size: - low_rank_templates.append(0) - continue in_unit_orig = in_units[labels == u] if p.registered: low_rank_templates.append( @@ -650,8 +658,9 @@ def _template_job(unit_ids): low_rank_templates.append( p.reducer(waveforms[in_unit], axis=0).numpy(force=True) ) + low_rank_templates = np.array(low_rank_templates) - return uids, raw_templates, low_rank_templates, snrs_by_chan + return units_chunk, raw_templates, low_rank_templates, snrs_by_chan class TorchSVDProjector(torch.nn.Module): diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index ace0a170..0caceecc 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -79,13 +79,14 @@ def from_template_data( geom: Optional[np.ndarray] = None, conv_ignore_threshold=0.0, coarse_approx_error_threshold=0.0, - conv_batch_size=128, + conv_batch_size=1024, units_batch_size=8, overwrite=False, device=None, n_jobs=0, show_progress=True, ): + print(f"pairwise from_template_data {device=}") compressed_convolve_to_h5( hdf5_filename, template_data=template_data, diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index 061596ee..07c7ac7f 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -28,7 +28,7 @@ def compressed_convolve_to_h5( geom: Optional[np.ndarray] = None, conv_ignore_threshold=0.0, coarse_approx_error_threshold=0.0, - conv_batch_size=128, + conv_batch_size=1024, units_batch_size=8, overwrite=False, device=None, @@ -57,6 +57,7 @@ def compressed_convolve_to_h5( upsampled_shifted_template_index = get_upsampled_shifted_template_index( template_shift_index, compressed_upsampled_temporal ) + print(f"compressed_convolve_to_h5 {conv_batch_size=} {units_batch_size=} {device=}") chunk_res_iterator = iterate_compressed_pairwise_convolutions( template_data=template_data, @@ -148,7 +149,7 @@ def iterate_compressed_pairwise_convolutions( conv_ignore_threshold=0.0, coarse_approx_error_threshold=0.0, max_shift="full", - conv_batch_size=128, + conv_batch_size=1024, units_batch_size=8, device=None, n_jobs=0, @@ -165,6 +166,7 @@ def iterate_compressed_pairwise_convolutions( process the results differently. """ # construct drift-related helper data if needed + print(f"iterate_compressed_pairwise_convolutions {conv_batch_size=} {units_batch_size=} {device=}") n_shifts = template_shift_index.all_pitch_shifts.size do_shifting = n_shifts > 1 geom_kdtree = reg_geom_kdtree = match_distance = None @@ -267,7 +269,7 @@ def compressed_convolve_pairs( conv_ignore_threshold=0.0, coarse_approx_error_threshold=0.0, max_shift="full", - batch_size=128, + batch_size=1024, device=None, ) -> Optional[CompressedConvResult]: """Compute compressed pairwise convolutions between template pairs @@ -280,9 +282,11 @@ def compressed_convolve_pairs( shifts, superres templates, and upsamples. Some of these may be zero or may be duplicates, so the return value is a sparse representation. See below. """ + # print(f"compressed_convolve_pairs {device=}") # print(f"{units_a.shape=}") # print(f"{units_b.shape=}") # print(f"{(units_a.size * units_b.size)=}") + # print(f"compressed_convolve_pairs {batch_size=} {units_a.size=} {device=}") # what pairs, shifts, etc are we convolving? shifted_temp_ix_a, temp_ix_a, shift_a, unit_a = handle_shift_indices( @@ -317,6 +321,9 @@ def compressed_convolve_pairs( match_distance=match_distance, device=device, ) + # print(f"{low_rank_templates.spatial_components.dtype=} {low_rank_templates.singular_values.dtype=}") + # print(f"{compressed_upsampled_temporal.compressed_upsampled_templates.dtype=}") + # print(f"{spatial_singular_a.dtype=} {spatial_singular_b.dtype=}") # figure out pairs of shifted templates to convolve in a deduplicated way pairs_ret = shift_deduplicated_pairs( @@ -392,27 +399,27 @@ def compressed_convolve_pairs( # print(f"{temporal_a[ix_a[conv_ix]].shape=}") # print(f"{conv_temporal_components_up_b.shape=}") pconv, kept = correlate_pairs_lowrank( - torch.as_tensor(spatial_singular_a[ix_a[conv_ix]]).to(device), - torch.as_tensor(spatial_singular_b[ix_b[conv_ix]]).to(device), - torch.as_tensor(temporal_a[ix_a[conv_ix]]).to(device), - torch.as_tensor(conv_temporal_components_up_b).to(device), + torch.as_tensor(spatial_singular_a[ix_a[conv_ix]], device=device), + torch.as_tensor(spatial_singular_b[ix_b[conv_ix]], device=device), + torch.as_tensor(temporal_a[ix_a[conv_ix]], device=device), + torch.as_tensor(conv_temporal_components_up_b, device=device), max_shift=max_shift, conv_ignore_threshold=conv_ignore_threshold, batch_size=batch_size, ) - print(f"-----------") - print(f"after corr {pconv.shape=} {conv_ix[kept].shape=}") + # print(f"-----------") + # print(f"after corr {pconv.shape=} {conv_ix[kept].shape=}") conv_ix = conv_ix[kept] if not conv_ix.size: return None kept_pairs = np.flatnonzero(np.isin(compression_index, kept)) - print(f"-----------") - print(f"kept {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}") - print(f"{compression_index.min()=} {compression_index.max()=}") - print(f"{compression_index[kept_pairs].min()=} {compression_index[kept_pairs].max()=}") - print(f"{ix_a.shape=} {ix_b.shape=}") - print(f"{kept.shape=} {kept.dtype=} {kept.min()=} {kept.max()=}") - print(f"{kept_pairs.shape=} {kept_pairs.dtype=} {kept_pairs.min()=} {kept_pairs.max()=}") + # print(f"-----------") + # print(f"kept {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}") + # print(f"{compression_index.min()=} {compression_index.max()=}") + # print(f"{compression_index[kept_pairs].min()=} {compression_index[kept_pairs].max()=}") + # print(f"{ix_a.shape=} {ix_b.shape=}") + # print(f"{kept.shape=} {kept.dtype=} {kept.min()=} {kept.max()=}") + # print(f"{kept_pairs.shape=} {kept_pairs.dtype=} {kept_pairs.min()=} {kept_pairs.max()=}") compression_index = np.searchsorted(kept, compression_index[kept_pairs]) conv_ix = np.searchsorted(kept_pairs, conv_ix) ix_a = ix_a[kept_pairs] @@ -472,7 +479,7 @@ def correlate_pairs_lowrank( temporal_b, max_shift="full", conv_ignore_threshold=0.0, - batch_size=128, + batch_size=1024, ): """Convolve pairs of low rank templates @@ -504,6 +511,8 @@ def correlate_pairs_lowrank( assert n_pairs == n_pairs_ assert t == t_ assert rank == rank_ + # print(f"{spatial_a.device=} {spatial_b.device=} {temporal_a.device=} {temporal_b.device=}") + # print(f"compressed_convolve_pairs {batch_size=} {n_pairs=} {spatial_a.device=}") if max_shift == "full": max_shift = t - 1 diff --git a/src/dartsort/templates/superres_util.py b/src/dartsort/templates/superres_util.py index 579f8cbd..2884d5a0 100644 --- a/src/dartsort/templates/superres_util.py +++ b/src/dartsort/templates/superres_util.py @@ -13,6 +13,7 @@ def superres_sorting( strategy="drift_pitch_loc_bin", superres_bin_size_um=10.0, min_spikes_per_bin=5, + probe_margin_um=200.0, ): """Construct the spatially superresolved spike train @@ -48,11 +49,20 @@ def superres_sorting( superres_sorting : DARTsortSorting """ pitch = drift_util.get_pitch(geom) - labels = sorting.labels - + full_labels = sorting.labels.copy() + + # remove spikes far away from the probe + if probe_margin_um is not None: + valid = spike_depths_um == np.clip( + spike_depths_um, + geom[:, 1].min() - probe_margin_um, + geom[:, 1].max() + probe_margin_um, + ) + full_labels[~valid] = -1 + # handle triaging - kept = np.flatnonzero(labels >= 0) - labels = labels[kept] + kept = np.flatnonzero(full_labels >= 0) + labels = full_labels[kept] spike_times_s = spike_times_s[kept] spike_depths_um = spike_depths_um[kept] @@ -80,17 +90,15 @@ def superres_sorting( ) else: raise ValueError(f"Unknown superres {strategy=}") - + # handle too-small units superres_labels, superres_to_original = remove_small_superres_units( superres_labels, superres_to_original, min_spikes_per_bin=min_spikes_per_bin ) - - # handle triaging again - full_superres_labels = sorting.labels.copy() - full_superres_labels[kept] = superres_labels - superres_sorting = replace(sorting, labels=full_superres_labels) + # back to un-triaged label space + full_labels[kept] = superres_labels + superres_sorting = replace(sorting, labels=full_labels) return superres_to_original, superres_sorting @@ -129,7 +137,9 @@ def drift_pitch_loc_bin_strategy( ) coarse_reg_depths = spike_depths_um + n_pitches_shift * pitch bin_ids = coarse_reg_depths // superres_bin_size_um - print(f"{np.isnan(n_pitches_shift).any()=} {np.isfinite(bin_ids).all()=} {superres_bin_size_um=}") + print( + f"{np.isnan(n_pitches_shift).any()=} {np.isfinite(bin_ids).all()=} {superres_bin_size_um=}" + ) print(f"{bin_ids.min()=} {bin_ids.max()=} {bin_ids.shape=}") print(f"{original_labels.min()=} {original_labels.max()=} {original_labels.shape=}") bin_ids = bin_ids.astype(int) @@ -140,8 +150,9 @@ def drift_pitch_loc_bin_strategy( return superres_labels, superres_to_original - -def remove_small_superres_units(superres_labels, superres_to_original, min_spikes_per_bin): +def remove_small_superres_units( + superres_labels, superres_to_original, min_spikes_per_bin +): if not min_spikes_per_bin: return superres_labels, superres_to_original @@ -157,4 +168,4 @@ def remove_small_superres_units(superres_labels, superres_to_original, min_spike superres_labels = relabeling[superres_labels] superres_to_original = superres_to_original[kept_labels] - return superres_labels, superres_to_original \ No newline at end of file + return superres_labels, superres_to_original diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 28e7710b..d25c1c90 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -163,6 +163,7 @@ def from_config( # main! results = get_templates(recording, sorting, **kwargs) + print(f"{[(k,v.dtype) for k,v in results.items() if (isinstance(v, np.ndarray))]=}") # handle registered templates if template_config.registered_templates: diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index 3054da74..0e6cccb6 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -187,6 +187,8 @@ def registered_template( weights = valid[:, None, :] * counts[:, None, None] weights = weights / np.maximum(weights.sum(0), 1) template = (np.nan_to_num(static_templates) * weights).sum(0) + dtype = str(waveforms.dtype).split(".")[1] if is_tensor else waveforms.dtype + template = template.astype(dtype) template[:, ~valid.any(0)] = np.nan if not np.isnan(pad_value): template = np.nan_to_num(template, copy=False, nan=pad_value) From 31c6dde566cddcbb80f7ed88f44e9ca2d08d0c67 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 13 Nov 2023 12:51:29 -0800 Subject: [PATCH 30/49] Return spike count --- src/dartsort/util/data_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index 16c8969c..92ece083 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -149,6 +149,7 @@ def check_recording( dedup_channel_index = make_channel_index( rec.get_channel_locations(), dedup_spatial_radius ) + failed = False # run detection and compute spike detection rate and data range spike_rates = [] @@ -173,6 +174,7 @@ def check_recording( "you experience memory issues.", RuntimeWarning, ) + failed = True if max_abs > expected_value_range: warn( @@ -180,5 +182,6 @@ def check_recording( "check that your data has been preprocessed, including standardization.", RuntimeWarning, ) + failed = True - return avg_detections_per_second, max_abs + return failed, avg_detections_per_second, max_abs From dd8f50b923cacfa6107bdaa10806e7919ea9fd7d Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 13 Nov 2023 15:58:14 -0500 Subject: [PATCH 31/49] Residual filename --- src/dartsort/main.py | 2 ++ src/dartsort/peel/peel_base.py | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/dartsort/main.py b/src/dartsort/main.py index f9dc062d..02376d43 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -149,6 +149,8 @@ def _run_peeler( output_directory.mkdir(exist_ok=True) model_dir = output_directory / model_subdir output_hdf5_filename = output_directory / hdf5_filename + if residual_filename is not None: + residual_filename = output_directory / residual_filename # fit models if needed peeler.load_or_fit_and_save_models( diff --git a/src/dartsort/peel/peel_base.py b/src/dartsort/peel/peel_base.py index e7e3edc0..dbc3f1c9 100644 --- a/src/dartsort/peel/peel_base.py +++ b/src/dartsort/peel/peel_base.py @@ -516,11 +516,10 @@ def _peeler_process_init(peeler, device, rank_queue, save_residual): def _peeler_process_job(chunk_start_samples): - peeler = _peeler_process_context.peeler # by returning here, we are implicitly relying on pickle # we can replace this with cloudpickle or manual np.save if helpful with torch.no_grad(): - return peeler.process_chunk( + return _peeler_process_context.peeler.process_chunk( chunk_start_samples, return_residual=_peeler_process_context.save_residual, ) From 219f5741188646fa7f359d866924859f3ea99729 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 14 Nov 2023 22:45:44 -0500 Subject: [PATCH 32/49] Refactor pairwise conv to allow for different templates in LHS/RHS --- src/dartsort/main.py | 2 +- src/dartsort/templates/pairwise.py | 59 +++++---- src/dartsort/templates/pairwise_util.py | 165 ++++++++++++++++-------- src/dartsort/templates/templates.py | 53 +++++--- src/dartsort/util/drift_util.py | 153 ++++++++++++---------- 5 files changed, 261 insertions(+), 171 deletions(-) diff --git a/src/dartsort/main.py b/src/dartsort/main.py index 02376d43..b441cc52 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -129,7 +129,7 @@ def match( return sorting, output_hdf5_filename -# -- helper function +# -- helper function for subtract, match def _run_peeler( diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index 0caceecc..4fbb17ee 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -30,19 +30,20 @@ class CompressedPairwiseConv: # shape: (n_shifts,) # shift_ix -> shift (pitch shift, an integer) - shifts: np.ndarray + shifts_a: np.ndarray + shifts_b: np.ndarray - # shape: (n_templates, n_shifts) + # shape: (n_templates_a, n_shifts_a) # (template_ix, shift_ix) -> shifted_template_ix # shifted_template_ix can be either invalid (this template does not occur - # at this shift), or it can range from 0, ..., n_shifted_templates-1 - shifted_template_index: np.ndarray + # at this shift), or it can range from 0, ..., n_shifted_templates_a-1 + shifted_template_index_a: np.ndarray - # shape: (n_templates, n_shifts, upsampling_factor) + # shape: (n_templates_b, n_shifts_b, upsampling_factor) # (template_ix, shift_ix, upsampling_ix) -> upsampled_shifted_template_ix - upsampled_shifted_template_index: np.ndarray + upsampled_shifted_template_index_b: np.ndarray - # shape: (n_shifted_templates, n_upsampled_shifted_templates) + # shape: (n_shifted_templates_a, n_upsampled_shifted_templates_b) # (shifted_template_ix, upsampled_shifted_template_ix) -> pconv_ix pconv_index: np.ndarray @@ -52,12 +53,9 @@ class CompressedPairwiseConv: pconv: np.ndarray def __post_init__(self): - assert self.shifts.ndim == 1 - assert self.shifts.size == self.shifted_template_index.shape[1] - assert ( - self.shifted_template_index.shape - == self.upsampled_shifted_template_index.shape[:2] - ) + assert self.shifts_a.ndim == self.shifts_b.ndim == 1 + assert self.shifts_a.size == self.shifted_template_index_a.shape[1] + assert self.shifts_b.size == self.upsampled_shifted_template_index_b.shape[1] self._is_torch = False @classmethod @@ -106,26 +104,26 @@ def from_template_data( ) return cls.from_h5(hdf5_filename) - def at_shifts(self, shifts=None): + def at_shifts(self, shifts_a=None): """Subset this database to one set of shifts. The database becomes shiftless (not in the pejorative sense). """ - if shifts is None: - assert self.shifts.shape == (1,) + if shifts_a is None: + assert self.shifts_a.shape == (1,) return self - assert shifts.shape == len(self.shifted_template_index) + assert shifts_a.shape == len(self.shifted_template_index_a) n_shifted_temps, n_up_shifted_temps = self.pconv_index.shape # active shifted and upsampled indices - shift_ix = np.searchsorted(self.shifts, shifts) - sub_shifted_temp_index = self.shifted_template_index[ - np.arange(len(self.shifted_template_index)), + shift_ix = np.searchsorted(self.shifts_a, shifts_a) + sub_shifted_temp_index = self.shifted_template_index_a[ + np.arange(len(self.shifted_template_index_a)), shift_ix, ] - sub_up_shifted_temp_index = self.upsampled_shifted_template_index[ - np.arange(len(self.shifted_template_index)), + sub_up_shifted_temp_index = self.upsampled_shifted_template_index_b[ + np.arange(len(self.shifted_template_index_a)), shift_ix, ] @@ -185,35 +183,36 @@ def query( if template_indices_a is None: if self._is_torch: template_indices_a = torch.arange( - len(self.shifted_template_index), device=self.device + len(self.shifted_template_index_a), device=self.device ) else: - template_indices_a = np.arange(len(self.shifted_template_index)) + template_indices_a = np.arange(len(self.shifted_template_index_a)) if not self._is_torch: template_indices_a = np.atleast_1d(template_indices_a) template_indices_b = np.atleast_1d(template_indices_b) # handle no shifting no_shifting = shifts_a is None or shifts_b is None - shifted_template_index = self.shifted_template_index - upsampled_shifted_template_index = self.upsampled_shifted_template_index + shifted_template_index = self.shifted_template_index_a + upsampled_shifted_template_index = self.upsampled_shifted_template_index_b if no_shifting: assert shifts_a is None and shifts_b is None - assert self.shifts.shape == (1,) + assert self.shifts_a.shape == (1,) + assert self.shifts_b.shape == (1,) a_ix = (template_indices_a,) b_ix = (template_indices_b,) shifted_template_index = shifted_template_index[:, 0] upsampled_shifted_template_index = upsampled_shifted_template_index[:, 0] else: - shift_indices_a = np.searchsorted(self.shifts, shifts_a) - shift_indices_b = np.searchsorted(self.shifts, shifts_b) + shift_indices_a = np.searchsorted(self.shifts_a, shifts_a) + shift_indices_b = np.searchsorted(self.shifts_b, shifts_b) a_ix = (template_indices_a, shift_indices_a) b_ix = (template_indices_b, shift_indices_b) # handle no upsampling no_upsampling = upsampling_indices_b is None if no_upsampling: - assert self.upsampled_shifted_template_index.shape[2] == 1 + assert self.upsampled_shifted_template_index_b.shape[2] == 1 upsampled_shifted_template_index = upsampled_shifted_template_index[..., 0] else: b_ix = b_ix + (upsampling_indices_b,) diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index 07c7ac7f..d9d467e1 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -2,7 +2,6 @@ from collections import namedtuple from dataclasses import dataclass, fields -from pathlib import Path from typing import Iterator, Optional, Union import h5py @@ -23,6 +22,8 @@ def compressed_convolve_to_h5( template_data: templates.TemplateData, low_rank_templates: template_util.LowRankTemplates, compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, + template_data_b: Optional[templates.TemplateData] = None, + low_rank_templates_b: Optional[templates.TemplateData] = None, chunk_time_centers_s: Optional[np.ndarray] = None, motion_est=None, geom: Optional[np.ndarray] = None, @@ -48,23 +49,32 @@ def compressed_convolve_to_h5( pass # TODO # construct indexing helpers - template_shift_index = drift_util.get_shift_and_unit_pairs( + ( + template_shift_index_a, + template_shift_index_b, + upsampled_shifted_template_index, + cooccurrence, + ) = construct_shift_indices( chunk_time_centers_s, geom, template_data, + compressed_upsampled_temporal, + template_data_b=template_data_b, motion_est=motion_est, ) - upsampled_shifted_template_index = get_upsampled_shifted_template_index( - template_shift_index, compressed_upsampled_temporal - ) print(f"compressed_convolve_to_h5 {conv_batch_size=} {units_batch_size=} {device=}") chunk_res_iterator = iterate_compressed_pairwise_convolutions( - template_data=template_data, - low_rank_templates=low_rank_templates, + template_data_a=template_data, + low_rank_templates_a=low_rank_templates, + template_data_b=template_data_b, + low_rank_templates_b=low_rank_templates_b, compressed_upsampled_temporal=compressed_upsampled_temporal, - template_shift_index=template_shift_index, + template_shift_index_a=template_shift_index_a, + template_shift_index_b=template_shift_index_b, + cooccurrence=cooccurrence, upsampled_shifted_template_index=upsampled_shifted_template_index, + do_shifting=motion_est is not None, geom=geom, conv_ignore_threshold=conv_ignore_threshold, coarse_approx_error_threshold=coarse_approx_error_threshold, @@ -78,7 +88,7 @@ def compressed_convolve_to_h5( pconv_index = np.zeros( ( - template_shift_index.n_shifted_templates, + template_shift_index_a.n_shifted_templates, upsampled_shifted_template_index.n_upsampled_shifted_templates, ), dtype=int, @@ -100,7 +110,7 @@ def compressed_convolve_to_h5( continue # get shifted template indices for A - shifted_temp_ix_a = template_shift_index.template_shift_index[ + shifted_temp_ix_a = template_shift_index_a.template_shift_index[ chunk_res.template_indices_a, chunk_res.shift_indices_a, ] @@ -126,12 +136,13 @@ def compressed_convolve_to_h5( n_pconvs += n_new_pconvs # write fixed size outputs - h5.create_dataset("shifts", data=template_shift_index.all_pitch_shifts) + h5.create_dataset("shifts_a", data=template_shift_index_a.all_pitch_shifts) + h5.create_dataset("shifts_b", data=template_shift_index_b.all_pitch_shifts) h5.create_dataset( - "shifted_template_index", data=template_shift_index.template_shift_index + "shifted_template_index_a", data=template_shift_index_a.template_shift_index ) h5.create_dataset( - "upsampled_shifted_template_index", + "upsampled_shifted_template_index_b", data=upsampled_shifted_template_index.upsampled_shifted_template_index, ) h5.create_dataset("pconv_index", data=pconv_index) @@ -140,11 +151,16 @@ def compressed_convolve_to_h5( def iterate_compressed_pairwise_convolutions( - template_data: templates.TemplateData, - low_rank_templates: template_util.LowRankTemplates, + template_data_a: templates.TemplateData, + low_rank_templates_a: template_util.LowRankTemplates, + template_data_b: templates.TemplateData, + low_rank_templates_b: template_util.LowRankTemplates, compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, - template_shift_index: drift_util.TemplateShiftIndex, + template_shift_index_a: drift_util.TemplateShiftIndex, + template_shift_index_b: drift_util.TemplateShiftIndex, + cooccurrence: np.ndarray, upsampled_shifted_template_index: UpsampledShiftedTemplateIndex, + do_shifting: bool = True, geom: Optional[np.ndarray] = None, conv_ignore_threshold=0.0, coarse_approx_error_threshold=0.0, @@ -157,7 +173,6 @@ def iterate_compressed_pairwise_convolutions( ) -> Iterator[Optional[CompressedConvResult]]: """A generator of CompressedConvResults capturing all pairs of templates - Runs the function compressed_convolve_pairs on chunks of units. This is a helper function for parallelizing computation of cross correlations @@ -165,35 +180,44 @@ def iterate_compressed_pairwise_convolutions( memory, so this is a generator yielding a chunk at a time. Callers may process the results differently. """ + reg_geom = template_data_a.registered_geom + if template_data_b is None: + template_data_b = template_data_a + assert low_rank_templates_b is None + low_rank_templates_b = low_rank_templates_a + assert np.array_equal(reg_geom, template_data_b.registered_geom) + # construct drift-related helper data if needed - print(f"iterate_compressed_pairwise_convolutions {conv_batch_size=} {units_batch_size=} {device=}") - n_shifts = template_shift_index.all_pitch_shifts.size - do_shifting = n_shifts > 1 + print( + f"iterate_compressed_pairwise_convolutions {conv_batch_size=} {units_batch_size=} {device=}" + ) geom_kdtree = reg_geom_kdtree = match_distance = None - reg_geom = template_data.registered_geom if do_shifting: - assert geom is not None - assert reg_geom is not None geom_kdtree = KDTree(geom) reg_geom_kdtree = KDTree(reg_geom) match_distance = pdist(geom).min() / 2 # make chunks - units = np.unique(template_data.unit_ids) + units_a = np.unique(template_data_a.unit_ids) + units_b = np.unique(template_data_b.unit_ids) jobs = [] - for start_a in range(0, units.size, units_batch_size): - end_a = min(start_a + units_batch_size, units.size) - for start_b in range(start_a, units.size, units_batch_size): - end_b = min(start_b + units_batch_size, units.size) - jobs.append((units[start_a:end_a], units[start_b:end_b])) + for start_a in range(0, units_a.size, units_batch_size): + end_a = min(start_a + units_batch_size, units_a.size) + for start_b in range(0, units_b.size, units_batch_size): + end_b = min(start_b + units_batch_size, units_b.size) + jobs.append((units_a[start_a:end_a], units_b[start_b:end_b])) # worker kwargs kwargs = dict( - template_data=template_data, - low_rank_templates=low_rank_templates, + template_data_a=template_data_a, + template_data_b=template_data_b, + low_rank_templates_a=low_rank_templates_a, + low_rank_templates_b=low_rank_templates_b, compressed_upsampled_temporal=compressed_upsampled_temporal, - template_shift_index=template_shift_index, + template_shift_index_a=template_shift_index_a, + template_shift_index_b=template_shift_index_b, upsampled_shifted_template_index=upsampled_shifted_template_index, + cooccurrence=cooccurrence, geom=geom, reg_geom=reg_geom, geom_kdtree=geom_kdtree, @@ -254,11 +278,15 @@ class CompressedConvResult: def compressed_convolve_pairs( - template_data: templates.TemplateData, - low_rank_templates: template_util.LowRankTemplates, + template_data_a: templates.TemplateData, + template_data_b: templates.TemplateData, + low_rank_templates_a: template_util.LowRankTemplates, + low_rank_templates_b: template_util.LowRankTemplates, compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, - template_shift_index: drift_util.TemplateShiftIndex, + template_shift_index_a: drift_util.TemplateShiftIndex, + template_shift_index_b: drift_util.TemplateShiftIndex, upsampled_shifted_template_index: UpsampledShiftedTemplateIndex, + cooccurrence: np.ndarray, geom: Optional[np.ndarray] = None, reg_geom: Optional[np.ndarray] = None, geom_kdtree: Optional[KDTree] = None, @@ -290,10 +318,10 @@ def compressed_convolve_pairs( # what pairs, shifts, etc are we convolving? shifted_temp_ix_a, temp_ix_a, shift_a, unit_a = handle_shift_indices( - units_a, template_data.unit_ids, template_shift_index + units_a, template_data_a.unit_ids, template_shift_index_a ) shifted_temp_ix_b, temp_ix_b, shift_b, unit_b = handle_shift_indices( - units_b, template_data.unit_ids, template_shift_index + units_b, template_data_b.unit_ids, template_shift_index_b ) # print(f"{shifted_temp_ix_a.shape=}") # print(f"{shifted_temp_ix_b.shape=}") @@ -302,8 +330,8 @@ def compressed_convolve_pairs( spatial_singular_a = get_shifted_spatial_singular( temp_ix_a, shift_a, - template_shift_index, - low_rank_templates, + template_shift_index_a, + low_rank_templates_a, geom=geom, registered_geom=reg_geom, geom_kdtree=geom_kdtree, @@ -313,8 +341,8 @@ def compressed_convolve_pairs( spatial_singular_b = get_shifted_spatial_singular( temp_ix_b, shift_b, - template_shift_index, - low_rank_templates, + template_shift_index_b, + low_rank_templates_b, geom=geom, registered_geom=reg_geom, geom_kdtree=geom_kdtree, @@ -333,9 +361,9 @@ def compressed_convolve_pairs( spatial_singular_b, temp_ix_a, temp_ix_b, + cooccurrence=cooccurrence, shift_a=shift_a, shift_b=shift_b, - template_shift_index=template_shift_index, conv_ignore_threshold=conv_ignore_threshold, geom=geom, registered_geom=reg_geom, @@ -393,7 +421,7 @@ def compressed_convolve_pairs( # shift_b = shift_b[ix_b] # run convolutions - temporal_a = low_rank_templates.temporal_components[temp_ix_a] + temporal_a = low_rank_templates_a.temporal_components[temp_ix_a] # print(f"{spatial_singular_a[ix_a[conv_ix]].shape=}") # print(f"{spatial_singular_b[ix_b[conv_ix]].shape=}") # print(f"{temporal_a[ix_a[conv_ix]].shape=}") @@ -454,9 +482,9 @@ def compressed_convolve_pairs( # recover metadata temp_ix_a = temp_ix_a[ix_a] - shift_ix_a = np.searchsorted(template_shift_index.all_pitch_shifts, shift_a[ix_a]) + shift_ix_a = np.searchsorted(template_shift_index_a.all_pitch_shifts, shift_a[ix_a]) temp_ix_b = temp_ix_b[ix_b] - shift_ix_b = np.searchsorted(template_shift_index.all_pitch_shifts, shift_b[ix_b]) + shift_ix_b = np.searchsorted(template_shift_index_b.all_pitch_shifts, shift_b[ix_b]) return CompressedConvResult( template_indices_a=temp_ix_a, @@ -558,7 +586,38 @@ def correlate_pairs_lowrank( return pconv, kept +def construct_shift_indices( + chunk_time_centers_s, + geom, + template_data_a, + compressed_upsampled_temporal, + template_data_b=None, + motion_est=None, +): + ( + template_shift_index_a, + template_shift_index_b, + cooccurrence, + ) = drift_util.get_shift_and_unit_pairs( + chunk_time_centers_s, + geom, + template_data_a, + template_data_b=template_data_b, + motion_est=motion_est, + ) + upsampled_shifted_template_index = get_upsampled_shifted_template_index( + template_shift_index_b, compressed_upsampled_temporal + ) + return ( + template_shift_index_a, + template_shift_index_b, + upsampled_shifted_template_index, + cooccurrence, + ) + + def handle_shift_indices(units, unit_ids, template_shift_index): + """Determine shifted template indices belonging to a set of units.""" shifted_temp_ix_to_unit = unit_ids[template_shift_index.shifted_temp_ix_to_temp_ix] if units is None: shifted_temp_ix = np.arange(template_shift_index.n_shifted_templates) @@ -613,9 +672,9 @@ def shift_deduplicated_pairs( spatialsing_b, temp_ix_a, temp_ix_b, + cooccurrence, shift_a=None, shift_b=None, - template_shift_index=None, conv_ignore_threshold=0.0, geom=None, registered_geom=None, @@ -661,7 +720,7 @@ def shift_deduplicated_pairs( # print(f"___ after overlaps {pair.sum()=}") # co-occurrence - cooccurrence = template_shift_index.cooccurrence[ + cooccurrence = cooccurrence[ shifted_temp_ix_a[:, None], shifted_temp_ix_b[None, :], ] @@ -675,7 +734,7 @@ def shift_deduplicated_pairs( # print(f"___ {nco=}") # if no shifting, deduplication is the identity - do_shifting = template_shift_index.all_pitch_shifts.size > 1 + do_shifting = reg_geom_kdtree is not None if not do_shifting: nco_range = torch.arange(nco, device=pair_ix_a.device) return pair_ix_a, pair_ix_b, nco_range, nco_range @@ -1005,10 +1064,14 @@ def coarse_approximate( @dataclass class ConvWorkerContext: - template_data: templates.TemplateData - low_rank_templates: template_util.LowRankTemplates + template_data_a: templates.TemplateData + template_data_b: templates.TemplateData + low_rank_templates_a: template_util.LowRankTemplates + low_rank_templates_b: template_util.LowRankTemplates compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates - template_shift_index: drift_util.TemplateShiftIndex + template_shift_index_a: drift_util.TemplateShiftIndex + template_shift_index_b: drift_util.TemplateShiftIndex + cooccurrence: np.ndarray upsampled_shifted_template_index: UpsampledShiftedTemplateIndex geom: Optional[np.ndarray] = None reg_geom: Optional[np.ndarray] = None diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index d25c1c90..745210d6 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, replace from pathlib import Path from typing import Optional @@ -7,7 +7,8 @@ from .get_templates import get_templates from .superres_util import superres_sorting -from .template_util import get_realigned_sorting, get_template_depths +from .template_util import (get_realigned_sorting, get_template_depths, + weighted_average) _motion_error_prefix = ( "If template_config has registered_templates==True " @@ -27,25 +28,12 @@ class TemplateData: registered_geom: Optional[np.ndarray] = None registered_template_depths_um: Optional[np.ndarray] = None + localization_radius_um: float = 100.0 @classmethod def from_npz(cls, npz_path): - with np.load(npz_path) as npz: - templates = npz["templates"] - unit_ids = npz["unit_ids"] - spike_counts = npz["spike_counts"] - registered_geom = registered_template_depths_um = None - if "registered_geom" in npz: - registered_geom = npz["registered_geom"] - if "registered_template_depths_um" in npz: - registered_template_depths_um = npz["registered_template_depths_um"] - return cls( - templates, - unit_ids, - spike_counts, - registered_geom, - registered_template_depths_um, - ) + with np.load(npz_path) as data: + return cls(**data) def to_npz(self, npz_path): to_save = dict( @@ -61,6 +49,30 @@ def to_npz(self, npz_path): ] = self.registered_template_depths_um np.savez(npz_path, **to_save) + def coarsen(self): + """Weighted average all templates that share a unit id and re-localize.""" + # update templates + templates = weighted_average(self.unit_ids, self.templates, self.spike_counts) + + # collect spike counts + spike_counts = np.zeros(len(templates)) + np.add.at(spike_counts, self.unit_ids, self.spike_counts) + + # re-localize + registered_template_depths_um = get_template_depths( + templates, + self.registered_geom, + localization_radius_um=self.localization_radius_um, + ) + + return replace( + self, + templates=templates, + unit_ids=np.arange(len(templates)), + spike_counts=spike_counts, + registered_template_depths_um=registered_template_depths_um, + ) + @classmethod def from_config( cls, @@ -163,7 +175,9 @@ def from_config( # main! results = get_templates(recording, sorting, **kwargs) - print(f"{[(k,v.dtype) for k,v in results.items() if (isinstance(v, np.ndarray))]=}") + print( + f"{[(k,v.dtype) for k,v in results.items() if (isinstance(v, np.ndarray))]=}" + ) # handle registered templates if template_config.registered_templates: @@ -178,6 +192,7 @@ def from_config( spike_counts, kwargs["registered_geom"], registered_template_depths_um, + localization_radius_um=template_config.registered_template_localization_radius_um, ) else: obj = cls( diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index 0e6cccb6..b04c6c9d 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -476,10 +476,33 @@ class TemplateShiftIndex: # (template ix, shift index) -> shifted template index template_shift_index: np.ndarray # (shifted temp ix, shifted temp ix) -> did these appear at the same time - cooccurrence: np.ndarray shifted_temp_ix_to_temp_ix: np.ndarray shifted_temp_ix_to_shift: np.ndarray + @classmethod + def from_shift_matrix(cls, shifts): + """shift: n_times x n_templates""" + all_shifts = np.unique(shifts) + n_templates = shifts.shape[1] + pairs = np.stack(np.broadcast_arrays(np.arange(n_templates)[None, :], shifts), axis=2) + pairs = np.unique(pairs.reshape(shifts.size, 2), axis=0) + n_shifted_templates = len(pairs) + shift_ix = np.searchsorted(all_shifts, pairs[:, 1]) + template_shift_index = np.full( + (n_templates, len(all_shifts)), n_shifted_templates + ) + template_shift_index[pairs[:, 0], shift_ix] = np.arange(n_shifted_templates) + return cls( + n_shifted_templates, + all_shifts, + template_shift_index, + *pairs.T, + ) + + def shifts_to_shifted_ids(self, template_ids, shifts): + shift_ixs = np.searchsorted(self.all_pitch_shifts, shifts) + return self.template_shift_index[template_ids, shift_ixs] + def static_template_shift_index(n_templates): temp_ixs = np.arange(n_templates) @@ -487,7 +510,6 @@ def static_template_shift_index(n_templates): n_templates, np.zeros(1), temp_ixs[:, None], - np.ones((n_templates, n_templates), dtype=bool), temp_ixs, np.zeros_like(temp_ixs), ) @@ -496,78 +518,69 @@ def static_template_shift_index(n_templates): def get_shift_and_unit_pairs( chunk_time_centers_s, geom, - template_data, + template_data_a, + template_data_b=None, motion_est=None, ): - n_templates = len(template_data.templates) - if motion_est is None: - # no motion case - return static_template_shift_index(n_templates) - - # all observed pitch shift values - all_pitch_shifts = np.empty(shape=(0,), dtype=int) - temp_ixs = np.arange(n_templates) - # set of (template idx, shift) - template_shift_pairs = np.empty(shape=(0, 2), dtype=int) - pitch = get_pitch(geom) - - for t_s in chunk_time_centers_s: - # see the fn `templates_at_time` - unregistered_depths_um = invert_motion_estimate( - motion_est, t_s, template_data.registered_template_depths_um - ) - pitch_shifts = get_spike_pitch_shifts( - depths_um=template_data.registered_template_depths_um, - pitch=pitch, - registered_depths_um=unregistered_depths_um, - ) - pitch_shifts = pitch_shifts.astype(int) - - # get unique pitch/unit shift pairs in chunk - template_shift = np.c_[temp_ixs, pitch_shifts] - - # update full set - all_pitch_shifts = np.union1d(all_pitch_shifts, pitch_shifts) - template_shift_pairs = np.unique( - np.concatenate((template_shift_pairs, template_shift), axis=0), axis=0 - ) + if template_data_b is None: + template_data_b = template_data_a - n_shifts = len(all_pitch_shifts) - n_template_shift_pairs = len(template_shift_pairs) + na = template_data_a.templates.shape[0] + nb = template_data_b.templates.shape[0] - # index template/shift pairs: template_shift_index[template_ix, shift_ix] = shifted template index - # fill with an invalid index - template_shift_index = np.full((n_templates, n_shifts), n_template_shift_pairs) - shift_ix = np.searchsorted(all_pitch_shifts, template_shift_pairs[:, 1]) - assert np.array_equal(all_pitch_shifts[shift_ix], template_shift_pairs[:, 1]) - template_shift_index[template_shift_pairs[:, 0], shift_ix] = np.arange( - n_template_shift_pairs + if motion_est is None: + shift_index_a = static_template_shift_index(na) + shift_index_b = static_template_shift_index(nb) + cooccurrence = np.ones((na, nb), dtype=bool) + return shift_index_a, shift_index_b, cooccurrence + + reg_depths_um_a = template_data_a.registered_template_depths_um + reg_depths_um_b = template_data_b.registered_template_depths_um + same = np.array_equal(reg_depths_um_a, reg_depths_um_b) + if same: + reg_depths_um = reg_depths_um_a + else: + reg_depths_um = np.concatenate((reg_depths_um_a, reg_depths_um_b)) + + # figure out all shifts for all units at all times + unreg_depths_um = np.concatenate( + [ + motion_est.disp_at_s(t_s, depth_um=reg_depths_um, grid=True).T + for t_s in chunk_time_centers_s + ], + axis=0, + ) + assert unreg_depths_um.shape == (len(chunk_time_centers_s), len(reg_depths_um)) + pitch_shifts = get_spike_pitch_shifts( + depths_um=reg_depths_um, + pitch=get_pitch(geom), + registered_depths_um=unreg_depths_um, ) - shifted_temp_ix_to_temp_ix = template_shift_pairs[:, 0] - shifted_temp_ix_to_shift = template_shift_pairs[:, 1] + if same: + shifts_a = shifts_b = pitch_shifts + else: + shifts_a = pitch_shifts[:, :na] + shifts_b = pitch_shifts[:, na:] + + # assign ids to pitch/shift pairs + template_shift_index_a = TemplateShiftIndex.from_shift_matrix(shifts_a) + if same: + template_shift_index_b = template_shift_index_a + else: + template_shift_index_b = TemplateShiftIndex.from_shift_matrix(shifts_b) # co-occurrence matrix: do these shifted templates appear together? - cooccurrence = np.eye(n_template_shift_pairs, dtype=bool) - for t_s in chunk_time_centers_s: - unregistered_depths_um = invert_motion_estimate( - motion_est, t_s, template_data.registered_template_depths_um - ) - pitch_shifts = get_spike_pitch_shifts( - depths_um=template_data.registered_template_depths_um, - pitch=pitch, - registered_depths_um=unregistered_depths_um, - ) - pitch_shifts = pitch_shifts.astype(int) - pitch_shift_ix = np.searchsorted(all_pitch_shifts, pitch_shifts) - - shifted_temp_ixs = template_shift_index[temp_ixs, pitch_shift_ix] - cooccurrence[shifted_temp_ixs[:, None], shifted_temp_ixs[None, :]] = 1 - - return TemplateShiftIndex( - n_template_shift_pairs, - all_pitch_shifts, - template_shift_index, - cooccurrence, - shifted_temp_ix_to_temp_ix, - shifted_temp_ix_to_shift, - ) + cooccurrence = np.zeros( + (template_shift_index_a.n_shifted_templates, template_shift_index_b.n_shifted_templates), + dtype=bool) + temps_a = np.arange(na) + temps_b = np.arange(nb) + for j in range(len(chunk_time_centers_s)): + shifted_ids_a = template_shift_index_a.shifts_to_shifted_ids(temps_a, shifts_a[j]) + if same: + shifted_ids_b = shifted_ids_a + else: + shifted_ids_b = template_shift_index_b.shifts_to_shifted_ids(temps_b, shifts_b[j]) + cooccurrence[shifted_ids_a[:, None], shifted_ids_b[None, :]] = 1 + + return template_shift_index_a, template_shift_index_b, cooccurrence From eda817cac43605d0e5a7d0a9deb03b1ce6ea1bdf Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 21 Nov 2023 18:46:47 -0500 Subject: [PATCH 33/49] Initial un-debugged coarse to fine matching --- src/dartsort/peel/matching.py | 519 ++++++++++++++-------- src/dartsort/templates/pairwise.py | 4 + src/dartsort/templates/template_util.py | 10 + src/dartsort/util/multiprocessing_util.py | 2 + src/dartsort/util/spiketorch.py | 66 ++- 5 files changed, 409 insertions(+), 192 deletions(-) diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index 8b0af61b..01dc84b1 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -35,6 +35,7 @@ def __init__( featurization_pipeline, motion_est=None, svd_compression_rank=10, + coarse_objective=True, temporal_upsampling_factor=8, upsampling_peak_window_radius=8, min_channel_amplitude=1.0, @@ -62,6 +63,7 @@ def __init__( # main properties self.template_data = template_data + self.coarse_objective = coarse_objective self.temporal_upsampling_factor = temporal_upsampling_factor self.upsampling_peak_window_radius = upsampling_peak_window_radius self.svd_compression_rank = svd_compression_rank @@ -174,6 +176,12 @@ def check_shapes(self): assert self.unit_ids.shape == (self.n_templates,) def handle_template_groups(self, unit_ids): + """Grouped templates in objective + + If not coarse_objective, then several rows of the objective may + belong to the same unit. They must be handled together when imposing + refractory conditions. + """ self.register_buffer("unit_ids", torch.from_numpy(unit_ids)) self.grouped_temps = True unique_units = np.unique(unit_ids) @@ -181,11 +189,19 @@ def handle_template_groups(self, unit_ids): self.grouped_temps = False if not self.grouped_temps: + self.register_buffer("superres_index", torch.arange(len(unit_ids))[:, None]) return - assert unit_ids.shape == (self.n_templates,) - group_index = [np.flatnonzero(unit_ids == u) for u in unit_ids] - max_group_size = max(map(len, group_index)) + + units, counts = np.unique(self.unit_ids, return_counts=True) + superres_index = np.full((self.n_templates, counts.max()), self.n_templates, -1) + for u in units: + my_sup = np.flatnonzero(self.unit_ids == u) + superres_index[u, : len(my_sup)] = my_sup + self.register_buffer("superres_index", torch.from_numpy(superres_index)) + + if self.coarse_objective: + return # like a channel index, sort of # this is a n_templates x group_size array that maps each @@ -193,8 +209,9 @@ def handle_template_groups(self, unit_ids): # are part of its group. so that the array is not ragged, # we pad rows with -1s when their group is smaller than the # largest group. - group_index = np.full((self.n_templates, max_group_size), -1) - for j, row in enumerate(group_index): + group_index = np.full((self.n_templates, counts.max()), -1) + for j, u in enumerate(unit_ids): + row = np.flatnonzero(unit_ids == u) group_index[j, : len(row)] = row self.register_buffer("group_index", torch.from_numpy(group_index)) @@ -218,14 +235,9 @@ def build_template_data( temporal_components = low_rank_templates.temporal_components.astype(dtype) singular_values = low_rank_templates.singular_values.astype(dtype) spatial_components = low_rank_templates.spatial_components.astype(dtype) - print(f"{template_data.templates.dtype=}") - print(f"{temporal_components.dtype=}") - print(f"{singular_values.dtype=}") - print(f"{spatial_components.dtype=}") self.register_buffer("temporal_components", torch.tensor(temporal_components)) self.register_buffer("singular_values", torch.tensor(singular_values)) self.register_buffer("spatial_components", torch.tensor(spatial_components)) - compressed_upsampled_temporal = self.handle_upsampling( temporal_components, ptps=template_data.templates.ptp(1).max(1), @@ -233,6 +245,45 @@ def build_template_data( upsampling_peak_window_radius=upsampling_peak_window_radius, ) + # handle the case where objective is not superres + if self.coarse_objective: + coarse_template_data = template_data.coarsen() + coarse_low_rank_templates = template_util.svd_compress_templates( + coarse_template_data.templates, + min_channel_amplitude=min_channel_amplitude, + rank=svd_compression_rank, + ) + temporal_components = coarse_low_rank_templates.temporal_components.astype( + dtype + ) + singular_values = coarse_low_rank_templates.singular_values.astype(dtype) + spatial_components = coarse_low_rank_templates.spatial_components.astype( + dtype + ) + self.objective_template_depths_um = ( + coarse_template_data.registered_template_depths_um + ) + self.register_buffer( + "objective_temporal_components", torch.tensor(temporal_components) + ) + self.register_buffer( + "objective_singular_values", torch.tensor(singular_values) + ) + self.register_buffer( + "objective_spatial_components", torch.tensor(spatial_components) + ) + else: + coarse_template_data = template_data + coarse_low_rank_templates = low_rank_templates + self.objective_template_depths_um = self.registered_template_depths_um + self.register_buffer( + "objective_temporal_components", self.temporal_components + ) + self.register_buffer("objective_singular_values", self.singular_values) + self.register_buffer( + "objective_spatial_components", self.spatial_components + ) + half_chunk = self.chunk_length_samples // 2 chunk_centers_samples = np.arange( half_chunk, self.recording.get_num_samples(), self.chunk_length_samples @@ -240,12 +291,12 @@ def build_template_data( chunk_centers_s = self.recording._recording_segments[0].sample_index_to_time( chunk_centers_samples ) - print(f"build_template_data {device=}") - print(f"{chunk_centers_s.shape=} {chunk_centers_s[:10]=}") self.pairwise_conv_db = CompressedPairwiseConv.from_template_data( save_folder / "pconv.h5", - template_data=template_data, - low_rank_templates=low_rank_templates, + template_data=coarse_template_data, + low_rank_templates=coarse_low_rank_templates, + template_data_b=template_data, + low_rank_templates_b=low_rank_templates, compressed_upsampled_temporal=compressed_upsampled_temporal, chunk_time_centers_s=chunk_centers_s, motion_est=self.motion_est, @@ -260,14 +311,6 @@ def build_template_data( ("temporal_components", temporal_components), ("singular_values", singular_values), ("spatial_components", spatial_components), - ( - "compressed_upsampling_map", - compressed_upsampled_temporal.compressed_upsampling_map, - ), - ( - "compressed_upsampled_temporal", - compressed_upsampled_temporal.compressed_upsampled_templates, - ), ] def handle_upsampling( @@ -282,42 +325,21 @@ def handle_upsampling( ptps=ptps, max_upsample=temporal_upsampling_factor, ) - print(f"{compressed_upsampled_temporal.compressed_upsampled_templates.dtype=}") self.register_buffer( "compressed_upsampling_map", torch.tensor(compressed_upsampled_temporal.compressed_upsampling_map), ) self.register_buffer( - "compressed_upsampled_temporal", - torch.tensor(compressed_upsampled_temporal.compressed_upsampled_templates), - ) - if temporal_upsampling_factor == 1: - return compressed_upsampled_temporal - - self.register_buffer( - "upsampling_window", - torch.arange( - -upsampling_peak_window_radius, upsampling_peak_window_radius + 1 - ), + "compressed_upsampling_index", + torch.tensor(compressed_upsampled_temporal.compressed_upsampling_index), ) - self.upsampling_window_len = 2 * upsampling_peak_window_radius - center = upsampling_peak_window_radius * temporal_upsampling_factor - radius = temporal_upsampling_factor // 2 + temporal_upsampling_factor % 2 self.register_buffer( - "upsampled_peak_search_window", - torch.arange(center - radius, center + radius + 1), - ) - self.register_buffer( - "peak_to_upsampling_index", - torch.concatenate( - [ - torch.arange(radius, -1, -1), - (temporal_upsampling_factor - 1) - torch.arange(radius), - ] - ), + "compressed_index_to_upsampling_index", + torch.tensor(compressed_upsampled_temporal.compressed_index_to_upsampling_index), ) self.register_buffer( - "peak_to_time_shift", torch.tensor([0] * (radius + 1) + [1] * radius) + "compressed_upsampled_temporal", + torch.tensor(compressed_upsampled_temporal.compressed_upsampled_templates), ) return compressed_upsampled_temporal @@ -391,7 +413,7 @@ def peel_chunk( return match_results def templates_at_time(self, t_s): - """Extract the right spatial components for each unit.""" + """Handle drift -- grab the right spatial neighborhoods.""" pconvdb = self.pairwise_conv_db if self.is_drifting: pitch_shifts, cur_spatial = template_util.templates_at_time( @@ -414,6 +436,20 @@ def templates_at_time(self, t_s): match_distance=self.match_distance, fill_value=0.0, ) + if self.coarse_objective: + cur_obj_spatial = template_util.templates_at_time( + t_s, + self.objective_spatial_components, + self.geom, + registered_template_depths_um=self.objective_template_depths_um, + registered_geom=self.registered_geom, + motion_est=self.motion_est, + return_pitch_shifts=True, + geom_kdtree=self.geom_kdtree, + match_distance=self.match_distance, + ) + else: + cur_obj_spatial = cur_spatial max_channels = cur_ampvecs[:, 0, :].argmax(1) pconvdb = pconvdb.at_shifts(pitch_shifts) else: @@ -423,14 +459,21 @@ def templates_at_time(self, t_s): if not pconvdb._is_torch: pconvdb = pconvdb.to(cur_spatial.device) - return CompressedTemplateData( - cur_spatial, - self.singular_values, - self.temporal_components, - self.compressed_upsampling_map, - self.compressed_upsampled_temporal, - torch.tensor(max_channels, device=cur_spatial.device), - pconvdb, + return MatchingTemplateData( + objective_spatial_components=cur_obj_spatial, + objective_singular_values=self.objective_singular_values, + objective_temporal_components=self.objective_temporal_components, + unit_ids=self.unit_ids, + coarse_objective=self.coarse_objective, + spatial_components=cur_spatial, + singular_values=self.singular_values, + temporal_components=self.temporal_components, + compressed_upsampling_map=self.compressed_upsampling_map, + compressed_upsampling_index=self.compressed_upsampling_index, + compressed_index_to_upsampling_index=self.compressed_index_to_upsampling_index, + compressed_upsampled_temporal=self.compressed_upsampled_temporal, + max_channels=max_channels, + pairwise_conv_db=pconvdb, ) def match_chunk( @@ -489,7 +532,7 @@ def match_chunk( # find high-res peaks print("before find") new_peaks = self.find_peaks( - padded_conv, padded_objective, refrac_mask, neg_temp_normsq + padded_conv, padded_objective, refrac_mask, compressed_template_data ) if new_peaks is None: break @@ -519,14 +562,13 @@ def match_chunk( new_peaks.scalings, conv_pad_len=self.obj_pad_len, ) - if return_residual: - compressed_template_data.subtract( - residual_padded, - new_peaks.times, - new_peaks.template_indices, - new_peaks.upsampling_indices, - new_peaks.scalings, - ) + compressed_template_data.subtract( + residual_padded, + new_peaks.times, + new_peaks.template_indices, + new_peaks.upsampling_indices, + new_peaks.scalings, + ) # new_norm = torch.linalg.norm(residual) ** 2 # print(f"{it=} {new_norm=}") @@ -539,8 +581,11 @@ def match_chunk( peaks.sort() # extract collision-cleaned waveforms on small neighborhoods - channels, waveforms = self.get_collisioncleaned_waveforms( - residual_padded, peaks, compressed_template_data + channels, waveforms = compressed_template_data.get_collisioncleaned_waveforms( + residual_padded, + peaks, + self.channel_index, + spike_length_samples=self.spike_length_samples, ) res = dict( @@ -558,28 +603,38 @@ def match_chunk( res["residual"] = residual return res - def find_peaks(self, padded_conv, padded_objective, refrac_mask, neg_temp_normsq): + def find_peaks(self, residual, padded_conv, padded_objective, refrac_mask, compressed_template_data): # first step: coarse peaks. not temporally upsampled or amplitude-scaled. objective = (padded_objective + refrac_mask)[ :-1, self.obj_pad_len : -self.obj_pad_len ] # formerly used detect_and_deduplicate, but that was slow. - objective_max, max_template = objective.max(dim=0) + objective_max, max_obj_template = objective.max(dim=0) times = argrelmax(objective_max, self.spike_length_samples, self.threshold) # tt = times.numpy(force=True) # print(f"{np.diff(tt).min()=} {tt=}") - template_indices = max_template[times] + obj_template_indices = max_obj_template[times] # remove peaks inside the padding if not times.numel(): return None + residual_snips = None + if self.coarse_objective or self.temporal_upsampling_factor > 1: + residual_snips = spiketorch.grab_spikes_full( + times - 1, + trough_offset=0, + spike_length_samples=self.spike_length_samples + 1, + ) + # second step: high-res peaks (upsampled and/or amp-scaled) - time_shifts, upsampling_indices, scalings, scores = self.find_fancy_peaks( - padded_conv, - padded_objective, - times + self.obj_pad_len, - template_indices, - neg_temp_normsq, + time_shifts, upsampling_indices, scalings, template_indices, scores = compressed_template_data.fine_match( + padded_conv[obj_template_indices, times], + objective_max[times], + residual_snips, + obj_template_indices, + amp_scale_variance=self.amplitude_scaling_variance, + amp_scale_min=self.amp_scale_min, + amp_scale_max=self.amp_scale_max, ) if time_shifts is not None: times += time_shifts @@ -598,103 +653,37 @@ def enforce_refractory(self, objective, times, template_indices): return # overwrite objective with -inf to enforce refractoriness time_ix = times[:, None] + self._refrac_ix[None, :] - if self.grouped_temps: + if not self.grouped_temps: + unit_ix = template_indices[:, None, None] + elif self.coarse_objective: + unit_ix = self.unit_ids[template_indices][:, None, None] + elif self.grouped_temps: unit_ix = self.group_index[template_indices] else: - unit_ix = template_indices[:, None, None] + assert False objective[unit_ix[:, :, None], time_ix[:, None, :]] = -torch.inf - def find_fancy_peaks( - self, conv, objective, times, template_indices, neg_temp_normsq - ): - """Given coarse peaks, find temporally upsampled and scaled ones.""" - # tricky bit. we search for upsampled peaks to the left and right - # of the original peak. when the up-peak comes to the right, we - # use one of the upsampled templates, no problem. when the peak - # comes to the left, it's different: it came from one of the upsampled - # templates shifted one sample (spike time += 1). - if self.temporal_upsampling_factor == 1 and not self.is_scaling: - return None, None, None, objective[template_indices, times] - - if self.is_scaling and self.temporal_upsampling_factor == 1: - inv_lambda = 1 / self.amplitude_scaling_variance - b = conv[times, template_indices] + inv_lambda - a = neg_temp_normsq[template_indices] + inv_lambda - scalings = torch.clip(b / a, self.amp_scale_min, self.amp_scale_max) - scores = 2.0 * scalings * b - torch.square(scalings) * a - inv_lambda - return None, None, scalings, scores - - # below, we are upsampling. - # get clips of objective function around the peaks - # we'll use the scaled objective here. - time_ix = times[:, None] + self.upsampling_window[None, :] - clip_ix = (template_indices[:, None], time_ix) - upsampled_clip_len = ( - self.upsampling_window_len * self.temporal_upsampling_factor - ) - if self.is_scaling: - high_res_conv = spiketorch.real_resample( - conv[clip_ix], upsampled_clip_len, dim=1 - ) - inv_lambda = 1.0 / self.amplitude_scaling_variance - b = high_res_conv + inv_lambda - a = neg_temp_normsq[template_indices] + inv_lambda - scalings = torch.clip(b / a, self.amp_scale_min, self.amp_scale_max) - high_res_obj = ( - 2.0 * scalings * b - torch.square(scalings) * a[:, None] - inv_lambda - ) - else: - scalings = None - obj_clips = objective[clip_ix] - high_res_obj = spiketorch.real_resample( - obj_clips, upsampled_clip_len, dim=1 - ) - - # zoom into a small upsampled area and determine the - # upsampled template and time shifts - scores, zoom_peak = torch.max( - high_res_obj[:, self.upsampled_peak_search_window], dim=1 - ) - upsampling_indices = self.peak_to_upsampling_index[zoom_peak] - time_shifts = self.peak_to_time_shift[zoom_peak] - - return time_shifts, upsampling_indices, scalings, scores - - def get_collisioncleaned_waveforms( - self, residual_padded, peaks, compressed_template_data - ): - channels = compressed_template_data.max_channels[peaks.template_indices] - waveforms = spiketorch.grab_spikes( - residual_padded, - peaks.times, - channels, - self.channel_index, - trough_offset=0, - spike_length_samples=self.spike_length_samples, - buffer=0, - already_padded=True, - ) - padded_spatial = F.pad(compressed_template_data.spatial_singular, (0, 1)) - spatial = padded_spatial[ - peaks.template_indices[:, None, None], - self._rank_ix[None, :, None], - self.channel_index[channels][:, None, :], - ] - temporal = compressed_template_data.upsampled_temporal_components[ - peaks.template_indices, :, peaks.upsampling_indices - ] - torch.baddbmm(waveforms, temporal, spatial, out=waveforms) - return channels, waveforms - @dataclass -class CompressedTemplateData: - """Objects of this class are returned by ObjectiveUpdateTemplateMatchingPeeler.templates_at_time()""" - +class MatchingTemplateData: + """All the data and math needed for computing convs etc in a single static chunk of data + + This is the 'model' for template matching in a MVC analogy. The class above is the controller. + Objects of this class are returned by ObjectiveUpdateTemplateMatchingPeeler.templates_at_time(), + which handles the drift logic and lets this class be simple. + """ + + objective_spatial_components: torch.Tensor + objective_singular_values: torch.Tensor + objective_temporal_components: torch.Tensor + unit_ids: torch.LongTensor + coarse_objective: bool spatial_components: torch.Tensor singular_values: torch.Tensor temporal_components: torch.Tensor compressed_upsampling_map: torch.LongTensor + compressed_upsampling_index: torch.LongTensor + compressed_index_to_upsampling_index: torch.LongTensor compressed_upsampled_temporal: torch.Tensor max_channels: torch.LongTensor pairwise_conv_db: CompressedPairwiseConv @@ -711,48 +700,55 @@ def __post_init__(self): self.rank, ) assert self.singular_values.shape == (self.n_templates, self.rank) + device = self.spatial_components.device # squared l2 norms are usually the sums of squared singular values: # self.template_norms_squared = torch.square(self.singular_values).sum(1) # in this case, we have subset the spatial components, so use a diff formula + self.objective_spatial_singular = ( + self.objective_spatial_components + * self.objective_singular_values[:, :, None] + ) self.spatial_singular = ( self.spatial_components * self.singular_values[:, :, None] ) + self.objective_template_norms_squared = torch.square( + self.objective_spatial_singular + ).sum((1, 2)) self.template_norms_squared = torch.square(self.spatial_singular).sum((1, 2)) - self.chan_ix = torch.arange( - self.spatial_components.shape[2], device=self.spatial_components.device - ) - self.time_ix = torch.arange( - self.spike_length_samples, device=self.spatial_components.device - ) + self.chan_ix = torch.arange(self.spatial_components.shape[2], device=device) + self.rank_ix = torch.arange(self.rank, device=device) + self.time_ix = torch.arange(self.spike_length_samples, device=device) self.conv_lags = torch.arange( - -self.spike_length_samples + 1, - self.spike_length_samples, - device=self.spatial_components.device, + -self.spike_length_samples + 1, self.spike_length_samples, device=device ) def convolve(self, traces, padding=0, out=None): - """Convolve all templates with traces.""" + """Convolve the objective templates with traces.""" out_len = traces.shape[0] + 2 * padding - self.spike_length_samples + 1 if out is None: out = torch.zeros( - 1, self.n_templates, out_len, dtype=traces.dtype, device=traces.device + (self.n_templates, out_len), dtype=traces.dtype, device=traces.device ) else: assert out.shape == (self.n_templates, out_len) - out = out[None] for q in range(self.rank): # units x time - rec_spatial = self.spatial_singular[:, q, :] @ traces.T + rec_spatial = self.objective_spatial_singular[:, q, :] @ traces.T # convolve with temporal components -- units x time - temporal = self.temporal_components[:, :, q] + temporal = self.objective_temporal_components[:, :, q] # conv1d with groups! only convolve each unit with its own temporal filter. - conv = F.conv1d( - rec_spatial[None], - temporal[:, None, :], - groups=self.n_templates, - padding=padding, + # conv = F.conv1d( + # rec_spatial[None], + # temporal[:, None, :], + # groups=self.n_templates, + # padding=padding, + # )[0] + conv = spiketorch.depthwise_oaconv1d( + rec_spatial, + temporal[:, :], + padding=padding ) if q: out += conv @@ -760,7 +756,7 @@ def convolve(self, traces, padding=0, out=None): out.copy_(conv) # back to units x time (remove extra dim used for conv1d) - return out[0] + return out def subtract_conv( self, @@ -788,6 +784,132 @@ def subtract_conv( sign=-1, ) + def fine_match( + self, + convs, + objs, + residual_snips, + objective_template_indices, + amp_scale_variance=0.0, + amp_scale_min=None, + amp_scale_max=None, + ): + """Determine superres ids, temporal upsampling, and scaling + + Given coarse matches (unit ids at times) and the current residual, + pick the best superres template, the best temporal offset, and the + best amplitude scaling. + + We used to upsample the objective to figure out the temporal upsampling, + but with superres in the picture we are now not computing the objective + using the same templates that we temporally upsample. So, instead + we use a greedy strategy: first pick the best (non-temporally upsampled) + superres template, then pick the upsampling and scaling at the same time. + These are all done by dotting everything and computing the objective, + which is probably more expensive than what we had before. + + Returns + ------- + time_shifts : Optional[array] + upsampling_indices : Optional[array] + scalings : Optional[array] + template_indices : array + objs : array + """ + if ( + not self.coarse_objective + and self.temporal_upsampling_factor == 1 + and not amp_scale_variance + ): + return None, None, None, objective_template_indices, objs + + if self.coarse_objective or self.temporal_upsampling_factor > 1: + # snips is a window padded by one sample, so that we have the + # traces snippets at the current times and one step back + n_spikes, window_length_samples, n_chans = residual_snips.shape + spike_length_samples = window_length_samples - 1 + # grab the current traces + snips = residual_snips[:, 1:] + # unpack the current traces and the traces one step back + snips_dt = F.unfold( + residual_snips[:, None, :, :], (spike_length_samples, snips.shape[2]) + ) + snips_dt = snips_dt.reshape( + len(snips), spike_length_samples, snips.shape[2], 2 + ) + + if self.coarse_objective: + # TODO best I came up with, but it still syncs + superres_ix = self.superres_index[objective_template_indices] + dup_ix, column_ix = (superres_ix < self.n_templates).nonzero(as_tuple=True) + template_indices = superres_ix[dup_ix, column_ix] + convs = torch.einsum( + "jtc,jrc,jtr->j", + snips[dup_ix], + self.spatial_singular[template_indices], + self.temporal_components[template_indices], + ) + neg_norms = -self.template_norms_squared[template_indices] + objs = torch.full( + superres_ix.shape, -torch.inf, device=convs.device + ) + objs[dup_ix, column_ix] = 2 * convs + neg_norms + objs, best_column_ix = objs.max(dim=1) + row_ix = torch.arange(best_column_ix.numel(), device=best_column_ix.device) + template_indices = superres_ix[row_ix, best_column_ix] + else: + template_indices = objective_template_indices + neg_norms = -self.template_norms_squared[template_indices] + objs = objs + + if self.temporal_upsampling_factor == 1 and not amp_scale_variance: + return None, None, None, template_indices, objs + + if self.temporal_upsampling_factor == 1: + # just scaling + inv_lambda = 1 / amp_scale_variance + b = convs + inv_lambda + a = neg_norms + inv_lambda + scalings = torch.clip(b / a, amp_scale_min, amp_scale_max) + objs = 2 * scalings * b - torch.square(scalings) * a - inv_lambda + return None, None, scalings, template_indices, objs + + # now, upsampling + # repeat the superres logic, the comp up index acts the same + comp_up_ix = self.compressed_upsampling_index[template_indices] + dup_ix, column_ix = ( + comp_up_ix < self.n_compressed_upsampled_templates + ).nonzero(as_tuple=True) + comp_up_indices = comp_up_ix[dup_ix, column_ix] + convs = torch.einsum( + "jtcd,jrc,jtr->jd", + snips_dt[dup_ix], + self.spatial_singular[template_indices[dup_ix]], + self.compressed_upsampled_temporal[comp_up_indices], + ) + neg_norms = neg_norms[dup_ix] + objs = torch.full((*comp_up_ix.shape, 2), -torch.inf, device=convs.device) + if amp_scale_variance: + inv_lambda = 1 / amp_scale_variance + b = convs + inv_lambda + a = neg_norms + inv_lambda + scalings = torch.clip(b / a, amp_scale_min, amp_scale_max) + objs[dup_ix, column_ix] = 2 * scalings * b - torch.square(scalings) * a - inv_lambda + else: + objs[dup_ix, column_ix] = 2 * convs - neg_norms + scalings = None + objs, best_column_dt_ix = objs.reshape(len(convs), -1).max(dim=1) + + best_column_ix = best_column_dt_ix // 2 + row_ix = torch.arange(best_column_ix.numel(), device=best_column_ix.device) + comp_up_indices = comp_up_ix[row_ix, best_column_ix] + upsampling_indices = self.compressed_index_to_upsampling_index[comp_up_indices] + + # even positions have were one step earlier + time_shifts = best_column_dt_ix % 2 - 1 + + return time_shifts, upsampling_indices, scalings, template_indices, objs + def subtract( self, traces, @@ -795,6 +917,7 @@ def subtract( template_indices, upsampling_indices, scalings, + batch_templates=..., ): """Subtract templates from traces.""" compressed_up_inds = self.compressed_upsampling_map[ @@ -811,6 +934,32 @@ def subtract( traces, (time_ix, self.chan_ix[None, None, :]), batch_templates, sign=-1 ) + def get_collisioncleaned_waveforms( + self, residual_padded, peaks, channel_index, spike_length_samples=121 + ): + channels = self.max_channels[peaks.template_indices] + waveforms = spiketorch.grab_spikes( + residual_padded, + peaks.times, + channels, + channel_index, + trough_offset=0, + spike_length_samples=spike_length_samples, + buffer=0, + already_padded=True, + ) + padded_spatial = F.pad(self.spatial_singular, (0, 1)) + spatial = padded_spatial[ + peaks.template_indices[:, None, None], + self.rank_ix[None, :, None], + channel_index[channels][:, None, :], + ] + temporal = self.upsampled_temporal_components[ + peaks.template_indices, :, peaks.upsampling_indices + ] + torch.baddbmm(waveforms, temporal, spatial, out=waveforms) + return channels, waveforms + class MatchingPeaks: BUFFER_INIT: int = 1500 diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index 4fbb17ee..73848c9f 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -72,6 +72,8 @@ def from_template_data( template_data: TemplateData, low_rank_templates: LowRankTemplates, compressed_upsampled_temporal: CompressedUpsampledTemplates, + template_data_b: Optional[TemplateData] = None, + low_rank_templates_b: Optional[TemplateData] = None, chunk_time_centers_s: Optional[np.ndarray] = None, motion_est=None, geom: Optional[np.ndarray] = None, @@ -90,6 +92,8 @@ def from_template_data( template_data=template_data, low_rank_templates=low_rank_templates, compressed_upsampled_temporal=compressed_upsampled_temporal, + template_data_b=template_data_b, + low_rank_templates_b=low_rank_templates_b, chunk_time_centers_s=chunk_time_centers_s, motion_est=motion_est, geom=geom, diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 30255af1..197b60d5 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -261,6 +261,7 @@ def temporally_upsample_templates( CompressedUpsampledTemplates = namedtuple( "CompressedUpsampledTemplates", [ + "n_compressed_upsampled_templates", "compressed_upsampled_templates", "compressed_upsampling_map", "compressed_index_to_template_index", @@ -291,14 +292,19 @@ def compressed_upsampled_templates( of the jth upsampled template for this unit. for low-amplitude units, compressed_upsampling_map[unit] will have fewer unique entries, corresponding to fewer saved upsampled copies for that unit. + compressed_upsampling_index : array (n_templates, max_upsample) + A n_compressed_upsampled_templates-padded ragged array mapping each + template index to its compressed upsampled indices compressed_index_to_template_index compressed_index_to_upsampling_index """ n_templates = templates.shape[0] if max_upsample == 1: return CompressedUpsampledTemplates( + n_templates, templates, np.arange(n_templates)[:, None], + np.arange(n_templates)[:, None], np.arange(n_templates), np.zeros(n_templates, dtype=int) ) @@ -316,6 +322,7 @@ def compressed_upsampled_templates( # build the compressed upsampling map compressed_upsampling_map = np.zeros((n_templates, max_upsample), dtype=int) + compressed_upsampling_index = np.full((n_templates, max_upsample), -1, dtype=int) template_indices = [] upsampling_indices = [] current_compressed_index = 0 @@ -327,6 +334,7 @@ def compressed_upsampled_templates( compressed_upsampling_map[i] = current_compressed_index + np.arange(nup).repeat( compression ) + compressed_upsampling_index[i, :nup] = current_compressed_index + np.arange(nup) current_compressed_index += nup # indices of the templates to keep in the full array of upsampled templates @@ -334,6 +342,7 @@ def compressed_upsampled_templates( upsampling_indices.extend(compression * np.arange(nup)) template_indices = np.array(template_indices) upsampling_indices = np.array(upsampling_indices) + compressed_upsampling_index[compressed_upsampling_index < 0] = current_compressed_index # get the upsampled templates all_upsampled_templates = temporally_upsample_templates( @@ -348,6 +357,7 @@ def compressed_upsampled_templates( compressed_upsampled_templates = all_upsampled_templates[rix] return CompressedUpsampledTemplates( + current_compressed_index, compressed_upsampled_templates, compressed_upsampling_map, template_indices, diff --git a/src/dartsort/util/multiprocessing_util.py b/src/dartsort/util/multiprocessing_util.py index dcbfb429..efc8a691 100644 --- a/src/dartsort/util/multiprocessing_util.py +++ b/src/dartsort/util/multiprocessing_util.py @@ -2,6 +2,8 @@ from concurrent.futures import ProcessPoolExecutor from multiprocessing import get_context +# TODO: torch.multiprocessing? + try: import cloudpickle except ImportError: diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index 439413fe..7824cd86 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -1,5 +1,8 @@ +import math + import torch import torch.nn.functional as F +from scipy.signal._signaltools import _calc_oa_lens from torch.fft import irfft, rfft @@ -99,6 +102,25 @@ def grab_spikes( return traces[time_ix[:, :, None], chan_ix[:, None, :]] +def grab_spikes_full( + traces, + trough_times, + trough_offset=42, + spike_length_samples=121, + buffer=0, +): + """Grab spikes from a tensor of traces""" + assert trough_times.ndim == 1 + spike_sample_offsets = torch.arange( + buffer - trough_offset, + buffer - trough_offset + spike_length_samples, + device=trough_times.device, + ) + time_ix = trough_times[:, None] + spike_sample_offsets[None, :] + chan_ix = torch.arange(traces.shape[1], device=traces.device) + return traces[time_ix[:, :, None], chan_ix[None, None, :]] + + def add_spikes_( traces, trough_times, @@ -226,26 +248,52 @@ def real_resample(x, num, dim=0): # inverse transform y = irfft(g, num, dim=dim) - y *= (float(num) / float(Nx)) + y *= float(num) / float(Nx) return y +def steps_and_pad(s1, in1_step, s2, in2_step, block_size, overlap): + shape_final = s1 + s2 - 1 + # figure out n steps and padding + if s1 > in1_step: + nstep1 = math.ceil((s1 + 1) / in1_step) + if (block_size - overlap) * nstep1 < shape_final: + nstep1 += 1 + + pad1 = nstep1 * in1_step - s1 + else: + nstep1 = 1 + pad1 = 0 + + if s2 > in2_step: + nstep2 = math.ceil((s2 + 1) / in2_step) + if (block_size - overlap) * nstep2 < shape_final: + nstep2 += 1 + + pad2 = nstep2 * in2_step - s2 + else: + nstep2 = 1 + pad2 = 0 + return nstep1, pad1, nstep2, pad2 + + def depthwise_oaconv1d(input, weight, f2=None, padding=0): - """Depthwise correlation (F.conv1d with groups=in_chans) with overlap-add - """ + """Depthwise correlation (F.conv1d with groups=in_chans) with overlap-add""" # conv on last axis # assert input.ndim == weight.ndim == 2 n1 = input.shape[0] n2 = weight.shape[0] - # assert n1 == n2 + assert n1 == n2 s1 = input.shape[1] s2 = weight.shape[1] - # assert s1 >= s2 + assert s1 >= s2 shape_final = s1 + s2 - 1 block_size, overlap, in1_step, in2_step = _calc_oa_lens(s1, s2) - nstep1, pad1, nstep2, pad2 = steps_and_pad(s1, in1_step, s2, in2_step, block_size, overlap) + nstep1, pad1, nstep2, pad2 = steps_and_pad( + s1, in1_step, s2, in2_step, block_size, overlap + ) if pad1 > 0: input = F.pad(input, (0, pad1)) @@ -272,6 +320,10 @@ def depthwise_oaconv1d(input, weight, f2=None, padding=0): oa = fold_res.reshape(n1, fold_out_len) # this is the full convolution - oa = oa[:, :shape_final - pad1] + oa = oa[:, : shape_final - pad1] + # extract correct padding + padding = padding + s2 - 1 + assert oa.shape[1] > 2 * padding + oa = oa[:, padding:oa.shape[1] - padding] return oa From a665f63b5adb9cb1fff5a0e56af139f0a8789aa7 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 21 Nov 2023 20:33:30 -0500 Subject: [PATCH 34/49] Fix tests --- src/dartsort/templates/pairwise_util.py | 5 +++++ src/dartsort/templates/template_util.py | 2 ++ 2 files changed, 7 insertions(+) diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index d9d467e1..95513076 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -1031,6 +1031,11 @@ def coarse_approximate( active_temp_a = temp_ix_a[inshift] unique_active_temp_a = np.unique(active_temp_a) + + # TODO just upsampling dedup + # active_temp_b = temp_ix_b[inshift] + # unique_active_temp_b = np.unique(active_temp_b) + # if unique_active_temp_a.size == unique_active_temp_b.size == 1: if unique_active_temp_a.size == 1: new_pconv.append(convs) old_ix_to_new_ix[inshift] = np.arange( diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 197b60d5..2a98e4b3 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -264,6 +264,7 @@ def temporally_upsample_templates( "n_compressed_upsampled_templates", "compressed_upsampled_templates", "compressed_upsampling_map", + "compressed_upsampling_index", "compressed_index_to_template_index", "compressed_index_to_upsampling_index", ], @@ -360,6 +361,7 @@ def compressed_upsampled_templates( current_compressed_index, compressed_upsampled_templates, compressed_upsampling_map, + compressed_upsampling_index, template_indices, upsampling_indices, ) From 1aa8b40a526b67fbda87efd5481276c81fc60466 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 21 Nov 2023 20:36:29 -0500 Subject: [PATCH 35/49] Format --- src/dartsort/util/drift_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index 0e6cccb6..d30d61dd 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -559,10 +559,10 @@ def get_shift_and_unit_pairs( ) pitch_shifts = pitch_shifts.astype(int) pitch_shift_ix = np.searchsorted(all_pitch_shifts, pitch_shifts) - + shifted_temp_ixs = template_shift_index[temp_ixs, pitch_shift_ix] cooccurrence[shifted_temp_ixs[:, None], shifted_temp_ixs[None, :]] = 1 - + return TemplateShiftIndex( n_template_shift_pairs, all_pitch_shifts, From 9066d76e29fc27d803c6a32d00eb5d8ea39ff55a Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 21 Nov 2023 20:37:13 -0500 Subject: [PATCH 36/49] Dataclasses --- src/dartsort/templates/template_util.py | 31 +++++++++++++------------ 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 30255af1..38d38d74 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -1,4 +1,4 @@ -from collections import namedtuple +from dataclasses import dataclass import numpy as np from dartsort.localize.localize_util import localize_waveforms @@ -192,9 +192,11 @@ def templates_at_time( # -- template numerical processing -LowRankTemplates = namedtuple( - "LowRankTemplates", ["temporal_components", "singular_values", "spatial_components"] -) +@dataclass +class LowRankTemplates: + temporal_components: np.ndarray + singular_values: np.ndarray + spatial_components: np.ndarray def svd_compress_templates( @@ -258,15 +260,12 @@ def temporally_upsample_templates( return upsampled_templates -CompressedUpsampledTemplates = namedtuple( - "CompressedUpsampledTemplates", - [ - "compressed_upsampled_templates", - "compressed_upsampling_map", - "compressed_index_to_template_index", - "compressed_index_to_upsampling_index", - ], -) +@dataclass +class CompressedUpsampledTemplates: + compressed_upsampled_templates: np.ndarray + compressed_upsampling_map: np.ndarray + compressed_index_to_template_index: np.ndarray + compressed_index_to_upsampling_index: np.ndarray def default_n_upsamples_map(ptps): @@ -300,7 +299,7 @@ def compressed_upsampled_templates( templates, np.arange(n_templates)[:, None], np.arange(n_templates), - np.zeros(n_templates, dtype=int) + np.zeros(n_templates, dtype=int), ) # how many copies should each unit get? @@ -341,7 +340,9 @@ def compressed_upsampled_templates( ) # n, up, t, c all_upsampled_templates = all_upsampled_templates.transpose(0, 2, 1, 3) - rix = np.ravel_multi_index((template_indices, upsampling_indices), all_upsampled_templates.shape[:2]) + rix = np.ravel_multi_index( + (template_indices, upsampling_indices), all_upsampled_templates.shape[:2] + ) all_upsampled_templates = all_upsampled_templates.reshape( n_templates * max_upsample, templates.shape[1], templates.shape[2] ) From 45b9a3c89dcab447214c776c53797936d96756f1 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 21 Nov 2023 20:38:14 -0500 Subject: [PATCH 37/49] Get running --- src/dartsort/templates/pairwise_util.py | 224 +++++++++++++++--------- 1 file changed, 138 insertions(+), 86 deletions(-) diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index 07c7ac7f..a15a701c 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -166,7 +166,9 @@ def iterate_compressed_pairwise_convolutions( process the results differently. """ # construct drift-related helper data if needed - print(f"iterate_compressed_pairwise_convolutions {conv_batch_size=} {units_batch_size=} {device=}") + print( + f"iterate_compressed_pairwise_convolutions {conv_batch_size=} {units_batch_size=} {device=}" + ) n_shifts = template_shift_index.all_pitch_shifts.size do_shifting = n_shifts > 1 geom_kdtree = reg_geom_kdtree = match_distance = None @@ -189,7 +191,7 @@ def iterate_compressed_pairwise_convolutions( # worker kwargs kwargs = dict( - template_data=template_data, + unit_ids=template_data.unit_ids, low_rank_templates=low_rank_templates, compressed_upsampled_temporal=compressed_upsampled_temporal, template_shift_index=template_shift_index, @@ -254,7 +256,7 @@ class CompressedConvResult: def compressed_convolve_pairs( - template_data: templates.TemplateData, + unit_ids: np.ndarray, low_rank_templates: template_util.LowRankTemplates, compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, template_shift_index: drift_util.TemplateShiftIndex, @@ -290,13 +292,13 @@ def compressed_convolve_pairs( # what pairs, shifts, etc are we convolving? shifted_temp_ix_a, temp_ix_a, shift_a, unit_a = handle_shift_indices( - units_a, template_data.unit_ids, template_shift_index + units_a, unit_ids, template_shift_index ) shifted_temp_ix_b, temp_ix_b, shift_b, unit_b = handle_shift_indices( - units_b, template_data.unit_ids, template_shift_index + units_b, unit_ids, template_shift_index ) - # print(f"{shifted_temp_ix_a.shape=}") - # print(f"{shifted_temp_ix_b.shape=}") + # print(f"0 {shifted_temp_ix_a.shape=} {(shifted_temp_ix_a.size / np.unique(unit_a).size)=}") + # print(f"0 {shifted_temp_ix_b.shape=} {(shifted_temp_ix_b.size / np.unique(unit_b).size)=}") # get (shifted) spatial components * singular values spatial_singular_a = get_shifted_spatial_singular( @@ -344,11 +346,12 @@ def compressed_convolve_pairs( ) if pairs_ret is None: return None - ix_a, ix_b, compression_index, conv_ix = pairs_ret + ix_a, ix_b, compression_index, conv_ix, spatial_shift_ids = pairs_ret # print(f"A {ix_a.shape=}") # print(f"A {ix_b.shape=}") # print(f"A {compression_index.shape=}") # print(f"A {conv_ix.shape=}") + # print(f"A {spatial_shift_ids.shape=}") # print(f"-----------") # print(f"after pairs {conv_ix.shape=} {compression_index.shape=}") @@ -358,14 +361,13 @@ def compressed_convolve_pairs( # handle upsampling # each pair will be duplicated by the b unit's number of upsampled copies ( - ix_a, ix_b, compression_index, conv_ix, conv_upsampling_indices_b, conv_temporal_components_up_b, + compression_dup_ix, ) = compressed_upsampled_pairs( - ix_a, ix_b, compression_index, conv_ix, @@ -374,6 +376,8 @@ def compressed_convolve_pairs( upsampled_shifted_template_index, compressed_upsampled_temporal, ) + ix_a = ix_a[compression_dup_ix] + spatial_shift_ids = spatial_shift_ids[compression_dup_ix] # print(f"B {ix_a.shape=}") # print(f"B {ix_b.shape=}") # print(f"B {compression_index.shape=}") @@ -409,23 +413,17 @@ def compressed_convolve_pairs( ) # print(f"-----------") # print(f"after corr {pconv.shape=} {conv_ix[kept].shape=}") - conv_ix = conv_ix[kept] - if not conv_ix.size: - return None - kept_pairs = np.flatnonzero(np.isin(compression_index, kept)) - # print(f"-----------") - # print(f"kept {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}") - # print(f"{compression_index.min()=} {compression_index.max()=}") - # print(f"{compression_index[kept_pairs].min()=} {compression_index[kept_pairs].max()=}") - # print(f"{ix_a.shape=} {ix_b.shape=}") - # print(f"{kept.shape=} {kept.dtype=} {kept.min()=} {kept.max()=}") - # print(f"{kept_pairs.shape=} {kept_pairs.dtype=} {kept_pairs.min()=} {kept_pairs.max()=}") - compression_index = np.searchsorted(kept, compression_index[kept_pairs]) - conv_ix = np.searchsorted(kept_pairs, conv_ix) - ix_a = ix_a[kept_pairs] - ix_b = ix_b[kept_pairs] - # compression_index = compression_index[kept] - pconv = pconv.cpu() + if kept is not None: + conv_ix = conv_ix[kept] + if not conv_ix.shape[0]: + return None + kept_pairs = np.flatnonzero(np.isin(compression_index, kept)) + compression_index = np.searchsorted(kept, compression_index[kept_pairs]) + conv_ix = np.searchsorted(kept_pairs, conv_ix) + ix_a = ix_a[kept_pairs] + ix_b = ix_b[kept_pairs] + spatial_shift_ids = spatial_shift_ids[kept_pairs] + assert pconv.numel() > 0 # print(f"-----------") # print(f"after searchsorted {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}") # print(f"{compression_index.min()=} {compression_index.max()=}") @@ -439,8 +437,9 @@ def compressed_convolve_pairs( unit_a[ix_a[conv_ix]], unit_b[ix_b[conv_ix]], temp_ix_a[ix_a[conv_ix]], - shift_a[ix_a[conv_ix]], - shift_b[ix_b[conv_ix]], + # shift_a[ix_a[conv_ix]], + # shift_b[ix_b[conv_ix]], + spatial_shift_ids[conv_ix], coarse_approx_error_threshold=coarse_approx_error_threshold, ) # print(f"-----------") @@ -465,7 +464,7 @@ def compressed_convolve_pairs( shift_indices_b=shift_ix_b, upsampling_indices_b=conv_upsampling_indices_b[compression_index], compression_index=compression_index, - compressed_conv=pconv.numpy(), + compressed_conv=pconv.numpy(force=True), ) @@ -511,7 +510,6 @@ def correlate_pairs_lowrank( assert n_pairs == n_pairs_ assert t == t_ assert rank == rank_ - # print(f"{spatial_a.device=} {spatial_b.device=} {temporal_a.device=} {temporal_b.device=}") # print(f"compressed_convolve_pairs {batch_size=} {n_pairs=} {spatial_a.device=}") if max_shift == "full": @@ -525,6 +523,7 @@ def correlate_pairs_lowrank( pconv = torch.zeros( (n_pairs, 2 * max_shift + 1), dtype=spatial_a.dtype, device=spatial_a.device ) + # print(f"compressed_convolve_pairs {pconv.shape=}") for istart in range(0, n_pairs, batch_size): iend = min(istart + batch_size, n_pairs) ix = slice(istart, iend) @@ -552,8 +551,10 @@ def correlate_pairs_lowrank( kept = max_val > conv_ignore_threshold pconv = pconv[kept] kept = np.flatnonzero(kept.numpy(force=True)) + # print(f"compressed_convolve_pairs {pconv.shape=} {kept.shape=}") else: - kept = np.arange(len(pconv)) + kept = None + # print(f"compressed_convolve_pairs {pconv.shape=} {kept=}") return pconv, kept @@ -678,7 +679,7 @@ def shift_deduplicated_pairs( do_shifting = template_shift_index.all_pitch_shifts.size > 1 if not do_shifting: nco_range = torch.arange(nco, device=pair_ix_a.device) - return pair_ix_a, pair_ix_b, nco_range, nco_range + return pair_ix_a, pair_ix_b, nco_range, nco_range, np.zeros(nco, dtype=int) # shift deduplication. algorithm: # 1 for each shifted template, determine the set of registered channels @@ -726,7 +727,6 @@ def shift_deduplicated_pairs( # get the relative shifts shift_a = shift_a[pair_ix_a] shift_b = shift_b[pair_ix_b] - shift_diff = shift_a - shift_b # print(f"{temp_ix_a=}") # print(f"{shift_a=}") # print(f"{active_chan_ids_a[pair_ix_a]=}") @@ -736,12 +736,19 @@ def shift_deduplicated_pairs( # print(f"{shift_diff=}") # figure out combinations + _, spatial_shift_ids = np.unique( + np.c_[ + active_chan_ids_a[pair_ix_a], + active_chan_ids_b[pair_ix_b], + shift_a - shift_b, + ], + axis=0, + return_inverse=True, + ) conv_determiners = np.c_[ temp_ix_a, - active_chan_ids_a[pair_ix_a], temp_ix_b, - active_chan_ids_b[pair_ix_b], - shift_diff, + spatial_shift_ids, ] # print(f"{conv_determiners=}") # conv_ix: indices of unique determiners @@ -750,7 +757,7 @@ def shift_deduplicated_pairs( conv_determiners, axis=0, return_index=True, return_inverse=True ) - return pair_ix_a, pair_ix_b, compression_index, conv_ix + return pair_ix_a, pair_ix_b, compression_index, conv_ix, spatial_shift_ids UpsampledShiftedTemplateIndex = namedtuple( @@ -827,7 +834,6 @@ def get_upsampled_shifted_template_index( def compressed_upsampled_pairs( - ix_a, ix_b, compression_index, conv_ix, @@ -845,54 +851,78 @@ def compressed_upsampled_pairs( We will upsample the templates in the RHS (b) in a compressed way. """ up_factor = compressed_upsampled_temporal.compressed_upsampling_map.shape[1] + compression_dup_ix = slice(None) if up_factor == 1: upinds = np.zeros(len(conv_ix), dtype=int) temp_comps = compressed_upsampled_temporal.compressed_upsampled_templates[ temp_ix_b[ix_b[conv_ix]] ] - return ix_a, ix_b, compression_index, conv_ix, upinds, temp_comps + return ix_b, compression_index, conv_ix, upinds, temp_comps, compression_dup_ix # each conv_ix needs to be duplicated as many times as its b template has - # upsampled copies. And, all ix_{a,b}[i] such that compression_ix[i] lands in + # upsampled copies + conv_shifted_temp_ix_b = shifted_temp_ix_b[ix_b[conv_ix]] + upsampling_mask = ( + conv_shifted_temp_ix_b[:, None] + == upsampled_shifted_template_index.up_shift_temp_ix_to_shift_temp_ix[None, :] + ) + conv_up_i, up_shift_up_i = np.nonzero(upsampling_mask) + conv_compressed_upsampled_ix = ( + upsampled_shifted_template_index.up_shift_temp_ix_to_shift_temp_ix[ + up_shift_up_i + ] + ) + conv_ix_up = conv_ix[conv_up_i] + # And, all ix_{a,b}[i] such that compression_ix[i] lands in # that conv_ix need to be duplicated as well. - ix_a_up = [] - ix_b_up = [] - compression_index_up = [] - conv_ix_up = [] - conv_compressed_upsampled_ix = [] - cur_dedup_ix = 0 - for i, convi in enumerate(conv_ix): - # get b's shifted template ix - conv_shifted_temp_ix_b = shifted_temp_ix_b[ix_b[convi]] - - # which compressed upsampled indices match this? - which_up = np.flatnonzero( - upsampled_shifted_template_index.up_shift_temp_ix_to_shift_temp_ix - == conv_shifted_temp_ix_b - ) - conv_comp_up_ix = ( - upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[which_up] - ) - - # which deduplication indices map ix_a,b to this convi? - which_dedup = np.flatnonzero(compression_index == i) - - # extend arrays with new indices - nupi = conv_comp_up_ix.size - ix_a_up.extend(np.repeat(ix_a[which_dedup], nupi)) - ix_b_up.extend(np.repeat(ix_b[which_dedup], nupi)) - conv_ix_up.extend([convi] * nupi) - compression_index_up.extend( - np.tile(np.arange(cur_dedup_ix, cur_dedup_ix + nupi), which_dedup.size) - ) - cur_dedup_ix += nupi - conv_compressed_upsampled_ix.extend(conv_comp_up_ix) - - ix_a_up = np.array(ix_a_up) - ix_b_up = np.array(ix_b_up) - compression_index_up = np.array(compression_index_up) - conv_ix_up = np.array(conv_ix_up) - conv_compressed_upsampled_ix = np.array(conv_compressed_upsampled_ix) + dup_mask = conv_ix[compression_index][:, None] == conv_ix_up[None, :] + if torch.is_tensor(dup_mask): + dup_mask = dup_mask.numpy(force=True) + compression_dup_ix, compression_index_up = np.nonzero(dup_mask) + ix_b_up = ix_b[compression_dup_ix] + # ix_a_up = np.repeat(ix_a, ndups) + # ix_b_up = np.repeat(ix_b, ndups) + + # ix_a_up = np.zeros(len(ix_a) * up_factor, dtype=int) + # ix_b_up = np.zeros(len(ix_a) * up_factor, dtype=int) + # compression_index_up = np.zeros(len(ix_a) * up_factor, dtype=int) + # conv_ix_up = np.zeros(len(conv_ix) * up_factor, dtype=int) + # conv_compressed_upsampled_ix = np.zeros(len(conv_ix) * up_factor, dtype=int) + # cur_dedup_ix = 0 + # cur_pair_ix = 0 + # cur_conv_up_ix = 0 + # for i, convi in enumerate(conv_ix): + # # get b's shifted template ix + # conv_shifted_temp_ix_b = shifted_temp_ix_b[ix_b[convi]] + + # # which compressed upsampled indices match this? + # which_up = np.flatnonzero( + # upsampled_shifted_template_index.up_shift_temp_ix_to_shift_temp_ix + # == conv_shifted_temp_ix_b + # ) + # conv_comp_up_ix = ( + # upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[which_up] + # ) + + # # which deduplication indices map ix_a,b to this convi? + # which_dedup = np.flatnonzero(compression_index == i) + + # # extend arrays with new indices + # nupi = conv_comp_up_ix.size + # n_new_pair = which_dedup.size * nupi + # ix_a_up[cur_pair_ix:cur_pair_ix+n_new_pair] = np.repeat(ix_a[which_dedup], nupi) + # ix_b_up[cur_pair_ix:cur_pair_ix+n_new_pair] = np.repeat(ix_b[which_dedup], nupi) + # conv_ix_up[cur_dedup_ix:cur_dedup_ix+nupi]=convi + # conv_compressed_upsampled_ix[cur_dedup_ix:cur_dedup_ix+nupi]=conv_comp_up_ix + # compression_index_up[cur_pair_ix:cur_pair_ix+n_new_pair]=np.tile(np.arange(cur_dedup_ix, cur_dedup_ix + nupi), which_dedup.size) + # cur_pair_ix += n_new_pair + # cur_dedup_ix += nupi + + # ix_a_up = ix_a_up[:cur_pair_ix] #np.array(ix_a_up) + # ix_b_up = ix_b_up[:cur_pair_ix] #np.array(ix_b_up) + # conv_compressed_upsampled_ix = conv_compressed_upsampled_ix[:cur_pair_ix] #np.array(conv_compressed_upsampled_ix) + # compression_index_up = compression_index_up[:cur_dedup_ix] #np.array(compression_index_up) + # conv_ix_up = conv_ix_up[:cur_dedup_ix] #np.array(conv_ix_up) # which upsamples and which templates? conv_upsampling_indices_b = ( @@ -907,12 +937,12 @@ def compressed_upsampled_pairs( ) return ( - ix_a_up, ix_b_up, compression_index_up, conv_ix_up, conv_upsampling_indices_b, conv_temporal_components_up_b, + compression_dup_ix, ) @@ -921,8 +951,9 @@ def coarse_approximate( units_a, units_b, temp_ix_a, - shift_a, - shift_b, + # shift_a, + # shift_b, + spatial_shift_ids, coarse_approx_error_threshold=0.0, ): """Try to replace fine (superres+temporally upsampled) convs with coarse ones @@ -942,16 +973,19 @@ def coarse_approximate( This needs to tell the caller how to update its bookkeeping. """ + if not pconv.numel(): + return pconv, slice(None) + new_pconv = [] old_ix_to_new_ix = np.full(len(pconv), -1) cur_new_ix = 0 - shift_diff = shift_a - shift_b + # shift_diff = shift_a - shift_b for ua in np.unique(units_a): ina = np.flatnonzero(units_a == ua) partners_b = np.unique(units_b[ina]) for ub in partners_b: inab = ina[units_b[ina] == ub] - dshift = shift_diff[inab] + dshift = spatial_shift_ids[inab] for shift in np.unique(dshift): inshift = inab[dshift == shift] @@ -996,7 +1030,7 @@ def coarse_approximate( ) cur_new_ix += insup.sum() - new_pconv = torch.cat(new_pconv) + new_pconv = torch.cat(new_pconv, out=pconv[:cur_new_ix]) return new_pconv, old_ix_to_new_ix @@ -1005,7 +1039,7 @@ def coarse_approximate( @dataclass class ConvWorkerContext: - template_data: templates.TemplateData + unit_ids: np.ndarray low_rank_templates: template_util.LowRankTemplates compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates template_shift_index: drift_util.TemplateShiftIndex @@ -1021,6 +1055,24 @@ class ConvWorkerContext: batch_size: int = 128 device: Optional[torch.device] = None + def __post_init__(self): + # to device + self.compressed_upsampled_temporal.compressed_upsampled_templates = ( + torch.as_tensor( + self.compressed_upsampled_temporal.compressed_upsampled_templates, + device=self.device, + ) + ) + self.low_rank_templates.spatial_components = torch.as_tensor( + self.low_rank_templates.spatial_components, device=self.device + ) + self.low_rank_templates.singular_values = torch.as_tensor( + self.low_rank_templates.singular_values, device=self.device + ) + self.low_rank_templates.temporal_components = torch.as_tensor( + self.low_rank_templates.temporal_components, device=self.device + ) + _conv_worker_context = None From 955be9613e8ec42bdf639b2467e79c51458e1e58 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 21 Nov 2023 21:34:11 -0500 Subject: [PATCH 38/49] Debug --- src/dartsort/templates/pairwise_util.py | 23 ++++++++----- src/dartsort/templates/template_util.py | 1 + src/dartsort/templates/templates.py | 14 ++++++-- src/dartsort/util/drift_util.py | 44 ++++++++++++------------- 4 files changed, 48 insertions(+), 34 deletions(-) diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index ead136f2..c7f2c35c 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -191,8 +191,6 @@ def iterate_compressed_pairwise_convolutions( print( f"iterate_compressed_pairwise_convolutions {conv_batch_size=} {units_batch_size=} {device=}" ) - n_shifts = template_shift_index.all_pitch_shifts.size - do_shifting = n_shifts > 1 geom_kdtree = reg_geom_kdtree = match_distance = None if do_shifting: geom_kdtree = KDTree(geom) @@ -1131,14 +1129,23 @@ def __post_init__(self): device=self.device, ) ) - self.low_rank_templates.spatial_components = torch.as_tensor( - self.low_rank_templates.spatial_components, device=self.device + self.low_rank_templates_a.spatial_components = torch.as_tensor( + self.low_rank_templates_a.spatial_components, device=self.device + ) + self.low_rank_templates_a.singular_values = torch.as_tensor( + self.low_rank_templates_a.singular_values, device=self.device + ) + self.low_rank_templates_a.temporal_components = torch.as_tensor( + self.low_rank_templates_a.temporal_components, device=self.device + ) + self.low_rank_templates_b.spatial_components = torch.as_tensor( + self.low_rank_templates_b.spatial_components, device=self.device ) - self.low_rank_templates.singular_values = torch.as_tensor( - self.low_rank_templates.singular_values, device=self.device + self.low_rank_templates_b.singular_values = torch.as_tensor( + self.low_rank_templates_b.singular_values, device=self.device ) - self.low_rank_templates.temporal_components = torch.as_tensor( - self.low_rank_templates.temporal_components, device=self.device + self.low_rank_templates_b.temporal_components = torch.as_tensor( + self.low_rank_templates_b.temporal_components, device=self.device ) diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 6b34d4da..57f303bd 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -262,6 +262,7 @@ def temporally_upsample_templates( @dataclass class CompressedUpsampledTemplates: + n_compressed_upsampled_templates: int compressed_upsampled_templates: np.ndarray compressed_upsampling_map: np.ndarray compressed_upsampling_index: np.ndarray diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 745210d6..f940ff49 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -52,11 +52,18 @@ def to_npz(self, npz_path): def coarsen(self): """Weighted average all templates that share a unit id and re-localize.""" # update templates - templates = weighted_average(self.unit_ids, self.templates, self.spike_counts) + print(f"a {np.isnan(self.templates).any()=}") + print(f"b {np.equal(self.templates, 0).all(axis=(1,2)).sum()=}") + unit_ids_unique, flat_ids = np.unique(self.unit_ids, return_inverse=True) + templates = weighted_average(flat_ids, self.templates, self.spike_counts) + print(f"b {np.isnan(templates).any()=}") + print(f"b {np.equal(templates, 0).all(axis=(1,2)).sum()=}") # collect spike counts spike_counts = np.zeros(len(templates)) - np.add.at(spike_counts, self.unit_ids, self.spike_counts) + np.add.at(spike_counts, np.arange(unit_ids_unique.size), self.spike_counts) + print(f"b {np.isnan(spike_counts).any()=}") + print(f"b {np.isnan(self.registered_geom).any()=}") # re-localize registered_template_depths_um = get_template_depths( @@ -64,11 +71,12 @@ def coarsen(self): self.registered_geom, localization_radius_um=self.localization_radius_um, ) + print(f"b {np.isnan(registered_template_depths_um).any()=}") return replace( self, templates=templates, - unit_ids=np.arange(len(templates)), + unit_ids=unit_ids_unique, spike_counts=spike_counts, registered_template_depths_um=registered_template_depths_um, ) diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index decd05eb..71ce470c 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -543,6 +543,14 @@ def get_shift_and_unit_pairs( reg_depths_um = np.concatenate((reg_depths_um_a, reg_depths_um_b)) # figure out all shifts for all units at all times + print(f"{chunk_time_centers_s.min()=} {chunk_time_centers_s.max()=}") + print(f"{reg_depths_um.min()=} {reg_depths_um.max()=}") + print(f"{reg_depths_um_a.min()=} {reg_depths_um_a.max()=}") + print(f"{reg_depths_um_b.min()=} {reg_depths_um_b.max()=}") + print(f"{motion_est.time_bin_centers_s.min()=}") + print(f"{motion_est.time_bin_centers_s.max()=}") + print(f"{motion_est.spatial_bin_centers_um.min()=}") + print(f"{motion_est.spatial_bin_centers_um.max()=}") unreg_depths_um = np.concatenate( [ motion_est.disp_at_s(t_s, depth_um=reg_depths_um, grid=True).T @@ -570,27 +578,17 @@ def get_shift_and_unit_pairs( template_shift_index_b = TemplateShiftIndex.from_shift_matrix(shifts_b) # co-occurrence matrix: do these shifted templates appear together? - cooccurrence = np.eye(n_template_shift_pairs, dtype=bool) - for t_s in chunk_time_centers_s: - unregistered_depths_um = invert_motion_estimate( - motion_est, t_s, template_data.registered_template_depths_um - ) - pitch_shifts = get_spike_pitch_shifts( - depths_um=template_data.registered_template_depths_um, - pitch=pitch, - registered_depths_um=unregistered_depths_um, - ) - pitch_shifts = pitch_shifts.astype(int) - pitch_shift_ix = np.searchsorted(all_pitch_shifts, pitch_shifts) - - shifted_temp_ixs = template_shift_index[temp_ixs, pitch_shift_ix] - cooccurrence[shifted_temp_ixs[:, None], shifted_temp_ixs[None, :]] = 1 + cooccurrence = np.zeros( + (template_shift_index_a.n_shifted_templates, template_shift_index_b.n_shifted_templates), + dtype=bool) + temps_a = np.arange(na) + temps_b = np.arange(nb) + for j in range(len(chunk_time_centers_s)): + shifted_ids_a = template_shift_index_a.shifts_to_shifted_ids(temps_a, shifts_a[j]) + if same: + shifted_ids_b = shifted_ids_a + else: + shifted_ids_b = template_shift_index_b.shifts_to_shifted_ids(temps_b, shifts_b[j]) + cooccurrence[shifted_ids_a[:, None], shifted_ids_b[None, :]] = 1 - return TemplateShiftIndex( - n_template_shift_pairs, - all_pitch_shifts, - template_shift_index, - cooccurrence, - shifted_temp_ix_to_temp_ix, - shifted_temp_ix_to_shift, - ) + return template_shift_index_a, template_shift_index_b, cooccurrence From ca633779379356e00d3b6db78e4d9baccfa8c842 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 21 Nov 2023 21:43:43 -0500 Subject: [PATCH 39/49] unit_ids handling... --- src/dartsort/peel/matching.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index 01dc84b1..e693ebf5 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -129,7 +129,6 @@ def precompute_peeling_data(self, save_folder, n_jobs=0, device=None): n_jobs=n_jobs, device=device, ) - self.handle_template_groups(self.template_data.unit_ids) # couple more torch buffers self.register_buffer( "_refrac_ix", @@ -175,7 +174,7 @@ def check_shapes(self): ) assert self.unit_ids.shape == (self.n_templates,) - def handle_template_groups(self, unit_ids): + def handle_template_groups(self, obj_unit_ids, unit_ids): """Grouped templates in objective If not coarse_objective, then several rows of the objective may @@ -183,6 +182,9 @@ def handle_template_groups(self, unit_ids): refractory conditions. """ self.register_buffer("unit_ids", torch.from_numpy(unit_ids)) + self.register_buffer("obj_unit_ids", torch.from_numpy(obj_unit_ids)) + units, counts, fine_to_coarse = np.unique(unit_ids, return_counts=True, return_inverse=True) + self.register_buffer("fine_to_coarse", torch.from_numpy(fine_to_coarse)) self.grouped_temps = True unique_units = np.unique(unit_ids) if unique_units.size == unit_ids.size: @@ -193,11 +195,10 @@ def handle_template_groups(self, unit_ids): return assert unit_ids.shape == (self.n_templates,) - units, counts = np.unique(self.unit_ids, return_counts=True) - superres_index = np.full((self.n_templates, counts.max()), self.n_templates, -1) - for u in units: - my_sup = np.flatnonzero(self.unit_ids == u) - superres_index[u, : len(my_sup)] = my_sup + superres_index = np.full((len(obj_unit_ids), counts.max()), self.n_templates, -1) + for j, u in enumerate(obj_unit_ids): + my_sup = np.flatnonzero(unit_ids == u) + superres_index[j, : len(my_sup)] = my_sup self.register_buffer("superres_index", torch.from_numpy(superres_index)) if self.coarse_objective: @@ -283,6 +284,7 @@ def build_template_data( self.register_buffer( "objective_spatial_components", self.spatial_components ) + self.handle_template_groups(coarse_template_data.unit_ids, self.template_data.unit_ids) half_chunk = self.chunk_length_samples // 2 chunk_centers_samples = np.arange( @@ -463,7 +465,7 @@ def templates_at_time(self, t_s): objective_spatial_components=cur_obj_spatial, objective_singular_values=self.objective_singular_values, objective_temporal_components=self.objective_temporal_components, - unit_ids=self.unit_ids, + fine_to_coarse=self.fine_to_coarse, coarse_objective=self.coarse_objective, spatial_components=cur_spatial, singular_values=self.singular_values, @@ -654,14 +656,14 @@ def enforce_refractory(self, objective, times, template_indices): # overwrite objective with -inf to enforce refractoriness time_ix = times[:, None] + self._refrac_ix[None, :] if not self.grouped_temps: - unit_ix = template_indices[:, None, None] + row_ix = template_indices[:, None, None] elif self.coarse_objective: - unit_ix = self.unit_ids[template_indices][:, None, None] + row_ix = self.fine_to_coarse[template_indices][:, None, None] elif self.grouped_temps: - unit_ix = self.group_index[template_indices] + row_ix = self.group_index[template_indices] else: assert False - objective[unit_ix[:, :, None], time_ix[:, None, :]] = -torch.inf + objective[row_ix[:, :, None], time_ix[:, None, :]] = -torch.inf @dataclass @@ -676,7 +678,7 @@ class MatchingTemplateData: objective_spatial_components: torch.Tensor objective_singular_values: torch.Tensor objective_temporal_components: torch.Tensor - unit_ids: torch.LongTensor + fine_to_coarse: torch.LongTensor coarse_objective: bool spatial_components: torch.Tensor singular_values: torch.Tensor From 4b1d1cba64d5fbc9041c84b7088ceb651aa50ce9 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 27 Nov 2023 19:18:46 -0500 Subject: [PATCH 40/49] A version which runs --- src/dartsort/peel/matching.py | 138 ++++++++++++++---------- src/dartsort/peel/peel_base.py | 2 +- src/dartsort/templates/pairwise.py | 111 +++++++++++++------ src/dartsort/templates/pairwise_util.py | 15 ++- src/dartsort/templates/template_util.py | 5 + src/dartsort/templates/templates.py | 12 +-- src/dartsort/util/drift_util.py | 34 ++++-- src/dartsort/util/spiketorch.py | 13 ++- 8 files changed, 216 insertions(+), 114 deletions(-) diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index e693ebf5..48abc064 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -166,12 +166,6 @@ def check_shapes(self): self.svd_compression_rank, self.n_registered_channels, ) - assert self.upsampled_temporal_components.shape == ( - self.n_templates, - self.spike_length_samples, - self.temporal_upsampling_factor, - self.svd_compression_rank, - ) assert self.unit_ids.shape == (self.n_templates,) def handle_template_groups(self, obj_unit_ids, unit_ids): @@ -183,7 +177,9 @@ def handle_template_groups(self, obj_unit_ids, unit_ids): """ self.register_buffer("unit_ids", torch.from_numpy(unit_ids)) self.register_buffer("obj_unit_ids", torch.from_numpy(obj_unit_ids)) - units, counts, fine_to_coarse = np.unique(unit_ids, return_counts=True, return_inverse=True) + units, fine_to_coarse, counts = np.unique( + unit_ids, return_counts=True, return_inverse=True + ) self.register_buffer("fine_to_coarse", torch.from_numpy(fine_to_coarse)) self.grouped_temps = True unique_units = np.unique(unit_ids) @@ -195,7 +191,7 @@ def handle_template_groups(self, obj_unit_ids, unit_ids): return assert unit_ids.shape == (self.n_templates,) - superres_index = np.full((len(obj_unit_ids), counts.max()), self.n_templates, -1) + superres_index = np.full((len(obj_unit_ids), counts.max()), self.n_templates) for j, u in enumerate(obj_unit_ids): my_sup = np.flatnonzero(unit_ids == u) superres_index[j, : len(my_sup)] = my_sup @@ -273,6 +269,7 @@ def build_template_data( self.register_buffer( "objective_spatial_components", torch.tensor(spatial_components) ) + self.obj_n_templates = spatial_components.shape[0] else: coarse_template_data = template_data coarse_low_rank_templates = low_rank_templates @@ -284,7 +281,10 @@ def build_template_data( self.register_buffer( "objective_spatial_components", self.spatial_components ) - self.handle_template_groups(coarse_template_data.unit_ids, self.template_data.unit_ids) + self.obj_n_templates = self.n_templates + self.handle_template_groups( + coarse_template_data.unit_ids, self.template_data.unit_ids + ) half_chunk = self.chunk_length_samples // 2 chunk_centers_samples = np.arange( @@ -337,7 +337,9 @@ def handle_upsampling( ) self.register_buffer( "compressed_index_to_upsampling_index", - torch.tensor(compressed_upsampled_temporal.compressed_index_to_upsampling_index), + torch.tensor( + compressed_upsampled_temporal.compressed_index_to_upsampling_index + ), ) self.register_buffer( "compressed_upsampled_temporal", @@ -418,7 +420,7 @@ def templates_at_time(self, t_s): """Handle drift -- grab the right spatial neighborhoods.""" pconvdb = self.pairwise_conv_db if self.is_drifting: - pitch_shifts, cur_spatial = template_util.templates_at_time( + pitch_shifts_b, cur_spatial = template_util.templates_at_time( t_s, self.spatial_components, self.geom, @@ -429,17 +431,8 @@ def templates_at_time(self, t_s): geom_kdtree=self.geom_kdtree, match_distance=self.match_distance, ) - cur_ampvecs = drift_util.get_waveforms_on_static_channels( - self.registered_template_ampvecs[:, None, :], - self.registered_geom, - n_pitches_shift=pitch_shifts, - registered_geom=self.geom, - target_kdtree=self.geom_kdtree, - match_distance=self.match_distance, - fill_value=0.0, - ) if self.coarse_objective: - cur_obj_spatial = template_util.templates_at_time( + pitch_shifts_a, cur_obj_spatial = template_util.templates_at_time( t_s, self.objective_spatial_components, self.geom, @@ -452,14 +445,24 @@ def templates_at_time(self, t_s): ) else: cur_obj_spatial = cur_spatial + pitch_shifts_a = pitch_shifts_b + cur_ampvecs = drift_util.get_waveforms_on_static_channels( + self.registered_template_ampvecs[:, None, :], + self.registered_geom, + n_pitches_shift=pitch_shifts_b, + registered_geom=self.geom, + target_kdtree=self.geom_kdtree, + match_distance=self.match_distance, + fill_value=0.0, + ) max_channels = cur_ampvecs[:, 0, :].argmax(1) - pconvdb = pconvdb.at_shifts(pitch_shifts) + pconvdb = pconvdb.at_shifts(pitch_shifts_a, pitch_shifts_b) else: cur_spatial = self.spatial_components max_channels = self.registered_template_ampvecs.argmax(1) if not pconvdb._is_torch: - pconvdb = pconvdb.to(cur_spatial.device) + pconvdb.to(cur_spatial.device) return MatchingTemplateData( objective_spatial_components=cur_obj_spatial, @@ -492,6 +495,7 @@ def match_chunk( # initialize residual, it needs to be padded to support our channel # indexing convention (used later to extract small channel # neighborhoods). this copies the input. + print(f"match_chunk {traces.shape=}") residual_padded = F.pad(traces, (0, 1), value=torch.nan) residual = residual_padded[:, :-1] @@ -499,13 +503,13 @@ def match_chunk( conv_len = traces.shape[0] - self.spike_length_samples + 1 padded_obj_len = conv_len + 2 * self.obj_pad_len padded_conv = torch.zeros( - self.n_templates, + self.obj_n_templates, padded_obj_len, dtype=traces.dtype, device=traces.device, ) padded_objective = torch.zeros( - self.n_templates + 1, + self.obj_n_templates + 1, padded_obj_len, dtype=traces.dtype, device=traces.device, @@ -513,7 +517,7 @@ def match_chunk( refrac_mask = torch.zeros_like(padded_objective) # padded objective has an extra unit (for group_index) and refractory # padding (for easier implementation of enforce_refractory) - neg_temp_normsq = -compressed_template_data.template_norms_squared[:, None] + neg_temp_normsq = -compressed_template_data.objective_template_norms_squared[:, None] # manages buffers for spike train data (peak times, labels, etc) peaks = MatchingPeaks(device=traces.device) @@ -534,7 +538,7 @@ def match_chunk( # find high-res peaks print("before find") new_peaks = self.find_peaks( - padded_conv, padded_objective, refrac_mask, compressed_template_data + residual, padded_conv, padded_objective, refrac_mask, compressed_template_data ) if new_peaks is None: break @@ -605,7 +609,14 @@ def match_chunk( res["residual"] = residual return res - def find_peaks(self, residual, padded_conv, padded_objective, refrac_mask, compressed_template_data): + def find_peaks( + self, + residual, + padded_conv, + padded_objective, + refrac_mask, + compressed_template_data, + ): # first step: coarse peaks. not temporally upsampled or amplitude-scaled. objective = (padded_objective + refrac_mask)[ :-1, self.obj_pad_len : -self.obj_pad_len @@ -623,13 +634,20 @@ def find_peaks(self, residual, padded_conv, padded_objective, refrac_mask, compr residual_snips = None if self.coarse_objective or self.temporal_upsampling_factor > 1: residual_snips = spiketorch.grab_spikes_full( + residual, times - 1, trough_offset=0, spike_length_samples=self.spike_length_samples + 1, ) # second step: high-res peaks (upsampled and/or amp-scaled) - time_shifts, upsampling_indices, scalings, template_indices, scores = compressed_template_data.fine_match( + ( + time_shifts, + upsampling_indices, + scalings, + template_indices, + scores, + ) = compressed_template_data.fine_match( padded_conv[obj_template_indices, times], objective_max[times], residual_snips, @@ -637,6 +655,7 @@ def find_peaks(self, residual, padded_conv, padded_objective, refrac_mask, compr amp_scale_variance=self.amplitude_scaling_variance, amp_scale_min=self.amp_scale_min, amp_scale_max=self.amp_scale_max, + superres_index=self.superres_index, ) if time_shifts is not None: times += time_shifts @@ -703,10 +722,13 @@ def __post_init__(self): ) assert self.singular_values.shape == (self.n_templates, self.rank) device = self.spatial_components.device + self.temporal_upsampling_factor = self.compressed_upsampling_index.shape[1] + self.n_compressed_upsampled_templates = self.compressed_upsampling_map.max() + 1 # squared l2 norms are usually the sums of squared singular values: # self.template_norms_squared = torch.square(self.singular_values).sum(1) # in this case, we have subset the spatial components, so use a diff formula + self.objective_n_templates = self.objective_spatial_components.shape[0] self.objective_spatial_singular = ( self.objective_spatial_components * self.objective_singular_values[:, :, None] @@ -729,11 +751,14 @@ def convolve(self, traces, padding=0, out=None): """Convolve the objective templates with traces.""" out_len = traces.shape[0] + 2 * padding - self.spike_length_samples + 1 if out is None: - out = torch.zeros( - (self.n_templates, out_len), dtype=traces.dtype, device=traces.device + out = torch.empty( + (self.objective_n_templates, out_len), + dtype=traces.dtype, + device=traces.device, ) else: - assert out.shape == (self.n_templates, out_len) + assert out.shape == (self.objective_n_templates, out_len) + print(f"convolve {traces.shape=} {out.shape=} {out_len=} {padding=}") for q in range(self.rank): # units x time @@ -748,10 +773,9 @@ def convolve(self, traces, padding=0, out=None): # padding=padding, # )[0] conv = spiketorch.depthwise_oaconv1d( - rec_spatial, - temporal[:, :], - padding=padding + rec_spatial, temporal[:, :], padding=padding ) + print(f"{conv.shape=}") if q: out += conv else: @@ -769,16 +793,17 @@ def subtract_conv( scalings, conv_pad_len=0, ): - template_indices_a, template_indices_b, pconvs = scalings[ - :, None - ] * self.pairwise_conv_db.query( + template_indices_a, template_indices_b, times, pconvs = self.pairwise_conv_db.query( template_indices_a=None, template_indices_b=template_indices, upsampling_indices_b=upsampling_indices, + scalings_b=scalings, + times_b=times, grid=True, ) ix_template = template_indices_a[:, None] - ix_time = times[None, :] + (conv_pad_len + self.conv_lags) + ix_time = times[:, None] + (conv_pad_len + self.conv_lags)[None, :] + print(f"{ix_template.shape=} {ix_time.shape=} {scalings.shape=} {pconvs.shape=}") spiketorch.add_at_( conv, (ix_template, ix_time), @@ -795,6 +820,7 @@ def fine_match( amp_scale_variance=0.0, amp_scale_min=None, amp_scale_max=None, + superres_index=None, ): """Determine superres ids, temporal upsampling, and scaling @@ -842,7 +868,7 @@ def fine_match( if self.coarse_objective: # TODO best I came up with, but it still syncs - superres_ix = self.superres_index[objective_template_indices] + superres_ix = superres_index[objective_template_indices] dup_ix, column_ix = (superres_ix < self.n_templates).nonzero(as_tuple=True) template_indices = superres_ix[dup_ix, column_ix] convs = torch.einsum( @@ -851,17 +877,15 @@ def fine_match( self.spatial_singular[template_indices], self.temporal_components[template_indices], ) - neg_norms = -self.template_norms_squared[template_indices] - objs = torch.full( - superres_ix.shape, -torch.inf, device=convs.device - ) - objs[dup_ix, column_ix] = 2 * convs + neg_norms + norms = self.template_norms_squared[template_indices] + objs = torch.full(superres_ix.shape, -torch.inf, device=convs.device) + objs[dup_ix, column_ix] = 2 * convs - norms objs, best_column_ix = objs.max(dim=1) row_ix = torch.arange(best_column_ix.numel(), device=best_column_ix.device) template_indices = superres_ix[row_ix, best_column_ix] else: template_indices = objective_template_indices - neg_norms = -self.template_norms_squared[template_indices] + norms = self.template_norms_squared[template_indices] objs = objs if self.temporal_upsampling_factor == 1 and not amp_scale_variance: @@ -871,7 +895,7 @@ def fine_match( # just scaling inv_lambda = 1 / amp_scale_variance b = convs + inv_lambda - a = neg_norms + inv_lambda + a = norms + inv_lambda scalings = torch.clip(b / a, amp_scale_min, amp_scale_max) objs = 2 * scalings * b - torch.square(scalings) * a - inv_lambda return None, None, scalings, template_indices, objs @@ -889,21 +913,24 @@ def fine_match( self.spatial_singular[template_indices[dup_ix]], self.compressed_upsampled_temporal[comp_up_indices], ) - neg_norms = neg_norms[dup_ix] + norms = norms[dup_ix] objs = torch.full((*comp_up_ix.shape, 2), -torch.inf, device=convs.device) if amp_scale_variance: inv_lambda = 1 / amp_scale_variance b = convs + inv_lambda - a = neg_norms + inv_lambda + a = norms[:, None] + inv_lambda scalings = torch.clip(b / a, amp_scale_min, amp_scale_max) - objs[dup_ix, column_ix] = 2 * scalings * b - torch.square(scalings) * a - inv_lambda + objs[dup_ix, column_ix] = ( + 2 * scalings * b - torch.square(scalings) * a - inv_lambda + ) else: - objs[dup_ix, column_ix] = 2 * convs - neg_norms + print(f"{objs.shape=} {objs[dup_ix, column_ix].shape=} {convs.shape=} {norms.shape=}") + objs[dup_ix, column_ix] = 2 * convs - norms[:, None] scalings = None - objs, best_column_dt_ix = objs.reshape(len(convs), -1).max(dim=1) + objs, best_column_dt_ix = objs.reshape(len(objs), comp_up_ix.shape[1] * 2).max(dim=1) best_column_ix = best_column_dt_ix // 2 - row_ix = torch.arange(best_column_ix.numel(), device=best_column_ix.device) + row_ix = torch.arange(len(objs), device=best_column_ix.device) comp_up_indices = comp_up_ix[row_ix, best_column_ix] upsampling_indices = self.compressed_index_to_upsampling_index[comp_up_indices] @@ -956,9 +983,10 @@ def get_collisioncleaned_waveforms( self.rank_ix[None, :, None], channel_index[channels][:, None, :], ] - temporal = self.upsampled_temporal_components[ - peaks.template_indices, :, peaks.upsampling_indices + comp_up_ix = self.compressed_upsampling_map[ + peaks.template_indices, peaks.upsampling_indices ] + temporal = self.compressed_upsampled_temporal[comp_up_ix] torch.baddbmm(waveforms, temporal, spatial, out=waveforms) return channels, waveforms diff --git a/src/dartsort/peel/peel_base.py b/src/dartsort/peel/peel_base.py index 6b4ba7f8..0a925425 100644 --- a/src/dartsort/peel/peel_base.py +++ b/src/dartsort/peel/peel_base.py @@ -215,7 +215,7 @@ def precompute_peeling_data(self, save_folder, n_jobs=0, device=None): # runs before fit_peeler_models() pass - def fit_peeler_models(self, save_folder): + def fit_peeler_models(self, save_folder, n_jobs=0, device=None): # subclasses should override if they need to fit models for peeling assert not self.peeling_needs_fit() diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index 73848c9f..67f689da 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -108,62 +108,83 @@ def from_template_data( ) return cls.from_h5(hdf5_filename) - def at_shifts(self, shifts_a=None): + def at_shifts(self, shifts_a=None, shifts_b=None): """Subset this database to one set of shifts. The database becomes shiftless (not in the pejorative sense). """ - if shifts_a is None: + if shifts_a is None or shifts_b is None: + assert shifts_a is shifts_b assert self.shifts_a.shape == (1,) + assert self.shifts_b.shape == (1,) return self - assert shifts_a.shape == len(self.shifted_template_index_a) - n_shifted_temps, n_up_shifted_temps = self.pconv_index.shape + assert shifts_a.shape == (len(self.shifted_template_index_a),) + assert shifts_b.shape == (len(self.upsampled_shifted_template_index_b),) + n_shifted_temps_a, n_up_shifted_temps_b = self.pconv_index.shape # active shifted and upsampled indices - shift_ix = np.searchsorted(self.shifts_a, shifts_a) - sub_shifted_temp_index = self.shifted_template_index_a[ - np.arange(len(self.shifted_template_index_a)), - shift_ix, + shift_ix_a = np.searchsorted(self.shifts_a, shifts_a) + shift_ix_b = np.searchsorted(self.shifts_b, shifts_b) + print( + f"at_shifts {self.shifts_a.shape=} {self.shifts_a.min()=} {self.shifts_a.max()=}" + ) + print(f"at_shifts {shifts_a.shape=} {shifts_a.min()=} {shifts_a.max()=}") + print(f"{shift_ix_a.shape=} {shift_ix_a.min()=} {shift_ix_a.max()=}") + + print( + f"at_shifts {self.shifts_b.shape=} {self.shifts_b.min()=} {self.shifts_b.max()=}" + ) + print(f"at_shifts {shifts_b.shape=} {shifts_b.min()=} {shifts_b.max()=}") + print(f"at_shifts {shift_ix_b.shape=} {shift_ix_b.min()=} {shift_ix_b.max()=}") + + print(f"at_shifts {self.shifted_template_index_a.shape=}") + print(f"at_shifts {self.upsampled_shifted_template_index_b.shape=}") + sub_shifted_temp_index_a = self.shifted_template_index_a[ + np.arange(len(self.shifted_template_index_a))[:, None], + shift_ix_a[:, None], ] - sub_up_shifted_temp_index = self.upsampled_shifted_template_index_b[ - np.arange(len(self.shifted_template_index_a)), - shift_ix, + sub_up_shifted_temp_index_b = self.upsampled_shifted_template_index_b[ + np.arange(len(self.upsampled_shifted_template_index_b))[:, None], + shift_ix_b[:, None], ] + print(f"at_shifts {sub_shifted_temp_index_a.shape=}") + print(f"at_shifts {sub_up_shifted_temp_index_b.shape=}") # in flat form for indexing into pconv_index. also, reindex. - valid_shifted = sub_shifted_temp_index < n_shifted_temps - shifted_temp_ixs, new_shifted_temp_ixs = np.unique( - sub_shifted_temp_index[valid_shifted] + valid_a = sub_shifted_temp_index_a < n_shifted_temps_a + shifted_temp_ixs_a, new_shifted_temp_ixs_a = np.unique( + sub_shifted_temp_index_a[valid_a], return_inverse=True ) - valid_up_shifted = sub_up_shifted_temp_index < n_up_shifted_temps - up_shifted_temp_ixs, new_up_shifted_temp_ixs = np.unique( - sub_up_shifted_temp_index[valid_up_shifted], return_inverse=True + valid_b = sub_up_shifted_temp_index_b < n_up_shifted_temps_b + up_shifted_temp_ixs_b, new_up_shifted_temp_ixs_b = np.unique( + sub_up_shifted_temp_index_b[valid_b], return_inverse=True ) # get relevant pconv subset and reindex sub_pconv_indices, new_pconv_indices = np.unique( self.pconv_index[ - shifted_temp_ixs[:, None], - up_shifted_temp_ixs.ravel()[None, :], + shifted_temp_ixs_a[:, None], + up_shifted_temp_ixs_b.ravel()[None, :], ], return_inverse=True, ) sub_pconv = self.pconv[sub_pconv_indices] # reindexing - n_sub_shifted_temps = len(shifted_temp_ixs) - n_sub_up_shifted_temps = len(up_shifted_temp_ixs) + n_sub_shifted_temps_a = len(shifted_temp_ixs_a) + n_sub_up_shifted_temps_b = len(up_shifted_temp_ixs_b) sub_pconv_index = new_pconv_indices.reshape( - n_sub_shifted_temps, n_sub_up_shifted_temps + n_sub_shifted_temps_a, n_sub_up_shifted_temps_b ) - sub_shifted_temp_index[valid_shifted] = new_shifted_temp_ixs - sub_up_shifted_temp_index[valid_shifted] = new_up_shifted_temp_ixs + sub_shifted_temp_index_a[valid_a] = new_shifted_temp_ixs_a + sub_up_shifted_temp_index_b[valid_b] = new_up_shifted_temp_ixs_b return self.__class__( - shifts=np.zeros(1), - shifted_template_index=sub_shifted_temp_index, - upsampled_shifted_template_index=sub_up_shifted_temp_index, + shifts_a=np.zeros(1), + shifts_b=np.zeros(1), + shifted_template_index_a=sub_shifted_temp_index_a, + upsampled_shifted_template_index_b=sub_up_shifted_temp_index_b, pconv_index=sub_pconv_index, pconv=sub_pconv, ) @@ -171,8 +192,10 @@ def at_shifts(self, shifts_a=None): def to(self, device=None): """Become torch tensors on device.""" for f in fields(self): - self.setattr(f.name, torch.as_tensor(getattr(self, f.name), device=device)) + setattr(self, f.name, torch.as_tensor(getattr(self, f.name), device=device)) self.device = device + self._is_torch = True + return self def query( self, @@ -181,6 +204,8 @@ def query( upsampling_indices_b=None, shifts_a=None, shifts_b=None, + scalings_b=None, + times_b=None, return_zero_convs=False, grid=False, ): @@ -229,16 +254,29 @@ def query( # return convolutions between all ai,bj or just ai,bi? if grid: - pconv_indices = self.pconv_index[shifted_temp_ix_a[:, None], up_shifted_temp_ix_b[None, :]] + pconv_indices = self.pconv_index[ + shifted_temp_ix_a[:, None], up_shifted_temp_ix_b[None, :] + ] if self._is_torch: template_indices_a, template_indices_b = torch.cartesian_prod( template_indices_a, template_indices_b ).T + if scalings_b is not None: + print(f"{scalings_b.shape=} {pconv_indices.shape=}") + scalings_b = torch.broadcast_to(scalings_b[None], pconv_indices.shape).reshape(-1) + if times_b is not None: + times_b = torch.broadcast_to(times_b[None], pconv_indices.shape).reshape(-1) pconv_indices = pconv_indices.view(-1) else: - template_indices_a, template_indices_b = np.meshgrid(template_indices_a, template_indices_b, indexing="ij") + template_indices_a, template_indices_b = np.meshgrid( + template_indices_a, template_indices_b, indexing="ij" + ) template_indices_a = template_indices_a.ravel() template_indices_b = template_indices_b.ravel() + if scalings_b is not None: + scalings_b = np.broadcast_to(scalings_b[None], pconv_indices.shape).ravel() + if times_b is not None: + times_b = np.broadcast_to(times_b[None], pconv_indices.shape).ravel() pconv_indices = pconv_indices.ravel() else: pconv_indices = self.pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b] @@ -249,5 +287,16 @@ def query( pconv_indices = pconv_indices[which] template_indices_a = template_indices_a[which] template_indices_b = template_indices_b[which] + if scalings_b is not None: + scalings_b = scalings_b[which] + if times_b is not None: + times_b = times_b[which] + + pconvs = self.pconv[pconv_indices] + if scalings_b is not None: + pconvs.mul_(scalings_b[:, None]) + + if times_b is not None: + return template_indices_a, template_indices_b, times_b, pconvs - return template_indices_a, template_indices_b, self.pconv[pconv_indices] + return template_indices_a, template_indices_b, pconvs diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index c7f2c35c..d20e7a4c 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -3,6 +3,7 @@ from collections import namedtuple from dataclasses import dataclass, fields from typing import Iterator, Optional, Union +from pathlib import Path import h5py import numpy as np @@ -45,8 +46,12 @@ def compressed_convolve_to_h5( of unit pairs, so that it's not all done in memory at one time, and so that it can be done in parallel. """ - if overwrite: - pass # TODO + output_hdf5_filename = Path(output_hdf5_filename) + if not overwrite and output_hdf5_filename.exists(): + with h5py.File(output_hdf5_filename, "r") as h5: + if "pconv_index" in h5: + return output_hdf5_filename + del h5 # construct indexing helpers ( @@ -918,7 +923,7 @@ def compressed_upsampled_pairs( # each conv_ix needs to be duplicated as many times as its b template has # upsampled copies - conv_shifted_temp_ix_b = shifted_temp_ix_b[ix_b[conv_ix]] + conv_shifted_temp_ix_b = np.atleast_1d(shifted_temp_ix_b[ix_b[conv_ix]]) upsampling_mask = ( conv_shifted_temp_ix_b[:, None] == upsampled_shifted_template_index.up_shift_temp_ix_to_shift_temp_ix[None, :] @@ -1030,8 +1035,8 @@ def coarse_approximate( This needs to tell the caller how to update its bookkeeping. """ - if not pconv.numel(): - return pconv, slice(None) + if not pconv.numel() or not coarse_approx_error_threshold: + return pconv, np.arange(len(pconv)) new_pconv = [] old_ix_to_new_ix = np.full(len(pconv), -1) diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 57f303bd..af28a958 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -168,12 +168,17 @@ def templates_at_time( unregistered_depths_um = drift_util.invert_motion_estimate( motion_est, t_s, registered_template_depths_um ) + print(f"templates_at_time {registered_template_depths_um.min()=} {registered_template_depths_um.max()=}") + print(f"templates_at_time {unregistered_depths_um.min()=} {unregistered_depths_um.max()=}") + diff = unregistered_depths_um - registered_template_depths_um + print(f"templates_at_time {diff.min()=} {diff.max()=}") # reverse arguments to pitch shifts since we are going the other direction pitch_shifts = drift_util.get_spike_pitch_shifts( depths_um=registered_template_depths_um, geom=geom, registered_depths_um=unregistered_depths_um, ) + print(f"templates_at_time {pitch_shifts.min()=} {pitch_shifts.max()=}") # extract relevant channel neighborhoods, also by reversing args to a drift helper unregistered_templates = drift_util.get_waveforms_on_static_channels( registered_templates, diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index f940ff49..5decccb8 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -52,18 +52,12 @@ def to_npz(self, npz_path): def coarsen(self): """Weighted average all templates that share a unit id and re-localize.""" # update templates - print(f"a {np.isnan(self.templates).any()=}") - print(f"b {np.equal(self.templates, 0).all(axis=(1,2)).sum()=}") unit_ids_unique, flat_ids = np.unique(self.unit_ids, return_inverse=True) templates = weighted_average(flat_ids, self.templates, self.spike_counts) - print(f"b {np.isnan(templates).any()=}") - print(f"b {np.equal(templates, 0).all(axis=(1,2)).sum()=}") # collect spike counts spike_counts = np.zeros(len(templates)) - np.add.at(spike_counts, np.arange(unit_ids_unique.size), self.spike_counts) - print(f"b {np.isnan(spike_counts).any()=}") - print(f"b {np.isnan(self.registered_geom).any()=}") + np.add.at(spike_counts, flat_ids, self.spike_counts) # re-localize registered_template_depths_um = get_template_depths( @@ -71,7 +65,9 @@ def coarsen(self): self.registered_geom, localization_radius_um=self.localization_radius_um, ) - print(f"b {np.isnan(registered_template_depths_um).any()=}") + print(f"coarsen {self.registered_geom[:,1].min()=} {self.registered_geom[:,1].max()=}") + print(f"coarsen {self.registered_template_depths_um.min()=} {self.registered_template_depths_um.max()=}") + print(f"coarsen {registered_template_depths_um.min()=} {registered_template_depths_um.max()=}") return replace( self, diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index 71ce470c..4013619f 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -220,9 +220,13 @@ def get_spike_pitch_shifts( else: probe_displacement = registered_depths_um - depths_um + print(f"get_spike_pitch_shifts {pitch=}") + print(f"get_spike_pitch_shifts {probe_displacement.min()=} {probe_displacement.max()=}") + # if probe_displacement > 0, then the registered position is below the original # and, to be conservative, round towards 0 rather than using // n_pitches_shift = (probe_displacement / pitch).astype(int) + print(f"get_spike_pitch_shifts {n_pitches_shift.min()=} {n_pitches_shift.max()=}") return n_pitches_shift @@ -543,22 +547,29 @@ def get_shift_and_unit_pairs( reg_depths_um = np.concatenate((reg_depths_um_a, reg_depths_um_b)) # figure out all shifts for all units at all times - print(f"{chunk_time_centers_s.min()=} {chunk_time_centers_s.max()=}") - print(f"{reg_depths_um.min()=} {reg_depths_um.max()=}") - print(f"{reg_depths_um_a.min()=} {reg_depths_um_a.max()=}") - print(f"{reg_depths_um_b.min()=} {reg_depths_um_b.max()=}") - print(f"{motion_est.time_bin_centers_s.min()=}") - print(f"{motion_est.time_bin_centers_s.max()=}") - print(f"{motion_est.spatial_bin_centers_um.min()=}") - print(f"{motion_est.spatial_bin_centers_um.max()=}") - unreg_depths_um = np.concatenate( + print(f"get_shift_and_unit_pairs {np.min(chunk_time_centers_s)=} {np.max(chunk_time_centers_s)=}") + print(f"get_shift_and_unit_pairs {reg_depths_um.min()=} {reg_depths_um.max()=}") + print(f"get_shift_and_unit_pairs {reg_depths_um_a.shape=} {reg_depths_um_b.shape=}") + print(f"get_shift_and_unit_pairs {reg_depths_um_a.min()=} {reg_depths_um_a.max()=}") + print(f"get_shift_and_unit_pairs {reg_depths_um_b.min()=} {reg_depths_um_b.max()=}") + print(f"get_shift_and_unit_pairs {motion_est.time_bin_centers_s.min()=}") + print(f"get_shift_and_unit_pairs {motion_est.time_bin_centers_s.max()=}") + print(f"get_shift_and_unit_pairs {motion_est.spatial_bin_centers_um.min()=}") + print(f"get_shift_and_unit_pairs {motion_est.spatial_bin_centers_um.max()=}") + unreg_depths_um = np.stack( [ - motion_est.disp_at_s(t_s, depth_um=reg_depths_um, grid=True).T + invert_motion_estimate( + motion_est, t_s, reg_depths_um + ) for t_s in chunk_time_centers_s ], axis=0, ) + print(f"get_shift_and_unit_pairs {reg_depths_um.shape=} {unreg_depths_um.shape=}") assert unreg_depths_um.shape == (len(chunk_time_centers_s), len(reg_depths_um)) + diff = reg_depths_um - unreg_depths_um + print(f"get_shift_and_unit_pairs {unreg_depths_um.min()=} {unreg_depths_um.max()=}") + print(f"get_shift_and_unit_pairs {diff.min()=} {diff.max()=}") pitch_shifts = get_spike_pitch_shifts( depths_um=reg_depths_um, pitch=get_pitch(geom), @@ -569,6 +580,9 @@ def get_shift_and_unit_pairs( else: shifts_a = pitch_shifts[:, :na] shifts_b = pitch_shifts[:, na:] + print(f"get_shift_and_unit_pairs {shifts_a.min()=} {shifts_a.max()=}") + print(f"get_shift_and_unit_pairs {shifts_b.min()=} {shifts_b.max()=}") + print(f"get_shift_and_unit_pairs {shifts_a.shape=} {shifts_b.shape=}") # assign ids to pitch/shift pairs template_shift_index_a = TemplateShiftIndex.from_shift_matrix(shifts_a) diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index 7824cd86..533b4a07 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -320,10 +320,15 @@ def depthwise_oaconv1d(input, weight, f2=None, padding=0): oa = fold_res.reshape(n1, fold_out_len) # this is the full convolution - oa = oa[:, : shape_final - pad1] + print(f"oaconv orig {oa.shape=}") + # oa = oa[:, : shape_final] + print(f"oaconv full {oa.shape=}") # extract correct padding - padding = padding + s2 - 1 - assert oa.shape[1] > 2 * padding - oa = oa[:, padding:oa.shape[1] - padding] + valid_len = s1 - s2 + 1 + valid_start = s2 - 1 + assert valid_start >= padding + oa = oa[:, valid_start - padding:valid_start + valid_len + padding] + print(f"oaconv {oa.shape=} {valid_len=} {valid_start=} {padding=}") + print(f"oaconv {(valid_start - padding)=} {(valid_start + valid_len + padding)=}") return oa From e7eb19f880ef6bde4056d4f46eb93bc4173d2439 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 28 Nov 2023 17:38:05 -0500 Subject: [PATCH 41/49] Sorta functional --- src/dartsort/localize/localize_torch.py | 2 + src/dartsort/main.py | 2 +- src/dartsort/peel/matching.py | 103 +++++++++-------- src/dartsort/templates/pairwise.py | 19 +--- src/dartsort/templates/pairwise_util.py | 145 +++--------------------- src/dartsort/templates/template_util.py | 5 - src/dartsort/templates/templates.py | 8 +- src/dartsort/util/drift_util.py | 21 +--- src/dartsort/util/spiketorch.py | 6 +- 9 files changed, 77 insertions(+), 234 deletions(-) diff --git a/src/dartsort/localize/localize_torch.py b/src/dartsort/localize/localize_torch.py index 56a35024..42741034 100644 --- a/src/dartsort/localize/localize_torch.py +++ b/src/dartsort/localize/localize_torch.py @@ -103,6 +103,8 @@ def localize_amplitude_vectors( geom_pad = F.pad(geom, (0, 0, 0, 1)) local_geoms = geom_pad[channel_index[main_channels]] local_geoms[:, :, 1] -= geom[main_channels, 1][:, None] + print(f"{amplitude_vectors.shape=}") + print(f"{local_geoms.shape=}") # center of mass initialization com = torch.divide( diff --git a/src/dartsort/main.py b/src/dartsort/main.py index b441cc52..f05a328c 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -169,7 +169,7 @@ def _run_peeler( ) # do localization - if featurization_config.do_localization: + if not featurization_config.denoise_only and featurization_config.do_localization: wf_name = featurization_config.output_waveforms_name localize_hdf5( output_hdf5_filename, diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index 48abc064..24e56859 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -285,11 +285,16 @@ def build_template_data( self.handle_template_groups( coarse_template_data.unit_ids, self.template_data.unit_ids ) + convlen = self.chunk_length_samples + self.chunk_margin_samples + block_size, *_ = spiketorch._calc_oa_lens(convlen, self.spike_length_samples) + self.register_buffer("objective_temporalf", torch.fft.rfft(self.objective_temporal_components, dim=1, n=block_size)) half_chunk = self.chunk_length_samples // 2 - chunk_centers_samples = np.arange( - half_chunk, self.recording.get_num_samples(), self.chunk_length_samples + chunk_starts = np.arange( + 0, self.recording.get_num_samples(), self.chunk_length_samples ) + chunk_ends = np.minimum(chunk_starts + self.chunk_length_samples, self.recording.get_num_samples()) + chunk_centers_samples = (chunk_starts + chunk_ends) / 2 chunk_centers_s = self.recording._recording_segments[0].sample_index_to_time( chunk_centers_samples ) @@ -392,6 +397,7 @@ def peel_chunk( left_margin=0, right_margin=0, return_residual=False, + return_conv=False, ): # get current template set chunk_center_samples = chunk_start_samples + self.chunk_length_samples // 2 @@ -404,11 +410,12 @@ def peel_chunk( match_results = self.match_chunk( traces, compressed_template_data, - trough_offset_samples=42, - left_margin=0, - right_margin=0, - threshold=30, + trough_offset_samples=self.trough_offset_samples, + left_margin=left_margin, + right_margin=right_margin, + threshold=self.threshold, return_residual=return_residual, + return_conv=return_conv, ) # process spike times and create return result @@ -459,15 +466,17 @@ def templates_at_time(self, t_s): pconvdb = pconvdb.at_shifts(pitch_shifts_a, pitch_shifts_b) else: cur_spatial = self.spatial_components + cur_obj_spatial = self.objective_spatial_components max_channels = self.registered_template_ampvecs.argmax(1) if not pconvdb._is_torch: - pconvdb.to(cur_spatial.device) + pconvdb.to(cur_obj_spatial.device) return MatchingTemplateData( objective_spatial_components=cur_obj_spatial, objective_singular_values=self.objective_singular_values, objective_temporal_components=self.objective_temporal_components, + objective_temporalf=self.objective_temporalf, fine_to_coarse=self.fine_to_coarse, coarse_objective=self.coarse_objective, spatial_components=cur_spatial, @@ -477,7 +486,7 @@ def templates_at_time(self, t_s): compressed_upsampling_index=self.compressed_upsampling_index, compressed_index_to_upsampling_index=self.compressed_index_to_upsampling_index, compressed_upsampled_temporal=self.compressed_upsampled_temporal, - max_channels=max_channels, + max_channels=torch.as_tensor(max_channels, device=cur_obj_spatial.device), pairwise_conv_db=pconvdb, ) @@ -490,12 +499,12 @@ def match_chunk( right_margin=0, threshold=30, return_residual=False, + return_conv=False, ): """Core peeling routine for subtraction""" # initialize residual, it needs to be padded to support our channel # indexing convention (used later to extract small channel # neighborhoods). this copies the input. - print(f"match_chunk {traces.shape=}") residual_padded = F.pad(traces, (0, 1), value=torch.nan) residual = residual_padded[:, :-1] @@ -517,7 +526,6 @@ def match_chunk( refrac_mask = torch.zeros_like(padded_objective) # padded objective has an extra unit (for group_index) and refractory # padding (for easier implementation of enforce_refractory) - neg_temp_normsq = -compressed_template_data.objective_template_norms_squared[:, None] # manages buffers for spike train data (peak times, labels, etc) peaks = MatchingPeaks(device=traces.device) @@ -528,28 +536,13 @@ def match_chunk( ) # main loop - print("start") for it in range(self.max_iter): - # update the coarse objective - torch.add( - neg_temp_normsq, padded_conv, alpha=2.0, out=padded_objective[:-1] - ) - # find high-res peaks - print("before find") new_peaks = self.find_peaks( residual, padded_conv, padded_objective, refrac_mask, compressed_template_data ) if new_peaks is None: break - # print("----------") - # if not it % 1: - # if new_peaks.n_spikes > 1: - # print( - # f"{it=} {new_peaks.n_spikes=} {new_peaks.scores.mean().numpy(force=True)=} {torch.diff(new_peaks.times).min()=}" - # ) - # tq = new_peaks.times.numpy(force=True) - # print(f"{np.diff(tq).min()=} {tq=}") # enforce refractoriness self.enforce_refractory( @@ -579,7 +572,8 @@ def match_chunk( # new_norm = torch.linalg.norm(residual) ** 2 # print(f"{it=} {new_norm=}") # print(f"{(new_norm-old_norm)=}") - # print(f"{new_peaks.scores.sum().numpy(force=True)=}") + # print(f"{new_peaks.n_spikes=}") + # print(f"{new_peaks.scores.mean().numpy(force=True)=}") # print("----------") # update spike train @@ -607,6 +601,8 @@ def match_chunk( ) if return_residual: res["residual"] = residual + if return_conv: + res["conv"] = padded_conv return res def find_peaks( @@ -617,6 +613,14 @@ def find_peaks( refrac_mask, compressed_template_data, ): + # update the coarse objective + torch.add( + compressed_template_data.objective_template_norms_squared.neg()[:, None], + padded_conv, + alpha=2.0, + out=padded_objective[:-1], + ) + # first step: coarse peaks. not temporally upsampled or amplitude-scaled. objective = (padded_objective + refrac_mask)[ :-1, self.obj_pad_len : -self.obj_pad_len @@ -624,8 +628,6 @@ def find_peaks( # formerly used detect_and_deduplicate, but that was slow. objective_max, max_obj_template = objective.max(dim=0) times = argrelmax(objective_max, self.spike_length_samples, self.threshold) - # tt = times.numpy(force=True) - # print(f"{np.diff(tt).min()=} {tt=}") obj_template_indices = max_obj_template[times] # remove peaks inside the padding if not times.numel(): @@ -648,7 +650,7 @@ def find_peaks( template_indices, scores, ) = compressed_template_data.fine_match( - padded_conv[obj_template_indices, times], + padded_conv[obj_template_indices, times + self.obj_pad_len], objective_max[times], residual_snips, obj_template_indices, @@ -659,7 +661,7 @@ def find_peaks( ) if time_shifts is not None: times += time_shifts - + return MatchingPeaks( n_spikes=times.numel(), times=times, @@ -697,6 +699,7 @@ class MatchingTemplateData: objective_spatial_components: torch.Tensor objective_singular_values: torch.Tensor objective_temporal_components: torch.Tensor + objective_temporalf: torch.Tensor fine_to_coarse: torch.LongTensor coarse_objective: bool spatial_components: torch.Tensor @@ -758,24 +761,23 @@ def convolve(self, traces, padding=0, out=None): ) else: assert out.shape == (self.objective_n_templates, out_len) - print(f"convolve {traces.shape=} {out.shape=} {out_len=} {padding=}") for q in range(self.rank): # units x time rec_spatial = self.objective_spatial_singular[:, q, :] @ traces.T # convolve with temporal components -- units x time temporal = self.objective_temporal_components[:, :, q] + temporalf = self.objective_temporalf[:, :, q] # conv1d with groups! only convolve each unit with its own temporal filter. - # conv = F.conv1d( - # rec_spatial[None], - # temporal[:, None, :], - # groups=self.n_templates, - # padding=padding, - # )[0] - conv = spiketorch.depthwise_oaconv1d( - rec_spatial, temporal[:, :], padding=padding - ) - print(f"{conv.shape=}") + conv = F.conv1d( + rec_spatial[None], + temporal[:, None, :], + groups=self.objective_n_templates, + padding=padding, + )[0] + # conv = spiketorch.depthwise_oaconv1d( + # rec_spatial, temporal, padding=padding, f2=temporalf + # ) if q: out += conv else: @@ -803,7 +805,6 @@ def subtract_conv( ) ix_template = template_indices_a[:, None] ix_time = times[:, None] + (conv_pad_len + self.conv_lags)[None, :] - print(f"{ix_template.shape=} {ix_time.shape=} {scalings.shape=} {pconvs.shape=}") spiketorch.add_at_( conv, (ix_template, ix_time), @@ -858,13 +859,12 @@ def fine_match( spike_length_samples = window_length_samples - 1 # grab the current traces snips = residual_snips[:, 1:] - # unpack the current traces and the traces one step back - snips_dt = F.unfold( - residual_snips[:, None, :, :], (spike_length_samples, snips.shape[2]) - ) - snips_dt = snips_dt.reshape( - len(snips), spike_length_samples, snips.shape[2], 2 - ) + # snips_dt = F.unfold( + # residual_snips[:, None, :, :], (spike_length_samples, snips.shape[2]) + # ) + # snips_dt = snips_dt.reshape( + # len(snips), spike_length_samples, snips.shape[2], 2 + # ) if self.coarse_objective: # TODO best I came up with, but it still syncs @@ -900,6 +900,10 @@ def fine_match( objs = 2 * scalings * b - torch.square(scalings) * a - inv_lambda return None, None, scalings, template_indices, objs + # unpack the current traces and the traces one step back + snips_prev = residual_snips[:, :-1] + snips_dt = torch.stack((snips_prev, snips), dim=3) + # now, upsampling # repeat the superres logic, the comp up index acts the same comp_up_ix = self.compressed_upsampling_index[template_indices] @@ -924,7 +928,6 @@ def fine_match( 2 * scalings * b - torch.square(scalings) * a - inv_lambda ) else: - print(f"{objs.shape=} {objs[dup_ix, column_ix].shape=} {convs.shape=} {norms.shape=}") objs[dup_ix, column_ix] = 2 * convs - norms[:, None] scalings = None objs, best_column_dt_ix = objs.reshape(len(objs), comp_up_ix.shape[1] * 2).max(dim=1) diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index 67f689da..eba3dd86 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -86,7 +86,6 @@ def from_template_data( n_jobs=0, show_progress=True, ): - print(f"pairwise from_template_data {device=}") compressed_convolve_to_h5( hdf5_filename, template_data=template_data, @@ -124,22 +123,9 @@ def at_shifts(self, shifts_a=None, shifts_b=None): n_shifted_temps_a, n_up_shifted_temps_b = self.pconv_index.shape # active shifted and upsampled indices + print(f"{self.shifts_a=} {shifts_a.min()=} {shifts_a.max()=}") shift_ix_a = np.searchsorted(self.shifts_a, shifts_a) shift_ix_b = np.searchsorted(self.shifts_b, shifts_b) - print( - f"at_shifts {self.shifts_a.shape=} {self.shifts_a.min()=} {self.shifts_a.max()=}" - ) - print(f"at_shifts {shifts_a.shape=} {shifts_a.min()=} {shifts_a.max()=}") - print(f"{shift_ix_a.shape=} {shift_ix_a.min()=} {shift_ix_a.max()=}") - - print( - f"at_shifts {self.shifts_b.shape=} {self.shifts_b.min()=} {self.shifts_b.max()=}" - ) - print(f"at_shifts {shifts_b.shape=} {shifts_b.min()=} {shifts_b.max()=}") - print(f"at_shifts {shift_ix_b.shape=} {shift_ix_b.min()=} {shift_ix_b.max()=}") - - print(f"at_shifts {self.shifted_template_index_a.shape=}") - print(f"at_shifts {self.upsampled_shifted_template_index_b.shape=}") sub_shifted_temp_index_a = self.shifted_template_index_a[ np.arange(len(self.shifted_template_index_a))[:, None], shift_ix_a[:, None], @@ -148,8 +134,6 @@ def at_shifts(self, shifts_a=None, shifts_b=None): np.arange(len(self.upsampled_shifted_template_index_b))[:, None], shift_ix_b[:, None], ] - print(f"at_shifts {sub_shifted_temp_index_a.shape=}") - print(f"at_shifts {sub_up_shifted_temp_index_b.shape=}") # in flat form for indexing into pconv_index. also, reindex. valid_a = sub_shifted_temp_index_a < n_shifted_temps_a @@ -262,7 +246,6 @@ def query( template_indices_a, template_indices_b ).T if scalings_b is not None: - print(f"{scalings_b.shape=} {pconv_indices.shape=}") scalings_b = torch.broadcast_to(scalings_b[None], pconv_indices.shape).reshape(-1) if times_b is not None: times_b = torch.broadcast_to(times_b[None], pconv_indices.shape).reshape(-1) diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index d20e7a4c..9f0f45eb 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -67,7 +67,6 @@ def compressed_convolve_to_h5( template_data_b=template_data_b, motion_est=motion_est, ) - print(f"compressed_convolve_to_h5 {conv_batch_size=} {units_batch_size=} {device=}") chunk_res_iterator = iterate_compressed_pairwise_convolutions( template_data_a=template_data, @@ -193,9 +192,6 @@ def iterate_compressed_pairwise_convolutions( assert np.array_equal(reg_geom, template_data_b.registered_geom) # construct drift-related helper data if needed - print( - f"iterate_compressed_pairwise_convolutions {conv_batch_size=} {units_batch_size=} {device=}" - ) geom_kdtree = reg_geom_kdtree = match_distance = None if do_shifting: geom_kdtree = KDTree(geom) @@ -315,12 +311,6 @@ def compressed_convolve_pairs( shifts, superres templates, and upsamples. Some of these may be zero or may be duplicates, so the return value is a sparse representation. See below. """ - # print(f"compressed_convolve_pairs {device=}") - # print(f"{units_a.shape=}") - # print(f"{units_b.shape=}") - # print(f"{(units_a.size * units_b.size)=}") - # print(f"compressed_convolve_pairs {batch_size=} {units_a.size=} {device=}") - # what pairs, shifts, etc are we convolving? shifted_temp_ix_a, temp_ix_a, shift_a, unit_a = handle_shift_indices( units_a, template_data_a.unit_ids, template_shift_index_a @@ -328,8 +318,6 @@ def compressed_convolve_pairs( shifted_temp_ix_b, temp_ix_b, shift_b, unit_b = handle_shift_indices( units_b, template_data_b.unit_ids, template_shift_index_b ) - # print(f"0 {shifted_temp_ix_a.shape=} {(shifted_temp_ix_a.size / np.unique(unit_a).size)=}") - # print(f"0 {shifted_temp_ix_b.shape=} {(shifted_temp_ix_b.size / np.unique(unit_b).size)=}") # get (shifted) spatial components * singular values spatial_singular_a = get_shifted_spatial_singular( @@ -354,9 +342,6 @@ def compressed_convolve_pairs( match_distance=match_distance, device=device, ) - # print(f"{low_rank_templates.spatial_components.dtype=} {low_rank_templates.singular_values.dtype=}") - # print(f"{compressed_upsampled_temporal.compressed_upsampled_templates.dtype=}") - # print(f"{spatial_singular_a.dtype=} {spatial_singular_b.dtype=}") # figure out pairs of shifted templates to convolve in a deduplicated way pairs_ret = shift_deduplicated_pairs( @@ -378,17 +363,7 @@ def compressed_convolve_pairs( if pairs_ret is None: return None ix_a, ix_b, compression_index, conv_ix, spatial_shift_ids = pairs_ret - # print(f"A {ix_a.shape=}") - # print(f"A {ix_b.shape=}") - # print(f"A {compression_index.shape=}") - # print(f"A {conv_ix.shape=}") - # print(f"A {spatial_shift_ids.shape=}") - - # print(f"-----------") - # print(f"after pairs {conv_ix.shape=} {compression_index.shape=}") - # print(f"{compression_index.min()=} {compression_index.max()=}") - # print(f"{ix_a.shape=} {ix_b.shape=}") - + # handle upsampling # each pair will be duplicated by the b unit's number of upsampled copies ( @@ -409,30 +384,9 @@ def compressed_convolve_pairs( ) ix_a = ix_a[compression_dup_ix] spatial_shift_ids = spatial_shift_ids[compression_dup_ix] - # print(f"B {ix_a.shape=}") - # print(f"B {ix_b.shape=}") - # print(f"B {compression_index.shape=}") - # print(f"B {conv_ix.shape=}") - - # print(f"-----------") - # print(f"after up {conv_ix.shape=} {compression_index.shape=}") - # print(f"{compression_index.min()=} {compression_index.max()=}") - # print(f"{ix_a.shape=} {ix_b.shape=}") - - # # now, these arrays all have length n_pairs - # shifted_temp_ix_a = shifted_temp_ix_a[ix_a] - # temp_ix_a = temp_ix_a[ix_a] - # shift_a = shift_a[ix_a] - # shifted_temp_ix_b = shifted_temp_ix_b[ix_b] - # temp_ix_b = temp_ix_b[ix_b] - # shift_b = shift_b[ix_b] # run convolutions temporal_a = low_rank_templates_a.temporal_components[temp_ix_a] - # print(f"{spatial_singular_a[ix_a[conv_ix]].shape=}") - # print(f"{spatial_singular_b[ix_b[conv_ix]].shape=}") - # print(f"{temporal_a[ix_a[conv_ix]].shape=}") - # print(f"{conv_temporal_components_up_b.shape=}") pconv, kept = correlate_pairs_lowrank( torch.as_tensor(spatial_singular_a[ix_a[conv_ix]], device=device), torch.as_tensor(spatial_singular_b[ix_b[conv_ix]], device=device), @@ -442,8 +396,6 @@ def compressed_convolve_pairs( conv_ignore_threshold=conv_ignore_threshold, batch_size=batch_size, ) - # print(f"-----------") - # print(f"after corr {pconv.shape=} {conv_ix[kept].shape=}") if kept is not None: conv_ix = conv_ix[kept] if not conv_ix.shape[0]: @@ -455,29 +407,16 @@ def compressed_convolve_pairs( ix_b = ix_b[kept_pairs] spatial_shift_ids = spatial_shift_ids[kept_pairs] assert pconv.numel() > 0 - # print(f"-----------") - # print(f"after searchsorted {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}") - # print(f"{compression_index.min()=} {compression_index.max()=}") - # print(f"{ix_a.shape=} {ix_b.shape=}") # coarse approx - # print(f"-----------") - # print(f"before approx {pconv.shape=} {conv_ix.shape=} {compression_index.shape=}") pconv, old_ix_to_new_ix = coarse_approximate( pconv, unit_a[ix_a[conv_ix]], unit_b[ix_b[conv_ix]], temp_ix_a[ix_a[conv_ix]], - # shift_a[ix_a[conv_ix]], - # shift_b[ix_b[conv_ix]], spatial_shift_ids[conv_ix], coarse_approx_error_threshold=coarse_approx_error_threshold, ) - # print(f"-----------") - # print(f"after approx") - # print(f"{pconv.shape=} {conv_ix.shape=} {old_ix_to_new_ix.shape=} {compression_index.shape=}") - # print(f"{compression_index.min()=} {compression_index.max()=}") - # print(f"{old_ix_to_new_ix.min()=} {old_ix_to_new_ix.max()=}") compression_index = old_ix_to_new_ix[compression_index] # above function invalidates the whole idea of conv_ix del conv_ix @@ -541,8 +480,7 @@ def correlate_pairs_lowrank( assert n_pairs == n_pairs_ assert t == t_ assert rank == rank_ - # print(f"compressed_convolve_pairs {batch_size=} {n_pairs=} {spatial_a.device=}") - + if max_shift == "full": max_shift = t - 1 elif max_shift == "valid": @@ -554,7 +492,6 @@ def correlate_pairs_lowrank( pconv = torch.zeros( (n_pairs, 2 * max_shift + 1), dtype=spatial_a.dtype, device=spatial_a.device ) - # print(f"compressed_convolve_pairs {pconv.shape=}") for istart in range(0, n_pairs, batch_size): iend = min(istart + batch_size, n_pairs) ix = slice(istart, iend) @@ -582,10 +519,8 @@ def correlate_pairs_lowrank( kept = max_val > conv_ignore_threshold pconv = pconv[kept] kept = np.flatnonzero(kept.numpy(force=True)) - # print(f"compressed_convolve_pairs {pconv.shape=} {kept.shape=}") else: kept = None - # print(f"compressed_convolve_pairs {pconv.shape=} {kept=}") return pconv, kept @@ -721,7 +656,6 @@ def shift_deduplicated_pairs( pair = chan_amp_a @ chan_amp_b.T pair = pair > conv_ignore_threshold pair = pair.cpu() - # print(f"___ after overlaps {pair.sum()=}") # co-occurrence cooccurrence = cooccurrence[ @@ -729,13 +663,11 @@ def shift_deduplicated_pairs( shifted_temp_ix_b[None, :], ] pair *= torch.as_tensor(cooccurrence, device=pair.device) - # print(f"___ after cooccur {pair.sum()=}") pair_ix_a, pair_ix_b = torch.nonzero(pair, as_tuple=True) nco = pair_ix_a.numel() if not nco: return None - # print(f"___ {nco=}") # if no shifting, deduplication is the identity do_shifting = reg_geom_kdtree is not None @@ -778,10 +710,6 @@ def shift_deduplicated_pairs( chanset_b, active_chan_ids_b = np.unique( active_chans_b, axis=0, return_inverse=True ) - # print(f"___ {chanset_a.sum(1)=}") - # print(f"___ {chanset_b.sum(1)=}") - # print(f"___ {active_chan_ids_a.shape=} {np.unique(active_chan_ids_a).shape=}") - # print(f"___ {active_chan_ids_b.shape=} {np.unique(active_chan_ids_b).shape=}") # 3 temp_ix_a = temp_ix_a[pair_ix_a] @@ -789,13 +717,6 @@ def shift_deduplicated_pairs( # get the relative shifts shift_a = shift_a[pair_ix_a] shift_b = shift_b[pair_ix_b] - # print(f"{temp_ix_a=}") - # print(f"{shift_a=}") - # print(f"{active_chan_ids_a[pair_ix_a]=}") - # print(f"{temp_ix_b=}") - # print(f"{shift_b=}") - # print(f"{active_chan_ids_b[pair_ix_b]=}") - # print(f"{shift_diff=}") # figure out combinations _, spatial_shift_ids = np.unique( @@ -812,7 +733,6 @@ def shift_deduplicated_pairs( temp_ix_b, spatial_shift_ids, ] - # print(f"{conv_determiners=}") # conv_ix: indices of unique determiners # compression_index: which representative does each pair belong to _, conv_ix, compression_index = np.unique( @@ -910,7 +830,8 @@ def compressed_upsampled_pairs( convolutions between templates ix_a[i], ix_b[i] equal that between templates ix_a[conv_ix[compression_index[i]]], ix_b[conv_ix[compression_index[i]]]. - We will upsample the templates in the RHS (b) in a compressed way. + We will upsample the templates in the RHS (b) in a compressed way, so that + each b index gets its own number of duplicates. """ up_factor = compressed_upsampled_temporal.compressed_upsampling_map.shape[1] compression_dup_ix = slice(None) @@ -930,61 +851,25 @@ def compressed_upsampled_pairs( ) conv_up_i, up_shift_up_i = np.nonzero(upsampling_mask) conv_compressed_upsampled_ix = ( - upsampled_shifted_template_index.up_shift_temp_ix_to_shift_temp_ix[ + upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[ up_shift_up_i ] ) - conv_ix_up = conv_ix[conv_up_i] + conv_dup = conv_ix[conv_up_i] # And, all ix_{a,b}[i] such that compression_ix[i] lands in # that conv_ix need to be duplicated as well. - dup_mask = conv_ix[compression_index][:, None] == conv_ix_up[None, :] + dup_mask = conv_ix[compression_index][:, None] == conv_dup[None, :] if torch.is_tensor(dup_mask): dup_mask = dup_mask.numpy(force=True) compression_dup_ix, compression_index_up = np.nonzero(dup_mask) ix_b_up = ix_b[compression_dup_ix] - # ix_a_up = np.repeat(ix_a, ndups) - # ix_b_up = np.repeat(ix_b, ndups) - - # ix_a_up = np.zeros(len(ix_a) * up_factor, dtype=int) - # ix_b_up = np.zeros(len(ix_a) * up_factor, dtype=int) - # compression_index_up = np.zeros(len(ix_a) * up_factor, dtype=int) - # conv_ix_up = np.zeros(len(conv_ix) * up_factor, dtype=int) - # conv_compressed_upsampled_ix = np.zeros(len(conv_ix) * up_factor, dtype=int) - # cur_dedup_ix = 0 - # cur_pair_ix = 0 - # cur_conv_up_ix = 0 - # for i, convi in enumerate(conv_ix): - # # get b's shifted template ix - # conv_shifted_temp_ix_b = shifted_temp_ix_b[ix_b[convi]] - - # # which compressed upsampled indices match this? - # which_up = np.flatnonzero( - # upsampled_shifted_template_index.up_shift_temp_ix_to_shift_temp_ix - # == conv_shifted_temp_ix_b - # ) - # conv_comp_up_ix = ( - # upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[which_up] - # ) - - # # which deduplication indices map ix_a,b to this convi? - # which_dedup = np.flatnonzero(compression_index == i) - - # # extend arrays with new indices - # nupi = conv_comp_up_ix.size - # n_new_pair = which_dedup.size * nupi - # ix_a_up[cur_pair_ix:cur_pair_ix+n_new_pair] = np.repeat(ix_a[which_dedup], nupi) - # ix_b_up[cur_pair_ix:cur_pair_ix+n_new_pair] = np.repeat(ix_b[which_dedup], nupi) - # conv_ix_up[cur_dedup_ix:cur_dedup_ix+nupi]=convi - # conv_compressed_upsampled_ix[cur_dedup_ix:cur_dedup_ix+nupi]=conv_comp_up_ix - # compression_index_up[cur_pair_ix:cur_pair_ix+n_new_pair]=np.tile(np.arange(cur_dedup_ix, cur_dedup_ix + nupi), which_dedup.size) - # cur_pair_ix += n_new_pair - # cur_dedup_ix += nupi - - # ix_a_up = ix_a_up[:cur_pair_ix] #np.array(ix_a_up) - # ix_b_up = ix_b_up[:cur_pair_ix] #np.array(ix_b_up) - # conv_compressed_upsampled_ix = conv_compressed_upsampled_ix[:cur_pair_ix] #np.array(conv_compressed_upsampled_ix) - # compression_index_up = compression_index_up[:cur_dedup_ix] #np.array(compression_index_up) - # conv_ix_up = conv_ix_up[:cur_dedup_ix] #np.array(conv_ix_up) + + # the conv ix need to be offset to keep the relation with the pairs + # ix_a[old i] + # offsets = np.cumsum((conv_ix[:, None] == conv_dup[None, :]).sum(0)) + # offsets -= offsets[0] + _, offsets = np.unique(compression_dup_ix, return_index=True) + conv_ix_up = offsets[conv_dup] # which upsamples and which templates? conv_upsampling_indices_b = ( @@ -997,7 +882,7 @@ def compressed_upsampled_pairs( conv_compressed_upsampled_ix ] ) - + return ( ix_b_up, compression_index_up, diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index af28a958..57f303bd 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -168,17 +168,12 @@ def templates_at_time( unregistered_depths_um = drift_util.invert_motion_estimate( motion_est, t_s, registered_template_depths_um ) - print(f"templates_at_time {registered_template_depths_um.min()=} {registered_template_depths_um.max()=}") - print(f"templates_at_time {unregistered_depths_um.min()=} {unregistered_depths_um.max()=}") - diff = unregistered_depths_um - registered_template_depths_um - print(f"templates_at_time {diff.min()=} {diff.max()=}") # reverse arguments to pitch shifts since we are going the other direction pitch_shifts = drift_util.get_spike_pitch_shifts( depths_um=registered_template_depths_um, geom=geom, registered_depths_um=unregistered_depths_um, ) - print(f"templates_at_time {pitch_shifts.min()=} {pitch_shifts.max()=}") # extract relevant channel neighborhoods, also by reversing args to a drift helper unregistered_templates = drift_util.get_waveforms_on_static_channels( registered_templates, diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 5decccb8..1abeee9d 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -65,9 +65,6 @@ def coarsen(self): self.registered_geom, localization_radius_um=self.localization_radius_um, ) - print(f"coarsen {self.registered_geom[:,1].min()=} {self.registered_geom[:,1].max()=}") - print(f"coarsen {self.registered_template_depths_um.min()=} {self.registered_template_depths_um.max()=}") - print(f"coarsen {registered_template_depths_um.min()=} {registered_template_depths_um.max()=}") return replace( self, @@ -179,11 +176,10 @@ def from_config( # main! results = get_templates(recording, sorting, **kwargs) - print( - f"{[(k,v.dtype) for k,v in results.items() if (isinstance(v, np.ndarray))]=}" - ) # handle registered templates + print(f"{results['templates'].shape=}") + print(f"{kwargs['registered_geom'].shape=}") if template_config.registered_templates: registered_template_depths_um = get_template_depths( results["templates"], diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index 4013619f..0bf4464a 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -220,13 +220,9 @@ def get_spike_pitch_shifts( else: probe_displacement = registered_depths_um - depths_um - print(f"get_spike_pitch_shifts {pitch=}") - print(f"get_spike_pitch_shifts {probe_displacement.min()=} {probe_displacement.max()=}") - # if probe_displacement > 0, then the registered position is below the original # and, to be conservative, round towards 0 rather than using // n_pitches_shift = (probe_displacement / pitch).astype(int) - print(f"get_spike_pitch_shifts {n_pitches_shift.min()=} {n_pitches_shift.max()=}") return n_pitches_shift @@ -547,15 +543,6 @@ def get_shift_and_unit_pairs( reg_depths_um = np.concatenate((reg_depths_um_a, reg_depths_um_b)) # figure out all shifts for all units at all times - print(f"get_shift_and_unit_pairs {np.min(chunk_time_centers_s)=} {np.max(chunk_time_centers_s)=}") - print(f"get_shift_and_unit_pairs {reg_depths_um.min()=} {reg_depths_um.max()=}") - print(f"get_shift_and_unit_pairs {reg_depths_um_a.shape=} {reg_depths_um_b.shape=}") - print(f"get_shift_and_unit_pairs {reg_depths_um_a.min()=} {reg_depths_um_a.max()=}") - print(f"get_shift_and_unit_pairs {reg_depths_um_b.min()=} {reg_depths_um_b.max()=}") - print(f"get_shift_and_unit_pairs {motion_est.time_bin_centers_s.min()=}") - print(f"get_shift_and_unit_pairs {motion_est.time_bin_centers_s.max()=}") - print(f"get_shift_and_unit_pairs {motion_est.spatial_bin_centers_um.min()=}") - print(f"get_shift_and_unit_pairs {motion_est.spatial_bin_centers_um.max()=}") unreg_depths_um = np.stack( [ invert_motion_estimate( @@ -565,11 +552,8 @@ def get_shift_and_unit_pairs( ], axis=0, ) - print(f"get_shift_and_unit_pairs {reg_depths_um.shape=} {unreg_depths_um.shape=}") assert unreg_depths_um.shape == (len(chunk_time_centers_s), len(reg_depths_um)) diff = reg_depths_um - unreg_depths_um - print(f"get_shift_and_unit_pairs {unreg_depths_um.min()=} {unreg_depths_um.max()=}") - print(f"get_shift_and_unit_pairs {diff.min()=} {diff.max()=}") pitch_shifts = get_spike_pitch_shifts( depths_um=reg_depths_um, pitch=get_pitch(geom), @@ -580,9 +564,8 @@ def get_shift_and_unit_pairs( else: shifts_a = pitch_shifts[:, :na] shifts_b = pitch_shifts[:, na:] - print(f"get_shift_and_unit_pairs {shifts_a.min()=} {shifts_a.max()=}") - print(f"get_shift_and_unit_pairs {shifts_b.min()=} {shifts_b.max()=}") - print(f"get_shift_and_unit_pairs {shifts_a.shape=} {shifts_b.shape=}") + print(f"{shifts_a.min()=} {shifts_a.max()=}") + print(f"{shifts_b.min()=} {shifts_b.max()=}") # assign ids to pitch/shift pairs template_shift_index_a = TemplateShiftIndex.from_shift_matrix(shifts_a) diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index 533b4a07..aeb305b9 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -289,7 +289,7 @@ def depthwise_oaconv1d(input, weight, f2=None, padding=0): s2 = weight.shape[1] assert s1 >= s2 - shape_final = s1 + s2 - 1 + # shape_full = s1 + s2 - 1 block_size, overlap, in1_step, in2_step = _calc_oa_lens(s1, s2) nstep1, pad1, nstep2, pad2 = steps_and_pad( s1, in1_step, s2, in2_step, block_size, overlap @@ -320,15 +320,11 @@ def depthwise_oaconv1d(input, weight, f2=None, padding=0): oa = fold_res.reshape(n1, fold_out_len) # this is the full convolution - print(f"oaconv orig {oa.shape=}") # oa = oa[:, : shape_final] - print(f"oaconv full {oa.shape=}") # extract correct padding valid_len = s1 - s2 + 1 valid_start = s2 - 1 assert valid_start >= padding oa = oa[:, valid_start - padding:valid_start + valid_len + padding] - print(f"oaconv {oa.shape=} {valid_len=} {valid_start=} {padding=}") - print(f"oaconv {(valid_start - padding)=} {(valid_start + valid_len + padding)=}") return oa From e6a60c5579db063bbdef210210e1f151838cdda9 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 29 Nov 2023 18:03:47 -0500 Subject: [PATCH 42/49] Matching increasingly working --- src/dartsort/peel/matching.py | 43 +++++++--- src/dartsort/templates/pairwise.py | 121 ++++++++++++++++------------- src/dartsort/vis/scatterplots.py | 16 ++-- 3 files changed, 112 insertions(+), 68 deletions(-) diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index 24e56859..91b48332 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -112,10 +112,10 @@ def __init__( self.match_distance = pdist(self.geom).min() / 2.0 # some parts of this constructor are deferred to precompute_peeling_data - self._peeling_needs_fit = True + self._needs_precompute = True def peeling_needs_fit(self): - return self._peeling_needs_fit + return self._needs_precompute def precompute_peeling_data(self, save_folder, n_jobs=0, device=None): self.build_template_data( @@ -138,7 +138,7 @@ def precompute_peeling_data(self, save_folder, n_jobs=0, device=None): ) self.register_buffer("_rank_ix", torch.arange(self.svd_compression_rank)) self.check_shapes() - self._peeling_needs_fit = False + self._needs_precompute = False def out_datasets(self): datasets = super().out_datasets() @@ -426,6 +426,7 @@ def peel_chunk( def templates_at_time(self, t_s): """Handle drift -- grab the right spatial neighborhoods.""" pconvdb = self.pairwise_conv_db + pitch_shifts_a=pitch_shifts_b=None if self.is_drifting: pitch_shifts_b, cur_spatial = template_util.templates_at_time( t_s, @@ -463,14 +464,17 @@ def templates_at_time(self, t_s): fill_value=0.0, ) max_channels = cur_ampvecs[:, 0, :].argmax(1) - pconvdb = pconvdb.at_shifts(pitch_shifts_a, pitch_shifts_b) + # pconvdb = pconvdb.at_shifts(pitch_shifts_a, pitch_shifts_b) + pitch_shifts_a = torch.as_tensor(pitch_shifts_a, device=cur_obj_spatial.device) + pitch_shifts_b = torch.as_tensor(pitch_shifts_b, device=cur_obj_spatial.device) else: cur_spatial = self.spatial_components cur_obj_spatial = self.objective_spatial_components max_channels = self.registered_template_ampvecs.argmax(1) - if not pconvdb._is_torch: - pconvdb.to(cur_obj_spatial.device) + # if not pconvdb._is_torch: + # # pconvdb.to("cpu") + pconvdb.to(cur_obj_spatial.device) return MatchingTemplateData( objective_spatial_components=cur_obj_spatial, @@ -488,6 +492,8 @@ def templates_at_time(self, t_s): compressed_upsampled_temporal=self.compressed_upsampled_temporal, max_channels=torch.as_tensor(max_channels, device=cur_obj_spatial.device), pairwise_conv_db=pconvdb, + shifts_a=pitch_shifts_a, + shifts_b=pitch_shifts_b, ) def match_chunk( @@ -548,6 +554,7 @@ def match_chunk( self.enforce_refractory( refrac_mask, new_peaks.times + self.obj_pad_len, + new_peaks.objective_template_indices, new_peaks.template_indices, ) @@ -665,21 +672,22 @@ def find_peaks( return MatchingPeaks( n_spikes=times.numel(), times=times, + objective_template_indices=obj_template_indices, template_indices=template_indices, upsampling_indices=upsampling_indices, scalings=scalings, scores=scores, ) - def enforce_refractory(self, objective, times, template_indices): + def enforce_refractory(self, objective, times, objective_template_indices, template_indices): if not times.numel(): return # overwrite objective with -inf to enforce refractoriness time_ix = times[:, None] + self._refrac_ix[None, :] if not self.grouped_temps: - row_ix = template_indices[:, None, None] + row_ix = template_indices[:, None] elif self.coarse_objective: - row_ix = self.fine_to_coarse[template_indices][:, None, None] + row_ix = objective_template_indices[:, None] elif self.grouped_temps: row_ix = self.group_index[template_indices] else: @@ -711,6 +719,8 @@ class MatchingTemplateData: compressed_upsampled_temporal: torch.Tensor max_channels: torch.LongTensor pairwise_conv_db: CompressedPairwiseConv + shifts_a: Optional[torch.Tensor] + shifts_b: Optional[torch.Tensor] def __post_init__(self): ( @@ -802,6 +812,9 @@ def subtract_conv( scalings_b=scalings, times_b=times, grid=True, + device=conv.device, + shifts_a=self.shifts_a, + shifts_b=self.shifts_b[template_indices] if self.shifts_b is not None else None, ) ix_template = template_indices_a[:, None] ix_time = times[:, None] + (conv_pad_len + self.conv_lags)[None, :] @@ -1002,6 +1015,7 @@ def __init__( self, n_spikes: int = 0, times: Optional[torch.LongTensor] = None, + objective_template_indices: Optional[torch.LongTensor] = None, template_indices: Optional[torch.LongTensor] = None, upsampling_indices: Optional[torch.LongTensor] = None, scalings: Optional[torch.Tensor] = None, @@ -1011,6 +1025,7 @@ def __init__( self.n_spikes = n_spikes self._times = times self._template_indices = template_indices + self._objective_template_indices = objective_template_indices self._upsampling_indices = upsampling_indices self._scalings = scalings self._scores = scores @@ -1027,6 +1042,10 @@ def __init__( self._template_indices = torch.zeros( self.cur_buf_size, dtype=int, device=device ) + if objective_template_indices is None: + self._objective_template_indices = torch.zeros( + self.cur_buf_size, dtype=int, device=device + ) if scalings is None: self._scalings = torch.ones(self.cur_buf_size, device=device) if upsampling_indices is None: @@ -1043,6 +1062,9 @@ def times(self): @property def template_indices(self): return self._template_indices[: self.n_spikes] + @property + def objective_template_indices(self): + return self._objective_template_indices[: self.n_spikes] @property def upsampling_indices(self): @@ -1061,6 +1083,7 @@ def grow_buffers(self, min_size=0): k = self.n_spikes self._times = _grow_buffer(self._times, k, sz) self._template_indices = _grow_buffer(self._template_indices, k, sz) + self._objective_template_indices = _grow_buffer(self._objective_template_indices, k, sz) self._upsampling_indices = _grow_buffer(self._upsampling_indices, k, sz) self._scalings = _grow_buffer(self._scalings, k, sz) self._scores = _grow_buffer(self._scores, k, sz) @@ -1070,6 +1093,7 @@ def sort(self): order = torch.argsort(self.times[: self.n_spikes]) self._times[: self.n_spikes] = self.times[order] self._template_indices[: self.n_spikes] = self.template_indices[order] + self._objective_template_indices[: self.n_spikes] = self.objective_template_indices[order] self._upsampling_indices[: self.n_spikes] = self.upsampling_indices[order] self._scalings[: self.n_spikes] = self.scalings[order] self._scores[: self.n_spikes] = self.scores[order] @@ -1080,6 +1104,7 @@ def extend(self, other): self.grow_buffers(min_size=new_n_spikes) self._times[self.n_spikes : new_n_spikes] = other.times self._template_indices[self.n_spikes : new_n_spikes] = other.template_indices + self._objective_template_indices[self.n_spikes : new_n_spikes] = other.objective_template_indices self._upsampling_indices[ self.n_spikes : new_n_spikes ] = other.upsampling_indices diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index eba3dd86..89031b53 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -51,19 +51,29 @@ class CompressedPairwiseConv: # pconv_ix -> a cross-correlation array # the 0 index is special: pconv[0] === 0. pconv: np.ndarray + in_memory: bool = False def __post_init__(self): assert self.shifts_a.ndim == self.shifts_b.ndim == 1 - assert self.shifts_a.size == self.shifted_template_index_a.shape[1] - assert self.shifts_b.size == self.upsampled_shifted_template_index_b.shape[1] - self._is_torch = False + assert self.shifts_a.shape == (self.shifted_template_index_a.shape[1],) + assert self.shifts_b.shape == (self.upsampled_shifted_template_index_b.shape[1],) @classmethod - def from_h5(cls, hdf5_filename): - ff = fields(cls) - with h5py.File(hdf5_filename, "r") as h5: - data = {f.name: h5[f.name][:] for f in ff} - return cls(**data) + def from_h5(cls, hdf5_filename, in_memory=True): + ff = [f for f in fields(cls) if not f.name == "in_memory"] + if in_memory: + with h5py.File(hdf5_filename, "r") as h5: + data = {f.name: torch.from_numpy(h5[f.name][:]) for f in ff} + return cls(**data, in_memory=in_memory) + + _h5 = h5py.File(hdf5_filename, "r") + data = {} + for f in ff: + if f.name == "pconv": + data[f.name] = _h5[f.name] + else: + data[f.name] = torch.from_numpy(_h5[f.name][:]) + return cls(**data, in_memory=in_memory) @classmethod def from_template_data( @@ -123,62 +133,68 @@ def at_shifts(self, shifts_a=None, shifts_b=None): n_shifted_temps_a, n_up_shifted_temps_b = self.pconv_index.shape # active shifted and upsampled indices - print(f"{self.shifts_a=} {shifts_a.min()=} {shifts_a.max()=}") - shift_ix_a = np.searchsorted(self.shifts_a, shifts_a) - shift_ix_b = np.searchsorted(self.shifts_b, shifts_b) + shift_ix_a = torch.searchsorted(self.shifts_a, shifts_a) + shift_ix_b = torch.searchsorted(self.shifts_b, shifts_b) sub_shifted_temp_index_a = self.shifted_template_index_a[ - np.arange(len(self.shifted_template_index_a))[:, None], + torch.arange(len(self.shifted_template_index_a))[:, None], shift_ix_a[:, None], ] sub_up_shifted_temp_index_b = self.upsampled_shifted_template_index_b[ - np.arange(len(self.upsampled_shifted_template_index_b))[:, None], + torch.arange(len(self.upsampled_shifted_template_index_b))[:, None], shift_ix_b[:, None], ] # in flat form for indexing into pconv_index. also, reindex. valid_a = sub_shifted_temp_index_a < n_shifted_temps_a - shifted_temp_ixs_a, new_shifted_temp_ixs_a = np.unique( + shifted_temp_ixs_a, new_shifted_temp_ixs_a = torch.unique( sub_shifted_temp_index_a[valid_a], return_inverse=True ) valid_b = sub_up_shifted_temp_index_b < n_up_shifted_temps_b - up_shifted_temp_ixs_b, new_up_shifted_temp_ixs_b = np.unique( + up_shifted_temp_ixs_b, new_up_shifted_temp_ixs_b = torch.unique( sub_up_shifted_temp_index_b[valid_b], return_inverse=True ) # get relevant pconv subset and reindex - sub_pconv_indices, new_pconv_indices = np.unique( + sub_pconv_indices, new_pconv_indices = torch.unique( self.pconv_index[ shifted_temp_ixs_a[:, None], up_shifted_temp_ixs_b.ravel()[None, :], ], return_inverse=True, ) - sub_pconv = self.pconv[sub_pconv_indices] + if self.in_memory: + sub_pconv = self.pconv[sub_pconv_indices.to(self.pconv.device)] + else: + sub_pconv = torch.from_numpy(batched_h5_read(self.pconv, sub_pconv_indices)) # reindexing n_sub_shifted_temps_a = len(shifted_temp_ixs_a) n_sub_up_shifted_temps_b = len(up_shifted_temp_ixs_b) - sub_pconv_index = new_pconv_indices.reshape( + sub_pconv_index = new_pconv_indices.view( n_sub_shifted_temps_a, n_sub_up_shifted_temps_b ) sub_shifted_temp_index_a[valid_a] = new_shifted_temp_ixs_a sub_up_shifted_temp_index_b[valid_b] = new_up_shifted_temp_ixs_b return self.__class__( - shifts_a=np.zeros(1), - shifts_b=np.zeros(1), + shifts_a=torch.zeros(1), + shifts_b=torch.zeros(1), shifted_template_index_a=sub_shifted_temp_index_a, upsampled_shifted_template_index_b=sub_up_shifted_temp_index_b, pconv_index=sub_pconv_index, pconv=sub_pconv, + in_memory=True, ) - def to(self, device=None): + def to(self, device=None, incl_pconv=False): """Become torch tensors on device.""" for f in fields(self): - setattr(self, f.name, torch.as_tensor(getattr(self, f.name), device=device)) + if f.name == "pconv": + continue + v = getattr(self, f.name) + if isinstance(v, np.ndarray) or torch.is_tensor(v): + setattr(self, f.name, torch.as_tensor(v, device=device)) self.device = device - self._is_torch = True return self def query( @@ -192,17 +208,14 @@ def query( times_b=None, return_zero_convs=False, grid=False, + device=None, ): if template_indices_a is None: - if self._is_torch: template_indices_a = torch.arange( len(self.shifted_template_index_a), device=self.device ) - else: - template_indices_a = np.arange(len(self.shifted_template_index_a)) - if not self._is_torch: - template_indices_a = np.atleast_1d(template_indices_a) - template_indices_b = np.atleast_1d(template_indices_b) + template_indices_a = torch.atleast_1d(template_indices_a) + template_indices_b = torch.atleast_1d(template_indices_b) # handle no shifting no_shifting = shifts_a is None or shifts_b is None @@ -217,8 +230,8 @@ def query( shifted_template_index = shifted_template_index[:, 0] upsampled_shifted_template_index = upsampled_shifted_template_index[:, 0] else: - shift_indices_a = np.searchsorted(self.shifts_a, shifts_a) - shift_indices_b = np.searchsorted(self.shifts_b, shifts_b) + shift_indices_a = torch.searchsorted(self.shifts_a, shifts_a) + shift_indices_b = torch.searchsorted(self.shifts_b, shifts_b) a_ix = (template_indices_a, shift_indices_a) b_ix = (template_indices_b, shift_indices_b) @@ -241,26 +254,14 @@ def query( pconv_indices = self.pconv_index[ shifted_temp_ix_a[:, None], up_shifted_temp_ix_b[None, :] ] - if self._is_torch: - template_indices_a, template_indices_b = torch.cartesian_prod( - template_indices_a, template_indices_b - ).T - if scalings_b is not None: - scalings_b = torch.broadcast_to(scalings_b[None], pconv_indices.shape).reshape(-1) - if times_b is not None: - times_b = torch.broadcast_to(times_b[None], pconv_indices.shape).reshape(-1) - pconv_indices = pconv_indices.view(-1) - else: - template_indices_a, template_indices_b = np.meshgrid( - template_indices_a, template_indices_b, indexing="ij" - ) - template_indices_a = template_indices_a.ravel() - template_indices_b = template_indices_b.ravel() - if scalings_b is not None: - scalings_b = np.broadcast_to(scalings_b[None], pconv_indices.shape).ravel() - if times_b is not None: - times_b = np.broadcast_to(times_b[None], pconv_indices.shape).ravel() - pconv_indices = pconv_indices.ravel() + template_indices_a, template_indices_b = torch.cartesian_prod( + template_indices_a, template_indices_b + ).T + if scalings_b is not None: + scalings_b = torch.broadcast_to(scalings_b[None], pconv_indices.shape).reshape(-1) + if times_b is not None: + times_b = torch.broadcast_to(times_b[None], pconv_indices.shape).reshape(-1) + pconv_indices = pconv_indices.view(-1) else: pconv_indices = self.pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b] @@ -275,7 +276,13 @@ def query( if times_b is not None: times_b = times_b[which] - pconvs = self.pconv[pconv_indices] + if self.in_memory: + pconvs = self.pconv[pconv_indices.to(self.pconv.device)] + else: + pconvs = torch.from_numpy(batched_h5_read(self.pconv, pconv_indices.numpy(force=True))) + if device is not None: + pconvs = pconvs.to(device) + if scalings_b is not None: pconvs.mul_(scalings_b[:, None]) @@ -283,3 +290,13 @@ def query( return template_indices_a, template_indices_b, times_b, pconvs return template_indices_a, template_indices_b, pconvs + +def batched_h5_read(dataset, indices, batch_size=1000): + if indices.size < batch_size: + return dataset[indices] + else: + out = np.empty((indices.size, *dataset.shape[1:]), dtype=dataset.dtype) + for bs in range(0, indices.size, batch_size): + be = min(indices.size, bs + batch_size) + out[bs:be] = dataset[indices[bs:be]] + return out \ No newline at end of file diff --git a/src/dartsort/vis/scatterplots.py b/src/dartsort/vis/scatterplots.py index 8d1258c3..67406451 100644 --- a/src/dartsort/vis/scatterplots.py +++ b/src/dartsort/vis/scatterplots.py @@ -24,6 +24,8 @@ def scatter_spike_features( amplitude_cmap=plt.cm.viridis, max_spikes_plot=500_000, probe_margin_um=100, + t_min=-np.inf, + t_max=np.inf, s=1, linewidth=0, limits="probe_margin", @@ -56,14 +58,14 @@ def scatter_spike_features( amplitudes = h5["denoised_amplitudes"][:] geom = h5["geom"][:] - to_show = None + to_show = np.flatnonzero(np.clip(times_s, t_min, t_max) == times_s) if geom is not None: - to_show = np.flatnonzero( - (depths_um > geom[:, 1].min() - probe_margin_um) - & (depths_um < geom[:, 1].max() + probe_margin_um) - & (x > geom[:, 0].min() - probe_margin_um) - & (x < geom[:, 0].max() + probe_margin_um) - ) + to_show = to_show[ + (depths_um[to_show] > geom[:, 1].min() - probe_margin_um) + & (depths_um[to_show] < geom[:, 1].max() + probe_margin_um) + & (x[to_show] > geom[:, 0].min() - probe_margin_um) + & (x[to_show] < geom[:, 0].max() + probe_margin_um) + ] _, s_x = scatter_x_vs_depth( x=x, From e225f300fdf7f62602e06464105011fccafefece Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 30 Nov 2023 16:55:13 -0500 Subject: [PATCH 43/49] Before trying hybrid gpu/cpu idea --- src/dartsort/peel/matching.py | 48 ++++++++++++++++++++++-------- src/dartsort/templates/pairwise.py | 2 +- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index 91b48332..89607adf 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -915,7 +915,7 @@ def fine_match( # unpack the current traces and the traces one step back snips_prev = residual_snips[:, :-1] - snips_dt = torch.stack((snips_prev, snips), dim=3) + # snips_dt = torch.stack((snips_prev, snips), dim=3) # now, upsampling # repeat the superres logic, the comp up index acts the same @@ -924,34 +924,56 @@ def fine_match( comp_up_ix < self.n_compressed_upsampled_templates ).nonzero(as_tuple=True) comp_up_indices = comp_up_ix[dup_ix, column_ix] - convs = torch.einsum( - "jtcd,jrc,jtr->jd", - snips_dt[dup_ix], - self.spatial_singular[template_indices[dup_ix]], + # convs = torch.einsum( + # "jtcd,jrc,jtr->jd", + # snips_dt[dup_ix], + # self.spatial_singular[template_indices[dup_ix]], + # self.compressed_upsampled_temporal[comp_up_indices], + # ) + temps = torch.bmm( self.compressed_upsampled_temporal[comp_up_indices], - ) + self.spatial_singular[template_indices[dup_ix]], + ).view(len(comp_up_indices), -1) + convs = torch.linalg.vecdot(snips[dup_ix].view(len(temps), -1), temps) + convs_prev = torch.linalg.vecdot(snips_prev[dup_ix].view(len(temps), -1), temps) + # convs = torch.einsum( + # "jtc,jrc,jtr->j", + # snips[dup_ix], + # self.spatial_singular[template_indices[dup_ix]], + # self.compressed_upsampled_temporal[comp_up_indices], + # ) + # convs_prev = torch.einsum( + # "jtc,jrc,jtr->j", + # snips_prev[dup_ix], + # self.spatial_singular[template_indices[dup_ix]], + # self.compressed_upsampled_temporal[comp_up_indices], + # ) + better = convs >= convs_prev + convs = torch.maximum(convs, convs_prev) + norms = norms[dup_ix] - objs = torch.full((*comp_up_ix.shape, 2), -torch.inf, device=convs.device) + objs = torch.full(comp_up_ix.shape, -torch.inf, device=convs.device) if amp_scale_variance: inv_lambda = 1 / amp_scale_variance b = convs + inv_lambda - a = norms[:, None] + inv_lambda + a = norms + inv_lambda scalings = torch.clip(b / a, amp_scale_min, amp_scale_max) objs[dup_ix, column_ix] = ( 2 * scalings * b - torch.square(scalings) * a - inv_lambda ) else: - objs[dup_ix, column_ix] = 2 * convs - norms[:, None] + objs[dup_ix, column_ix] = 2 * convs - norms scalings = None - objs, best_column_dt_ix = objs.reshape(len(objs), comp_up_ix.shape[1] * 2).max(dim=1) + objs, best_column_ix = objs.max(dim=1) - best_column_ix = best_column_dt_ix // 2 row_ix = torch.arange(len(objs), device=best_column_ix.device) comp_up_indices = comp_up_ix[row_ix, best_column_ix] upsampling_indices = self.compressed_index_to_upsampling_index[comp_up_indices] - # even positions have were one step earlier - time_shifts = best_column_dt_ix % 2 - 1 + # prev convs were one step earlier + time_shifts = torch.full(comp_up_ix.shape, -1, device=convs.device) + time_shifts[dup_ix, column_ix] += better + time_shifts = time_shifts[row_ix, best_column_ix] return time_shifts, upsampling_indices, scalings, template_indices, objs diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index 89031b53..a74d077b 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -282,7 +282,7 @@ def query( pconvs = torch.from_numpy(batched_h5_read(self.pconv, pconv_indices.numpy(force=True))) if device is not None: pconvs = pconvs.to(device) - + if scalings_b is not None: pconvs.mul_(scalings_b[:, None]) From 8236157cbc30e7818c8a6b3ead4fe34cc0676b50 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 4 Dec 2023 10:27:59 -0500 Subject: [PATCH 44/49] Try at_shifts --- src/dartsort/peel/matching.py | 49 +++++++++------ src/dartsort/templates/pairwise.py | 95 +++++++++++++++++++++++------- 2 files changed, 106 insertions(+), 38 deletions(-) diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index 89607adf..db8169d5 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -426,7 +426,8 @@ def peel_chunk( def templates_at_time(self, t_s): """Handle drift -- grab the right spatial neighborhoods.""" pconvdb = self.pairwise_conv_db - pitch_shifts_a=pitch_shifts_b=None + pitch_shifts_a = pitch_shifts_b = None + pconvdb.to(self.objective_spatial_components.device, pin=True) if self.is_drifting: pitch_shifts_b, cur_spatial = template_util.templates_at_time( t_s, @@ -464,17 +465,22 @@ def templates_at_time(self, t_s): fill_value=0.0, ) max_channels = cur_ampvecs[:, 0, :].argmax(1) - # pconvdb = pconvdb.at_shifts(pitch_shifts_a, pitch_shifts_b) + # pitch_shifts_a = torch.as_tensor(pitch_shifts_a) + # pitch_shifts_b = torch.as_tensor(pitch_shifts_b) pitch_shifts_a = torch.as_tensor(pitch_shifts_a, device=cur_obj_spatial.device) pitch_shifts_b = torch.as_tensor(pitch_shifts_b, device=cur_obj_spatial.device) + pconvdb = pconvdb.at_shifts(pitch_shifts_a, pitch_shifts_b) + # pitch_shifts_a = torch.as_tensor(pitch_shifts_a, device=cur_obj_spatial.device) + # pitch_shifts_b = torch.as_tensor(pitch_shifts_b, device=cur_obj_spatial.device) else: cur_spatial = self.spatial_components cur_obj_spatial = self.objective_spatial_components max_channels = self.registered_template_ampvecs.argmax(1) # if not pconvdb._is_torch: - # # pconvdb.to("cpu") - pconvdb.to(cur_obj_spatial.device) + # pconvdb.to("cpu") + # if cur_obj_spatial.device.type == "cuda" and not pconvdb.device.type == "cuda": + # pconvdb.to(cur_obj_spatial.device, pin=True) return MatchingTemplateData( objective_spatial_components=cur_obj_spatial, @@ -492,8 +498,10 @@ def templates_at_time(self, t_s): compressed_upsampled_temporal=self.compressed_upsampled_temporal, max_channels=torch.as_tensor(max_channels, device=cur_obj_spatial.device), pairwise_conv_db=pconvdb, - shifts_a=pitch_shifts_a, - shifts_b=pitch_shifts_b, + shifts_a=None, + shifts_b=None, + # shifts_a=pitch_shifts_a, + # shifts_b=pitch_shifts_b, ) def match_chunk( @@ -560,20 +568,20 @@ def match_chunk( # subtract them # old_norm = torch.linalg.norm(residual) ** 2 - compressed_template_data.subtract_conv( - padded_conv, + compressed_template_data.subtract( + residual_padded, new_peaks.times, new_peaks.template_indices, new_peaks.upsampling_indices, new_peaks.scalings, - conv_pad_len=self.obj_pad_len, ) - compressed_template_data.subtract( - residual_padded, + compressed_template_data.subtract_conv( + padded_conv, new_peaks.times, new_peaks.template_indices, new_peaks.upsampling_indices, new_peaks.scalings, + conv_pad_len=self.obj_pad_len, ) # new_norm = torch.linalg.norm(residual) ** 2 @@ -627,7 +635,7 @@ def find_peaks( alpha=2.0, out=padded_objective[:-1], ) - + # first step: coarse peaks. not temporally upsampled or amplitude-scaled. objective = (padded_objective + refrac_mask)[ :-1, self.obj_pad_len : -self.obj_pad_len @@ -668,7 +676,7 @@ def find_peaks( ) if time_shifts is not None: times += time_shifts - + return MatchingPeaks( n_spikes=times.numel(), times=times, @@ -884,12 +892,17 @@ def fine_match( superres_ix = superres_index[objective_template_indices] dup_ix, column_ix = (superres_ix < self.n_templates).nonzero(as_tuple=True) template_indices = superres_ix[dup_ix, column_ix] - convs = torch.einsum( - "jtc,jrc,jtr->j", - snips[dup_ix], - self.spatial_singular[template_indices], + convs = torch.baddbmm( self.temporal_components[template_indices], - ) + snips[dup_ix], + self.spatial_singular[template_indices].mT, + ).sum((1, 2)) + # convs = torch.einsum( + # "jtc,jrc,jtr->j", + # snips[dup_ix], + # self.spatial_singular[template_indices], + # self.temporal_components[template_indices], + # ) norms = self.template_norms_squared[template_indices] objs = torch.full(superres_ix.shape, -torch.inf, device=convs.device) objs[dup_ix, column_ix] = 2 * convs - norms diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index a74d077b..5b261288 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -52,20 +52,35 @@ class CompressedPairwiseConv: # the 0 index is special: pconv[0] === 0. pconv: np.ndarray in_memory: bool = False + device: torch.device = torch.device("cpu") def __post_init__(self): assert self.shifts_a.ndim == self.shifts_b.ndim == 1 assert self.shifts_a.shape == (self.shifted_template_index_a.shape[1],) - assert self.shifts_b.shape == (self.upsampled_shifted_template_index_b.shape[1],) + assert self.shifts_b.shape == ( + self.upsampled_shifted_template_index_b.shape[1], + ) + + self.a_shift_offset, self.offset_shift_a_to_ix = _get_shift_indexer( + self.shifts_a + ) + self.b_shift_offset, self.offset_shift_b_to_ix = _get_shift_indexer( + self.shifts_b + ) + + def get_shift_ix_a(self, shifts_a): + return self.offset_shift_a_to_ix[shifts_a.to(int) + self.a_shift_offset] + + def get_shift_ix_b(self, shifts_b): + return self.offset_shift_b_to_ix[shifts_b.to(int) + self.b_shift_offset] @classmethod def from_h5(cls, hdf5_filename, in_memory=True): - ff = [f for f in fields(cls) if not f.name == "in_memory"] + ff = [f for f in fields(cls) if f.name not in ("in_memory", "device")] if in_memory: with h5py.File(hdf5_filename, "r") as h5: data = {f.name: torch.from_numpy(h5[f.name][:]) for f in ff} return cls(**data, in_memory=in_memory) - _h5 = h5py.File(hdf5_filename, "r") data = {} for f in ff: @@ -117,7 +132,7 @@ def from_template_data( ) return cls.from_h5(hdf5_filename) - def at_shifts(self, shifts_a=None, shifts_b=None): + def at_shifts(self, shifts_a=None, shifts_b=None, device=None): """Subset this database to one set of shifts. The database becomes shiftless (not in the pejorative sense). @@ -133,8 +148,8 @@ def at_shifts(self, shifts_a=None, shifts_b=None): n_shifted_temps_a, n_up_shifted_temps_b = self.pconv_index.shape # active shifted and upsampled indices - shift_ix_a = torch.searchsorted(self.shifts_a, shifts_a) - shift_ix_b = torch.searchsorted(self.shifts_b, shifts_b) + shift_ix_a = self.get_shift_ix_a(shifts_a) + shift_ix_b = self.get_shift_ix_b(shifts_b) sub_shifted_temp_index_a = self.shifted_template_index_a[ torch.arange(len(self.shifted_template_index_a))[:, None], shift_ix_a[:, None], @@ -166,6 +181,8 @@ def at_shifts(self, shifts_a=None, shifts_b=None): sub_pconv = self.pconv[sub_pconv_indices.to(self.pconv.device)] else: sub_pconv = torch.from_numpy(batched_h5_read(self.pconv, sub_pconv_indices)) + if device is not None: + sub_pconv = sub_pconv.to(device) # reindexing n_sub_shifted_temps_a = len(shifted_temp_ixs_a) @@ -184,17 +201,30 @@ def at_shifts(self, shifts_a=None, shifts_b=None): pconv_index=sub_pconv_index, pconv=sub_pconv, in_memory=True, + device=self.device, ) - def to(self, device=None, incl_pconv=False): + def to(self, device=None, incl_pconv=False, pin=False): """Become torch tensors on device.""" - for f in fields(self): - if f.name == "pconv": + print(f"to {device=}") + for name in ["offset_shift_a_to_ix", "offset_shift_b_to_ix"] + [ + f.name for f in fields(self) + ]: + if name == "pconv" and not incl_pconv: continue - v = getattr(self, f.name) + v = getattr(self, name) if isinstance(v, np.ndarray) or torch.is_tensor(v): - setattr(self, f.name, torch.as_tensor(v, device=device)) + setattr(self, name, torch.as_tensor(v, device=device)) self.device = device + if pin and self.device.type == "cuda" and torch.cuda.is_available() and not self.pconv.is_pinned(): + # self.pconv.share_memory_() + print("pin") + torch.cuda.cudart().cudaHostRegister( + self.pconv.data_ptr(), self.pconv.numel() * self.pconv.element_size(), 0 + ) + # assert x.is_shared() + assert self.pconv.is_pinned() + # self.pconv = self.pconv.pin_memory() return self def query( @@ -211,9 +241,9 @@ def query( device=None, ): if template_indices_a is None: - template_indices_a = torch.arange( - len(self.shifted_template_index_a), device=self.device - ) + template_indices_a = torch.arange( + len(self.shifted_template_index_a), device=self.device + ) template_indices_a = torch.atleast_1d(template_indices_a) template_indices_b = torch.atleast_1d(template_indices_b) @@ -230,8 +260,8 @@ def query( shifted_template_index = shifted_template_index[:, 0] upsampled_shifted_template_index = upsampled_shifted_template_index[:, 0] else: - shift_indices_a = torch.searchsorted(self.shifts_a, shifts_a) - shift_indices_b = torch.searchsorted(self.shifts_b, shifts_b) + shift_indices_a = self.get_shift_ix_a(shifts_a) + shift_indices_b = self.get_shift_ix_a(shifts_b) a_ix = (template_indices_a, shift_indices_a) b_ix = (template_indices_b, shift_indices_b) @@ -250,6 +280,9 @@ def query( up_shifted_temp_ix_b = upsampled_shifted_template_index[b_ix] # return convolutions between all ai,bj or just ai,bi? + print(f"{shifted_temp_ix_a.device=} {up_shifted_temp_ix_b.device=}") + print(f"{self.device=} {self.shifts_a.device=}") + print(f"{template_indices_a.device=} {template_indices_b.device=}") if grid: pconv_indices = self.pconv_index[ shifted_temp_ix_a[:, None], up_shifted_temp_ix_b[None, :] @@ -258,9 +291,13 @@ def query( template_indices_a, template_indices_b ).T if scalings_b is not None: - scalings_b = torch.broadcast_to(scalings_b[None], pconv_indices.shape).reshape(-1) + scalings_b = torch.broadcast_to( + scalings_b[None], pconv_indices.shape + ).reshape(-1) if times_b is not None: - times_b = torch.broadcast_to(times_b[None], pconv_indices.shape).reshape(-1) + times_b = torch.broadcast_to( + times_b[None], pconv_indices.shape + ).reshape(-1) pconv_indices = pconv_indices.view(-1) else: pconv_indices = self.pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b] @@ -279,7 +316,9 @@ def query( if self.in_memory: pconvs = self.pconv[pconv_indices.to(self.pconv.device)] else: - pconvs = torch.from_numpy(batched_h5_read(self.pconv, pconv_indices.numpy(force=True))) + pconvs = torch.from_numpy( + batched_h5_read(self.pconv, pconv_indices.numpy(force=True)) + ) if device is not None: pconvs = pconvs.to(device) @@ -291,6 +330,7 @@ def query( return template_indices_a, template_indices_b, pconvs + def batched_h5_read(dataset, indices, batch_size=1000): if indices.size < batch_size: return dataset[indices] @@ -299,4 +339,19 @@ def batched_h5_read(dataset, indices, batch_size=1000): for bs in range(0, indices.size, batch_size): be = min(indices.size, bs + batch_size) out[bs:be] = dataset[indices[bs:be]] - return out \ No newline at end of file + return out + + +def _get_shift_indexer(shifts): + assert torch.equal(shifts, torch.sort(shifts).values) + shift_offset = -int(shifts[0]) + offset_shift_to_ix = [] + for j, shift in enumerate(shifts): + ix = shift + shift_offset + assert len(offset_shift_to_ix) <= ix + assert 0 <= ix < len(shifts) + while len(offset_shift_to_ix) < ix: + offset_shift_to_ix.append(len(shifts)) + offset_shift_to_ix.append(j) + offset_shift_to_ix = torch.tensor(offset_shift_to_ix, device=shifts.device) + return shift_offset, offset_shift_to_ix From e3cd9620e989291cf14f08064f2d2c85e001337c Mon Sep 17 00:00:00 2001 From: Julien Boussard Date: Mon, 4 Dec 2023 10:58:28 -0500 Subject: [PATCH 45/49] no hdbscan in requirements --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b2d0197f..a152b042 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,3 @@ pytest ibl-neuropixel spikeinterface cloudpickle -hdbscan From b68d4cf148f241bacf6217bf14a878687a95dd02 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 5 Dec 2023 19:39:29 -0500 Subject: [PATCH 46/49] Add more tests for template matching --- tests/test_matching.py | 472 ++++++++++++++++++++++++++++++++++++++++ tests/test_templates.py | 80 ++++--- tests/test_util.py | 26 +++ 3 files changed, 552 insertions(+), 26 deletions(-) create mode 100644 tests/test_matching.py create mode 100644 tests/test_util.py diff --git a/tests/test_matching.py b/tests/test_matching.py new file mode 100644 index 00000000..89c110a3 --- /dev/null +++ b/tests/test_matching.py @@ -0,0 +1,472 @@ +import numpy as np +import spikeinterface.full as si +import torch +import torch.nn.functional as F +from dartsort import config, main +from dartsort.templates import TemplateData, template_util +from dredge import motion_util +from test_util import no_overlap_recording_sorting + +nofeatcfg = config.FeaturizationConfig( + do_nn_denoise=False, + do_tpca_denoise=False, + do_enforce_decrease=False, + denoise_only=True, +) + +spike_length_samples = 121 +trough_offset_samples = 42 + + +def test_tiny(tmp_path): + recording_length_samples = 200 + n_channels = 2 + geom = np.c_[np.zeros(2), np.arange(2)] + geom + + # template main channel traces + trace0 = 50 * np.exp( + -(((np.arange(spike_length_samples) - trough_offset_samples) / 10) ** 2) + ) + + # templates + templates = np.zeros((2, spike_length_samples, n_channels), dtype="float32") + templates[0, :, 0] = trace0 + templates[1, :, 1] = trace0 + + # spike train + # fmt: off + tcl = [ + 50, 0, 0, + 51, 1, 1, + ] + # fmt: on + times, channels, labels = np.array(tcl).reshape(-1, 3).T + rec = np.zeros((recording_length_samples, n_channels), dtype="float32") + for t, l in zip(times, labels): + rec[ + t - trough_offset_samples : t - trough_offset_samples + spike_length_samples + ] += templates[l] + rec = si.NumpyRecording(rec, 30_000) + rec.set_dummy_probe_from_locations(geom) + + template_config = config.TemplateConfig( + low_rank_denoising=False, + superres_bin_min_spikes=0, + ) + template_data = TemplateData.from_config( + *no_overlap_recording_sorting(templates), + template_config, + motion_est=motion_util.IdentityMotionEstimate(), + n_jobs=0, + save_folder=tmp_path, + overwrite=True, + ) + + matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config( + rec, + main.MatchingConfig( + threshold=0.01, + template_temporal_upsampling_factor=1, + ), + nofeatcfg, + template_data, + motion_est=motion_util.IdentityMotionEstimate(), + ) + matcher.precompute_peeling_data(tmp_path) + res = matcher.peel_chunk( + torch.from_numpy(rec.get_traces().copy()), + return_residual=True, + return_conv=True, + ) + + ixa, ixb, pconv = matcher.pairwise_conv_db.query( + [0, 1], [0, 1], upsampling_indices_b=[0, 0], grid=True + ) + maxpc = pconv.max(dim=1).values + for ia, ib, pc in zip(ixa, ixb, maxpc): + assert np.isclose(pc, (templates[ia] * templates[ib]).sum()) + assert res["n_spikes"] == len(times) + assert np.array_equal(res["times_samples"], times) + assert np.array_equal(res["labels"], labels) + assert np.isclose( + torch.square(res["residual"]).mean(), + 0.0, + ) + assert np.isclose( + torch.square(res["conv"]).mean(), + 0.0, + atol=1e-5, + ) + + matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config( + rec, + main.MatchingConfig( + threshold=0.01, + template_temporal_upsampling_factor=8, + ), + nofeatcfg, + template_data, + motion_est=motion_util.IdentityMotionEstimate(), + ) + matcher.precompute_peeling_data(tmp_path) + res = matcher.peel_chunk( + torch.from_numpy(rec.get_traces().copy()), + return_residual=True, + return_conv=True, + ) + assert res["n_spikes"] == len(times) + assert np.array_equal(res["times_samples"], times) + assert np.array_equal(res["labels"], labels) + assert np.array_equal(res["upsampling_indices"], [0, 0]) + assert np.isclose( + torch.square(res["residual"]).mean(), + 0.0, + ) + assert np.isclose( + torch.square(res["conv"]).mean(), + 0.0, + atol=1e-5, + ) + + +def static_tester(tmp_path, up_factor=1): + recording_length_samples = 40_011 + n_channels = 2 + geom = np.c_[np.zeros(2), np.arange(2)] + geom + + # template main channel traces + trace0 = 50 * np.exp( + -(((np.arange(spike_length_samples) - trough_offset_samples) / 10) ** 2) + ) + trace1 = 250 * np.exp( + -(((np.arange(spike_length_samples) - trough_offset_samples) / 30) ** 2) + ) + + # templates + templates = np.zeros((3, spike_length_samples, n_channels), dtype="float32") + templates[0, :, 0] = trace0 + templates[1, :, 0] = trace1 + templates[2, :, 1] = trace0 + + # spike train + # fmt: off + tcl = [ + 100, 0, 0, + 150, 0, 0, + 151, 1, 2, + 500, 0, 1, + 2000, 0, 0, + 2001, 0, 1, + 35000, 1, 2, + 35001, 0, 1, + ] + # fmt: on + times, channels, labels = np.array(tcl).reshape(-1, 3).T + rec = np.zeros((recording_length_samples, n_channels), dtype="float32") + for t, l in zip(times, labels): + rec[ + t - trough_offset_samples : t - trough_offset_samples + spike_length_samples + ] += templates[l] + rec = si.NumpyRecording(rec, 30_000) + rec.set_dummy_probe_from_locations(geom) + + template_config = config.TemplateConfig( + low_rank_denoising=False, superres_bin_min_spikes=0 + ) + template_data = TemplateData.from_config( + *no_overlap_recording_sorting(templates), + template_config, + motion_est=motion_util.IdentityMotionEstimate(), + n_jobs=0, + save_folder=tmp_path, + overwrite=True, + ) + + matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config( + rec, + main.MatchingConfig( + threshold=0.01, + template_temporal_upsampling_factor=up_factor, + coarse_approx_error_threshold=0.0, + conv_ignore_threshold=0.0, + template_svd_compression_rank=2, + ), + nofeatcfg, + template_data, + motion_est=motion_util.IdentityMotionEstimate(), + ) + matcher.precompute_peeling_data(tmp_path) + + lrt = template_util.svd_compress_templates( + template_data.templates, rank=matcher.svd_compression_rank + ) + tempup = template_util.compressed_upsampled_templates( + lrt.temporal_components, + ptps=template_data.templates.ptp(1).max(1), + max_upsample=up_factor, + ) + assert np.array_equal(matcher.compressed_upsampled_temporal, tempup.compressed_upsampled_templates) + assert np.array_equal(matcher.objective_spatial_components, lrt.spatial_components) + assert np.array_equal(matcher.objective_singular_values, lrt.singular_values) + assert np.array_equal(matcher.spatial_components, lrt.spatial_components) + assert np.array_equal(matcher.singular_values, lrt.singular_values) + for up in range(up_factor): + ixa, ixb, pconv = matcher.pairwise_conv_db.query( + np.arange(3), + np.arange(3), + upsampling_indices_b=up + np.zeros(3, dtype=int), + grid=True, + ) + centerpc = pconv[:, spike_length_samples - 1] + for ia, ib, pc, pcf in zip(ixa, ixb, centerpc, pconv): + tempupb = tempup.compressed_upsampled_templates[ + tempup.compressed_upsampling_map[ib, up] + ] + tupb = (tempupb * lrt.singular_values[ib]) @ lrt.spatial_components[ib] + tc = (templates[ia] * tupb).sum() + + template_a = torch.as_tensor(templates[ia][None]) + ssb = lrt.singular_values[ib][:, None] * lrt.spatial_components[ib] + conv_filt = torch.bmm(torch.as_tensor(ssb[None]), template_a.mT) + conv_filt = conv_filt[:, None] # (nco, 1, rank, t) + conv_in = torch.as_tensor(tempupb[None]).mT[None] + pconv_ = F.conv2d( + conv_in, conv_filt, padding=(0, 120), groups=1 + ) + pconv1 = pconv_.squeeze()[spike_length_samples - 1].numpy(force=True) + assert torch.isclose(pcf, pconv_).all() + + pconv2 = F.conv2d( + torch.as_tensor(templates[ia])[None, None], + torch.as_tensor(tupb)[None, None], + ).squeeze().numpy(force=True) + assert np.isclose(pconv2, tc) + assert np.isclose(pc, tc) + assert np.isclose(pconv1, pc) + + res = matcher.peel_chunk( + torch.from_numpy(rec.get_traces().copy()), + return_residual=True, + return_conv=True, + ) + + assert res["n_spikes"] == len(times) + assert np.array_equal(res["times_samples"], times) + assert np.array_equal(res["labels"], labels) + assert np.isclose( + torch.square(res["residual"]).mean(), + 0.0, + ) + assert np.isclose( + torch.square(res["conv"]).mean(), + 0.0, + atol=1e-4, + ) + + +def test_static_noup(tmp_path): + static_tester(tmp_path) + + +def test_static_up(tmp_path): + static_tester(tmp_path, up_factor=4) + + +def drifting_tester(tmp_path, up_factor=1): + recording_length_samples = 40_011 + n_channels = 2 + geom = np.c_[np.zeros(2), np.arange(2)] + geom + + # template main channel traces + trace0 = 50 * np.exp( + -(((np.arange(spike_length_samples) - trough_offset_samples) / 10) ** 2) + ) + trace1 = 250 * np.exp( + -(((np.arange(spike_length_samples) - trough_offset_samples) / 30) ** 2) + ) + + # templates + templates = np.zeros((3, spike_length_samples, n_channels), dtype="float32") + templates[0, :, 0] = trace0 + templates[1, :, 0] = trace1 + templates[2, :, 1] = trace0 + + # spike train + # fmt: off + tcl = [ + 100, 0, 0, + 150, 0, 0, + 151, 1, 2, + 500, 0, 1, + 2000, 0, 0, + 2001, 0, 1, + 25000, 1, 2, + 25001, 0, 1, + ] + # fmt: on + times, channels, labels = np.array(tcl).reshape(-1, 3).T + rec = np.zeros((recording_length_samples, n_channels), dtype="float32") + for t, l in zip(times, labels): + rec[ + t - trough_offset_samples : t - trough_offset_samples + spike_length_samples + ] += templates[l] + rec = si.NumpyRecording(rec, 30_000) + rec.set_dummy_probe_from_locations(geom) + + template_config = config.TemplateConfig( + low_rank_denoising=False, superres_bin_min_spikes=0 + ) + template_data = TemplateData.from_config( + *no_overlap_recording_sorting(templates), + template_config, + motion_est=motion_util.IdentityMotionEstimate(), + n_jobs=0, + save_folder=tmp_path, + overwrite=True, + ) + + matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config( + rec, + main.MatchingConfig( + threshold=0.01, + template_temporal_upsampling_factor=up_factor, + coarse_approx_error_threshold=0.0, + conv_ignore_threshold=0.0, + template_svd_compression_rank=2, + ), + nofeatcfg, + template_data, + motion_est=motion_util.IdentityMotionEstimate(), + ) + matcher.precompute_peeling_data(tmp_path) + + # tup = template_util.compressed_upsampled_templates( + # template_data.templates, max_upsample=up_factor + # ) + lrt = template_util.svd_compress_templates( + template_data.templates, rank=matcher.svd_compression_rank + ) + tempup = template_util.compressed_upsampled_templates( + lrt.temporal_components, + ptps=template_data.templates.ptp(1).max(1), + max_upsample=up_factor, + ) + print(f"{lrt.temporal_components.shape=}") + print(f"{lrt.singular_values.shape=}") + print(f"{lrt.spatial_components.shape=}") + assert np.array_equal(matcher.compressed_upsampled_temporal, tempup.compressed_upsampled_templates) + assert np.array_equal(matcher.objective_spatial_components, lrt.spatial_components) + assert np.array_equal(matcher.objective_singular_values, lrt.singular_values) + assert np.array_equal(matcher.spatial_components, lrt.spatial_components) + assert np.array_equal(matcher.singular_values, lrt.singular_values) + for up in range(up_factor): + ixa, ixb, pconv = matcher.pairwise_conv_db.query( + np.arange(3), + np.arange(3), + upsampling_indices_b=up + np.zeros(3, dtype=int), + grid=True, + ) + centerpc = pconv[:, spike_length_samples - 1] + for ia, ib, pc, pcf in zip(ixa, ixb, centerpc, pconv): + # tupb = tup.compressed_upsampled_templates[ + # tup.compressed_upsampling_map[ib, up] + # ] + tempupb = tempup.compressed_upsampled_templates[ + tempup.compressed_upsampling_map[ib, up] + ] + tupb = (tempupb * lrt.singular_values[ib]) @ lrt.spatial_components[ib] + tc = (templates[ia] * tupb).sum() + + template_a = torch.as_tensor(templates[ia][None]) + ssb = lrt.singular_values[ib][:, None] * lrt.spatial_components[ib] + conv_filt = torch.bmm(torch.as_tensor(ssb[None]), template_a.mT) + conv_filt = conv_filt[:, None] # (nco, 1, rank, t) + conv_in = torch.as_tensor(tempupb[None]).mT[None] + pconv_ = F.conv2d( + conv_in, conv_filt, padding=(0, 120), groups=1 + ) + # print(f"{torch.abs(pcf - pconv_).max()=}") + pconv1 = pconv_.squeeze()[spike_length_samples - 1].numpy(force=True) + assert torch.isclose(pcf, pconv_).all() + + pconv2 = F.conv2d( + torch.as_tensor(templates[ia])[None, None], + torch.as_tensor(tupb)[None, None], + ).squeeze().numpy(force=True) + assert np.isclose(pconv2, tc) + + # print(f" - {ia=} {ib=}") + # print(f" {pc=} {tc=} {pconv1=} {pconv2=}") + # print(f" {pcf[120]=} {pcf[121]=} {pcf[122]=}") + # print(f" ~ {np.isclose(pc, tc)=}") + # print(f" {np.isclose(pconv1, pc)=} {np.isclose(tc, pconv2)=}") + assert np.isclose(pc, tc) + assert np.isclose(pconv1, pc) + + res = matcher.peel_chunk( + torch.from_numpy(rec.get_traces().copy()), + return_residual=True, + return_conv=True, + ) + + print() + print() + print(f"{len(times)=}") + print(f"{res['n_spikes']=}") + print() + print(f'{torch.square(res["residual"]).mean()=}') + print(f'{torch.abs(res["residual"]).max()=}') + print(f'{torch.square(res["conv"]).mean()=}') + print(f'{torch.abs(res["conv"]).max()=}') + print(f'{res["conv"].min()=} {res["conv"].max()=}') + tnsq = np.linalg.norm(templates, axis=(1, 2)) ** 2 + print(f"{res['conv'].shape=} {tnsq.shape=}") + print(f'{(2*res["conv"] - tnsq[:,None]).max()=}') + print() + print(f'{res["times_samples"]=}') + print(f"{times=}") + print() + print(f'{res["labels"]=}') + print(f"{labels=}") + print() + print(f'{np.c_[res["times_samples"], res["labels"], res["upsampling_indices"]]=}') + print(f"{np.c_[times, labels]=}") + print() + print(f'{res["upsampling_indices"]=}') + + assert res["n_spikes"] == len(times) + assert np.array_equal(res["times_samples"], times) + assert np.array_equal(res["labels"], labels) + print(f"{torch.square(res['residual']).mean()=}") + print(f"{torch.square(res['conv']).mean()=}") + assert np.isclose( + torch.square(res["residual"]).mean(), + 0.0, + ) + assert np.isclose( + torch.square(res["conv"]).mean(), + 0.0, + atol=1e-4, + ) + + +if __name__ == "__main__": + import tempfile + from pathlib import Path + + print("test tiny") + with tempfile.TemporaryDirectory() as tdir: + test_tiny(Path(tdir)) + + print() + print("test test_static_noup") + with tempfile.TemporaryDirectory() as tdir: + test_static_noup(Path(tdir)) + + print() + print("test test_static_up") + with tempfile.TemporaryDirectory() as tdir: + test_static_up(Path(tdir)) diff --git a/tests/test_templates.py b/tests/test_templates.py index e574dde3..b8ff4e67 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -8,7 +8,31 @@ template_util, templates) from dartsort.util import drift_util from dartsort.util.data_util import DARTsortSorting -from dredge.motion_util import get_motion_estimate +from dredge.motion_util import IdentityMotionEstimate, get_motion_estimate +from test_util import no_overlap_recording_sorting + + +def test_roundtrip(tmp_path): + rg = np.random.default_rng(0) + temps = rg.normal(size=(11, 121, 384)).astype(np.float32) + template_data = templates.TemplateData.from_config( + *no_overlap_recording_sorting(temps, pad=0), + template_config=config.TemplateConfig( + low_rank_denoising=False, + superres_bin_min_spikes=0, + realign_peaks=False, + ), + motion_est=IdentityMotionEstimate(), + n_jobs=0, + save_folder=tmp_path, + overwrite=True, + ) + print(f"{np.abs(template_data.templates - temps).max()=}") + print(f"{np.abs(template_data.templates - temps).mean()=}") + print(f"{np.abs(template_data.templates - temps).min()=}") + print(f"{template_data.templates.ptp(1).max(1)=}") + print(f"{temps.ptp(1).max(1)=}") + assert np.array_equal(template_data.templates, temps) def test_static_templates(): @@ -202,16 +226,15 @@ def test_pconv(): compressed_upsampled_temporal=ctempup, ) pconvdb = pairwise.CompressedPairwiseConv.from_h5(pconvdb_path) - assert np.all(pconvdb.pconv[0] == 0) - print(f"{pconvdb.pconv.shape=}") + assert (pconvdb.pconv[0] == 0.0).all() for tixa in range(5): for tixb in range(5): ixa, ixb, pconv = pconvdb.query(tixa, tixb) if (tixa, tixb) not in overlaps: - assert not ixa.size - assert not ixb.size - assert not pconv.size + assert not ixa.numel() + assert not ixb.numel() + assert not pconv.numel() continue olap = overlaps[tixa, tixb] @@ -259,7 +282,7 @@ def test_pconv(): chunk_time_centers_s=[0, 1, 2], ) pconvdb = pairwise.CompressedPairwiseConv.from_h5(pconvdb_path) - assert np.all(pconvdb.pconv[0] == 0) + assert (pconvdb.pconv[0] == 0.0).all() print(f"{pconvdb.pconv.shape=}") for tixa in range(5): @@ -267,9 +290,9 @@ def test_pconv(): ixa, ixb, pconv = pconvdb.query(tixa, tixb, shifts_a=0, shifts_b=0) if (tixa, tixb) not in overlaps: - assert not ixa.size - assert not ixb.size - assert not pconv.size + assert not ixa.numel() + assert not ixb.numel() + assert not pconv.numel() continue olap = overlaps[tixa, tixb] @@ -279,24 +302,24 @@ def test_pconv(): for tixb in range(5): for shiftb in (-1, 0, 1): ixa, ixb, pconv = pconvdb.query(0, tixb, shifts_a=-1, shifts_b=shiftb) - assert not ixa.size - assert not ixb.size - assert not pconv.size + assert not ixa.numel() + assert not ixb.numel() + assert not pconv.numel() for tixb in range(5): for shift in (-1, 0, 1): ixa, ixb, pconv = pconvdb.query(4, tixb, shifts_a=shift, shifts_b=shift) if tixb != 4 or shift == 1: - assert not ixa.size - assert not ixb.size - assert not pconv.size + assert not ixa.numel() + assert not ixb.numel() + assert not pconv.numel() else: assert np.isclose(pconv.max(), 4 if shift < 1 else 0) ixa, ixb, pconv = pconvdb.query(tixb, 4, shifts_a=shift, shifts_b=shift) if tixb != 4 or shift == 1: - assert not ixa.size - assert not ixb.size - assert not pconv.size + assert not ixa.numel() + assert not ixb.numel() + assert not pconv.numel() else: assert np.isclose(pconv.max(), 4) @@ -307,13 +330,18 @@ def test_pconv(): ixa, ixb, pconv = pconvdb.query(tixa, tixb, shifts_a=shifta, shifts_b=shiftb) if shifta != shiftb: # this is because we are rigid here - assert not ixa.size - assert not ixb.size - assert not pconv.size + assert not ixa.numel() + assert not ixb.numel() + assert not pconv.numel() if __name__ == "__main__": - test_static_templates() - test_drifting_templates() - test_main_object() - test_pconv() + import tempfile + from pathlib import Path + + with tempfile.TemporaryDirectory() as tdir: + test_roundtrip(Path(tdir)) + # test_static_templates() + # test_drifting_templates() + # test_main_object() + # test_pconv() diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 00000000..2d84e4e3 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,26 @@ +import numpy as np +import spikeinterface.core as sc +from dartsort.util.data_util import DARTsortSorting + + +def no_overlap_recording_sorting(templates, fs=30000, trough_offset_samples=42, pad=0): + n_templates, spike_length_samples, n_channels = templates.shape + rec = templates.reshape(n_templates * spike_length_samples, n_channels) + if pad > 0: + rec = np.pad(rec, [(pad, pad), (0, 0)]) + geom = np.c_[np.zeros(n_channels), np.arange(n_channels)] + rec = sc.NumpyRecording(rec, fs) + rec.set_dummy_probe_from_locations(geom) + depths = np.zeros(n_templates) + locs = np.c_[np.zeros_like(depths), np.zeros_like(depths), depths] + times = np.arange(n_templates) * spike_length_samples + trough_offset_samples + times_seconds = times / fs + sorting = DARTsortSorting( + times + pad, + np.zeros(n_templates), + np.arange(n_templates), + extra_features=dict( + point_source_localizations=locs, times_seconds=times_seconds + ), + ) + return rec, sorting From 439076bdadc0cb8ddc6a8ff1510a8900a912735b Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 5 Dec 2023 19:39:46 -0500 Subject: [PATCH 47/49] Debug matching and implement resid dist merge --- src/dartsort/cluster/merge.py | 213 ++++++++++++++++++++++++ src/dartsort/templates/get_templates.py | 7 +- src/dartsort/templates/pairwise.py | 13 +- src/dartsort/templates/pairwise_util.py | 123 +++++++++++++- src/dartsort/templates/template_util.py | 14 +- src/dartsort/templates/templates.py | 18 +- src/dartsort/util/spikeio.py | 6 +- 7 files changed, 367 insertions(+), 27 deletions(-) diff --git a/src/dartsort/cluster/merge.py b/src/dartsort/cluster/merge.py index e69de29b..0c87bbe5 100644 --- a/src/dartsort/cluster/merge.py +++ b/src/dartsort/cluster/merge.py @@ -0,0 +1,213 @@ +from dataclasses import replace +from typing import Optional + +import numpy as np +from dartsort.config import TemplateConfig +from dartsort.templates import TemplateData, template_util +from dartsort.templates.pairwise_util import ( + construct_shift_indices, iterate_compressed_pairwise_convolutions) +from dartsort.util import DARTsortSorting +from scipy.cluster.hierarchy import complete, fcluster + + +def merge_templates( + sorting: DARTsortSorting, + recording, + template_data: Optional[TemplateData] = None, + template_config: Optional[TemplateConfig] = None, + motion_est=None, + max_shift_samples=20, + superres_linkage=np.max, + merge_distance_threshold=0.25, + temporal_upsampling_factor=8, + amplitude_scaling_variance=0.0, + amplitude_scaling_boundary=0.5, + svd_compression_rank=10, + min_channel_amplitude=0.0, + conv_batch_size=1024, + units_batch_size=8, + device=None, + n_jobs=0, + n_jobs_templates=0, + template_save_folder=None, + overwrite_templates=False, + show_progress=True, + template_npz_filename="template_data.npz", +): + if template_data is None: + template_data = TemplateData.from_config( + recording, + sorting, + template_config, + motion_est=motion_est, + n_jobs=n_jobs_templates, + save_folder=template_save_folder, + overwrite=overwrite_templates, + device=device, + save_npz_name=template_npz_filename, + ) + + # allocate distance + shift matrices. shifts[i,j] is trough[j]-trough[i]. + n_templates = template_data.templates.shape[0] + sup_dists = np.full((n_templates, n_templates), np.inf) + sup_shifts = np.zero((n_templates, n_templates), dtype=int) + + # build distance matrix + dec_res_iter = get_deconv_resid_norm_iter( + template_data, + max_shift_samples=max_shift_samples, + temporal_upsampling_factor=temporal_upsampling_factor, + amplitude_scaling_variance=amplitude_scaling_variance, + amplitude_scaling_boundary=amplitude_scaling_boundary, + svd_compression_rank=svd_compression_rank, + min_channel_amplitude=min_channel_amplitude, + conv_batch_size=conv_batch_size, + units_batch_size=units_batch_size, + device=device, + n_jobs=n_jobs, + show_progress=show_progress, + ) + for res in dec_res_iter: + tixa = res.template_indices_a + tixb = res.template_indices_b + rms_ratio = res.deconv_resid_norms / res.template_a_norms + sup_dists[tixa, tixb] = rms_ratio + sup_shifts[tixa, tixb] = res.shifts + + # apply linkage to reduce across superres templates + units = np.unique(template_data.unit_ids) + if units.size < n_templates: + dists = np.full((units.size, units.size), np.inf) + shifts = np.zero((units.size, units.size), dtype=int) + for ia, ua in enumerate(units): + in_ua = np.flatnonzero(template_data.unit_ids == ua) + for ib, ub in enumerate(units): + in_ub = np.flatnonzero(template_data.unit_ids == ub) + in_pair = (in_ua[:, None], in_ub[None, :]) + dists[ia, ib] = superres_linkage(sup_dists[in_pair]) + shifts[ia, ib] = np.median(sup_shifts[in_pair]) + coarse_td = template_data.coarsen(with_locs=False) + template_snrs = coarse_td.templates.ptp(1).max(1) / coarse_td.spike_counts + else: + dists = sup_dists + shifts = sup_shifts + template_snrs = ( + template_data.templates.ptp(1).max(1) / template_data.spike_counts + ) + + # now run hierarchical clustering + return recluster( + sorting, + dists, + shifts, + template_snrs, + merge_distance_threshold=merge_distance_threshold, + ) + + +def recluster(sorting, dists, shifts, template_snrs, merge_distance_threshold=0.25): + # upper triangle not including diagonal, aka condensed distance matrix in scipy + pdist = dists[np.triu_indices(dists.shape[0], k=1)] + # scipy hierarchical clustering only supports finite values, so let's just + # drop in a huge value here + pdist[~np.isfinite(pdist)] = 1_000_000 + pdist[np.isfinite(pdist)].max() + # complete linkage: max dist between all pairs across clusters. + Z = complete(pdist) + # extract flat clustering using our max dist threshold + new_labels = fcluster(Z, merge_distance_threshold, criterion="distance") + + # update labels + labels_updated = sorting.labels.copy() + kept = np.flatnonzero(labels_updated >= 0) + labels_updated[kept] = new_labels[labels_updated[kept]] + + # update times according to shifts + times_updated = sorting.times_samples.copy() + + # find original labels in each cluster + clust_inverse = {i: [] for i in new_labels} + for orig_label, new_label in enumerate(new_labels): + clust_inverse[new_label].append(orig_label) + + # align to best snr unit + for new_label, orig_labels in clust_inverse.items(): + # we don't need to realign clusters which didn't change + if len(orig_labels) <= 1: + continue + + orig_snrs = template_snrs[orig_labels] + best_orig = orig_labels[orig_snrs.argmax()] + for ogl in np.setdiff1d(orig_labels, [best_orig]): + in_orig_unit = np.flatnonzero(sorting.labels == ogl) + # this is like trough[best] - trough[ogl] + shift_og_best = shifts[ogl, best_orig] + # if >0, trough of og is behind trough of best. + # subtracting will move trough of og to the right. + times_updated[in_orig_unit] -= shift_og_best + + return replace(sorting, times_samples=times_updated, labels=labels_updated) + + +def get_deconv_resid_norm_iter( + template_data, + max_shift_samples=20, + temporal_upsampling_factor=8, + amplitude_scaling_variance=0.0, + amplitude_scaling_boundary=0.5, + svd_compression_rank=10, + min_channel_amplitude=0.0, + conv_batch_size=1024, + units_batch_size=8, + device=None, + n_jobs=0, + show_progress=True, +): + # get template aux data + low_rank_templates = template_util.svd_compress_templates( + template_data.templates, + min_channel_amplitude=min_channel_amplitude, + rank=svd_compression_rank, + ) + compressed_upsampled_temporal = template_util.compressed_upsampled_templates( + low_rank_templates.temporal_components, + ptps=template_data.templates.ptp(1).max(1), + max_upsample=temporal_upsampling_factor, + ) + + # construct helper data and run pairwise convolutions + ( + template_shift_index_a, + template_shift_index_b, + upsampled_shifted_template_index, + cooccurrence, + ) = construct_shift_indices( + None, + None, + template_data, + compressed_upsampled_temporal, + motion_est=None, + ) + yield from iterate_compressed_pairwise_convolutions( + template_data, + low_rank_templates, + template_data, + low_rank_templates, + compressed_upsampled_temporal, + template_shift_index_a, + template_shift_index_b, + cooccurrence, + upsampled_shifted_template_index, + do_shifting=False, + reduce_deconv_resid_norm=True, + geom=template_data.registered_geometry, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + amplitude_scaling_variance=amplitude_scaling_variance, + amplitude_scaling_boundary=amplitude_scaling_boundary, + max_shift=max_shift_samples, + conv_batch_size=conv_batch_size, + units_batch_size=units_batch_size, + device=device, + n_jobs=n_jobs, + show_progress=show_progress, + ) diff --git a/src/dartsort/templates/get_templates.py b/src/dartsort/templates/get_templates.py index a354495d..51f53c57 100644 --- a/src/dartsort/templates/get_templates.py +++ b/src/dartsort/templates/get_templates.py @@ -102,6 +102,7 @@ def get_templates( # pad the trough_offset_samples and spike_length_samples so that # if the user did not request denoising we can just return the # raw templates right away + print("realign") trough_offset_load = trough_offset_samples + realign_max_sample_shift spike_length_load = spike_length_samples + 2 * realign_max_sample_shift raw_results = get_raw_templates( @@ -165,6 +166,7 @@ def get_templates( device=device, ) raw_templates, low_rank_templates, snrs_by_channel = res + print(f"{raw_templates.ptp(1).max(1)=}") if raw_only: return dict( @@ -571,8 +573,11 @@ def _template_job(unit_ids): # read waveforms for all units times = p.sorting.times_samples[in_units] valid = np.flatnonzero( - (times >= p.trough_offset_samples) & (times < p.max_spike_time) + (times >= p.trough_offset_samples) & (times <= p.max_spike_time) ) + print(f"{times=} {valid=}") + print(f"{p.trough_offset_samples=} {p.max_spike_time}") + print(f"{times.shape=} {valid.shape=}") if not valid.size: return in_units = in_units[valid] diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index 5b261288..e889746b 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -69,9 +69,11 @@ def __post_init__(self): ) def get_shift_ix_a(self, shifts_a): + shifts_a = torch.atleast_1d(torch.as_tensor(shifts_a)) return self.offset_shift_a_to_ix[shifts_a.to(int) + self.a_shift_offset] def get_shift_ix_b(self, shifts_b): + shifts_b = torch.atleast_1d(torch.as_tensor(shifts_b)) return self.offset_shift_b_to_ix[shifts_b.to(int) + self.b_shift_offset] @classmethod @@ -206,7 +208,6 @@ def at_shifts(self, shifts_a=None, shifts_b=None, device=None): def to(self, device=None, incl_pconv=False, pin=False): """Become torch tensors on device.""" - print(f"to {device=}") for name in ["offset_shift_a_to_ix", "offset_shift_b_to_ix"] + [ f.name for f in fields(self) ]: @@ -244,8 +245,8 @@ def query( template_indices_a = torch.arange( len(self.shifted_template_index_a), device=self.device ) - template_indices_a = torch.atleast_1d(template_indices_a) - template_indices_b = torch.atleast_1d(template_indices_b) + template_indices_a = torch.atleast_1d(torch.as_tensor(template_indices_a)) + template_indices_b = torch.atleast_1d(torch.as_tensor(template_indices_b)) # handle no shifting no_shifting = shifts_a is None or shifts_b is None @@ -271,18 +272,16 @@ def query( assert self.upsampled_shifted_template_index_b.shape[2] == 1 upsampled_shifted_template_index = upsampled_shifted_template_index[..., 0] else: - b_ix = b_ix + (upsampling_indices_b,) + b_ix = b_ix + (torch.atleast_1d(torch.as_tensor(upsampling_indices_b)),) # get shifted template indices for A + print(f"{a_ix=}") shifted_temp_ix_a = shifted_template_index[a_ix] # upsampled shifted template indices for B up_shifted_temp_ix_b = upsampled_shifted_template_index[b_ix] # return convolutions between all ai,bj or just ai,bi? - print(f"{shifted_temp_ix_a.device=} {up_shifted_temp_ix_b.device=}") - print(f"{self.device=} {self.shifts_a.device=}") - print(f"{template_indices_a.device=} {template_indices_b.device=}") if grid: pconv_indices = self.pconv_index[ shifted_temp_ix_a[:, None], up_shifted_temp_ix_b[None, :] diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index 9f0f45eb..6182ac46 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -2,8 +2,8 @@ from collections import namedtuple from dataclasses import dataclass, fields -from typing import Iterator, Optional, Union from pathlib import Path +from typing import Iterator, Optional, Union import h5py import numpy as np @@ -169,6 +169,9 @@ def iterate_compressed_pairwise_convolutions( conv_ignore_threshold=0.0, coarse_approx_error_threshold=0.0, max_shift="full", + amplitude_scaling_variance=0.0, + amplitude_scaling_boundary=0.5, + reduce_deconv_resid_norm=False, conv_batch_size=1024, units_batch_size=8, device=None, @@ -228,6 +231,9 @@ def iterate_compressed_pairwise_convolutions( coarse_approx_error_threshold=coarse_approx_error_threshold, max_shift=max_shift, batch_size=conv_batch_size, + amplitude_scaling_variance=amplitude_scaling_variance, + amplitude_scaling_boundary=amplitude_scaling_boundary, + reduce_deconv_resid_norm=reduce_deconv_resid_norm, ) n_jobs, Executor, context, rank_queue = get_pool(n_jobs, with_rank_queue=True) @@ -251,7 +257,9 @@ def iterate_compressed_pairwise_convolutions( @dataclass class CompressedConvResult: - """Return type of compressed_convolve_pairs + """Main return type of compressed_convolve_pairs + + If reduce_deconv_resid_norm=True, a DeconvResidResult is returned. After convolving a bunch of template pairs, some convolutions may be zero. Let n_pairs be the number of nonzero convolutions. @@ -278,6 +286,94 @@ class CompressedConvResult: compressed_conv: np.ndarray +@dataclass +class DeconvResidResult: + """Return type of compressed_convolve_pairs + + After convolving a bunch of template pairs, some convolutions + may be zero. Let n_pairs be the number of nonzero convolutions. + We don't store the zero ones. + """ + + # arrays of shape n_pairs, + # For each convolved pair, these document which templates were + # in the pair, what their relative shifts were, and what the + # upsampling was (we only upsample the RHS) + template_indices_a: np.ndarray + template_indices_b: np.ndarray + + # norm after subtracting best upsampled/scaled/shifted B template from A template + deconv_resid_norms: np.ndarray + + # ints. B trough - A trough + shifts: np.ndarray + + # for caller to implement different metrics + template_a_norms: np.ndarray + + # TODO: how to handle the nnz normalization we used to do? + # that one was done wrong -- the residual was not restricted + # to high amplitude channels. + + +def conv_to_resid( + template_data_a: templates.TemplateData, + template_data_b: templates.TemplateData, + conv_result: CompressedConvResult, + amplitude_scaling_variance=0.0, + amplitude_scaling_boundary=0.5, +) -> DeconvResidResult: + # decompress + pconvs = conv_result.compressed_conv[conv_result.compression_index] + full_length = pconvs.shape[1] + center = full_length // 2 + + # here, we just care about pairs of (superres) templates, not upsampling + # or shifting. so, get unique such pairs. + pairs = np.c_[conv_result.template_indices_a, conv_result.template_indices_b] + pairs = np.unique(pairs, axis=0) + n_pairs = len(pairs) + + # for loop to reduce over all (upsampled etc) member templates + deconv_resid_norms = np.zeros(n_pairs) + shifts = np.zeros(n_pairs, dtype=int) + template_indices_a, template_indices_b = pairs.T + templates_a = template_data_a.templates[template_indices_a].numpy(force=True) + templates_b = template_data_b.templates[template_indices_b].numpy(force=True) + template_a_norms = np.linalg.norm(templates_a, axis=(1, 2)) + template_b_norms = np.linalg.norm(templates_b, axis=(1, 2)) + for j, (ix_a, ix_b) in enumerate(pairs): + in_a = conv_result.template_indices_a == ix_a + in_b = conv_result.template_indices_b == ix_b + in_pair = np.flatnonzero(in_a & in_b) + + # reduce over fine templates + pair_conv = pconvs[in_pair].max(dim=0).values + best_conv, lag_index = pair_conv.max() + shifts[j] = lag_index - center + + # figure out scaling + if amplitude_scaling_variance: + amp_scale_min = 1 / (1 + amplitude_scaling_boundary) + amp_scale_max = 1 + amplitude_scaling_boundary + inv_lambda = 1 / amplitude_scaling_variance + b = best_conv + inv_lambda + a = template_a_norms[j] + inv_lambda + scaling = torch.clip(b / a, amp_scale_min, amp_scale_max) + norm_reduction = 2 * scaling * b - torch.square(scaling) * a - inv_lambda + else: + norm_reduction = 2 * best_conv - template_b_norms[j] + deconv_resid_norms[j] = template_a_norms[j] - norm_reduction + + return DeconvResidResult( + template_indices_a, + template_indices_b, + deconv_resid_norms, + shifts, + template_a_norms, + ) + + def compressed_convolve_pairs( template_data_a: templates.TemplateData, template_data_b: templates.TemplateData, @@ -297,6 +393,9 @@ def compressed_convolve_pairs( units_b: Optional[np.ndarray] = None, conv_ignore_threshold=0.0, coarse_approx_error_threshold=0.0, + amplitude_scaling_variance=0.0, + amplitude_scaling_boundary=0.5, + reduce_deconv_resid_norm=False, max_shift="full", batch_size=1024, device=None, @@ -363,7 +462,7 @@ def compressed_convolve_pairs( if pairs_ret is None: return None ix_a, ix_b, compression_index, conv_ix, spatial_shift_ids = pairs_ret - + # handle upsampling # each pair will be duplicated by the b unit's number of upsampled copies ( @@ -427,7 +526,7 @@ def compressed_convolve_pairs( temp_ix_b = temp_ix_b[ix_b] shift_ix_b = np.searchsorted(template_shift_index_b.all_pitch_shifts, shift_b[ix_b]) - return CompressedConvResult( + res = CompressedConvResult( template_indices_a=temp_ix_a, template_indices_b=temp_ix_b, shift_indices_a=shift_ix_a, @@ -436,6 +535,15 @@ def compressed_convolve_pairs( compression_index=compression_index, compressed_conv=pconv.numpy(force=True), ) + if reduce_deconv_resid_norm: + return conv_to_resid( + template_data_a, + template_data_b, + res, + amplitude_scaling_variance=amplitude_scaling_variance, + amplitude_scaling_boundary=amplitude_scaling_boundary, + ) + return res # -- helpers @@ -480,7 +588,7 @@ def correlate_pairs_lowrank( assert n_pairs == n_pairs_ assert t == t_ assert rank == rank_ - + if max_shift == "full": max_shift = t - 1 elif max_shift == "valid": @@ -514,7 +622,7 @@ def correlate_pairs_lowrank( pconv[istart:iend] = pconv_[0, :, 0, :] # nco, nup, time # more stringent covisibility - if conv_ignore_threshold > 0: + if conv_ignore_threshold is not None: max_val = pconv.reshape(n_pairs, -1).abs().max(dim=1).values kept = max_val > conv_ignore_threshold pconv = pconv[kept] @@ -1007,6 +1115,9 @@ class ConvWorkerContext: match_distance: Optional[float] = None conv_ignore_threshold: float = 0.0 coarse_approx_error_threshold: float = 0.0 + amplitude_scaling_variance: float = 0.0 + amplitude_scaling_boundary: float = 0.5 + reduce_deconv_resid_norm: bool = False max_shift: Union[int, str] = "full" batch_size: int = 128 device: Optional[torch.device] = None diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 57f303bd..30da5991 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -321,7 +321,7 @@ def compressed_upsampled_templates( n_upsamples = np.clip(n_upsamples_map(ptps), 1, max_upsample).astype(int) # build the compressed upsampling map - compressed_upsampling_map = np.zeros((n_templates, max_upsample), dtype=int) + compressed_upsampling_map = np.full((n_templates, max_upsample), -1, dtype=int) compressed_upsampling_index = np.full((n_templates, max_upsample), -1, dtype=int) template_indices = [] upsampling_indices = [] @@ -340,9 +340,19 @@ def compressed_upsampled_templates( # indices of the templates to keep in the full array of upsampled templates template_indices.extend([i] * nup) upsampling_indices.extend(compression * np.arange(nup)) + assert (compressed_upsampling_map >= 0).all() + assert ( + np.unique(compressed_upsampling_map).size + == (compressed_upsampling_index >= 0).sum() + == compressed_upsampling_map.max() + 1 + == compressed_upsampling_index.max() + 1 + == current_compressed_index + ) template_indices = np.array(template_indices) upsampling_indices = np.array(upsampling_indices) - compressed_upsampling_index[compressed_upsampling_index < 0] = current_compressed_index + compressed_upsampling_index[ + compressed_upsampling_index < 0 + ] = current_compressed_index # get the upsampled templates all_upsampled_templates = temporally_upsample_templates( diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 1abeee9d..b5af6bd7 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -49,7 +49,7 @@ def to_npz(self, npz_path): ] = self.registered_template_depths_um np.savez(npz_path, **to_save) - def coarsen(self): + def coarsen(self, with_locs=True): """Weighted average all templates that share a unit id and re-localize.""" # update templates unit_ids_unique, flat_ids = np.unique(self.unit_ids, return_inverse=True) @@ -60,11 +60,13 @@ def coarsen(self): np.add.at(spike_counts, flat_ids, self.spike_counts) # re-localize - registered_template_depths_um = get_template_depths( - templates, - self.registered_geom, - localization_radius_um=self.localization_radius_um, - ) + registered_template_depths_um = None + if with_locs: + registered_template_depths_um = get_template_depths( + templates, + self.registered_geom, + localization_radius_um=self.localization_radius_um, + ) return replace( self, @@ -167,7 +169,9 @@ def from_config( min_spikes_per_bin=template_config.superres_bin_min_spikes, ) else: + # we don't skip empty units unit_ids = np.arange(sorting.labels.max() + 1) + print(f"post superres {sorting.times_samples=} {sorting.labels=}") # count spikes in each template spike_counts = np.zeros_like(unit_ids) @@ -178,8 +182,6 @@ def from_config( results = get_templates(recording, sorting, **kwargs) # handle registered templates - print(f"{results['templates'].shape=}") - print(f"{kwargs['registered_geom'].shape=}") if template_config.registered_templates: registered_template_depths_um = get_template_depths( results["templates"], diff --git a/src/dartsort/util/spikeio.py b/src/dartsort/util/spikeio.py index 8f678c0e..0430b82a 100644 --- a/src/dartsort/util/spikeio.py +++ b/src/dartsort/util/spikeio.py @@ -21,7 +21,7 @@ def read_full_waveforms( assert times_samples.dtype.kind == "i" assert ( times_samples.max() - < recording.get_num_samples() + <= recording.get_num_samples() - (spike_length_samples - trough_offset_samples) ) n_channels = recording.get_num_channels() @@ -92,7 +92,7 @@ def read_subset_waveforms( assert times_samples.dtype.kind == "i" assert ( times_samples.max() - < recording.get_num_samples() + <= recording.get_num_samples() - (spike_length_samples - trough_offset_samples) ) n_channels = recording.get_num_channels() @@ -169,7 +169,7 @@ def read_waveforms_channel_index( assert times_samples.min() >= trough_offset_samples assert ( times_samples.max() - < recording.get_num_samples() + <= recording.get_num_samples() - (spike_length_samples - trough_offset_samples) ) n_channels = recording.get_num_channels() From f1795d9953d8a3d18c4a46d6f6093248b7abe9e5 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 5 Dec 2023 19:49:51 -0500 Subject: [PATCH 48/49] No print --- src/dartsort/templates/pairwise.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py index 5b261288..7d7f125c 100644 --- a/src/dartsort/templates/pairwise.py +++ b/src/dartsort/templates/pairwise.py @@ -206,7 +206,6 @@ def at_shifts(self, shifts_a=None, shifts_b=None, device=None): def to(self, device=None, incl_pconv=False, pin=False): """Become torch tensors on device.""" - print(f"to {device=}") for name in ["offset_shift_a_to_ix", "offset_shift_b_to_ix"] + [ f.name for f in fields(self) ]: @@ -280,9 +279,6 @@ def query( up_shifted_temp_ix_b = upsampled_shifted_template_index[b_ix] # return convolutions between all ai,bj or just ai,bi? - print(f"{shifted_temp_ix_a.device=} {up_shifted_temp_ix_b.device=}") - print(f"{self.device=} {self.shifts_a.device=}") - print(f"{template_indices_a.device=} {template_indices_b.device=}") if grid: pconv_indices = self.pconv_index[ shifted_temp_ix_a[:, None], up_shifted_temp_ix_b[None, :] From 96e124696e9d05dc3cdda18d303b85e229a66831 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 5 Dec 2023 22:52:55 -0500 Subject: [PATCH 49/49] Debug merge --- src/dartsort/cluster/merge.py | 39 ++++++++++++++++++++----- src/dartsort/peel/matching.py | 3 +- src/dartsort/templates/get_templates.py | 5 ---- src/dartsort/templates/pairwise_util.py | 20 +++++++------ src/dartsort/templates/templates.py | 3 +- src/dartsort/util/drift_util.py | 2 -- 6 files changed, 46 insertions(+), 26 deletions(-) diff --git a/src/dartsort/cluster/merge.py b/src/dartsort/cluster/merge.py index 0c87bbe5..fb0f20c3 100644 --- a/src/dartsort/cluster/merge.py +++ b/src/dartsort/cluster/merge.py @@ -6,7 +6,7 @@ from dartsort.templates import TemplateData, template_util from dartsort.templates.pairwise_util import ( construct_shift_indices, iterate_compressed_pairwise_convolutions) -from dartsort.util import DARTsortSorting +from dartsort.util.data_util import DARTsortSorting from scipy.cluster.hierarchy import complete, fcluster @@ -33,7 +33,28 @@ def merge_templates( overwrite_templates=False, show_progress=True, template_npz_filename="template_data.npz", -): +) -> DARTsortSorting: + """Template distance based merge + + Pass in a sorting, recording and template config to make templates, + and this will merge them (with superres). Or, if you have templates + already, pass them into template_data and we can skip the template + construction. + + Arguments + --------- + max_shift_samples + Max offset during matching + superres_linkage + How to combine distances between two units' superres templates + By default, it's the max. + amplitude_scaling_* + Optionally allow scaling during matching + + Returns + ------- + A new DARTsortSorting + """ if template_data is None: template_data = TemplateData.from_config( recording, @@ -50,7 +71,7 @@ def merge_templates( # allocate distance + shift matrices. shifts[i,j] is trough[j]-trough[i]. n_templates = template_data.templates.shape[0] sup_dists = np.full((n_templates, n_templates), np.inf) - sup_shifts = np.zero((n_templates, n_templates), dtype=int) + sup_shifts = np.zeros((n_templates, n_templates), dtype=int) # build distance matrix dec_res_iter = get_deconv_resid_norm_iter( @@ -78,7 +99,7 @@ def merge_templates( units = np.unique(template_data.unit_ids) if units.size < n_templates: dists = np.full((units.size, units.size), np.inf) - shifts = np.zero((units.size, units.size), dtype=int) + shifts = np.zeros((units.size, units.size), dtype=int) for ia, ua in enumerate(units): in_ua = np.flatnonzero(template_data.unit_ids == ua) for ib, ub in enumerate(units): @@ -98,6 +119,7 @@ def merge_templates( # now run hierarchical clustering return recluster( sorting, + units, dists, shifts, template_snrs, @@ -105,7 +127,7 @@ def merge_templates( ) -def recluster(sorting, dists, shifts, template_snrs, merge_distance_threshold=0.25): +def recluster(sorting, units, dists, shifts, template_snrs, merge_distance_threshold=0.25): # upper triangle not including diagonal, aka condensed distance matrix in scipy pdist = dists[np.triu_indices(dists.shape[0], k=1)] # scipy hierarchical clustering only supports finite values, so let's just @@ -118,8 +140,9 @@ def recluster(sorting, dists, shifts, template_snrs, merge_distance_threshold=0. # update labels labels_updated = sorting.labels.copy() - kept = np.flatnonzero(labels_updated >= 0) - labels_updated[kept] = new_labels[labels_updated[kept]] + kept = np.flatnonzero(np.isin(sorting.labels, units)) + _, flat_labels = np.unique(labels_updated[kept], return_inverse=True) + labels_updated[kept] = new_labels[flat_labels] # update times according to shifts times_updated = sorting.times_samples.copy() @@ -199,7 +222,7 @@ def get_deconv_resid_norm_iter( upsampled_shifted_template_index, do_shifting=False, reduce_deconv_resid_norm=True, - geom=template_data.registered_geometry, + geom=template_data.registered_geom, conv_ignore_threshold=0.0, coarse_approx_error_threshold=0.0, amplitude_scaling_variance=amplitude_scaling_variance, diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index db8169d5..bf17020b 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -427,7 +427,8 @@ def templates_at_time(self, t_s): """Handle drift -- grab the right spatial neighborhoods.""" pconvdb = self.pairwise_conv_db pitch_shifts_a = pitch_shifts_b = None - pconvdb.to(self.objective_spatial_components.device, pin=True) + if self.objective_spatial_components.device.type == "cuda" and not pconvdb.device.type == "cuda": + pconvdb.to(self.objective_spatial_components.device) if self.is_drifting: pitch_shifts_b, cur_spatial = template_util.templates_at_time( t_s, diff --git a/src/dartsort/templates/get_templates.py b/src/dartsort/templates/get_templates.py index 51f53c57..58b75d22 100644 --- a/src/dartsort/templates/get_templates.py +++ b/src/dartsort/templates/get_templates.py @@ -102,7 +102,6 @@ def get_templates( # pad the trough_offset_samples and spike_length_samples so that # if the user did not request denoising we can just return the # raw templates right away - print("realign") trough_offset_load = trough_offset_samples + realign_max_sample_shift spike_length_load = spike_length_samples + 2 * realign_max_sample_shift raw_results = get_raw_templates( @@ -166,7 +165,6 @@ def get_templates( device=device, ) raw_templates, low_rank_templates, snrs_by_channel = res - print(f"{raw_templates.ptp(1).max(1)=}") if raw_only: return dict( @@ -575,9 +573,6 @@ def _template_job(unit_ids): valid = np.flatnonzero( (times >= p.trough_offset_samples) & (times <= p.max_spike_time) ) - print(f"{times=} {valid=}") - print(f"{p.trough_offset_samples=} {p.max_spike_time}") - print(f"{times.shape=} {valid.shape=}") if not valid.size: return in_units = in_units[valid] diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index 6182ac46..1ac6ed8b 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -338,20 +338,21 @@ def conv_to_resid( deconv_resid_norms = np.zeros(n_pairs) shifts = np.zeros(n_pairs, dtype=int) template_indices_a, template_indices_b = pairs.T - templates_a = template_data_a.templates[template_indices_a].numpy(force=True) - templates_b = template_data_b.templates[template_indices_b].numpy(force=True) - template_a_norms = np.linalg.norm(templates_a, axis=(1, 2)) - template_b_norms = np.linalg.norm(templates_b, axis=(1, 2)) + templates_a = template_data_a.templates[template_indices_a] + templates_b = template_data_b.templates[template_indices_b] + template_a_norms = np.linalg.norm(templates_a, axis=(1, 2)) ** 2 + template_b_norms = np.linalg.norm(templates_b, axis=(1, 2)) ** 2 for j, (ix_a, ix_b) in enumerate(pairs): in_a = conv_result.template_indices_a == ix_a in_b = conv_result.template_indices_b == ix_b in_pair = np.flatnonzero(in_a & in_b) # reduce over fine templates - pair_conv = pconvs[in_pair].max(dim=0).values - best_conv, lag_index = pair_conv.max() + pair_conv = pconvs[in_pair].max(axis=0) + lag_index = np.argmax(pair_conv) + best_conv = pair_conv[lag_index] shifts[j] = lag_index - center - + # figure out scaling if amplitude_scaling_variance: amp_scale_min = 1 / (1 + amplitude_scaling_boundary) @@ -359,11 +360,12 @@ def conv_to_resid( inv_lambda = 1 / amplitude_scaling_variance b = best_conv + inv_lambda a = template_a_norms[j] + inv_lambda - scaling = torch.clip(b / a, amp_scale_min, amp_scale_max) - norm_reduction = 2 * scaling * b - torch.square(scaling) * a - inv_lambda + scaling = np.clip(b / a, amp_scale_min, amp_scale_max) + norm_reduction = 2 * scaling * b - np.square(scaling) * a - inv_lambda else: norm_reduction = 2 * best_conv - template_b_norms[j] deconv_resid_norms[j] = template_a_norms[j] - norm_reduction + assert deconv_resid_norms[j] >= 0 return DeconvResidResult( template_indices_a, diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index b5af6bd7..6ff43c34 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -47,6 +47,8 @@ def to_npz(self, npz_path): to_save[ "registered_template_depths_um" ] = self.registered_template_depths_um + if not npz_path.parent.exists(): + npz_path.parent.mkdir() np.savez(npz_path, **to_save) def coarsen(self, with_locs=True): @@ -171,7 +173,6 @@ def from_config( else: # we don't skip empty units unit_ids = np.arange(sorting.labels.max() + 1) - print(f"post superres {sorting.times_samples=} {sorting.labels=}") # count spikes in each template spike_counts = np.zeros_like(unit_ids) diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index 0bf4464a..7c3dce5d 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -564,8 +564,6 @@ def get_shift_and_unit_pairs( else: shifts_a = pitch_shifts[:, :na] shifts_b = pitch_shifts[:, na:] - print(f"{shifts_a.min()=} {shifts_a.max()=}") - print(f"{shifts_b.min()=} {shifts_b.max()=}") # assign ids to pitch/shift pairs template_shift_index_a = TemplateShiftIndex.from_shift_matrix(shifts_a)