Skip to content

Commit e65dae0

Browse files
authored
Preserve tags when swapping proxies (#1189)
1 parent b5cd708 commit e65dae0

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

thunder/core/jit_ext.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,6 +1577,11 @@ def _get_process_group_from(*fn_and_args) -> Optional["ProcessGroup"]:
15771577
return found_pg
15781578

15791579

1580+
def update_tags(proxy_swapmap: dict[Variable, Proxy]) -> None:
1581+
for old, new in proxy_swapmap.items():
1582+
new.tags.update(unvariableify(old).tags)
1583+
1584+
15801585
def thunder_general_jit(
15811586
fn: Callable,
15821587
args: tuple[Any, ...],
@@ -1682,6 +1687,8 @@ def restrict_proxy_swapmap(proxies: tuple[Proxy]) -> dict[Variable, Proxy]:
16821687
# Update prologue trace by renaming proxies which are passed from prologue to the computation trace
16831688
prologue_trace = _apply_trace_proxy_rename(prologue_trace, restrict_proxy_swapmap(pro_to_comp_proxies))
16841689

1690+
update_tags(ctx._proxy_swapmap)
1691+
16851692
# Update computation trace by renaming proxies which are in the ctx._proxy_swapmap
16861693
computation_trace = _apply_trace_proxy_rename(computation_trace, ctx._proxy_swapmap, "computation")
16871694

0 commit comments

Comments
 (0)