File tree Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Original file line number Diff line number Diff line change @@ -1577,6 +1577,11 @@ def _get_process_group_from(*fn_and_args) -> Optional["ProcessGroup"]:
1577
1577
return found_pg
1578
1578
1579
1579
1580
+ def update_tags (proxy_swapmap : dict [Variable , Proxy ]) -> None :
1581
+ for old , new in proxy_swapmap .items ():
1582
+ new .tags .update (unvariableify (old ).tags )
1583
+
1584
+
1580
1585
def thunder_general_jit (
1581
1586
fn : Callable ,
1582
1587
args : tuple [Any , ...],
@@ -1682,6 +1687,8 @@ def restrict_proxy_swapmap(proxies: tuple[Proxy]) -> dict[Variable, Proxy]:
1682
1687
# Update prologue trace by renaming proxies which are passed from prologue to the computation trace
1683
1688
prologue_trace = _apply_trace_proxy_rename (prologue_trace , restrict_proxy_swapmap (pro_to_comp_proxies ))
1684
1689
1690
+ update_tags (ctx ._proxy_swapmap )
1691
+
1685
1692
# Update computation trace by renaming proxies which are in the ctx._proxy_swapmap
1686
1693
computation_trace = _apply_trace_proxy_rename (computation_trace , ctx ._proxy_swapmap , "computation" )
1687
1694
You can’t perform that action at this time.
0 commit comments