Skip to content

Commit

Permalink
Sketch out some ideas for hybrid comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Jan 19, 2024
1 parent b8d8503 commit 40561bd
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
57 changes: 55 additions & 2 deletions src/dartsort/util/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from dredge.motion_util import MotionEstimate
from sklearn.decomposition import PCA
from spikeinterface.comparison import GroundTruthComparison

from ..cluster import relocate
from ..templates import TemplateData
Expand All @@ -42,9 +43,9 @@ class DARTsortAnalysis:
"""

sorting: DARTsortSorting
hdf5_path: Path
recording: sc.BaseRecording
template_data: TemplateData
hdf5_path: Optional[Path] = None
featurization_pipeline: Optional[WaveformPipeline] = None
motion_est: Optional[MotionEstimate] = None

Expand Down Expand Up @@ -288,6 +289,7 @@ def tpca_features(self, which=slice(None)):
def unit_raw_waveforms(
self,
unit_id,
which=None,
template_index=None,
max_count=250,
random_seed=0,
Expand All @@ -296,10 +298,13 @@ def unit_raw_waveforms(
spike_length_samples=121,
relocated=False,
):
which = self.in_unit(unit_id)
if which is None:
which = self.in_unit(unit_id)
if template_index is not None:
assert template_index in self.unit_template_indices(unit_id)
which = self.in_template(template_index)
if max_count is None:
max_count = which.size
if which.size > max_count:
rg = np.random.default_rng(0)
which = rg.choice(which, size=max_count, replace=False)
Expand Down Expand Up @@ -475,6 +480,54 @@ def unit_shift_or_relocate_channels(
return waveforms, max_chan, show_geom, show_channel_index


@dataclass
class DARTsortGroundTruthComparison:
gt_analysis: DARTsortAnalysis
predicted_analysis: DARTsortAnalysis
gt_name: Optional[str] = None
predicted_name: Optional[str] = None
delta_time: float = 0.4
match_score: float = 0.1
well_detected_score: float = 0.8
exhaustive_gt: bool = False
n_jobs: int = -1
match_mode: str = "hungarian"

def __post_init__(self):
self.comparison = GroundTruthComparison(
gt_sorting=self.gt_analysis.sorting.to_numpy_sorting(),
tested_sorting=self.predicted_analysis.sorting.to_numpy_sorting(),
gt_name=self.gt_name,
predicted_name=self.predicted_name,
delta_time=self.delta_time,
match_score=self.match_score,
well_detected_score=self.well_detected_score,
exhaustive_gt=self.exhaustive_gt,
n_jobs=self.n_jobs,
match_mode=self.match_mode,
)

def get_match(self, gt_unit):
pass

def get_spikes_by_category(self, gt_unit, predicted_unit=None):
if predicted_unit is None:
predicted_unit = self.get_match(gt_unit)

return dict(
matched_predicted_indices=...,
matched_gt_indices=...,
only_gt_indices=...,
only_predicted_indices=...,
)

def get_performance(self, gt_unit):
pass

def get_waveforms_by_category(self, gt_unit, predicted_unit=None):
return ...


# -- h5 helper... slow reading...


Expand Down
11 changes: 11 additions & 0 deletions src/dartsort/vis/comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class UnitComparisonPlot:
kind: str
width = 1
height = 1

def draw(self, axis, sorting_comparison, ground_truth_unit_id, predicted_unit_id=None):
raise NotImplementedError


class RawWaveformComparisonPlot(UnitComparisonPlot):
pass

0 comments on commit 40561bd

Please sign in to comment.