diff --git a/src/dartsort/cluster/split.py b/src/dartsort/cluster/split.py index 5d9ed88c..dfe47186 100644 --- a/src/dartsort/cluster/split.py +++ b/src/dartsort/cluster/split.py @@ -5,7 +5,6 @@ import numpy as np import torch from dartsort.util import drift_util, waveform_util -from dartsort.util.data_util import DARTsortSorting from dartsort.util.multiprocessing_util import get_pool from hdbscan import HDBSCAN from sklearn.decomposition import PCA @@ -210,10 +209,15 @@ def split_cluster(self, in_unit): if n_spikes < self.min_cluster_size: return SplitResult() - max_registered_channel, n_pitches_shift, reloc_amplitudes, kept = self.get_registered_channels(in_unit) + ( + max_registered_channel, + n_pitches_shift, + reloc_amplitudes, + kept, + ) = self.get_registered_channels(in_unit) if not kept.size: return SplitResult() - + features = [] if self.use_localization_features: loc_features = self.localization_features[in_unit] @@ -222,7 +226,9 @@ def split_cluster(self, in_unit): features.append(loc_features) if self.n_pca_features > 0: - enough_good_spikes, kept, pca_embeds = self.pca_features(in_unit, max_registered_channel, n_pitches_shift) + enough_good_spikes, kept, pca_embeds = self.pca_features( + in_unit, max_registered_channel, n_pitches_shift + ) if not enough_good_spikes: return SplitResult() # scale pc features to match localization features diff --git a/src/dartsort/config.py b/src/dartsort/config.py index fcbf7271..33d140a0 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -170,4 +170,5 @@ class MatchingConfig: default_featurization_config = FeaturizationConfig() default_subtraction_config = SubtractionConfig() default_template_config = TemplateConfig() +coarse_template_config = TemplateConfig(superres_templates=False) default_matching_config = MatchingConfig() diff --git a/src/dartsort/util/analysis.py b/src/dartsort/util/analysis.py new file mode 100644 index 00000000..a696ceb9 --- /dev/null +++ b/src/dartsort/util/analysis.py @@ -0,0 +1,367 @@ +"""Deeper object-oriented interaction with sorter data + +This is meant to make implementing plotting code easier: this +code becomes the model in a MVC framework, and vis/unit.py can +implement a view and controller. + +This should also make it easier to compute drift-aware metrics +(e.g., d' using registered templates and shifted waveforms). +""" +from dataclasses import dataclass, replace +from pathlib import Path +from typing import Optional + +import h5py +import numpy as np +import spikeinterface.core as sc +from dredge.motion_util import MotionEstimate +from sklearn.decomposition import PCA + +from ..cluster import relocate +from ..templates import TemplateData +from ..transform import WaveformPipeline +from .data_util import DARTsortSorting +from .drift_util import (get_spike_pitch_shifts, + get_waveforms_on_static_channels) +from .spikeio import read_waveforms_channel_index +from .waveform_util import make_channel_index + + +@dataclass +class DARTsortAnalysis: + """Stores all relevant properties for a drift-aware waveform analysis + + If motion_est is None, there is no motion correction applied. + + If motion_est is not None but relocated is False, waveforms are shifted + across channel neighborhoods to account for drift. + + If additionally relocated is True, point-source relocation is applied + to change around the amplitudes on each channel. + """ + + sorting: DARTsortSorting + hdf5_path: Path + recording: sc.BaseRecording + template_data: TemplateData + featurization_pipeline: Optional[WaveformPipeline] = None + motion_est: Optional[MotionEstimate] = None + + # hdf5 keys + localizations_dataset = "point_source_localizations" + amplitudes_dataset = "denoised_amplitudes" + amplitude_vectors_dataset = "denoised_amplitude_vectors" + tpca_features_dataset = "collisioncleaned_tpca_features" + + # helper constructors + + @classmethod + def from_peeling_hdf5_and_recording( + cls, hdf5_path, recording, template_data, featurization_pipeline=None, motion_est=None, **kwargs + ): + return cls( + DARTsortSorting.from_peeling_hdf5(hdf5_path, load_simple_features=False), + Path(hdf5_path), + recording, + template_data=template_data, + featurization_pipeline=featurization_pipeline, + motion_est=motion_est, + **kwargs, + ) + + @classmethod + def from_peeling_paths( + cls, + recording, + hdf5_path, + model_dir=None, + motion_est=None, + template_data_npz="template_data.npz", + **kwargs, + ): + hdf5_path = Path(hdf5_path) + if model_dir is None: + model_dir = hdf5_path.parent / f"{hdf5_path.stem}_models" + assert model_dir.exists() + sorting = DARTsortSorting.from_peeling_hdf5(hdf5_path, load_simple_features=False) + template_data = TemplateData.from_npz(Path(model_dir) / template_data_npz) + pipeline = torch.load(model_dir / "featurization_pipeline.pt") + return cls( + sorting, hdf5_path, recording, template_data, pipeline, motion_est, **kwargs + ) + + # pickle/h5py gizmos + + def __post_init__(self): + assert self.hdf5_path.exists() + self.coarse_template_data = self.template_data.coarsen() + + # this obj will be pickled and we don't use these, let's save ourselves the ram + if self.sorting.extra_features: + self.sorting = replace(self.sorting, extra_features=None) + self.shifting = ( + self.motion_est is not None + or self.template_data.registered_geom is not None + ) + if self.shifting: + assert ( + self.motion_est is not None + and self.template_data.registered_geom is not None + ) + + # cached hdf5 pointer + self._h5 = None + + # cached arrays + self.clear_cache() + + def clear_cache(self): + self._xyza = None + self._max_chan_amplitudes = None + self._amplitude_vectors = None + self._channel_index = None + self._geom = None + self._tpca_features = None + self._sklearn_tpca = None + self._feats = {} + + def __getstate__(self): + # remove cached stuff before pickling + return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} + + # cache gizmos + + @property + def h5(self): + if self._h5 is None: + self._h5 = h5py.File(self.hdf5_path, "r") + return self._h5 + + @property + def xyza(self): + if self._xyza is None: + self._xyza = self.h5[self.localizations_dataset][:] + return self._xyza + + @property + def max_chan_amplitudes(self): + if self._max_chan_amplitudes is None: + self._max_chan_amplitudes = self.h5[self.amplitudes_dataset][:] + return self._max_chan_amplitudes + + @property + def amplitude_vectors(self): + if self._amplitude_vectors is None: + self._amplitude_vectors = self.h5[self.amplitude_vectors_dataset][:] + return self._amplitude_vectors + + @property + def geom(self): + if self._geom is None: + self._geom = self.h5["geom"][:] + return self._geom + + @property + def channel_index(self): + if self._channel_index is None: + self._channel_index = self.h5["channel_index"][:] + return self._channel_index + + @property + def sklearn_tpca(self): + if self._sklearn_tpca is None: + tpca_feature = [ + f + for f in self.featurization_pipeline.transformers + if f.name == self.tpca_features_dataset + ] + assert len(tpca_feature) == 1 + self._sklearn_tpca = tpca_feature[0].to_sklearn() + return self._sklearn_tpca + + # spike train helpers + + def unit_ids(self): + return np.unique(self.sorting.labels) + + def in_unit(self, unit_id): + return np.flatnonzero(np.isin(self.sorting.labels, unit_id)) + + # spike feature loading methods + + def named_feature(self, name, which=slice(None)): + if name not in self._feats: + self._feats[name] = self.h5[name][:] + return self._feats[name][which] + + def x(self, which=slice(None)): + return self.xyza[which, 0] + + def z(self, which=slice(None), registered=True): + z = self.xyza[which, 2] + if registered and self.motion_est is not None: + z = self.motion_est.correct_s(self.sorting.times_seconds, z) + return z + + def times_seconds(self, which=slice(None)): + return self.sorting.times_seconds[which] + + def times_samples(self, which=slice(None)): + return self.sorting.times_samples[which] + + def amplitudes(self, which=slice(None), relocated=False): + if not relocated or self.motion_est is None: + return self.max_chan_amplitudes[which] + + reloc_amp_vecs = relocate.relocated_waveforms_on_static_channels( + self.amplitude_vectors[which], + main_channels=self.channels[which], + channel_index=self.channel_index, + xyza_from=self.xyza[which], + z_to=self.z(which), + geom=self.geom, + registered_geom=self.template_data.registered_geom, + ) + return reloc_amp_vecs.max(1) + + def tpca_features(self, which=slice(None)): + if self._tpca_features is None: + self._tpca_features = self.h5[self.tpca_features_dataset] + if isinstance(which, slice): + which = np.arange(len(self.sorting))[which] + return batched_h5_read(self._tpca_features, which) + + # cluster-dependent feature loading methods + + def unit_waveforms( + self, + unit_id, + max_count=250, + random_seed=0, + show_radius_um=75, + trough_offset_samples=42, + spike_length_samples=121, + ): + which = self.in_unit(unit_id) + if which.size > max_count: + rg = np.random.default_rng(0) + which = rg.choice(which, size=max_count, replace=False) + if not which.size: + return np.zeros((0, spike_length_samples, 1)) + + # read waveforms from disk + if self.shifting: + load_ci = self.channel_index + if self.shifting: + load_ci = make_channel_index( + self.recording.get_channel_locations(), show_radius_um + ) + waveforms = read_waveforms_channel_index( + self.recording, + self.times_samples(which=which), + load_ci, + self.sorting.channels[which], + trough_offset_samples=trough_offset_samples, + spike_length_samples=spike_length_samples, + fill_value=np.nan, + ) + if not self.shifting: + return waveforms + + return self.unit_shift_channels(unit_id, which, waveforms, load_ci, show_radius_um=show_radius_um) + + def unit_tpca_waveforms( + self, unit_id, max_count=250, random_seed=0, show_radius_um=75, relocated=False, + ): + which = self.in_unit(unit_id) + if not which.size: + return np.zeros((0, 1, self.channel_index.shape[1])) + + tpca_embeds = self.tpca_features(which=which) + n, rank, c = tpca_embeds.shape + waveforms = tpca_embeds.transpose(0, 2, 1).reshape(n * c, rank) + waveforms = self.sklearn_tpca.inverse_transform(waveforms) + t = waveforms.shape[1] + waveforms = waveforms.reshape(n, c, t).transpose(0, 2, 1) + + return self.unit_shift_or_relocate_channels( + unit_id, + which, + waveforms, + self.channel_index, + show_radius_um=show_radius_um, + relocate=relocated, + ) + + def unit_pca_features(self, unit_id, relocated=True, rank=2, pca_radius_um=75, random_seed=0): + waveforms = self.unit_tpca_waveforms(unit_id, relocated=relocated, show_radius_um=pca_radius_um, random_seed=random_seed) + + no_nan = np.flatnonzero(~np.isnan(waveforms).any(axis=1)) + features = np.full((len(waveforms), rank), np.nan, dtype=waveforms.dtype) + if no_nan.size < max(self.min_cluster_size, self.n_pca_features): + return features + + pca = PCA(self.n_pca_features, random_state=random_seed, whiten=True) + features[no_nan] = pca.fit_transform(waveforms[no_nan]) + return features + + def unit_shift_or_relocate_channels( + self, unit_id, which, waveforms, load_channel_index, show_radius_um=75, relocate=False + ): + geom = self.recording.get_channel_locations() + show_geom = self.recording.registered_geom + if show_geom is None: + show_geom is geom + temp = self.coarse_templates.templates[ + self.coarse_templates.unit_ids == unit_id + ] + assert temp.shape[0] == 1 + max_chan = temp.squeeze().ptp(0).argmax() + show_chans = np.flatnonzero( + np.square(show_geom - show_geom[max_chan][None]).sum(1) + < show_radius_um**2 + ) + + if relocate: + return relocate.relocated_waveforms_on_static_channels( + waveforms, + main_channels=self.sorting.channels[which], + channel_index=load_channel_index, + xyza_from=self.xyza[which], + target_channels=show_chans, + z_to=self.z(which=which, registered=True), + geom=geom, + registered_geom=show_geom, + ) + + n_pitches_shift = get_spike_pitch_shifts( + self.z(which=which, registered=False), + geom=geom, + registered_depths_um=self.z(which=which, registered=True), + times_s=self.times_seconds(which=which), + motion_est=self.motion_est, + ) + + return get_waveforms_on_static_channels( + waveforms, + geom=geom, + n_pitches_shift=n_pitches_shift, + main_channels=self.sorting.channels[which], + channel_index=load_channel_index, + target_channels=show_chans, + registered_geom=show_geom, + ) + + +# -- h5 helper... slow reading... + + +def batched_h5_read(dataset, indices, batch_size=1000): + if indices.size < batch_size: + return dataset[indices] + else: + out = np.empty((indices.size, *dataset.shape[1:]), dtype=dataset.dtype) + for bs in range(0, indices.size, batch_size): + be = min(indices.size, bs + batch_size) + out[bs:be] = dataset[indices[bs:be]] + return out diff --git a/src/dartsort/util/py_util.py b/src/dartsort/util/py_util.py index e0f36fbb..6572297e 100644 --- a/src/dartsort/util/py_util.py +++ b/src/dartsort/util/py_util.py @@ -3,6 +3,15 @@ class timer: + """ + with timer("hi"): + bubblesort(np.arange(1e6)[::-1]) + # prints: hi took <> s + with timer("zoom") as tic: + pass + assert np.isclose(tic.dt, 0) + """ + def __init__(self, name="timer"): self.name = name @@ -11,8 +20,8 @@ def __enter__(self): return self def __exit__(self, *args): - self.t = time.time() - self.start - print(self.name, "took", self.t, "s") + self.dt = time.time() - self.start + print(self.name, "took", self.dt, "s") class NoKeyboardInterrupt: diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py new file mode 100644 index 00000000..a4f95360 --- /dev/null +++ b/src/dartsort/vis/unit.py @@ -0,0 +1,315 @@ +"""Toolkit for extensible single unit summary plots + +The goal is to make it easy to add and remove plots without the plotting +code turning into a web of if statements, for loops, and bizarre subplot +and subfigure mazes. + +Relies on the DARTsortAnalysis object of utils/analysis.py to do most of +the data work so that this file can focus on plotting (sort of MVC). +""" +from collections import namedtuple + +import matplotlib.pyplot as plt +import numpy as np + +# -- main class. see fn make_unit_summary below to make lots of UnitPlots. + + +class UnitPlot: + kind: str + width = 1 + height = 1 + + def draw(self, axis, sorting_analysis, unit_id): + raise NotImplementedError + + +# -- small summary plots + + +class ACG(UnitPlot): + kind = "histogram" + height = 0.75 + + def __init__(self, max_lag=50): + self.max_lag = max_lag + + def draw(self, axis, sorting_analysis, unit_id): + times_samples = sorting_analysis.times_samples( + which=sorting_analysis.in_unit(unit_id) + ) + lags, acg = correlogram(times_samples, max_lag=self.max_lag) + axis.bar(lags[:-1], acg) + axis.set_xlabel("lag (samples)") + axis.set_ylabel("acg") + + +class ISIHistogram(UnitPlot): + kind = "histogram" + height = 0.75 + + def __init__(self, bin_ms=0.1): + self.bin_ms = bin_ms + + def draw(self, axis, sorting_analysis, unit_id): + times_ms = ( + sorting_analysis.times_seconds(which=sorting_analysis.in_unit(unit_id)) + / 1000 + ) + bin_edges = np.arange( + np.floor(times_ms.min()), + np.floor(times_ms.max()), + self.bin_ms, + ) + axis.hist(times_ms, bins=bin_edges) + axis.set_xlabel("isi (ms)") + axis.set_ylabel("count") + + +class XZScatter(UnitPlot): + kind = "scatter" + + def __init__(self, relocate_amplitudes=False, registered=True, max_amplitude=15): + self.relocate_amplitudes = relocate_amplitudes + self.registered = registered + self.max_amplitude = max_amplitude + + def draw(self, axis, sorting_analysis, unit_id): + in_unit = sorting_analysis.in_unit(unit_id) + x = sorting_analysis.x(which=in_unit) + z = sorting_analysis.z(which=in_unit, registered=self.registered) + amps = sorting_analysis.amplitudes(which=in_unit, relocated=self.relocate_amplitudes) + s = axis.scatter(x, z, c=np.minimum(amps, self.max_amplitude), lw=0, s=3) + self.set_xlabel("x (um)") + reg_str = "registered " * self.registered + self.set_ylabel(reg_str + "z (um)") + reloc_str = "relocated " * self.relocate_amplitudes + plt.colorbar(s, ax=axis, shrink=0.5, label=reloc_str + "amplitude (su)") + + +class PCAScatter(UnitPlot): + kind = "scatter" + + def __init__(self, relocate_amplitudes=False, relocated=True, max_amplitude=15): + self.relocated = relocated + self.relocate_amplitudes = relocate_amplitudes + self.max_amplitude = max_amplitude + + def draw(self, axis, sorting_analysis, unit_id): + in_unit = sorting_analysis.in_unit(unit_id) + loadings = sorting_analysis.pca_features(which=in_unit, relocated=self.relocated) + amps = sorting_analysis.amplitudes(which=in_unit, relocated=self.relocate_amplitudes) + s = axis.scatter(*loadings.T, c=np.minimum(amps, self.max_amplitude), lw=0, s=3) + reloc_str = "relocated " * self.relocated + self.set_xlabel(reloc_str + "per-unit PC1 (um)") + self.set_ylabel(reloc_str + "per-unit PC2 (um)") + reloc_amp_str = "relocated " * self.relocate_amplitudes + plt.colorbar(s, ax=axis, shrink=0.5, label=reloc_amp_str + "amplitude (su)") + +# -- wide scatter plots + + +class TZScatter(UnitPlot): + kind = "widescatter" + width = 2 + + def __init__(self, relocate_amplitudes=False, registered=True, max_amplitude=15): + self.relocate_amplitudes = relocate_amplitudes + self.registered = registered + self.max_amplitude = max_amplitude + + def draw(self, axis, sorting_analysis, unit_id): + in_unit = sorting_analysis.in_unit(unit_id) + t = sorting_analysis.times_seconds(which=in_unit) + z = sorting_analysis.z(which=in_unit, registered=self.registered) + amps = sorting_analysis.amplitudes(which=in_unit, relocated=self.relocate_amplitudes) + s = axis.scatter(t, z, c=np.minimum(amps, self.max_amplitude), lw=0, s=3) + self.set_xlabel("time (seconds)") + reg_str = "registered " * self.registered + self.set_ylabel(reg_str + "z (um)") + reloc_str = "relocated " * self.relocate_amplitudes + plt.colorbar(s, ax=axis, shrink=0.5, label=reloc_str + "amplitude (su)") + + +class TFeatScatter(UnitPlot): + kind = "widescatter" + width = 2 + + def __init__(self, feat_name, color_by_amplitude=True, relocate_amplitudes=False, max_amplitude=15): + self.relocate_amplitudes = relocate_amplitudes + self.feat_name = feat_name + self.max_amplitude = max_amplitude + self.color_by_amplitude = color_by_amplitude + + def draw(self, axis, sorting_analysis, unit_id): + in_unit = sorting_analysis.in_unit(unit_id) + t = sorting_analysis.times_seconds(which=in_unit) + z = sorting_analysis.named_feature(self.feat_name, which=in_unit) + c = None + if self.color_by_amplitude: + amps = sorting_analysis.amplitudes(which=in_unit, relocated=self.relocate_amplitudes) + c = np.minimum(amps, self.max_amplitude) + s = axis.scatter(t, z, c=c, lw=0, s=3) + self.set_xlabel("time (seconds)") + self.set_ylabel(self.feat_name) + if self.color_by_amplitude: + reloc_str = "relocated " * self.relocate_amplitudes + plt.colorbar(s, ax=axis, shrink=0.5, label=reloc_str + "amplitude (su)") + + +class TAmpScatter(UnitPlot): + kind = "widescatter" + width = 2 + + def __init__(self, relocate_amplitudes=False, max_amplitude=15): + self.relocate_amplitudes = relocate_amplitudes + self.max_amplitude = max_amplitude + + def draw(self, axis, sorting_analysis, unit_id): + in_unit = sorting_analysis.in_unit(unit_id) + t = sorting_analysis.times_seconds(which=in_unit) + amps = sorting_analysis.amplitudes(which=in_unit, relocated=self.relocate_amplitudes) + axis.scatter(t, amps, c="k", lw=0, s=3) + self.set_xlabel("time (seconds)") + reloc_str = "relocated " * self.relocate_amplitudes + self.set_ylabel(reloc_str + "amplitude (su)") + + +# -- waveform plots + + + + + +# -- main routines + + +default_plots = ( + ACG(), + ISIHistogram(), + XZScatter(), + PCAScatter(), + TZScatter(), + TZScatter(registered=False), + TAmpScatter(), + TAmpScatter(relocate_amplitudes=True), +) + + +def make_unit_summary( + sorting_analysis, + unit_id, + plots=default_plots, + max_height=4, + figsize=(11, 8.5), +): + plots_by_kind = {} + for plot in plots: + if plot.kind not in plots_by_kind: + plots_by_kind[plot.kind] = [] + plots_by_kind[plot.kind].append(plot) + + # -- lay out the figure + columns = summary_layout(plots_by_kind, max_height=max_height) + + # -- draw the figure + width_ratios = [column[0].width for column in columns] + figure = plt.figure(figsize=figsize, layout="constrained") + subfigures = figure.subfigures( + nrows=1, ncols=len(columns), hspace=0.1, width_ratios=width_ratios + ) + all_panels = subfigures.tolist() + for column, subfig in zip(columns, subfigures): + n_cards = len(column) + height_ratios = [card.height for card in column] + remaining_height = max_height - sum(height_ratios) + if remaining_height > 0: + height_ratios.append([remaining_height]) + + cardfigs = subfig.subfigures( + nrows=n_cards + (remaining_height > 0), ncols=1, height_ratios=height_ratios + ) + all_panels.extend(cardfigs) + + for cardfig, card in zip(cardfigs, column): + axes = cardfig.subplots(nrows=len(card.plots), ncols=1) + for plot, axis in zip(card.plots, axes): + plot.draw(axis, sorting_analysis, unit_id) + + # clean up the panels, or else things get clipped + for panel in all_panels: + panel.set_facecolor([0, 0, 0, 0]) + panel.patch.set_facecolor([0, 0, 0, 0]) + + return figure + + +def make_all_summaries( + sorting_analysis, save_folder, max_height=4, figsize=(11, 8.5), dpi=200 +): + pass + + +# -- utilities + + +def correlogram(times_a, times_b=None, max_lag=50): + lags = np.arange(-max_lag, max_lag + 1) + ccg = np.zeros(len(lags), dtype=int) + + times_a = np.sort(times_a) + auto = times_b is None + if auto: + times_b = times_a + else: + times_b = np.sort(times_b) + + for i, lag in enumerate(lags): + insertion_inds = np.searchsorted(times_a, times_b + lag) + ccg[i] = np.sum(times_a[insertion_inds] == times_b) + + if auto: + ccg[lags == 0] = 0 + + return lags, ccg + + +# -- plotting helpers + + +Card = namedtuple("Card", ["kind", "width", "height", "plots"]) + + +def summary_layout(plots_by_kind, max_height=4): + # break plots into groups ("cards") by kind + cards = [] + for kind, plots in plots_by_kind.items(): + width = max(p.width for p in plots) + card_plots = [] + for plot in plots: + if sum(p.height for p in card_plots) + plot.height <= max_height: + card_plots.append(plot) + else: + cards.append( + Card(plots[0].kind, width, sum(p.height for p in card_plots)) + ) + card_plots = [] + if card_plots: + cards.append(Card(plots[0].kind, width, sum(p.height for p in card_plots))) + cards = sorted(cards, key=lambda card: card.width) + + # flow the same-width cards over columns + columns = [[]] + cur_width = cards[0].width + for card in cards: + if card.width != cur_width: + columns.append([card]) + cur_width = card.width + continue + + if sum(c.height for c in columns[-1]) + card.height <= max_height: + columns[-1].append(card) + else: + columns.append([card]) + + return columns