From dbb0ab60dc60dd70910ba94ca90a1f4f04b625e0 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sun, 18 Feb 2024 12:17:25 -0500 Subject: [PATCH] Re-enable scaling in merge; multiprocessing QoL; scatterplots QoL --- src/dartsort/cluster/merge.py | 30 +++++----- src/dartsort/templates/pairwise_util.py | 72 ++++++++++++++--------- src/dartsort/util/analysis.py | 3 +- src/dartsort/util/multiprocessing_util.py | 9 ++- src/dartsort/vis/scatterplots.py | 27 ++++----- src/dartsort/vis/unit.py | 4 +- 6 files changed, 82 insertions(+), 63 deletions(-) diff --git a/src/dartsort/cluster/merge.py b/src/dartsort/cluster/merge.py index 802228be..baf88c8a 100644 --- a/src/dartsort/cluster/merge.py +++ b/src/dartsort/cluster/merge.py @@ -24,9 +24,9 @@ def merge_templates( sym_function=np.minimum, merge_distance_threshold=0.25, temporal_upsampling_factor=8, - amplitude_scaling_variance=0.0, - amplitude_scaling_boundary=0.5, - svd_compression_rank=10, + amplitude_scaling_variance=0.001, + amplitude_scaling_boundary=0.1, + svd_compression_rank=20, min_channel_amplitude=1.0, min_spatial_cosine=0.0, conv_batch_size=128, @@ -115,9 +115,9 @@ def merge_across_sortings( sym_function=np.minimum, max_shift_samples=20, temporal_upsampling_factor=8, - amplitude_scaling_variance=0.0, - amplitude_scaling_boundary=0.5, - svd_compression_rank=10, + amplitude_scaling_variance=0.001, + amplitude_scaling_boundary=0.1, + svd_compression_rank=20, min_channel_amplitude=0.0, min_spatial_cosine=0.0, conv_batch_size=128, @@ -213,9 +213,9 @@ def calculate_merge_distances( sym_function=np.minimum, max_shift_samples=20, temporal_upsampling_factor=8, - amplitude_scaling_variance=0.0, - amplitude_scaling_boundary=0.5, - svd_compression_rank=10, + amplitude_scaling_variance=0.001, + amplitude_scaling_boundary=0.1, + svd_compression_rank=20, min_channel_amplitude=1.0, min_spatial_cosine=0.0, cooccurrence_mask=None, @@ -299,9 +299,9 @@ def cross_match_distance_matrix( sym_function=np.minimum, max_shift_samples=20, temporal_upsampling_factor=8, - amplitude_scaling_variance=0.0, - amplitude_scaling_boundary=0.5, - svd_compression_rank=10, + amplitude_scaling_variance=0.001, + amplitude_scaling_boundary=0.1, + svd_compression_rank=20, min_channel_amplitude=0.0, min_spatial_cosine=0.0, conv_batch_size=128, @@ -482,9 +482,9 @@ def get_deconv_resid_norm_iter( template_data, max_shift_samples=20, temporal_upsampling_factor=8, - amplitude_scaling_variance=0.0, - amplitude_scaling_boundary=0.5, - svd_compression_rank=10, + amplitude_scaling_variance=0.001, + amplitude_scaling_boundary=0.1, + svd_compression_rank=20, min_channel_amplitude=0.0, min_spatial_cosine=0.0, cooccurrence_mask=None, diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index a6870e85..c8b7a6de 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -323,8 +323,9 @@ class DeconvResidResult: def conv_to_resid( - template_data_a: templates.TemplateData, - template_data_b: templates.TemplateData, + # template_data_a: templates.TemplateData, + low_rank_templates_a: template_util.LowRankTemplates, + low_rank_templates_b: template_util.LowRankTemplates, conv_result: CompressedConvResult, amplitude_scaling_variance=0.0, amplitude_scaling_boundary=0.5, @@ -344,10 +345,19 @@ def conv_to_resid( deconv_resid_norms = np.zeros(n_pairs) shifts = np.zeros(n_pairs, dtype=int) template_indices_a, template_indices_b = pairs.T - templates_a = template_data_a.templates[template_indices_a] - templates_b = template_data_b.templates[template_indices_b] - template_a_norms = np.linalg.norm(templates_a, axis=(1, 2)) ** 2 - template_b_norms = np.linalg.norm(templates_b, axis=(1, 2)) ** 2 + + # templates_a = template_data_a.templates[template_indices_a] + # template_a_norms = np.linalg.norm(templates_a, axis=(1, 2)) ** 2 + # templates_b = template_data_b.templates[template_indices_b] + # template_b_norms = np.linalg.norm(templates_b, axis=(1, 2)) ** 2 + + # low rank template norms + svs_a = low_rank_templates_a.singular_values[template_indices_a] + template_a_norms = torch.square(svs_a).sum(1).numpy(force=True) + svs_b = low_rank_templates_b.singular_values[template_indices_b] + template_b_norms = torch.square(svs_b).sum(1).numpy(force=True) + + # now, compute reduction in norm of A after matching by B for j, (ix_a, ix_b) in enumerate(pairs): in_a = conv_result.template_indices_a == ix_a in_b = conv_result.template_indices_b == ix_b @@ -361,18 +371,18 @@ def conv_to_resid( # figure out scaling if amplitude_scaling_variance: - amp_scale_min = 1. / (1. + amplitude_scaling_boundary) - amp_scale_max = 1. + amplitude_scaling_boundary - inv_lambda = 1. / amplitude_scaling_variance + amp_scale_min = 1.0 / (1.0 + amplitude_scaling_boundary) + amp_scale_max = 1.0 + amplitude_scaling_boundary + inv_lambda = 1.0 / amplitude_scaling_variance b = best_conv + inv_lambda - a = template_a_norms[j] + inv_lambda - scaling = np.clip(b / a, amp_scale_min, amp_scale_max) - norm_reduction = 2. * scaling * b - np.square(scaling) * a - inv_lambda + a = template_b_norms[j] + inv_lambda + scaling = (b / a).clip(amp_scale_min, amp_scale_max) + norm_reduction = 2.0 * scaling * b - np.square(scaling) * a - inv_lambda else: - norm_reduction = 2. * best_conv - template_b_norms[j] + norm_reduction = 2.0 * best_conv - template_b_norms[j] + deconv_resid_norms[j] = template_a_norms[j] - norm_reduction - - assert deconv_resid_norms[j] >= -1e-1 + assert (deconv_resid_norms >= -0.01).all() return DeconvResidResult( template_indices_a, @@ -482,7 +492,7 @@ def compressed_convolve_pairs( compression_index, conv_ix, conv_upsampling_indices_b, - conv_temporal_components_up_b, #Need to change this conv_temporal_components_up_b[conv_compressed_upsampled_ix_b] + conv_temporal_components_up_b, # Need to change this conv_temporal_components_up_b[conv_compressed_upsampled_ix_b] conv_compressed_upsampled_ix_b, compression_dup_ix, ) = compressed_upsampled_pairs( @@ -554,8 +564,8 @@ def compressed_convolve_pairs( ) if reduce_deconv_resid_norm: return conv_to_resid( - template_data_a, - template_data_b, + low_rank_templates_a, + low_rank_templates_b, res, amplitude_scaling_variance=amplitude_scaling_variance, amplitude_scaling_boundary=amplitude_scaling_boundary, @@ -630,7 +640,9 @@ def correlate_pairs_lowrank( ix = slice(istart, iend) # want conv filter: nco, 1, rank, t - template_a = torch.bmm(temporal_a[ix_a[conv_ix][ix]], spatial_a[ix_a[conv_ix][ix]]) + template_a = torch.bmm( + temporal_a[ix_a[conv_ix][ix]], spatial_a[ix_a[conv_ix][ix]] + ) conv_filt = torch.bmm(spatial_b[ix_b[conv_ix][ix]], template_a.mT) conv_filt = conv_filt[:, None] # (nco, 1, rank, t) @@ -982,7 +994,15 @@ def compressed_upsampled_pairs( # temp_comps = compressed_upsampled_temporal.compressed_upsampled_templates[ # np.atleast_1d(temp_ix_b[ix_b[conv_ix]]) # ] - return ix_b, compression_index, conv_ix, upinds, compressed_upsampled_temporal.compressed_upsampled_templates, np.atleast_1d(temp_ix_b[ix_b[conv_ix]]), compression_dup_ix + return ( + ix_b, + compression_index, + conv_ix, + upinds, + compressed_upsampled_temporal.compressed_upsampled_templates, + np.atleast_1d(temp_ix_b[ix_b[conv_ix]]), + compression_dup_ix, + ) # each conv_ix needs to be duplicated as many times as its b template has # upsampled copies @@ -993,9 +1013,7 @@ def compressed_upsampled_pairs( ) conv_up_i, up_shift_up_i = np.nonzero(upsampling_mask) conv_compressed_upsampled_ix = ( - upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[ - up_shift_up_i - ] + upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[up_shift_up_i] ) conv_dup = conv_ix[conv_up_i] # And, all ix_{a,b}[i] such that compression_ix[i] lands in @@ -1005,9 +1023,9 @@ def compressed_upsampled_pairs( dup_mask = dup_mask.numpy(force=True) compression_dup_ix, compression_index_up = np.nonzero(dup_mask) ix_b_up = ix_b[compression_dup_ix] - + # the conv ix need to be offset to keep the relation with the pairs - # ix_a[old i] + # ix_a[old i] # offsets = np.cumsum((conv_ix[:, None] == conv_dup[None, :]).sum(0)) # offsets -= offsets[0] _, offsets = np.unique(compression_dup_ix, return_index=True) @@ -1019,9 +1037,9 @@ def compressed_upsampled_pairs( conv_compressed_upsampled_ix ] ) - + # conv_temporal_components_up_b = compressed_upsampled_temporal.compressed_upsampled_templates - + return ( ix_b_up, compression_index_up, diff --git a/src/dartsort/util/analysis.py b/src/dartsort/util/analysis.py index 45ec288c..381f2428 100644 --- a/src/dartsort/util/analysis.py +++ b/src/dartsort/util/analysis.py @@ -580,7 +580,8 @@ def nearby_coarse_templates(self, unit_id, n_neighbors=5): unit_ix = np.searchsorted(self.unit_ids, unit_id) unit_dists = self.merge_dist[unit_ix] distance_order = np.argsort(unit_dists) - assert distance_order[0] == unit_ix + distance_order = np.concatenate(([unit_ix], distance_order[distance_order != unit_ix])) + # assert distance_order[0] == unit_ix neighbor_ixs = distance_order[:n_neighbors] neighbor_ids = self.unit_ids[neighbor_ixs] neighbor_dists = self.merge_dist[neighbor_ixs[:, None], neighbor_ixs[None, :]] diff --git a/src/dartsort/util/multiprocessing_util.py b/src/dartsort/util/multiprocessing_util.py index efc8a691..f88e4804 100644 --- a/src/dartsort/util/multiprocessing_util.py +++ b/src/dartsort/util/multiprocessing_util.py @@ -6,8 +6,10 @@ try: import cloudpickle + have_cloudpickle = True except ImportError: pass + have_cloudpickle = False class MockFuture: @@ -69,12 +71,17 @@ def submit(self, fn, /, *args, **kwargs): def get_pool( - n_jobs, context="spawn", cls=ProcessPoolExecutor, with_rank_queue=False + n_jobs, + context="spawn", + cls=ProcessPoolExecutor, + with_rank_queue=False, ): if n_jobs == -1: n_jobs = multiprocessing.cpu_count() do_parallel = n_jobs >= 1 n_jobs = max(1, n_jobs) + if cls == CloudpicklePoolExecutor and not have_cloudpickle: + cls = ProcessPoolExecutor Executor = cls if do_parallel else MockPoolExecutor context = get_context(context) if with_rank_queue: diff --git a/src/dartsort/vis/scatterplots.py b/src/dartsort/vis/scatterplots.py index bac13368..0dd28aec 100644 --- a/src/dartsort/vis/scatterplots.py +++ b/src/dartsort/vis/scatterplots.py @@ -486,29 +486,22 @@ def scatter_feature_vs_depth( c = np.clip(amplitudes, 0, amplitude_color_cutoff) cmap = amplitude_cmap kept = slice(None) - elif show_triaged: - c = labels - # cmap = colorcet.m_glasbey_light - cmap = glasbey1024 - print("quack") - c = cmap[c % len(cmap)] - kept = labels[to_show] >= 0 - ax.scatter( - feature[to_show[~kept]], - depths_um[to_show[~kept]], - color="dimgray", - s=s, - linewidth=linewidth, - rasterized=rasterized, - **scatter_kw, - ) else: c = labels # cmap = colorcet.m_glasbey_light cmap = glasbey1024 - print("quack") c = cmap[c % len(cmap)] kept = labels[to_show] >= 0 + if show_triaged: + ax.scatter( + feature[to_show[~kept]], + depths_um[to_show[~kept]], + color="dimgray", + s=s, + linewidth=linewidth, + rasterized=rasterized, + **scatter_kw, + ) s = ax.scatter( feature[to_show[kept]], diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index 1f2a6c47..d3d04fe4 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -17,7 +17,7 @@ import numpy as np from matplotlib.legend_handler import HandlerTuple -from ..util.multiprocessing_util import get_pool +from ..util.multiprocessing_util import get_pool, CloudpicklePoolExecutor from .waveforms import geomplot # -- main class. see fn make_unit_summary below to make lots of UnitPlots. @@ -747,7 +747,7 @@ def make_all_summaries( save_folder = Path(save_folder) save_folder.mkdir(exist_ok=True) - n_jobs, Executor, context = get_pool(n_jobs) + n_jobs, Executor, context = get_pool(n_jobs, cls=CloudpicklePoolExecutor) with Executor( max_workers=n_jobs, mp_context=context,