@@ -10,6 +10,7 @@ def forward_backward(
10
10
feature_scales = (1 , 1 , 50 ),
11
11
adaptive_feature_scales = False ,
12
12
motion_est = None ,
13
+ verbose = True ,
13
14
):
14
15
"""
15
16
Ensemble over HDBSCAN clustering
@@ -19,21 +20,23 @@ def forward_backward(
19
20
return chunk_sortings [0 ]
20
21
21
22
times_seconds = chunk_sortings [0 ].times_seconds
22
- times_samples = chunk_sortings [ 0 ]. times_samples
23
+
23
24
min_time_s = chunk_time_ranges_s [0 ][0 ]
24
25
idx_all_chunks = [get_indices_in_chunk (times_seconds , chunk_range ) for chunk_range in chunk_time_ranges_s ]
25
26
26
27
# put all labels into one array
27
28
# TODO: this does not allow for overlapping chunks.
28
- labels_all = np .full_like (times_samples , - 1 )
29
+ labels_all = np .full_like (times_seconds , - 1 )
29
30
for ix , sorting in zip (idx_all_chunks , chunk_sortings ):
30
- assert labels_all [ix ].max () < 0 # assert non-overlapping
31
- labels_all [ix ] = sorting .labels [ix ]
31
+ if len (ix ):
32
+ assert labels_all [ix ].max () < 0 # assert non-overlapping
33
+ labels_all [ix ] = sorting .labels [ix ]
32
34
33
35
# load features that we will need
34
36
# needs to be all features here
35
37
amps = chunk_sortings [0 ].denoised_ptp_amplitudes
36
38
xyza = chunk_sortings [0 ].point_source_localizations
39
+
37
40
x = xyza [:, 0 ]
38
41
z_reg = xyza [:, 2 ]
39
42
@@ -44,7 +47,11 @@ def forward_backward(
44
47
if motion_est is not None :
45
48
z_reg = motion_est .correct_s (times_seconds , z_reg )
46
49
47
- for k in trange (len (chunk_sortings ) - 1 , desc = "Ensembling chunks" ):
50
+ if verbose is True :
51
+ tbar = trange (len (chunk_sortings ) - 1 , desc = "Ensembling chunks" )
52
+ else :
53
+ tbar = range (len (chunk_sortings ) - 1 )
54
+ for k in tbar :
48
55
# CHANGE THE 1 ---
49
56
# idx_1 = np.flatnonzero(np.logical_and(times_seconds>=min_time_s, times_seconds<min_time_s+k*shift+chunk_size_s))
50
57
idx_1 = np .flatnonzero (
@@ -62,160 +69,162 @@ def forward_backward(
62
69
amps_2 = feature_scales [2 ] * np .log (log_c + amps [idx_2 ])
63
70
labels_1 = labels_all [idx_1 ].copy ().astype ("int" )
64
71
labels_2 = chunk_sortings [k + 1 ].labels [idx_2 ]
65
- unit_label_shift = int (labels_1 .max () + 1 )
66
- labels_2 [labels_2 > - 1 ] += unit_label_shift
67
72
68
73
units_1 = np .unique (labels_1 )
69
74
units_1 = units_1 [units_1 > - 1 ]
70
75
units_2 = np .unique (labels_2 )
71
76
units_2 = units_2 [units_2 > - 1 ]
72
77
73
- # FORWARD PASS
74
-
75
- dist_matrix = np .zeros ((units_1 .shape [0 ], units_2 .shape [0 ]))
76
-
77
- # Speed up this code - this matrix can be sparse (only compute distance for "neighboring" units) - OK for now, still pretty fast
78
- for i in range (units_1 .shape [0 ]):
79
- unit_1 = units_1 [i ]
80
- for j in range (units_2 .shape [0 ]):
81
- unit_2 = units_2 [j ]
82
- feat_1 = np .c_ [
83
- np .median (x_1 [labels_1 == unit_1 ]),
84
- np .median (z_1 [labels_1 == unit_1 ]),
85
- np .median (amps_1 [labels_1 == unit_1 ]),
86
- ]
87
- feat_2 = np .c_ [
88
- np .median (x_2 [labels_2 == unit_2 ]),
89
- np .median (z_2 [labels_2 == unit_2 ]),
90
- np .median (amps_2 [labels_2 == unit_2 ]),
91
- ]
92
- dist_matrix [i , j ] = ((feat_1 - feat_2 ) ** 2 ).sum ()
93
-
94
- # find for chunk 2 units the closest units in chunk 1 and split chunk 1 units
95
- dist_forward = dist_matrix .argmin (0 )
96
- units_ , counts_ = np .unique (dist_forward , return_counts = True )
97
-
98
- for unit_to_split in units_ [counts_ > 1 ]:
99
- units_to_match_to = (
100
- np .flatnonzero (dist_forward == unit_to_split ) + unit_label_shift
101
- )
102
- features_to_match_to = np .c_ [
103
- np .median (x_2 [labels_2 == units_to_match_to [0 ]]),
104
- np .median (z_2 [labels_2 == units_to_match_to [0 ]]),
105
- np .median (amps_2 [labels_2 == units_to_match_to [0 ]]),
106
- ]
107
- for u in units_to_match_to [1 :]:
108
- features_to_match_to = np .concatenate (
109
- (
110
- features_to_match_to ,
111
- np .c_ [
112
- np .median (x_2 [labels_2 == u ]),
113
- np .median (z_2 [labels_2 == u ]),
114
- np .median (amps_2 [labels_2 == u ]),
115
- ],
116
- )
117
- )
118
- spikes_to_update = np .flatnonzero (labels_1 == unit_to_split )
119
- x_s_to_update = x_1 [spikes_to_update ]
120
- z_s_to_update = z_1 [spikes_to_update ]
121
- amps_s_to_update = amps_1 [spikes_to_update ]
122
- for j , s in enumerate (spikes_to_update ):
123
- # Don't update if new distance is too high?
124
- feat_s = np .c_ [
125
- x_s_to_update [j ], z_s_to_update [j ], amps_s_to_update [j ]
126
- ]
127
- labels_1 [s ] = units_to_match_to [
128
- ((feat_s - features_to_match_to ) ** 2 ).sum (1 ).argmin ()
129
- ]
130
-
131
- # Relabel labels_1 and labels_2
132
- for unit_to_relabel in units_ :
133
- if counts_ [np .flatnonzero (units_ == unit_to_relabel )][0 ] == 1 :
134
- idx_to_relabel = np .flatnonzero (labels_1 == unit_to_relabel )
135
- labels_1 [idx_to_relabel ] = units_2 [dist_forward == unit_to_relabel ]
136
-
137
- # BACKWARD PASS
138
-
139
- units_not_matched = np .unique (labels_1 )
140
- units_not_matched = units_not_matched [units_not_matched > - 1 ]
141
- units_not_matched = units_not_matched [units_not_matched < unit_label_shift ]
142
-
143
- if len (units_not_matched ):
144
- all_units_to_match_to = (
145
- dist_matrix [units_not_matched ].argmin (1 ) + unit_label_shift
146
- )
147
- for unit_to_split in np .unique (all_units_to_match_to ):
148
- units_to_match_to = np .concatenate (
149
- (
150
- units_not_matched [all_units_to_match_to == unit_to_split ],
151
- [unit_to_split ],
152
- )
78
+ if len (units_2 ) and len (units_1 ):
79
+ unit_label_shift = int (labels_1 .max () + 1 )
80
+ labels_2 [labels_2 > - 1 ] += unit_label_shift
81
+ units_2 += unit_label_shift
82
+
83
+ # FORWARD PASS
84
+ dist_matrix = np .zeros ((units_1 .shape [0 ], units_2 .shape [0 ]))
85
+
86
+ # Speed up this code - this matrix can be sparse (only compute distance for "neighboring" units) - OK for now, still pretty fast
87
+ for i in range (units_1 .shape [0 ]):
88
+ unit_1 = units_1 [i ]
89
+ for j in range (units_2 .shape [0 ]):
90
+ unit_2 = units_2 [j ]
91
+ feat_1 = np .c_ [
92
+ np .median (x_1 [labels_1 == unit_1 ]),
93
+ np .median (z_1 [labels_1 == unit_1 ]),
94
+ np .median (amps_1 [labels_1 == unit_1 ]),
95
+ ]
96
+ feat_2 = np .c_ [
97
+ np .median (x_2 [labels_2 == unit_2 ]),
98
+ np .median (z_2 [labels_2 == unit_2 ]),
99
+ np .median (amps_2 [labels_2 == unit_2 ]),
100
+ ]
101
+ dist_matrix [i , j ] = ((feat_1 - feat_2 ) ** 2 ).sum ()
102
+
103
+ # find for chunk 2 units the closest units in chunk 1 and split chunk 1 units
104
+ dist_forward = dist_matrix .argmin (0 )
105
+ units_ , counts_ = np .unique (dist_forward , return_counts = True )
106
+
107
+ for unit_to_split in units_ [counts_ > 1 ]:
108
+ units_to_match_to = (
109
+ np .flatnonzero (dist_forward == unit_to_split ) + unit_label_shift
153
110
)
154
-
155
111
features_to_match_to = np .c_ [
156
- np .median (x_1 [ labels_1 == units_to_match_to [0 ]]),
157
- np .median (z_1 [ labels_1 == units_to_match_to [0 ]]),
158
- np .median (amps_1 [ labels_1 == units_to_match_to [0 ]]),
112
+ np .median (x_2 [ labels_2 == units_to_match_to [0 ]]),
113
+ np .median (z_2 [ labels_2 == units_to_match_to [0 ]]),
114
+ np .median (amps_2 [ labels_2 == units_to_match_to [0 ]]),
159
115
]
160
116
for u in units_to_match_to [1 :]:
161
117
features_to_match_to = np .concatenate (
162
118
(
163
119
features_to_match_to ,
164
120
np .c_ [
165
- np .median (x_1 [ labels_1 == u ]),
166
- np .median (z_1 [ labels_1 == u ]),
167
- np .median (amps_1 [ labels_1 == u ]),
121
+ np .median (x_2 [ labels_2 == u ]),
122
+ np .median (z_2 [ labels_2 == u ]),
123
+ np .median (amps_2 [ labels_2 == u ]),
168
124
],
169
125
)
170
126
)
171
- spikes_to_update = np .flatnonzero (labels_2 == unit_to_split )
172
- x_s_to_update = x_2 [spikes_to_update ]
173
- z_s_to_update = z_2 [spikes_to_update ]
174
- amps_s_to_update = amps_2 [spikes_to_update ]
127
+ spikes_to_update = np .flatnonzero (labels_1 == unit_to_split )
128
+ x_s_to_update = x_1 [spikes_to_update ]
129
+ z_s_to_update = z_1 [spikes_to_update ]
130
+ amps_s_to_update = amps_1 [spikes_to_update ]
175
131
for j , s in enumerate (spikes_to_update ):
132
+ # Don't update if new distance is too high?
176
133
feat_s = np .c_ [
177
134
x_s_to_update [j ], z_s_to_update [j ], amps_s_to_update [j ]
178
135
]
179
- labels_2 [s ] = units_to_match_to [
136
+ labels_1 [s ] = units_to_match_to [
180
137
((feat_s - features_to_match_to ) ** 2 ).sum (1 ).argmin ()
181
138
]
182
-
183
- # Do we need to "regularize" and make sure the distance intra units after merging is smaller than the distance inter units before merging
184
- all_labels_1 = np .unique (labels_1 )
185
- all_labels_1 = all_labels_1 [all_labels_1 > - 1 ]
186
-
187
- features_all_1 = np .c_ [
188
- np .median (x_1 [labels_1 == all_labels_1 [0 ]]),
189
- np .median (z_1 [labels_1 == all_labels_1 [0 ]]),
190
- np .median (amps_1 [labels_1 == all_labels_1 [0 ]]),
191
- ]
192
- for u in all_labels_1 [1 :]:
193
- features_all_1 = np .concatenate (
194
- (
195
- features_all_1 ,
196
- np .c_ [
197
- np .median (x_1 [labels_1 == u ]),
198
- np .median (z_1 [labels_1 == u ]),
199
- np .median (amps_1 [labels_1 == u ]),
200
- ],
139
+
140
+ # Relabel labels_1 and labels_2
141
+ for unit_to_relabel in units_ :
142
+ if counts_ [np .flatnonzero (units_ == unit_to_relabel )][0 ] == 1 :
143
+ idx_to_relabel = np .flatnonzero (labels_1 == unit_to_relabel )
144
+ labels_1 [idx_to_relabel ] = units_2 [dist_forward == unit_to_relabel ]
145
+
146
+ # BACKWARD PASS
147
+
148
+ units_not_matched = np .unique (labels_1 )
149
+ units_not_matched = units_not_matched [units_not_matched > - 1 ]
150
+ units_not_matched = units_not_matched [units_not_matched < unit_label_shift ]
151
+
152
+ if len (units_not_matched ):
153
+ all_units_to_match_to = (
154
+ dist_matrix [units_not_matched ].argmin (1 ) + unit_label_shift
201
155
)
156
+ for unit_to_split in np .unique (all_units_to_match_to ):
157
+ units_to_match_to = np .concatenate (
158
+ (
159
+ units_not_matched [all_units_to_match_to == unit_to_split ],
160
+ [unit_to_split ],
161
+ )
162
+ )
163
+
164
+ features_to_match_to = np .c_ [
165
+ np .median (x_1 [labels_1 == units_to_match_to [0 ]]),
166
+ np .median (z_1 [labels_1 == units_to_match_to [0 ]]),
167
+ np .median (amps_1 [labels_1 == units_to_match_to [0 ]]),
168
+ ]
169
+ for u in units_to_match_to [1 :]:
170
+ features_to_match_to = np .concatenate (
171
+ (
172
+ features_to_match_to ,
173
+ np .c_ [
174
+ np .median (x_1 [labels_1 == u ]),
175
+ np .median (z_1 [labels_1 == u ]),
176
+ np .median (amps_1 [labels_1 == u ]),
177
+ ],
178
+ )
179
+ )
180
+ spikes_to_update = np .flatnonzero (labels_2 == unit_to_split )
181
+ x_s_to_update = x_2 [spikes_to_update ]
182
+ z_s_to_update = z_2 [spikes_to_update ]
183
+ amps_s_to_update = amps_2 [spikes_to_update ]
184
+ for j , s in enumerate (spikes_to_update ):
185
+ feat_s = np .c_ [
186
+ x_s_to_update [j ], z_s_to_update [j ], amps_s_to_update [j ]
187
+ ]
188
+ labels_2 [s ] = units_to_match_to [
189
+ ((feat_s - features_to_match_to ) ** 2 ).sum (1 ).argmin ()
190
+ ]
191
+
192
+ # Do we need to "regularize" and make sure the distance intra units after merging is smaller than the distance inter units before merging
193
+ # all_labels_1 = np.unique(labels_1)
194
+ # all_labels_1 = all_labels_1[all_labels_1 > -1]
195
+
196
+ # features_all_1 = np.c_[
197
+ # np.median(x_1[labels_1 == all_labels_1[0]]),
198
+ # np.median(z_1[labels_1 == all_labels_1[0]]),
199
+ # np.median(amps_1[labels_1 == all_labels_1[0]]),
200
+ # ]
201
+ # for u in all_labels_1[1:]:
202
+ # features_all_1 = np.concatenate(
203
+ # (
204
+ # features_all_1,
205
+ # np.c_[
206
+ # np.median(x_1[labels_1 == u]),
207
+ # np.median(z_1[labels_1 == u]),
208
+ # np.median(amps_1[labels_1 == u]),
209
+ # ],
210
+ # )
211
+ # )
212
+
213
+ # distance_inter = (
214
+ # (features_all_1[:, :, None] - features_all_1.T[None]) ** 2
215
+ # ).sum(1)
216
+
217
+ labels_12 = np .concatenate ((labels_1 , labels_2 ))
218
+ _ , labels_12 [labels_12 > - 1 ] = np .unique (
219
+ labels_12 [labels_12 > - 1 ], return_inverse = True
220
+ ) # Make contiguous
221
+ idx_all = np .flatnonzero (
222
+ times_seconds < min_time_s + chunk_time_ranges_s [k + 1 ][1 ]
202
223
)
203
-
204
- distance_inter = (
205
- (features_all_1 [:, :, None ] - features_all_1 .T [None ]) ** 2
206
- ).sum (1 )
207
-
208
- labels_12 = np .concatenate ((labels_1 , labels_2 ))
209
- _ , labels_12 [labels_12 > - 1 ] = np .unique (
210
- labels_12 [labels_12 > - 1 ], return_inverse = True
211
- ) # Make contiguous
212
- idx_all = np .flatnonzero (
213
- times_seconds < min_time_s + chunk_time_ranges_s [k + 1 ][1 ]
214
- )
215
- labels_all = - 1 * np .ones (
216
- times_seconds .shape [0 ]
217
- ) # discard all spikes at the end for now
218
- labels_all [idx_all ] = labels_12 .astype ("int" )
224
+ labels_all = - 1 * np .ones (
225
+ times_seconds .shape [0 ]
226
+ ) # discard all spikes at the end for now
227
+ labels_all [idx_all ] = labels_12 .astype ("int" )
219
228
220
229
return labels_all
221
230
0 commit comments