Skip to content

Commit

Permalink
Iterate over fwd_args for hopefully more precise new_fwd_args
Browse files Browse the repository at this point in the history
in the lookaside of `torch.ops.higher_order.autograd_function_apply`
that could take `torch.SymInt` and/or `torch.SymFloat` as its arguments
passed to GraphModule's forward/backward.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Dec 18, 2024
1 parent d201c8c commit 651366f
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,23 @@ def _generate_random_str_id() -> str:
# note that this key is quite new: https://github.com/pytorch/pytorch/pull/134087
# non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx")
length_of_tensor_args = sum(args_tensor_mask)
new_fwd_args = (wrap_const(None),) + fwd_args[:length_of_tensor_args]

# N.B.(crcrpar) When `torch.compile(..., dynamic=True)`,
# GraphModules' forward seem to take `SymInt` and other values
# as its argument with some probability. Though that piece of information unfortunately
# does not seem to be indicated in ``args_tensor_`` nor ``non_differentiable_idx``.
# Thus we optimistically iterate over ``fwd_args`` and gather non-tensor values to ``fwd_args``.
new_fwd_args = []
for i, v in enumerate(fwd_args):
if i < length_of_tensor_args:
print(unwrap(v))
new_fwd_args.append(v)
else:
# note(crcrpar): we might want to include `FutureTensorProxy` and
# a proxy of tensor subclass in the near future.
if not isinstance(unwrap(v), TensorProxy):
new_fwd_args.append(v)
new_fwd_args = (wrap_const(None),) + tuple(new_fwd_args)

aug_fwd_trace, aug_fwd_provenance = _convert_pytorchfunc_to_thundertrace(fwd, False, *new_fwd_args)
if aug_fwd_trace is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
Expand Down

0 comments on commit 651366f

Please sign in to comment.