Skip to content

Commit

Permalink
handle empty autograd.Function (#1221)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Oct 1, 2024
1 parent 69ee6a2 commit 10a8a44
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 42 deletions.
90 changes: 49 additions & 41 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,9 @@ def _general_jit_named_buffers_lookaside(obj: Any, *args, **kwargs):
)


autograd_counter = 0


@register_general_jit_lookaside(torch.autograd.function.Function.apply.__func__)
def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwargs):
"""Encapsulate forward into a bsym, define and register augmented fwd and bwd.
Expand All @@ -601,17 +604,19 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
3. Trace ``MyFunc.backward``, define :class:`~thunder.core.trace.TraceCtx` whose args are ``(*residuals, *grads)``.
So far, non-tensor ``ctx`` attributes seem to be folded into a trace.
"""
global autograd_counter

from thunder.core.baseutils import check, sequencify
from thunder.core.transforms import augmented_forward_impls, backward_impls, VJPDual

jit_ctx: JitCtx = get_jit_ctx()

custom_fwd_bsyms: list[BoundSymbol] = []
orig_scopes = jit_ctx.computation_trace.scopes

jit_ctx.computation_trace.scopes = [custom_fwd_bsyms]
jit_ctx.computation_trace.push_scope([])

custom_autograd_function_cls = unwrap(obj)
autograd_counter += 1
symbol_name = f"{custom_autograd_function_cls.__name__}_{autograd_counter}"

custom_forward = custom_autograd_function_cls.forward
ctx = torch.autograd.function.FunctionCtx()
ctx_proxy = proxy(ctx, name=None, history=None)
Expand All @@ -623,26 +628,11 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
# Forward.
unwrapped_custom_forward_args = tree_map(lambda a: unwrap(a), args)
unwrapped_custom_forward_result = unwrap(custom_forward_result)
symbol_name = custom_autograd_function_cls.__name__
custom_fwd_sym = Symbol(
name=symbol_name,
id=symbol_name,
meta=lambda *unwrapped_custom_forward_args: unwrapped_custom_forward_result,
# autograd.Function produces views of the tensor to attache the autograd node to
unwrapped_custom_forward_result = tree_map(
lambda x: prims.shallow_copy(x) if isinstance(x, TensorProxy) else x, unwrapped_custom_forward_result
)
bsym_of_custom_fwd = BoundSymbol(
custom_fwd_sym,
args=unwrapped_custom_forward_args,
kwargs={},
output=unwrapped_custom_forward_result,
subsymbols=custom_fwd_bsyms,
source_filename=jit_ctx.computation_trace._current_source_filename,
source_positions=None,
_call_ctx=custom_fwd_bsyms[0]._call_ctx,
_import_ctx=custom_fwd_bsyms[0]._import_ctx,
_object_ctx=custom_fwd_bsyms[0]._object_ctx,
_executor=custom_fwd_bsyms[0]._executor,
)
orig_scopes[-1].append(bsym_of_custom_fwd)
custom_fwd_bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope()

# not augmented for when we don't need grad
trace_of_fwd = TraceCtx()
Expand All @@ -651,7 +641,7 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
with tracectx(trace_of_fwd):
prims.python_return(unwrapped_custom_forward_result)

si = SigInfo(custom_fwd_sym.name)
si = SigInfo(symbol_name)
for a in unwrapped_custom_forward_args:
if isinstance(a, Proxy):
si.args.append((a.name, None))
Expand All @@ -665,6 +655,24 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
def core_of_forward(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(trace_of_fwd, *args, **kwargs)

custom_fwd_meta = lambda *unwrapped_custom_forward_args: unwrapped_custom_forward_result

def bind_postprocess(bsym):
bsym._call_ctx = {}

custom_fwd_sym = Symbol(
name=symbol_name,
id=symbol_name,
meta=core_of_forward,
_bind_postprocess=bind_postprocess,
)

unwrapped_forward_result = custom_fwd_sym(*unwrapped_custom_forward_args)
forward_result = wrap(
unwrapped_forward_result,
provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[obj.provenance, custom_forward_result.provenance]),
)

thunder.executors.torchex._register_implementation(
custom_fwd_sym, core_of_forward, checker=thunder.executors.torchex._always_executable
)
Expand All @@ -676,7 +684,8 @@ def core_of_forward(*args, **kwargs):
trace_of_augmented_fwd = TraceCtx()
for bsym in custom_fwd_bsyms:
trace_of_augmented_fwd.add_bound_symbol(bsym)
trace_of_augmented_fwd.add_bound_symbol(prims.python_return.bind(augmented_bsym_output, output=()))
with tracectx(trace_of_augmented_fwd):
prims.python_return(augmented_bsym_output)
si = SigInfo(custom_fwd_sym.name)
for a in unwrapped_custom_forward_args:
if isinstance(a, Proxy):
Expand All @@ -691,8 +700,6 @@ def core_of_forward(*args, **kwargs):
def core_of_augmented_forward(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(trace_of_augmented_fwd, *args, **kwargs)

# core_of_augmented_forward = trace_of_augmented_fwd.python_callable(include_decorators=False)

@wraps(core_of_augmented_forward)
def augmented_custom_forward_rule(*args, **kwargs):
primal, residulas = core_of_augmented_forward(*args, **kwargs)
Expand All @@ -703,21 +710,12 @@ def augmented_custom_forward_rule(*args, **kwargs):

# Backward definition
custom_backward = custom_autograd_function_cls.backward
custom_bwd_bsyms: list[BoundSymbol] = []
jit_ctx.computation_trace.scopes = [custom_bwd_bsyms]

grads = tree_map(
lambda a: a.replace_name(f"grad_{a.name}"),
sequencify(unwrapped_custom_forward_result),
)
wrapped_grads = tree_map(lambda g: wrap(g, provenance=custom_forward_result.provenance), grads)
custom_backward_result = _interpret_call(custom_backward, wrapped_ctx, *wrapped_grads)
if custom_backward_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return custom_backward_result

trace_of_backward = TraceCtx()
for bsym in custom_bwd_bsyms:
trace_of_backward.add_bound_symbol(bsym)
trace_of_backward.add_bound_symbol(prims.python_return.bind(*unwrap(custom_backward_result), output=()))
bwd_si = SigInfo(f"{custom_fwd_sym.name}_backward")
for a in ctx_proxy.saved_tensors + grads:
if isinstance(a, Proxy):
Expand All @@ -728,6 +726,19 @@ def augmented_custom_forward_rule(*args, **kwargs):
trace_of_backward._siginfo = bwd_si
trace_of_backward.args = tuple(ctx_proxy.saved_tensors + grads)

jit_ctx.computation_trace.push_scope([])
wrapped_grads = tree_map(lambda g: wrap(g, provenance=custom_forward_result.provenance), grads)
custom_backward_result = _interpret_call(custom_backward, wrapped_ctx, *wrapped_grads)
if custom_backward_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return custom_backward_result

custom_bwd_bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope()

for bsym in custom_bwd_bsyms:
trace_of_backward.add_bound_symbol(bsym)
with tracectx(trace_of_backward):
prims.python_return.bind(*unwrap(custom_backward_result), output=())

@wraps(trace_of_backward.python_callable())
def bwd_trace_callable_interface(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(trace_of_backward, *args, **kwargs)
Expand Down Expand Up @@ -757,10 +768,7 @@ def backward_impl(*args, **kwargs):
return bwd_impl_callable(*new_args)

backward_impls[custom_fwd_sym.name] = backward_impl

# Cosmetic
jit_ctx.computation_trace.scopes = orig_scopes
return custom_forward_result
return forward_result


@register_general_jit_lookaside(torch.finfo)
Expand Down
10 changes: 9 additions & 1 deletion thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class PrimIDs(Enum):
TRANSPOSE = auto()
UNFOLD = auto()
VIEW = auto()
SHALLOW_COPY = auto() # a view copy
# Memory layout prims (Experimental)
STRIDE_ORDER = auto()
# Elementwise unary prims
Expand Down Expand Up @@ -1658,7 +1659,7 @@ def return_printer(
):
utils.check(
len(kwarg_printables) == 0,
lambda: f"Expected no kwargs for del but got {kwarg_printables}",
lambda: f"Expected no kwargs for return but got {kwarg_printables}",
exception_type=AssertionError,
)

Expand Down Expand Up @@ -3538,6 +3539,13 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro
view = make_prim(PrimIDs.VIEW, "view", meta=reshape_meta, tags=(OpTags.SHAPE_OP,))


def shallow_copy_meta(a: TensorProxy, /) -> TensorProxy:
return TensorProxy(like=a)


shallow_copy = make_prim(PrimIDs.SHALLOW_COPY, "shallow_copy", meta=shallow_copy_meta, tags=(OpTags.SHAPE_OP,))


def unfold_meta(a: TensorProxy, /, dim: int, size: int, step: int) -> TensorProxy:
dim = utils.canonicalize_dim(a.ndim, dim)
max_size = 1 if a.ndim == 0 else a.shape[dim]
Expand Down
3 changes: 3 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2169,3 +2169,6 @@ def _shape_impl(t):

shape = ex.register_operator("shape", meta=prims.shape_meta, fn=_shape_impl)
_register_implementation(prims.shape, shape, checker=_always_executable)

shallow_copy = ex.register_operator("shallow_copy", meta=prims.shallow_copy, fn=lambda x: x)
_register_implementation(prims.shallow_copy, shallow_copy, checker=_always_executable)
35 changes: 35 additions & 0 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,41 @@ def my_sin_with_wrong_backward(x):
gradcheck(jitted, (x,))


def test_autograd_function_empty_forward():

class Fn(torch.autograd.Function):
@staticmethod
def forward(self, x):
return x

@staticmethod
def backward(self, grad_x):
return 2 * grad_x

def fn(x):
# TODO: there still is a bug when the result is directly returned
return Fn.apply(x) * 3

a = torch.randn(2)
jfn = thunder.jit(fn)

ref = fn(a)
out = jfn(a)

assert_close(out, ref)

a = torch.randn(2, requires_grad=True)
go = torch.randn_like(a)

ref = fn(a)
out = jfn(a)
(grad,) = torch.autograd.grad(out, a, go)
(grad_ref,) = torch.autograd.grad(ref, a, go)

assert_close(out, ref)
assert_close(grad, grad_ref)


@requiresCUDA # I have not found a good other object to use
def test_cpp_property():
def fn():
Expand Down
1 change: 1 addition & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5812,6 +5812,7 @@ def check_overlap_ops():
tensor_split,
chunk,
getitem,
prims.shallow_copy,
}

# Add all auto-registered torch operators symbol that return tensor views to _syms_returning_views
Expand Down

0 comments on commit 10a8a44

Please sign in to comment.