Skip to content

use transform_for_execution to get callable for torch compile #1041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 22 additions & 23 deletions thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from thunder.core.trace import from_trace, TraceCtx, TraceProvenance
from thunder.core.transform_common import dce
from thunder.core.pytree import tree_flatten
from thunder.executors.passes import update_fusion_call_ctx
from thunder.executors.passes import (
update_fusion_call_ctx,
_transform_for_operator_executor_execution,
transform_for_execution,
)
from thunder.executors.utils import Region
from thunder.extend import FusionExecutor, register_executor, ImplInfo
from thunder.core.compile_data import get_compile_option
Expand Down Expand Up @@ -55,36 +59,31 @@ def make_compiled(
from thunder import trace
from thunder.core.transforms import eval_trace
from thunder.executors.torchex import no_autocast
from thunder.executors.torchex import ex as torchex
from thunder.executors.pythonex import ex as pythonex
from thunder.core.codeutils import SigInfo

# Here we construct a trace that will be used to compile the function
# TODO: maybe we should have a utility that does this properly
region_trace = TraceCtx(None)
region_trace.bound_symbols = list(bsyms)
region_trace.args = sorted_unique_inputs
region_trace.kwargs = {}
region_trace.bound_symbols.append(prims.python_return.bind(sorted_unique_outputs, output=()))
for a in region_trace.args:
region_trace.add_name(a.name)
for bsym in region_trace.bound_symbols:
for o in bsym.flat_outs:
if o is not None: # TODO: investigate
region_trace.add_name(o.name)

# maybe make this the default if no sig info is present?
region_trace._siginfo = SigInfo("to_be_compiled")
region_trace._siginfo.args = [(a.name, None) for a in region_trace.args]

torchex_trace = transform_for_execution(region_trace, executors_list=(torchex,))
trace_callable = torchex_trace.python_callable(include_decorators=False)

def torch_interpreted_func(*args):
return eval_trace(region_trace, *args, symbol_mapper=to_torch_translator)

# Here instead of using thunder.trace we could use torch_trace =
# passes._transform_for_operator_executor_execution(region_trace, [torchex])
# but then we would need to handle unpacking of the args explicitly For
# example with:
# try:
# token = set_tracectx(region_trace)
# col = CollectionProxy(region_trace.args, name="args")
# _ = prims.unpack_sequence(col, len(region_trace.args))
# finally:
# reset_tracectx(token)
# region_trace.bound_symbols.extend(bsyms)
# But there are some issues with the
# _transform_for_operator_executor_execution implementation that need to be
# fixed first. One issue is that it doesn't maintain the ssa form of the
# trace, which is needed for all the passes to work correctly.
# TODO: issue "Try using _transform_for_operator_executor_execution for
# torch.compile executor"
torch_trace = trace(inline_trace=False)(torch_interpreted_func, *sorted_unique_inputs)
trace_callable = torch_trace.python_callable(include_decorators=False)
torch_compile_fullgraph: None | bool = get_compile_option(
"torch_compile_fullgraph", "Whether to enable `fullgraph` from `torch.compile`. Defaults to `True`."
)
Expand Down
11 changes: 11 additions & 0 deletions thunder/tests/test_torch_compile_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from thunder.tests.bf16 import device_supports_bf16
from thunder.tests.litgpt_model import GPT, Config
from thunder.tests.framework import requiresCUDA
from torch.testing import assert_close


def test_supported_ops_are_in_pytorch_executor():
Expand Down Expand Up @@ -71,3 +72,13 @@ def test_torch_compile_cat_rope_single_fusion():
backward_execution_trace = thunder.last_backward_traces(jfn)[-1]
assert len(get_fusions(backward_execution_trace)) == 1
assert len(backward_execution_trace.bound_symbols) == 14


@pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported")
def test_transform_for_execution_for_callable():
def fn(a):
return a.type("torch.DoubleTensor")

a = torch.randn(3)
jfn = thunder.jit(fn, executors=(thunder.executors.torch_compile.torch_compile_ex,))
assert_close(jfn(a), fn(a))
Loading