Skip to content

Commit

Permalink
Add version checks to FindCudaExecutable
Browse files Browse the repository at this point in the history
Currently we look for ptxas and nvlink in a few different places on the host machine, then we choose the first found binary without taking its version into account. If the chosen binary doesn't fulfill our version requirements we will later fail even if there was a suitable ptxas or nvlink in the search path in the first place.

This change makes it take the version of each binary into account when going through the search path. Unsuitable binaries will be discarded right away and the search continues until we are out of locations to check.

This should help with host environments that have multiple CUDA toolkits installed and should make ptxas and nvlink selection more robust.

The concreate changes:

1. `FindCudaExecutable` now also takes a minimum version and a list of forbidden (think buggy) versions that are supposed to be skipped.
2. `WarnIfBadPtxAsVersion` has been removed. It was checking for ptxas < 11.1 which is way older than our minimum supported version of 11.8 and was not doing anything given the check described in #3.
3. There was another version check for `ptxas` in `NVPTXCompiler::ChooseLinkingMethod` which was checking for `version(ptxas)` < 11.8. This has also been removed/replace by the version check described in #4.
4. Version checking for `ptxas` and `nvlink` has been consolidated into 2 methods `FindPtxAsExectuable` and `FindNvLinkExecutable`. These methods hard code the current minimum version (and the list of excluded versions) of each tool in one place. It's still not great but at least less spaghetti-like.

PiperOrigin-RevId: 618797392
  • Loading branch information
beckerhe authored and copybara-github committed Mar 25, 2024
1 parent c5b90c0 commit af5a607
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 90 deletions.
36 changes: 5 additions & 31 deletions xla/service/gpu/nvptx_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.
#include <fstream>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
Expand Down Expand Up @@ -672,8 +671,7 @@ NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
return cache_value->maybe_cubin;
}

static std::optional<std::array<int64_t, 3>> GetNvLinkVersion(
const std::string& preferred_cuda_dir) {
static bool IsNvlinkEnabled() {
const bool use_nvlink_by_default =
#ifdef TF_DISABLE_NVLINK_BY_DEFAULT
false;
Expand All @@ -684,24 +682,7 @@ static std::optional<std::array<int64_t, 3>> GetNvLinkVersion(
TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_USE_NVLINK_FOR_PARALLEL_COMPILATION",
/*default_val=*/
use_nvlink_by_default, &use_nvlink));

if (!use_nvlink) {
return std::nullopt;
}

// Make sure nvlink exists and is executable.
absl::StatusOr<std::string> bin_path =
se::FindCudaExecutable("nvlink", preferred_cuda_dir);

if (!bin_path.ok()) {
return std::nullopt;
}

auto version = se::GetToolVersion(bin_path.value());
if (!version.ok()) {
return std::nullopt;
}
return *version;
return use_nvlink;
}

absl::StatusOr<NVPTXCompiler::LinkingMethod> ChooseLinkingMethodImpl(
Expand All @@ -710,16 +691,9 @@ absl::StatusOr<NVPTXCompiler::LinkingMethod> ChooseLinkingMethodImpl(
TF_ASSIGN_OR_RETURN(auto ptxas_version_tuple,
se::GetAsmCompilerVersion(preferred_cuda_dir));

// ptxas versions prior to 11.8 are not supported anymore. We check this here,
// since we are fetching the ptxas version anyway. Catching the error
// elsewhere might introduce unnecessary overhead.
if (ptxas_version_tuple < std::array<int64_t, 3>{11, 8, 0}) {
return absl::InternalError("XLA requires ptxas version 11.8 or higher");
}

std::optional<std::array<int64_t, 3>> nvlink_version =
GetNvLinkVersion(preferred_cuda_dir);
if (nvlink_version && *nvlink_version >= ptxas_version_tuple) {
auto nvlink_version = stream_executor::GetNvLinkVersion(preferred_cuda_dir);
if (IsNvlinkEnabled() && nvlink_version.ok() &&
nvlink_version.value() >= ptxas_version_tuple) {
return LinkingMethod::kNvLink;
}

Expand Down
4 changes: 1 addition & 3 deletions xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -566,17 +566,15 @@ cuda_only_cc_library(
"//xla:status_macros",
"//xla/stream_executor:stream_executor_headers",
"//xla/stream_executor/gpu:asm_compiler_header",
"//xla/stream_executor/gpu:gpu_asm_opts",
"//xla/stream_executor/gpu:gpu_driver_header",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@local_config_cuda//cuda:cuda_headers",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
Expand Down
33 changes: 24 additions & 9 deletions xla/stream_executor/cuda/cuda_asm_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/base/call_once.h"
#include "absl/base/optimization.h"
#include "absl/cleanup/cleanup.h"
#include "absl/log/log.h"
Expand All @@ -31,6 +30,8 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/gpu/asm_compiler.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
Expand All @@ -56,18 +57,32 @@ namespace stream_executor {
} \
} while (false)

static absl::StatusOr<std::string> FindNvlinkExecutable(
std::string_view preferred_cuda_dir) {
static constexpr ToolVersion kMinimumNvlinkVersion{11, 8, 0};
static constexpr absl::Span<const ToolVersion> kNoExcludedVersions{};
static constexpr std::string_view kNvLinkBinaryName = "nvlink";

return FindCudaExecutable(kNvLinkBinaryName, preferred_cuda_dir,
kMinimumNvlinkVersion, kNoExcludedVersions);
}

absl::StatusOr<ToolVersion> GetNvLinkVersion(
std::string_view preferred_cuda_dir) {
// Make sure nvlink exists and is executable.
TF_ASSIGN_OR_RETURN(std::string bin_path,
FindNvlinkExecutable(preferred_cuda_dir));

return GetToolVersion(bin_path);
}

absl::StatusOr<std::vector<uint8_t>> LinkUsingNvlink(
absl::string_view preferred_cuda_dir, gpu::GpuContext* context,
std::vector<CubinOrPTXImage> images) {
{
static absl::once_flag log_once;
absl::call_once(log_once,
[] { LOG(INFO) << "Using nvlink for parallel linking"; });
}
LOG_FIRST_N(INFO, 1) << "Using nvlink for parallel linking";

TF_ASSIGN_OR_RETURN(
std::string bin_path,
FindCudaExecutable("nvlink", std::string(preferred_cuda_dir)));
TF_ASSIGN_OR_RETURN(std::string bin_path,
FindNvlinkExecutable(preferred_cuda_dir));

if (images.empty()) {
return std::vector<uint8>();
Expand Down
2 changes: 2 additions & 0 deletions xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ gpu_only_cc_library(
"//xla/stream_executor/cuda:ptx_compiler",
"//xla/stream_executor/cuda:ptx_compiler_support",
"//xla/stream_executor/platform",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
Expand All @@ -468,6 +469,7 @@ gpu_only_cc_library(
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:subprocess",
"@tsl//tsl/util:env_var",
] + if_cuda_is_configured([
"//xla/stream_executor/cuda:cuda_asm_compiler",
"//xla/stream_executor/cuda:cuda_driver",
Expand Down
94 changes: 50 additions & 44 deletions xla/stream_executor/gpu/asm_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/const_init.h"
#include "absl/base/optimization.h"
#include "absl/base/thread_annotations.h"
Expand Down Expand Up @@ -63,7 +64,7 @@ limitations under the License.
namespace stream_executor {

static absl::StatusOr<std::string> GetToolVersionString(
absl::string_view binary_path) {
std::string_view binary_path) {
// If binary_path doesn't exist, then tsl::SubProcess will log a bunch of
// error messages that have confused users in the past. Therefore we first
// check whether the binary_path exists and error out early if not.
Expand Down Expand Up @@ -103,7 +104,7 @@ static absl::StatusOr<ToolVersion> GetToolVersionImpl(
}
static constexpr LazyRE2 kVersionRegex = {R"(\bV(\d+)\.(\d+)\.(\d+)\b)"};
ToolVersion version{};
absl::string_view vmaj_str, vmin_str, vdot_str;
std::string_view vmaj_str, vmin_str, vdot_str;
if (!RE2::PartialMatch(tool_version.value(), *kVersionRegex, &vmaj_str,
&vmin_str, &vdot_str) ||
!absl::SimpleAtoi(vmaj_str, &version[0]) ||
Expand Down Expand Up @@ -134,28 +135,6 @@ absl::StatusOr<ToolVersion> GetToolVersion(std::string_view tool_path) {
.first->second;
}

// Prints a warning if the ptxas at ptxas_path has known bugs.
//
// Only prints a warning the first time it's called for a particular value of
// ptxas_path.
//
// Locks on entry.˝
static void WarnIfBadPtxasVersion(absl::string_view ptxas_path) {
absl::StatusOr<std::array<int64_t, 3>> version = GetToolVersion(ptxas_path);
if (!version.ok()) {
LOG(WARNING) << "Couldn't get ptxas version : " << version.status();
return;
}

if (std::make_tuple((*version)[0], (*version)[1]) < std::make_tuple(11, 1)) {
LOG(ERROR) << "*** WARNING *** You are using ptxas " << (*version)[0] << "."
<< (*version)[1] << "." << (*version)[2]
<< ", which is older than 11.1. ptxas before 11.1 is known to "
"miscompile XLA code, leading to incorrect results or "
"invalid-address errors.\n";
}
}

absl::StatusOr<absl::Span<const uint8_t>> CompileGpuAsmOrGetCached(
int device_ordinal, const char* ptx, GpuAsmOpts compilation_options) {
using PtxCacheKey = std::tuple<int, std::string, GpuAsmOpts::PtxOptionsTuple>;
Expand Down Expand Up @@ -201,7 +180,9 @@ absl::StatusOr<std::vector<uint8_t>> CompileGpuAsm(int device_ordinal,
}

absl::StatusOr<std::string> FindCudaExecutable(
std::string_view binary_name, std::string_view preferred_cuda_dir) {
std::string_view binary_name, std::string_view preferred_cuda_dir,
ToolVersion minimum_version,
absl::Span<const ToolVersion> excluded_versions) {
#if defined(PLATFORM_WINDOWS)
const std::string binary_filename = std::string{binary_name} + ".exe";
#else
Expand Down Expand Up @@ -234,18 +215,44 @@ absl::StatusOr<std::string> FindCudaExecutable(

for (const auto& candidate : candidates) {
VLOG(2) << "Looking for " << candidate;
if (GetToolVersion(candidate).ok()) {
VLOG(2) << "Using " << candidate;
return candidate;
auto candidate_version = GetToolVersion(candidate);
if (!candidate_version.ok()) {
continue;
}

if (candidate_version.value() < minimum_version) {
VLOG(2) << candidate << " with version "
<< absl::StrJoin(minimum_version, ".") << " is too old.";
continue;
}

if (absl::c_find(excluded_versions, candidate_version.value()) !=
excluded_versions.end()) {
VLOG(2) << candidate << " has version "
<< absl::StrJoin(candidate_version.value(), ".")
<< " which was explicitly excluded.";
continue;
}

VLOG(2) << "Using " << candidate << " with version "
<< absl::StrJoin(candidate_version.value(), ".");
return candidate;
}

return absl::NotFoundError(
absl::StrCat("Couldn't find ", binary_name,
absl::StrCat("Couldn't find a suitable version of ", binary_name,
". The following locations were considered: ",
absl::StrJoin(candidates, ", ")));
}

absl::StatusOr<std::string> FindCudaExecutable(
std::string_view binary_name, std::string_view preferred_cuda_dir) {
static constexpr ToolVersion kNoMinimumVersion{0, 0, 0};
static constexpr absl::Span<const ToolVersion> kNoExcludedVersions{};
return FindCudaExecutable(binary_name, preferred_cuda_dir, kNoMinimumVersion,
kNoExcludedVersions);
}

static void LogPtxasTooOld(const std::string& ptxas_path, int cc_major,
int cc_minor) {
using AlreadyLoggedSetTy =
Expand Down Expand Up @@ -274,29 +281,28 @@ static void AppendArgsFromOptions(GpuAsmOpts options,
options.extra_flags.end());
}

absl::StatusOr<std::array<int64_t, 3>> GetAsmCompilerVersion(
const std::string& preferred_cuda_dir) {
static absl::StatusOr<std::string> FindPtxAsExecutable(
std::string_view preferred_cuda_dir) {
static constexpr ToolVersion kMinimumSupportedPtxAsVersion{11, 8, 0};
static constexpr ToolVersion kBuggyPtxAsVersions[] = {{12, 3, 103}};
static constexpr std::string_view kPtxAsBinaryName = "ptxas";

return FindCudaExecutable(kPtxAsBinaryName, preferred_cuda_dir,
kMinimumSupportedPtxAsVersion, kBuggyPtxAsVersions);
}

absl::StatusOr<ToolVersion> GetAsmCompilerVersion(
std::string_view preferred_cuda_dir) {
TF_ASSIGN_OR_RETURN(std::string ptxas_path,
FindCudaExecutable("ptxas", preferred_cuda_dir));
FindPtxAsExecutable(preferred_cuda_dir));
return GetToolVersion(ptxas_path);
}

absl::StatusOr<std::vector<uint8_t>> CompileGpuAsmUsingPtxAs(
int cc_major, int cc_minor, const char* ptx_contents, GpuAsmOpts options,
bool cancel_if_reg_spill) {
TF_ASSIGN_OR_RETURN(auto ptxas_version_tuple,
GetAsmCompilerVersion(options.preferred_cuda_dir));
if (ptxas_version_tuple == std::array<int64_t, 3>{12, 3, 103}) {
return absl::InternalError(absl::StrFormat(
"ptxas %d.%d.%d has a bug that we think can affect XLA. "
"Please use a different version.",
std::get<0>(ptxas_version_tuple), std::get<1>(ptxas_version_tuple),
std::get<2>(ptxas_version_tuple)));
}
TF_ASSIGN_OR_RETURN(std::string ptxas_path,
FindCudaExecutable("ptxas", options.preferred_cuda_dir));

WarnIfBadPtxasVersion(ptxas_path);
FindPtxAsExecutable(options.preferred_cuda_dir));

// Write ptx into a temporary file.
std::string ptx_path;
Expand Down
15 changes: 12 additions & 3 deletions xla/stream_executor/gpu/asm_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,25 @@ absl::StatusOr<std::vector<uint8_t>> LinkUsingNvlink(
absl::string_view preferred_cuda_dir, gpu::GpuContext* context,
std::vector<CubinOrPTXImage> images);

using ToolVersion = std::array<int64_t, 3>;
absl::StatusOr<std::string> FindCudaExecutable(
std::string_view binary_name, std::string_view preferred_cuda_dir,
ToolVersion minimum_version,
absl::Span<const ToolVersion> excluded_versions);

absl::StatusOr<std::string> FindCudaExecutable(
std::string_view binary_name, std::string_view preferred_cuda_dir);

// Runs tool --version and parses its version string.
using ToolVersion = std::array<int64_t, 3>;
absl::StatusOr<ToolVersion> GetToolVersion(std::string_view tool_path);

// On NVIDIA GPUs, returns the CUDA toolkit version supported by the driver,
// On NVIDIA GPUs, returns the version of the ptxas command line tool.
absl::StatusOr<ToolVersion> GetAsmCompilerVersion(
const std::string& preferred_cuda_dir);
std::string_view preferred_cuda_dir);

// On NVIDIA GPUs, returns the version of the nvlink command line tool.
absl::StatusOr<ToolVersion> GetNvLinkVersion(
std::string_view preferred_cuda_dir);

#if GOOGLE_CUDA
// Maintains a cache of pointers to loaded kernels
Expand Down

0 comments on commit af5a607

Please sign in to comment.