diff --git a/xla/service/gpu/nvptx_compiler.cc b/xla/service/gpu/nvptx_compiler.cc index 62920ee5adb92..3422cd810290f 100644 --- a/xla/service/gpu/nvptx_compiler.cc +++ b/xla/service/gpu/nvptx_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -672,8 +671,7 @@ NVPTXCompiler::CompileGpuAsmOrGetCachedResult( return cache_value->maybe_cubin; } -static std::optional> GetNvLinkVersion( - const std::string& preferred_cuda_dir) { +static bool IsNvlinkEnabled() { const bool use_nvlink_by_default = #ifdef TF_DISABLE_NVLINK_BY_DEFAULT false; @@ -684,24 +682,7 @@ static std::optional> GetNvLinkVersion( TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_USE_NVLINK_FOR_PARALLEL_COMPILATION", /*default_val=*/ use_nvlink_by_default, &use_nvlink)); - - if (!use_nvlink) { - return std::nullopt; - } - - // Make sure nvlink exists and is executable. - absl::StatusOr bin_path = - se::FindCudaExecutable("nvlink", preferred_cuda_dir); - - if (!bin_path.ok()) { - return std::nullopt; - } - - auto version = se::GetToolVersion(bin_path.value()); - if (!version.ok()) { - return std::nullopt; - } - return *version; + return use_nvlink; } absl::StatusOr ChooseLinkingMethodImpl( @@ -710,16 +691,9 @@ absl::StatusOr ChooseLinkingMethodImpl( TF_ASSIGN_OR_RETURN(auto ptxas_version_tuple, se::GetAsmCompilerVersion(preferred_cuda_dir)); - // ptxas versions prior to 11.8 are not supported anymore. We check this here, - // since we are fetching the ptxas version anyway. Catching the error - // elsewhere might introduce unnecessary overhead. - if (ptxas_version_tuple < std::array{11, 8, 0}) { - return absl::InternalError("XLA requires ptxas version 11.8 or higher"); - } - - std::optional> nvlink_version = - GetNvLinkVersion(preferred_cuda_dir); - if (nvlink_version && *nvlink_version >= ptxas_version_tuple) { + auto nvlink_version = stream_executor::GetNvLinkVersion(preferred_cuda_dir); + if (IsNvlinkEnabled() && nvlink_version.ok() && + nvlink_version.value() >= ptxas_version_tuple) { return LinkingMethod::kNvLink; } diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index 37011dc3c6ab8..af38d30389621 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -566,10 +566,7 @@ cuda_only_cc_library( "//xla:status_macros", "//xla/stream_executor:stream_executor_headers", "//xla/stream_executor/gpu:asm_compiler_header", - "//xla/stream_executor/gpu:gpu_asm_opts", "//xla/stream_executor/gpu:gpu_driver_header", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/log", @@ -577,6 +574,7 @@ cuda_only_cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@local_config_cuda//cuda:cuda_headers", "@tsl//tsl/platform:env", "@tsl//tsl/platform:errors", diff --git a/xla/stream_executor/cuda/cuda_asm_compiler.cc b/xla/stream_executor/cuda/cuda_asm_compiler.cc index 2d435a74ef308..771d570b5f3f3 100644 --- a/xla/stream_executor/cuda/cuda_asm_compiler.cc +++ b/xla/stream_executor/cuda/cuda_asm_compiler.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "absl/base/call_once.h" #include "absl/base/optimization.h" #include "absl/cleanup/cleanup.h" #include "absl/log/log.h" @@ -31,6 +30,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/status_macros.h" #include "xla/stream_executor/gpu/asm_compiler.h" #include "xla/stream_executor/gpu/gpu_driver.h" @@ -56,18 +57,32 @@ namespace stream_executor { } \ } while (false) +static absl::StatusOr FindNvlinkExecutable( + std::string_view preferred_cuda_dir) { + static constexpr ToolVersion kMinimumNvlinkVersion{11, 8, 0}; + static constexpr absl::Span kNoExcludedVersions{}; + static constexpr std::string_view kNvLinkBinaryName = "nvlink"; + + return FindCudaExecutable(kNvLinkBinaryName, preferred_cuda_dir, + kMinimumNvlinkVersion, kNoExcludedVersions); +} + +absl::StatusOr GetNvLinkVersion( + std::string_view preferred_cuda_dir) { + // Make sure nvlink exists and is executable. + TF_ASSIGN_OR_RETURN(std::string bin_path, + FindNvlinkExecutable(preferred_cuda_dir)); + + return GetToolVersion(bin_path); +} + absl::StatusOr> LinkUsingNvlink( absl::string_view preferred_cuda_dir, gpu::GpuContext* context, std::vector images) { - { - static absl::once_flag log_once; - absl::call_once(log_once, - [] { LOG(INFO) << "Using nvlink for parallel linking"; }); - } + LOG_FIRST_N(INFO, 1) << "Using nvlink for parallel linking"; - TF_ASSIGN_OR_RETURN( - std::string bin_path, - FindCudaExecutable("nvlink", std::string(preferred_cuda_dir))); + TF_ASSIGN_OR_RETURN(std::string bin_path, + FindNvlinkExecutable(preferred_cuda_dir)); if (images.empty()) { return std::vector(); diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index dc89a624530b8..a78839a68b914 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -445,6 +445,7 @@ gpu_only_cc_library( "//xla/stream_executor/cuda:ptx_compiler", "//xla/stream_executor/cuda:ptx_compiler_support", "//xla/stream_executor/platform", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -468,6 +469,7 @@ gpu_only_cc_library( "@tsl//tsl/platform:status", "@tsl//tsl/platform:statusor", "@tsl//tsl/platform:subprocess", + "@tsl//tsl/util:env_var", ] + if_cuda_is_configured([ "//xla/stream_executor/cuda:cuda_asm_compiler", "//xla/stream_executor/cuda:cuda_driver", diff --git a/xla/stream_executor/gpu/asm_compiler.cc b/xla/stream_executor/gpu/asm_compiler.cc index 54641031ee58d..6293d7a8eaa7c 100644 --- a/xla/stream_executor/gpu/asm_compiler.cc +++ b/xla/stream_executor/gpu/asm_compiler.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/base/const_init.h" #include "absl/base/optimization.h" #include "absl/base/thread_annotations.h" @@ -63,7 +64,7 @@ limitations under the License. namespace stream_executor { static absl::StatusOr GetToolVersionString( - absl::string_view binary_path) { + std::string_view binary_path) { // If binary_path doesn't exist, then tsl::SubProcess will log a bunch of // error messages that have confused users in the past. Therefore we first // check whether the binary_path exists and error out early if not. @@ -103,7 +104,7 @@ static absl::StatusOr GetToolVersionImpl( } static constexpr LazyRE2 kVersionRegex = {R"(\bV(\d+)\.(\d+)\.(\d+)\b)"}; ToolVersion version{}; - absl::string_view vmaj_str, vmin_str, vdot_str; + std::string_view vmaj_str, vmin_str, vdot_str; if (!RE2::PartialMatch(tool_version.value(), *kVersionRegex, &vmaj_str, &vmin_str, &vdot_str) || !absl::SimpleAtoi(vmaj_str, &version[0]) || @@ -134,28 +135,6 @@ absl::StatusOr GetToolVersion(std::string_view tool_path) { .first->second; } -// Prints a warning if the ptxas at ptxas_path has known bugs. -// -// Only prints a warning the first time it's called for a particular value of -// ptxas_path. -// -// Locks on entry.˝ -static void WarnIfBadPtxasVersion(absl::string_view ptxas_path) { - absl::StatusOr> version = GetToolVersion(ptxas_path); - if (!version.ok()) { - LOG(WARNING) << "Couldn't get ptxas version : " << version.status(); - return; - } - - if (std::make_tuple((*version)[0], (*version)[1]) < std::make_tuple(11, 1)) { - LOG(ERROR) << "*** WARNING *** You are using ptxas " << (*version)[0] << "." - << (*version)[1] << "." << (*version)[2] - << ", which is older than 11.1. ptxas before 11.1 is known to " - "miscompile XLA code, leading to incorrect results or " - "invalid-address errors.\n"; - } -} - absl::StatusOr> CompileGpuAsmOrGetCached( int device_ordinal, const char* ptx, GpuAsmOpts compilation_options) { using PtxCacheKey = std::tuple; @@ -201,7 +180,9 @@ absl::StatusOr> CompileGpuAsm(int device_ordinal, } absl::StatusOr FindCudaExecutable( - std::string_view binary_name, std::string_view preferred_cuda_dir) { + std::string_view binary_name, std::string_view preferred_cuda_dir, + ToolVersion minimum_version, + absl::Span excluded_versions) { #if defined(PLATFORM_WINDOWS) const std::string binary_filename = std::string{binary_name} + ".exe"; #else @@ -234,18 +215,44 @@ absl::StatusOr FindCudaExecutable( for (const auto& candidate : candidates) { VLOG(2) << "Looking for " << candidate; - if (GetToolVersion(candidate).ok()) { - VLOG(2) << "Using " << candidate; - return candidate; + auto candidate_version = GetToolVersion(candidate); + if (!candidate_version.ok()) { + continue; + } + + if (candidate_version.value() < minimum_version) { + VLOG(2) << candidate << " with version " + << absl::StrJoin(minimum_version, ".") << " is too old."; + continue; + } + + if (absl::c_find(excluded_versions, candidate_version.value()) != + excluded_versions.end()) { + VLOG(2) << candidate << " has version " + << absl::StrJoin(candidate_version.value(), ".") + << " which was explicitly excluded."; + continue; } + + VLOG(2) << "Using " << candidate << " with version " + << absl::StrJoin(candidate_version.value(), "."); + return candidate; } return absl::NotFoundError( - absl::StrCat("Couldn't find ", binary_name, + absl::StrCat("Couldn't find a suitable version of ", binary_name, ". The following locations were considered: ", absl::StrJoin(candidates, ", "))); } +absl::StatusOr FindCudaExecutable( + std::string_view binary_name, std::string_view preferred_cuda_dir) { + static constexpr ToolVersion kNoMinimumVersion{0, 0, 0}; + static constexpr absl::Span kNoExcludedVersions{}; + return FindCudaExecutable(binary_name, preferred_cuda_dir, kNoMinimumVersion, + kNoExcludedVersions); +} + static void LogPtxasTooOld(const std::string& ptxas_path, int cc_major, int cc_minor) { using AlreadyLoggedSetTy = @@ -274,29 +281,28 @@ static void AppendArgsFromOptions(GpuAsmOpts options, options.extra_flags.end()); } -absl::StatusOr> GetAsmCompilerVersion( - const std::string& preferred_cuda_dir) { +static absl::StatusOr FindPtxAsExecutable( + std::string_view preferred_cuda_dir) { + static constexpr ToolVersion kMinimumSupportedPtxAsVersion{11, 8, 0}; + static constexpr ToolVersion kBuggyPtxAsVersions[] = {{12, 3, 103}}; + static constexpr std::string_view kPtxAsBinaryName = "ptxas"; + + return FindCudaExecutable(kPtxAsBinaryName, preferred_cuda_dir, + kMinimumSupportedPtxAsVersion, kBuggyPtxAsVersions); +} + +absl::StatusOr GetAsmCompilerVersion( + std::string_view preferred_cuda_dir) { TF_ASSIGN_OR_RETURN(std::string ptxas_path, - FindCudaExecutable("ptxas", preferred_cuda_dir)); + FindPtxAsExecutable(preferred_cuda_dir)); return GetToolVersion(ptxas_path); } absl::StatusOr> CompileGpuAsmUsingPtxAs( int cc_major, int cc_minor, const char* ptx_contents, GpuAsmOpts options, bool cancel_if_reg_spill) { - TF_ASSIGN_OR_RETURN(auto ptxas_version_tuple, - GetAsmCompilerVersion(options.preferred_cuda_dir)); - if (ptxas_version_tuple == std::array{12, 3, 103}) { - return absl::InternalError(absl::StrFormat( - "ptxas %d.%d.%d has a bug that we think can affect XLA. " - "Please use a different version.", - std::get<0>(ptxas_version_tuple), std::get<1>(ptxas_version_tuple), - std::get<2>(ptxas_version_tuple))); - } TF_ASSIGN_OR_RETURN(std::string ptxas_path, - FindCudaExecutable("ptxas", options.preferred_cuda_dir)); - - WarnIfBadPtxasVersion(ptxas_path); + FindPtxAsExecutable(options.preferred_cuda_dir)); // Write ptx into a temporary file. std::string ptx_path; diff --git a/xla/stream_executor/gpu/asm_compiler.h b/xla/stream_executor/gpu/asm_compiler.h index 78d785f76041b..5933a218baca1 100644 --- a/xla/stream_executor/gpu/asm_compiler.h +++ b/xla/stream_executor/gpu/asm_compiler.h @@ -104,16 +104,25 @@ absl::StatusOr> LinkUsingNvlink( absl::string_view preferred_cuda_dir, gpu::GpuContext* context, std::vector images); +using ToolVersion = std::array; +absl::StatusOr FindCudaExecutable( + std::string_view binary_name, std::string_view preferred_cuda_dir, + ToolVersion minimum_version, + absl::Span excluded_versions); + absl::StatusOr FindCudaExecutable( std::string_view binary_name, std::string_view preferred_cuda_dir); // Runs tool --version and parses its version string. -using ToolVersion = std::array; absl::StatusOr GetToolVersion(std::string_view tool_path); -// On NVIDIA GPUs, returns the CUDA toolkit version supported by the driver, +// On NVIDIA GPUs, returns the version of the ptxas command line tool. absl::StatusOr GetAsmCompilerVersion( - const std::string& preferred_cuda_dir); + std::string_view preferred_cuda_dir); + +// On NVIDIA GPUs, returns the version of the nvlink command line tool. +absl::StatusOr GetNvLinkVersion( + std::string_view preferred_cuda_dir); #if GOOGLE_CUDA // Maintains a cache of pointers to loaded kernels