Skip to content

Commit

Permalink
PR openxla#19342: [ROCm] Skip unsupported tests in dot_algorithms_test
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla#19342

Triton is currently disabled on ROCm. Skipping the following subtests in `dot_algorithms_test`:
- TritonAlgorithmTest.Algorithm_BF16_BF16_F32_X3
- TritonAlgorithmTest.Algorithm_BF16_BF16_F32_X6
- TritonAlgorithmTest.Algorithm_TF32_TF32_F32
- TritonAlgorithmTest.Algorithm_TF32_TF32_F32_X3
- TritonAlgorithmTest.Algorithm_BF16_BF16_F32
Copybara import of the project:

--
32bd775 by Milica Makevic <Milica.Makevic@amd.com>:

Disable unsupported Triton subtests

Merging this change closes openxla#19342

COPYBARA_INTEGRATE_REVIEW=openxla#19342 from ROCm:disable_triton_tests 32bd775
PiperOrigin-RevId: 696740956
  • Loading branch information
mmakevic-amd authored and hsharsha committed Jan 7, 2025
1 parent 1572448 commit cad9c6f
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions xla/service/gpu/fusions/triton/dot_algorithms_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,9 @@ CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32
}

TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32_X3) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "Triton currently disabled on ROCM.";
}
const std::string kHloText = R"(
HloModule Algorithm_BF16_BF16_F32_X3
Expand All @@ -1166,6 +1169,9 @@ TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32_X3) {
}

TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32_X6) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "Triton currently disabled on ROCM.";
}
const std::string kHloText = R"(
HloModule Algorithm_BF16_BF16_F32_X6
Expand All @@ -1185,7 +1191,33 @@ TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32_X6) {
EXPECT_TRUE(ok);
}

TEST_F(TritonAlgorithmTest, Algorithm_TF32_TF32_F32) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "Triton currently disabled on ROCM.";
}
const std::string kHloText = R"(
HloModule Algorithm_TF32_TF32_F32
ENTRY main {
lhs = f32[128,1]{1,0} parameter(0)
rhs = f32[1,128]{1,0} parameter(1)
ROOT dot = f32[128,128]{1,0} dot(lhs, rhs),
algorithm=dot_tf32_tf32_f32,
lhs_contracting_dims={1},
rhs_contracting_dims={0}
}
)";
const std::string pattern =
R"(CHECK: "kind":"__triton_gemm","triton_gemm_config")";
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern));
EXPECT_TRUE(ok);
}

TEST_F(TritonAlgorithmTest, Algorithm_TF32_TF32_F32_X3) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "Triton currently disabled on ROCM.";
}
const std::string kHloText = R"(
HloModule Algorithm_TF32_TF32_F32_X3
Expand All @@ -1209,6 +1241,9 @@ TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32) {
if (!SupportsBF16(GpuComputeComp())) {
GTEST_SKIP() << "BF16 not supported.";
}
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "Triton currently disabled on ROCM.";
}
const std::string kHloText = R"(
HloModule Algorithm_BF16_BF16_F32
Expand Down

0 comments on commit cad9c6f

Please sign in to comment.