diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index aa0fe8728c..e59f9adef0 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -836,20 +836,13 @@ def test_tom_overrides_proxy(device): assert v is params_actual[k] -@pytest.mark.parametrize( - "device", - ("cpu", "cuda"), -) -def test_cache_symbolic_values_basic(device): - if device == "cuda" and not torch.cuda.is_available(): - pytest.skip("CUDA not available") - +def test_cache_symbolic_values_basic(): def foo(a, scalar): return (a * scalar).sum(scalar) jfoo = thunder.jit(foo, cache="symbolic values") - a = torch.randn((2, 2, 2), device=device) + a = torch.randn((2, 2, 2), device="cpu") b = 1 actual = jfoo(a, b) @@ -1118,15 +1111,8 @@ def forward(self, x): torch.testing.assert_close(actual, expected) -@pytest.mark.parametrize( - "device", - ("cpu", "cuda"), -) -def test_cache_symbolic_values_reshape(device): - if device == "cuda" and not torch.cuda.is_available(): - pytest.skip("CUDA not available") - - a = torch.randn((4, 8, 6), device=device) +def test_cache_symbolic_values_reshape(): + a = torch.randn((4, 8, 6), device="cpu") def foo(t, batch_size): return t.reshape(batch_size, -1).sum(-1) @@ -1421,20 +1407,13 @@ def fn(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10): assert [a.name for a in thunder.last_traces(fn)[-1].args] == [f"a{i}" for i in range(11)] -@pytest.mark.parametrize( - "device", - ("cpu", "cuda"), -) def test_cache_symbolic_values_dynamic_shape(device): - if device == "cuda" and not torch.cuda.is_available(): - pytest.skip("CUDA not available") - def foo(a): return a.relu() jfoo = thunder.jit(foo, cache="symbolic values") - a = torch.randn((2, 2, 2), device=device) + a = torch.randn((2, 2, 2), device="cpu") actual = jfoo(a) expected = foo(a) @@ -1443,7 +1422,7 @@ def foo(a): assert thunder.cache_misses(jfoo) == 1 assert thunder.cache_hits(jfoo) == 0 - a = torch.randn((3, 4, 5), device=device) + a = torch.randn((3, 4, 5), device="cpu") actual = jfoo(a) expected = foo(a)