Skip to content

Commit

Permalink
converter of pytorch func to thunder trace
Browse files Browse the repository at this point in the history
for a bit of better readability of the lookaside of
`torch.autograd.Function`.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Oct 18, 2024
1 parent 947e9ef commit 95d047b
Showing 1 changed file with 60 additions and 45 deletions.
105 changes: 60 additions & 45 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,47 @@ def _general_jit_named_buffers_lookaside(obj: Any, *args, **kwargs):
)


def _convert_pytorchfunc_to_thundertrace(
func: Callable[[Any], Any],
shallow_copy_output: bool,
*args,
**kwargs,
) -> tuple[TraceCtx | INTERPRETER_SIGNALS, ProvenanceRecord | None]:
"""Converts pytorch function to thunder trace.
Note that the generated trace would not have _siginfo and args set.
Args:
func: A callable composed of pytorch functions.
shallow_copy_output: Needs to be :obj:`True` only if func is `torch.autograd.Function.apply` as
it produces views of the tensor to attach the autograd node to.
*args:
**kwargs
"""
active_jit_ctx: JitCtx = get_jit_ctx()
active_jit_ctx.computation_trace.push_scope([])
wrapped_func_result = _interpret_call(func, *args, **kwargs)
if wrapped_func_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return wrapped_func_result, None

trace = TraceCtx()
trace.bound_symbols.extend(active_jit_ctx.computation_trace.pop_scope())
func_result = unwrap(wrapped_func_result)
if shallow_copy_output:
from thunder.core.baseutils import sequencify

out_to_shallow_copy: dict[Variable, TensorProxy] = {}
for a in sequencify(func_result):
shallow_copy_of_a = prims.shallow_copy.meta(a)
bsym = prims.shallow_copy.bind(a, output=shallow_copy_of_a)
trace.add_bound_symbol(bsym)
out_to_shallow_copy[variableify(a)] = shallow_copy_of_a
func_result = tree_map(lambda t: out_to_shallow_copy.get(variableify(t), t), func_result)
with tracectx(trace):
prims.python_return(func_result)
return trace, wrapped_func_result.provenance


@register_general_jit_lookaside(torch.autograd.function.Function.apply.__func__)
def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwargs):
"""Encapsulate forward into a bsym, define and register augmented fwd and bwd.
Expand All @@ -609,61 +650,43 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
"""
from thunder.core.baseutils import check, sequencify

jit_ctx: JitCtx = get_jit_ctx()

jit_ctx.computation_trace.push_scope([])

custom_autograd_function_cls = unwrap(obj)
symbol_name = custom_autograd_function_cls.__name__

custom_forward = custom_autograd_function_cls.forward
ctx = torch.autograd.function.FunctionCtx()
ctx_proxy = proxy(ctx, name=None, history=None)
wrapped_ctx = wrap_const(ctx_proxy)
custom_forward_result = _interpret_call(custom_forward, wrapped_ctx, *args, **kwargs)
if custom_forward_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return custom_forward_result
trace_of_fwd, fwd_output_provenance = _convert_pytorchfunc_to_thundertrace(
custom_forward, True, wrapped_ctx, *args, **kwargs
)

# Forward.
unwrapped_custom_forward_args = tree_map(lambda a: unwrap(a), args)
unwrapped_custom_forward_result = unwrap(custom_forward_result)
# autograd.Function produces views of the tensor to attache the autograd node to
unwrapped_custom_forward_result = tree_map(
lambda x: prims.shallow_copy(x) if isinstance(x, TensorProxy) else x,
unwrapped_custom_forward_result,
trace_of_fwd._siginfo = SigInfo.from_name_and_args(
custom_autograd_function_cls.__name__,
unwrapped_custom_forward_args,
)
custom_fwd_bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope()

# not augmented for when we don't need grad
trace_of_fwd = TraceCtx()
trace_of_fwd.bound_symbols.extend(custom_fwd_bsyms)
with tracectx(trace_of_fwd):
prims.python_return(unwrapped_custom_forward_result)

trace_of_fwd._siginfo = SigInfo.from_name_and_args(symbol_name, unwrapped_custom_forward_args)
trace_of_fwd.args = unwrapped_custom_forward_args

@wraps(trace_of_fwd.python_callable())
def core_of_forward(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(trace_of_fwd, *args, **kwargs)

custom_fwd_sym = jit_ctx.ad_hoc_executor.register_operator(
symbol_name,
custom_fwd_sym = get_jit_ctx().ad_hoc_executor.register_operator(
trace_of_fwd._siginfo.name,
like=core_of_forward,
)

unwrapped_forward_result = custom_fwd_sym(*unwrapped_custom_forward_args)
forward_result = wrap(
unwrapped_forward_result,
provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[obj.provenance, custom_forward_result.provenance]),
provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[obj.provenance, fwd_output_provenance]),
)

augmented_bsym_output: tuple[tuple[TensorProxy, ...], tuple[TensorProxy, ...]] = (
tuple(sequencify(unwrapped_custom_forward_result)),
tuple(sequencify(trace_of_fwd.output)),
ctx_proxy.saved_tensors,
)
trace_of_augmented_fwd = TraceCtx()
trace_of_augmented_fwd.bound_symbols.extend(custom_fwd_bsyms)
trace_of_augmented_fwd.bound_symbols.extend(trace_of_fwd.bound_symbols[:-1])
with tracectx(trace_of_augmented_fwd):
prims.python_return(augmented_bsym_output)
trace_of_augmented_fwd._siginfo = SigInfo.from_name_and_args(custom_fwd_sym.name, unwrapped_custom_forward_args)
Expand All @@ -673,26 +696,18 @@ def core_of_forward(*args, **kwargs):
custom_backward = custom_autograd_function_cls.backward
grads = tree_map(
lambda a: a.replace_name(f"grad_{a.name}"),
sequencify(unwrapped_custom_forward_result),
sequencify(trace_of_fwd.output),
)
trace_of_backward = TraceCtx()
wrapped_grads = tree_map(lambda g: wrap(g, provenance=fwd_output_provenance), grads)
trace_of_backward, _ = _convert_pytorchfunc_to_thundertrace(custom_backward, False, wrapped_ctx, *wrapped_grads)
if trace_of_backward is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return trace_of_backward
trace_of_backward._siginfo = SigInfo.from_name_and_args(
f"{custom_fwd_sym.name}_backward",
ctx_proxy.saved_tensors + grads,
)
trace_of_backward.args = tuple(ctx_proxy.saved_tensors + grads)

jit_ctx.computation_trace.push_scope([])
wrapped_grads = tree_map(lambda g: wrap(g, provenance=custom_forward_result.provenance), grads)
custom_backward_result = _interpret_call(custom_backward, wrapped_ctx, *wrapped_grads)
if custom_backward_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return custom_backward_result

for bsym in jit_ctx.computation_trace.pop_scope():
trace_of_backward.add_bound_symbol(bsym)
with tracectx(trace_of_backward):
prims.python_return(unwrap(custom_backward_result))

bwd_trace_impl = TraceCtx()
bwd_trace_impl.bound_symbols.extend(trace_of_backward.bound_symbols)
bwd_trace_impl._siginfo = SigInfo.from_name_and_args(
Expand All @@ -714,13 +729,13 @@ def grad_transform(*args, **kwargs):
check(not kwargs, lambda: f"{kwargs=} should be empty")
primal, residuals = interpret_trace(trace_of_augmented_fwd, *args, **kwargs)
check(len(primal) == 1, lambda: f"{primal=} has {len(primal)} proxies but expected 1")
grads = (get_grad(primal[0]),)
bwd_args = ctx_proxy.saved_consts + residuals + grads
grads = tree_map(lambda t: get_grad(t), primal)
bwd_args = ctx_proxy.saved_consts + tuple(sequencify(residuals)) + grads
result = bwd_impl_callable(*bwd_args)
put_grads(args, result)
return primal

jit_ctx.ad_hoc_executor.register_implementation(
get_jit_ctx().ad_hoc_executor.register_implementation(
custom_fwd_sym,
execution_transform=core_of_forward,
grad_transform=grad_transform,
Expand Down

0 comments on commit 95d047b

Please sign in to comment.