|
1 | | -import importlib |
| 1 | +import ast |
2 | 2 | import os.path |
| 3 | +import pathlib |
3 | 4 | from importlib.metadata import distribution |
4 | 5 | from site import getsitepackages |
5 | 6 |
|
6 | | -import numba |
7 | 7 | from numba import cuda |
8 | 8 |
|
9 | 9 | from . import cache, config |
|
38 | 38 | # Get the default fastmath flags for all njit functions |
39 | 39 | # and update the _STUMPY_DEFAULTS dictionary |
40 | 40 |
|
41 | | -if not numba.config.DISABLE_JIT: # pragma: no cover |
42 | | - njit_funcs = cache.get_njit_funcs() |
43 | | - for module_name, func_name in njit_funcs: |
44 | | - module = importlib.import_module(f".{module_name}", package="stumpy") |
45 | | - func = getattr(module, func_name) |
46 | | - key = module_name + "." + func_name # e.g., core._mass |
47 | | - key = "STUMPY_FASTMATH_" + key.upper() # e.g., STUMPY_FASTHMATH_CORE._MASS |
48 | | - config._STUMPY_DEFAULTS[key] = func.targetoptions["fastmath"] |
| 41 | + |
| 42 | +def _get_fastmath_value(module_name, func_name): # pragma: no cover |
| 43 | + fname = module_name + ".py" |
| 44 | + fname = pathlib.Path(__file__).parent / fname |
| 45 | + with open(fname, "r", encoding="utf-8") as f: |
| 46 | + src = f.read() |
| 47 | + tree = ast.parse(src) |
| 48 | + for node in ast.walk(tree): |
| 49 | + if isinstance(node, ast.FunctionDef) and node.name == func_name: |
| 50 | + for dec in node.decorator_list: |
| 51 | + for kw in dec.keywords: |
| 52 | + if kw.arg == "fastmath": |
| 53 | + fastmath_flag = ast.get_source_segment(src, kw.value) |
| 54 | + return eval(fastmath_flag) |
| 55 | + |
| 56 | + |
| 57 | +njit_funcs = cache.get_njit_funcs() |
| 58 | +for module_name, func_name in njit_funcs: |
| 59 | + key = module_name + "." + func_name # e.g., core._mass |
| 60 | + key = "STUMPY_FASTMATH_" + key.upper() # e.g., STUMPY_FASTHMATH_CORE._MASS |
| 61 | + config._STUMPY_DEFAULTS[key] = _get_fastmath_value(module_name, func_name) |
49 | 62 |
|
50 | 63 | if cuda.is_available(): |
51 | 64 | from .gpu_aamp import gpu_aamp # noqa: F401 |
|
72 | 85 | core._gpu_searchsorted_left = core._gpu_searchsorted_left_driver_not_found |
73 | 86 | core._gpu_searchsorted_right = core._gpu_searchsorted_right_driver_not_found |
74 | 87 |
|
75 | | - import ast |
76 | | - import pathlib |
77 | | - |
78 | 88 | # Fix GPU-STUMP Docs |
79 | 89 | gpu_stump.__doc__ = "" |
80 | 90 | filepath = pathlib.Path(__file__).parent / "gpu_stump.py" |
|
0 commit comments