Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev #1

Merged
merged 18 commits into from
Jan 5, 2024
63 changes: 63 additions & 0 deletions .github/workflows/build_and_push.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
name: Build and Push Docker Images

on:
push:
branches:
- dev
paths:
- '*/**'

jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
app_dir:
- si_kilosort25
- si_kilosort3

steps:
- name: Checkout repository
uses: actions/checkout@v3

- name: Set up Docker Builder
uses: docker/setup-buildx-action@v2

- name: Log in to GitHub Container Registry
uses: docker/login-action@v1
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# username: YOUR_GITHUB_USERNAME # Replace with your GitHub username
# password: ${{ secrets.GHCR_PAT }} # Use the PAT secret

- name: Pip install dendro
run: |
python -m pip install --upgrade pip
git clone https://github.com/flatironinstitute/dendro.git
cd dendro/python
git pull
pip install .

- name: Generate Spec File and Build Docker Image
# if: contains(github.event.head_commit.modified, matrix.app_dir)
run: |
cd ${{ matrix.app_dir }}
dendro make-app-spec-file --app-dir . --spec-output-file spec.json
docker build -t ghcr.io/catalystneuro/dendro_${{ matrix.app_dir }}:latest .
docker push ghcr.io/catalystneuro/dendro_${{ matrix.app_dir }}:latest

- name: Commit files
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
GIT_STATUS=$(git status -s)
[[ ! -z "$GIT_STATUS" ]] && git add ${{ matrix.app_dir }}/spec.json && git commit -m "update spec.json" -a || echo "No changes to commit"

- name: Push changes
uses: ad-m/github-push-action@master
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
branch: ${{ github.ref }}
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,10 @@ SpikeInterface Apps for Dendro
Build single App image:
```shell
DOCKER_BUILDKIT=1 docker build -t <tag-name> .
```

Examples:
```shell
DOCKER_BUILDKIT=1 docker build -t ghcr.io/catalystneuro/dendro_si_kilosort25:latest .
docker push ghcr.io/catalystneuro/dendro_si_kilosort25:latest
```
3 changes: 3 additions & 0 deletions si_kilosort25/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
results/
scratch/
output/
12 changes: 12 additions & 0 deletions si_kilosort25/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ RUN git clone https://github.com/flatironinstitute/dendro.git && \
cd dendro/python && \
pip install -e .

# Install spikeinterface-pipelines from source, for now
RUN git clone https://github.com/SpikeInterface/spikeinterface_pipelines.git && \
cd spikeinterface_pipelines && \
git checkout dev && \
pip install -e .

# Install spikeinterface from source, for now
RUN git clone https://github.com/SpikeInterface/spikeinterface.git && \
cd spikeinterface && \
# git checkout dev && \
pip install -e .[full]

# Copy files into the container
WORKDIR /app
COPY *.py ./
Expand Down
4 changes: 3 additions & 1 deletion si_kilosort25/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

from dendro.sdk import App

from processor_pipeline import PipelineProcessor
Expand All @@ -8,7 +10,7 @@
app = App(
name=app_name,
description="Spike Interface Pipeline - Kilosort2.5",
app_image=f"ghcr.io/catalystneuro/{app_name}",
app_image=f"ghcr.io/catalystneuro/dendro_{app_name}",
app_executable="/app/main.py"
)

Expand Down
31 changes: 22 additions & 9 deletions si_kilosort25/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,20 @@ class HighpassSpatialFilter(BaseModel):
highpass_butter_wn: float = Field(default=0.01, description="Natural frequency for the Butterworth filter")


class MotionCorrection(BaseModel):
compute: bool = Field(default=True, description="Whether to compute motion correction")
apply: bool = Field(default=False, description="Whether to apply motion correction")
preset: str = Field(default="nonrigid_accurate", description="Preset for motion correction")


class PreprocessingContext(BaseModel):
preprocessing_strategy: str = Field(default="cmr", description="Strategy for preprocessing")
highpass_filter: HighpassFilter = Field(default=HighpassFilter(), description="Highpass filter")
phase_shift: PhaseShift = Field(default=PhaseShift(), description="Phase shift")
detect_bad_channels: DetectBadChannels = Field(default=DetectBadChannels(), description="Detect bad channels")
common_reference: CommonReference = Field(default=CommonReference(), description="Common reference")
highpass_spatial_filter: HighpassSpatialFilter = Field(default=HighpassSpatialFilter(), description="Highpass spatial filter")
preprocessing_strategy: str = Field(default="cmr", description="Strategy for preprocessing")
motion_correction: MotionCorrection = Field(default=MotionCorrection(), description="Motion correction")
remove_out_channels: bool = Field(default=False, description="Flag to remove out channels")
remove_bad_channels: bool = Field(default=False, description="Flag to remove bad channels")
max_bad_channel_fraction_to_remove: float = Field(default=1.1, description="Maximum fraction of bad channels to remove")
Expand All @@ -72,19 +79,20 @@ class Kilosort25SortingContext(BaseModel):
sig: float = Field(default=20, description="spatial smoothness constant for registration")
freq_min: float = Field(default=150, description="High-pass filter cutoff frequency")
sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike")
lam: float = Field(default=10.0, description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)")
nPCs: int = Field(default=3, description="Number of PCA dimensions")
ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection")
nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4")
NT: int = Field(default=-1, description='Batch size (if -1 it is automatically computed)')
# NT: int = Field(default=-1, description='Batch size (if -1 it is automatically computed)')
AUCsplit: float = Field(default=0.9, description="Threshold on the area under the curve (AUC) criterion for performing a split in the final step")
do_correction: bool = Field(default=True, description="If True drift registration is applied")
wave_length: float = Field(default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)")
keep_good_only: bool = Field(default=True, description="If True only 'good' units are returned")
keep_good_only: bool = Field(default=False, description="If True only 'good' units are returned")
skip_kilosort_preprocessing: bool = Field(default=False, description="Can optionaly skip the internal kilosort preprocessing")
scaleproc: int = Field(default=-1, description="int16 scaling of whitened data, if -1 set to 200.")
# scaleproc: int = Field(default=-1, description="int16 scaling of whitened data, if -1 set to 200.")


class SortingContext(BaseModel):
class SpikeSortingContext(BaseModel):
sorter_name: str = Field(default="kilosort2_5", description="Name of the sorter to use.")
sorter_kwargs: Kilosort25SortingContext = Field(default=Kilosort25SortingContext(), description="Sorter specific kwargs.")

Expand All @@ -109,11 +117,16 @@ class CurationContext(BaseModel):
class PipelineContext(BaseModel):
input: InputFile = Field(description='Input NWB file')
output: OutputFile = Field(description='Output NWB file')
lazy_read_input: bool = Field(default=True, description='Lazy read input file')
stub_test: bool = Field(default=False, description='Stub test')
recording_context: RecordingContext = Field(description='Recording context')
preprocessing_context: PreprocessingContext = Field(description='Preprocessing context')
sorting_context: SortingContext = Field(description='Sorting context')
postprocessing_context: PostprocessingContext = Field(description='Postprocessing context')
curation_context: CurationContext = Field(description='Curation context')
run_preprocessing: bool = Field(default=True, description='Run preprocessing')
preprocessing_context: PreprocessingContext = Field(default=PreprocessingContext(), description='Preprocessing context')
run_spikesorting: bool = Field(default=True, description='Run spike sorting')
spikesorting_context: SpikeSortingContext = Field(default=SpikeSortingContext(), description='Sorting context')
run_postprocessing: bool = Field(default=True, description='Run postprocessing')
# postprocessing_context: PostprocessingContext = Field(default=PostprocessingContext(), description='Postprocessing context')
# curation_context: CurationContext = Field(default=CurationContext(), description='Curation context')


# # ------------------------------
Expand Down
91 changes: 62 additions & 29 deletions si_kilosort25/nwb_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from typing import Union, List
from neuroconv.tools.spikeinterface import write_sorting
# from neuroconv.tools.spikeinterface import write_sorting
from pynwb import NWBFile
from pynwb.file import Subject
from uuid import uuid4
import numpy as np
import h5py
import spikeinterface as si
import pynwb


class NwbRecording(si.BaseRecording):
def __init__(self, file: h5py.File, electrical_series_path: str) -> None:
electrical_series: h5py.Group = file[electrical_series_path]
def __init__(
self,
file, # file-like object
electrical_series_path: str
) -> None:
h5_file = h5py.File(file, 'r')

electrical_series: h5py.Group = h5_file[electrical_series_path]
electrical_series_data = electrical_series['data']
dtype = electrical_series_data.dtype

Expand All @@ -24,26 +31,30 @@ def __init__(self, file: h5py.File, electrical_series_path: str) -> None:

# Get channel ids
electrode_indices = electrical_series['electrodes'][:]
electrodes_table = file['/general/extracellular_ephys/electrodes']
electrodes_table = h5_file['/general/extracellular_ephys/electrodes']
channel_ids = [electrodes_table['id'][i] for i in electrode_indices]

si.BaseRecording.__init__(self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype)
super().__init__(
channel_ids=channel_ids,
sampling_frequency=sampling_frequency,
dtype=dtype
)

# Set electrode locations
if 'x' in electrodes_table:
channel_loc_x = [electrodes_table['x'][i] for i in electrode_indices]
channel_loc_y = [electrodes_table['y'][i] for i in electrode_indices]
if 'z' in electrodes_table:
channel_loc_z = [electrodes_table['z'][i] for i in electrode_indices]
else:
channel_loc_z = None
elif 'rel_x' in electrodes_table:
if 'rel_x' in electrodes_table:
channel_loc_x = [electrodes_table['rel_x'][i] for i in electrode_indices]
channel_loc_y = [electrodes_table['rel_y'][i] for i in electrode_indices]
if 'rel_z' in electrodes_table:
channel_loc_z = [electrodes_table['rel_z'][i] for i in electrode_indices]
else:
channel_loc_z = None
elif 'x' in electrodes_table:
channel_loc_x = [electrodes_table['x'][i] for i in electrode_indices]
channel_loc_y = [electrodes_table['y'][i] for i in electrode_indices]
if 'z' in electrodes_table:
channel_loc_z = [electrodes_table['z'][i] for i in electrode_indices]
else:
channel_loc_z = None
else:
channel_loc_x = None
channel_loc_y = None
Expand All @@ -58,6 +69,18 @@ def __init__(self, file: h5py.File, electrical_series_path: str) -> None:
locations[i, 2] = channel_loc_z[electrode_index]
self.set_dummy_probe_from_locations(locations)

# Extractors channel groups must be integers, but Nwb electrodes group_name can be strings
if "group_name" in electrodes_table:
unique_electrode_group_names = list(np.unique(electrodes_table["group_name"][:]))
print(unique_electrode_group_names)

groups = []
for electrode_index in electrode_indices:
group_name = electrodes_table["group_name"][electrode_index]
group_id = unique_electrode_group_names.index(group_name)
groups.append(group_id)
self.set_channel_groups(groups)

recording_segment = NwbRecordingSegment(
electrical_series_data=electrical_series_data,
sampling_frequency=sampling_frequency
Expand All @@ -68,7 +91,7 @@ def __init__(self, file: h5py.File, electrical_series_path: str) -> None:
class NwbRecordingSegment(si.BaseRecordingSegment):
def __init__(self, electrical_series_data: h5py.Dataset, sampling_frequency: float) -> None:
self._electrical_series_data = electrical_series_data
si.BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency)
super().__init__(sampling_frequency=sampling_frequency)

def get_num_samples(self) -> int:
return self._electrical_series_data.shape[0]
Expand All @@ -80,7 +103,7 @@ def get_traces(self, start_frame: int, end_frame: int, channel_indices: Union[Li
return self._electrical_series_data[start_frame:end_frame, channel_indices]


def create_sorting_out_nwb_file(nwbfile_original: NWBFile, sorting: si.BaseSorting, sorting_out_fname: str):
def create_sorting_out_nwb_file(nwbfile_original, sorting: si.BaseSorting, sorting_out_fname: str):
nwbfile = NWBFile(
session_description=nwbfile_original.session_description + " - spike sorting results.",
identifier=str(uuid4()),
Expand All @@ -91,19 +114,29 @@ def create_sorting_out_nwb_file(nwbfile_original: NWBFile, sorting: si.BaseSorti
institution=nwbfile_original.institution,
experiment_description=nwbfile_original.experiment_description,
related_publications=nwbfile_original.related_publications,
subject=Subject(
subject_id=nwbfile_original.subject.subject_id,
age=nwbfile_original.subject.age,
description=nwbfile_original.subject.description,
species=nwbfile_original.subject.species,
sex=nwbfile_original.subject.sex,
)
)
subject = Subject(
subject_id=nwbfile_original.subject.subject_id,
age=nwbfile_original.subject.age,
description=nwbfile_original.subject.description,
species=nwbfile_original.subject.species,
sex=nwbfile_original.subject.sex,
)
nwbfile.subject = subject

write_sorting(
sorting=sorting,
nwbfile=nwbfile,
nwbfile_path=sorting_out_fname,
overwrite=True
)
for ii, unit_id in enumerate(sorting.get_unit_ids()):
st = sorting.get_unit_spike_train(unit_id) / sorting.get_sampling_frequency()
nwbfile.add_unit(
id=ii + 1, # must be an int
spike_times=st
)

# Write the nwb file
with pynwb.NWBHDF5IO(path=sorting_out_fname, mode='w') as io:
io.write(container=nwbfile, cache_spec=True)

# write_sorting(
# sorting=sorting,
# nwbfile=nwbfile,
# nwbfile_path=sorting_out_fname,
# overwrite=True
# )
Loading
Loading