Skip to content

Commit

Permalink
cosmetic
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 14, 2024
1 parent 069bc57 commit 7729af1
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ def _general_jit_torch_ops_higher_order_autograd_function_apply(fwd, bwd, *fwd_a
from thunder.core.baseutils import sequencify
from thunder.core.pytree import tree_map
from thunder.core.trace_interpreter import interpret_trace
from thunder.torch import autograd_function_apply

def _generate_random_str_id() -> str:
import secrets
Expand Down Expand Up @@ -802,23 +803,23 @@ def _generate_random_str_id() -> str:
aug_fwd_trace._siginfo = SigInfo.from_name_and_args(tmp_name, aug_fwd_trace.args)

grads = sequencify(tree_map(lambda t: TensorProxy(like=t), sequencify(output)))
bwd_args = (wrap_const(None),)
bwd_tensor_args = grads + tuple(saved_values)
wrapped_bwd_tensor_args = tree_map(lambda t: wrap(t, provenance=aug_fwd_provenance), bwd_tensor_args)
bwd_args = (None,) + bwd_tensor_args
wrapped_bwd_args = tree_map(lambda t: wrap(t, provenance=aug_fwd_provenance), bwd_args)
bwd_trace, bwd_trace_provenance = _convert_pytorchfunc_to_thundertrace(
bwd,
False,
*(bwd_args + wrapped_bwd_tensor_args),
*wrapped_bwd_args,
)
if bwd_trace is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return bwd_trace
bwd_trace.args = (None,) + bwd_tensor_args
bwd_trace.args = bwd_args
bwd_unpack_bsyms = [
prims.unpack_trivial.bind(a, name=a.name, output=a)
for a in filter(lambda a: isinstance(a, Proxy), bwd_trace.args)
]
bwd_trace.bound_symbols = bwd_unpack_bsyms + bwd_trace.bound_symbols
bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{tmp_name}", saved_values + grads)
bwd_trace._siginfo = SigInfo.from_name_and_args(f"bwd_{tmp_name}", bwd_trace.args)

@wraps(aug_fwd_trace.python_callable())
def augmented_forward_caller(*args, **kwargs):
Expand All @@ -828,8 +829,6 @@ def augmented_forward_caller(*args, **kwargs):
def backward_caller(*args, **kwargs):
return interpret_trace(bwd_trace, *args, **kwargs)

from thunder.torch import autograd_function_apply

return interpreter_needs_wrap(autograd_function_apply)(
wrap_const(augmented_forward_caller),
wrap_const(backward_caller),
Expand Down

0 comments on commit 7729af1

Please sign in to comment.