Skip to content

Commit 620dfdf

Browse files
committed
Add second attempt for assertion
1 parent 2774302 commit 620dfdf

File tree

1 file changed

+63
-15
lines changed

1 file changed

+63
-15
lines changed

tests/test_precision.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import importlib
23
from unittest.mock import patch
34

45
import naive
@@ -8,7 +9,7 @@
89
from numba import cuda
910

1011
import stumpy
11-
from stumpy import config, core
12+
from stumpy import cache, config, core
1213

1314
try:
1415
from numba.errors import NumbaPerformanceWarning
@@ -146,20 +147,67 @@ def test_snippets():
146147
cmp_regimes,
147148
) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func)
148149

149-
npt.assert_almost_equal(
150-
ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION
151-
)
152-
npt.assert_almost_equal(
153-
ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION
154-
)
155-
npt.assert_almost_equal(
156-
ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION
157-
)
158-
npt.assert_almost_equal(
159-
ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION
160-
)
161-
npt.assert_almost_equal(ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION)
162-
npt.assert_almost_equal(ref_regimes, cmp_regimes)
150+
# Revise fastmath flag, recompile, and re-calculate snippets,
151+
# and then revert the changes
152+
config.STUMPY_FASTMATH_FLAGS = {"nsz", "arcp", "contract", "afn"}
153+
core._calculate_squared_distance.targetoptions["fastmath"] = config.STUMPY_FASTMATH_FLAGS
154+
njit_funcs = cache.get_njit_funcs()
155+
for module_name, func_name in njit_funcs:
156+
code = f"from stumpy.{module_name} import {func_name}; {func_name}.recompile()"
157+
exec(code)
158+
# module = importlib.import_module(f".{module_name}", package="stumpy")
159+
# func = getattr(module, func_name)
160+
# func.recompile()
161+
162+
(
163+
cmp_snippets_NOreassoc,
164+
cmp_indices_NOreassoc,
165+
cmp_profiles_NOreassoc,
166+
cmp_fractions_NOreassoc,
167+
cmp_areas_NOreassoc,
168+
cmp_regimes_NOreassoc,
169+
) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func)
170+
171+
config._reset("STUMPY_FASTMATH_FLAGS")
172+
for module_name, func_name in njit_funcs:
173+
module = importlib.import_module(f".{module_name}", package="stumpy")
174+
func = getattr(module, func_name)
175+
func.recompile()
176+
177+
if np.allclose(ref_snippets, cmp_snippets):
178+
npt.assert_almost_equal(
179+
ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION
180+
)
181+
npt.assert_almost_equal(
182+
ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION
183+
)
184+
npt.assert_almost_equal(
185+
ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION
186+
)
187+
npt.assert_almost_equal(
188+
ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION
189+
)
190+
npt.assert_almost_equal(
191+
ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION
192+
)
193+
npt.assert_almost_equal(ref_regimes, cmp_regimes)
194+
else:
195+
npt.assert_almost_equal(
196+
ref_snippets, cmp_snippets_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
197+
)
198+
npt.assert_almost_equal(
199+
ref_indices, cmp_indices_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
200+
)
201+
npt.assert_almost_equal(
202+
ref_profiles, cmp_profiles_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
203+
)
204+
npt.assert_almost_equal(
205+
ref_fractions, cmp_fractions_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
206+
)
207+
npt.assert_almost_equal(
208+
ref_areas, cmp_areas_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
209+
)
210+
npt.assert_almost_equal(ref_regimes, cmp_regimes_NOreassoc)
163211

164212

165213
@pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning)

0 commit comments

Comments
 (0)