diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 3926a938ff..89a7353583 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -415,10 +415,12 @@ def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool: if in_tensor in trace_args_set: continue prod_bsym: BoundSymbol = producer_bsyms[in_tensor] - if not prod_bsym.flat_proxy_args: + + flat_tensor_proxy_args = tuple(filter(lambda p: isinstance(p, TensorProxy), prod_bsym.flat_args)) + if not flat_tensor_proxy_args: # assuming `prod_bsym` is a tensor factory method such as `torch.empty`, `torch.zeros`, and `torch.ones` continue - orig_tensor = prod_bsym.flat_proxy_args[0] + orig_tensor = flat_tensor_proxy_args[0] consumer_of_orig_tensor = consumers[orig_tensor] # When the orig tensor is not used by consumers other than `prod_bsym`, it'd be safe. # Otherwise, we'd need to replace the use of ``orig_tensor`` with a view, unless the original