Skip to content

Commit

Permalink
create and check the list of tensor proxies
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Jun 24, 2024
1 parent 7230a04 commit ddd5d1a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,12 @@ def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool:
if in_tensor in trace_args_set:
continue
prod_bsym: BoundSymbol = producer_bsyms[in_tensor]
if not prod_bsym.flat_proxy_args:

flat_tensor_proxy_args = tuple(filter(lambda p: isinstance(p, TensorProxy), prod_bsym.flat_args))
if not flat_tensor_proxy_args:
# assuming `prod_bsym` is a tensor factory method such as `torch.empty`, `torch.zeros`, and `torch.ones`
continue
orig_tensor = prod_bsym.flat_proxy_args[0]
orig_tensor = flat_tensor_proxy_args[0]
consumer_of_orig_tensor = consumers[orig_tensor]
# When the orig tensor is not used by consumers other than `prod_bsym`, it'd be safe.
# Otherwise, we'd need to replace the use of ``orig_tensor`` with a view, unless the original
Expand Down

0 comments on commit ddd5d1a

Please sign in to comment.