From 594a273d78bd70d561a837fdee1c56614c850774 Mon Sep 17 00:00:00 2001
From: Charlie Windolf <cwindolf95@gmail.com>
Date: Thu, 18 Jan 2024 13:56:58 -0500
Subject: [PATCH 1/2] Export some useful stuff in top level __init__

---
 src/dartsort/__init__.py         | 10 ++++++++++
 src/dartsort/cluster/relocate.py |  2 +-
 src/dartsort/util/data_util.py   |  5 +++--
 3 files changed, 14 insertions(+), 3 deletions(-)

diff --git a/src/dartsort/__init__.py b/src/dartsort/__init__.py
index e69de29b..67cbcffd 100644
--- a/src/dartsort/__init__.py
+++ b/src/dartsort/__init__.py
@@ -0,0 +1,10 @@
+from config import *
+from main import (DARTsortSorting, ObjectiveUpdateTemplateMatchingPeeler,
+                  SubtractionPeeler, check_recording, cluster, dartsort, match,
+                  run_peeler, split_merge, subtract)
+
+from .localize.localize_util import (localize_amplitude_vectors, localize_hdf5,
+                                     localize_waveforms)
+from .peel.grab import GrabAndFeaturize
+from .transform import WaveformPipeline
+from .waveform_util import make_channel_index
diff --git a/src/dartsort/cluster/relocate.py b/src/dartsort/cluster/relocate.py
index d5c12160..60b2d412 100644
--- a/src/dartsort/cluster/relocate.py
+++ b/src/dartsort/cluster/relocate.py
@@ -64,7 +64,7 @@ def relocated_waveforms_on_static_channels(
     )
     rescaling = target_amplitudes / original_amplitudes
     shifted_waveforms *= rescaling[:, None, :]
-    
+
     if two_d:
         shifted_waveforms = shifted_waveforms[:, 0, :]
 
diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py
index bd4b4b57..6da109ff 100644
--- a/src/dartsort/util/data_util.py
+++ b/src/dartsort/util/data_util.py
@@ -94,14 +94,15 @@ def from_peeling_hdf5(
         channels_dataset="channels",
         labels_dataset="labels",
         load_simple_features=True,
+        labels=None,
     ):
-        channels = labels = None
+        channels = None
         with h5py.File(peeling_hdf5_filename, "r") as h5:
             times_samples = h5[times_samples_dataset][()]
             sampling_frequency = h5["sampling_frequency"][()]
             if channels_dataset in h5:
                 channels = h5[channels_dataset][()]
-            if labels_dataset in h5:
+            if labels_dataset in h5 and labels is None:
                 labels = h5[labels_dataset][()]
 
             n_spikes = len(times_samples)

From 8b19706380c477f3d7dac70c66fe7745c4f71b29 Mon Sep 17 00:00:00 2001
From: Charlie Windolf <cwindolf95@gmail.com>
Date: Thu, 18 Jan 2024 14:00:09 -0500
Subject: [PATCH 2/2] Wow, oops.

---
 src/dartsort/__init__.py | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/src/dartsort/__init__.py b/src/dartsort/__init__.py
index 67cbcffd..8492ce2c 100644
--- a/src/dartsort/__init__.py
+++ b/src/dartsort/__init__.py
@@ -1,10 +1,9 @@
-from config import *
-from main import (DARTsortSorting, ObjectiveUpdateTemplateMatchingPeeler,
-                  SubtractionPeeler, check_recording, cluster, dartsort, match,
-                  run_peeler, split_merge, subtract)
-
+from .config import *
 from .localize.localize_util import (localize_amplitude_vectors, localize_hdf5,
                                      localize_waveforms)
+from .main import (DARTsortSorting, ObjectiveUpdateTemplateMatchingPeeler,
+                   SubtractionPeeler, check_recording, cluster, dartsort,
+                   match, run_peeler, split_merge, subtract)
 from .peel.grab import GrabAndFeaturize
 from .transform import WaveformPipeline
-from .waveform_util import make_channel_index
+from .util.waveform_util import make_channel_index