Skip to content

Commit 86dd69a

Browse files
committed
Test reproducibility
1 parent 244a415 commit 86dd69a

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

tests/test_grab_and_featurize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,11 @@ def test_grab_and_featurize():
253253

254254
# this is kind of a good test of reproducibility
255255
# totally reproducible on CPU, suprprisingly large diffs on GPU
256-
if not torch.cuda.is_available():
256+
# reproducibility is fine on some BLAS but not MKL?
257+
repro = (not torch.cuda.is_available()) and (
258+
"BLAS_INFO=mkl" not in torch.__config__.show()
259+
)
260+
if repro:
257261
assert np.array_equal(locs0, locs1)
258262
else:
259263
valid = np.clip(locs1[:, 2], geom[:, 1].min(), geom[:, 1].max())

tests/test_matching.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,12 @@ def test_tiny(tmp_path):
9292
assert np.isclose(
9393
torch.square(res["residual"]).mean(),
9494
0.0,
95+
atol=1e-4,
9596
)
9697
assert np.isclose(
9798
torch.square(res["conv"]).mean(),
9899
0.0,
99-
atol=1e-5,
100+
atol=1e-4,
100101
)
101102

102103
matcher = main.ObjectiveUpdateTemplateMatchingPeeler.from_config(
@@ -124,11 +125,12 @@ def test_tiny(tmp_path):
124125
assert np.isclose(
125126
torch.square(res["residual"]).mean(),
126127
0.0,
128+
atol=1e-4,
127129
)
128130
assert np.isclose(
129131
torch.square(res["conv"]).mean(),
130132
0.0,
131-
atol=1e-5,
133+
atol=1e-4,
132134
)
133135

134136

@@ -270,11 +272,12 @@ def test_tiny_up(tmp_path, up_factor=8):
270272
assert np.isclose(
271273
torch.square(res["residual"]).mean(),
272274
0.0,
275+
atol=1e-4,
273276
)
274277
assert np.isclose(
275278
torch.square(res["conv"]).mean(),
276279
0.0,
277-
atol=1e-5,
280+
atol=1e-4,
278281
)
279282

280283

@@ -410,6 +413,7 @@ def static_tester(tmp_path, up_factor=1):
410413
assert np.isclose(
411414
torch.square(res["residual"]).mean(),
412415
0.0,
416+
atol=1e-4,
413417
)
414418
assert np.isclose(
415419
torch.square(res["conv"]).mean(),

0 commit comments

Comments
 (0)