Skip to content

Commit f12e20f

Browse files
committed
pass results through epilogue if present
1 parent 2303b30 commit f12e20f

File tree

2 files changed

+26
-18
lines changed

2 files changed

+26
-18
lines changed

thunder/__init__.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ def get_computation_and_inputs(*args, **kwargs):
560560
cs.last_prologue_execution_stop = time.time_ns()
561561

562562
cs.last_traces = computation_traces
563+
cs.last_epilogue_traces = epilogue_traces
563564
backward_traces = []
564565
cs.last_backward_traces = backward_traces
565566
cs.last_interpreter_log = last_interpreter_log
@@ -652,6 +653,7 @@ def get_computation_and_inputs(*args, **kwargs):
652653

653654
cs.last_traces += extraces
654655
cs.last_prologue_traces = [prologue_trc] + prologue_traces
656+
cs.last_epilogue_traces = epilogue_traces
655657
cs.last_prologue = pro
656658

657659
return cache_entry, inps, pro_to_epi
@@ -689,8 +691,7 @@ def fn_(*args, **kwargs) -> Any:
689691
result = data_for_autograd["output"]
690692

691693
if cache_entry.epilogue_fn:
692-
result, comp_to_epi = result
693-
cache_entry.epilogue_fn(*pro_to_epi, *comp_to_epi)
694+
result = cache_entry.epilogue_fn(*pro_to_epi, *result)
694695

695696
cs.last_trace_host_execution_stop = time.time_ns()
696697
cs.last_computation_execution_stop = cs.last_trace_host_execution_stop
@@ -804,6 +805,18 @@ def last_prologue_traces(fn) -> TraceCtx:
804805
return cs.last_prologue_traces
805806

806807

808+
def last_epilogue_traces(fn) -> TraceCtx:
809+
"""Obtains the list of prologue traces that have been produced for the last run of the function and the selected prologue."""
810+
cs = compile_stats(fn)
811+
if cs is None:
812+
raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.")
813+
if (
814+
cs.last_prologue_traces is None
815+
): # this is prologue on purpose because we might legitimately not have an epilogue trace
816+
raise TypeError(f"{fn} doesn't seem to have been called yet.")
817+
return cs.last_epilogue_traces
818+
819+
807820
def cache_option(fn) -> CACHE_OPTIONS:
808821
"""Returns the cache options set when JITting the function."""
809822
cd = compile_data(fn)

thunder/core/jit_ext.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,14 +1565,15 @@ def process_recorded_modifications(ctx, epilogue_trace):
15651565
name = ".".join(name + [k])
15661566
with tracectx(epilogue_trace):
15671567
bsym = prims.pack_buffer.bind(root_module_proxy, name, value.value, output=None)
1568-
epilogue_trace.bound_symbols.append(bsym)
1568+
assert epilogue_trace.bound_symbols[-1].sym == prims.python_return
1569+
epilogue_trace.bound_symbols.insert(-1, bsym)
15691570
else:
15701571
raise NotImplementedError(f"Modifications {inst} on dicts are not supported")
15711572
else:
15721573
raise NotImplementedError(f"Modifications of {type(uvalue).__name__} objects are not supported")
15731574

15741575

1575-
def bind_inputs(name, trace, input_vars, input_proxies):
1576+
def bind_inputs(name, trace, input_proxies):
15761577
# Unpacks inputs into the computation trace
15771578
# TODO This currently does the unpacks at the end of the trace, then moves them to the beginning, there's
15781579
# almost certainly a more elegant way to do this
@@ -1585,7 +1586,7 @@ def bind_inputs(name, trace, input_vars, input_proxies):
15851586
trace.bound_symbols = bsyms[-len(input_proxies) :] + bsyms[: -len(input_proxies)]
15861587

15871588
si = SigInfo(name)
1588-
si.args = [(v.proxy.name, None) for v in input_vars]
1589+
si.args = [(p.name, None) for v in input_proxies]
15891590
trace._siginfo = si
15901591
trace.args = input_proxies
15911592

@@ -1655,7 +1656,9 @@ def thunder_general_jit(
16551656
with general_jit_ctx(ctx):
16561657
with tracectx(computation_trace):
16571658
result = jfn(*args, **kwargs)
1659+
with tracectx(epilogue_trace):
16581660
prims.python_return(result)
1661+
with tracectx(computation_trace):
16591662
computation_trace.set_current_source_location(None, None)
16601663
process_recorded_modifications(ctx, epilogue_trace)
16611664
last_interpreter_log = jfn._last_interpreter_log
@@ -1674,29 +1677,21 @@ def thunder_general_jit(
16741677
comp_to_epi.append(i)
16751678
else:
16761679
pro_to_epi.append(i)
1677-
comp_to_epi = tuple(comp_to_epi)
1680+
16781681
comp_to_epi_proxies = tuple(v.proxy for v in comp_to_epi)
16791682
pro_to_epi = tuple(pro_to_epi)
16801683

1681-
if epilogue_trace.bound_symbols:
1682-
with tracectx(computation_trace):
1683-
last = computation_trace.bound_symbols.pop(-1)
1684-
assert last.sym.id == prims.PrimIDs.RETURN
1685-
prims.python_return((result, comp_to_epi_proxies))
1686-
1687-
with tracectx(epilogue_trace):
1688-
prims.python_return(None)
1689-
else:
1690-
epilogue_trace = None
1684+
with tracectx(computation_trace):
1685+
prims.python_return(comp_to_epi_proxies)
16911686

16921687
pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs)
16931688

16941689
proxy_order = {id(p): i for i, p in enumerate(pro_to_comp_proxies)}
16951690
pro_to_comp = tuple(sorted(pro_to_comp, key=lambda v: proxy_order[id(v.proxy)]))
16961691

1697-
bind_inputs("computation", computation_trace, pro_to_comp, pro_to_comp_proxies)
1692+
bind_inputs("computation", computation_trace, pro_to_comp_proxies)
16981693
if epilogue_trace:
1699-
bind_inputs("epilogue", epilogue_trace, pro_to_epi + comp_to_epi, pro_to_epi_proxies + comp_to_epi_proxies)
1694+
bind_inputs("epilogue", epilogue_trace, pro_to_epi_proxies + comp_to_epi_proxies)
17001695

17011696
# Returns a new swapmap dictionary which has the keys (ctx._proxy_swapmap.key() & variableify(proxies))
17021697
def restrict_proxy_swapmap(proxies: tuple[Proxy]) -> dict[Variable, Proxy]:

0 commit comments

Comments
 (0)