diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index f55ee29657..d61aa10812 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -7,7 +7,7 @@ import torch from thunder.core.baseutils import run_once -from thunder.dynamo.utils import recompile_graph +from thunder.dynamo.utils import recompile_graph, remove_empty_autocast from thunder.dynamo.splitter import _splitter if TYPE_CHECKING: @@ -72,6 +72,8 @@ def __init__(self, **thunder_options): self._torch_compile = partial(torch.compile, **torch_inductor_options) def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): + gm = remove_empty_autocast(gm) + # Dynamo uses lazy generation of the underlying Python code, so we need to # force recompilation of the GraphModule before passing it to Thunder. recompile_graph(gm) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 8b4c690c0a..d434d02342 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -512,3 +512,44 @@ def checkpoint_converter(gm: torch.fx.GraphModule, sub_gm: torch.fx.GraphModule) else: function_module = getattr(gm, n.args[0].name) _checkpoint_function_converter(function_module) + + +def remove_empty_autocast(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Function to remove empty autocast regions from GraphModule. + + Dynamo can provide empty autocast regions in which case, it is more performant to remove them + from the graph than to compile them and pay the cost of calling a wrapped optimized function + which does nothing. + + Args: + graph_module: Graph module to which this pass is applied. + + """ + + empty_autocast_removed_graph_module = copy.deepcopy(graph_module) + + # Dummy init node. + prev_node = torch.fx.node.Node(graph_module.graph, "start_node", "call_function", lambda: None, None, None) + nodes_to_erase = [] + for node in empty_autocast_removed_graph_module.graph.nodes: + # As _enter_autocast and _exit_autocast functions map the regions created by context manager, + # previous `_enter_autocast` will always correspond with current `_exit_autocast`. + if ( + prev_node.target == torch.amp.autocast_mode._enter_autocast + and node.target == torch.amp.autocast_mode._exit_autocast + ): + # NOTE: Order of node being appended matters. + # The node to be erased has to have zero users. + # So, we remove `_exit_autocast` first (which consumes output from `_enter_autocast`) + # and then we can remove the corresponding `_enter_autocast`. + nodes_to_erase.append(node) + nodes_to_erase.append(prev_node) + + prev_node = node + + # Erase the marked nodes. + for node in nodes_to_erase: + empty_autocast_removed_graph_module.graph.erase_node(node) + + return empty_autocast_removed_graph_module diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 2f9bb0d124..da9129dcbe 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -1,5 +1,6 @@ import pytest import warnings +import itertools import torch import torch.fx import torch.nn as nn @@ -515,6 +516,60 @@ def func(x): torch.testing.assert_close(actual_grad, expected_grad) +def test_empty_autocast(): + autocast_ops = (torch.amp.autocast_mode._enter_autocast, torch.amp.autocast_mode._exit_autocast) + + def _call_thunder_backend(fn, args): + backend = ThunderCompiler() + jf = torch.compile(backend=backend)(f) + jf(*args) + return backend + + # autocast region is removed + def f(): + with torch.autocast(dtype=torch.bfloat16, device_type="cpu"): + pass + return + + backend = _call_thunder_backend(f, ()) + assert all(node.target not in autocast_ops for node in backend.subgraph_infos[0].split_graph_module.graph.nodes) + + # Both autocast regions are removed + def f(x): + with torch.autocast(dtype=torch.bfloat16, device_type="cpu"): + pass + y = x @ x + with torch.autocast(dtype=torch.bfloat16, device_type="cpu"): + pass + return y + + x = torch.randn(3, 3) + backend = _call_thunder_backend(f, (x,)) + + all_nodes = itertools.chain( + backend.subgraph_infos[0].split_graph_module.graph.nodes, + backend.subgraph_infos[0].split_graph_module.thunder_1.graph.nodes, + ) + assert all(node.target not in autocast_ops for node in all_nodes) + + # First autocast region is removed and second isn't + def f(x): + with torch.autocast(dtype=torch.bfloat16, device_type="cpu"): + pass + y = x @ x + with torch.autocast(dtype=torch.bfloat16, device_type="cpu"): + y = y @ y + return y + + x = torch.randn(3, 3) + backend = _call_thunder_backend(f, (x,)) + all_nodes = itertools.chain( + backend.subgraph_infos[0].split_graph_module.graph.nodes, + backend.subgraph_infos[0].split_graph_module.thunder_1.graph.nodes, + ) + assert sum(node.target in autocast_ops for node in all_nodes) == 2 + + # Sample command to run the benchmark using ThunderCompilerGraphBenchmarking # pytest thunder/tests/test_dynamo.py -k test_ThunderCompilerGraphBenchmarking_groupby --benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName' # For more details, see :class:`thunder.dynamo.compiler_graph_benchmark.ThunderCompilerGraphBenchmarking`