Skip to content

Commit

Permalink
New analysis and visualization tools
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Dec 13, 2023
1 parent 1256c99 commit 5358b3d
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 32 deletions.
63 changes: 49 additions & 14 deletions src/dartsort/util/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ class DARTsortAnalysis:

@classmethod
def from_peeling_hdf5_and_recording(
cls, hdf5_path, recording, template_data, featurization_pipeline=None, motion_est=None, **kwargs
cls,
hdf5_path,
recording,
template_data,
featurization_pipeline=None,
motion_est=None,
**kwargs,
):
return cls(
DARTsortSorting.from_peeling_hdf5(hdf5_path, load_simple_features=False),
Expand All @@ -83,7 +89,9 @@ def from_peeling_paths(
if model_dir is None:
model_dir = hdf5_path.parent / f"{hdf5_path.stem}_models"
assert model_dir.exists()
sorting = DARTsortSorting.from_peeling_hdf5(hdf5_path, load_simple_features=False)
sorting = DARTsortSorting.from_peeling_hdf5(
hdf5_path, load_simple_features=False
)
template_data = TemplateData.from_npz(Path(model_dir) / template_data_npz)
pipeline = torch.load(model_dir / "featurization_pipeline.pt")
return cls(
Expand Down Expand Up @@ -182,7 +190,8 @@ def sklearn_tpca(self):
# spike train helpers

def unit_ids(self):
return np.unique(self.sorting.labels)
allunits = np.unique(self.sorting.labels)
return allunits[allunits >= 0]

def in_unit(self, unit_id):
return np.flatnonzero(np.isin(self.sorting.labels, unit_id))
Expand Down Expand Up @@ -233,14 +242,15 @@ def tpca_features(self, which=slice(None)):

# cluster-dependent feature loading methods

def unit_waveforms(
def unit_raw_waveforms(
self,
unit_id,
max_count=250,
random_seed=0,
show_radius_um=75,
trough_offset_samples=42,
spike_length_samples=121,
relocated=False,
):
which = self.in_unit(unit_id)
if which.size > max_count:
Expand Down Expand Up @@ -268,10 +278,22 @@ def unit_waveforms(
if not self.shifting:
return waveforms

return self.unit_shift_channels(unit_id, which, waveforms, load_ci, show_radius_um=show_radius_um)
return self.unit_shift_or_relocate_channels(
unit_id,
which,
waveforms,
load_ci,
show_radius_um=show_radius_um,
relocated=relocated,
)

def unit_tpca_waveforms(
self, unit_id, max_count=250, random_seed=0, show_radius_um=75, relocated=False,
self,
unit_id,
max_count=250,
random_seed=0,
show_radius_um=75,
relocated=False,
):
which = self.in_unit(unit_id)
if not which.size:
Expand All @@ -293,8 +315,15 @@ def unit_tpca_waveforms(
relocate=relocated,
)

def unit_pca_features(self, unit_id, relocated=True, rank=2, pca_radius_um=75, random_seed=0):
waveforms = self.unit_tpca_waveforms(unit_id, relocated=relocated, show_radius_um=pca_radius_um, random_seed=random_seed)
def unit_pca_features(
self, unit_id, relocated=True, rank=2, pca_radius_um=75, random_seed=0
):
waveforms = self.unit_tpca_waveforms(
unit_id,
relocated=relocated,
show_radius_um=pca_radius_um,
random_seed=random_seed,
)

no_nan = np.flatnonzero(~np.isnan(waveforms).any(axis=1))
features = np.full((len(waveforms), rank), np.nan, dtype=waveforms.dtype)
Expand All @@ -306,7 +335,13 @@ def unit_pca_features(self, unit_id, relocated=True, rank=2, pca_radius_um=75, r
return features

def unit_shift_or_relocate_channels(
self, unit_id, which, waveforms, load_channel_index, show_radius_um=75, relocate=False
self,
unit_id,
which,
waveforms,
load_channel_index,
show_radius_um=75,
relocate=False,
):
geom = self.recording.get_channel_locations()
show_geom = self.recording.registered_geom
Expand All @@ -317,10 +352,8 @@ def unit_shift_or_relocate_channels(
]
assert temp.shape[0] == 1
max_chan = temp.squeeze().ptp(0).argmax()
show_chans = np.flatnonzero(
np.square(show_geom - show_geom[max_chan][None]).sum(1)
< show_radius_um**2
)
show_channel_index = make_channel_index(show_geom, show_radius_um)
show_chans = show_channel_index[max_chan]

if relocate:
return relocate.relocated_waveforms_on_static_channels(
Expand All @@ -342,7 +375,7 @@ def unit_shift_or_relocate_channels(
motion_est=self.motion_est,
)

return get_waveforms_on_static_channels(
waveforms = get_waveforms_on_static_channels(
waveforms,
geom=geom,
n_pitches_shift=n_pitches_shift,
Expand All @@ -352,6 +385,8 @@ def unit_shift_or_relocate_channels(
registered_geom=show_geom,
)

return waveforms, max_chan, show_geom, show_channel_index


# -- h5 helper... slow reading...

Expand Down
Loading

0 comments on commit 5358b3d

Please sign in to comment.