Skip to content

Commit

Permalink
Merge pull request #4 from catalystneuro/motion-correction-models
Browse files Browse the repository at this point in the history
motion correction models
  • Loading branch information
luiztauffer authored Jan 13, 2024
2 parents 3bdca68 + 6b0a1da commit 80a8024
Show file tree
Hide file tree
Showing 5 changed files with 886 additions and 13 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ SpikeInterface Apps for Dendro

## Dev

Create / update spec file:
```shell
dendro make-app-spec-file --app-dir . --spec-output-file spec.json
```

Build single App image:
```shell
DOCKER_BUILDKIT=1 docker build -t <tag-name> .
Expand All @@ -13,4 +18,29 @@ Examples:
```shell
DOCKER_BUILDKIT=1 docker build -t ghcr.io/catalystneuro/dendro_si_kilosort25:latest .
docker push ghcr.io/catalystneuro/dendro_si_kilosort25:latest
```

## Test locally

Set up a bash script similar to this:
```shell
#!/bin/bash

# Docker image
IMAGE="ghcr.io/catalystneuro/dendro_si_kilosort25"

# Command to be executed inside the container
ENTRYPOINT_CMD="dendro"
ARGS="test-app-processor --app-dir . --processor spikeinterface_pipeline_ks25 --context sample_context_1.yaml"


# Run the Docker container, with hot-reload to local code versions
docker run --gpus all \
-v $(pwd):/app \
-v /mnt/shared_storage/Github/dendro/python:/src/dendro/python \
-v /mnt/shared_storage/Github/spikeinterface_pipelines:/src/spikeinterface_pipelines \
-w /app \
--entrypoint "$ENTRYPOINT_CMD" \
$IMAGE \
$ARGS
```
132 changes: 128 additions & 4 deletions si_kilosort25/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,136 @@ class HighpassSpatialFilter(BaseModel):
highpass_butter_wn: float = Field(default=0.01, description="Natural frequency for the Butterworth filter")


class MotionCorrection(BaseModel):
compute: bool = Field(default=True, description="Whether to compute motion correction")
apply: bool = Field(default=False, description="Whether to apply motion correction")
preset: str = Field(default="nonrigid_accurate", description="Preset for motion correction")
# ---------------------------------------------------------------
# Motion Correction Models
# ---------------------------------------------------------------
class MCDetectKwargs(BaseModel):
method: str = Field(default="locally_exclusive", description="")
peak_sign: str = Field(default="neg", description="")
detect_threshold: float = Field(default=8.0, description="")
exclude_sweep_ms: float = Field(default=0.1, description="")
radius_um: float = Field(default=50.0, description="")


class MCLocalizeCenterOfMass(BaseModel):
radius_um: float = Field(default=75.0, description="Radius in um for channel sparsity.")
feature: str = Field(default="ptp", description="'ptp', 'mean', 'energy' or 'peak_voltage'. Feature to consider for computation")


class MCLocalizeMonopolarTriangulation(BaseModel):
radius_um: float = Field(default=75.0, description="For channel sparsity.")
max_distance_um: float = Field(default=150.0, description="Boundary for distance estimation.")
optimizer: str = Field(default="minimize_with_log_penality", description="")
enforce_decrease: bool = Field(default=True, description="Enforce spatial decreasingness for PTP vectors")
feature: str = Field(default="ptp", description="'ptp', 'energy' or 'peak_voltage'. The available features to consider for estimating the position via monopolar triangulation are peak-to-peak amplitudes (ptp, default), energy ('energy', as L2 norm) or voltages at the center of the waveform (peak_voltage)")


class MCLocalizeGridConvolution(BaseModel):
radius_um: float = Field(default=40.0, description="Radius in um for channel sparsity.")
upsampling_um: float = Field(default=5.0, description="Upsampling resolution for the grid of templates.")
sigma_um: List[float] = Field(default=[5.0, 25.0, 5], description="Spatial decays of the fake templates.")
sigma_ms: float = Field(default=0.25, description="The temporal decay of the fake templates.")
margin_um: float = Field(default=30.0, description="The margin for the grid of fake templates.")
percentile: float = Field(default=10.0, description="The percentage in [0, 100] of the best scalar products kept to estimate the position.")
sparsity_threshold: float = Field(default=0.01, description="The sparsity threshold (in [0, 1]) below which weights should be considered as 0.")


class MCEstimateMotionDecentralized(BaseModel):
method: str = Field(default="decentralized", description="")
direction: str = Field(default="y", description="")
bin_duration_s: float = Field(default=2.0, description="")
rigid: bool = Field(default=False, description="")
bin_um: float = Field(default=5.0, description="")
margin_um: float = Field(default=0.0, description="")
win_shape: str = Field(default="gaussian", description="")
win_step_um: float = Field(default=100.0, description="")
win_sigma_um: float = Field(default=200.0, description="")
histogram_depth_smooth_um: float = Field(default=5.0, description="")
histogram_time_smooth_s: Union[float, None] = Field(default=None, description="")
pairwise_displacement_method: str = Field(default="conv", description="")
max_displacement_um: float = Field(default=100.0, description="")
weight_scale: str = Field(default="linear", description="")
error_sigma: float = Field(default=0.2, description="")
conv_engine: str = Field(
default="numpy",
description="",
json_schema_extra={'options': ["torch", "numpy"]},
)
torch_device: str = Field(default="", description="")
batch_size: int = Field(default=1, description="")
corr_threshold: float = Field(default=0.0, description="")
time_horizon_s: Union[float, None] = Field(default=None, description="")
convergence_method: str = Field(default="lsmr", description="")
soft_weights: bool = Field(default=False, description="")
normalized_xcorr: bool = Field(default=True, description="")
centered_xcorr: bool = Field(default=True, description="")
temporal_prior: bool = Field(default=True, description="")
spatial_prior: bool = Field(default=False, description="")
force_spatial_median_continuity: bool = Field(default=False, description="")
reference_displacement: str = Field(default="median", description="")
reference_displacement_time_s: float = Field(default=0, description="")
robust_regression_sigma: int = Field(default=2, description="")
weight_with_amplitude: bool = Field(default=False, description="")


class MCEstimateMotionIterativeTemplate(BaseModel):
bin_duration_s: float = Field(default=2.0, description="")
rigid: bool = Field(default=False, description="")
win_step_um: float = Field(default=50.0, description="")
win_sigma_um: float = Field(default=150.0, description="")
margin_um: float = Field(default=0.0, description="")
win_shape: str = Field(default="rect", description="")


class MCInterpolateMotionKwargs(BaseModel):
direction: int = Field(default=1, description="0 | 1 | 2. Dimension along which channel_locations are shifted (0 - x, 1 - y, 2 - z).")
border_mode: str = Field(default="remove_channels", description="'remove_channels' | 'force_extrapolate' | 'force_zeros'. Control how channels are handled on border.")
spatial_interpolation_method: str = Field(default="idw", description="The spatial interpolation method used to interpolate the channel locations.")
sigma_um: float = Field(default=20.0, description="Used in the 'kriging' formula")
p: int = Field(default=1, description="Used in the 'kriging' formula")
num_closest: int = Field(default=3, description="Number of closest channels used by 'idw' method for interpolation.")


class MCNonrigidAccurate(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeMonopolarTriangulation = Field(default=MCLocalizeMonopolarTriangulation(), description="")
estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(default=MCEstimateMotionDecentralized(), description="")
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="")


class MCRigidFast(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeCenterOfMass = Field(default=MCLocalizeCenterOfMass(), description="")
estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(default=MCEstimateMotionDecentralized(bin_duration_s=10.0, rigid=True), description="")
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="")


class MCKilosortLike(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeGridConvolution = Field(default=MCLocalizeGridConvolution(), description="")
estimate_motion_kwargs: MCEstimateMotionIterativeTemplate = Field(default=MCEstimateMotionIterativeTemplate(), description="")
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(border_mode="force_extrapolate", spatial_interpolation_method="kriging"), description="")


class MotionCorrection(BaseModel):
strategy: str = Field(
default="compute",
description="What strategy to use for motion correction",
json_schema_extra={'options': ["skip", "compute", "apply"]},
)
preset: str = Field(
default="nonrigid_accurate",
description="Preset for motion correction",
json_schema_extra={'options': ["nonrigid_accurate", "rigid_fast", "kilosort_like"]},
)
motion_kwargs_nonrigid_accurate: MCNonrigidAccurate = Field(default=MCNonrigidAccurate(), description="Motion correction parameters for nonrigid_accurate preset")
motion_kwargs_rigid_fast: MCRigidFast = Field(default=MCRigidFast(), description="Motion correction parameters for rigid_fast preset")
motion_kwargs_kilosort_like: MCKilosortLike = Field(default=MCKilosortLike(), description="Motion correction parameters for kilosort_like preset")


# ---------------------------------------------------------------
# Preprocessing Context
# ---------------------------------------------------------------
class PreprocessingContext(BaseModel):
preprocessing_strategy: str = Field(default="cmr", description="Strategy for preprocessing")
highpass_filter: HighpassFilter = Field(default=HighpassFilter(), description="Highpass filter")
Expand Down
12 changes: 11 additions & 1 deletion si_kilosort25/processor_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def run(context: PipelineContext):

logger.info(recording)

# TODO - run pipeline
# Run pipeline
job_kwargs = {
'n_jobs': -1,
'chunk_duration': '1s',
Expand All @@ -55,6 +55,16 @@ def run(context: PipelineContext):

run_preprocessing = context.run_preprocessing
preprocessing_params = context.preprocessing_context.model_dump()
motion_correction_preset = preprocessing_params['motion_correction']['preset']
nonrigid_accurate_kwargs = preprocessing_params['motion_correction'].pop('motion_kwargs_nonrigid_accurate')
rigid_fast_kwargs = preprocessing_params['motion_correction'].pop('motion_kwargs_rigid_fast')
kilosort_like_kwargs = preprocessing_params['motion_correction'].pop('motion_kwargs_kilosort_like')
if motion_correction_preset == 'nonrigid_accurate':
preprocessing_params['motion_correction']['motion_kwargs'] = nonrigid_accurate_kwargs
elif motion_correction_preset == 'rigid_fast':
preprocessing_params['motion_correction']['motion_kwargs'] = rigid_fast_kwargs
elif motion_correction_preset == 'kilosort_like':
preprocessing_params['motion_correction']['motion_kwargs'] = kilosort_like_kwargs

run_spikesorting = context.run_spikesorting
spikesorting_params = context.spikesorting_context.model_dump()
Expand Down
5 changes: 4 additions & 1 deletion si_kilosort25/sample_context_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ lazy_read_input: true
stub_test: true
recording_context:
electrical_series_path: /acquisition/ElectricalSeriesRaw
run_preprocessing: false
run_preprocessing: true
preprocessing_context:
motion_correction:
preset: nonrigid_accurate
run_spikesorting: true
spikesorting_context:
do_correction: false
Expand Down
Loading

0 comments on commit 80a8024

Please sign in to comment.