Skip to content

Commit

Permalink
remove redundant copy bsym check
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Jun 16, 2024
1 parent 2f3f467 commit 205c29e
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,17 +436,14 @@ def is_inplace(bsym: BoundSymbol) -> bool:
hasattr(thunder.torch, functional_sym_name),
lambda: f"{functional_sym_name}, out-of-place impl of {bsym.sym.id=} not found in `thunder.torch` namespace",
)
sub_bsyms: list[BoundSymbol] = bsym.subsymbols
check(sub_bsyms, lambda: f"{bsym.sym.id=} expected to have subsymbols but {bsym.subsymbols=}")
copy_bsym = sub_bsyms[-1]
check(
copy_bsym.sym.id == prims.PrimIDs.COPY_,
lambda: f"bsym.subsymbols[-1] expected to be {prims.PrimIDs.COPY_} but {copy_bsym.sym.id=}",
)
if (copy_to := copy_bsym.flat_proxy_args[1]) in trace_args_set:
copy_bsym = bsym.subsymbols[-1]
copy_return = copy_bsym.flat_proxy_outs[0]
copy_from = copy_bsym.flat_proxy_args[0]
copy_to = copy_bsym.flat_proxy_args[1]
if copy_to in trace_args_set:
new_bsyms.append(new_bsym)
else:
swap_map[variableify(copy_bsym.flat_proxy_outs[0])] = copy_bsym.flat_proxy_args[0]
swap_map[variableify(copy_return)] = copy_from
new_bsym.subsymbols = new_bsym.subsymbols[:-1]
new_bsym = new_bsym.from_bsym_swap_proxies(swap_map)
functional_sym: Symbol = getattr(thunder.torch, functional_sym_name)
Expand Down

0 comments on commit 205c29e

Please sign in to comment.