Skip to content

Commit

Permalink
Speed up user filename identification.
Browse files Browse the repository at this point in the history
* add caching
* use regexs to match a set of strings, rather than repeated matches.

PiperOrigin-RevId: 606682169
  • Loading branch information
hawkinsp authored and jax authors committed Feb 13, 2024
1 parent 031fdac commit cfcae37
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions jax/_src/source_info_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import functools
import itertools
import os.path
import re
import sys
import sysconfig
import threading
Expand Down Expand Up @@ -54,15 +55,32 @@ class Frame(NamedTuple):
os.path.dirname(sysconfig.__file__)
]

@functools.cache
def _exclude_path_regex() -> re.Pattern[str]:
# The regex below would not handle an empty set of exclusions correctly.
assert len(_exclude_paths) > 0
return re.compile('|'.join(f'^{re.escape(path)}' for path in _exclude_paths))


def register_exclusion(path: str):
_exclude_paths.append(path)
_exclude_path_regex.cache_clear()
is_user_filename.cache_clear()


# Explicit inclusions take priority over exclude paths.
_include_paths: list[str] = []

@functools.cache
def _include_path_regex() -> re.Pattern[str]:
patterns = [f'^{re.escape(path)}' for path in _include_paths]
patterns.append('_test.py$')
return re.compile('|'.join(patterns))

def register_inclusion(path: str):
_include_paths.append(path)
_include_path_regex.cache_clear()
is_user_filename.cache_clear()


class Scope(NamedTuple):
Expand Down Expand Up @@ -138,11 +156,11 @@ def replace(self, *, traceback: Traceback | None = None,
def new_source_info() -> SourceInfo:
return SourceInfo(None, NameStack())

@functools.cache
def is_user_filename(filename: str) -> bool:
"""Heuristic that guesses the identity of the user's code in a stack trace."""
return (filename.endswith("_test.py") or
not any(filename.startswith(p) for p in _exclude_paths) or
any(filename.startswith(p) for p in _include_paths))
return (_include_path_regex().search(filename) is not None
or _exclude_path_regex().search(filename) is None)

if sys.version_info >= (3, 11):
def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame:
Expand Down

0 comments on commit cfcae37

Please sign in to comment.