diff --git a/src/dartsort/cluster/merge.py b/src/dartsort/cluster/merge.py index e51a711e..53a02bd2 100644 --- a/src/dartsort/cluster/merge.py +++ b/src/dartsort/cluster/merge.py @@ -18,6 +18,7 @@ def merge_templates( motion_est=None, max_shift_samples=20, superres_linkage=np.max, + sym_function=np.minimum, merge_distance_threshold=0.25, temporal_upsampling_factor=8, amplitude_scaling_variance=0.0, @@ -68,6 +69,50 @@ def merge_templates( save_npz_name=template_npz_filename, ) + units, dists, shifts, template_snrs = calculate_merge_distances( + template_data, + superres_linkage=superres_linkage, + sym_function=sym_function, + max_shift_samples=max_shift_samples, + temporal_upsampling_factor=temporal_upsampling_factor, + amplitude_scaling_variance=amplitude_scaling_variance, + amplitude_scaling_boundary=amplitude_scaling_boundary, + svd_compression_rank=svd_compression_rank, + min_channel_amplitude=min_channel_amplitude, + conv_batch_size=conv_batch_size, + units_batch_size=units_batch_size, + device=device, + n_jobs=n_jobs, + show_progress=show_progress, + ) + + # now run hierarchical clustering + return recluster( + sorting, + units, + dists, + shifts, + template_snrs, + merge_distance_threshold=merge_distance_threshold, + ) + + +def calculate_merge_distances( + template_data, + superres_linkage=np.max, + sym_function=np.minimum, + max_shift_samples=20, + temporal_upsampling_factor=8, + amplitude_scaling_variance=0.0, + amplitude_scaling_boundary=0.5, + svd_compression_rank=10, + min_channel_amplitude=0.0, + conv_batch_size=128, + units_batch_size=8, + device=None, + n_jobs=0, + show_progress=True, +): # allocate distance + shift matrices. shifts[i,j] is trough[j]-trough[i]. n_templates = template_data.templates.shape[0] sup_dists = np.full((n_templates, n_templates), np.inf) @@ -116,15 +161,9 @@ def merge_templates( template_data.templates.ptp(1).max(1) / template_data.spike_counts ) - # now run hierarchical clustering - return recluster( - sorting, - units, - dists, - shifts, - template_snrs, - merge_distance_threshold=merge_distance_threshold, - ) + dists = sym_function(dists, dists.T) + + return units, dists, shifts, template_snrs def recluster( @@ -134,9 +173,7 @@ def recluster( shifts, template_snrs, merge_distance_threshold=0.25, - sym_function=np.minimum, ): - dists = sym_function(dists, dists.T) # upper triangle not including diagonal, aka condensed distance matrix in scipy pdist = dists[np.triu_indices(dists.shape[0], k=1)] diff --git a/src/dartsort/config.py b/src/dartsort/config.py index 44e8473f..27373fef 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -31,7 +31,6 @@ default_pretrained_path = default_pretrained_path.joinpath("single_chan_denoiser.pt") -# TODO: integrate this in the other configs @dataclass(frozen=True) class WaveformConfig: """Defaults yield 42 sample trough offset and 121 total at 30kHz.""" @@ -44,7 +43,10 @@ def trough_offset_samples(self, sampling_frequency=30_000): def spike_length_samples(self, sampling_frequency=30_000): spike_len_ms = self.ms_before + self.ms_after - return int(spike_len_ms * (sampling_frequency / 1000)) + length = int(spike_len_ms * (sampling_frequency / 1000)) + # odd is better for convolution arithmetic elsewhere + length = 2 * (length // 2) + 1 + return length @dataclass(frozen=True) @@ -103,8 +105,6 @@ class FeaturizationConfig: @dataclass(frozen=True) class SubtractionConfig: - trough_offset_samples: int = 42 - spike_length_samples: int = 121 detection_thresholds: List[int] = (10, 8, 6, 5, 4) chunk_length_samples: int = 30_000 peak_sign: str = "both" @@ -147,8 +147,6 @@ class MotionEstimationConfig: @dataclass(frozen=True) class TemplateConfig: - trough_offset_samples: int = 42 - spike_length_samples: int = 121 spikes_per_unit = 500 # -- template construction parameters @@ -177,8 +175,6 @@ class TemplateConfig: @dataclass(frozen=True) class MatchingConfig: - trough_offset_samples: int = 42 - spike_length_samples: int = 121 chunk_length_samples: int = 30_000 extract_radius: float = 200.0 n_chunks_fit: int = 40 @@ -235,6 +231,7 @@ class SplitMergeConfig: merge_distance_threshold: float = 0.25 +default_waveform_config = WaveformConfig() default_featurization_config = FeaturizationConfig() default_subtraction_config = SubtractionConfig() default_template_config = TemplateConfig() @@ -242,3 +239,4 @@ class SplitMergeConfig: default_split_merge_config = SplitMergeConfig() coarse_template_config = TemplateConfig(superres_templates=False) default_matching_config = MatchingConfig() +default_motion_estimation_config = MotionEstimationConfig() diff --git a/src/dartsort/main.py b/src/dartsort/main.py index c9448daf..8923a865 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -1,3 +1,4 @@ +from dataclasses import asdict from pathlib import Path from dartsort.cluster.initial import ensemble_chunks @@ -6,14 +7,16 @@ from dartsort.config import (default_clustering_config, default_featurization_config, default_matching_config, + default_motion_estimation_config, default_split_merge_config, default_subtraction_config, - default_template_config) + default_template_config, default_waveform_config) from dartsort.peel import (ObjectiveUpdateTemplateMatchingPeeler, SubtractionPeeler) from dartsort.templates import TemplateData from dartsort.util.data_util import check_recording from dartsort.util.peel_util import run_peeler +from dartsort.util.registration_util import estimate_motion def dartsort_from_config( @@ -26,7 +29,9 @@ def dartsort_from_config( def dartsort( recording, output_directory, + waveform_config=default_waveform_config, featurization_config=default_featurization_config, + motion_estimation_config=default_motion_estimation_config, subtraction_config=default_subtraction_config, matching_config=default_subtraction_config, template_config=default_template_config, @@ -43,6 +48,7 @@ def dartsort( sorting, sub_h5 = subtract( recording, output_directory, + waveform_config=waveform_config, featurization_config=featurization_config, subtraction_config=subtraction_config, n_jobs=n_jobs, @@ -50,8 +56,14 @@ def dartsort( device=device, ) if motion_est is None: - # TODO - motion_est = estimate_motion() + motion_est = estimate_motion( + recording, + sorting, + output_directory, + overwrite=overwrite, + device=device, + **asdict(motion_estimation_config), + ) sorting = cluster( sub_h5, recording, @@ -78,6 +90,7 @@ def dartsort( output_directory, motion_est=motion_est, template_config=template_config, + waveform_config=waveform_config, featurization_config=featurization_config, matching_config=matching_config, n_jobs_templates=n_jobs, @@ -95,6 +108,7 @@ def dartsort( def subtract( recording, output_directory, + waveform_config=default_waveform_config, featurization_config=default_featurization_config, subtraction_config=default_subtraction_config, chunk_starts_samples=None, @@ -109,6 +123,7 @@ def subtract( check_recording(recording) subtraction_peeler = SubtractionPeeler.from_config( recording, + waveform_config=waveform_config, subtraction_config=subtraction_config, featurization_config=featurization_config, ) @@ -174,6 +189,7 @@ def match( sorting=None, output_directory=None, motion_est=None, + waveform_config=default_waveform_config, template_config=default_template_config, featurization_config=default_featurization_config, matching_config=default_matching_config, @@ -192,21 +208,30 @@ def match( model_dir = Path(output_directory) / model_subdir # compute templates + trough_offset_samples = waveform_config.trough_offset_samples( + recording.sampling_frequency + ) + spike_length_samples = waveform_config.spike_length_samples( + recording.sampling_frequency + ) template_data = TemplateData.from_config( recording, sorting, - template_config, + template_config=template_config, motion_est=motion_est, n_jobs=n_jobs_templates, save_folder=model_dir, overwrite=overwrite, device=device, save_npz_name=template_npz_filename, + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, ) # instantiate peeler matching_peeler = ObjectiveUpdateTemplateMatchingPeeler.from_config( recording, + waveform_config, matching_config, featurization_config, template_data, diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index eefbb353..9f489c32 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -365,6 +365,7 @@ def handle_upsampling( def from_config( cls, recording, + waveform_config, matching_config, featurization_config, template_data, @@ -377,6 +378,9 @@ def from_config( featurization_pipeline = WaveformPipeline.from_config( geom, channel_index, featurization_config ) + trough_offset_samples = waveform_config.trough_offset_samples( + recording.sampling_frequency + ) return cls( recording, template_data, @@ -391,7 +395,7 @@ def from_config( amplitude_scaling_boundary=matching_config.amplitude_scaling_boundary, conv_ignore_threshold=matching_config.conv_ignore_threshold, coarse_approx_error_threshold=matching_config.coarse_approx_error_threshold, - trough_offset_samples=matching_config.trough_offset_samples, + trough_offset_samples=trough_offset_samples, threshold=matching_config.threshold, chunk_length_samples=matching_config.chunk_length_samples, n_chunks_fit=matching_config.n_chunks_fit, diff --git a/src/dartsort/peel/subtract.py b/src/dartsort/peel/subtract.py index 2732f4e7..3671802f 100644 --- a/src/dartsort/peel/subtract.py +++ b/src/dartsort/peel/subtract.py @@ -1,3 +1,4 @@ +import warnings from collections import namedtuple from pathlib import Path @@ -74,22 +75,20 @@ def peeling_needs_fit(self): def save_models(self, save_folder): super().save_models(save_folder) - sub_denoise_pt = ( - Path(save_folder) / "subtraction_denoising_pipeline.pt" - ) + sub_denoise_pt = Path(save_folder) / "subtraction_denoising_pipeline.pt" torch.save(self.subtraction_denoising_pipeline, sub_denoise_pt) def load_models(self, save_folder): super().load_models(save_folder) - sub_denoise_pt = ( - Path(save_folder) / "subtraction_denoising_pipeline.pt" - ) + sub_denoise_pt = Path(save_folder) / "subtraction_denoising_pipeline.pt" if sub_denoise_pt.exists(): self.subtraction_denoising_pipeline = torch.load(sub_denoise_pt) @classmethod - def from_config(cls, recording, subtraction_config, featurization_config): + def from_config( + cls, recording, waveform_config, subtraction_config, featurization_config + ): # waveform extraction channel neighborhoods geom = torch.tensor(recording.get_channel_locations()) channel_index = make_channel_index( @@ -110,13 +109,28 @@ def from_config(cls, recording, subtraction_config, featurization_config): geom, channel_index, featurization_config ) + # waveform logic + trough_offset_samples = waveform_config.trough_offset_samples( + recording.sampling_frequency + ) + spike_length_samples = waveform_config.spike_length_samples( + recording.sampling_frequency + ) + + if trough_offset_samples != 42 or spike_length_samples != 121: + # temporary warning just so I can see if this happens + warnings.warn( + f"waveform_config {trough_offset_samples=} {spike_length_samples=} " + f"since {recording.sampling_frequency=}" + ) + return cls( recording, channel_index, subtraction_denoising_pipeline, featurization_pipeline, - trough_offset_samples=subtraction_config.trough_offset_samples, - spike_length_samples=subtraction_config.spike_length_samples, + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, detection_thresholds=subtraction_config.detection_thresholds, chunk_length_samples=subtraction_config.chunk_length_samples, peak_sign=subtraction_config.peak_sign, @@ -203,16 +217,11 @@ def _fit_subtraction_transformers( if which == "denoisers": self.subtraction_denoising_pipeline = WaveformPipeline( [init_waveform_feature] - + [ - t - for t in orig_denoise - if (t.is_denoiser and not t.needs_fit()) - ] + + [t for t in orig_denoise if (t.is_denoiser and not t.needs_fit())] ) else: self.subtraction_denoising_pipeline = WaveformPipeline( - [init_waveform_feature] - + [t for t in orig_denoise if t.is_denoiser] + [init_waveform_feature] + [t for t in orig_denoise if t.is_denoiser] ) # and we don't need any features for this @@ -392,8 +401,7 @@ def subtract_chunk( # discard spikes in the margins and sort times_samples for caller keep = torch.nonzero( - (spike_times >= left_margin) - & (spike_times < traces.shape[0] - right_margin) + (spike_times >= left_margin) & (spike_times < traces.shape[0] - right_margin) )[:, 0] if not keep.any(): return empty_chunk_subtraction_result( @@ -437,9 +445,7 @@ def subtract_chunk( ) -def empty_chunk_subtraction_result( - spike_length_samples, channel_index, residual -): +def empty_chunk_subtraction_result(spike_length_samples, channel_index, residual): empty_waveforms = torch.empty( (0, spike_length_samples, channel_index.shape[1]), dtype=residual.dtype, diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index bdae1c59..6e8cd005 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -3,13 +3,13 @@ from typing import Optional import numpy as np +from dartsort.localize.localize_util import localize_waveforms from dartsort.util import drift_util from .get_templates import get_templates from .superres_util import superres_sorting from .template_util import (get_realigned_sorting, get_template_depths, weighted_average) -from dartsort.localize.localize_util import localize_waveforms _motion_error_prefix = ( "If template_config has registered_templates==True " @@ -82,11 +82,12 @@ def coarsen(self, with_locs=True): ) def template_locations(self): - template_locations = localize_waveforms(self.templates, - self.registered_geom, - main_channels=self.templates.ptp(1).argmax(1), - radius=self.localization_radius_um - ) + template_locations = localize_waveforms( + self.templates, + self.registered_geom, + main_channels=self.templates.ptp(1).argmax(1), + radius=self.localization_radius_um, + ) return template_locations def unit_templates(self, unit_id): @@ -105,6 +106,8 @@ def from_config( localizations_dataset_name="point_source_localizations", n_jobs=0, device=None, + trough_offset_samples=42, + spike_length_samples=121, ): if save_folder is not None: save_folder = Path(save_folder) @@ -141,8 +144,8 @@ def from_config( geom = recording.get_channel_locations() kwargs = dict( - trough_offset_samples=template_config.trough_offset_samples, - spike_length_samples=template_config.spike_length_samples, + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, spikes_per_unit=template_config.spikes_per_unit, # realign handled in advance below, not needed in kwargs # realign_peaks=template_config.realign_peaks, @@ -216,16 +219,16 @@ def from_config( kwargs["registered_geom"], registered_template_depths_um, localization_radius_um=template_config.registered_template_localization_radius_um, - trough_offset_samples=template_config.trough_offset_samples, - spike_length_samples=template_config.spike_length_samples, + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, ) else: obj = cls( results["templates"], unit_ids, spike_counts, - trough_offset_samples=template_config.trough_offset_samples, - spike_length_samples=template_config.spike_length_samples, + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, ) if save_folder is not None: obj.to_npz(npz_path) diff --git a/src/dartsort/util/analysis.py b/src/dartsort/util/analysis.py index b6aa04a7..bcc89851 100644 --- a/src/dartsort/util/analysis.py +++ b/src/dartsort/util/analysis.py @@ -9,7 +9,7 @@ """ from dataclasses import dataclass, replace from pathlib import Path -from typing import Optional +from typing import Callable, Optional import h5py import numpy as np @@ -19,7 +19,7 @@ from sklearn.decomposition import PCA from spikeinterface.comparison import GroundTruthComparison -from ..cluster import relocate +from ..cluster import merge, relocate from ..templates import TemplateData from ..transform import WaveformPipeline from .data_util import DARTsortSorting @@ -56,6 +56,11 @@ class DARTsortAnalysis: tpca_features_dataset = "collisioncleaned_tpca_features" template_indices_dataset = "collisioncleaned_tpca_features" + # configuration for analysis computations not included in above objects + device: Optional[str, torch.device] = None + merge_distance_templates_kind: str = "coarse" + merge_superres_linkage: Callable[[np.ndarray], float] = np.max + # helper constructors @classmethod @@ -110,7 +115,9 @@ def from_peeling_paths( def __post_init__(self): if self.featurization_pipeline is not None: assert not self.featurization_pipeline.needs_fit() - assert np.isin(self.template_data.unit_ids, np.unique(self.sorting.labels)).all() + assert np.isin( + self.template_data.unit_ids, np.unique(self.sorting.labels) + ).all() assert self.hdf5_path.exists() self.coarse_template_data = self.template_data.coarsen() @@ -127,6 +134,7 @@ def __post_init__(self): self.motion_est is not None and self.template_data.registered_geom is not None ) + assert self.coarse_template_data.unit_ids == self.unit_ids # cached hdf5 pointer self._h5 = None @@ -145,6 +153,7 @@ def clear_cache(self): self._sklearn_tpca = None self._unit_ids = None self._spike_counts = None + self._merge_dist = None self._feats = {} def __getstate__(self): @@ -209,6 +218,12 @@ def sklearn_tpca(self): self._sklearn_tpca = tpca_feature[0].to_sklearn() return self._sklearn_tpca + @property + def merge_dist(self): + if self._merge_dist is None: + self._merge_dist = self._calc_merge_dist() + return self._merge_dist + # spike train helpers @property @@ -236,6 +251,16 @@ def in_template(self, template_index): def unit_template_indices(self, unit_id): return np.flatnonzero(self.template_data.unit_ids == self.unit_id) + @property + def show_geom(self): + show_geom = self.template_data.registered_geom + if show_geom is None: + show_geom = self.recording.get_channel_locations() + return show_geom + + def show_channel_index(self, radius_um=50): + return make_channel_index(self.show_geom, radius_um) + # spike feature loading methods def named_feature(self, name, which=slice(None)): @@ -330,7 +355,12 @@ def unit_raw_waveforms( if not self.shifting: return which, waveforms - waveforms, max_chan, show_geom, show_channel_index = self.unit_shift_or_relocate_channels( + ( + waveforms, + max_chan, + show_geom, + show_channel_index, + ) = self.unit_shift_or_relocate_channels( unit_id, which, waveforms, @@ -367,7 +397,12 @@ def unit_tpca_waveforms( t = waveforms.shape[1] waveforms = waveforms.reshape(n, c, t).transpose(0, 2, 1) - waveforms, max_chan, show_geom, show_channel_index = self.unit_shift_or_relocate_channels( + ( + waveforms, + max_chan, + show_geom, + show_channel_index, + ) = self.unit_shift_or_relocate_channels( unit_id, which, waveforms, @@ -378,9 +413,21 @@ def unit_tpca_waveforms( return which, waveforms, max_chan, show_geom, show_channel_index def unit_pca_features( - self, unit_id, relocated=True, rank=2, pca_radius_um=75, random_seed=0, max_count=500 + self, + unit_id, + relocated=True, + rank=2, + pca_radius_um=75, + random_seed=0, + max_count=500, ): - which, waveforms, max_chan, show_geom, show_channel_index = self.unit_tpca_waveforms( + ( + which, + waveforms, + max_chan, + show_geom, + show_channel_index, + ) = self.unit_tpca_waveforms( unit_id, relocated=relocated, show_radius_um=pca_radius_um, @@ -439,7 +486,9 @@ def unit_shift_or_relocate_channels( show_channel_index = make_channel_index(show_geom, show_radius_um) show_chans = show_channel_index[max_chan] show_chans = show_chans[show_chans < len(show_geom)] - show_channel_index = np.broadcast_to(show_chans[None], (len(show_geom), show_chans.size)) + show_channel_index = np.broadcast_to( + show_chans[None], (len(show_geom), show_chans.size) + ) if not self.shifting: return waveforms, max_chan, show_geom, show_channel_index @@ -478,6 +527,34 @@ def unit_shift_or_relocate_channels( return waveforms, max_chan, show_geom, show_channel_index + def nearby_coarse_templates(self, unit_id, n_neighbors=5): + unit_ix = np.searchsorted(self.unit_ids, unit_id) + unit_dists = self.merge_dist[unit_ix] + distance_order = np.argsort(unit_dists) + assert distance_order[0] == unit_ix + neighbor_ixs = distance_order[:n_neighbors] + neighbor_ids = self.unit_ids[:n_neighbors] + neighbor_dists = self.merge_dist[neighbor_ixs[:, None], neighbor_ixs[None, :]] + neighbor_coarse_templates = self.coarse_template_data.templates[neighbor_ixs] + return neighbor_ids, neighbor_dists, neighbor_coarse_templates + + # computation + + def _calc_merge_dist(self): + """Compute the merge distance matrix""" + merge_td = self.template_data + if self.merge_distance_templates_kind == "coarse": + merge_td = self.coarse_template_data + + units, dists, shifts, template_snrs = merge.calculate_merge_distances( + merge_td, + superres_linkage=self.merge_superres_linkage, + device=self.device, + n_jobs=1, + ) + assert np.array_equal(units, self.unit_ids) + self._merge_dist = dists + @dataclass class DARTsortGroundTruthComparison: diff --git a/src/dartsort/util/registration_util.py b/src/dartsort/util/registration_util.py index 43839895..a8566565 100644 --- a/src/dartsort/util/registration_util.py +++ b/src/dartsort/util/registration_util.py @@ -1,14 +1,78 @@ +import pickle +from typing import Optional + +import numpy as np + try: from dredge import dredge_ap + have_dredge = True except ImportError: have_dredge = False pass -def estimate_motion(recording, sorting, motion_estimation_config=None, localizations_dataset_name="point_source_localizations"): - if not motion_estimation_config.do_motion_estimation: +def estimate_motion( + recording, + sorting, + output_directory, + overwrite=False, + do_motion_estimation=True, + probe_boundary_padding_um=100.0, + spatial_bin_length_um: float = 1.0, + temporal_bin_length_s: float = 1.0, + window_step_um: float = 400.0, + window_scale_um: float = 450.0, + window_margin_um: Optional[float] = None, + max_dt_s: float = 0.1, + max_disp_um: Optional[float] = None, + localizations_dataset_name="point_source_localizations", + amplitudes_dataset_name="denoised_ptp_amplitudes", + device=None, +): + if not do_motion_estimation: return None + motion_est_pkl = output_directory / "motion_est.pkl" + if not overwrite and motion_est_pkl.exists(): + with open(motion_est_pkl, "rb") as jar: + return pickle.load(jar) + if not have_dredge: raise ValueError("Please install DREDge to use motion estimation.") + + x = getattr(sorting, localizations_dataset_name)[:, 0] + z = getattr(sorting, localizations_dataset_name)[:, 1] + geom = recording.get_channel_locations() + xmin = geom[:, 0].min() - probe_boundary_padding_um + xmax = geom[:, 0].max() + probe_boundary_padding_um + zmin = geom[:, 1].min() - probe_boundary_padding_um + zmax = geom[:, 1].max() + probe_boundary_padding_um + xvalid = x == np.clip(x, xmin, xmax) + zvalid = z == np.clip(z, zmin, zmax) + valid = np.flatnonzero(xvalid & zvalid) + + # features for registration + z = z[valid] + t_s = sorting.times_seconds[valid] + a = getattr(sorting, amplitudes_dataset_name)[valid] + + # run registration + motion_est, info = dredge_ap.register( + a, + z, + t_s, + window_step_um=window_step_um, + bin_um=spatial_bin_length_um, + bin_s=temporal_bin_length_s, + window_scale_um=window_scale_um, + window_margin_um=window_margin_um, + max_disp_um=max_disp_um, + max_dt_s=max_dt_s, + device=device, + ) + + with open(motion_est_pkl, "wb") as jar: + pickle.dump(motion_est, jar) + + return motion_est diff --git a/src/dartsort/vis/analysis_plots.py b/src/dartsort/vis/analysis_plots.py new file mode 100644 index 00000000..085078f8 --- /dev/null +++ b/src/dartsort/vis/analysis_plots.py @@ -0,0 +1,2 @@ +def distance_matrix_dendro(): + pass diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index 26338c22..e94a2461 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -10,6 +10,7 @@ from collections import namedtuple from pathlib import Path +import colorcet as cc import matplotlib.pyplot as plt import numpy as np from matplotlib.legend_handler import HandlerTuple @@ -303,7 +304,9 @@ def get_waveforms(self, sorting_analysis, unit_id): raise NotImplementedError def draw(self, axis, sorting_analysis, unit_id): - which, waveforms, max_chan, geom, ci = self.get_waveforms(sorting_analysis, unit_id) + which, waveforms, max_chan, geom, ci = self.get_waveforms( + sorting_analysis, unit_id + ) max_abs_amp = None show_template = self.show_template @@ -327,7 +330,9 @@ def draw(self, axis, sorting_analysis, unit_id): new_length=self.spike_length_samples, ) max_abs_amp = self.max_abs_template_scale * np.abs(templates).max() - show_superres_templates = self.show_superres_templates and self.template_index is None + show_superres_templates = ( + self.show_superres_templates and self.template_index is None + ) if show_superres_templates: suptemplates = sorting_analysis.template_data.unit_templates(unit_id) show_superres_templates = bool(suptemplates.size) @@ -444,6 +449,102 @@ def get_waveforms(self, sorting_analysis, unit_id): ) +# -- merge-focused plots + + +class NearbyCoarseTemplatesPlot(UnitPlot): + title = "nearby coarse templates" + kind = "neighbors" + width = 2 + height = 2 + + def __init__(self, show_radius_um=50, n_neighbors=5, legend=True): + self.show_radius_um = show_radius_um + self.n_neighbors = n_neighbors + self.legend = legend + + def draw(self, axis, sorting_analysis, unit_id): + ( + neighbor_ids, + neighbor_dists, + neighbor_coarse_templates, + ) = sorting_analysis.nearby_coarse_templates( + self, unit_id, n_neighbors=self.n_neighbors + ) + colors = cc.m_glasbey_light[neighbor_ids] + assert neighbor_ids[0] == unit_id + chan = neighbor_coarse_templates[0].ptp(1).argmax(1) + ci = sorting_analysis.show_channel_index(self.show_radius_um) + channels = ci[chan] + neighbor_coarse_templates = neighbor_coarse_templates[:, :, channels] + maxamp = np.abs(neighbor_coarse_templates).max() + + labels = [] + handles = [] + for uid, color, template in zip( + neighbor_ids, colors, neighbor_coarse_templates + ): + lines = geomplot( + template[None], + max_channels=[chan], + channel_index=ci, + geom=sorting_analysis.show_geom, + ax=axis, + show_zero=False, + max_abs_amp=maxamp, + subar=True, + bar_color="k", + bar_background="w", + zlim="tight", + color=color, + ) + labels.append(str(uid)) + handles.append(lines[0]) + axis.legend(handles=handles, labels=labels, fancybox=False) + axis.set_xticks([]) + axis.set_yticks([]) + + +class CoarseTemplateDistancePlot(UnitPlot): + title = "coarse template distance" + kind = "neighbors" + width = 2 + height = 2 + + def __init__(self, show_radius_um=50, n_neighbors=5, dist_vmax=1.0): + self.show_radius_um = show_radius_um + self.n_neighbors = n_neighbors + self.dist_vmax = dist_vmax + + def draw(self, axis, sorting_analysis, unit_id): + ( + neighbor_ids, + neighbor_dists, + neighbor_coarse_templates, + ) = sorting_analysis.nearby_coarse_templates( + self, unit_id, n_neighbors=self.n_neighbors + ) + colors = cc.m_glasbey_light[neighbor_ids] + assert neighbor_ids[0] == unit_id + + im = axis.imshow( + neighbor_dists, + vmin=0, + vmax=self.dist_vmax, + cmap=plt.cm.RdGy, + origin="lower", + interpolation="none", + ) + plt.colorbar(im, ax=axis) + axis.set_xticks(range(len(neighbor_ids)), neighbor_ids) + axis.set_yticks(range(len(neighbor_ids)), neighbor_ids) + for i, (tx, ty) in enumerate( + zip(axis.xaxis.get_ticklabels(), axis.yaxis.get_ticklabels()) + ): + tx.set_color(colors[i]) + ty.set_color(colors[i]) + + # -- multi plots # these have multiple plots per unit, and we don't know in advance how many # for instance, making separate plots of spikes belonging to each superres template @@ -524,6 +625,8 @@ def unit_plots(self, sorting_analysis, unit_id): TimeAmpScatter(relocate_amplitudes=True), RawWaveformPlot(), TPCAWaveformPlot(relocated=True), + NearbyCoarseTemplatesPlot(), + CoarseTemplateDistancePlot(), ) diff --git a/tests/test_matching.py b/tests/test_matching.py index 6b8a1b6c..2e3e298f 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -65,6 +65,7 @@ def test_tiny(tmp_path): matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config( rec, + config.default_waveform_config, config.MatchingConfig( threshold=0.01, template_temporal_upsampling_factor=1, @@ -102,6 +103,7 @@ def test_tiny(tmp_path): matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config( rec, + config.default_waveform_config, config.MatchingConfig( threshold=0.01, template_temporal_upsampling_factor=8, @@ -191,6 +193,7 @@ def test_tiny_up(tmp_path, up_factor=8): matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config( rec, + config.default_waveform_config, config.MatchingConfig( threshold=0.01, template_temporal_upsampling_factor=up_factor, @@ -336,6 +339,7 @@ def static_tester(tmp_path, up_factor=1): matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config( rec, + config.default_waveform_config, config.MatchingConfig( threshold=0.01, template_temporal_upsampling_factor=up_factor, diff --git a/tests/test_subtract.py b/tests/test_subtract.py index ac7406bd..ff2678e7 100644 --- a/tests/test_subtract.py +++ b/tests/test_subtract.py @@ -324,7 +324,7 @@ def test_small_nonn(): h = dense_layout() geom = np.c_[h["x"], h["y"]][:n_channels] - rec = sc.NumpyRecording(noise, 10_000) + rec = sc.NumpyRecording(noise, 30_000) rec.set_dummy_probe_from_locations(geom) subconf = SubtractionConfig( @@ -405,7 +405,7 @@ def test_small_default_config(): h = dense_layout() geom = np.c_[h["x"], h["y"]][:n_channels] - rec = sc.NumpyRecording(noise, 10_000) + rec = sc.NumpyRecording(noise, 30_000) rec.set_dummy_probe_from_locations(geom) with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_templates.py b/tests/test_templates.py index b8ff4e67..202ee79e 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -167,13 +167,13 @@ def test_main_object(): rec, sorting, config.TemplateConfig( - trough_offset_samples=0, - spike_length_samples=2, realign_peaks=False, superres_templates=False, denoising_rank=2, ), motion_est=me, + trough_offset_samples=0, + spike_length_samples=2, )