diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 1b09d274fa..be788375bd 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -590,6 +590,9 @@ def _general_jit_named_buffers_lookaside(obj: Any, *args, **kwargs): ) +autograd_counter = 0 + + @register_general_jit_lookaside(torch.autograd.function.Function.apply.__func__) def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwargs): """Encapsulate forward into a bsym, define and register augmented fwd and bwd. @@ -601,17 +604,19 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar 3. Trace ``MyFunc.backward``, define :class:`~thunder.core.trace.TraceCtx` whose args are ``(*residuals, *grads)``. So far, non-tensor ``ctx`` attributes seem to be folded into a trace. """ + global autograd_counter + from thunder.core.baseutils import check, sequencify from thunder.core.transforms import augmented_forward_impls, backward_impls, VJPDual jit_ctx: JitCtx = get_jit_ctx() - custom_fwd_bsyms: list[BoundSymbol] = [] - orig_scopes = jit_ctx.computation_trace.scopes - - jit_ctx.computation_trace.scopes = [custom_fwd_bsyms] + jit_ctx.computation_trace.push_scope([]) custom_autograd_function_cls = unwrap(obj) + autograd_counter += 1 + symbol_name = f"{custom_autograd_function_cls.__name__}_{autograd_counter}" + custom_forward = custom_autograd_function_cls.forward ctx = torch.autograd.function.FunctionCtx() ctx_proxy = proxy(ctx, name=None, history=None) @@ -623,26 +628,11 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar # Forward. unwrapped_custom_forward_args = tree_map(lambda a: unwrap(a), args) unwrapped_custom_forward_result = unwrap(custom_forward_result) - symbol_name = custom_autograd_function_cls.__name__ - custom_fwd_sym = Symbol( - name=symbol_name, - id=symbol_name, - meta=lambda *unwrapped_custom_forward_args: unwrapped_custom_forward_result, + # autograd.Function produces views of the tensor to attache the autograd node to + unwrapped_custom_forward_result = tree_map( + lambda x: prims.shallow_copy(x) if isinstance(x, TensorProxy) else x, unwrapped_custom_forward_result ) - bsym_of_custom_fwd = BoundSymbol( - custom_fwd_sym, - args=unwrapped_custom_forward_args, - kwargs={}, - output=unwrapped_custom_forward_result, - subsymbols=custom_fwd_bsyms, - source_filename=jit_ctx.computation_trace._current_source_filename, - source_positions=None, - _call_ctx=custom_fwd_bsyms[0]._call_ctx, - _import_ctx=custom_fwd_bsyms[0]._import_ctx, - _object_ctx=custom_fwd_bsyms[0]._object_ctx, - _executor=custom_fwd_bsyms[0]._executor, - ) - orig_scopes[-1].append(bsym_of_custom_fwd) + custom_fwd_bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope() # not augmented for when we don't need grad trace_of_fwd = TraceCtx() @@ -651,7 +641,7 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar with tracectx(trace_of_fwd): prims.python_return(unwrapped_custom_forward_result) - si = SigInfo(custom_fwd_sym.name) + si = SigInfo(symbol_name) for a in unwrapped_custom_forward_args: if isinstance(a, Proxy): si.args.append((a.name, None)) @@ -665,6 +655,24 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar def core_of_forward(*args, **kwargs): return thunder.core.trace_interpreter.interpret_trace(trace_of_fwd, *args, **kwargs) + custom_fwd_meta = lambda *unwrapped_custom_forward_args: unwrapped_custom_forward_result + + def bind_postprocess(bsym): + bsym._call_ctx = {} + + custom_fwd_sym = Symbol( + name=symbol_name, + id=symbol_name, + meta=core_of_forward, + _bind_postprocess=bind_postprocess, + ) + + unwrapped_forward_result = custom_fwd_sym(*unwrapped_custom_forward_args) + forward_result = wrap( + unwrapped_forward_result, + provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[obj.provenance, custom_forward_result.provenance]), + ) + thunder.executors.torchex._register_implementation( custom_fwd_sym, core_of_forward, checker=thunder.executors.torchex._always_executable ) @@ -676,7 +684,8 @@ def core_of_forward(*args, **kwargs): trace_of_augmented_fwd = TraceCtx() for bsym in custom_fwd_bsyms: trace_of_augmented_fwd.add_bound_symbol(bsym) - trace_of_augmented_fwd.add_bound_symbol(prims.python_return.bind(augmented_bsym_output, output=())) + with tracectx(trace_of_augmented_fwd): + prims.python_return(augmented_bsym_output) si = SigInfo(custom_fwd_sym.name) for a in unwrapped_custom_forward_args: if isinstance(a, Proxy): @@ -691,8 +700,6 @@ def core_of_forward(*args, **kwargs): def core_of_augmented_forward(*args, **kwargs): return thunder.core.trace_interpreter.interpret_trace(trace_of_augmented_fwd, *args, **kwargs) - # core_of_augmented_forward = trace_of_augmented_fwd.python_callable(include_decorators=False) - @wraps(core_of_augmented_forward) def augmented_custom_forward_rule(*args, **kwargs): primal, residulas = core_of_augmented_forward(*args, **kwargs) @@ -703,21 +710,12 @@ def augmented_custom_forward_rule(*args, **kwargs): # Backward definition custom_backward = custom_autograd_function_cls.backward - custom_bwd_bsyms: list[BoundSymbol] = [] - jit_ctx.computation_trace.scopes = [custom_bwd_bsyms] + grads = tree_map( lambda a: a.replace_name(f"grad_{a.name}"), sequencify(unwrapped_custom_forward_result), ) - wrapped_grads = tree_map(lambda g: wrap(g, provenance=custom_forward_result.provenance), grads) - custom_backward_result = _interpret_call(custom_backward, wrapped_ctx, *wrapped_grads) - if custom_backward_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: - return custom_backward_result - trace_of_backward = TraceCtx() - for bsym in custom_bwd_bsyms: - trace_of_backward.add_bound_symbol(bsym) - trace_of_backward.add_bound_symbol(prims.python_return.bind(*unwrap(custom_backward_result), output=())) bwd_si = SigInfo(f"{custom_fwd_sym.name}_backward") for a in ctx_proxy.saved_tensors + grads: if isinstance(a, Proxy): @@ -728,6 +726,19 @@ def augmented_custom_forward_rule(*args, **kwargs): trace_of_backward._siginfo = bwd_si trace_of_backward.args = tuple(ctx_proxy.saved_tensors + grads) + jit_ctx.computation_trace.push_scope([]) + wrapped_grads = tree_map(lambda g: wrap(g, provenance=custom_forward_result.provenance), grads) + custom_backward_result = _interpret_call(custom_backward, wrapped_ctx, *wrapped_grads) + if custom_backward_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return custom_backward_result + + custom_bwd_bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope() + + for bsym in custom_bwd_bsyms: + trace_of_backward.add_bound_symbol(bsym) + with tracectx(trace_of_backward): + prims.python_return.bind(*unwrap(custom_backward_result), output=()) + @wraps(trace_of_backward.python_callable()) def bwd_trace_callable_interface(*args, **kwargs): return thunder.core.trace_interpreter.interpret_trace(trace_of_backward, *args, **kwargs) @@ -757,10 +768,7 @@ def backward_impl(*args, **kwargs): return bwd_impl_callable(*new_args) backward_impls[custom_fwd_sym.name] = backward_impl - - # Cosmetic - jit_ctx.computation_trace.scopes = orig_scopes - return custom_forward_result + return forward_result @register_general_jit_lookaside(torch.finfo) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index c46bda2277..d847b30936 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -166,6 +166,7 @@ class PrimIDs(Enum): TRANSPOSE = auto() UNFOLD = auto() VIEW = auto() + SHALLOW_COPY = auto() # a view copy # Memory layout prims (Experimental) STRIDE_ORDER = auto() # Elementwise unary prims @@ -1658,7 +1659,7 @@ def return_printer( ): utils.check( len(kwarg_printables) == 0, - lambda: f"Expected no kwargs for del but got {kwarg_printables}", + lambda: f"Expected no kwargs for return but got {kwarg_printables}", exception_type=AssertionError, ) @@ -3538,6 +3539,13 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro view = make_prim(PrimIDs.VIEW, "view", meta=reshape_meta, tags=(OpTags.SHAPE_OP,)) +def shallow_copy_meta(a: TensorProxy, /) -> TensorProxy: + return TensorProxy(like=a) + + +shallow_copy = make_prim(PrimIDs.SHALLOW_COPY, "shallow_copy", meta=shallow_copy_meta, tags=(OpTags.SHAPE_OP,)) + + def unfold_meta(a: TensorProxy, /, dim: int, size: int, step: int) -> TensorProxy: dim = utils.canonicalize_dim(a.ndim, dim) max_size = 1 if a.ndim == 0 else a.shape[dim] diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index ef609ab5b9..455ef6aca4 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2169,3 +2169,6 @@ def _shape_impl(t): shape = ex.register_operator("shape", meta=prims.shape_meta, fn=_shape_impl) _register_implementation(prims.shape, shape, checker=_always_executable) + +shallow_copy = ex.register_operator("shallow_copy", meta=prims.shallow_copy, fn=lambda x: x) +_register_implementation(prims.shallow_copy, shallow_copy, checker=_always_executable) diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 1a8a721c6e..44ea9cd04c 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1272,6 +1272,41 @@ def my_sin_with_wrong_backward(x): gradcheck(jitted, (x,)) +def test_autograd_function_empty_forward(): + + class Fn(torch.autograd.Function): + @staticmethod + def forward(self, x): + return x + + @staticmethod + def backward(self, grad_x): + return 2 * grad_x + + def fn(x): + # TODO: there still is a bug when the result is directly returned + return Fn.apply(x) * 3 + + a = torch.randn(2) + jfn = thunder.jit(fn) + + ref = fn(a) + out = jfn(a) + + assert_close(out, ref) + + a = torch.randn(2, requires_grad=True) + go = torch.randn_like(a) + + ref = fn(a) + out = jfn(a) + (grad,) = torch.autograd.grad(out, a, go) + (grad_ref,) = torch.autograd.grad(ref, a, go) + + assert_close(out, ref) + assert_close(grad, grad_ref) + + @requiresCUDA # I have not found a good other object to use def test_cpp_property(): def fn(): diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 2c4141c36d..51f8c62d54 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5812,6 +5812,7 @@ def check_overlap_ops(): tensor_split, chunk, getitem, + prims.shallow_copy, } # Add all auto-registered torch operators symbol that return tensor views to _syms_returning_views