Skip to content

Commit

Permalink
Debug
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 22, 2023
1 parent fd428ee commit 955be96
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 34 deletions.
23 changes: 15 additions & 8 deletions src/dartsort/templates/pairwise_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ def iterate_compressed_pairwise_convolutions(
print(
f"iterate_compressed_pairwise_convolutions {conv_batch_size=} {units_batch_size=} {device=}"
)
n_shifts = template_shift_index.all_pitch_shifts.size
do_shifting = n_shifts > 1
geom_kdtree = reg_geom_kdtree = match_distance = None
if do_shifting:
geom_kdtree = KDTree(geom)
Expand Down Expand Up @@ -1131,14 +1129,23 @@ def __post_init__(self):
device=self.device,
)
)
self.low_rank_templates.spatial_components = torch.as_tensor(
self.low_rank_templates.spatial_components, device=self.device
self.low_rank_templates_a.spatial_components = torch.as_tensor(
self.low_rank_templates_a.spatial_components, device=self.device
)
self.low_rank_templates_a.singular_values = torch.as_tensor(
self.low_rank_templates_a.singular_values, device=self.device
)
self.low_rank_templates_a.temporal_components = torch.as_tensor(
self.low_rank_templates_a.temporal_components, device=self.device
)
self.low_rank_templates_b.spatial_components = torch.as_tensor(
self.low_rank_templates_b.spatial_components, device=self.device
)
self.low_rank_templates.singular_values = torch.as_tensor(
self.low_rank_templates.singular_values, device=self.device
self.low_rank_templates_b.singular_values = torch.as_tensor(
self.low_rank_templates_b.singular_values, device=self.device
)
self.low_rank_templates.temporal_components = torch.as_tensor(
self.low_rank_templates.temporal_components, device=self.device
self.low_rank_templates_b.temporal_components = torch.as_tensor(
self.low_rank_templates_b.temporal_components, device=self.device
)


Expand Down
1 change: 1 addition & 0 deletions src/dartsort/templates/template_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def temporally_upsample_templates(

@dataclass
class CompressedUpsampledTemplates:
n_compressed_upsampled_templates: int
compressed_upsampled_templates: np.ndarray
compressed_upsampling_map: np.ndarray
compressed_upsampling_index: np.ndarray
Expand Down
14 changes: 11 additions & 3 deletions src/dartsort/templates/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,31 @@ def to_npz(self, npz_path):
def coarsen(self):
"""Weighted average all templates that share a unit id and re-localize."""
# update templates
templates = weighted_average(self.unit_ids, self.templates, self.spike_counts)
print(f"a {np.isnan(self.templates).any()=}")
print(f"b {np.equal(self.templates, 0).all(axis=(1,2)).sum()=}")
unit_ids_unique, flat_ids = np.unique(self.unit_ids, return_inverse=True)
templates = weighted_average(flat_ids, self.templates, self.spike_counts)
print(f"b {np.isnan(templates).any()=}")
print(f"b {np.equal(templates, 0).all(axis=(1,2)).sum()=}")

# collect spike counts
spike_counts = np.zeros(len(templates))
np.add.at(spike_counts, self.unit_ids, self.spike_counts)
np.add.at(spike_counts, np.arange(unit_ids_unique.size), self.spike_counts)
print(f"b {np.isnan(spike_counts).any()=}")
print(f"b {np.isnan(self.registered_geom).any()=}")

# re-localize
registered_template_depths_um = get_template_depths(
templates,
self.registered_geom,
localization_radius_um=self.localization_radius_um,
)
print(f"b {np.isnan(registered_template_depths_um).any()=}")

return replace(
self,
templates=templates,
unit_ids=np.arange(len(templates)),
unit_ids=unit_ids_unique,
spike_counts=spike_counts,
registered_template_depths_um=registered_template_depths_um,
)
Expand Down
44 changes: 21 additions & 23 deletions src/dartsort/util/drift_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,14 @@ def get_shift_and_unit_pairs(
reg_depths_um = np.concatenate((reg_depths_um_a, reg_depths_um_b))

# figure out all shifts for all units at all times
print(f"{chunk_time_centers_s.min()=} {chunk_time_centers_s.max()=}")
print(f"{reg_depths_um.min()=} {reg_depths_um.max()=}")
print(f"{reg_depths_um_a.min()=} {reg_depths_um_a.max()=}")
print(f"{reg_depths_um_b.min()=} {reg_depths_um_b.max()=}")
print(f"{motion_est.time_bin_centers_s.min()=}")
print(f"{motion_est.time_bin_centers_s.max()=}")
print(f"{motion_est.spatial_bin_centers_um.min()=}")
print(f"{motion_est.spatial_bin_centers_um.max()=}")
unreg_depths_um = np.concatenate(
[
motion_est.disp_at_s(t_s, depth_um=reg_depths_um, grid=True).T
Expand Down Expand Up @@ -570,27 +578,17 @@ def get_shift_and_unit_pairs(
template_shift_index_b = TemplateShiftIndex.from_shift_matrix(shifts_b)

# co-occurrence matrix: do these shifted templates appear together?
cooccurrence = np.eye(n_template_shift_pairs, dtype=bool)
for t_s in chunk_time_centers_s:
unregistered_depths_um = invert_motion_estimate(
motion_est, t_s, template_data.registered_template_depths_um
)
pitch_shifts = get_spike_pitch_shifts(
depths_um=template_data.registered_template_depths_um,
pitch=pitch,
registered_depths_um=unregistered_depths_um,
)
pitch_shifts = pitch_shifts.astype(int)
pitch_shift_ix = np.searchsorted(all_pitch_shifts, pitch_shifts)

shifted_temp_ixs = template_shift_index[temp_ixs, pitch_shift_ix]
cooccurrence[shifted_temp_ixs[:, None], shifted_temp_ixs[None, :]] = 1
cooccurrence = np.zeros(
(template_shift_index_a.n_shifted_templates, template_shift_index_b.n_shifted_templates),
dtype=bool)
temps_a = np.arange(na)
temps_b = np.arange(nb)
for j in range(len(chunk_time_centers_s)):
shifted_ids_a = template_shift_index_a.shifts_to_shifted_ids(temps_a, shifts_a[j])
if same:
shifted_ids_b = shifted_ids_a
else:
shifted_ids_b = template_shift_index_b.shifts_to_shifted_ids(temps_b, shifts_b[j])
cooccurrence[shifted_ids_a[:, None], shifted_ids_b[None, :]] = 1

return TemplateShiftIndex(
n_template_shift_pairs,
all_pitch_shifts,
template_shift_index,
cooccurrence,
shifted_temp_ix_to_temp_ix,
shifted_temp_ix_to_shift,
)
return template_shift_index_a, template_shift_index_b, cooccurrence

0 comments on commit 955be96

Please sign in to comment.