From ce89694690dac8109b10eb529186a254d62c233f Mon Sep 17 00:00:00 2001 From: Sean Law Date: Tue, 28 Jan 2025 11:52:25 -0500 Subject: [PATCH 01/19] Added numba cache dir for pytest --- stumpy/cache.py | 37 ++++++++++++++++++++++++++----------- tests/test_cache.py | 7 ++++++- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/stumpy/cache.py b/stumpy/cache.py index fbaf35230..4b12312cd 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -5,6 +5,7 @@ import ast import importlib import inspect +import os import pathlib import site import warnings @@ -102,48 +103,57 @@ def _enable(): raise -def _clear(): +def _clear(cache_dir=None): """ Clear numba cache Parameters ---------- - None + cache_dir : str + The path to the numba cache directory Returns ------- None """ - site_pkg_dir = site.getsitepackages()[0] - numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__" + if cache_dir is not None: # pragma: no cover + numba_cache_dir = str(cache_dir) + elif "PYTEST_CURRENT_TEST" in os.environ: + numba_cache_dir = "stumpy/__pycache__" + else: # pragma: no cover + site_pkg_dir = site.getsitepackages()[0] + numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__" + [f.unlink() for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()] -def clear(): +def clear(cache_dir=None): """ Clear numba cache directory Parameters ---------- - None + cache_dir : str + The path to the numba cache directory Returns ------- None """ warnings.warn(CACHE_WARNING) - _clear() + _clear(cache_dir) return -def _get_cache(): +def _get_cache(cache_dir=None): """ Retrieve a list of cached numba functions Parameters ---------- - None + cache_dir : str + The path to the numba cache directory Returns ------- @@ -151,8 +161,13 @@ def _get_cache(): A list of cached numba functions """ warnings.warn(CACHE_WARNING) - site_pkg_dir = site.getsitepackages()[0] - numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__" + if cache_dir is not None: # pragma: no cover + numba_cache_dir = str(cache_dir) + if "PYTEST_CURRENT_TEST" in os.environ: + numba_cache_dir = "stumpy/__pycache__" + else: # pragma: no cover + site_pkg_dir = site.getsitepackages()[0] + numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__" return [f.name for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()] diff --git a/tests/test_cache.py b/tests/test_cache.py index 4b8af788d..80d085aa2 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -9,11 +9,13 @@ def test_cache_get_njit_funcs(): def test_cache_save_after_clear(): + cache.clear() + cache.save() + T = np.random.rand(10) m = 3 stump(T, m) - cache.save() ref_cache = cache._get_cache() cache.clear() @@ -21,7 +23,10 @@ def test_cache_save_after_clear(): assert len(cache._get_cache()) == 0 cache.save() + stump(T, m) comp_cache = cache._get_cache() # testing cache._save() after cache._clear() assert sorted(ref_cache) == sorted(comp_cache) + + cache.clear() From 4a5b0cf88e69de04698a0571318895fa11ddd69a Mon Sep 17 00:00:00 2001 From: Sean Law Date: Tue, 28 Jan 2025 12:27:00 -0500 Subject: [PATCH 02/19] Added cache._clear() to cache._save() --- tests/test_cache.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 80d085aa2..36c046fd1 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -9,7 +9,6 @@ def test_cache_get_njit_funcs(): def test_cache_save_after_clear(): - cache.clear() cache.save() T = np.random.rand(10) From ffa468c3930875e669bf6da4033d2a90a1152652 Mon Sep 17 00:00:00 2001 From: Sean Law Date: Tue, 28 Jan 2025 21:36:36 -0500 Subject: [PATCH 03/19] Removed recompile from fastmath --- stumpy/cache.py | 1 + stumpy/fastmath.py | 7 +++++-- tests/test_fastmath.py | 8 +++++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/stumpy/cache.py b/stumpy/cache.py index 4b12312cd..7e3a049e0 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -218,6 +218,7 @@ def _save(): None """ _enable() + _clear() _recompile() return diff --git a/stumpy/fastmath.py b/stumpy/fastmath.py index 5aac4ee0a..05402c921 100644 --- a/stumpy/fastmath.py +++ b/stumpy/fastmath.py @@ -1,4 +1,5 @@ import importlib +import warnings import numba from numba import njit @@ -55,11 +56,13 @@ def _set(module_name, func_name, flag): func = getattr(module, func_name) try: func.targetoptions["fastmath"] = flag - func.recompile() + msg = "One or more fastmath flags have been set/reset. " + msg += "Please call `cache._recompile()` to ensure that all njit functions " + msg += "are properly recompiled." + warnings.warn(msg) except AttributeError as e: if numba.config.DISABLE_JIT and ( str(e) == "'function' object has no attribute 'targetoptions'" - or str(e) == "'function' object has no attribute 'recompile'" ): pass else: # pragma: no cover diff --git a/tests/test_fastmath.py b/tests/test_fastmath.py index d2a993069..a16bb6898 100644 --- a/tests/test_fastmath.py +++ b/tests/test_fastmath.py @@ -1,7 +1,7 @@ import numba import numpy as np -from stumpy import fastmath +from stumpy import cache, fastmath def test_set(): @@ -11,11 +11,13 @@ def test_set(): # case1: flag=False fastmath._set("fastmath", "_add_assoc", flag=False) + cache._recompile() out = fastmath._add_assoc(0, np.inf) assert np.isnan(out) # case2: flag={'reassoc', 'nsz'} fastmath._set("fastmath", "_add_assoc", flag={"reassoc", "nsz"}) + cache._recompile() out = fastmath._add_assoc(0, np.inf) if numba.config.DISABLE_JIT: assert np.isnan(out) @@ -24,11 +26,13 @@ def test_set(): # case3: flag={'reassoc'} fastmath._set("fastmath", "_add_assoc", flag={"reassoc"}) + cache._recompile() out = fastmath._add_assoc(0, np.inf) assert np.isnan(out) # case4: flag={'nsz'} fastmath._set("fastmath", "_add_assoc", flag={"nsz"}) + cache._recompile() out = fastmath._add_assoc(0, np.inf) assert np.isnan(out) @@ -39,7 +43,9 @@ def test_reset(): # https://numba.pydata.org/numba-doc/dev/user/performance-tips.html#fastmath # and then reset it to the default value, i.e. `True` fastmath._set("fastmath", "_add_assoc", False) + cache._recompile() fastmath._reset("fastmath", "_add_assoc") + cache._recompile() if numba.config.DISABLE_JIT: assert np.isnan(fastmath._add_assoc(0.0, np.inf)) else: # pragma: no cover From ef30b96b56c0365854856a0538d184c26feb3bdd Mon Sep 17 00:00:00 2001 From: Sean Law Date: Tue, 28 Jan 2025 22:39:11 -0500 Subject: [PATCH 04/19] Added ref cache length check --- tests/test_cache.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 36c046fd1..3750b47bc 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,3 +1,4 @@ +import numba import numpy as np from stumpy import cache, stump @@ -9,23 +10,25 @@ def test_cache_get_njit_funcs(): def test_cache_save_after_clear(): - cache.save() - T = np.random.rand(10) m = 3 - stump(T, m) + cache.save() + stump(T, m) ref_cache = cache._get_cache() + if numba.config.DISABLE_JIT: + assert len(ref_cache) == 0 + else: + assert len(ref_cache) > 0 + cache.clear() - # testing cache._clear() assert len(cache._get_cache()) == 0 cache.save() stump(T, m) comp_cache = cache._get_cache() - # testing cache._save() after cache._clear() assert sorted(ref_cache) == sorted(comp_cache) cache.clear() From 0c022a76c75653c42113b8950399da94754b5a4e Mon Sep 17 00:00:00 2001 From: Sean Law Date: Tue, 28 Jan 2025 23:04:45 -0500 Subject: [PATCH 05/19] Improved coverage --- tests/test_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 3750b47bc..8cf7bca21 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -19,7 +19,7 @@ def test_cache_save_after_clear(): if numba.config.DISABLE_JIT: assert len(ref_cache) == 0 - else: + else: # pragma: no cover assert len(ref_cache) > 0 cache.clear() From c3aaf01fc4f8f7aea39ea1544012c0ab44525231 Mon Sep 17 00:00:00 2001 From: Sean Law Date: Wed, 29 Jan 2025 06:58:34 -0500 Subject: [PATCH 06/19] Fixed black formatting --- tests/test_floss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_floss.py b/tests/test_floss.py index 3a11a5320..74735d1af 100644 --- a/tests/test_floss.py +++ b/tests/test_floss.py @@ -92,7 +92,7 @@ def naive_rea(cac, n_regimes, L, excl_factor): return np.array(loc_regimes, dtype=np.int64) -test_data = [(np.random.randint(0, 50, size=50, dtype=np.int64))] +test_data = [np.random.randint(0, 50, size=50, dtype=np.int64)] substitution_locations = [(slice(0, 0), 0, -1, slice(1, 3), [0, 3])] substitution_values = [np.nan, np.inf] From c54d63e24e365ef0013f781931ae0b74abf42d1f Mon Sep 17 00:00:00 2001 From: Sean Law Date: Wed, 29 Jan 2025 07:41:10 -0500 Subject: [PATCH 07/19] Fixed if to elif --- stumpy/cache.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/stumpy/cache.py b/stumpy/cache.py index 7e3a049e0..95d2dc49f 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -133,7 +133,7 @@ def clear(cache_dir=None): Parameters ---------- - cache_dir : str + cache_dir : str, default None The path to the numba cache directory Returns @@ -163,7 +163,7 @@ def _get_cache(cache_dir=None): warnings.warn(CACHE_WARNING) if cache_dir is not None: # pragma: no cover numba_cache_dir = str(cache_dir) - if "PYTEST_CURRENT_TEST" in os.environ: + elif "PYTEST_CURRENT_TEST" in os.environ: numba_cache_dir = "stumpy/__pycache__" else: # pragma: no cover site_pkg_dir = site.getsitepackages()[0] @@ -205,33 +205,35 @@ def _recompile(): return -def _save(): +def _save(cache_dir): """ Save all njit functions Parameters ---------- - None + cache_dir : str + The path to the numba cache directory Returns ------- None """ _enable() - _clear() + _clear(cache_dir) _recompile() return -def save(): +def save(cache_dir=None): """ Save/overwrite all the cache data files of all-so-far compiled njit functions. Parameters ---------- - None + cache_dir : str, default None + The path to the numba cache directory Returns ------- @@ -243,6 +245,6 @@ def save(): else: # pragma: no cover warnings.warn(CACHE_WARNING) - _save() + _save(cache_dir) return From f8370fd73cdf43d8e8af81a80d573377ebb325ef Mon Sep 17 00:00:00 2001 From: Sean Law Date: Wed, 29 Jan 2025 08:17:15 -0500 Subject: [PATCH 08/19] Made get_cache more verbose --- stumpy/cache.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/stumpy/cache.py b/stumpy/cache.py index 95d2dc49f..47b2594d6 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -168,7 +168,12 @@ def _get_cache(cache_dir=None): else: # pragma: no cover site_pkg_dir = site.getsitepackages()[0] numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__" - return [f.name for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()] + + return [ + f"{numba_cache_dir}/{f.name}" + for f in pathlib.Path(numba_cache_dir).glob("*nb*") + if f.is_file() + ] def _recompile(): From 4b647897f389a7e40417611fa7eff44806700382 Mon Sep 17 00:00:00 2001 From: Sean Law Date: Wed, 29 Jan 2025 08:25:37 -0500 Subject: [PATCH 09/19] Added warning --- stumpy/cache.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/stumpy/cache.py b/stumpy/cache.py index 47b2594d6..05544cace 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -250,6 +250,11 @@ def save(cache_dir=None): else: # pragma: no cover warnings.warn(CACHE_WARNING) + if numba.config.CACHE_DIR != '': + msg = "Found user specified `NUMBA_CACHE_DIR`/`numba.config.CACHE_DIR`. " + msg =+ "The `stumpy` cache files may not be saved/cleared correctly!" + warnings.warn(msg) + _save(cache_dir) return From f334516e08fb50095b745150c6092e9cc34afe8f Mon Sep 17 00:00:00 2001 From: Sean Law Date: Wed, 29 Jan 2025 10:46:51 -0500 Subject: [PATCH 10/19] Fixed typo --- stumpy/cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stumpy/cache.py b/stumpy/cache.py index 05544cace..0726a74c1 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -250,9 +250,9 @@ def save(cache_dir=None): else: # pragma: no cover warnings.warn(CACHE_WARNING) - if numba.config.CACHE_DIR != '': + if numba.config.CACHE_DIR != "": msg = "Found user specified `NUMBA_CACHE_DIR`/`numba.config.CACHE_DIR`. " - msg =+ "The `stumpy` cache files may not be saved/cleared correctly!" + msg += "The `stumpy` cache files may not be saved/cleared correctly!" warnings.warn(msg) _save(cache_dir) From a505b87a101f15466b31ae0c6f562e5099cf969f Mon Sep 17 00:00:00 2001 From: Sean Law Date: Wed, 29 Jan 2025 22:00:34 -0500 Subject: [PATCH 11/19] Refactored code --- stumpy/cache.py | 7 +------ stumpy/fastmath.py | 1 + tests/test_cache.py | 22 +++++++++------------- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/stumpy/cache.py b/stumpy/cache.py index 0726a74c1..8bbec4164 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -5,7 +5,6 @@ import ast import importlib import inspect -import os import pathlib import site import warnings @@ -118,8 +117,6 @@ def _clear(cache_dir=None): """ if cache_dir is not None: # pragma: no cover numba_cache_dir = str(cache_dir) - elif "PYTEST_CURRENT_TEST" in os.environ: - numba_cache_dir = "stumpy/__pycache__" else: # pragma: no cover site_pkg_dir = site.getsitepackages()[0] numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__" @@ -163,8 +160,6 @@ def _get_cache(cache_dir=None): warnings.warn(CACHE_WARNING) if cache_dir is not None: # pragma: no cover numba_cache_dir = str(cache_dir) - elif "PYTEST_CURRENT_TEST" in os.environ: - numba_cache_dir = "stumpy/__pycache__" else: # pragma: no cover site_pkg_dir = site.getsitepackages()[0] numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__" @@ -250,7 +245,7 @@ def save(cache_dir=None): else: # pragma: no cover warnings.warn(CACHE_WARNING) - if numba.config.CACHE_DIR != "": + if numba.config.CACHE_DIR != "": # pragma: no cover msg = "Found user specified `NUMBA_CACHE_DIR`/`numba.config.CACHE_DIR`. " msg += "The `stumpy` cache files may not be saved/cleared correctly!" warnings.warn(msg) diff --git a/stumpy/fastmath.py b/stumpy/fastmath.py index 05402c921..de99694c1 100644 --- a/stumpy/fastmath.py +++ b/stumpy/fastmath.py @@ -64,6 +64,7 @@ def _set(module_name, func_name, flag): if numba.config.DISABLE_JIT and ( str(e) == "'function' object has no attribute 'targetoptions'" ): + warnings.warn("Fastmath flags could not be set as Numba JIT is disabled") pass else: # pragma: no cover raise diff --git a/tests/test_cache.py b/tests/test_cache.py index 8cf7bca21..72f44a022 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,4 +1,3 @@ -import numba import numpy as np from stumpy import cache, stump @@ -13,22 +12,19 @@ def test_cache_save_after_clear(): T = np.random.rand(10) m = 3 - cache.save() - stump(T, m) - ref_cache = cache._get_cache() + cache_dir = "stumpy/__pycache__" - if numba.config.DISABLE_JIT: - assert len(ref_cache) == 0 - else: # pragma: no cover - assert len(ref_cache) > 0 + cache.save(cache_dir) + stump(T, m) + ref_cache = cache._get_cache(cache_dir) - cache.clear() - assert len(cache._get_cache()) == 0 + cache.clear(cache_dir) + assert len(cache._get_cache(cache_dir)) == 0 - cache.save() + cache.save(cache_dir) stump(T, m) - comp_cache = cache._get_cache() + comp_cache = cache._get_cache(cache_dir) assert sorted(ref_cache) == sorted(comp_cache) - cache.clear() + cache.clear(cache_dir) From e7d9b2b8b430e34a4d7cb7bd59e5d16b19270d2e Mon Sep 17 00:00:00 2001 From: Sean Law Date: Thu, 30 Jan 2025 07:30:04 -0500 Subject: [PATCH 12/19] Cleaned up from comments --- stumpy/cache.py | 21 +++++++++++---------- tests/test_cache.py | 12 ++++++++++-- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/stumpy/cache.py b/stumpy/cache.py index 8bbec4164..92d88cd77 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -205,39 +205,40 @@ def _recompile(): return -def _save(cache_dir): +def _save(): """ Save all njit functions Parameters ---------- - cache_dir : str - The path to the numba cache directory + None Returns ------- None """ _enable() - _clear(cache_dir) _recompile() return -def save(cache_dir=None): +def save(): """ - Save/overwrite all the cache data files of - all-so-far compiled njit functions. + Save/overwrite all of the cached njit functions. Parameters ---------- - cache_dir : str, default None - The path to the numba cache directory + None Returns ------- None + + Notes + ----- + The cache is never cleared before saving/overwriting and may be explicitly + cleared by calling `cache.clear()` before saving. """ if numba.config.DISABLE_JIT: msg = "Could not save/cache function because NUMBA JIT is disabled" @@ -250,6 +251,6 @@ def save(cache_dir=None): msg += "The `stumpy` cache files may not be saved/cleared correctly!" warnings.warn(msg) - _save(cache_dir) + _save() return diff --git a/tests/test_cache.py b/tests/test_cache.py index 72f44a022..ef53eb9dc 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,3 +1,4 @@ +import numba import numpy as np from stumpy import cache, stump @@ -14,14 +15,21 @@ def test_cache_save_after_clear(): cache_dir = "stumpy/__pycache__" - cache.save(cache_dir) + cache.clear(cache_dir) + cache.save() + stump(T, m) ref_cache = cache._get_cache(cache_dir) + if numba.config.DISABLE_JIT: + assert len(ref_cache) == 0 + else: # pragma: no cover + assert len(ref_cache) > 0 + cache.clear(cache_dir) assert len(cache._get_cache(cache_dir)) == 0 + cache.save() - cache.save(cache_dir) stump(T, m) comp_cache = cache._get_cache(cache_dir) From 723c63aff44f1d5e30189e29648d54e033fec295 Mon Sep 17 00:00:00 2001 From: Sean Law Date: Thu, 30 Jan 2025 09:24:03 -0500 Subject: [PATCH 13/19] Added warning to clear before save --- stumpy/cache.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/stumpy/cache.py b/stumpy/cache.py index 92d88cd77..4dac4b0d4 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -15,6 +15,7 @@ CACHE_WARNING += "and should never be used or depended upon as it is not supported! " CACHE_WARNING += "All caching capabilities are not tested and may be removed/changed " CACHE_WARNING += "without prior notice. Please proceed with caution!" +CACHE_CLEARED = True def get_njit_funcs(): @@ -115,6 +116,8 @@ def _clear(cache_dir=None): ------- None """ + global CACHE_CLEARED + if cache_dir is not None: # pragma: no cover numba_cache_dir = str(cache_dir) else: # pragma: no cover @@ -123,6 +126,8 @@ def _clear(cache_dir=None): [f.unlink() for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()] + CACHE_CLEARED = True + def clear(cache_dir=None): """ @@ -240,6 +245,8 @@ def save(): The cache is never cleared before saving/overwriting and may be explicitly cleared by calling `cache.clear()` before saving. """ + global CACHE_CLEARED + if numba.config.DISABLE_JIT: msg = "Could not save/cache function because NUMBA JIT is disabled" warnings.warn(msg) @@ -251,6 +258,11 @@ def save(): msg += "The `stumpy` cache files may not be saved/cleared correctly!" warnings.warn(msg) + if not CACHE_CLEARED: # pragma: no cover + msg = "The cached files are not cleared before saving/overwriting. " + msg = "You may need to call `cache.clear()` before calling `cache.save()`." + warnings.warn("msg") + _save() return From 27b2a2bb43e31440ac57e5062eaeabb2e93b6e2c Mon Sep 17 00:00:00 2001 From: Sean Law Date: Thu, 30 Jan 2025 09:30:59 -0500 Subject: [PATCH 14/19] Reset CACHE_CLEARED after cache._save() is called --- stumpy/cache.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/stumpy/cache.py b/stumpy/cache.py index 4dac4b0d4..af8427c14 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -222,9 +222,18 @@ def _save(): ------- None """ + global CACHE_CLEARED + + if not CACHE_CLEARED: # pragma: no cover + msg = "Numba njit cached files are not cleared before saving/overwriting. " + msg = "You may need to call `cache.clear()` before calling `cache.save()`." + warnings.warn(msg) + _enable() _recompile() + CACHE_CLEARED = False + return @@ -258,11 +267,6 @@ def save(): msg += "The `stumpy` cache files may not be saved/cleared correctly!" warnings.warn(msg) - if not CACHE_CLEARED: # pragma: no cover - msg = "The cached files are not cleared before saving/overwriting. " - msg = "You may need to call `cache.clear()` before calling `cache.save()`." - warnings.warn("msg") - _save() return From b2dbebd1191e34e6777260799addcb6f2300dbd3 Mon Sep 17 00:00:00 2001 From: Sean Law Date: Thu, 30 Jan 2025 23:05:25 -0500 Subject: [PATCH 15/19] Cleaned up code --- stumpy/cache.py | 11 +++++------ tests/test_cache.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/stumpy/cache.py b/stumpy/cache.py index af8427c14..c1cb93367 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -109,7 +109,7 @@ def _clear(cache_dir=None): Parameters ---------- - cache_dir : str + cache_dir : str, default None The path to the numba cache directory Returns @@ -118,7 +118,7 @@ def _clear(cache_dir=None): """ global CACHE_CLEARED - if cache_dir is not None: # pragma: no cover + if cache_dir is not None: numba_cache_dir = str(cache_dir) else: # pragma: no cover site_pkg_dir = site.getsitepackages()[0] @@ -136,7 +136,8 @@ def clear(cache_dir=None): Parameters ---------- cache_dir : str, default None - The path to the numba cache directory + The path to the numba cache directory. When `cache_dir` is `None`, then this + defaults to `site-packages/stumpy/__pycache__`. Returns ------- @@ -163,7 +164,7 @@ def _get_cache(cache_dir=None): A list of cached numba functions """ warnings.warn(CACHE_WARNING) - if cache_dir is not None: # pragma: no cover + if cache_dir is not None: numba_cache_dir = str(cache_dir) else: # pragma: no cover site_pkg_dir = site.getsitepackages()[0] @@ -254,8 +255,6 @@ def save(): The cache is never cleared before saving/overwriting and may be explicitly cleared by calling `cache.clear()` before saving. """ - global CACHE_CLEARED - if numba.config.DISABLE_JIT: msg = "Could not save/cache function because NUMBA JIT is disabled" warnings.warn(msg) diff --git a/tests/test_cache.py b/tests/test_cache.py index ef53eb9dc..ef7e37c39 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -30,7 +30,7 @@ def test_cache_save_after_clear(): assert len(cache._get_cache(cache_dir)) == 0 cache.save() - stump(T, m) + # stump(T, m) comp_cache = cache._get_cache(cache_dir) assert sorted(ref_cache) == sorted(comp_cache) From 62b9deafc52787bbadf589ad95fc373b0d6a8eec Mon Sep 17 00:00:00 2001 From: Sean Law Date: Fri, 31 Jan 2025 21:43:20 -0500 Subject: [PATCH 16/19] Added detailed cache note --- stumpy/cache.py | 11 +++++++++-- tests/test_cache.py | 7 +++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/stumpy/cache.py b/stumpy/cache.py index c1cb93367..9df0e5994 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -252,8 +252,15 @@ def save(): Notes ----- - The cache is never cleared before saving/overwriting and may be explicitly - cleared by calling `cache.clear()` before saving. + The cache is never cleared before saving/overwriting and may be explicitly cleared + by calling `cache.clear()` before saving. If `cache.save()` is called for the first + time (before any `njit` function is called) then only the `.nbi` files (i.e., the + "cache index") for all `njit` functions are saved. As each `njit` function (and + sub-functions) is called then their corresponding `.nbc` file (i.e., "object code") + is saved. Each `.nbc` file will only be saved after its `njit` function is called + once. However, subsequent calls to `cache.save()` (after clearing the cache via + `cache.clear()`) will automatically save BOTH the `.nbi` files as well as the `.nbc` + files as long as their `njit` function has been called at least once. """ if numba.config.DISABLE_JIT: msg = "Could not save/cache function because NUMBA JIT is disabled" diff --git a/tests/test_cache.py b/tests/test_cache.py index ef7e37c39..63770fc2a 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -16,9 +16,9 @@ def test_cache_save_after_clear(): cache_dir = "stumpy/__pycache__" cache.clear(cache_dir) - cache.save() + cache.save() # Saves nbi files only until njit funcs are called for the first time - stump(T, m) + stump(T, m) # Saves nbc files, subsequent saves will write both nbi and nbc files ref_cache = cache._get_cache(cache_dir) if numba.config.DISABLE_JIT: @@ -28,9 +28,8 @@ def test_cache_save_after_clear(): cache.clear(cache_dir) assert len(cache._get_cache(cache_dir)) == 0 - cache.save() + cache.save() # Save both nbi and nbc files without needing to call `stump` function - # stump(T, m) comp_cache = cache._get_cache(cache_dir) assert sorted(ref_cache) == sorted(comp_cache) From 707f4c13b4e35946c59fcc4d4efd35619376048e Mon Sep 17 00:00:00 2001 From: Sean Law Date: Fri, 31 Jan 2025 21:45:21 -0500 Subject: [PATCH 17/19] Fixed black formatting --- stumpy/cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stumpy/cache.py b/stumpy/cache.py index 9df0e5994..5a0392e6b 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -260,7 +260,7 @@ def save(): is saved. Each `.nbc` file will only be saved after its `njit` function is called once. However, subsequent calls to `cache.save()` (after clearing the cache via `cache.clear()`) will automatically save BOTH the `.nbi` files as well as the `.nbc` - files as long as their `njit` function has been called at least once. + files as long as their `njit` function has been called at least once. """ if numba.config.DISABLE_JIT: msg = "Could not save/cache function because NUMBA JIT is disabled" From f5719c35d6025154e015612e26706d4221f59b7a Mon Sep 17 00:00:00 2001 From: Sean Law Date: Fri, 31 Jan 2025 23:05:50 -0500 Subject: [PATCH 18/19] Added example --- stumpy/cache.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/stumpy/cache.py b/stumpy/cache.py index 5a0392e6b..6c3552079 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -253,14 +253,24 @@ def save(): Notes ----- The cache is never cleared before saving/overwriting and may be explicitly cleared - by calling `cache.clear()` before saving. If `cache.save()` is called for the first - time (before any `njit` function is called) then only the `.nbi` files (i.e., the - "cache index") for all `njit` functions are saved. As each `njit` function (and + by calling `cache.clear()` before saving. It is best practice to call `cache.save()` + only after calling all of your `njit` functions. If `cache.save()` is called for the + first time (before any `njit` function is called) then only the `.nbi` files (i.e., + the "cache index") for all `njit` functions are saved. As each `njit` function (and sub-functions) is called then their corresponding `.nbc` file (i.e., "object code") is saved. Each `.nbc` file will only be saved after its `njit` function is called - once. However, subsequent calls to `cache.save()` (after clearing the cache via - `cache.clear()`) will automatically save BOTH the `.nbi` files as well as the `.nbc` - files as long as their `njit` function has been called at least once. + at least once. However, subsequent calls to `cache.save()` (after clearing the cache + via `cache.clear()`) will automatically save BOTH the `.nbi` files as well as the + `.nbc` files as long as their `njit` function has been called at least once. + + Examples + -------- + >>> import stumpy + >>> from stumpy import cache + >>> import numpy as np + >>> cache.clear() + >>> mp = stumpy.stump(np.array([584., -11., 23., 79., 1001., 0., -19.]), m=3) + >>> cache.save() """ if numba.config.DISABLE_JIT: msg = "Could not save/cache function because NUMBA JIT is disabled" From f0bb30941ecdc95fc5bfb6b9d68b2f991164ced7 Mon Sep 17 00:00:00 2001 From: Sean Law Date: Sat, 1 Feb 2025 07:16:08 -0500 Subject: [PATCH 19/19] Updated test and added more comments --- tests/test_cache.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 63770fc2a..2127c8ed2 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -16,9 +16,9 @@ def test_cache_save_after_clear(): cache_dir = "stumpy/__pycache__" cache.clear(cache_dir) - cache.save() # Saves nbi files only until njit funcs are called for the first time + stump(T, m) + cache.save() # Enable and save both `.nbi` and `.nbc` cache files - stump(T, m) # Saves nbc files, subsequent saves will write both nbi and nbc files ref_cache = cache._get_cache(cache_dir) if numba.config.DISABLE_JIT: @@ -28,7 +28,9 @@ def test_cache_save_after_clear(): cache.clear(cache_dir) assert len(cache._get_cache(cache_dir)) == 0 - cache.save() # Save both nbi and nbc files without needing to call `stump` function + # Note that `stump(T, m)` has already been called once above and any subsequent + # calls to `cache.save()` will automatically save both `.nbi` and `.nbc` cache files + cache.save() # Save both `.nbi` and `.nbc` cache files comp_cache = cache._get_cache(cache_dir)