Skip to content

Commit

Permalink
Merge pull request #6 from catalystneuro/preprocessing-app
Browse files Browse the repository at this point in the history
preprocessing app
  • Loading branch information
luiztauffer authored Feb 6, 2024
2 parents c2f9886 + 255841d commit b1fd312
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 32 deletions.
25 changes: 18 additions & 7 deletions common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@

from .models_preprocessing import PreprocessingContext
from .models_postprocessing import PostprocessingContext
from .models_curation import CurationContext


class RecordingContext(BaseModel):
electrical_series_path: str = Field(description='Path to the electrical series in the NWB file')
lazy_read_input: bool = Field(default=True, description='Lazy read input file')
write_recording: bool = Field(default=False, description='Write recording')
stub_test: bool = Field(default=False, description='Stub test')
stub_test_duration_sec: float = Field(default=300, description='Stub test duration in seconds')


class JobKwargs(BaseModel):
Expand Down Expand Up @@ -66,13 +71,17 @@ class Kilosort3SortingContext(BaseModel):
# ------------------------------
# Pipeline Models
# ------------------------------
class PipelineContext(BaseModel):
class PipelinePreprocessingContext(BaseModel):
input: InputFile = Field(description='Input NWB file')
output: OutputFile = Field(description='Output NWB file')
job_kwargs: JobKwargs = Field(default=JobKwargs(), description='Job kwargs')
recording_context: RecordingContext = Field(description='Recording context')
preprocessing_context: PreprocessingContext = Field(default=PreprocessingContext())


class PipelineFullContext(BaseModel):
input: InputFile = Field(description='Input NWB file')
output: OutputFile = Field(description='Output NWB file')
lazy_read_input: bool = Field(default=True, description='Lazy read input file')
write_recording: bool = Field(default=False, description='Write recording')
stub_test: bool = Field(default=False, description='Stub test')
stub_test_duration_sec: float = Field(default=300, description='Stub test duration in seconds')
job_kwargs: JobKwargs = Field(default=JobKwargs(), description='Job kwargs')
recording_context: RecordingContext = Field(description='Recording context')
run_preprocessing: bool = Field(default=True, description='Run preprocessing')
Expand All @@ -84,8 +93,10 @@ class PipelineContext(BaseModel):
] = Field(description='Sorting context')
run_postprocessing: bool = Field(default=True, description='Run postprocessing')
postprocessing_context: PostprocessingContext = Field(default=PostprocessingContext(), description='Postprocessing context')
# curation_context: CurationContext = Field(default=CurationContext(), description='Curation context')

run_curation: bool = Field(default=True, description='Run curation')
curation_context: CurationContext = Field(default=CurationContext(), description='Curation context')
run_visualization: bool = Field(default=True, description='Run visualization')
# visualization_context: VisualizationContext = Field(default=VisualizationContext(), description='Visualization context')



Expand Down
11 changes: 11 additions & 0 deletions common/models_curation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from pydantic import BaseModel, Field


class CurationContext(BaseModel):
curation_query: str = Field(
default="isi_violations_ratio < 0.5 and amplitude_cutoff < 0.1 and presence_ratio > 0.8",
description=(
"Query to select units to keep after curation. "
"Default is 'isi_violations_ratio < 0.5 and amplitude_cutoff < 0.1 and presence_ratio > 0.8'."
)
)
1 change: 1 addition & 0 deletions common/models_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class MotionCorrection(BaseModel):


class PreprocessingContext(BaseModel):
add_preprocessed_to_output_nwb: bool = Field(default=False, description="Whether to add preprocessed data to output NWB file or not")
preprocessing_strategy: str = Field(default="cmr", description="Strategy for preprocessing")
highpass_filter: HighpassFilter = Field(default=HighpassFilter(), description="Highpass filter")
phase_shift: PhaseShift = Field(default=PhaseShift(), description="Phase shift")
Expand Down
25 changes: 16 additions & 9 deletions common/nwb_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, List
from typing import Union, List, Optional
# from neuroconv.tools.spikeinterface import write_sorting
from pynwb import NWBFile
from pynwb.file import Subject
Expand Down Expand Up @@ -103,7 +103,12 @@ def get_traces(self, start_frame: int, end_frame: int, channel_indices: Union[Li
return self._electrical_series_data[start_frame:end_frame, channel_indices]


def create_sorting_out_nwb_file(nwbfile_original, sorting: si.BaseSorting, sorting_out_fname: str):
def create_sorting_out_nwb_file(
nwbfile_original,
recording: Optional[si.BaseRecording] = None,
sorting: Optional[si.BaseSorting] = None,
output_fname: Optional[str] = None
):
nwbfile = NWBFile(
session_description=nwbfile_original.session_description + " - spike sorting results.",
identifier=str(uuid4()),
Expand All @@ -123,15 +128,17 @@ def create_sorting_out_nwb_file(nwbfile_original, sorting: si.BaseSorting, sorti
)
)

for ii, unit_id in enumerate(sorting.get_unit_ids()):
st = sorting.get_unit_spike_train(unit_id) / sorting.get_sampling_frequency()
nwbfile.add_unit(
id=ii + 1, # must be an int
spike_times=st
)
# Add sorting
if sorting is not None:
for ii, unit_id in enumerate(sorting.get_unit_ids()):
st = sorting.get_unit_spike_train(unit_id) / sorting.get_sampling_frequency()
nwbfile.add_unit(
id=ii + 1, # must be an int
spike_times=st
)

# Write the nwb file
with pynwb.NWBHDF5IO(path=sorting_out_fname, mode='w') as io:
with pynwb.NWBHDF5IO(path=output_fname, mode='w') as io:
io.write(container=nwbfile, cache_spec=True)

# write_sorting(
Expand Down
41 changes: 27 additions & 14 deletions common/processor_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import h5py
import logging

from .models import PipelineContext
from .models import PipelineFullContext
from .nwb_utils import create_sorting_out_nwb_file


Expand All @@ -16,7 +16,7 @@
os.environ["OPENBLAS_NUM_THREADS"] = "1"


def run_pipeline(context: PipelineContext):
def run_pipeline(context: PipelineFullContext):
"""
Runs the spikeinterface pipeline.
Expand All @@ -31,7 +31,7 @@ def run_pipeline(context: PipelineContext):

# Create SI recording from InputFile
logger.info('Opening remote input file')
download = not context.lazy_read_input
download = not context.recording_context.lazy_read_input
ff = context.input.get_file(download=download)

logger.info('Creating input recording')
Expand All @@ -42,15 +42,15 @@ def run_pipeline(context: PipelineContext):
# stream_mode="remfile"
)

if context.stub_test:
if context.recording_context.stub_test:
logger.info('Running in stub test mode')
stub_test_num_frames = context.stub_test_duration_sec * recording.get_sampling_frequency()
stub_test_num_frames = context.recording_context.stub_test_duration_sec * recording.get_sampling_frequency()
n_frames = int(min(stub_test_num_frames, recording.get_num_frames()))
recording = recording.frame_slice(start_frame=0, end_frame=n_frames)

logger.info(recording)

if context.write_recording:
if context.recording_context.write_recording:
logger.info('Writing binary recording')
recording = recording.save(folder=scratch_folder / "recording")

Expand Down Expand Up @@ -101,19 +101,30 @@ def run_pipeline(context: PipelineContext):
qm_list.append('nn_noise_overlap')
postprocessing_params['quality_metrics']['metric_names'] = qm_list

# Curation params
run_curation = context.run_curation
curation_params = context.curation_context.model_dump()

# Visualization params
run_visualization = context.run_visualization

# Run pipeline
logger.info('Running pipeline')
_, sorting, _ = si_pipeline.run_pipeline(
recording_preprocessed, sorting, waveform_extractor, sorting_curated, visualization_output = si_pipeline.run_pipeline(
recording=recording,
scratch_folder="./scratch/",
results_folder="./results/",
job_kwargs=job_kwargs,
run_preprocessing=run_preprocessing,
preprocessing_params=preprocessing_params,
run_spikesorting=run_spikesorting,
spikesorting_params=spikesorting_params,
run_postprocessing=run_postprocessing,
postprocessing_params=postprocessing_params,
curation_params=curation_params,
# visualization_params=,
run_preprocessing=run_preprocessing,
run_spikesorting=run_spikesorting,
run_postprocessing=run_postprocessing,
run_curation=run_curation,
run_visualization=run_visualization
)

# Upload output file
Expand All @@ -125,13 +136,15 @@ def run_pipeline(context: PipelineContext):

if not os.path.exists('output'):
os.mkdir('output')
sorting_out_fname = 'output/sorting.nwb'

output_fname = 'output/output.nwb'

create_sorting_out_nwb_file(
nwbfile_original=nwbfile_rec,
sorting=sorting,
sorting_out_fname=sorting_out_fname
recording=recording_preprocessed,
sorting=sorting_curated if run_curation else sorting,
output_fname=output_fname
)

logger.info('Uploading output NWB file')
context.output.upload(sorting_out_fname)
context.output.upload(output_fname)
4 changes: 2 additions & 2 deletions si_kilosort25/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import Field
from common.models import (
Kilosort25SortingContext,
PipelineContext as CommonPipelineContext
PipelineFullContext
)
from common.processor_pipeline import run_pipeline

Expand All @@ -20,7 +20,7 @@


# We need to overwrite this with the specific sorter, to generate the correct forms
class PipelineContext(CommonPipelineContext):
class PipelineContext(PipelineFullContext):
spikesorting_context: Kilosort25SortingContext = Field(default=Kilosort25SortingContext())


Expand Down
Empty file added si_preprocessing/Dockerfile
Empty file.
46 changes: 46 additions & 0 deletions si_preprocessing/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python3

from dendro.sdk import App, ProcessorBase
from common.models import PipelineFullContext, PipelinePreprocessingContext
from common.processor_pipeline import run_pipeline


app_name = 'si_preprocessing'

app = App(
name=app_name,
description="Spike Interface Pipeline - Preprocessing",
app_image=f"ghcr.io/catalystneuro/dendro_{app_name}",
app_executable="/app/main.py"
)


class PipelineProcessor(ProcessorBase):
name = 'spikeinterface_pipeline_preprocessing'
label = 'SpikeInterface Pipeline - Preprocessing'
description = "SpikeInterface Pipeline Processor for Preprocessing tasks"
tags = ['spike_interface', 'preprocessing', 'electrophysiology', 'pipeline']
attributes = {
'wip': True
}

@staticmethod
def run(context: PipelinePreprocessingContext):
context_preprocessing = context.model_dump()
context_preprocessing['preprocessing_context']['add_preprocessed_to_output_nwb'] = True
context_full = PipelineFullContext(
run_preprocessing=True,
run_spikesorting=False,
run_postprocessing=False,
run_curation=False,
run_visualization=False,
**context_preprocessing
)
run_pipeline(context_full)


app.add_processor(PipelineProcessor)


if __name__ == '__main__':
app.run()

0 comments on commit b1fd312

Please sign in to comment.