From 8338688468f7f8e6b3636da07a9500ce9ded0aa5 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 18 Jan 2024 17:36:01 +0100 Subject: [PATCH 1/5] start organizing it --- common/models.py | 20 +++++++++++---- common/models_preprocessing.py | 1 + common/nwb_utils.py | 25 +++++++++++------- common/processor_pipeline.py | 20 ++++++++------- si_preprocessing/Dockerfile | 0 si_preprocessing/main.py | 46 ++++++++++++++++++++++++++++++++++ 6 files changed, 89 insertions(+), 23 deletions(-) create mode 100644 si_preprocessing/Dockerfile create mode 100644 si_preprocessing/main.py diff --git a/common/models.py b/common/models.py index 00054c5..188c5dc 100644 --- a/common/models.py +++ b/common/models.py @@ -8,6 +8,9 @@ class RecordingContext(BaseModel): electrical_series_path: str = Field(description='Path to the electrical series in the NWB file') + lazy_read_input: bool = Field(default=True, description='Lazy read input file') + write_recording: bool = Field(default=False, description='Write recording') + stub_test: bool = Field(default=False, description='Stub test') class JobKwargs(BaseModel): @@ -66,12 +69,17 @@ class Kilosort3SortingContext(BaseModel): # ------------------------------ # Pipeline Models # ------------------------------ -class PipelineContext(BaseModel): +class PipelinePreprocessingContext(BaseModel): + input: InputFile = Field(description='Input NWB file') + output: OutputFile = Field(description='Output NWB file') + job_kwargs: JobKwargs = Field(default=JobKwargs(), description='Job kwargs') + recording_context: RecordingContext = Field(description='Recording context') + preprocessing_context: PreprocessingContext = Field(default=PreprocessingContext()) + + +class PipelineFullContext(BaseModel): input: InputFile = Field(description='Input NWB file') output: OutputFile = Field(description='Output NWB file') - lazy_read_input: bool = Field(default=True, description='Lazy read input file') - write_recording: bool = Field(default=False, description='Write recording') - stub_test: bool = Field(default=False, description='Stub test') job_kwargs: JobKwargs = Field(default=JobKwargs(), description='Job kwargs') recording_context: RecordingContext = Field(description='Recording context') run_preprocessing: bool = Field(default=True, description='Run preprocessing') @@ -83,8 +91,10 @@ class PipelineContext(BaseModel): ] = Field(description='Sorting context') run_postprocessing: bool = Field(default=True, description='Run postprocessing') postprocessing_context: PostprocessingContext = Field(default=PostprocessingContext(), description='Postprocessing context') + run_curation: bool = Field(default=True, description='Run curation') # curation_context: CurationContext = Field(default=CurationContext(), description='Curation context') - + run_visualization: bool = Field(default=True, description='Run visualization') + # visualization_context: VisualizationContext = Field(default=VisualizationContext(), description='Visualization context') diff --git a/common/models_preprocessing.py b/common/models_preprocessing.py index 7e720be..33b9300 100644 --- a/common/models_preprocessing.py +++ b/common/models_preprocessing.py @@ -163,6 +163,7 @@ class MotionCorrection(BaseModel): class PreprocessingContext(BaseModel): + add_preprocessed_to_output_nwb: bool = Field(default=False, description="Whether to add preprocessed data to output NWB file or not") preprocessing_strategy: str = Field(default="cmr", description="Strategy for preprocessing") highpass_filter: HighpassFilter = Field(default=HighpassFilter(), description="Highpass filter") phase_shift: PhaseShift = Field(default=PhaseShift(), description="Phase shift") diff --git a/common/nwb_utils.py b/common/nwb_utils.py index 66cff2d..ca38a7b 100644 --- a/common/nwb_utils.py +++ b/common/nwb_utils.py @@ -1,4 +1,4 @@ -from typing import Union, List +from typing import Union, List, Optional # from neuroconv.tools.spikeinterface import write_sorting from pynwb import NWBFile from pynwb.file import Subject @@ -103,7 +103,12 @@ def get_traces(self, start_frame: int, end_frame: int, channel_indices: Union[Li return self._electrical_series_data[start_frame:end_frame, channel_indices] -def create_sorting_out_nwb_file(nwbfile_original, sorting: si.BaseSorting, sorting_out_fname: str): +def create_sorting_out_nwb_file( + nwbfile_original, + recording: Optional[si.BaseRecording] = None, + sorting: Optional[si.BaseSorting] = None, + output_fname: Optional[str] = None +): nwbfile = NWBFile( session_description=nwbfile_original.session_description + " - spike sorting results.", identifier=str(uuid4()), @@ -123,15 +128,17 @@ def create_sorting_out_nwb_file(nwbfile_original, sorting: si.BaseSorting, sorti ) ) - for ii, unit_id in enumerate(sorting.get_unit_ids()): - st = sorting.get_unit_spike_train(unit_id) / sorting.get_sampling_frequency() - nwbfile.add_unit( - id=ii + 1, # must be an int - spike_times=st - ) + # Add sorting + if sorting is not None: + for ii, unit_id in enumerate(sorting.get_unit_ids()): + st = sorting.get_unit_spike_train(unit_id) / sorting.get_sampling_frequency() + nwbfile.add_unit( + id=ii + 1, # must be an int + spike_times=st + ) # Write the nwb file - with pynwb.NWBHDF5IO(path=sorting_out_fname, mode='w') as io: + with pynwb.NWBHDF5IO(path=output_fname, mode='w') as io: io.write(container=nwbfile, cache_spec=True) # write_sorting( diff --git a/common/processor_pipeline.py b/common/processor_pipeline.py index 5c4289b..f1c9527 100644 --- a/common/processor_pipeline.py +++ b/common/processor_pipeline.py @@ -6,7 +6,7 @@ import h5py import logging -from .models import PipelineContext +from .models import PipelineFullContext from .nwb_utils import create_sorting_out_nwb_file @@ -16,7 +16,7 @@ os.environ["OPENBLAS_NUM_THREADS"] = "1" -def run_pipeline(context: PipelineContext): +def run_pipeline(context: PipelineFullContext): """ Runs the spikeinterface pipeline. @@ -31,7 +31,7 @@ def run_pipeline(context: PipelineContext): # Create SI recording from InputFile logger.info('Opening remote input file') - download = not context.lazy_read_input + download = not context.recording_context.lazy_read_input ff = context.input.get_file(download=download) logger.info('Creating input recording') @@ -42,14 +42,14 @@ def run_pipeline(context: PipelineContext): # stream_mode="remfile" ) - if context.stub_test: + if context.recording_context.stub_test: logger.info('Running in stub test mode') n_frames = int(min(300_000, recording.get_num_frames())) recording = recording.frame_slice(start_frame=0, end_frame=n_frames) logger.info(recording) - if context.write_recording: + if context.recording_context.write_recording: logger.info('Writing binary recording') recording = recording.save(folder=scratch_folder / "recording") @@ -102,7 +102,7 @@ def run_pipeline(context: PipelineContext): # Run pipeline logger.info('Running pipeline') - _, sorting, _ = si_pipeline.run_pipeline( + recording_preprocessed, sorting, waveform_extractor = si_pipeline.run_pipeline( recording=recording, scratch_folder="./scratch/", results_folder="./results/", @@ -124,13 +124,15 @@ def run_pipeline(context: PipelineContext): if not os.path.exists('output'): os.mkdir('output') - sorting_out_fname = 'output/sorting.nwb' + + output_fname = 'output/output.nwb' create_sorting_out_nwb_file( nwbfile_original=nwbfile_rec, + recording=recording_preprocessed, sorting=sorting, - sorting_out_fname=sorting_out_fname + output_fname=output_fname ) logger.info('Uploading output NWB file') - context.output.upload(sorting_out_fname) + context.output.upload(output_fname) diff --git a/si_preprocessing/Dockerfile b/si_preprocessing/Dockerfile new file mode 100644 index 0000000..e69de29 diff --git a/si_preprocessing/main.py b/si_preprocessing/main.py new file mode 100644 index 0000000..ca7b8ec --- /dev/null +++ b/si_preprocessing/main.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +from dendro.sdk import App, ProcessorBase +from common.models import PipelineFullContext, PipelinePreprocessingContext +from common.processor_pipeline import run_pipeline + + +app_name = 'si_preprocessing' + +app = App( + name=app_name, + description="Spike Interface Pipeline - Preprocessing", + app_image=f"ghcr.io/catalystneuro/dendro_{app_name}", + app_executable="/app/main.py" +) + + +class PipelineProcessor(ProcessorBase): + name = 'spikeinterface_pipeline_preprocessing' + label = 'SpikeInterface Pipeline - Preprocessing' + description = "SpikeInterface Pipeline Processor for Preprocessing tasks" + tags = ['spike_interface', 'preprocessing', 'electrophysiology', 'pipeline'] + attributes = { + 'wip': True + } + + @staticmethod + def run(context: PipelinePreprocessingContext): + context_preprocessing = context.model_dump() + context_preprocessing['preprocessing_context']['add_preprocessed_to_output_nwb'] = True + context_full = PipelineFullContext( + run_preprocessing=True, + run_spikesorting=False, + run_postprocessing=False, + run_curation=False, + run_visualization=False, + **context_preprocessing + ) + run_pipeline(context_full) + + +app.add_processor(PipelineProcessor) + + +if __name__ == '__main__': + app.run() From 98334f93760060560790acc9a59b7b824009e071 Mon Sep 17 00:00:00 2001 From: luiz Date: Tue, 6 Feb 2024 13:51:25 -0300 Subject: [PATCH 2/5] fix model --- common/processor_pipeline.py | 2 +- si_kilosort25/main.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/processor_pipeline.py b/common/processor_pipeline.py index 8743844..87c5eb1 100644 --- a/common/processor_pipeline.py +++ b/common/processor_pipeline.py @@ -44,7 +44,7 @@ def run_pipeline(context: PipelineFullContext): if context.recording_context.stub_test: logger.info('Running in stub test mode') - stub_test_num_frames = context.stub_test_duration_sec * recording.get_sampling_frequency() + stub_test_num_frames = context.recording_context.stub_test_duration_sec * recording.get_sampling_frequency() n_frames = int(min(stub_test_num_frames, recording.get_num_frames())) recording = recording.frame_slice(start_frame=0, end_frame=n_frames) diff --git a/si_kilosort25/main.py b/si_kilosort25/main.py index cf86292..53e759f 100755 --- a/si_kilosort25/main.py +++ b/si_kilosort25/main.py @@ -4,7 +4,7 @@ from pydantic import Field from common.models import ( Kilosort25SortingContext, - PipelineContext as CommonPipelineContext + PipelineFullContext ) from common.processor_pipeline import run_pipeline @@ -20,7 +20,7 @@ # We need to overwrite this with the specific sorter, to generate the correct forms -class PipelineContext(CommonPipelineContext): +class PipelineContext(PipelineFullContext): spikesorting_context: Kilosort25SortingContext = Field(default=Kilosort25SortingContext()) From a3a866d28743118794234f38cdd9007679eff5d3 Mon Sep 17 00:00:00 2001 From: luiz Date: Tue, 6 Feb 2024 14:05:09 -0300 Subject: [PATCH 3/5] args --- common/processor_pipeline.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/common/processor_pipeline.py b/common/processor_pipeline.py index 87c5eb1..19b65e9 100644 --- a/common/processor_pipeline.py +++ b/common/processor_pipeline.py @@ -101,19 +101,29 @@ def run_pipeline(context: PipelineFullContext): qm_list.append('nn_noise_overlap') postprocessing_params['quality_metrics']['metric_names'] = qm_list + # Curation params + run_curation = context.run_curation + + # Visualization params + run_visualization = context.run_visualization + # Run pipeline logger.info('Running pipeline') - recording_preprocessed, sorting, waveform_extractor = si_pipeline.run_pipeline( + recording_preprocessed, sorting, waveform_extractor, sorting_curated, visualization_output = si_pipeline.run_pipeline( recording=recording, scratch_folder="./scratch/", results_folder="./results/", job_kwargs=job_kwargs, - run_preprocessing=run_preprocessing, preprocessing_params=preprocessing_params, - run_spikesorting=run_spikesorting, spikesorting_params=spikesorting_params, - run_postprocessing=run_postprocessing, postprocessing_params=postprocessing_params, + # curation_params=, + # visualization_params=, + run_preprocessing=run_preprocessing, + run_spikesorting=run_spikesorting, + run_postprocessing=run_postprocessing, + run_curation=run_curation, + run_visualization=run_visualization ) # Upload output file From ebc6addd3f8694aa773581fa00f07fe07a70e038 Mon Sep 17 00:00:00 2001 From: luiz Date: Tue, 6 Feb 2024 14:09:51 -0300 Subject: [PATCH 4/5] path --- common/processor_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/processor_pipeline.py b/common/processor_pipeline.py index 19b65e9..b1a4da5 100644 --- a/common/processor_pipeline.py +++ b/common/processor_pipeline.py @@ -37,7 +37,7 @@ def run_pipeline(context: PipelineFullContext): logger.info('Creating input recording') recording = NwbRecordingExtractor( file=ff, - electrical_series_location=context.recording_context.electrical_series_path, + electrical_series_path=context.recording_context.electrical_series_path, # file_path=context.input.get_url(), # stream_mode="remfile" ) From 255841d2420d3b1160475147d6ee02e8b1e3b022 Mon Sep 17 00:00:00 2001 From: luiz Date: Tue, 6 Feb 2024 14:24:44 -0300 Subject: [PATCH 5/5] curation --- common/models.py | 3 ++- common/models_curation.py | 11 +++++++++++ common/processor_pipeline.py | 5 +++-- 3 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 common/models_curation.py diff --git a/common/models.py b/common/models.py index e5cf2cd..161b74e 100644 --- a/common/models.py +++ b/common/models.py @@ -4,6 +4,7 @@ from .models_preprocessing import PreprocessingContext from .models_postprocessing import PostprocessingContext +from .models_curation import CurationContext class RecordingContext(BaseModel): @@ -93,7 +94,7 @@ class PipelineFullContext(BaseModel): run_postprocessing: bool = Field(default=True, description='Run postprocessing') postprocessing_context: PostprocessingContext = Field(default=PostprocessingContext(), description='Postprocessing context') run_curation: bool = Field(default=True, description='Run curation') - # curation_context: CurationContext = Field(default=CurationContext(), description='Curation context') + curation_context: CurationContext = Field(default=CurationContext(), description='Curation context') run_visualization: bool = Field(default=True, description='Run visualization') # visualization_context: VisualizationContext = Field(default=VisualizationContext(), description='Visualization context') diff --git a/common/models_curation.py b/common/models_curation.py new file mode 100644 index 0000000..2ca7e01 --- /dev/null +++ b/common/models_curation.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel, Field + + +class CurationContext(BaseModel): + curation_query: str = Field( + default="isi_violations_ratio < 0.5 and amplitude_cutoff < 0.1 and presence_ratio > 0.8", + description=( + "Query to select units to keep after curation. " + "Default is 'isi_violations_ratio < 0.5 and amplitude_cutoff < 0.1 and presence_ratio > 0.8'." + ) + ) \ No newline at end of file diff --git a/common/processor_pipeline.py b/common/processor_pipeline.py index b1a4da5..3efa183 100644 --- a/common/processor_pipeline.py +++ b/common/processor_pipeline.py @@ -103,6 +103,7 @@ def run_pipeline(context: PipelineFullContext): # Curation params run_curation = context.run_curation + curation_params = context.curation_context.model_dump() # Visualization params run_visualization = context.run_visualization @@ -117,7 +118,7 @@ def run_pipeline(context: PipelineFullContext): preprocessing_params=preprocessing_params, spikesorting_params=spikesorting_params, postprocessing_params=postprocessing_params, - # curation_params=, + curation_params=curation_params, # visualization_params=, run_preprocessing=run_preprocessing, run_spikesorting=run_spikesorting, @@ -141,7 +142,7 @@ def run_pipeline(context: PipelineFullContext): create_sorting_out_nwb_file( nwbfile_original=nwbfile_rec, recording=recording_preprocessed, - sorting=sorting, + sorting=sorting_curated if run_curation else sorting, output_fname=output_fname )