Skip to content

Commit

Permalink
Consistently handle algorithm selection in GemmAlgorithmPicker
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 618789520
  • Loading branch information
beckerhe authored and copybara-github committed Mar 25, 2024
1 parent 24a91cb commit c5b90c0
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions xla/service/gpu/gemm_algorithm_picker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,15 +417,28 @@ absl::StatusOr<bool> RunOnInstruction(HloInstruction* gemm,
config.GetGpuComputeCapability());

if (update_algorithm) {
int64_t new_algorithm{};
if (algorithm.has_gemm()) {
backend_config.set_selected_algorithm(algorithm.gemm().algorithm());
new_algorithm = algorithm.gemm().algorithm();
} else {
// NOTE: runtime autotuning is no longer available => set to default
backend_config.set_selected_algorithm(se::blas::kDefaultAlgorithm);
new_algorithm = se::blas::kDefaultAlgorithm;
}

if (new_algorithm == old_algorithm &&
backend_config.has_selected_algorithm()) {
// We don't need to update the backend config if
// the algorithm hasn't changed unless previously
// the algorithm wasn't set explicitly.
return false;
}

backend_config.set_selected_algorithm(new_algorithm);
TF_RETURN_IF_ERROR(gemm->set_backend_config(gpu_config));
return true; // We changed `gemm`
}
TF_RETURN_IF_ERROR(gemm->set_backend_config(gpu_config));
return old_algorithm != backend_config.selected_algorithm();

return false; // No change to `gemm`
}

absl::StatusOr<bool> RunOnComputation(HloComputation* computation,
Expand Down

0 comments on commit c5b90c0

Please sign in to comment.