Skip to content

Commit

Permalink
thunderFX: delegate autocast regions to thunder (#1378)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Nov 4, 2024
1 parent dcf0729 commit 3d42c10
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 24 deletions.
27 changes: 18 additions & 9 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,13 @@ def get_proxy_inputs_from_node(node: torch.fx.Node) -> tuple[tuple, dict]:
with thunder.core.trace.tracectx(TraceCtx()):

def make_tensor_proxy(arg_node):
# This is a Node in the graph representing a Tensor or tuple of Tensors.
# This is a Node in the graph representing a Tensor or tuple of Tensors or
# a PyTorch object like one representing torch.autocast.
if isinstance(arg_node, torch.fx.Node):
if "example_value" not in arg_node.meta:
# This is a non tensor object like `torch.autocast` ctx manager object.
return arg_node

example_value = arg_node.meta["example_value"]

if isinstance(example_value, torch.Tensor):
Expand Down Expand Up @@ -176,7 +181,15 @@ def try_execute_thunder_symbol(thunder_symbol: Symbol, node: torch.fx.Node) -> t
"""
import thunder
from thunder.core.trace import TraceCtx
from thunder.core.compile_data import compile_data_and_stats
from thunder.common import CompileData, CompileStats

# This is required for verifying `_enter_autocast`
# which pushes state onto `CompileData.autocast_stack`.
cd = CompileData(fn=lambda x: x, disable_preprocessing=True)
cs = CompileStats()

@compile_data_and_stats(cd, cs)
@thunder._with_cache_info_ctx
def _run_with_cache_info():

Expand Down Expand Up @@ -226,8 +239,8 @@ def get_nodes_in_unsupported_ctx_regions(gm: torch.fx.GraphModule) -> set[torch.
nodes_in_unsupported_ctx_regions: set[torch.fx.Node] = set()
ctx_cnt = 0 # Count of `enters_autocast` we have seen till now

# We want to mark nodes with `_enter_autocast` and `_exit_autocast`
# as unsupported as `thunder` doesn't correctly deal with these stateful functions.
# We want to mark nodes disabling `autograd` as unsupported
# because `thunder` doesn't correctly deal with these stateful functions.

def is_no_grad_ctx_enter(node):
if node.target == torch._C._set_grad_enabled:
Expand All @@ -244,13 +257,9 @@ def is_no_grad_ctx_exit(node):
return False

for node in gm.graph.nodes:
if node.op == "call_function" and (
node.target in (torch.amp.autocast_mode._enter_autocast,) or is_no_grad_ctx_enter(node)
):
if node.op == "call_function" and is_no_grad_ctx_enter(node):
ctx_cnt += 1
elif node.op == "call_function" and (
node.target in (torch.amp.autocast_mode._exit_autocast,) or is_no_grad_ctx_exit(node)
):
elif node.op == "call_function" and is_no_grad_ctx_exit(node):
ctx_cnt -= 1
else:
if ctx_cnt > 0:
Expand Down
75 changes: 61 additions & 14 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from thunder import dtypes
from thunder.dynamo import ThunderCompiler
from thunder.dynamo.utils import CompilerType
from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking
from thunder import last_traces
from thunder.core.symbol import Symbol
Expand Down Expand Up @@ -126,7 +127,7 @@ def func(x):
),
),
)
def test_splitter_unsupported_ctx(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None):
def test_splitter_autocast_ctx(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None):
x = torch.rand(2, 2, device=device, dtype=dtype, requires_grad=True)

backend = ThunderCompiler()
Expand All @@ -149,15 +150,10 @@ def func(x):
torch.testing.assert_close(actual_grad, expected_grad)

assert len(backend.subgraph_infos) == 1
assert len(backend.subgraph_infos[0].submodule_to_compiled_functions) > 1 # Verify that the subgraph was split.
assert any(
"it is in unsupported context" in split_reason.info for split_reason in backend.subgraph_infos[0].split_reasons
)
targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes)
assert any(target.startswith("thunder_") for target in targets) # Verify that the submodules have name `thunder_*`
assert any(
target.startswith("inductor_") for target in targets
) # Verify that the submodules have name `inductor_*`
assert len(backend.subgraph_infos[0].split_reasons) == 0
compiled_functions = tuple(backend.subgraph_infos[0].submodule_to_compiled_functions.values())
assert all(compiled_fn.compiler == CompilerType.THUNDER for compiled_fn in compiled_functions)
assert not any(compiled_fn.compiler == CompilerType.TORCH_INDUCTOR for compiled_fn in compiled_functions)


@instantiate(
Expand All @@ -172,7 +168,7 @@ def func(x):
),
),
)
def test_splitter_unsupported_ctx_with_graph_break(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None):
def test_splitter_autocast_ctx_with_graph_break(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None):
x = torch.rand(2, 2, device=device, dtype=dtype, requires_grad=True)

backend = ThunderCompiler()
Expand All @@ -184,7 +180,7 @@ def func(x):
torch._dynamo.graph_break()
return torch.matmul(x, y)

expected = torch.compile(func, dynamic=False)(x)
expected = torch.compile(func, dynamic=dynamic)(x)
cfunc = torch.compile(func, backend=backend, dynamic=dynamic)
actual = cfunc(x)

Expand All @@ -197,8 +193,59 @@ def func(x):
# 2 subgraphs due to graph-break
assert len(backend.subgraph_infos) == 2
for subgraph_info in backend.subgraph_infos:
# Verify that for each subgraph we had split due to `autocast` being enabled.
assert any("it is in unsupported context" in split_reason.info for split_reason in subgraph_info.split_reasons)
assert len(subgraph_info.split_reasons) == 0
compiled_functions = tuple(subgraph_info.submodule_to_compiled_functions.values())
assert all(compiled_fn.compiler == CompilerType.THUNDER for compiled_fn in compiled_functions)
assert not any(compiled_fn.compiler == CompilerType.TORCH_INDUCTOR for compiled_fn in compiled_functions)


@instantiate(
dtypes=NOTHING,
executors=[DynamoThunderExecutor],
decorators=(
pytest.mark.parametrize("dynamic", (True, False, None), ids=("dynamic", "static", "auto")),
pytest.mark.xfail(
condition=IS_WINDOWS,
strict=True,
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
),
),
)
def test_splitter_autocast_ctx_with_split(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None):
x = torch.rand(2, 2, device=device, dtype=dtype, requires_grad=True)

backend = ThunderCompiler()

def func(x):
x = x + 2
with torch.autocast(device):
y = torch.sin(x)

# torch.sinc has automatic fallback registered,
# so that operation will be given to inductor.
y = torch.sinc(y)
return torch.matmul(x, y)

expected = torch.compile(func, dynamic=dynamic)(x)
cfunc = torch.compile(func, backend=backend, dynamic=dynamic)
actual = cfunc(x)

g = torch.rand_like(actual)
torch.testing.assert_close(actual, expected)
actual_grad = torch.autograd.grad(actual, x, g)
expected_grad = torch.autograd.grad(expected, x, g)
torch.testing.assert_close(actual_grad, expected_grad)

assert len(backend.subgraph_infos) == 1 # no graph break in dynamo

subgraph_info = backend.subgraph_infos[0]
assert len(subgraph_info.split_reasons) > 1 # Split due to `torch.sinc`
compiled_functions = tuple(subgraph_info.submodule_to_compiled_functions.values())
assert any(compiled_fn.compiler == CompilerType.THUNDER for compiled_fn in compiled_functions)
assert any(compiled_fn.compiler == CompilerType.TORCH_INDUCTOR for compiled_fn in compiled_functions)
assert any(
"automatic torch fallback" in split_reason.info for split_reason in subgraph_info.split_reasons
) # Verify that we had a split because we detected an `automatic registered operator`


@instantiate(
Expand Down
9 changes: 8 additions & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5616,7 +5616,12 @@ def backward_autograd_function_apply(
id="torch.amp.autocast_mode._enter_autocast",
tags=(prims.OpTags.DONT_DCE, prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP),
)
def autocast_enter(device_type, dtype=None, enabled=True):
def autocast_enter(device_type, dtype=None, enabled=True, _unused_cache_enabled=True):
# We may receive device_type=cuda:0
# PyTorch applies autocast irrespective of device index.
# So, here we grab the device_type from the string.
device_type, unused_deviceno = devices._device_from_string_helper(device_type)
device_type = devices.devicetype_string(device_type)
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
get_compile_data().autocast_stack.push(device_type, dtype, enabled)
Expand All @@ -5628,6 +5633,8 @@ def autocast_enter(device_type, dtype=None, enabled=True):
tags=(prims.OpTags.DONT_DCE, prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP),
)
def autocast_exit(*args):
if get_compile_data().autocast_stack.is_empty():
return
get_compile_data().autocast_stack.pop()


Expand Down

0 comments on commit 3d42c10

Please sign in to comment.