Skip to content

Commit

Permalink
Initial undebugged split/merge ensemble, plus a computation config th…
Browse files Browse the repository at this point in the history
…ing that I will likely live to regret
  • Loading branch information
cwindolf committed Feb 12, 2024
1 parent 923e8cf commit 2da8295
Show file tree
Hide file tree
Showing 7 changed files with 400 additions and 33 deletions.
12 changes: 6 additions & 6 deletions src/dartsort/cluster/density.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import numpy as np
from scipy.spatial import KDTree
from scipy.ndimage import gaussian_filter
from scipy.interpolate import RegularGridInterpolator
from scipy.ndimage import gaussian_filter
from scipy.sparse import coo_array
from scipy.sparse.csgraph import connected_components
from scipy import sparse
from scipy.spatial import KDTree


def kdtree_inliers(
Expand Down Expand Up @@ -156,10 +156,10 @@ def density_peaks_clustering(
)
if noise_density:
nhdn[density <= noise_density] = n
nhdn = nhdn.astype(np.int32)
has_nhdn = np.flatnonzero(nhdn < n).astype(np.int32)
nhdn = nhdn.astype(np.intc)
has_nhdn = np.flatnonzero(nhdn < n).astype(np.intc)

graph = sparse.coo_array(
graph = coo_array(
(np.ones(has_nhdn.size), (nhdn[has_nhdn], has_nhdn)), shape=(n, n)
)
ncc, labels = connected_components(graph)
Expand Down
44 changes: 43 additions & 1 deletion src/dartsort/cluster/ensemble_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
from tqdm.auto import trange
from tqdm.auto import tqdm, trange

from ..config import default_split_merge_config
from . import merge, split


def forward_backward(
Expand Down Expand Up @@ -236,3 +239,42 @@ def get_indices_in_chunk(times_s, chunk_time_range_s):
return np.flatnonzero(
(times_s >= chunk_time_range_s[0]) & (times_s < chunk_time_range_s[1])
)


def split_merge_ensemble(
recording,
chunk_sortings,
motion_est=None,
split_merge_config=default_split_merge_config,
n_jobs_split=0,
n_jobs_merge=0,
device=None,
show_progress=True,
):
# split inside each chunk
chunk_sortings = [
split.split_clusters(
sorting,
split_strategy=split_merge_config.split_strategy,
recursive=split_merge_config.recursive_split,
n_jobs=n_jobs_split,
show_progress=False
)
for sorting in tqdm(chunk_sortings, desc="Split within chunks")
]

# merge within and across chunks
sorting = merge.merge_across_sortings(
chunk_sortings,
recording,
template_config=split_merge_config.merge_template_config,
motion_est=motion_est,
cross_merge_distance_threshold=split_merge_config.cross_merge_distance_threshold,
within_merge_distance_threshold=split_merge_config.merge_distance_threshold,
device=device,
n_jobs=n_jobs_merge,
n_jobs_templates=n_jobs_merge,
show_progress=True,
)

return sorting
19 changes: 17 additions & 2 deletions src/dartsort/cluster/initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

import h5py
import numpy as np
from dartsort.util import job_util
from dartsort.util.data_util import DARTsortSorting

from . import cluster_util, ensemble_utils, density
from . import cluster_util, density, ensemble_utils


def cluster_chunk(
Expand Down Expand Up @@ -179,6 +180,7 @@ def ensemble_chunks(
peeling_hdf5_filename,
recording,
clustering_config,
computation_config=None,
motion_est=None,
):
"""Initial clustering combined across chunks of time
Expand All @@ -205,7 +207,9 @@ def ensemble_chunks(
if len(chunk_sortings) == 1:
return chunk_sortings[0]

assert clustering_config.ensemble_strategy in ("forward_backward",)
assert clustering_config.ensemble_strategy in ("forward_backward", "split_merge")
if computation_config is None:
computation_config = job_util.get_global_computation_config()

if clustering_config.ensemble_strategy == "forward_backward":
labels = ensemble_utils.forward_backward(
Expand All @@ -218,5 +222,16 @@ def ensemble_chunks(
motion_est=motion_est,
)
sorting = replace(chunk_sortings[0], labels=labels)
elif clustering_config.ensemble_strategy == "split_merge":
sorting = ensemble_utils.split_merge_ensemble(
recording,
chunk_sortings,
motion_est=motion_est,
split_merge_config=clustering_config.split_merge_config,
n_jobs_split=computation_config.n_jobs_cpu,
n_jobs_merge=computation_config.actual_n_jobs_gpu,
device=None,
show_progress=True,
)

return sorting
Loading

0 comments on commit 2da8295

Please sign in to comment.