From 83d93cf1b13c49e543f1d6698901b415367cba82 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 5 Nov 2024 10:00:24 -0800 Subject: [PATCH] Analysis is slow... --- src/dartsort/templates/pairwise_util.py | 4 ++-- src/dartsort/util/analysis.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/dartsort/templates/pairwise_util.py b/src/dartsort/templates/pairwise_util.py index 66134dcc..7d98e3ce 100644 --- a/src/dartsort/templates/pairwise_util.py +++ b/src/dartsort/templates/pairwise_util.py @@ -894,8 +894,8 @@ def shift_deduplicated_pairs( dot = chan_amp_a @ chan_amp_b.T pair = dot > conv_ignore_threshold if min_spatial_cosine: - norm_a = torch.sqrt((chan_amp_a * chan_amp_a).sum(1)) - norm_b = torch.sqrt((chan_amp_b * chan_amp_b).sum(1)) + norm_a = torch.sqrt(chan_amp_a.square().sum(1)) + norm_b = torch.sqrt(chan_amp_b.square().sum(1)) cos = dot / (norm_a[:, None] * norm_b[None, :]) pair = pair & (cos > min_spatial_cosine) diff --git a/src/dartsort/util/analysis.py b/src/dartsort/util/analysis.py index 40fdf952..9b3ac0ef 100644 --- a/src/dartsort/util/analysis.py +++ b/src/dartsort/util/analysis.py @@ -68,7 +68,7 @@ class DARTsortAnalysis: merge_distance_kind: str = "rms" merge_distance_spatial_radius_a: Optional[float] = None merge_distance_min_channel_amplitude: float = 0.0 - merge_distance_min_spatial_cosine: float = 0.0 + merge_distance_min_spatial_cosine: float = 0.5 merge_temporal_upsampling: int = 1 merge_superres_linkage: Callable[[np.ndarray], float] = np.max compute_distances: bool = "if_hdf5" @@ -101,7 +101,8 @@ def from_sorting( assert model_dir.exists() featurization_pipeline = torch.load( - model_dir / "featurization_pipeline.pt" + model_dir / "featurization_pipeline.pt", + weights_only=True, ) have_templates = False