Skip to content

Commit

Permalink
Merge branch 'main' of github.com:cwindolf/spike-psvae
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Jan 19, 2024
2 parents 1e0f752 + c29f13f commit b8d8503
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 16 deletions.
5 changes: 5 additions & 0 deletions src/dartsort/cluster/cluster_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def hdbscan_clustering(
min_samples=25,
cluster_selection_epsilon=1,
scales=(1, 1, 50),
adaptive_feature_scales=False,
log_c=5,
recursive=True,
remove_duplicates=True,
Expand All @@ -75,6 +76,10 @@ def hdbscan_clustering(
else:
z_reg = motion_est.correct_s(times_seconds, z_abs)

if adaptive_feature_scales:
scales = (1, 1, np.median(np.abs(x - np.median(x)))/np.median(np.abs(np.log(log_c + amps)-np.median(np.log(log_c + amps))))
)

features = np.c_[x * scales[0], z_reg * scales[1], np.log(log_c + amps) * scales[2]]
if features.shape[1]>=features.shape[0]:
return -1*np.ones(features.shape[0])
Expand Down
8 changes: 8 additions & 0 deletions src/dartsort/cluster/ensemble_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def forward_backward(
chunk_sortings,
log_c=5,
feature_scales=(1, 1, 50),
adaptive_feature_scales=False,
motion_est=None,
):
"""
Expand All @@ -17,6 +18,8 @@ def forward_backward(
if len(chunk_sortings) == 1:
return chunk_sortings[0]



times_seconds = chunk_sortings[0].times_seconds
times_samples = chunk_sortings[0].times_samples
min_time_s = chunk_time_ranges_s[0][0]
Expand All @@ -35,6 +38,11 @@ def forward_backward(
xyza = chunk_sortings[0].point_source_localizations
x = xyza[:, 0]
z_reg = xyza[:, 2]

if adaptive_feature_scales:
feature_scales = (1, 1, np.median(np.abs(x - np.median(x)))/np.median(np.abs(np.log(log_c + amps)-np.median(np.log(log_c + amps))))
)

if motion_est is not None:
z_reg = motion_est.correct_s(times_seconds, z_reg)

Expand Down
2 changes: 2 additions & 0 deletions src/dartsort/cluster/initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def cluster_chunk(
log_c = clustering_config.log_c,
cluster_selection_epsilon=clustering_config.cluster_selection_epsilon,
scales=clustering_config.feature_scales,
adaptive_feature_scales=clustering_config.adaptive_feature_scales,
recursive=clustering_config.recursive,
remove_duplicates=clustering_config.remove_duplicates,
)
Expand Down Expand Up @@ -192,6 +193,7 @@ def ensemble_chunks(
chunk_sortings,
log_c=clustering_config.log_c,
feature_scales=clustering_config.feature_scales,
adaptive_feature_scales=clustering_config.adaptive_feature_scales,
motion_est=motion_est,
)
sorting = replace(chunk_sortings[0], labels=labels)
Expand Down
4 changes: 3 additions & 1 deletion src/dartsort/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class TemplateConfig:
superres_bin_size_um: float = 10.0
superres_bin_min_spikes: int = 5
superres_strategy: str = "drift_pitch_loc_bin"
adaptive_bin_size: bool = False

# low rank denoising?
low_rank_denoising: bool = True
Expand Down Expand Up @@ -184,8 +185,9 @@ class ClusteringConfig:
min_samples: int = 25
cluster_selection_epsilon: int = 1
feature_scales: Tuple[float] = (1.0, 1.0, 50.0)
adaptive_feature_scales: bool = False
log_c: float = 5.0
recursive: bool = True
recursive: bool = False
remove_duplicates: bool = True

# grid snap parameters
Expand Down
81 changes: 66 additions & 15 deletions src/dartsort/templates/superres_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def superres_sorting(
superres_bin_size_um=10.0,
min_spikes_per_bin=5,
probe_margin_um=200.0,
spike_x_um=None,
adaptive_bin_size=False,
):
"""Construct the spatially superresolved spike train
Expand Down Expand Up @@ -78,6 +80,8 @@ def superres_sorting(
pitch,
motion_est,
superres_bin_size_um=superres_bin_size_um,
spike_x_um=spike_x_um,
adaptive_bin_size=adaptive_bin_size,
)
elif strategy == "drift_pitch_loc_bin":
superres_labels, superres_to_original = drift_pitch_loc_bin_strategy(
Expand All @@ -87,6 +91,8 @@ def superres_sorting(
pitch,
motion_est,
superres_bin_size_um=superres_bin_size_um,
spike_x_um=spike_x_um,
adaptive_bin_size=adaptive_bin_size,
)
else:
raise ValueError(f"Unknown superres {strategy=}")
Expand All @@ -109,6 +115,8 @@ def motion_estimate_strategy(
pitch,
motion_est,
superres_bin_size_um=10.0,
spike_x_um=None, # x positions of all spikes
adaptive_bin_size=False,
):
""" """
# reg_pos = pos - disp, pos = reg_pos + disp
Expand All @@ -118,14 +126,34 @@ def motion_estimate_strategy(
else:
displacements = motion_est.disp_at_s(spike_times_s, spike_depths_um)
mod_positions = displacements % pitch
bin_ids = mod_positions // superres_bin_size_um
bin_ids = bin_ids.astype(int)
orig_label_and_bin, superres_labels = np.unique(
np.c_[original_labels, bin_ids], axis=0, return_inverse=True
)
superres_to_original = orig_label_and_bin[:, 0]
return superres_labels, superres_to_original


if not adaptive_bin_size:
bin_ids = mod_positions // superres_bin_size_um
bin_ids = bin_ids.astype(int)
orig_label_and_bin, superres_labels = np.unique(
np.c_[original_labels, bin_ids], axis=0, return_inverse=True
)
superres_to_original = orig_label_and_bin[:, 0]
return superres_labels, superres_to_original
else:
if spike_x_um is None:
raise ValueError(f"Adaptive bin size with unknown cluster width")
else:
superres_labels, superres_to_original = np.zeros(original_labels.shape), []
cmp = 0
for unit in np.unique(original_labels):
idx_unit = np.flatnonzero(original_labels == unit)
x_spread = np.median(np.abs(spike_x_um[idx_unit] - np.median(spike_x_um[idx_unit])))/0.6745
unit_superres_bin_size_um = np.maximum(np.round(2*x_spread/pitch)*pitch/2, 1)
bin_ids = mod_positions[idx_unit] // unit_superres_bin_size_um
bin_ids = bin_ids.astype(int)
orig_label_and_bin, superres_labels_unit = np.unique(
np.c_[original_labels[idx_unit], bin_ids], axis=0, return_inverse=True
)
superres_to_original.append(orig_label_and_bin[:, 0])
superres_labels[idx_unit] = superres_labels_unit+cmp
cmp+=superres_labels_unit.max()+1
return superres_labels.astype('int'), np.hstack(superres_to_original)

def drift_pitch_loc_bin_strategy(
original_labels,
Expand All @@ -134,19 +162,42 @@ def drift_pitch_loc_bin_strategy(
pitch,
motion_est,
superres_bin_size_um=10.0,
spike_x_um=None,
adaptive_bin_size=False,
):
n_pitches_shift = drift_util.get_spike_pitch_shifts(
spike_depths_um, pitch=pitch, times_s=spike_times_s, motion_est=motion_est
)
coarse_reg_depths = spike_depths_um + n_pitches_shift * pitch
bin_ids = coarse_reg_depths // superres_bin_size_um
bin_ids = bin_ids.astype(int)
orig_label_and_bin, superres_labels = np.unique(
np.c_[original_labels, bin_ids], axis=0, return_inverse=True
)
superres_to_original = orig_label_and_bin[:, 0]
return superres_labels, superres_to_original

if not adaptive_bin_size:
bin_ids = coarse_reg_depths // superres_bin_size_um
bin_ids = bin_ids.astype(int)
orig_label_and_bin, superres_labels = np.unique(
np.c_[original_labels, bin_ids], axis=0, return_inverse=True
)
superres_to_original = orig_label_and_bin[:, 0]
return superres_labels, superres_to_original
else:
if spike_x_um is None:
raise ValueError(f"Adaptive bin size with unknown cluster width")
else:
superres_labels, superres_to_original = np.zeros(original_labels.shape), []
cmp = 0
for unit in np.unique(original_labels):
idx_unit = np.flatnonzero(original_labels == unit)
x_spread = np.median(np.abs(spike_x_um[idx_unit] - np.median(spike_x_um[idx_unit])))/0.6745
unit_superres_bin_size_um = np.maximum(np.round(2*x_spread/pitch)*pitch/2, 1)

bin_ids = coarse_reg_depths[idx_unit] // unit_superres_bin_size_um
bin_ids = bin_ids.astype(int)
orig_label_and_bin, superres_labels_unit = np.unique(
np.c_[original_labels[idx_unit], bin_ids], axis=0, return_inverse=True
)
superres_to_original.append(orig_label_and_bin[:, 0])
superres_labels[idx_unit] = superres_labels_unit+cmp
cmp+=superres_labels_unit.max()+1
return superres_labels.astype('int'), np.hstack(superres_to_original)

def remove_small_superres_units(
superres_labels, superres_to_original, min_spikes_per_bin
Expand Down
3 changes: 3 additions & 0 deletions src/dartsort/templates/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def from_config(
# load spike depths
# TODO: relying on this index feels wrong
spike_depths_um = sorting.extra_features[localizations_dataset_name][:, 2]
spike_x_um = sorting.extra_features[localizations_dataset_name][:, 0]
geom = recording.get_channel_locations()

kwargs = dict(
Expand Down Expand Up @@ -177,6 +178,8 @@ def from_config(
strategy=template_config.superres_strategy,
superres_bin_size_um=template_config.superres_bin_size_um,
min_spikes_per_bin=template_config.superres_bin_min_spikes,
spike_x_um=spike_x_um,
adaptive_bin_size=template_config.adaptive_bin_size,
)
else:
# we don't skip empty units
Expand Down

0 comments on commit b8d8503

Please sign in to comment.