Skip to content

Commit

Permalink
Add a CLI to help parallelize experiments.
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Jan 8, 2025
1 parent 68dca10 commit 4824027
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 201 deletions.
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ channels:
- conda-forge
dependencies:
- python=3.11
- click
- h5py
- hdbscan
- matplotlib
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
]
description = "DARTsort"
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.11" # tomllib
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
Expand All @@ -25,5 +25,4 @@ classifiers = [
"dartsort.pretrained" = ["*.pt", "*.npz"]

[project.scripts]
"dartsort_si_config_py" = "dartsort.cli:dartsort_si_config_py"
"dartvis_si_all" = "dartsort.cli:dartvis_si_all"
"dartsort" = "dartsort.cli:dartsort_cli"
1 change: 0 additions & 1 deletion requirements-ci.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cpu
click
h5py
hdbscan
matplotlib
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cpu
click
h5py
hdbscan
matplotlib
Expand Down
218 changes: 50 additions & 168 deletions src/dartsort/cli.py
Original file line number Diff line number Diff line change
@@ -1,180 +1,62 @@
"""
This is very work in progress!
I'm not sure that things will work this way at all in the future!
Not sure we'll keep using click -- I want to auto generate documentation
from the config objects?
"""
import numpy as np
import argparse
import spikeinterface.core as sc

import importlib.util
from pathlib import Path
from .util import cli_util, internal_config
from . import config, main

import click
import spikeinterface.full as si

from .main import dartsort, default_dartsort_config
from .vis.vismain import visualize_all_sorting_steps
def dartsort_cli():
"""dartsort_cli
# -- entry points
--<!!> Not stable.

@click.command()
@click.argument("si_rec_path")
@click.argument("output_directory")
@click.option("--config_path", type=str, default=None)
@click.option("--take_subtraction_from", type=str, default=None)
@click.option("--n_jobs_gpu", default=None, type=int)
@click.option("--n_jobs_cpu", default=None, type=int)
@click.option("--overwrite", default=False, flag_value=True, is_flag=True)
@click.option("--no_show_progress", default=False, flag_value=True, is_flag=True)
@click.option("--device", type=str, default=None)
@click.option("--rec_to_memory", default=False, flag_value=True, is_flag=True)
def dartsort_si_config_py(
si_rec_path,
output_directory,
config_path=None,
take_subtraction_from=None,
n_jobs_gpu=None,
n_jobs_cpu=None,
overwrite=False,
no_show_progress=False,
device=None,
rec_to_memory=False,
):
run_from_si_rec_path_and_config_py(
si_rec_path,
output_directory,
config_path=config_path,
take_subtraction_from=take_subtraction_from,
n_jobs_gpu=n_jobs_gpu,
n_jobs_cpu=n_jobs_cpu,
overwrite=overwrite,
show_progress=not no_show_progress,
device=device,
rec_to_memory=rec_to_memory,
I am figuring out how to do preprocessing still. It may be configured?
"""
# -- define CLI
ap = argparse.ArgumentParser(
prog="dartsort",
epilog=dartsort_cli.__doc__.split("--")[1],
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)


@click.command()
@click.argument("si_rec_path")
@click.argument("dartsort_dir")
@click.argument("visualizations_dir")
@click.option("--channel_show_radius_um", default=50.0)
@click.option("--pca_radius_um", default=75.0)
@click.option("--no_superres_templates", default=False, flag_value=True, is_flag=True)
@click.option("--n_jobs_gpu", default=0)
@click.option("--n_jobs_cpu", default=0)
@click.option("--overwrite", default=False, flag_value=True, is_flag=True)
@click.option("--no_scatterplots", default=False, flag_value=True, is_flag=True)
@click.option("--no_summaries", default=False, flag_value=True, is_flag=True)
@click.option("--no_animations", default=False, flag_value=True, is_flag=True)
@click.option("--rec_to_memory", default=False, flag_value=True, is_flag=True)
def dartvis_si_all(
si_rec_path,
dartsort_dir,
visualizations_dir,
channel_show_radius_um=50.0,
pca_radius_um=75.0,
no_superres_templates=False,
n_jobs_gpu=0,
n_jobs_cpu=0,
overwrite=False,
no_scatterplots=False,
no_summaries=False,
no_animations=False,
rec_to_memory=False,
):
recording = si.load_extractor(si_rec_path)
if rec_to_memory:
recording = recording.save_to_memory(n_jobs=n_jobs_cpu)
visualize_all_sorting_steps(
recording,
dartsort_dir,
visualizations_dir,
superres_templates=not no_superres_templates,
channel_show_radius_um=channel_show_radius_um,
pca_radius_um=pca_radius_um,
make_scatterplots=not no_scatterplots,
make_unit_summaries=not no_summaries,
make_animations=not no_animations,
n_jobs=n_jobs_gpu,
n_jobs_templates=n_jobs_cpu,
overwrite=overwrite,
ap.add_argument("recording", help="Path to SpikeInterface RecordingExtractor.")
ap.add_argument("output_directory", help="Folder where outputs will be saved.")
ap.add_argument(
"--config-toml",
type=str,
default=None,
help="Path to configuration in TOML format. Arguments passed on the "
"command line will override their values in the TOML file.",
)


# -- scripting utils


def run_from_si_rec_path_and_config_py(
si_rec_path,
output_directory,
config_path=None,
take_subtraction_from=None,
n_jobs_gpu=None,
n_jobs_cpu=None,
overwrite=False,
show_progress=True,
device=None,
rec_to_memory=False,
):
# stub for eventual function that reads a config file
# I'm not sure this will be the way we actually do configuration
# maybe we'll end up deserializing DARTsortConfigs from a non-python
# config language
if config_path is None:
cfg = default_dartsort_config
else:
spec = importlib.util.spec_from_file_location("config_module", config_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
cfg = module.cfg

recording = si.load_extractor(si_rec_path)
print(f"{recording=}")

if rec_to_memory:
recording = recording.save_to_memory()

if take_subtraction_from is not None:
symlink_subtraction_and_motion(
take_subtraction_from,
output_directory,
)

return dartsort(
recording,
output_directory,
cfg=cfg,
motion_est=None,
n_jobs_gpu=n_jobs_gpu,
n_jobs_cpu=n_jobs_cpu,
overwrite=overwrite,
show_progress=show_progress,
device=device,
# user-facing API
cli_util.dataclass_to_argparse(config.DARTsortUserConfig, parser=ap)

# super secret developer-only args
dev_args = ap.add_argument_group("Secret development flags ($1.50 fee to use)")
cli_util.dataclass_to_argparse(
config.DeveloperConfig,
parser=dev_args,
prefix="_",
skipnames=cli_util.fieldnames(config.DARTsortUserConfig),
)

# -- parse args
args = ap.parse_args()

def symlink_subtraction_and_motion(input_dir, output_dir):
input_dir = Path(input_dir)
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
# load the recording
# TODO: preprocessing management
rec = sc.load_extractor(cli_util.ensurepath(args.recording))

sub_h5 = input_dir / "subtraction.h5"
if not sub_h5.exists():
print(f"Can't symlink {sub_h5}")
return

targ_sub_h5 = output_dir / "subtraction.h5"
if not targ_sub_h5.exists():
targ_sub_h5.symlink_to(sub_h5)

sub_models = input_dir / "subtraction_models"
targ_sub_models = output_dir / "subtraction_models"
if not targ_sub_models.exists():
targ_sub_models.symlink_to(sub_models, target_is_directory=True)
# determine the config from the command line args
cfg = cli_util.combine_toml_and_argv(
(config.DARTsortUserConfig, config.DeveloperConfig),
config.DeveloperConfig,
cli_util.ensurepath(args.config_toml),
args,
)

motion_est_pkl = input_dir / "motion_est.pkl"
if motion_est_pkl.exists():
targ_me_pkl = output_dir / "motion_est.pkl"
if not targ_me_pkl.exists():
targ_me_pkl.symlink_to(motion_est_pkl)
# -- run
# TODO: maybe this should dump to Phy?
output_directory = cli_util.ensurepath(args.output_directory, strict=False)
ret = main.dartsort(rec, output_directory, cfg=cfg, return_extra=cfg.needs_extra)
main.run_dev_tasks(ret, output_directory, cfg)
5 changes: 3 additions & 2 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -2069,8 +2069,9 @@ def merge_units(
# merge behavior is either a hierarchical merge or this tree-based
# idea, depending on the value of a parameter
if merge_kind is None:
merge_kind = "hierarchical"
if self.merge_criterion_threshold is not None:
if self.merge_criterion == "bimodality":
merge_kind = "hierarchical"
else:
merge_kind = "tree"

# distances are needed by both methods
Expand Down
22 changes: 19 additions & 3 deletions src/dartsort/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ class DARTsortUserConfig:
"""User-facing configuration options"""

# -- high level behavior
dredge_only: bool = False
dredge_only: bool = argfield(
False, doc="Whether to stop after initial localization and motion tracking."
)
matching_iterations: int = 1

# -- computer options
Expand Down Expand Up @@ -155,12 +157,26 @@ class DeveloperConfig(DARTsortUserConfig):
use_universal_templates: bool = False
signal_rank: Annotated[int, Field(ge=0)] = 0

merge_criterion_threshold: float | None = 0.0
merge_criterion_threshold: float = 0.0
merge_criterion: Literal[
"heldout_loglik", "heldout_ccl", "loglik", "ccl", "aic", "bic", "icl"
"heldout_loglik",
"heldout_ccl",
"loglik",
"ccl",
"aic",
"bic",
"icl",
"bimodality",
] = "heldout_ccl"
merge_bimodality_threshold: float = 0.05
n_refinement_iters: int = 3

gmm_max_spikes: Annotated[int, Field(gt=0)] = 4_000_000
gmm_val_proportion: Annotated[float, Field(gt=0)] = 0.25

# flags for dev tasks run by main.run_dev_tasks
save_intermediate_labels: bool = False

@property
def needs_extra(self):
return self.save_intermediate_labels
14 changes: 11 additions & 3 deletions src/dartsort/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def dartsort(
return_extra=False,
):
output_directory = Path(output_directory)
output_directory.mkdir(exist_ok=True)
cfg = to_internal_config(cfg)

ret = {}
Expand Down Expand Up @@ -95,11 +96,11 @@ def dartsort(
computation_config=cfg.computation_config,
)
if return_extra:
ret["refined_labels"] = sorting.labels.copy()
ret["refined0_labels"] = sorting.labels.copy()

# alternate matching with
for step in range(cfg.matching_iterations):
is_final = step == cfg.matching_iterations - 1
for step in range(1, cfg.matching_iterations + 1):
is_final = step == cfg.matching_iterations
prop = 1.0 if is_final else cfg.intermediate_matching_subsampling

sorting, match_h5 = match(
Expand Down Expand Up @@ -307,3 +308,10 @@ def match_chunked(
hdf5_filenames.append(chunk_h5)

return sortings, hdf5_filenames


def run_dev_tasks(results, output_directory, cfg):
if cfg.save_intermediate_labels:
for k, v in results.items():
if k.endswith("_labels"):
np.save(output_directory / f"{k}.npy", v, allow_pickle=False)
Loading

0 comments on commit 4824027

Please sign in to comment.