Skip to content

Commit

Permalink
follow comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Oct 21, 2024
1 parent 4ff8580 commit 856141c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
if target is torch.ops.higher_order.tag_activation_checkpoint:
m = node.graph.owning_module
assert hasattr(m, node.args[0].name)
higher_order_module = getattr(m, node.args[0].name)
is_module_supported, split_reason = is_graphmodule_supported_by_thunder(higher_order_module)
checkpointed_fn = getattr(m, node.args[0].name)
is_module_supported, split_reason = is_graphmodule_supported_by_thunder(checkpointed_fn)
return is_module_supported, split_reason

# If thunder has a mapping for this operation, try executing the meta function and see.
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,10 @@ def forward(self, x):
out = jf(x)
torch.testing.assert_close(ref_out, out)

g = torch.ones_like(out)
g = torch.randn_like(out)
out.backward(g)

ref_g = torch.ones_like(ref_out)
ref_g = g.clone()
ref_out.backward(ref_g)
torch.testing.assert_close(x.grad, x_ref.grad)
torch.testing.assert_close(tuple(model.parameters()), tuple(ref_model.parameters()))
Expand Down

0 comments on commit 856141c

Please sign in to comment.