diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index f8ca3adba1..4b7935b3c7 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -1691,6 +1691,16 @@ def thunder_general_jit( bind_inputs("computation", computation_trace, pro_to_comp_proxies) if epilogue_trace: + l = len(epilogue_trace.bound_symbols) + if l == 0: + epilogue_trace = None + elif l == 1: + (r,) = epilogue_trace.bound_symbols + assert r.sym == prims.python_return + epilogue_trace = None + + if epilogue_trace: + print(epilogue_trace) bind_inputs("epilogue", epilogue_trace, pro_to_epi_proxies + comp_to_epi_proxies) # Returns a new swapmap dictionary which has the keys (ctx._proxy_swapmap.key() & variableify(proxies))