Skip to content

Commit

Permalink
Debug merge
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Dec 6, 2023
1 parent a124a1f commit 96e1246
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 26 deletions.
39 changes: 31 additions & 8 deletions src/dartsort/cluster/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dartsort.templates import TemplateData, template_util
from dartsort.templates.pairwise_util import (
construct_shift_indices, iterate_compressed_pairwise_convolutions)
from dartsort.util import DARTsortSorting
from dartsort.util.data_util import DARTsortSorting
from scipy.cluster.hierarchy import complete, fcluster


Expand All @@ -33,7 +33,28 @@ def merge_templates(
overwrite_templates=False,
show_progress=True,
template_npz_filename="template_data.npz",
):
) -> DARTsortSorting:
"""Template distance based merge
Pass in a sorting, recording and template config to make templates,
and this will merge them (with superres). Or, if you have templates
already, pass them into template_data and we can skip the template
construction.
Arguments
---------
max_shift_samples
Max offset during matching
superres_linkage
How to combine distances between two units' superres templates
By default, it's the max.
amplitude_scaling_*
Optionally allow scaling during matching
Returns
-------
A new DARTsortSorting
"""
if template_data is None:
template_data = TemplateData.from_config(
recording,
Expand All @@ -50,7 +71,7 @@ def merge_templates(
# allocate distance + shift matrices. shifts[i,j] is trough[j]-trough[i].
n_templates = template_data.templates.shape[0]
sup_dists = np.full((n_templates, n_templates), np.inf)
sup_shifts = np.zero((n_templates, n_templates), dtype=int)
sup_shifts = np.zeros((n_templates, n_templates), dtype=int)

# build distance matrix
dec_res_iter = get_deconv_resid_norm_iter(
Expand Down Expand Up @@ -78,7 +99,7 @@ def merge_templates(
units = np.unique(template_data.unit_ids)
if units.size < n_templates:
dists = np.full((units.size, units.size), np.inf)
shifts = np.zero((units.size, units.size), dtype=int)
shifts = np.zeros((units.size, units.size), dtype=int)
for ia, ua in enumerate(units):
in_ua = np.flatnonzero(template_data.unit_ids == ua)
for ib, ub in enumerate(units):
Expand All @@ -98,14 +119,15 @@ def merge_templates(
# now run hierarchical clustering
return recluster(
sorting,
units,
dists,
shifts,
template_snrs,
merge_distance_threshold=merge_distance_threshold,
)


def recluster(sorting, dists, shifts, template_snrs, merge_distance_threshold=0.25):
def recluster(sorting, units, dists, shifts, template_snrs, merge_distance_threshold=0.25):
# upper triangle not including diagonal, aka condensed distance matrix in scipy
pdist = dists[np.triu_indices(dists.shape[0], k=1)]
# scipy hierarchical clustering only supports finite values, so let's just
Expand All @@ -118,8 +140,9 @@ def recluster(sorting, dists, shifts, template_snrs, merge_distance_threshold=0.

# update labels
labels_updated = sorting.labels.copy()
kept = np.flatnonzero(labels_updated >= 0)
labels_updated[kept] = new_labels[labels_updated[kept]]
kept = np.flatnonzero(np.isin(sorting.labels, units))
_, flat_labels = np.unique(labels_updated[kept], return_inverse=True)
labels_updated[kept] = new_labels[flat_labels]

# update times according to shifts
times_updated = sorting.times_samples.copy()
Expand Down Expand Up @@ -199,7 +222,7 @@ def get_deconv_resid_norm_iter(
upsampled_shifted_template_index,
do_shifting=False,
reduce_deconv_resid_norm=True,
geom=template_data.registered_geometry,
geom=template_data.registered_geom,
conv_ignore_threshold=0.0,
coarse_approx_error_threshold=0.0,
amplitude_scaling_variance=amplitude_scaling_variance,
Expand Down
3 changes: 2 additions & 1 deletion src/dartsort/peel/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ def templates_at_time(self, t_s):
"""Handle drift -- grab the right spatial neighborhoods."""
pconvdb = self.pairwise_conv_db
pitch_shifts_a = pitch_shifts_b = None
pconvdb.to(self.objective_spatial_components.device, pin=True)
if self.objective_spatial_components.device.type == "cuda" and not pconvdb.device.type == "cuda":
pconvdb.to(self.objective_spatial_components.device)
if self.is_drifting:
pitch_shifts_b, cur_spatial = template_util.templates_at_time(
t_s,
Expand Down
5 changes: 0 additions & 5 deletions src/dartsort/templates/get_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def get_templates(
# pad the trough_offset_samples and spike_length_samples so that
# if the user did not request denoising we can just return the
# raw templates right away
print("realign")
trough_offset_load = trough_offset_samples + realign_max_sample_shift
spike_length_load = spike_length_samples + 2 * realign_max_sample_shift
raw_results = get_raw_templates(
Expand Down Expand Up @@ -166,7 +165,6 @@ def get_templates(
device=device,
)
raw_templates, low_rank_templates, snrs_by_channel = res
print(f"{raw_templates.ptp(1).max(1)=}")

if raw_only:
return dict(
Expand Down Expand Up @@ -575,9 +573,6 @@ def _template_job(unit_ids):
valid = np.flatnonzero(
(times >= p.trough_offset_samples) & (times <= p.max_spike_time)
)
print(f"{times=} {valid=}")
print(f"{p.trough_offset_samples=} {p.max_spike_time}")
print(f"{times.shape=} {valid.shape=}")
if not valid.size:
return
in_units = in_units[valid]
Expand Down
20 changes: 11 additions & 9 deletions src/dartsort/templates/pairwise_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,32 +338,34 @@ def conv_to_resid(
deconv_resid_norms = np.zeros(n_pairs)
shifts = np.zeros(n_pairs, dtype=int)
template_indices_a, template_indices_b = pairs.T
templates_a = template_data_a.templates[template_indices_a].numpy(force=True)
templates_b = template_data_b.templates[template_indices_b].numpy(force=True)
template_a_norms = np.linalg.norm(templates_a, axis=(1, 2))
template_b_norms = np.linalg.norm(templates_b, axis=(1, 2))
templates_a = template_data_a.templates[template_indices_a]
templates_b = template_data_b.templates[template_indices_b]
template_a_norms = np.linalg.norm(templates_a, axis=(1, 2)) ** 2
template_b_norms = np.linalg.norm(templates_b, axis=(1, 2)) ** 2
for j, (ix_a, ix_b) in enumerate(pairs):
in_a = conv_result.template_indices_a == ix_a
in_b = conv_result.template_indices_b == ix_b
in_pair = np.flatnonzero(in_a & in_b)

# reduce over fine templates
pair_conv = pconvs[in_pair].max(dim=0).values
best_conv, lag_index = pair_conv.max()
pair_conv = pconvs[in_pair].max(axis=0)
lag_index = np.argmax(pair_conv)
best_conv = pair_conv[lag_index]
shifts[j] = lag_index - center

# figure out scaling
if amplitude_scaling_variance:
amp_scale_min = 1 / (1 + amplitude_scaling_boundary)
amp_scale_max = 1 + amplitude_scaling_boundary
inv_lambda = 1 / amplitude_scaling_variance
b = best_conv + inv_lambda
a = template_a_norms[j] + inv_lambda
scaling = torch.clip(b / a, amp_scale_min, amp_scale_max)
norm_reduction = 2 * scaling * b - torch.square(scaling) * a - inv_lambda
scaling = np.clip(b / a, amp_scale_min, amp_scale_max)
norm_reduction = 2 * scaling * b - np.square(scaling) * a - inv_lambda
else:
norm_reduction = 2 * best_conv - template_b_norms[j]
deconv_resid_norms[j] = template_a_norms[j] - norm_reduction
assert deconv_resid_norms[j] >= 0

return DeconvResidResult(
template_indices_a,
Expand Down
3 changes: 2 additions & 1 deletion src/dartsort/templates/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def to_npz(self, npz_path):
to_save[
"registered_template_depths_um"
] = self.registered_template_depths_um
if not npz_path.parent.exists():
npz_path.parent.mkdir()
np.savez(npz_path, **to_save)

def coarsen(self, with_locs=True):
Expand Down Expand Up @@ -171,7 +173,6 @@ def from_config(
else:
# we don't skip empty units
unit_ids = np.arange(sorting.labels.max() + 1)
print(f"post superres {sorting.times_samples=} {sorting.labels=}")

# count spikes in each template
spike_counts = np.zeros_like(unit_ids)
Expand Down
2 changes: 0 additions & 2 deletions src/dartsort/util/drift_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,6 @@ def get_shift_and_unit_pairs(
else:
shifts_a = pitch_shifts[:, :na]
shifts_b = pitch_shifts[:, na:]
print(f"{shifts_a.min()=} {shifts_a.max()=}")
print(f"{shifts_b.min()=} {shifts_b.max()=}")

# assign ids to pitch/shift pairs
template_shift_index_a = TemplateShiftIndex.from_shift_matrix(shifts_a)
Expand Down

0 comments on commit 96e1246

Please sign in to comment.