diff --git a/environment.yml b/environment.yml index cf7f34c6..d8123cce 100644 --- a/environment.yml +++ b/environment.yml @@ -7,3 +7,4 @@ dependencies: - h5py - tqdm - scikit-learn + - colorcet diff --git a/requirements.txt b/requirements.txt index e18ff27a..a152b042 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ pytest ibl-neuropixel spikeinterface -cloudpickle \ No newline at end of file +cloudpickle diff --git a/scripts/uhd_pipeline.py b/scripts/uhd_pipeline.py index 215fbf20..331266a3 100644 --- a/scripts/uhd_pipeline.py +++ b/scripts/uhd_pipeline.py @@ -164,7 +164,7 @@ # Don't trust spikeinterface preprocessing :( ... if preprocessing: print("Preprocessing...") - preprocessing_dir = Path(output_all) / "preprocessing_test" + preprocessing_dir = Path(output_all) / "preprocessing" Path(preprocessing_dir).mkdir(exist_ok=True) if t_end_preproc is None: t_end_preproc=rec_len_sec diff --git a/src/dartsort/cluster/merge.py b/src/dartsort/cluster/merge.py index e69de29b..fb0f20c3 100644 --- a/src/dartsort/cluster/merge.py +++ b/src/dartsort/cluster/merge.py @@ -0,0 +1,236 @@ +from dataclasses import replace +from typing import Optional + +import numpy as np +from dartsort.config import TemplateConfig +from dartsort.templates import TemplateData, template_util +from dartsort.templates.pairwise_util import ( + construct_shift_indices, iterate_compressed_pairwise_convolutions) +from dartsort.util.data_util import DARTsortSorting +from scipy.cluster.hierarchy import complete, fcluster + + +def merge_templates( + sorting: DARTsortSorting, + recording, + template_data: Optional[TemplateData] = None, + template_config: Optional[TemplateConfig] = None, + motion_est=None, + max_shift_samples=20, + superres_linkage=np.max, + merge_distance_threshold=0.25, + temporal_upsampling_factor=8, + amplitude_scaling_variance=0.0, + amplitude_scaling_boundary=0.5, + svd_compression_rank=10, + min_channel_amplitude=0.0, + conv_batch_size=1024, + units_batch_size=8, + device=None, + n_jobs=0, + n_jobs_templates=0, + template_save_folder=None, + overwrite_templates=False, + show_progress=True, + template_npz_filename="template_data.npz", +) -> DARTsortSorting: + """Template distance based merge + + Pass in a sorting, recording and template config to make templates, + and this will merge them (with superres). Or, if you have templates + already, pass them into template_data and we can skip the template + construction. + + Arguments + --------- + max_shift_samples + Max offset during matching + superres_linkage + How to combine distances between two units' superres templates + By default, it's the max. + amplitude_scaling_* + Optionally allow scaling during matching + + Returns + ------- + A new DARTsortSorting + """ + if template_data is None: + template_data = TemplateData.from_config( + recording, + sorting, + template_config, + motion_est=motion_est, + n_jobs=n_jobs_templates, + save_folder=template_save_folder, + overwrite=overwrite_templates, + device=device, + save_npz_name=template_npz_filename, + ) + + # allocate distance + shift matrices. shifts[i,j] is trough[j]-trough[i]. + n_templates = template_data.templates.shape[0] + sup_dists = np.full((n_templates, n_templates), np.inf) + sup_shifts = np.zeros((n_templates, n_templates), dtype=int) + + # build distance matrix + dec_res_iter = get_deconv_resid_norm_iter( + template_data, + max_shift_samples=max_shift_samples, + temporal_upsampling_factor=temporal_upsampling_factor, + amplitude_scaling_variance=amplitude_scaling_variance, + amplitude_scaling_boundary=amplitude_scaling_boundary, + svd_compression_rank=svd_compression_rank, + min_channel_amplitude=min_channel_amplitude, + conv_batch_size=conv_batch_size, + units_batch_size=units_batch_size, + device=device, + n_jobs=n_jobs, + show_progress=show_progress, + ) + for res in dec_res_iter: + tixa = res.template_indices_a + tixb = res.template_indices_b + rms_ratio = res.deconv_resid_norms / res.template_a_norms + sup_dists[tixa, tixb] = rms_ratio + sup_shifts[tixa, tixb] = res.shifts + + # apply linkage to reduce across superres templates + units = np.unique(template_data.unit_ids) + if units.size < n_templates: + dists = np.full((units.size, units.size), np.inf) + shifts = np.zeros((units.size, units.size), dtype=int) + for ia, ua in enumerate(units): + in_ua = np.flatnonzero(template_data.unit_ids == ua) + for ib, ub in enumerate(units): + in_ub = np.flatnonzero(template_data.unit_ids == ub) + in_pair = (in_ua[:, None], in_ub[None, :]) + dists[ia, ib] = superres_linkage(sup_dists[in_pair]) + shifts[ia, ib] = np.median(sup_shifts[in_pair]) + coarse_td = template_data.coarsen(with_locs=False) + template_snrs = coarse_td.templates.ptp(1).max(1) / coarse_td.spike_counts + else: + dists = sup_dists + shifts = sup_shifts + template_snrs = ( + template_data.templates.ptp(1).max(1) / template_data.spike_counts + ) + + # now run hierarchical clustering + return recluster( + sorting, + units, + dists, + shifts, + template_snrs, + merge_distance_threshold=merge_distance_threshold, + ) + + +def recluster(sorting, units, dists, shifts, template_snrs, merge_distance_threshold=0.25): + # upper triangle not including diagonal, aka condensed distance matrix in scipy + pdist = dists[np.triu_indices(dists.shape[0], k=1)] + # scipy hierarchical clustering only supports finite values, so let's just + # drop in a huge value here + pdist[~np.isfinite(pdist)] = 1_000_000 + pdist[np.isfinite(pdist)].max() + # complete linkage: max dist between all pairs across clusters. + Z = complete(pdist) + # extract flat clustering using our max dist threshold + new_labels = fcluster(Z, merge_distance_threshold, criterion="distance") + + # update labels + labels_updated = sorting.labels.copy() + kept = np.flatnonzero(np.isin(sorting.labels, units)) + _, flat_labels = np.unique(labels_updated[kept], return_inverse=True) + labels_updated[kept] = new_labels[flat_labels] + + # update times according to shifts + times_updated = sorting.times_samples.copy() + + # find original labels in each cluster + clust_inverse = {i: [] for i in new_labels} + for orig_label, new_label in enumerate(new_labels): + clust_inverse[new_label].append(orig_label) + + # align to best snr unit + for new_label, orig_labels in clust_inverse.items(): + # we don't need to realign clusters which didn't change + if len(orig_labels) <= 1: + continue + + orig_snrs = template_snrs[orig_labels] + best_orig = orig_labels[orig_snrs.argmax()] + for ogl in np.setdiff1d(orig_labels, [best_orig]): + in_orig_unit = np.flatnonzero(sorting.labels == ogl) + # this is like trough[best] - trough[ogl] + shift_og_best = shifts[ogl, best_orig] + # if >0, trough of og is behind trough of best. + # subtracting will move trough of og to the right. + times_updated[in_orig_unit] -= shift_og_best + + return replace(sorting, times_samples=times_updated, labels=labels_updated) + + +def get_deconv_resid_norm_iter( + template_data, + max_shift_samples=20, + temporal_upsampling_factor=8, + amplitude_scaling_variance=0.0, + amplitude_scaling_boundary=0.5, + svd_compression_rank=10, + min_channel_amplitude=0.0, + conv_batch_size=1024, + units_batch_size=8, + device=None, + n_jobs=0, + show_progress=True, +): + # get template aux data + low_rank_templates = template_util.svd_compress_templates( + template_data.templates, + min_channel_amplitude=min_channel_amplitude, + rank=svd_compression_rank, + ) + compressed_upsampled_temporal = template_util.compressed_upsampled_templates( + low_rank_templates.temporal_components, + ptps=template_data.templates.ptp(1).max(1), + max_upsample=temporal_upsampling_factor, + ) + + # construct helper data and run pairwise convolutions + ( + template_shift_index_a, + template_shift_index_b, + upsampled_shifted_template_index, + cooccurrence, + ) = construct_shift_indices( + None, + None, + template_data, + compressed_upsampled_temporal, + motion_est=None, + ) + yield from iterate_compressed_pairwise_convolutions( + template_data, + low_rank_templates, + template_data, + low_rank_templates, + compressed_upsampled_temporal, + template_shift_index_a, + template_shift_index_b, + cooccurrence, + upsampled_shifted_template_index, + do_shifting=False, + reduce_deconv_resid_norm=True, + geom=template_data.registered_geom, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + amplitude_scaling_variance=amplitude_scaling_variance, + amplitude_scaling_boundary=amplitude_scaling_boundary, + max_shift=max_shift_samples, + conv_batch_size=conv_batch_size, + units_batch_size=units_batch_size, + device=device, + n_jobs=n_jobs, + show_progress=show_progress, + ) diff --git a/src/dartsort/cluster/split.py b/src/dartsort/cluster/split.py index 4b4c7f66..479b4232 100644 --- a/src/dartsort/cluster/split.py +++ b/src/dartsort/cluster/split.py @@ -74,9 +74,7 @@ def split_clusters( new_labels = split_result.new_labels triaged = split_result.new_labels < 0 labels[in_unit[triaged]] = new_labels[triaged] - labels[in_unit[new_labels > 0]] = ( - cur_max_label + new_labels[new_labels > 0] - ) + labels[in_unit[new_labels > 0]] = cur_max_label + new_labels[new_labels > 0] new_untriaged_labels = labels[in_unit[new_labels >= 0]] cur_max_label = new_untriaged_labels.max() @@ -84,9 +82,7 @@ def split_clusters( if recursive: new_units = np.unique(new_untriaged_labels) for i in new_units: - jobs.append( - pool.submit(_split_job, np.flatnonzero(labels == i)) - ) + jobs.append(pool.submit(_split_job, np.flatnonzero(labels == i))) if show_progress: iterator.total += len(new_units) @@ -151,7 +147,7 @@ def __init__( min_cluster_size=25, min_samples=25, cluster_selection_epsilon=25, - reassign_outliers=True, + reassign_outliers=False, random_state=0, **dataset_name_kwargs, ): @@ -241,18 +237,14 @@ def split_cluster(self, in_unit): is_split = np.setdiff1d(np.unique(hdb_labels), [-1]).size > 1 if is_split and self.reassign_outliers: - hdb_labels = cluster_util.knn_reassign_outliers( - hdb_labels, features - ) + hdb_labels = cluster_util.knn_reassign_outliers(hdb_labels, features) new_labels = None if is_split: new_labels = np.full(n_spikes, -1) new_labels[kept] = hdb_labels - return SplitResult( - is_split=is_split, in_unit=in_unit, new_labels=new_labels - ) + return SplitResult(is_split=is_split, in_unit=in_unit, new_labels=new_labels) def pca_features(self, in_unit): """Compute relocated PCA features on a drift-invariant channel set""" @@ -316,12 +308,8 @@ def pca_features(self, in_unit): return False, no_nan, None # fit pca and embed - pca = PCA( - self.n_pca_features, random_state=self.random_state, whiten=True - ) - pca_projs = np.full( - (n, self.n_pca_features), np.nan, dtype=waveforms.dtype - ) + pca = PCA(self.n_pca_features, random_state=self.random_state, whiten=True) + pca_projs = np.full((n, self.n_pca_features), np.nan, dtype=waveforms.dtype) pca_projs[no_nan] = pca.fit_transform(waveforms[no_nan]) return True, no_nan, pca_projs @@ -335,7 +323,7 @@ def initialize_from_h5( amplitudes_dataset_name="denoised_amplitudes", amplitude_vectors_dataset_name="denoised_amplitude_vectors", ): - h5 = h5py.File(peeling_hdf5_filename, "r") + h5 = h5py.File(peeling_hdf5_filename, "r", locking=False) self.geom = h5["geom"][:] self.channel_index = h5["channel_index"][:] @@ -386,9 +374,7 @@ def initialize_from_h5( # this is to help split_clusters take a string argument all_split_strategies = [FeatureSplit] -split_strategies_by_class_name = { - cls.__name__: cls for cls in all_split_strategies -} +split_strategies_by_class_name = {cls.__name__: cls for cls in all_split_strategies} # -- parallelism widgets @@ -404,9 +390,7 @@ def __init__(self, split_strategy): def _split_job_init(split_strategy_class_name, split_strategy_kwargs): global _split_job_context split_strategy = split_strategies_by_class_name[split_strategy_class_name] - _split_job_context = SplitJobContext( - split_strategy(**split_strategy_kwargs) - ) + _split_job_context = SplitJobContext(split_strategy(**split_strategy_kwargs)) def _split_job(in_unit): diff --git a/src/dartsort/config.py b/src/dartsort/config.py index 7afb974c..2c74bbb3 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -76,6 +76,7 @@ class FeaturizationConfig: localization_radius: float = 100.0 # these are saved always if do_localization save_amplitude_vectors: bool = True + localization_model = "dipole" # -- further info about denoising # in the future we may add multi-channel or other nns @@ -96,7 +97,7 @@ class FeaturizationConfig: class SubtractionConfig: trough_offset_samples: int = 42 spike_length_samples: int = 121 - detection_thresholds: List[int] = (12, 10, 8, 6, 5, 4) + detection_thresholds: List[int] = (12, 9, 6) chunk_length_samples: int = 30_000 peak_sign: str = "neg" spatial_dedup_radius: float = 150.0 @@ -153,7 +154,7 @@ class MatchingConfig: fit_subsampling_random_state: int = 0 # template matching parameters - threshold: float = 30.0 + threshold: float = 50.0 template_svd_compression_rank: int = 10 template_temporal_upsampling_factor: int = 8 template_min_channel_amplitude: float = 1.0 @@ -161,6 +162,8 @@ class MatchingConfig: amplitude_scaling_variance: float = 0.0 amplitude_scaling_boundary: float = 0.5 max_iter: int = 1000 + conv_ignore_threshold: float = 5.0 + coarse_approx_error_threshold: float = 5.0 @dataclass(frozen=True) class ClusteringConfig: @@ -174,4 +177,4 @@ class ClusteringConfig: # -- ensembling ensemble_strategy: Optional[str] = "forward_backward" chunk_size_s: int = 300 - # forward-backward \ No newline at end of file + # forward-backward diff --git a/src/dartsort/detect/detect.py b/src/dartsort/detect/detect.py index 10e46074..7700e2f0 100644 --- a/src/dartsort/detect/detect.py +++ b/src/dartsort/detect/detect.py @@ -41,7 +41,8 @@ def detect_and_deduplicate( with corresponding channels """ nsamples, nchans = traces.shape - if dedup_channel_index is not None: + all_dedup = isinstance(dedup_channel_index, str) and dedup_channel_index == "all" + if not all_dedup and dedup_channel_index is not None: assert dedup_channel_index.shape[0] == nchans # -- handle peak sign. we use max pool below, so make peaks positive @@ -79,14 +80,17 @@ def detect_and_deduplicate( F.threshold_(energies, threshold, 0.0) # -- temporal deduplication - max_energies = energies if dedup_temporal_radius > 0: - max_energies, indices = F.max_pool2d_with_indices( + max_energies = F.max_pool2d( energies, kernel_size=[2 * dedup_temporal_radius + 1, 1], stride=1, padding=[dedup_temporal_radius, 0], ) + elif dedup_channel_index is not None: + max_energies = energies.clone() + else: + max_energies = energies # back to TC energies = energies[0, 0] max_energies = max_energies[0, 0] @@ -94,7 +98,9 @@ def detect_and_deduplicate( # -- spatial deduplication # we would like to max pool again on the other axis, # but that doesn't support any old radial neighborhood - if dedup_channel_index is not None: + if all_dedup: + max_energies[:] = max_energies.max(dim=1, keepdim=True).values + elif dedup_channel_index is not None: # pad channel axis with extra chan of 0s max_energies = F.pad(max_energies, (0, 1)) for batch_start in range(0, nsamples, spatial_dedup_batch_size): diff --git a/src/dartsort/localize/localize_torch.py b/src/dartsort/localize/localize_torch.py index d7e4d20b..42741034 100644 --- a/src/dartsort/localize/localize_torch.py +++ b/src/dartsort/localize/localize_torch.py @@ -19,6 +19,7 @@ def localize_amplitude_vectors( dtype=torch.double, y0=1.0, levenberg_marquardt_kwargs=None, + th_dipole_proj_dist=250.0, ): """Localize a bunch of amplitude vectors with torch @@ -59,7 +60,7 @@ def localize_amplitude_vectors( # maybe this will become a wrapper function if we want more models. # and, this is why we return a dict, different models will have different # parameters - assert model in ("com", "pointsource") + assert model in ("com", "pointsource", "dipole") n_spikes, c = amplitude_vectors.shape n_channels_tot = len(geom) if channel_index is None: @@ -67,6 +68,8 @@ def localize_amplitude_vectors( channel_index = full_channel_index(n_channels_tot) assert channel_index.shape == (n_channels_tot, c) assert main_channels.shape == (n_spikes,) + # we'll return numpy if user sent numpy + is_numpy = not torch.is_tensor(amplitude_vectors) # handle channel subsetting if radius is not None or n_channels_subset is not None: @@ -100,6 +103,8 @@ def localize_amplitude_vectors( geom_pad = F.pad(geom, (0, 0, 0, 1)) local_geoms = geom_pad[channel_index[main_channels]] local_geoms[:, :, 1] -= geom[main_channels, 1][:, None] + print(f"{amplitude_vectors.shape=}") + print(f"{local_geoms.shape=}") # center of mass initialization com = torch.divide( @@ -110,10 +115,7 @@ def localize_amplitude_vectors( if model == "com": z_abs_com = zcom + geom[main_channels, 1] - nancom = torch.full_like(xcom, torch.nan) - return dict( - x=xcom, y=nancom, z_rel=zcom, z_abs=z_abs_com, alpha=nancom - ) + return dict(x=xcom, z_rel=zcom, z_abs=z_abs_com) # normalized PTP vectors # this helps to keep the objective in a similar range, so we can use @@ -122,32 +124,65 @@ def localize_amplitude_vectors( normalized_amp_vecs = amplitude_vectors / max_amplitudes[:, None] # -- torch optimize - # initialize with center of mass - locs = torch.column_stack((xcom, torch.full_like(xcom, y0), zcom)) if levenberg_marquardt_kwargs is None: levenberg_marquardt_kwargs = {} - locs, i = batched_levenberg_marquardt( - locs, - vmap_point_source_grad_and_mse, - vmap_point_source_hessian, - extra_args=(normalized_amp_vecs, in_probe_mask, local_geoms), - **levenberg_marquardt_kwargs, - ) - # finish: get alpha closed form - x, y0, z_rel = locs.T - y = F.softplus(y0) - alpha = vmap_point_source_find_alpha( - amplitude_vectors, in_probe_mask, x, y, z_rel, local_geoms - ) - z_abs = z_rel + geom[main_channels, 1] + # initialize with center of mass + locs = torch.column_stack((xcom, torch.full_like(xcom, y0), zcom)) + + if model == "pointsource": + locs, i = batched_levenberg_marquardt( + locs, + vmap_point_source_grad_and_mse, + vmap_point_source_hessian, + extra_args=(normalized_amp_vecs, in_probe_mask, local_geoms), + **levenberg_marquardt_kwargs, + ) + + # finish: get alpha closed form + x, y0, z_rel = locs.T + y = F.softplus(y0) + alpha = vmap_point_source_find_alpha( + amplitude_vectors, in_probe_mask, x, y, z_rel, local_geoms + ) + z_abs = z_rel + geom[main_channels, 1] + + results = dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha) + if is_numpy: + results = {k: v.numpy(force=True) for k, v in results.items()} + return results + + elif model == "dipole": + locs, i = batched_levenberg_marquardt( + locs, + vmap_dipole_grad_and_mse, + vmap_dipole_hessian, + extra_args=(normalized_amp_vecs, local_geoms), + **levenberg_marquardt_kwargs, + ) - return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=alpha) + x, y0, z_rel = locs.T + y = F.softplus(y0) + projected_dist = vmap_dipole_find_projection_distance( + normalized_amp_vecs, x, y, z_rel, local_geoms + ) + z_abs = z_rel + geom[main_channels, 1] -# -- point source model library functions + if is_numpy: + x = x.numpy(force=True) + y = y.numpy(force=True) + z_rel = z_rel.numpy(force=True) + z_abs = z_abs.numpy(force=True) + alpha = alpha.numpy(force=True) + return dict(x=x, y=y, z_rel=z_rel, z_abs=z_abs, alpha=projected_dist) + else: + assert False + + +# -- point source / dipole model library functions def point_source_amplitude_at(x, y, z, alpha, local_geom): """Point source model predicted amplitude at local_geom given location""" dxs = torch.square(x - local_geom[:, 0]) @@ -165,10 +200,23 @@ def point_source_find_alpha(amp_vec, channel_mask, x, y, z, local_geoms): ) return alpha +def dipole_find_projection_distance(normalized_amp_vec, x, y, z, local_geom): + """COmpute a value dist/dipole in x,z that tells us if dipole or monopole is better""" + + dxs = x - local_geom[:, 0] + dzs = z - local_geom[:, 1] + dys = y.expand(dzs.size()) + duv = torch.stack([dxs, dys, dzs], dim=1) + X = duv / torch.pow(torch.sum(torch.square(duv), dim=1), 3/2)[:, None] + beta = torch.matmul(torch.linalg.pinv(torch.matmul(X.T, X)), torch.matmul(X.T, normalized_amp_vec)) + beta /= torch.sqrt(torch.square(beta).sum()) + dipole_planar_direction = torch.sqrt(torch.square(beta[[0, 2]]).sum()) + closest_chan = torch.argmin(torch.sum(torch.square(duv), dim=1)) + min_duv = duv[closest_chan[None]][0] #workaround around vmap doesn't work for one dim tensor .item() + val_th = torch.sqrt(torch.square(min_duv).sum())/dipole_planar_direction + return val_th -def point_source_mse( - loc, amplitude_vector, channel_mask, local_geom, logbarrier=True -): +def point_source_mse(loc, amplitude_vector, channel_mask, local_geom, logbarrier=True): """Objective in point source model Arguments @@ -191,12 +239,9 @@ def point_source_mse( x, y0, z = loc y = F.softplus(y0) - alpha = point_source_find_alpha( - amplitude_vector, channel_mask, x, y, z, local_geom - ) + alpha = point_source_find_alpha(amplitude_vector, channel_mask, x, y, z, local_geom) obj = torch.square( - amplitude_vector - - point_source_amplitude_at(x, y, z, alpha, local_geom) + amplitude_vector - point_source_amplitude_at(x, y, z, alpha, local_geom) ).mean() if logbarrier: obj -= torch.log(10.0 * y) / 10000.0 @@ -205,7 +250,35 @@ def point_source_mse( return obj +def dipole_mse(loc, amplitude_vector, local_geom, logbarrier=True): + """Dipole model predicted amplitude at local_geom given location""" + + x, y0, z = loc + y = F.softplus(y0) + + dxs = x - local_geom[:, 0] + dzs = z - local_geom[:, 1] + dys = y.expand(dzs.size()) + + duv = torch.stack([dxs, dys, dzs], dim=1) + X = duv / torch.pow(torch.sum(torch.square(duv), dim=1), 3/2)[:, None] + # beta = torch.linalg.lstsq(X, amplitude_vector[:, None])[0] + # beta = torch.linalg.solve(torch.matmul(X.T, X), torch.matmul(X.T, amplitude_vector)) + beta = torch.matmul(torch.linalg.pinv(torch.matmul(X.T, X)), torch.matmul(X.T, amplitude_vector)) + qtq = torch.matmul(X, beta) + + obj = torch.square(amplitude_vector - qtq).mean() + if logbarrier: + obj -= torch.log(10.0 * y) / 10000.0 + + return obj + + # vmapped functions for use in the optimizer, and might be handy for users too vmap_point_source_grad_and_mse = vmap(grad_and_value(point_source_mse)) vmap_point_source_hessian = vmap(hessian(point_source_mse)) vmap_point_source_find_alpha = vmap(point_source_find_alpha) + +vmap_dipole_grad_and_mse = vmap(grad_and_value(dipole_mse)) +vmap_dipole_hessian = vmap(hessian(dipole_mse)) +vmap_dipole_find_projection_distance = vmap(dipole_find_projection_distance) diff --git a/src/dartsort/localize/localize_util.py b/src/dartsort/localize/localize_util.py index b0cc1b59..f85e0ffb 100644 --- a/src/dartsort/localize/localize_util.py +++ b/src/dartsort/localize/localize_util.py @@ -43,6 +43,7 @@ def localize_hdf5( spikes_per_batch=100_000, show_progress=True, device=None, + localization_model="pointsource", ): """Run localization on a HDF5 file with stored amplitude vectors @@ -100,6 +101,7 @@ def localize_hdf5( channel_index=channel_index, radius=radius, n_channels_subset=n_channels_subset, + model=localization_model, ) xyza_batch = np.c_[ locs["x"].cpu().numpy(), diff --git a/src/dartsort/main.py b/src/dartsort/main.py index ac122b69..806d1f44 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -3,9 +3,9 @@ from dartsort.config import (FeaturizationConfig, ClusteringConfig, MatchingConfig, SubtractionConfig, TemplateConfig) from dartsort.localize.localize_util import localize_hdf5 -# from dartsort.peel import (ResidualUpdateTemplateMatchingPeeler, -# SubtractionPeeler) from dartsort.cluster.initial import ensemble_chunks +from dartsort.peel import (ObjectiveUpdateTemplateMatchingPeeler, + SubtractionPeeler) from dartsort.templates import TemplateData from dartsort.util.data_util import DARTsortSorting, check_recording @@ -102,7 +102,10 @@ def match( device=None, hdf5_filename="matching0.h5", model_subdir="matching0_models", + template_npz_filename="template_data.npz", ): + model_dir = Path(output_directory) / model_subdir + # compute templates template_data = TemplateData.from_config( recording, @@ -110,11 +113,14 @@ def match( template_config, motion_est=motion_est, n_jobs=n_jobs_templates, - save_folder=output_directory, + save_folder=model_dir, overwrite=overwrite, + device=device, + save_npz_name=template_npz_filename, ) + # instantiate peeler - matching_peeler = ResidualUpdateTemplateMatchingPeeler.from_config( + matching_peeler = ObjectiveUpdateTemplateMatchingPeeler.from_config( recording, matching_config, featurization_config, @@ -137,7 +143,7 @@ def match( return sorting, output_hdf5_filename -# -- helper function +# -- helper function for subtract, match def _run_peeler( @@ -157,6 +163,8 @@ def _run_peeler( output_directory.mkdir(exist_ok=True) model_dir = output_directory / model_subdir output_hdf5_filename = output_directory / hdf5_filename + if residual_filename is not None: + residual_filename = output_directory / residual_filename # fit models if needed peeler.load_or_fit_and_save_models( @@ -171,10 +179,11 @@ def _run_peeler( overwrite=overwrite, residual_filename=residual_filename, show_progress=show_progress, + device=device, ) # do localization - if featurization_config.do_localization: + if not featurization_config.denoise_only and featurization_config.do_localization: wf_name = featurization_config.output_waveforms_name localize_hdf5( output_hdf5_filename, @@ -182,6 +191,7 @@ def _run_peeler( amplitude_vectors_dataset_name=f"{wf_name}_amplitude_vectors", show_progress=show_progress, device=device, + localization_model=featurization_config.localization_model ) return ( diff --git a/src/dartsort/peel/__init__.py b/src/dartsort/peel/__init__.py index aa03dd24..811f3f96 100644 --- a/src/dartsort/peel/__init__.py +++ b/src/dartsort/peel/__init__.py @@ -1,2 +1,2 @@ -from .matching import ResidualUpdateTemplateMatchingPeeler +from .matching import ObjectiveUpdateTemplateMatchingPeeler from .subtract import SubtractionPeeler, subtract_chunk diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index b5cfb32f..bf17020b 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -12,17 +12,19 @@ import numpy as np import torch import torch.nn.functional as F -from dartsort.detect import detect_and_deduplicate from dartsort.templates import template_util +from dartsort.templates.pairwise import CompressedPairwiseConv from dartsort.transform import WaveformPipeline -from dartsort.util import spiketorch +from dartsort.util import drift_util, spiketorch from dartsort.util.data_util import SpikeDataset from dartsort.util.waveform_util import make_channel_index +from scipy.spatial import KDTree +from scipy.spatial.distance import pdist from .peel_base import BasePeeler -class ResidualUpdateTemplateMatchingPeeler(BasePeeler): +class ObjectiveUpdateTemplateMatchingPeeler(BasePeeler): peel_kind = "TemplateMatching" def __init__( @@ -33,56 +35,53 @@ def __init__( featurization_pipeline, motion_est=None, svd_compression_rank=10, + coarse_objective=True, temporal_upsampling_factor=8, upsampling_peak_window_radius=8, min_channel_amplitude=1.0, refractory_radius_frames=10, amplitude_scaling_variance=0.0, amplitude_scaling_boundary=0.5, + conv_ignore_threshold=5.0, + coarse_approx_error_threshold=5.0, trough_offset_samples=42, - threshold=30.0, + threshold=50.0, chunk_length_samples=30_000, n_chunks_fit=40, fit_subsampling_random_state=0, max_iter=1000, ): - n_templates, spike_length_samples = template_data.templates.shape[:2] super().__init__( recording=recording, channel_index=channel_index, featurization_pipeline=featurization_pipeline, chunk_length_samples=chunk_length_samples, - chunk_margin_samples=2 * spike_length_samples, + chunk_margin_samples=2 * template_data.templates.shape[1], n_chunks_fit=n_chunks_fit, fit_subsampling_random_state=fit_subsampling_random_state, ) - # process templates - ( - temporal_components, - singular_values, - spatial_components, - ) = template_util.svd_compress_templates( - template_data.templates, - min_channel_amplitude=min_channel_amplitude, - rank=svd_compression_rank, - ) - self.handle_upsampling( - temporal_components, - temporal_upsampling_factor=temporal_upsampling_factor, - upsampling_peak_window_radius=upsampling_peak_window_radius, - ) - # main properties + self.template_data = template_data + self.coarse_objective = coarse_objective + self.temporal_upsampling_factor = temporal_upsampling_factor + self.upsampling_peak_window_radius = upsampling_peak_window_radius + self.svd_compression_rank = svd_compression_rank + self.min_channel_amplitude = min_channel_amplitude self.threshold = threshold + self.conv_ignore_threshold = conv_ignore_threshold + self.coarse_approx_error_threshold = coarse_approx_error_threshold self.refractory_radius_frames = refractory_radius_frames self.max_iter = max_iter - self.n_templates = n_templates - self.spike_length_samples = spike_length_samples + self.n_templates, self.spike_length_samples = template_data.templates.shape[:2] + self.trough_offset_samples = trough_offset_samples self.geom = recording.get_channel_locations() - self.svd_compression_rank = svd_compression_rank self.n_channels = len(self.geom) - self.obj_pad_len = max(refractory_radius_frames, upsampling_peak_window_radius) + self.obj_pad_len = max( + refractory_radius_frames, + upsampling_peak_window_radius, + self.spike_length_samples - 1, + ) self.n_registered_channels = ( len(template_data.registered_geom) if template_data.registered_geom is not None @@ -93,15 +92,6 @@ def __init__( self.channel_index = channel_index self.registered_template_ampvecs = template_data.templates.ptp(1) - # torch buffers - self.register_buffer("temporal_components", torch.tensor(temporal_components)) - self.register_buffer("singular_values", torch.tensor(singular_values)) - self.register_buffer("spatial_components", torch.tensor(spatial_components)) - self.register_buffer( - "_refrac_ix", - torch.arange(-refractory_radius_frames, refractory_radius_frames + 1), - ) - # amplitude scaling properties self.is_scaling = bool(amplitude_scaling_variance) self.amplitude_scaling_variance = amplitude_scaling_variance @@ -113,23 +103,42 @@ def __init__( self.motion_est = motion_est self.registered_geom = template_data.registered_geom self.registered_template_depths_um = template_data.registered_template_depths_um - - self.handle_template_groups(template_data.unit_ids) - self.check_shapes() - - self.fixed_output_data += [ - ("temporal_components", temporal_components), - ("singular_values", singular_values), - ("spatial_components", spatial_components), - ( - "upsampled_temporal_components", - self.upsampled_temporal_components.numpy(force=True).copy(), - ), - ] if self.is_drifting: self.fixed_output_data.append( ("registered_geom", template_data.registered_geom) ) + self.registered_geom_kdtree = KDTree(self.registered_geom) + self.geom_kdtree = KDTree(self.geom) + self.match_distance = pdist(self.geom).min() / 2.0 + + # some parts of this constructor are deferred to precompute_peeling_data + self._needs_precompute = True + + def peeling_needs_fit(self): + return self._needs_precompute + + def precompute_peeling_data(self, save_folder, n_jobs=0, device=None): + self.build_template_data( + save_folder, + self.template_data, + temporal_upsampling_factor=self.temporal_upsampling_factor, + upsampling_peak_window_radius=self.upsampling_peak_window_radius, + svd_compression_rank=self.svd_compression_rank, + min_channel_amplitude=self.min_channel_amplitude, + dtype=self.recording.dtype, + n_jobs=n_jobs, + device=device, + ) + # couple more torch buffers + self.register_buffer( + "_refrac_ix", + torch.arange( + -self.refractory_radius_frames, self.refractory_radius_frames + 1 + ), + ) + self.register_buffer("_rank_ix", torch.arange(self.svd_compression_rank)) + self.check_shapes() + self._needs_precompute = False def out_datasets(self): datasets = super().out_datasets() @@ -157,27 +166,39 @@ def check_shapes(self): self.svd_compression_rank, self.n_registered_channels, ) - assert self.upsampled_temporal_components.shape == ( - self.n_templates, - self.spike_length_samples, - self.temporal_upsampling_factor, - self.svd_compression_rank, - ) assert self.unit_ids.shape == (self.n_templates,) - def handle_template_groups(self, unit_ids): - self.unit_ids = unit_ids + def handle_template_groups(self, obj_unit_ids, unit_ids): + """Grouped templates in objective + + If not coarse_objective, then several rows of the objective may + belong to the same unit. They must be handled together when imposing + refractory conditions. + """ + self.register_buffer("unit_ids", torch.from_numpy(unit_ids)) + self.register_buffer("obj_unit_ids", torch.from_numpy(obj_unit_ids)) + units, fine_to_coarse, counts = np.unique( + unit_ids, return_counts=True, return_inverse=True + ) + self.register_buffer("fine_to_coarse", torch.from_numpy(fine_to_coarse)) self.grouped_temps = True unique_units = np.unique(unit_ids) if unique_units.size == unit_ids.size: self.grouped_temps = False if not self.grouped_temps: + self.register_buffer("superres_index", torch.arange(len(unit_ids))[:, None]) return - assert unit_ids.shape == (self.n_templates,) - group_index = [np.flatnonzero(unit_ids == u) for u in unit_ids] - max_group_size = max(map(len, group_index)) + + superres_index = np.full((len(obj_unit_ids), counts.max()), self.n_templates) + for j, u in enumerate(obj_unit_ids): + my_sup = np.flatnonzero(unit_ids == u) + superres_index[j, : len(my_sup)] = my_sup + self.register_buffer("superres_index", torch.from_numpy(superres_index)) + + if self.coarse_objective: + return # like a channel index, sort of # this is a n_templates x group_size array that maps each @@ -185,53 +206,151 @@ def handle_template_groups(self, unit_ids): # are part of its group. so that the array is not ragged, # we pad rows with -1s when their group is smaller than the # largest group. - group_index = np.full((self.n_templates, max_group_size), -1) - for j, row in enumerate(group_index): + group_index = np.full((self.n_templates, counts.max()), -1) + for j, u in enumerate(unit_ids): + row = np.flatnonzero(unit_ids == u) group_index[j, : len(row)] = row - self.group_index = torch.tensor(group_index) + self.register_buffer("group_index", torch.from_numpy(group_index)) - def handle_upsampling( + def build_template_data( self, - temporal_components, + save_folder, + template_data, temporal_upsampling_factor=8, upsampling_peak_window_radius=8, + svd_compression_rank=10, + min_channel_amplitude=1.0, + dtype=np.float32, + n_jobs=0, + device=None, ): - self.temporal_upsampling_factor = temporal_upsampling_factor - upsampled_temporal_components = temporal_components - if temporal_upsampling_factor > 1: - upsampled_temporal_components = template_util.temporally_upsample_templates( - temporal_components, - temporal_upsampling_factor=temporal_upsampling_factor, + low_rank_templates = template_util.svd_compress_templates( + template_data.templates, + min_channel_amplitude=min_channel_amplitude, + rank=svd_compression_rank, + ) + temporal_components = low_rank_templates.temporal_components.astype(dtype) + singular_values = low_rank_templates.singular_values.astype(dtype) + spatial_components = low_rank_templates.spatial_components.astype(dtype) + self.register_buffer("temporal_components", torch.tensor(temporal_components)) + self.register_buffer("singular_values", torch.tensor(singular_values)) + self.register_buffer("spatial_components", torch.tensor(spatial_components)) + compressed_upsampled_temporal = self.handle_upsampling( + temporal_components, + ptps=template_data.templates.ptp(1).max(1), + temporal_upsampling_factor=temporal_upsampling_factor, + upsampling_peak_window_radius=upsampling_peak_window_radius, + ) + + # handle the case where objective is not superres + if self.coarse_objective: + coarse_template_data = template_data.coarsen() + coarse_low_rank_templates = template_util.svd_compress_templates( + coarse_template_data.templates, + min_channel_amplitude=min_channel_amplitude, + rank=svd_compression_rank, + ) + temporal_components = coarse_low_rank_templates.temporal_components.astype( + dtype + ) + singular_values = coarse_low_rank_templates.singular_values.astype(dtype) + spatial_components = coarse_low_rank_templates.spatial_components.astype( + dtype + ) + self.objective_template_depths_um = ( + coarse_template_data.registered_template_depths_um ) self.register_buffer( - "upsampled_temporal_components", - torch.tensor(upsampled_temporal_components), + "objective_temporal_components", torch.tensor(temporal_components) ) self.register_buffer( - "upsampling_window", - torch.arange( - -upsampling_peak_window_radius, upsampling_peak_window_radius + 1 - ), + "objective_singular_values", torch.tensor(singular_values) ) - self.upsampling_window_len = 2 * upsampling_peak_window_radius - center = upsampling_peak_window_radius * temporal_upsampling_factor - radius = temporal_upsampling_factor // 2 + temporal_upsampling_factor % 2 self.register_buffer( - "upsampled_peak_search_window", - torch.arange(center - radius, center + radius + 1), + "objective_spatial_components", torch.tensor(spatial_components) ) + self.obj_n_templates = spatial_components.shape[0] + else: + coarse_template_data = template_data + coarse_low_rank_templates = low_rank_templates + self.objective_template_depths_um = self.registered_template_depths_um self.register_buffer( - "peak_to_upsampling_index", - torch.concatenate( - [ - torch.arange(radius, -1, -1), - (temporal_upsampling_factor - 1) - torch.arange(radius), - ] - ), + "objective_temporal_components", self.temporal_components ) + self.register_buffer("objective_singular_values", self.singular_values) self.register_buffer( - "peak_to_time_shift", torch.tensor([0] * (radius + 1) + [1] * radius) + "objective_spatial_components", self.spatial_components ) + self.obj_n_templates = self.n_templates + self.handle_template_groups( + coarse_template_data.unit_ids, self.template_data.unit_ids + ) + convlen = self.chunk_length_samples + self.chunk_margin_samples + block_size, *_ = spiketorch._calc_oa_lens(convlen, self.spike_length_samples) + self.register_buffer("objective_temporalf", torch.fft.rfft(self.objective_temporal_components, dim=1, n=block_size)) + + half_chunk = self.chunk_length_samples // 2 + chunk_starts = np.arange( + 0, self.recording.get_num_samples(), self.chunk_length_samples + ) + chunk_ends = np.minimum(chunk_starts + self.chunk_length_samples, self.recording.get_num_samples()) + chunk_centers_samples = (chunk_starts + chunk_ends) / 2 + chunk_centers_s = self.recording._recording_segments[0].sample_index_to_time( + chunk_centers_samples + ) + self.pairwise_conv_db = CompressedPairwiseConv.from_template_data( + save_folder / "pconv.h5", + template_data=coarse_template_data, + low_rank_templates=coarse_low_rank_templates, + template_data_b=template_data, + low_rank_templates_b=low_rank_templates, + compressed_upsampled_temporal=compressed_upsampled_temporal, + chunk_time_centers_s=chunk_centers_s, + motion_est=self.motion_est, + geom=self.geom, + conv_ignore_threshold=self.conv_ignore_threshold, + coarse_approx_error_threshold=self.coarse_approx_error_threshold, + device=device, + n_jobs=n_jobs, + ) + + self.fixed_output_data += [ + ("temporal_components", temporal_components), + ("singular_values", singular_values), + ("spatial_components", spatial_components), + ] + + def handle_upsampling( + self, + temporal_components, + ptps, + temporal_upsampling_factor=8, + upsampling_peak_window_radius=8, + ): + compressed_upsampled_temporal = template_util.compressed_upsampled_templates( + temporal_components, + ptps=ptps, + max_upsample=temporal_upsampling_factor, + ) + self.register_buffer( + "compressed_upsampling_map", + torch.tensor(compressed_upsampled_temporal.compressed_upsampling_map), + ) + self.register_buffer( + "compressed_upsampling_index", + torch.tensor(compressed_upsampled_temporal.compressed_upsampling_index), + ) + self.register_buffer( + "compressed_index_to_upsampling_index", + torch.tensor( + compressed_upsampled_temporal.compressed_index_to_upsampling_index + ), + ) + self.register_buffer( + "compressed_upsampled_temporal", + torch.tensor(compressed_upsampled_temporal.compressed_upsampled_templates), + ) + return compressed_upsampled_temporal @classmethod def from_config( @@ -261,6 +380,8 @@ def from_config( refractory_radius_frames=matching_config.refractory_radius_frames, amplitude_scaling_variance=matching_config.amplitude_scaling_variance, amplitude_scaling_boundary=matching_config.amplitude_scaling_boundary, + conv_ignore_threshold=matching_config.conv_ignore_threshold, + coarse_approx_error_threshold=matching_config.coarse_approx_error_threshold, trough_offset_samples=matching_config.trough_offset_samples, threshold=matching_config.threshold, chunk_length_samples=matching_config.chunk_length_samples, @@ -276,10 +397,11 @@ def peel_chunk( left_margin=0, right_margin=0, return_residual=False, + return_conv=False, ): # get current template set chunk_center_samples = chunk_start_samples + self.chunk_length_samples // 2 - + segment = self.recording._recording_segments[0] chunk_center_seconds = segment.sample_index_to_time(chunk_center_samples) compressed_template_data = self.templates_at_time(chunk_center_seconds) @@ -288,10 +410,12 @@ def peel_chunk( match_results = self.match_chunk( traces, compressed_template_data, - trough_offset_samples=42, - left_margin=0, - right_margin=0, - threshold=30, + trough_offset_samples=self.trough_offset_samples, + left_margin=left_margin, + right_margin=right_margin, + threshold=self.threshold, + return_residual=return_residual, + return_conv=return_conv, ) # process spike times and create return result @@ -300,35 +424,85 @@ def peel_chunk( return match_results def templates_at_time(self, t_s): - """Extract the right spatial components for each unit.""" + """Handle drift -- grab the right spatial neighborhoods.""" + pconvdb = self.pairwise_conv_db + pitch_shifts_a = pitch_shifts_b = None + if self.objective_spatial_components.device.type == "cuda" and not pconvdb.device.type == "cuda": + pconvdb.to(self.objective_spatial_components.device) if self.is_drifting: - cur_spatial = template_util.templates_at_time( + pitch_shifts_b, cur_spatial = template_util.templates_at_time( t_s, self.spatial_components, self.geom, registered_template_depths_um=self.registered_template_depths_um, registered_geom=self.registered_geom, motion_est=self.motion_est, + return_pitch_shifts=True, + geom_kdtree=self.geom_kdtree, + match_distance=self.match_distance, ) - cur_ampvecs = template_util.templates_at_time( - t_s, + if self.coarse_objective: + pitch_shifts_a, cur_obj_spatial = template_util.templates_at_time( + t_s, + self.objective_spatial_components, + self.geom, + registered_template_depths_um=self.objective_template_depths_um, + registered_geom=self.registered_geom, + motion_est=self.motion_est, + return_pitch_shifts=True, + geom_kdtree=self.geom_kdtree, + match_distance=self.match_distance, + ) + else: + cur_obj_spatial = cur_spatial + pitch_shifts_a = pitch_shifts_b + cur_ampvecs = drift_util.get_waveforms_on_static_channels( self.registered_template_ampvecs[:, None, :], - self.geom, - registered_template_depths_um=self.registered_template_depths_um, - registered_geom=self.registered_geom, - motion_est=self.motion_est, + self.registered_geom, + n_pitches_shift=pitch_shifts_b, + registered_geom=self.geom, + target_kdtree=self.geom_kdtree, + match_distance=self.match_distance, + fill_value=0.0, ) max_channels = cur_ampvecs[:, 0, :].argmax(1) + # pitch_shifts_a = torch.as_tensor(pitch_shifts_a) + # pitch_shifts_b = torch.as_tensor(pitch_shifts_b) + pitch_shifts_a = torch.as_tensor(pitch_shifts_a, device=cur_obj_spatial.device) + pitch_shifts_b = torch.as_tensor(pitch_shifts_b, device=cur_obj_spatial.device) + pconvdb = pconvdb.at_shifts(pitch_shifts_a, pitch_shifts_b) + # pitch_shifts_a = torch.as_tensor(pitch_shifts_a, device=cur_obj_spatial.device) + # pitch_shifts_b = torch.as_tensor(pitch_shifts_b, device=cur_obj_spatial.device) else: cur_spatial = self.spatial_components + cur_obj_spatial = self.objective_spatial_components max_channels = self.registered_template_ampvecs.argmax(1) - return CompressedTemplateData( - cur_spatial, - self.singular_values, - self.temporal_components, - self.upsampled_temporal_components, - max_channels, + # if not pconvdb._is_torch: + # pconvdb.to("cpu") + # if cur_obj_spatial.device.type == "cuda" and not pconvdb.device.type == "cuda": + # pconvdb.to(cur_obj_spatial.device, pin=True) + + return MatchingTemplateData( + objective_spatial_components=cur_obj_spatial, + objective_singular_values=self.objective_singular_values, + objective_temporal_components=self.objective_temporal_components, + objective_temporalf=self.objective_temporalf, + fine_to_coarse=self.fine_to_coarse, + coarse_objective=self.coarse_objective, + spatial_components=cur_spatial, + singular_values=self.singular_values, + temporal_components=self.temporal_components, + compressed_upsampling_map=self.compressed_upsampling_map, + compressed_upsampling_index=self.compressed_upsampling_index, + compressed_index_to_upsampling_index=self.compressed_index_to_upsampling_index, + compressed_upsampled_temporal=self.compressed_upsampled_temporal, + max_channels=torch.as_tensor(max_channels, device=cur_obj_spatial.device), + pairwise_conv_db=pconvdb, + shifts_a=None, + shifts_b=None, + # shifts_a=pitch_shifts_a, + # shifts_b=pitch_shifts_b, ) def match_chunk( @@ -339,6 +513,8 @@ def match_chunk( left_margin=0, right_margin=0, threshold=30, + return_residual=False, + return_conv=False, ): """Core peeling routine for subtraction""" # initialize residual, it needs to be padded to support our channel @@ -348,57 +524,88 @@ def match_chunk( residual = residual_padded[:, :-1] # name objective variables so that we can update them in-place later - conv = None conv_len = traces.shape[0] - self.spike_length_samples + 1 padded_obj_len = conv_len + 2 * self.obj_pad_len + padded_conv = torch.zeros( + self.obj_n_templates, + padded_obj_len, + dtype=traces.dtype, + device=traces.device, + ) padded_objective = torch.zeros( - self.n_templates + 1, + self.obj_n_templates + 1, padded_obj_len, dtype=traces.dtype, device=traces.device, ) + refrac_mask = torch.zeros_like(padded_objective) # padded objective has an extra unit (for group_index) and refractory # padding (for easier implementation of enforce_refractory) - objective = padded_objective[ - :-1, self.refractory_radius_frames : -self.refractory_radius_frames - ] - neg_temp_normsq = -compressed_template_data.template_norms_squared[:, None] # manages buffers for spike train data (peak times, labels, etc) - peaks = MatchingPeaks() + peaks = MatchingPeaks(device=traces.device) - # main loop - for _ in range(self.max_iter): - # update objective - conv = compressed_template_data.convolve(residual, out=conv) - # unscaled objective for coarse peaks, scaled when finding high res peak - torch.add(neg_temp_normsq, conv, alpha=2.0, out=objective) + # initialize convolution + compressed_template_data.convolve( + residual, padding=self.obj_pad_len, out=padded_conv + ) + # main loop + for it in range(self.max_iter): # find high-res peaks - new_peaks = self.find_peaks(conv, padded_objective, peaks, neg_temp_normsq) + new_peaks = self.find_peaks( + residual, padded_conv, padded_objective, refrac_mask, compressed_template_data + ) if new_peaks is None: break + # enforce refractoriness + self.enforce_refractory( + refrac_mask, + new_peaks.times + self.obj_pad_len, + new_peaks.objective_template_indices, + new_peaks.template_indices, + ) + # subtract them - # offset times: conv result peaks with valid padding are offset - # by spike len - 1 samples from the corresponding trace peaks - sample_times = new_peaks.times + self.spike_length_samples - 1 + # old_norm = torch.linalg.norm(residual) ** 2 compressed_template_data.subtract( - residual, - sample_times, + residual_padded, + new_peaks.times, new_peaks.template_indices, new_peaks.upsampling_indices, new_peaks.scalings, ) + compressed_template_data.subtract_conv( + padded_conv, + new_peaks.times, + new_peaks.template_indices, + new_peaks.upsampling_indices, + new_peaks.scalings, + conv_pad_len=self.obj_pad_len, + ) + + # new_norm = torch.linalg.norm(residual) ** 2 + # print(f"{it=} {new_norm=}") + # print(f"{(new_norm-old_norm)=}") + # print(f"{new_peaks.n_spikes=}") + # print(f"{new_peaks.scores.mean().numpy(force=True)=}") + # print("----------") # update spike train peaks.extend(new_peaks) peaks.sort() # extract collision-cleaned waveforms on small neighborhoods - channels, waveforms = self.get_collisioncleaned_waveforms() + channels, waveforms = compressed_template_data.get_collisioncleaned_waveforms( + residual_padded, + peaks, + self.channel_index, + spike_length_samples=self.spike_length_samples, + ) - return dict( + res = dict( + n_spikes=peaks.n_spikes, times_samples=peaks.times + self.trough_offset_samples, channels=channels, labels=self.unit_ids[peaks.template_indices], @@ -408,147 +615,121 @@ def match_chunk( scores=peaks.scores, collisioncleaned_waveforms=waveforms, ) + if return_residual: + res["residual"] = residual + if return_conv: + res["conv"] = padded_conv + return res - def find_peaks(self, conv, padded_objective, peaks, neg_temp_normsq): - # zeroth step: enforce refractoriness. - self.enforce_refractory( - padded_objective, - peaks.times + self.obj_pad_len, - peaks.template_indices, + def find_peaks( + self, + residual, + padded_conv, + padded_objective, + refrac_mask, + compressed_template_data, + ): + # update the coarse objective + torch.add( + compressed_template_data.objective_template_norms_squared.neg()[:, None], + padded_conv, + alpha=2.0, + out=padded_objective[:-1], ) # first step: coarse peaks. not temporally upsampled or amplitude-scaled. - objective = padded_objective[:-1, self.obj_pad_len : -self.obj_pad_len] - times, template_indices = detect_and_deduplicate( - objective.T, - self.threshold, - dedup_channel_index=None, - peak_sign="pos", - # add 1 here to account for possible time_shifts later - relative_peak_radius=self.spike_length_samples + 1, - dedup_temporal_radius=0, - # dedup_temporal_radius=self.spike_length_samples + 1, - ) + objective = (padded_objective + refrac_mask)[ + :-1, self.obj_pad_len : -self.obj_pad_len + ] + # formerly used detect_and_deduplicate, but that was slow. + objective_max, max_obj_template = objective.max(dim=0) + times = argrelmax(objective_max, self.spike_length_samples, self.threshold) + obj_template_indices = max_obj_template[times] + # remove peaks inside the padding if not times.numel(): return None + residual_snips = None + if self.coarse_objective or self.temporal_upsampling_factor > 1: + residual_snips = spiketorch.grab_spikes_full( + residual, + times - 1, + trough_offset=0, + spike_length_samples=self.spike_length_samples + 1, + ) + # second step: high-res peaks (upsampled and/or amp-scaled) - time_shifts, upsampling_indices, scalings, scores = self.find_fancy_peaks( - conv, objective, times, template_indices, neg_temp_normsq + ( + time_shifts, + upsampling_indices, + scalings, + template_indices, + scores, + ) = compressed_template_data.fine_match( + padded_conv[obj_template_indices, times + self.obj_pad_len], + objective_max[times], + residual_snips, + obj_template_indices, + amp_scale_variance=self.amplitude_scaling_variance, + amp_scale_min=self.amp_scale_min, + amp_scale_max=self.amp_scale_max, + superres_index=self.superres_index, ) if time_shifts is not None: times += time_shifts return MatchingPeaks( - n_spikes=times.size, + n_spikes=times.numel(), times=times, + objective_template_indices=obj_template_indices, template_indices=template_indices, upsampling_indices=upsampling_indices, scalings=scalings, scores=scores, ) - def enforce_refractory(self, objective, times, template_indices): + def enforce_refractory(self, objective, times, objective_template_indices, template_indices): + if not times.numel(): + return # overwrite objective with -inf to enforce refractoriness - time_ix = times[None, :] + self._refrac_ix[:, None] - if self.grouped_temps: - unit_ix = self.group_index[template_indices] - else: - unit_ix = template_indices[:, None] - objective[unit_ix, time_ix] = -torch.inf - - def find_fancy_peaks( - self, conv, objective, times, template_indices, neg_temp_normsq - ): - """Given coarse peaks, find temporally upsampled and scaled ones.""" - # tricky bit. we search for upsampled peaks to the left and right - # of the original peak. when the up-peak comes to the right, we - # use one of the upsampled templates, no problem. when the peak - # comes to the left, it's different: it came from one of the upsampled - # templates shifted one sample (spike time += 1). - if self.up_factor == 1 and not self.is_scaling: - return None, None, None, objective[template_indices, times] - - if self.is_scaling and self.up_factor == 1: - inv_lambda = 1 / self.amplitude_scaling_variance - b = conv[times, template_indices] + inv_lambda - a = neg_temp_normsq[template_indices] + inv_lambda - scalings = torch.clip(b / a, self.amp_scale_min, self.amp_scale_max) - scores = 2.0 * scalings * b - torch.square(scalings) * a - inv_lambda - return None, None, scalings, scores - - # below, we are upsampling. - # get clips of objective function around the peaks - # we'll use the scaled objective here. - time_ix = times[:, None] + self.upsampling_window[None, :] - clip_ix = (template_indices[:, None], time_ix) - upsampled_clip_len = ( - self.upsampling_window_len * self.temporal_upsampling_factor - ) - if self.is_scaling: - high_res_conv = spiketorch.real_resample( - conv[clip_ix], upsampled_clip_len, dim=1 - ) - inv_lambda = 1.0 / self.amplitude_scaling_variance - b = high_res_conv + inv_lambda - a = neg_temp_normsq[template_indices] + inv_lambda - scalings = torch.clip(b / a, self.amp_scale_min, self.amp_scale_max) - high_res_obj = ( - 2.0 * scalings * b - torch.square(scalings) * a[:, None] - inv_lambda - ) + time_ix = times[:, None] + self._refrac_ix[None, :] + if not self.grouped_temps: + row_ix = template_indices[:, None] + elif self.coarse_objective: + row_ix = objective_template_indices[:, None] + elif self.grouped_temps: + row_ix = self.group_index[template_indices] else: - scalings = None - obj_clips = objective[clip_ix] - high_res_obj = spiketorch.real_resample( - obj_clips, upsampled_clip_len, dim=1 - ) - - # zoom into a small upsampled area and determine the - # upsampled template and time shifts - scores, zoom_peak = torch.max( - high_res_obj[:, self.upsampled_peak_search_window], dim=1 - ) - upsampling_indices = self.peak_to_upsampling_index[zoom_peak] - time_shifts = self.peak_to_time_shifts[zoom_peak] - - return time_shifts, upsampling_indices, scalings, scores - - def get_collisioncleaned_waveforms( - self, residual_padded, peaks, compressed_template_data - ): - channels = compressed_template_data.max_channels[peaks.template_indices] - waveforms = spiketorch.grab_spikes( - residual_padded, - peaks.times, - channels, - self.channel_index, - trough_offset=0, - spike_length_samples=self.spike_length_samples, - buffer=0, - already_padded=True, - ) - spatial = compressed_template_data.spatial_singular[ - peaks.template_indices[:, None, None], - :, - self.channel_index[channels][:, None, :], - ] - temporal = compressed_template_data.upsampled_temporal_components[ - peaks.template_indices, - peaks.upsampling_indices, - ] - torch.baddbmm(waveforms, temporal, spatial, out=waveforms) - return channels, waveforms + assert False + objective[row_ix[:, :, None], time_ix[:, None, :]] = -torch.inf @dataclass -class CompressedTemplateData: - """Objects of this class are returned by ResidualUpdateTemplateMatchingPeeler.templates_at_time()""" - +class MatchingTemplateData: + """All the data and math needed for computing convs etc in a single static chunk of data + + This is the 'model' for template matching in a MVC analogy. The class above is the controller. + Objects of this class are returned by ObjectiveUpdateTemplateMatchingPeeler.templates_at_time(), + which handles the drift logic and lets this class be simple. + """ + + objective_spatial_components: torch.Tensor + objective_singular_values: torch.Tensor + objective_temporal_components: torch.Tensor + objective_temporalf: torch.Tensor + fine_to_coarse: torch.LongTensor + coarse_objective: bool spatial_components: torch.Tensor singular_values: torch.Tensor temporal_components: torch.Tensor - upsampled_temporal_components: torch.Tensor + compressed_upsampling_map: torch.LongTensor + compressed_upsampling_index: torch.LongTensor + compressed_index_to_upsampling_index: torch.LongTensor + compressed_upsampled_temporal: torch.Tensor max_channels: torch.LongTensor + pairwise_conv_db: CompressedPairwiseConv + shifts_a: Optional[torch.Tensor] + shifts_b: Optional[torch.Tensor] def __post_init__(self): ( @@ -556,47 +737,259 @@ def __post_init__(self): self.spike_length_samples, self.rank, ) = self.temporal_components.shape - # squared l2 norms are the sums of squared singular values - self.template_norms_squared = torch.square(self.singular_values).sum(1) + assert self.spatial_components.shape[:2] == (self.n_templates, self.rank) + assert self.compressed_upsampled_temporal.shape[1:] == ( + self.spike_length_samples, + self.rank, + ) + assert self.singular_values.shape == (self.n_templates, self.rank) + device = self.spatial_components.device + self.temporal_upsampling_factor = self.compressed_upsampling_index.shape[1] + self.n_compressed_upsampled_templates = self.compressed_upsampling_map.max() + 1 + + # squared l2 norms are usually the sums of squared singular values: + # self.template_norms_squared = torch.square(self.singular_values).sum(1) + # in this case, we have subset the spatial components, so use a diff formula + self.objective_n_templates = self.objective_spatial_components.shape[0] + self.objective_spatial_singular = ( + self.objective_spatial_components + * self.objective_singular_values[:, :, None] + ) self.spatial_singular = ( self.spatial_components * self.singular_values[:, :, None] ) - self.chan_ix = torch.arange( - self.spatial_components.shape[2], device=self.spatial_components.device - ) - self.time_ix = torch.arange( - self.spike_length_samples, device=self.spatial_components.device + self.objective_template_norms_squared = torch.square( + self.objective_spatial_singular + ).sum((1, 2)) + self.template_norms_squared = torch.square(self.spatial_singular).sum((1, 2)) + self.chan_ix = torch.arange(self.spatial_components.shape[2], device=device) + self.rank_ix = torch.arange(self.rank, device=device) + self.time_ix = torch.arange(self.spike_length_samples, device=device) + self.conv_lags = torch.arange( + -self.spike_length_samples + 1, self.spike_length_samples, device=device ) - def convolve(self, traces, out=None): - """This is not the fastest strategy on GPU, but it's low-memory and fast on CPU.""" + def convolve(self, traces, padding=0, out=None): + """Convolve the objective templates with traces.""" + out_len = traces.shape[0] + 2 * padding - self.spike_length_samples + 1 if out is None: - out = torch.zeros( - 1, - self.n_templates, - traces.shape[0] - self.spike_length_samples + 1, + out = torch.empty( + (self.objective_n_templates, out_len), dtype=traces.dtype, device=traces.device, ) else: - assert out.shape == ( - self.n_templates, - traces.shape[0] - self.spike_length_samples + 1, - ) - out = out[None] + assert out.shape == (self.objective_n_templates, out_len) for q in range(self.rank): # units x time - rec_spatial = self.spatial_singular[:, q, :] @ traces.T + rec_spatial = self.objective_spatial_singular[:, q, :] @ traces.T # convolve with temporal components -- units x time - temporal = self.temporal_components[:, :, q] + temporal = self.objective_temporal_components[:, :, q] + temporalf = self.objective_temporalf[:, :, q] # conv1d with groups! only convolve each unit with its own temporal filter. - out += F.conv1d( - rec_spatial[None], temporal[:, None, :], groups=self.n_templates - ) + conv = F.conv1d( + rec_spatial[None], + temporal[:, None, :], + groups=self.objective_n_templates, + padding=padding, + )[0] + # conv = spiketorch.depthwise_oaconv1d( + # rec_spatial, temporal, padding=padding, f2=temporalf + # ) + if q: + out += conv + else: + out.copy_(conv) # back to units x time (remove extra dim used for conv1d) - return out[0] + return out + + def subtract_conv( + self, + conv, + times, + template_indices, + upsampling_indices, + scalings, + conv_pad_len=0, + ): + template_indices_a, template_indices_b, times, pconvs = self.pairwise_conv_db.query( + template_indices_a=None, + template_indices_b=template_indices, + upsampling_indices_b=upsampling_indices, + scalings_b=scalings, + times_b=times, + grid=True, + device=conv.device, + shifts_a=self.shifts_a, + shifts_b=self.shifts_b[template_indices] if self.shifts_b is not None else None, + ) + ix_template = template_indices_a[:, None] + ix_time = times[:, None] + (conv_pad_len + self.conv_lags)[None, :] + spiketorch.add_at_( + conv, + (ix_template, ix_time), + pconvs, + sign=-1, + ) + + def fine_match( + self, + convs, + objs, + residual_snips, + objective_template_indices, + amp_scale_variance=0.0, + amp_scale_min=None, + amp_scale_max=None, + superres_index=None, + ): + """Determine superres ids, temporal upsampling, and scaling + + Given coarse matches (unit ids at times) and the current residual, + pick the best superres template, the best temporal offset, and the + best amplitude scaling. + + We used to upsample the objective to figure out the temporal upsampling, + but with superres in the picture we are now not computing the objective + using the same templates that we temporally upsample. So, instead + we use a greedy strategy: first pick the best (non-temporally upsampled) + superres template, then pick the upsampling and scaling at the same time. + These are all done by dotting everything and computing the objective, + which is probably more expensive than what we had before. + + Returns + ------- + time_shifts : Optional[array] + upsampling_indices : Optional[array] + scalings : Optional[array] + template_indices : array + objs : array + """ + if ( + not self.coarse_objective + and self.temporal_upsampling_factor == 1 + and not amp_scale_variance + ): + return None, None, None, objective_template_indices, objs + + if self.coarse_objective or self.temporal_upsampling_factor > 1: + # snips is a window padded by one sample, so that we have the + # traces snippets at the current times and one step back + n_spikes, window_length_samples, n_chans = residual_snips.shape + spike_length_samples = window_length_samples - 1 + # grab the current traces + snips = residual_snips[:, 1:] + # snips_dt = F.unfold( + # residual_snips[:, None, :, :], (spike_length_samples, snips.shape[2]) + # ) + # snips_dt = snips_dt.reshape( + # len(snips), spike_length_samples, snips.shape[2], 2 + # ) + + if self.coarse_objective: + # TODO best I came up with, but it still syncs + superres_ix = superres_index[objective_template_indices] + dup_ix, column_ix = (superres_ix < self.n_templates).nonzero(as_tuple=True) + template_indices = superres_ix[dup_ix, column_ix] + convs = torch.baddbmm( + self.temporal_components[template_indices], + snips[dup_ix], + self.spatial_singular[template_indices].mT, + ).sum((1, 2)) + # convs = torch.einsum( + # "jtc,jrc,jtr->j", + # snips[dup_ix], + # self.spatial_singular[template_indices], + # self.temporal_components[template_indices], + # ) + norms = self.template_norms_squared[template_indices] + objs = torch.full(superres_ix.shape, -torch.inf, device=convs.device) + objs[dup_ix, column_ix] = 2 * convs - norms + objs, best_column_ix = objs.max(dim=1) + row_ix = torch.arange(best_column_ix.numel(), device=best_column_ix.device) + template_indices = superres_ix[row_ix, best_column_ix] + else: + template_indices = objective_template_indices + norms = self.template_norms_squared[template_indices] + objs = objs + + if self.temporal_upsampling_factor == 1 and not amp_scale_variance: + return None, None, None, template_indices, objs + + if self.temporal_upsampling_factor == 1: + # just scaling + inv_lambda = 1 / amp_scale_variance + b = convs + inv_lambda + a = norms + inv_lambda + scalings = torch.clip(b / a, amp_scale_min, amp_scale_max) + objs = 2 * scalings * b - torch.square(scalings) * a - inv_lambda + return None, None, scalings, template_indices, objs + + # unpack the current traces and the traces one step back + snips_prev = residual_snips[:, :-1] + # snips_dt = torch.stack((snips_prev, snips), dim=3) + + # now, upsampling + # repeat the superres logic, the comp up index acts the same + comp_up_ix = self.compressed_upsampling_index[template_indices] + dup_ix, column_ix = ( + comp_up_ix < self.n_compressed_upsampled_templates + ).nonzero(as_tuple=True) + comp_up_indices = comp_up_ix[dup_ix, column_ix] + # convs = torch.einsum( + # "jtcd,jrc,jtr->jd", + # snips_dt[dup_ix], + # self.spatial_singular[template_indices[dup_ix]], + # self.compressed_upsampled_temporal[comp_up_indices], + # ) + temps = torch.bmm( + self.compressed_upsampled_temporal[comp_up_indices], + self.spatial_singular[template_indices[dup_ix]], + ).view(len(comp_up_indices), -1) + convs = torch.linalg.vecdot(snips[dup_ix].view(len(temps), -1), temps) + convs_prev = torch.linalg.vecdot(snips_prev[dup_ix].view(len(temps), -1), temps) + # convs = torch.einsum( + # "jtc,jrc,jtr->j", + # snips[dup_ix], + # self.spatial_singular[template_indices[dup_ix]], + # self.compressed_upsampled_temporal[comp_up_indices], + # ) + # convs_prev = torch.einsum( + # "jtc,jrc,jtr->j", + # snips_prev[dup_ix], + # self.spatial_singular[template_indices[dup_ix]], + # self.compressed_upsampled_temporal[comp_up_indices], + # ) + better = convs >= convs_prev + convs = torch.maximum(convs, convs_prev) + + norms = norms[dup_ix] + objs = torch.full(comp_up_ix.shape, -torch.inf, device=convs.device) + if amp_scale_variance: + inv_lambda = 1 / amp_scale_variance + b = convs + inv_lambda + a = norms + inv_lambda + scalings = torch.clip(b / a, amp_scale_min, amp_scale_max) + objs[dup_ix, column_ix] = ( + 2 * scalings * b - torch.square(scalings) * a - inv_lambda + ) + else: + objs[dup_ix, column_ix] = 2 * convs - norms + scalings = None + objs, best_column_ix = objs.max(dim=1) + + row_ix = torch.arange(len(objs), device=best_column_ix.device) + comp_up_indices = comp_up_ix[row_ix, best_column_ix] + upsampling_indices = self.compressed_index_to_upsampling_index[comp_up_indices] + + # prev convs were one step earlier + time_shifts = torch.full(comp_up_ix.shape, -1, device=convs.device) + time_shifts[dup_ix, column_ix] += better + time_shifts = time_shifts[row_ix, best_column_ix] + + return time_shifts, upsampling_indices, scalings, template_indices, objs def subtract( self, @@ -605,47 +998,98 @@ def subtract( template_indices, upsampling_indices, scalings, + batch_templates=..., ): + """Subtract templates from traces.""" + compressed_up_inds = self.compressed_upsampling_map[ + template_indices, upsampling_indices + ] batch_templates = torch.einsum( - "n,nrc,ntr", + "n,nrc,ntr->ntc", scalings, self.spatial_singular[template_indices], - self.upsampled_temporal_components[template_indices, upsampling_indices], + self.compressed_upsampled_temporal[compressed_up_inds], ) time_ix = times[:, None, None] + self.time_ix[None, :, None] spiketorch.add_at_( traces, (time_ix, self.chan_ix[None, None, :]), batch_templates, sign=-1 ) + def get_collisioncleaned_waveforms( + self, residual_padded, peaks, channel_index, spike_length_samples=121 + ): + channels = self.max_channels[peaks.template_indices] + waveforms = spiketorch.grab_spikes( + residual_padded, + peaks.times, + channels, + channel_index, + trough_offset=0, + spike_length_samples=spike_length_samples, + buffer=0, + already_padded=True, + ) + padded_spatial = F.pad(self.spatial_singular, (0, 1)) + spatial = padded_spatial[ + peaks.template_indices[:, None, None], + self.rank_ix[None, :, None], + channel_index[channels][:, None, :], + ] + comp_up_ix = self.compressed_upsampling_map[ + peaks.template_indices, peaks.upsampling_indices + ] + temporal = self.compressed_upsampled_temporal[comp_up_ix] + torch.baddbmm(waveforms, temporal, spatial, out=waveforms) + return channels, waveforms + class MatchingPeaks: - BUFFER_INIT: int = 1000 + BUFFER_INIT: int = 1500 BUFFER_GROWTH: float = 1.5 def __init__( self, n_spikes: int = 0, times: Optional[torch.LongTensor] = None, + objective_template_indices: Optional[torch.LongTensor] = None, template_indices: Optional[torch.LongTensor] = None, upsampling_indices: Optional[torch.LongTensor] = None, scalings: Optional[torch.Tensor] = None, scores: Optional[torch.Tensor] = None, + device=None, ): self.n_spikes = n_spikes + self._times = times + self._template_indices = template_indices + self._objective_template_indices = objective_template_indices + self._upsampling_indices = upsampling_indices + self._scalings = scalings + self._scores = scores + + if device is None and times is not None: + device = times.device if times is None: - cur_buf_size = self.BUFFER_INIT - self._times = torch.zeros(cur_buf_size, dtype=int) + self.cur_buf_size = self.BUFFER_INIT + self._times = torch.zeros(self.cur_buf_size, dtype=int, device=device) else: - cur_buf_size = times.size - assert cur_buf_size == n_spikes + self.cur_buf_size = times.numel() + assert self.cur_buf_size == n_spikes if template_indices is None: - self._template_indices = torch.zeros(cur_buf_size, dtype=int) + self._template_indices = torch.zeros( + self.cur_buf_size, dtype=int, device=device + ) + if objective_template_indices is None: + self._objective_template_indices = torch.zeros( + self.cur_buf_size, dtype=int, device=device + ) if scalings is None: - self._scalings = torch.zeros(cur_buf_size) + self._scalings = torch.ones(self.cur_buf_size, device=device) if upsampling_indices is None: - self._upsampling_indices = torch.zeros(cur_buf_size, dtype=int) + self._upsampling_indices = torch.zeros( + self.cur_buf_size, dtype=int, device=device + ) if scores is None: - self._scores = torch.zeros(cur_buf_size) + self._scores = torch.zeros(self.cur_buf_size, device=device) @property def times(self): @@ -654,6 +1098,9 @@ def times(self): @property def template_indices(self): return self._template_indices[: self.n_spikes] + @property + def objective_template_indices(self): + return self._objective_template_indices[: self.n_spikes] @property def upsampling_indices(self): @@ -668,34 +1115,21 @@ def scores(self): return self._scores[: self.n_spikes] def grow_buffers(self, min_size=0): - new_buf_size = max(min_size, int(self.cur_buf_size * self.BUFFER_GROWTH)) - new_times = torch.zeros(new_buf_size, dtype=self._times.dtype) - new_template_indices = torch.zeros( - new_buf_size, dtype=self._template_indices.dtype - ) - new_upsampling_indices = torch.zeros( - new_buf_size, dtype=self._upsampling_indices.dtype - ) - new_scalings = torch.zeros(new_buf_size, dtype=self._scalings.dtype) - new_scores = torch.zeros(new_buf_size, dtype=self._scores.dtype) - - new_times[: self.n_spikes] = self.times - new_template_indices[: self.n_spikes] = self.template_indices - new_upsampling_indices[: self.n_spikes] = self.upsampling_indices - new_scalings[: self.n_spikes] = self.scalings - new_scores[: self.n_spikes] = self.scores - - self.cur_buf_size = new_buf_size - self._times = new_times - self._template_indices = new_template_indices - self._upsampling_indices = new_upsampling_indices - self._scalings = new_scalings - self._scores = new_scores + sz = max(min_size, int(self.cur_buf_size * self.BUFFER_GROWTH)) + k = self.n_spikes + self._times = _grow_buffer(self._times, k, sz) + self._template_indices = _grow_buffer(self._template_indices, k, sz) + self._objective_template_indices = _grow_buffer(self._objective_template_indices, k, sz) + self._upsampling_indices = _grow_buffer(self._upsampling_indices, k, sz) + self._scalings = _grow_buffer(self._scalings, k, sz) + self._scores = _grow_buffer(self._scores, k, sz) + self.cur_buf_size = sz def sort(self): order = torch.argsort(self.times[: self.n_spikes]) self._times[: self.n_spikes] = self.times[order] self._template_indices[: self.n_spikes] = self.template_indices[order] + self._objective_template_indices[: self.n_spikes] = self.objective_template_indices[order] self._upsampling_indices[: self.n_spikes] = self.upsampling_indices[order] self._scalings[: self.n_spikes] = self.scalings[order] self._scores[: self.n_spikes] = self.scores[order] @@ -706,9 +1140,31 @@ def extend(self, other): self.grow_buffers(min_size=new_n_spikes) self._times[self.n_spikes : new_n_spikes] = other.times self._template_indices[self.n_spikes : new_n_spikes] = other.template_indices + self._objective_template_indices[self.n_spikes : new_n_spikes] = other.objective_template_indices self._upsampling_indices[ self.n_spikes : new_n_spikes ] = other.upsampling_indices self._scalings[self.n_spikes : new_n_spikes] = other.scalings self._scores[self.n_spikes : new_n_spikes] = other.scores self.n_spikes = new_n_spikes + + +def _grow_buffer(x, old_length, new_size): + new = torch.empty(new_size, dtype=x.dtype, device=x.device) + new[:old_length] = x[:old_length] + return new + + +def argrelmax(x, radius, threshold, exclude_edge=True): + x1 = F.max_pool1d( + x[None, None], + kernel_size=2 * radius + 1, + padding=radius, + stride=1, + )[0, 0] + x1[x < x1] = 0 + F.threshold_(x1, threshold, 0.0) + ix = torch.nonzero(x1)[:, 0] + if exclude_edge: + return ix[(ix > 0) & (ix < x.numel() - 1)] + return ix diff --git a/src/dartsort/peel/peel_base.py b/src/dartsort/peel/peel_base.py index a0ae5a4e..0a925425 100644 --- a/src/dartsort/peel/peel_base.py +++ b/src/dartsort/peel/peel_base.py @@ -207,7 +207,15 @@ def peel_chunk( raise NotImplementedError - def fit_peeler_models(self, save_folder): + def peeling_needs_fit(self): + return False + + def precompute_peeling_data(self, save_folder, n_jobs=0, device=None): + # subclasses should override if they need to cache data for peeling + # runs before fit_peeler_models() + pass + + def fit_peeler_models(self, save_folder, n_jobs=0, device=None): # subclasses should override if they need to fit models for peeling assert not self.peeling_needs_fit() @@ -270,7 +278,7 @@ def process_chunk(self, chunk_start_samples, return_residual=False): assert not any(k in features for k in peel_result) chunk_result = {**peel_result, **features} chunk_result = { - k: v.cpu().numpy() if torch.is_tensor(v) else v + k: v.numpy(force=True) if torch.is_tensor(v) else v for k, v in chunk_result.items() } @@ -310,15 +318,15 @@ def gather_chunk_result( return n_new_spikes - def peeling_needs_fit(self): - return False - def needs_fit(self): return self.peeling_needs_fit() or self.featurization_pipeline.needs_fit() def fit_models(self, save_folder, n_jobs=0, device=None): with torch.no_grad(): if self.peeling_needs_fit(): + self.precompute_peeling_data( + save_folder=save_folder, n_jobs=n_jobs, device=device + ) self.fit_peeler_models( save_folder=save_folder, n_jobs=n_jobs, device=device ) @@ -510,11 +518,10 @@ def _peeler_process_init(peeler, device, rank_queue, save_residual): def _peeler_process_job(chunk_start_samples): - peeler = _peeler_process_context.peeler # by returning here, we are implicitly relying on pickle # we can replace this with cloudpickle or manual np.save if helpful with torch.no_grad(): - return peeler.process_chunk( + return _peeler_process_context.peeler.process_chunk( chunk_start_samples, return_residual=_peeler_process_context.save_residual, ) diff --git a/src/dartsort/templates/get_templates.py b/src/dartsort/templates/get_templates.py index d8b1f993..58b75d22 100644 --- a/src/dartsort/templates/get_templates.py +++ b/src/dartsort/templates/get_templates.py @@ -10,7 +10,8 @@ from dartsort.util import spikeio from dartsort.util.drift_util import registered_template from dartsort.util.multiprocessing_util import get_pool -from dartsort.util.waveform_util import fast_nanmedian, make_channel_index +from dartsort.util.spiketorch import fast_nanmedian, ptp +from dartsort.util.waveform_util import make_channel_index from scipy.spatial import KDTree from scipy.spatial.distance import pdist from sklearn.decomposition import TruncatedSVD @@ -180,6 +181,7 @@ def get_templates( snr_threshold=denoising_snr_threshold, ) templates = weights * raw_templates + (1 - weights) * low_rank_templates + templates = templates.astype(recording.dtype) return dict( sorting=sorting, @@ -378,13 +380,16 @@ def get_all_shifted_raw_and_low_rank_templates( registered_kdtree = KDTree(registered_geom) n_units = sorting.labels.max() + 1 - raw_templates = np.zeros((n_units, spike_length_samples, n_template_channels)) + raw_templates = np.zeros( + (n_units, spike_length_samples, n_template_channels), dtype=recording.dtype + ) low_rank_templates = None if not raw: low_rank_templates = np.zeros( - (n_units, spike_length_samples, n_template_channels) + (n_units, spike_length_samples, n_template_channels), + dtype=recording.dtype, ) - snrs_by_channel = np.zeros((n_units, n_template_channels)) + snrs_by_channel = np.zeros((n_units, n_template_channels), dtype=recording.dtype) unit_id_chunks = [ unit_ids[i : i + units_per_job] for i in range(0, n_units, units_per_job) @@ -420,6 +425,8 @@ def get_all_shifted_raw_and_low_rank_templates( unit="template", ) for res in results: + if res is None: + continue units_chunk, raw_temps_chunk, low_rank_temps_chunk, snrs_chunk = res raw_templates[units_chunk] = raw_temps_chunk if not raw: @@ -469,19 +476,21 @@ def __init__( self.max_spike_time = recording.get_num_samples() - ( spike_length_samples - trough_offset_samples ) - + self.spike_buffer = torch.zeros( (spikes_per_unit * units_per_job, spike_length_samples, self.n_channels), device=device, dtype=torch.from_numpy(np.zeros(1, dtype=recording.dtype)).dtype, ) + self.n_template_channels = self.n_channels if self.registered: self.geom = recording.get_channel_locations() self.match_distance = pdist(self.geom).min() / 2 self.registered_geom = registered_kdtree.data self.registered_kdtree = registered_kdtree self.pitch_shifts = pitch_shifts + self.n_template_channels = len(self.registered_geom) _template_process_context = None @@ -534,6 +543,8 @@ def _template_job(unit_ids): p = _template_process_context in_units_full = np.flatnonzero(np.isin(p.sorting.labels, unit_ids)) + if not in_units_full.size: + return labels_full = p.sorting.labels[in_units_full] # only so many spikes per unit @@ -560,10 +571,10 @@ def _template_job(unit_ids): # read waveforms for all units times = p.sorting.times_samples[in_units] valid = np.flatnonzero( - (times >= p.trough_offset_samples) & (times < p.max_spike_time) + (times >= p.trough_offset_samples) & (times <= p.max_spike_time) ) if not valid.size: - return uids, 0, 0, 0 + return in_units = in_units[valid] labels = labels[valid] times = times[valid] @@ -573,19 +584,19 @@ def _template_job(unit_ids): trough_offset_samples=p.trough_offset_samples, spike_length_samples=p.spike_length_samples, ) - p.spike_buffer[:times.size] = torch.from_numpy(waveforms) - waveforms = p.spike_buffer[:times.size] + p.spike_buffer[: times.size] = torch.from_numpy(waveforms) + waveforms = p.spike_buffer[: times.size] n, t, c = waveforms.shape # compute raw templates and spike counts per channel raw_templates = [] counts = [] + units_chunk = [] for u in uids: in_unit = np.flatnonzero(labels == u) if not in_unit.size: - raw_templates.append(np.zeros(1)) - counts.append(0) continue + units_chunk.append(u) in_unit_orig = in_units[labels == u] if p.registered: raw_templates.append( @@ -611,12 +622,15 @@ def _template_job(unit_ids): ) ) else: - raw_templates.append(p.reducer(waveforms[in_unit], axis=0)) + raw_templates.append( + p.reducer(waveforms[in_unit], axis=0).numpy(force=True) + ) counts.append(in_unit.size) - snrs_by_chan = [rt.ptp(0) * c for rt, c in zip(raw_templates, counts)] + snrs_by_chan = [ptp(rt, 0) * c for rt, c in zip(raw_templates, counts)] + raw_templates = np.array(raw_templates) if p.denoising_tsvd is None: - return uids, raw_templates, None, snrs_by_chan + return units_chunk, raw_templates, None, snrs_by_chan # apply denoising waveforms = waveforms.permute(0, 2, 1).reshape(n * c, t) @@ -625,11 +639,8 @@ def _template_job(unit_ids): # get low rank templates low_rank_templates = [] - for u in uids: + for u in units_chunk: in_unit = np.flatnonzero(labels == u) - if not in_unit.size: - low_rank_templates.append(0) - continue in_unit_orig = in_units[labels == u] if p.registered: low_rank_templates.append( @@ -644,9 +655,12 @@ def _template_job(unit_ids): ) ) else: - low_rank_templates.append(p.reducer(waveforms[in_unit], axis=0)) + low_rank_templates.append( + p.reducer(waveforms[in_unit], axis=0).numpy(force=True) + ) + low_rank_templates = np.array(low_rank_templates) - return uids, raw_templates, low_rank_templates, snrs_by_chan + return units_chunk, raw_templates, low_rank_templates, snrs_by_chan class TorchSVDProjector(torch.nn.Module): diff --git a/src/dartsort/templates/pairwise.py b/src/dartsort/templates/pairwise.py new file mode 100644 index 00000000..e889746b --- /dev/null +++ b/src/dartsort/templates/pairwise.py @@ -0,0 +1,356 @@ +from dataclasses import dataclass, fields +from typing import Optional + +import h5py +import numpy as np +import torch + +from .pairwise_util import compressed_convolve_to_h5 +from .template_util import CompressedUpsampledTemplates, LowRankTemplates +from .templates import TemplateData + + +@dataclass +class CompressedPairwiseConv: + """A database of channel-summed cross-correlations between template pairs + + There are too many templates to store all of these, especially after + superres binning, temporal upsampling, and pitch shifting. We compress + this as much as possible, first by deduplication (many convolutions of + templates at different shifts are identical), next by not wasting space + (no need to compute as many upsampled copies of small templates), and + finally by approximation (for pairs of far-away units, correlations of + superres templates are very close to correlations of the non-superres + template). + + This database holds some indexing structures that help us store these + correlations sparsely. .query() grabs the actual correlations for the + user. + """ + + # shape: (n_shifts,) + # shift_ix -> shift (pitch shift, an integer) + shifts_a: np.ndarray + shifts_b: np.ndarray + + # shape: (n_templates_a, n_shifts_a) + # (template_ix, shift_ix) -> shifted_template_ix + # shifted_template_ix can be either invalid (this template does not occur + # at this shift), or it can range from 0, ..., n_shifted_templates_a-1 + shifted_template_index_a: np.ndarray + + # shape: (n_templates_b, n_shifts_b, upsampling_factor) + # (template_ix, shift_ix, upsampling_ix) -> upsampled_shifted_template_ix + upsampled_shifted_template_index_b: np.ndarray + + # shape: (n_shifted_templates_a, n_upsampled_shifted_templates_b) + # (shifted_template_ix, upsampled_shifted_template_ix) -> pconv_ix + pconv_index: np.ndarray + + # shape: (n_pconvs, 2 * spike_length_samples - 1) + # pconv_ix -> a cross-correlation array + # the 0 index is special: pconv[0] === 0. + pconv: np.ndarray + in_memory: bool = False + device: torch.device = torch.device("cpu") + + def __post_init__(self): + assert self.shifts_a.ndim == self.shifts_b.ndim == 1 + assert self.shifts_a.shape == (self.shifted_template_index_a.shape[1],) + assert self.shifts_b.shape == ( + self.upsampled_shifted_template_index_b.shape[1], + ) + + self.a_shift_offset, self.offset_shift_a_to_ix = _get_shift_indexer( + self.shifts_a + ) + self.b_shift_offset, self.offset_shift_b_to_ix = _get_shift_indexer( + self.shifts_b + ) + + def get_shift_ix_a(self, shifts_a): + shifts_a = torch.atleast_1d(torch.as_tensor(shifts_a)) + return self.offset_shift_a_to_ix[shifts_a.to(int) + self.a_shift_offset] + + def get_shift_ix_b(self, shifts_b): + shifts_b = torch.atleast_1d(torch.as_tensor(shifts_b)) + return self.offset_shift_b_to_ix[shifts_b.to(int) + self.b_shift_offset] + + @classmethod + def from_h5(cls, hdf5_filename, in_memory=True): + ff = [f for f in fields(cls) if f.name not in ("in_memory", "device")] + if in_memory: + with h5py.File(hdf5_filename, "r") as h5: + data = {f.name: torch.from_numpy(h5[f.name][:]) for f in ff} + return cls(**data, in_memory=in_memory) + _h5 = h5py.File(hdf5_filename, "r") + data = {} + for f in ff: + if f.name == "pconv": + data[f.name] = _h5[f.name] + else: + data[f.name] = torch.from_numpy(_h5[f.name][:]) + return cls(**data, in_memory=in_memory) + + @classmethod + def from_template_data( + cls, + hdf5_filename, + template_data: TemplateData, + low_rank_templates: LowRankTemplates, + compressed_upsampled_temporal: CompressedUpsampledTemplates, + template_data_b: Optional[TemplateData] = None, + low_rank_templates_b: Optional[TemplateData] = None, + chunk_time_centers_s: Optional[np.ndarray] = None, + motion_est=None, + geom: Optional[np.ndarray] = None, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + conv_batch_size=1024, + units_batch_size=8, + overwrite=False, + device=None, + n_jobs=0, + show_progress=True, + ): + compressed_convolve_to_h5( + hdf5_filename, + template_data=template_data, + low_rank_templates=low_rank_templates, + compressed_upsampled_temporal=compressed_upsampled_temporal, + template_data_b=template_data_b, + low_rank_templates_b=low_rank_templates_b, + chunk_time_centers_s=chunk_time_centers_s, + motion_est=motion_est, + geom=geom, + conv_ignore_threshold=conv_ignore_threshold, + coarse_approx_error_threshold=coarse_approx_error_threshold, + conv_batch_size=conv_batch_size, + units_batch_size=units_batch_size, + overwrite=overwrite, + device=device, + n_jobs=n_jobs, + show_progress=show_progress, + ) + return cls.from_h5(hdf5_filename) + + def at_shifts(self, shifts_a=None, shifts_b=None, device=None): + """Subset this database to one set of shifts. + + The database becomes shiftless (not in the pejorative sense). + """ + if shifts_a is None or shifts_b is None: + assert shifts_a is shifts_b + assert self.shifts_a.shape == (1,) + assert self.shifts_b.shape == (1,) + return self + + assert shifts_a.shape == (len(self.shifted_template_index_a),) + assert shifts_b.shape == (len(self.upsampled_shifted_template_index_b),) + n_shifted_temps_a, n_up_shifted_temps_b = self.pconv_index.shape + + # active shifted and upsampled indices + shift_ix_a = self.get_shift_ix_a(shifts_a) + shift_ix_b = self.get_shift_ix_b(shifts_b) + sub_shifted_temp_index_a = self.shifted_template_index_a[ + torch.arange(len(self.shifted_template_index_a))[:, None], + shift_ix_a[:, None], + ] + sub_up_shifted_temp_index_b = self.upsampled_shifted_template_index_b[ + torch.arange(len(self.upsampled_shifted_template_index_b))[:, None], + shift_ix_b[:, None], + ] + + # in flat form for indexing into pconv_index. also, reindex. + valid_a = sub_shifted_temp_index_a < n_shifted_temps_a + shifted_temp_ixs_a, new_shifted_temp_ixs_a = torch.unique( + sub_shifted_temp_index_a[valid_a], return_inverse=True + ) + valid_b = sub_up_shifted_temp_index_b < n_up_shifted_temps_b + up_shifted_temp_ixs_b, new_up_shifted_temp_ixs_b = torch.unique( + sub_up_shifted_temp_index_b[valid_b], return_inverse=True + ) + + # get relevant pconv subset and reindex + sub_pconv_indices, new_pconv_indices = torch.unique( + self.pconv_index[ + shifted_temp_ixs_a[:, None], + up_shifted_temp_ixs_b.ravel()[None, :], + ], + return_inverse=True, + ) + if self.in_memory: + sub_pconv = self.pconv[sub_pconv_indices.to(self.pconv.device)] + else: + sub_pconv = torch.from_numpy(batched_h5_read(self.pconv, sub_pconv_indices)) + if device is not None: + sub_pconv = sub_pconv.to(device) + + # reindexing + n_sub_shifted_temps_a = len(shifted_temp_ixs_a) + n_sub_up_shifted_temps_b = len(up_shifted_temp_ixs_b) + sub_pconv_index = new_pconv_indices.view( + n_sub_shifted_temps_a, n_sub_up_shifted_temps_b + ) + sub_shifted_temp_index_a[valid_a] = new_shifted_temp_ixs_a + sub_up_shifted_temp_index_b[valid_b] = new_up_shifted_temp_ixs_b + + return self.__class__( + shifts_a=torch.zeros(1), + shifts_b=torch.zeros(1), + shifted_template_index_a=sub_shifted_temp_index_a, + upsampled_shifted_template_index_b=sub_up_shifted_temp_index_b, + pconv_index=sub_pconv_index, + pconv=sub_pconv, + in_memory=True, + device=self.device, + ) + + def to(self, device=None, incl_pconv=False, pin=False): + """Become torch tensors on device.""" + for name in ["offset_shift_a_to_ix", "offset_shift_b_to_ix"] + [ + f.name for f in fields(self) + ]: + if name == "pconv" and not incl_pconv: + continue + v = getattr(self, name) + if isinstance(v, np.ndarray) or torch.is_tensor(v): + setattr(self, name, torch.as_tensor(v, device=device)) + self.device = device + if pin and self.device.type == "cuda" and torch.cuda.is_available() and not self.pconv.is_pinned(): + # self.pconv.share_memory_() + print("pin") + torch.cuda.cudart().cudaHostRegister( + self.pconv.data_ptr(), self.pconv.numel() * self.pconv.element_size(), 0 + ) + # assert x.is_shared() + assert self.pconv.is_pinned() + # self.pconv = self.pconv.pin_memory() + return self + + def query( + self, + template_indices_a, + template_indices_b, + upsampling_indices_b=None, + shifts_a=None, + shifts_b=None, + scalings_b=None, + times_b=None, + return_zero_convs=False, + grid=False, + device=None, + ): + if template_indices_a is None: + template_indices_a = torch.arange( + len(self.shifted_template_index_a), device=self.device + ) + template_indices_a = torch.atleast_1d(torch.as_tensor(template_indices_a)) + template_indices_b = torch.atleast_1d(torch.as_tensor(template_indices_b)) + + # handle no shifting + no_shifting = shifts_a is None or shifts_b is None + shifted_template_index = self.shifted_template_index_a + upsampled_shifted_template_index = self.upsampled_shifted_template_index_b + if no_shifting: + assert shifts_a is None and shifts_b is None + assert self.shifts_a.shape == (1,) + assert self.shifts_b.shape == (1,) + a_ix = (template_indices_a,) + b_ix = (template_indices_b,) + shifted_template_index = shifted_template_index[:, 0] + upsampled_shifted_template_index = upsampled_shifted_template_index[:, 0] + else: + shift_indices_a = self.get_shift_ix_a(shifts_a) + shift_indices_b = self.get_shift_ix_a(shifts_b) + a_ix = (template_indices_a, shift_indices_a) + b_ix = (template_indices_b, shift_indices_b) + + # handle no upsampling + no_upsampling = upsampling_indices_b is None + if no_upsampling: + assert self.upsampled_shifted_template_index_b.shape[2] == 1 + upsampled_shifted_template_index = upsampled_shifted_template_index[..., 0] + else: + b_ix = b_ix + (torch.atleast_1d(torch.as_tensor(upsampling_indices_b)),) + + # get shifted template indices for A + print(f"{a_ix=}") + shifted_temp_ix_a = shifted_template_index[a_ix] + + # upsampled shifted template indices for B + up_shifted_temp_ix_b = upsampled_shifted_template_index[b_ix] + + # return convolutions between all ai,bj or just ai,bi? + if grid: + pconv_indices = self.pconv_index[ + shifted_temp_ix_a[:, None], up_shifted_temp_ix_b[None, :] + ] + template_indices_a, template_indices_b = torch.cartesian_prod( + template_indices_a, template_indices_b + ).T + if scalings_b is not None: + scalings_b = torch.broadcast_to( + scalings_b[None], pconv_indices.shape + ).reshape(-1) + if times_b is not None: + times_b = torch.broadcast_to( + times_b[None], pconv_indices.shape + ).reshape(-1) + pconv_indices = pconv_indices.view(-1) + else: + pconv_indices = self.pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b] + + # most users will be happy not to get a bunch of zeros for pairs that don't overlap + if not return_zero_convs: + which = pconv_indices > 0 + pconv_indices = pconv_indices[which] + template_indices_a = template_indices_a[which] + template_indices_b = template_indices_b[which] + if scalings_b is not None: + scalings_b = scalings_b[which] + if times_b is not None: + times_b = times_b[which] + + if self.in_memory: + pconvs = self.pconv[pconv_indices.to(self.pconv.device)] + else: + pconvs = torch.from_numpy( + batched_h5_read(self.pconv, pconv_indices.numpy(force=True)) + ) + if device is not None: + pconvs = pconvs.to(device) + + if scalings_b is not None: + pconvs.mul_(scalings_b[:, None]) + + if times_b is not None: + return template_indices_a, template_indices_b, times_b, pconvs + + return template_indices_a, template_indices_b, pconvs + + +def batched_h5_read(dataset, indices, batch_size=1000): + if indices.size < batch_size: + return dataset[indices] + else: + out = np.empty((indices.size, *dataset.shape[1:]), dtype=dataset.dtype) + for bs in range(0, indices.size, batch_size): + be = min(indices.size, bs + batch_size) + out[bs:be] = dataset[indices[bs:be]] + return out + + +def _get_shift_indexer(shifts): + assert torch.equal(shifts, torch.sort(shifts).values) + shift_offset = -int(shifts[0]) + offset_shift_to_ix = [] + for j, shift in enumerate(shifts): + ix = shift + shift_offset + assert len(offset_shift_to_ix) <= ix + assert 0 <= ix < len(shifts) + while len(offset_shift_to_ix) < ix: + offset_shift_to_ix.append(len(shifts)) + offset_shift_to_ix.append(j) + offset_shift_to_ix = torch.tensor(offset_shift_to_ix, device=shifts.device) + return shift_offset, offset_shift_to_ix diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py new file mode 100644 index 00000000..1ac6ed8b --- /dev/null +++ b/src/dartsort/templates/pairwise_util.py @@ -0,0 +1,1181 @@ +from __future__ import annotations # allow forward type references + +from collections import namedtuple +from dataclasses import dataclass, fields +from pathlib import Path +from typing import Iterator, Optional, Union + +import h5py +import numpy as np +import torch +import torch.nn.functional as F +from dartsort.util import drift_util +from dartsort.util.multiprocessing_util import get_pool +from scipy.spatial import KDTree +from scipy.spatial.distance import pdist +from tqdm.auto import tqdm + +from . import template_util, templates + + +def compressed_convolve_to_h5( + output_hdf5_filename, + template_data: templates.TemplateData, + low_rank_templates: template_util.LowRankTemplates, + compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, + template_data_b: Optional[templates.TemplateData] = None, + low_rank_templates_b: Optional[templates.TemplateData] = None, + chunk_time_centers_s: Optional[np.ndarray] = None, + motion_est=None, + geom: Optional[np.ndarray] = None, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + conv_batch_size=1024, + units_batch_size=8, + overwrite=False, + device=None, + n_jobs=0, + show_progress=True, +): + """Convolve all pairs of templates and store result in a .h5 + + See pairwise.CompressedPairwiseConvDB for how to read the + resulting convolutions back. + + This runs compressed_convolve_pairs in a loop over chunks + of unit pairs, so that it's not all done in memory at one time, + and so that it can be done in parallel. + """ + output_hdf5_filename = Path(output_hdf5_filename) + if not overwrite and output_hdf5_filename.exists(): + with h5py.File(output_hdf5_filename, "r") as h5: + if "pconv_index" in h5: + return output_hdf5_filename + del h5 + + # construct indexing helpers + ( + template_shift_index_a, + template_shift_index_b, + upsampled_shifted_template_index, + cooccurrence, + ) = construct_shift_indices( + chunk_time_centers_s, + geom, + template_data, + compressed_upsampled_temporal, + template_data_b=template_data_b, + motion_est=motion_est, + ) + + chunk_res_iterator = iterate_compressed_pairwise_convolutions( + template_data_a=template_data, + low_rank_templates_a=low_rank_templates, + template_data_b=template_data_b, + low_rank_templates_b=low_rank_templates_b, + compressed_upsampled_temporal=compressed_upsampled_temporal, + template_shift_index_a=template_shift_index_a, + template_shift_index_b=template_shift_index_b, + cooccurrence=cooccurrence, + upsampled_shifted_template_index=upsampled_shifted_template_index, + do_shifting=motion_est is not None, + geom=geom, + conv_ignore_threshold=conv_ignore_threshold, + coarse_approx_error_threshold=coarse_approx_error_threshold, + max_shift="full", + conv_batch_size=conv_batch_size, + units_batch_size=units_batch_size, + device=device, + n_jobs=n_jobs, + show_progress=show_progress, + ) + + pconv_index = np.zeros( + ( + template_shift_index_a.n_shifted_templates, + upsampled_shifted_template_index.n_upsampled_shifted_templates, + ), + dtype=int, + ) + n_pconvs = 1 + with h5py.File(output_hdf5_filename, "w") as h5: + # resizeable pconv dataset + spike_length_samples = template_data.templates.shape[1] + pconv = h5.create_dataset( + "pconv", + dtype=np.float32, + shape=(1, 2 * spike_length_samples - 1), + maxshape=(None, 2 * spike_length_samples - 1), + chunks=(128, 2 * spike_length_samples - 1), + ) + + for chunk_res in chunk_res_iterator: + if chunk_res is None: + continue + + # get shifted template indices for A + shifted_temp_ix_a = template_shift_index_a.template_shift_index[ + chunk_res.template_indices_a, + chunk_res.shift_indices_a, + ] + + # upsampled shifted template indices for B + up_shifted_temp_ix_b = ( + upsampled_shifted_template_index.upsampled_shifted_template_index[ + chunk_res.template_indices_b, + chunk_res.shift_indices_b, + chunk_res.upsampling_indices_b, + ] + ) + + # store new set of indices + new_pconv_indices = chunk_res.compression_index + n_pconvs + pconv_index[shifted_temp_ix_a, up_shifted_temp_ix_b] = new_pconv_indices + + # store new pconvs + n_new_pconvs = chunk_res.compressed_conv.shape[0] + pconv.resize(n_pconvs + n_new_pconvs, axis=0) + pconv[n_pconvs:] = chunk_res.compressed_conv + + n_pconvs += n_new_pconvs + + # write fixed size outputs + h5.create_dataset("shifts_a", data=template_shift_index_a.all_pitch_shifts) + h5.create_dataset("shifts_b", data=template_shift_index_b.all_pitch_shifts) + h5.create_dataset( + "shifted_template_index_a", data=template_shift_index_a.template_shift_index + ) + h5.create_dataset( + "upsampled_shifted_template_index_b", + data=upsampled_shifted_template_index.upsampled_shifted_template_index, + ) + h5.create_dataset("pconv_index", data=pconv_index) + + return output_hdf5_filename + + +def iterate_compressed_pairwise_convolutions( + template_data_a: templates.TemplateData, + low_rank_templates_a: template_util.LowRankTemplates, + template_data_b: templates.TemplateData, + low_rank_templates_b: template_util.LowRankTemplates, + compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, + template_shift_index_a: drift_util.TemplateShiftIndex, + template_shift_index_b: drift_util.TemplateShiftIndex, + cooccurrence: np.ndarray, + upsampled_shifted_template_index: UpsampledShiftedTemplateIndex, + do_shifting: bool = True, + geom: Optional[np.ndarray] = None, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + max_shift="full", + amplitude_scaling_variance=0.0, + amplitude_scaling_boundary=0.5, + reduce_deconv_resid_norm=False, + conv_batch_size=1024, + units_batch_size=8, + device=None, + n_jobs=0, + show_progress=True, +) -> Iterator[Optional[CompressedConvResult]]: + """A generator of CompressedConvResults capturing all pairs of templates + + Runs the function compressed_convolve_pairs on chunks of units. + + This is a helper function for parallelizing computation of cross correlations + between pairs of templates. There are too many to store all the results in + memory, so this is a generator yielding a chunk at a time. Callers may + process the results differently. + """ + reg_geom = template_data_a.registered_geom + if template_data_b is None: + template_data_b = template_data_a + assert low_rank_templates_b is None + low_rank_templates_b = low_rank_templates_a + assert np.array_equal(reg_geom, template_data_b.registered_geom) + + # construct drift-related helper data if needed + geom_kdtree = reg_geom_kdtree = match_distance = None + if do_shifting: + geom_kdtree = KDTree(geom) + reg_geom_kdtree = KDTree(reg_geom) + match_distance = pdist(geom).min() / 2 + + # make chunks + units_a = np.unique(template_data_a.unit_ids) + units_b = np.unique(template_data_b.unit_ids) + jobs = [] + for start_a in range(0, units_a.size, units_batch_size): + end_a = min(start_a + units_batch_size, units_a.size) + for start_b in range(0, units_b.size, units_batch_size): + end_b = min(start_b + units_batch_size, units_b.size) + jobs.append((units_a[start_a:end_a], units_b[start_b:end_b])) + + # worker kwargs + kwargs = dict( + template_data_a=template_data_a, + template_data_b=template_data_b, + low_rank_templates_a=low_rank_templates_a, + low_rank_templates_b=low_rank_templates_b, + compressed_upsampled_temporal=compressed_upsampled_temporal, + template_shift_index_a=template_shift_index_a, + template_shift_index_b=template_shift_index_b, + upsampled_shifted_template_index=upsampled_shifted_template_index, + cooccurrence=cooccurrence, + geom=geom, + reg_geom=reg_geom, + geom_kdtree=geom_kdtree, + reg_geom_kdtree=reg_geom_kdtree, + match_distance=match_distance, + conv_ignore_threshold=conv_ignore_threshold, + coarse_approx_error_threshold=coarse_approx_error_threshold, + max_shift=max_shift, + batch_size=conv_batch_size, + amplitude_scaling_variance=amplitude_scaling_variance, + amplitude_scaling_boundary=amplitude_scaling_boundary, + reduce_deconv_resid_norm=reduce_deconv_resid_norm, + ) + + n_jobs, Executor, context, rank_queue = get_pool(n_jobs, with_rank_queue=True) + with Executor( + n_jobs, + mp_context=context, + initializer=_conv_worker_init, + initargs=(rank_queue, device, kwargs), + ) as pool: + it = pool.map(_conv_job, jobs) + if show_progress: + it = tqdm( + it, + smoothing=0.01, + desc="Pairwise convolution", + unit="pair block", + total=len(jobs), + ) + yield from it + + +@dataclass +class CompressedConvResult: + """Main return type of compressed_convolve_pairs + + If reduce_deconv_resid_norm=True, a DeconvResidResult is returned. + + After convolving a bunch of template pairs, some convolutions + may be zero. Let n_pairs be the number of nonzero convolutions. + We don't store the zero ones. + """ + + # arrays of shape n_pairs, + # For each convolved pair, these document which templates were + # in the pair, what their relative shifts were, and what the + # upsampling was (we only upsample the RHS) + template_indices_a: np.ndarray + template_indices_b: np.ndarray + shift_indices_a: np.ndarray + shift_indices_b: np.ndarray + upsampling_indices_b: np.ndarray + + # another one of shape n_pairs + # maps a pair index to the corresponding convolution index + # some convolutions are duplicates, so this array contains + # many duplicate entries in the range 0, ..., n_convs-1 + compression_index: np.ndarray + + # this one has shape (n_convs, 2 * spike_length_samples - 1) + compressed_conv: np.ndarray + + +@dataclass +class DeconvResidResult: + """Return type of compressed_convolve_pairs + + After convolving a bunch of template pairs, some convolutions + may be zero. Let n_pairs be the number of nonzero convolutions. + We don't store the zero ones. + """ + + # arrays of shape n_pairs, + # For each convolved pair, these document which templates were + # in the pair, what their relative shifts were, and what the + # upsampling was (we only upsample the RHS) + template_indices_a: np.ndarray + template_indices_b: np.ndarray + + # norm after subtracting best upsampled/scaled/shifted B template from A template + deconv_resid_norms: np.ndarray + + # ints. B trough - A trough + shifts: np.ndarray + + # for caller to implement different metrics + template_a_norms: np.ndarray + + # TODO: how to handle the nnz normalization we used to do? + # that one was done wrong -- the residual was not restricted + # to high amplitude channels. + + +def conv_to_resid( + template_data_a: templates.TemplateData, + template_data_b: templates.TemplateData, + conv_result: CompressedConvResult, + amplitude_scaling_variance=0.0, + amplitude_scaling_boundary=0.5, +) -> DeconvResidResult: + # decompress + pconvs = conv_result.compressed_conv[conv_result.compression_index] + full_length = pconvs.shape[1] + center = full_length // 2 + + # here, we just care about pairs of (superres) templates, not upsampling + # or shifting. so, get unique such pairs. + pairs = np.c_[conv_result.template_indices_a, conv_result.template_indices_b] + pairs = np.unique(pairs, axis=0) + n_pairs = len(pairs) + + # for loop to reduce over all (upsampled etc) member templates + deconv_resid_norms = np.zeros(n_pairs) + shifts = np.zeros(n_pairs, dtype=int) + template_indices_a, template_indices_b = pairs.T + templates_a = template_data_a.templates[template_indices_a] + templates_b = template_data_b.templates[template_indices_b] + template_a_norms = np.linalg.norm(templates_a, axis=(1, 2)) ** 2 + template_b_norms = np.linalg.norm(templates_b, axis=(1, 2)) ** 2 + for j, (ix_a, ix_b) in enumerate(pairs): + in_a = conv_result.template_indices_a == ix_a + in_b = conv_result.template_indices_b == ix_b + in_pair = np.flatnonzero(in_a & in_b) + + # reduce over fine templates + pair_conv = pconvs[in_pair].max(axis=0) + lag_index = np.argmax(pair_conv) + best_conv = pair_conv[lag_index] + shifts[j] = lag_index - center + + # figure out scaling + if amplitude_scaling_variance: + amp_scale_min = 1 / (1 + amplitude_scaling_boundary) + amp_scale_max = 1 + amplitude_scaling_boundary + inv_lambda = 1 / amplitude_scaling_variance + b = best_conv + inv_lambda + a = template_a_norms[j] + inv_lambda + scaling = np.clip(b / a, amp_scale_min, amp_scale_max) + norm_reduction = 2 * scaling * b - np.square(scaling) * a - inv_lambda + else: + norm_reduction = 2 * best_conv - template_b_norms[j] + deconv_resid_norms[j] = template_a_norms[j] - norm_reduction + assert deconv_resid_norms[j] >= 0 + + return DeconvResidResult( + template_indices_a, + template_indices_b, + deconv_resid_norms, + shifts, + template_a_norms, + ) + + +def compressed_convolve_pairs( + template_data_a: templates.TemplateData, + template_data_b: templates.TemplateData, + low_rank_templates_a: template_util.LowRankTemplates, + low_rank_templates_b: template_util.LowRankTemplates, + compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates, + template_shift_index_a: drift_util.TemplateShiftIndex, + template_shift_index_b: drift_util.TemplateShiftIndex, + upsampled_shifted_template_index: UpsampledShiftedTemplateIndex, + cooccurrence: np.ndarray, + geom: Optional[np.ndarray] = None, + reg_geom: Optional[np.ndarray] = None, + geom_kdtree: Optional[KDTree] = None, + reg_geom_kdtree: Optional[KDTree] = None, + match_distance: Optional[float] = None, + units_a: Optional[np.ndarray] = None, + units_b: Optional[np.ndarray] = None, + conv_ignore_threshold=0.0, + coarse_approx_error_threshold=0.0, + amplitude_scaling_variance=0.0, + amplitude_scaling_boundary=0.5, + reduce_deconv_resid_norm=False, + max_shift="full", + batch_size=1024, + device=None, +) -> Optional[CompressedConvResult]: + """Compute compressed pairwise convolutions between template pairs + + Takes as input all the template data and groups of pairs of units to convolve + (units_a,b). units_a,b are unit indices, not template indices (i.e., coarse + units, not superresolved bin indices). + + Returns compressed convolutions between all units_a[i], units_b[j], for all + shifts, superres templates, and upsamples. Some of these may be zero or may + be duplicates, so the return value is a sparse representation. See below. + """ + # what pairs, shifts, etc are we convolving? + shifted_temp_ix_a, temp_ix_a, shift_a, unit_a = handle_shift_indices( + units_a, template_data_a.unit_ids, template_shift_index_a + ) + shifted_temp_ix_b, temp_ix_b, shift_b, unit_b = handle_shift_indices( + units_b, template_data_b.unit_ids, template_shift_index_b + ) + + # get (shifted) spatial components * singular values + spatial_singular_a = get_shifted_spatial_singular( + temp_ix_a, + shift_a, + template_shift_index_a, + low_rank_templates_a, + geom=geom, + registered_geom=reg_geom, + geom_kdtree=geom_kdtree, + match_distance=match_distance, + device=device, + ) + spatial_singular_b = get_shifted_spatial_singular( + temp_ix_b, + shift_b, + template_shift_index_b, + low_rank_templates_b, + geom=geom, + registered_geom=reg_geom, + geom_kdtree=geom_kdtree, + match_distance=match_distance, + device=device, + ) + + # figure out pairs of shifted templates to convolve in a deduplicated way + pairs_ret = shift_deduplicated_pairs( + shifted_temp_ix_a, + shifted_temp_ix_b, + spatial_singular_a, + spatial_singular_b, + temp_ix_a, + temp_ix_b, + cooccurrence=cooccurrence, + shift_a=shift_a, + shift_b=shift_b, + conv_ignore_threshold=conv_ignore_threshold, + geom=geom, + registered_geom=reg_geom, + reg_geom_kdtree=reg_geom_kdtree, + match_distance=match_distance, + ) + if pairs_ret is None: + return None + ix_a, ix_b, compression_index, conv_ix, spatial_shift_ids = pairs_ret + + # handle upsampling + # each pair will be duplicated by the b unit's number of upsampled copies + ( + ix_b, + compression_index, + conv_ix, + conv_upsampling_indices_b, + conv_temporal_components_up_b, + compression_dup_ix, + ) = compressed_upsampled_pairs( + ix_b, + compression_index, + conv_ix, + temp_ix_b, + shifted_temp_ix_b, + upsampled_shifted_template_index, + compressed_upsampled_temporal, + ) + ix_a = ix_a[compression_dup_ix] + spatial_shift_ids = spatial_shift_ids[compression_dup_ix] + + # run convolutions + temporal_a = low_rank_templates_a.temporal_components[temp_ix_a] + pconv, kept = correlate_pairs_lowrank( + torch.as_tensor(spatial_singular_a[ix_a[conv_ix]], device=device), + torch.as_tensor(spatial_singular_b[ix_b[conv_ix]], device=device), + torch.as_tensor(temporal_a[ix_a[conv_ix]], device=device), + torch.as_tensor(conv_temporal_components_up_b, device=device), + max_shift=max_shift, + conv_ignore_threshold=conv_ignore_threshold, + batch_size=batch_size, + ) + if kept is not None: + conv_ix = conv_ix[kept] + if not conv_ix.shape[0]: + return None + kept_pairs = np.flatnonzero(np.isin(compression_index, kept)) + compression_index = np.searchsorted(kept, compression_index[kept_pairs]) + conv_ix = np.searchsorted(kept_pairs, conv_ix) + ix_a = ix_a[kept_pairs] + ix_b = ix_b[kept_pairs] + spatial_shift_ids = spatial_shift_ids[kept_pairs] + assert pconv.numel() > 0 + + # coarse approx + pconv, old_ix_to_new_ix = coarse_approximate( + pconv, + unit_a[ix_a[conv_ix]], + unit_b[ix_b[conv_ix]], + temp_ix_a[ix_a[conv_ix]], + spatial_shift_ids[conv_ix], + coarse_approx_error_threshold=coarse_approx_error_threshold, + ) + compression_index = old_ix_to_new_ix[compression_index] + # above function invalidates the whole idea of conv_ix + del conv_ix + + # recover metadata + temp_ix_a = temp_ix_a[ix_a] + shift_ix_a = np.searchsorted(template_shift_index_a.all_pitch_shifts, shift_a[ix_a]) + temp_ix_b = temp_ix_b[ix_b] + shift_ix_b = np.searchsorted(template_shift_index_b.all_pitch_shifts, shift_b[ix_b]) + + res = CompressedConvResult( + template_indices_a=temp_ix_a, + template_indices_b=temp_ix_b, + shift_indices_a=shift_ix_a, + shift_indices_b=shift_ix_b, + upsampling_indices_b=conv_upsampling_indices_b[compression_index], + compression_index=compression_index, + compressed_conv=pconv.numpy(force=True), + ) + if reduce_deconv_resid_norm: + return conv_to_resid( + template_data_a, + template_data_b, + res, + amplitude_scaling_variance=amplitude_scaling_variance, + amplitude_scaling_boundary=amplitude_scaling_boundary, + ) + return res + + +# -- helpers + + +def correlate_pairs_lowrank( + spatial_a, + spatial_b, + temporal_a, + temporal_b, + max_shift="full", + conv_ignore_threshold=0.0, + batch_size=1024, +): + """Convolve pairs of low rank templates + + For each i, we want to convolve (temporal_a[i] @ spatial_a[i]) with + (temporal_b[i] @ spatial_b[i]). So, spatial_{a,b} and temporal_{a,b} + should contain lots of duplicates, since they are already representing + pairs. + + Templates Ka = Sa Ta, Kb = Sb Tb. The channel-summed convolution is + (Ka (*) Kb) = sum_c Ka(c) * Kb(c) + = (Sb.T @ Ka) (*) Tb + = (Sb.T @ Sa @ Ta) (*) Tb + where * is cross-correlation, and (*) is channel (or rank) summed. + We use full-height conv2d to do rank-summed convs. + + Returns + ------- + pconv, kept + """ + n_pairs, rank, nchan = spatial_a.shape + n_pairs_, rank_, nchan_ = spatial_b.shape + assert rank == rank_ + assert nchan == nchan_ + assert n_pairs == n_pairs_ + n_pairs_, t, rank_ = temporal_a.shape + assert n_pairs == n_pairs_ + assert rank_ == rank + n_pairs_, t_, rank_ = temporal_b.shape + assert n_pairs == n_pairs_ + assert t == t_ + assert rank == rank_ + + if max_shift == "full": + max_shift = t - 1 + elif max_shift == "valid": + max_shift = 0 + elif max_shift == "same": + max_shift = t // 2 + + # batch over n_pairs for memory reasons + pconv = torch.zeros( + (n_pairs, 2 * max_shift + 1), dtype=spatial_a.dtype, device=spatial_a.device + ) + for istart in range(0, n_pairs, batch_size): + iend = min(istart + batch_size, n_pairs) + ix = slice(istart, iend) + + # want conv filter: nco, 1, rank, t + template_a = torch.bmm(temporal_a[ix], spatial_a[ix]) + conv_filt = torch.bmm(spatial_b[ix], template_a.mT) + conv_filt = conv_filt[:, None] # (nco, 1, rank, t) + + # 1, nco, rank, t + conv_in = temporal_b[ix].mT[None] + + # conv2d: + # depthwise, chans=nco. batch=1. h=rank. w=t. out: nup=1, nco, 1, 2p+1. + # input (conv_in): nup, nco, rank, t. + # filters (conv_filt): nco, 1, rank, t. (groups=nco). + pconv_ = F.conv2d( + conv_in, conv_filt, padding=(0, max_shift), groups=iend - istart + ) + pconv[istart:iend] = pconv_[0, :, 0, :] # nco, nup, time + + # more stringent covisibility + if conv_ignore_threshold is not None: + max_val = pconv.reshape(n_pairs, -1).abs().max(dim=1).values + kept = max_val > conv_ignore_threshold + pconv = pconv[kept] + kept = np.flatnonzero(kept.numpy(force=True)) + else: + kept = None + + return pconv, kept + + +def construct_shift_indices( + chunk_time_centers_s, + geom, + template_data_a, + compressed_upsampled_temporal, + template_data_b=None, + motion_est=None, +): + ( + template_shift_index_a, + template_shift_index_b, + cooccurrence, + ) = drift_util.get_shift_and_unit_pairs( + chunk_time_centers_s, + geom, + template_data_a, + template_data_b=template_data_b, + motion_est=motion_est, + ) + upsampled_shifted_template_index = get_upsampled_shifted_template_index( + template_shift_index_b, compressed_upsampled_temporal + ) + return ( + template_shift_index_a, + template_shift_index_b, + upsampled_shifted_template_index, + cooccurrence, + ) + + +def handle_shift_indices(units, unit_ids, template_shift_index): + """Determine shifted template indices belonging to a set of units.""" + shifted_temp_ix_to_unit = unit_ids[template_shift_index.shifted_temp_ix_to_temp_ix] + if units is None: + shifted_temp_ix = np.arange(template_shift_index.n_shifted_templates) + else: + shifted_temp_ix = np.flatnonzero(np.isin(shifted_temp_ix_to_unit, units)) + + shift = template_shift_index.shifted_temp_ix_to_shift[shifted_temp_ix] + temp_ix = template_shift_index.shifted_temp_ix_to_temp_ix[shifted_temp_ix] + unit = unit_ids[temp_ix] + + return shifted_temp_ix, temp_ix, shift, unit + + +def get_shifted_spatial_singular( + temp_ix, + shift, + template_shift_index, + low_rank_templates, + geom=None, + registered_geom=None, + geom_kdtree=None, + match_distance=None, + device=None, +): + # do we need to shift the templates? + n_shifts = template_shift_index.all_pitch_shifts.size + do_shifting = n_shifts > 1 + + spatial_singular = ( + low_rank_templates.spatial_components[temp_ix] + * low_rank_templates.singular_values[temp_ix][..., None] + ) + if do_shifting: + spatial_singular = drift_util.get_waveforms_on_static_channels( + spatial_singular, + registered_geom, + n_pitches_shift=shift, + registered_geom=geom, + target_kdtree=geom_kdtree, + match_distance=match_distance, + fill_value=0.0, + ) + spatial_singular = torch.as_tensor(spatial_singular, device=device) + + return spatial_singular + + +def shift_deduplicated_pairs( + shifted_temp_ix_a, + shifted_temp_ix_b, + spatialsing_a, + spatialsing_b, + temp_ix_a, + temp_ix_b, + cooccurrence, + shift_a=None, + shift_b=None, + conv_ignore_threshold=0.0, + geom=None, + registered_geom=None, + reg_geom_kdtree=None, + match_distance=None, +): + """Choose a set of pairs of indices from group A and B to convolve + + Some pairs of shifted templates don't overlap, so we don't need to convolve them. + Some pairs of shifted templates never show up in the recording at the same time + (what this code calls "cooccurrence"), so we don't need to convolve them. + + More complicated: for each shift, a certain set of registered template channels + survives. Given that the some set of visible channels has survived for a pair of + templates at shifts shift_a and shift_b, their cross-correlation at these shifts + is the same as the one at shift_a_prime and shift_b_prime if the same exact channels + survived at shift_a_prime and shift_b_prime and if + shift_a-shift_b == shift_a_prime-shift_b_prime. + + Returns + ------- + pair_ix_a, pair_ix_b + Size < original number of shifted templates a,b + The indices of shifted templates which overlap enough to be + co-visible. So, these are subsets of shifted_temp_ix_a,b + compression_index + Size == pair_ix_a,b size + Arrays with shape matching pair_ix_a,b, so that the xcorr of templates + shifted_temp_ix_a[pair_ix_a[i]], shifted_temp_ix_b[pair_ix_b[i]] + is the same as that of + shifted_temp_ix_a[pair_ix_a[conv_ix[compression_index[i]]], + pair_ix_b[conv_ix[compression_index[i]]] + conv_ix + Size < original number of shifted templates a,b + Pairs of templates which should actually be convolved + """ + # check spatially overlapping + chan_amp_a = torch.sqrt(torch.square(spatialsing_a).sum(1)) + chan_amp_b = torch.sqrt(torch.square(spatialsing_b).sum(1)) + pair = chan_amp_a @ chan_amp_b.T + pair = pair > conv_ignore_threshold + pair = pair.cpu() + + # co-occurrence + cooccurrence = cooccurrence[ + shifted_temp_ix_a[:, None], + shifted_temp_ix_b[None, :], + ] + pair *= torch.as_tensor(cooccurrence, device=pair.device) + + pair_ix_a, pair_ix_b = torch.nonzero(pair, as_tuple=True) + nco = pair_ix_a.numel() + if not nco: + return None + + # if no shifting, deduplication is the identity + do_shifting = reg_geom_kdtree is not None + if not do_shifting: + nco_range = torch.arange(nco, device=pair_ix_a.device) + return pair_ix_a, pair_ix_b, nco_range, nco_range, np.zeros(nco, dtype=int) + + # shift deduplication. algorithm: + # 1 for each shifted template, determine the set of registered channels + # which it occupies + # 2 assign each such set an ID (an "active channel ID") + # - // then a pair of shifted templates' xcorr is a function of the pair + # // of active channel IDs and the difference of shifts + # 3 figure out the set of unique (active chan id a, active chan id b, shift diff a,b) + # combinations in each pair of units + + # 1: get active channel neighborhoods as many-hot len(reg_geom)-vectors + active_chans_a = drift_util.get_waveforms_on_static_channels( + (chan_amp_a > 0).numpy(force=True), + geom, + n_pitches_shift=-shift_a, + registered_geom=registered_geom, + target_kdtree=reg_geom_kdtree, + match_distance=match_distance, + fill_value=0, + ) + active_chans_b = drift_util.get_waveforms_on_static_channels( + (chan_amp_b > 0).numpy(force=True), + geom, + n_pitches_shift=-shift_b, + registered_geom=registered_geom, + target_kdtree=reg_geom_kdtree, + match_distance=match_distance, + fill_value=0, + ) + # 2: assign IDs to each such vector + chanset_a, active_chan_ids_a = np.unique( + active_chans_a, axis=0, return_inverse=True + ) + chanset_b, active_chan_ids_b = np.unique( + active_chans_b, axis=0, return_inverse=True + ) + + # 3 + temp_ix_a = temp_ix_a[pair_ix_a] + temp_ix_b = temp_ix_b[pair_ix_b] + # get the relative shifts + shift_a = shift_a[pair_ix_a] + shift_b = shift_b[pair_ix_b] + + # figure out combinations + _, spatial_shift_ids = np.unique( + np.c_[ + active_chan_ids_a[pair_ix_a], + active_chan_ids_b[pair_ix_b], + shift_a - shift_b, + ], + axis=0, + return_inverse=True, + ) + conv_determiners = np.c_[ + temp_ix_a, + temp_ix_b, + spatial_shift_ids, + ] + # conv_ix: indices of unique determiners + # compression_index: which representative does each pair belong to + _, conv_ix, compression_index = np.unique( + conv_determiners, axis=0, return_index=True, return_inverse=True + ) + + return pair_ix_a, pair_ix_b, compression_index, conv_ix, spatial_shift_ids + + +UpsampledShiftedTemplateIndex = namedtuple( + "UpsampledShiftedTemplateIndex", + [ + "n_upsampled_shifted_templates", + "upsampled_shifted_template_index", + "up_shift_temp_ix_to_shift_temp_ix", + "up_shift_temp_ix_to_temp_ix", + "up_shift_temp_ix_to_comp_up_ix", + ], +) + + +def get_upsampled_shifted_template_index( + template_shift_index, compressed_upsampled_temporal +): + """Make a compressed index space for upsampled shifted templates + + See also: template_util.{compressed_upsampled_templates,ComptessedUpsampledTemplates}. + + The comp_up_ix / compressed upsampled template indices here are indices into that + structure. + + Returns + ------- + UpsampledShiftedTemplateIndex + named tuple with fields: + upsampled_shifted_template_index : (n_templates, n_shifts, up_factor) + Maps template_ix, shift_ix, up_ix -> compressed upsampled template index + up_shift_temp_ix_to_shift_temp_ix + up_shift_temp_ix_to_temp_ix + up_shift_temp_ix_to_comp_up_ix + """ + n_shifted_templates = template_shift_index.n_shifted_templates + n_templates, n_shifts = template_shift_index.template_shift_index.shape + max_upsample = compressed_upsampled_temporal.compressed_upsampling_map.shape[1] + + cur_up_shift_temp_ix = 0 + # fill with an invalid index + upsampled_shifted_template_index = np.full( + (n_templates, n_shifts, max_upsample), n_shifted_templates * max_upsample + ) + usti2sti = [] + usti2ti = [] + usti2cui = [] + for i in range(n_templates): + shifted_temps = template_shift_index.template_shift_index[i] + valid_shifts = np.flatnonzero(shifted_temps < n_shifted_templates) + + upsampled_temps = compressed_upsampled_temporal.compressed_upsampling_map[i] + unique_comp_up_inds, inverse = np.unique(upsampled_temps, return_inverse=True) + + for j in valid_shifts: + up_shift_inds = cur_up_shift_temp_ix + np.arange(unique_comp_up_inds.size) + upsampled_shifted_template_index[i, j] = up_shift_inds[inverse] + cur_up_shift_temp_ix += up_shift_inds.size + + usti2sti.extend([shifted_temps[j]] * up_shift_inds.size) + usti2ti.extend([i] * up_shift_inds.size) + usti2cui.extend(unique_comp_up_inds) + + up_shift_temp_ix_to_shift_temp_ix = np.array(usti2sti) + up_shift_temp_ix_to_temp_ix = np.array(usti2ti) + up_shift_temp_ix_to_comp_up_ix = np.array(usti2cui) + + return UpsampledShiftedTemplateIndex( + up_shift_temp_ix_to_shift_temp_ix.size, + upsampled_shifted_template_index, + up_shift_temp_ix_to_shift_temp_ix, + up_shift_temp_ix_to_temp_ix, + up_shift_temp_ix_to_comp_up_ix, + ) + + +def compressed_upsampled_pairs( + ix_b, + compression_index, + conv_ix, + temp_ix_b, + shifted_temp_ix_b, + upsampled_shifted_template_index, + compressed_upsampled_temporal, +): + """Add in upsampling to the set of pairs that need to be convolved + + So far, ix_a,b, compression_index, and conv_ix are such that non-upsampled + convolutions between templates ix_a[i], ix_b[i] equal that between templates + ix_a[conv_ix[compression_index[i]]], ix_b[conv_ix[compression_index[i]]]. + + We will upsample the templates in the RHS (b) in a compressed way, so that + each b index gets its own number of duplicates. + """ + up_factor = compressed_upsampled_temporal.compressed_upsampling_map.shape[1] + compression_dup_ix = slice(None) + if up_factor == 1: + upinds = np.zeros(len(conv_ix), dtype=int) + temp_comps = compressed_upsampled_temporal.compressed_upsampled_templates[ + temp_ix_b[ix_b[conv_ix]] + ] + return ix_b, compression_index, conv_ix, upinds, temp_comps, compression_dup_ix + + # each conv_ix needs to be duplicated as many times as its b template has + # upsampled copies + conv_shifted_temp_ix_b = np.atleast_1d(shifted_temp_ix_b[ix_b[conv_ix]]) + upsampling_mask = ( + conv_shifted_temp_ix_b[:, None] + == upsampled_shifted_template_index.up_shift_temp_ix_to_shift_temp_ix[None, :] + ) + conv_up_i, up_shift_up_i = np.nonzero(upsampling_mask) + conv_compressed_upsampled_ix = ( + upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[ + up_shift_up_i + ] + ) + conv_dup = conv_ix[conv_up_i] + # And, all ix_{a,b}[i] such that compression_ix[i] lands in + # that conv_ix need to be duplicated as well. + dup_mask = conv_ix[compression_index][:, None] == conv_dup[None, :] + if torch.is_tensor(dup_mask): + dup_mask = dup_mask.numpy(force=True) + compression_dup_ix, compression_index_up = np.nonzero(dup_mask) + ix_b_up = ix_b[compression_dup_ix] + + # the conv ix need to be offset to keep the relation with the pairs + # ix_a[old i] + # offsets = np.cumsum((conv_ix[:, None] == conv_dup[None, :]).sum(0)) + # offsets -= offsets[0] + _, offsets = np.unique(compression_dup_ix, return_index=True) + conv_ix_up = offsets[conv_dup] + + # which upsamples and which templates? + conv_upsampling_indices_b = ( + compressed_upsampled_temporal.compressed_index_to_upsampling_index[ + conv_compressed_upsampled_ix + ] + ) + conv_temporal_components_up_b = ( + compressed_upsampled_temporal.compressed_upsampled_templates[ + conv_compressed_upsampled_ix + ] + ) + + return ( + ix_b_up, + compression_index_up, + conv_ix_up, + conv_upsampling_indices_b, + conv_temporal_components_up_b, + compression_dup_ix, + ) + + +def coarse_approximate( + pconv, + units_a, + units_b, + temp_ix_a, + # shift_a, + # shift_b, + spatial_shift_ids, + coarse_approx_error_threshold=0.0, +): + """Try to replace fine (superres+temporally upsampled) convs with coarse ones + + For each pair of convolved units, we first try to replace all of the pairwise + convolutions between these units with their mean, respecting the shifts. + + If that fails, we try to do this in a factorized way: for each superres unit a, + try to replace all of its convolutions with unit b with their mean, respecting + the shifts. + + Above, "respecting the shifts" means we only do this within each shift-deduplication + class, since changes in the sets of channels being convolved cause large changes + in the cross correlation. pconv has already been deduplicated with respect to + equivalent channel neighborhoods, so all that matters for that purpose is the + shift difference. + + This needs to tell the caller how to update its bookkeeping. + """ + if not pconv.numel() or not coarse_approx_error_threshold: + return pconv, np.arange(len(pconv)) + + new_pconv = [] + old_ix_to_new_ix = np.full(len(pconv), -1) + cur_new_ix = 0 + # shift_diff = shift_a - shift_b + for ua in np.unique(units_a): + ina = np.flatnonzero(units_a == ua) + partners_b = np.unique(units_b[ina]) + for ub in partners_b: + inab = ina[units_b[ina] == ub] + dshift = spatial_shift_ids[inab] + for shift in np.unique(dshift): + inshift = inab[dshift == shift] + + convs = pconv[inshift] + meanconv = convs.mean(dim=0, keepdims=True) + if (convs - meanconv).abs().max() < coarse_approx_error_threshold: + # do something + new_pconv.append(meanconv) + old_ix_to_new_ix[inshift] = cur_new_ix + cur_new_ix += 1 + continue + # else: + # # if we don't want the factorized thing... + # new_pconv.append(convs) + # old_ix_to_new_ix[inshift] = np.arange(cur_new_ix, cur_new_ix + inshift.size) + # cur_new_ix += inshift.size + # continue + + active_temp_a = temp_ix_a[inshift] + unique_active_temp_a = np.unique(active_temp_a) + + # TODO just upsampling dedup + # active_temp_b = temp_ix_b[inshift] + # unique_active_temp_b = np.unique(active_temp_b) + # if unique_active_temp_a.size == unique_active_temp_b.size == 1: + if unique_active_temp_a.size == 1: + new_pconv.append(convs) + old_ix_to_new_ix[inshift] = np.arange( + cur_new_ix, cur_new_ix + inshift.size + ) + cur_new_ix += inshift.size + continue + + for tixa in unique_active_temp_a: + insup = active_temp_a == tixa + supconvs = convs[insup] + + meanconv = supconvs.mean(dim=0, keepdims=True) + if (convs - meanconv).abs().max() < coarse_approx_error_threshold: + new_pconv.append(meanconv) + old_ix_to_new_ix[inshift[insup]] = cur_new_ix + cur_new_ix += 1 + else: + new_pconv.append(supconvs) + old_ix_to_new_ix[inshift[insup]] = np.arange( + cur_new_ix, cur_new_ix + insup.sum() + ) + cur_new_ix += insup.sum() + + new_pconv = torch.cat(new_pconv, out=pconv[:cur_new_ix]) + return new_pconv, old_ix_to_new_ix + + +# -- parallelism helpers + + +@dataclass +class ConvWorkerContext: + template_data_a: templates.TemplateData + template_data_b: templates.TemplateData + low_rank_templates_a: template_util.LowRankTemplates + low_rank_templates_b: template_util.LowRankTemplates + compressed_upsampled_temporal: template_util.CompressedUpsampledTemplates + template_shift_index_a: drift_util.TemplateShiftIndex + template_shift_index_b: drift_util.TemplateShiftIndex + cooccurrence: np.ndarray + upsampled_shifted_template_index: UpsampledShiftedTemplateIndex + geom: Optional[np.ndarray] = None + reg_geom: Optional[np.ndarray] = None + geom_kdtree: Optional[KDTree] = None + reg_geom_kdtree: Optional[KDTree] = None + match_distance: Optional[float] = None + conv_ignore_threshold: float = 0.0 + coarse_approx_error_threshold: float = 0.0 + amplitude_scaling_variance: float = 0.0 + amplitude_scaling_boundary: float = 0.5 + reduce_deconv_resid_norm: bool = False + max_shift: Union[int, str] = "full" + batch_size: int = 128 + device: Optional[torch.device] = None + + def __post_init__(self): + # to device + self.compressed_upsampled_temporal.compressed_upsampled_templates = ( + torch.as_tensor( + self.compressed_upsampled_temporal.compressed_upsampled_templates, + device=self.device, + ) + ) + self.low_rank_templates_a.spatial_components = torch.as_tensor( + self.low_rank_templates_a.spatial_components, device=self.device + ) + self.low_rank_templates_a.singular_values = torch.as_tensor( + self.low_rank_templates_a.singular_values, device=self.device + ) + self.low_rank_templates_a.temporal_components = torch.as_tensor( + self.low_rank_templates_a.temporal_components, device=self.device + ) + self.low_rank_templates_b.spatial_components = torch.as_tensor( + self.low_rank_templates_b.spatial_components, device=self.device + ) + self.low_rank_templates_b.singular_values = torch.as_tensor( + self.low_rank_templates_b.singular_values, device=self.device + ) + self.low_rank_templates_b.temporal_components = torch.as_tensor( + self.low_rank_templates_b.temporal_components, device=self.device + ) + + +_conv_worker_context = None + + +def _conv_worker_init(rank_queue, device, kwargs): + global _conv_worker_context + + my_rank = rank_queue.get() + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + if device.type == "cuda" and device.index is None: + if torch.cuda.device_count() > 1: + device = torch.device("cuda", index=my_rank % torch.cuda.device_count()) + + _conv_worker_context = ConvWorkerContext(device=device, **kwargs) + + +def _conv_job(unit_chunk): + global _conv_worker_context + units_a, units_b = unit_chunk + return compressed_convolve_pairs( + units_a=units_a, units_b=units_b, **asdict_shallow(_conv_worker_context) + ) + + +def asdict_shallow(obj): + return {field.name: getattr(obj, field.name) for field in fields(obj)} diff --git a/src/dartsort/templates/superres_util.py b/src/dartsort/templates/superres_util.py index 80d4c6a1..2884d5a0 100644 --- a/src/dartsort/templates/superres_util.py +++ b/src/dartsort/templates/superres_util.py @@ -13,6 +13,7 @@ def superres_sorting( strategy="drift_pitch_loc_bin", superres_bin_size_um=10.0, min_spikes_per_bin=5, + probe_margin_um=200.0, ): """Construct the spatially superresolved spike train @@ -48,11 +49,20 @@ def superres_sorting( superres_sorting : DARTsortSorting """ pitch = drift_util.get_pitch(geom) - labels = sorting.labels - + full_labels = sorting.labels.copy() + + # remove spikes far away from the probe + if probe_margin_um is not None: + valid = spike_depths_um == np.clip( + spike_depths_um, + geom[:, 1].min() - probe_margin_um, + geom[:, 1].max() + probe_margin_um, + ) + full_labels[~valid] = -1 + # handle triaging - kept = np.flatnonzero(labels >= 0) - labels = labels[kept] + kept = np.flatnonzero(full_labels >= 0) + labels = full_labels[kept] spike_times_s = spike_times_s[kept] spike_depths_um = spike_depths_um[kept] @@ -80,17 +90,15 @@ def superres_sorting( ) else: raise ValueError(f"Unknown superres {strategy=}") - + # handle too-small units superres_labels, superres_to_original = remove_small_superres_units( superres_labels, superres_to_original, min_spikes_per_bin=min_spikes_per_bin ) - - # handle triaging again - full_superres_labels = sorting.labels.copy() - full_superres_labels[kept] = superres_labels - superres_sorting = replace(sorting, labels=full_superres_labels) + # back to un-triaged label space + full_labels[kept] = superres_labels + superres_sorting = replace(sorting, labels=full_labels) return superres_to_original, superres_sorting @@ -108,6 +116,7 @@ def motion_estimate_strategy( displacements = motion_est.disp_at_s(spike_times_s, spike_depths_um) mod_positions = displacements % pitch bin_ids = mod_positions // superres_bin_size_um + bin_ids = bin_ids.astype(int) orig_label_and_bin, superres_labels = np.unique( np.c_[original_labels, bin_ids], axis=0, return_inverse=True ) @@ -128,6 +137,12 @@ def drift_pitch_loc_bin_strategy( ) coarse_reg_depths = spike_depths_um + n_pitches_shift * pitch bin_ids = coarse_reg_depths // superres_bin_size_um + print( + f"{np.isnan(n_pitches_shift).any()=} {np.isfinite(bin_ids).all()=} {superres_bin_size_um=}" + ) + print(f"{bin_ids.min()=} {bin_ids.max()=} {bin_ids.shape=}") + print(f"{original_labels.min()=} {original_labels.max()=} {original_labels.shape=}") + bin_ids = bin_ids.astype(int) orig_label_and_bin, superres_labels = np.unique( np.c_[original_labels, bin_ids], axis=0, return_inverse=True ) @@ -135,8 +150,9 @@ def drift_pitch_loc_bin_strategy( return superres_labels, superres_to_original - -def remove_small_superres_units(superres_labels, superres_to_original, min_spikes_per_bin): +def remove_small_superres_units( + superres_labels, superres_to_original, min_spikes_per_bin +): if not min_spikes_per_bin: return superres_labels, superres_to_original @@ -152,4 +168,4 @@ def remove_small_superres_units(superres_labels, superres_to_original, min_spike superres_labels = relabeling[superres_labels] superres_to_original = superres_to_original[kept_labels] - return superres_labels, superres_to_original \ No newline at end of file + return superres_labels, superres_to_original diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 7ee3d868..30da5991 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -1,8 +1,10 @@ +from dataclasses import dataclass + import numpy as np from dartsort.localize.localize_util import localize_waveforms from dartsort.util import drift_util from dartsort.util.data_util import DARTsortSorting -from dartsort.util.waveform_util import fast_nanmedian +from dartsort.util.spiketorch import fast_nanmedian from scipy.interpolate import interp1d from .get_templates import get_raw_templates, get_templates @@ -55,7 +57,6 @@ def get_registered_templates( denoising_fit_radius=75, denoising_spikes_fit=50_000, denoising_snr_threshold=50.0, - zero_radius_um=None, reducer=fast_nanmedian, random_seed=0, n_jobs=0, @@ -84,7 +85,6 @@ def get_registered_templates( denoising_fit_radius=denoising_fit_radius, denoising_spikes_fit=denoising_spikes_fit, denoising_snr_threshold=denoising_snr_threshold, - zero_radius_um=zero_radius_um, reducer=reducer, random_seed=random_seed, n_jobs=n_jobs, @@ -118,6 +118,23 @@ def get_realigned_sorting( return results["sorting"] +def weighted_average(unit_ids, templates, weights): + n_out = unit_ids.max() + 1 + n_in, t, c = templates.shape + out = np.zeros((n_out, t, c), dtype=templates.dtype) + weights = weights.astype(float) + for i in range(n_out): + which_in = np.flatnonzero(unit_ids == i) + if not which_in.size: + continue + + w = weights[which_in][:, None, None] + w /= w.sum() + out[i] = (w * templates[which_in]).sum(0) + + return out + + # -- template drift handling @@ -126,6 +143,7 @@ def get_template_depths(templates, geom, localization_radius_um=100): templates, geom=geom, radius=localization_radius_um ) template_depths_um = template_locs["z_abs"] + return template_depths_um @@ -136,6 +154,9 @@ def templates_at_time( registered_template_depths_um=None, registered_geom=None, motion_est=None, + return_pitch_shifts=False, + geom_kdtree=None, + match_distance=None, ): if registered_geom is None: return registered_templates @@ -160,16 +181,27 @@ def templates_at_time( n_pitches_shift=pitch_shifts, registered_geom=geom, fill_value=np.nan, + target_kdtree=geom_kdtree, + match_distance=match_distance, ) - assert not np.isnan(unregistered_templates).any() - + if return_pitch_shifts: + return pitch_shifts, unregistered_templates return unregistered_templates # -- template numerical processing -def svd_compress_templates(templates, min_channel_amplitude=1.0, rank=5): +@dataclass +class LowRankTemplates: + temporal_components: np.ndarray + singular_values: np.ndarray + spatial_components: np.ndarray + + +def svd_compress_templates( + templates, min_channel_amplitude=1.0, rank=5, channel_sparse=True +): """ Returns: temporal_components: n_units, spike_length_samples, rank @@ -178,12 +210,37 @@ def svd_compress_templates(templates, min_channel_amplitude=1.0, rank=5): """ vis_mask = templates.ptp(axis=1, keepdims=True) > min_channel_amplitude vis_templates = templates * vis_mask - U, s, Vh = np.linalg.svd(vis_templates, full_matrices=False) - # s is descending. - temporal_components = U[..., :, :rank] - singular_values = s[..., :rank] - spatial_components = Vh[..., :rank, :] - return temporal_components, singular_values, spatial_components + dtype = templates.dtype + + if not channel_sparse: + U, s, Vh = np.linalg.svd(vis_templates, full_matrices=False) + # s is descending. + temporal_components = U[:, :, :rank].astype(dtype) + singular_values = s[:, :rank].astype(dtype) + spatial_components = Vh[:, :rank, :].astype(dtype) + return temporal_components, singular_values, spatial_components + + # channel sparse: only SVD the nonzero channels + # this encodes the same exact subspace as above, and the reconstruction + # error is the same as above as a function of rank. it's just that + # we can zero out some spatial components, which is a useful property + # (used in pairwise convolutions for instance) + n, t, c = templates.shape + temporal_components = np.zeros((n, t, rank), dtype=dtype) + singular_values = np.zeros((n, rank), dtype=dtype) + spatial_components = np.zeros((n, rank, c), dtype=dtype) + for i in range(len(templates)): + template = templates[i] + mask = np.flatnonzero(vis_mask[i, 0]) + k = min(rank, mask.size) + if not k: + continue + U, s, Vh = np.linalg.svd(template[:, mask], full_matrices=False) + temporal_components[i, :, :k] = U[:, :rank] + singular_values[i, :k] = s[:rank] + spatial_components[i, :k, mask] = Vh[:rank].T + + return LowRankTemplates(temporal_components, singular_values, spatial_components) def temporally_upsample_templates( @@ -192,9 +249,130 @@ def temporally_upsample_templates( """Note, also works on temporal components thanks to compatible shape.""" n, t, c = templates.shape tp = np.arange(t).astype(float) - erp = interp1d(tp, templates, axis=1, bounds_error=True) - tup = np.arange(t, step=1. / temporal_upsampling_factor) + erp = interp1d(tp, templates, axis=1, bounds_error=True, kind=kind) + tup = np.arange(t, step=1.0 / temporal_upsampling_factor) tup.clip(0, t - 1, out=tup) upsampled_templates = erp(tup) - upsampled_templates = upsampled_templates.reshape(n, t, temporal_upsampling_factor, c) + upsampled_templates = upsampled_templates.reshape( + n, t, temporal_upsampling_factor, c + ) + upsampled_templates = upsampled_templates.astype(templates.dtype) return upsampled_templates + + +@dataclass +class CompressedUpsampledTemplates: + n_compressed_upsampled_templates: int + compressed_upsampled_templates: np.ndarray + compressed_upsampling_map: np.ndarray + compressed_upsampling_index: np.ndarray + compressed_index_to_template_index: np.ndarray + compressed_index_to_upsampling_index: np.ndarray + + +def default_n_upsamples_map(ptps): + return 4 ** (ptps // 2) + + +def compressed_upsampled_templates( + templates, + ptps=None, + max_upsample=8, + n_upsamples_map=default_n_upsamples_map, + kind="cubic", +): + """compressedly store fewer temporally upsampled copies of lower amplitude templates + + Returns + ------- + A CompressedUpsampledTemplates object with fields: + compressed_upsampled_templates : array (n_compressed_upsampled_templates, spike_length_samples, n_channels) + compressed_upsampling_map : array (n_templates, max_upsample) + compressed_upsampled_templates[compressed_upsampling_map[unit, j]] is an approximation + of the jth upsampled template for this unit. for low-amplitude units, + compressed_upsampling_map[unit] will have fewer unique entries, corresponding + to fewer saved upsampled copies for that unit. + compressed_upsampling_index : array (n_templates, max_upsample) + A n_compressed_upsampled_templates-padded ragged array mapping each + template index to its compressed upsampled indices + compressed_index_to_template_index + compressed_index_to_upsampling_index + """ + n_templates = templates.shape[0] + if max_upsample == 1: + return CompressedUpsampledTemplates( + n_templates, + templates, + np.arange(n_templates)[:, None], + np.arange(n_templates)[:, None], + np.arange(n_templates), + np.zeros(n_templates, dtype=int), + ) + + # how many copies should each unit get? + # sometimes users may pass temporal SVD components in instead of templates, + # so we allow them to pass in the amplitudes of the actual templates + if ptps is None: + ptps = templates.ptp(1).max(1) + assert ptps.shape == (n_templates,) + if n_upsamples_map is None: + n_upsamples = np.full(n_templates, max_upsample) + else: + n_upsamples = np.clip(n_upsamples_map(ptps), 1, max_upsample).astype(int) + + # build the compressed upsampling map + compressed_upsampling_map = np.full((n_templates, max_upsample), -1, dtype=int) + compressed_upsampling_index = np.full((n_templates, max_upsample), -1, dtype=int) + template_indices = [] + upsampling_indices = [] + current_compressed_index = 0 + for i, nup in enumerate(n_upsamples): + compression = max_upsample // nup + nup = max_upsample // compression # handle divisibility failure + + # new compressed indices + compressed_upsampling_map[i] = current_compressed_index + np.arange(nup).repeat( + compression + ) + compressed_upsampling_index[i, :nup] = current_compressed_index + np.arange(nup) + current_compressed_index += nup + + # indices of the templates to keep in the full array of upsampled templates + template_indices.extend([i] * nup) + upsampling_indices.extend(compression * np.arange(nup)) + assert (compressed_upsampling_map >= 0).all() + assert ( + np.unique(compressed_upsampling_map).size + == (compressed_upsampling_index >= 0).sum() + == compressed_upsampling_map.max() + 1 + == compressed_upsampling_index.max() + 1 + == current_compressed_index + ) + template_indices = np.array(template_indices) + upsampling_indices = np.array(upsampling_indices) + compressed_upsampling_index[ + compressed_upsampling_index < 0 + ] = current_compressed_index + + # get the upsampled templates + all_upsampled_templates = temporally_upsample_templates( + templates, temporal_upsampling_factor=max_upsample, kind=kind + ) + # n, up, t, c + all_upsampled_templates = all_upsampled_templates.transpose(0, 2, 1, 3) + rix = np.ravel_multi_index( + (template_indices, upsampling_indices), all_upsampled_templates.shape[:2] + ) + all_upsampled_templates = all_upsampled_templates.reshape( + n_templates * max_upsample, templates.shape[1], templates.shape[2] + ) + compressed_upsampled_templates = all_upsampled_templates[rix] + + return CompressedUpsampledTemplates( + current_compressed_index, + compressed_upsampled_templates, + compressed_upsampling_map, + compressed_upsampling_index, + template_indices, + upsampling_indices, + ) diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index d36ff123..6ff43c34 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -1,17 +1,14 @@ -from dataclasses import dataclass +from dataclasses import dataclass, replace +from pathlib import Path from typing import Optional import numpy as np -from pathlib import Path +from dartsort.util import drift_util from .get_templates import get_templates from .superres_util import superres_sorting -from .template_util import ( - get_registered_templates, - get_realigned_sorting, - get_template_depths, -) -from dartsort.util import drift_util +from .template_util import (get_realigned_sorting, get_template_depths, + weighted_average) _motion_error_prefix = ( "If template_config has registered_templates==True " @@ -22,31 +19,65 @@ @dataclass class TemplateData: + # (n_templates, spike_length_samples, n_registered_channels or n_channels) templates: np.ndarray + # (n_templates,) maps template index to unit index (multiple templates can share a unit index) unit_ids: np.ndarray + # (n_templates,) spike count for each template + spike_counts: np.ndarray + registered_geom: Optional[np.ndarray] = None registered_template_depths_um: Optional[np.ndarray] = None - + localization_radius_um: float = 100.0 + @classmethod def from_npz(cls, npz_path): - with np.load(npz_path) as npz: - templates = npz["templates"] - unit_ids = npz["unit_ids"] - registered_geom = registered_template_depths_um = None - if "registered_geom" in npz: - registered_geom = npz["registered_geom"] - if "registered_template_depths_um" in npz: - registered_template_depths_um = npz["registered_template_depths_um"] - return cls(templates, unit_ids, registered_geom, registered_template_depths_um) - + with np.load(npz_path) as data: + return cls(**data) + def to_npz(self, npz_path): - to_save = dict(templates=self.templates, unit_ids=self.unit_ids) + to_save = dict( + templates=self.templates, + unit_ids=self.unit_ids, + spike_counts=self.spike_counts, + ) if self.registered_geom is not None: to_save["registered_geom"] = self.registered_geom if self.registered_template_depths_um is not None: - to_save["registered_template_depths_um"] = self.registered_template_depths_um + to_save[ + "registered_template_depths_um" + ] = self.registered_template_depths_um + if not npz_path.parent.exists(): + npz_path.parent.mkdir() np.savez(npz_path, **to_save) - + + def coarsen(self, with_locs=True): + """Weighted average all templates that share a unit id and re-localize.""" + # update templates + unit_ids_unique, flat_ids = np.unique(self.unit_ids, return_inverse=True) + templates = weighted_average(flat_ids, self.templates, self.spike_counts) + + # collect spike counts + spike_counts = np.zeros(len(templates)) + np.add.at(spike_counts, flat_ids, self.spike_counts) + + # re-localize + registered_template_depths_um = None + if with_locs: + registered_template_depths_um = get_template_depths( + templates, + self.registered_geom, + localization_radius_um=self.localization_radius_um, + ) + + return replace( + self, + templates=templates, + unit_ids=unit_ids_unique, + spike_counts=spike_counts, + registered_template_depths_um=registered_template_depths_um, + ) + @classmethod def from_config( cls, @@ -59,7 +90,7 @@ def from_config( save_npz_name="template_data.npz", localizations_dataset_name="point_source_localizations", n_jobs=0, - device=None, + device=None, ): if save_folder is not None: save_folder = Path(save_folder) @@ -67,8 +98,8 @@ def from_config( save_folder.mkdir() npz_path = save_folder / save_npz_name if npz_path.exists() and not overwrite: - return cls.from_npz(npz_path) - + return cls.from_npz(npz_path) + motion_aware = ( template_config.registered_templates or template_config.superres_templates ) @@ -95,6 +126,7 @@ def from_config( trough_offset_samples=template_config.trough_offset_samples, spike_length_samples=template_config.spike_length_samples, spikes_per_unit=template_config.spikes_per_unit, + # realign handled in advance below, not needed in kwargs # realign_peaks=template_config.realign_peaks, realign_max_sample_shift=template_config.realign_max_sample_shift, denoising_rank=template_config.denoising_rank, @@ -139,8 +171,15 @@ def from_config( min_spikes_per_bin=template_config.superres_bin_min_spikes, ) else: + # we don't skip empty units unit_ids = np.arange(sorting.labels.max() + 1) + # count spikes in each template + spike_counts = np.zeros_like(unit_ids) + ix, counts = np.unique(sorting.labels, return_counts=True) + spike_counts[ix[ix >= 0]] = counts[ix >= 0] + + # main! results = get_templates(recording, sorting, **kwargs) # handle registered templates @@ -153,15 +192,18 @@ def from_config( obj = cls( results["templates"], unit_ids, + spike_counts, kwargs["registered_geom"], registered_template_depths_um, + localization_radius_um=template_config.registered_template_localization_radius_um, ) else: - obj = cls( + obj = cls( results["templates"], unit_ids, + spike_counts, ) - + if save_folder is not None: obj.to_npz(npz_path) diff --git a/src/dartsort/transform/all_transformers.py b/src/dartsort/transform/all_transformers.py index fd7901ee..bf20d743 100644 --- a/src/dartsort/transform/all_transformers.py +++ b/src/dartsort/transform/all_transformers.py @@ -1,6 +1,6 @@ from .amplitudes import AmplitudeVector, MaxAmplitude from .enforce_decrease import EnforceDecrease -from .localize import PointSourceLocalization +from .localize import Localization, PointSourceLocalization from .single_channel_denoiser import SingleChannelWaveformDenoiser from .temporal_pca import TemporalPCADenoiser, TemporalPCAFeaturizer from .transform_base import Waveform @@ -13,6 +13,7 @@ SingleChannelWaveformDenoiser, TemporalPCADenoiser, TemporalPCAFeaturizer, + Localization, PointSourceLocalization, ] diff --git a/src/dartsort/transform/localize.py b/src/dartsort/transform/localize.py index 29786441..fa7acf37 100644 --- a/src/dartsort/transform/localize.py +++ b/src/dartsort/transform/localize.py @@ -5,7 +5,7 @@ from .transform_base import BaseWaveformFeaturizer -class PointSourceLocalization(BaseWaveformFeaturizer): +class Localization(BaseWaveformFeaturizer): """Order of output columns: x, y, z_abs, alpha""" default_name = "point_source_localizations" @@ -22,6 +22,7 @@ def __init__( amplitude_kind="peak", name=None, name_prefix="", + localization_model="pointsource", ): assert amplitude_kind in ("peak", "ptp") super().__init__( @@ -34,6 +35,7 @@ def __init__( self.radius = radius self.n_channels_subset = n_channels_subset self.logbarrier = logbarrier + self.localization_model = localization_model def transform(self, waveforms, max_channels=None): # get amplitude vectors @@ -52,6 +54,7 @@ def transform(self, waveforms, max_channels=None): n_channels_subset=self.n_channels_subset, logbarrier=self.logbarrier, dtype=self.dtype, + model=self.localization_model, ) localizations = torch.column_stack( @@ -63,3 +66,5 @@ def transform(self, waveforms, max_channels=None): ] ) return localizations + +PointSourceLocalization = Localization \ No newline at end of file diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index 16c8969c..bd4b4b57 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -72,9 +72,8 @@ def to_numpy_sorting(self): def __str__(self): name = self.__class__.__name__ nspikes = self.times_samples.size - units = np.unique(self.labels) - units = units[units >= 0] - unit_str = f"{units.size} unit" + ("s" if units.size > 1 else "") + nunits = (np.unique(self.labels) >= 0).sum() + unit_str = f"{nunits} unit" + "s" * (nunits > 1) feat_str = "" if self.extra_features: feat_str = ", ".join(self.extra_features.keys()) @@ -149,6 +148,7 @@ def check_recording( dedup_channel_index = make_channel_index( rec.get_channel_locations(), dedup_spatial_radius ) + failed = False # run detection and compute spike detection rate and data range spike_rates = [] @@ -173,6 +173,7 @@ def check_recording( "you experience memory issues.", RuntimeWarning, ) + failed = True if max_abs > expected_value_range: warn( @@ -180,5 +181,6 @@ def check_recording( "check that your data has been preprocessed, including standardization.", RuntimeWarning, ) + failed = True - return avg_detections_per_second, max_abs + return failed, avg_detections_per_second, max_abs diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index 76748c7a..7c3dce5d 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -10,12 +10,15 @@ by integer numbers of pitches. As many shifted copies are created as needed to capture all the drift. """ +from dataclasses import dataclass + import numpy as np import torch from scipy.spatial import KDTree from scipy.spatial.distance import pdist -from .waveform_util import fast_nanmedian, get_pitch +from .spiketorch import fast_nanmedian +from .waveform_util import get_pitch # -- registered geometry and templates helpers @@ -184,6 +187,8 @@ def registered_template( weights = valid[:, None, :] * counts[:, None, None] weights = weights / np.maximum(weights.sum(0), 1) template = (np.nan_to_num(static_templates) * weights).sum(0) + dtype = str(waveforms.dtype).split(".")[1] if is_tensor else waveforms.dtype + template = template.astype(dtype) template[:, ~valid.any(0)] = np.nan if not np.isnan(pad_value): template = np.nan_to_num(template, copy=False, nan=pad_value) @@ -230,7 +235,9 @@ def invert_motion_estimate(motion_est, t_s, registered_depths_um): hasattr(motion_est, "spatial_bin_centers_um") and motion_est.spatial_bin_centers_um is not None ): + # nonrigid motion bin_centers = motion_est.spatial_bin_centers_um + t_s = np.full(bin_centers.shape, t_s) bin_center_disps = motion_est.disp_at_s(t_s, depth_um=bin_centers) # registered_bin_centers = motion_est.correct_s(t_s, depths_um=bin_centers) registered_bin_centers = bin_centers - bin_center_disps @@ -239,6 +246,7 @@ def invert_motion_estimate(motion_est, t_s, registered_depths_um): registered_depths_um, registered_bin_centers, bin_center_disps ) else: + # rigid motion disps = motion_est.disp_at_s(t_s) return registered_depths_um + disps @@ -374,11 +382,19 @@ def get_waveforms_on_static_channels( # scatter the waveforms into their static channel neighborhoods if out is None: - static_waveforms = np.full( - (n_spikes, t, n_static_channels + 1), - fill_value=fill_value, - dtype=waveforms.dtype, - ) + if torch.is_tensor(waveforms): + static_waveforms = torch.full( + (n_spikes, t, n_static_channels + 1), + fill_value=fill_value, + dtype=waveforms.dtype, + device=waveforms.device, + ) + else: + static_waveforms = np.full( + (n_spikes, t, n_static_channels + 1), + fill_value=fill_value, + dtype=waveforms.dtype, + ) else: assert out.shape == (n_spikes, t, n_static_channels + 1) out.fill(fill_value) @@ -404,12 +420,22 @@ def _full_probe_shifting_fast( fill_value, out=None, ): + is_tensor = torch.is_tensor(waveforms) + if out is None: - static_waveforms = np.full( - (*waveforms.shape[:2], target_kdtree.n + 1), - fill_value=fill_value, - dtype=waveforms.dtype, - ) + if is_tensor: + static_waveforms = torch.full( + (*waveforms.shape[:2], target_kdtree.n + 1), + fill_value=fill_value, + dtype=waveforms.dtype, + device=waveforms.device, + ) + else: + static_waveforms = np.full( + (*waveforms.shape[:2], target_kdtree.n + 1), + fill_value=fill_value, + dtype=waveforms.dtype, + ) else: assert out.shape == (*waveforms.shape[:2], target_kdtree.n + 1) out.fill(fill_value) @@ -434,3 +460,130 @@ def _full_probe_shifting_fast( shifted_channels[shift_inverse][:, None, :], ] = waveforms return static_waveforms[:, :, : target_kdtree.n] + + +# -- which templates appear at which shifts in a recording? +# and, which pairs of shifted templates appear together? + + +@dataclass +class TemplateShiftIndex: + """Return value for get_shift_and_unit_pairs""" + + n_shifted_templates: int + # shift index -> shift + all_pitch_shifts: np.ndarray + # (template ix, shift index) -> shifted template index + template_shift_index: np.ndarray + # (shifted temp ix, shifted temp ix) -> did these appear at the same time + shifted_temp_ix_to_temp_ix: np.ndarray + shifted_temp_ix_to_shift: np.ndarray + + @classmethod + def from_shift_matrix(cls, shifts): + """shift: n_times x n_templates""" + all_shifts = np.unique(shifts) + n_templates = shifts.shape[1] + pairs = np.stack(np.broadcast_arrays(np.arange(n_templates)[None, :], shifts), axis=2) + pairs = np.unique(pairs.reshape(shifts.size, 2), axis=0) + n_shifted_templates = len(pairs) + shift_ix = np.searchsorted(all_shifts, pairs[:, 1]) + template_shift_index = np.full( + (n_templates, len(all_shifts)), n_shifted_templates + ) + template_shift_index[pairs[:, 0], shift_ix] = np.arange(n_shifted_templates) + return cls( + n_shifted_templates, + all_shifts, + template_shift_index, + *pairs.T, + ) + + def shifts_to_shifted_ids(self, template_ids, shifts): + shift_ixs = np.searchsorted(self.all_pitch_shifts, shifts) + return self.template_shift_index[template_ids, shift_ixs] + + +def static_template_shift_index(n_templates): + temp_ixs = np.arange(n_templates) + return TemplateShiftIndex( + n_templates, + np.zeros(1), + temp_ixs[:, None], + temp_ixs, + np.zeros_like(temp_ixs), + ) + + +def get_shift_and_unit_pairs( + chunk_time_centers_s, + geom, + template_data_a, + template_data_b=None, + motion_est=None, +): + if template_data_b is None: + template_data_b = template_data_a + + na = template_data_a.templates.shape[0] + nb = template_data_b.templates.shape[0] + + if motion_est is None: + shift_index_a = static_template_shift_index(na) + shift_index_b = static_template_shift_index(nb) + cooccurrence = np.ones((na, nb), dtype=bool) + return shift_index_a, shift_index_b, cooccurrence + + reg_depths_um_a = template_data_a.registered_template_depths_um + reg_depths_um_b = template_data_b.registered_template_depths_um + same = np.array_equal(reg_depths_um_a, reg_depths_um_b) + if same: + reg_depths_um = reg_depths_um_a + else: + reg_depths_um = np.concatenate((reg_depths_um_a, reg_depths_um_b)) + + # figure out all shifts for all units at all times + unreg_depths_um = np.stack( + [ + invert_motion_estimate( + motion_est, t_s, reg_depths_um + ) + for t_s in chunk_time_centers_s + ], + axis=0, + ) + assert unreg_depths_um.shape == (len(chunk_time_centers_s), len(reg_depths_um)) + diff = reg_depths_um - unreg_depths_um + pitch_shifts = get_spike_pitch_shifts( + depths_um=reg_depths_um, + pitch=get_pitch(geom), + registered_depths_um=unreg_depths_um, + ) + if same: + shifts_a = shifts_b = pitch_shifts + else: + shifts_a = pitch_shifts[:, :na] + shifts_b = pitch_shifts[:, na:] + + # assign ids to pitch/shift pairs + template_shift_index_a = TemplateShiftIndex.from_shift_matrix(shifts_a) + if same: + template_shift_index_b = template_shift_index_a + else: + template_shift_index_b = TemplateShiftIndex.from_shift_matrix(shifts_b) + + # co-occurrence matrix: do these shifted templates appear together? + cooccurrence = np.zeros( + (template_shift_index_a.n_shifted_templates, template_shift_index_b.n_shifted_templates), + dtype=bool) + temps_a = np.arange(na) + temps_b = np.arange(nb) + for j in range(len(chunk_time_centers_s)): + shifted_ids_a = template_shift_index_a.shifts_to_shifted_ids(temps_a, shifts_a[j]) + if same: + shifted_ids_b = shifted_ids_a + else: + shifted_ids_b = template_shift_index_b.shifts_to_shifted_ids(temps_b, shifts_b[j]) + cooccurrence[shifted_ids_a[:, None], shifted_ids_b[None, :]] = 1 + + return template_shift_index_a, template_shift_index_b, cooccurrence diff --git a/src/dartsort/util/multiprocessing_util.py b/src/dartsort/util/multiprocessing_util.py index dcbfb429..efc8a691 100644 --- a/src/dartsort/util/multiprocessing_util.py +++ b/src/dartsort/util/multiprocessing_util.py @@ -2,6 +2,8 @@ from concurrent.futures import ProcessPoolExecutor from multiprocessing import get_context +# TODO: torch.multiprocessing? + try: import cloudpickle except ImportError: diff --git a/src/dartsort/util/spikeio.py b/src/dartsort/util/spikeio.py index 8f678c0e..0430b82a 100644 --- a/src/dartsort/util/spikeio.py +++ b/src/dartsort/util/spikeio.py @@ -21,7 +21,7 @@ def read_full_waveforms( assert times_samples.dtype.kind == "i" assert ( times_samples.max() - < recording.get_num_samples() + <= recording.get_num_samples() - (spike_length_samples - trough_offset_samples) ) n_channels = recording.get_num_channels() @@ -92,7 +92,7 @@ def read_subset_waveforms( assert times_samples.dtype.kind == "i" assert ( times_samples.max() - < recording.get_num_samples() + <= recording.get_num_samples() - (spike_length_samples - trough_offset_samples) ) n_channels = recording.get_num_channels() @@ -169,7 +169,7 @@ def read_waveforms_channel_index( assert times_samples.min() >= trough_offset_samples assert ( times_samples.max() - < recording.get_num_samples() + <= recording.get_num_samples() - (spike_length_samples - trough_offset_samples) ) n_channels = recording.get_num_channels() diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index b892958f..aeb305b9 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -1,9 +1,24 @@ +import math + import torch import torch.nn.functional as F +from scipy.signal._signaltools import _calc_oa_lens from torch.fft import irfft, rfft +def fast_nanmedian(x, axis=-1): + is_tensor = torch.is_tensor(x) + x = torch.nanmedian(torch.as_tensor(x), dim=axis).values + if is_tensor: + return x + else: + return x.numpy() + + def ptp(waveforms, dim=1): + is_tensor = torch.is_tensor(waveforms) + if not is_tensor: + return waveforms.ptp(axis=dim) return waveforms.max(dim=dim).values - waveforms.min(dim=dim).values @@ -24,8 +39,6 @@ def ravel_multi_index(multi_index, dims): Indices into the flattened tensor of shape `dims` """ assert len(multi_index) == len(dims) - if any(torch.any((ix < 0) | (ix >= d)) for ix, d in zip(multi_index, dims)): - raise ValueError("Out of bounds indices in ravel_multi_index") # collect multi indices multi_index = torch.broadcast_tensors(*multi_index) @@ -53,9 +66,10 @@ def add_at_(dest, ix, src, sign=1): src = src.neg() elif sign != 1: src = sign * src + flat_ix = ravel_multi_index(ix, dest.shape) dest.view(-1).scatter_add_( 0, - ravel_multi_index(ix, dest.shape), + flat_ix, src.reshape(-1), ) @@ -88,6 +102,25 @@ def grab_spikes( return traces[time_ix[:, :, None], chan_ix[:, None, :]] +def grab_spikes_full( + traces, + trough_times, + trough_offset=42, + spike_length_samples=121, + buffer=0, +): + """Grab spikes from a tensor of traces""" + assert trough_times.ndim == 1 + spike_sample_offsets = torch.arange( + buffer - trough_offset, + buffer - trough_offset + spike_length_samples, + device=trough_times.device, + ) + time_ix = trough_times[:, None] + spike_sample_offsets[None, :] + chan_ix = torch.arange(traces.shape[1], device=traces.device) + return traces[time_ix[:, :, None], chan_ix[None, None, :]] + + def add_spikes_( traces, trough_times, @@ -215,6 +248,83 @@ def real_resample(x, num, dim=0): # inverse transform y = irfft(g, num, dim=dim) - y *= (float(num) / float(Nx)) + y *= float(num) / float(Nx) return y + + +def steps_and_pad(s1, in1_step, s2, in2_step, block_size, overlap): + shape_final = s1 + s2 - 1 + # figure out n steps and padding + if s1 > in1_step: + nstep1 = math.ceil((s1 + 1) / in1_step) + if (block_size - overlap) * nstep1 < shape_final: + nstep1 += 1 + + pad1 = nstep1 * in1_step - s1 + else: + nstep1 = 1 + pad1 = 0 + + if s2 > in2_step: + nstep2 = math.ceil((s2 + 1) / in2_step) + if (block_size - overlap) * nstep2 < shape_final: + nstep2 += 1 + + pad2 = nstep2 * in2_step - s2 + else: + nstep2 = 1 + pad2 = 0 + return nstep1, pad1, nstep2, pad2 + + +def depthwise_oaconv1d(input, weight, f2=None, padding=0): + """Depthwise correlation (F.conv1d with groups=in_chans) with overlap-add""" + # conv on last axis + # assert input.ndim == weight.ndim == 2 + n1 = input.shape[0] + n2 = weight.shape[0] + assert n1 == n2 + s1 = input.shape[1] + s2 = weight.shape[1] + assert s1 >= s2 + + # shape_full = s1 + s2 - 1 + block_size, overlap, in1_step, in2_step = _calc_oa_lens(s1, s2) + nstep1, pad1, nstep2, pad2 = steps_and_pad( + s1, in1_step, s2, in2_step, block_size, overlap + ) + + if pad1 > 0: + input = F.pad(input, (0, pad1)) + input = input.reshape(n1, nstep1, in1_step) + + # freq domain correlation + f1 = torch.fft.rfft(input, n=block_size) + if f2 is None: + f2 = torch.fft.rfft(weight, n=block_size) + # .conj() here to do cross-correlation instead of convolution (time reversal property of rfft) + f1.mul_(f2.conj()[:, None, :]) + res = torch.fft.irfft(f1, n=block_size) + + # overlap add part with torch + fold_input = res.reshape(n1, nstep1, block_size).permute(0, 2, 1) + fold_out_len = nstep1 * in1_step + overlap + fold_res = F.fold( + fold_input, + output_size=(1, fold_out_len), + kernel_size=(1, block_size), + stride=(1, in1_step), + ) + assert fold_res.shape == (n1, 1, 1, fold_out_len) + + oa = fold_res.reshape(n1, fold_out_len) + # this is the full convolution + # oa = oa[:, : shape_final] + # extract correct padding + valid_len = s1 - s2 + 1 + valid_start = s2 - 1 + assert valid_start >= padding + oa = oa[:, valid_start - padding:valid_start + valid_len + padding] + + return oa diff --git a/src/dartsort/util/waveform_util.py b/src/dartsort/util/waveform_util.py index 4acee974..a1394734 100644 --- a/src/dartsort/util/waveform_util.py +++ b/src/dartsort/util/waveform_util.py @@ -335,14 +335,3 @@ def get_channel_subset( npx.arange(N)[:, None], rel_sub_channel_index[max_channels][:, :], ] - - -# -- general util - -def fast_nanmedian(x, axis=-1): - is_tensor = torch.is_tensor(x) - x = torch.nanmedian(torch.as_tensor(x), dim=axis).values - if is_tensor: - return x - else: - return x.numpy() diff --git a/src/dartsort/vis/scatterplots.py b/src/dartsort/vis/scatterplots.py index f639394b..67406451 100644 --- a/src/dartsort/vis/scatterplots.py +++ b/src/dartsort/vis/scatterplots.py @@ -24,6 +24,8 @@ def scatter_spike_features( amplitude_cmap=plt.cm.viridis, max_spikes_plot=500_000, probe_margin_um=100, + t_min=-np.inf, + t_max=np.inf, s=1, linewidth=0, limits="probe_margin", @@ -56,14 +58,14 @@ def scatter_spike_features( amplitudes = h5["denoised_amplitudes"][:] geom = h5["geom"][:] - to_show = None + to_show = np.flatnonzero(np.clip(times_s, t_min, t_max) == times_s) if geom is not None: - to_show = np.flatnonzero( - (depths_um > geom[:, 1].min() - probe_margin_um) - & (depths_um < geom[:, 1].max() + probe_margin_um) - & (x > geom[:, 0].min() - probe_margin_um) - & (x < geom[:, 0].max() + probe_margin_um) - ) + to_show = to_show[ + (depths_um[to_show] > geom[:, 1].min() - probe_margin_um) + & (depths_um[to_show] < geom[:, 1].max() + probe_margin_um) + & (x[to_show] > geom[:, 0].min() - probe_margin_um) + & (x[to_show] < geom[:, 0].max() + probe_margin_um) + ] _, s_x = scatter_x_vs_depth( x=x, @@ -128,7 +130,7 @@ def scatter_spike_features( to_show=to_show, **scatter_kw, ) - + if label_axes: axes[0].set_ylabel("depth (um)") axes[0].set_xlabel("x (um)") @@ -166,6 +168,8 @@ def scatter_time_vs_depth( the times_s, depths_um, and (one of) amplitudes or labels as arrays, or alternatively, these can be left unset and they will be loaded from hdf5_filename when it is supplied. + + Returns: axis, scatter """ if hdf5_filename is not None: with h5py.File(hdf5_filename, "r") as h5: diff --git a/src/spike_psvae/chunk_features.py b/src/spike_psvae/chunk_features.py index e7940148..a0defb89 100644 --- a/src/spike_psvae/chunk_features.py +++ b/src/spike_psvae/chunk_features.py @@ -392,6 +392,7 @@ def transform( else: if self.ptp_precision_decimals is not None: ptps = np.round(ptps, decimals=self.ptp_precision_decimals) + ( xs, ys, @@ -487,7 +488,7 @@ def raw_fit(self, wfs, max_channels): self.needs_fit = False self.dtype = self.tpca.components_.dtype - self.n_components = self.tpca.n_components + self.n_components = self.n_components self.components_ = self.tpca.components_ self.mean_ = self.tpca.mean_ if self.centered: # otherwise SVD diff --git a/src/spike_psvae/denoise.py b/src/spike_psvae/denoise.py index a42e992f..4ecf62cf 100644 --- a/src/spike_psvae/denoise.py +++ b/src/spike_psvae/denoise.py @@ -95,7 +95,9 @@ def phase_shift_and_hallucination_idx_preshift(waveforms_roll_denoise, waveforms which = slice(offset-10, offset+10) - d_s_corr = wfs_corr(waveforms_roll_denoise[:, which], waveforms_roll[:, which])#torch.sum(wfs_denoised[which]*chan_wfs[which], 1)/torch.sqrt(torch.sum(chan_wfs[which]*chan_wfs[which],1) * torch.sum(wfs_denoised[which]*wfs_denoised[which],1)) ## didn't use which at the beginning! check whether this changes the results + d_s_corr = wfs_corr(waveforms_roll_denoise[:, which], waveforms_roll[:, which]) + # torch.sum(wfs_denoised[which]*chan_wfs[which], 1)/torch.sqrt(torch.sum(chan_wfs[which]*chan_wfs[which],1) * torch.sum(wfs_denoised[which]*wfs_denoised[which],1)) + # didn't use which at the beginning! check whether this changes the results halu_idx = (ptp(waveforms_roll_denoise, 1) decr_ptp[parents_rel].max(): decr_ptp[c] *= decr_ptp[parents_rel].max() / decr_ptp[c] + # decreasing_ptps[i] = decr_ptp # apply decreasing ptps to the original waveforms rescale = (decreasing_ptps / orig_ptps)[:, None, :] + if is_torch: rescale = torch.as_tensor(rescale, device=waveforms.device) if in_place: diff --git a/src/spike_psvae/hybrid_analysis.py b/src/spike_psvae/hybrid_analysis.py index 8305b3c6..57a0d48c 100644 --- a/src/spike_psvae/hybrid_analysis.py +++ b/src/spike_psvae/hybrid_analysis.py @@ -1836,7 +1836,7 @@ def calc_template_snrs( spike_length_samples=spike_length_samples, buffer=wf_buffer, ) - denominator = np.abs(np.einsum("ij,nij->n", t, noise) / C).mean() + denominator = np.abs(np.einsum("ij,nij->n", t, noise) / C).std()#.mean() snrs.append(numerator / denominator) return np.array(snrs) diff --git a/src/spike_psvae/localize_index.py b/src/spike_psvae/localize_index.py index ea8708a9..b90f8d73 100644 --- a/src/spike_psvae/localize_index.py +++ b/src/spike_psvae/localize_index.py @@ -65,9 +65,9 @@ def ptp_at_dipole(x1, y1, z1, alpha, x2, y2, z2): ) - 1 / np.sqrt( - np.square(x2 - local_geom[:, 0]) - + np.square(z2 - local_geom[:, 1]) - + np.square(y2) + np.square(x2 + x1 - local_geom[:, 0]) + + np.square(z2 + z1 - local_geom[:, 1]) + + np.square(y2 + y1) ) ) return ptp_dipole_out @@ -107,18 +107,24 @@ def mse(loc): # - (np.log1p(10.0 * y) / 10000.0 if logbarrier else 0) # ) - def mse_dipole(x_in): - x1 = x_in[0] - y1 = x_in[1] - z1 = x_in[2] - x2 = x_in[3] - y2 = x_in[4] - z2 = x_in[5] - q = ptp_at_dipole(x1, y1, z1, 1.0, x2, y2, z2) - alpha = (q * ptp).sum() / (q * q).sum() - return np.square( - ptp - ptp_at_dipole(x1, y1, z1, alpha, x2, y2, z2) - ).mean() - (np.log1p(10.0 * y1) / 10000.0 if logbarrier else 0) + def mse_dipole(loc): + x, y, z = loc + # q = ptp_at(x, y, z, 1.0) + # alpha = (q * (ptp / maxptp - delta)).sum() / (q * q).sum() + duv = np.c_[ + x - local_geom[:, 0], + np.broadcast_to(y, ptp.shape), + z - local_geom[:, 1], + ] + X = duv / np.power(np.square(duv).sum(axis=1, keepdims=True), 3/2) + beta = np.linalg.solve(X.T @ X, X.T @ (ptp / maxptp)) + qtq = X @ beta + return ( + np.square(ptp / maxptp - qtq).mean() + # np.square(ptp / maxptp - delta - ptp_at(x, y, z, alpha)).mean() + # np.square(np.maximum(0, ptp / maxptp - ptp_at(x, y, z, alpha))).mean() + - np.log1p(10.0 * y) / 10000.0 + ) if model == "pointsource": result = minimize( @@ -146,24 +152,51 @@ def mse_dipole(x_in): result = minimize( mse_dipole, - x0=[xcom, Y0, zcom, xcom + 1, Y0 + 1, zcom + 1], + x0=[xcom, Y0, zcom], bounds=[ (local_geom[:, 0].min() - DX, local_geom[:, 0].max() + DX), (1e-4, 250), (-DZ, DZ), - (-100, 100), - (-100, 100), - (-100, 100), ], ) # print(result) - bx, by, bz_rel, bpx, bpy, bpz = result.x - - q = ptp_at_dipole(bx, by, bz_rel, 1.0, bpx, bpy, bpz) - - balpha = (q * ptp).sum() / (q * q).sum() - return bx, by, bz_rel, balpha + bx, by, bz_rel = result.x + + duv = np.c_[ + bx - local_geom[:, 0], + np.broadcast_to(by, ptp.shape), + bz_rel - local_geom[:, 1], + ] + X = duv / np.power(np.square(duv).sum(axis=1, keepdims=True), 3/2) + beta = np.linalg.solve(X.T @ X, X.T @ (ptp / maxptp)) + beta /= np.sqrt(np.square(beta).sum()) + dipole_planar_direction = np.sqrt(np.square(beta[[0, 2]]).sum()) + closest_chan = np.square(duv).sum(1).argmin() + min_duv = duv[closest_chan] + + val_th = np.sqrt(np.square(min_duv).sum())/dipole_planar_direction + + # reparameterized_dist = np.sqrt(np.square(min_duv[0]/beta[2]) + np.square(min_duv[2]/beta[0]) + # + np.square(min_duv[1]/beta[1])) + + if val_th<250: + return bx, by, bz_rel, val_th + else: + result = minimize( + mse, + x0=[xcom, Y0, zcom], + bounds=[ + (local_geom[:, 0].min() - DX, local_geom[:, 0].max() + DX), + (1e-4, 250), + (-DZ, DZ), + ], + ) + # print(result) + bx, by, bz_rel = result.x + q = ptp_at(bx, by, bz_rel, 1.0) + balpha = (ptp * q).sum() / np.square(q).sum() + return bx, by, bz_rel, val_th else: raise NameError("Wrong localization model") @@ -230,6 +263,5 @@ def localize_ptps_index( ys[n] = y z_rels[n] = z_rel alphas[n] = alpha - z_abss = z_rels + geom[maxchans, 1] return xs, ys, z_rels, z_abss, alphas diff --git a/src/spike_psvae/subtract.py b/src/spike_psvae/subtract.py index d1e3d62a..879b3645 100644 --- a/src/spike_psvae/subtract.py +++ b/src/spike_psvae/subtract.py @@ -668,10 +668,10 @@ def subtraction_binary( n_channels = geom.shape[0] recording = sc.read_binary( - standardized_bin, - sampling_rate, - n_channels, - binary_dtype, + file_paths=standardized_bin, + sampling_frequency=sampling_rate, + num_channels=n_channels, + dtype=binary_dtype, time_axis=time_axis, is_filtered=True, ) @@ -1077,7 +1077,7 @@ def subtraction_batch( batch_data_folder / f"{prefix}{f.name}.npy", feat, ) - + denoised_wfs = full_denoising( cleaned_wfs, spike_index[:, 1], diff --git a/tests/test_grab_and_featurize.py b/tests/test_grab_and_featurize.py index 67abb6ae..eb036136 100644 --- a/tests/test_grab_and_featurize.py +++ b/tests/test_grab_and_featurize.py @@ -143,7 +143,7 @@ def test_grab_and_featurize(): fit_radius=10, ), transform.Waveform(channel_index, name="tpca_waveforms"), - transform.PointSourceLocalization( + transform.Localization( channel_index=channel_index, geom=geom, radius=50.0 ), ] @@ -249,8 +249,14 @@ def test_grab_and_featurize(): assert np.array_equal(h5["channel_index"][()], channel_index) assert h5["last_chunk_start"][()] == 90_000 - # this is kind of a good test of reproducibility/random seeds - assert np.array_equal(locs0, locs1) + # this is kind of a good test of reproducibility + # totally reproducible on CPU, suprprisingly large diffs on GPU + if not torch.cuda.is_available(): + assert np.array_equal(locs0, locs1) + else: + valid = np.clip(locs1[:, 2], geom[:,1].min(), geom[:,1].max()) + valid = locs1[:, 2] == valid + assert np.isclose(locs0[valid], locs1[valid], atol=1e-6).all() if __name__ == "__main__": diff --git a/tests/test_matching.py b/tests/test_matching.py new file mode 100644 index 00000000..89c110a3 --- /dev/null +++ b/tests/test_matching.py @@ -0,0 +1,472 @@ +import numpy as np +import spikeinterface.full as si +import torch +import torch.nn.functional as F +from dartsort import config, main +from dartsort.templates import TemplateData, template_util +from dredge import motion_util +from test_util import no_overlap_recording_sorting + +nofeatcfg = config.FeaturizationConfig( + do_nn_denoise=False, + do_tpca_denoise=False, + do_enforce_decrease=False, + denoise_only=True, +) + +spike_length_samples = 121 +trough_offset_samples = 42 + + +def test_tiny(tmp_path): + recording_length_samples = 200 + n_channels = 2 + geom = np.c_[np.zeros(2), np.arange(2)] + geom + + # template main channel traces + trace0 = 50 * np.exp( + -(((np.arange(spike_length_samples) - trough_offset_samples) / 10) ** 2) + ) + + # templates + templates = np.zeros((2, spike_length_samples, n_channels), dtype="float32") + templates[0, :, 0] = trace0 + templates[1, :, 1] = trace0 + + # spike train + # fmt: off + tcl = [ + 50, 0, 0, + 51, 1, 1, + ] + # fmt: on + times, channels, labels = np.array(tcl).reshape(-1, 3).T + rec = np.zeros((recording_length_samples, n_channels), dtype="float32") + for t, l in zip(times, labels): + rec[ + t - trough_offset_samples : t - trough_offset_samples + spike_length_samples + ] += templates[l] + rec = si.NumpyRecording(rec, 30_000) + rec.set_dummy_probe_from_locations(geom) + + template_config = config.TemplateConfig( + low_rank_denoising=False, + superres_bin_min_spikes=0, + ) + template_data = TemplateData.from_config( + *no_overlap_recording_sorting(templates), + template_config, + motion_est=motion_util.IdentityMotionEstimate(), + n_jobs=0, + save_folder=tmp_path, + overwrite=True, + ) + + matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config( + rec, + main.MatchingConfig( + threshold=0.01, + template_temporal_upsampling_factor=1, + ), + nofeatcfg, + template_data, + motion_est=motion_util.IdentityMotionEstimate(), + ) + matcher.precompute_peeling_data(tmp_path) + res = matcher.peel_chunk( + torch.from_numpy(rec.get_traces().copy()), + return_residual=True, + return_conv=True, + ) + + ixa, ixb, pconv = matcher.pairwise_conv_db.query( + [0, 1], [0, 1], upsampling_indices_b=[0, 0], grid=True + ) + maxpc = pconv.max(dim=1).values + for ia, ib, pc in zip(ixa, ixb, maxpc): + assert np.isclose(pc, (templates[ia] * templates[ib]).sum()) + assert res["n_spikes"] == len(times) + assert np.array_equal(res["times_samples"], times) + assert np.array_equal(res["labels"], labels) + assert np.isclose( + torch.square(res["residual"]).mean(), + 0.0, + ) + assert np.isclose( + torch.square(res["conv"]).mean(), + 0.0, + atol=1e-5, + ) + + matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config( + rec, + main.MatchingConfig( + threshold=0.01, + template_temporal_upsampling_factor=8, + ), + nofeatcfg, + template_data, + motion_est=motion_util.IdentityMotionEstimate(), + ) + matcher.precompute_peeling_data(tmp_path) + res = matcher.peel_chunk( + torch.from_numpy(rec.get_traces().copy()), + return_residual=True, + return_conv=True, + ) + assert res["n_spikes"] == len(times) + assert np.array_equal(res["times_samples"], times) + assert np.array_equal(res["labels"], labels) + assert np.array_equal(res["upsampling_indices"], [0, 0]) + assert np.isclose( + torch.square(res["residual"]).mean(), + 0.0, + ) + assert np.isclose( + torch.square(res["conv"]).mean(), + 0.0, + atol=1e-5, + ) + + +def static_tester(tmp_path, up_factor=1): + recording_length_samples = 40_011 + n_channels = 2 + geom = np.c_[np.zeros(2), np.arange(2)] + geom + + # template main channel traces + trace0 = 50 * np.exp( + -(((np.arange(spike_length_samples) - trough_offset_samples) / 10) ** 2) + ) + trace1 = 250 * np.exp( + -(((np.arange(spike_length_samples) - trough_offset_samples) / 30) ** 2) + ) + + # templates + templates = np.zeros((3, spike_length_samples, n_channels), dtype="float32") + templates[0, :, 0] = trace0 + templates[1, :, 0] = trace1 + templates[2, :, 1] = trace0 + + # spike train + # fmt: off + tcl = [ + 100, 0, 0, + 150, 0, 0, + 151, 1, 2, + 500, 0, 1, + 2000, 0, 0, + 2001, 0, 1, + 35000, 1, 2, + 35001, 0, 1, + ] + # fmt: on + times, channels, labels = np.array(tcl).reshape(-1, 3).T + rec = np.zeros((recording_length_samples, n_channels), dtype="float32") + for t, l in zip(times, labels): + rec[ + t - trough_offset_samples : t - trough_offset_samples + spike_length_samples + ] += templates[l] + rec = si.NumpyRecording(rec, 30_000) + rec.set_dummy_probe_from_locations(geom) + + template_config = config.TemplateConfig( + low_rank_denoising=False, superres_bin_min_spikes=0 + ) + template_data = TemplateData.from_config( + *no_overlap_recording_sorting(templates), + template_config, + motion_est=motion_util.IdentityMotionEstimate(), + n_jobs=0, + save_folder=tmp_path, + overwrite=True, + ) + + matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config( + rec, + main.MatchingConfig( + threshold=0.01, + template_temporal_upsampling_factor=up_factor, + coarse_approx_error_threshold=0.0, + conv_ignore_threshold=0.0, + template_svd_compression_rank=2, + ), + nofeatcfg, + template_data, + motion_est=motion_util.IdentityMotionEstimate(), + ) + matcher.precompute_peeling_data(tmp_path) + + lrt = template_util.svd_compress_templates( + template_data.templates, rank=matcher.svd_compression_rank + ) + tempup = template_util.compressed_upsampled_templates( + lrt.temporal_components, + ptps=template_data.templates.ptp(1).max(1), + max_upsample=up_factor, + ) + assert np.array_equal(matcher.compressed_upsampled_temporal, tempup.compressed_upsampled_templates) + assert np.array_equal(matcher.objective_spatial_components, lrt.spatial_components) + assert np.array_equal(matcher.objective_singular_values, lrt.singular_values) + assert np.array_equal(matcher.spatial_components, lrt.spatial_components) + assert np.array_equal(matcher.singular_values, lrt.singular_values) + for up in range(up_factor): + ixa, ixb, pconv = matcher.pairwise_conv_db.query( + np.arange(3), + np.arange(3), + upsampling_indices_b=up + np.zeros(3, dtype=int), + grid=True, + ) + centerpc = pconv[:, spike_length_samples - 1] + for ia, ib, pc, pcf in zip(ixa, ixb, centerpc, pconv): + tempupb = tempup.compressed_upsampled_templates[ + tempup.compressed_upsampling_map[ib, up] + ] + tupb = (tempupb * lrt.singular_values[ib]) @ lrt.spatial_components[ib] + tc = (templates[ia] * tupb).sum() + + template_a = torch.as_tensor(templates[ia][None]) + ssb = lrt.singular_values[ib][:, None] * lrt.spatial_components[ib] + conv_filt = torch.bmm(torch.as_tensor(ssb[None]), template_a.mT) + conv_filt = conv_filt[:, None] # (nco, 1, rank, t) + conv_in = torch.as_tensor(tempupb[None]).mT[None] + pconv_ = F.conv2d( + conv_in, conv_filt, padding=(0, 120), groups=1 + ) + pconv1 = pconv_.squeeze()[spike_length_samples - 1].numpy(force=True) + assert torch.isclose(pcf, pconv_).all() + + pconv2 = F.conv2d( + torch.as_tensor(templates[ia])[None, None], + torch.as_tensor(tupb)[None, None], + ).squeeze().numpy(force=True) + assert np.isclose(pconv2, tc) + assert np.isclose(pc, tc) + assert np.isclose(pconv1, pc) + + res = matcher.peel_chunk( + torch.from_numpy(rec.get_traces().copy()), + return_residual=True, + return_conv=True, + ) + + assert res["n_spikes"] == len(times) + assert np.array_equal(res["times_samples"], times) + assert np.array_equal(res["labels"], labels) + assert np.isclose( + torch.square(res["residual"]).mean(), + 0.0, + ) + assert np.isclose( + torch.square(res["conv"]).mean(), + 0.0, + atol=1e-4, + ) + + +def test_static_noup(tmp_path): + static_tester(tmp_path) + + +def test_static_up(tmp_path): + static_tester(tmp_path, up_factor=4) + + +def drifting_tester(tmp_path, up_factor=1): + recording_length_samples = 40_011 + n_channels = 2 + geom = np.c_[np.zeros(2), np.arange(2)] + geom + + # template main channel traces + trace0 = 50 * np.exp( + -(((np.arange(spike_length_samples) - trough_offset_samples) / 10) ** 2) + ) + trace1 = 250 * np.exp( + -(((np.arange(spike_length_samples) - trough_offset_samples) / 30) ** 2) + ) + + # templates + templates = np.zeros((3, spike_length_samples, n_channels), dtype="float32") + templates[0, :, 0] = trace0 + templates[1, :, 0] = trace1 + templates[2, :, 1] = trace0 + + # spike train + # fmt: off + tcl = [ + 100, 0, 0, + 150, 0, 0, + 151, 1, 2, + 500, 0, 1, + 2000, 0, 0, + 2001, 0, 1, + 25000, 1, 2, + 25001, 0, 1, + ] + # fmt: on + times, channels, labels = np.array(tcl).reshape(-1, 3).T + rec = np.zeros((recording_length_samples, n_channels), dtype="float32") + for t, l in zip(times, labels): + rec[ + t - trough_offset_samples : t - trough_offset_samples + spike_length_samples + ] += templates[l] + rec = si.NumpyRecording(rec, 30_000) + rec.set_dummy_probe_from_locations(geom) + + template_config = config.TemplateConfig( + low_rank_denoising=False, superres_bin_min_spikes=0 + ) + template_data = TemplateData.from_config( + *no_overlap_recording_sorting(templates), + template_config, + motion_est=motion_util.IdentityMotionEstimate(), + n_jobs=0, + save_folder=tmp_path, + overwrite=True, + ) + + matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config( + rec, + main.MatchingConfig( + threshold=0.01, + template_temporal_upsampling_factor=up_factor, + coarse_approx_error_threshold=0.0, + conv_ignore_threshold=0.0, + template_svd_compression_rank=2, + ), + nofeatcfg, + template_data, + motion_est=motion_util.IdentityMotionEstimate(), + ) + matcher.precompute_peeling_data(tmp_path) + + # tup = template_util.compressed_upsampled_templates( + # template_data.templates, max_upsample=up_factor + # ) + lrt = template_util.svd_compress_templates( + template_data.templates, rank=matcher.svd_compression_rank + ) + tempup = template_util.compressed_upsampled_templates( + lrt.temporal_components, + ptps=template_data.templates.ptp(1).max(1), + max_upsample=up_factor, + ) + print(f"{lrt.temporal_components.shape=}") + print(f"{lrt.singular_values.shape=}") + print(f"{lrt.spatial_components.shape=}") + assert np.array_equal(matcher.compressed_upsampled_temporal, tempup.compressed_upsampled_templates) + assert np.array_equal(matcher.objective_spatial_components, lrt.spatial_components) + assert np.array_equal(matcher.objective_singular_values, lrt.singular_values) + assert np.array_equal(matcher.spatial_components, lrt.spatial_components) + assert np.array_equal(matcher.singular_values, lrt.singular_values) + for up in range(up_factor): + ixa, ixb, pconv = matcher.pairwise_conv_db.query( + np.arange(3), + np.arange(3), + upsampling_indices_b=up + np.zeros(3, dtype=int), + grid=True, + ) + centerpc = pconv[:, spike_length_samples - 1] + for ia, ib, pc, pcf in zip(ixa, ixb, centerpc, pconv): + # tupb = tup.compressed_upsampled_templates[ + # tup.compressed_upsampling_map[ib, up] + # ] + tempupb = tempup.compressed_upsampled_templates[ + tempup.compressed_upsampling_map[ib, up] + ] + tupb = (tempupb * lrt.singular_values[ib]) @ lrt.spatial_components[ib] + tc = (templates[ia] * tupb).sum() + + template_a = torch.as_tensor(templates[ia][None]) + ssb = lrt.singular_values[ib][:, None] * lrt.spatial_components[ib] + conv_filt = torch.bmm(torch.as_tensor(ssb[None]), template_a.mT) + conv_filt = conv_filt[:, None] # (nco, 1, rank, t) + conv_in = torch.as_tensor(tempupb[None]).mT[None] + pconv_ = F.conv2d( + conv_in, conv_filt, padding=(0, 120), groups=1 + ) + # print(f"{torch.abs(pcf - pconv_).max()=}") + pconv1 = pconv_.squeeze()[spike_length_samples - 1].numpy(force=True) + assert torch.isclose(pcf, pconv_).all() + + pconv2 = F.conv2d( + torch.as_tensor(templates[ia])[None, None], + torch.as_tensor(tupb)[None, None], + ).squeeze().numpy(force=True) + assert np.isclose(pconv2, tc) + + # print(f" - {ia=} {ib=}") + # print(f" {pc=} {tc=} {pconv1=} {pconv2=}") + # print(f" {pcf[120]=} {pcf[121]=} {pcf[122]=}") + # print(f" ~ {np.isclose(pc, tc)=}") + # print(f" {np.isclose(pconv1, pc)=} {np.isclose(tc, pconv2)=}") + assert np.isclose(pc, tc) + assert np.isclose(pconv1, pc) + + res = matcher.peel_chunk( + torch.from_numpy(rec.get_traces().copy()), + return_residual=True, + return_conv=True, + ) + + print() + print() + print(f"{len(times)=}") + print(f"{res['n_spikes']=}") + print() + print(f'{torch.square(res["residual"]).mean()=}') + print(f'{torch.abs(res["residual"]).max()=}') + print(f'{torch.square(res["conv"]).mean()=}') + print(f'{torch.abs(res["conv"]).max()=}') + print(f'{res["conv"].min()=} {res["conv"].max()=}') + tnsq = np.linalg.norm(templates, axis=(1, 2)) ** 2 + print(f"{res['conv'].shape=} {tnsq.shape=}") + print(f'{(2*res["conv"] - tnsq[:,None]).max()=}') + print() + print(f'{res["times_samples"]=}') + print(f"{times=}") + print() + print(f'{res["labels"]=}') + print(f"{labels=}") + print() + print(f'{np.c_[res["times_samples"], res["labels"], res["upsampling_indices"]]=}') + print(f"{np.c_[times, labels]=}") + print() + print(f'{res["upsampling_indices"]=}') + + assert res["n_spikes"] == len(times) + assert np.array_equal(res["times_samples"], times) + assert np.array_equal(res["labels"], labels) + print(f"{torch.square(res['residual']).mean()=}") + print(f"{torch.square(res['conv']).mean()=}") + assert np.isclose( + torch.square(res["residual"]).mean(), + 0.0, + ) + assert np.isclose( + torch.square(res["conv"]).mean(), + 0.0, + atol=1e-4, + ) + + +if __name__ == "__main__": + import tempfile + from pathlib import Path + + print("test tiny") + with tempfile.TemporaryDirectory() as tdir: + test_tiny(Path(tdir)) + + print() + print("test test_static_noup") + with tempfile.TemporaryDirectory() as tdir: + test_static_noup(Path(tdir)) + + print() + print("test test_static_up") + with tempfile.TemporaryDirectory() as tdir: + test_static_up(Path(tdir)) diff --git a/tests/test_templates.py b/tests/test_templates.py index f6faad40..b8ff4e67 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,9 +1,38 @@ +import tempfile +from pathlib import Path + import numpy as np import spikeinterface.core as sc from dartsort import config -from dartsort.templates import get_templates, template_util, templates +from dartsort.templates import (get_templates, pairwise, pairwise_util, + template_util, templates) +from dartsort.util import drift_util from dartsort.util.data_util import DARTsortSorting -from dredge.motion_util import get_motion_estimate +from dredge.motion_util import IdentityMotionEstimate, get_motion_estimate +from test_util import no_overlap_recording_sorting + + +def test_roundtrip(tmp_path): + rg = np.random.default_rng(0) + temps = rg.normal(size=(11, 121, 384)).astype(np.float32) + template_data = templates.TemplateData.from_config( + *no_overlap_recording_sorting(temps, pad=0), + template_config=config.TemplateConfig( + low_rank_denoising=False, + superres_bin_min_spikes=0, + realign_peaks=False, + ), + motion_est=IdentityMotionEstimate(), + n_jobs=0, + save_folder=tmp_path, + overwrite=True, + ) + print(f"{np.abs(template_data.templates - temps).max()=}") + print(f"{np.abs(template_data.templates - temps).mean()=}") + print(f"{np.abs(template_data.templates - temps).min()=}") + print(f"{template_data.templates.ptp(1).max(1)=}") + print(f"{temps.ptp(1).max(1)=}") + assert np.array_equal(template_data.templates, temps) def test_static_templates(): @@ -130,17 +159,189 @@ def test_main_object(): labels=[0, 0, 1, 1], channels=[0, 0, 0, 0], sampling_frequency=1, - extra_features=dict(point_source_localizations=np.zeros((4, 4)), times_seconds=[0, 2, 6, 8]), + extra_features=dict( + point_source_localizations=np.zeros((4, 4)), times_seconds=[0, 2, 6, 8] + ), ) tdata = templates.TemplateData.from_config( rec, sorting, - config.TemplateConfig(trough_offset_samples=0, spike_length_samples=2, realign_peaks=False), + config.TemplateConfig( + trough_offset_samples=0, + spike_length_samples=2, + realign_peaks=False, + superres_templates=False, + denoising_rank=2, + ), motion_est=me, ) +def test_pconv(): + # want to make sure drift handling is as expected + # design an experiment + + # 4 chans, no drift + # 3 units (superres): 0 (0,1), 1 (2,3), 3 (4) + # temps overlap like: + # 0 chan=0 z=0 + # 12 1 1 + # 23 2 2 + # 4 3 3 + t = 2 + c = 4 + temps = np.zeros((5, t, c), dtype=np.float32) + temps[0, 0, 0] = 2 + temps[1, 0, 1] = 2 + temps[2, 0, [1, 2]] = 2 + temps[3, 0, 2] = 2 + temps[4, 0, 3] = 2 + geom = np.c_[np.zeros(c), np.arange(c).astype(float)] + overlaps = {(i, i): np.square(temps[i]).sum() for i in range(5)} + overlaps[(1, 2)] = overlaps[(2, 1)] = (temps[1] * temps[2]).sum() + overlaps[(2, 3)] = overlaps[(3, 2)] = (temps[3] * temps[2]).sum() + + print(f"--------- no drift") + tdata = templates.TemplateData( + templates=temps, + unit_ids=np.array([0, 0, 1, 1, 2]), + spike_counts=np.ones(5), + registered_geom=None, + registered_template_depths_um=None, + ) + svd_compressed = template_util.svd_compress_templates(temps, rank=1) + ctempup = template_util.compressed_upsampled_templates( + svd_compressed.temporal_components, + ptps=temps.ptp(1).max(1), + max_upsample=1, + kind="cubic", + ) + + with tempfile.TemporaryDirectory() as tdir: + pconvdb_path = pairwise_util.compressed_convolve_to_h5( + Path(tdir) / "test.h5", + geom=geom, + template_data=tdata, + low_rank_templates=svd_compressed, + compressed_upsampled_temporal=ctempup, + ) + pconvdb = pairwise.CompressedPairwiseConv.from_h5(pconvdb_path) + assert (pconvdb.pconv[0] == 0.0).all() + + for tixa in range(5): + for tixb in range(5): + ixa, ixb, pconv = pconvdb.query(tixa, tixb) + if (tixa, tixb) not in overlaps: + assert not ixa.numel() + assert not ixb.numel() + assert not pconv.numel() + continue + + olap = overlaps[tixa, tixb] + assert (ixa, ixb) == (tixa, tixb) + assert np.isclose(pconv.max(), olap) + + # drifting version + # rigid drift from -1 to 0 to 1, note pitch=1 + # same templates but padded + print(f"--------- rigid drift") + tempspad = np.pad(temps, [(0, 0), (0, 0), (1, 1)]) + svd_compressed = template_util.svd_compress_templates(tempspad, rank=1) + reg_geom = np.c_[np.zeros(c + 2), np.arange(c + 2).astype(float)] + tdata = templates.TemplateData( + templates=tempspad, + unit_ids=np.array([0, 0, 1, 1, 2]), + spike_counts=np.ones(5), + registered_geom=reg_geom, + registered_template_depths_um=np.zeros(5), + ) + geom = np.c_[np.zeros(c), np.arange(1, c + 1).astype(float)] + motion_est = get_motion_estimate(time_bin_centers_s=np.array([0., 1, 2]), displacement=[-1., 0, 1]) + + # visualize shifted temps + # for tix in range(5): + # for shift in (-1, 0, 1): + # spatial_shifted = drift_util.get_waveforms_on_static_channels( + # spat[tix][None], + # reg_geom, + # n_pitches_shift=np.array([shift]), + # registered_geom=geom, + # fill_value=0.0, + # ) + # print(f"{shift=}") + # print(f"{spatial_shifted=}") + + with tempfile.TemporaryDirectory() as tdir: + pconvdb_path = pairwise_util.compressed_convolve_to_h5( + Path(tdir) / "test.h5", + geom=geom, + template_data=tdata, + low_rank_templates=svd_compressed, + compressed_upsampled_temporal=ctempup, + motion_est=motion_est, + chunk_time_centers_s=[0, 1, 2], + ) + pconvdb = pairwise.CompressedPairwiseConv.from_h5(pconvdb_path) + assert (pconvdb.pconv[0] == 0.0).all() + print(f"{pconvdb.pconv.shape=}") + + for tixa in range(5): + for tixb in range(5): + ixa, ixb, pconv = pconvdb.query(tixa, tixb, shifts_a=0, shifts_b=0) + + if (tixa, tixb) not in overlaps: + assert not ixa.numel() + assert not ixb.numel() + assert not pconv.numel() + continue + + olap = overlaps[tixa, tixb] + assert (ixa, ixb) == (tixa, tixb) + assert np.isclose(pconv.max(), olap) + + for tixb in range(5): + for shiftb in (-1, 0, 1): + ixa, ixb, pconv = pconvdb.query(0, tixb, shifts_a=-1, shifts_b=shiftb) + assert not ixa.numel() + assert not ixb.numel() + assert not pconv.numel() + + for tixb in range(5): + for shift in (-1, 0, 1): + ixa, ixb, pconv = pconvdb.query(4, tixb, shifts_a=shift, shifts_b=shift) + if tixb != 4 or shift == 1: + assert not ixa.numel() + assert not ixb.numel() + assert not pconv.numel() + else: + assert np.isclose(pconv.max(), 4 if shift < 1 else 0) + ixa, ixb, pconv = pconvdb.query(tixb, 4, shifts_a=shift, shifts_b=shift) + if tixb != 4 or shift == 1: + assert not ixa.numel() + assert not ixb.numel() + assert not pconv.numel() + else: + assert np.isclose(pconv.max(), 4) + + for shifta in (-1, 0, 1): + for shiftb in (-1, 0, 1): + for tixa in range(5): + for tixb in range(5): + ixa, ixb, pconv = pconvdb.query(tixa, tixb, shifts_a=shifta, shifts_b=shiftb) + if shifta != shiftb: + # this is because we are rigid here + assert not ixa.numel() + assert not ixb.numel() + assert not pconv.numel() + + if __name__ == "__main__": - test_static_templates() - test_drifting_templates() - test_main_object() + import tempfile + from pathlib import Path + + with tempfile.TemporaryDirectory() as tdir: + test_roundtrip(Path(tdir)) + # test_static_templates() + # test_drifting_templates() + # test_main_object() + # test_pconv() diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 00000000..2d84e4e3 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,26 @@ +import numpy as np +import spikeinterface.core as sc +from dartsort.util.data_util import DARTsortSorting + + +def no_overlap_recording_sorting(templates, fs=30000, trough_offset_samples=42, pad=0): + n_templates, spike_length_samples, n_channels = templates.shape + rec = templates.reshape(n_templates * spike_length_samples, n_channels) + if pad > 0: + rec = np.pad(rec, [(pad, pad), (0, 0)]) + geom = np.c_[np.zeros(n_channels), np.arange(n_channels)] + rec = sc.NumpyRecording(rec, fs) + rec.set_dummy_probe_from_locations(geom) + depths = np.zeros(n_templates) + locs = np.c_[np.zeros_like(depths), np.zeros_like(depths), depths] + times = np.arange(n_templates) * spike_length_samples + trough_offset_samples + times_seconds = times / fs + sorting = DARTsortSorting( + times + pad, + np.zeros(n_templates), + np.arange(n_templates), + extra_features=dict( + point_source_localizations=locs, times_seconds=times_seconds + ), + ) + return rec, sorting