diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index ce774c4936..9b22e15b48 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -792,7 +792,20 @@ def _generate_random_str_id() -> str: # note that this key is quite new: https://github.com/pytorch/pytorch/pull/134087 # non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx") length_of_tensor_args = sum(args_tensor_mask) - new_fwd_args = (wrap_const(None),) + fwd_args[:length_of_tensor_args] + + # note(crcrpar) When `torch.compile(..., dynamic=True)`, + # GraphModules' forward seem to take `SymInt` and other values + # as its argument with some probability. Though that piece of information unfortunately + # does not seem to be indicated in ``args_tensor_`` nor ``non_differentiable_idx``. + # Thus we optimistically iterate over ``fwd_args`` and gather non-tensor values to ``fwd_args``. + new_fwd_args = [] + for i, v in enumerate(fwd_args): + if i < length_of_tensor_args: + new_fwd_args.append(v) + else: + if not isinstance(unwrap(v), (torch.Tensor, TensorProxy)): + new_fwd_args.append(v) + new_fwd_args = (wrap_const(None),) + tuple(new_fwd_args) aug_fwd_trace, aug_fwd_provenance = _convert_pytorchfunc_to_thundertrace(fwd, False, *new_fwd_args) if aug_fwd_trace is INTERPRETER_SIGNALS.EXCEPTION_RAISED: