Skip to content

Commit

Permalink
Merge branch 'main' of github.com:cwindolf/spike-psvae into main
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Feb 18, 2024
2 parents dbb0ab6 + 0d08b2d commit 441ddf6
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 46 deletions.
20 changes: 16 additions & 4 deletions src/dartsort/cluster/density.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def kdtree_inliers(
return inliers, kdtree


def get_smoothed_densities(X, inliers=slice(None), sigmas=None, return_hist=False, sigma_lows=None, sigma_ramp_ax=-1, bin_sizes=None, bin_size_ratio=5.0, min_bin_size=1.0, ramp_min_bin_size=5.0):
def get_smoothed_densities(X, inliers=slice(None), sigmas=None, return_hist=False, sigma_lows=None, sigma_ramp_ax=-1, bin_sizes=None, bin_size_ratio=5.0, min_bin_size=1.0, ramp_min_bin_size=5.0, revert=True):
"""Get RBF density estimates for each X[i] and bandwidth in sigmas
Outliers will be marked with NaN KDEs. Please pass inliers, or else your
Expand Down Expand Up @@ -87,14 +87,19 @@ def get_smoothed_densities(X, inliers=slice(None), sigmas=None, return_hist=Fals
elif sigma is not None and sigma_low is not None:
# filter by a sequence of bandwidths
ramp = np.linspace(sigma_low, sigma, num=hist.shape[sigma_ramp_ax])
if revert:
ramp[:, sigma_ramp_ax] = np.linspace(sigma[sigma_ramp_ax], sigma_low[sigma_ramp_ax], num=hist.shape[sigma_ramp_ax])
# ramp[:, 0] = sigma[0]
# ramp[:, 1] = sigma[1]

# operate along the ramp axis
hist_move = np.moveaxis(hist, sigma_ramp_ax, 0)
hist_smoothed = hist_move.copy()
for j, sig in enumerate(ramp):
sig_move = sig.copy()
sig_move[0] = sig[sigma_ramp_ax]
sig_move[sigma_ramp_ax] = sig[0]
hist_smoothed[j] = gaussian_filter(hist_move, sig_move)[j]
hist_smoothed[j] = gaussian_filter(hist_move, sig_move)[j] #sig_move
hist = np.moveaxis(hist_smoothed, 0, sigma_ramp_ax)
if return_hist:
hists.append(hist)
Expand Down Expand Up @@ -189,6 +194,8 @@ def density_peaks_clustering(
workers=1,
return_extra=False,
triage_quantile_per_cluster=0,
amp_no_triaging=12,
revert=True,
):
n = len(X)

Expand All @@ -201,7 +208,7 @@ def density_peaks_clustering(
)

do_ratio = int(sigma_regional is not None)
density = get_smoothed_densities(X, inliers=inliers, sigmas=sigma_local, sigma_lows=sigma_local_low)
density = get_smoothed_densities(X, inliers=inliers, sigmas=sigma_local, sigma_lows=sigma_local_low, revert=revert)
if do_ratio:
reg_density = get_smoothed_densities(X, inliers=inliers, sigmas=sigma_regional, sigma_lows=sigma_regional_low)
density = np.nan_to_num(density / reg_density)
Expand Down Expand Up @@ -239,8 +246,13 @@ def density_peaks_clustering(
if triage_quantile_per_cluster>0:
for k in np.unique(labels[labels>-1]):
idx_label = np.flatnonzero(labels == k)
amp_vec = X[idx_label, 2]
# triage_quantile_unit = triage_quantile_per_cluster
q = np.quantile(density[idx_label], triage_quantile_per_cluster)
labels[idx_label[density[idx_label]<q]] = -1
spikes_to_remove = np.flatnonzero(np.logical_and(
density[idx_label]<q, amp_vec<amp_no_triaging,
))
labels[idx_label[spikes_to_remove]] = -1

if not return_extra:
return labels
Expand Down
5 changes: 3 additions & 2 deletions src/dartsort/cluster/forward_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def forward_backward(
adaptive_feature_scales=False,
motion_est=None,
verbose=True,
min_cluster_size=25,
):
"""
Ensemble over HDBSCAN clustering
Expand Down Expand Up @@ -92,8 +93,8 @@ def forward_backward(
dist_matrix = np.zeros((units_1.shape[0], units_2.shape[0]))
for i, unit_1 in enumerate(units_1):
for j, unit_2 in enumerate(units_2):
idxunit1 = np.flatnonzero(labels_1 == unit_1)
idxunit2 = np.flatnonzero(labels_2 == unit_2)
idxunit1 = np.flatnonzero(labels_1 == unit_1)[-min_cluster_size:]
idxunit2 = np.flatnonzero(labels_2 == unit_2)[:min_cluster_size]
feat_1 = np.median(features1[idxunit1], axis=0)
feat_2 = np.median(features2[idxunit2], axis=0)
dist_matrix[i, j] = ((feat_1 - feat_2) ** 2).sum()
Expand Down
1 change: 1 addition & 0 deletions src/dartsort/cluster/initial.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def cluster_chunk(
remove_clusters_smaller_than=clustering_config.remove_clusters_smaller_than,
noise_density=clustering_config.noise_density,
triage_quantile_per_cluster=clustering_config.triage_quantile_per_cluster,
revert=clustering_config.triage_quantile_per_cluster,
workers=4,
return_extra=clustering_config.attach_density_feature,
)
Expand Down
22 changes: 12 additions & 10 deletions src/dartsort/cluster/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def __init__(
min_samples=25,
cluster_selection_epsilon=25,
sigma_local=5,
sigma_local_low=None,
noise_density=0.1,
n_neighbors_search=20,
radius_search=10,
Expand Down Expand Up @@ -247,6 +248,7 @@ def __init__(

#DPC parameters
self.sigma_local=sigma_local
self.sigma_local_low=sigma_local_low
self.noise_density=noise_density
self.n_neighbors_search=n_neighbors_search
self.radius_search=radius_search
Expand All @@ -272,7 +274,7 @@ def __init__(
self.chunk_size_s is not None
), "Need to input chunk size for ensembling over chunks"

assert np.isin(cluster_alg, ["hdbscan", "dpc"])
assert np.isin(cluster_alg, ["hdbscan", "dpc"]), "cluster_alg needs to be hdbscan or dpc"
self.cluster_alg = cluster_alg

# load up the required h5 datasets
Expand Down Expand Up @@ -354,7 +356,7 @@ def split_cluster_chunks(self, in_unit):
split_sortings,
log_c=self.log_c,
feature_scales=self.localization_feature_scales,
adaptive_feature_scales=False, # self.rescale_all_features,
adaptive_feature_scales=self.rescale_all_features,
motion_est=self.motion_est,
verbose=False,
)
Expand Down Expand Up @@ -399,9 +401,8 @@ def split_cluster(self, in_unit_all):
else:
kept = np.arange(in_unit.shape[0])

self.spread_feature(in_unit)

features = []

if self.use_localization_features:
loc_features = self.localization_features[in_unit]
if self.relocated:
Expand Down Expand Up @@ -446,7 +447,7 @@ def split_cluster(self, in_unit_all):

features = np.column_stack([f[kept] for f in features])

if self.cluster_alg == "hdbscan":
if self.cluster_alg == "hdbscan" and features.shape[0]>self.min_cluster_size:
clust = HDBSCAN(
min_cluster_size=self.min_cluster_size,
min_samples=self.min_samples,
Expand All @@ -455,21 +456,22 @@ def split_cluster(self, in_unit_all):
prediction_data=self.reassign_outliers,
)
hdb_labels = clust.fit_predict(features)
elif self.cluster_alg == "dpc":
is_split = np.setdiff1d(np.unique(hdb_labels), [-1]).size > 1
elif self.cluster_alg == "dpc" and features.shape[0]>self.remove_clusters_smaller_than:
hdb_labels = density.density_peaks_clustering(
features,
sigma_local=self.sigma_local,
sigma_local_low=self.sigma_local_low,
sigma_regional=None,
noise_density=self.noise_density,
n_neighbors_search=self.n_neighbors_search,
radius_search=self.radius_search,
triage_quantile_per_cluster=self.triage_quantile_per_cluster,
remove_clusters_smaller_than=self.remove_clusters_smaller_than,
)
else:
print("cluster_alg needs to be hdbscan or dpc")

is_split = np.setdiff1d(np.unique(hdb_labels), [-1]).size > 1
is_split = np.setdiff1d(np.unique(hdb_labels), [-1]).size > 1
else:
is_split=False

if is_split and self.reassign_outliers:
hdb_labels = cluster_util.knn_reassign_outliers(hdb_labels, features)
Expand Down
1 change: 1 addition & 0 deletions src/dartsort/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ class ClusteringConfig:
noise_density: float = 0.0
attach_density_feature: bool = False
triage_quantile_per_cluster: float = 0.0
revert: bool = False

# -- ensembling
ensemble_strategy: Optional[str] = "forward_backward"
Expand Down
6 changes: 3 additions & 3 deletions src/dartsort/templates/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def from_config(

# handle superresolved templates
if template_config.superres_templates:
unit_ids, superres_sorting = superres_sorting(
unit_ids, superres_sort = superres_sorting(
sorting,
sorting.times_seconds,
spike_depths_um,
Expand All @@ -195,7 +195,7 @@ def from_config(
adaptive_bin_size=template_config.adaptive_bin_size,
)
else:
superres_sorting = sorting
superres_sort = sorting
# we don't skip empty units
unit_ids = np.arange(sorting.labels.max() + 1)

Expand All @@ -205,7 +205,7 @@ def from_config(
spike_counts[ix[ix >= 0]] = counts[ix >= 0]

# main!
results = get_templates(recording, superres_sorting, **kwargs)
results = get_templates(recording, superres_sort, **kwargs)

# handle registered templates
if template_config.registered_templates and motion_est is not None:
Expand Down
41 changes: 14 additions & 27 deletions src/spike_psvae/cluster_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def array_scatter_5_features(
x,
z,
maxptp,
trough_val,
trough_time,
tip_val,
density,
zlim=(-50, 3900),
xlim=None,
ptplim=None,
Expand All @@ -129,7 +127,7 @@ def array_scatter_5_features(
):
fig = None
if axes is None:
fig, axes = plt.subplots(1, 6, sharey=True, figsize=figsize)
fig, axes = plt.subplots(1, 5, sharey=True, figsize=figsize)

if title is not None:
fig.suptitle(title)
Expand Down Expand Up @@ -161,7 +159,7 @@ def array_scatter_5_features(
axes[1].scatter(
x,
z,
c=maxptp_c,
c=np.exp(maxptp_c/50)-5,
s=s_dot,
alpha=0.1,
marker=".",
Expand All @@ -180,58 +178,47 @@ def array_scatter_5_features(
excluded_ids=excluded_ids,
do_ellipse=do_ellipse,
)
axes[2].set_xlabel("Amplitude (s.u.)", fontsize=16)
axes[2].set_xlabel("50*log(5+ptp) (s.u.)", fontsize=16)
axes[2].tick_params(axis='x', labelsize=16)

cluster_scatter(
trough_val,
density,
z,
labels,
ax=axes[3],
s=s_dot,
alpha=0.1,
excluded_ids=excluded_ids,
do_ellipse=do_ellipse,
do_ellipse=False,
)
axes[3].set_xlabel("Trough Val (s.u.)", fontsize=16)
axes[3].tick_params(axis='x', labelsize=16)

cluster_scatter(
trough_time,
axes[4].scatter(
x,
z,
labels,
ax=axes[4],
c=density,
s=s_dot,
alpha=0.1,
excluded_ids=excluded_ids,
do_ellipse=do_ellipse,
marker=".",
cmap=plt.cm.jet,
)
axes[4].set_xlabel("Trough Time (1/30ms)", fontsize=16)
axes[4].tick_params(axis='x', labelsize=16)
axes[4].scatter(*geom.T, c="orange", marker="s", s=s_size_geom)
axes[4].set_title("colored by ptps")

cluster_scatter(
tip_val,
z,
labels,
ax=axes[5],
s=s_dot,
alpha=0.1,
excluded_ids=excluded_ids,
do_ellipse=do_ellipse,
)
axes[5].set_xlabel("Tip Val (s.u.)", fontsize=16)
axes[5].tick_params(axis='x', labelsize=16)

axes[0].set_ylim(zlim)
axes[1].set_ylim(zlim)
axes[2].set_ylim(zlim)
axes[3].set_ylim(zlim)
axes[4].set_ylim(zlim)
axes[5].set_ylim(zlim)

if xlim is not None:
axes[0].set_xlim(xlim)
axes[1].set_xlim(xlim)
axes[4].set_xlim(xlim)
if ptplim is not None:
axes[2].set_xlim(ptplim)

Expand Down

0 comments on commit 441ddf6

Please sign in to comment.