Skip to content

Commit

Permalink
speed up forward-backward
Browse files Browse the repository at this point in the history
  • Loading branch information
Julien Boussard committed Feb 14, 2024
1 parent 822054f commit f91a2a6
Showing 1 changed file with 44 additions and 65 deletions.
109 changes: 44 additions & 65 deletions src/dartsort/cluster/forward_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def forward_backward(
units_2 = np.unique(labels_2)
units_2 = units_2[units_2 > -1]

features1 = np.c_[x_1, z_1, amps_1]
features2 = np.c_[x_2, z_2, amps_2]

if len(units_2) and len(units_1):
unit_label_shift = int(labels_1.max() + 1)
labels_2[labels_2 > -1] += unit_label_shift
Expand All @@ -91,20 +94,12 @@ def forward_backward(
dist_matrix = np.zeros((units_1.shape[0], units_2.shape[0]))

# Speed up this code - this matrix can be sparse (only compute distance for "neighboring" units) - OK for now, still pretty fast
for i in range(units_1.shape[0]):
unit_1 = units_1[i]
for j in range(units_2.shape[0]):
unit_2 = units_2[j]
feat_1 = np.c_[
np.median(x_1[labels_1 == unit_1]),
np.median(z_1[labels_1 == unit_1]),
np.median(amps_1[labels_1 == unit_1]),
]
feat_2 = np.c_[
np.median(x_2[labels_2 == unit_2]),
np.median(z_2[labels_2 == unit_2]),
np.median(amps_2[labels_2 == unit_2]),
]
for i, unit_1 in enumerate(units_1):
for j, unit_2 in enumerate(units_2):
idxunit1 = np.flatnonzero(labels_1 == unit_1)
idxunit2 = np.flatnonzero(labels_2 == unit_2)
feat_1 = np.median(features1[idxunit1], axis=0)
feat_2 = np.median(features2[idxunit2], axis=0)
dist_matrix[i, j] = ((feat_1 - feat_2) ** 2).sum()

# find for chunk 2 units the closest units in chunk 1 and split chunk 1 units
Expand All @@ -115,34 +110,31 @@ def forward_backward(
units_to_match_to = (
np.flatnonzero(dist_forward == unit_to_split) + unit_label_shift
)
features_to_match_to = np.c_[
np.median(x_2[labels_2 == units_to_match_to[0]]),
np.median(z_2[labels_2 == units_to_match_to[0]]),
np.median(amps_2[labels_2 == units_to_match_to[0]]),
]

idxunit2 = np.flatnonzero(labels_2 == units_to_match_to[0])
features_to_match_to = np.median(features2[idxunit2], axis=0)
# features_to_match_to = np.c_[
# np.median(x_2[labels_2 == units_to_match_to[0]]),
# np.median(z_2[labels_2 == units_to_match_to[0]]),
# np.median(amps_2[labels_2 == units_to_match_to[0]]),
# ]
#CAN CHANGE THIS
for u in units_to_match_to[1:]:
features_to_match_to = np.concatenate(
(
idxunit2 = np.flatnonzero(labels_2 == u)
features_to_match_to = np.c_[
features_to_match_to,
np.c_[
np.median(x_2[labels_2 == u]),
np.median(z_2[labels_2 == u]),
np.median(amps_2[labels_2 == u]),
],
)
)
spikes_to_update = np.flatnonzero(labels_1 == unit_to_split)
x_s_to_update = x_1[spikes_to_update]
z_s_to_update = z_1[spikes_to_update]
amps_s_to_update = amps_1[spikes_to_update]
for j, s in enumerate(spikes_to_update):
# Don't update if new distance is too high?
feat_s = np.c_[
x_s_to_update[j], z_s_to_update[j], amps_s_to_update[j]
]
labels_1[s] = units_to_match_to[
((feat_s - features_to_match_to) ** 2).sum(1).argmin()
np.median(features2[idxunit2], axis=0)
]

spikes_to_update = np.flatnonzero(labels_1 == unit_to_split)

# x_s_to_update = x_1[spikes_to_update]
# z_s_to_update = z_1[spikes_to_update]
# amps_s_to_update = amps_1[spikes_to_update]
# feat_s = np.c_[x_s_to_update, z_s_to_update, amps_s_to_update]
feat_s = features1[spikes_to_update]

labels_1[spikes_to_update] = units_to_match_to[((features_to_match_to.T[:, None] - feat_s[None])** 2).sum(2).argmin(0)]

# Relabel labels_1 and labels_2
for unit_to_relabel in units_:
Expand All @@ -167,35 +159,22 @@ def forward_backward(
[unit_to_split],
)
)

features_to_match_to = np.c_[
np.median(x_1[labels_1 == units_to_match_to[0]]),
np.median(z_1[labels_1 == units_to_match_to[0]]),
np.median(amps_1[labels_1 == units_to_match_to[0]]),
]

features_to_match_to = np.median(features1[labels_1 == units_to_match_to[0]], axis=0)
for u in units_to_match_to[1:]:
features_to_match_to = np.concatenate(
(
features_to_match_to = np.c_[
features_to_match_to,
np.c_[
np.median(x_1[labels_1 == u]),
np.median(z_1[labels_1 == u]),
np.median(amps_1[labels_1 == u]),
],
)
)
spikes_to_update = np.flatnonzero(labels_2 == unit_to_split)
x_s_to_update = x_2[spikes_to_update]
z_s_to_update = z_2[spikes_to_update]
amps_s_to_update = amps_2[spikes_to_update]
for j, s in enumerate(spikes_to_update):
feat_s = np.c_[
x_s_to_update[j], z_s_to_update[j], amps_s_to_update[j]
]
labels_2[s] = units_to_match_to[
((feat_s - features_to_match_to) ** 2).sum(1).argmin()
np.median(features1[labels_1 == u], axis=0)
]


spikes_to_update = np.flatnonzero(labels_2 == unit_to_split)
features2[spikes_to_update]
# x_s_to_update = x_2[spikes_to_update]
# z_s_to_update = z_2[spikes_to_update]
# amps_s_to_update = amps_2[spikes_to_update]
# feat_s = np.c_[x_s_to_update, z_s_to_update, amps_s_to_update]
labels_2[spikes_to_update] = units_to_match_to[((features_to_match_to.T[:, None] - features2[spikes_to_update][None])** 2).sum(2).argmin(0)]

# Do we need to "regularize" and make sure the distance intra units after merging is smaller than the distance inter units before merging
# all_labels_1 = np.unique(labels_1)
# all_labels_1 = all_labels_1[all_labels_1 > -1]
Expand Down

0 comments on commit f91a2a6

Please sign in to comment.