diff --git a/environment.yml b/environment.yml index 73c18341..766a52dc 100644 --- a/environment.yml +++ b/environment.yml @@ -4,7 +4,6 @@ channels: - conda-forge dependencies: - python=3.11 - - click - h5py - hdbscan - matplotlib diff --git a/pyproject.toml b/pyproject.toml index b93d05fc..0114f66e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ ] description = "DARTsort" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.11" # tomllib classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", @@ -25,5 +25,4 @@ classifiers = [ "dartsort.pretrained" = ["*.pt", "*.npz"] [project.scripts] -"dartsort_si_config_py" = "dartsort.cli:dartsort_si_config_py" -"dartvis_si_all" = "dartsort.cli:dartvis_si_all" +"dartsort" = "dartsort.cli:dartsort_cli" diff --git a/requirements-ci.txt b/requirements-ci.txt index 24e5849d..fd141ec0 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,5 +1,4 @@ --extra-index-url https://download.pytorch.org/whl/cpu -click h5py hdbscan matplotlib diff --git a/requirements.txt b/requirements.txt index ab14c82d..293b2c6b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ --extra-index-url https://download.pytorch.org/whl/cpu -click h5py hdbscan matplotlib diff --git a/src/dartsort/cli.py b/src/dartsort/cli.py index 00724b32..39fa1024 100644 --- a/src/dartsort/cli.py +++ b/src/dartsort/cli.py @@ -1,180 +1,62 @@ -""" -This is very work in progress! -I'm not sure that things will work this way at all in the future! -Not sure we'll keep using click -- I want to auto generate documentation -from the config objects? -""" +import numpy as np +import argparse +import spikeinterface.core as sc -import importlib.util -from pathlib import Path +from .util import cli_util, internal_config +from . import config, main -import click -import spikeinterface.full as si -from .main import dartsort, default_dartsort_config -from .vis.vismain import visualize_all_sorting_steps +def dartsort_cli(): + """dartsort_cli -# -- entry points + -- Not stable. - -@click.command() -@click.argument("si_rec_path") -@click.argument("output_directory") -@click.option("--config_path", type=str, default=None) -@click.option("--take_subtraction_from", type=str, default=None) -@click.option("--n_jobs_gpu", default=None, type=int) -@click.option("--n_jobs_cpu", default=None, type=int) -@click.option("--overwrite", default=False, flag_value=True, is_flag=True) -@click.option("--no_show_progress", default=False, flag_value=True, is_flag=True) -@click.option("--device", type=str, default=None) -@click.option("--rec_to_memory", default=False, flag_value=True, is_flag=True) -def dartsort_si_config_py( - si_rec_path, - output_directory, - config_path=None, - take_subtraction_from=None, - n_jobs_gpu=None, - n_jobs_cpu=None, - overwrite=False, - no_show_progress=False, - device=None, - rec_to_memory=False, -): - run_from_si_rec_path_and_config_py( - si_rec_path, - output_directory, - config_path=config_path, - take_subtraction_from=take_subtraction_from, - n_jobs_gpu=n_jobs_gpu, - n_jobs_cpu=n_jobs_cpu, - overwrite=overwrite, - show_progress=not no_show_progress, - device=device, - rec_to_memory=rec_to_memory, + I am figuring out how to do preprocessing still. It may be configured? + """ + # -- define CLI + ap = argparse.ArgumentParser( + prog="dartsort", + epilog=dartsort_cli.__doc__.split("--")[1], + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - - -@click.command() -@click.argument("si_rec_path") -@click.argument("dartsort_dir") -@click.argument("visualizations_dir") -@click.option("--channel_show_radius_um", default=50.0) -@click.option("--pca_radius_um", default=75.0) -@click.option("--no_superres_templates", default=False, flag_value=True, is_flag=True) -@click.option("--n_jobs_gpu", default=0) -@click.option("--n_jobs_cpu", default=0) -@click.option("--overwrite", default=False, flag_value=True, is_flag=True) -@click.option("--no_scatterplots", default=False, flag_value=True, is_flag=True) -@click.option("--no_summaries", default=False, flag_value=True, is_flag=True) -@click.option("--no_animations", default=False, flag_value=True, is_flag=True) -@click.option("--rec_to_memory", default=False, flag_value=True, is_flag=True) -def dartvis_si_all( - si_rec_path, - dartsort_dir, - visualizations_dir, - channel_show_radius_um=50.0, - pca_radius_um=75.0, - no_superres_templates=False, - n_jobs_gpu=0, - n_jobs_cpu=0, - overwrite=False, - no_scatterplots=False, - no_summaries=False, - no_animations=False, - rec_to_memory=False, -): - recording = si.load_extractor(si_rec_path) - if rec_to_memory: - recording = recording.save_to_memory(n_jobs=n_jobs_cpu) - visualize_all_sorting_steps( - recording, - dartsort_dir, - visualizations_dir, - superres_templates=not no_superres_templates, - channel_show_radius_um=channel_show_radius_um, - pca_radius_um=pca_radius_um, - make_scatterplots=not no_scatterplots, - make_unit_summaries=not no_summaries, - make_animations=not no_animations, - n_jobs=n_jobs_gpu, - n_jobs_templates=n_jobs_cpu, - overwrite=overwrite, + ap.add_argument("recording", help="Path to SpikeInterface RecordingExtractor.") + ap.add_argument("output_directory", help="Folder where outputs will be saved.") + ap.add_argument( + "--config-toml", + type=str, + default=None, + help="Path to configuration in TOML format. Arguments passed on the " + "command line will override their values in the TOML file.", ) - - -# -- scripting utils - - -def run_from_si_rec_path_and_config_py( - si_rec_path, - output_directory, - config_path=None, - take_subtraction_from=None, - n_jobs_gpu=None, - n_jobs_cpu=None, - overwrite=False, - show_progress=True, - device=None, - rec_to_memory=False, -): - # stub for eventual function that reads a config file - # I'm not sure this will be the way we actually do configuration - # maybe we'll end up deserializing DARTsortConfigs from a non-python - # config language - if config_path is None: - cfg = default_dartsort_config - else: - spec = importlib.util.spec_from_file_location("config_module", config_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - cfg = module.cfg - - recording = si.load_extractor(si_rec_path) - print(f"{recording=}") - - if rec_to_memory: - recording = recording.save_to_memory() - - if take_subtraction_from is not None: - symlink_subtraction_and_motion( - take_subtraction_from, - output_directory, - ) - - return dartsort( - recording, - output_directory, - cfg=cfg, - motion_est=None, - n_jobs_gpu=n_jobs_gpu, - n_jobs_cpu=n_jobs_cpu, - overwrite=overwrite, - show_progress=show_progress, - device=device, + # user-facing API + cli_util.dataclass_to_argparse(config.DARTsortUserConfig, parser=ap) + + # super secret developer-only args + dev_args = ap.add_argument_group("Secret development flags ($1.50 fee to use)") + cli_util.dataclass_to_argparse( + config.DeveloperConfig, + parser=dev_args, + prefix="_", + skipnames=cli_util.fieldnames(config.DARTsortUserConfig), ) + # -- parse args + args = ap.parse_args() -def symlink_subtraction_and_motion(input_dir, output_dir): - input_dir = Path(input_dir) - output_dir = Path(output_dir) - output_dir.mkdir(exist_ok=True) + # load the recording + # TODO: preprocessing management + rec = sc.load_extractor(cli_util.ensurepath(args.recording)) - sub_h5 = input_dir / "subtraction.h5" - if not sub_h5.exists(): - print(f"Can't symlink {sub_h5}") - return - - targ_sub_h5 = output_dir / "subtraction.h5" - if not targ_sub_h5.exists(): - targ_sub_h5.symlink_to(sub_h5) - - sub_models = input_dir / "subtraction_models" - targ_sub_models = output_dir / "subtraction_models" - if not targ_sub_models.exists(): - targ_sub_models.symlink_to(sub_models, target_is_directory=True) + # determine the config from the command line args + cfg = cli_util.combine_toml_and_argv( + (config.DARTsortUserConfig, config.DeveloperConfig), + config.DeveloperConfig, + cli_util.ensurepath(args.config_toml), + args, + ) - motion_est_pkl = input_dir / "motion_est.pkl" - if motion_est_pkl.exists(): - targ_me_pkl = output_dir / "motion_est.pkl" - if not targ_me_pkl.exists(): - targ_me_pkl.symlink_to(motion_est_pkl) + # -- run + # TODO: maybe this should dump to Phy? + output_directory = cli_util.ensurepath(args.output_directory, strict=False) + ret = main.dartsort(rec, output_directory, cfg=cfg, return_extra=cfg.needs_extra) + main.run_dev_tasks(ret, output_directory, cfg) diff --git a/src/dartsort/cluster/gaussian_mixture.py b/src/dartsort/cluster/gaussian_mixture.py index 0be24cc3..e1b1578b 100644 --- a/src/dartsort/cluster/gaussian_mixture.py +++ b/src/dartsort/cluster/gaussian_mixture.py @@ -2069,8 +2069,9 @@ def merge_units( # merge behavior is either a hierarchical merge or this tree-based # idea, depending on the value of a parameter if merge_kind is None: - merge_kind = "hierarchical" - if self.merge_criterion_threshold is not None: + if self.merge_criterion == "bimodality": + merge_kind = "hierarchical" + else: merge_kind = "tree" # distances are needed by both methods diff --git a/src/dartsort/config.py b/src/dartsort/config.py index ebb0523b..d83fee87 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -13,7 +13,9 @@ class DARTsortUserConfig: """User-facing configuration options""" # -- high level behavior - dredge_only: bool = False + dredge_only: bool = argfield( + False, doc="Whether to stop after initial localization and motion tracking." + ) matching_iterations: int = 1 # -- computer options @@ -155,12 +157,26 @@ class DeveloperConfig(DARTsortUserConfig): use_universal_templates: bool = False signal_rank: Annotated[int, Field(ge=0)] = 0 - merge_criterion_threshold: float | None = 0.0 + merge_criterion_threshold: float = 0.0 merge_criterion: Literal[ - "heldout_loglik", "heldout_ccl", "loglik", "ccl", "aic", "bic", "icl" + "heldout_loglik", + "heldout_ccl", + "loglik", + "ccl", + "aic", + "bic", + "icl", + "bimodality", ] = "heldout_ccl" merge_bimodality_threshold: float = 0.05 n_refinement_iters: int = 3 gmm_max_spikes: Annotated[int, Field(gt=0)] = 4_000_000 gmm_val_proportion: Annotated[float, Field(gt=0)] = 0.25 + + # flags for dev tasks run by main.run_dev_tasks + save_intermediate_labels: bool = False + + @property + def needs_extra(self): + return self.save_intermediate_labels diff --git a/src/dartsort/main.py b/src/dartsort/main.py index b39dc099..275b888f 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -41,6 +41,7 @@ def dartsort( return_extra=False, ): output_directory = Path(output_directory) + output_directory.mkdir(exist_ok=True) cfg = to_internal_config(cfg) ret = {} @@ -95,11 +96,11 @@ def dartsort( computation_config=cfg.computation_config, ) if return_extra: - ret["refined_labels"] = sorting.labels.copy() + ret["refined0_labels"] = sorting.labels.copy() # alternate matching with - for step in range(cfg.matching_iterations): - is_final = step == cfg.matching_iterations - 1 + for step in range(1, cfg.matching_iterations + 1): + is_final = step == cfg.matching_iterations prop = 1.0 if is_final else cfg.intermediate_matching_subsampling sorting, match_h5 = match( @@ -307,3 +308,10 @@ def match_chunked( hdf5_filenames.append(chunk_h5) return sortings, hdf5_filenames + + +def run_dev_tasks(results, output_directory, cfg): + if cfg.save_intermediate_labels: + for k, v in results.items(): + if k.endswith("_labels"): + np.save(output_directory / f"{k}.npy", v, allow_pickle=False) diff --git a/src/dartsort/util/cli_util.py b/src/dartsort/util/cli_util.py new file mode 100644 index 00000000..e6579986 --- /dev/null +++ b/src/dartsort/util/cli_util.py @@ -0,0 +1,134 @@ +from pathlib import Path +from dataclasses import MISSING, fields, field, asdict +from argparse import ArgumentParser, BooleanOptionalAction, _StoreAction +import tomllib +import typing + +from torch import Value + + +def ensurepath(path, strict=True): + path = Path(path) + path = path.expanduser() + path = path.resolve(strict=strict) + return path + + +def argfield( + default=MISSING, default_factory=MISSING, arg_type=MISSING, cli=True, doc="" +): + """Helper for defining fields with extended CLI behavior. + + This is only needed when a field's type is not a callable which can + take string inputs and return an object of the right type, such as + typing.Union or something. Then arg_type is what the CLI will call + to convert the argv element into an object of the desired type. + + Fields with cli=False will not be available from the command line. + """ + metadata = dict(cli=cli, doc=doc) + if arg_type is not MISSING: + metadata["arg_type"] = arg_type + return field(default=default, default_factory=default_factory, metadata=metadata) + + +def fieldnames(cls): + return set(f.name for f in fields(cls)) + + +def manglefieldset(name): + return f"{name}$$fieldset" + + +class FieldStoreAction(_StoreAction): + def __call__(self, parser, namespace, values, option_string=None): + super().__call__(parser, namespace, values, option_string=option_string) + setattr(namespace, f"{self.dest}$$fieldset", values) + + +class FieldBooleanOptionalAction(BooleanOptionalAction): + def __call__(self, parser, namespace, values, option_string=None): + super().__call__(parser, namespace, values, option_string=option_string) + setattr(namespace, manglefieldset(self.dest), True) + + +def dataclass_to_argparse(cls, parser=None, prefix="", skipnames=None): + """Add a dataclass's fields as arguments to an ArgumentParser + + Inspired by Jeremy Stafford's datacli. Works together with argfield + to set metadata needed sometimes. + """ + if parser is None: + parser = ArgumentParser() + + for field in fields(cls): + if skipnames and field.name in skipnames: + continue + if not field.metadata.get("cli", True): + continue + + required = field.default is MISSING and field.default_factory is MISSING + doc = field.metadata.get("doc", None) + type_ = field.metadata.get("arg_type", field.type) + if type_ is MISSING: + raise ValueError(f"Need type or arg_type for {field}.") + choices = None + if typing.get_origin(type_) == typing.Literal: + choices = typing.get_args(type_) + type_ = type(choices[0]) + + name = f"--{prefix}{field.name.replace('_', '-')}" + metavar = field.name.upper() + default = field.default + if default is MISSING: + default = None + kw = dict( + default=default, help=doc, metavar=metavar, dest=field.name, choices=choices + ) + + try: + if type_ == bool: + parser.add_argument(name, action=FieldBooleanOptionalAction, **kw) + else: + parser.add_argument( + name, action=FieldStoreAction, type=type_, required=required, **kw + ) + except Exception as e: + ee = ValueError(f"Exception raised while adding {field=} to CLI") + raise ee from e + + return parser + + +def dataclass_from_toml(clss, toml_path): + with open(toml_path, "r") as toml: + for cls in clss: + try: + return cls(**tomllib.load(toml)) + except TypeError: + continue + + +def update_dataclass_from_args(cls, obj, args): + if obj is None: + kv = {} + else: + kv = asdict(obj) + + for field in fields(cls): + if hasattr(args, manglefieldset(field.name)): + kv[field.name] = getattr(args, field.name) + + return cls(**kv) + + +def combine_toml_and_argv(toml_dataclasses, target_dataclass, toml_path, args): + # validate the toml file, if supplied + cfg = None + if toml_path: + cfg = dataclass_from_toml(toml_dataclasses, toml_path) + + # update with additional arguments + cfg = update_dataclass_from_args(target_dataclass, cfg, args) + + return cfg diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index a177c8c0..5332ec7f 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -7,6 +7,7 @@ from pydantic.dataclasses import dataclass from .py_util import int_or_inf +from .cli_util import argfield try: from importlib.resources import files @@ -21,25 +22,6 @@ default_pretrained_path = str(default_pretrained_path) -def argfield( - default=MISSING, default_factory=MISSING, arg_type=MISSING, cli=True, doc="" -): - """Helper for defining fields with extended CLI behavior. - - This is only needed when a field's type is not a callable which can - take string inputs and return an object of the right type, such as - typing.Union or something. Then arg_type is what the CLI will call - to convert the argv element into an object of the desired type. - - Fields with cli=False will not be available from the command line. - """ - return field( - default=default, - default_factory=default_factory, - metadata=dict(arg_type=arg_type, cli=cli, doc=""), - ) - - @dataclass(frozen=True, kw_only=True, slots=True) class WaveformConfig: """Defaults yield 42 sample trough offset and 121 total at 30kHz."""