Skip to content

Commit 14e6c9b

Browse files
authored
Check number of proxy args before indexing (#642)
1 parent 6fb7e28 commit 14e6c9b

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

thunder/core/transform_common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,12 @@ def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool:
415415
if in_tensor in trace_args_set:
416416
continue
417417
prod_bsym: BoundSymbol = producer_bsyms[in_tensor]
418-
orig_tensor = prod_bsym.flat_proxy_args[0]
418+
419+
flat_tensor_proxy_args = tuple(filter(lambda p: isinstance(p, TensorProxy), prod_bsym.flat_args))
420+
if not flat_tensor_proxy_args:
421+
# assuming `prod_bsym` is a tensor factory method such as `torch.empty`, `torch.zeros`, and `torch.ones`
422+
continue
423+
orig_tensor = flat_tensor_proxy_args[0]
419424
consumer_of_orig_tensor = consumers[orig_tensor]
420425
# When the orig tensor is not used by consumers other than `prod_bsym`, it'd be safe.
421426
# Otherwise, we'd need to replace the use of ``orig_tensor`` with a view, unless the original

0 commit comments

Comments
 (0)