From 71db6cd9d5fa06ede7fd63b52303ccc5c84afcef Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 3 Oct 2024 23:42:29 +0900 Subject: [PATCH] remove useless context Signed-off-by: Masaki Kozuki --- thunder/core/jit_ext.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 733602f497..f6da1bd4c0 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -848,8 +848,7 @@ def _generate_random_str_id() -> str: augmented_fwd_trace = TraceCtx() for bsym in fwd_bsyms: augmented_fwd_trace.add_bound_symbol(bsym) - with tracectx(augmented_fwd_trace): - augmented_fwd_trace.add_bound_symbol(prims.python_return.bind(output, saved_values, output=())) + augmented_fwd_trace.add_bound_symbol(prims.python_return.bind(output, saved_values, output=())) si = SigInfo(f"augmented_autograd_function_apply_{sym_id}") for a in bsym_of_custom_autograd_func.args: if isinstance(a, Proxy): @@ -881,8 +880,7 @@ def augmented_fwd_rule(*args): if bwd_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: return bwd_result unwrapped_bwd_result = unwrap(bwd_result) - with tracectx(bwd_trace): - bwd_trace.bound_symbols.append(prims.python_return.bind(unwrapped_bwd_result, output=())) + bwd_trace.bound_symbols.append(prims.python_return.bind(unwrapped_bwd_result, output=())) bwd_si = SigInfo(f"bwd_{si.name}") for a in saved_values + grads: