Skip to content

Commit

Permalink
Fix a couple tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 7, 2024
1 parent 5c70958 commit 41ca76d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 22 deletions.
17 changes: 2 additions & 15 deletions tests/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def test_tiny_up(tmp_path, up_factor=8):
temp = cupts.compressed_upsampled_templates[
cupts.compressed_upsampling_map[l, u]
]
print(f"{t=} {temp.ptp(0).max()=}")
rec0[
t - trough_offset_samples : t - trough_offset_samples + spike_length_samples
] += temp
Expand All @@ -205,7 +204,6 @@ def test_tiny_up(tmp_path, up_factor=8):
overwrite=True,
with_locs=True,
)
print(f"{template_data.templates.ptp(1).max(1)=}")

matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config(
rec,
Expand All @@ -225,7 +223,7 @@ def test_tiny_up(tmp_path, up_factor=8):
)
tempup = template_util.compressed_upsampled_templates(
lrt.temporal_components,
ptps=template_data.templates.ptp(1).max(1),
ptps=np.ptp(template_data.templates, 1).max(1),
max_upsample=up_factor,
)
assert np.array_equal(
Expand Down Expand Up @@ -287,17 +285,6 @@ def test_tiny_up(tmp_path, up_factor=8):
return_conv=True,
)

print(f'{res["n_spikes"]=} {len(times)=}')
print(f"{cupts.compressed_upsampled_templates.ptp(1).max(1)=}")
print(
f'{res["collisioncleaned_waveforms"].numpy(force=True).ptp(1).max(1)=}'
)
print(
f'{np.c_[res["times_samples"], res["labels"], res["upsampling_indices"]]=}'
)
print(f"{np.c_[times, labels, upsampling_indices]=}")
print(f'{torch.square(res["residual"]).mean()=}')
print(f"C {torch.square(res['conv']).mean()=}")
assert res["n_spikes"] == len(times)
assert np.array_equal(res["times_samples"], times)
assert np.array_equal(res["labels"], labels)
Expand Down Expand Up @@ -392,7 +379,7 @@ def static_tester(tmp_path, up_factor=1):
)
tempup = template_util.compressed_upsampled_templates(
lrt.temporal_components,
ptps=template_data.templates.ptp(1).max(1),
ptps=np.ptp(template_data.templates, 1).max(1),
max_upsample=up_factor,
)
assert np.array_equal(
Expand Down
8 changes: 1 addition & 7 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from dartsort import config
from dartsort.templates import (get_templates, pairwise, pairwise_util,
template_util, templates)
from dartsort.util import drift_util
from dartsort.util.data_util import DARTsortSorting
from dredge.motion_util import IdentityMotionEstimate, get_motion_estimate
from test_util import no_overlap_recording_sorting
Expand All @@ -27,11 +26,6 @@ def test_roundtrip(tmp_path):
save_folder=tmp_path,
overwrite=True,
)
print(f"{np.abs(template_data.templates - temps).max()=}")
print(f"{np.abs(template_data.templates - temps).mean()=}")
print(f"{np.abs(template_data.templates - temps).min()=}")
print(f"{template_data.templates.ptp(1).max(1)=}")
print(f"{temps.ptp(1).max(1)=}")
assert np.array_equal(template_data.templates, temps)


Expand Down Expand Up @@ -218,7 +212,7 @@ def test_pconv():
svd_compressed = template_util.svd_compress_templates(temps, rank=1)
ctempup = template_util.compressed_upsampled_templates(
svd_compressed.temporal_components,
ptps=temps.ptp(1).max(1),
ptps=np.ptp(temps, 1).max(1),
max_upsample=1,
kind="cubic",
)
Expand Down

0 comments on commit 41ca76d

Please sign in to comment.