Skip to content

Commit

Permalink
Re-enable scaling in merge; multiprocessing QoL; scatterplots QoL
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Feb 18, 2024
1 parent e427b27 commit dbb0ab6
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 63 deletions.
30 changes: 15 additions & 15 deletions src/dartsort/cluster/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def merge_templates(
sym_function=np.minimum,
merge_distance_threshold=0.25,
temporal_upsampling_factor=8,
amplitude_scaling_variance=0.0,
amplitude_scaling_boundary=0.5,
svd_compression_rank=10,
amplitude_scaling_variance=0.001,
amplitude_scaling_boundary=0.1,
svd_compression_rank=20,
min_channel_amplitude=1.0,
min_spatial_cosine=0.0,
conv_batch_size=128,
Expand Down Expand Up @@ -115,9 +115,9 @@ def merge_across_sortings(
sym_function=np.minimum,
max_shift_samples=20,
temporal_upsampling_factor=8,
amplitude_scaling_variance=0.0,
amplitude_scaling_boundary=0.5,
svd_compression_rank=10,
amplitude_scaling_variance=0.001,
amplitude_scaling_boundary=0.1,
svd_compression_rank=20,
min_channel_amplitude=0.0,
min_spatial_cosine=0.0,
conv_batch_size=128,
Expand Down Expand Up @@ -213,9 +213,9 @@ def calculate_merge_distances(
sym_function=np.minimum,
max_shift_samples=20,
temporal_upsampling_factor=8,
amplitude_scaling_variance=0.0,
amplitude_scaling_boundary=0.5,
svd_compression_rank=10,
amplitude_scaling_variance=0.001,
amplitude_scaling_boundary=0.1,
svd_compression_rank=20,
min_channel_amplitude=1.0,
min_spatial_cosine=0.0,
cooccurrence_mask=None,
Expand Down Expand Up @@ -299,9 +299,9 @@ def cross_match_distance_matrix(
sym_function=np.minimum,
max_shift_samples=20,
temporal_upsampling_factor=8,
amplitude_scaling_variance=0.0,
amplitude_scaling_boundary=0.5,
svd_compression_rank=10,
amplitude_scaling_variance=0.001,
amplitude_scaling_boundary=0.1,
svd_compression_rank=20,
min_channel_amplitude=0.0,
min_spatial_cosine=0.0,
conv_batch_size=128,
Expand Down Expand Up @@ -482,9 +482,9 @@ def get_deconv_resid_norm_iter(
template_data,
max_shift_samples=20,
temporal_upsampling_factor=8,
amplitude_scaling_variance=0.0,
amplitude_scaling_boundary=0.5,
svd_compression_rank=10,
amplitude_scaling_variance=0.001,
amplitude_scaling_boundary=0.1,
svd_compression_rank=20,
min_channel_amplitude=0.0,
min_spatial_cosine=0.0,
cooccurrence_mask=None,
Expand Down
72 changes: 45 additions & 27 deletions src/dartsort/templates/pairwise_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,9 @@ class DeconvResidResult:


def conv_to_resid(
template_data_a: templates.TemplateData,
template_data_b: templates.TemplateData,
# template_data_a: templates.TemplateData,
low_rank_templates_a: template_util.LowRankTemplates,
low_rank_templates_b: template_util.LowRankTemplates,
conv_result: CompressedConvResult,
amplitude_scaling_variance=0.0,
amplitude_scaling_boundary=0.5,
Expand All @@ -344,10 +345,19 @@ def conv_to_resid(
deconv_resid_norms = np.zeros(n_pairs)
shifts = np.zeros(n_pairs, dtype=int)
template_indices_a, template_indices_b = pairs.T
templates_a = template_data_a.templates[template_indices_a]
templates_b = template_data_b.templates[template_indices_b]
template_a_norms = np.linalg.norm(templates_a, axis=(1, 2)) ** 2
template_b_norms = np.linalg.norm(templates_b, axis=(1, 2)) ** 2

# templates_a = template_data_a.templates[template_indices_a]
# template_a_norms = np.linalg.norm(templates_a, axis=(1, 2)) ** 2
# templates_b = template_data_b.templates[template_indices_b]
# template_b_norms = np.linalg.norm(templates_b, axis=(1, 2)) ** 2

# low rank template norms
svs_a = low_rank_templates_a.singular_values[template_indices_a]
template_a_norms = torch.square(svs_a).sum(1).numpy(force=True)
svs_b = low_rank_templates_b.singular_values[template_indices_b]
template_b_norms = torch.square(svs_b).sum(1).numpy(force=True)

# now, compute reduction in norm of A after matching by B
for j, (ix_a, ix_b) in enumerate(pairs):
in_a = conv_result.template_indices_a == ix_a
in_b = conv_result.template_indices_b == ix_b
Expand All @@ -361,18 +371,18 @@ def conv_to_resid(

# figure out scaling
if amplitude_scaling_variance:
amp_scale_min = 1. / (1. + amplitude_scaling_boundary)
amp_scale_max = 1. + amplitude_scaling_boundary
inv_lambda = 1. / amplitude_scaling_variance
amp_scale_min = 1.0 / (1.0 + amplitude_scaling_boundary)
amp_scale_max = 1.0 + amplitude_scaling_boundary
inv_lambda = 1.0 / amplitude_scaling_variance
b = best_conv + inv_lambda
a = template_a_norms[j] + inv_lambda
scaling = np.clip(b / a, amp_scale_min, amp_scale_max)
norm_reduction = 2. * scaling * b - np.square(scaling) * a - inv_lambda
a = template_b_norms[j] + inv_lambda
scaling = (b / a).clip(amp_scale_min, amp_scale_max)
norm_reduction = 2.0 * scaling * b - np.square(scaling) * a - inv_lambda
else:
norm_reduction = 2. * best_conv - template_b_norms[j]
norm_reduction = 2.0 * best_conv - template_b_norms[j]

deconv_resid_norms[j] = template_a_norms[j] - norm_reduction

assert deconv_resid_norms[j] >= -1e-1
assert (deconv_resid_norms >= -0.01).all()

return DeconvResidResult(
template_indices_a,
Expand Down Expand Up @@ -482,7 +492,7 @@ def compressed_convolve_pairs(
compression_index,
conv_ix,
conv_upsampling_indices_b,
conv_temporal_components_up_b, #Need to change this conv_temporal_components_up_b[conv_compressed_upsampled_ix_b]
conv_temporal_components_up_b, # Need to change this conv_temporal_components_up_b[conv_compressed_upsampled_ix_b]
conv_compressed_upsampled_ix_b,
compression_dup_ix,
) = compressed_upsampled_pairs(
Expand Down Expand Up @@ -554,8 +564,8 @@ def compressed_convolve_pairs(
)
if reduce_deconv_resid_norm:
return conv_to_resid(
template_data_a,
template_data_b,
low_rank_templates_a,
low_rank_templates_b,
res,
amplitude_scaling_variance=amplitude_scaling_variance,
amplitude_scaling_boundary=amplitude_scaling_boundary,
Expand Down Expand Up @@ -630,7 +640,9 @@ def correlate_pairs_lowrank(
ix = slice(istart, iend)

# want conv filter: nco, 1, rank, t
template_a = torch.bmm(temporal_a[ix_a[conv_ix][ix]], spatial_a[ix_a[conv_ix][ix]])
template_a = torch.bmm(
temporal_a[ix_a[conv_ix][ix]], spatial_a[ix_a[conv_ix][ix]]
)
conv_filt = torch.bmm(spatial_b[ix_b[conv_ix][ix]], template_a.mT)
conv_filt = conv_filt[:, None] # (nco, 1, rank, t)

Expand Down Expand Up @@ -982,7 +994,15 @@ def compressed_upsampled_pairs(
# temp_comps = compressed_upsampled_temporal.compressed_upsampled_templates[
# np.atleast_1d(temp_ix_b[ix_b[conv_ix]])
# ]
return ix_b, compression_index, conv_ix, upinds, compressed_upsampled_temporal.compressed_upsampled_templates, np.atleast_1d(temp_ix_b[ix_b[conv_ix]]), compression_dup_ix
return (
ix_b,
compression_index,
conv_ix,
upinds,
compressed_upsampled_temporal.compressed_upsampled_templates,
np.atleast_1d(temp_ix_b[ix_b[conv_ix]]),
compression_dup_ix,
)

# each conv_ix needs to be duplicated as many times as its b template has
# upsampled copies
Expand All @@ -993,9 +1013,7 @@ def compressed_upsampled_pairs(
)
conv_up_i, up_shift_up_i = np.nonzero(upsampling_mask)
conv_compressed_upsampled_ix = (
upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[
up_shift_up_i
]
upsampled_shifted_template_index.up_shift_temp_ix_to_comp_up_ix[up_shift_up_i]
)
conv_dup = conv_ix[conv_up_i]
# And, all ix_{a,b}[i] such that compression_ix[i] lands in
Expand All @@ -1005,9 +1023,9 @@ def compressed_upsampled_pairs(
dup_mask = dup_mask.numpy(force=True)
compression_dup_ix, compression_index_up = np.nonzero(dup_mask)
ix_b_up = ix_b[compression_dup_ix]

# the conv ix need to be offset to keep the relation with the pairs
# ix_a[old i]
# ix_a[old i]
# offsets = np.cumsum((conv_ix[:, None] == conv_dup[None, :]).sum(0))
# offsets -= offsets[0]
_, offsets = np.unique(compression_dup_ix, return_index=True)
Expand All @@ -1019,9 +1037,9 @@ def compressed_upsampled_pairs(
conv_compressed_upsampled_ix
]
)

# conv_temporal_components_up_b = compressed_upsampled_temporal.compressed_upsampled_templates

return (
ix_b_up,
compression_index_up,
Expand Down
3 changes: 2 additions & 1 deletion src/dartsort/util/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,8 @@ def nearby_coarse_templates(self, unit_id, n_neighbors=5):
unit_ix = np.searchsorted(self.unit_ids, unit_id)
unit_dists = self.merge_dist[unit_ix]
distance_order = np.argsort(unit_dists)
assert distance_order[0] == unit_ix
distance_order = np.concatenate(([unit_ix], distance_order[distance_order != unit_ix]))
# assert distance_order[0] == unit_ix
neighbor_ixs = distance_order[:n_neighbors]
neighbor_ids = self.unit_ids[neighbor_ixs]
neighbor_dists = self.merge_dist[neighbor_ixs[:, None], neighbor_ixs[None, :]]
Expand Down
9 changes: 8 additions & 1 deletion src/dartsort/util/multiprocessing_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

try:
import cloudpickle
have_cloudpickle = True
except ImportError:
pass
have_cloudpickle = False


class MockFuture:
Expand Down Expand Up @@ -69,12 +71,17 @@ def submit(self, fn, /, *args, **kwargs):


def get_pool(
n_jobs, context="spawn", cls=ProcessPoolExecutor, with_rank_queue=False
n_jobs,
context="spawn",
cls=ProcessPoolExecutor,
with_rank_queue=False,
):
if n_jobs == -1:
n_jobs = multiprocessing.cpu_count()
do_parallel = n_jobs >= 1
n_jobs = max(1, n_jobs)
if cls == CloudpicklePoolExecutor and not have_cloudpickle:
cls = ProcessPoolExecutor
Executor = cls if do_parallel else MockPoolExecutor
context = get_context(context)
if with_rank_queue:
Expand Down
27 changes: 10 additions & 17 deletions src/dartsort/vis/scatterplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,29 +486,22 @@ def scatter_feature_vs_depth(
c = np.clip(amplitudes, 0, amplitude_color_cutoff)
cmap = amplitude_cmap
kept = slice(None)
elif show_triaged:
c = labels
# cmap = colorcet.m_glasbey_light
cmap = glasbey1024
print("quack")
c = cmap[c % len(cmap)]
kept = labels[to_show] >= 0
ax.scatter(
feature[to_show[~kept]],
depths_um[to_show[~kept]],
color="dimgray",
s=s,
linewidth=linewidth,
rasterized=rasterized,
**scatter_kw,
)
else:
c = labels
# cmap = colorcet.m_glasbey_light
cmap = glasbey1024
print("quack")
c = cmap[c % len(cmap)]
kept = labels[to_show] >= 0
if show_triaged:
ax.scatter(
feature[to_show[~kept]],
depths_um[to_show[~kept]],
color="dimgray",
s=s,
linewidth=linewidth,
rasterized=rasterized,
**scatter_kw,
)

s = ax.scatter(
feature[to_show[kept]],
Expand Down
4 changes: 2 additions & 2 deletions src/dartsort/vis/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
from matplotlib.legend_handler import HandlerTuple

from ..util.multiprocessing_util import get_pool
from ..util.multiprocessing_util import get_pool, CloudpicklePoolExecutor
from .waveforms import geomplot

# -- main class. see fn make_unit_summary below to make lots of UnitPlots.
Expand Down Expand Up @@ -747,7 +747,7 @@ def make_all_summaries(
save_folder = Path(save_folder)
save_folder.mkdir(exist_ok=True)

n_jobs, Executor, context = get_pool(n_jobs)
n_jobs, Executor, context = get_pool(n_jobs, cls=CloudpicklePoolExecutor)
with Executor(
max_workers=n_jobs,
mp_context=context,
Expand Down

0 comments on commit dbb0ab6

Please sign in to comment.