Skip to content

Commit

Permalink
add helper for SigInfo creation in Function.apply lookaside (#1283)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored Oct 10, 2024
1 parent dafc79d commit 588cf3c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
13 changes: 13 additions & 0 deletions thunder/core/codeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,19 @@ def _arg_printer(name: str, has_default: bool, default: Any = None) -> str:

return f"def {self.name}({arg_str}):"

@staticmethod
def from_name_and_args(name: str, args: Sequence[Any]):
si = SigInfo(name)
for a in args:
if isinstance(a, ProxyInterface):
si.args.append((a.name, None))
else:
from thunder.core.proxies import proxy

pa = proxy(a)
si.args.append((pa.name, None))
return si


# Creates a SigInfo object from a function and the inputs to it
# The SigInfo object contains name and value information for the args, varargs, kwargs, and varkwargs
Expand Down
30 changes: 6 additions & 24 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,14 +640,7 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
with tracectx(trace_of_fwd):
prims.python_return(unwrapped_custom_forward_result)

si = SigInfo(symbol_name)
for a in unwrapped_custom_forward_args:
if isinstance(a, Proxy):
si.args.append((a.name, None))
else:
pa = proxy(a)
si.args.append((pa.name, None))
trace_of_fwd._siginfo = si
trace_of_fwd._siginfo = SigInfo.from_name_and_args(symbol_name, unwrapped_custom_forward_args)
trace_of_fwd.args = unwrapped_custom_forward_args

@wraps(trace_of_fwd.python_callable())
Expand Down Expand Up @@ -687,14 +680,7 @@ def bind_postprocess(bsym):
trace_of_augmented_fwd.add_bound_symbol(bsym)
with tracectx(trace_of_augmented_fwd):
prims.python_return(augmented_bsym_output)
si = SigInfo(custom_fwd_sym.name)
for a in unwrapped_custom_forward_args:
if isinstance(a, Proxy):
si.args.append((a.name, None))
else:
pa = proxy(a)
si.args.append((pa.name, None))
trace_of_augmented_fwd._siginfo = si
trace_of_augmented_fwd._siginfo = SigInfo.from_name_and_args(custom_fwd_sym.name, unwrapped_custom_forward_args)
trace_of_augmented_fwd.args = unwrapped_custom_forward_args

@wraps(trace_of_augmented_fwd.python_callable())
Expand Down Expand Up @@ -745,18 +731,14 @@ def augmented_custom_forward_rule(*args, **kwargs):
def bwd_trace_callable_interface(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(trace_of_backward, *args, **kwargs)

bwd_si = SigInfo("backward_impl")
for a in ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads:
if isinstance(a, Proxy):
bwd_si.args.append((a.name, None))
else:
pa = proxy(a)
bwd_si.args.append((pa.name, None))
bwd_trace_impl = TraceCtx()
for bsym in custom_bwd_bsyms:
bwd_trace_impl.add_bound_symbol(bsym)
bwd_trace_impl.add_bound_symbol(prims.python_return.bind(*sequencify(unwrap(custom_backward_result)), output=()))
bwd_trace_impl._siginfo = bwd_si
bwd_trace_impl._siginfo = SigInfo.from_name_and_args(
"backward_impl",
ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads,
)
bwd_trace_impl.args = tuple(ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads)

@wraps(bwd_trace_impl.python_callable())
Expand Down

0 comments on commit 588cf3c

Please sign in to comment.