Skip to content

Commit

Permalink
Merge pull request #12 from catalystneuro/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
luiztauffer authored Mar 19, 2024
2 parents 02d51bf + 40b5974 commit 63646b8
Show file tree
Hide file tree
Showing 24 changed files with 2,108 additions and 237 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build_and_push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ jobs:
app_dir:
- si_kilosort25
- si_kilosort3
- si_mountainsort5

steps:
- name: Checkout repository
Expand Down
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ SpikeInterface Apps for Dendro

## Dev

For each Processor, create a symbolic link to the `common` folder.

```shell
ln -s ../common common
```

Create / update spec file:
```shell
dendro make-app-spec-file --app-dir . --spec-output-file spec.json
Expand All @@ -21,6 +27,9 @@ docker push ghcr.io/catalystneuro/dendro_si_kilosort25:latest

DOCKER_BUILDKIT=1 docker build -f si_kilosort3/Dockerfile -t ghcr.io/catalystneuro/dendro_si_kilosort3:latest .
docker push ghcr.io/catalystneuro/dendro_si_kilosort3:latest

DOCKER_BUILDKIT=1 docker build -f si_mountainsort5/Dockerfile -t ghcr.io/catalystneuro/dendro_si_mountainsort5:latest .
docker push ghcr.io/catalystneuro/dendro_si_mountainsort5:latest
```

## Test locally
Expand Down
75 changes: 15 additions & 60 deletions common/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from dendro.sdk import InputFile, OutputFile
from pydantic import BaseModel, Field
from typing import List, Union
from typing import Union

from .models_preprocessing import PreprocessingContext
from .models_postprocessing import PostprocessingContext
from .models_curation import CurationContext
from .models_sorting import (
Kilosort25SortingContext,
Kilosort3SortingContext,
MountainSort5SortingContext,
SpykingCircusModel,
)


class RecordingContext(BaseModel):
Expand All @@ -21,53 +27,6 @@ class JobKwargs(BaseModel):
progress_bar: bool = Field(default=False, description='Show progress bar.')


# ------------------------------
# Sorter Models
# ------------------------------
class Kilosort25SortingContext(BaseModel):
detect_threshold: float = Field(default=6, description="Threshold for spike detection")
projection_threshold: List[int] = Field(default=[10, 4], description="Threshold on projections")
preclust_threshold: float = Field(default=8, description="Threshold crossings for pre-clustering (in PCA projection space)")
car: bool = Field(default=True, description="Enable or disable common reference")
minFR: float = Field(default=0.1, description="Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed")
minfr_goodchannels: float = Field(default=0.1, description="Minimum firing rate on a 'good' channel")
nblocks: int = Field(default=5, description="blocks for registration. 0 turns it off, 1 does rigid registration. Replaces 'datashift' option.")
sig: float = Field(default=20, description="spatial smoothness constant for registration")
freq_min: float = Field(default=150, description="High-pass filter cutoff frequency")
sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike")
lam: float = Field(default=10.0, description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)")
nPCs: int = Field(default=3, description="Number of PCA dimensions")
ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection")
nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4")
AUCsplit: float = Field(default=0.9, description="Threshold on the area under the curve (AUC) criterion for performing a split in the final step")
do_correction: bool = Field(default=True, description="If True drift registration is applied")
wave_length: float = Field(default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)")
keep_good_only: bool = Field(default=False, description="If True only 'good' units are returned")
skip_kilosort_preprocessing: bool = Field(default=False, description="Can optionaly skip the internal kilosort preprocessing")


class Kilosort3SortingContext(BaseModel):
detect_threshold: float = Field(default=6, description="Threshold for spike detection")
projection_threshold: List[int] = Field(default=[9, 9], description="Threshold on projections")
preclust_threshold: float = Field(default=8, description="Threshold crossings for pre-clustering (in PCA projection space)")
car: bool = Field(default=True, description="Enable or disable common reference")
minFR: float = Field(default=0.2, description="Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed")
minfr_goodchannels: float = Field(default=0.2, description="Minimum firing rate on a 'good' channel")
nblocks: int = Field(default=5, description="blocks for registration. 0 turns it off, 1 does rigid registration. Replaces 'datashift' option.")
sig: float = Field(default=20, description="spatial smoothness constant for registration")
freq_min: float = Field(default=300, description="High-pass filter cutoff frequency")
sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike")
lam: float = Field(default=20.0, description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)")
nPCs: int = Field(default=3, description="Number of PCA dimensions")
ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection")
nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4")
AUCsplit: float = Field(default=0.8, description="Threshold on the area under the curve (AUC) criterion for performing a split in the final step")
do_correction: bool = Field(default=True, description="If True drift registration is applied")
wave_length: float = Field(default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)")
keep_good_only: bool = Field(default=False, description="If True only 'good' units are returned")
skip_kilosort_preprocessing: bool = Field(default=False, description="Can optionaly skip the internal kilosort preprocessing")


# ------------------------------
# Pipeline Models
# ------------------------------
Expand All @@ -87,10 +46,17 @@ class PipelineFullContext(BaseModel):
run_preprocessing: bool = Field(default=True, description='Run preprocessing')
preprocessing_context: PreprocessingContext = Field(default=PreprocessingContext(), description='Preprocessing context')
run_spikesorting: bool = Field(default=True, description='Run spike sorting')
sorter_name: str = Field(
default='mountainsort5',
description="Name of the sorter to use.",
json_schema_extra={'options': ["kilosort2_5", "kilosort3", "mountainsort5"]}
)
spikesorting_context: Union[
Kilosort25SortingContext,
Kilosort3SortingContext,
] = Field(description='Sorting context')
MountainSort5SortingContext,
SpykingCircusModel,
] = Field(description='Sorting context', union_mode="left_to_right")
run_postprocessing: bool = Field(default=True, description='Run postprocessing')
postprocessing_context: PostprocessingContext = Field(default=PostprocessingContext(), description='Postprocessing context')
run_curation: bool = Field(default=True, description='Run curation')
Expand All @@ -99,17 +65,6 @@ class PipelineFullContext(BaseModel):
# visualization_context: VisualizationContext = Field(default=VisualizationContext(), description='Visualization context')



# # ------------------------------
# # Curation Models
# # ------------------------------
# class CurationKwargs:
# duplicate_threshold: float = Field(0.9, description="Threshold for duplicate units")
# isi_violations_ratio_threshold: float = Field(0.5, description="Threshold for ISI violations ratio")
# presence_ratio_threshold: float = Field(0.8, description="Threshold for presence ratio")
# amplitude_cutoff_threshold: float = Field(0.1, description="Threshold for amplitude cutoff")


# # ------------------------------
# # Visualization Models
# # ------------------------------
Expand Down
86 changes: 86 additions & 0 deletions common/models_sorting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from pydantic import BaseModel, Field
from typing import List, Union


class Kilosort25SortingContext(BaseModel):
detect_threshold: float = Field(default=6, description="Threshold for spike detection")
projection_threshold: List[int] = Field(default=[10, 4], description="Threshold on projections")
preclust_threshold: float = Field(default=8, description="Threshold crossings for pre-clustering (in PCA projection space)")
car: bool = Field(default=True, description="Enable or disable common reference")
minFR: float = Field(default=0.1, description="Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed")
minfr_goodchannels: float = Field(default=0.1, description="Minimum firing rate on a 'good' channel")
nblocks: int = Field(default=5, description="blocks for registration. 0 turns it off, 1 does rigid registration. Replaces 'datashift' option.")
sig: float = Field(default=20, description="spatial smoothness constant for registration")
freq_min: float = Field(default=150, description="High-pass filter cutoff frequency")
sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike")
lam: float = Field(default=10.0, description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)")
nPCs: int = Field(default=3, description="Number of PCA dimensions")
ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection")
nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4")
AUCsplit: float = Field(default=0.9, description="Threshold on the area under the curve (AUC) criterion for performing a split in the final step")
do_correction: bool = Field(default=True, description="If True drift registration is applied")
wave_length: float = Field(default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)")
keep_good_only: bool = Field(default=False, description="If True only 'good' units are returned")
skip_kilosort_preprocessing: bool = Field(default=False, description="Can optionaly skip the internal kilosort preprocessing")


class Kilosort3SortingContext(BaseModel):
detect_threshold: float = Field(default=6, description="Threshold for spike detection")
projection_threshold: List[int] = Field(default=[9, 9], description="Threshold on projections")
preclust_threshold: float = Field(default=8, description="Threshold crossings for pre-clustering (in PCA projection space)")
car: bool = Field(default=True, description="Enable or disable common reference")
minFR: float = Field(default=0.2, description="Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed")
minfr_goodchannels: float = Field(default=0.2, description="Minimum firing rate on a 'good' channel")
nblocks: int = Field(default=5, description="blocks for registration. 0 turns it off, 1 does rigid registration. Replaces 'datashift' option.")
sig: float = Field(default=20, description="spatial smoothness constant for registration")
freq_min: float = Field(default=300, description="High-pass filter cutoff frequency")
sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike")
lam: float = Field(default=20.0, description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)")
nPCs: int = Field(default=3, description="Number of PCA dimensions")
ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection")
nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4")
AUCsplit: float = Field(default=0.8, description="Threshold on the area under the curve (AUC) criterion for performing a split in the final step")
do_correction: bool = Field(default=True, description="If True drift registration is applied")
wave_length: float = Field(default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)")
keep_good_only: bool = Field(default=False, description="If True only 'good' units are returned")
skip_kilosort_preprocessing: bool = Field(default=False, description="Can optionaly skip the internal kilosort preprocessing")


class MountainSort5SortingContext(BaseModel):
scheme: str = Field(
default='2',
description="Sorting scheme",
json_schema_extra={'options': ["1", "2", "3"]}
)
detect_threshold: float = Field(default=5.5, description="Threshold for spike detection")
detect_sign: int = Field(default=-1, description="Sign of the peak")
detect_time_radius_msec: float = Field(default=0.5, description="Time radius in milliseconds")
snippet_T1: int = Field(default=20, description="Snippet T1")
snippet_T2: int = Field(default=20, description="Snippet T2")
npca_per_channel: int = Field(default=3, description="Number of PCA per channel")
npca_per_subdivision: int = Field(default=10, description="Number of PCA per subdivision")
snippet_mask_radius: int = Field(default=250, description="Snippet mask radius")
scheme1_detect_channel_radius: int = Field(default=150, description="Scheme 1 detect channel radius")
scheme2_phase1_detect_channel_radius: int = Field(default=200, description="Scheme 2 phase 1 detect channel radius")
scheme2_detect_channel_radius: int = Field(default=50, description="Scheme 2 detect channel radius")
scheme2_max_num_snippets_per_training_batch: int = Field(default=200, description="Scheme 2 max number of snippets per training batch")
scheme2_training_duration_sec: int = Field(default=300, description="Scheme 2 training duration in seconds")
scheme2_training_recording_sampling_mode: str = Field(default='uniform', description="Scheme 2 training recording sampling mode")
scheme3_block_duration_sec: int = Field(default=1800, description="Scheme 3 block duration in seconds")
freq_min: int = Field(default=300, description="High-pass filter cutoff frequency")
freq_max: int = Field(default=6000, description="Low-pass filter cutoff frequency")
filter: bool = Field(default=True, description="Enable or disable filter")
whiten: bool = Field(default=True, description="Enable or disable whiten")


class SpykingCircusModel(BaseModel):
detect_sign: int = Field(default=-1, description="Sign of the peak")
adjacency_radius: int = Field(default=100, description="Adjacency radius")
detect_threshold: float = Field(default=6, description="Threshold for spike detection")
template_width_ms: int = Field(default=3, description="Template width in milliseconds")
filter: bool = Field(default=True, description="Enable or disable filter")
merge_spikes: bool = Field(default=True, description="Enable or disable merge spikes")
auto_merge: float = Field(default=0.75, description="Auto merge")
num_workers: Union[int, None] = Field(default=None, description="Number of workers")
whitening_max_elts: int = Field(default=1000, description="Whitening max elements")
clustering_max_elts: int = Field(default=10000, description="Clustering max elements")
Loading

0 comments on commit 63646b8

Please sign in to comment.