Skip to content

Commit 4d6d6fa

Browse files
author
julienboussard
committed
fix split and add split per chunk
1 parent e2de16d commit 4d6d6fa

File tree

2 files changed

+263
-146
lines changed

2 files changed

+263
-146
lines changed

src/dartsort/cluster/ensemble_utils.py

Lines changed: 143 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def forward_backward(
1010
feature_scales=(1, 1, 50),
1111
adaptive_feature_scales=False,
1212
motion_est=None,
13+
verbose=True,
1314
):
1415
"""
1516
Ensemble over HDBSCAN clustering
@@ -19,21 +20,23 @@ def forward_backward(
1920
return chunk_sortings[0]
2021

2122
times_seconds = chunk_sortings[0].times_seconds
22-
times_samples = chunk_sortings[0].times_samples
23+
2324
min_time_s = chunk_time_ranges_s[0][0]
2425
idx_all_chunks = [get_indices_in_chunk(times_seconds, chunk_range) for chunk_range in chunk_time_ranges_s]
2526

2627
# put all labels into one array
2728
# TODO: this does not allow for overlapping chunks.
28-
labels_all = np.full_like(times_samples, -1)
29+
labels_all = np.full_like(times_seconds, -1)
2930
for ix, sorting in zip(idx_all_chunks, chunk_sortings):
30-
assert labels_all[ix].max() < 0 # assert non-overlapping
31-
labels_all[ix] = sorting.labels[ix]
31+
if len(ix):
32+
assert labels_all[ix].max() < 0 # assert non-overlapping
33+
labels_all[ix] = sorting.labels[ix]
3234

3335
# load features that we will need
3436
# needs to be all features here
3537
amps = chunk_sortings[0].denoised_ptp_amplitudes
3638
xyza = chunk_sortings[0].point_source_localizations
39+
3740
x = xyza[:, 0]
3841
z_reg = xyza[:, 2]
3942

@@ -44,7 +47,11 @@ def forward_backward(
4447
if motion_est is not None:
4548
z_reg = motion_est.correct_s(times_seconds, z_reg)
4649

47-
for k in trange(len(chunk_sortings) - 1, desc="Ensembling chunks"):
50+
if verbose is True:
51+
tbar = trange(len(chunk_sortings) - 1, desc="Ensembling chunks")
52+
else:
53+
tbar = range(len(chunk_sortings) - 1)
54+
for k in tbar:
4855
# CHANGE THE 1 ---
4956
# idx_1 = np.flatnonzero(np.logical_and(times_seconds>=min_time_s, times_seconds<min_time_s+k*shift+chunk_size_s))
5057
idx_1 = np.flatnonzero(
@@ -62,160 +69,162 @@ def forward_backward(
6269
amps_2 = feature_scales[2] * np.log(log_c + amps[idx_2])
6370
labels_1 = labels_all[idx_1].copy().astype("int")
6471
labels_2 = chunk_sortings[k + 1].labels[idx_2]
65-
unit_label_shift = int(labels_1.max() + 1)
66-
labels_2[labels_2 > -1] += unit_label_shift
6772

6873
units_1 = np.unique(labels_1)
6974
units_1 = units_1[units_1 > -1]
7075
units_2 = np.unique(labels_2)
7176
units_2 = units_2[units_2 > -1]
7277

73-
# FORWARD PASS
74-
75-
dist_matrix = np.zeros((units_1.shape[0], units_2.shape[0]))
76-
77-
# Speed up this code - this matrix can be sparse (only compute distance for "neighboring" units) - OK for now, still pretty fast
78-
for i in range(units_1.shape[0]):
79-
unit_1 = units_1[i]
80-
for j in range(units_2.shape[0]):
81-
unit_2 = units_2[j]
82-
feat_1 = np.c_[
83-
np.median(x_1[labels_1 == unit_1]),
84-
np.median(z_1[labels_1 == unit_1]),
85-
np.median(amps_1[labels_1 == unit_1]),
86-
]
87-
feat_2 = np.c_[
88-
np.median(x_2[labels_2 == unit_2]),
89-
np.median(z_2[labels_2 == unit_2]),
90-
np.median(amps_2[labels_2 == unit_2]),
91-
]
92-
dist_matrix[i, j] = ((feat_1 - feat_2) ** 2).sum()
93-
94-
# find for chunk 2 units the closest units in chunk 1 and split chunk 1 units
95-
dist_forward = dist_matrix.argmin(0)
96-
units_, counts_ = np.unique(dist_forward, return_counts=True)
97-
98-
for unit_to_split in units_[counts_ > 1]:
99-
units_to_match_to = (
100-
np.flatnonzero(dist_forward == unit_to_split) + unit_label_shift
101-
)
102-
features_to_match_to = np.c_[
103-
np.median(x_2[labels_2 == units_to_match_to[0]]),
104-
np.median(z_2[labels_2 == units_to_match_to[0]]),
105-
np.median(amps_2[labels_2 == units_to_match_to[0]]),
106-
]
107-
for u in units_to_match_to[1:]:
108-
features_to_match_to = np.concatenate(
109-
(
110-
features_to_match_to,
111-
np.c_[
112-
np.median(x_2[labels_2 == u]),
113-
np.median(z_2[labels_2 == u]),
114-
np.median(amps_2[labels_2 == u]),
115-
],
116-
)
117-
)
118-
spikes_to_update = np.flatnonzero(labels_1 == unit_to_split)
119-
x_s_to_update = x_1[spikes_to_update]
120-
z_s_to_update = z_1[spikes_to_update]
121-
amps_s_to_update = amps_1[spikes_to_update]
122-
for j, s in enumerate(spikes_to_update):
123-
# Don't update if new distance is too high?
124-
feat_s = np.c_[
125-
x_s_to_update[j], z_s_to_update[j], amps_s_to_update[j]
126-
]
127-
labels_1[s] = units_to_match_to[
128-
((feat_s - features_to_match_to) ** 2).sum(1).argmin()
129-
]
130-
131-
# Relabel labels_1 and labels_2
132-
for unit_to_relabel in units_:
133-
if counts_[np.flatnonzero(units_ == unit_to_relabel)][0] == 1:
134-
idx_to_relabel = np.flatnonzero(labels_1 == unit_to_relabel)
135-
labels_1[idx_to_relabel] = units_2[dist_forward == unit_to_relabel]
136-
137-
# BACKWARD PASS
138-
139-
units_not_matched = np.unique(labels_1)
140-
units_not_matched = units_not_matched[units_not_matched > -1]
141-
units_not_matched = units_not_matched[units_not_matched < unit_label_shift]
142-
143-
if len(units_not_matched):
144-
all_units_to_match_to = (
145-
dist_matrix[units_not_matched].argmin(1) + unit_label_shift
146-
)
147-
for unit_to_split in np.unique(all_units_to_match_to):
148-
units_to_match_to = np.concatenate(
149-
(
150-
units_not_matched[all_units_to_match_to == unit_to_split],
151-
[unit_to_split],
152-
)
78+
if len(units_2) and len(units_1):
79+
unit_label_shift = int(labels_1.max() + 1)
80+
labels_2[labels_2 > -1] += unit_label_shift
81+
units_2 += unit_label_shift
82+
83+
# FORWARD PASS
84+
dist_matrix = np.zeros((units_1.shape[0], units_2.shape[0]))
85+
86+
# Speed up this code - this matrix can be sparse (only compute distance for "neighboring" units) - OK for now, still pretty fast
87+
for i in range(units_1.shape[0]):
88+
unit_1 = units_1[i]
89+
for j in range(units_2.shape[0]):
90+
unit_2 = units_2[j]
91+
feat_1 = np.c_[
92+
np.median(x_1[labels_1 == unit_1]),
93+
np.median(z_1[labels_1 == unit_1]),
94+
np.median(amps_1[labels_1 == unit_1]),
95+
]
96+
feat_2 = np.c_[
97+
np.median(x_2[labels_2 == unit_2]),
98+
np.median(z_2[labels_2 == unit_2]),
99+
np.median(amps_2[labels_2 == unit_2]),
100+
]
101+
dist_matrix[i, j] = ((feat_1 - feat_2) ** 2).sum()
102+
103+
# find for chunk 2 units the closest units in chunk 1 and split chunk 1 units
104+
dist_forward = dist_matrix.argmin(0)
105+
units_, counts_ = np.unique(dist_forward, return_counts=True)
106+
107+
for unit_to_split in units_[counts_ > 1]:
108+
units_to_match_to = (
109+
np.flatnonzero(dist_forward == unit_to_split) + unit_label_shift
153110
)
154-
155111
features_to_match_to = np.c_[
156-
np.median(x_1[labels_1 == units_to_match_to[0]]),
157-
np.median(z_1[labels_1 == units_to_match_to[0]]),
158-
np.median(amps_1[labels_1 == units_to_match_to[0]]),
112+
np.median(x_2[labels_2 == units_to_match_to[0]]),
113+
np.median(z_2[labels_2 == units_to_match_to[0]]),
114+
np.median(amps_2[labels_2 == units_to_match_to[0]]),
159115
]
160116
for u in units_to_match_to[1:]:
161117
features_to_match_to = np.concatenate(
162118
(
163119
features_to_match_to,
164120
np.c_[
165-
np.median(x_1[labels_1 == u]),
166-
np.median(z_1[labels_1 == u]),
167-
np.median(amps_1[labels_1 == u]),
121+
np.median(x_2[labels_2 == u]),
122+
np.median(z_2[labels_2 == u]),
123+
np.median(amps_2[labels_2 == u]),
168124
],
169125
)
170126
)
171-
spikes_to_update = np.flatnonzero(labels_2 == unit_to_split)
172-
x_s_to_update = x_2[spikes_to_update]
173-
z_s_to_update = z_2[spikes_to_update]
174-
amps_s_to_update = amps_2[spikes_to_update]
127+
spikes_to_update = np.flatnonzero(labels_1 == unit_to_split)
128+
x_s_to_update = x_1[spikes_to_update]
129+
z_s_to_update = z_1[spikes_to_update]
130+
amps_s_to_update = amps_1[spikes_to_update]
175131
for j, s in enumerate(spikes_to_update):
132+
# Don't update if new distance is too high?
176133
feat_s = np.c_[
177134
x_s_to_update[j], z_s_to_update[j], amps_s_to_update[j]
178135
]
179-
labels_2[s] = units_to_match_to[
136+
labels_1[s] = units_to_match_to[
180137
((feat_s - features_to_match_to) ** 2).sum(1).argmin()
181138
]
182-
183-
# Do we need to "regularize" and make sure the distance intra units after merging is smaller than the distance inter units before merging
184-
all_labels_1 = np.unique(labels_1)
185-
all_labels_1 = all_labels_1[all_labels_1 > -1]
186-
187-
features_all_1 = np.c_[
188-
np.median(x_1[labels_1 == all_labels_1[0]]),
189-
np.median(z_1[labels_1 == all_labels_1[0]]),
190-
np.median(amps_1[labels_1 == all_labels_1[0]]),
191-
]
192-
for u in all_labels_1[1:]:
193-
features_all_1 = np.concatenate(
194-
(
195-
features_all_1,
196-
np.c_[
197-
np.median(x_1[labels_1 == u]),
198-
np.median(z_1[labels_1 == u]),
199-
np.median(amps_1[labels_1 == u]),
200-
],
139+
140+
# Relabel labels_1 and labels_2
141+
for unit_to_relabel in units_:
142+
if counts_[np.flatnonzero(units_ == unit_to_relabel)][0] == 1:
143+
idx_to_relabel = np.flatnonzero(labels_1 == unit_to_relabel)
144+
labels_1[idx_to_relabel] = units_2[dist_forward == unit_to_relabel]
145+
146+
# BACKWARD PASS
147+
148+
units_not_matched = np.unique(labels_1)
149+
units_not_matched = units_not_matched[units_not_matched > -1]
150+
units_not_matched = units_not_matched[units_not_matched < unit_label_shift]
151+
152+
if len(units_not_matched):
153+
all_units_to_match_to = (
154+
dist_matrix[units_not_matched].argmin(1) + unit_label_shift
201155
)
156+
for unit_to_split in np.unique(all_units_to_match_to):
157+
units_to_match_to = np.concatenate(
158+
(
159+
units_not_matched[all_units_to_match_to == unit_to_split],
160+
[unit_to_split],
161+
)
162+
)
163+
164+
features_to_match_to = np.c_[
165+
np.median(x_1[labels_1 == units_to_match_to[0]]),
166+
np.median(z_1[labels_1 == units_to_match_to[0]]),
167+
np.median(amps_1[labels_1 == units_to_match_to[0]]),
168+
]
169+
for u in units_to_match_to[1:]:
170+
features_to_match_to = np.concatenate(
171+
(
172+
features_to_match_to,
173+
np.c_[
174+
np.median(x_1[labels_1 == u]),
175+
np.median(z_1[labels_1 == u]),
176+
np.median(amps_1[labels_1 == u]),
177+
],
178+
)
179+
)
180+
spikes_to_update = np.flatnonzero(labels_2 == unit_to_split)
181+
x_s_to_update = x_2[spikes_to_update]
182+
z_s_to_update = z_2[spikes_to_update]
183+
amps_s_to_update = amps_2[spikes_to_update]
184+
for j, s in enumerate(spikes_to_update):
185+
feat_s = np.c_[
186+
x_s_to_update[j], z_s_to_update[j], amps_s_to_update[j]
187+
]
188+
labels_2[s] = units_to_match_to[
189+
((feat_s - features_to_match_to) ** 2).sum(1).argmin()
190+
]
191+
192+
# Do we need to "regularize" and make sure the distance intra units after merging is smaller than the distance inter units before merging
193+
# all_labels_1 = np.unique(labels_1)
194+
# all_labels_1 = all_labels_1[all_labels_1 > -1]
195+
196+
# features_all_1 = np.c_[
197+
# np.median(x_1[labels_1 == all_labels_1[0]]),
198+
# np.median(z_1[labels_1 == all_labels_1[0]]),
199+
# np.median(amps_1[labels_1 == all_labels_1[0]]),
200+
# ]
201+
# for u in all_labels_1[1:]:
202+
# features_all_1 = np.concatenate(
203+
# (
204+
# features_all_1,
205+
# np.c_[
206+
# np.median(x_1[labels_1 == u]),
207+
# np.median(z_1[labels_1 == u]),
208+
# np.median(amps_1[labels_1 == u]),
209+
# ],
210+
# )
211+
# )
212+
213+
# distance_inter = (
214+
# (features_all_1[:, :, None] - features_all_1.T[None]) ** 2
215+
# ).sum(1)
216+
217+
labels_12 = np.concatenate((labels_1, labels_2))
218+
_, labels_12[labels_12 > -1] = np.unique(
219+
labels_12[labels_12 > -1], return_inverse=True
220+
) # Make contiguous
221+
idx_all = np.flatnonzero(
222+
times_seconds < min_time_s + chunk_time_ranges_s[k + 1][1]
202223
)
203-
204-
distance_inter = (
205-
(features_all_1[:, :, None] - features_all_1.T[None]) ** 2
206-
).sum(1)
207-
208-
labels_12 = np.concatenate((labels_1, labels_2))
209-
_, labels_12[labels_12 > -1] = np.unique(
210-
labels_12[labels_12 > -1], return_inverse=True
211-
) # Make contiguous
212-
idx_all = np.flatnonzero(
213-
times_seconds < min_time_s + chunk_time_ranges_s[k + 1][1]
214-
)
215-
labels_all = -1 * np.ones(
216-
times_seconds.shape[0]
217-
) # discard all spikes at the end for now
218-
labels_all[idx_all] = labels_12.astype("int")
224+
labels_all = -1 * np.ones(
225+
times_seconds.shape[0]
226+
) # discard all spikes at the end for now
227+
labels_all[idx_all] = labels_12.astype("int")
219228

220229
return labels_all
221230

0 commit comments

Comments
 (0)