Skip to content

Commit

Permalink
[jit_ext] Inserts unpack bsyms to traces generated inside `torch.auto…
Browse files Browse the repository at this point in the history
…grad.Function` (#1414)
  • Loading branch information
crcrpar authored Nov 11, 2024
1 parent a13a099 commit f7b2e15
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,11 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
unwrapped_custom_forward_args,
)
trace_of_fwd.args = unwrapped_custom_forward_args
unpack_bsyms = [
prims.unpack_trivial.bind(a, name=a.name, output=a)
for a in filter(lambda a: isinstance(a, Proxy), trace_of_fwd.args)
]
trace_of_fwd.bound_symbols = unpack_bsyms + trace_of_fwd.bound_symbols

@wraps(trace_of_fwd.python_callable())
def core_of_forward(*args, **kwargs):
Expand Down Expand Up @@ -723,6 +728,11 @@ def core_of_forward(*args, **kwargs):
ctx_proxy.saved_tensors + grads,
)
trace_of_backward.args = tuple(ctx_proxy.saved_tensors + grads)
bwd_unpack_bsyms = [
prims.unpack_trivial.bind(a, name=a.name, output=a)
for a in filter(lambda a: isinstance(a, Proxy), trace_of_backward.args)
]
trace_of_backward.bound_symbols = bwd_unpack_bsyms + trace_of_backward.bound_symbols

bwd_trace_impl = TraceCtx()
bwd_trace_impl.bound_symbols.extend(trace_of_backward.bound_symbols)
Expand Down

0 comments on commit f7b2e15

Please sign in to comment.