Skip to content

Commit

Permalink
Some debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Dec 20, 2024
1 parent 0cbdd00 commit 4f0e2e2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/dartsort/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class DARTsortUserConfig:
"or 'cuda' or 'cuda:1'. If unset, uses n_jobs_gpu of your CUDA "
"GPUs if you have multiple, or else just the one, or your CPU.",
)
executor: str = "threading_unless_multigpu"

# -- waveform snippet length parameters
ms_before: Annotated[float, Field(gt=0)] = argfield(
Expand Down Expand Up @@ -128,6 +129,7 @@ class DARTsortUserConfig:
rigid: bool = argfield(
default=False, doc="Use rigid registration and ignore the window parameters."
)
probe_boundary_padding_um: float = 100.0
spatial_bin_length_um: Annotated[float, Field(gt=0)] = 1.0
temporal_bin_length_s: Annotated[float, Field(gt=0)] = 1.0
window_step_um: Annotated[float, Field(gt=0)] = 400.0
Expand All @@ -140,6 +142,7 @@ class DARTsortUserConfig:
default=None, arg_type=float
)
correlation_threshold: Annotated[float, Field(gt=0, lt=1)] = 0.1
min_amplitude: float | None = argfield(default=None, arg_type=float)


@dataclass(frozen=True, kw_only=True, slots=True)
Expand Down
5 changes: 4 additions & 1 deletion src/dartsort/util/internal_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,10 @@ def to_internal_config(cfg):
extract_radius=cfg.featurization_radius_um,
)
computation_config = ComputationConfig(
n_jobs_cpu=cfg.n_jobs_cpu, n_jobs_gpu=cfg.n_jobs_gpu, device=cfg.device
n_jobs_cpu=cfg.n_jobs_cpu,
n_jobs_gpu=cfg.n_jobs_gpu,
device=cfg.device,
executor=cfg.executor,
)

return DARTsortInternalConfig(
Expand Down
4 changes: 3 additions & 1 deletion src/dartsort/util/universal_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,14 @@ def singlechan_to_library(
max_distance=max_distance,
dx=dx,
)
if torch.is_tensor(singlechan_templates):
singlechan_templates = singlechan_templates.numpy(force=True)
nf, nc = footprints.shape
nsct, nt = singlechan_templates.shape
templates = footprints[:, None, None, :] * singlechan_templates[None, :, :, None]
assert templates.shape == (nf, nsct, nt, nc)
templates = templates.reshape(nf * nsct, nt, nc)
templates /= np.linalg.norm(templates, axis=(1, 2))
templates /= np.linalg.norm(templates, axis=(1, 2), keepdims=True)

return TemplateData(
templates,
Expand Down

0 comments on commit 4f0e2e2

Please sign in to comment.