From f7b2e1590fd0d98985131c54653b2a95ba4729ed Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 12 Nov 2024 04:47:38 +0900 Subject: [PATCH] [jit_ext] Inserts unpack bsyms to traces generated inside `torch.autograd.Function` (#1414) --- thunder/core/jit_ext.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 0c6b21e55e..c7ec43ad45 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -682,6 +682,11 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar unwrapped_custom_forward_args, ) trace_of_fwd.args = unwrapped_custom_forward_args + unpack_bsyms = [ + prims.unpack_trivial.bind(a, name=a.name, output=a) + for a in filter(lambda a: isinstance(a, Proxy), trace_of_fwd.args) + ] + trace_of_fwd.bound_symbols = unpack_bsyms + trace_of_fwd.bound_symbols @wraps(trace_of_fwd.python_callable()) def core_of_forward(*args, **kwargs): @@ -723,6 +728,11 @@ def core_of_forward(*args, **kwargs): ctx_proxy.saved_tensors + grads, ) trace_of_backward.args = tuple(ctx_proxy.saved_tensors + grads) + bwd_unpack_bsyms = [ + prims.unpack_trivial.bind(a, name=a.name, output=a) + for a in filter(lambda a: isinstance(a, Proxy), trace_of_backward.args) + ] + trace_of_backward.bound_symbols = bwd_unpack_bsyms + trace_of_backward.bound_symbols bwd_trace_impl = TraceCtx() bwd_trace_impl.bound_symbols.extend(trace_of_backward.bound_symbols)