From 594a273d78bd70d561a837fdee1c56614c850774 Mon Sep 17 00:00:00 2001 From: Charlie Windolf <cwindolf95@gmail.com> Date: Thu, 18 Jan 2024 13:56:58 -0500 Subject: [PATCH 1/2] Export some useful stuff in top level __init__ --- src/dartsort/__init__.py | 10 ++++++++++ src/dartsort/cluster/relocate.py | 2 +- src/dartsort/util/data_util.py | 5 +++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/dartsort/__init__.py b/src/dartsort/__init__.py index e69de29b..67cbcffd 100644 --- a/src/dartsort/__init__.py +++ b/src/dartsort/__init__.py @@ -0,0 +1,10 @@ +from config import * +from main import (DARTsortSorting, ObjectiveUpdateTemplateMatchingPeeler, + SubtractionPeeler, check_recording, cluster, dartsort, match, + run_peeler, split_merge, subtract) + +from .localize.localize_util import (localize_amplitude_vectors, localize_hdf5, + localize_waveforms) +from .peel.grab import GrabAndFeaturize +from .transform import WaveformPipeline +from .waveform_util import make_channel_index diff --git a/src/dartsort/cluster/relocate.py b/src/dartsort/cluster/relocate.py index d5c12160..60b2d412 100644 --- a/src/dartsort/cluster/relocate.py +++ b/src/dartsort/cluster/relocate.py @@ -64,7 +64,7 @@ def relocated_waveforms_on_static_channels( ) rescaling = target_amplitudes / original_amplitudes shifted_waveforms *= rescaling[:, None, :] - + if two_d: shifted_waveforms = shifted_waveforms[:, 0, :] diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index bd4b4b57..6da109ff 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -94,14 +94,15 @@ def from_peeling_hdf5( channels_dataset="channels", labels_dataset="labels", load_simple_features=True, + labels=None, ): - channels = labels = None + channels = None with h5py.File(peeling_hdf5_filename, "r") as h5: times_samples = h5[times_samples_dataset][()] sampling_frequency = h5["sampling_frequency"][()] if channels_dataset in h5: channels = h5[channels_dataset][()] - if labels_dataset in h5: + if labels_dataset in h5 and labels is None: labels = h5[labels_dataset][()] n_spikes = len(times_samples) From 8b19706380c477f3d7dac70c66fe7745c4f71b29 Mon Sep 17 00:00:00 2001 From: Charlie Windolf <cwindolf95@gmail.com> Date: Thu, 18 Jan 2024 14:00:09 -0500 Subject: [PATCH 2/2] Wow, oops. --- src/dartsort/__init__.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/dartsort/__init__.py b/src/dartsort/__init__.py index 67cbcffd..8492ce2c 100644 --- a/src/dartsort/__init__.py +++ b/src/dartsort/__init__.py @@ -1,10 +1,9 @@ -from config import * -from main import (DARTsortSorting, ObjectiveUpdateTemplateMatchingPeeler, - SubtractionPeeler, check_recording, cluster, dartsort, match, - run_peeler, split_merge, subtract) - +from .config import * from .localize.localize_util import (localize_amplitude_vectors, localize_hdf5, localize_waveforms) +from .main import (DARTsortSorting, ObjectiveUpdateTemplateMatchingPeeler, + SubtractionPeeler, check_recording, cluster, dartsort, + match, run_peeler, split_merge, subtract) from .peel.grab import GrabAndFeaturize from .transform import WaveformPipeline -from .waveform_util import make_channel_index +from .util.waveform_util import make_channel_index