From cfcae3763c14a467b26e397fbc516624b9cbe94a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 13 Feb 2024 11:03:53 -0800 Subject: [PATCH] Speed up user filename identification. * add caching * use regexs to match a set of strings, rather than repeated matches. PiperOrigin-RevId: 606682169 --- jax/_src/source_info_util.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index f28e78a20a26..356e5c0d883f 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -20,6 +20,7 @@ import functools import itertools import os.path +import re import sys import sysconfig import threading @@ -54,15 +55,32 @@ class Frame(NamedTuple): os.path.dirname(sysconfig.__file__) ] +@functools.cache +def _exclude_path_regex() -> re.Pattern[str]: + # The regex below would not handle an empty set of exclusions correctly. + assert len(_exclude_paths) > 0 + return re.compile('|'.join(f'^{re.escape(path)}' for path in _exclude_paths)) + + def register_exclusion(path: str): _exclude_paths.append(path) + _exclude_path_regex.cache_clear() + is_user_filename.cache_clear() # Explicit inclusions take priority over exclude paths. _include_paths: list[str] = [] +@functools.cache +def _include_path_regex() -> re.Pattern[str]: + patterns = [f'^{re.escape(path)}' for path in _include_paths] + patterns.append('_test.py$') + return re.compile('|'.join(patterns)) + def register_inclusion(path: str): _include_paths.append(path) + _include_path_regex.cache_clear() + is_user_filename.cache_clear() class Scope(NamedTuple): @@ -138,11 +156,11 @@ def replace(self, *, traceback: Traceback | None = None, def new_source_info() -> SourceInfo: return SourceInfo(None, NameStack()) +@functools.cache def is_user_filename(filename: str) -> bool: """Heuristic that guesses the identity of the user's code in a stack trace.""" - return (filename.endswith("_test.py") or - not any(filename.startswith(p) for p in _exclude_paths) or - any(filename.startswith(p) for p in _include_paths)) + return (_include_path_regex().search(filename) is not None + or _exclude_path_regex().search(filename) is None) if sys.version_info >= (3, 11): def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame: