diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 86e5f46214..7f05c81282 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -1,25 +1,24 @@ -from enum import Enum, auto -import dataclasses -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Optional, Tuple, Set from collections.abc import Callable import pprint -import itertools -import copy from functools import partial -import operator -import inspect import torch -from torch.fx.passes.split_module import split_module as fx_split_module +from torch.fx.passes.split_module import split_module import warnings from collections.abc import Mapping -from torch.fx.passes import operator_support from thunder.core.baseutils import run_once -from thunder.torch.default_torch_ops import torch_auto_registered_ops -from thunder.torch import _torch_to_thunder_function_map -auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values())) +from thunder.dynamo.utils import ( + SubgraphInfo, + CompiledFunction, + CompilerType, + SplitReason, + SplitReasonType, + is_node_supported, + get_nodes_in_unsupported_ctx_regions, +) @run_once @@ -30,262 +29,6 @@ def _warn_thunder_compiler(): ) -class CompilerType(Enum): - """ - An enumeration representing different types of compilers. - """ - - THUNDER = auto() - TORCH_INDUCTOR = auto() - - -@dataclasses.dataclass -class CompiledFunction: - """ - A dataclass representing a compiled function along with its original graph module and compiler type. - - Attributes: - original_graph_module (torch.fx.GraphModule): The original graph module from which the function is compiled. - compiled_fn (Callable): The compiled function. - compiler (CompilerType): The type of compiler used to compile the function. - """ - - original_graph_module: torch.fx.GraphModule - compiled_fn: Callable - compiler: CompilerType - - -class SplitReasonType(Enum): - """ - An enumeration representing different reasons for split in the graph. - """ - - UNSUPPORTED_NODE = auto() - MISSING_OP_SUPPORT = auto() - EXCEPTION_PROXY_THUNDER_OP = auto() - EXCEPTION_META_THUNDER_OP = auto() - - -@dataclasses.dataclass -class SplitReason: - """ - A dataclass containing information about a split. - - Attributes: - type (SplitReasonType): Reason for the split. - info (str): String with details of what caused the split. - exception (Exception | None): Exception if there was any. - """ - - type: SplitReasonType - info: str | None - exception: Exception | None = None - - -@dataclasses.dataclass -class SubgraphInfo: - """ - A dataclass containing information about a subgraph. - - Attributes: - original_graph_module (torch.fx.GraphModule): The original graph module. - compiled_functions (list[CompiledFunction]): A list of compiled functions derived from the subgraph. This will be a list with one function in case the graph was not split. - is_split (bool): Indicates whether the subgraph has been split. This happens if there was a thunder unsupported functionality. - split_reasons (list[SplitReason] | None): Optional list of reasons explaining why the subgraph was split. Present only if `is_split` is True. - split_graph_module (torch.fx.GraphModule | None): Optional. The graph module for the split subgraph. Present only if `is_split` is True. - """ - - original_graph_module: torch.fx.GraphModule - compiled_functions: list[CompiledFunction] - is_split: bool - split_reasons: list | None = None - split_graph_module: torch.fx.GraphModule | None = None - - -def try_execute_symbol(thunder_symbol: "Symbol", node: torch.fx.Node) -> tuple[bool, SplitReason | None]: - """ - Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments. - - This function operates within a Thunder tracing context to generate proxies for the provided node's arguments. - It then attempts to execute the Thunder symbol with these proxies. If any exceptions occur during proxy creation - or execution, it returns a tuple indicating failure and provides a `SplitReason` detailing the exception. - - Args: - thunder_symbol (Symbol): The Thunder symbol to be executed. This is expected to be a callable that can - operate on proxy arguments. - node (torch.fx.Node): The Torch FX node whose arguments are to be proxied and passed to the Thunder symbol. - - Returns: - tuple[bool, SplitReason | None]: A tuple where the first element is a boolean whether the execution passed or failed. - The second element is a `SplitReason` object if an error occurred, or `None` if the execution was successful. - """ - import thunder - from thunder.core.trace import TraceCtx - from thunder.core.proxies import proxy - - trc = TraceCtx() - # We need to be under trace context to generate proxies. - with thunder.core.trace.tracectx(trc): - try: - - def make_tensor_proxy(arg_node): - # This is a Node in the graph representing a Tensor. - if isinstance(arg_node, torch.fx.Node): - example_value = arg_node.meta["example_value"] - - # This fails if the shape of the FakeTensor contains SymInts. - return proxy(example_value) - - # This is int, float, etc. - # TODO(kshitij12345) - verify the above line for more cases. - return arg_node - - proxy_args = tuple(map(make_tensor_proxy, node.args)) - proxy_kwargs = {k: make_tensor_proxy(v) for k, v in node.kwargs.items()} - except Exception as e: - return False, SplitReason( - SplitReasonType.EXCEPTION_PROXY_THUNDER_OP, - f"Failed while creating proxy for node with name: {node.name} and target: {node.target}, see exception field", - exception=e, - ) - - try: - thunder_symbol(*proxy_args, **proxy_kwargs) - except Exception as e: - return False, SplitReason( - SplitReasonType.EXCEPTION_META_THUNDER_OP, - f"Failed while running meta for node with name: {node.name} and target: {node.target}, see exception field", - exception=e, - ) - - # Execution with proxies was successful. - return True, None - - -class ThunderOperatorSupport: - def __init__(self, gm): - self.gm = gm - self.unsupported_nodes = set() - self.find_unsupported_ctx_regions(gm) - self.split_reasons: list[SplitReason] = [] - - def find_unsupported_ctx_regions(self, gm): - """ - Finds the node within `autocast` or other supported context and marks them as unsupported. - Even though, thunder may support the operation within the reason, it doesn't correctly apply the change - triggered from the context. - """ - # NOTE - Currently only detects the autocast regions. - - ctx_cnt = 0 # Count of `enters_autocast` we have seen till now - - # We want to mark nodes with `_enter_autocast` and `_exit_autocast` - # as unsupported as `thunder` doesn't correctly deal with these stateful functions. - for node in gm.graph.nodes: - if node.op == "call_function" and node.target in (torch.amp.autocast_mode._enter_autocast,): - ctx_cnt += 1 - elif node.op == "call_function" and node.target in (torch.amp.autocast_mode._exit_autocast,): - ctx_cnt -= 1 - else: - if ctx_cnt > 0: - self.unsupported_nodes.add(node) - - def is_node_supported(self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node): - """ - Determine whether thunder can execute the operation described by this node. - """ - # These are the nodes which are in unsupported context regions - if node in self.unsupported_nodes: - self.split_reasons.append( - SplitReason( - SplitReasonType.UNSUPPORTED_NODE, - info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.", - ) - ) - return False - - # Docs from the torch.fx.Node - https://pytorch.org/docs/stable/fx.html#torch.fx.Node - # Each Node has a function specified by its op property - # Below are the details for the ones this function is interested in - - # `call_function` applies a free function to some values. - # name is similarly the name of the value to assign to. - # target is the function to be applied. args and kwargs represent - # the arguments to the function, following the Python calling convention - # `call_method` calls a method on a value. - # name is as similar. - # target is the string name of the method to apply to the self argument. - # args and kwargs represent the arguments to invoke the module on, including the self argument - # - # NOTE: `call_module` should be inlined in dynamo graphs since https://github.com/pytorch/pytorch/pull/131275 - # But there is flag to disable inlining `call_module`. Determining `call_module` support would actually require calling `thunder.jit` on it. - # - # `call_module` applies a module in the module hierarchy’s forward() method to given arguments. - # name is as previous. target is the fully-qualified name of the module in the module hierarchy to call. - # args and kwargs represent the arguments to invoke the module on, excluding the self argument - - target = node.target # Target is the function to call. - if node.op == "call_method": - self_arg = node.args[0] - target = getattr(torch.Tensor, node.target, None) - assert target is not None, f"Failed to find method {node.target}" - - # If the operation has automatic registration, we mark it as unsupported as `inductor` might be - # able to deal with it better. - if target in auto_register_ops: - self.split_reasons.append( - SplitReason( - SplitReasonType.MISSING_OP_SUPPORT, - info=f"node with name: {node.name} and target: {node.target} only has an automatic torch fallback in thunder.", - ) - ) - return False - - # If thunder has a mapping for this operation, try executing the meta function and see. - # We have a symbol for `torch.where`, but we don't support one overload of it. - # So, we try and execute the meta to get a real signal. - # - # Regarding `inspect.isbuiltin`, dynamo graph uses `+`, `>` which are builtin `add`, `gt`. - # We try to proxify the arguments and call these operations on them to see if they are supported. - if target in _torch_to_thunder_function_map or inspect.isbuiltin(target): - if target in [torch.ones]: # Factory functions. (removing this will lead to split) - # NOTE - Factory functions don't work as the expect `_cache_info` to be populated - # with default dtype but `_cache_info` is only created and populated in `thunder.jit` path. - # my_trc = TraceCtx() - # with tracectx(my_trc): - # thunder.torch.ones(3, 3) - return True - - thunder_symbol_or_builtin = _torch_to_thunder_function_map.get(target, target) - did_run, opt_split_reason = try_execute_symbol(thunder_symbol_or_builtin, node) - if opt_split_reason is not None: - self.split_reasons.append(opt_split_reason) - return did_run - - # We found no automatic fallback registration and no mapping to thunder symbol. - self.split_reasons.append( - SplitReason( - SplitReasonType.MISSING_OP_SUPPORT, - info=f"node with name: {node.name} and target: {node.target} didn't have any mapping in thunder.", - ) - ) - return False - - -def _all_graph_supported_by_thunder(gm: torch.fx.GraphModule, sample_input: list[torch.SymInt, torch.Tensor]) -> bool: - """ - Determine whether there is any thunder unsupported operation. - """ - # NOTE - Unused for now. - op_support = ThunderOperatorSupport(gm) - supported = True - for node in gm.graph.nodes: - if node.op in ["call_method", "call_function"]: - supported = op_support.is_node_supported(gm, node) - if not supported: - break - return supported - - class ThunderCompiler: def __init__(self, **thunder_options): """ @@ -381,10 +124,6 @@ def forward(self, l_x_: "f32[2]", y: "f32[2]"): matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None return matmul """ - # Create an `ThunderOperatorSupport` instance which will be used in the callback. - # This will determine whether the operation represented by the node is supported by thunder. - operator_support = ThunderOperatorSupport(gm) - # The callback below is called for every node in the graph. # It returns an `int` denoting the parition where the node should be placed. # We want to partition the graph into contiguous regions (with one or more operations) @@ -394,7 +133,10 @@ def forward(self, l_x_: "f32[2]", y: "f32[2]"): # `supported_partitions` is used to track the thunder supported partitions. prev_value = None partition_cnt = 0 - supported_partitions = set() + supported_partitions: set[int] = set() + split_reasons: list[SplitReason] = [] + + nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm) def callback(node) -> int: assert node.op not in ( @@ -403,7 +145,19 @@ def callback(node) -> int: "output", ), f"fx.split_module should have only passed node.op=call_* but received {node.op}" nonlocal prev_value, partition_cnt - is_thunder_supported = operator_support.is_node_supported(gm, node) + + if node in nodes_in_unsupported_ctx_regions: + is_thunder_supported = False + split_reason = SplitReason( + SplitReasonType.UNSUPPORTED_NODE, + info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.", + ) + split_reasons.append(split_reason) + else: + is_thunder_supported, split_reason = is_node_supported(node) + if split_reason is not None: + split_reasons.append(split_reason) + if prev_value == is_thunder_supported: # We are in the same region. return partition_cnt @@ -416,7 +170,7 @@ def callback(node) -> int: return partition_cnt # `fx_split_module` iterates over nodes and determines the partition to place them based on the callback. - split_module: torch.fx.GraphModule = fx_split_module( + split_gm: torch.fx.GraphModule = split_module( gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True ) @@ -424,27 +178,40 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions # Call compile on the split region/s. - comipled_fn = [] - for node in split_module.graph.nodes: + thunder_compiled_fns = [] + submodule_to_compiled_fns = {} + is_split = False + for node in split_gm.graph.nodes: if is_thunder_supported_partition(node): - graph_module = getattr(split_module, node.name) + # there is erase method on GraphModule + graph_module = getattr(split_gm, node.name) jit_fn = self._thunder_jit(graph_module) - setattr(split_module, node.name, jit_fn) - comipled_fn.append(CompiledFunction(graph_module, jit_fn, CompilerType.THUNDER)) + setattr(split_gm, node.name, jit_fn) + thunder_compiled_fns.append(jit_fn) + submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER) elif node.name.startswith("submod"): # For inductor - graph_module = getattr(split_module, node.name) + graph_module = getattr(split_gm, node.name) jit_fn = torch.compile(graph_module, backend="inductor") - setattr(split_module, node.name, jit_fn) - comipled_fn.append(CompiledFunction(graph_module, jit_fn, CompilerType.TORCH_INDUCTOR)) + setattr(split_gm, node.name, jit_fn) + submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR) + is_split = True else: # Everything else is a glue code to call and pass outputs between the other partitions. pass - gm.print_readable() + # gm.print_readable() # Append the details regarding this graph/subgraph. - self.subgraph_infos.append(SubgraphInfo(gm, comipled_fn, True, operator_support.split_reasons, split_module)) - split_module.print_readable() - return split_module + self.subgraph_infos.append( + SubgraphInfo( + gm, + split_gm, + thunder_compiled_fns, + submodule_to_compiled_fns, + split_reasons, + ) + ) + # split_gm.print_readable() + return split_gm def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): from thunder import jit diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py new file mode 100644 index 0000000000..527f4df9e1 --- /dev/null +++ b/thunder/dynamo/utils.py @@ -0,0 +1,262 @@ +from enum import Enum, auto +import dataclasses +from typing import List, Dict, Optional, Tuple, Set +from collections.abc import Callable +import pprint +import itertools +import copy +import inspect + +import torch +from torch.fx.passes.split_module import split_module +import warnings +from collections.abc import Mapping + +from thunder.torch.default_torch_ops import torch_auto_registered_ops +from thunder.torch import _torch_to_thunder_function_map + + +auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values())) + + +class CompilerType(Enum): + """ + An enumeration representing different types of compilers. + """ + + THUNDER = auto() + TORCH_INDUCTOR = auto() + + +@dataclasses.dataclass +class CompiledFunction: + """ + A dataclass representing a compiled function along with its original graph module and compiler type. + + Attributes: + compiled_fn (Callable): The compiled function. + compiler (CompilerType): The type of compiler used to compile the function. + """ + + compiled_fn: Callable + compiler: CompilerType + + +class SplitReasonType(Enum): + """ + An enumeration representing different reasons for split in the graph. + """ + + UNSUPPORTED_NODE = auto() + MISSING_OP_SUPPORT = auto() + EXCEPTION_PROXY_THUNDER_OP = auto() + EXCEPTION_META_THUNDER_OP = auto() + + +@dataclasses.dataclass +class SplitReason: + """ + A dataclass containing information about a split. + + Attributes: + type (SplitReasonType): Reason for the split. + info (str): String with details of what caused the split. + exception (Exception | None): Exception if there was any. + """ + + type: SplitReasonType + info: str | None + exception: Exception | None = None + + +@dataclasses.dataclass +class SubgraphInfo: + """ + A dataclass containing information about a subgraph. + + Attributes: + original_graph_module (torch.fx.GraphModule): The original graph module. + split_graph_module (torch.fx.GraphModule): Optional. The graph module for the split subgraph. + thunder_compiled_fns (list[Callable]): List of thunder optimized callables. This could be None if there the graph module was not supported by thunder. Look at the `split_reasons` for further information. + compiled_functions (list[CompiledFunction]): A list of compiled functions derived from the subgraph. This will be a list with one function in case the graph was not split. + split_reasons (list[SplitReason] | None): Optional list of reasons explaining why the subgraph was split. Present only if there are was a split. + """ + + original_graph_module: torch.fx.GraphModule + split_graph_module: torch.fx.GraphModule + thunder_compiled_fns: list[Callable] + submodule_to_compiled_functions: Mapping[torch.fx.GraphModule, CompiledFunction] + split_reasons: list | None = None + + +def try_execute_symbol(thunder_symbol: "Symbol", node: torch.fx.Node) -> tuple[bool, SplitReason | None]: + """ + Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments. + + This function operates within a Thunder tracing context to generate proxies for the provided node's arguments. + It then attempts to execute the Thunder symbol with these proxies. If any exceptions occur during proxy creation + or execution, it returns a tuple indicating failure and provides a `SplitReason` detailing the exception. + + Args: + thunder_symbol (Symbol): The Thunder symbol to be executed. This is expected to be a callable that can + operate on proxy arguments. + node (torch.fx.Node): The Torch FX node whose arguments are to be proxied and passed to the Thunder symbol. + + Returns: + tuple[bool, SplitReason | None]: A tuple where the first element is a boolean whether the execution passed or failed. + The second element is a `SplitReason` object if an error occurred, or `None` if the execution was successful. + """ + import thunder + from thunder.core.trace import TraceCtx + from thunder.core.proxies import proxy + + @thunder._with_cache_info_ctx + def _run_with_cache_info(): + + # We need cache info here as the default dtype and device support + # for factory functions like ones, zeros, etc expects these details to be present. + # TODO: Move this to CompileData as well? + # This details are in cache info because `jit_ext.py` + # adds checks in prologue for the details which are present in here. + cache_info = thunder._get_cache_info() + cache_info["default_dtype"] = torch.get_default_dtype() + cache_info["default_device"] = torch.get_default_device() + + trc = TraceCtx() + # We need to be under trace context to generate proxies. + with thunder.core.trace.tracectx(trc): + try: + + def make_tensor_proxy(arg_node): + # This is a Node in the graph representing a Tensor. + if isinstance(arg_node, torch.fx.Node): + example_value = arg_node.meta["example_value"] + + if isinstance(example_value, torch.Tensor): + # If `dynamic` shapes are enabled, we may see a FakeTensor + # where shape has SymInt. In that case, we check if we can + # get the concrete value from SymInt. + # Here, we only want to verify that thunder can run an operation. + # So, it is ok to verify with concrete value. + def concrete_shape(x): + def get_backed_value(s): + if isinstance(s, torch.SymInt): + return s.node.hint + return s + + shape = tuple(map(get_backed_value, x.shape)) + return shape + + example_value = example_value.new_ones( + concrete_shape(example_value), device=example_value.device, dtype=example_value.dtype + ) + return proxy(example_value) + + # This is int, float, etc. + # TODO(kshitij12345) - verify the above line for more cases. + return arg_node + + proxy_args = tuple(map(make_tensor_proxy, node.args)) + proxy_kwargs = {k: make_tensor_proxy(v) for k, v in node.kwargs.items()} + except Exception as e: + return False, SplitReason( + SplitReasonType.EXCEPTION_PROXY_THUNDER_OP, + f"Failed while creating proxy for node with name: {node.name} and target: {node.target}, see exception field", + exception=e, + ) + + try: + thunder_symbol(*proxy_args, **proxy_kwargs) + except Exception as e: + return False, SplitReason( + SplitReasonType.EXCEPTION_META_THUNDER_OP, + f"Failed while running meta for node with name: {node.name} and target: {node.target}, see exception field", + exception=e, + ) + + # Execution with proxies was successful. + return True, None + + return _run_with_cache_info() + + +def get_nodes_in_unsupported_ctx_regions(gm) -> set[torch.fx.Node]: + """ + Finds the node within `autocast` or other supported context and marks them as unsupported. + Even though, thunder may support the operation within the reason, it doesn't correctly apply the change + triggered from the context. + """ + # NOTE - Currently only detects the autocast regions. + + nodes_in_unsupported_ctx_regions: set[torch.fx.Node] = set() + ctx_cnt = 0 # Count of `enters_autocast` we have seen till now + + # We want to mark nodes with `_enter_autocast` and `_exit_autocast` + # as unsupported as `thunder` doesn't correctly deal with these stateful functions. + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in (torch.amp.autocast_mode._enter_autocast,): + ctx_cnt += 1 + elif node.op == "call_function" and node.target in (torch.amp.autocast_mode._exit_autocast,): + ctx_cnt -= 1 + else: + if ctx_cnt > 0: + nodes_in_unsupported_ctx_regions.add(node) + + return nodes_in_unsupported_ctx_regions + + +def is_node_supported(node: torch.fx.Node) -> tuple[bool, SplitReason | None]: + """ + Determine whether thunder can execute the operation described by this node. + """ + # Docs from the torch.fx.Node - https://pytorch.org/docs/stable/fx.html#torch.fx.Node + # Each Node has a function specified by its op property + # Below are the details for the ones this function is interested in - + # `call_function` applies a free function to some values. + # name is similarly the name of the value to assign to. + # target is the function to be applied. args and kwargs represent + # the arguments to the function, following the Python calling convention + # `call_method` calls a method on a value. + # name is as similar. + # target is the string name of the method to apply to the self argument. + # args and kwargs represent the arguments to invoke the module on, including the self argument + # + # NOTE: `call_module` should be inlined in dynamo graphs since https://github.com/pytorch/pytorch/pull/131275 + # But there is flag to disable inlining `call_module`. Determining `call_module` support would actually require calling `thunder.jit` on it. + # + # `call_module` applies a module in the module hierarchy’s forward() method to given arguments. + # name is as previous. target is the fully-qualified name of the module in the module hierarchy to call. + # args and kwargs represent the arguments to invoke the module on, excluding the self argument + + target = node.target # Target is the function to call. + if node.op == "call_method": + self_arg = node.args[0] + target = getattr(torch.Tensor, node.target, None) + assert target is not None, f"Failed to find method {node.target}" + + # If the operation has automatic registration, we mark it as unsupported as `inductor` might be + # able to deal with it better. + if target in auto_register_ops: + split_reason = SplitReason( + SplitReasonType.MISSING_OP_SUPPORT, + info=f"node with name: {node.name} and target: {node.target} only has an automatic torch fallback in thunder.", + ) + return False, split_reason + + # If thunder has a mapping for this operation, try executing the meta function and see. + # We have a symbol for `torch.where`, but we don't support one overload of it. + # So, we try and execute the meta to get a real signal. + # + # Regarding `inspect.isbuiltin`, dynamo graph uses `+`, `>` which are builtin `add`, `gt`. + # We try to proxify the arguments and call these operations on them to see if they are supported. + if target in _torch_to_thunder_function_map or inspect.isbuiltin(target): + thunder_symbol_or_builtin = _torch_to_thunder_function_map.get(target, target) + did_run, opt_split_reason = try_execute_symbol(thunder_symbol_or_builtin, node) + return did_run, opt_split_reason + + # We found no automatic fallback registration and no mapping to thunder symbol. + split_reason = SplitReason( + SplitReasonType.MISSING_OP_SUPPORT, + info=f"node with name: {node.name} and target: {node.target} didn't have any mapping in thunder.", + ) + return False, split_reason diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 8dd8842245..59d85f8ad9 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -2,7 +2,6 @@ from thunder.tests.framework import instantiate, NOTHING, DynamoThunderExecutor from thunder import dtypes from thunder.dynamo import ThunderCompiler -from thunder.dynamo.compiler import CompilerType from thunder import last_traces import torch @@ -30,25 +29,16 @@ def func(x): # out should have grad_fn and its name should be ThunderFunctionBackward assert out.grad_fn is not None - if not dynamic: - # If dynamic, while trying to execute `x + 1`, we fail with - # "s0 had an unexpected type . Supported types are (, )") - # as the FakeTensor for `x` has shape with SymInt. - assert out.grad_fn.name() == "ThunderFunctionBackward" + assert out.grad_fn.name() == "ThunderFunctionBackward" # We record the GraphModules that was compiled by ThunderCompiler - assert len(backend.subgraph_infos) == 2 + assert len(backend.subgraph_infos) == 2 # 2 due to data-dependent flow - subgraph_info = backend.subgraph_infos[0] - if dynamic: - assert len(subgraph_info.compiled_functions) == 2 # Due to Symint! - else: - assert len(subgraph_info.compiled_functions) == 1 - idx = 1 if dynamic else 0 - compiled_fn_info = subgraph_info.compiled_functions[idx] - assert compiled_fn_info.compiler == CompilerType.THUNDER - assert last_traces(compiled_fn_info.compiled_fn) - assert isinstance(compiled_fn_info.original_graph_module, torch.fx.GraphModule) + for subgraph_info in backend.subgraph_infos: + assert isinstance(subgraph_info.original_graph_module, torch.fx.GraphModule) + assert len(subgraph_info.thunder_compiled_fns) # There was atleast one function compiled with thunder. + for thunder_fn in subgraph_info.thunder_compiled_fns: + assert last_traces(thunder_fn) # Verify that we can fetch last_traces @instantiate( @@ -78,7 +68,7 @@ def func(x): torch.testing.assert_close(actual_grad, expected_grad) assert len(backend.subgraph_infos) == 1 - assert backend.subgraph_infos[0].is_split + assert len(backend.subgraph_infos[0].submodule_to_compiled_functions) > 1 # Verify that the subgraph was split. assert any( "automatic torch fallback" in split_reason.info for split_reason in backend.subgraph_infos[0].split_reasons ) @@ -112,7 +102,7 @@ def func(x): torch.testing.assert_close(actual_grad, expected_grad) assert len(backend.subgraph_infos) == 1 - assert backend.subgraph_infos[0].is_split + assert len(backend.subgraph_infos[0].submodule_to_compiled_functions) > 1 # Verify that the subgraph was split. assert any( "didn't have any mapping in thunder" in split_reason.info for split_reason in backend.subgraph_infos[0].split_reasons @@ -146,9 +136,11 @@ def func(x): expected_grad = torch.autograd.grad(expected, x, g) torch.testing.assert_close(actual_grad, expected_grad) + # 2 subgraphs due to graph-break assert len(backend.subgraph_infos) == 2 + for subgraph_info in backend.subgraph_infos: - assert subgraph_info.is_split + # Verify that for each subgraph we had split due to `autocast` being enabled. assert any( "didn't have any mapping in thunder" in split_reason.info for split_reason in subgraph_info.split_reasons )