diff --git a/src/dartsort/__init__.py b/src/dartsort/__init__.py index e69de29b..8492ce2c 100644 --- a/src/dartsort/__init__.py +++ b/src/dartsort/__init__.py @@ -0,0 +1,9 @@ +from .config import * +from .localize.localize_util import (localize_amplitude_vectors, localize_hdf5, + localize_waveforms) +from .main import (DARTsortSorting, ObjectiveUpdateTemplateMatchingPeeler, + SubtractionPeeler, check_recording, cluster, dartsort, + match, run_peeler, split_merge, subtract) +from .peel.grab import GrabAndFeaturize +from .transform import WaveformPipeline +from .util.waveform_util import make_channel_index diff --git a/src/dartsort/cluster/relocate.py b/src/dartsort/cluster/relocate.py index d5c12160..60b2d412 100644 --- a/src/dartsort/cluster/relocate.py +++ b/src/dartsort/cluster/relocate.py @@ -64,7 +64,7 @@ def relocated_waveforms_on_static_channels( ) rescaling = target_amplitudes / original_amplitudes shifted_waveforms *= rescaling[:, None, :] - + if two_d: shifted_waveforms = shifted_waveforms[:, 0, :] diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index bd4b4b57..6da109ff 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -94,14 +94,15 @@ def from_peeling_hdf5( channels_dataset="channels", labels_dataset="labels", load_simple_features=True, + labels=None, ): - channels = labels = None + channels = None with h5py.File(peeling_hdf5_filename, "r") as h5: times_samples = h5[times_samples_dataset][()] sampling_frequency = h5["sampling_frequency"][()] if channels_dataset in h5: channels = h5[channels_dataset][()] - if labels_dataset in h5: + if labels_dataset in h5 and labels is None: labels = h5[labels_dataset][()] n_spikes = len(times_samples)