diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index ead136f2..c7f2c35c 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -191,8 +191,6 @@ def iterate_compressed_pairwise_convolutions( print( f"iterate_compressed_pairwise_convolutions {conv_batch_size=} {units_batch_size=} {device=}" ) - n_shifts = template_shift_index.all_pitch_shifts.size - do_shifting = n_shifts > 1 geom_kdtree = reg_geom_kdtree = match_distance = None if do_shifting: geom_kdtree = KDTree(geom) @@ -1131,14 +1129,23 @@ def __post_init__(self): device=self.device, ) ) - self.low_rank_templates.spatial_components = torch.as_tensor( - self.low_rank_templates.spatial_components, device=self.device + self.low_rank_templates_a.spatial_components = torch.as_tensor( + self.low_rank_templates_a.spatial_components, device=self.device + ) + self.low_rank_templates_a.singular_values = torch.as_tensor( + self.low_rank_templates_a.singular_values, device=self.device + ) + self.low_rank_templates_a.temporal_components = torch.as_tensor( + self.low_rank_templates_a.temporal_components, device=self.device + ) + self.low_rank_templates_b.spatial_components = torch.as_tensor( + self.low_rank_templates_b.spatial_components, device=self.device ) - self.low_rank_templates.singular_values = torch.as_tensor( - self.low_rank_templates.singular_values, device=self.device + self.low_rank_templates_b.singular_values = torch.as_tensor( + self.low_rank_templates_b.singular_values, device=self.device ) - self.low_rank_templates.temporal_components = torch.as_tensor( - self.low_rank_templates.temporal_components, device=self.device + self.low_rank_templates_b.temporal_components = torch.as_tensor( + self.low_rank_templates_b.temporal_components, device=self.device ) diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index 6b34d4da..57f303bd 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -262,6 +262,7 @@ def temporally_upsample_templates( @dataclass class CompressedUpsampledTemplates: + n_compressed_upsampled_templates: int compressed_upsampled_templates: np.ndarray compressed_upsampling_map: np.ndarray compressed_upsampling_index: np.ndarray diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 745210d6..f940ff49 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -52,11 +52,18 @@ def to_npz(self, npz_path): def coarsen(self): """Weighted average all templates that share a unit id and re-localize.""" # update templates - templates = weighted_average(self.unit_ids, self.templates, self.spike_counts) + print(f"a {np.isnan(self.templates).any()=}") + print(f"b {np.equal(self.templates, 0).all(axis=(1,2)).sum()=}") + unit_ids_unique, flat_ids = np.unique(self.unit_ids, return_inverse=True) + templates = weighted_average(flat_ids, self.templates, self.spike_counts) + print(f"b {np.isnan(templates).any()=}") + print(f"b {np.equal(templates, 0).all(axis=(1,2)).sum()=}") # collect spike counts spike_counts = np.zeros(len(templates)) - np.add.at(spike_counts, self.unit_ids, self.spike_counts) + np.add.at(spike_counts, np.arange(unit_ids_unique.size), self.spike_counts) + print(f"b {np.isnan(spike_counts).any()=}") + print(f"b {np.isnan(self.registered_geom).any()=}") # re-localize registered_template_depths_um = get_template_depths( @@ -64,11 +71,12 @@ def coarsen(self): self.registered_geom, localization_radius_um=self.localization_radius_um, ) + print(f"b {np.isnan(registered_template_depths_um).any()=}") return replace( self, templates=templates, - unit_ids=np.arange(len(templates)), + unit_ids=unit_ids_unique, spike_counts=spike_counts, registered_template_depths_um=registered_template_depths_um, ) diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index decd05eb..71ce470c 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -543,6 +543,14 @@ def get_shift_and_unit_pairs( reg_depths_um = np.concatenate((reg_depths_um_a, reg_depths_um_b)) # figure out all shifts for all units at all times + print(f"{chunk_time_centers_s.min()=} {chunk_time_centers_s.max()=}") + print(f"{reg_depths_um.min()=} {reg_depths_um.max()=}") + print(f"{reg_depths_um_a.min()=} {reg_depths_um_a.max()=}") + print(f"{reg_depths_um_b.min()=} {reg_depths_um_b.max()=}") + print(f"{motion_est.time_bin_centers_s.min()=}") + print(f"{motion_est.time_bin_centers_s.max()=}") + print(f"{motion_est.spatial_bin_centers_um.min()=}") + print(f"{motion_est.spatial_bin_centers_um.max()=}") unreg_depths_um = np.concatenate( [ motion_est.disp_at_s(t_s, depth_um=reg_depths_um, grid=True).T @@ -570,27 +578,17 @@ def get_shift_and_unit_pairs( template_shift_index_b = TemplateShiftIndex.from_shift_matrix(shifts_b) # co-occurrence matrix: do these shifted templates appear together? - cooccurrence = np.eye(n_template_shift_pairs, dtype=bool) - for t_s in chunk_time_centers_s: - unregistered_depths_um = invert_motion_estimate( - motion_est, t_s, template_data.registered_template_depths_um - ) - pitch_shifts = get_spike_pitch_shifts( - depths_um=template_data.registered_template_depths_um, - pitch=pitch, - registered_depths_um=unregistered_depths_um, - ) - pitch_shifts = pitch_shifts.astype(int) - pitch_shift_ix = np.searchsorted(all_pitch_shifts, pitch_shifts) - - shifted_temp_ixs = template_shift_index[temp_ixs, pitch_shift_ix] - cooccurrence[shifted_temp_ixs[:, None], shifted_temp_ixs[None, :]] = 1 + cooccurrence = np.zeros( + (template_shift_index_a.n_shifted_templates, template_shift_index_b.n_shifted_templates), + dtype=bool) + temps_a = np.arange(na) + temps_b = np.arange(nb) + for j in range(len(chunk_time_centers_s)): + shifted_ids_a = template_shift_index_a.shifts_to_shifted_ids(temps_a, shifts_a[j]) + if same: + shifted_ids_b = shifted_ids_a + else: + shifted_ids_b = template_shift_index_b.shifts_to_shifted_ids(temps_b, shifts_b[j]) + cooccurrence[shifted_ids_a[:, None], shifted_ids_b[None, :]] = 1 - return TemplateShiftIndex( - n_template_shift_pairs, - all_pitch_shifts, - template_shift_index, - cooccurrence, - shifted_temp_ix_to_temp_ix, - shifted_temp_ix_to_shift, - ) + return template_shift_index_a, template_shift_index_b, cooccurrence