Skip to content

Commit

Permalink
Merge pull request #81 from ROCm/rocm-jaxlib-v0.4.35-qa-triton-autotuner
Browse files Browse the repository at this point in the history
Enable Triton Auto-tuning in XLA
  • Loading branch information
i-chaochen authored Jan 8, 2025
2 parents e9826d8 + 2bd6f7e commit f55fc88
Show file tree
Hide file tree
Showing 14 changed files with 400 additions and 141 deletions.
11 changes: 11 additions & 0 deletions third_party/triton/temporary/fix_InsertInstructionSchedHints.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
--- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
+++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
@@ -59,7 +59,7 @@
let summary = "Insert instruction scheduling hints after the dot ops in the main loop";
let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()";

- let dependentDialects = ["mlir::LLVM::LLVMDialect"];
+ let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::triton::amdgpu::TritonAMDGPUDialect"];
}

def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> {
1 change: 1 addition & 0 deletions third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ temporary_patch_list = [
"//third_party/triton:temporary/fix_left_shift_overflow.patch",
"//third_party/triton:temporary/prefetch.patch",
"//third_party/triton:temporary/i4_to_bf16.patch",
"//third_party/triton:temporary/fix_InsertInstructionSchedHints.patch",
# Add new patches just above this line
]
2 changes: 2 additions & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2057,13 +2057,15 @@ cc_library(
"//xla/hlo/transforms:hlo_constant_folding",
"//xla/hlo/transforms:reshape_mover",
"//xla/hlo/transforms:tuple_simplifier",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:call_inliner",
"//xla/service:float_support",
"//xla/service:hlo_module_config",
"//xla/service:hlo_verifier",
"//xla/service/gpu/autotuning:autotuner_util",
"//xla/service/gpu/autotuning:conv_algorithm_picker",
"//xla/service/gpu/autotuning:gemm_algorithm_picker",
"//xla/service/gpu/autotuning:gemm_fusion_autotuner",
"//xla/service/gpu/llvm_gpu_backend",
"//xla/service/gpu/transforms:algebraic_simplifier",
"//xla/service/gpu/transforms:conv_padding_legalization",
Expand Down
12 changes: 12 additions & 0 deletions xla/service/gpu/amdgpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ limitations under the License.
#include "xla/hlo/transforms/simplifiers/hlo_constant_folding.h"
#include "xla/hlo/transforms/simplifiers/reshape_mover.h"
#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/service/call_inliner.h"
#include "xla/service/float_support.h"
#include "xla/service/gpu/autotuning/autotuner_util.h"
#include "xla/service/gpu/autotuning/conv_algorithm_picker.h"
#include "xla/service/gpu/autotuning/gemm_algorithm_picker.h"
#include "xla/service/gpu/autotuning/gemm_fusion_autotuner.h"
#include "xla/service/gpu/cublas_padding_requirements.h"
#include "xla/service/gpu/gpu_compiler.h"
#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
Expand Down Expand Up @@ -268,5 +270,15 @@ AMDGPUCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
return BackendCompileResult{"", std::move(hsaco)};
}

absl::Status AMDGPUCompiler::AddGemmFusionAutotuningPasses(
HloPassPipeline* pipeline, HloModule* hlo_module,
AutotuneConfig& autotune_config, tsl::thread::ThreadPool* thread_pool,
const MultiProcessKeyValueStore& key_value_store,
const se::SemanticVersion& toolkit_version) {
pipeline->AddPass<GemmFusionAutotuner>(autotune_config, toolkit_version,
thread_pool, key_value_store);
return absl::OkStatus();
}

} // namespace gpu
} // namespace xla
6 changes: 6 additions & 0 deletions xla/service/gpu/amdgpu_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ class AMDGPUCompiler : public GpuCompiler {
se::GpuComputeCapability gpu_version, bool relocatable,
const HloModule* debug_module, const CompileOptions& options) override;

absl::Status AddGemmFusionAutotuningPasses(
HloPassPipeline* pipeline, HloModule* hlo_module,
AutotuneConfig& autotune_config, tsl::thread::ThreadPool* thread_pool,
const MultiProcessKeyValueStore& key_value_store,
const se::SemanticVersion& toolkit_version) override;

private:
AMDGPUCompiler(const AMDGPUCompiler&) = delete;
AMDGPUCompiler& operator=(const AMDGPUCompiler&) = delete;
Expand Down
88 changes: 83 additions & 5 deletions xla/service/gpu/autotuning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ load(
"@tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
)
load(
"@local_config_rocm//rocm:build_defs.bzl",
"if_rocm_is_configured",
)
load("//xla:xla.bzl", "xla_cc_test")
load("//xla/tests:build_defs.bzl", "xla_test")

Expand All @@ -26,14 +30,89 @@ package_group(
)

cc_library(
name = "gemm_fusion_autotuner",
srcs = ["gemm_fusion_autotuner.cc"],
hdrs = ["gemm_fusion_autotuner.h"],
name = "gemm_fusion_autotuner_cuda",
srcs = [
"gemm_fusion_autotuner.h",
"gemm_fusion_autotuner_cuda.cc",
],
tags = [
"cuda-only",
"gpu",
],
deps = [
":autotuner_compile_util",
":autotuner_util",
"//xla:autotuning_proto_cc",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:algorithm_util",
"//xla/service:executable",
"//xla/service:shaped_buffer",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:matmul_utils",
"//xla/service/gpu:stream_executor_util",
"//xla/service/gpu/transforms:cudnn_fusion_compiler",
"//xla/stream_executor:device_description",
"//xla/stream_executor:semantic_version",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@local_config_cuda//cuda:cuda_headers",
"@tsl//tsl/platform:env",
],
)

cc_library(
name = "gemm_fusion_autotuner_rocm",
srcs = [
"gemm_fusion_autotuner.h",
"gemm_fusion_autotuner_rocm.cc",
],
tags = [
"gpu",
"rocm-only",
],
deps = [
":autotuner_compile_util",
":autotuner_util",
"//xla:autotuning_proto_cc",
"//xla:xla_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:executable",
"//xla/service:shaped_buffer",
"//xla/service/gpu:matmul_utils",
"//xla/stream_executor:device_description",
"//xla/stream_executor:semantic_version",
"//xla/stream_executor/rocm:rocblas_plugin",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@local_config_rocm//rocm:rocm_headers",
"@tsl//tsl/platform:env",
],
)

cc_library(
name = "gemm_fusion_autotuner",
srcs = [
"gemm_fusion_autotuner.cc",
],
hdrs = ["gemm_fusion_autotuner.h"],
tags = ["gpu"],
deps = if_cuda_is_configured([":gemm_fusion_autotuner_cuda"]) + if_rocm_is_configured([
":gemm_fusion_autotuner_rocm",
]) + [
":autotuner_compile_util",
":autotuner_util",
"//xla:autotune_results_proto_cc",
Expand Down Expand Up @@ -69,7 +148,6 @@ cc_library(
"//xla/service/gpu/kernels:custom_kernel",
"//xla/service/gpu/kernels:custom_kernel_fusion",
"//xla/service/gpu/kernels:custom_kernel_fusion_pattern",
"//xla/service/gpu/transforms:cudnn_fusion_compiler",
"//xla/service/gpu/transforms:custom_kernel_fusion_rewriter",
"//xla/service/gpu/transforms:dot_algorithm_rewriter",
"//xla/service/gpu/transforms:fusion_wrapper",
Expand All @@ -93,10 +171,10 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@local_config_cuda//cuda:cuda_headers",
"@tsl//tsl/platform:blocking_counter",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
Expand Down
107 changes: 22 additions & 85 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "xla/autotune_results.pb.h"
#include "xla/autotuning.pb.h"
#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h"
Expand Down Expand Up @@ -74,7 +73,6 @@ limitations under the License.
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/split_k_gemm_rewriter.h"
#include "xla/service/gpu/stream_executor_util.h"
#include "xla/service/gpu/transforms/cudnn_fusion_compiler.h"
#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h"
#include "xla/service/gpu/transforms/dot_algorithm_rewriter.h"
#include "xla/service/gpu/transforms/fusion_wrapper.h"
Expand Down Expand Up @@ -424,29 +422,11 @@ absl::StatusOr<std::unique_ptr<HloModule>> CuDnnFusionExtractor(
return module;
}

bool IsFusionKind(const HloInstruction& hlo, absl::string_view kind) {
auto gpu_config = hlo.backend_config<GpuBackendConfig>();
if (!gpu_config.ok()) {
return false;
}
return gpu_config->fusion_backend_config().kind() == kind;
}

int GetCuDnnPlanCount(const HloInstruction& hlo,
const AutotuneConfig& autotune_config) {
if (auto gpu_config = hlo.backend_config<GpuBackendConfig>();
!gpu_config.ok() ||
gpu_config->fusion_backend_config().has_cudnn_fusion_config()) {
return {};
}
return CuDnnFusionCompiler::GetAvailablePlanCount(
*autotune_config.GetExecutor(), *DynCast<HloFusionInstruction>(&hlo));
}

AutotuneResult FromConfig(const BackendConfig& config) {
AutotuneResult res;
if (std::holds_alternative<GemmFusionAutotunerImpl::CuBlasConfig>(config)) {
res.mutable_gemm()->set_algorithm(CUBLAS_GEMM_DEFAULT);
res.mutable_gemm()->set_algorithm(
GemmFusionAutotunerImpl::BLAS_GEMM_DEFAULT);
} else if (std::holds_alternative<
GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config)) {
res.mutable_custom_kernel_fusion()->set_kernel_index(
Expand Down Expand Up @@ -674,6 +654,15 @@ absl::Status GemmFusionAutotunerRewriterVisitor::HandleFusion(
return absl::OkStatus();
}

bool GemmFusionAutotunerImpl::IsFusionKind(const HloInstruction& hlo,
absl::string_view kind) {
auto gpu_config = hlo.backend_config<GpuBackendConfig>();
if (!gpu_config.ok()) {
return false;
}
return gpu_config->fusion_backend_config().kind() == kind;
}

// Methods required for sorting the configs.
bool GemmFusionAutotunerImpl::CuBlasConfig::operator<(
const CuBlasConfig& other) const {
Expand Down Expand Up @@ -788,28 +777,8 @@ GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) {
configs.push_back(CuBlasConfig{});
}

// Add cuDNN plans, if available.
bool is_hopper =
!config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
bool is_cudnn_enabled =
debug_options_.xla_gpu_cudnn_gemm_fusion_level() > 0 && is_hopper &&
GetDnnVersionInfoOrDefault(config_.GetExecutor()).major_version() >= 9;
if ((IsFusionKind(fusion, kCuDnnFusionKind) && IsAutotuningEnabled()) ||
(IsFusionKind(fusion, kTritonGemmFusionKind) && is_cudnn_enabled &&
algorithm_util::IsSupportedByCudnn(
dot->precision_config().algorithm()) &&
!dot->sparse_operands() && IsAutotuningEnabled())) {
const int plan_count = GetCuDnnPlanCount(fusion, config_);
for (int plan_id = 0; plan_id < plan_count; ++plan_id) {
configs.push_back(CuDnnConfig{plan_id});
}
}
if (IsFusionKind(fusion, kCuDnnFusionKind)) {
if (!IsAutotuningEnabled()) {
configs.push_back(CuDnnConfig{-1});
}
return configs;
}
// Add lib (e.g. cuDNN) plans, if available.
if (AddLibConfigs(fusion, dot, configs)) return configs;
}

// Add CustomKernelFusion (Cutlass) configs, if available.
Expand Down Expand Up @@ -885,8 +854,6 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {

// Triton configurations are adjusted and deduplicated.
absl::flat_hash_set<TritonGemmConfig> added;
bool is_hopper =
!config_.IsDeviceless() && GetComputeCapability().IsAtLeastHopper();
for (TritonGemmConfig& config : triton_configs) {
config.block_m = std::min(config.block_m, limits.block_m);
config.block_n = std::min(config.block_n, limits.block_n);
Expand All @@ -909,10 +876,8 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
// Sparse meta should have at least one element per thread.
// Note: only 2:4 structured sparsity is currently supported.
if (dot.sparse_operands()) {
if (is_hopper) {
config.block_m = std::max(config.block_m, 64);
config.num_warps = std::max(config.num_warps, 4);
}
config.block_m = std::max(config.block_m, 64);
config.num_warps = std::max(config.num_warps, 4);
config.block_k = std::max(
config.block_k,
2 * std::max(kMinTileSize, kLdmatrixGranularity / minBitWidth));
Expand Down Expand Up @@ -1192,9 +1157,13 @@ absl::StatusOr<std::vector<AutotuneResult>> GemmFusionAutotunerImpl::Profile(
std::vector<TritonGemmConfig>
GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
std::vector<TritonGemmConfig> configs;
se::CudaComputeCapability cc = GetComputeCapability();
bool tune_ctas =
debug_options_.xla_gpu_enable_triton_hopper() && cc.IsAtLeastHopper();
se::GpuComputeCapability gcc = GetComputeCapability();
bool tune_ctas = false;

if (!isRocm()) {
auto cc = std::get<se::CudaComputeCapability>(gcc);
debug_options_.xla_gpu_enable_triton_hopper() && cc.IsAtLeastHopper();
}

for (int num_stages : kNumStages) {
for (int tile_m : kBlockSizes) {
Expand Down Expand Up @@ -1235,38 +1204,6 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const {
return configs;
}

std::vector<TritonGemmConfig> GemmFusionAutotunerImpl::GetDefaultTritonConfigs()
const {
using Config = TritonGemmConfig;
std::vector<Config> configs = {
Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4),
Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4),
Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4),
Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4),
Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4),
Config(64, 32, 64, 1, 2, 8), Config(128, 256, 32, 1, 3, 8),
Config(256, 128, 32, 1, 3, 8), Config(256, 64, 32, 1, 4, 4),
Config(64, 256, 32, 1, 4, 4), Config(128, 64, 32, 1, 4, 4),
Config(64, 128, 32, 1, 4, 4), Config(256, 128, 128, 1, 3, 8),
Config(256, 64, 128, 1, 4, 4), Config(64, 256, 128, 1, 4, 4),
Config(128, 128, 128, 1, 4, 4), Config(128, 64, 64, 1, 4, 4),
Config(64, 128, 64, 1, 4, 4), Config(128, 32, 64, 1, 4, 4),
Config(64, 32, 64, 1, 4, 4), Config(32, 128, 32, 1, 4, 4),
Config(128, 128, 32, 1, 4, 4), Config(16, 16, 256, 1, 3, 4),
Config(128, 128, 64, 2, 1, 8), Config(64, 64, 64, 1, 2, 4),
Config(16, 64, 256, 8, 1, 4), Config(256, 256, 128, 1, 3, 8)};
if (GetComputeCapability().IsAtLeastHopper()) {
absl::c_copy(
std::vector<Config>{
Config(16, 32, 32, 8, 1, 2),
Config(16, 64, 128, 8, 1, 4),
Config(16, 64, 128, 16, 3, 4),
},
std::back_inserter(configs));
}
return configs;
}

absl::Status DumpAutotuningLogs(const DebugOptions& debug_opts,
const AutotuningLogs& autotuning_logs) {
if (absl::string_view file_path = debug_opts.xla_gpu_dump_autotune_logs_to();
Expand Down
Loading

0 comments on commit f55fc88

Please sign in to comment.