Skip to content

Commit

Permalink
Respect dot algorithm spec on TPU backends.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686286362
  • Loading branch information
pravnar authored and Google-ML-Automation committed Oct 18, 2024
1 parent bbcc3ee commit fa296d8
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,21 @@ def testDotAlgorithm(self, algorithm, dtype):
}:
raise SkipTest(
f"The dot algorithm '{algorithm}' is not supported on GPU.")
if jtu.test_device_matches(["tpu"]):
if algorithm not in {
lax.DotAlgorithmPreset.DEFAULT,
lax.DotAlgorithmPreset.BF16_BF16_F32,
lax.DotAlgorithmPreset.BF16_BF16_F32_X3,
lax.DotAlgorithmPreset.BF16_BF16_F32_X6,
}:
raise SkipTest(
f"The dot algorithm '{algorithm}' is not supported on TPU."
)
if algorithm != lax.DotAlgorithmPreset.DEFAULT and dtype != np.float32:
raise SkipTest(
f"The dot algorithm '{algorithm}' is only supported for float32 on"
" TPU."
)
lhs_shape = (3, 4)
rhs_shape = (4, 3)
rng = jtu.rand_default(self.rng())
Expand All @@ -1136,6 +1151,8 @@ def testDotAlgorithmCasting(self):
if xla_bridge.using_pjrt_c_api():
raise SkipTest(
"The dot algorithm attribute is not supported by PJRT C API.")
if jtu.test_device_matches(["tpu"]):
raise SkipTest("F32_F32_F32 is not supported on TPU.")
def fun(lhs, rhs):
return lax.dot(lhs, rhs, precision="F32_F32_F32")
lhs_shape = (3, 4)
Expand Down Expand Up @@ -1188,12 +1205,14 @@ def testDotPreferredElement(self, lhs_shape, rhs_shape, dtype,
)
def test_mixed_fp8_dot_general(self, lhs_shape, rhs_shape, dtype_lhs, dtype_rhs):
if jtu.test_device_matches(["tpu"]):
raise SkipTest("Mixed fp8 precision matmul is not yet supported on TPU")
raise SkipTest("Mixed fp8 precision matmul is not yet supported on TPU")
if not jtu.is_device_rocm() and (
dtype_lhs in [dtypes.float8_e4m3fnuz, dtypes.float8_e5m2fnuz] or
dtype_rhs in [dtypes.float8_e4m3fnuz, dtypes.float8_e5m2fnuz]
):
raise SkipTest("float8_e4m3fnuz and float8_e5m2fnuz types are only supported on ROCm")
raise SkipTest(
"float8_e4m3fnuz and float8_e5m2fnuz types are only supported on ROCm"
)
rng = jtu.rand_default(self.rng())
lhs = rng(lhs_shape, dtype=dtype_lhs)
rhs = rng(rhs_shape, dtype=dtype_rhs)
Expand Down

0 comments on commit fa296d8

Please sign in to comment.