Skip to content

Commit

Permalink
Move TPU ops test to ops_test.py
Browse files Browse the repository at this point in the history
Move the TPU ops test from `tpu_ops_test.py` to `ops_test.py`. The functions tested in this file are not TPU-specific operations, so we don't need a separate test file.

PiperOrigin-RevId: 656347969
  • Loading branch information
ayaka14732 authored and jax authors committed Jul 26, 2024
1 parent 2db99e0 commit bb160cf
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 84 deletions.
50 changes: 50 additions & 0 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,5 +1171,55 @@ class PallasPrimitivesInterpreterTest(PallasPrimitivesTest):
INTERPRET = True


class TpuOpsTest(PallasBaseTest):

def setUp(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Test requires TPU device.")

super().setUp()

@parameterized.parameters([-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4])
def test_erf_inv(self, x):
@jax.jit
@functools.partial(
pl.pallas_call,
# TODO(ayx): add float64 support for `erf_inv`
out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
)
def kernel(x_ref, o_ref):
o_ref[...] = lax.erf_inv(x_ref[...])

x = jnp.full((4,), x)
out = kernel(x)
expected = lax.erf_inv(x)
np.testing.assert_array_equal(out, expected)

SIGN_PARAMS = [
(jnp.int32, (-3, 0, 5)),
(jnp.uint32, (0, 5)),
(jnp.float32, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)),
]

@parameterized.named_parameters(
(f"{dtype.__name__}_{value}", dtype, value)
for dtype, values in SIGN_PARAMS
for value in values
)
def test_sign(self, dtype, value):
@jax.jit
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct((4,), dtype),
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.sign(x_ref[...])

x = jnp.full((4,), value, dtype=dtype)
out = kernel(x)
expected = jnp.sign(x)
np.testing.assert_array_equal(out, expected)


if __name__ == "__main__":
absltest.main()
84 changes: 0 additions & 84 deletions tests/pallas/tpu_ops_test.py

This file was deleted.

0 comments on commit bb160cf

Please sign in to comment.