diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 5ccce153b6..b000d4bf61 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -13,7 +13,11 @@ from thunder.core.trace import from_trace, TraceCtx, TraceProvenance from thunder.core.transform_common import dce from thunder.core.pytree import tree_flatten -from thunder.executors.passes import update_fusion_call_ctx +from thunder.executors.passes import ( + update_fusion_call_ctx, + _transform_for_operator_executor_execution, + transform_for_execution, +) from thunder.executors.utils import Region from thunder.extend import FusionExecutor, register_executor, ImplInfo from thunder.core.compile_data import get_compile_option @@ -55,36 +59,31 @@ def make_compiled( from thunder import trace from thunder.core.transforms import eval_trace from thunder.executors.torchex import no_autocast + from thunder.executors.torchex import ex as torchex + from thunder.executors.pythonex import ex as pythonex + from thunder.core.codeutils import SigInfo # Here we construct a trace that will be used to compile the function + # TODO: maybe we should have a utility that does this properly region_trace = TraceCtx(None) region_trace.bound_symbols = list(bsyms) region_trace.args = sorted_unique_inputs region_trace.kwargs = {} region_trace.bound_symbols.append(prims.python_return.bind(sorted_unique_outputs, output=())) + for a in region_trace.args: + region_trace.add_name(a.name) + for bsym in region_trace.bound_symbols: + for o in bsym.flat_outs: + if o is not None: # TODO: investigate + region_trace.add_name(o.name) + + # maybe make this the default if no sig info is present? + region_trace._siginfo = SigInfo("to_be_compiled") + region_trace._siginfo.args = [(a.name, None) for a in region_trace.args] + + torchex_trace = transform_for_execution(region_trace, executors_list=(torchex,)) + trace_callable = torchex_trace.python_callable(include_decorators=False) - def torch_interpreted_func(*args): - return eval_trace(region_trace, *args, symbol_mapper=to_torch_translator) - - # Here instead of using thunder.trace we could use torch_trace = - # passes._transform_for_operator_executor_execution(region_trace, [torchex]) - # but then we would need to handle unpacking of the args explicitly For - # example with: - # try: - # token = set_tracectx(region_trace) - # col = CollectionProxy(region_trace.args, name="args") - # _ = prims.unpack_sequence(col, len(region_trace.args)) - # finally: - # reset_tracectx(token) - # region_trace.bound_symbols.extend(bsyms) - # But there are some issues with the - # _transform_for_operator_executor_execution implementation that need to be - # fixed first. One issue is that it doesn't maintain the ssa form of the - # trace, which is needed for all the passes to work correctly. - # TODO: issue "Try using _transform_for_operator_executor_execution for - # torch.compile executor" - torch_trace = trace(inline_trace=False)(torch_interpreted_func, *sorted_unique_inputs) - trace_callable = torch_trace.python_callable(include_decorators=False) torch_compile_fullgraph: None | bool = get_compile_option( "torch_compile_fullgraph", "Whether to enable `fullgraph` from `torch.compile`. Defaults to `True`." ) diff --git a/thunder/tests/test_torch_compile_executor.py b/thunder/tests/test_torch_compile_executor.py index 48e37f8a89..3f7b331387 100644 --- a/thunder/tests/test_torch_compile_executor.py +++ b/thunder/tests/test_torch_compile_executor.py @@ -9,6 +9,7 @@ from thunder.tests.bf16 import device_supports_bf16 from thunder.tests.litgpt_model import GPT, Config from thunder.tests.framework import requiresCUDA +from torch.testing import assert_close def test_supported_ops_are_in_pytorch_executor(): @@ -71,3 +72,13 @@ def test_torch_compile_cat_rope_single_fusion(): backward_execution_trace = thunder.last_backward_traces(jfn)[-1] assert len(get_fusions(backward_execution_trace)) == 1 assert len(backward_execution_trace.bound_symbols) == 14 + + +@pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported") +def test_transform_for_execution_for_callable(): + def fn(a): + return a.type("torch.DoubleTensor") + + a = torch.randn(3) + jfn = thunder.jit(fn, executors=(thunder.executors.torch_compile.torch_compile_ex,)) + assert_close(jfn(a), fn(a))