diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 4d0c239356..735e48f6c7 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -596,6 +596,47 @@ def _general_jit_named_buffers_lookaside(obj: Any, *args, **kwargs): ) +def _convert_pytorchfunc_to_thundertrace( + func: Callable[[Any], Any], + shallow_copy_output: bool, + *args, + **kwargs, +) -> tuple[TraceCtx | INTERPRETER_SIGNALS, ProvenanceRecord | None]: + """Converts pytorch function to thunder trace. + + Note that the generated trace would not have _siginfo and args set. + + Args: + func: A callable composed of pytorch functions. + shallow_copy_output: Needs to be :obj:`True` only if func is `torch.autograd.Function.apply` as + it produces views of the tensor to attach the autograd node to. + *args: + **kwargs + """ + active_jit_ctx: JitCtx = get_jit_ctx() + active_jit_ctx.computation_trace.push_scope([]) + wrapped_func_result = _interpret_call(func, *args, **kwargs) + if wrapped_func_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return wrapped_func_result, None + + trace = TraceCtx() + trace.bound_symbols.extend(active_jit_ctx.computation_trace.pop_scope()) + func_result = unwrap(wrapped_func_result) + if shallow_copy_output: + from thunder.core.baseutils import sequencify + + out_to_shallow_copy: dict[Variable, TensorProxy] = {} + for a in sequencify(func_result): + shallow_copy_of_a = prims.shallow_copy.meta(a) + bsym = prims.shallow_copy.bind(a, output=shallow_copy_of_a) + trace.add_bound_symbol(bsym) + out_to_shallow_copy[variableify(a)] = shallow_copy_of_a + func_result = tree_map(lambda t: out_to_shallow_copy.get(variableify(t), t), func_result) + with tracectx(trace): + prims.python_return(func_result) + return trace, wrapped_func_result.provenance + + @register_general_jit_lookaside(torch.autograd.function.Function.apply.__func__) def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwargs): """Encapsulate forward into a bsym, define and register augmented fwd and bwd. @@ -609,61 +650,43 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar """ from thunder.core.baseutils import check, sequencify - jit_ctx: JitCtx = get_jit_ctx() - - jit_ctx.computation_trace.push_scope([]) - custom_autograd_function_cls = unwrap(obj) - symbol_name = custom_autograd_function_cls.__name__ - custom_forward = custom_autograd_function_cls.forward ctx = torch.autograd.function.FunctionCtx() ctx_proxy = proxy(ctx, name=None, history=None) wrapped_ctx = wrap_const(ctx_proxy) - custom_forward_result = _interpret_call(custom_forward, wrapped_ctx, *args, **kwargs) - if custom_forward_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: - return custom_forward_result + trace_of_fwd, fwd_output_provenance = _convert_pytorchfunc_to_thundertrace( + custom_forward, True, wrapped_ctx, *args, **kwargs + ) # Forward. unwrapped_custom_forward_args = tree_map(lambda a: unwrap(a), args) - unwrapped_custom_forward_result = unwrap(custom_forward_result) - # autograd.Function produces views of the tensor to attache the autograd node to - unwrapped_custom_forward_result = tree_map( - lambda x: prims.shallow_copy(x) if isinstance(x, TensorProxy) else x, - unwrapped_custom_forward_result, + trace_of_fwd._siginfo = SigInfo.from_name_and_args( + custom_autograd_function_cls.__name__, + unwrapped_custom_forward_args, ) - custom_fwd_bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope() - - # not augmented for when we don't need grad - trace_of_fwd = TraceCtx() - trace_of_fwd.bound_symbols.extend(custom_fwd_bsyms) - with tracectx(trace_of_fwd): - prims.python_return(unwrapped_custom_forward_result) - - trace_of_fwd._siginfo = SigInfo.from_name_and_args(symbol_name, unwrapped_custom_forward_args) trace_of_fwd.args = unwrapped_custom_forward_args @wraps(trace_of_fwd.python_callable()) def core_of_forward(*args, **kwargs): return thunder.core.trace_interpreter.interpret_trace(trace_of_fwd, *args, **kwargs) - custom_fwd_sym = jit_ctx.ad_hoc_executor.register_operator( - symbol_name, + custom_fwd_sym = get_jit_ctx().ad_hoc_executor.register_operator( + trace_of_fwd._siginfo.name, like=core_of_forward, ) - unwrapped_forward_result = custom_fwd_sym(*unwrapped_custom_forward_args) forward_result = wrap( unwrapped_forward_result, - provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[obj.provenance, custom_forward_result.provenance]), + provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[obj.provenance, fwd_output_provenance]), ) augmented_bsym_output: tuple[tuple[TensorProxy, ...], tuple[TensorProxy, ...]] = ( - tuple(sequencify(unwrapped_custom_forward_result)), + tuple(sequencify(trace_of_fwd.output)), ctx_proxy.saved_tensors, ) trace_of_augmented_fwd = TraceCtx() - trace_of_augmented_fwd.bound_symbols.extend(custom_fwd_bsyms) + trace_of_augmented_fwd.bound_symbols.extend(trace_of_fwd.bound_symbols[:-1]) with tracectx(trace_of_augmented_fwd): prims.python_return(augmented_bsym_output) trace_of_augmented_fwd._siginfo = SigInfo.from_name_and_args(custom_fwd_sym.name, unwrapped_custom_forward_args) @@ -673,26 +696,18 @@ def core_of_forward(*args, **kwargs): custom_backward = custom_autograd_function_cls.backward grads = tree_map( lambda a: a.replace_name(f"grad_{a.name}"), - sequencify(unwrapped_custom_forward_result), + sequencify(trace_of_fwd.output), ) - trace_of_backward = TraceCtx() + wrapped_grads = tree_map(lambda g: wrap(g, provenance=fwd_output_provenance), grads) + trace_of_backward, _ = _convert_pytorchfunc_to_thundertrace(custom_backward, False, wrapped_ctx, *wrapped_grads) + if trace_of_backward is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return trace_of_backward trace_of_backward._siginfo = SigInfo.from_name_and_args( f"{custom_fwd_sym.name}_backward", ctx_proxy.saved_tensors + grads, ) trace_of_backward.args = tuple(ctx_proxy.saved_tensors + grads) - jit_ctx.computation_trace.push_scope([]) - wrapped_grads = tree_map(lambda g: wrap(g, provenance=custom_forward_result.provenance), grads) - custom_backward_result = _interpret_call(custom_backward, wrapped_ctx, *wrapped_grads) - if custom_backward_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: - return custom_backward_result - - for bsym in jit_ctx.computation_trace.pop_scope(): - trace_of_backward.add_bound_symbol(bsym) - with tracectx(trace_of_backward): - prims.python_return(unwrap(custom_backward_result)) - bwd_trace_impl = TraceCtx() bwd_trace_impl.bound_symbols.extend(trace_of_backward.bound_symbols) bwd_trace_impl._siginfo = SigInfo.from_name_and_args( @@ -714,13 +729,13 @@ def grad_transform(*args, **kwargs): check(not kwargs, lambda: f"{kwargs=} should be empty") primal, residuals = interpret_trace(trace_of_augmented_fwd, *args, **kwargs) check(len(primal) == 1, lambda: f"{primal=} has {len(primal)} proxies but expected 1") - grads = (get_grad(primal[0]),) - bwd_args = ctx_proxy.saved_consts + residuals + grads + grads = tree_map(lambda t: get_grad(t), primal) + bwd_args = ctx_proxy.saved_consts + tuple(sequencify(residuals)) + grads result = bwd_impl_callable(*bwd_args) put_grads(args, result) return primal - jit_ctx.ad_hoc_executor.register_implementation( + get_jit_ctx().ad_hoc_executor.register_implementation( custom_fwd_sym, execution_transform=core_of_forward, grad_transform=grad_transform,