Skip to content

Commit

Permalink
Debug unit plots and fix an alignment issue
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Dec 14, 2023
1 parent 8bb6866 commit 07488bb
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 98 deletions.
10 changes: 6 additions & 4 deletions src/dartsort/templates/get_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def get_templates(
sorting, templates = realign_sorting(
sorting,
raw_results["raw_templates"],
raw_results["snrs_by_channel"],
max_shift=realign_max_sample_shift,
trough_offset_samples=trough_offset_samples,
recording_length_samples=recording.get_num_samples(),
Expand Down Expand Up @@ -234,6 +235,7 @@ def get_raw_templates(
def realign_sorting(
sorting,
templates,
snrs_by_channel,
max_shift=20,
trough_offset_samples=42,
recording_length_samples=None,
Expand All @@ -244,7 +246,7 @@ def realign_sorting(
return sorting, templates

# find template peak time
template_maxchans = templates.ptp(1).argmax(1)
template_maxchans = snrs_by_channel.argmax(1)
template_maxchan_traces = templates[np.arange(n), :, template_maxchans]
template_peak_times = np.abs(template_maxchan_traces).argmax(1)

Expand All @@ -254,10 +256,10 @@ def realign_sorting(

# create aligned spike train
new_times = sorting.times_samples + template_shifts[sorting.labels]
labels = sorting.labels
labels = sorting.labels.copy()
if recording_length_samples is not None:
highlim = recording_length_samples - (t - trough_offset_samples)
labels[(new_times < trough_offset_samples) & (new_times >= highlim)] = -1
labels[(new_times < trough_offset_samples) & (new_times > highlim)] = -1
aligned_sorting = replace(sorting, labels=labels, times_samples=new_times)

# trim templates
Expand Down Expand Up @@ -567,7 +569,7 @@ def _template_job(unit_ids):
order = np.argsort(in_units)
in_units = in_units[order]
labels = labels[order]

# read waveforms for all units
times = p.sorting.times_samples[in_units]
valid = np.flatnonzero(
Expand Down
9 changes: 9 additions & 0 deletions src/dartsort/templates/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class TemplateData:
registered_geom: Optional[np.ndarray] = None
registered_template_depths_um: Optional[np.ndarray] = None
localization_radius_um: float = 100.0
trough_offset_samples: int = 42
spike_length_samples: int = 121

@classmethod
def from_npz(cls, npz_path):
Expand Down Expand Up @@ -77,6 +79,9 @@ def coarsen(self, with_locs=True):
spike_counts=spike_counts,
registered_template_depths_um=registered_template_depths_um,
)

def unit_templates(self, unit_id):
return self.templates[self.unit_ids == unit_id]

@classmethod
def from_config(
Expand Down Expand Up @@ -196,12 +201,16 @@ def from_config(
kwargs["registered_geom"],
registered_template_depths_um,
localization_radius_um=template_config.registered_template_localization_radius_um,
trough_offset_samples=template_config.trough_offset_samples,
spike_length_samples=template_config.spike_length_samples,
)
else:
obj = cls(
results["templates"],
unit_ids,
spike_counts,
trough_offset_samples=template_config.trough_offset_samples,
spike_length_samples=template_config.spike_length_samples,
)

if save_folder is not None:
Expand Down
116 changes: 81 additions & 35 deletions src/dartsort/util/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..transform import WaveformPipeline
from .data_util import DARTsortSorting
from .drift_util import (get_spike_pitch_shifts,
get_waveforms_on_static_channels)
get_waveforms_on_static_channels, registered_average)
from .spikeio import read_waveforms_channel_index
from .waveform_util import make_channel_index

Expand Down Expand Up @@ -84,16 +84,20 @@ def from_peeling_paths(
model_dir=None,
motion_est=None,
template_data_npz="template_data.npz",
template_data=None,
sorting=None,
**kwargs,
):
hdf5_path = Path(hdf5_path)
if model_dir is None:
model_dir = hdf5_path.parent / f"{hdf5_path.stem}_models"
assert model_dir.exists()
sorting = DARTsortSorting.from_peeling_hdf5(
hdf5_path, load_simple_features=False
)
template_data = TemplateData.from_npz(Path(model_dir) / template_data_npz)
if sorting is None:
sorting = DARTsortSorting.from_peeling_hdf5(
hdf5_path, load_simple_features=False
)
if template_data is None:
template_data = TemplateData.from_npz(Path(model_dir) / template_data_npz)
pipeline = torch.load(model_dir / "featurization_pipeline.pt")
return cls(
sorting, hdf5_path, recording, template_data, pipeline, motion_est, **kwargs
Expand Down Expand Up @@ -132,12 +136,14 @@ def clear_cache(self):
self._geom = None
self._tpca_features = None
self._sklearn_tpca = None
self._unit_ids = None
self._spike_counts = None
self._feats = {}

def __getstate__(self):
# remove cached stuff before pickling
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}

return {k: v if not k.startswith("_") else None for k, v in self.__dict__.items()}
# cache gizmos

@property
Expand Down Expand Up @@ -190,9 +196,21 @@ def sklearn_tpca(self):

# spike train helpers

@property
def unit_ids(self):
allunits = np.unique(self.sorting.labels)
return allunits[allunits >= 0]
if self._unit_ids is None:
allunits, counts = np.unique(self.sorting.labels, return_counts=True)
self._unit_ids = allunits[allunits >= 0]
self._spike_counts = counts[allunits >= 0]
return self._unit_ids

@property
def spike_counts(self):
if self._spike_counts is None:
allunits, counts = np.unique(self.sorting.labels, return_counts=True)
self._unit_ids = allunits[allunits >= 0]
self._spike_counts = counts[allunits >= 0]
return self._spike_counts

def in_unit(self, unit_id):
return np.flatnonzero(np.isin(self.sorting.labels, unit_id))
Expand All @@ -210,11 +228,11 @@ def x(self, which=slice(None)):
def z(self, which=slice(None), registered=True):
z = self.xyza[which, 2]
if registered and self.motion_est is not None:
z = self.motion_est.correct_s(self.sorting.times_seconds, z)
z = self.motion_est.correct_s(self.times_seconds(which=which), z)
return z

def times_seconds(self, which=slice(None)):
return self.sorting.times_seconds[which]
return self.recording._recording_segments[0].sample_index_to_time(self.times_samples(which=which))

def times_samples(self, which=slice(None)):
return self.sorting.times_samples[which]
Expand All @@ -225,14 +243,15 @@ def amplitudes(self, which=slice(None), relocated=False):

reloc_amp_vecs = relocate.relocated_waveforms_on_static_channels(
self.amplitude_vectors[which],
main_channels=self.channels[which],
main_channels=self.sorting.channels[which],
channel_index=self.channel_index,
xyza_from=self.xyza[which],
z_to=self.z(which),
geom=self.geom,
registered_geom=self.template_data.registered_geom,
target_channels=slice(None),
)
return reloc_amp_vecs.max(1)
return np.nanmax(reloc_amp_vecs, axis=1)

def tpca_features(self, which=slice(None)):
if self._tpca_features is None:
Expand Down Expand Up @@ -313,25 +332,26 @@ def unit_tpca_waveforms(
waveforms,
self.channel_index,
show_radius_um=show_radius_um,
relocate=relocated,
relocated=relocated,
)

def unit_pca_features(
self, unit_id, relocated=True, rank=2, pca_radius_um=75, random_seed=0
):
waveforms = self.unit_tpca_waveforms(
waveforms, max_chan, show_geom, show_channel_index = self.unit_tpca_waveforms(
unit_id,
relocated=relocated,
show_radius_um=pca_radius_um,
random_seed=random_seed,
)
waveforms = waveforms.reshape(len(waveforms), -1)

no_nan = np.flatnonzero(~np.isnan(waveforms).any(axis=1))
features = np.full((len(waveforms), rank), np.nan, dtype=waveforms.dtype)
if no_nan.size < max(self.min_cluster_size, self.n_pca_features):
if no_nan.size < rank:
return features

pca = PCA(self.n_pca_features, random_state=random_seed, whiten=True)
pca = PCA(rank, random_state=random_seed, whiten=True)
features[no_nan] = pca.fit_transform(waveforms[no_nan])
return features

Expand All @@ -342,22 +362,46 @@ def unit_shift_or_relocate_channels(
waveforms,
load_channel_index,
show_radius_um=75,
relocate=False,
relocated=False,
):
geom = self.recording.get_channel_locations()
show_geom = self.recording.registered_geom
show_geom = self.template_data.registered_geom
if show_geom is None:
show_geom is geom
temp = self.coarse_templates.templates[
self.coarse_templates.unit_ids == unit_id
]
assert temp.shape[0] == 1
max_chan = temp.squeeze().ptp(0).argmax()
show_geom = geom
temp = self.coarse_template_data.unit_templates(unit_id)
n_pitches_shift = None
if temp.shape[0]:
max_chan = temp.squeeze().ptp(0).argmax()
else:
amps = waveforms.ptp(1)
if self.shifting:
n_pitches_shift = get_spike_pitch_shifts(
self.z(which=which, registered=False),
geom=geom,
registered_depths_um=self.z(which=which, registered=True),
times_s=self.times_seconds(which=which),
motion_est=self.motion_est,
)
amp_template = registered_average(
amps[:, None, :],
n_pitches_shift,
geom,
show_geom,
main_channels=self.sorting.channels[which],
channel_index=load_channel_index,
)[0]
else:
amp_template = np.nanmean(amps, axis=0)
max_chan = np.nanargmax(amp_template)
show_channel_index = make_channel_index(show_geom, show_radius_um)
show_chans = show_channel_index[max_chan]
show_chans = show_chans[show_chans < len(show_geom)]

if not self.shifting:
return waveforms, max_chan, show_geom, show_channel_index

if relocate:
return relocate.relocated_waveforms_on_static_channels(
if relocated:
waveforms = relocate.relocated_waveforms_on_static_channels(
waveforms,
main_channels=self.sorting.channels[which],
channel_index=load_channel_index,
Expand All @@ -367,14 +411,16 @@ def unit_shift_or_relocate_channels(
geom=geom,
registered_geom=show_geom,
)

n_pitches_shift = get_spike_pitch_shifts(
self.z(which=which, registered=False),
geom=geom,
registered_depths_um=self.z(which=which, registered=True),
times_s=self.times_seconds(which=which),
motion_est=self.motion_est,
)
return waveforms, max_chan, show_geom, show_channel_index

if n_pitches_shift is None:
n_pitches_shift = get_spike_pitch_shifts(
self.z(which=which, registered=False),
geom=geom,
registered_depths_um=self.z(which=which, registered=True),
times_s=self.times_seconds(which=which),
motion_est=self.motion_est,
)

waveforms = get_waveforms_on_static_channels(
waveforms,
Expand Down
2 changes: 2 additions & 0 deletions src/dartsort/vis/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .scatterplots import *
from .unit import *
from .waveforms import *
Loading

0 comments on commit 07488bb

Please sign in to comment.