diff --git a/.github/workflows/build_and_push.yaml b/.github/workflows/build_and_push.yaml new file mode 100644 index 0000000..3784d8f --- /dev/null +++ b/.github/workflows/build_and_push.yaml @@ -0,0 +1,63 @@ +name: Build and Push Docker Images + +on: + push: + branches: + - dev + paths: + - '*/**' + +jobs: + build: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + app_dir: + - si_kilosort25 + - si_kilosort3 + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Set up Docker Builder + uses: docker/setup-buildx-action@v2 + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + # username: YOUR_GITHUB_USERNAME # Replace with your GitHub username + # password: ${{ secrets.GHCR_PAT }} # Use the PAT secret + + - name: Pip install dendro + run: | + python -m pip install --upgrade pip + git clone https://github.com/flatironinstitute/dendro.git + cd dendro/python + git pull + pip install . + + - name: Generate Spec File and Build Docker Image + # if: contains(github.event.head_commit.modified, matrix.app_dir) + run: | + cd ${{ matrix.app_dir }} + dendro make-app-spec-file --app-dir . --spec-output-file spec.json + docker build -t ghcr.io/catalystneuro/dendro_${{ matrix.app_dir }}:latest . + docker push ghcr.io/catalystneuro/dendro_${{ matrix.app_dir }}:latest + + - name: Commit files + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + GIT_STATUS=$(git status -s) + [[ ! -z "$GIT_STATUS" ]] && git add ${{ matrix.app_dir }}/spec.json && git commit -m "update spec.json" -a || echo "No changes to commit" + + - name: Push changes + uses: ad-m/github-push-action@master + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + branch: ${{ github.ref }} diff --git a/README.md b/README.md index 4780248..7c180d9 100644 --- a/README.md +++ b/README.md @@ -7,4 +7,10 @@ SpikeInterface Apps for Dendro Build single App image: ```shell DOCKER_BUILDKIT=1 docker build -t . +``` + +Examples: +```shell +DOCKER_BUILDKIT=1 docker build -t ghcr.io/catalystneuro/dendro_si_kilosort25:latest . +docker push ghcr.io/catalystneuro/dendro_si_kilosort25:latest ``` \ No newline at end of file diff --git a/si_kilosort25/.gitignore b/si_kilosort25/.gitignore new file mode 100644 index 0000000..dd736b8 --- /dev/null +++ b/si_kilosort25/.gitignore @@ -0,0 +1,3 @@ +results/ +scratch/ +output/ \ No newline at end of file diff --git a/si_kilosort25/Dockerfile b/si_kilosort25/Dockerfile index 17daeaa..a5dffaf 100644 --- a/si_kilosort25/Dockerfile +++ b/si_kilosort25/Dockerfile @@ -6,6 +6,18 @@ RUN git clone https://github.com/flatironinstitute/dendro.git && \ cd dendro/python && \ pip install -e . +# Install spikeinterface-pipelines from source, for now +RUN git clone https://github.com/SpikeInterface/spikeinterface_pipelines.git && \ + cd spikeinterface_pipelines && \ + git checkout dev && \ + pip install -e . + +# Install spikeinterface from source, for now +RUN git clone https://github.com/SpikeInterface/spikeinterface.git && \ + cd spikeinterface && \ + # git checkout dev && \ + pip install -e .[full] + # Copy files into the container WORKDIR /app COPY *.py ./ diff --git a/si_kilosort25/main.py b/si_kilosort25/main.py index 1e1a5a9..b3f91df 100644 --- a/si_kilosort25/main.py +++ b/si_kilosort25/main.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + from dendro.sdk import App from processor_pipeline import PipelineProcessor @@ -8,7 +10,7 @@ app = App( name=app_name, description="Spike Interface Pipeline - Kilosort2.5", - app_image=f"ghcr.io/catalystneuro/{app_name}", + app_image=f"ghcr.io/catalystneuro/dendro_{app_name}", app_executable="/app/main.py" ) diff --git a/si_kilosort25/models.py b/si_kilosort25/models.py index f75eadb..c9cb5b1 100644 --- a/si_kilosort25/models.py +++ b/si_kilosort25/models.py @@ -46,13 +46,20 @@ class HighpassSpatialFilter(BaseModel): highpass_butter_wn: float = Field(default=0.01, description="Natural frequency for the Butterworth filter") +class MotionCorrection(BaseModel): + compute: bool = Field(default=True, description="Whether to compute motion correction") + apply: bool = Field(default=False, description="Whether to apply motion correction") + preset: str = Field(default="nonrigid_accurate", description="Preset for motion correction") + + class PreprocessingContext(BaseModel): + preprocessing_strategy: str = Field(default="cmr", description="Strategy for preprocessing") highpass_filter: HighpassFilter = Field(default=HighpassFilter(), description="Highpass filter") phase_shift: PhaseShift = Field(default=PhaseShift(), description="Phase shift") detect_bad_channels: DetectBadChannels = Field(default=DetectBadChannels(), description="Detect bad channels") common_reference: CommonReference = Field(default=CommonReference(), description="Common reference") highpass_spatial_filter: HighpassSpatialFilter = Field(default=HighpassSpatialFilter(), description="Highpass spatial filter") - preprocessing_strategy: str = Field(default="cmr", description="Strategy for preprocessing") + motion_correction: MotionCorrection = Field(default=MotionCorrection(), description="Motion correction") remove_out_channels: bool = Field(default=False, description="Flag to remove out channels") remove_bad_channels: bool = Field(default=False, description="Flag to remove bad channels") max_bad_channel_fraction_to_remove: float = Field(default=1.1, description="Maximum fraction of bad channels to remove") @@ -72,19 +79,20 @@ class Kilosort25SortingContext(BaseModel): sig: float = Field(default=20, description="spatial smoothness constant for registration") freq_min: float = Field(default=150, description="High-pass filter cutoff frequency") sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike") + lam: float = Field(default=10.0, description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)") nPCs: int = Field(default=3, description="Number of PCA dimensions") ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection") nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4") - NT: int = Field(default=-1, description='Batch size (if -1 it is automatically computed)') + # NT: int = Field(default=-1, description='Batch size (if -1 it is automatically computed)') AUCsplit: float = Field(default=0.9, description="Threshold on the area under the curve (AUC) criterion for performing a split in the final step") do_correction: bool = Field(default=True, description="If True drift registration is applied") wave_length: float = Field(default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)") - keep_good_only: bool = Field(default=True, description="If True only 'good' units are returned") + keep_good_only: bool = Field(default=False, description="If True only 'good' units are returned") skip_kilosort_preprocessing: bool = Field(default=False, description="Can optionaly skip the internal kilosort preprocessing") - scaleproc: int = Field(default=-1, description="int16 scaling of whitened data, if -1 set to 200.") + # scaleproc: int = Field(default=-1, description="int16 scaling of whitened data, if -1 set to 200.") -class SortingContext(BaseModel): +class SpikeSortingContext(BaseModel): sorter_name: str = Field(default="kilosort2_5", description="Name of the sorter to use.") sorter_kwargs: Kilosort25SortingContext = Field(default=Kilosort25SortingContext(), description="Sorter specific kwargs.") @@ -109,11 +117,16 @@ class CurationContext(BaseModel): class PipelineContext(BaseModel): input: InputFile = Field(description='Input NWB file') output: OutputFile = Field(description='Output NWB file') + lazy_read_input: bool = Field(default=True, description='Lazy read input file') + stub_test: bool = Field(default=False, description='Stub test') recording_context: RecordingContext = Field(description='Recording context') - preprocessing_context: PreprocessingContext = Field(description='Preprocessing context') - sorting_context: SortingContext = Field(description='Sorting context') - postprocessing_context: PostprocessingContext = Field(description='Postprocessing context') - curation_context: CurationContext = Field(description='Curation context') + run_preprocessing: bool = Field(default=True, description='Run preprocessing') + preprocessing_context: PreprocessingContext = Field(default=PreprocessingContext(), description='Preprocessing context') + run_spikesorting: bool = Field(default=True, description='Run spike sorting') + spikesorting_context: SpikeSortingContext = Field(default=SpikeSortingContext(), description='Sorting context') + run_postprocessing: bool = Field(default=True, description='Run postprocessing') + # postprocessing_context: PostprocessingContext = Field(default=PostprocessingContext(), description='Postprocessing context') + # curation_context: CurationContext = Field(default=CurationContext(), description='Curation context') # # ------------------------------ diff --git a/si_kilosort25/nwb_utils.py b/si_kilosort25/nwb_utils.py index c6b0016..66cff2d 100644 --- a/si_kilosort25/nwb_utils.py +++ b/si_kilosort25/nwb_utils.py @@ -1,16 +1,23 @@ from typing import Union, List -from neuroconv.tools.spikeinterface import write_sorting +# from neuroconv.tools.spikeinterface import write_sorting from pynwb import NWBFile from pynwb.file import Subject from uuid import uuid4 import numpy as np import h5py import spikeinterface as si +import pynwb class NwbRecording(si.BaseRecording): - def __init__(self, file: h5py.File, electrical_series_path: str) -> None: - electrical_series: h5py.Group = file[electrical_series_path] + def __init__( + self, + file, # file-like object + electrical_series_path: str + ) -> None: + h5_file = h5py.File(file, 'r') + + electrical_series: h5py.Group = h5_file[electrical_series_path] electrical_series_data = electrical_series['data'] dtype = electrical_series_data.dtype @@ -24,26 +31,30 @@ def __init__(self, file: h5py.File, electrical_series_path: str) -> None: # Get channel ids electrode_indices = electrical_series['electrodes'][:] - electrodes_table = file['/general/extracellular_ephys/electrodes'] + electrodes_table = h5_file['/general/extracellular_ephys/electrodes'] channel_ids = [electrodes_table['id'][i] for i in electrode_indices] - si.BaseRecording.__init__(self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype) + super().__init__( + channel_ids=channel_ids, + sampling_frequency=sampling_frequency, + dtype=dtype + ) # Set electrode locations - if 'x' in electrodes_table: - channel_loc_x = [electrodes_table['x'][i] for i in electrode_indices] - channel_loc_y = [electrodes_table['y'][i] for i in electrode_indices] - if 'z' in electrodes_table: - channel_loc_z = [electrodes_table['z'][i] for i in electrode_indices] - else: - channel_loc_z = None - elif 'rel_x' in electrodes_table: + if 'rel_x' in electrodes_table: channel_loc_x = [electrodes_table['rel_x'][i] for i in electrode_indices] channel_loc_y = [electrodes_table['rel_y'][i] for i in electrode_indices] if 'rel_z' in electrodes_table: channel_loc_z = [electrodes_table['rel_z'][i] for i in electrode_indices] else: channel_loc_z = None + elif 'x' in electrodes_table: + channel_loc_x = [electrodes_table['x'][i] for i in electrode_indices] + channel_loc_y = [electrodes_table['y'][i] for i in electrode_indices] + if 'z' in electrodes_table: + channel_loc_z = [electrodes_table['z'][i] for i in electrode_indices] + else: + channel_loc_z = None else: channel_loc_x = None channel_loc_y = None @@ -58,6 +69,18 @@ def __init__(self, file: h5py.File, electrical_series_path: str) -> None: locations[i, 2] = channel_loc_z[electrode_index] self.set_dummy_probe_from_locations(locations) + # Extractors channel groups must be integers, but Nwb electrodes group_name can be strings + if "group_name" in electrodes_table: + unique_electrode_group_names = list(np.unique(electrodes_table["group_name"][:])) + print(unique_electrode_group_names) + + groups = [] + for electrode_index in electrode_indices: + group_name = electrodes_table["group_name"][electrode_index] + group_id = unique_electrode_group_names.index(group_name) + groups.append(group_id) + self.set_channel_groups(groups) + recording_segment = NwbRecordingSegment( electrical_series_data=electrical_series_data, sampling_frequency=sampling_frequency @@ -68,7 +91,7 @@ def __init__(self, file: h5py.File, electrical_series_path: str) -> None: class NwbRecordingSegment(si.BaseRecordingSegment): def __init__(self, electrical_series_data: h5py.Dataset, sampling_frequency: float) -> None: self._electrical_series_data = electrical_series_data - si.BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency) + super().__init__(sampling_frequency=sampling_frequency) def get_num_samples(self) -> int: return self._electrical_series_data.shape[0] @@ -80,7 +103,7 @@ def get_traces(self, start_frame: int, end_frame: int, channel_indices: Union[Li return self._electrical_series_data[start_frame:end_frame, channel_indices] -def create_sorting_out_nwb_file(nwbfile_original: NWBFile, sorting: si.BaseSorting, sorting_out_fname: str): +def create_sorting_out_nwb_file(nwbfile_original, sorting: si.BaseSorting, sorting_out_fname: str): nwbfile = NWBFile( session_description=nwbfile_original.session_description + " - spike sorting results.", identifier=str(uuid4()), @@ -91,19 +114,29 @@ def create_sorting_out_nwb_file(nwbfile_original: NWBFile, sorting: si.BaseSorti institution=nwbfile_original.institution, experiment_description=nwbfile_original.experiment_description, related_publications=nwbfile_original.related_publications, + subject=Subject( + subject_id=nwbfile_original.subject.subject_id, + age=nwbfile_original.subject.age, + description=nwbfile_original.subject.description, + species=nwbfile_original.subject.species, + sex=nwbfile_original.subject.sex, + ) ) - subject = Subject( - subject_id=nwbfile_original.subject.subject_id, - age=nwbfile_original.subject.age, - description=nwbfile_original.subject.description, - species=nwbfile_original.subject.species, - sex=nwbfile_original.subject.sex, - ) - nwbfile.subject = subject - write_sorting( - sorting=sorting, - nwbfile=nwbfile, - nwbfile_path=sorting_out_fname, - overwrite=True - ) + for ii, unit_id in enumerate(sorting.get_unit_ids()): + st = sorting.get_unit_spike_train(unit_id) / sorting.get_sampling_frequency() + nwbfile.add_unit( + id=ii + 1, # must be an int + spike_times=st + ) + + # Write the nwb file + with pynwb.NWBHDF5IO(path=sorting_out_fname, mode='w') as io: + io.write(container=nwbfile, cache_spec=True) + + # write_sorting( + # sorting=sorting, + # nwbfile=nwbfile, + # nwbfile_path=sorting_out_fname, + # overwrite=True + # ) diff --git a/si_kilosort25/processor_pipeline.py b/si_kilosort25/processor_pipeline.py index b0ee668..3b4defb 100644 --- a/si_kilosort25/processor_pipeline.py +++ b/si_kilosort25/processor_pipeline.py @@ -1,17 +1,24 @@ from dendro.sdk import ProcessorBase from spikeinterface_pipelines import pipeline as si_pipeline +from spikeinterface.extractors import NwbRecordingExtractor import os import pynwb +import h5py +import logging from models import PipelineContext -from nwb_utils import NwbRecording, create_sorting_out_nwb_file +from nwb_utils import create_sorting_out_nwb_file + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) class PipelineProcessor(ProcessorBase): - name = 'spikeinterface_pipeline' - label = 'SpikeInterface Pipeline' - description = "SpikeInterface Pipeline Processor" - tags = ['spike_interface', 'electrophysiology', 'preprocessing', 'spike_sorter', 'postprocessing'] + name = 'spikeinterface_pipeline_ks25' + label = 'SpikeInterface Pipeline - Kilosort 2.5' + description = "SpikeInterface Pipeline Processor for Kilosort 2.5" + tags = ['spike_sorting', 'spike_interface', 'electrophysiology', 'pipeline'] attributes = { 'wip': True } @@ -20,25 +27,60 @@ class PipelineProcessor(ProcessorBase): def run(context: PipelineContext): # Create SI recording from InputFile - input = context.input_file - recording = NwbRecording( - file=input.get_h5py_file(), - electrical_series_path=context.recording_context.electrical_series_path + logger.info('Opening remote input file') + download = not context.lazy_read_input + ff = context.input.get_file(download=download) + + logger.info('Creating input recording') + recording = NwbRecordingExtractor( + file=ff, + electrical_series_location=context.recording_context.electrical_series_path, + # file_path=context.input.get_url(), + # stream_mode="remfile" ) + if context.stub_test: + logger.info('Running in stub test mode') + n_frames = int(min(300_000, recording.get_num_frames())) + recording = recording.frame_slice(start_frame=0, end_frame=n_frames) + + logger.info(recording) + # TODO - run pipeline - _, sorting = si_pipeline.pipeline( + job_kwargs = { + 'n_jobs': -1, + 'chunk_duration': '1s', + 'progress_bar': False + } + + run_preprocessing = context.run_preprocessing + preprocessing_params = context.preprocessing_context.model_dump() + + run_spikesorting = context.run_spikesorting + spikesorting_params = context.spikesorting_context.model_dump() + + run_postprocessing = context.run_postprocessing + # postprocessing_params = context.postprocessing_context.model_dump() + + logger.info('Running pipeline') + _, sorting, _ = si_pipeline.run_pipeline( recording=recording, - results_path="./results/", - preprocessing_params=context.preprocessing_params, - sorting_params=context.sorting_context, - # postprocessing_params=context.postprocessing_params, - # run_preprocessing=context.run_preprocessing, + scratch_folder="./scratch/", + results_folder="./results/", + job_kwargs=job_kwargs, + run_preprocessing=run_preprocessing, + preprocessing_params=preprocessing_params, + run_spikesorting=run_spikesorting, + spikesorting_params=spikesorting_params, + run_postprocessing=run_postprocessing, + # postprocessing_params=postprocessing_params, ) # TODO - upload output file - print('Writing output NWB file') - with pynwb.NWBHDF5IO(file=input.get_h5py_file(), mode='r', load_namespaces=True) as io: + logger.info('Writing output NWB file') + h5_file = h5py.File(ff, 'r') + with pynwb.NWBHDF5IO(file=h5_file, mode='r', load_namespaces=True) as io: + # with pynwb.NWBHDF5IO(file=input.get_h5py_file(), mode='r', load_namespaces=True) as io: nwbfile_rec = io.read() if not os.path.exists('output'): @@ -46,10 +88,10 @@ def run(context: PipelineContext): sorting_out_fname = 'output/sorting.nwb' create_sorting_out_nwb_file( - nwbfile_rec=nwbfile_rec, + nwbfile_original=nwbfile_rec, sorting=sorting, sorting_out_fname=sorting_out_fname ) - print('Uploading output NWB file') - context.output.set(sorting_out_fname) + logger.info('Uploading output NWB file') + context.output.upload(sorting_out_fname) diff --git a/si_kilosort25/requirements.txt b/si_kilosort25/requirements.txt index da8a948..153f45e 100644 --- a/si_kilosort25/requirements.txt +++ b/si_kilosort25/requirements.txt @@ -1,3 +1,4 @@ -spikeinterface[full]==0.98.2 +pynwb +# spikeinterface[full]==0.99.1 # spikeinterface_pipelines # dendro diff --git a/si_kilosort25/sample_context_1.yaml b/si_kilosort25/sample_context_1.yaml new file mode 100644 index 0000000..1e498c9 --- /dev/null +++ b/si_kilosort25/sample_context_1.yaml @@ -0,0 +1,11 @@ +input: https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/eb0/202/eb020241-616b-47ce-8d52-76151fe9e90d +output: ./output/sorting_upload.nwb +lazy_read_input: true +stub_test: true +recording_context: + electrical_series_path: /acquisition/ElectricalSeriesRaw +run_preprocessing: false +run_spikesorting: true +spikesorting_context: + do_correction: false +run_postprocessing: false diff --git a/si_kilosort25/sample_context_2.yaml b/si_kilosort25/sample_context_2.yaml new file mode 100644 index 0000000..0abfbe3 --- /dev/null +++ b/si_kilosort25/sample_context_2.yaml @@ -0,0 +1,6 @@ +input: https://dandi-api-staging-dandisets.s3.amazonaws.com/blobs/1ed/41e/1ed41e35-8445-4608-b327-b30f74388bea +output: ./output/sorting.nwb +lazy_read_input: true +stub_test: false +recording_context: + electrical_series_path: /acquisition/ElectricalSeriesRaw diff --git a/si_kilosort25/spec.json b/si_kilosort25/spec.json index 6973f47..ce6b714 100644 --- a/si_kilosort25/spec.json +++ b/si_kilosort25/spec.json @@ -1,180 +1,221 @@ { "name": "si_kilosort25", "description": "Spike Interface Pipeline - Kilosort2.5", - "appImage": "ghcr.io/catalystneuro/si_kilosort25", + "appImage": "ghcr.io/catalystneuro/dendro_si_kilosort25", "appExecutable": "/app/main.py", "executable": "/app/main.py", "processors": [ { - "name": "spikeinterface_pipeline", - "description": "SpikeInterface Pipeline Processor", - "label": "SpikeInterface Pipeline", + "name": "spikeinterface_pipeline_ks25", + "description": "SpikeInterface Pipeline Processor for Kilosort 2.5", + "label": "SpikeInterface Pipeline - Kilosort 2.5", "inputs": [ { "name": "input", - "description": "" + "description": "Input NWB file" } ], "outputs": [ { "name": "output", - "description": "" + "description": "Output NWB file" } ], "parameters": [ + { + "name": "lazy_read_input", + "description": "Lazy read input file", + "type": "bool", + "default": true + }, + { + "name": "stub_test", + "description": "Stub test", + "type": "bool", + "default": false + }, { "name": "recording_context.electrical_series_path", - "description": "", + "description": "Path to the electrical series in the NWB file", + "type": "str" + }, + { + "name": "run_preprocessing", + "description": "Run preprocessing", + "type": "bool", + "default": true + }, + { + "name": "preprocessing_context.preprocessing_strategy", + "description": "Strategy for preprocessing", "type": "str", - "default": null + "default": "cmr" }, { "name": "preprocessing_context.highpass_filter.freq_min", - "description": "", + "description": "Minimum frequency for the highpass filter", "type": "float", "default": 300.0 }, { "name": "preprocessing_context.highpass_filter.margin_ms", - "description": "", + "description": "Margin in milliseconds", "type": "float", "default": 5.0 }, { "name": "preprocessing_context.phase_shift.margin_ms", - "description": "", + "description": "Margin in milliseconds for phase shift", "type": "float", "default": 100.0 }, { "name": "preprocessing_context.detect_bad_channels.method", - "description": "", + "description": "Method to detect bad channels", "type": "str", "default": "coherence+psd" }, { "name": "preprocessing_context.detect_bad_channels.dead_channel_threshold", - "description": "", + "description": "Threshold for dead channel", "type": "float", "default": -0.5 }, { "name": "preprocessing_context.detect_bad_channels.noisy_channel_threshold", - "description": "", + "description": "Threshold for noisy channel", "type": "float", "default": 1.0 }, { "name": "preprocessing_context.detect_bad_channels.outside_channel_threshold", - "description": "", + "description": "Threshold for outside channel", "type": "float", "default": -0.3 }, { "name": "preprocessing_context.detect_bad_channels.n_neighbors", - "description": "", + "description": "Number of neighbors", "type": "int", "default": 11 }, { "name": "preprocessing_context.detect_bad_channels.seed", - "description": "", + "description": "Seed value", "type": "int", "default": 0 }, { "name": "preprocessing_context.common_reference.reference", - "description": "", + "description": "Type of reference", "type": "str", "default": "global" }, { "name": "preprocessing_context.common_reference.operator", - "description": "", + "description": "Operator used for common reference", "type": "str", "default": "median" }, { "name": "preprocessing_context.highpass_spatial_filter.n_channel_pad", - "description": "", + "description": "Number of channels to pad", "type": "int", "default": 60 }, { "name": "preprocessing_context.highpass_spatial_filter.n_channel_taper", - "description": "", + "description": "Number of channels to taper", "type": "int", "default": null }, { "name": "preprocessing_context.highpass_spatial_filter.direction", - "description": "", + "description": "Direction for the spatial filter", "type": "str", "default": "y" }, { "name": "preprocessing_context.highpass_spatial_filter.apply_agc", - "description": "", + "description": "Whether to apply automatic gain control", "type": "bool", "default": true }, { "name": "preprocessing_context.highpass_spatial_filter.agc_window_length_s", - "description": "", + "description": "Window length in seconds for AGC", "type": "float", "default": 0.01 }, { "name": "preprocessing_context.highpass_spatial_filter.highpass_butter_order", - "description": "", + "description": "Order for the Butterworth filter", "type": "int", "default": 3 }, { "name": "preprocessing_context.highpass_spatial_filter.highpass_butter_wn", - "description": "", + "description": "Natural frequency for the Butterworth filter", "type": "float", "default": 0.01 }, { - "name": "preprocessing_context.preprocessing_strategy", - "description": "", + "name": "preprocessing_context.motion_correction.compute", + "description": "Whether to compute motion correction", + "type": "bool", + "default": true + }, + { + "name": "preprocessing_context.motion_correction.apply", + "description": "Whether to apply motion correction", + "type": "bool", + "default": false + }, + { + "name": "preprocessing_context.motion_correction.preset", + "description": "Preset for motion correction", "type": "str", - "default": "cmr" + "default": "nonrigid_accurate" }, { "name": "preprocessing_context.remove_out_channels", - "description": "", + "description": "Flag to remove out channels", "type": "bool", "default": false }, { "name": "preprocessing_context.remove_bad_channels", - "description": "", + "description": "Flag to remove bad channels", "type": "bool", "default": false }, { "name": "preprocessing_context.max_bad_channel_fraction_to_remove", - "description": "", + "description": "Maximum fraction of bad channels to remove", "type": "float", "default": 1.1 }, { - "name": "sorting_context.sorter_name", - "description": "", + "name": "run_spikesorting", + "description": "Run spike sorting", + "type": "bool", + "default": true + }, + { + "name": "spikesorting_context.sorter_name", + "description": "Name of the sorter to use.", "type": "str", "default": "kilosort2_5" }, { - "name": "sorting_context.sorter_kwargs.detect_threshold", - "description": "", + "name": "spikesorting_context.sorter_kwargs.detect_threshold", + "description": "Threshold for spike detection", "type": "float", "default": 6 }, { - "name": "sorting_context.sorter_kwargs.projection_threshold", - "description": "", + "name": "spikesorting_context.sorter_kwargs.projection_threshold", + "description": "Threshold on projections", "type": "List[int]", "default": [ 10, @@ -182,112 +223,112 @@ ] }, { - "name": "sorting_context.sorter_kwargs.preclust_threshold", - "description": "", + "name": "spikesorting_context.sorter_kwargs.preclust_threshold", + "description": "Threshold crossings for pre-clustering (in PCA projection space)", "type": "float", "default": 8 }, { - "name": "sorting_context.sorter_kwargs.car", - "description": "", + "name": "spikesorting_context.sorter_kwargs.car", + "description": "Enable or disable common reference", "type": "bool", "default": true }, { - "name": "sorting_context.sorter_kwargs.minFR", - "description": "", + "name": "spikesorting_context.sorter_kwargs.minFR", + "description": "Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed", "type": "float", "default": 0.1 }, { - "name": "sorting_context.sorter_kwargs.minfr_goodchannels", - "description": "", + "name": "spikesorting_context.sorter_kwargs.minfr_goodchannels", + "description": "Minimum firing rate on a 'good' channel", "type": "float", "default": 0.1 }, { - "name": "sorting_context.sorter_kwargs.nblocks", - "description": "", + "name": "spikesorting_context.sorter_kwargs.nblocks", + "description": "blocks for registration. 0 turns it off, 1 does rigid registration. Replaces 'datashift' option.", "type": "int", "default": 5 }, { - "name": "sorting_context.sorter_kwargs.sig", - "description": "", + "name": "spikesorting_context.sorter_kwargs.sig", + "description": "spatial smoothness constant for registration", "type": "float", "default": 20 }, { - "name": "sorting_context.sorter_kwargs.freq_min", - "description": "", + "name": "spikesorting_context.sorter_kwargs.freq_min", + "description": "High-pass filter cutoff frequency", "type": "float", "default": 150 }, { - "name": "sorting_context.sorter_kwargs.sigmaMask", - "description": "", + "name": "spikesorting_context.sorter_kwargs.sigmaMask", + "description": "Spatial constant in um for computing residual variance of spike", "type": "float", "default": 30 }, { - "name": "sorting_context.sorter_kwargs.nPCs", - "description": "", + "name": "spikesorting_context.sorter_kwargs.lam", + "description": "The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)", + "type": "float", + "default": 10.0 + }, + { + "name": "spikesorting_context.sorter_kwargs.nPCs", + "description": "Number of PCA dimensions", "type": "int", "default": 3 }, { - "name": "sorting_context.sorter_kwargs.ntbuff", - "description": "", + "name": "spikesorting_context.sorter_kwargs.ntbuff", + "description": "Samples of symmetrical buffer for whitening and spike detection", "type": "int", "default": 64 }, { - "name": "sorting_context.sorter_kwargs.nfilt_factor", - "description": "", + "name": "spikesorting_context.sorter_kwargs.nfilt_factor", + "description": "Max number of clusters per good channel (even temporary ones) 4", "type": "int", "default": 4 }, { - "name": "sorting_context.sorter_kwargs.NT", - "description": "", - "type": "int", - "default": -1 - }, - { - "name": "sorting_context.sorter_kwargs.AUCsplit", - "description": "", + "name": "spikesorting_context.sorter_kwargs.AUCsplit", + "description": "Threshold on the area under the curve (AUC) criterion for performing a split in the final step", "type": "float", "default": 0.9 }, { - "name": "sorting_context.sorter_kwargs.do_correction", - "description": "", + "name": "spikesorting_context.sorter_kwargs.do_correction", + "description": "If True drift registration is applied", "type": "bool", "default": true }, { - "name": "sorting_context.sorter_kwargs.wave_length", - "description": "", + "name": "spikesorting_context.sorter_kwargs.wave_length", + "description": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)", "type": "float", "default": 61 }, { - "name": "sorting_context.sorter_kwargs.keep_good_only", - "description": "", + "name": "spikesorting_context.sorter_kwargs.keep_good_only", + "description": "If True only 'good' units are returned", "type": "bool", - "default": true + "default": false }, { - "name": "sorting_context.sorter_kwargs.skip_kilosort_preprocessing", - "description": "", + "name": "spikesorting_context.sorter_kwargs.skip_kilosort_preprocessing", + "description": "Can optionaly skip the internal kilosort preprocessing", "type": "bool", "default": false }, { - "name": "sorting_context.sorter_kwargs.scaleproc", - "description": "", - "type": "int", - "default": -1 + "name": "run_postprocessing", + "description": "Run postprocessing", + "type": "bool", + "default": true } ], "attributes": [ @@ -298,19 +339,16 @@ ], "tags": [ { - "tag": "spike_interface" - }, - { - "tag": "electrophysiology" + "tag": "spike_sorting" }, { - "tag": "preprocessing" + "tag": "spike_interface" }, { - "tag": "spike_sorter" + "tag": "electrophysiology" }, { - "tag": "postprocessing" + "tag": "pipeline" } ] } diff --git a/si_kilosort25/test_in_container.sh b/si_kilosort25/test_in_container.sh new file mode 100644 index 0000000..b889390 --- /dev/null +++ b/si_kilosort25/test_in_container.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Docker image +IMAGE="ghcr.io/catalystneuro/dendro_si_kilosort25" + +# Command to be executed inside the container +ENTRYPOINT_CMD="dendro" +ARGS="test-app-processor --app-dir . --processor spikeinterface_pipeline_ks25 --context sample_context_1.yaml" + + +# Run the Docker container +docker run --gpus all \ + -v $(pwd):/app \ + -v /mnt/shared_storage/Github/dendro/python:/src/dendro/python \ + -v /mnt/shared_storage/Github/spikeinterface_pipelines:/src/spikeinterface_pipelines \ + -w /app \ + --entrypoint "$ENTRYPOINT_CMD" \ + $IMAGE \ + $ARGS