Skip to content

Commit 53a3277

Browse files
committed
adjust tests details
1 parent 4ed48f6 commit 53a3277

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

thunder/tests/test_grad.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,6 +1938,7 @@ def func(x):
19381938
torch.testing.assert_close(actual, expected)
19391939
torch.testing.assert_close(actual_gr, expected_gr)
19401940

1941+
19411942
@pytest.mark.parametrize("device", ("cuda", "cpu"))
19421943
def test_backward_recomputation_decomposed_ops(device):
19431944
if device == "cuda" and not torch.cuda.is_available():
@@ -1951,8 +1952,8 @@ def fn(a):
19511952
a = torch.randn(2, 2, device=device, requires_grad=True)
19521953
res = jfn(a)
19531954
res2 = jfn2(a)
1954-
assert len(res.grad_fn.saved_tensors) == 3 # should be decomposed
1955-
assert len(res2.grad_fn.saved_tensors) == 1
1955+
assert len(res.grad_fn.next_functions[0][0].saved_tensors) == 3 # should be decomposed
1956+
assert len(res2.grad_fn.next_functions[0][0].saved_tensors) == 1
19561957

19571958
if NVFUSER_AVAILABLE and device == "cuda":
19581959
# check everything is fused

thunder/tests/test_networks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,4 +553,4 @@ def test_hf_llama():
553553

554554
top_level_symbol_names = {bsym.sym.name for bsym in thunder.last_traces(jm)[-1].bound_symbols}
555555
# changes this to fewer as needed, the goal is to not have too many fusions
556-
assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 7
556+
assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 8

0 commit comments

Comments
 (0)