From cc7335cc94102473cba6a07169a7a305f722f2a3 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Fri, 11 Oct 2024 13:33:06 +0200 Subject: [PATCH] [thunderFX] splitter - handle no_grad correctly (#1282) --- thunder/dynamo/utils.py | 37 ++++++++++++++++++++++++++++++++++-- thunder/tests/test_dynamo.py | 30 +++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 34925438f3..b0ef08097e 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -18,6 +18,11 @@ auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values())) +# Currently, thunder as mapping torch these function but they +# just throw warning. +UNSUPPORTED_THUNDER_FUNCTION = (torch._C._set_grad_enabled,) + + class CompilerType(Enum): """ An enumeration representing different types of compilers. @@ -221,10 +226,29 @@ def get_nodes_in_unsupported_ctx_regions(gm: torch.fx.GraphModule) -> set[torch. # We want to mark nodes with `_enter_autocast` and `_exit_autocast` # as unsupported as `thunder` doesn't correctly deal with these stateful functions. + + def is_no_grad_ctx_enter(node): + if node.target == torch._C._set_grad_enabled: + arg: bool = node.args[0] + assert isinstance(arg, bool) + return not arg # arg is False (i.e. grad was disabled) + return False + + def is_no_grad_ctx_exit(node): + if node.target == torch._C._set_grad_enabled: + arg: bool = node.args[0] + assert isinstance(arg, bool) + return arg # arg is True (i.e. grad was enabled) + return False + for node in gm.graph.nodes: - if node.op == "call_function" and node.target in (torch.amp.autocast_mode._enter_autocast,): + if node.op == "call_function" and ( + node.target in (torch.amp.autocast_mode._enter_autocast,) or is_no_grad_ctx_enter(node) + ): ctx_cnt += 1 - elif node.op == "call_function" and node.target in (torch.amp.autocast_mode._exit_autocast,): + elif node.op == "call_function" and ( + node.target in (torch.amp.autocast_mode._exit_autocast,) or is_no_grad_ctx_exit(node) + ): ctx_cnt -= 1 else: if ctx_cnt > 0: @@ -271,6 +295,15 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason ) return False, split_reason + # These functions are present in `_torch_to_thunder_function_map` but don't mimic exact behavior. + # Eg. torch._C._set_grad_enabled's thunder implementation just throws warning that this is unsupported. + if target in UNSUPPORTED_THUNDER_FUNCTION: + split_reason = SplitReason( + SplitReasonType.UNSUPPORTED_NODE, + info=f"node with name: {node.name} and target: {node.target} has been manually disabled.", + ) + return False, split_reason + # If thunder has a mapping for this operation, try executing the meta function and see. # We have a symbol for `torch.where`, but we don't support one overload of it. # So, we try and execute the meta to get a real signal. diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index afc2e9e29d..a667405a13 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -435,3 +435,33 @@ def test_thundercompiler_optim_step(executor, device, dtype, optim): tuple(ref_model.parameters()), msg=lambda s: f"{i+1}-iter {s}", ) + + +@instantiate(dtypes=NOTHING, executors=[DynamoThunderExecutor]) +def test_no_grad_ctx_manager(executor, device: str, dtype: dtypes.dtype): + backend = ThunderCompiler() + + def func(x): + with torch.no_grad(): + with torch.autocast("cuda", dtype=torch.bfloat16): + y = x @ x + return y + x + + x = torch.randn(3, 3, device=device, dtype=dtype, requires_grad=True) + actual = torch.compile(func, backend=backend)(x) + expected = torch.compile(func, backend="eager")(x) + + # We record the GraphModules that was compiled by ThunderCompiler + assert len(backend.subgraph_infos) == 1 + + for subgraph_info in backend.subgraph_infos: + assert len(subgraph_info.split_reasons) > 1 # Verify there were splits in the subgraph. + assert isinstance(subgraph_info.original_graph_module, torch.fx.GraphModule) + assert any("has been manually disabled" in split_reason.info for split_reason in subgraph_info.split_reasons) + + torch.testing.assert_close(actual, expected) + + g = torch.randn_like(actual) + actual_grad = torch.autograd.grad(actual, x, g) + expected_grad = torch.autograd.grad(expected, x, g) + torch.testing.assert_close(actual_grad, expected_grad)