@@ -38,11 +38,14 @@ def dartsort(
38
38
) = default_dartsort_config ,
39
39
motion_est = None ,
40
40
overwrite = False ,
41
+ return_extra = False ,
41
42
):
42
43
output_directory = Path (output_directory )
43
44
cfg = to_internal_config (cfg )
44
45
45
- # first step: subtraction and motion estimation
46
+ ret = {}
47
+
48
+ # first step: initial detection and motion estimation
46
49
sorting , sub_h5 = subtract (
47
50
recording ,
48
51
output_directory ,
@@ -52,8 +55,13 @@ def dartsort(
52
55
computation_config = cfg .computation_config ,
53
56
overwrite = overwrite ,
54
57
)
58
+ if return_extra :
59
+ ret ["initial_detection" ] = sorting
60
+
55
61
if cfg .subtract_only :
56
- return dict (sorting = sorting )
62
+ ret ["sorting" ] = sorting
63
+ return ret
64
+
57
65
if motion_est is None :
58
66
motion_est = estimate_motion (
59
67
recording ,
@@ -63,25 +71,33 @@ def dartsort(
63
71
device = cfg .computation_config .actual_device (),
64
72
** asdict (cfg .motion_estimation_config ),
65
73
)
74
+ ret ["motion_est" ] = motion_est
75
+
66
76
if cfg .dredge_only :
67
- return dict (sorting = sorting , motion_est = motion_est )
77
+ ret ["sorting" ] = sorting
78
+ return ret
68
79
69
- # clustering E/M. start by initializing clusters.
80
+ # clustering
70
81
sorting = initial_clustering (
71
82
recording ,
72
83
sorting = sorting ,
73
84
motion_est = motion_est ,
74
85
clustering_config = cfg .clustering_config ,
75
86
computation_config = cfg .computation_config ,
76
87
)
88
+ if return_extra :
89
+ ret ["initial_labels" ] = sorting .labels
77
90
sorting = refine_clustering (
78
91
recording = recording ,
79
92
sorting = sorting ,
80
93
motion_est = motion_est ,
81
94
refinement_config = cfg .refinement_config ,
82
95
computation_config = cfg .computation_config ,
83
96
)
97
+ if return_extra :
98
+ ret ["refined_labels" ] = sorting .labels
84
99
100
+ # alternate matching with
85
101
for step in range (cfg .matching_iterations ):
86
102
is_final = step == cfg .matching_iterations - 1
87
103
prop = 1.0 if is_final else cfg .intermediate_matching_subsampling
@@ -100,6 +116,8 @@ def dartsort(
100
116
hdf5_filename = f"matching{ step } .h5" ,
101
117
model_subdir = f"matching{ step } _models" ,
102
118
)
119
+ if return_extra :
120
+ ret [f"matching{ step } " ] = sorting
103
121
104
122
if (not is_final ) or cfg .final_refinement :
105
123
sorting = refine_clustering (
@@ -109,9 +127,12 @@ def dartsort(
109
127
refinement_config = cfg .refinement_config ,
110
128
computation_config = cfg .computation_config ,
111
129
)
130
+ if return_extra :
131
+ ret [f"refined{ step } _labels" ] = sorting .labels
112
132
113
133
# done~
114
- return dict (sorting = sorting , motion_est = motion_est )
134
+ ret ["sorting" ] = sorting
135
+ return ret
115
136
116
137
117
138
def subtract (
0 commit comments