@@ -1565,14 +1565,15 @@ def process_recorded_modifications(ctx, epilogue_trace):
1565
1565
name = "." .join (name + [k ])
1566
1566
with tracectx (epilogue_trace ):
1567
1567
bsym = prims .pack_buffer .bind (root_module_proxy , name , value .value , output = None )
1568
- epilogue_trace .bound_symbols .append (bsym )
1568
+ assert epilogue_trace .bound_symbols [- 1 ].sym == prims .python_return
1569
+ epilogue_trace .bound_symbols .insert (- 1 , bsym )
1569
1570
else :
1570
1571
raise NotImplementedError (f"Modifications { inst } on dicts are not supported" )
1571
1572
else :
1572
1573
raise NotImplementedError (f"Modifications of { type (uvalue ).__name__ } objects are not supported" )
1573
1574
1574
1575
1575
- def bind_inputs (name , trace , input_vars , input_proxies ):
1576
+ def bind_inputs (name , trace , input_proxies ):
1576
1577
# Unpacks inputs into the computation trace
1577
1578
# TODO This currently does the unpacks at the end of the trace, then moves them to the beginning, there's
1578
1579
# almost certainly a more elegant way to do this
@@ -1585,7 +1586,7 @@ def bind_inputs(name, trace, input_vars, input_proxies):
1585
1586
trace .bound_symbols = bsyms [- len (input_proxies ) :] + bsyms [: - len (input_proxies )]
1586
1587
1587
1588
si = SigInfo (name )
1588
- si .args = [(v . proxy . name , None ) for v in input_vars ]
1589
+ si .args = [(p . name , None ) for v in input_proxies ]
1589
1590
trace ._siginfo = si
1590
1591
trace .args = input_proxies
1591
1592
@@ -1655,7 +1656,9 @@ def thunder_general_jit(
1655
1656
with general_jit_ctx (ctx ):
1656
1657
with tracectx (computation_trace ):
1657
1658
result = jfn (* args , ** kwargs )
1659
+ with tracectx (epilogue_trace ):
1658
1660
prims .python_return (result )
1661
+ with tracectx (computation_trace ):
1659
1662
computation_trace .set_current_source_location (None , None )
1660
1663
process_recorded_modifications (ctx , epilogue_trace )
1661
1664
last_interpreter_log = jfn ._last_interpreter_log
@@ -1674,29 +1677,21 @@ def thunder_general_jit(
1674
1677
comp_to_epi .append (i )
1675
1678
else :
1676
1679
pro_to_epi .append (i )
1677
- comp_to_epi = tuple ( comp_to_epi )
1680
+
1678
1681
comp_to_epi_proxies = tuple (v .proxy for v in comp_to_epi )
1679
1682
pro_to_epi = tuple (pro_to_epi )
1680
1683
1681
- if epilogue_trace .bound_symbols :
1682
- with tracectx (computation_trace ):
1683
- last = computation_trace .bound_symbols .pop (- 1 )
1684
- assert last .sym .id == prims .PrimIDs .RETURN
1685
- prims .python_return ((result , comp_to_epi_proxies ))
1686
-
1687
- with tracectx (epilogue_trace ):
1688
- prims .python_return (None )
1689
- else :
1690
- epilogue_trace = None
1684
+ with tracectx (computation_trace ):
1685
+ prims .python_return (comp_to_epi_proxies )
1691
1686
1692
1687
pro_to_comp_proxies , pro_to_epi_proxies = unpack_inputs (ctx , prologue_trace , pro_to_comp , pro_to_epi , args , kwargs )
1693
1688
1694
1689
proxy_order = {id (p ): i for i , p in enumerate (pro_to_comp_proxies )}
1695
1690
pro_to_comp = tuple (sorted (pro_to_comp , key = lambda v : proxy_order [id (v .proxy )]))
1696
1691
1697
- bind_inputs ("computation" , computation_trace , pro_to_comp , pro_to_comp_proxies )
1692
+ bind_inputs ("computation" , computation_trace , pro_to_comp_proxies )
1698
1693
if epilogue_trace :
1699
- bind_inputs ("epilogue" , epilogue_trace , pro_to_epi + comp_to_epi , pro_to_epi_proxies + comp_to_epi_proxies )
1694
+ bind_inputs ("epilogue" , epilogue_trace , pro_to_epi_proxies + comp_to_epi_proxies )
1700
1695
1701
1696
# Returns a new swapmap dictionary which has the keys (ctx._proxy_swapmap.key() & variableify(proxies))
1702
1697
def restrict_proxy_swapmap (proxies : tuple [Proxy ]) -> dict [Variable , Proxy ]:
0 commit comments