Skip to content

Commit

Permalink
Work on tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Oct 22, 2024
1 parent ee19a71 commit 1130bfc
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 50 deletions.
35 changes: 12 additions & 23 deletions src/dartsort/main.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,22 @@
from dataclasses import asdict
from pathlib import Path
import numpy as np

import numpy as np
from dartsort.cluster.initial import ensemble_chunks
from dartsort.cluster.merge import merge_templates
from dartsort.cluster.split import split_clusters
from dartsort.config import (
DARTsortConfig,
default_clustering_config,
default_dartsort_config,
default_featurization_config,
default_matching_config,
default_split_merge_config,
default_subtraction_config,
default_template_config,
default_waveform_config,
)
from dartsort.peel import (
ObjectiveUpdateTemplateMatchingPeeler,
SubtractionPeeler,
)
from dartsort.config import (DARTsortConfig, default_clustering_config,
default_dartsort_config,
default_featurization_config,
default_matching_config,
default_split_merge_config,
default_subtraction_config,
default_template_config, default_waveform_config)
from dartsort.peel import (ObjectiveUpdateTemplateMatchingPeeler,
SubtractionPeeler)
from dartsort.templates import TemplateData
from dartsort.util.data_util import (
DARTsortSorting,
check_recording,
keep_only_most_recent_spikes,
)
from dartsort.util.data_util import (DARTsortSorting, check_recording,
keep_only_most_recent_spikes)
from dartsort.util.peel_util import run_peeler
from dartsort.util.registration_util import estimate_motion

Expand Down Expand Up @@ -269,7 +260,6 @@ def match(
featurization_config=default_featurization_config,
matching_config=default_matching_config,
chunk_starts_samples=None,
subsampling_proportion=1.0,
n_jobs_templates=0,
n_jobs_match=0,
overwrite=False,
Expand Down Expand Up @@ -323,7 +313,6 @@ def match(
model_subdir,
featurization_config,
chunk_starts_samples=chunk_starts_samples,
subsampling_proportion=subsampling_proportion,
overwrite=overwrite,
n_jobs=n_jobs_match,
residual_filename=residual_filename,
Expand Down
2 changes: 1 addition & 1 deletion src/dartsort/peel/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def match_chunk(
collisioncleaned_waveforms=waveforms,
)
if return_residual:
res["residual"] = residual
res["residual"] = residual[left_margin : traces.shape[0] - right_margin]
if return_conv:
res["conv"] = padded_conv
return res
Expand Down
9 changes: 8 additions & 1 deletion src/dartsort/peel/peel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def precompute_peeling_data(
# runs before fit_peeler_models()
pass

def fit_peeler_models(self, save_folder, n_jobs=0, device=None):
def fit_peeler_models(self, save_folder, tmp_dir=None, n_jobs=0, device=None):
# subclasses should override if they need to fit models for peeling
assert not self.peeling_needs_fit()

Expand Down Expand Up @@ -538,6 +538,7 @@ def get_chunk_starts(
subsampled=False,
n_chunks=None,
ordered=False,
skip_last=False,
):
if chunk_starts_samples is not None:
return chunk_starts_samples
Expand All @@ -550,6 +551,10 @@ def get_chunk_starts(
if t_start is None:
t_start = 0
chunk_starts_samples = range(t_start, t_end, chunk_length_samples)
if skip_last:
chunk_starts_samples = list(chunk_starts_samples)
if t_end - chunk_starts_samples[-1] < chunk_length_samples:
chunk_starts_samples = chunk_starts_samples[:-1]

if not subsampled:
return chunk_starts_samples
Expand Down Expand Up @@ -584,6 +589,7 @@ def run_subsampled_peeling(
task_name=None,
overwrite=True,
ordered=False,
skip_last=False,
):
# run peeling on these chunks to the temp folder
chunk_starts = self.get_chunk_starts(
Expand All @@ -593,6 +599,7 @@ def run_subsampled_peeling(
t_end=t_end,
n_chunks=n_chunks,
ordered=ordered,
skip_last=skip_last,
)
self.peel(
hdf5_filename,
Expand Down
1 change: 0 additions & 1 deletion src/dartsort/transform/all_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
TemporalPCA,
Voltage,
Decollider,
Passthrough,
]

transformers_by_class_name = {cls.__name__: cls for cls in all_transformers}
Expand Down
2 changes: 1 addition & 1 deletion src/dartsort/transform/transform_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class BaseWaveformAutoencoder(BaseWaveformDenoiser, BaseWaveformFeaturizer):

class Passthrough(BaseWaveformDenoiser, BaseWaveformFeaturizer):

def __init__(self, pipeline):
def __init__(self, pipeline, geom=None, channel_index=None):
t = [t for t in pipeline if t.is_featurizer]
if not len(t):
t = pipeline.transformers
Expand Down
1 change: 1 addition & 0 deletions src/dartsort/util/peel_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def run_peeler(
task_name="Residual snips",
overwrite=False,
ordered=True,
skip_last=True,
)

return (
Expand Down
2 changes: 1 addition & 1 deletion src/dartsort/util/spikeio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def read_full_waveforms(
read_times = times_samples - trough_offset_samples

if not return_scaled and recording.binary_compatible_with(
file_offset=0, time_axis=0, file_paths_lenght=1
file_offset=0, time_axis=0, file_paths_length=1
):
# fast path. this is like 2x as fast as the read_traces for loop
# below, but requires a recording on disk in a nice format
Expand Down
3 changes: 3 additions & 0 deletions src/dartsort/util/spiketorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def ravel_multi_index(multi_index, dims):
Indices into the flattened tensor of shape `dims`
"""
if len(dims) == 1:
if isinstance(multi_index, tuple):
assert len(multi_index) == 1
multi_index = multi_index[0]
assert multi_index.ndim == 1
return multi_index

Expand Down
14 changes: 3 additions & 11 deletions tests/test_grab_and_featurize.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,17 +252,9 @@ def test_grab_and_featurize():
assert h5["last_chunk_start"][()] == 90_000

# this is kind of a good test of reproducibility
# totally reproducible on CPU, suprprisingly large diffs on GPU
# reproducibility is fine on some BLAS but not MKL?
repro = (not torch.cuda.is_available()) and (
"BLAS_INFO=mkl" not in torch.__config__.show()
)
if repro:
assert np.array_equal(locs0, locs1)
else:
valid = np.clip(locs1[:, 2], geom[:, 1].min(), geom[:, 1].max())
valid = locs1[:, 2] == valid
assert np.isclose(locs0[valid], locs1[valid], atol=1e-6).all()
valid = np.clip(locs1[:, 2], geom[:, 1].min(), geom[:, 1].max())
valid = locs1[:, 2] == valid
assert np.isclose(locs0[valid], locs1[valid], rtol=1e-3, atol=1e-3).all()


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions tests/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_tiny(tmp_path):
n_jobs=0,
save_folder=tmp_path,
overwrite=True,
with_locs=True,
)

matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config(
Expand Down Expand Up @@ -202,6 +203,7 @@ def test_tiny_up(tmp_path, up_factor=8):
n_jobs=0,
save_folder=tmp_path,
overwrite=True,
with_locs=True,
)
print(f"{template_data.templates.ptp(1).max(1)=}")

Expand Down Expand Up @@ -366,6 +368,7 @@ def static_tester(tmp_path, up_factor=1):
n_jobs=0,
save_folder=tmp_path,
overwrite=True,
with_locs=True,
)

matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config(
Expand Down
30 changes: 19 additions & 11 deletions tests/test_subtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
from dartsort.util import waveform_util
from test_util import dense_layout

fixedlenkeys = (
"subtract_channel_index",
"channel_index",
"geom",
"residual",
"residual_times_seconds",
)


def test_fakedata_nonn():
print("test_fakedata_nonn")
Expand Down Expand Up @@ -109,7 +117,7 @@ def test_fakedata_nonn():
with h5py.File(out_h5, locking=False) as h5:
assert h5["times_samples"].shape == (ns0,)
assert h5["channels"].shape == (ns0,)
assert h5["point_source_localizations"].shape == (ns0, 4)
assert h5["point_source_localizations"].shape in [(ns0, 4), (ns0, 3)]
assert np.array_equal(h5["channel_index"][:], channel_index)
assert h5["collisioncleaned_tpca_features"].shape == (
ns0,
Expand All @@ -133,7 +141,7 @@ def test_fakedata_nonn():
with h5py.File(out_h5, locking=False) as h5:
assert h5["times_samples"].shape == (ns0,)
assert h5["channels"].shape == (ns0,)
assert h5["point_source_localizations"].shape == (ns0, 4)
assert h5["point_source_localizations"].shape in [(ns0, 4), (ns0, 3)]
assert np.array_equal(h5["channel_index"][:], channel_index)
assert np.array_equal(h5["geom"][()], geom)
assert h5["last_chunk_start"][()] == int(np.floor(T_s) * fs)
Expand All @@ -157,7 +165,7 @@ def test_fakedata_nonn():
with h5py.File(out_h5, locking=False) as h5:
assert h5["times_samples"].shape == (ns0,)
assert h5["channels"].shape == (ns0,)
assert h5["point_source_localizations"].shape == (ns0, 4)
assert h5["point_source_localizations"].shape in [(ns0, 4), (ns0, 3)]
assert np.array_equal(h5["channel_index"][:], channel_index)
assert np.array_equal(h5["geom"][()], geom)
assert h5["last_chunk_start"][()] == int(np.floor(T_s) * fs)
Expand All @@ -181,7 +189,7 @@ def test_fakedata_nonn():
with h5py.File(out_h5, locking=False) as h5:
assert h5["times_samples"].shape == (ns0,)
assert h5["channels"].shape == (ns0,)
assert h5["point_source_localizations"].shape == (ns0, 4)
assert h5["point_source_localizations"].shape in [(ns0, 4), (ns0, 3)]
assert np.array_equal(h5["channel_index"][:], channel_index)
assert h5["collisioncleaned_tpca_features"].shape == (
ns0,
Expand All @@ -206,7 +214,7 @@ def test_fakedata_nonn():
with h5py.File(out_h5, locking=False) as h5:
assert h5["times_samples"].shape == (ns0,)
assert h5["channels"].shape == (ns0,)
assert h5["point_source_localizations"].shape == (ns0, 4)
assert h5["point_source_localizations"].shape in [(ns0, 4), (ns0, 3)]
assert np.array_equal(h5["channel_index"][:], channel_index)
assert np.array_equal(h5["geom"][()], geom)
assert h5["last_chunk_start"][()] == int(np.floor(T_s) * fs)
Expand All @@ -231,7 +239,7 @@ def test_fakedata_nonn():
with h5py.File(out_h5, locking=False) as h5:
assert h5["times_samples"].shape == (ns0,)
assert h5["channels"].shape == (ns0,)
assert h5["point_source_localizations"].shape == (ns0, 4)
assert h5["point_source_localizations"].shape in [(ns0, 4), (ns0, 3)]
assert np.array_equal(h5["channel_index"][:], channel_index)
assert np.array_equal(h5["geom"][()], geom)
assert h5["last_chunk_start"][()] == int(np.floor(T_s) * fs)
Expand Down Expand Up @@ -348,7 +356,7 @@ def test_small_nonn():
with h5py.File(out_h5, locking=False) as h5:
lens = []
for k in h5.keys():
if k not in ("subtract_channel_index", "channel_index", "geom") and h5[k].ndim >= 1:
if k not in fixedlenkeys and h5[k].ndim >= 1:
lens.append(h5[k].shape[0])
assert np.unique(lens).size == 1

Expand All @@ -367,7 +375,7 @@ def test_small_nonn():
with h5py.File(out_h5, locking=False) as h5:
lens = []
for k in h5.keys():
if k not in ("subtract_channel_index", "channel_index", "geom") and h5[k].ndim >= 1:
if k not in fixedlenkeys and h5[k].ndim >= 1:
lens.append(h5[k].shape[0])
assert np.unique(lens).size == 1

Expand All @@ -385,7 +393,7 @@ def test_small_nonn():
with h5py.File(out_h5, locking=False) as h5:
lens = []
for k in h5.keys():
if k not in ("subtract_channel_index", "channel_index", "geom") and h5[k].ndim >= 1:
if k not in fixedlenkeys and h5[k].ndim >= 1:
lens.append(h5[k].shape[0])
assert np.unique(lens).size == 1

Expand Down Expand Up @@ -423,7 +431,7 @@ def small_default_config(extract_radius=200):
with h5py.File(out_h5, locking=False) as h5:
lens = []
for k in h5.keys():
if k not in ("subtract_channel_index", "channel_index", "geom") and h5[k].ndim >= 1:
if k not in fixedlenkeys and h5[k].ndim >= 1:
lens.append(h5[k].shape[0])
assert np.unique(lens).size == 1

Expand All @@ -439,7 +447,7 @@ def small_default_config(extract_radius=200):
with h5py.File(out_h5, locking=False) as h5:
lens = []
for k in h5.keys():
if k not in ("subtract_channel_index", "channel_index", "geom") and h5[k].ndim >= 1:
if k not in fixedlenkeys and h5[k].ndim >= 1:
lens.append(h5[k].shape[0])
assert np.unique(lens).size == 1

Expand Down

0 comments on commit 1130bfc

Please sign in to comment.