Skip to content

Commit

Permalink
Restrict torchcompile_cat to be applied only for CUDA inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk committed Aug 15, 2024
1 parent 76f57f9 commit 54cc32f
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
@@ -7,11 +7,12 @@
from lightning_utilities import compare_version

from thunder.core import prims, utils
from thunder.core.proxies import Proxy, unvariableify, Variable
from thunder.core.proxies import Proxy, TensorProxy, unvariableify, Variable
from thunder.core.rematerialization import rematerialize
from thunder.core.symbol import BoundSymbol, Symbol
from thunder.core.trace import from_trace, TraceCtx, TraceProvenance
from thunder.core.transform_common import dce
from thunder.core.pytree import tree_flatten
from thunder.executors.passes import update_fusion_call_ctx
from thunder.executors.utils import Region
from thunder.extend import FusionExecutor, register_executor, ImplInfo
@@ -190,6 +191,16 @@ def _can_fuse_node(n: Node):
from thunder.executors.torchex import ex as pytorch_ex


def cuda_device_checker(*args, **kwargs):
# We only want to compile if all the TensorProxy arguments are on the GPU
flat_args, _ = tree_flatten((args, kwargs))
flat_tensorproxy_args = [x for x in flat_args if isinstance(x, TensorProxy)]
for arg in flat_tensorproxy_args:
if arg.device.type != "cuda":
return False
return True


# NOTE: [torch_compile_cat_ex vs torch_compile_ex]
# The former only relies on `torch.compile` for the operators where it shines the most and is meant to be used
# together with the nvfuser executor. Its current goal is only to fuse RoPE but the set of ops fused will change as each
@@ -219,7 +230,7 @@ def _can_fuse_node(n: Node):
prims.slice_prim.id,
prims.transpose.id,
}
torch_compile_cat_ex._implmap = {op: ImplInfo() for op in pytorch_ex.implmap if op in supported_ops}
torch_compile_cat_ex._implmap = {op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops}


torch_compile_ex = TorchCompileExecutor(name="torchcompile")

0 comments on commit 54cc32f

Please sign in to comment.