diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index be19748592..d98836a2e5 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -627,6 +627,8 @@ def _convert_pytorchfunc_to_thundertrace( *args: **kwargs """ + from thunder.core.baseutils import sequencify + active_jit_ctx: JitCtx = get_jit_ctx() active_jit_ctx.computation_trace.push_scope([]) wrapped_func_result = _interpret_call(func, *args, **kwargs) @@ -637,8 +639,6 @@ def _convert_pytorchfunc_to_thundertrace( trace.bound_symbols.extend(active_jit_ctx.computation_trace.pop_scope()) func_result = unwrap(wrapped_func_result) if shallow_copy_output: - from thunder.core.baseutils import sequencify - out_to_shallow_copy: dict[Variable, TensorProxy] = {} for a in sequencify(func_result): shallow_copy_of_a = prims.shallow_copy.meta(a) @@ -648,7 +648,7 @@ def _convert_pytorchfunc_to_thundertrace( func_result = tree_map(lambda t: out_to_shallow_copy.get(variableify(t), t), func_result) with tracectx(trace): prims.python_return(func_result) - return trace, wrapped_func_result.provenance + return trace, sequencify(wrapped_func_result)[0].provenance @register_general_jit_lookaside(torch.autograd.function.Function.apply.__func__) @@ -792,17 +792,17 @@ def _generate_random_str_id() -> str: # non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx") length_of_tensor_args = sum(args_tensor_mask) new_fwd_args = (wrap_const(None),) + fwd_args[:length_of_tensor_args] - jit_ctx.computation_trace.push_scope([]) - fwd_result = _interpret_call(fwd, *new_fwd_args) - if fwd_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: - return fwd_result - output, saved_values = unwrap(fwd_result) - wrapped_output = wrap(output, provenance=fwd_result.provenance) + aug_fwd_trace, aug_fwd_provenance = _convert_pytorchfunc_to_thundertrace(fwd, False, *new_fwd_args) + if aug_fwd_trace is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return aug_fwd_trace + aug_fwd_result = aug_fwd_trace.output + output, saved_values = unwrap(aug_fwd_result) + wrapped_output = wrap(output, provenance=aug_fwd_provenance) unwrapped_fwd_args = tree_map(lambda t: unwrap(t), new_fwd_args)[1:] - fwd_bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope() + fwd_bsyms: list[BoundSymbol] = aug_fwd_trace.bound_symbols producer_map = utils.producers(fwd_bsyms) tensor_to_prod_bsym: dict[Variable, BoundSymbol] = {} for p in tree_flatten((output, saved_values))[0]: @@ -813,67 +813,85 @@ def _generate_random_str_id() -> str: tensor_to_prod_bsym[variableify(p)] = prod_bsym prod_bsym_to_tensor = {v: unvariableify(k) for k, v in tensor_to_prod_bsym.items()} - # Encapsulate custom fwd into a bsym. sym_id = f"autograd_function_apply_{_generate_random_str_id()}" - sym = Symbol( - name=sym_id, - id=sym_id, - _module=fwd_bsyms[-1].sym.module, - ) - bsym_of_custom_autograd_func = BoundSymbol( - sym, - args=unwrapped_fwd_args, - kwargs={}, - output=output, - subsymbols=fwd_bsyms, - header=( - f"output of fwd_body: {output}, saved_values from fwd_body: " - f"{[t.name if isinstance(t, Proxy) else t for t in saved_values]}" - ), - source_filename=jit_ctx.computation_trace._current_source_filename, - source_positions=None, - _call_ctx=fwd_bsyms[0]._call_ctx, - _import_ctx=fwd_bsyms[0]._import_ctx, - _object_ctx=fwd_bsyms[0]._object_ctx, - _executor=fwd_bsyms[0]._executor, + vanilla_fwd_trace = TraceCtx() + vanilla_fwd_trace.args = unwrapped_fwd_args + unpack_bsyms = [ + prims.unpack_trivial.bind(a, name=a.name, output=a) + for a in filter(lambda a: isinstance(a, Proxy), vanilla_fwd_trace.args) + ] + for bsym in unpack_bsyms + fwd_bsyms[:-1]: + vanilla_fwd_trace.add_bound_symbol(bsym) + vanilla_fwd_trace.add_bound_symbol(prims.python_return.bind(output, output=())) + vanilla_fwd_trace._siginfo = SigInfo.from_name_and_args(sym_id, vanilla_fwd_trace.args) + + @wraps(vanilla_fwd_trace.python_callable()) + def core_of_fwd(*args, **kwargs): + return thunder.core.trace_interpreter.interpret_trace(vanilla_fwd_trace, *args, **kwargs) + + sym = jit_ctx.ad_hoc_executor.register_operator( + vanilla_fwd_trace._siginfo.name, + like=core_of_fwd, ) - jit_ctx.computation_trace.scopes[-1].append(bsym_of_custom_autograd_func) + unwrapped_forward_result = sym(*unwrapped_fwd_args) # Define augmented fwd rule and backward rule on the fly. augmented_fwd_trace = TraceCtx() - for bsym in fwd_bsyms: + augmented_fwd_trace.args = vanilla_fwd_trace.args + for bsym in unpack_bsyms + fwd_bsyms[:-1]: augmented_fwd_trace.add_bound_symbol(bsym) augmented_fwd_trace.add_bound_symbol(prims.python_return.bind(output, saved_values, output=())) - si = SigInfo.from_name_and_args(f"augmented_autograd_function_apply_{sym_id}", bsym_of_custom_autograd_func.args) + si = SigInfo.from_name_and_args(f"augmented_autograd_function_apply_{sym_id}", augmented_fwd_trace.args) augmented_fwd_trace._siginfo = si - augmented_fwd_callable = augmented_fwd_trace.python_callable(include_decorators=False) - def augmented_fwd_rule(*args): - # First arg is `None` or `FunctionCtx` - updated_output, updated_saved_values = augmented_fwd_callable(*args) - residuals = tuple(sequencify(updated_saved_values)) - return VJPDual(primal=updated_output, residuals=residuals) + grads = sequencify(tree_map(lambda t: TensorProxy(like=t), sequencify(output))) + bwd_args = (wrap_const(None),) + bwd_tensor_args = grads + tuple(saved_values) + wrapped_bwd_tensor_args = tree_map(lambda t: wrap(t, provenance=aug_fwd_provenance), bwd_tensor_args) + bwd_trace, bwd_trace_provenance = _convert_pytorchfunc_to_thundertrace( + bwd, + False, + *(bwd_args + wrapped_bwd_tensor_args), + ) + if bwd_trace is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return bwd_trace + bwd_trace.args = bwd_tensor_args + bwd_unpack_bsyms = [ + prims.unpack_trivial.bind(a, name=a.name, output=a) + for a in filter(lambda a: isinstance(a, Proxy), bwd_trace.args) + ] + bwd_trace.bound_symbols = bwd_unpack_bsyms + bwd_trace.bound_symbols + bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{sym_id}", saved_values + grads) - augmented_forward_impls[sym.id] = augmented_fwd_rule + @wraps(bwd_trace.python_callable()) + def bwd_impl_callable(*args, **kwargs): + return thunder.core.trace_interpreter.interpret_trace(bwd_trace, *args, **kwargs) - jit_ctx.computation_trace.push_scope([]) - bwd_trace = TraceCtx() + @wraps(core_of_fwd) + def grad_transform(*args, **kwargs): + from thunder.core.transforms import get_grad, put_grads - grads = sequencify(tree_map(lambda t: TensorProxy(like=t), output)) - bwd_args = (wrap_const(None),) - bwd_tensor_args = grads + tuple(saved_values) - wrapped_bwd_tensor_args = tree_map(lambda t: wrap(t, provenance=fwd_result.provenance), bwd_tensor_args) - bwd_result = _interpret_call(bwd, *(bwd_args + wrapped_bwd_tensor_args)) - if bwd_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: - return bwd_result - unwrapped_bwd_result = unwrap(bwd_result) - bwd_trace.bound_symbols = jit_ctx.computation_trace.pop_scope() - bwd_trace.bound_symbols.append(prims.python_return.bind(unwrapped_bwd_result, output=())) - - bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{si.name}", saved_values + grads) - backward_impls[sym.id] = bwd_trace.python_callable(include_decorators=False) - - return wrapped_output + primal, residuals = thunder.core.trace_interpreter.interpret_trace( + augmented_fwd_trace, + *args, + **kwargs, + ) + grads = tree_map(lambda t: get_grad(t), sequencify(primal)) + bwd_args = grads + residuals + result = bwd_impl_callable(*bwd_args) + put_grads(args, result) + return primal + + jit_ctx.ad_hoc_executor.register_implementation( + sym, + execution_transform=core_of_fwd, + grad_transform=grad_transform, + ) + + return wrap( + unwrapped_forward_result, + provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[fwd.provenance, bwd.provenance]), + ) @register_general_jit_lookaside(torch.autocast.__enter__)