Skip to content

Commit

Permalink
check num of proxy args
Browse files Browse the repository at this point in the history
there are some ops, mainly factory ops like `torch.empty`, that do not take any tensor
proxies as their args/kwargs.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Jun 24, 2024
1 parent 8c953b3 commit 7230a04
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool:
if in_tensor in trace_args_set:
continue
prod_bsym: BoundSymbol = producer_bsyms[in_tensor]
if not prod_bsym.flat_proxy_args:
# assuming `prod_bsym` is a tensor factory method such as `torch.empty`, `torch.zeros`, and `torch.ones`
continue
orig_tensor = prod_bsym.flat_proxy_args[0]
consumer_of_orig_tensor = consumers[orig_tensor]
# When the orig tensor is not used by consumers other than `prod_bsym`, it'd be safe.
Expand Down

0 comments on commit 7230a04

Please sign in to comment.