From 4c9f389b3f6979c2cfab75ed7338d4f8cdfe22a1 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Wed, 18 Dec 2024 13:58:40 +0100 Subject: [PATCH] makes is_constant_for_vjp return False for higher_order_autograd_function_apply --- thunder/core/transforms.py | 2 ++ thunder/tests/test_dynamo.py | 15 +++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 3cb34a623e..b0b9fcc96b 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -2522,6 +2522,8 @@ def is_constant_for_vjp(symbol: prims.Symbol) -> bool: Returns: bool: True if the symbol is constant, False otherwise. """ + if isinstance(symbol.sym.id, str) and symbol.sym.id.startswith("higher_order_autograd_function_apply"): + return False are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args) # Symbol's tag their output in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`. # These are treated as constant for VJP. diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index c653a18310..bc9e23550a 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -274,7 +274,7 @@ def forward(ctx, x): @staticmethod def backward(ctx, g): (x,) = ctx.saved_tensors - return g * torch.cos(x) + return g * torch.cos(x) * 100 def func(x): y = torch.cos(x) + Sin.apply(x) @@ -286,9 +286,16 @@ def func(x): actual = cfunc(x) backend = cfunc._backend - targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes) - assert any(target.startswith("thunder_") for target in targets) - assert any(target.startswith("inductor_") for target in targets) + assert len(backend.subgraph_infos) == 1 # no graph break in dynamo + subgraph_info = backend.subgraph_infos[0] + assert len(subgraph_info.split_reasons) == 0 # no split + assert len(subgraph_info.thunder_compiled_fns) == 1 + jfunc = subgraph_info.thunder_compiled_fns[0] + trc = last_traces(jfunc)[0] + assert any( + isinstance(bsym.sym.id, str) and bsym.sym.id.startswith("higher_order_autograd_function_apply") + for bsym in trc.bound_symbols + ) # Verify forward pass torch.testing.assert_close(actual, expected)