Skip to content

Commit

Permalink
makes is_constant_for_vjp return False for higher_order_autograd_func…
Browse files Browse the repository at this point in the history
…tion_apply
  • Loading branch information
kiya00 committed Dec 18, 2024
1 parent 21641ff commit 4c9f389
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
2 changes: 2 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2522,6 +2522,8 @@ def is_constant_for_vjp(symbol: prims.Symbol) -> bool:
Returns:
bool: True if the symbol is constant, False otherwise.
"""
if isinstance(symbol.sym.id, str) and symbol.sym.id.startswith("higher_order_autograd_function_apply"):
return False
are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args)
# Symbol's tag their output in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`.
# These are treated as constant for VJP.
Expand Down
15 changes: 11 additions & 4 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def forward(ctx, x):
@staticmethod
def backward(ctx, g):
(x,) = ctx.saved_tensors
return g * torch.cos(x)
return g * torch.cos(x) * 100

def func(x):
y = torch.cos(x) + Sin.apply(x)
Expand All @@ -286,9 +286,16 @@ def func(x):
actual = cfunc(x)

backend = cfunc._backend
targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes)
assert any(target.startswith("thunder_") for target in targets)
assert any(target.startswith("inductor_") for target in targets)
assert len(backend.subgraph_infos) == 1 # no graph break in dynamo
subgraph_info = backend.subgraph_infos[0]
assert len(subgraph_info.split_reasons) == 0 # no split
assert len(subgraph_info.thunder_compiled_fns) == 1
jfunc = subgraph_info.thunder_compiled_fns[0]
trc = last_traces(jfunc)[0]
assert any(
isinstance(bsym.sym.id, str) and bsym.sym.id.startswith("higher_order_autograd_function_apply")
for bsym in trc.bound_symbols
)

# Verify forward pass
torch.testing.assert_close(actual, expected)
Expand Down

0 comments on commit 4c9f389

Please sign in to comment.