Skip to content

Commit

Permalink
thunderFX : pass to remove empty autocast regions (#1400)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Nov 19, 2024
1 parent 052bac3 commit a617503
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 1 deletion.
4 changes: 3 additions & 1 deletion thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from thunder.core.baseutils import run_once
from thunder.dynamo.utils import recompile_graph
from thunder.dynamo.utils import recompile_graph, remove_empty_autocast
from thunder.dynamo.splitter import _splitter

if TYPE_CHECKING:
Expand Down Expand Up @@ -72,6 +72,8 @@ def __init__(self, **thunder_options):
self._torch_compile = partial(torch.compile, **torch_inductor_options)

def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
gm = remove_empty_autocast(gm)

# Dynamo uses lazy generation of the underlying Python code, so we need to
# force recompilation of the GraphModule before passing it to Thunder.
recompile_graph(gm)
Expand Down
41 changes: 41 additions & 0 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,44 @@ def checkpoint_converter(gm: torch.fx.GraphModule, sub_gm: torch.fx.GraphModule)
else:
function_module = getattr(gm, n.args[0].name)
_checkpoint_function_converter(function_module)


def remove_empty_autocast(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Function to remove empty autocast regions from GraphModule.
Dynamo can provide empty autocast regions in which case, it is more performant to remove them
from the graph than to compile them and pay the cost of calling a wrapped optimized function
which does nothing.
Args:
graph_module: Graph module to which this pass is applied.
"""

empty_autocast_removed_graph_module = copy.deepcopy(graph_module)

# Dummy init node.
prev_node = torch.fx.node.Node(graph_module.graph, "start_node", "call_function", lambda: None, None, None)
nodes_to_erase = []
for node in empty_autocast_removed_graph_module.graph.nodes:
# As _enter_autocast and _exit_autocast functions map the regions created by context manager,
# previous `_enter_autocast` will always correspond with current `_exit_autocast`.
if (
prev_node.target == torch.amp.autocast_mode._enter_autocast
and node.target == torch.amp.autocast_mode._exit_autocast
):
# NOTE: Order of node being appended matters.
# The node to be erased has to have zero users.
# So, we remove `_exit_autocast` first (which consumes output from `_enter_autocast`)
# and then we can remove the corresponding `_enter_autocast`.
nodes_to_erase.append(node)
nodes_to_erase.append(prev_node)

prev_node = node

# Erase the marked nodes.
for node in nodes_to_erase:
empty_autocast_removed_graph_module.graph.erase_node(node)

return empty_autocast_removed_graph_module
55 changes: 55 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import warnings
import itertools
import torch
import torch.fx
import torch.nn as nn
Expand Down Expand Up @@ -515,6 +516,60 @@ def func(x):
torch.testing.assert_close(actual_grad, expected_grad)


def test_empty_autocast():
autocast_ops = (torch.amp.autocast_mode._enter_autocast, torch.amp.autocast_mode._exit_autocast)

def _call_thunder_backend(fn, args):
backend = ThunderCompiler()
jf = torch.compile(backend=backend)(f)
jf(*args)
return backend

# autocast region is removed
def f():
with torch.autocast(dtype=torch.bfloat16, device_type="cpu"):
pass
return

backend = _call_thunder_backend(f, ())
assert all(node.target not in autocast_ops for node in backend.subgraph_infos[0].split_graph_module.graph.nodes)

# Both autocast regions are removed
def f(x):
with torch.autocast(dtype=torch.bfloat16, device_type="cpu"):
pass
y = x @ x
with torch.autocast(dtype=torch.bfloat16, device_type="cpu"):
pass
return y

x = torch.randn(3, 3)
backend = _call_thunder_backend(f, (x,))

all_nodes = itertools.chain(
backend.subgraph_infos[0].split_graph_module.graph.nodes,
backend.subgraph_infos[0].split_graph_module.thunder_1.graph.nodes,
)
assert all(node.target not in autocast_ops for node in all_nodes)

# First autocast region is removed and second isn't
def f(x):
with torch.autocast(dtype=torch.bfloat16, device_type="cpu"):
pass
y = x @ x
with torch.autocast(dtype=torch.bfloat16, device_type="cpu"):
y = y @ y
return y

x = torch.randn(3, 3)
backend = _call_thunder_backend(f, (x,))
all_nodes = itertools.chain(
backend.subgraph_infos[0].split_graph_module.graph.nodes,
backend.subgraph_infos[0].split_graph_module.thunder_1.graph.nodes,
)
assert sum(node.target in autocast_ops for node in all_nodes) == 2


# Sample command to run the benchmark using ThunderCompilerGraphBenchmarking
# pytest thunder/tests/test_dynamo.py -k test_ThunderCompilerGraphBenchmarking_groupby --benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName'
# For more details, see :class:`thunder.dynamo.compiler_graph_benchmark.ThunderCompilerGraphBenchmarking`
Expand Down

0 comments on commit a617503

Please sign in to comment.