Skip to content

Commit

Permalink
[XLA:GPU] support the hlo rewrite for the dot BF16_BF16_F32_X6 algori…
Browse files Browse the repository at this point in the history
…thm.

Now when we have all the pieces of the puzzle for X3 algorithm we could easily add its equivalent for X6.

PiperOrigin-RevId: 688294267
  • Loading branch information
loislo authored and hsharsha committed Jan 6, 2025
1 parent 790b249 commit 2c121c1
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 14 deletions.
2 changes: 2 additions & 0 deletions xla/service/algorithm_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ absl::StatusOr<se::blas::ComputationType> GetBlasComputationType(
case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM:
case PrecisionConfig::ALG_DOT_F16_F16_F32:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
case PrecisionConfig::ALG_DOT_F32_F32_F32:
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
return se::blas::ComputationType::kF32;
Expand Down Expand Up @@ -113,6 +114,7 @@ bool IsSupportedByCublasOrCublasLt(
switch (algorithm) {
case PrecisionConfig::ALG_DOT_BF16_BF16_F32:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
case PrecisionConfig::ALG_UNSET:
case PrecisionConfig::ALG_DOT_F16_F16_F32:
case PrecisionConfig::ALG_DOT_F32_F32_F32:
Expand Down
2 changes: 0 additions & 2 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,6 @@ absl::StatusOr<std::unique_ptr<HloModule>> CublasGemmAutotuneExtractor(
// don't use cuBlas in the end. This assumes that the substituting algorithm
// has result which are close enough for the check in this file.
if (dot->precision_config().algorithm() ==
PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 ||
dot->precision_config().algorithm() ==
PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) {
dot->mutable_precision_config()->set_algorithm(
PrecisionConfig::ALG_DOT_F32_F32_F32);
Expand Down
93 changes: 93 additions & 0 deletions xla/service/gpu/fusions/triton/dot_algorithms_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,22 @@ TEST_F(AlgorithmTest, Algorithm3xBF16) {
RunAndCompare(kHloText, ErrorSpec{/*aabs=*/0.001, /*arel=*/0.001}));
}

TEST_F(AlgorithmTest, Algorithm6xBF16) {
constexpr std::string_view kHloText = R"(
HloModule Algorithm6xBF16
ENTRY e {
p0 = f32[128,128] parameter(0)
p1 = f32[128,128] parameter(1)
ROOT dot = f32[128,128] dot(p0, p1),
lhs_contracting_dims={1}, rhs_contracting_dims={0},
algorithm=dot_bf16_bf16_f32_x6
}
)";
EXPECT_TRUE(
RunAndCompare(kHloText, ErrorSpec{/*aabs=*/0.001, /*arel=*/0.001}));
}

TEST_F(BlasAlgorithmTest, Algorithm_BF16_BF16_F32) {
// We check that the algorithm is propagated to the BLAS call.
// We also check that the kernel name matches the algorithm for Ampere.
Expand Down Expand Up @@ -299,6 +315,63 @@ TEST_F(BlasAlgorithmTest, Algorithm_BF16_BF16_F32_X3) {
}
}

TEST_F(BlasAlgorithmTest, Algorithm_BF16_BF16_F32_X6) {
if (!SupportsBF16(GpuComputeComp())) {
GTEST_SKIP() << "BF16 not supported.";
}
constexpr std::string_view kHloText = R"(
HloModule Algorithm_BF16_BF16_F32_X6
ENTRY main {
lhs = f32[8512,256]{1,0} parameter(0)
rhs = f32[256,8512]{1,0} parameter(1)
ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs),
algorithm=dot_bf16_bf16_f32_x6,
lhs_contracting_dims={1},
rhs_contracting_dims={0}
}
)";
// Single dot was replaced with 3 dots.
const std::string pattern = R"(
CHECK-COUNT-6: custom_call_target="__cublas$gemm"
)";

TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern));
ASSERT_TRUE(ok);

auto tracer = KernelNameTracer::Create();
if (tracer == nullptr) {
GTEST_SKIP() << "KernelNameTracer is not implemented.";
}
tracer->start();
EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false));
auto kernel_names = tracer->stop();

auto cc = GetCudaComputeCapability();
using CudaComputeCapabilities =
stream_executor::CudaComputeCapability::CudaComputeCapabilities;
switch (cc.major) {
case CudaComputeCapabilities::BLACKWELL:
GTEST_SKIP() << "CudaComputeCapabilities::BLACKWELL has the kernel name: "
<< kernel_names[0];
break;
case CudaComputeCapabilities::AMPERE:
ASSERT_EQ(kernel_names.size(), 1);
EXPECT_THAT(kernel_names[0], ::testing::Eq("loop_convert_fusion_1"));
break;
case CudaComputeCapabilities::HOPPER:
EXPECT_THAT(kernel_names,
::testing::UnorderedElementsAre(
::testing::HasSubstr("loop_convert_fusion"),
::testing::HasSubstr("gemm_bf16f32_bf16f32_f32_")));
break;
default:
GTEST_SKIP() << "Unsupported compute capability: " << cc.major
<< " has the kernel name: " << kernel_names[0];
}
}

TEST_F(BlasAlgorithmTest, Algorithm_TF32_TF32_F32_X3) {
// We check that the algorithm is propagated to the BLAS call.
// We also check that the kernel name matches the algorithm for Ampere.
Expand Down Expand Up @@ -1092,6 +1165,26 @@ TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32_X3) {
EXPECT_TRUE(ok);
}

TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32_X6) {
const std::string kHloText = R"(
HloModule Algorithm_BF16_BF16_F32_X6
ENTRY main {
lhs = f32[8512,64]{1,0} parameter(0)
rhs = f32[64,8512]{1,0} parameter(1)
ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs),
algorithm=dot_bf16_bf16_f32_x6,
lhs_contracting_dims={1},
rhs_contracting_dims={0}
}
)";
const std::string pattern =
R"(CHECK: "kind":"__triton_gemm","triton_gemm_config")";
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern));
EXPECT_TRUE(ok);
}

TEST_F(TritonAlgorithmTest, Algorithm_TF32_TF32_F32_X3) {
const std::string kHloText = R"(
HloModule Algorithm_TF32_TF32_F32_X3
Expand Down
77 changes: 65 additions & 12 deletions xla/service/gpu/transforms/dot_algorithm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "xla/service/gpu/transforms/dot_algorithm_rewriter.h"

#include <cstdint>
#include <tuple>
#include <utility>

#include "absl/container/flat_hash_set.h"
Expand Down Expand Up @@ -54,22 +55,31 @@ HloInstruction* Truncate(HloInstruction* f32_param) {
HloInstruction::CreateBitcastConvert(f32_param->shape(), masked_u32));
}

HloInstruction* SubAndRoundToBF16(HloInstruction* instr, HloInstruction* high) {
HloInstruction* sub = instr->AddInstruction(HloInstruction::CreateBinary(
HloInstruction* Sub(HloInstruction* instr, HloInstruction* high) {
return instr->AddInstruction(HloInstruction::CreateBinary(
instr->shape(), HloOpcode::kSubtract, instr, high));
}

HloInstruction* RoundToBF16(HloInstruction* instr) {
Shape new_shape = instr->shape();
new_shape.set_element_type(PrimitiveType::BF16);
return sub->AddInstruction(HloInstruction::CreateConvert(new_shape, sub));
return instr->AddInstruction(HloInstruction::CreateConvert(new_shape, instr));
}

std::pair<HloInstruction*, HloInstruction*> Split(HloInstruction* f32_param) {
std::pair<HloInstruction*, HloInstruction*> Split2x(HloInstruction* f32_param) {
HloInstruction* high_f32 = Truncate(f32_param);
HloInstruction* low_bf16 = SubAndRoundToBF16(f32_param, high_f32);
Shape bf16_shape = high_f32->shape();
bf16_shape.set_element_type(PrimitiveType::BF16);
HloInstruction* high_bf16 = high_f32->AddInstruction(
HloInstruction::CreateConvert(bf16_shape, high_f32));
return std::make_pair(high_bf16, low_bf16);
HloInstruction* low_f32 = Sub(f32_param, high_f32);
return std::make_pair(RoundToBF16(high_f32), RoundToBF16(low_f32));
}

std::tuple<HloInstruction*, HloInstruction*, HloInstruction*> Split3x(
HloInstruction* f32_param) {
HloInstruction* high_f32_t = Truncate(f32_param);
HloInstruction* mid_f32 = Sub(f32_param, high_f32_t);
HloInstruction* mid_f32_t = Truncate(mid_f32);
HloInstruction* low_f32_t = Truncate(Sub(mid_f32, mid_f32_t));
return std::make_tuple(RoundToBF16(high_f32_t), RoundToBF16(mid_f32_t),
RoundToBF16(low_f32_t));
}

void RewriteF32ToBF16X3(HloInstruction* instr) {
Expand All @@ -80,8 +90,8 @@ void RewriteF32ToBF16X3(HloInstruction* instr) {
const Shape& shape = dot->shape();
const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();

auto [lhs_high_bf16, lhs_low_bf16] = Split(dot->mutable_operand(0));
auto [rhs_high_bf16, rhs_low_bf16] = Split(dot->mutable_operand(1));
auto [lhs_high_bf16, lhs_low_bf16] = Split2x(dot->mutable_operand(0));
auto [rhs_high_bf16, rhs_low_bf16] = Split2x(dot->mutable_operand(1));

HloInstruction* high_dot =
computation->AddInstruction(HloInstruction::CreateDot(
Expand All @@ -102,6 +112,45 @@ void RewriteF32ToBF16X3(HloInstruction* instr) {
TF_CHECK_OK(dot->parent()->RemoveInstruction(dot));
}

void RewriteF32ToBF16X6(HloInstruction* instr) {
HloComputation* computation = instr->parent();
HloDotInstruction* original_dot = Cast<HloDotInstruction>(instr);
PrecisionConfig precision_config = original_dot->precision_config();
precision_config.clear_algorithm();
const Shape& shape = original_dot->shape();
const DotDimensionNumbers& dnums = original_dot->dot_dimension_numbers();
auto dot = [&](HloInstruction* lhs, HloInstruction* rhs) {
return computation->AddInstruction(
HloInstruction::CreateDot(shape, lhs, rhs, dnums, precision_config));
};
auto sum = [&](HloInstruction* lhs, HloInstruction* rhs) {
return computation->AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs));
};

auto [lhs_high_bf16, lhs_mid_bf16, lhs_low_bf16] =
Split3x(original_dot->mutable_operand(0));
auto [rhs_high_bf16, rhs_mid_bf16, rhs_low_bf16] =
Split3x(original_dot->mutable_operand(1));

HloInstruction* middle_middle_dot = dot(lhs_mid_bf16, rhs_mid_bf16);
HloInstruction* high_low_dot = dot(lhs_high_bf16, rhs_low_bf16);
HloInstruction* low_high_dot = dot(lhs_low_bf16, rhs_high_bf16);
HloInstruction* high_middle_dot = dot(lhs_high_bf16, rhs_mid_bf16);
HloInstruction* middle_high_dot = dot(lhs_mid_bf16, rhs_high_bf16);
HloInstruction* high_high_dot = dot(lhs_high_bf16, rhs_high_bf16);

HloInstruction* result = nullptr;
result = sum(middle_middle_dot, high_low_dot);
result = sum(result, low_high_dot);
result = sum(result, high_middle_dot);
result = sum(result, middle_high_dot);
result = sum(result, high_high_dot);

TF_CHECK_OK(original_dot->ReplaceAllUsesWith(result));
TF_CHECK_OK(original_dot->parent()->RemoveInstruction(original_dot));
}

} // namespace

absl::StatusOr<bool> DotAlgorithmRewriter::Run(
Expand All @@ -122,6 +171,10 @@ absl::StatusOr<bool> DotAlgorithmRewriter::Run(
RewriteF32ToBF16X3(instruction);
changed = true;
break;
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6:
RewriteF32ToBF16X6(instruction);
changed = true;
break;
default:
break;
}
Expand Down

0 comments on commit 2c121c1

Please sign in to comment.