Skip to content

Commit 0cbdd00

Browse files
committed
Return more details
1 parent 0596b09 commit 0cbdd00

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

src/dartsort/main.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,14 @@ def dartsort(
3838
) = default_dartsort_config,
3939
motion_est=None,
4040
overwrite=False,
41+
return_extra=False,
4142
):
4243
output_directory = Path(output_directory)
4344
cfg = to_internal_config(cfg)
4445

45-
# first step: subtraction and motion estimation
46+
ret = {}
47+
48+
# first step: initial detection and motion estimation
4649
sorting, sub_h5 = subtract(
4750
recording,
4851
output_directory,
@@ -52,8 +55,13 @@ def dartsort(
5255
computation_config=cfg.computation_config,
5356
overwrite=overwrite,
5457
)
58+
if return_extra:
59+
ret["initial_detection"] = sorting
60+
5561
if cfg.subtract_only:
56-
return dict(sorting=sorting)
62+
ret["sorting"] = sorting
63+
return ret
64+
5765
if motion_est is None:
5866
motion_est = estimate_motion(
5967
recording,
@@ -63,25 +71,33 @@ def dartsort(
6371
device=cfg.computation_config.actual_device(),
6472
**asdict(cfg.motion_estimation_config),
6573
)
74+
ret["motion_est"] = motion_est
75+
6676
if cfg.dredge_only:
67-
return dict(sorting=sorting, motion_est=motion_est)
77+
ret["sorting"] = sorting
78+
return ret
6879

69-
# clustering E/M. start by initializing clusters.
80+
# clustering
7081
sorting = initial_clustering(
7182
recording,
7283
sorting=sorting,
7384
motion_est=motion_est,
7485
clustering_config=cfg.clustering_config,
7586
computation_config=cfg.computation_config,
7687
)
88+
if return_extra:
89+
ret["initial_labels"] = sorting.labels
7790
sorting = refine_clustering(
7891
recording=recording,
7992
sorting=sorting,
8093
motion_est=motion_est,
8194
refinement_config=cfg.refinement_config,
8295
computation_config=cfg.computation_config,
8396
)
97+
if return_extra:
98+
ret["refined_labels"] = sorting.labels
8499

100+
# alternate matching with
85101
for step in range(cfg.matching_iterations):
86102
is_final = step == cfg.matching_iterations - 1
87103
prop = 1.0 if is_final else cfg.intermediate_matching_subsampling
@@ -100,6 +116,8 @@ def dartsort(
100116
hdf5_filename=f"matching{step}.h5",
101117
model_subdir=f"matching{step}_models",
102118
)
119+
if return_extra:
120+
ret[f"matching{step}"] = sorting
103121

104122
if (not is_final) or cfg.final_refinement:
105123
sorting = refine_clustering(
@@ -109,9 +127,12 @@ def dartsort(
109127
refinement_config=cfg.refinement_config,
110128
computation_config=cfg.computation_config,
111129
)
130+
if return_extra:
131+
ret[f"refined{step}_labels"] = sorting.labels
112132

113133
# done~
114-
return dict(sorting=sorting, motion_est=motion_est)
134+
ret["sorting"] = sorting
135+
return ret
115136

116137

117138
def subtract(

0 commit comments

Comments
 (0)