From 9f3fc95ee58f63d4478df532a286538c89f531f7 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 26 Aug 2024 13:22:35 +0200 Subject: [PATCH 1/9] add splitter to dynamo backend --- thunder/dynamo/compiler.py | 434 ++++++++++++++++++++++++-- thunder/dynamo/splitter_experiment.py | 210 +++++++++++++ thunder/tests/test_dynamo.py | 120 ++++++- 3 files changed, 740 insertions(+), 24 deletions(-) create mode 100644 thunder/dynamo/splitter_experiment.py diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index e67216c0eb..86e5f46214 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -1,7 +1,25 @@ +from enum import Enum, auto +import dataclasses +from typing import List, Dict, Optional, Tuple +from collections.abc import Callable +import pprint +import itertools +import copy +from functools import partial +import operator +import inspect + import torch +from torch.fx.passes.split_module import split_module as fx_split_module import warnings +from collections.abc import Mapping +from torch.fx.passes import operator_support from thunder.core.baseutils import run_once +from thunder.torch.default_torch_ops import torch_auto_registered_ops +from thunder.torch import _torch_to_thunder_function_map + +auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values())) @run_once @@ -12,6 +30,262 @@ def _warn_thunder_compiler(): ) +class CompilerType(Enum): + """ + An enumeration representing different types of compilers. + """ + + THUNDER = auto() + TORCH_INDUCTOR = auto() + + +@dataclasses.dataclass +class CompiledFunction: + """ + A dataclass representing a compiled function along with its original graph module and compiler type. + + Attributes: + original_graph_module (torch.fx.GraphModule): The original graph module from which the function is compiled. + compiled_fn (Callable): The compiled function. + compiler (CompilerType): The type of compiler used to compile the function. + """ + + original_graph_module: torch.fx.GraphModule + compiled_fn: Callable + compiler: CompilerType + + +class SplitReasonType(Enum): + """ + An enumeration representing different reasons for split in the graph. + """ + + UNSUPPORTED_NODE = auto() + MISSING_OP_SUPPORT = auto() + EXCEPTION_PROXY_THUNDER_OP = auto() + EXCEPTION_META_THUNDER_OP = auto() + + +@dataclasses.dataclass +class SplitReason: + """ + A dataclass containing information about a split. + + Attributes: + type (SplitReasonType): Reason for the split. + info (str): String with details of what caused the split. + exception (Exception | None): Exception if there was any. + """ + + type: SplitReasonType + info: str | None + exception: Exception | None = None + + +@dataclasses.dataclass +class SubgraphInfo: + """ + A dataclass containing information about a subgraph. + + Attributes: + original_graph_module (torch.fx.GraphModule): The original graph module. + compiled_functions (list[CompiledFunction]): A list of compiled functions derived from the subgraph. This will be a list with one function in case the graph was not split. + is_split (bool): Indicates whether the subgraph has been split. This happens if there was a thunder unsupported functionality. + split_reasons (list[SplitReason] | None): Optional list of reasons explaining why the subgraph was split. Present only if `is_split` is True. + split_graph_module (torch.fx.GraphModule | None): Optional. The graph module for the split subgraph. Present only if `is_split` is True. + """ + + original_graph_module: torch.fx.GraphModule + compiled_functions: list[CompiledFunction] + is_split: bool + split_reasons: list | None = None + split_graph_module: torch.fx.GraphModule | None = None + + +def try_execute_symbol(thunder_symbol: "Symbol", node: torch.fx.Node) -> tuple[bool, SplitReason | None]: + """ + Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments. + + This function operates within a Thunder tracing context to generate proxies for the provided node's arguments. + It then attempts to execute the Thunder symbol with these proxies. If any exceptions occur during proxy creation + or execution, it returns a tuple indicating failure and provides a `SplitReason` detailing the exception. + + Args: + thunder_symbol (Symbol): The Thunder symbol to be executed. This is expected to be a callable that can + operate on proxy arguments. + node (torch.fx.Node): The Torch FX node whose arguments are to be proxied and passed to the Thunder symbol. + + Returns: + tuple[bool, SplitReason | None]: A tuple where the first element is a boolean whether the execution passed or failed. + The second element is a `SplitReason` object if an error occurred, or `None` if the execution was successful. + """ + import thunder + from thunder.core.trace import TraceCtx + from thunder.core.proxies import proxy + + trc = TraceCtx() + # We need to be under trace context to generate proxies. + with thunder.core.trace.tracectx(trc): + try: + + def make_tensor_proxy(arg_node): + # This is a Node in the graph representing a Tensor. + if isinstance(arg_node, torch.fx.Node): + example_value = arg_node.meta["example_value"] + + # This fails if the shape of the FakeTensor contains SymInts. + return proxy(example_value) + + # This is int, float, etc. + # TODO(kshitij12345) - verify the above line for more cases. + return arg_node + + proxy_args = tuple(map(make_tensor_proxy, node.args)) + proxy_kwargs = {k: make_tensor_proxy(v) for k, v in node.kwargs.items()} + except Exception as e: + return False, SplitReason( + SplitReasonType.EXCEPTION_PROXY_THUNDER_OP, + f"Failed while creating proxy for node with name: {node.name} and target: {node.target}, see exception field", + exception=e, + ) + + try: + thunder_symbol(*proxy_args, **proxy_kwargs) + except Exception as e: + return False, SplitReason( + SplitReasonType.EXCEPTION_META_THUNDER_OP, + f"Failed while running meta for node with name: {node.name} and target: {node.target}, see exception field", + exception=e, + ) + + # Execution with proxies was successful. + return True, None + + +class ThunderOperatorSupport: + def __init__(self, gm): + self.gm = gm + self.unsupported_nodes = set() + self.find_unsupported_ctx_regions(gm) + self.split_reasons: list[SplitReason] = [] + + def find_unsupported_ctx_regions(self, gm): + """ + Finds the node within `autocast` or other supported context and marks them as unsupported. + Even though, thunder may support the operation within the reason, it doesn't correctly apply the change + triggered from the context. + """ + # NOTE - Currently only detects the autocast regions. + + ctx_cnt = 0 # Count of `enters_autocast` we have seen till now + + # We want to mark nodes with `_enter_autocast` and `_exit_autocast` + # as unsupported as `thunder` doesn't correctly deal with these stateful functions. + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in (torch.amp.autocast_mode._enter_autocast,): + ctx_cnt += 1 + elif node.op == "call_function" and node.target in (torch.amp.autocast_mode._exit_autocast,): + ctx_cnt -= 1 + else: + if ctx_cnt > 0: + self.unsupported_nodes.add(node) + + def is_node_supported(self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node): + """ + Determine whether thunder can execute the operation described by this node. + """ + # These are the nodes which are in unsupported context regions + if node in self.unsupported_nodes: + self.split_reasons.append( + SplitReason( + SplitReasonType.UNSUPPORTED_NODE, + info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.", + ) + ) + return False + + # Docs from the torch.fx.Node - https://pytorch.org/docs/stable/fx.html#torch.fx.Node + # Each Node has a function specified by its op property + # Below are the details for the ones this function is interested in - + # `call_function` applies a free function to some values. + # name is similarly the name of the value to assign to. + # target is the function to be applied. args and kwargs represent + # the arguments to the function, following the Python calling convention + # `call_method` calls a method on a value. + # name is as similar. + # target is the string name of the method to apply to the self argument. + # args and kwargs represent the arguments to invoke the module on, including the self argument + # + # NOTE: `call_module` should be inlined in dynamo graphs since https://github.com/pytorch/pytorch/pull/131275 + # But there is flag to disable inlining `call_module`. Determining `call_module` support would actually require calling `thunder.jit` on it. + # + # `call_module` applies a module in the module hierarchy’s forward() method to given arguments. + # name is as previous. target is the fully-qualified name of the module in the module hierarchy to call. + # args and kwargs represent the arguments to invoke the module on, excluding the self argument + + target = node.target # Target is the function to call. + if node.op == "call_method": + self_arg = node.args[0] + target = getattr(torch.Tensor, node.target, None) + assert target is not None, f"Failed to find method {node.target}" + + # If the operation has automatic registration, we mark it as unsupported as `inductor` might be + # able to deal with it better. + if target in auto_register_ops: + self.split_reasons.append( + SplitReason( + SplitReasonType.MISSING_OP_SUPPORT, + info=f"node with name: {node.name} and target: {node.target} only has an automatic torch fallback in thunder.", + ) + ) + return False + + # If thunder has a mapping for this operation, try executing the meta function and see. + # We have a symbol for `torch.where`, but we don't support one overload of it. + # So, we try and execute the meta to get a real signal. + # + # Regarding `inspect.isbuiltin`, dynamo graph uses `+`, `>` which are builtin `add`, `gt`. + # We try to proxify the arguments and call these operations on them to see if they are supported. + if target in _torch_to_thunder_function_map or inspect.isbuiltin(target): + if target in [torch.ones]: # Factory functions. (removing this will lead to split) + # NOTE - Factory functions don't work as the expect `_cache_info` to be populated + # with default dtype but `_cache_info` is only created and populated in `thunder.jit` path. + # my_trc = TraceCtx() + # with tracectx(my_trc): + # thunder.torch.ones(3, 3) + return True + + thunder_symbol_or_builtin = _torch_to_thunder_function_map.get(target, target) + did_run, opt_split_reason = try_execute_symbol(thunder_symbol_or_builtin, node) + if opt_split_reason is not None: + self.split_reasons.append(opt_split_reason) + return did_run + + # We found no automatic fallback registration and no mapping to thunder symbol. + self.split_reasons.append( + SplitReason( + SplitReasonType.MISSING_OP_SUPPORT, + info=f"node with name: {node.name} and target: {node.target} didn't have any mapping in thunder.", + ) + ) + return False + + +def _all_graph_supported_by_thunder(gm: torch.fx.GraphModule, sample_input: list[torch.SymInt, torch.Tensor]) -> bool: + """ + Determine whether there is any thunder unsupported operation. + """ + # NOTE - Unused for now. + op_support = ThunderOperatorSupport(gm) + supported = True + for node in gm.graph.nodes: + if node.op in ["call_method", "call_function"]: + supported = op_support.is_node_supported(gm, node) + if not supported: + break + return supported + + class ThunderCompiler: def __init__(self, **thunder_options): """ @@ -36,27 +310,141 @@ def __init__(self, **thunder_options): ... return x - 1 >>> out = func(x) """ - from thunder import ThunderModule + from thunder import ThunderModule, jit _warn_thunder_compiler() # Thunder-compiled functions should be readily available for inspection - # and testing, so we will store them in a list. The order of the + # and testing, so we will store them in a list[SubgraphInfo]. The order of the # functions in the list will be the same as the order in which they were - # compiled. In addition, we will store a mapping from the ThunderModule - # to the GraphModule that was passed to ThunderCompiler. This will allow - # us to inspect the GraphModule that was compiled by Thunder. - self.thunder_fns: list[ThunderModule] = [] - self.thunder_to_gm: dict[ThunderModule, torch.fx.GraphModule] = {} + # compiled. + # Ref to the documentation of `SubgraphInfo` to know more about the information it contains. + self.subgraph_infos: list[SubgraphInfo] = [] self.thunder_options = thunder_options + self._thunder_jit = partial(jit, **thunder_options) + + def splitter( + self, gm: torch.fx.GraphModule, _unused_sample_args: list[torch.SymInt, torch.Tensor] + ) -> torch.fx.GraphModule: + """ + This method will split graph into multiple graph modules based on thunder supported operations. + This function will try to split the graph in contiguous partitions. + + Example: + # All operations are supported by thunder + class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[2]"): + l_x_ = L_x_ + + y: "f32[2]" = torch.sin(l_x_) + matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None + return (matmul,) - # TODO: There will be pieces of Dynamo IR that Thunder cannot compile, so we - # will need to build a fallback mechanism to handle those cases. - # Possible stages of the compilation that need to be saved for inspection: - # 1. The GraphModule as it was passed to ThunderCompiler. - # 2. The GraphModule after split for Thunder/PyTorch. - # 3. If the whole GraphModule is not supported, record the reasons why. + # Split Graph: All operations are supported by thunder, we will see only one partition. + class GraphModule(torch.nn.Module): + def forward(self, l_x_: "f32[2]"): + submod_1 = self.submod_1(l_x_); l_x_ = None + return (submod_1,) + + class submod_1(torch.nn.Module): + def forward(self, l_x_: "f32[2]"): + y: "f32[2]" = torch.sin(l_x_) + matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None + return matmul + + Example: + # With unsupported operation `sinc` + class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[2]"): + l_x_ = L_x_ + + y: "f32[2]" = torch.sinc(l_x_) + + matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None + return (matmul,) + + # Split Graph: Since `sinc` is unsupported, we will see two partitions, one for thunder and one for inductor. + class GraphModule(torch.nn.Module): + def forward(self, l_x_: "f32[2]"): + submod_1 = self.submod_1(l_x_) + submod_2 = self.submod_2(l_x_, submod_1); l_x_ = submod_1 = None + return (submod_2,) + + class submod_1(torch.nn.Module): # Partition for inductor + def forward(self, l_x_: "f32[2]"): + y: "f32[2]" = torch.sinc(l_x_); l_x_ = None + return y + + class submod_2(torch.nn.Module): # Partition for thunder + def forward(self, l_x_: "f32[2]", y: "f32[2]"): + matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None + return matmul + """ + # Create an `ThunderOperatorSupport` instance which will be used in the callback. + # This will determine whether the operation represented by the node is supported by thunder. + operator_support = ThunderOperatorSupport(gm) + + # The callback below is called for every node in the graph. + # It returns an `int` denoting the parition where the node should be placed. + # We want to partition the graph into contiguous regions (with one or more operations) + # into thunder supported or unsupported region. + # `prev_value` is used to determine if we are still in same region (i.e. supported region or unsupported region). + # `partition_cnt` is bumped everytime we change the region i.e. flip from supported to unsupported or from unsupported to supported. + # `supported_partitions` is used to track the thunder supported partitions. + prev_value = None + partition_cnt = 0 + supported_partitions = set() + + def callback(node) -> int: + assert node.op not in ( + "placeholder", + "get_attr", + "output", + ), f"fx.split_module should have only passed node.op=call_* but received {node.op}" + nonlocal prev_value, partition_cnt + is_thunder_supported = operator_support.is_node_supported(gm, node) + if prev_value == is_thunder_supported: # We are in the same region. + return partition_cnt + + # There is a flip. Either from supported to unsupported or unsupported to supported. + prev_value = is_thunder_supported + partition_cnt += 1 # Bump the region cnt. + + if is_thunder_supported: + supported_partitions.add(partition_cnt) + return partition_cnt + + # `fx_split_module` iterates over nodes and determines the partition to place them based on the callback. + split_module: torch.fx.GraphModule = fx_split_module( + gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True + ) + + def is_thunder_supported_partition(node: torch.fx.Node) -> bool: + return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions + + # Call compile on the split region/s. + comipled_fn = [] + for node in split_module.graph.nodes: + if is_thunder_supported_partition(node): + graph_module = getattr(split_module, node.name) + jit_fn = self._thunder_jit(graph_module) + setattr(split_module, node.name, jit_fn) + comipled_fn.append(CompiledFunction(graph_module, jit_fn, CompilerType.THUNDER)) + elif node.name.startswith("submod"): # For inductor + graph_module = getattr(split_module, node.name) + jit_fn = torch.compile(graph_module, backend="inductor") + setattr(split_module, node.name, jit_fn) + comipled_fn.append(CompiledFunction(graph_module, jit_fn, CompilerType.TORCH_INDUCTOR)) + else: + # Everything else is a glue code to call and pass outputs between the other partitions. + pass + + gm.print_readable() + # Append the details regarding this graph/subgraph. + self.subgraph_infos.append(SubgraphInfo(gm, comipled_fn, True, operator_support.split_reasons, split_module)) + split_module.print_readable() + return split_module def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): from thunder import jit @@ -65,9 +453,17 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor # force recompilation of the GraphModule before passing it to Thunder. gm.real_recompile() - # Here in the future we could add some logic to check if the GraphModule - # is executable by Thunder, but for now we simply compile it and return - jitted_gm = jit(gm, **self.thunder_options) - self.thunder_fns.append(jitted_gm) - self.thunder_to_gm[jitted_gm] = gm - return jitted_gm + # Check if the complete graph `gm` is supported by thunder + # If yes, pass the whole `gm` to `thunder.jit` and return the compiled function. + # if is_graph_supported_by_thunder(gm, sample_args): + # jitted_gm = self.thunder_jit(gm) + # self.thunder_fns.append(jitted_gm) + # self.thunder_to_gm[jitted_gm] = gm + # compiled_fn = CompiledFunction(gm, jitted_gm, CompilerType.THUNDER) + # self.subgraph_infos.append(SubgraphInfo(gm, [compiled_fn], False)) + # return jitted_gm + + # The whole graph may not be supported by `thunder`, so we split it in `thunder` supported sections + # and unsupported sections which are passed to `torch.compile(backend='inductor')` + split_module = self.splitter(gm, sample_args) + return split_module diff --git a/thunder/dynamo/splitter_experiment.py b/thunder/dynamo/splitter_experiment.py new file mode 100644 index 0000000000..67fbb08218 --- /dev/null +++ b/thunder/dynamo/splitter_experiment.py @@ -0,0 +1,210 @@ +# CapabilityBasedParitioner returns a GraphModule where `fused_*` represent the subgraphs +# that should go to `thunder` and the forward of this graph module should be passed to `torch.compile` (after removing thunder bits) +# Example - +# class GraphModule(torch.nn.Module): +# def forward(self, L_x_: "f32[2]"): +# l_x_ = L_x_ + +# # No stacktrace found for following nodes +# _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', None, True, None) + +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:181 in func, code: return torch.matmul(x, y) +# fused_0: "bf16[]" = self.fused_0(l_x_); l_x_ = None + +# # No stacktrace found for following nodes +# _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None +# return (fused_0,) + +# class fused_0(torch.nn.Module): +# def forward(self, l_x_: "f32[2]"): +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:177 in func, code: x = x + 2 +# x: "f32[2]" = l_x_ + 2; l_x_ = None + +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:179 in func, code: z = torch.ones(3, 3) +# z: "f32[3, 3]" = torch.ones(3, 3); z = None + +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:180 in func, code: y = torch.sin(x) +# y: "f32[2]" = torch.sin(x) + +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:181 in func, code: return torch.matmul(x, y) +# matmul: "bf16[]" = torch.matmul(x, y); x = y = None +# return matmul + + +def capability_partitioner_splitter(gm, sample_args): + gm_copy = copy.deepcopy(gm) + op_support = ThunderOperatorSupport(gm_copy) + partitioner = CapabilityBasedPartitioner(gm_copy, op_support) + fused_partition = partitioner.partition_and_fuse() + gm_copy.print_readable() + return gm_copy + + +# Splitter with _SplitterBase. + +# class GraphModule(torch.nn.Module): +# def forward(self, L_x_: "f32[2]"): +# l_x_ = L_x_ + +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:176 in func, code: x = x + 2 +# x: "f32[2]" = l_x_ + 2; l_x_ = None + +# # No stacktrace found for following nodes +# _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', None, True, None) + +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:178 in func, code: y = torch.sin(x) +# y: "f32[2]" = torch.sin(x) + +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:179 in func, code: return torch.matmul(x, y) +# matmul: "bf16[]" = torch.matmul(x, y); x = y = None + +# # No stacktrace found for following nodes +# _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None +# return (matmul,) + +# Got 1 acc subgraphs and 2 non-acc subgraphs +# class GraphModule(torch.nn.Module): +# def forward(self, l_x_: "f32[2]"): +# # No stacktrace found for following nodes +# _run_on_cpu_0 = self._run_on_cpu_0(); _run_on_cpu_0 = None +# _run_on_acc_1 = self._run_on_acc_1(l_x_); l_x_ = None +# _run_on_cpu_2 = self._run_on_cpu_2(_run_on_acc_1); _run_on_acc_1 = None +# return (_run_on_cpu_2,) + +# class _run_on_cpu_0(torch.nn.Module): +# def forward(self): +# # No stacktrace found for following nodes +# _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', None, True, None) +# _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None +# return () + +# class _run_on_acc_1(torch.nn.Module): +# def forward(self, l_x_: "f32[2]"): +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:176 in func, code: x = x + 2 +# x: "f32[2]" = l_x_ + 2; l_x_ = None +# return x + +# class _run_on_cpu_2(torch.nn.Module): +# def forward(self, x: "f32[2]"): +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:178 in func, code: y = torch.sin(x) +# y: "f32[2]" = torch.sin(x) + +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:179 in func, code: return torch.matmul(x, y) +# matmul: "bf16[]" = torch.matmul(x, y); x = y = None +# return matmul + + +class GraphModuleSplitter(torch.fx.passes.splitter_base._SplitterBase): + def starter_nodes(self): + """ + Finds nodes that consume module inputs or get_attr nodes. + """ + starter_cpu_nodes: NodeSet = set() + starter_acc_nodes: NodeSet = set() + + for node in self.module.graph.nodes: + if node.op not in {"placeholder", "get_attr"}: + continue + for user in node.users: + if user in self.acc_nodes: + starter_acc_nodes.add(user) + else: + starter_cpu_nodes.add(user) + + for node in self.module.graph.nodes: + if node.op in {"output", "placeholder", "get_attr"}: + continue + + if len(self.deps[node]) == 0: + if node in self.acc_nodes: + starter_acc_nodes.add(node) + else: + starter_cpu_nodes.add(node) + + return starter_cpu_nodes, starter_acc_nodes + + +def splitter(self, gm, sample_input): + """ + This function splits the graph provided by Dynamo + if it contains any operation or construct that is not supported by thunder. + For the unsupported subgraph, it is passed to inductor. + """ + from thunder import jit + + # Setup the splitter class + settings = torch.fx.passes.splitter_base._SplitterSettingBase(allow_non_tensor=True) + splitter = GraphModuleSplitter(gm, sample_input, operator_support=ThunderOperatorSupport(gm), settings=settings) + gm.print_readable() + # Call the splitter to split GraphModule. + split_module = splitter() + split_module.print_readable() + compiled_funcs = [] + for node in split_module.graph.nodes: + if node.name.startswith("_run_on_acc_"): + graph_module = getattr(split_module, node.name) + jit_fn = self.thunder_jit(graph_module) + setattr(split_module, node.name, jit_fn) + compiled_funcs.append(jit_fn) + if node.name.startswith("_run_on_cpu_") or node.name.startswith("_run_on_gpu_"): + graph_module = getattr(split_module, node.name) + jit_fn = torch.compile(graph_module, backend="inductor") + setattr(split_module, node.name, jit_fn) + compiled_funcs.append(jit_fn) + + self.subgraph_infos.append(SubgraphInfo(gm, True, compiled_funcs, [], split_module)) + # split_module.print_readable() + return split_module + + +# With the current, approach +# Original Graph +# class GraphModule(torch.nn.Module): +# def forward(self, L_x_: "f32[2]"): +# l_x_ = L_x_ + +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:176 in func, code: x = x + 2 +# x: "f32[2]" = l_x_ + 2; l_x_ = None + +# # No stacktrace found for following nodes +# _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', None, True, None) + +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:178 in func, code: y = torch.sin(x) +# y: "f32[2]" = torch.sin(x) + +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:179 in func, code: return torch.matmul(x, y) +# matmul: "bf16[]" = torch.matmul(x, y); x = y = None + +# # No stacktrace found for following nodes +# _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None +# return (matmul,)` +# + +# Split Graph +# class GraphModule(torch.nn.Module): +# def forward(self, l_x_: "f32[2]"): +# # No stacktrace found for following nodes +# submod_1 = self.submod_1(l_x_); l_x_ = None +# submod_2 = self.submod_2(submod_1); submod_1 = None +# return (submod_2,) + +# class submod_1(torch.nn.Module): +# def forward(self, l_x_: "f32[2]"): +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:176 in func, code: x = x + 2 +# x: "f32[2]" = l_x_ + 2; l_x_ = None +# return x + +# class submod_2(torch.nn.Module): +# def forward(self, x: "f32[2]"): +# # No stacktrace found for following nodes +# _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', None, True, None) + +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:178 in func, code: y = torch.sin(x) +# y: "f32[2]" = torch.sin(x) + +# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:179 in func, code: return torch.matmul(x, y) +# matmul: "bf16[]" = torch.matmul(x, y); x = y = None + +# # No stacktrace found for following nodes +# _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None +# return matmul diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index bd6f26abe5..3e117655fc 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -2,6 +2,7 @@ from thunder.tests.framework import instantiate, NOTHING, DynamoThunderExecutor from thunder import dtypes from thunder.dynamo import ThunderCompiler +from thunder.dynamo.compiler import CompilerType from thunder import last_traces import torch @@ -32,9 +33,118 @@ def func(x): assert out.grad_fn.name() == "ThunderFunctionBackward" # We record the GraphModules that was compiled by ThunderCompiler - assert len(backend.thunder_to_gm) == 2 - thunder_func, gm = list(backend.thunder_to_gm.items())[0] - assert isinstance(gm, torch.fx.GraphModule) + assert len(backend.subgraph_infos) == 2 - # This shouldn't be empty - assert last_traces(thunder_func) + subgraph_info = backend.subgraph_infos[0] + if dynamic: + assert len(subgraph_info.compiled_functions) == 2 # Due to Symint! + else: + assert len(subgraph_info.compiled_functions) == 1 + idx = 1 if dynamic else 0 + compiled_fn_info = subgraph_info.compiled_functions[idx] + assert compiled_fn_info.compiler == CompilerType.THUNDER + assert last_traces(compiled_fn_info.compiled_fn) + assert isinstance(compiled_fn_info.original_graph_module, torch.fx.GraphModule) + + +@instantiate( + dtypes=NOTHING, + executors=[DynamoThunderExecutor], + decorators=(pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")),), +) +def test_basic_splitter(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): + x = torch.ones(2, 2, device=device, dtype=dtype, requires_grad=True) + + backend = ThunderCompiler() + + def func(x): + # torch.sinc has automatic fallback registered. + x = x.exp() + y = torch.sinc(x) + torch.cos(x) + return y + 1 + + cfunc = torch.compile(func, backend=backend, dynamic=dynamic) + expected = torch.compile(func, dynamic=False)(x) + actual = cfunc(x) + + g = torch.rand_like(actual) + torch.testing.assert_close(actual, expected) + actual_grad = torch.autograd.grad(actual, x, g) + expected_grad = torch.autograd.grad(expected, x, g) + torch.testing.assert_close(actual_grad, expected_grad) + + assert len(backend.subgraph_infos) == 1 + assert backend.subgraph_infos[0].is_split + assert any( + "automatic torch fallback" in split_reason.info for split_reason in backend.subgraph_infos[0].split_reasons + ) + + +@instantiate( + dtypes=NOTHING, + executors=[DynamoThunderExecutor], + decorators=(pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")),), +) +def test_splitter_unsupported_ctx(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): + x = torch.rand(2, 2, device=device, dtype=dtype, requires_grad=True) + + backend = ThunderCompiler() + + def func(x): + x = x + 2 + with torch.autocast("cpu"): + y = torch.log(x) + return torch.matmul(x, y) + + expected = torch.compile(func, dynamic=False)(x) + + cfunc = torch.compile(func, backend=backend, dynamic=dynamic) + actual = cfunc(x) + + g = torch.rand_like(actual) + torch.testing.assert_close(actual, expected) + actual_grad = torch.autograd.grad(actual, x, g) + expected_grad = torch.autograd.grad(expected, x, g) + torch.testing.assert_close(actual_grad, expected_grad) + + assert len(backend.subgraph_infos) == 1 + assert backend.subgraph_infos[0].is_split + assert any( + "didn't have any mapping in thunder" in split_reason.info + for split_reason in backend.subgraph_infos[0].split_reasons + ) + + +@instantiate( + dtypes=NOTHING, + executors=[DynamoThunderExecutor], + decorators=(pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")),), +) +def test_splitter_unsupported_ctx_with_graph_break(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): + x = torch.rand(2, 2, device=device, dtype=dtype, requires_grad=True) + + backend = ThunderCompiler() + + def func(x): + x = x + 2 + with torch.autocast("cpu"): + y = torch.sin(x) + torch._dynamo.graph_break() + return torch.matmul(x, y) + + expected = torch.compile(func, dynamic=False)(x) + cfunc = torch.compile(func, backend=backend, dynamic=dynamic) + actual = cfunc(x) + + g = torch.rand_like(actual) + torch.testing.assert_close(actual, expected) + actual_grad = torch.autograd.grad(actual, x, g) + expected_grad = torch.autograd.grad(expected, x, g) + torch.testing.assert_close(actual_grad, expected_grad) + + assert len(backend.subgraph_infos) == 2 + for subgraph_info in backend.subgraph_infos: + assert subgraph_info.is_split + assert any( + "didn't have any mapping in thunder" in split_reason.info for split_reason in subgraph_info.split_reasons + ) From ea3de2bffcadc39e3a0041be39705778dd056dbc Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 26 Aug 2024 14:06:04 +0200 Subject: [PATCH 2/9] failure with dynamic=True --- thunder/tests/test_dynamo.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 3e117655fc..8dd8842245 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -30,7 +30,11 @@ def func(x): # out should have grad_fn and its name should be ThunderFunctionBackward assert out.grad_fn is not None - assert out.grad_fn.name() == "ThunderFunctionBackward" + if not dynamic: + # If dynamic, while trying to execute `x + 1`, we fail with + # "s0 had an unexpected type . Supported types are (, )") + # as the FakeTensor for `x` has shape with SymInt. + assert out.grad_fn.name() == "ThunderFunctionBackward" # We record the GraphModules that was compiled by ThunderCompiler assert len(backend.subgraph_infos) == 2 From a5158d05e25a8d0f6eb64a5ac049433c75fd64b6 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 29 Aug 2024 11:41:53 +0200 Subject: [PATCH 3/9] address review : part 1 --- thunder/dynamo/compiler.py | 341 ++++++----------------------------- thunder/dynamo/utils.py | 262 +++++++++++++++++++++++++++ thunder/tests/test_dynamo.py | 32 ++-- 3 files changed, 328 insertions(+), 307 deletions(-) create mode 100644 thunder/dynamo/utils.py diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 86e5f46214..7f05c81282 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -1,25 +1,24 @@ -from enum import Enum, auto -import dataclasses -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Optional, Tuple, Set from collections.abc import Callable import pprint -import itertools -import copy from functools import partial -import operator -import inspect import torch -from torch.fx.passes.split_module import split_module as fx_split_module +from torch.fx.passes.split_module import split_module import warnings from collections.abc import Mapping -from torch.fx.passes import operator_support from thunder.core.baseutils import run_once -from thunder.torch.default_torch_ops import torch_auto_registered_ops -from thunder.torch import _torch_to_thunder_function_map -auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values())) +from thunder.dynamo.utils import ( + SubgraphInfo, + CompiledFunction, + CompilerType, + SplitReason, + SplitReasonType, + is_node_supported, + get_nodes_in_unsupported_ctx_regions, +) @run_once @@ -30,262 +29,6 @@ def _warn_thunder_compiler(): ) -class CompilerType(Enum): - """ - An enumeration representing different types of compilers. - """ - - THUNDER = auto() - TORCH_INDUCTOR = auto() - - -@dataclasses.dataclass -class CompiledFunction: - """ - A dataclass representing a compiled function along with its original graph module and compiler type. - - Attributes: - original_graph_module (torch.fx.GraphModule): The original graph module from which the function is compiled. - compiled_fn (Callable): The compiled function. - compiler (CompilerType): The type of compiler used to compile the function. - """ - - original_graph_module: torch.fx.GraphModule - compiled_fn: Callable - compiler: CompilerType - - -class SplitReasonType(Enum): - """ - An enumeration representing different reasons for split in the graph. - """ - - UNSUPPORTED_NODE = auto() - MISSING_OP_SUPPORT = auto() - EXCEPTION_PROXY_THUNDER_OP = auto() - EXCEPTION_META_THUNDER_OP = auto() - - -@dataclasses.dataclass -class SplitReason: - """ - A dataclass containing information about a split. - - Attributes: - type (SplitReasonType): Reason for the split. - info (str): String with details of what caused the split. - exception (Exception | None): Exception if there was any. - """ - - type: SplitReasonType - info: str | None - exception: Exception | None = None - - -@dataclasses.dataclass -class SubgraphInfo: - """ - A dataclass containing information about a subgraph. - - Attributes: - original_graph_module (torch.fx.GraphModule): The original graph module. - compiled_functions (list[CompiledFunction]): A list of compiled functions derived from the subgraph. This will be a list with one function in case the graph was not split. - is_split (bool): Indicates whether the subgraph has been split. This happens if there was a thunder unsupported functionality. - split_reasons (list[SplitReason] | None): Optional list of reasons explaining why the subgraph was split. Present only if `is_split` is True. - split_graph_module (torch.fx.GraphModule | None): Optional. The graph module for the split subgraph. Present only if `is_split` is True. - """ - - original_graph_module: torch.fx.GraphModule - compiled_functions: list[CompiledFunction] - is_split: bool - split_reasons: list | None = None - split_graph_module: torch.fx.GraphModule | None = None - - -def try_execute_symbol(thunder_symbol: "Symbol", node: torch.fx.Node) -> tuple[bool, SplitReason | None]: - """ - Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments. - - This function operates within a Thunder tracing context to generate proxies for the provided node's arguments. - It then attempts to execute the Thunder symbol with these proxies. If any exceptions occur during proxy creation - or execution, it returns a tuple indicating failure and provides a `SplitReason` detailing the exception. - - Args: - thunder_symbol (Symbol): The Thunder symbol to be executed. This is expected to be a callable that can - operate on proxy arguments. - node (torch.fx.Node): The Torch FX node whose arguments are to be proxied and passed to the Thunder symbol. - - Returns: - tuple[bool, SplitReason | None]: A tuple where the first element is a boolean whether the execution passed or failed. - The second element is a `SplitReason` object if an error occurred, or `None` if the execution was successful. - """ - import thunder - from thunder.core.trace import TraceCtx - from thunder.core.proxies import proxy - - trc = TraceCtx() - # We need to be under trace context to generate proxies. - with thunder.core.trace.tracectx(trc): - try: - - def make_tensor_proxy(arg_node): - # This is a Node in the graph representing a Tensor. - if isinstance(arg_node, torch.fx.Node): - example_value = arg_node.meta["example_value"] - - # This fails if the shape of the FakeTensor contains SymInts. - return proxy(example_value) - - # This is int, float, etc. - # TODO(kshitij12345) - verify the above line for more cases. - return arg_node - - proxy_args = tuple(map(make_tensor_proxy, node.args)) - proxy_kwargs = {k: make_tensor_proxy(v) for k, v in node.kwargs.items()} - except Exception as e: - return False, SplitReason( - SplitReasonType.EXCEPTION_PROXY_THUNDER_OP, - f"Failed while creating proxy for node with name: {node.name} and target: {node.target}, see exception field", - exception=e, - ) - - try: - thunder_symbol(*proxy_args, **proxy_kwargs) - except Exception as e: - return False, SplitReason( - SplitReasonType.EXCEPTION_META_THUNDER_OP, - f"Failed while running meta for node with name: {node.name} and target: {node.target}, see exception field", - exception=e, - ) - - # Execution with proxies was successful. - return True, None - - -class ThunderOperatorSupport: - def __init__(self, gm): - self.gm = gm - self.unsupported_nodes = set() - self.find_unsupported_ctx_regions(gm) - self.split_reasons: list[SplitReason] = [] - - def find_unsupported_ctx_regions(self, gm): - """ - Finds the node within `autocast` or other supported context and marks them as unsupported. - Even though, thunder may support the operation within the reason, it doesn't correctly apply the change - triggered from the context. - """ - # NOTE - Currently only detects the autocast regions. - - ctx_cnt = 0 # Count of `enters_autocast` we have seen till now - - # We want to mark nodes with `_enter_autocast` and `_exit_autocast` - # as unsupported as `thunder` doesn't correctly deal with these stateful functions. - for node in gm.graph.nodes: - if node.op == "call_function" and node.target in (torch.amp.autocast_mode._enter_autocast,): - ctx_cnt += 1 - elif node.op == "call_function" and node.target in (torch.amp.autocast_mode._exit_autocast,): - ctx_cnt -= 1 - else: - if ctx_cnt > 0: - self.unsupported_nodes.add(node) - - def is_node_supported(self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node): - """ - Determine whether thunder can execute the operation described by this node. - """ - # These are the nodes which are in unsupported context regions - if node in self.unsupported_nodes: - self.split_reasons.append( - SplitReason( - SplitReasonType.UNSUPPORTED_NODE, - info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.", - ) - ) - return False - - # Docs from the torch.fx.Node - https://pytorch.org/docs/stable/fx.html#torch.fx.Node - # Each Node has a function specified by its op property - # Below are the details for the ones this function is interested in - - # `call_function` applies a free function to some values. - # name is similarly the name of the value to assign to. - # target is the function to be applied. args and kwargs represent - # the arguments to the function, following the Python calling convention - # `call_method` calls a method on a value. - # name is as similar. - # target is the string name of the method to apply to the self argument. - # args and kwargs represent the arguments to invoke the module on, including the self argument - # - # NOTE: `call_module` should be inlined in dynamo graphs since https://github.com/pytorch/pytorch/pull/131275 - # But there is flag to disable inlining `call_module`. Determining `call_module` support would actually require calling `thunder.jit` on it. - # - # `call_module` applies a module in the module hierarchy’s forward() method to given arguments. - # name is as previous. target is the fully-qualified name of the module in the module hierarchy to call. - # args and kwargs represent the arguments to invoke the module on, excluding the self argument - - target = node.target # Target is the function to call. - if node.op == "call_method": - self_arg = node.args[0] - target = getattr(torch.Tensor, node.target, None) - assert target is not None, f"Failed to find method {node.target}" - - # If the operation has automatic registration, we mark it as unsupported as `inductor` might be - # able to deal with it better. - if target in auto_register_ops: - self.split_reasons.append( - SplitReason( - SplitReasonType.MISSING_OP_SUPPORT, - info=f"node with name: {node.name} and target: {node.target} only has an automatic torch fallback in thunder.", - ) - ) - return False - - # If thunder has a mapping for this operation, try executing the meta function and see. - # We have a symbol for `torch.where`, but we don't support one overload of it. - # So, we try and execute the meta to get a real signal. - # - # Regarding `inspect.isbuiltin`, dynamo graph uses `+`, `>` which are builtin `add`, `gt`. - # We try to proxify the arguments and call these operations on them to see if they are supported. - if target in _torch_to_thunder_function_map or inspect.isbuiltin(target): - if target in [torch.ones]: # Factory functions. (removing this will lead to split) - # NOTE - Factory functions don't work as the expect `_cache_info` to be populated - # with default dtype but `_cache_info` is only created and populated in `thunder.jit` path. - # my_trc = TraceCtx() - # with tracectx(my_trc): - # thunder.torch.ones(3, 3) - return True - - thunder_symbol_or_builtin = _torch_to_thunder_function_map.get(target, target) - did_run, opt_split_reason = try_execute_symbol(thunder_symbol_or_builtin, node) - if opt_split_reason is not None: - self.split_reasons.append(opt_split_reason) - return did_run - - # We found no automatic fallback registration and no mapping to thunder symbol. - self.split_reasons.append( - SplitReason( - SplitReasonType.MISSING_OP_SUPPORT, - info=f"node with name: {node.name} and target: {node.target} didn't have any mapping in thunder.", - ) - ) - return False - - -def _all_graph_supported_by_thunder(gm: torch.fx.GraphModule, sample_input: list[torch.SymInt, torch.Tensor]) -> bool: - """ - Determine whether there is any thunder unsupported operation. - """ - # NOTE - Unused for now. - op_support = ThunderOperatorSupport(gm) - supported = True - for node in gm.graph.nodes: - if node.op in ["call_method", "call_function"]: - supported = op_support.is_node_supported(gm, node) - if not supported: - break - return supported - - class ThunderCompiler: def __init__(self, **thunder_options): """ @@ -381,10 +124,6 @@ def forward(self, l_x_: "f32[2]", y: "f32[2]"): matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None return matmul """ - # Create an `ThunderOperatorSupport` instance which will be used in the callback. - # This will determine whether the operation represented by the node is supported by thunder. - operator_support = ThunderOperatorSupport(gm) - # The callback below is called for every node in the graph. # It returns an `int` denoting the parition where the node should be placed. # We want to partition the graph into contiguous regions (with one or more operations) @@ -394,7 +133,10 @@ def forward(self, l_x_: "f32[2]", y: "f32[2]"): # `supported_partitions` is used to track the thunder supported partitions. prev_value = None partition_cnt = 0 - supported_partitions = set() + supported_partitions: set[int] = set() + split_reasons: list[SplitReason] = [] + + nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm) def callback(node) -> int: assert node.op not in ( @@ -403,7 +145,19 @@ def callback(node) -> int: "output", ), f"fx.split_module should have only passed node.op=call_* but received {node.op}" nonlocal prev_value, partition_cnt - is_thunder_supported = operator_support.is_node_supported(gm, node) + + if node in nodes_in_unsupported_ctx_regions: + is_thunder_supported = False + split_reason = SplitReason( + SplitReasonType.UNSUPPORTED_NODE, + info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.", + ) + split_reasons.append(split_reason) + else: + is_thunder_supported, split_reason = is_node_supported(node) + if split_reason is not None: + split_reasons.append(split_reason) + if prev_value == is_thunder_supported: # We are in the same region. return partition_cnt @@ -416,7 +170,7 @@ def callback(node) -> int: return partition_cnt # `fx_split_module` iterates over nodes and determines the partition to place them based on the callback. - split_module: torch.fx.GraphModule = fx_split_module( + split_gm: torch.fx.GraphModule = split_module( gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True ) @@ -424,27 +178,40 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions # Call compile on the split region/s. - comipled_fn = [] - for node in split_module.graph.nodes: + thunder_compiled_fns = [] + submodule_to_compiled_fns = {} + is_split = False + for node in split_gm.graph.nodes: if is_thunder_supported_partition(node): - graph_module = getattr(split_module, node.name) + # there is erase method on GraphModule + graph_module = getattr(split_gm, node.name) jit_fn = self._thunder_jit(graph_module) - setattr(split_module, node.name, jit_fn) - comipled_fn.append(CompiledFunction(graph_module, jit_fn, CompilerType.THUNDER)) + setattr(split_gm, node.name, jit_fn) + thunder_compiled_fns.append(jit_fn) + submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER) elif node.name.startswith("submod"): # For inductor - graph_module = getattr(split_module, node.name) + graph_module = getattr(split_gm, node.name) jit_fn = torch.compile(graph_module, backend="inductor") - setattr(split_module, node.name, jit_fn) - comipled_fn.append(CompiledFunction(graph_module, jit_fn, CompilerType.TORCH_INDUCTOR)) + setattr(split_gm, node.name, jit_fn) + submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR) + is_split = True else: # Everything else is a glue code to call and pass outputs between the other partitions. pass - gm.print_readable() + # gm.print_readable() # Append the details regarding this graph/subgraph. - self.subgraph_infos.append(SubgraphInfo(gm, comipled_fn, True, operator_support.split_reasons, split_module)) - split_module.print_readable() - return split_module + self.subgraph_infos.append( + SubgraphInfo( + gm, + split_gm, + thunder_compiled_fns, + submodule_to_compiled_fns, + split_reasons, + ) + ) + # split_gm.print_readable() + return split_gm def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): from thunder import jit diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py new file mode 100644 index 0000000000..527f4df9e1 --- /dev/null +++ b/thunder/dynamo/utils.py @@ -0,0 +1,262 @@ +from enum import Enum, auto +import dataclasses +from typing import List, Dict, Optional, Tuple, Set +from collections.abc import Callable +import pprint +import itertools +import copy +import inspect + +import torch +from torch.fx.passes.split_module import split_module +import warnings +from collections.abc import Mapping + +from thunder.torch.default_torch_ops import torch_auto_registered_ops +from thunder.torch import _torch_to_thunder_function_map + + +auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values())) + + +class CompilerType(Enum): + """ + An enumeration representing different types of compilers. + """ + + THUNDER = auto() + TORCH_INDUCTOR = auto() + + +@dataclasses.dataclass +class CompiledFunction: + """ + A dataclass representing a compiled function along with its original graph module and compiler type. + + Attributes: + compiled_fn (Callable): The compiled function. + compiler (CompilerType): The type of compiler used to compile the function. + """ + + compiled_fn: Callable + compiler: CompilerType + + +class SplitReasonType(Enum): + """ + An enumeration representing different reasons for split in the graph. + """ + + UNSUPPORTED_NODE = auto() + MISSING_OP_SUPPORT = auto() + EXCEPTION_PROXY_THUNDER_OP = auto() + EXCEPTION_META_THUNDER_OP = auto() + + +@dataclasses.dataclass +class SplitReason: + """ + A dataclass containing information about a split. + + Attributes: + type (SplitReasonType): Reason for the split. + info (str): String with details of what caused the split. + exception (Exception | None): Exception if there was any. + """ + + type: SplitReasonType + info: str | None + exception: Exception | None = None + + +@dataclasses.dataclass +class SubgraphInfo: + """ + A dataclass containing information about a subgraph. + + Attributes: + original_graph_module (torch.fx.GraphModule): The original graph module. + split_graph_module (torch.fx.GraphModule): Optional. The graph module for the split subgraph. + thunder_compiled_fns (list[Callable]): List of thunder optimized callables. This could be None if there the graph module was not supported by thunder. Look at the `split_reasons` for further information. + compiled_functions (list[CompiledFunction]): A list of compiled functions derived from the subgraph. This will be a list with one function in case the graph was not split. + split_reasons (list[SplitReason] | None): Optional list of reasons explaining why the subgraph was split. Present only if there are was a split. + """ + + original_graph_module: torch.fx.GraphModule + split_graph_module: torch.fx.GraphModule + thunder_compiled_fns: list[Callable] + submodule_to_compiled_functions: Mapping[torch.fx.GraphModule, CompiledFunction] + split_reasons: list | None = None + + +def try_execute_symbol(thunder_symbol: "Symbol", node: torch.fx.Node) -> tuple[bool, SplitReason | None]: + """ + Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments. + + This function operates within a Thunder tracing context to generate proxies for the provided node's arguments. + It then attempts to execute the Thunder symbol with these proxies. If any exceptions occur during proxy creation + or execution, it returns a tuple indicating failure and provides a `SplitReason` detailing the exception. + + Args: + thunder_symbol (Symbol): The Thunder symbol to be executed. This is expected to be a callable that can + operate on proxy arguments. + node (torch.fx.Node): The Torch FX node whose arguments are to be proxied and passed to the Thunder symbol. + + Returns: + tuple[bool, SplitReason | None]: A tuple where the first element is a boolean whether the execution passed or failed. + The second element is a `SplitReason` object if an error occurred, or `None` if the execution was successful. + """ + import thunder + from thunder.core.trace import TraceCtx + from thunder.core.proxies import proxy + + @thunder._with_cache_info_ctx + def _run_with_cache_info(): + + # We need cache info here as the default dtype and device support + # for factory functions like ones, zeros, etc expects these details to be present. + # TODO: Move this to CompileData as well? + # This details are in cache info because `jit_ext.py` + # adds checks in prologue for the details which are present in here. + cache_info = thunder._get_cache_info() + cache_info["default_dtype"] = torch.get_default_dtype() + cache_info["default_device"] = torch.get_default_device() + + trc = TraceCtx() + # We need to be under trace context to generate proxies. + with thunder.core.trace.tracectx(trc): + try: + + def make_tensor_proxy(arg_node): + # This is a Node in the graph representing a Tensor. + if isinstance(arg_node, torch.fx.Node): + example_value = arg_node.meta["example_value"] + + if isinstance(example_value, torch.Tensor): + # If `dynamic` shapes are enabled, we may see a FakeTensor + # where shape has SymInt. In that case, we check if we can + # get the concrete value from SymInt. + # Here, we only want to verify that thunder can run an operation. + # So, it is ok to verify with concrete value. + def concrete_shape(x): + def get_backed_value(s): + if isinstance(s, torch.SymInt): + return s.node.hint + return s + + shape = tuple(map(get_backed_value, x.shape)) + return shape + + example_value = example_value.new_ones( + concrete_shape(example_value), device=example_value.device, dtype=example_value.dtype + ) + return proxy(example_value) + + # This is int, float, etc. + # TODO(kshitij12345) - verify the above line for more cases. + return arg_node + + proxy_args = tuple(map(make_tensor_proxy, node.args)) + proxy_kwargs = {k: make_tensor_proxy(v) for k, v in node.kwargs.items()} + except Exception as e: + return False, SplitReason( + SplitReasonType.EXCEPTION_PROXY_THUNDER_OP, + f"Failed while creating proxy for node with name: {node.name} and target: {node.target}, see exception field", + exception=e, + ) + + try: + thunder_symbol(*proxy_args, **proxy_kwargs) + except Exception as e: + return False, SplitReason( + SplitReasonType.EXCEPTION_META_THUNDER_OP, + f"Failed while running meta for node with name: {node.name} and target: {node.target}, see exception field", + exception=e, + ) + + # Execution with proxies was successful. + return True, None + + return _run_with_cache_info() + + +def get_nodes_in_unsupported_ctx_regions(gm) -> set[torch.fx.Node]: + """ + Finds the node within `autocast` or other supported context and marks them as unsupported. + Even though, thunder may support the operation within the reason, it doesn't correctly apply the change + triggered from the context. + """ + # NOTE - Currently only detects the autocast regions. + + nodes_in_unsupported_ctx_regions: set[torch.fx.Node] = set() + ctx_cnt = 0 # Count of `enters_autocast` we have seen till now + + # We want to mark nodes with `_enter_autocast` and `_exit_autocast` + # as unsupported as `thunder` doesn't correctly deal with these stateful functions. + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in (torch.amp.autocast_mode._enter_autocast,): + ctx_cnt += 1 + elif node.op == "call_function" and node.target in (torch.amp.autocast_mode._exit_autocast,): + ctx_cnt -= 1 + else: + if ctx_cnt > 0: + nodes_in_unsupported_ctx_regions.add(node) + + return nodes_in_unsupported_ctx_regions + + +def is_node_supported(node: torch.fx.Node) -> tuple[bool, SplitReason | None]: + """ + Determine whether thunder can execute the operation described by this node. + """ + # Docs from the torch.fx.Node - https://pytorch.org/docs/stable/fx.html#torch.fx.Node + # Each Node has a function specified by its op property + # Below are the details for the ones this function is interested in - + # `call_function` applies a free function to some values. + # name is similarly the name of the value to assign to. + # target is the function to be applied. args and kwargs represent + # the arguments to the function, following the Python calling convention + # `call_method` calls a method on a value. + # name is as similar. + # target is the string name of the method to apply to the self argument. + # args and kwargs represent the arguments to invoke the module on, including the self argument + # + # NOTE: `call_module` should be inlined in dynamo graphs since https://github.com/pytorch/pytorch/pull/131275 + # But there is flag to disable inlining `call_module`. Determining `call_module` support would actually require calling `thunder.jit` on it. + # + # `call_module` applies a module in the module hierarchy’s forward() method to given arguments. + # name is as previous. target is the fully-qualified name of the module in the module hierarchy to call. + # args and kwargs represent the arguments to invoke the module on, excluding the self argument + + target = node.target # Target is the function to call. + if node.op == "call_method": + self_arg = node.args[0] + target = getattr(torch.Tensor, node.target, None) + assert target is not None, f"Failed to find method {node.target}" + + # If the operation has automatic registration, we mark it as unsupported as `inductor` might be + # able to deal with it better. + if target in auto_register_ops: + split_reason = SplitReason( + SplitReasonType.MISSING_OP_SUPPORT, + info=f"node with name: {node.name} and target: {node.target} only has an automatic torch fallback in thunder.", + ) + return False, split_reason + + # If thunder has a mapping for this operation, try executing the meta function and see. + # We have a symbol for `torch.where`, but we don't support one overload of it. + # So, we try and execute the meta to get a real signal. + # + # Regarding `inspect.isbuiltin`, dynamo graph uses `+`, `>` which are builtin `add`, `gt`. + # We try to proxify the arguments and call these operations on them to see if they are supported. + if target in _torch_to_thunder_function_map or inspect.isbuiltin(target): + thunder_symbol_or_builtin = _torch_to_thunder_function_map.get(target, target) + did_run, opt_split_reason = try_execute_symbol(thunder_symbol_or_builtin, node) + return did_run, opt_split_reason + + # We found no automatic fallback registration and no mapping to thunder symbol. + split_reason = SplitReason( + SplitReasonType.MISSING_OP_SUPPORT, + info=f"node with name: {node.name} and target: {node.target} didn't have any mapping in thunder.", + ) + return False, split_reason diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 8dd8842245..59d85f8ad9 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -2,7 +2,6 @@ from thunder.tests.framework import instantiate, NOTHING, DynamoThunderExecutor from thunder import dtypes from thunder.dynamo import ThunderCompiler -from thunder.dynamo.compiler import CompilerType from thunder import last_traces import torch @@ -30,25 +29,16 @@ def func(x): # out should have grad_fn and its name should be ThunderFunctionBackward assert out.grad_fn is not None - if not dynamic: - # If dynamic, while trying to execute `x + 1`, we fail with - # "s0 had an unexpected type . Supported types are (, )") - # as the FakeTensor for `x` has shape with SymInt. - assert out.grad_fn.name() == "ThunderFunctionBackward" + assert out.grad_fn.name() == "ThunderFunctionBackward" # We record the GraphModules that was compiled by ThunderCompiler - assert len(backend.subgraph_infos) == 2 + assert len(backend.subgraph_infos) == 2 # 2 due to data-dependent flow - subgraph_info = backend.subgraph_infos[0] - if dynamic: - assert len(subgraph_info.compiled_functions) == 2 # Due to Symint! - else: - assert len(subgraph_info.compiled_functions) == 1 - idx = 1 if dynamic else 0 - compiled_fn_info = subgraph_info.compiled_functions[idx] - assert compiled_fn_info.compiler == CompilerType.THUNDER - assert last_traces(compiled_fn_info.compiled_fn) - assert isinstance(compiled_fn_info.original_graph_module, torch.fx.GraphModule) + for subgraph_info in backend.subgraph_infos: + assert isinstance(subgraph_info.original_graph_module, torch.fx.GraphModule) + assert len(subgraph_info.thunder_compiled_fns) # There was atleast one function compiled with thunder. + for thunder_fn in subgraph_info.thunder_compiled_fns: + assert last_traces(thunder_fn) # Verify that we can fetch last_traces @instantiate( @@ -78,7 +68,7 @@ def func(x): torch.testing.assert_close(actual_grad, expected_grad) assert len(backend.subgraph_infos) == 1 - assert backend.subgraph_infos[0].is_split + assert len(backend.subgraph_infos[0].submodule_to_compiled_functions) > 1 # Verify that the subgraph was split. assert any( "automatic torch fallback" in split_reason.info for split_reason in backend.subgraph_infos[0].split_reasons ) @@ -112,7 +102,7 @@ def func(x): torch.testing.assert_close(actual_grad, expected_grad) assert len(backend.subgraph_infos) == 1 - assert backend.subgraph_infos[0].is_split + assert len(backend.subgraph_infos[0].submodule_to_compiled_functions) > 1 # Verify that the subgraph was split. assert any( "didn't have any mapping in thunder" in split_reason.info for split_reason in backend.subgraph_infos[0].split_reasons @@ -146,9 +136,11 @@ def func(x): expected_grad = torch.autograd.grad(expected, x, g) torch.testing.assert_close(actual_grad, expected_grad) + # 2 subgraphs due to graph-break assert len(backend.subgraph_infos) == 2 + for subgraph_info in backend.subgraph_infos: - assert subgraph_info.is_split + # Verify that for each subgraph we had split due to `autocast` being enabled. assert any( "didn't have any mapping in thunder" in split_reason.info for split_reason in subgraph_info.split_reasons ) From f22a92117aa26d9c7cac57defdfc515b08ea1ca0 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 29 Aug 2024 11:42:34 +0200 Subject: [PATCH 4/9] remove experimental code --- thunder/dynamo/splitter_experiment.py | 210 -------------------------- 1 file changed, 210 deletions(-) delete mode 100644 thunder/dynamo/splitter_experiment.py diff --git a/thunder/dynamo/splitter_experiment.py b/thunder/dynamo/splitter_experiment.py deleted file mode 100644 index 67fbb08218..0000000000 --- a/thunder/dynamo/splitter_experiment.py +++ /dev/null @@ -1,210 +0,0 @@ -# CapabilityBasedParitioner returns a GraphModule where `fused_*` represent the subgraphs -# that should go to `thunder` and the forward of this graph module should be passed to `torch.compile` (after removing thunder bits) -# Example - -# class GraphModule(torch.nn.Module): -# def forward(self, L_x_: "f32[2]"): -# l_x_ = L_x_ - -# # No stacktrace found for following nodes -# _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', None, True, None) - -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:181 in func, code: return torch.matmul(x, y) -# fused_0: "bf16[]" = self.fused_0(l_x_); l_x_ = None - -# # No stacktrace found for following nodes -# _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None -# return (fused_0,) - -# class fused_0(torch.nn.Module): -# def forward(self, l_x_: "f32[2]"): -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:177 in func, code: x = x + 2 -# x: "f32[2]" = l_x_ + 2; l_x_ = None - -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:179 in func, code: z = torch.ones(3, 3) -# z: "f32[3, 3]" = torch.ones(3, 3); z = None - -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:180 in func, code: y = torch.sin(x) -# y: "f32[2]" = torch.sin(x) - -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:181 in func, code: return torch.matmul(x, y) -# matmul: "bf16[]" = torch.matmul(x, y); x = y = None -# return matmul - - -def capability_partitioner_splitter(gm, sample_args): - gm_copy = copy.deepcopy(gm) - op_support = ThunderOperatorSupport(gm_copy) - partitioner = CapabilityBasedPartitioner(gm_copy, op_support) - fused_partition = partitioner.partition_and_fuse() - gm_copy.print_readable() - return gm_copy - - -# Splitter with _SplitterBase. - -# class GraphModule(torch.nn.Module): -# def forward(self, L_x_: "f32[2]"): -# l_x_ = L_x_ - -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:176 in func, code: x = x + 2 -# x: "f32[2]" = l_x_ + 2; l_x_ = None - -# # No stacktrace found for following nodes -# _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', None, True, None) - -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:178 in func, code: y = torch.sin(x) -# y: "f32[2]" = torch.sin(x) - -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:179 in func, code: return torch.matmul(x, y) -# matmul: "bf16[]" = torch.matmul(x, y); x = y = None - -# # No stacktrace found for following nodes -# _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None -# return (matmul,) - -# Got 1 acc subgraphs and 2 non-acc subgraphs -# class GraphModule(torch.nn.Module): -# def forward(self, l_x_: "f32[2]"): -# # No stacktrace found for following nodes -# _run_on_cpu_0 = self._run_on_cpu_0(); _run_on_cpu_0 = None -# _run_on_acc_1 = self._run_on_acc_1(l_x_); l_x_ = None -# _run_on_cpu_2 = self._run_on_cpu_2(_run_on_acc_1); _run_on_acc_1 = None -# return (_run_on_cpu_2,) - -# class _run_on_cpu_0(torch.nn.Module): -# def forward(self): -# # No stacktrace found for following nodes -# _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', None, True, None) -# _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None -# return () - -# class _run_on_acc_1(torch.nn.Module): -# def forward(self, l_x_: "f32[2]"): -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:176 in func, code: x = x + 2 -# x: "f32[2]" = l_x_ + 2; l_x_ = None -# return x - -# class _run_on_cpu_2(torch.nn.Module): -# def forward(self, x: "f32[2]"): -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:178 in func, code: y = torch.sin(x) -# y: "f32[2]" = torch.sin(x) - -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:179 in func, code: return torch.matmul(x, y) -# matmul: "bf16[]" = torch.matmul(x, y); x = y = None -# return matmul - - -class GraphModuleSplitter(torch.fx.passes.splitter_base._SplitterBase): - def starter_nodes(self): - """ - Finds nodes that consume module inputs or get_attr nodes. - """ - starter_cpu_nodes: NodeSet = set() - starter_acc_nodes: NodeSet = set() - - for node in self.module.graph.nodes: - if node.op not in {"placeholder", "get_attr"}: - continue - for user in node.users: - if user in self.acc_nodes: - starter_acc_nodes.add(user) - else: - starter_cpu_nodes.add(user) - - for node in self.module.graph.nodes: - if node.op in {"output", "placeholder", "get_attr"}: - continue - - if len(self.deps[node]) == 0: - if node in self.acc_nodes: - starter_acc_nodes.add(node) - else: - starter_cpu_nodes.add(node) - - return starter_cpu_nodes, starter_acc_nodes - - -def splitter(self, gm, sample_input): - """ - This function splits the graph provided by Dynamo - if it contains any operation or construct that is not supported by thunder. - For the unsupported subgraph, it is passed to inductor. - """ - from thunder import jit - - # Setup the splitter class - settings = torch.fx.passes.splitter_base._SplitterSettingBase(allow_non_tensor=True) - splitter = GraphModuleSplitter(gm, sample_input, operator_support=ThunderOperatorSupport(gm), settings=settings) - gm.print_readable() - # Call the splitter to split GraphModule. - split_module = splitter() - split_module.print_readable() - compiled_funcs = [] - for node in split_module.graph.nodes: - if node.name.startswith("_run_on_acc_"): - graph_module = getattr(split_module, node.name) - jit_fn = self.thunder_jit(graph_module) - setattr(split_module, node.name, jit_fn) - compiled_funcs.append(jit_fn) - if node.name.startswith("_run_on_cpu_") or node.name.startswith("_run_on_gpu_"): - graph_module = getattr(split_module, node.name) - jit_fn = torch.compile(graph_module, backend="inductor") - setattr(split_module, node.name, jit_fn) - compiled_funcs.append(jit_fn) - - self.subgraph_infos.append(SubgraphInfo(gm, True, compiled_funcs, [], split_module)) - # split_module.print_readable() - return split_module - - -# With the current, approach -# Original Graph -# class GraphModule(torch.nn.Module): -# def forward(self, L_x_: "f32[2]"): -# l_x_ = L_x_ - -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:176 in func, code: x = x + 2 -# x: "f32[2]" = l_x_ + 2; l_x_ = None - -# # No stacktrace found for following nodes -# _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', None, True, None) - -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:178 in func, code: y = torch.sin(x) -# y: "f32[2]" = torch.sin(x) - -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:179 in func, code: return torch.matmul(x, y) -# matmul: "bf16[]" = torch.matmul(x, y); x = y = None - -# # No stacktrace found for following nodes -# _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None -# return (matmul,)` -# - -# Split Graph -# class GraphModule(torch.nn.Module): -# def forward(self, l_x_: "f32[2]"): -# # No stacktrace found for following nodes -# submod_1 = self.submod_1(l_x_); l_x_ = None -# submod_2 = self.submod_2(submod_1); submod_1 = None -# return (submod_2,) - -# class submod_1(torch.nn.Module): -# def forward(self, l_x_: "f32[2]"): -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:176 in func, code: x = x + 2 -# x: "f32[2]" = l_x_ + 2; l_x_ = None -# return x - -# class submod_2(torch.nn.Module): -# def forward(self, x: "f32[2]"): -# # No stacktrace found for following nodes -# _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', None, True, None) - -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:178 in func, code: y = torch.sin(x) -# y: "f32[2]" = torch.sin(x) - -# # File: /home/kkalambarkar/lightning-thunder/scratchpad/test.py:179 in func, code: return torch.matmul(x, y) -# matmul: "bf16[]" = torch.matmul(x, y); x = y = None - -# # No stacktrace found for following nodes -# _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None -# return matmul From 31bbc30fa4da9bee6d6af9f8626d242c203e6189 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 29 Aug 2024 13:56:55 +0200 Subject: [PATCH 5/9] address review : part 2 --- thunder/dynamo/compiler.py | 55 ++++++++++++++++++++------------------ thunder/dynamo/utils.py | 9 +++++++ thunder/tests/framework.py | 3 ++- 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 7f05c81282..2dcd8295d8 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -18,6 +18,7 @@ SplitReasonType, is_node_supported, get_nodes_in_unsupported_ctx_regions, + update_node_and_submodule, ) @@ -30,7 +31,7 @@ def _warn_thunder_compiler(): class ThunderCompiler: - def __init__(self, **thunder_options): + def __init__(self, *, thunder_options: dict | None = None, torch_inductor_options: dict | None = None): """ A class that compiles a `fx.GraphModule` to a `thunder.ThunderModule`. This class is meant to be used as a backend for the `torch.compile` @@ -38,6 +39,7 @@ def __init__(self, **thunder_options): Keyword arguments: thunder_options: a dictionary of options to pass to `thunder.jit`. + torch_inductor_options: a dictionary of options to pass to `torch.compile`. Example: >>> import torch @@ -64,10 +66,17 @@ def __init__(self, **thunder_options): # Ref to the documentation of `SubgraphInfo` to know more about the information it contains. self.subgraph_infos: list[SubgraphInfo] = [] + if thunder_options is None: + thunder_options = {} + + if torch_inductor_options is None: + torch_inductor_options = {} + self.thunder_options = thunder_options self._thunder_jit = partial(jit, **thunder_options) + self._torch_compile = partial(torch.compile, **torch_inductor_options) - def splitter( + def _splitter( self, gm: torch.fx.GraphModule, _unused_sample_args: list[torch.SymInt, torch.Tensor] ) -> torch.fx.GraphModule: """ @@ -87,10 +96,10 @@ def forward(self, L_x_: "f32[2]"): # Split Graph: All operations are supported by thunder, we will see only one partition. class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[2]"): - submod_1 = self.submod_1(l_x_); l_x_ = None - return (submod_1,) + thunder_1 = self.thunder_1(l_x_); l_x_ = None + return (thunder_1,) - class submod_1(torch.nn.Module): + class thunder_1(torch.nn.Module): def forward(self, l_x_: "f32[2]"): y: "f32[2]" = torch.sin(l_x_) matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None @@ -110,16 +119,16 @@ def forward(self, L_x_: "f32[2]"): # Split Graph: Since `sinc` is unsupported, we will see two partitions, one for thunder and one for inductor. class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[2]"): - submod_1 = self.submod_1(l_x_) - submod_2 = self.submod_2(l_x_, submod_1); l_x_ = submod_1 = None - return (submod_2,) + inductor_1 = self.inductor_1(l_x_) + thunder_2 = self.thunder_2(l_x_, inductor_1); l_x_ = inductor_1 = None + return (thunder_2,) - class submod_1(torch.nn.Module): # Partition for inductor + class inductor_1(torch.nn.Module): # Partition for inductor def forward(self, l_x_: "f32[2]"): y: "f32[2]" = torch.sinc(l_x_); l_x_ = None return y - class submod_2(torch.nn.Module): # Partition for thunder + class thunder_2(torch.nn.Module): # Partition for thunder def forward(self, l_x_: "f32[2]", y: "f32[2]"): matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None return matmul @@ -139,12 +148,13 @@ def forward(self, l_x_: "f32[2]", y: "f32[2]"): nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm) def callback(node) -> int: + nonlocal prev_value, partition_cnt, split_reasons, supported_partitions + assert node.op not in ( "placeholder", "get_attr", "output", ), f"fx.split_module should have only passed node.op=call_* but received {node.op}" - nonlocal prev_value, partition_cnt if node in nodes_in_unsupported_ctx_regions: is_thunder_supported = False @@ -169,7 +179,7 @@ def callback(node) -> int: supported_partitions.add(partition_cnt) return partition_cnt - # `fx_split_module` iterates over nodes and determines the partition to place them based on the callback. + # `split_module` iterates over nodes and determines the partition to place them based on the callback. split_gm: torch.fx.GraphModule = split_module( gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True ) @@ -186,19 +196,22 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: # there is erase method on GraphModule graph_module = getattr(split_gm, node.name) jit_fn = self._thunder_jit(graph_module) - setattr(split_gm, node.name, jit_fn) + update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn) thunder_compiled_fns.append(jit_fn) submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER) elif node.name.startswith("submod"): # For inductor graph_module = getattr(split_gm, node.name) - jit_fn = torch.compile(graph_module, backend="inductor") - setattr(split_gm, node.name, jit_fn) + jit_fn = self._torch_compile(graph_module) + update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR) is_split = True else: # Everything else is a glue code to call and pass outputs between the other partitions. pass + # We update the GraphModule in `update_node_and_submodule`, so we need to recompile. + split_gm.recompile() + # gm.print_readable() # Append the details regarding this graph/subgraph. self.subgraph_infos.append( @@ -220,17 +233,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor # force recompilation of the GraphModule before passing it to Thunder. gm.real_recompile() - # Check if the complete graph `gm` is supported by thunder - # If yes, pass the whole `gm` to `thunder.jit` and return the compiled function. - # if is_graph_supported_by_thunder(gm, sample_args): - # jitted_gm = self.thunder_jit(gm) - # self.thunder_fns.append(jitted_gm) - # self.thunder_to_gm[jitted_gm] = gm - # compiled_fn = CompiledFunction(gm, jitted_gm, CompilerType.THUNDER) - # self.subgraph_infos.append(SubgraphInfo(gm, [compiled_fn], False)) - # return jitted_gm - # The whole graph may not be supported by `thunder`, so we split it in `thunder` supported sections # and unsupported sections which are passed to `torch.compile(backend='inductor')` - split_module = self.splitter(gm, sample_args) + split_module = self._splitter(gm, sample_args) return split_module diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 527f4df9e1..a596851a4b 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -260,3 +260,12 @@ def is_node_supported(node: torch.fx.Node) -> tuple[bool, SplitReason | None]: info=f"node with name: {node.name} and target: {node.target} didn't have any mapping in thunder.", ) return False, split_reason + + +def update_node_and_submodule(graph_module, node, new_name, new_callable): + assert graph_module.delete_submodule( + node.name + ), f"Didn't find a submodule named {node.name} in graph_module {graph_module}" + node.name = new_name + node.target = new_name + return graph_module.add_submodule(node.name, new_callable) diff --git a/thunder/tests/framework.py b/thunder/tests/framework.py index 7d99675087..96a98aace3 100644 --- a/thunder/tests/framework.py +++ b/thunder/tests/framework.py @@ -255,7 +255,8 @@ class DynamoThunderTestExecutor(TestExecutor): supported_dtypes = (datatypes.dtype,) def make_callable(self, fn, **kwargs): - return torch.compile(backend=ThunderCompiler(**kwargs))(fn) + # We assume all kwargs are for `thunder.jit` and not for `torch.compile` which is used in splitter. + return torch.compile(backend=ThunderCompiler(thunder_options=kwargs))(fn) # TODO Refactor these executors into the actual executor (sub)modules From 734a3be86c30ecea5f85fc5917039031611b4979 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 29 Aug 2024 14:41:03 +0200 Subject: [PATCH 6/9] update code and add comments --- thunder/dynamo/compiler.py | 7 ++-- thunder/dynamo/utils.py | 66 +++++++++++++++++++++++++++----------- 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 2dcd8295d8..2f9eeefca9 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -16,7 +16,7 @@ CompilerType, SplitReason, SplitReasonType, - is_node_supported, + is_node_supported_by_thunder, get_nodes_in_unsupported_ctx_regions, update_node_and_submodule, ) @@ -157,6 +157,9 @@ def callback(node) -> int: ), f"fx.split_module should have only passed node.op=call_* but received {node.op}" if node in nodes_in_unsupported_ctx_regions: + # If node was in unsupported ctx region like `autocast`, + # even though the operation maybe supported, we pass it to `torch.compile` + # as `thunder` doesn't correctly work with these. is_thunder_supported = False split_reason = SplitReason( SplitReasonType.UNSUPPORTED_NODE, @@ -164,7 +167,7 @@ def callback(node) -> int: ) split_reasons.append(split_reason) else: - is_thunder_supported, split_reason = is_node_supported(node) + is_thunder_supported, split_reason = is_node_supported_by_thunder(node) if split_reason is not None: split_reasons.append(split_reason) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index a596851a4b..9c92fcf054 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -89,7 +89,21 @@ class SubgraphInfo: split_reasons: list | None = None -def try_execute_symbol(thunder_symbol: "Symbol", node: torch.fx.Node) -> tuple[bool, SplitReason | None]: +def _concrete_shape(x): + """ + Get the concrete shape for a FakeTensor if it has `torch.SymInt` in its shape. + """ + + def get_backed_value(s): + if isinstance(s, torch.SymInt): + return s.node.hint + # Value is already concrete. + return s + + return tuple(map(get_backed_value, x.shape)) + + +def try_execute_thunder_symbol(thunder_symbol: "Symbol", node: torch.fx.Node) -> tuple[bool, SplitReason | None]: """ Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments. @@ -116,7 +130,7 @@ def _run_with_cache_info(): # We need cache info here as the default dtype and device support # for factory functions like ones, zeros, etc expects these details to be present. # TODO: Move this to CompileData as well? - # This details are in cache info because `jit_ext.py` + # This details are in cache_info because `jit_ext.py` # adds checks in prologue for the details which are present in here. cache_info = thunder._get_cache_info() cache_info["default_dtype"] = torch.get_default_dtype() @@ -128,7 +142,7 @@ def _run_with_cache_info(): try: def make_tensor_proxy(arg_node): - # This is a Node in the graph representing a Tensor. + # This is a Node in the graph representing a Tensor or tuple of Tensors. if isinstance(arg_node, torch.fx.Node): example_value = arg_node.meta["example_value"] @@ -138,22 +152,22 @@ def make_tensor_proxy(arg_node): # get the concrete value from SymInt. # Here, we only want to verify that thunder can run an operation. # So, it is ok to verify with concrete value. - def concrete_shape(x): - def get_backed_value(s): - if isinstance(s, torch.SymInt): - return s.node.hint - return s - - shape = tuple(map(get_backed_value, x.shape)) - return shape - example_value = example_value.new_ones( - concrete_shape(example_value), device=example_value.device, dtype=example_value.dtype + _concrete_shape(example_value), device=example_value.device, dtype=example_value.dtype + ) + elif isinstance(example_value, tuple): + example_value = tuple( + e_v.new_ones(_concrete_shape(e_v), device=e_v.device, dtype=e_v.dtype) + for e_v in example_value + ) + else: + # NOTE - This will be caught will be caught and be part of the SplitReason. + raise TypeError( + f"Received `make_tensor_proxy` received example_value which wasn't Tensor or Tuple" ) return proxy(example_value) # This is int, float, etc. - # TODO(kshitij12345) - verify the above line for more cases. return arg_node proxy_args = tuple(map(make_tensor_proxy, node.args)) @@ -205,7 +219,7 @@ def get_nodes_in_unsupported_ctx_regions(gm) -> set[torch.fx.Node]: return nodes_in_unsupported_ctx_regions -def is_node_supported(node: torch.fx.Node) -> tuple[bool, SplitReason | None]: +def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason | None]: """ Determine whether thunder can execute the operation described by this node. """ @@ -251,7 +265,7 @@ def is_node_supported(node: torch.fx.Node) -> tuple[bool, SplitReason | None]: # We try to proxify the arguments and call these operations on them to see if they are supported. if target in _torch_to_thunder_function_map or inspect.isbuiltin(target): thunder_symbol_or_builtin = _torch_to_thunder_function_map.get(target, target) - did_run, opt_split_reason = try_execute_symbol(thunder_symbol_or_builtin, node) + did_run, opt_split_reason = try_execute_thunder_symbol(thunder_symbol_or_builtin, node) return did_run, opt_split_reason # We found no automatic fallback registration and no mapping to thunder symbol. @@ -262,10 +276,26 @@ def is_node_supported(node: torch.fx.Node) -> tuple[bool, SplitReason | None]: return False, split_reason -def update_node_and_submodule(graph_module, node, new_name, new_callable): +def update_node_and_submodule( + graph_module: torch.fx.GraphModule, node: torch.fx.Node, new_name: str, new_callable: Callable +): + """ + Updates the graph module and the node in place with a new name and a new callable as the target. + + This function removes the existing submodule associated with the node's current name in graph_module and replaces + it with a new submodule using the specified new name and callable. The node's name and target are updated accordingly. + + Args: + graph_module (torch.fx.GraphModule): The graph module containing the node and submodules. + node (torch.fx.Node): The node to be updated within the graph module. + new_name (str): The new name to assign to the node and the submodule. + new_callable (Callable): The new callable to be used as the target for the submodule. + """ assert graph_module.delete_submodule( node.name ), f"Didn't find a submodule named {node.name} in graph_module {graph_module}" node.name = new_name node.target = new_name - return graph_module.add_submodule(node.name, new_callable) + assert graph_module.add_submodule( + node.name, new_callable + ), f"Adding submodule with name {node.name} in graph_module {graph_module} failed" From e28ec467f7013e534c5ac5ffe38964cba2b2f601 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 29 Aug 2024 18:12:24 +0200 Subject: [PATCH 7/9] add comment --- thunder/dynamo/compiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 2f9eeefca9..410b4839fe 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -196,15 +196,16 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: is_split = False for node in split_gm.graph.nodes: if is_thunder_supported_partition(node): - # there is erase method on GraphModule graph_module = getattr(split_gm, node.name) jit_fn = self._thunder_jit(graph_module) + # Update the node name from "submod_*" to "thunder_*" for more user-friendly names update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn) thunder_compiled_fns.append(jit_fn) submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER) elif node.name.startswith("submod"): # For inductor graph_module = getattr(split_gm, node.name) jit_fn = self._torch_compile(graph_module) + # Update the node name from "submod_*" to "inductor_*" for more user-friendly names update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR) is_split = True From fbe6737666aaceae0befbe6aa7578a682a4c45ce Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 30 Aug 2024 11:10:44 +0200 Subject: [PATCH 8/9] test for submodule name --- thunder/tests/test_dynamo.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 59d85f8ad9..44c88d82f4 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -72,6 +72,8 @@ def func(x): assert any( "automatic torch fallback" in split_reason.info for split_reason in backend.subgraph_infos[0].split_reasons ) + targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes) + assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*` @instantiate( @@ -107,6 +109,11 @@ def func(x): "didn't have any mapping in thunder" in split_reason.info for split_reason in backend.subgraph_infos[0].split_reasons ) + targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes) + assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*` + assert any( + target.startswith("inductor_") for target in targets + ) # Verify that the submodules have name `inductor_*` @instantiate( From ecde8d55c3e2901727c2b55d62197db38a22f5ef Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Sat, 31 Aug 2024 00:49:00 +0200 Subject: [PATCH 9/9] address review --- thunder/dynamo/compiler.py | 185 ++--------------------------------- thunder/dynamo/splitter.py | 173 ++++++++++++++++++++++++++++++++ thunder/tests/framework.py | 3 +- thunder/tests/test_dynamo.py | 34 +++++-- 4 files changed, 211 insertions(+), 184 deletions(-) create mode 100644 thunder/dynamo/splitter.py diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 410b4839fe..914f413376 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -1,6 +1,5 @@ from typing import List, Dict, Optional, Tuple, Set from collections.abc import Callable -import pprint from functools import partial import torch @@ -10,16 +9,8 @@ from thunder.core.baseutils import run_once -from thunder.dynamo.utils import ( - SubgraphInfo, - CompiledFunction, - CompilerType, - SplitReason, - SplitReasonType, - is_node_supported_by_thunder, - get_nodes_in_unsupported_ctx_regions, - update_node_and_submodule, -) +from thunder.dynamo.utils import SubgraphInfo +from thunder.dynamo.splitter import _splitter @run_once @@ -31,15 +22,16 @@ def _warn_thunder_compiler(): class ThunderCompiler: - def __init__(self, *, thunder_options: dict | None = None, torch_inductor_options: dict | None = None): + def __init__(self, **thunder_options): """ A class that compiles a `fx.GraphModule` to a `thunder.ThunderModule`. This class is meant to be used as a backend for the `torch.compile` function. Keyword arguments: - thunder_options: a dictionary of options to pass to `thunder.jit`. - torch_inductor_options: a dictionary of options to pass to `torch.compile`. + thunder_options: a dictionary of options to pass to `thunder.jit`. Besides all the arguments to `thunder.jit`, + it accepts `torch_inductor_options` which are passed to `torch.compile` if part of the graph + is not supported by thunder. Example: >>> import torch @@ -66,178 +58,19 @@ def __init__(self, *, thunder_options: dict | None = None, torch_inductor_option # Ref to the documentation of `SubgraphInfo` to know more about the information it contains. self.subgraph_infos: list[SubgraphInfo] = [] - if thunder_options is None: - thunder_options = {} - - if torch_inductor_options is None: - torch_inductor_options = {} + torch_inductor_options = thunder_options.pop("torch_inductor_options", {}) self.thunder_options = thunder_options self._thunder_jit = partial(jit, **thunder_options) self._torch_compile = partial(torch.compile, **torch_inductor_options) - def _splitter( - self, gm: torch.fx.GraphModule, _unused_sample_args: list[torch.SymInt, torch.Tensor] - ) -> torch.fx.GraphModule: - """ - This method will split graph into multiple graph modules based on thunder supported operations. - This function will try to split the graph in contiguous partitions. - - Example: - # All operations are supported by thunder - class GraphModule(torch.nn.Module): - def forward(self, L_x_: "f32[2]"): - l_x_ = L_x_ - - y: "f32[2]" = torch.sin(l_x_) - matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None - return (matmul,) - - # Split Graph: All operations are supported by thunder, we will see only one partition. - class GraphModule(torch.nn.Module): - def forward(self, l_x_: "f32[2]"): - thunder_1 = self.thunder_1(l_x_); l_x_ = None - return (thunder_1,) - - class thunder_1(torch.nn.Module): - def forward(self, l_x_: "f32[2]"): - y: "f32[2]" = torch.sin(l_x_) - matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None - return matmul - - Example: - # With unsupported operation `sinc` - class GraphModule(torch.nn.Module): - def forward(self, L_x_: "f32[2]"): - l_x_ = L_x_ - - y: "f32[2]" = torch.sinc(l_x_) - - matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None - return (matmul,) - - # Split Graph: Since `sinc` is unsupported, we will see two partitions, one for thunder and one for inductor. - class GraphModule(torch.nn.Module): - def forward(self, l_x_: "f32[2]"): - inductor_1 = self.inductor_1(l_x_) - thunder_2 = self.thunder_2(l_x_, inductor_1); l_x_ = inductor_1 = None - return (thunder_2,) - - class inductor_1(torch.nn.Module): # Partition for inductor - def forward(self, l_x_: "f32[2]"): - y: "f32[2]" = torch.sinc(l_x_); l_x_ = None - return y - - class thunder_2(torch.nn.Module): # Partition for thunder - def forward(self, l_x_: "f32[2]", y: "f32[2]"): - matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None - return matmul - """ - # The callback below is called for every node in the graph. - # It returns an `int` denoting the parition where the node should be placed. - # We want to partition the graph into contiguous regions (with one or more operations) - # into thunder supported or unsupported region. - # `prev_value` is used to determine if we are still in same region (i.e. supported region or unsupported region). - # `partition_cnt` is bumped everytime we change the region i.e. flip from supported to unsupported or from unsupported to supported. - # `supported_partitions` is used to track the thunder supported partitions. - prev_value = None - partition_cnt = 0 - supported_partitions: set[int] = set() - split_reasons: list[SplitReason] = [] - - nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm) - - def callback(node) -> int: - nonlocal prev_value, partition_cnt, split_reasons, supported_partitions - - assert node.op not in ( - "placeholder", - "get_attr", - "output", - ), f"fx.split_module should have only passed node.op=call_* but received {node.op}" - - if node in nodes_in_unsupported_ctx_regions: - # If node was in unsupported ctx region like `autocast`, - # even though the operation maybe supported, we pass it to `torch.compile` - # as `thunder` doesn't correctly work with these. - is_thunder_supported = False - split_reason = SplitReason( - SplitReasonType.UNSUPPORTED_NODE, - info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.", - ) - split_reasons.append(split_reason) - else: - is_thunder_supported, split_reason = is_node_supported_by_thunder(node) - if split_reason is not None: - split_reasons.append(split_reason) - - if prev_value == is_thunder_supported: # We are in the same region. - return partition_cnt - - # There is a flip. Either from supported to unsupported or unsupported to supported. - prev_value = is_thunder_supported - partition_cnt += 1 # Bump the region cnt. - - if is_thunder_supported: - supported_partitions.add(partition_cnt) - return partition_cnt - - # `split_module` iterates over nodes and determines the partition to place them based on the callback. - split_gm: torch.fx.GraphModule = split_module( - gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True - ) - - def is_thunder_supported_partition(node: torch.fx.Node) -> bool: - return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions - - # Call compile on the split region/s. - thunder_compiled_fns = [] - submodule_to_compiled_fns = {} - is_split = False - for node in split_gm.graph.nodes: - if is_thunder_supported_partition(node): - graph_module = getattr(split_gm, node.name) - jit_fn = self._thunder_jit(graph_module) - # Update the node name from "submod_*" to "thunder_*" for more user-friendly names - update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn) - thunder_compiled_fns.append(jit_fn) - submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER) - elif node.name.startswith("submod"): # For inductor - graph_module = getattr(split_gm, node.name) - jit_fn = self._torch_compile(graph_module) - # Update the node name from "submod_*" to "inductor_*" for more user-friendly names - update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) - submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR) - is_split = True - else: - # Everything else is a glue code to call and pass outputs between the other partitions. - pass - - # We update the GraphModule in `update_node_and_submodule`, so we need to recompile. - split_gm.recompile() - - # gm.print_readable() - # Append the details regarding this graph/subgraph. - self.subgraph_infos.append( - SubgraphInfo( - gm, - split_gm, - thunder_compiled_fns, - submodule_to_compiled_fns, - split_reasons, - ) - ) - # split_gm.print_readable() - return split_gm - def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): - from thunder import jit - # Dynamo uses lazy generation of the underlying Python code, so we need to # force recompilation of the GraphModule before passing it to Thunder. gm.real_recompile() # The whole graph may not be supported by `thunder`, so we split it in `thunder` supported sections # and unsupported sections which are passed to `torch.compile(backend='inductor')` - split_module = self._splitter(gm, sample_args) + split_module, subgraph_info = _splitter(gm, self._thunder_jit, self._torch_compile, sample_args) + self.subgraph_infos.append(subgraph_info) return split_module diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py new file mode 100644 index 0000000000..613ce50acd --- /dev/null +++ b/thunder/dynamo/splitter.py @@ -0,0 +1,173 @@ +from typing import List, Dict, Optional, Tuple, Set +from collections.abc import Callable +from functools import partial + +import torch +from torch.fx.passes.split_module import split_module +import warnings +from collections.abc import Mapping + +from thunder.core.baseutils import run_once + +from thunder.dynamo.utils import ( + SubgraphInfo, + CompiledFunction, + CompilerType, + SplitReason, + SplitReasonType, + is_node_supported_by_thunder, + get_nodes_in_unsupported_ctx_regions, + update_node_and_submodule, +) + + +def _splitter( + gm: torch.fx.GraphModule, + thunder_jit: Callable, + torch_inductor: Callable, + _unused_sample_args: list[torch.SymInt, torch.Tensor], +) -> torch.fx.GraphModule: + """ + This method will split graph into multiple graph modules based on thunder supported operations. + This function will try to split the graph in contiguous partitions. + + Example: + # All operations are supported by thunder + class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[2]"): + l_x_ = L_x_ + + y: "f32[2]" = torch.sin(l_x_) + matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None + return (matmul,) + + # Split Graph: All operations are supported by thunder, we will see only one partition. + class GraphModule(torch.nn.Module): + def forward(self, l_x_: "f32[2]"): + thunder_1 = self.thunder_1(l_x_); l_x_ = None + return (thunder_1,) + + class thunder_1(torch.nn.Module): + def forward(self, l_x_: "f32[2]"): + y: "f32[2]" = torch.sin(l_x_) + matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None + return matmul + + Example: + # With unsupported operation `sinc` + class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[2]"): + l_x_ = L_x_ + + y: "f32[2]" = torch.sinc(l_x_) + + matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None + return (matmul,) + + # Split Graph: Since `sinc` is unsupported, we will see two partitions, one for thunder and one for inductor. + class GraphModule(torch.nn.Module): + def forward(self, l_x_: "f32[2]"): + inductor_1 = self.inductor_1(l_x_) + thunder_2 = self.thunder_2(l_x_, inductor_1); l_x_ = inductor_1 = None + return (thunder_2,) + + class inductor_1(torch.nn.Module): # Partition for inductor + def forward(self, l_x_: "f32[2]"): + y: "f32[2]" = torch.sinc(l_x_); l_x_ = None + return y + + class thunder_2(torch.nn.Module): # Partition for thunder + def forward(self, l_x_: "f32[2]", y: "f32[2]"): + matmul: "f32[]" = torch.matmul(l_x_, y); l_x_ = y = None + return matmul + """ + # The callback below is called for every node in the graph. + # It returns an `int` denoting the parition where the node should be placed. + # We want to partition the graph into contiguous regions (with one or more operations) + # into thunder supported or unsupported region. + # `prev_value` is used to determine if we are still in same region (i.e. supported region or unsupported region). + # `partition_cnt` is bumped everytime we change the region i.e. flip from supported to unsupported or from unsupported to supported. + # `supported_partitions` is used to track the thunder supported partitions. + prev_value = None + partition_cnt = 0 + supported_partitions: set[int] = set() + split_reasons: list[SplitReason] = [] + + nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm) + + def callback(node) -> int: + nonlocal prev_value, partition_cnt, split_reasons, supported_partitions + + assert node.op not in ( + "placeholder", + "get_attr", + "output", + ), f"fx.split_module should have only passed node.op=call_* but received {node.op}" + + if node in nodes_in_unsupported_ctx_regions: + # If node was in unsupported ctx region like `autocast`, + # even though the operation maybe supported, we pass it to `torch.compile` + # as `thunder` doesn't correctly work with these. + is_thunder_supported = False + split_reason = SplitReason( + SplitReasonType.UNSUPPORTED_NODE, + info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.", + ) + split_reasons.append(split_reason) + else: + is_thunder_supported, split_reason = is_node_supported_by_thunder(node) + if split_reason is not None: + split_reasons.append(split_reason) + + if prev_value == is_thunder_supported: # We are in the same region. + return partition_cnt + + # There is a flip. Either from supported to unsupported or unsupported to supported. + prev_value = is_thunder_supported + partition_cnt += 1 # Bump the region cnt. + + if is_thunder_supported: + supported_partitions.add(partition_cnt) + return partition_cnt + + # `split_module` iterates over nodes and determines the partition to place them based on the callback. + split_gm: torch.fx.GraphModule = split_module( + gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True + ) + + def is_thunder_supported_partition(node: torch.fx.Node) -> bool: + return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions + + # Call compile on the split region/s. + thunder_compiled_fns = [] + submodule_to_compiled_fns = {} + is_split = False + for node in split_gm.graph.nodes: + if is_thunder_supported_partition(node): + graph_module = getattr(split_gm, node.name) + jit_fn = thunder_jit(graph_module) + # Update the node name from "submod_*" to "thunder_*" for more user-friendly names + update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn) + thunder_compiled_fns.append(jit_fn) + submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER) + elif node.name.startswith("submod"): # For inductor + graph_module = getattr(split_gm, node.name) + jit_fn = torch_inductor(graph_module) + # Update the node name from "submod_*" to "inductor_*" for more user-friendly names + update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) + submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR) + is_split = True + else: + # Everything else is a glue code to call and pass outputs between the other partitions. + pass + + # We update the GraphModule in `update_node_and_submodule`, so we need to recompile. + split_gm.recompile() + + return split_gm, SubgraphInfo( + gm, + split_gm, + thunder_compiled_fns, + submodule_to_compiled_fns, + split_reasons, + ) diff --git a/thunder/tests/framework.py b/thunder/tests/framework.py index 96a98aace3..7d99675087 100644 --- a/thunder/tests/framework.py +++ b/thunder/tests/framework.py @@ -255,8 +255,7 @@ class DynamoThunderTestExecutor(TestExecutor): supported_dtypes = (datatypes.dtype,) def make_callable(self, fn, **kwargs): - # We assume all kwargs are for `thunder.jit` and not for `torch.compile` which is used in splitter. - return torch.compile(backend=ThunderCompiler(thunder_options=kwargs))(fn) + return torch.compile(backend=ThunderCompiler(**kwargs))(fn) # TODO Refactor these executors into the actual executor (sub)modules diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 44c88d82f4..b4eb8ab147 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -1,5 +1,5 @@ import torch.fx -from thunder.tests.framework import instantiate, NOTHING, DynamoThunderExecutor +from thunder.tests.framework import instantiate, NOTHING, DynamoThunderExecutor, IS_WINDOWS from thunder import dtypes from thunder.dynamo import ThunderCompiler from thunder import last_traces @@ -44,7 +44,14 @@ def func(x): @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor], - decorators=(pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")),), + decorators=( + pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")), + pytest.mark.xfail( + condition=IS_WINDOWS, + strict=True, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", + ), + ), ) def test_basic_splitter(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): x = torch.ones(2, 2, device=device, dtype=dtype, requires_grad=True) @@ -52,7 +59,8 @@ def test_basic_splitter(executor, device: str, dtype: dtypes.dtype, dynamic: boo backend = ThunderCompiler() def func(x): - # torch.sinc has automatic fallback registered. + # torch.sinc has automatic fallback registered, + # so that operation will be given to inductor. x = x.exp() y = torch.sinc(x) + torch.cos(x) return y + 1 @@ -71,7 +79,7 @@ def func(x): assert len(backend.subgraph_infos[0].submodule_to_compiled_functions) > 1 # Verify that the subgraph was split. assert any( "automatic torch fallback" in split_reason.info for split_reason in backend.subgraph_infos[0].split_reasons - ) + ) # Verify that we had a split because we detected an `automatic registered operator` targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes) assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*` @@ -79,7 +87,14 @@ def func(x): @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor], - decorators=(pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")),), + decorators=( + pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")), + pytest.mark.xfail( + condition=IS_WINDOWS, + strict=True, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", + ), + ), ) def test_splitter_unsupported_ctx(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): x = torch.rand(2, 2, device=device, dtype=dtype, requires_grad=True) @@ -119,7 +134,14 @@ def func(x): @instantiate( dtypes=NOTHING, executors=[DynamoThunderExecutor], - decorators=(pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")),), + decorators=( + pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")), + pytest.mark.xfail( + condition=IS_WINDOWS, + strict=True, + reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", + ), + ), ) def test_splitter_unsupported_ctx_with_graph_break(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): x = torch.rand(2, 2, device=device, dtype=dtype, requires_grad=True)