From cad9c6fd8cac9119e01b4620659093e2e93e5db0 Mon Sep 17 00:00:00 2001 From: mmakevic-amd Date: Thu, 14 Nov 2024 20:07:48 -0800 Subject: [PATCH] PR #19342: [ROCm] Skip unsupported tests in dot_algorithms_test Imported from GitHub PR https://github.com/openxla/xla/pull/19342 Triton is currently disabled on ROCm. Skipping the following subtests in `dot_algorithms_test`: - TritonAlgorithmTest.Algorithm_BF16_BF16_F32_X3 - TritonAlgorithmTest.Algorithm_BF16_BF16_F32_X6 - TritonAlgorithmTest.Algorithm_TF32_TF32_F32 - TritonAlgorithmTest.Algorithm_TF32_TF32_F32_X3 - TritonAlgorithmTest.Algorithm_BF16_BF16_F32 Copybara import of the project: -- 32bd775f87e142bcb194dcc8cc9807c864c995da by Milica Makevic : Disable unsupported Triton subtests Merging this change closes #19342 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/19342 from ROCm:disable_triton_tests 32bd775f87e142bcb194dcc8cc9807c864c995da PiperOrigin-RevId: 696740956 --- .../gpu/fusions/triton/dot_algorithms_test.cc | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/xla/service/gpu/fusions/triton/dot_algorithms_test.cc b/xla/service/gpu/fusions/triton/dot_algorithms_test.cc index bcebfb9776248..c111fb388d821 100644 --- a/xla/service/gpu/fusions/triton/dot_algorithms_test.cc +++ b/xla/service/gpu/fusions/triton/dot_algorithms_test.cc @@ -1146,6 +1146,9 @@ CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 } TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32_X3) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "Triton currently disabled on ROCM."; + } const std::string kHloText = R"( HloModule Algorithm_BF16_BF16_F32_X3 @@ -1166,6 +1169,9 @@ TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32_X3) { } TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32_X6) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "Triton currently disabled on ROCM."; + } const std::string kHloText = R"( HloModule Algorithm_BF16_BF16_F32_X6 @@ -1185,7 +1191,33 @@ TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32_X6) { EXPECT_TRUE(ok); } +TEST_F(TritonAlgorithmTest, Algorithm_TF32_TF32_F32) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "Triton currently disabled on ROCM."; + } + const std::string kHloText = R"( + HloModule Algorithm_TF32_TF32_F32 + + ENTRY main { + lhs = f32[128,1]{1,0} parameter(0) + rhs = f32[1,128]{1,0} parameter(1) + ROOT dot = f32[128,128]{1,0} dot(lhs, rhs), + algorithm=dot_tf32_tf32_f32, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + const std::string pattern = + R"(CHECK: "kind":"__triton_gemm","triton_gemm_config")"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + EXPECT_TRUE(ok); +} + TEST_F(TritonAlgorithmTest, Algorithm_TF32_TF32_F32_X3) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "Triton currently disabled on ROCM."; + } const std::string kHloText = R"( HloModule Algorithm_TF32_TF32_F32_X3 @@ -1209,6 +1241,9 @@ TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "Triton currently disabled on ROCM."; + } const std::string kHloText = R"( HloModule Algorithm_BF16_BF16_F32