diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index e67216c0eb..914f413376 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -1,8 +1,17 @@ +from typing import List, Dict, Optional, Tuple, Set +from collections.abc import Callable +from functools import partial + import torch +from torch.fx.passes.split_module import split_module import warnings +from collections.abc import Mapping from thunder.core.baseutils import run_once +from thunder.dynamo.utils import SubgraphInfo +from thunder.dynamo.splitter import _splitter + @run_once def _warn_thunder_compiler(): @@ -20,7 +29,9 @@ def __init__(self, **thunder_options): function. Keyword arguments: - thunder_options: a dictionary of options to pass to `thunder.jit`. + thunder_options: a dictionary of options to pass to `thunder.jit`. Besides all the arguments to `thunder.jit`, + it accepts `torch_inductor_options` which are passed to `torch.compile` if part of the graph + is not supported by thunder. Example: >>> import torch @@ -36,38 +47,30 @@ def __init__(self, **thunder_options): ... return x - 1 >>> out = func(x) """ - from thunder import ThunderModule + from thunder import ThunderModule, jit _warn_thunder_compiler() # Thunder-compiled functions should be readily available for inspection - # and testing, so we will store them in a list. The order of the + # and testing, so we will store them in a list[SubgraphInfo]. The order of the # functions in the list will be the same as the order in which they were - # compiled. In addition, we will store a mapping from the ThunderModule - # to the GraphModule that was passed to ThunderCompiler. This will allow - # us to inspect the GraphModule that was compiled by Thunder. - self.thunder_fns: list[ThunderModule] = [] - self.thunder_to_gm: dict[ThunderModule, torch.fx.GraphModule] = {} + # compiled. + # Ref to the documentation of `SubgraphInfo` to know more about the information it contains. + self.subgraph_infos: list[SubgraphInfo] = [] - self.thunder_options = thunder_options + torch_inductor_options = thunder_options.pop("torch_inductor_options", {}) - # TODO: There will be pieces of Dynamo IR that Thunder cannot compile, so we - # will need to build a fallback mechanism to handle those cases. - # Possible stages of the compilation that need to be saved for inspection: - # 1. The GraphModule as it was passed to ThunderCompiler. - # 2. The GraphModule after split for Thunder/PyTorch. - # 3. If the whole GraphModule is not supported, record the reasons why. + self.thunder_options = thunder_options + self._thunder_jit = partial(jit, **thunder_options) + self._torch_compile = partial(torch.compile, **torch_inductor_options) def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): - from thunder import jit - # Dynamo uses lazy generation of the underlying Python code, so we need to # force recompilation of the GraphModule before passing it to Thunder. gm.real_recompile() - # Here in the future we could add some logic to check if the GraphModule - # is executable by Thunder, but for now we simply compile it and return - jitted_gm = jit(gm, **self.thunder_options) - self.thunder_fns.append(jitted_gm) - self.thunder_to_gm[jitted_gm] = gm - return jitted_gm + # The whole graph may not be supported by `thunder`, so we split it in `thunder` supported sections + # and unsupported sections which are passed to `torch.compile(backend='inductor')` + split_module, subgraph_info = _splitter(gm, self._thunder_jit, self._torch_compile, sample_args) + self.subgraph_infos.append(subgraph_info) + return split_module diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py new file mode 100644 index 0000000000..613ce50acd --- /dev/null +++ b/thunder/dynamo/splitter.py @@ -0,0 +1,173 @@ +from typing import List, Dict, Optional, Tuple, Set +from collections.abc import Callable +from functools import partial + +import torch +from torch.fx.passes.split_module import split_module +import warnings +from collections.abc import Mapping + +from thunder.core.baseutils import run_once + +from thunder.dynamo.utils import ( + SubgraphInfo, + CompiledFunction, + CompilerType, + SplitReason, + SplitReasonType, + is_node_supported_by_thunder, + get_nodes_in_unsupported_ctx_regions, + update_node_and_submodule, +) + + +def _splitter( + gm: torch.fx.GraphModule, + thunder_jit: Callable, + torch_inductor: Callable, + _unused_sample_args: list[torch.SymInt, torch.Tensor], +) -> torch.fx.GraphModule: + """ + This method will split graph into multiple graph modules based on thunder supported operations. + This function will try to split the graph in contiguous partitions. + + Example: + # All operations are supported by thunder + class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[2]"): + l_x_ = L_x_ + + y: "f32[2]" = torch.sin(l_x_) + matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None + return (matmul,) + + # Split Graph: All operations are supported by thunder, we will see only one partition. + class GraphModule(torch.nn.Module): + def forward(self, l_x_: "f32[2]"): + thunder_1 = self.thunder_1(l_x_); l_x_ = None + return (thunder_1,) + + class thunder_1(torch.nn.Module): + def forward(self, l_x_: "f32[2]"): + y: "f32[2]" = torch.sin(l_x_) + matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None + return matmul + + Example: + # With unsupported operation `sinc` + class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[2]"): + l_x_ = L_x_ + + y: "f32[2]" = torch.sinc(l_x_) + + matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None + return (matmul,) + + # Split Graph: Since `sinc` is unsupported, we will see two partitions, one for thunder and one for inductor. + class GraphModule(torch.nn.Module): + def forward(self, l_x_: "f32[2]"): + inductor_1 = self.inductor_1(l_x_) + thunder_2 = self.thunder_2(l_x_, inductor_1); l_x_ = inductor_1 = None + return (thunder_2,) + + class inductor_1(torch.nn.Module): # Partition for inductor + def forward(self, l_x_: "f32[2]"): + y: "f32[2]" = torch.sinc(l_x_); l_x_ = None + return y + + class thunder_2(torch.nn.Module): # Partition for thunder + def forward(self, l_x_: "f32[2]", y: "f32[2]"): + matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None + return matmul + """ + # The callback below is called for every node in the graph. + # It returns an `int` denoting the parition where the node should be placed. + # We want to partition the graph into contiguous regions (with one or more operations) + # into thunder supported or unsupported region. + # `prev_value` is used to determine if we are still in same region (i.e. supported region or unsupported region). + # `partition_cnt` is bumped everytime we change the region i.e. flip from supported to unsupported or from unsupported to supported. + # `supported_partitions` is used to track the thunder supported partitions. + prev_value = None + partition_cnt = 0 + supported_partitions: set[int] = set() + split_reasons: list[SplitReason] = [] + + nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm) + + def callback(node) -> int: + nonlocal prev_value, partition_cnt, split_reasons, supported_partitions + + assert node.op not in ( + "placeholder", + "get_attr", + "output", + ), f"fx.split_module should have only passed node.op=call_* but received {node.op}" + + if node in nodes_in_unsupported_ctx_regions: + # If node was in unsupported ctx region like `autocast`, + # even though the operation maybe supported, we pass it to `torch.compile` + # as `thunder` doesn't correctly work with these. + is_thunder_supported = False + split_reason = SplitReason( + SplitReasonType.UNSUPPORTED_NODE, + info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.", + ) + split_reasons.append(split_reason) + else: + is_thunder_supported, split_reason = is_node_supported_by_thunder(node) + if split_reason is not None: + split_reasons.append(split_reason) + + if prev_value == is_thunder_supported: # We are in the same region. + return partition_cnt + + # There is a flip. Either from supported to unsupported or unsupported to supported. + prev_value = is_thunder_supported + partition_cnt += 1 # Bump the region cnt. + + if is_thunder_supported: + supported_partitions.add(partition_cnt) + return partition_cnt + + # `split_module` iterates over nodes and determines the partition to place them based on the callback. + split_gm: torch.fx.GraphModule = split_module( + gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True + ) + + def is_thunder_supported_partition(node: torch.fx.Node) -> bool: + return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions + + # Call compile on the split region/s. + thunder_compiled_fns = [] + submodule_to_compiled_fns = {} + is_split = False + for node in split_gm.graph.nodes: + if is_thunder_supported_partition(node): + graph_module = getattr(split_gm, node.name) + jit_fn = thunder_jit(graph_module) + # Update the node name from "submod_*" to "thunder_*" for more user-friendly names + update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn) + thunder_compiled_fns.append(jit_fn) + submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER) + elif node.name.startswith("submod"): # For inductor + graph_module = getattr(split_gm, node.name) + jit_fn = torch_inductor(graph_module) + # Update the node name from "submod_*" to "inductor_*" for more user-friendly names + update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) + submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR) + is_split = True + else: + # Everything else is a glue code to call and pass outputs between the other partitions. + pass + + # We update the GraphModule in `update_node_and_submodule`, so we need to recompile. + split_gm.recompile() + + return split_gm, SubgraphInfo( + gm, + split_gm, + thunder_compiled_fns, + submodule_to_compiled_fns, + split_reasons, + ) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py new file mode 100644 index 0000000000..9c92fcf054 --- /dev/null +++ b/thunder/dynamo/utils.py @@ -0,0 +1,301 @@ +from enum import Enum, auto +import dataclasses +from typing import List, Dict, Optional, Tuple, Set +from collections.abc import Callable +import pprint +import itertools +import copy +import inspect + +import torch +from torch.fx.passes.split_module import split_module +import warnings +from collections.abc import Mapping + +from thunder.torch.default_torch_ops import torch_auto_registered_ops +from thunder.torch import _torch_to_thunder_function_map + + +auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values())) + + +class CompilerType(Enum): + """ + An enumeration representing different types of compilers. + """ + + THUNDER = auto() + TORCH_INDUCTOR = auto() + + +@dataclasses.dataclass +class CompiledFunction: + """ + A dataclass representing a compiled function along with its original graph module and compiler type. + + Attributes: + compiled_fn (Callable): The compiled function. + compiler (CompilerType): The type of compiler used to compile the function. + """ + + compiled_fn: Callable + compiler: CompilerType + + +class SplitReasonType(Enum): + """ + An enumeration representing different reasons for split in the graph. + """ + + UNSUPPORTED_NODE = auto() + MISSING_OP_SUPPORT = auto() + EXCEPTION_PROXY_THUNDER_OP = auto() + EXCEPTION_META_THUNDER_OP = auto() + + +@dataclasses.dataclass +class SplitReason: + """ + A dataclass containing information about a split. + + Attributes: + type (SplitReasonType): Reason for the split. + info (str): String with details of what caused the split. + exception (Exception | None): Exception if there was any. + """ + + type: SplitReasonType + info: str | None + exception: Exception | None = None + + +@dataclasses.dataclass +class SubgraphInfo: + """ + A dataclass containing information about a subgraph. + + Attributes: + original_graph_module (torch.fx.GraphModule): The original graph module. + split_graph_module (torch.fx.GraphModule): Optional. The graph module for the split subgraph. + thunder_compiled_fns (list[Callable]): List of thunder optimized callables. This could be None if there the graph module was not supported by thunder. Look at the `split_reasons` for further information. + compiled_functions (list[CompiledFunction]): A list of compiled functions derived from the subgraph. This will be a list with one function in case the graph was not split. + split_reasons (list[SplitReason] | None): Optional list of reasons explaining why the subgraph was split. Present only if there are was a split. + """ + + original_graph_module: torch.fx.GraphModule + split_graph_module: torch.fx.GraphModule + thunder_compiled_fns: list[Callable] + submodule_to_compiled_functions: Mapping[torch.fx.GraphModule, CompiledFunction] + split_reasons: list | None = None + + +def _concrete_shape(x): + """ + Get the concrete shape for a FakeTensor if it has `torch.SymInt` in its shape. + """ + + def get_backed_value(s): + if isinstance(s, torch.SymInt): + return s.node.hint + # Value is already concrete. + return s + + return tuple(map(get_backed_value, x.shape)) + + +def try_execute_thunder_symbol(thunder_symbol: "Symbol", node: torch.fx.Node) -> tuple[bool, SplitReason | None]: + """ + Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments. + + This function operates within a Thunder tracing context to generate proxies for the provided node's arguments. + It then attempts to execute the Thunder symbol with these proxies. If any exceptions occur during proxy creation + or execution, it returns a tuple indicating failure and provides a `SplitReason` detailing the exception. + + Args: + thunder_symbol (Symbol): The Thunder symbol to be executed. This is expected to be a callable that can + operate on proxy arguments. + node (torch.fx.Node): The Torch FX node whose arguments are to be proxied and passed to the Thunder symbol. + + Returns: + tuple[bool, SplitReason | None]: A tuple where the first element is a boolean whether the execution passed or failed. + The second element is a `SplitReason` object if an error occurred, or `None` if the execution was successful. + """ + import thunder + from thunder.core.trace import TraceCtx + from thunder.core.proxies import proxy + + @thunder._with_cache_info_ctx + def _run_with_cache_info(): + + # We need cache info here as the default dtype and device support + # for factory functions like ones, zeros, etc expects these details to be present. + # TODO: Move this to CompileData as well? + # This details are in cache_info because `jit_ext.py` + # adds checks in prologue for the details which are present in here. + cache_info = thunder._get_cache_info() + cache_info["default_dtype"] = torch.get_default_dtype() + cache_info["default_device"] = torch.get_default_device() + + trc = TraceCtx() + # We need to be under trace context to generate proxies. + with thunder.core.trace.tracectx(trc): + try: + + def make_tensor_proxy(arg_node): + # This is a Node in the graph representing a Tensor or tuple of Tensors. + if isinstance(arg_node, torch.fx.Node): + example_value = arg_node.meta["example_value"] + + if isinstance(example_value, torch.Tensor): + # If `dynamic` shapes are enabled, we may see a FakeTensor + # where shape has SymInt. In that case, we check if we can + # get the concrete value from SymInt. + # Here, we only want to verify that thunder can run an operation. + # So, it is ok to verify with concrete value. + example_value = example_value.new_ones( + _concrete_shape(example_value), device=example_value.device, dtype=example_value.dtype + ) + elif isinstance(example_value, tuple): + example_value = tuple( + e_v.new_ones(_concrete_shape(e_v), device=e_v.device, dtype=e_v.dtype) + for e_v in example_value + ) + else: + # NOTE - This will be caught will be caught and be part of the SplitReason. + raise TypeError( + f"Received `make_tensor_proxy` received example_value which wasn't Tensor or Tuple" + ) + return proxy(example_value) + + # This is int, float, etc. + return arg_node + + proxy_args = tuple(map(make_tensor_proxy, node.args)) + proxy_kwargs = {k: make_tensor_proxy(v) for k, v in node.kwargs.items()} + except Exception as e: + return False, SplitReason( + SplitReasonType.EXCEPTION_PROXY_THUNDER_OP, + f"Failed while creating proxy for node with name: {node.name} and target: {node.target}, see exception field", + exception=e, + ) + + try: + thunder_symbol(*proxy_args, **proxy_kwargs) + except Exception as e: + return False, SplitReason( + SplitReasonType.EXCEPTION_META_THUNDER_OP, + f"Failed while running meta for node with name: {node.name} and target: {node.target}, see exception field", + exception=e, + ) + + # Execution with proxies was successful. + return True, None + + return _run_with_cache_info() + + +def get_nodes_in_unsupported_ctx_regions(gm) -> set[torch.fx.Node]: + """ + Finds the node within `autocast` or other supported context and marks them as unsupported. + Even though, thunder may support the operation within the reason, it doesn't correctly apply the change + triggered from the context. + """ + # NOTE - Currently only detects the autocast regions. + + nodes_in_unsupported_ctx_regions: set[torch.fx.Node] = set() + ctx_cnt = 0 # Count of `enters_autocast` we have seen till now + + # We want to mark nodes with `_enter_autocast` and `_exit_autocast` + # as unsupported as `thunder` doesn't correctly deal with these stateful functions. + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in (torch.amp.autocast_mode._enter_autocast,): + ctx_cnt += 1 + elif node.op == "call_function" and node.target in (torch.amp.autocast_mode._exit_autocast,): + ctx_cnt -= 1 + else: + if ctx_cnt > 0: + nodes_in_unsupported_ctx_regions.add(node) + + return nodes_in_unsupported_ctx_regions + + +def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason | None]: + """ + Determine whether thunder can execute the operation described by this node. + """ + # Docs from the torch.fx.Node - https://pytorch.org/docs/stable/fx.html#torch.fx.Node + # Each Node has a function specified by its op property + # Below are the details for the ones this function is interested in - + # `call_function` applies a free function to some values. + # name is similarly the name of the value to assign to. + # target is the function to be applied. args and kwargs represent + # the arguments to the function, following the Python calling convention + # `call_method` calls a method on a value. + # name is as similar. + # target is the string name of the method to apply to the self argument. + # args and kwargs represent the arguments to invoke the module on, including the self argument + # + # NOTE: `call_module` should be inlined in dynamo graphs since https://github.com/pytorch/pytorch/pull/131275 + # But there is flag to disable inlining `call_module`. Determining `call_module` support would actually require calling `thunder.jit` on it. + # + # `call_module` applies a module in the module hierarchy’s forward() method to given arguments. + # name is as previous. target is the fully-qualified name of the module in the module hierarchy to call. + # args and kwargs represent the arguments to invoke the module on, excluding the self argument + + target = node.target # Target is the function to call. + if node.op == "call_method": + self_arg = node.args[0] + target = getattr(torch.Tensor, node.target, None) + assert target is not None, f"Failed to find method {node.target}" + + # If the operation has automatic registration, we mark it as unsupported as `inductor` might be + # able to deal with it better. + if target in auto_register_ops: + split_reason = SplitReason( + SplitReasonType.MISSING_OP_SUPPORT, + info=f"node with name: {node.name} and target: {node.target} only has an automatic torch fallback in thunder.", + ) + return False, split_reason + + # If thunder has a mapping for this operation, try executing the meta function and see. + # We have a symbol for `torch.where`, but we don't support one overload of it. + # So, we try and execute the meta to get a real signal. + # + # Regarding `inspect.isbuiltin`, dynamo graph uses `+`, `>` which are builtin `add`, `gt`. + # We try to proxify the arguments and call these operations on them to see if they are supported. + if target in _torch_to_thunder_function_map or inspect.isbuiltin(target): + thunder_symbol_or_builtin = _torch_to_thunder_function_map.get(target, target) + did_run, opt_split_reason = try_execute_thunder_symbol(thunder_symbol_or_builtin, node) + return did_run, opt_split_reason + + # We found no automatic fallback registration and no mapping to thunder symbol. + split_reason = SplitReason( + SplitReasonType.MISSING_OP_SUPPORT, + info=f"node with name: {node.name} and target: {node.target} didn't have any mapping in thunder.", + ) + return False, split_reason + + +def update_node_and_submodule( + graph_module: torch.fx.GraphModule, node: torch.fx.Node, new_name: str, new_callable: Callable +): + """ + Updates the graph module and the node in place with a new name and a new callable as the target. + + This function removes the existing submodule associated with the node's current name in graph_module and replaces + it with a new submodule using the specified new name and callable. The node's name and target are updated accordingly. + + Args: + graph_module (torch.fx.GraphModule): The graph module containing the node and submodules. + node (torch.fx.Node): The node to be updated within the graph module. + new_name (str): The new name to assign to the node and the submodule. + new_callable (Callable): The new callable to be used as the target for the submodule. + """ + assert graph_module.delete_submodule( + node.name + ), f"Didn't find a submodule named {node.name} in graph_module {graph_module}" + node.name = new_name + node.target = new_name + assert graph_module.add_submodule( + node.name, new_callable + ), f"Adding submodule with name {node.name} in graph_module {graph_module} failed" diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index bd6f26abe5..b4eb8ab147 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -1,5 +1,5 @@ import torch.fx -from thunder.tests.framework import instantiate, NOTHING, DynamoThunderExecutor +from thunder.tests.framework import instantiate, NOTHING, DynamoThunderExecutor, IS_WINDOWS from thunder import dtypes from thunder.dynamo import ThunderCompiler from thunder import last_traces @@ -32,9 +32,144 @@ def func(x): assert out.grad_fn.name() == "ThunderFunctionBackward" # We record the GraphModules that was compiled by ThunderCompiler - assert len(backend.thunder_to_gm) == 2 - thunder_func, gm = list(backend.thunder_to_gm.items())[0] - assert isinstance(gm, torch.fx.GraphModule) + assert len(backend.subgraph_infos) == 2 # 2 due to data-dependent flow - # This shouldn't be empty - assert last_traces(thunder_func) + for subgraph_info in backend.subgraph_infos: + assert isinstance(subgraph_info.original_graph_module, torch.fx.GraphModule) + assert len(subgraph_info.thunder_compiled_fns) # There was atleast one function compiled with thunder. + for thunder_fn in subgraph_info.thunder_compiled_fns: + assert last_traces(thunder_fn) # Verify that we can fetch last_traces + + +@instantiate( + dtypes=NOTHING, + executors=[DynamoThunderExecutor], + decorators=( + pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")), + pytest.mark.xfail( + condition=IS_WINDOWS, + strict=True, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", + ), + ), +) +def test_basic_splitter(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): + x = torch.ones(2, 2, device=device, dtype=dtype, requires_grad=True) + + backend = ThunderCompiler() + + def func(x): + # torch.sinc has automatic fallback registered, + # so that operation will be given to inductor. + x = x.exp() + y = torch.sinc(x) + torch.cos(x) + return y + 1 + + cfunc = torch.compile(func, backend=backend, dynamic=dynamic) + expected = torch.compile(func, dynamic=False)(x) + actual = cfunc(x) + + g = torch.rand_like(actual) + torch.testing.assert_close(actual, expected) + actual_grad = torch.autograd.grad(actual, x, g) + expected_grad = torch.autograd.grad(expected, x, g) + torch.testing.assert_close(actual_grad, expected_grad) + + assert len(backend.subgraph_infos) == 1 + assert len(backend.subgraph_infos[0].submodule_to_compiled_functions) > 1 # Verify that the subgraph was split. + assert any( + "automatic torch fallback" in split_reason.info for split_reason in backend.subgraph_infos[0].split_reasons + ) # Verify that we had a split because we detected an `automatic registered operator` + targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes) + assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*` + + +@instantiate( + dtypes=NOTHING, + executors=[DynamoThunderExecutor], + decorators=( + pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")), + pytest.mark.xfail( + condition=IS_WINDOWS, + strict=True, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", + ), + ), +) +def test_splitter_unsupported_ctx(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): + x = torch.rand(2, 2, device=device, dtype=dtype, requires_grad=True) + + backend = ThunderCompiler() + + def func(x): + x = x + 2 + with torch.autocast("cpu"): + y = torch.log(x) + return torch.matmul(x, y) + + expected = torch.compile(func, dynamic=False)(x) + + cfunc = torch.compile(func, backend=backend, dynamic=dynamic) + actual = cfunc(x) + + g = torch.rand_like(actual) + torch.testing.assert_close(actual, expected) + actual_grad = torch.autograd.grad(actual, x, g) + expected_grad = torch.autograd.grad(expected, x, g) + torch.testing.assert_close(actual_grad, expected_grad) + + assert len(backend.subgraph_infos) == 1 + assert len(backend.subgraph_infos[0].submodule_to_compiled_functions) > 1 # Verify that the subgraph was split. + assert any( + "didn't have any mapping in thunder" in split_reason.info + for split_reason in backend.subgraph_infos[0].split_reasons + ) + targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes) + assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*` + assert any( + target.startswith("inductor_") for target in targets + ) # Verify that the submodules have name `inductor_*` + + +@instantiate( + dtypes=NOTHING, + executors=[DynamoThunderExecutor], + decorators=( + pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")), + pytest.mark.xfail( + condition=IS_WINDOWS, + strict=True, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", + ), + ), +) +def test_splitter_unsupported_ctx_with_graph_break(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): + x = torch.rand(2, 2, device=device, dtype=dtype, requires_grad=True) + + backend = ThunderCompiler() + + def func(x): + x = x + 2 + with torch.autocast("cpu"): + y = torch.sin(x) + torch._dynamo.graph_break() + return torch.matmul(x, y) + + expected = torch.compile(func, dynamic=False)(x) + cfunc = torch.compile(func, backend=backend, dynamic=dynamic) + actual = cfunc(x) + + g = torch.rand_like(actual) + torch.testing.assert_close(actual, expected) + actual_grad = torch.autograd.grad(actual, x, g) + expected_grad = torch.autograd.grad(expected, x, g) + torch.testing.assert_close(actual_grad, expected_grad) + + # 2 subgraphs due to graph-break + assert len(backend.subgraph_infos) == 2 + + for subgraph_info in backend.subgraph_infos: + # Verify that for each subgraph we had split due to `autocast` being enabled. + assert any( + "didn't have any mapping in thunder" in split_reason.info for split_reason in subgraph_info.split_reasons + )