From 4ef715e4cde13eef36eb84f1abd67a61bee7b22e Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 19 Dec 2024 14:23:38 -0500 Subject: [PATCH] Big config refactoring to finally have a user-friendly API --- src/dartsort/config.py | 566 +++++----------------- src/dartsort/main.py | 23 +- src/dartsort/peel/subtract.py | 6 +- src/dartsort/templates/get_templates.py | 64 ++- src/dartsort/templates/templates.py | 14 +- src/dartsort/util/internal_config.py | 524 ++++++++++++++++++++ src/dartsort/util/multiprocessing_util.py | 9 + 7 files changed, 724 insertions(+), 482 deletions(-) create mode 100644 src/dartsort/util/internal_config.py diff --git a/src/dartsort/config.py b/src/dartsort/config.py index 067224b8..b2bffebb 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -1,463 +1,153 @@ -"""Configuration classes - -Users should not edit this file! - -Rather, make your own custom configs by instantiating new -config objects, for example, to turn off neural net denoising -in the featurization pipeline you can make: - -``` -featurization_config = FeaturizationConfig(do_nn_denoise=False) -``` - -This will use all the other parameters' default values. This -object can then be passed into the high level functions like -`subtract(...)`. -""" +"""Knobs""" from dataclasses import MISSING, field -from typing import Literal +from typing import Literal, Annotated -import numpy as np -import torch +from pydantic import Field from pydantic.dataclasses import dataclass -try: - from importlib.resources import files -except ImportError: - try: - from importlib_resources import files - except ImportError: - raise ValueError("Need python>=3.10 or pip install importlib_resources.") - -default_pretrained_path = files("dartsort.pretrained") -default_pretrained_path = default_pretrained_path.joinpath("single_chan_denoiser.pt") -default_pretrained_path = str(default_pretrained_path) - - -def argfield(default=MISSING, default_factory=MISSING, arg_type=MISSING, cli=True): - """Helper for defining fields with extended CLI behavior. - - This is only needed when a field's type is not a callable which can - take string inputs and return an object of the right type, such as - typing.Union or something. Then arg_type is what the CLI will call - to convert the argv element into an object of the desired type. - - Fields with cli=False will not be available from the command line. - """ - return field( - default=default, - default_factory=default_factory, - metadata=dict(arg_type=arg_type, cli=cli), - ) - - -@dataclass(frozen=True, kw_only=True, slots=True) -class WaveformConfig: - """Defaults yield 42 sample trough offset and 121 total at 30kHz.""" - - ms_before: float = 1.4 - ms_after: float = 2.6 - - def trough_offset_samples(self, sampling_frequency=30_000): - sampling_frequency = np.round(sampling_frequency) - return int(self.ms_before * (sampling_frequency / 1000)) - - def spike_length_samples(self, sampling_frequency=30_000): - spike_len_ms = self.ms_before + self.ms_after - sampling_frequency = np.round(sampling_frequency) - length = int(spike_len_ms * (sampling_frequency / 1000)) - # odd is better for convolution arithmetic elsewhere - length = 2 * (length // 2) + 1 - return length - - def relative_slice(self, other, sampling_frequency=30_000): - """My trough-aligned subset of samples in other, which contains me.""" - assert other.ms_before >= self.ms_before - assert other.ms_after >= self.ms_after - my_trough = self.trough_offset_samples(sampling_frequency) - my_len = self.spike_length_samples(sampling_frequency) - other_trough = other.trough_offset_samples(sampling_frequency) - other_len = other.spike_length_samples(sampling_frequency) - start_offset = other_trough - my_trough - end_offset = (other_len - other_trough) - (my_len - my_trough) - if start_offset == end_offset == 0: - return slice(None) - return slice(start_offset, other_len - end_offset) +from .util.internal_config import * @dataclass(frozen=True, kw_only=True, slots=True) -class FeaturizationConfig: - """Featurization and denoising configuration - - Parameters for a featurization and denoising pipeline - which has the flow: - [input waveforms] - -> [featurization of input waveforms] - -> [denoising] - -> [featurization of output waveforms] - - The flags below allow users to control which features - are computed for the input waveforms, what denoising - operations are applied, and what features are computed - for the output (post-denoising) waveforms. - - Users who'd rather do something not covered by this - typical case can manually instantiate a WaveformPipeline - and pass it into their peeler. - """ - - skip: bool = False - - # -- denoising configuration - do_nn_denoise: bool = False - do_tpca_denoise: bool = True - do_enforce_decrease: bool = True - # turn off features below - denoise_only: bool = False - - # -- residual snips - n_residual_snips: int = 4096 - - # -- featurization configuration - save_input_voltages: bool = False - save_input_waveforms: bool = False - save_input_tpca_projs: bool = True - save_output_waveforms: bool = False - save_output_tpca_projs: bool = False - save_amplitudes: bool = True - save_all_amplitudes: bool = False - # localization runs on output waveforms - do_localization: bool = True - localization_radius: float = 100.0 - # these are saved always if do_localization - localization_amplitude_type: Literal["peak", "ptp"] = "peak" - localization_model: Literal["pointsource", "dipole"] = "pointsource" - nn_localization: bool = True +class DARTsortUserConfig: + """User-facing configuration options""" - # -- further info about denoising - # in the future we may add multi-channel or other nns - nn_denoiser_class_name: str = "SingleChannelWaveformDenoiser" - nn_denoiser_pretrained_path: str = default_pretrained_path - nn_denoiser_train_epochs: int = 50 - nn_denoiser_extra_kwargs: dict | None = argfield(None, cli=False) + # -- high level behavior + dredge_only: bool = False + matching_iterations: int = 1 - # optionally restrict how many channels TPCA are fit on - tpca_fit_radius: float = 75.0 - tpca_rank: int = 8 - tpca_centered: bool = False - learn_cleaned_tpca_basis: bool = False - input_tpca_waveform_config: WaveformConfig | None = WaveformConfig( - ms_before=0.75, ms_after=1.25 + # -- parallelism options + n_jobs_cpu: int = argfield( + default=0, + doc="Number of parallel workers to use when running on CPU. " + "0 means everything runs on the main thread.", + ) + n_jobs_gpu: int = argfield( + default=0, + doc="Number of parallel workers to use when running on GPU. " + "0 means everything runs on the main thread.", + ) + device: str | None = argfield( + default=None, + arg_type=str, + doc="The name of the PyTorch device to use. For example, 'cpu' " + "or 'cuda' or 'cuda:1'. If unset, uses n_jobs_gpu of your CUDA " + "GPUs if you have multiple, or else just the one, or your CPU.", ) - # used when naming datasets saved to h5 files - input_waveforms_name: str = "collisioncleaned" - output_waveforms_name: str = "denoised" - - -@dataclass(frozen=True, kw_only=True, slots=True) -class SubtractionConfig: - detection_threshold: float = 4.0 - chunk_length_samples: int = 30_000 - peak_sign: str = "both" - spatial_dedup_radius: float = 150.0 - subtract_radius: float = 200.0 - extract_radius: float = 100.0 - n_chunks_fit: int = 100 - max_waveforms_fit: int = 50_000 - n_waveforms_fit: int = 20_000 - fit_subsampling_random_state: int = 0 - fit_sampling: str = "random" - residnorm_decrease_threshold: float = 3.162 # sqrt(10) - use_singlechan_templates: bool = False - singlechan_threshold: float = 50.0 - n_singlechan_templates: int = 10 - singlechan_alignment_padding: int = 20 - use_universal_templates: bool = False - - # how will waveforms be denoised before subtraction? - # users can also save waveforms/features during subtraction - subtraction_denoising_config: FeaturizationConfig = FeaturizationConfig( - denoise_only=True, - do_nn_denoise=True, - input_waveforms_name="raw", - output_waveforms_name="subtracted", + # -- waveform snippet length parameters + ms_before: Annotated[float, Field(gt=0)] = argfield( + default=1.4, + doc="Length of time (ms) before trough (or peak) in waveform snippets. " + "Default value corresponds to 42 samples at 30kHz.", + ) + ms_after: Annotated[float, Field(gt=0)] = argfield( + default=2.6, + doc="Length of time (ms) after trough (or peak) in waveform snippets. " + "Default value corresponds to 79 samples at 30kHz.", + ) + alignment_ms: Annotated[float, Field(gt=0)] = argfield( + default=0.8, + doc="Time shift allowed when aligning events.", ) + # -- thresholds + initial_threshold: Annotated[float, Field(gt=0)] = argfield( + default=4.0, + doc="Threshold in standardized voltage units for initial detection; " + "peaks or troughs larger than this value will be grabbed.", + ) + matching_threshold: Annotated[float, Field(gt=0)] = argfield( + default=10.0, + doc="Template matching threshold. If subtracting a template leads " + "to at least this great of a decrease in the norm of the residual, " + "that match will be used.", + ) + denoiser_badness_factor: Annotated[float, Field(gt=0, lt=1)] = argfield( + default=0.1, + doc="In initial detection, subtracting clean waveforms inferred " + "by the NN denoiser need only decrease the residual norm squared " + "by this multiple of the squared matching threshold to be accepted.", + ) -@dataclass(frozen=True, kw_only=True, slots=True) -class MotionEstimationConfig: - """Configure motion estimation. + # -- featurization length, radius, rank parameters + temporal_pca_rank: Annotated[int, Field(gt=0)] = argfield( + default=8, doc="Rank of global temporal PCA." + ) + feature_ms_before: Annotated[float, Field(gt=0)] = argfield( + default=0.75, + doc="As ms_before, but used only when computing PCA features in clustering.", + ) + feature_ms_after: Annotated[float, Field(gt=0)] = argfield( + default=1.25, + doc="As ms_after, but used only when computing PCA features in clustering.", + ) + subtraction_radius_um: Annotated[float, Field(gt=0)] = argfield( + default=200.0, + doc="Radius of neighborhoods around spike events extracted " + "when denoising and subtracting NN-denoised events.", + ) + deduplication_radius_um: Annotated[float, Field(gt=0)] = argfield( + default=150.0, + doc="During initial detection, if two spike events occur at the " + "same time within this radius, then the smaller of the two is " + "ignored. But also all of the secondary channels of the big one, " + "which is important.", + ) + featurization_radius_um: Annotated[float, Field(gt=0)] = argfield( + default=100.0, + doc="Radius around detection channel or template peak channel used " + "to extract spike features for clustering.", + ) + fit_radius_um: Annotated[int, Field(gt=0)] = argfield( + default=75.0, + doc="Extraction radius when fitting features like PCA; " + "smaller than other radii to include less noise.", + ) + localization_radius_um: Annotated[float, Field(gt=0)] = argfield( + default=100.0, + doc="Radius around main channel used when localizing spikes.", + ) - You can also make your own and pass it to dartsort() to bypass this - """ + # -- clustering parameters + density_bandwidth: Annotated[float, Field(gt=0)] = 5.0 + interpolation_bandwidth: Annotated[float, Field(gt=0)] = 20.0 - do_motion_estimation: bool = True + # -- matching parameters + amplitude_scaling_stddev: Annotated[float, Field(ge=0)] = 0.1 + amplitude_scaling_limit: Annotated[float, Field(ge=0)] = 1.0 + temporal_upsamples: Annotated[int, Field(gt=1)] = 4 - # sometimes spikes can be localized far away from the probe, causing - # issues with motion estimation, we will ignore such spikes - probe_boundary_padding_um: float = 100.0 + # -- motion estimation parameters + do_motion_estimation: bool = argfield( + default=True, + doc="Set this to false if your data is super stable or already motion-corrected.", + ) # DREDge parameters - spatial_bin_length_um: float = 1.0 - temporal_bin_length_s: float = 1.0 - window_step_um: float = 400.0 - window_scale_um: float = 450.0 - window_margin_um: float | None = argfield(default=None, arg_type=float) - max_dt_s: float = 1000.0 - max_disp_um: float | None = argfield(default=None, arg_type=float) - correlation_threshold: float = 0.1 - min_amplitude: float | None = argfield(default=None, arg_type=float) - rigid: bool = False - - -@dataclass(frozen=True, kw_only=True, slots=True) -class TemplateConfig: - spikes_per_unit: int = 500 - - # -- template construction parameters - # registered templates? - registered_templates: bool = True - registered_template_localization_radius_um: float = 100.0 - - # superresolved templates - superres_templates: bool = False - superres_bin_size_um: float = 10.0 - superres_bin_min_spikes: int = 5 - superres_strategy: str = "drift_pitch_loc_bin" - adaptive_bin_size: bool = False - - # low rank denoising? - low_rank_denoising: bool = True - denoising_rank: int = 5 - denoising_snr_threshold: float = 50.0 - denoising_fit_radius: float = 75.0 - - # realignment - realign_peaks: bool = True - realign_max_sample_shift: int = 20 - - # track template over time - time_tracking: bool = False - chunk_size_s: int = 300 - - -@dataclass(frozen=True, kw_only=True, slots=True) -class MatchingConfig: - chunk_length_samples: int = 30_000 - extract_radius: float = 100.0 - n_chunks_fit: int = 100 - max_waveforms_fit: int = 50_000 - n_waveforms_fit: int = 20_000 - fit_subsampling_random_state: int = 0 - fit_sampling: str = "random" - - # template matching parameters - threshold: float = 150.0 - template_svd_compression_rank: int = 10 - template_temporal_upsampling_factor: int = 8 - template_min_channel_amplitude: float = 1.0 - refractory_radius_frames: int = 10 - amplitude_scaling_variance: float = 0.0 - amplitude_scaling_boundary: float = 0.5 - max_iter: int = 1000 - conv_ignore_threshold: float = 5.0 - coarse_approx_error_threshold: float = 0.0 - coarse_objective: bool = True - - -@dataclass(frozen=True, kw_only=True, slots=True) -class SplitMergeConfig: - # -- split - split_strategy: str = "FeatureSplit" - recursive_split: bool = True - split_strategy_kwargs: dict | None = field( - default_factory=lambda: dict(max_spikes=20_000) + rigid: bool = argfield( + default=False, doc="Use rigid registration and ignore the window parameters." ) - - # -- merge - merge_template_config: TemplateConfig = TemplateConfig(superres_templates=False) - linkage: str = "complete" - merge_distance_threshold: float = 0.25 - cross_merge_distance_threshold: float = 0.5 - min_spatial_cosine: float = 0.0 - - -@dataclass(frozen=True, kw_only=True, slots=True) -class ClusteringConfig: - # -- initial clustering - cluster_strategy: str = "dpc" - - # initial clustering features - use_amplitude: bool = True - amp_log_c: float = 5.0 - amp_scale: float = 50.0 - n_main_channel_pcs: int = 0 - pc_scale: float = 10.0 - adaptive_feature_scales: bool = False - - # density peaks parameters - sigma_local: float = 5.0 - sigma_regional: float | None = argfield(default=25.0, arg_type=float) - workers: int = -1 - n_neighbors_search: int = 20 - radius_search: float = 5.0 - remove_clusters_smaller_than: int = 10 - noise_density: float = 0.0 - outlier_radius: float = 5.0 - outlier_neighbor_count: int = 5 - kdtree_subsample_max_size: int = 2_500_000 - - # hdbscan parameters - min_cluster_size: int = 25 - min_samples: int = 25 - cluster_selection_epsilon: int = 1 - recursive: bool = False - remove_duplicates: bool = False - - # remove large clusters in hdbscan? - remove_big_units: bool = False - zstd_big_units: float = 50.0 - - # grid snap parameters - grid_dx: float = 15.0 - grid_dz: float = 15.0 - - # uhd version of density peaks parameters - sigma_local_low: float | None = argfield(default=None, arg_type=float) - sigma_regional_low: float | None = argfield(default=None, arg_type=float) - distance_dependent_noise_density: bool = False - attach_density_feature: bool = False - triage_quantile_per_cluster: float = 0.0 - revert: bool = False - ramp_triage_per_cluster: bool = False - triage_quantile_before_clustering: float = 0.0 - amp_no_triaging_before_clustering: float = 6.0 - amp_no_triaging_after_clustering: float = 8.0 - use_y_triaging: bool = False - remove_small_far_clusters: bool = False - - # -- ensembling - ensemble_strategy: str | None = argfield(default=None, arg_type=str) - chunk_size_s: float = 300.0 - split_merge_ensemble_config: SplitMergeConfig | None = None - - -@dataclass(frozen=True, kw_only=True, slots=True) -class RefinementConfig: - refinement_stragegy: Literal["gmm", "splitmerge"] = "gmm" - - # -- gmm parameters - # noise params - cov_kind = "full" - # feature params - core_radius: float = 35.0 - interpolation_sigma: float = 20.0 - val_proportion: float = 0.25 - max_n_spikes: float | int = argfield(default=np.inf, arg_type=float) - # model params - signal_rank: int = 0 - n_spikes_fit: int = 4096 - n_em_iters: int = 25 - distance_metric: Literal["noise_metric", "kl", "reverse_kl", "symkl"] = "symkl" - distance_normalization_kind: Literal["none", "noise", "channels"] = "noise" - merge_distance_threshold: float = 1.5 - # if None, switches to bimodality - merge_criterion_threshold: float | None = 0.0 - merge_criterion: Literal[ - "heldout_loglik", "heldout_ccl", "loglik", "ccl", "aic", "bic", "icl" - ] = "heldout_ccl" - merge_bimodality_threshold: float = 0.05 - em_converged_prop: float = 0.02 - em_converged_churn: float = 0.01 - em_converged_atol: float = 1e-2 - n_total_iters: int = 3 - - # if someone wants this - split_merge_config: SplitMergeConfig | None = None - - -@dataclass(frozen=True, kw_only=True, slots=True) -class ComputationConfig: - n_jobs_cpu: int = 0 - n_jobs_gpu: int = 0 - executor: str = "ThreadPoolExecutor" - device: str | None = argfield(default=None, arg_type=str) - - def actual_device(self): - if self.device is None: - have_cuda = torch.cuda.is_available() - if have_cuda: - return torch.device("cuda") - return torch.device("cpu") - return torch.device(self.device) - - def actual_n_jobs(self): - if self.actual_device().type == "cuda": - return self.n_jobs_gpu - return self.n_jobs_cpu - - -@dataclass(frozen=True, kw_only=True, slots=True) -class DARTsortInternalConfig: - waveform_config: WaveformConfig = WaveformConfig() - featurization_config: FeaturizationConfig = FeaturizationConfig() - subtraction_config: SubtractionConfig = SubtractionConfig() - template_config: TemplateConfig = TemplateConfig() - clustering_config: ClusteringConfig = ClusteringConfig() - refinement_config: RefinementConfig = RefinementConfig() - matching_config: MatchingConfig = MatchingConfig() - motion_estimation_config: MotionEstimationConfig = MotionEstimationConfig() - computation_config: ComputationConfig = ComputationConfig() - - # high level behavior - subtract_only: bool = False - final_refinement: bool = True - matching_iterations: int = 1 - intermediate_matching_subsampling: float = 1.0 - - def to_internal_config(self): - return self - - -@dataclass(frozen=True, kw_only=True, slots=True) -class DeveloperConfig: - pass - - def to_internal_config(self): - return DARTsortInternalConfig() + spatial_bin_length_um: Annotated[float, Field(gt=0)] = 1.0 + temporal_bin_length_s: Annotated[float, Field(gt=0)] = 1.0 + window_step_um: Annotated[float, Field(gt=0)] = 400.0 + window_scale_um: Annotated[float, Field(gt=0)] = 450.0 + window_margin_um: Annotated[float, Field(gt=0)] | None = argfield( + default=None, arg_type=float + ) + max_dt_s: Annotated[float, Field(gt=0)] = 1000.0 + max_disp_um: Annotated[float, Field(gt=0)] | None = argfield( + default=None, arg_type=float + ) + correlation_threshold: Annotated[float, Field(gt=0, lt=1)] = 0.1 @dataclass(frozen=True, kw_only=True, slots=True) -class DARTsortUserConfig: - pass - - def to_internal_config(self): - return DARTsortInternalConfig() +class DeveloperConfig(DARTsortUserConfig): + """Additional parameters for experiments. This API will never be stable.""" - -default_waveform_config = WaveformConfig() -default_featurization_config = FeaturizationConfig() -default_subtraction_config = SubtractionConfig() -default_template_config = TemplateConfig() -default_clustering_config = ClusteringConfig() -default_split_merge_config = SplitMergeConfig() -coarse_template_config = TemplateConfig(superres_templates=False) -raw_template_config = TemplateConfig( - realign_peaks=False, low_rank_denoising=False, superres_templates=False -) -unshifted_raw_template_config = TemplateConfig( - registered_templates=False, - realign_peaks=False, - low_rank_denoising=False, - superres_templates=False, -) -unaligned_coarse_denoised_template_config = TemplateConfig( - realign_peaks=False, low_rank_denoising=True, superres_templates=False -) -default_matching_config = MatchingConfig() -default_motion_estimation_config = MotionEstimationConfig() -default_computation_config = ComputationConfig() -default_dartsort_config = DARTsortInternalConfig() -default_refinement_config = RefinementConfig() + use_nn_in_subtraction: bool = True + use_singlechan_templates: bool = False + use_universal_templates: bool = False + signal_rank: Annotated[int, Field(ge=0)] = 0 diff --git a/src/dartsort/main.py b/src/dartsort/main.py index 42dd78fe..82c12612 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -9,14 +9,14 @@ DARTsortUserConfig, DARTsortInternalConfig, DeveloperConfig, - default_clustering_config, + to_internal_config, default_dartsort_config, + default_waveform_config, + default_template_config, + default_clustering_config, default_featurization_config, - default_matching_config, - default_split_merge_config, default_subtraction_config, - default_template_config, - default_waveform_config, + default_matching_config, default_computation_config, ) from dartsort.peel import ObjectiveUpdateTemplateMatchingPeeler, SubtractionPeeler @@ -40,7 +40,7 @@ def dartsort( overwrite=False, ): output_directory = Path(output_directory) - cfg = cfg.to_internal_config() + cfg = to_internal_config(cfg) # first step: subtraction and motion estimation sorting, sub_h5 = subtract( @@ -52,6 +52,8 @@ def dartsort( computation_config=cfg.computation_config, overwrite=overwrite, ) + if cfg.subtract_only: + return dict(sorting=sorting) if motion_est is None: motion_est = estimate_motion( recording, @@ -61,8 +63,8 @@ def dartsort( device=cfg.computation_config.actual_device(), **asdict(cfg.motion_estimation_config), ) - if cfg.subtract_only: - return sorting + if cfg.dredge_only: + return dict(sorting=sorting, motion_est=motion_est) # clustering E/M. start by initializing clusters. sorting = initial_clustering( @@ -80,11 +82,10 @@ def dartsort( computation_config=cfg.computation_config, ) - # E/M iterations for step in range(cfg.matching_iterations): - # E step: deconvolution is_final = step == cfg.matching_iterations - 1 prop = 1.0 if is_final else cfg.intermediate_matching_subsampling + sorting, match_h5 = match( recording, sorting, @@ -110,7 +111,7 @@ def dartsort( ) # done~ - return sorting + return dict(sorting=sorting, motion_est=motion_est) def subtract( diff --git a/src/dartsort/peel/subtract.py b/src/dartsort/peel/subtract.py index 92517007..553ea933 100644 --- a/src/dartsort/peel/subtract.py +++ b/src/dartsort/peel/subtract.py @@ -218,6 +218,10 @@ def from_config( f"waveform_config {trough_offset_samples=} {spike_length_samples=} " f"since {recording.sampling_frequency=}" ) + singlechan_alignment_padding = int( + subtraction_config.singlechan_alignment_padding_ms + * (recording.sampling_frequency / 1000) + ) return cls( recording, @@ -240,7 +244,7 @@ def from_config( use_singlechan_templates=subtraction_config.use_singlechan_templates, n_singlechan_templates=subtraction_config.n_singlechan_templates, singlechan_threshold=subtraction_config.singlechan_threshold, - singlechan_alignment_padding=subtraction_config.singlechan_alignment_padding, + singlechan_alignment_padding=singlechan_alignment_padding, ) def peel_chunk( diff --git a/src/dartsort/templates/get_templates.py b/src/dartsort/templates/get_templates.py index 14255ce4..3a9c9a69 100644 --- a/src/dartsort/templates/get_templates.py +++ b/src/dartsort/templates/get_templates.py @@ -3,6 +3,7 @@ The class TemplateData in templates.py provides a friendlier interface, where you can get templates using the TemplateConfig in config.py. """ + from dataclasses import replace import numpy as np @@ -28,6 +29,7 @@ def get_templates( registered_geom=None, realign_peaks=False, realign_max_sample_shift=20, + realign_to="max", low_rank_denoising=True, denoising_tsvd=None, denoising_rank=5, @@ -131,6 +133,7 @@ def get_templates( raw_results["raw_templates"], raw_results["snrs_by_channel"], raw_results["unit_ids"], + realign_to=realign_to, max_shift=realign_max_sample_shift, trough_offset_samples=trough_offset_samples, recording_length_samples=recording.get_num_samples(), @@ -177,7 +180,14 @@ def get_templates( min_count_at_shift=min_count_at_shift, device=device, ) - unit_ids, spike_counts, raw_templates, low_rank_templates, snrs_by_channel, spike_counts_by_channel = res + ( + unit_ids, + spike_counts, + raw_templates, + low_rank_templates, + snrs_by_channel, + spike_counts_by_channel, + ) = res if raw_only: return dict( @@ -261,6 +271,7 @@ def realign_sorting( templates, snrs_by_channel, unit_ids, + realign_to="trough", max_shift=20, trough_offset_samples=42, recording_length_samples=None, @@ -273,7 +284,13 @@ def realign_sorting( # find template peak time template_maxchans = snrs_by_channel.argmax(1) template_maxchan_traces = templates[np.arange(n), :, template_maxchans] - template_peak_times = np.abs(template_maxchan_traces).argmax(1) + if realign_to == "max": + template_peak_times = np.abs(template_maxchan_traces).argmax(1) + elif realign_to == "trough": + # find the peak... + template_peak_times = template_maxchan_traces.argmin(1) + else: + assert False # find unit sample time shifts template_shifts_ = template_peak_times - (trough_offset_samples + max_shift) @@ -356,8 +373,7 @@ def denoising_weights( d=6.0, edge_behavior="saturate", ): - """Weights are applied to raw template, 1-weights to low rank - """ + """Weights are applied to raw template, 1-weights to low rank""" # v shaped function for time weighting vt = np.abs(np.arange(spike_length_samples) - trough_offset, dtype=float) if trough_offset < spike_length_samples: @@ -402,9 +418,7 @@ def get_all_shifted_raw_and_low_rank_templates( dtype=np.float32, device=None, ): - n_jobs, Executor, context, rank_queue = get_pool( - n_jobs, with_rank_queue=True - ) + n_jobs, Executor, context, rank_queue = get_pool(n_jobs, with_rank_queue=True) unit_ids, spike_counts = np.unique(sorting.labels, return_counts=True) spike_counts = spike_counts[unit_ids >= 0] unit_ids = unit_ids[unit_ids >= 0] @@ -428,16 +442,11 @@ def get_all_shifted_raw_and_low_rank_templates( (n_units, spike_length_samples, n_template_channels), dtype=dtype, ) - snrs_by_channel = np.zeros( - (n_units, n_template_channels), dtype=dtype - ) - spike_counts_by_channel = np.zeros( - (n_units, n_template_channels), dtype=dtype - ) + snrs_by_channel = np.zeros((n_units, n_template_channels), dtype=dtype) + spike_counts_by_channel = np.zeros((n_units, n_template_channels), dtype=dtype) unit_id_chunks = [ - unit_ids[i : i + units_per_job] - for i in range(0, n_units, units_per_job) + unit_ids[i : i + units_per_job] for i in range(0, n_units, units_per_job) ] with Executor( @@ -475,7 +484,13 @@ def get_all_shifted_raw_and_low_rank_templates( for res in results: if res is None: continue - units_chunk, raw_temps_chunk, low_rank_temps_chunk, snrs_chunk, chancounts_chunk = res + ( + units_chunk, + raw_temps_chunk, + low_rank_temps_chunk, + snrs_chunk, + chancounts_chunk, + ) = res ix_chunk = np.isin(unit_ids, units_chunk) raw_templates[ix_chunk] = raw_temps_chunk if not raw: @@ -487,7 +502,14 @@ def get_all_shifted_raw_and_low_rank_templates( if show_progress: pbar.close() - return unit_ids, spike_counts, raw_templates, low_rank_templates, snrs_by_channel, spike_counts_by_channel + return ( + unit_ids, + spike_counts, + raw_templates, + low_rank_templates, + snrs_by_channel, + spike_counts_by_channel, + ) class TemplateProcessContext: @@ -520,9 +542,7 @@ def __init__( self.denoising_tsvd = denoising_tsvd if denoising_tsvd is not None: self.denoising_tsvd = TorchSVDProjector( - torch.from_numpy( - denoising_tsvd.components_.astype(dtype) - ) + torch.from_numpy(denoising_tsvd.components_.astype(dtype)) ) self.denoising_tsvd.to(self.device) self.spikes_per_unit = spikes_per_unit @@ -584,9 +604,7 @@ def _template_process_init( device = torch.device(device) if device.type == "cuda" and device.index is None: if torch.cuda.device_count() > 1: - device = torch.device( - "cuda", index=rank % torch.cuda.device_count() - ) + device = torch.device("cuda", index=rank % torch.cuda.device_count()) torch.set_grad_enabled(False) rg = np.random.default_rng(random_seed + rank) diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 498f80c1..0d8eb8df 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -170,12 +170,10 @@ def from_config( "TemplateData.from_config needs sorting!=None when its .npz file does not exist." ) - trough_offset_samples = waveform_config.trough_offset_samples( - recording.sampling_frequency - ) - spike_length_samples = waveform_config.spike_length_samples( - recording.sampling_frequency - ) + fs = recording.sampling_frequency + trough_offset_samples = waveform_config.trough_offset_samples(fs) + spike_length_samples = waveform_config.spike_length_samples(fs) + realign_max_sample_shift = int(template_config.realign_shift_ms * (fs / 1000)) motion_aware = ( template_config.registered_templates or template_config.superres_templates @@ -204,7 +202,7 @@ def from_config( spikes_per_unit=template_config.spikes_per_unit, # realign handled in advance below, not needed in kwargs # realign_peaks=False, - realign_max_sample_shift=template_config.realign_max_sample_shift, + realign_max_sample_shift=realign_max_sample_shift, denoising_rank=template_config.denoising_rank, denoising_fit_radius=template_config.denoising_fit_radius, denoising_snr_threshold=template_config.denoising_snr_threshold, @@ -384,8 +382,6 @@ def get_chunked_templates( waveform_config=waveform_config, tsvd=tsvd, ) - print(f"{full_template_data.unit_ids.shape=}") - print(f"{full_template_data.templates.shape=}") # break it up back into chunks chunk_template_data = [] diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py new file mode 100644 index 00000000..ce8d47e5 --- /dev/null +++ b/src/dartsort/util/internal_config.py @@ -0,0 +1,524 @@ +from dataclasses import MISSING, field, fields +import dataclasses +from typing import Literal, Annotated + +import numpy as np +import torch +from pydantic.dataclasses import dataclass + +try: + from importlib.resources import files +except ImportError: + try: + from importlib_resources import files + except ImportError: + raise ValueError("Need python>=3.10 or pip install importlib_resources.") + +default_pretrained_path = files("dartsort.pretrained") +default_pretrained_path = default_pretrained_path.joinpath("single_chan_denoiser.pt") +default_pretrained_path = str(default_pretrained_path) + + +def argfield( + default=MISSING, default_factory=MISSING, arg_type=MISSING, cli=True, doc="" +): + """Helper for defining fields with extended CLI behavior. + + This is only needed when a field's type is not a callable which can + take string inputs and return an object of the right type, such as + typing.Union or something. Then arg_type is what the CLI will call + to convert the argv element into an object of the desired type. + + Fields with cli=False will not be available from the command line. + """ + return field( + default=default, + default_factory=default_factory, + metadata=dict(arg_type=arg_type, cli=cli, doc=""), + ) + + +@dataclass(frozen=True, kw_only=True, slots=True) +class WaveformConfig: + """Defaults yield 42 sample trough offset and 121 total at 30kHz.""" + + ms_before: float = 1.4 + ms_after: float = 2.6 + + def trough_offset_samples(self, sampling_frequency=30_000): + sampling_frequency = np.round(sampling_frequency) + return int(self.ms_before * (sampling_frequency / 1000)) + + def spike_length_samples(self, sampling_frequency=30_000): + spike_len_ms = self.ms_before + self.ms_after + sampling_frequency = np.round(sampling_frequency) + length = int(spike_len_ms * (sampling_frequency / 1000)) + # odd is better for convolution arithmetic elsewhere + length = 2 * (length // 2) + 1 + return length + + def relative_slice(self, other, sampling_frequency=30_000): + """My trough-aligned subset of samples in other, which contains me.""" + assert other.ms_before >= self.ms_before + assert other.ms_after >= self.ms_after + my_trough = self.trough_offset_samples(sampling_frequency) + my_len = self.spike_length_samples(sampling_frequency) + other_trough = other.trough_offset_samples(sampling_frequency) + other_len = other.spike_length_samples(sampling_frequency) + start_offset = other_trough - my_trough + end_offset = (other_len - other_trough) - (my_len - my_trough) + if start_offset == end_offset == 0: + return slice(None) + return slice(start_offset, other_len - end_offset) + + +@dataclass(frozen=True, kw_only=True, slots=True) +class FeaturizationConfig: + """Featurization and denoising configuration + + Parameters for a featurization and denoising pipeline + which has the flow: + [input waveforms] + -> [featurization of input waveforms] + -> [denoising] + -> [featurization of output waveforms] + + The flags below allow users to control which features + are computed for the input waveforms, what denoising + operations are applied, and what features are computed + for the output (post-denoising) waveforms. + + Users who'd rather do something not covered by this + typical case can manually instantiate a WaveformPipeline + and pass it into their peeler. + """ + + skip: bool = False + + # -- denoising configuration + do_nn_denoise: bool = False + do_tpca_denoise: bool = True + do_enforce_decrease: bool = True + # turn off features below + denoise_only: bool = False + + # -- residual snips + n_residual_snips: int = 4096 + + # -- featurization configuration + save_input_voltages: bool = False + save_input_waveforms: bool = False + save_input_tpca_projs: bool = True + save_output_waveforms: bool = False + save_output_tpca_projs: bool = False + save_amplitudes: bool = True + save_all_amplitudes: bool = False + # localization runs on output waveforms + do_localization: bool = True + localization_radius: float = 100.0 + # these are saved always if do_localization + localization_amplitude_type: Literal["peak", "ptp"] = "peak" + localization_model: Literal["pointsource", "dipole"] = "pointsource" + nn_localization: bool = True + + # -- further info about denoising + # in the future we may add multi-channel or other nns + nn_denoiser_class_name: str = "SingleChannelWaveformDenoiser" + nn_denoiser_pretrained_path: str = default_pretrained_path + nn_denoiser_train_epochs: int = 50 + nn_denoiser_extra_kwargs: dict | None = argfield(None, cli=False) + + # optionally restrict how many channels TPCA are fit on + tpca_fit_radius: float = 75.0 + tpca_rank: int = 8 + tpca_centered: bool = False + learn_cleaned_tpca_basis: bool = False + input_tpca_waveform_config: WaveformConfig | None = WaveformConfig( + ms_before=0.75, ms_after=1.25 + ) + + # used when naming datasets saved to h5 files + input_waveforms_name: str = "collisioncleaned" + output_waveforms_name: str = "denoised" + + +@dataclass(frozen=True, kw_only=True, slots=True) +class SubtractionConfig: + detection_threshold: float = 4.0 + chunk_length_samples: int = 30_000 + peak_sign: str = "both" + spatial_dedup_radius: float = 150.0 + subtract_radius: float = 200.0 + extract_radius: float = 100.0 + n_chunks_fit: int = 100 + max_waveforms_fit: int = 50_000 + n_waveforms_fit: int = 20_000 + fit_subsampling_random_state: int = 0 + fit_sampling: str = "random" + residnorm_decrease_threshold: float = 3.162 # sqrt(10) + use_singlechan_templates: bool = False + singlechan_threshold: float = 50.0 + n_singlechan_templates: int = 10 + singlechan_alignment_padding_ms: float = 0.7 + use_universal_templates: bool = False + + # how will waveforms be denoised before subtraction? + # users can also save waveforms/features during subtraction + subtraction_denoising_config: FeaturizationConfig = FeaturizationConfig( + denoise_only=True, + do_nn_denoise=True, + input_waveforms_name="raw", + output_waveforms_name="subtracted", + ) + + +@dataclass(frozen=True, kw_only=True, slots=True) +class MotionEstimationConfig: + """Configure motion estimation.""" + + do_motion_estimation: bool = True + + # sometimes spikes can be localized far away from the probe, causing + # issues with motion estimation, we will ignore such spikes + probe_boundary_padding_um: float = 100.0 + + # DREDge parameters + spatial_bin_length_um: float = 1.0 + temporal_bin_length_s: float = 1.0 + window_step_um: float = 400.0 + window_scale_um: float = 450.0 + window_margin_um: float | None = argfield(default=None, arg_type=float) + max_dt_s: float = 1000.0 + max_disp_um: float | None = argfield(default=None, arg_type=float) + correlation_threshold: float = 0.1 + min_amplitude: float | None = argfield(default=None, arg_type=float) + rigid: bool = False + + +@dataclass(frozen=True, kw_only=True, slots=True) +class TemplateConfig: + spikes_per_unit: int = 500 + + # -- template construction parameters + # registered templates? + registered_templates: bool = True + registered_template_localization_radius_um: float = 100.0 + + # superresolved templates + superres_templates: bool = False + superres_bin_size_um: float = 10.0 + superres_bin_min_spikes: int = 5 + superres_strategy: str = "drift_pitch_loc_bin" + adaptive_bin_size: bool = False + + # low rank denoising? + low_rank_denoising: bool = True + denoising_rank: int = 5 + denoising_snr_threshold: float = 50.0 + denoising_fit_radius: float = 75.0 + + # realignment + realign_peaks: bool = True + realign_shift_ms: float = 0.7 + + # track template over time + time_tracking: bool = False + chunk_size_s: int = 300 + + +@dataclass(frozen=True, kw_only=True, slots=True) +class MatchingConfig: + chunk_length_samples: int = 30_000 + extract_radius: float = 100.0 + n_chunks_fit: int = 100 + max_waveforms_fit: int = 50_000 + n_waveforms_fit: int = 20_000 + fit_subsampling_random_state: int = 0 + fit_sampling: str = "random" + + # template matching parameters + threshold: float = 150.0 + template_svd_compression_rank: int = 10 + template_temporal_upsampling_factor: int = 8 + template_min_channel_amplitude: float = 1.0 + refractory_radius_frames: int = 10 + amplitude_scaling_variance: float = 0.0 + amplitude_scaling_boundary: float = 0.5 + max_iter: int = 1000 + conv_ignore_threshold: float = 5.0 + coarse_approx_error_threshold: float = 0.0 + coarse_objective: bool = True + + +@dataclass(frozen=True, kw_only=True, slots=True) +class SplitMergeConfig: + # -- split + split_strategy: str = "FeatureSplit" + recursive_split: bool = True + split_strategy_kwargs: dict | None = field( + default_factory=lambda: dict(max_spikes=20_000) + ) + + # -- merge + merge_template_config: TemplateConfig = TemplateConfig(superres_templates=False) + linkage: str = "complete" + merge_distance_threshold: float = 0.25 + cross_merge_distance_threshold: float = 0.5 + min_spatial_cosine: float = 0.0 + + +@dataclass(frozen=True, kw_only=True, slots=True) +class ClusteringConfig: + # -- initial clustering + cluster_strategy: str = "dpc" + + # initial clustering features + use_amplitude: bool = True + amp_log_c: float = 5.0 + amp_scale: float = 50.0 + n_main_channel_pcs: int = 0 + pc_scale: float = 10.0 + adaptive_feature_scales: bool = False + + # density peaks parameters + sigma_local: float = 5.0 + sigma_regional: float | None = argfield(default=25.0, arg_type=float) + workers: int = -1 + n_neighbors_search: int = 20 + radius_search: float = 5.0 + remove_clusters_smaller_than: int = 10 + noise_density: float = 0.0 + outlier_radius: float = 5.0 + outlier_neighbor_count: int = 5 + kdtree_subsample_max_size: int = 2_500_000 + + # hdbscan parameters + min_cluster_size: int = 25 + min_samples: int = 25 + cluster_selection_epsilon: int = 1 + recursive: bool = False + remove_duplicates: bool = False + + # remove large clusters in hdbscan? + remove_big_units: bool = False + zstd_big_units: float = 50.0 + + # grid snap parameters + grid_dx: float = 15.0 + grid_dz: float = 15.0 + + # uhd version of density peaks parameters + sigma_local_low: float | None = argfield(default=None, arg_type=float) + sigma_regional_low: float | None = argfield(default=None, arg_type=float) + distance_dependent_noise_density: bool = False + attach_density_feature: bool = False + triage_quantile_per_cluster: float = 0.0 + revert: bool = False + ramp_triage_per_cluster: bool = False + triage_quantile_before_clustering: float = 0.0 + amp_no_triaging_before_clustering: float = 6.0 + amp_no_triaging_after_clustering: float = 8.0 + use_y_triaging: bool = False + remove_small_far_clusters: bool = False + + # -- ensembling + ensemble_strategy: str | None = argfield(default=None, arg_type=str) + chunk_size_s: float = 300.0 + split_merge_ensemble_config: SplitMergeConfig | None = None + + +@dataclass(frozen=True, kw_only=True, slots=True) +class RefinementConfig: + refinement_stragegy: Literal["gmm", "splitmerge"] = "gmm" + + # -- gmm parameters + # noise params + cov_kind = "full" + # feature params + core_radius: float = 35.0 + interpolation_sigma: float = 20.0 + val_proportion: float = 0.25 + max_n_spikes: float | int = argfield(default=np.inf, arg_type=float) + # model params + signal_rank: int = 0 + n_spikes_fit: int = 4096 + n_em_iters: int = 25 + distance_metric: Literal["noise_metric", "kl", "reverse_kl", "symkl"] = "symkl" + distance_normalization_kind: Literal["none", "noise", "channels"] = "noise" + merge_distance_threshold: float = 1.5 + # if None, switches to bimodality + merge_criterion_threshold: float | None = 0.0 + merge_criterion: Literal[ + "heldout_loglik", "heldout_ccl", "loglik", "ccl", "aic", "bic", "icl" + ] = "heldout_ccl" + merge_bimodality_threshold: float = 0.05 + em_converged_prop: float = 0.02 + em_converged_churn: float = 0.01 + em_converged_atol: float = 1e-2 + n_total_iters: int = 3 + + # if someone wants this + split_merge_config: SplitMergeConfig | None = None + + +@dataclass(frozen=True, kw_only=True, slots=True) +class ComputationConfig: + n_jobs_cpu: int = 0 + n_jobs_gpu: int = 0 + executor: str = "threading_unless_multigpu" + device: str | None = argfield(default=None, arg_type=str) + + def actual_device(self): + if self.device is None: + have_cuda = torch.cuda.is_available() + if have_cuda: + return torch.device("cuda") + return torch.device("cpu") + return torch.device(self.device) + + def actual_n_jobs(self): + if self.actual_device().type == "cuda": + return self.n_jobs_gpu + return self.n_jobs_cpu + + def is_multi_gpu(self): + if self.n_jobs_gpu in (0, 1): + return False + dev = self.actual_device() + if dev.type != "cuda": + return False + if dev.index is not None: + return False + return torch.cuda.device_count() > 1 + + +@dataclass(frozen=True, kw_only=True, slots=True) +class DARTsortInternalConfig: + """This is an internal object. Make a DARTsortUserConfig, not one of these.""" + + waveform_config: WaveformConfig = WaveformConfig() + featurization_config: FeaturizationConfig = FeaturizationConfig() + subtraction_config: SubtractionConfig = SubtractionConfig() + template_config: TemplateConfig = TemplateConfig() + clustering_config: ClusteringConfig = ClusteringConfig() + refinement_config: RefinementConfig = RefinementConfig() + matching_config: MatchingConfig = MatchingConfig() + motion_estimation_config: MotionEstimationConfig = MotionEstimationConfig() + computation_config: ComputationConfig = ComputationConfig() + + # high level behavior + subtract_only: bool = False + dredge_only: bool = False + final_refinement: bool = True + matching_iterations: int = 1 + intermediate_matching_subsampling: float = 1.0 + + +default_waveform_config = WaveformConfig() +default_featurization_config = FeaturizationConfig() +default_subtraction_config = SubtractionConfig() +default_template_config = TemplateConfig() +default_clustering_config = ClusteringConfig() +default_split_merge_config = SplitMergeConfig() +coarse_template_config = TemplateConfig(superres_templates=False) +raw_template_config = TemplateConfig( + realign_peaks=False, low_rank_denoising=False, superres_templates=False +) +unshifted_raw_template_config = TemplateConfig( + registered_templates=False, + realign_peaks=False, + low_rank_denoising=False, + superres_templates=False, +) +unaligned_coarse_denoised_template_config = TemplateConfig( + realign_peaks=False, low_rank_denoising=True, superres_templates=False +) +default_matching_config = MatchingConfig() +default_motion_estimation_config = MotionEstimationConfig() +default_computation_config = ComputationConfig() +default_dartsort_config = DARTsortInternalConfig() +default_refinement_config = RefinementConfig() + + +def to_internal_config(cfg): + from dartsort.config import DARTsortUserConfig, DeveloperConfig + + if isinstance(cfg, DARTsortInternalConfig): + return cfg + else: + assert isinstance(cfg, (DARTsortUserConfig, DeveloperConfig)) + + # if we have a user cfg, dump into dev cfg, and work from there + if isinstance(cfg, DARTsortUserConfig): + cfg = DeveloperConfig(**dataclasses.asdict(cfg)) + + waveform_config = WaveformConfig(ms_before=cfg.ms_before, ms_after=cfg.ms_after) + tpca_waveform_config = WaveformConfig( + ms_before=cfg.feature_ms_before, ms_after=cfg.feature_ms_after + ) + featurization_config = FeaturizationConfig( + tpca_rank=cfg.temporal_pca_rank, + input_tpca_waveform_config=tpca_waveform_config, + localization_radius=cfg.localization_radius_um, + tpca_fit_radius=cfg.fit_radius_um, + ) + subtraction_denoising_config = FeaturizationConfig( + denoise_only=True, + do_nn_denoise=cfg.use_nn_in_subtraction, + tpca_rank=cfg.temporal_pca_rank, + tpca_fit_radius=cfg.fit_radius_um, + ) + subtraction_config = SubtractionConfig( + detection_threshold=cfg.initial_threshold, + spatial_dedup_radius=cfg.deduplication_radius_um, + subtract_radius=cfg.subtraction_radius_um, + extract_radius=cfg.featurization_radius_um, + singlechan_alignment_padding_ms=cfg.alignment_ms, + use_singlechan_templates=cfg.use_singlechan_templates, + use_universal_templates=cfg.use_universal_templates, + subtraction_denoising_config=subtraction_denoising_config, + residnorm_decrease_threshold=np.sqrt( + cfg.denoiser_badness_factor * cfg.matching_threshold**2 + ), + ) + template_config = TemplateConfig( + registered_template_localization_radius_um=cfg.localization_radius_um, + denoising_fit_radius=cfg.fit_radius_um, + realign_shift_ms=cfg.alignment_ms, + ) + clustering_config = ClusteringConfig( + sigma_local=cfg.density_bandwidth, + sigma_regional=5 * cfg.density_bandwidth, + outlier_radius=cfg.density_bandwidth, + radius_search=cfg.density_bandwidth, + ) + refinement_config = RefinementConfig( + signal_rank=cfg.signal_rank, interpolation_sigma=cfg.interpolation_bandwidth + ) + motion_estimation_config = MotionEstimationConfig( + **{k.name: getattr(cfg, k.name) for k in fields(MotionEstimationConfig)} + ) + matching_config = MatchingConfig( + threshold=cfg.matching_threshold, + amplitude_scaling_variance=cfg.amplitude_scaling_stddev**2, + amplitude_scaling_boundary=cfg.amplitude_scaling_limit, + template_temporal_upsampling_factor=cfg.temporal_upsamples, + extract_radius=cfg.featurization_radius_um, + ) + computation_config = ComputationConfig( + n_jobs_cpu=cfg.n_jobs_cpu, n_jobs_gpu=cfg.n_jobs_gpu, device=cfg.device + ) + + return DARTsortInternalConfig( + waveform_config=waveform_config, + featurization_config=featurization_config, + subtraction_config=subtraction_config, + template_config=template_config, + clustering_config=clustering_config, + refinement_config=refinement_config, + matching_config=matching_config, + motion_estimation_config=motion_estimation_config, + computation_config=computation_config, + dredge_only=cfg.dredge_only, + matching_iterations=cfg.matching_iterations, + ) diff --git a/src/dartsort/util/multiprocessing_util.py b/src/dartsort/util/multiprocessing_util.py index db34863c..f6ea793f 100644 --- a/src/dartsort/util/multiprocessing_util.py +++ b/src/dartsort/util/multiprocessing_util.py @@ -107,6 +107,7 @@ def __init__(self): def cloudpickle_run(fn, args): + assert cloudpickle is not None fn = cloudpickle.loads(fn) args, kwargs = cloudpickle.loads(args) # return cloudpickle.dumps(fn(*args, **kwargs)) @@ -120,6 +121,7 @@ def cloudpickle_run(fn, args): class CloudpicklePoolExecutor(ProcessPoolExecutor): def submit(self, fn, /, *args, **kwargs): + assert cloudpickle is not None args = cloudpickle.dumps((args, kwargs)) future = super().submit(cloudpickle_run, cloudpickle.dumps(fn), args) # future.add_done_callback(uncloudpickle_callback) @@ -141,6 +143,7 @@ def pool_from_cfg(computation_config=None, with_rank_queue=False, check_local=Fa cls=computation_config.executor, with_rank_queue=with_rank_queue, check_local=check_local, + multi_gpu=computation_config.is_multi_gpu(), ) @@ -153,6 +156,7 @@ def get_pool( n_tasks=None, max_tasks_per_child=None, check_local=False, + multi_gpu=False, ): if n_jobs == -1: n_jobs = multiprocessing.cpu_count() @@ -160,6 +164,11 @@ def get_pool( n_jobs = max(1, n_jobs) if isinstance(cls, str): + if cls == "threading_unless_multigpu": + if n_jobs > 1 and multi_gpu: + cls = "ProcessPoolExecutor" + else: + cls = "ThreadPoolExecutor" if cls == "CloudpicklePoolExecutor": cls = CloudpicklePoolExecutor elif cls == "ThreadPoolExecutor":