From 7230a047d97d185230aa467d178765a2a26a90cf Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 25 Jun 2024 00:52:31 +0900 Subject: [PATCH] check num of proxy args there are some ops, mainly factory ops like `torch.empty`, that do not take any tensor proxies as their args/kwargs. Signed-off-by: Masaki Kozuki --- thunder/core/transform_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 122dfee68a..3926a938ff 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -415,6 +415,9 @@ def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool: if in_tensor in trace_args_set: continue prod_bsym: BoundSymbol = producer_bsyms[in_tensor] + if not prod_bsym.flat_proxy_args: + # assuming `prod_bsym` is a tensor factory method such as `torch.empty`, `torch.zeros`, and `torch.ones` + continue orig_tensor = prod_bsym.flat_proxy_args[0] consumer_of_orig_tensor = consumers[orig_tensor] # When the orig tensor is not used by consumers other than `prod_bsym`, it'd be safe.