Skip to content

Commit

Permalink
PR #17259: Adding Strictness level to PGLE accuracy checker.
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17259

Values for the new flag: `xla_gpu_pgle_accuracy_checker: {OFF, WARN, ERROR}`
Copybara import of the project:

--
86ccf83 by Shraiysh Vaishay <svaishay@nvidia.com>:

Add xla_gpu_pgle_accuracy_checker to set strictness levels

xla_gpu_pgle_accuracy_checker can take the values {OFF, WARN, ERROR}
and this flag decides what will be done when there are missing
instructions in PGLE profile: either do nothing (OFF), warn about it
(WARN) or halt compilation (ERROR)

Merging this change closes #17259

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17259 from shraiysh:pgle_strictness_levels 86ccf83
PiperOrigin-RevId: 685614179
  • Loading branch information
shraiysh authored and Google-ML-Automation committed Oct 21, 2024
1 parent c5c29c1 commit 7d4b093
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 38 deletions.
28 changes: 21 additions & 7 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_cudnn_gemm_max_plans(5);

opts.set_xla_gpu_enable_pgle_accuracy_checker(false);
opts.set_xla_gpu_pgle_accuracy_checker(
DebugOptions::PGLE_STRICTNESS_LEVEL_WARN);

opts.set_xla_gpu_executable_warn_stuck_timeout_seconds(10);
opts.set_xla_gpu_executable_terminate_timeout_seconds(30);
Expand Down Expand Up @@ -701,6 +702,18 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
return true;
};

// Custom "sub-parser" lambda for xla_gpu_pgle_accuracy_checker.
auto setter_for_xla_gpu_pgle_accuracy_checker =
[debug_options](const std::string& value) {
DebugOptions::PGLEStrictnessLevel strictness_level;
if (!DebugOptions::PGLEStrictnessLevel_Parse(value,
&strictness_level)) {
return false;
}
debug_options->set_xla_gpu_pgle_accuracy_checker(strictness_level);
return true;
};

// Don't use an initializer list for initializing the vector; this would
// create a temporary copy, and exceeds the stack space when compiling with
// certain configurations.
Expand Down Expand Up @@ -1975,12 +1988,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"a training. The location of the marker (if any) is determined "
"by the option value of type DebugOptions::StepMarkerLocation."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_pgle_accuracy_checker",
bool_setter_for(&DebugOptions::set_xla_gpu_enable_pgle_accuracy_checker),
debug_options->xla_gpu_enable_pgle_accuracy_checker(),
"Enables strict PGLE checking. If an FDO profile is specified and "
"latency hiding scheduler encounters missing instructions in the profile "
"compilation will halt."));
"xla_gpu_pgle_accuracy_checker", setter_for_xla_gpu_pgle_accuracy_checker,
DebugOptions::PGLEStrictnessLevel_Name(
debug_options->xla_gpu_pgle_accuracy_checker()),
"If an FDO profile is specified and latency hiding scheduler encounters "
"missing instructions in the profile, then the compilation will halt "
"(ERROR), or a warning will be emitted (WARN), or the checker is "
"disabled (OFF)"));

flag_list->push_back(tsl::Flag(
"xla_gpu_executable_warn_stuck_timeout",
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2623,7 +2623,8 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines(
pipeline.AddPass<SanitizeConstantNames>();
}

if (module->config().debug_options().xla_gpu_enable_pgle_accuracy_checker()) {
if (module->config().debug_options().xla_gpu_pgle_accuracy_checker() ==
DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) {
AddHloVerifier(
&main_pipeline,
module->config().debug_options().xla_experimental_ignore_channel_id(),
Expand Down
38 changes: 26 additions & 12 deletions xla/service/gpu/gpu_hlo_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -449,15 +449,32 @@ absl::StatusOr<ScheduleMetadata> ScheduleGpuModule(
VLOG(1) << "Fingerprint before LHS for module " << module->name() << "("
<< module->unique_id() << ") = " << fingerprint;

const DebugOptions& options = module->config().debug_options();
const bool enable_latency_hiding_scheduler =
module->config()
.debug_options()
.xla_gpu_enable_latency_hiding_scheduler();
options.xla_gpu_enable_latency_hiding_scheduler();

if (!enable_latency_hiding_scheduler) {
return ScheduleMetadata{memory_limit};
}

if (options.xla_gpu_pgle_profile_file_or_directory_path().empty() &&
module->config().fdo_profile().empty() &&
options.xla_gpu_pgle_accuracy_checker() ==
DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) {
return absl::InvalidArgumentError(
"xla_gpu_pgle_accuracy_checker is set to ERROR, but no profile "
"path specified in xla_gpu_pgle_profile_file_or_directory_path");
}

if (options.xla_gpu_pgle_profile_file_or_directory_path().empty() &&
module->config().fdo_profile().empty() &&
options.xla_gpu_pgle_accuracy_checker(),
DebugOptions::PGLE_STRICTNESS_LEVEL_WARN) {
LOG(WARNING)
<< "xla_gpu_pgle_accuracy_checker is set to WARN, but no profile path "
"specified in xla_gpu_pgle_profile_file_or_directory_path";
}

SchedulerConfig config = GetSchedulerConfig(
memory_limit,
module->config()
Expand All @@ -481,9 +498,7 @@ absl::StatusOr<ScheduleMetadata> ScheduleGpuModule(
ReadPGLEProfile(module, fingerprint);

const bool enable_analytical_latency_estimator =
module->config()
.debug_options()
.xla_gpu_enable_analytical_latency_estimator();
options.xla_gpu_enable_analytical_latency_estimator();
HloPassPipeline pipeline("latency-hiding-scheduler");
if (profile.has_value()) {
auto aggregator = std::make_unique<GPUProfileStatisticsAggregator>();
Expand All @@ -492,9 +507,10 @@ absl::StatusOr<ScheduleMetadata> ScheduleGpuModule(
std::move(aggregator));
LOG(INFO) << "Found profile, using profile guided latency estimator";
VLOG(1) << "Profile:\n" << profile->DebugString();
if (module->config()
.debug_options()
.xla_gpu_enable_pgle_accuracy_checker()) {
if (options.xla_gpu_pgle_accuracy_checker() ==
DebugOptions::PGLE_STRICTNESS_LEVEL_WARN ||
options.xla_gpu_pgle_accuracy_checker() ==
DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) {
pipeline.AddPass<PGLEAccuracyChecker>(*pg_latency_estimator);
}
latency_estimator = std::move(pg_latency_estimator);
Expand All @@ -511,9 +527,7 @@ absl::StatusOr<ScheduleMetadata> ScheduleGpuModule(
}

auto async_tracker = [&]() -> std::unique_ptr<AsyncTracker> {
return module->config()
.debug_options()
.xla_gpu_lhs_enable_gpu_async_tracker()
return options.xla_gpu_lhs_enable_gpu_async_tracker()
? std::make_unique<GpuAsyncTracker>(config)
: std::make_unique<GpuAsyncTrackerBase>(config);
}();
Expand Down
28 changes: 27 additions & 1 deletion xla/service/gpu/gpu_hlo_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,8 @@ TEST_F(GpuHloScheduleTest, ProfileGuidedCostModelFailsWithIncompleteProfile) {

HloModuleConfig config(module->config());
DebugOptions dboptions(config.debug_options());
dboptions.set_xla_gpu_enable_pgle_accuracy_checker(true);
dboptions.set_xla_gpu_pgle_accuracy_checker(
DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR);
config.set_debug_options(dboptions);
module->set_config(config);

Expand Down Expand Up @@ -1696,5 +1697,30 @@ TEST_F(GpuHloScheduleTest, CopyStartDoneScheduled) {
)"));
}

TEST_F(GpuHloScheduleTest, InvalidPGLEOptions) {
const char* hlo = R"(
HloModule test
ENTRY add {
a = s32[] parameter(0)
b = s32[] parameter(1)
ROOT add = add(a,b)
}
)";

HloModuleConfig config;
DebugOptions options;
options.set_xla_gpu_pgle_accuracy_checker(
DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR);
options.set_xla_gpu_enable_latency_hiding_scheduler(true);
config.set_debug_options(options);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo, config));

GTEST_FLAG_SET(death_test_style, "threadsafe");
EXPECT_DEATH(BuildHloOrdering(module.get()),
"xla_gpu_pgle_accuracy_checker is set to ERROR, but no profile "
"path specified in xla_gpu_pgle_profile_file_or_directory_path");
}

} // namespace gpu
} // namespace xla
14 changes: 7 additions & 7 deletions xla/service/gpu/gpu_latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,17 @@ int GetIndexByName(absl::Span<HloInstruction* const> instruction_sequence,
class GpuLatencyHidingSchedulerBaseTest : public HloTestBase {
protected:
absl::StatusOr<HloModule*> ScheduleModule(
HloModule* module, int64_t num_parallel_resources = 1) {
HloModule* module, int64_t num_parallel_resources = 1,
DebugOptions::PGLEStrictnessLevel strictness =
DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) {
auto& test_backend = backend();
const auto& gpu_device_info =
test_backend.default_stream_executor()->GetDeviceDescription();
HloModuleConfig config(module->config());
DebugOptions dboptions(config.debug_options());
dboptions.set_xla_gpu_enable_pgle_accuracy_checker(true);
dboptions.set_xla_gpu_experimental_parallel_collective_overlap_limit(
DebugOptions& options = module->mutable_config().mutable_debug_options();
options.set_xla_gpu_experimental_parallel_collective_overlap_limit(
num_parallel_resources);
config.set_debug_options(dboptions);
module->set_config(config);
options.set_xla_gpu_pgle_accuracy_checker(strictness);

TF_RETURN_IF_ERROR(
ScheduleGpuModule(module, /*pointer_size=*/8, gpu_device_info)
.status());
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/transforms/pgle_accuracy_checker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ TEST_F(PGLEAccuracyCheckerTest,
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloString));
*module->mutable_config().mutable_fdo_profile() = kProfileString;
module->mutable_config()
.mutable_debug_options()
.set_xla_gpu_pgle_accuracy_checker(
DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR);

auto pgle_estimator = GetProfileGuidedLatencyEstimator(profile);
PGLEAccuracyChecker pgle_accuracy_checker(*pgle_estimator);
Expand Down
19 changes: 11 additions & 8 deletions xla/service/profile_guided_latency_estimator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,19 @@ absl::Status ProfileGuidedLatencyEstimator::CheckAccuracy(
ProfileStatisticsAggregator::Statistics stats = aggregator_->GetStats();
size_t missing_instructions_count = stats.missing_instructions.size();
if (missing_instructions_count > 0) {
LOG(ERROR) << "Found " << stats.found_instructions_count
<< " instructions from the profile.";
LOG(ERROR) << "Missing " << missing_instructions_count
<< " instructions from the profile.";
LOG(WARNING) << "Found " << stats.found_instructions_count
<< " instructions from the profile.";
LOG(WARNING) << "Missing " << missing_instructions_count
<< " instructions from the profile.";
for (const HloInstruction* instr : stats.missing_instructions) {
LOG(ERROR) << " " << instr->name();
LOG(WARNING) << " " << instr->name();
}
if (module.config().debug_options().xla_gpu_pgle_accuracy_checker() ==
DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) {
return absl::InvalidArgumentError(
absl::StrCat("Found ", missing_instructions_count,
" missing instructions. Discarding the profile."));
}
return absl::InvalidArgumentError(
absl::StrCat("Found ", missing_instructions_count,
" missing instructions. Discarding the profile."));
}
return absl::OkStatus();
}
Expand Down
9 changes: 7 additions & 2 deletions xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -992,8 +992,13 @@ message DebugOptions {

// Enables strict PGLE checking. If an FDO profile is specified and latency
// hiding scheduler encounters missing instructions in the profile
// compilation will halt.
bool xla_gpu_enable_pgle_accuracy_checker = 326;
// compilation will halt or warn depending on the value of this option.
enum PGLEStrictnessLevel {
PGLE_STRICTNESS_LEVEL_OFF = 0;
PGLE_STRICTNESS_LEVEL_WARN = 1;
PGLE_STRICTNESS_LEVEL_ERROR = 2;
}
PGLEStrictnessLevel xla_gpu_pgle_accuracy_checker = 326;

// Timeouts for RendezvousSingle stuck warning and termination.
int32 xla_gpu_executable_warn_stuck_timeout_seconds = 327;
Expand Down

0 comments on commit 7d4b093

Please sign in to comment.