Skip to content

Commit

Permalink
Add grid_snap parcellation
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Jan 15, 2024
1 parent d553067 commit 86ef67f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 8 deletions.
28 changes: 26 additions & 2 deletions src/dartsort/cluster/cluster_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,30 @@ def closest_registered_channels(times_seconds, x, z_abs, geom, motion_est=None):
return reg_channels


def grid_snap(times_seconds, x, z_abs, geom, grid_dx=15, grid_dz=15, motion_est=None):
if motion_est is None:
motion_est == IdentityMotionEstimate()
z_reg = motion_est.correct_s(times_seconds, z_abs)
reg_pos = np.c_[x, z_reg]

# make a grid inside the registered geom bounding box
registered_geom = drift_util.registered_geometry(geom, motion_est)
min_x, max_x = registered_geom[:, 0].min(), registered_geom[:, 0].max()
min_z, max_z = registered_geom[:, 1].min(), registered_geom[:, 1].max()
grid_x = np.arange(min_x, max_x, grid_dx)
grid_x += (min_x + max_x) / 2 - grid_x.mean()
grid_z = np.arange(min_z, max_z, grid_dz)
grid_z += (min_z + max_z) / 2 - grid_z.mean()
grid_xx, grid_zz = np.meshgrid(grid_x, grid_z, indexing="ij")
grid = np.c_[grid_xx.ravel(), grid_zz.ravel()]

# snap to closest grid point
registered_kdt = KDTree(grid)
distances, reg_channels = registered_kdt.query(reg_pos)

return reg_channels


def hdbscan_clustering(
recording,
times_seconds,
Expand All @@ -38,7 +62,7 @@ def hdbscan_clustering(
scales=(1, 1, 50),
log_c=5,
recursive=True,
do_remove_dups=True,
remove_duplicates=True,
frames_dedup=12,
frame_dedup_cluster=20,
):
Expand All @@ -60,7 +84,7 @@ def hdbscan_clustering(
)
clusterer.fit(features)

if do_remove_dups:
if remove_duplicates:
(
clusterer,
duplicate_indices,
Expand Down
30 changes: 25 additions & 5 deletions src/dartsort/cluster/initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ def cluster_chunk(
labels[in_chunk] = cluster_util.closest_registered_channels(
times_s[in_chunk], xyza[in_chunk, 0], xyza[in_chunk, 2], geom, motion_est
)
elif clustering_config.cluster_strategy == "grid_snap":
labels[in_chunk] = cluster_util.grid_snap(
times_s[in_chunk],
xyza[in_chunk, 0],
xyza[in_chunk, 2],
geom,
grid_dx=clustering_config.grid_dx,
grid_dz=clustering_config.grid_dz,
motion_est=motion_est,
)
elif clustering_config.cluster_strategy == "hdbscan":
labels[in_chunk] = cluster_util.hdbscan_clustering(
times_s[in_chunk],
Expand All @@ -69,6 +79,8 @@ def cluster_chunk(
min_samples=clustering_config.min_samples,
cluster_selection_epsilon=clustering_config.cluster_selection_epsilon,
scales=clustering_config.feature_scales,
recursive=clustering_config.recursive,
remove_duplicates=clustering_config.remove_duplicates,
)
else:
assert False
Expand Down Expand Up @@ -100,17 +112,25 @@ def cluster_chunks(
"""
chunk_samples = recording.sampling_frequency * clustering_config.chunk_size_s

# determine number of chunks, and we'll count the extra if it's at least 2/3
n_chunks = recording.get_num_samples() / chunk_samples
n_chunks = np.floor(n_chunks) + (n_chunks - np.floor(n_chunks) > 0.66)
n_chunks = int(max(1, n_chunks))
# determine number of chunks
# if we're not ensembling, that's 1 chunk.
if (
not clustering_config.ensemble_strategy
or clustering_config.ensemble_strategy.lower() == "none"
):
n_chunks = 1
else:
n_chunks = recording.get_num_samples() / chunk_samples
# we'll count the remainder as a chunk if it's at least 2/3 of one
n_chunks = np.floor(n_chunks) + (n_chunks - np.floor(n_chunks) > 0.66)
n_chunks = int(max(1, n_chunks))

# evenly divide the recording into chunks
assert recording.get_num_segments() == 1
start_time_s, end_time_s = recording._recording_segments[0].sample_index_to_time(
[0, recording.get_num_samples() - 1]
)
chunk_times_s = np.linspace(start_time_s, end_time_s, num=n_chunks)
chunk_times_s = np.linspace(start_time_s, end_time_s, num=n_chunks + 1)
chunk_time_ranges_s = list(zip(chunk_times_s[:-1], chunk_times_s[1:]))

# cluster each chunk. can be parallelized in the future.
Expand Down
8 changes: 7 additions & 1 deletion src/dartsort/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,16 @@ class ClusteringConfig:
cluster_selection_epsilon: int = 1
feature_scales: Tuple[float] = (1.0, 1.0, 50.0)
log_c: float = 5.0
recursive: bool = True
remove_duplicates: bool = True

# grid snap parameters
grid_dx: float = 15.0
grid_dz: float = 15.0

# -- ensembling
ensemble_strategy: Optional[str] = "forward_backward"
chunk_size_s: int = 150
chunk_size_s: float = 150.0


default_featurization_config = FeaturizationConfig()
Expand Down

0 comments on commit 86ef67f

Please sign in to comment.