From 07488bb0d43bd53646993d040290262a2e453af7 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 14 Dec 2023 12:06:33 -0500 Subject: [PATCH] Debug unit plots and fix an alignment issue --- src/dartsort/templates/get_templates.py | 10 +- src/dartsort/templates/templates.py | 9 + src/dartsort/util/analysis.py | 116 ++++++--- src/dartsort/vis/__init__.py | 2 + src/dartsort/vis/unit.py | 319 ++++++++++++++++++++---- src/dartsort/vis/waveforms.py | 37 ++- 6 files changed, 395 insertions(+), 98 deletions(-) diff --git a/src/dartsort/templates/get_templates.py b/src/dartsort/templates/get_templates.py index 58b75d22..ac2c1999 100644 --- a/src/dartsort/templates/get_templates.py +++ b/src/dartsort/templates/get_templates.py @@ -122,6 +122,7 @@ def get_templates( sorting, templates = realign_sorting( sorting, raw_results["raw_templates"], + raw_results["snrs_by_channel"], max_shift=realign_max_sample_shift, trough_offset_samples=trough_offset_samples, recording_length_samples=recording.get_num_samples(), @@ -234,6 +235,7 @@ def get_raw_templates( def realign_sorting( sorting, templates, + snrs_by_channel, max_shift=20, trough_offset_samples=42, recording_length_samples=None, @@ -244,7 +246,7 @@ def realign_sorting( return sorting, templates # find template peak time - template_maxchans = templates.ptp(1).argmax(1) + template_maxchans = snrs_by_channel.argmax(1) template_maxchan_traces = templates[np.arange(n), :, template_maxchans] template_peak_times = np.abs(template_maxchan_traces).argmax(1) @@ -254,10 +256,10 @@ def realign_sorting( # create aligned spike train new_times = sorting.times_samples + template_shifts[sorting.labels] - labels = sorting.labels + labels = sorting.labels.copy() if recording_length_samples is not None: highlim = recording_length_samples - (t - trough_offset_samples) - labels[(new_times < trough_offset_samples) & (new_times >= highlim)] = -1 + labels[(new_times < trough_offset_samples) & (new_times > highlim)] = -1 aligned_sorting = replace(sorting, labels=labels, times_samples=new_times) # trim templates @@ -567,7 +569,7 @@ def _template_job(unit_ids): order = np.argsort(in_units) in_units = in_units[order] labels = labels[order] - + # read waveforms for all units times = p.sorting.times_samples[in_units] valid = np.flatnonzero( diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 75e62197..3a9284ea 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -29,6 +29,8 @@ class TemplateData: registered_geom: Optional[np.ndarray] = None registered_template_depths_um: Optional[np.ndarray] = None localization_radius_um: float = 100.0 + trough_offset_samples: int = 42 + spike_length_samples: int = 121 @classmethod def from_npz(cls, npz_path): @@ -77,6 +79,9 @@ def coarsen(self, with_locs=True): spike_counts=spike_counts, registered_template_depths_um=registered_template_depths_um, ) + + def unit_templates(self, unit_id): + return self.templates[self.unit_ids == unit_id] @classmethod def from_config( @@ -196,12 +201,16 @@ def from_config( kwargs["registered_geom"], registered_template_depths_um, localization_radius_um=template_config.registered_template_localization_radius_um, + trough_offset_samples=template_config.trough_offset_samples, + spike_length_samples=template_config.spike_length_samples, ) else: obj = cls( results["templates"], unit_ids, spike_counts, + trough_offset_samples=template_config.trough_offset_samples, + spike_length_samples=template_config.spike_length_samples, ) if save_folder is not None: diff --git a/src/dartsort/util/analysis.py b/src/dartsort/util/analysis.py index 46cad9fb..e51b8342 100644 --- a/src/dartsort/util/analysis.py +++ b/src/dartsort/util/analysis.py @@ -23,7 +23,7 @@ from ..transform import WaveformPipeline from .data_util import DARTsortSorting from .drift_util import (get_spike_pitch_shifts, - get_waveforms_on_static_channels) + get_waveforms_on_static_channels, registered_average) from .spikeio import read_waveforms_channel_index from .waveform_util import make_channel_index @@ -84,16 +84,20 @@ def from_peeling_paths( model_dir=None, motion_est=None, template_data_npz="template_data.npz", + template_data=None, + sorting=None, **kwargs, ): hdf5_path = Path(hdf5_path) if model_dir is None: model_dir = hdf5_path.parent / f"{hdf5_path.stem}_models" assert model_dir.exists() - sorting = DARTsortSorting.from_peeling_hdf5( - hdf5_path, load_simple_features=False - ) - template_data = TemplateData.from_npz(Path(model_dir) / template_data_npz) + if sorting is None: + sorting = DARTsortSorting.from_peeling_hdf5( + hdf5_path, load_simple_features=False + ) + if template_data is None: + template_data = TemplateData.from_npz(Path(model_dir) / template_data_npz) pipeline = torch.load(model_dir / "featurization_pipeline.pt") return cls( sorting, hdf5_path, recording, template_data, pipeline, motion_est, **kwargs @@ -132,12 +136,14 @@ def clear_cache(self): self._geom = None self._tpca_features = None self._sklearn_tpca = None + self._unit_ids = None + self._spike_counts = None self._feats = {} def __getstate__(self): # remove cached stuff before pickling - return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} - + return {k: v if not k.startswith("_") else None for k, v in self.__dict__.items()} + # cache gizmos @property @@ -190,9 +196,21 @@ def sklearn_tpca(self): # spike train helpers + @property def unit_ids(self): - allunits = np.unique(self.sorting.labels) - return allunits[allunits >= 0] + if self._unit_ids is None: + allunits, counts = np.unique(self.sorting.labels, return_counts=True) + self._unit_ids = allunits[allunits >= 0] + self._spike_counts = counts[allunits >= 0] + return self._unit_ids + + @property + def spike_counts(self): + if self._spike_counts is None: + allunits, counts = np.unique(self.sorting.labels, return_counts=True) + self._unit_ids = allunits[allunits >= 0] + self._spike_counts = counts[allunits >= 0] + return self._spike_counts def in_unit(self, unit_id): return np.flatnonzero(np.isin(self.sorting.labels, unit_id)) @@ -210,11 +228,11 @@ def x(self, which=slice(None)): def z(self, which=slice(None), registered=True): z = self.xyza[which, 2] if registered and self.motion_est is not None: - z = self.motion_est.correct_s(self.sorting.times_seconds, z) + z = self.motion_est.correct_s(self.times_seconds(which=which), z) return z def times_seconds(self, which=slice(None)): - return self.sorting.times_seconds[which] + return self.recording._recording_segments[0].sample_index_to_time(self.times_samples(which=which)) def times_samples(self, which=slice(None)): return self.sorting.times_samples[which] @@ -225,14 +243,15 @@ def amplitudes(self, which=slice(None), relocated=False): reloc_amp_vecs = relocate.relocated_waveforms_on_static_channels( self.amplitude_vectors[which], - main_channels=self.channels[which], + main_channels=self.sorting.channels[which], channel_index=self.channel_index, xyza_from=self.xyza[which], z_to=self.z(which), geom=self.geom, registered_geom=self.template_data.registered_geom, + target_channels=slice(None), ) - return reloc_amp_vecs.max(1) + return np.nanmax(reloc_amp_vecs, axis=1) def tpca_features(self, which=slice(None)): if self._tpca_features is None: @@ -313,25 +332,26 @@ def unit_tpca_waveforms( waveforms, self.channel_index, show_radius_um=show_radius_um, - relocate=relocated, + relocated=relocated, ) def unit_pca_features( self, unit_id, relocated=True, rank=2, pca_radius_um=75, random_seed=0 ): - waveforms = self.unit_tpca_waveforms( + waveforms, max_chan, show_geom, show_channel_index = self.unit_tpca_waveforms( unit_id, relocated=relocated, show_radius_um=pca_radius_um, random_seed=random_seed, ) + waveforms = waveforms.reshape(len(waveforms), -1) no_nan = np.flatnonzero(~np.isnan(waveforms).any(axis=1)) features = np.full((len(waveforms), rank), np.nan, dtype=waveforms.dtype) - if no_nan.size < max(self.min_cluster_size, self.n_pca_features): + if no_nan.size < rank: return features - pca = PCA(self.n_pca_features, random_state=random_seed, whiten=True) + pca = PCA(rank, random_state=random_seed, whiten=True) features[no_nan] = pca.fit_transform(waveforms[no_nan]) return features @@ -342,22 +362,46 @@ def unit_shift_or_relocate_channels( waveforms, load_channel_index, show_radius_um=75, - relocate=False, + relocated=False, ): geom = self.recording.get_channel_locations() - show_geom = self.recording.registered_geom + show_geom = self.template_data.registered_geom if show_geom is None: - show_geom is geom - temp = self.coarse_templates.templates[ - self.coarse_templates.unit_ids == unit_id - ] - assert temp.shape[0] == 1 - max_chan = temp.squeeze().ptp(0).argmax() + show_geom = geom + temp = self.coarse_template_data.unit_templates(unit_id) + n_pitches_shift = None + if temp.shape[0]: + max_chan = temp.squeeze().ptp(0).argmax() + else: + amps = waveforms.ptp(1) + if self.shifting: + n_pitches_shift = get_spike_pitch_shifts( + self.z(which=which, registered=False), + geom=geom, + registered_depths_um=self.z(which=which, registered=True), + times_s=self.times_seconds(which=which), + motion_est=self.motion_est, + ) + amp_template = registered_average( + amps[:, None, :], + n_pitches_shift, + geom, + show_geom, + main_channels=self.sorting.channels[which], + channel_index=load_channel_index, + )[0] + else: + amp_template = np.nanmean(amps, axis=0) + max_chan = np.nanargmax(amp_template) show_channel_index = make_channel_index(show_geom, show_radius_um) show_chans = show_channel_index[max_chan] + show_chans = show_chans[show_chans < len(show_geom)] + + if not self.shifting: + return waveforms, max_chan, show_geom, show_channel_index - if relocate: - return relocate.relocated_waveforms_on_static_channels( + if relocated: + waveforms = relocate.relocated_waveforms_on_static_channels( waveforms, main_channels=self.sorting.channels[which], channel_index=load_channel_index, @@ -367,14 +411,16 @@ def unit_shift_or_relocate_channels( geom=geom, registered_geom=show_geom, ) - - n_pitches_shift = get_spike_pitch_shifts( - self.z(which=which, registered=False), - geom=geom, - registered_depths_um=self.z(which=which, registered=True), - times_s=self.times_seconds(which=which), - motion_est=self.motion_est, - ) + return waveforms, max_chan, show_geom, show_channel_index + + if n_pitches_shift is None: + n_pitches_shift = get_spike_pitch_shifts( + self.z(which=which, registered=False), + geom=geom, + registered_depths_um=self.z(which=which, registered=True), + times_s=self.times_seconds(which=which), + motion_est=self.motion_est, + ) waveforms = get_waveforms_on_static_channels( waveforms, diff --git a/src/dartsort/vis/__init__.py b/src/dartsort/vis/__init__.py index 5010da17..693f438e 100644 --- a/src/dartsort/vis/__init__.py +++ b/src/dartsort/vis/__init__.py @@ -1 +1,3 @@ from .scatterplots import * +from .unit import * +from .waveforms import * diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index de236905..9f46e369 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -11,10 +11,12 @@ from pathlib import Path import matplotlib.pyplot as plt +from matplotlib.legend_handler import HandlerTuple +from matplotlib.figure import Figure import numpy as np from tqdm.auto import tqdm -from ..multiprocessing_util import get_pool +from ..util.multiprocessing_util import get_pool from .waveforms import geomplot # -- main class. see fn make_unit_summary below to make lots of UnitPlots. @@ -29,6 +31,30 @@ def draw(self, axis, sorting_analysis, unit_id): raise NotImplementedError +class TextInfo(UnitPlot): + kind = "text" + height = 0.5 + + def draw(self, axis, sorting_analysis, unit_id): + axis.axis("off") + msg = f"unit {unit_id}\n" + + msg += f"feature source: {sorting_analysis.hdf5_path.name}\n" + + nspikes = sorting_analysis.spike_counts[ + sorting_analysis.unit_ids == unit_id + ].sum() + msg += f"n spikes: {nspikes}\n" + axis.text(0, 0, msg, fontsize=6.5) + + temps = sorting_analysis.template_data.unit_templates(unit_id) + if temps.size: + ptp = temps.ptp(1).max(1).mean() + msg += f"mean superres maxptp: {ptp:0.1f}su\n" + else: + msg += "no template (too few spikes)" + + # -- small summary plots @@ -44,7 +70,7 @@ def draw(self, axis, sorting_analysis, unit_id): which=sorting_analysis.in_unit(unit_id) ) lags, acg = correlogram(times_samples, max_lag=self.max_lag) - axis.bar(lags[:-1], acg) + axis.bar(lags, acg) axis.set_xlabel("lag (samples)") axis.set_ylabel("acg") @@ -53,43 +79,60 @@ class ISIHistogram(UnitPlot): kind = "histogram" height = 0.75 - def __init__(self, bin_ms=0.1): + def __init__(self, bin_ms=0.1, max_ms=5): self.bin_ms = bin_ms + self.max_ms = max_ms def draw(self, axis, sorting_analysis, unit_id): - times_ms = ( - sorting_analysis.times_seconds(which=sorting_analysis.in_unit(unit_id)) - / 1000 + times_s = sorting_analysis.times_seconds( + which=sorting_analysis.in_unit(unit_id) ) + dt_ms = np.diff(times_s) * 1000 bin_edges = np.arange( - np.floor(times_ms.min()), - np.floor(times_ms.max()), + 0, + self.max_ms + self.bin_ms, self.bin_ms, ) - axis.hist(times_ms, bins=bin_edges) + # counts, _ = np.histogram(dt_ms, bin_edges) + # bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1]) + # axis.bar(bin_centers, counts) + plt.hist(dt_ms, bin_edges) axis.set_xlabel("isi (ms)") - axis.set_ylabel("count") + axis.set_ylabel(f"count (out of {dt_ms.size} total isis)") class XZScatter(UnitPlot): kind = "scatter" - def __init__(self, relocate_amplitudes=False, registered=True, max_amplitude=15): + def __init__( + self, + relocate_amplitudes=False, + registered=True, + max_amplitude=15, + probe_margin_um=100, + ): self.relocate_amplitudes = relocate_amplitudes self.registered = registered self.max_amplitude = max_amplitude + self.probe_margin_um = probe_margin_um def draw(self, axis, sorting_analysis, unit_id): in_unit = sorting_analysis.in_unit(unit_id) x = sorting_analysis.x(which=in_unit) z = sorting_analysis.z(which=in_unit, registered=self.registered) + geomx, geomz = sorting_analysis.geom.T + pad = self.probe_margin_um + valid = x == np.clip(x, geomx.min() - pad, geomx.max() + pad) + valid &= z == np.clip(z, geomz.min() - pad, geomz.max() + pad) amps = sorting_analysis.amplitudes( - which=in_unit, relocated=self.relocate_amplitudes + which=in_unit[valid], relocated=self.relocate_amplitudes ) - s = axis.scatter(x, z, c=np.minimum(amps, self.max_amplitude), lw=0, s=3) - self.set_xlabel("x (um)") + s = axis.scatter( + x[valid], z[valid], c=np.minimum(amps, self.max_amplitude), lw=0, s=3 + ) + axis.set_xlabel("x (um)") reg_str = "registered " * self.registered - self.set_ylabel(reg_str + "z (um)") + axis.set_ylabel(reg_str + "z (um)") reloc_str = "relocated " * self.relocate_amplitudes plt.colorbar(s, ax=axis, shrink=0.5, label=reloc_str + "amplitude (su)") @@ -104,16 +147,16 @@ def __init__(self, relocate_amplitudes=False, relocated=True, max_amplitude=15): def draw(self, axis, sorting_analysis, unit_id): in_unit = sorting_analysis.in_unit(unit_id) - loadings = sorting_analysis.pca_features( - which=in_unit, relocated=self.relocated + loadings = sorting_analysis.unit_pca_features( + unit_id=unit_id, relocated=self.relocated ) amps = sorting_analysis.amplitudes( which=in_unit, relocated=self.relocate_amplitudes ) s = axis.scatter(*loadings.T, c=np.minimum(amps, self.max_amplitude), lw=0, s=3) reloc_str = "relocated " * self.relocated - self.set_xlabel(reloc_str + "per-unit PC1 (um)") - self.set_ylabel(reloc_str + "per-unit PC2 (um)") + axis.set_xlabel(reloc_str + "per-unit PC1 (um)") + axis.set_ylabel(reloc_str + "per-unit PC2 (um)") reloc_amp_str = "relocated " * self.relocate_amplitudes plt.colorbar(s, ax=axis, shrink=0.5, label=reloc_amp_str + "amplitude (su)") @@ -125,22 +168,34 @@ class TZScatter(UnitPlot): kind = "widescatter" width = 2 - def __init__(self, relocate_amplitudes=False, registered=True, max_amplitude=15): + def __init__( + self, + relocate_amplitudes=False, + registered=True, + max_amplitude=15, + probe_margin_um=100, + ): self.relocate_amplitudes = relocate_amplitudes self.registered = registered self.max_amplitude = max_amplitude + self.probe_margin_um = probe_margin_um def draw(self, axis, sorting_analysis, unit_id): in_unit = sorting_analysis.in_unit(unit_id) t = sorting_analysis.times_seconds(which=in_unit) z = sorting_analysis.z(which=in_unit, registered=self.registered) + geomx, geomz = sorting_analysis.geom.T + pad = self.probe_margin_um + valid = z == np.clip(z, geomz.min() - pad, geomz.max() + pad) amps = sorting_analysis.amplitudes( - which=in_unit, relocated=self.relocate_amplitudes + which=in_unit[valid], relocated=self.relocate_amplitudes + ) + s = axis.scatter( + t[valid], z[valid], c=np.minimum(amps, self.max_amplitude), lw=0, s=3 ) - s = axis.scatter(t, z, c=np.minimum(amps, self.max_amplitude), lw=0, s=3) - self.set_xlabel("time (seconds)") + axis.set_xlabel("time (seconds)") reg_str = "registered " * self.registered - self.set_ylabel(reg_str + "z (um)") + axis.set_ylabel(reg_str + "z (um)") reloc_str = "relocated " * self.relocate_amplitudes plt.colorbar(s, ax=axis, shrink=0.5, label=reloc_str + "amplitude (su)") @@ -164,16 +219,16 @@ def __init__( def draw(self, axis, sorting_analysis, unit_id): in_unit = sorting_analysis.in_unit(unit_id) t = sorting_analysis.times_seconds(which=in_unit) - z = sorting_analysis.named_feature(self.feat_name, which=in_unit) + feat = sorting_analysis.named_feature(self.feat_name, which=in_unit) c = None if self.color_by_amplitude: amps = sorting_analysis.amplitudes( which=in_unit, relocated=self.relocate_amplitudes ) c = np.minimum(amps, self.max_amplitude) - s = axis.scatter(t, z, c=c, lw=0, s=3) - self.set_xlabel("time (seconds)") - self.set_ylabel(self.feat_name) + s = axis.scatter(t, feat, c=c, lw=0, s=3) + axis.set_xlabel("time (seconds)") + axis.set_ylabel(self.feat_name) if self.color_by_amplitude: reloc_str = "relocated " * self.relocate_amplitudes plt.colorbar(s, ax=axis, shrink=0.5, label=reloc_str + "amplitude (su)") @@ -194,9 +249,9 @@ def draw(self, axis, sorting_analysis, unit_id): which=in_unit, relocated=self.relocate_amplitudes ) axis.scatter(t, amps, c="k", lw=0, s=3) - self.set_xlabel("time (seconds)") + axis.set_xlabel("time (seconds)") reloc_str = "relocated " * self.relocate_amplitudes - self.set_ylabel(reloc_str + "amplitude (su)") + axis.set_ylabel(reloc_str + "amplitude (su)") # -- waveform plots @@ -206,15 +261,23 @@ class WaveformPlot(UnitPlot): kind = "waveform" width = 2 height = 2 + title = "waveforms" def __init__( self, trough_offset_samples=42, spike_length_samples=121, count=250, - show_radius_um=75, + show_radius_um=50, relocated=False, color="k", + alpha=0.1, + show_superres_templates=True, + superres_template_cmap=plt.cm.winter, + show_template=True, + template_color="orange", + max_abs_template_scale=1.35, + legend=True, ): self.count = count self.show_radius_um = show_radius_um @@ -222,13 +285,48 @@ def __init__( self.color = color self.trough_offset_samples = trough_offset_samples self.spike_length_samples = spike_length_samples + self.alpha = alpha + self.show_template = show_template + self.template_color = template_color + self.show_superres_templates = show_superres_templates + self.superres_template_cmap = superres_template_cmap + self.legend = legend + self.max_abs_template_scale = max_abs_template_scale def get_waveforms(self, sorting_analysis, unit_id): raise NotImplementedError def draw(self, axis, sorting_analysis, unit_id): waveforms, max_chan, geom, ci = self.get_waveforms(sorting_analysis, unit_id) - geomplot( + + max_abs_amp = None + show_template = self.show_template + if show_template: + templates = sorting_analysis.coarse_template_data.unit_templates(unit_id) + show_template = bool(templates.size) + if show_template: + templates = trim_waveforms( + templates, + old_offset=sorting_analysis.coarse_template_data.trough_offset_samples, + new_offset=self.trough_offset_samples, + new_length=self.spike_length_samples, + ) + max_abs_amp = self.max_abs_template_scale * np.abs(templates).max() + show_superres_templates = self.show_superres_templates + if show_superres_templates: + suptemplates = sorting_analysis.template_data.unit_templates(unit_id) + show_superres_templates = bool(suptemplates.size) + if show_superres_templates: + suptemplates = trim_waveforms( + suptemplates, + old_offset=sorting_analysis.template_data.trough_offset_samples, + new_offset=self.trough_offset_samples, + new_length=self.spike_length_samples, + ) + show_superres_templates = suptemplates.shape[0] > 1 + max_abs_amp = self.max_abs_template_scale * np.abs(suptemplates).max() + + ls = geomplot( waveforms, max_channels=np.full(len(waveforms), max_chan), channel_index=ci, @@ -239,10 +337,73 @@ def draw(self, axis, sorting_analysis, unit_id): msbar=False, zlim="tight", color=self.color, + alpha=self.alpha, + max_abs_amp=max_abs_amp, + lw=1, ) + handles = [ls[0]] + labels = ["waveforms"] + + if show_superres_templates: + showchans = ci[max_chan] + showchans = showchans[showchans < len(geom)] + colors = self.superres_template_cmap( + np.linspace(0, 1, num=suptemplates.shape[0]) + ) + suphandles = [] + for i in range(suptemplates.shape[0]): + ls = geomplot( + suptemplates[i][:, showchans], + geom=geom[showchans], + ax=axis, + show_zero=False, + zlim="tight", + color=colors[i], + alpha=1, + max_abs_amp=max_abs_amp, + lw=1, + ) + suphandles.append(ls[0]) + handles.append(tuple(suphandles)) + labels.append("superres templates") + + if show_template: + showchans = ci[max_chan] + showchans = showchans[showchans < len(geom)] + ls = geomplot( + templates[:, :, showchans], + geom=geom[showchans], + ax=axis, + show_zero=False, + zlim="tight", + color=self.template_color, + alpha=1, + max_abs_amp=max_abs_amp, + lw=1, + ) + handles.append(ls[0]) + labels.append("mean of superres templates") + + reloc_str = "relocated " * self.relocated + shift_str = "shifted " * sorting_analysis.shifting + axis.set_title(reloc_str + shift_str + self.title) + reg_str = "registered " * sorting_analysis.shifting + axis.set_ylabel(reg_str + "depth (um)") + axis.set_xticks([]) + + if self.legend: + axis.legend( + handles, + labels, + handler_map={tuple: HandlerTuple(ndivide=None)}, + fancybox=False, + loc="upper left", + ) class RawWaveformPlot(WaveformPlot): + title = "raw waveforms" + def get_waveforms(self, sorting_analysis, unit_id): return sorting_analysis.unit_raw_waveforms( unit_id, @@ -255,6 +416,8 @@ def get_waveforms(self, sorting_analysis, unit_id): class TPCAWaveformPlot(WaveformPlot): + title = "collision-cleaned tpca waveforms" + def get_waveforms(self, sorting_analysis, unit_id): return sorting_analysis.unit_tpca_waveforms( unit_id, @@ -267,6 +430,7 @@ def get_waveforms(self, sorting_analysis, unit_id): # -- main routines default_plots = ( + TextInfo(), ACG(), ISIHistogram(), XZScatter(), @@ -276,7 +440,7 @@ def get_waveforms(self, sorting_analysis, unit_id): TAmpScatter(), TAmpScatter(relocate_amplitudes=True), RawWaveformPlot(), - TPCAWaveformPlot(), + TPCAWaveformPlot(relocated=True), ) @@ -286,13 +450,15 @@ def make_unit_summary( plots=default_plots, max_height=4, figsize=(11, 8.5), + figure=None, ): # -- lay out the figure columns = summary_layout(plots, max_height=max_height) # -- draw the figure width_ratios = [column[0].width for column in columns] - figure = plt.figure(figsize=figsize, layout="constrained") + if figure is None: + figure = plt.figure(figsize=figsize, layout="constrained") subfigures = figure.subfigures( nrows=1, ncols=len(columns), hspace=0.1, width_ratios=width_ratios ) @@ -302,15 +468,17 @@ def make_unit_summary( height_ratios = [card.height for card in column] remaining_height = max_height - sum(height_ratios) if remaining_height > 0: - height_ratios.append([remaining_height]) + height_ratios.append(remaining_height) cardfigs = subfig.subfigures( nrows=n_cards + (remaining_height > 0), ncols=1, height_ratios=height_ratios ) + cardfigs = np.atleast_1d(cardfigs) all_panels.extend(cardfigs) for cardfig, card in zip(cardfigs, column): axes = cardfig.subplots(nrows=len(card.plots), ncols=1) + axes = np.atleast_1d(axes) for plot, axis in zip(card.plots, axes): plot.draw(axis, sorting_analysis, unit_id) @@ -332,6 +500,7 @@ def make_all_summaries( image_ext="png", n_jobs=0, show_progress=True, + overwrite=False, ): save_folder = Path(save_folder) save_folder.mkdir(exist_ok=True) @@ -349,16 +518,19 @@ def make_all_summaries( dpi, save_folder, image_ext, + overwrite, ), ) as pool: jobs = sorting_analysis.unit_ids + results = pool.map(_summary_job, jobs) if show_progress: - jobs = tqdm( - jobs, + results = tqdm( + results, desc="Unit summaries", smoothing=0, + total=len(jobs), ) - for res in pool.map(_summary_job, jobs): + for res in results: pass @@ -377,8 +549,10 @@ def correlogram(times_a, times_b=None, max_lag=50): times_b = np.sort(times_b) for i, lag in enumerate(lags): - insertion_inds = np.searchsorted(times_a, times_b + lag) - ccg[i] = np.sum(times_a[insertion_inds] == times_b) + lagged_b = times_b + lag + insertion_inds = np.searchsorted(times_a, lagged_b) + found = insertion_inds < len(times_a) + ccg[i] = np.sum(times_a[insertion_inds[found]] == lagged_b[found]) if auto: ccg[lags == 0] = 0 @@ -386,6 +560,15 @@ def correlogram(times_a, times_b=None, max_lag=50): return lags, ccg +def trim_waveforms(waveforms, old_offset=42, new_offset=42, new_length=121): + if waveforms.shape[1] == new_length and old_offset == new_offset: + return waveforms + + start = old_offset - new_offset + end = start + new_length + return waveforms[:, start:end] + + # -- plotting helpers @@ -409,11 +592,20 @@ def summary_layout(plots, max_height=4): card_plots.append(plot) else: cards.append( - Card(plots[0].kind, width, sum(p.height for p in card_plots)) + Card( + plots[0].kind, + width, + sum(p.height for p in card_plots), + card_plots, + ) ) card_plots = [] if card_plots: - cards.append(Card(plots[0].kind, width, sum(p.height for p in card_plots))) + cards.append( + Card( + plots[0].kind, width, sum(p.height for p in card_plots), card_plots + ) + ) cards = sorted(cards, key=lambda card: card.width) # flow the same-width cards over columns @@ -438,7 +630,15 @@ def summary_layout(plots, max_height=4): class SummaryJobContext: def __init__( - self, sorting_analysis, plots, max_height, figsize, dpi, save_folder, image_ext + self, + sorting_analysis, + plots, + max_height, + figsize, + dpi, + save_folder, + image_ext, + overwrite, ): self.sorting_analysis = sorting_analysis self.plots = plots @@ -447,6 +647,7 @@ def __init__( self.dpi = dpi self.save_folder = save_folder self.image_ext = image_ext + self.overwrite = overwrite _summary_job_context = None @@ -458,15 +659,33 @@ def _summary_init(*args): def _summary_job(unit_id): - fig = make_unit_summary( - _summary_job_context.orting_analysis, + # handle resuming/overwriting + ext = _summary_job_context.image_ext + tmp_out = _summary_job_context.save_folder / f"tmp_unit{unit_id:04d}.{ext}" + final_out = _summary_job_context.save_folder / f"unit{unit_id:04d}.{ext}" + if tmp_out.exists(): + tmp_out.unlink() + if not _summary_job_context.overwrite and final_out.exists(): + return + if _summary_job_context.overwrite and final_out.exists(): + final_out.unlink() + + fig = plt.figure( + figsize=_summary_job_context.figsize, + layout="constrained", + # dpi=_summary_job_context.dpi, + ) + make_unit_summary( + _summary_job_context.sorting_analysis, unit_id, plots=_summary_job_context.plots, max_height=_summary_job_context.max_height, figsize=_summary_job_context.figsize, + figure=fig, ) - ext = _summary_job_context.image_ext - fig.savefig( - _summary_job_context.save_folder / f"unit{unit_id:04d}.{ext}", - dpi=_summary_job_context.dpi, - ) + + # the save is done sort of atomically to help with the resuming and avoid + # half-baked image files + fig.savefig(tmp_out, dpi=_summary_job_context.dpi) + tmp_out.rename(final_out) + plt.close(fig) diff --git a/src/dartsort/vis/waveforms.py b/src/dartsort/vis/waveforms.py index 045ad5a0..80d4658d 100644 --- a/src/dartsort/vis/waveforms.py +++ b/src/dartsort/vis/waveforms.py @@ -19,6 +19,8 @@ def geomplot( xlim_factor=1, subar=False, msbar=False, + bar_color="k", + bar_background="w", zlim="tight", **plot_kwargs, ): @@ -27,6 +29,9 @@ def geomplot( ax = ax or plt.gca() # -- validate shapes + if waveforms.ndim == 2: + waveforms = waveforms[None] + assert waveforms.ndim == 3 if max_channels is None and channel_index is None: max_channels = np.zeros(waveforms.shape[0], dtype=int) channel_index = ( @@ -34,15 +39,11 @@ def geomplot( * np.ones(geom.shape[0], dtype=int)[:, None] ) max_channels = np.atleast_1d(max_channels) - if waveforms.ndim == 2: - waveforms = waveforms[None] - else: - assert waveforms.ndim == 3 n_channels, C = channel_index.shape assert geom.shape == (n_channels, 2) T = waveforms.shape[1] if waveforms.shape != (*max_channels.shape, T, C): - raise ValueError(f"Bad shapes: {waveforms.shape=}, {max_channels.shape=}") + raise ValueError(f"Bad shapes: {waveforms.shape=}, {max_channels.shape=}, {C=}") # -- figure out units for plotting z_uniq, z_ix = np.unique(geom[:, 1], return_inverse=True) @@ -93,7 +94,7 @@ def geomplot( for c in unique_chans: if show_zero: if show_zero_kwargs is None: - show_zero_kwargs = dict(color="gray", lw=1, linestyle="--") + show_zero_kwargs = dict(color="gray", lw=0.8, linestyle="--") ax.axhline(geom_plot[c, 1], **show_zero_kwargs) if show_chan_label: ax.annotate(chan_labels[c], geom_plot[c] + ann_offset, size=6, color="gray") @@ -110,6 +111,20 @@ def geomplot( min_z = min(geom_plot[c, 1] for c in unique_chans) if msbar: min_z += max_abs_amp + if bar_background: + ax.add_patch( + Rectangle( + [ + geom_plot[:, 0].max() + T // 4 - 2, + min_z - max_abs_amp / 2 - subar / 10, + ], + 4 + 7 + 2, + subar + subar / 5, + fc=bar_background, + zorder=11, + alpha=0.8, + ) + ) ax.add_patch( Rectangle( [ @@ -118,11 +133,12 @@ def geomplot( ], 4, subar, - fc="k", + fc=bar_color, + zorder=12, ) ) ax.text( - geom_plot[:, 0].max() + T // 4 + 4 + 5, + geom_plot[:, 0].max() + T // 4 + 5, min_z - max_abs_amp / 2 + subar / 2, f"{subar} s.u.", transform=ax.transData, @@ -130,6 +146,8 @@ def geomplot( ha="left", va="center", rotation=-90, + color=bar_color, + zorder=12, ) if msbar: @@ -140,7 +158,7 @@ def geomplot( geom_plot[:, 0].max(), ], 2 * [min_z - max_abs_amp], - color="k", + color=bar_color, lw=2, zorder=890, ) @@ -152,6 +170,7 @@ def geomplot( fontsize=5, ha="center", va="bottom", + color=bar_color, ) if zlim is None: