diff --git a/src/dartsort/cluster/density.py b/src/dartsort/cluster/density.py index 19066795..84df35e3 100644 --- a/src/dartsort/cluster/density.py +++ b/src/dartsort/cluster/density.py @@ -31,7 +31,7 @@ def kdtree_inliers( return inliers, kdtree -def get_smoothed_densities(X, inliers=slice(None), sigmas=None, return_hist=False, sigma_lows=None, sigma_ramp_ax=-1, bin_sizes=None, bin_size_ratio=5.0, min_bin_size=1.0, ramp_min_bin_size=5.0): +def get_smoothed_densities(X, inliers=slice(None), sigmas=None, return_hist=False, sigma_lows=None, sigma_ramp_ax=-1, bin_sizes=None, bin_size_ratio=5.0, min_bin_size=1.0, ramp_min_bin_size=5.0, revert=True): """Get RBF density estimates for each X[i] and bandwidth in sigmas Outliers will be marked with NaN KDEs. Please pass inliers, or else your @@ -87,6 +87,11 @@ def get_smoothed_densities(X, inliers=slice(None), sigmas=None, return_hist=Fals elif sigma is not None and sigma_low is not None: # filter by a sequence of bandwidths ramp = np.linspace(sigma_low, sigma, num=hist.shape[sigma_ramp_ax]) + if revert: + ramp[:, sigma_ramp_ax] = np.linspace(sigma[sigma_ramp_ax], sigma_low[sigma_ramp_ax], num=hist.shape[sigma_ramp_ax]) + # ramp[:, 0] = sigma[0] + # ramp[:, 1] = sigma[1] + # operate along the ramp axis hist_move = np.moveaxis(hist, sigma_ramp_ax, 0) hist_smoothed = hist_move.copy() @@ -94,7 +99,7 @@ def get_smoothed_densities(X, inliers=slice(None), sigmas=None, return_hist=Fals sig_move = sig.copy() sig_move[0] = sig[sigma_ramp_ax] sig_move[sigma_ramp_ax] = sig[0] - hist_smoothed[j] = gaussian_filter(hist_move, sig_move)[j] + hist_smoothed[j] = gaussian_filter(hist_move, sig_move)[j] #sig_move hist = np.moveaxis(hist_smoothed, 0, sigma_ramp_ax) if return_hist: hists.append(hist) @@ -189,6 +194,8 @@ def density_peaks_clustering( workers=1, return_extra=False, triage_quantile_per_cluster=0, + amp_no_triaging=12, + revert=True, ): n = len(X) @@ -201,7 +208,7 @@ def density_peaks_clustering( ) do_ratio = int(sigma_regional is not None) - density = get_smoothed_densities(X, inliers=inliers, sigmas=sigma_local, sigma_lows=sigma_local_low) + density = get_smoothed_densities(X, inliers=inliers, sigmas=sigma_local, sigma_lows=sigma_local_low, revert=revert) if do_ratio: reg_density = get_smoothed_densities(X, inliers=inliers, sigmas=sigma_regional, sigma_lows=sigma_regional_low) density = np.nan_to_num(density / reg_density) @@ -239,8 +246,13 @@ def density_peaks_clustering( if triage_quantile_per_cluster>0: for k in np.unique(labels[labels>-1]): idx_label = np.flatnonzero(labels == k) + amp_vec = X[idx_label, 2] + # triage_quantile_unit = triage_quantile_per_cluster q = np.quantile(density[idx_label], triage_quantile_per_cluster) - labels[idx_label[density[idx_label]self.min_cluster_size: clust = HDBSCAN( min_cluster_size=self.min_cluster_size, min_samples=self.min_samples, @@ -455,10 +456,12 @@ def split_cluster(self, in_unit_all): prediction_data=self.reassign_outliers, ) hdb_labels = clust.fit_predict(features) - elif self.cluster_alg == "dpc": + is_split = np.setdiff1d(np.unique(hdb_labels), [-1]).size > 1 + elif self.cluster_alg == "dpc" and features.shape[0]>self.remove_clusters_smaller_than: hdb_labels = density.density_peaks_clustering( features, sigma_local=self.sigma_local, + sigma_local_low=self.sigma_local_low, sigma_regional=None, noise_density=self.noise_density, n_neighbors_search=self.n_neighbors_search, @@ -466,10 +469,9 @@ def split_cluster(self, in_unit_all): triage_quantile_per_cluster=self.triage_quantile_per_cluster, remove_clusters_smaller_than=self.remove_clusters_smaller_than, ) - else: - print("cluster_alg needs to be hdbscan or dpc") - - is_split = np.setdiff1d(np.unique(hdb_labels), [-1]).size > 1 + is_split = np.setdiff1d(np.unique(hdb_labels), [-1]).size > 1 + else: + is_split=False if is_split and self.reassign_outliers: hdb_labels = cluster_util.knn_reassign_outliers(hdb_labels, features) diff --git a/src/dartsort/config.py b/src/dartsort/config.py index 587a2b83..7cf2baf2 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -244,6 +244,7 @@ class ClusteringConfig: noise_density: float = 0.0 attach_density_feature: bool = False triage_quantile_per_cluster: float = 0.0 + revert: bool = False # -- ensembling ensemble_strategy: Optional[str] = "forward_backward" diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 939cfd13..6dc5b08c 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -182,7 +182,7 @@ def from_config( # handle superresolved templates if template_config.superres_templates: - unit_ids, superres_sorting = superres_sorting( + unit_ids, superres_sort = superres_sorting( sorting, sorting.times_seconds, spike_depths_um, @@ -195,7 +195,7 @@ def from_config( adaptive_bin_size=template_config.adaptive_bin_size, ) else: - superres_sorting = sorting + superres_sort = sorting # we don't skip empty units unit_ids = np.arange(sorting.labels.max() + 1) @@ -205,7 +205,7 @@ def from_config( spike_counts[ix[ix >= 0]] = counts[ix >= 0] # main! - results = get_templates(recording, superres_sorting, **kwargs) + results = get_templates(recording, superres_sort, **kwargs) # handle registered templates if template_config.registered_templates and motion_est is not None: diff --git a/src/spike_psvae/cluster_viz.py b/src/spike_psvae/cluster_viz.py index 13e6b652..fdbcf9ae 100644 --- a/src/spike_psvae/cluster_viz.py +++ b/src/spike_psvae/cluster_viz.py @@ -113,9 +113,7 @@ def array_scatter_5_features( x, z, maxptp, - trough_val, - trough_time, - tip_val, + density, zlim=(-50, 3900), xlim=None, ptplim=None, @@ -129,7 +127,7 @@ def array_scatter_5_features( ): fig = None if axes is None: - fig, axes = plt.subplots(1, 6, sharey=True, figsize=figsize) + fig, axes = plt.subplots(1, 5, sharey=True, figsize=figsize) if title is not None: fig.suptitle(title) @@ -161,7 +159,7 @@ def array_scatter_5_features( axes[1].scatter( x, z, - c=maxptp_c, + c=np.exp(maxptp_c/50)-5, s=s_dot, alpha=0.1, marker=".", @@ -180,58 +178,47 @@ def array_scatter_5_features( excluded_ids=excluded_ids, do_ellipse=do_ellipse, ) - axes[2].set_xlabel("Amplitude (s.u.)", fontsize=16) + axes[2].set_xlabel("50*log(5+ptp) (s.u.)", fontsize=16) axes[2].tick_params(axis='x', labelsize=16) cluster_scatter( - trough_val, + density, z, labels, ax=axes[3], s=s_dot, alpha=0.1, excluded_ids=excluded_ids, - do_ellipse=do_ellipse, + do_ellipse=False, ) axes[3].set_xlabel("Trough Val (s.u.)", fontsize=16) axes[3].tick_params(axis='x', labelsize=16) - cluster_scatter( - trough_time, + axes[4].scatter( + x, z, - labels, - ax=axes[4], + c=density, s=s_dot, alpha=0.1, - excluded_ids=excluded_ids, - do_ellipse=do_ellipse, + marker=".", + cmap=plt.cm.jet, ) axes[4].set_xlabel("Trough Time (1/30ms)", fontsize=16) axes[4].tick_params(axis='x', labelsize=16) + axes[4].scatter(*geom.T, c="orange", marker="s", s=s_size_geom) + axes[4].set_title("colored by ptps") - cluster_scatter( - tip_val, - z, - labels, - ax=axes[5], - s=s_dot, - alpha=0.1, - excluded_ids=excluded_ids, - do_ellipse=do_ellipse, - ) - axes[5].set_xlabel("Tip Val (s.u.)", fontsize=16) - axes[5].tick_params(axis='x', labelsize=16) axes[0].set_ylim(zlim) axes[1].set_ylim(zlim) axes[2].set_ylim(zlim) axes[3].set_ylim(zlim) axes[4].set_ylim(zlim) - axes[5].set_ylim(zlim) if xlim is not None: axes[0].set_xlim(xlim) axes[1].set_xlim(xlim) + axes[4].set_xlim(xlim) if ptplim is not None: axes[2].set_xlim(ptplim)