Skip to content

Commit cc6546b

Browse files
authored
Fixed #1116 Avoid Importing Fastmath Functions (#1117)
1 parent 97da018 commit cc6546b

File tree

3 files changed

+40
-14
lines changed

3 files changed

+40
-14
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,5 @@ tbb = ">=2019.5"
112112
tests = "./test.sh"
113113
coverage = "./test.sh coverage"
114114
docs = "cd docs && ./setup.sh"
115+
black = 'black --exclude=".*\.ipynb" --extend-exclude=".venv|.pixi" --diff ./'
116+
isort = 'isort --profile black --skip .venv --skip .pixi ./'

stumpy/__init__.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
import importlib
1+
import ast
22
import os.path
3+
import pathlib
34
from importlib.metadata import distribution
45
from site import getsitepackages
56

6-
import numba
77
from numba import cuda
88

99
from . import cache, config
@@ -38,14 +38,27 @@
3838
# Get the default fastmath flags for all njit functions
3939
# and update the _STUMPY_DEFAULTS dictionary
4040

41-
if not numba.config.DISABLE_JIT: # pragma: no cover
42-
njit_funcs = cache.get_njit_funcs()
43-
for module_name, func_name in njit_funcs:
44-
module = importlib.import_module(f".{module_name}", package="stumpy")
45-
func = getattr(module, func_name)
46-
key = module_name + "." + func_name # e.g., core._mass
47-
key = "STUMPY_FASTMATH_" + key.upper() # e.g., STUMPY_FASTHMATH_CORE._MASS
48-
config._STUMPY_DEFAULTS[key] = func.targetoptions["fastmath"]
41+
42+
def _get_fastmath_value(module_name, func_name): # pragma: no cover
43+
fname = module_name + ".py"
44+
fname = pathlib.Path(__file__).parent / fname
45+
with open(fname, "r", encoding="utf-8") as f:
46+
src = f.read()
47+
tree = ast.parse(src)
48+
for node in ast.walk(tree):
49+
if isinstance(node, ast.FunctionDef) and node.name == func_name:
50+
for dec in node.decorator_list:
51+
for kw in dec.keywords:
52+
if kw.arg == "fastmath":
53+
fastmath_flag = ast.get_source_segment(src, kw.value)
54+
return eval(fastmath_flag)
55+
56+
57+
njit_funcs = cache.get_njit_funcs()
58+
for module_name, func_name in njit_funcs:
59+
key = module_name + "." + func_name # e.g., core._mass
60+
key = "STUMPY_FASTMATH_" + key.upper() # e.g., STUMPY_FASTHMATH_CORE._MASS
61+
config._STUMPY_DEFAULTS[key] = _get_fastmath_value(module_name, func_name)
4962

5063
if cuda.is_available():
5164
from .gpu_aamp import gpu_aamp # noqa: F401
@@ -72,9 +85,6 @@
7285
core._gpu_searchsorted_left = core._gpu_searchsorted_left_driver_not_found
7386
core._gpu_searchsorted_right = core._gpu_searchsorted_right_driver_not_found
7487

75-
import ast
76-
import pathlib
77-
7888
# Fix GPU-STUMP Docs
7989
gpu_stump.__doc__ = ""
8090
filepath = pathlib.Path(__file__).parent / "gpu_stump.py"

tests/test_fastmath.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import importlib
2+
13
import numba
24
import numpy as np
5+
import pytest
36

4-
from stumpy import cache, fastmath
7+
from stumpy import _get_fastmath_value, cache, fastmath
58

69

710
def test_set():
@@ -50,3 +53,14 @@ def test_reset():
5053
assert np.isnan(fastmath._add_assoc(0.0, np.inf))
5154
else: # pragma: no cover
5255
assert fastmath._add_assoc(0.0, np.inf) == 0.0
56+
57+
58+
@pytest.mark.skipif(numba.config.DISABLE_JIT, reason="JIT Disabled")
59+
def test_get_fastmath_value(): # pragma: no cover
60+
njit_funcs = cache.get_njit_funcs()
61+
for module_name, func_name in njit_funcs:
62+
module = importlib.import_module(f".{module_name}", package="stumpy")
63+
func = getattr(module, func_name)
64+
ref = func.targetoptions["fastmath"]
65+
cmp = _get_fastmath_value(module_name, func_name)
66+
assert ref == cmp

0 commit comments

Comments
 (0)