diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index b4b83fb028..0244975fa0 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -332,8 +332,8 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason if target is torch.ops.higher_order.tag_activation_checkpoint: m = node.graph.owning_module assert hasattr(m, node.args[0].name) - higher_order_module = getattr(m, node.args[0].name) - is_module_supported, split_reason = is_graphmodule_supported_by_thunder(higher_order_module) + checkpointed_fn = getattr(m, node.args[0].name) + is_module_supported, split_reason = is_graphmodule_supported_by_thunder(checkpointed_fn) return is_module_supported, split_reason # If thunder has a mapping for this operation, try executing the meta function and see. diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index e5733652c1..a3f76e2ca4 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -574,10 +574,10 @@ def forward(self, x): out = jf(x) torch.testing.assert_close(ref_out, out) - g = torch.ones_like(out) + g = torch.randn_like(out) out.backward(g) - ref_g = torch.ones_like(ref_out) + ref_g = g.clone() ref_out.backward(ref_g) torch.testing.assert_close(x.grad, x_ref.grad) torch.testing.assert_close(tuple(model.parameters()), tuple(ref_model.parameters()))