From 96e124696e9d05dc3cdda18d303b85e229a66831 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 5 Dec 2023 22:52:55 -0500 Subject: [PATCH] Debug merge --- src/dartsort/cluster/merge.py | 39 ++++++++++++++++++++----- src/dartsort/peel/matching.py | 3 +- src/dartsort/templates/get_templates.py | 5 ---- src/dartsort/templates/pairwise_util.py | 20 +++++++------ src/dartsort/templates/templates.py | 3 +- src/dartsort/util/drift_util.py | 2 -- 6 files changed, 46 insertions(+), 26 deletions(-) diff --git a/src/dartsort/cluster/merge.py b/src/dartsort/cluster/merge.py index 0c87bbe5..fb0f20c3 100644 --- a/src/dartsort/cluster/merge.py +++ b/src/dartsort/cluster/merge.py @@ -6,7 +6,7 @@ from dartsort.templates import TemplateData, template_util from dartsort.templates.pairwise_util import ( construct_shift_indices, iterate_compressed_pairwise_convolutions) -from dartsort.util import DARTsortSorting +from dartsort.util.data_util import DARTsortSorting from scipy.cluster.hierarchy import complete, fcluster @@ -33,7 +33,28 @@ def merge_templates( overwrite_templates=False, show_progress=True, template_npz_filename="template_data.npz", -): +) -> DARTsortSorting: + """Template distance based merge + + Pass in a sorting, recording and template config to make templates, + and this will merge them (with superres). Or, if you have templates + already, pass them into template_data and we can skip the template + construction. + + Arguments + --------- + max_shift_samples + Max offset during matching + superres_linkage + How to combine distances between two units' superres templates + By default, it's the max. + amplitude_scaling_* + Optionally allow scaling during matching + + Returns + ------- + A new DARTsortSorting + """ if template_data is None: template_data = TemplateData.from_config( recording, @@ -50,7 +71,7 @@ def merge_templates( # allocate distance + shift matrices. shifts[i,j] is trough[j]-trough[i]. n_templates = template_data.templates.shape[0] sup_dists = np.full((n_templates, n_templates), np.inf) - sup_shifts = np.zero((n_templates, n_templates), dtype=int) + sup_shifts = np.zeros((n_templates, n_templates), dtype=int) # build distance matrix dec_res_iter = get_deconv_resid_norm_iter( @@ -78,7 +99,7 @@ def merge_templates( units = np.unique(template_data.unit_ids) if units.size < n_templates: dists = np.full((units.size, units.size), np.inf) - shifts = np.zero((units.size, units.size), dtype=int) + shifts = np.zeros((units.size, units.size), dtype=int) for ia, ua in enumerate(units): in_ua = np.flatnonzero(template_data.unit_ids == ua) for ib, ub in enumerate(units): @@ -98,6 +119,7 @@ def merge_templates( # now run hierarchical clustering return recluster( sorting, + units, dists, shifts, template_snrs, @@ -105,7 +127,7 @@ def merge_templates( ) -def recluster(sorting, dists, shifts, template_snrs, merge_distance_threshold=0.25): +def recluster(sorting, units, dists, shifts, template_snrs, merge_distance_threshold=0.25): # upper triangle not including diagonal, aka condensed distance matrix in scipy pdist = dists[np.triu_indices(dists.shape[0], k=1)] # scipy hierarchical clustering only supports finite values, so let's just @@ -118,8 +140,9 @@ def recluster(sorting, dists, shifts, template_snrs, merge_distance_threshold=0. # update labels labels_updated = sorting.labels.copy() - kept = np.flatnonzero(labels_updated >= 0) - labels_updated[kept] = new_labels[labels_updated[kept]] + kept = np.flatnonzero(np.isin(sorting.labels, units)) + _, flat_labels = np.unique(labels_updated[kept], return_inverse=True) + labels_updated[kept] = new_labels[flat_labels] # update times according to shifts times_updated = sorting.times_samples.copy() @@ -199,7 +222,7 @@ def get_deconv_resid_norm_iter( upsampled_shifted_template_index, do_shifting=False, reduce_deconv_resid_norm=True, - geom=template_data.registered_geometry, + geom=template_data.registered_geom, conv_ignore_threshold=0.0, coarse_approx_error_threshold=0.0, amplitude_scaling_variance=amplitude_scaling_variance, diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index db8169d5..bf17020b 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -427,7 +427,8 @@ def templates_at_time(self, t_s): """Handle drift -- grab the right spatial neighborhoods.""" pconvdb = self.pairwise_conv_db pitch_shifts_a = pitch_shifts_b = None - pconvdb.to(self.objective_spatial_components.device, pin=True) + if self.objective_spatial_components.device.type == "cuda" and not pconvdb.device.type == "cuda": + pconvdb.to(self.objective_spatial_components.device) if self.is_drifting: pitch_shifts_b, cur_spatial = template_util.templates_at_time( t_s, diff --git a/src/dartsort/templates/get_templates.py b/src/dartsort/templates/get_templates.py index 51f53c57..58b75d22 100644 --- a/src/dartsort/templates/get_templates.py +++ b/src/dartsort/templates/get_templates.py @@ -102,7 +102,6 @@ def get_templates( # pad the trough_offset_samples and spike_length_samples so that # if the user did not request denoising we can just return the # raw templates right away - print("realign") trough_offset_load = trough_offset_samples + realign_max_sample_shift spike_length_load = spike_length_samples + 2 * realign_max_sample_shift raw_results = get_raw_templates( @@ -166,7 +165,6 @@ def get_templates( device=device, ) raw_templates, low_rank_templates, snrs_by_channel = res - print(f"{raw_templates.ptp(1).max(1)=}") if raw_only: return dict( @@ -575,9 +573,6 @@ def _template_job(unit_ids): valid = np.flatnonzero( (times >= p.trough_offset_samples) & (times <= p.max_spike_time) ) - print(f"{times=} {valid=}") - print(f"{p.trough_offset_samples=} {p.max_spike_time}") - print(f"{times.shape=} {valid.shape=}") if not valid.size: return in_units = in_units[valid] diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index 6182ac46..1ac6ed8b 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -338,20 +338,21 @@ def conv_to_resid( deconv_resid_norms = np.zeros(n_pairs) shifts = np.zeros(n_pairs, dtype=int) template_indices_a, template_indices_b = pairs.T - templates_a = template_data_a.templates[template_indices_a].numpy(force=True) - templates_b = template_data_b.templates[template_indices_b].numpy(force=True) - template_a_norms = np.linalg.norm(templates_a, axis=(1, 2)) - template_b_norms = np.linalg.norm(templates_b, axis=(1, 2)) + templates_a = template_data_a.templates[template_indices_a] + templates_b = template_data_b.templates[template_indices_b] + template_a_norms = np.linalg.norm(templates_a, axis=(1, 2)) ** 2 + template_b_norms = np.linalg.norm(templates_b, axis=(1, 2)) ** 2 for j, (ix_a, ix_b) in enumerate(pairs): in_a = conv_result.template_indices_a == ix_a in_b = conv_result.template_indices_b == ix_b in_pair = np.flatnonzero(in_a & in_b) # reduce over fine templates - pair_conv = pconvs[in_pair].max(dim=0).values - best_conv, lag_index = pair_conv.max() + pair_conv = pconvs[in_pair].max(axis=0) + lag_index = np.argmax(pair_conv) + best_conv = pair_conv[lag_index] shifts[j] = lag_index - center - + # figure out scaling if amplitude_scaling_variance: amp_scale_min = 1 / (1 + amplitude_scaling_boundary) @@ -359,11 +360,12 @@ def conv_to_resid( inv_lambda = 1 / amplitude_scaling_variance b = best_conv + inv_lambda a = template_a_norms[j] + inv_lambda - scaling = torch.clip(b / a, amp_scale_min, amp_scale_max) - norm_reduction = 2 * scaling * b - torch.square(scaling) * a - inv_lambda + scaling = np.clip(b / a, amp_scale_min, amp_scale_max) + norm_reduction = 2 * scaling * b - np.square(scaling) * a - inv_lambda else: norm_reduction = 2 * best_conv - template_b_norms[j] deconv_resid_norms[j] = template_a_norms[j] - norm_reduction + assert deconv_resid_norms[j] >= 0 return DeconvResidResult( template_indices_a, diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index b5af6bd7..6ff43c34 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -47,6 +47,8 @@ def to_npz(self, npz_path): to_save[ "registered_template_depths_um" ] = self.registered_template_depths_um + if not npz_path.parent.exists(): + npz_path.parent.mkdir() np.savez(npz_path, **to_save) def coarsen(self, with_locs=True): @@ -171,7 +173,6 @@ def from_config( else: # we don't skip empty units unit_ids = np.arange(sorting.labels.max() + 1) - print(f"post superres {sorting.times_samples=} {sorting.labels=}") # count spikes in each template spike_counts = np.zeros_like(unit_ids) diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index 0bf4464a..7c3dce5d 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -564,8 +564,6 @@ def get_shift_and_unit_pairs( else: shifts_a = pitch_shifts[:, :na] shifts_b = pitch_shifts[:, na:] - print(f"{shifts_a.min()=} {shifts_a.max()=}") - print(f"{shifts_b.min()=} {shifts_b.max()=}") # assign ids to pitch/shift pairs template_shift_index_a = TemplateShiftIndex.from_shift_matrix(shifts_a)