diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index ce774c4936..2d99879b20 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -792,7 +792,22 @@ def _generate_random_str_id() -> str: # note that this key is quite new: https://github.com/pytorch/pytorch/pull/134087 # non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx") length_of_tensor_args = sum(args_tensor_mask) - new_fwd_args = (wrap_const(None),) + fwd_args[:length_of_tensor_args] + + # N.B.(crcrpar) When `torch.compile(..., dynamic=True)`, + # GraphModules' forward seem to take `SymInt` and other values + # as its argument with some probability. Though that piece of information unfortunately + # does not seem to be indicated in ``args_tensor_`` nor ``non_differentiable_idx``. + # Thus we optimistically iterate over ``fwd_args`` and gather non-tensor values to ``fwd_args``. + new_fwd_args = [] + for i, v in enumerate(fwd_args): + if i < length_of_tensor_args: + new_fwd_args.append(v) + else: + # note(crcrpar): we might want to include `FutureTensorProxy` and + # a proxy of tensor subclass in the near future. + if not isinstance(unwrap(v), TensorProxy): + new_fwd_args.append(v) + new_fwd_args = (wrap_const(None),) + tuple(new_fwd_args) aug_fwd_trace, aug_fwd_provenance = _convert_pytorchfunc_to_thundertrace(fwd, False, *new_fwd_args) if aug_fwd_trace is INTERPRETER_SIGNALS.EXCEPTION_RAISED: