Skip to content

Commit

Permalink
removing test for symbolic values cache running on different devices
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Oct 21, 2024
1 parent 79e59d0 commit 048f94a
Showing 1 changed file with 6 additions and 27 deletions.
33 changes: 6 additions & 27 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,20 +836,13 @@ def test_tom_overrides_proxy(device):
assert v is params_actual[k]


@pytest.mark.parametrize(
"device",
("cpu", "cuda"),
)
def test_cache_symbolic_values_basic(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")

def test_cache_symbolic_values_basic():
def foo(a, scalar):
return (a * scalar).sum(scalar)

jfoo = thunder.jit(foo, cache="symbolic values")

a = torch.randn((2, 2, 2), device=device)
a = torch.randn((2, 2, 2), device="cpu")
b = 1

actual = jfoo(a, b)
Expand Down Expand Up @@ -1118,15 +1111,8 @@ def forward(self, x):
torch.testing.assert_close(actual, expected)


@pytest.mark.parametrize(
"device",
("cpu", "cuda"),
)
def test_cache_symbolic_values_reshape(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")

a = torch.randn((4, 8, 6), device=device)
def test_cache_symbolic_values_reshape():
a = torch.randn((4, 8, 6), device="cpu")

def foo(t, batch_size):
return t.reshape(batch_size, -1).sum(-1)
Expand Down Expand Up @@ -1421,20 +1407,13 @@ def fn(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10):
assert [a.name for a in thunder.last_traces(fn)[-1].args] == [f"a{i}" for i in range(11)]


@pytest.mark.parametrize(
"device",
("cpu", "cuda"),
)
def test_cache_symbolic_values_dynamic_shape(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")

def foo(a):
return a.relu()

jfoo = thunder.jit(foo, cache="symbolic values")

a = torch.randn((2, 2, 2), device=device)
a = torch.randn((2, 2, 2), device="cpu")

actual = jfoo(a)
expected = foo(a)
Expand All @@ -1443,7 +1422,7 @@ def foo(a):
assert thunder.cache_misses(jfoo) == 1
assert thunder.cache_hits(jfoo) == 0

a = torch.randn((3, 4, 5), device=device)
a = torch.randn((3, 4, 5), device="cpu")

actual = jfoo(a)
expected = foo(a)
Expand Down

0 comments on commit 048f94a

Please sign in to comment.