Skip to content

Commit

Permalink
keep torchsymbol
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 14, 2024
1 parent f0cf8ae commit 069bc57
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 77 deletions.
96 changes: 20 additions & 76 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,10 +772,9 @@ def grad_transform(*args, **kwargs):
# ref: https://github.com/pytorch/pytorch/blob/38114ec/torch/_functorch/autograd_function.py#L715-L752
@register_general_jit_lookaside(torch.ops.higher_order.autograd_function_apply)
def _general_jit_torch_ops_higher_order_autograd_function_apply(fwd, bwd, *fwd_args, **fwd_kwargs):
from thunder.core import utils
from thunder.core.baseutils import sequencify
from thunder.core.pytree import tree_flatten, tree_map
from thunder.core.transforms import VJPDual, augmented_forward_impls, backward_impls
from thunder.core.pytree import tree_map
from thunder.core.trace_interpreter import interpret_trace

def _generate_random_str_id() -> str:
import secrets
Expand All @@ -784,8 +783,6 @@ def _generate_random_str_id() -> str:
length = 5
return "".join(secrets.choice(string.ascii_lowercase) for _ in range(length))

jit_ctx: JitCtx = get_jit_ctx()

args_tensor_mask = unwrap(fwd_kwargs["args_tensor_mask"])
# TODO(crcrpar): Think about making use of `non_differentiable_idx`
# note that this key is quite new: https://github.com/pytorch/pytorch/pull/134087
Expand All @@ -798,51 +795,11 @@ def _generate_random_str_id() -> str:
return aug_fwd_trace
aug_fwd_result = aug_fwd_trace.output
output, saved_values = unwrap(aug_fwd_result)
wrapped_output = wrap(output, provenance=aug_fwd_provenance)

unwrapped_fwd_args = tree_map(lambda t: unwrap(t), new_fwd_args)[1:]

fwd_bsyms: list[BoundSymbol] = aug_fwd_trace.bound_symbols
producer_map = utils.producers(fwd_bsyms)
tensor_to_prod_bsym: dict[Variable, BoundSymbol] = {}
for p in tree_flatten((output, saved_values))[0]:
if not isinstance(p, TensorProxy):
continue
if p in producer_map:
prod_bsym = producer_map[p]
tensor_to_prod_bsym[variableify(p)] = prod_bsym
prod_bsym_to_tensor = {v: unvariableify(k) for k, v in tensor_to_prod_bsym.items()}

sym_id = f"autograd_function_apply_{_generate_random_str_id()}"
vanilla_fwd_trace = TraceCtx()
vanilla_fwd_trace.args = unwrapped_fwd_args
unpack_bsyms = [
prims.unpack_trivial.bind(a, name=a.name, output=a)
for a in filter(lambda a: isinstance(a, Proxy), vanilla_fwd_trace.args)
]
for bsym in unpack_bsyms + fwd_bsyms[:-1]:
vanilla_fwd_trace.add_bound_symbol(bsym)
vanilla_fwd_trace.add_bound_symbol(prims.python_return.bind(output, output=()))
vanilla_fwd_trace._siginfo = SigInfo.from_name_and_args(sym_id, vanilla_fwd_trace.args)

@wraps(vanilla_fwd_trace.python_callable())
def core_of_fwd(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(vanilla_fwd_trace, *args, **kwargs)

sym = jit_ctx.ad_hoc_executor.register_operator(
vanilla_fwd_trace._siginfo.name,
like=core_of_fwd,
)
unwrapped_forward_result = sym(*unwrapped_fwd_args)
unwrapped_fwd_args = tree_map(lambda t: unwrap(t), new_fwd_args)

# Define augmented fwd rule and backward rule on the fly.
augmented_fwd_trace = TraceCtx()
augmented_fwd_trace.args = vanilla_fwd_trace.args
for bsym in unpack_bsyms + fwd_bsyms[:-1]:
augmented_fwd_trace.add_bound_symbol(bsym)
augmented_fwd_trace.add_bound_symbol(prims.python_return.bind(output, saved_values, output=()))
si = SigInfo.from_name_and_args(f"augmented_autograd_function_apply_{sym_id}", augmented_fwd_trace.args)
augmented_fwd_trace._siginfo = si
tmp_name = _generate_random_str_id()
aug_fwd_trace.args = unwrapped_fwd_args
aug_fwd_trace._siginfo = SigInfo.from_name_and_args(tmp_name, aug_fwd_trace.args)

grads = sequencify(tree_map(lambda t: TensorProxy(like=t), sequencify(output)))
bwd_args = (wrap_const(None),)
Expand All @@ -855,42 +812,29 @@ def core_of_fwd(*args, **kwargs):
)
if bwd_trace is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return bwd_trace
bwd_trace.args = bwd_tensor_args
bwd_trace.args = (None,) + bwd_tensor_args
bwd_unpack_bsyms = [
prims.unpack_trivial.bind(a, name=a.name, output=a)
for a in filter(lambda a: isinstance(a, Proxy), bwd_trace.args)
]
bwd_trace.bound_symbols = bwd_unpack_bsyms + bwd_trace.bound_symbols
bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{sym_id}", saved_values + grads)

@wraps(bwd_trace.python_callable())
def bwd_impl_callable(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(bwd_trace, *args, **kwargs)
bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{tmp_name}", saved_values + grads)

@wraps(core_of_fwd)
def grad_transform(*args, **kwargs):
from thunder.core.transforms import get_grad, put_grads
@wraps(aug_fwd_trace.python_callable())
def augmented_forward_caller(*args, **kwargs):
return interpret_trace(aug_fwd_trace, *args, **kwargs)

primal, residuals = thunder.core.trace_interpreter.interpret_trace(
augmented_fwd_trace,
*args,
**kwargs,
)
grads = tree_map(lambda t: get_grad(t), sequencify(primal))
bwd_args = grads + residuals
result = bwd_impl_callable(*bwd_args)
put_grads(args, result)
return primal
@wraps(bwd_trace.python_callable())
def backward_caller(*args, **kwargs):
return interpret_trace(bwd_trace, *args, **kwargs)

jit_ctx.ad_hoc_executor.register_implementation(
sym,
execution_transform=core_of_fwd,
grad_transform=grad_transform,
)
from thunder.torch import autograd_function_apply

return wrap(
unwrapped_forward_result,
provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[fwd.provenance, bwd.provenance]),
return interpreter_needs_wrap(autograd_function_apply)(
wrap_const(augmented_forward_caller),
wrap_const(backward_caller),
*fwd_args,
**fwd_kwargs,
)


Expand Down
6 changes: 5 additions & 1 deletion thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,11 @@ def my_sin(x):
bsym_str_ids = tuple(
bsym.sym.id for bsym in initial_computation_trace.bound_symbols if isinstance(bsym.sym.id, str)
)
assert any(bsid.startswith("autograd_function_apply") for bsid in bsym_str_ids), bsym_str_ids
assert any(
bsym.sym.id == "torch.ops.higher_order.autograd_function_apply"
for bsym in initial_computation_trace.bound_symbols
if isinstance(bsym.sym.id, str)
)

grad = torch.rand_like(y)
actual_grad = torch.autograd.grad(y, x, grad)
Expand Down
40 changes: 40 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5583,6 +5583,46 @@ def wait(slf) -> None:
utils.check(False, lambda: "torch.distributed is not available")


# ref: https://github.com/pytorch/pytorch/blob/b99ef1a/torch/_functorch/autograd_function.py#L715-L752
@torchsymbol(
torch.ops.higher_order.autograd_function_apply,
id="torch.ops.higher_order.autograd_function_apply",
is_method=False,
)
def autograd_function_apply(
fwd: Callable[list[TensorProxy], TensorProxy | tuple[TensorProxy, ...]],
bwd: Callable[list[TensorProxy], TensorProxy | tuple[TensorProxy, ...]],
*args: Any,
args_tensor_mask: Sequence[bool] | None,
non_differentiable_idx: Sequence[int] | None = None,
) -> TensorProxy | tuple[TensorProxy, ...]:
result, saved_for_backward = fwd(None, *args)
return result


@register_augmented_forward("torch.ops.higher_order.autograd_function_apply")
def augmented_forward_autograd_function_apply(
fwd: Callable[list[Any | TensorProxy], TensorProxy | tuple[TensorProxy, ...]],
bwd: Callable[list[Any | TensorProxy], tuple[TensorProxy, ...]],
*args: Any,
args_tensor_mask: Sequence[bool],
non_differentiable_idx: Sequence[int] | None = None,
) -> tuple[TensorProxy | tuple[TensorProxy, ...], tuple[Any, ...]]:
result, saved_for_backward = fwd(None, *args)
return result, (saved_for_backward, bwd, args_tensor_mask, non_differentiable_idx)


@register_backward("torch.ops.higher_order.autograd_function_apply")
def backward_autograd_function_apply(
saved_for_backward: tuple[Any, ...],
bwd: Callable[list[Any | TensorProxy], tuple[TensorProxy, ...]],
args_tensor_mask: Sequence[bool],
non_differentiable_idx: Sequence[int] | None = None,
*grad_output: Sequence[TensorProxy],
) -> tuple[Any, ...]:
return bwd(None, *grad_output, *saved_for_backward)


@torchsymbol(
torch.amp.autocast_mode._enter_autocast,
id="torch.amp.autocast_mode._enter_autocast",
Expand Down

0 comments on commit 069bc57

Please sign in to comment.