Skip to content

Commit 2388b24

Browse files
committed
Debug test reproducibility on MKL
1 parent 86dd69a commit 2388b24

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

tests/test_drift_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_shifted_waveforms():
1515
assert np.array_equal(np.unique(reg_geom[:, 1]), [1, 2, 3, 4, 5, 6, 7])
1616

1717
# fixed check
18-
waveforms = np.arange(15).reshape(5, 3)[:, None, :]
18+
waveforms = np.arange(15).reshape(5, 3)[:, None, :].astype(np.float32)
1919
w = drift_util.get_waveforms_on_static_channels(waveforms, geom=geom)
2020
assert np.array_equal(w, waveforms)
2121
w = drift_util.get_waveforms_on_static_channels(

tests/test_matching.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def test_tiny(tmp_path):
9292
assert np.isclose(
9393
torch.square(res["residual"]).mean(),
9494
0.0,
95-
atol=1e-4,
9695
)
96+
print(f"A {torch.square(res['conv']).mean()=}")
9797
assert np.isclose(
9898
torch.square(res["conv"]).mean(),
9999
0.0,
@@ -125,8 +125,8 @@ def test_tiny(tmp_path):
125125
assert np.isclose(
126126
torch.square(res["residual"]).mean(),
127127
0.0,
128-
atol=1e-4,
129128
)
129+
print(f"B {torch.square(res['conv']).mean()=}")
130130
assert np.isclose(
131131
torch.square(res["conv"]).mean(),
132132
0.0,
@@ -265,14 +265,13 @@ def test_tiny_up(tmp_path, up_factor=8):
265265
print(f'{np.c_[res["times_samples"], res["labels"], res["upsampling_indices"]]=}')
266266
print(f"{np.c_[times, labels, upsampling_indices]=}")
267267
print(f'{torch.square(res["residual"]).mean()=}')
268-
print(f'{torch.square(res["conv"]).mean()=}')
268+
print(f"C {torch.square(res['conv']).mean()=}")
269269
assert res["n_spikes"] == len(times)
270270
assert np.array_equal(res["times_samples"], times)
271271
assert np.array_equal(res["labels"], labels)
272272
assert np.isclose(
273273
torch.square(res["residual"]).mean(),
274274
0.0,
275-
atol=1e-4,
276275
)
277276
assert np.isclose(
278277
torch.square(res["conv"]).mean(),
@@ -413,12 +412,12 @@ def static_tester(tmp_path, up_factor=1):
413412
assert np.isclose(
414413
torch.square(res["residual"]).mean(),
415414
0.0,
416-
atol=1e-4,
417415
)
416+
print(f"D {torch.square(res['conv']).mean()=}")
418417
assert np.isclose(
419418
torch.square(res["conv"]).mean(),
420419
0.0,
421-
atol=1e-4,
420+
atol=1e-3,
422421
)
423422

424423

0 commit comments

Comments
 (0)