Skip to content

Commit

Permalink
[thunderFX] splitter - handle no_grad correctly (#1282)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Oct 11, 2024
1 parent 16ad769 commit cc7335c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 2 deletions.
37 changes: 35 additions & 2 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values()))


# Currently, thunder as mapping torch these function but they
# just throw warning.
UNSUPPORTED_THUNDER_FUNCTION = (torch._C._set_grad_enabled,)


class CompilerType(Enum):
"""
An enumeration representing different types of compilers.
Expand Down Expand Up @@ -221,10 +226,29 @@ def get_nodes_in_unsupported_ctx_regions(gm: torch.fx.GraphModule) -> set[torch.

# We want to mark nodes with `_enter_autocast` and `_exit_autocast`
# as unsupported as `thunder` doesn't correctly deal with these stateful functions.

def is_no_grad_ctx_enter(node):
if node.target == torch._C._set_grad_enabled:
arg: bool = node.args[0]
assert isinstance(arg, bool)
return not arg # arg is False (i.e. grad was disabled)
return False

def is_no_grad_ctx_exit(node):
if node.target == torch._C._set_grad_enabled:
arg: bool = node.args[0]
assert isinstance(arg, bool)
return arg # arg is True (i.e. grad was enabled)
return False

for node in gm.graph.nodes:
if node.op == "call_function" and node.target in (torch.amp.autocast_mode._enter_autocast,):
if node.op == "call_function" and (
node.target in (torch.amp.autocast_mode._enter_autocast,) or is_no_grad_ctx_enter(node)
):
ctx_cnt += 1
elif node.op == "call_function" and node.target in (torch.amp.autocast_mode._exit_autocast,):
elif node.op == "call_function" and (
node.target in (torch.amp.autocast_mode._exit_autocast,) or is_no_grad_ctx_exit(node)
):
ctx_cnt -= 1
else:
if ctx_cnt > 0:
Expand Down Expand Up @@ -271,6 +295,15 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
)
return False, split_reason

# These functions are present in `_torch_to_thunder_function_map` but don't mimic exact behavior.
# Eg. torch._C._set_grad_enabled's thunder implementation just throws warning that this is unsupported.
if target in UNSUPPORTED_THUNDER_FUNCTION:
split_reason = SplitReason(
SplitReasonType.UNSUPPORTED_NODE,
info=f"node with name: {node.name} and target: {node.target} has been manually disabled.",
)
return False, split_reason

# If thunder has a mapping for this operation, try executing the meta function and see.
# We have a symbol for `torch.where`, but we don't support one overload of it.
# So, we try and execute the meta to get a real signal.
Expand Down
30 changes: 30 additions & 0 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,33 @@ def test_thundercompiler_optim_step(executor, device, dtype, optim):
tuple(ref_model.parameters()),
msg=lambda s: f"{i+1}-iter {s}",
)


@instantiate(dtypes=NOTHING, executors=[DynamoThunderExecutor])
def test_no_grad_ctx_manager(executor, device: str, dtype: dtypes.dtype):
backend = ThunderCompiler()

def func(x):
with torch.no_grad():
with torch.autocast("cuda", dtype=torch.bfloat16):
y = x @ x
return y + x

x = torch.randn(3, 3, device=device, dtype=dtype, requires_grad=True)
actual = torch.compile(func, backend=backend)(x)
expected = torch.compile(func, backend="eager")(x)

# We record the GraphModules that was compiled by ThunderCompiler
assert len(backend.subgraph_infos) == 1

for subgraph_info in backend.subgraph_infos:
assert len(subgraph_info.split_reasons) > 1 # Verify there were splits in the subgraph.
assert isinstance(subgraph_info.original_graph_module, torch.fx.GraphModule)
assert any("has been manually disabled" in split_reason.info for split_reason in subgraph_info.split_reasons)

torch.testing.assert_close(actual, expected)

g = torch.randn_like(actual)
actual_grad = torch.autograd.grad(actual, x, g)
expected_grad = torch.autograd.grad(expected, x, g)
torch.testing.assert_close(actual_grad, expected_grad)

0 comments on commit cc7335c

Please sign in to comment.