Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DebugTransform - new interface, fixes and doc update. #978

Merged
merged 8 commits into from
Aug 19, 2024
40 changes: 34 additions & 6 deletions thunder/dev_utils/debug_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@


def create_debug_boundsymbol(name: str, bsym: BoundSymbol, call_ctx: Callable):
debug_sym = Symbol(name, lambda *_: None, is_prim=True)
debug_bsym = debug_sym.bind(*bsym.args, output=None, _call_ctx={name: partial(call_ctx, bsym)}, **bsym.kwargs)
def bind_postprocess(debug_bsym):
debug_bsym._call_ctx = {name: partial(call_ctx, debug_bsym, bsym)}

debug_sym = Symbol(name, lambda *_: None, is_prim=True, _bind_postprocess=bind_postprocess)
debug_bsym = debug_sym.bind(*bsym.args, output=None, **bsym.kwargs)
return debug_bsym


class DebugTransform(thunder.core.transforms.Transform):
class _DebugTransform(thunder.core.transforms.Transform):
def __init__(
self,
*,
Expand All @@ -39,7 +42,7 @@ def transform_trace_post_optimization(self, trace: TraceCtx, **kwargs) -> TraceC

if self.pre_callback is not None:

def _pre_call_ctx(bsym, *args, **kwargs):
def _pre_call_ctx(pre_debug_bsym, bsym, *args, **kwargs):
out = self.pre_callback(bsym, *args, **kwargs)
thunder.core.utils.check_type(out, str)
pre_debug_bsym.header = out
Expand All @@ -52,7 +55,7 @@ def _pre_call_ctx(bsym, *args, **kwargs):

if self.post_callback is not None:

def _post_call_ctx(bsym, *args, **kwargs):
def _post_call_ctx(post_debug_bsym, bsym, *args, **kwargs):
out = self.post_callback(bsym, *args, **kwargs)
thunder.core.utils.check_type(out, str)
post_debug_bsym.header = out
Expand All @@ -61,10 +64,35 @@ def _post_call_ctx(bsym, *args, **kwargs):
post_debug_bsym = create_debug_boundsymbol(post_debug_name, bsym, _post_call_ctx)
new_bsyms.append(post_debug_bsym)

debug_counter += 1

debug_trace.bound_symbols = new_bsyms
elapsed_time_ns = time.perf_counter_ns() - start_time_ns

debug_trace.set_provenance(TraceProvenance(f"Debug trace (took {elapsed_time_ns * 1e-6:.2f} milliseconds)"))

debug_counter += 1
return debug_trace


def debug_execution_trace(cfn, pre_callback: Callable | None = None, post_callback: Callable | None = None):
"""
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
Adds a debugging transform to the trace allowing pre and post execution callbacks.

The function inserts debug symbols in the computation traces to call the callbacks before and/or after each symbol
in the trace. These callbacks can be used to inspect or log information about the execution of the computation.

Args:
cfn: :func:`thunder.jit` function to debug.
pre_callback: An optional callable that is executed before each bound symbol is processed.
It should have the signature ``(BoundSymbol, *args, **kwargs)`` and return a string. If :obj:`None`, no
pre-execution callback is used.
post_callback: An optional callable that is executed after each bound symbol is processed.
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
It should have the signature ``(BoundSymbol, *args, **kwargs)`` and return a string. If :obj:`None`, no
post-execution callback is used.
"""
if pre_callback is None and post_callback is None:
raise RuntimeError(
"debug_execution_trace: Both `pre_callback` and `post_callback` were None, expected atleast one of them to not be None."
)
_debug_transform = _DebugTransform(pre_callback=pre_callback, post_callback=post_callback)
return thunder.core.transforms.add_transform(cfn, transform=_debug_transform)
41 changes: 41 additions & 0 deletions thunder/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,44 @@ def test_nvfuser_cse():
assert prologue_proxy.device == thunder.core.devices.to_device(t.device)
assert comp_proxy.dtype == thunder.core.dtypes.to_dtype(t.dtype)
assert prologue_proxy.dtype == thunder.core.dtypes.to_dtype(t.dtype)


def test_debug_transform():
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
from thunder.dev_utils.debug_transform import debug_execution_trace

# Only use the primitive operations in `fn` so that
# we can count them easily.
N_PRIMITIVE_OPS = 3

def fn(x, y):
return (x + y * y) / x

def pre_callback(bsym, *args, **kwargs):
return f"Pre - {bsym.sym.name}"

def post_callback(bsym, *args, **kwargs):
return f"Post - {bsym.sym.name}"

jfn = debug_execution_trace(thunder.jit(fn), pre_callback=pre_callback, post_callback=post_callback)
x = torch.randn(3, 3)
y = torch.randn(3, 3)
jfn(x, y)

fwd_exec_trace = thunder.last_traces(jfn)[-1]

debug_syms = set()
for bsym in fwd_exec_trace.bound_symbols:
if bsym.sym.name.startswith("debug"):
debug_syms.add(bsym)

n_expected_debug_syms = 2 * N_PRIMITIVE_OPS # Multiply by 2 as we have both pre and post callbacks
assert len(debug_syms) == n_expected_debug_syms

# As `debug_syms` have name of the form `debug_{pre|post}_{sym_name}_{debug_count}`,
# we expect to see debug_sym with `N_PRIMITIVE_OPS` at `{debug_count}` part of the name.
assert any(map(lambda bsym: bsym.sym.name.endswith(f"{str(N_PRIMITIVE_OPS)}"), debug_syms))

# Verify that we have correctly set the header for all debug_syms.
debug_headers = {sym.header for sym in debug_syms}
expected_headers = {"Pre - true_divide", "Pre - add", "Pre - mul", "Post - true_divide", "Post - add", "Post - mul"}
assert debug_headers == expected_headers
Loading