Skip to content

Commit

Permalink
Use ModuleHandle for module management in Executor
Browse files Browse the repository at this point in the history
We used to use void pointers to identify loaded GPU binaries. This change makes it use the `ModuleHandle` type. It avoids a bunch of `reinterpret_cast` and makes it also more clear what the void pointer represents.

In addition this is replacing the out parameters in a bunch of related functions by `absl::StatusOr<...>` return types.

PiperOrigin-RevId: 686290766
  • Loading branch information
beckerhe authored and Google-ML-Automation committed Oct 16, 2024
1 parent 5d60fc7 commit 918e7cf
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 155 deletions.
2 changes: 1 addition & 1 deletion xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
if (!(executor->GetPlatform()->id() ==
stream_executor::cuda::kCudaPlatformId &&
binary().empty() && text().empty())) {
TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle));
TF_ASSIGN_OR_RETURN(module_handle, executor->LoadModule(module_spec));
}

// A flag signalling if constant initialization submitted memcpy operations
Expand Down
1 change: 1 addition & 0 deletions xla/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ cc_library(
name = "module_spec",
hdrs = ["module_spec.h"],
deps = [
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:logging",
],
Expand Down
138 changes: 70 additions & 68 deletions xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ tsl::thread::ThreadPool* GetDriverExecutor() {
return thread_pool;
}

// Loads ptx_contents with the CUDA driver's PTX JIT and stores the resulting
// handle in "module". Any error logs that are produced are logged internally.
absl::Status LoadPtx(Context* context, const char* ptx_contents,
CUmodule* module) {
// Loads ptx_contents with the CUDA driver's PTX JIT and return the resulting
// handle. Any error logs that are produced are logged internally.
absl::StatusOr<CUmodule> LoadPtx(Context* context, const char* ptx_contents) {
absl::Notification notification;
absl::Status returned_status = absl::OkStatus();
CUmodule module;
GetDriverExecutor()->Schedule(
[context, ptx_contents, module, &returned_status, &notification]() {
[context, ptx_contents, &module, &returned_status, &notification]() {
ScopedActivateContext activation(context);
void* ptx_data = const_cast<char*>(ptx_contents);
static const unsigned int kLogBufferBytesLimit = 1024;
Expand All @@ -155,7 +155,7 @@ absl::Status LoadPtx(Context* context, const char* ptx_contents,

absl::Status status;
status = cuda::ToStatus(cuModuleLoadDataEx(
module, ptx_data, TF_ARRAYSIZE(options), options, option_values));
&module, ptx_data, TF_ARRAYSIZE(options), options, option_values));

// The PTX JIT mutates the values in the option values array to reflect
// the size of the logs it output; now that we've made the call, read
Expand Down Expand Up @@ -195,24 +195,26 @@ absl::Status LoadPtx(Context* context, const char* ptx_contents,
});
notification.WaitForNotification();

return returned_status;
TF_RETURN_IF_ERROR(returned_status);
return module;
}

// Loads cubin_bytes with the CUDA driver's blob loading interface and stores
// the resulting handle in "module".
absl::Status LoadCubin(Context* context, const char* cubin_bytes,
CUmodule* module) {
absl::StatusOr<CUmodule> LoadCubin(Context* context, const char* cubin_bytes) {
ScopedActivateContext activation(context);
return cuda::ToStatus(
cuModuleLoadFatBinary(module, cubin_bytes),
"Failed to load in-memory CUBIN (compiled for a different GPU?).");
CUmodule module;
TF_RETURN_IF_ERROR(cuda::ToStatus(
cuModuleLoadFatBinary(&module, cubin_bytes),
"Failed to load in-memory CUBIN (compiled for a different GPU?)."));
return module;
}

// Retrieves a named kernel from a loaded module, and places the resulting
// handle into function (outparam) on success. Neither kernel_name nor
// function may be null. No ownership is taken of kernel_name.
absl::Status GetModuleFunction(Context* context, CUmodule module,
const char* kernel_name, CUfunction* function) {
// Retrieves a named kernel from a loaded module, and return the CUfunction
// handle on success. Neither kernel_name nor function may be null. No ownership
// is taken of kernel_name.
absl::StatusOr<CUfunction> GetModuleFunction(Context* context, CUmodule module,
const char* kernel_name) {
ScopedActivateContext activated{context};
CHECK(module != nullptr && kernel_name != nullptr);
cudaError_t cuda_error = cudaPeekAtLastError();
Expand All @@ -222,8 +224,11 @@ absl::Status GetModuleFunction(Context* context, CUmodule module,
cuda_error, "): ", cudaGetErrorName(cuda_error), " : ",
cudaGetErrorString(cuda_error)));
}
return cuda::ToStatus(cuModuleGetFunction(function, module, kernel_name),
"Failed to get module function");
CUfunction function;
TF_RETURN_IF_ERROR(
cuda::ToStatus(cuModuleGetFunction(&function, module, kernel_name),
"Failed to get module function"));
return function;
}

// Retrieves a named global/constant symbol from a loaded module, and returns
Expand Down Expand Up @@ -609,57 +614,66 @@ absl::StatusOr<bool> CudaExecutor::DelayKernelIsSupported() {
return static_cast<bool>(status);
}

absl::Status CudaExecutor::LoadModuleFromCuBin(const char* cubin,
CUmodule* module) {
absl::StatusOr<ModuleHandle> CudaExecutor::LoadModuleFromCuBin(
const char* cubin) {
ModuleHandle module_handle{cubin};
uint64_t module_refcount;
std::tie(*module, module_refcount) = gpu_binary_to_module_[cubin];
CUmodule module;
std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle];

if (*module == nullptr) {
TF_RETURN_IF_ERROR(LoadCubin(gpu_context(), cubin, module));
if (module == nullptr) {
TF_ASSIGN_OR_RETURN(module, LoadCubin(gpu_context(), cubin));
module_refcount = 1;
VLOG(3) << "Loaded CUBIN " << static_cast<const void*>(cubin)
<< " as module " << *module;
<< " as module " << module;
} else {
++module_refcount;
VLOG(3) << "CUBIN " << static_cast<const void*>(cubin)
<< " is already loaded as module " << *module;
<< " is already loaded as module " << module;
}
gpu_binary_to_module_[cubin] = {*module, module_refcount};
return absl::OkStatus();
gpu_binary_to_module_[module_handle] = {module, module_refcount};
return module_handle;
}

absl::Status CudaExecutor::LoadModuleFromPtx(const char* ptx,
CUmodule* module) {
absl::StatusOr<ModuleHandle> CudaExecutor::LoadModuleFromPtx(const char* ptx) {
ModuleHandle module_handle{ptx};
uint64_t module_refcount;
std::tie(*module, module_refcount) = gpu_binary_to_module_[ptx];
CUmodule module;
std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle];

if (*module == nullptr) {
TF_RETURN_IF_ERROR(LoadPtx(gpu_context(), ptx, module));
if (module == nullptr) {
TF_ASSIGN_OR_RETURN(module, LoadPtx(gpu_context(), ptx));
VLOG(3) << "Loaded PTX " << static_cast<const void*>(ptx) << " as module "
<< *module;
<< module;
module_refcount = 1;
} else {
++module_refcount;
VLOG(3) << "PTX " << static_cast<const void*>(ptx)
<< " is already loaded as module " << module;
}
gpu_binary_to_module_[ptx] = {*module, module_refcount};
return absl::OkStatus();
gpu_binary_to_module_[module_handle] = {module, module_refcount};
return module_handle;
}

absl::StatusOr<std::unique_ptr<Kernel>> CudaExecutor::LoadKernel(
const MultiKernelLoaderSpec& spec) {
auto cuda_kernel = std::make_unique<CudaKernel>(this);
CUmodule module;
const std::string* kernel_name;

if (spec.has_cuda_cubin_in_memory()) {
absl::MutexLock lock{&in_memory_modules_mu_};
kernel_name = &spec.cuda_cubin_in_memory().kernel_name();
const char* cubin = reinterpret_cast<const char*>(
spec.cuda_cubin_in_memory().cubin_bytes().data());
TF_RETURN_IF_ERROR(LoadModuleFromCuBin(cubin, &module));
kernel_to_gpu_binary_[cuda_kernel.get()] = cubin;
TF_ASSIGN_OR_RETURN(ModuleHandle module_handle, LoadModuleFromCuBin(cubin));
kernel_to_gpu_binary_[cuda_kernel.get()] = module_handle;

CUmodule module = gpu_binary_to_module_.at(module_handle).first;
VLOG(2) << "getting function " << *kernel_name << " from module " << module;
TF_ASSIGN_OR_RETURN(
CUfunction function,
GetModuleFunction(gpu_context(), module, kernel_name->c_str()));
cuda_kernel->set_gpu_function(function);

} else if (spec.has_cuda_ptx_in_memory()) {
kernel_name = &spec.cuda_ptx_in_memory().kernel_name();
Expand All @@ -677,8 +691,15 @@ absl::StatusOr<std::unique_ptr<Kernel>> CudaExecutor::LoadKernel(
}

absl::MutexLock lock{&in_memory_modules_mu_};
TF_RETURN_IF_ERROR(LoadModuleFromPtx(ptx, &module));
kernel_to_gpu_binary_[cuda_kernel.get()] = ptx;
TF_ASSIGN_OR_RETURN(ModuleHandle module_handle, LoadModuleFromPtx(ptx));
kernel_to_gpu_binary_[cuda_kernel.get()] = module_handle;

CUmodule module = gpu_binary_to_module_.at(module_handle).first;
VLOG(2) << "getting function " << *kernel_name << " from module " << module;
TF_ASSIGN_OR_RETURN(
CUfunction function,
GetModuleFunction(gpu_context(), module, kernel_name->c_str()));
cuda_kernel->set_gpu_function(function);

} else if (spec.has_in_process_symbol()) {
kernel_name = &spec.in_process_symbol().kernel_name();
Expand All @@ -695,15 +716,6 @@ absl::StatusOr<std::unique_ptr<Kernel>> CudaExecutor::LoadKernel(
return absl::InternalError("No method of loading CUDA kernel provided");
}
VLOG(3) << "LoadKernel on kernel : " << *kernel_name;
// If we resolved kernel from a symbol pointer, there is no need to load it
// from a module, as CUDA runtime did that automatically for us.
if (!spec.has_in_process_symbol()) {
VLOG(2) << "getting function " << *kernel_name << " from module " << module;
CUfunction function;
TF_RETURN_IF_ERROR(GetModuleFunction(gpu_context(), module,
kernel_name->c_str(), &function));
cuda_kernel->set_gpu_function(function);
}

// Update CUDA kernel properties after it was loaded in the CUDA context.
cuda_kernel->set_name(*kernel_name);
Expand Down Expand Up @@ -733,7 +745,7 @@ CudaExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) {
return std::make_unique<CudaTimer>(std::move(timer));
}

bool CudaExecutor::UnloadGpuBinary(const void* gpu_binary) {
bool CudaExecutor::UnloadGpuBinary(ModuleHandle gpu_binary) {
auto module_it = gpu_binary_to_module_.find(gpu_binary);
if (gpu_binary_to_module_.end() == module_it) {
VLOG(3) << "No loaded CUDA module for " << gpu_binary;
Expand Down Expand Up @@ -768,19 +780,14 @@ void CudaExecutor::UnloadKernel(const Kernel* kernel) {
kernel_to_gpu_binary_.erase(gpu_binary_it);
}

absl::Status CudaExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
ModuleHandle* module_handle) {
absl::StatusOr<ModuleHandle> CudaExecutor::LoadModule(
const MultiModuleLoaderSpec& spec) {
// In GpuExecutor we store the pointer to the GPU binary (PTX or CUBIN) as
// ModuleHandle::id().
CUmodule cu_module;
if (spec.has_cuda_cubin_in_memory()) {
absl::MutexLock lock{&in_memory_modules_mu_};
TF_RETURN_IF_ERROR(LoadModuleFromCuBin(
reinterpret_cast<const char*>(spec.cuda_cubin_in_memory().data()),
&cu_module));
*module_handle = ModuleHandle(const_cast<void*>(
static_cast<const void*>(spec.cuda_cubin_in_memory().data())));
return absl::OkStatus();
return LoadModuleFromCuBin(
reinterpret_cast<const char*>(spec.cuda_cubin_in_memory().data()));
} else if (spec.has_cuda_ptx_in_memory()) {
if (cc_major_ == 0 && cc_minor_ == 0) {
return absl::InternalError("Compute capability not set");
Expand All @@ -791,19 +798,14 @@ absl::Status CudaExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
}

absl::MutexLock lock{&in_memory_modules_mu_};
TF_RETURN_IF_ERROR(
LoadModuleFromPtx(spec.cuda_ptx_in_memory(), &cu_module));
*module_handle = ModuleHandle(
const_cast<void*>(static_cast<const void*>(spec.cuda_ptx_in_memory())));
return absl::OkStatus();
return LoadModuleFromPtx(spec.cuda_ptx_in_memory());
}
return absl::InternalError("No method of loading CUDA module provided");
}

bool CudaExecutor::UnloadModule(ModuleHandle module_handle) {
const char* gpu_binary = reinterpret_cast<const char*>(module_handle.id());
absl::MutexLock lock{&in_memory_modules_mu_};
return UnloadGpuBinary(gpu_binary);
return UnloadGpuBinary(module_handle);
}

namespace {
Expand Down Expand Up @@ -1122,7 +1124,7 @@ absl::StatusOr<DeviceMemoryBase> CudaExecutor::GetSymbol(

{ // give limited scope to mutex_lock
absl::MutexLock lock{&in_memory_modules_mu_};
auto it = gpu_binary_to_module_.find(module_handle.id());
auto it = gpu_binary_to_module_.find(module_handle);
CHECK(it != gpu_binary_to_module_.end());

CUmodule gpu_module_handle = it->second.first;
Expand Down
23 changes: 11 additions & 12 deletions xla/stream_executor/cuda/cuda_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ class CudaExecutor : public GpuExecutor {
absl::StatusOr<std::unique_ptr<Kernel>> LoadKernel(
const MultiKernelLoaderSpec& spec) override;
void UnloadKernel(const Kernel* kernel) override;
absl::Status LoadModule(const MultiModuleLoaderSpec& spec,
ModuleHandle* module_handle) override;
absl::StatusOr<ModuleHandle> LoadModule(
const MultiModuleLoaderSpec& spec) override;
bool UnloadModule(ModuleHandle module_handle) override;
absl::StatusOr<std::shared_ptr<DeviceMemoryBase>> CreateOrShareConstant(
Stream* stream, absl::Span<const uint8_t> content) override;
Expand Down Expand Up @@ -142,16 +142,15 @@ class CudaExecutor : public GpuExecutor {
absl::Status GetKernelMetadata(GpuKernel* cuda_kernel,
KernelMetadata* kernel_metadata);

// (supported on CUDA only)
absl::Status LoadModuleFromCuBin(const char* cubin, CUmodule* module)
// Loads a module in cubin format.
absl::StatusOr<ModuleHandle> LoadModuleFromCuBin(const char* cubin)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);

// Loads the PTX text `ptx` as a CUDA module. `ptx` must be null terminated.
// (supported on CUDA only)
absl::Status LoadModuleFromPtx(const char* ptx, CUmodule* module)
// Loads the PTX text `ptx` as a CUDA module. `ptx` must be null terminated.
absl::StatusOr<ModuleHandle> LoadModuleFromPtx(const char* ptx)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);

bool UnloadGpuBinary(const void* gpu_binary)
bool UnloadGpuBinary(ModuleHandle gpu_binary)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_);

// Returns true if a delay kernel is supported.
Expand All @@ -167,12 +166,12 @@ class CudaExecutor : public GpuExecutor {
std::map<const absl::uint128, std::weak_ptr<DeviceMemoryBase>>
shared_constants_ ABSL_GUARDED_BY(shared_constants_mu_);

// Kernel -> loaded GPU binary. Many kernels may load the same binary.
absl::flat_hash_map<const Kernel*, const void*> kernel_to_gpu_binary_
// Kernel -> loaded GPU module. Many kernels may load the same binary.
absl::flat_hash_map<const Kernel*, ModuleHandle> kernel_to_gpu_binary_
ABSL_GUARDED_BY(in_memory_modules_mu_);

// GPU binary (PTX or CUBIN) -> {CUDA module, reference count}.
absl::flat_hash_map<const void*, std::pair<CUmodule, uint64_t>>
// Loaded GPU module handle -> {CUDA module, reference count}.
absl::flat_hash_map<ModuleHandle, std::pair<CUmodule, uint64_t>>
gpu_binary_to_module_ ABSL_GUARDED_BY(in_memory_modules_mu_);

// Handle for the CUDA device being operated on. Immutable
Expand Down
5 changes: 2 additions & 3 deletions xla/stream_executor/gpu/mock_gpu_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ class MockGpuExecutor : public GpuExecutor {
MOCK_METHOD(absl::StatusOr<std::unique_ptr<Kernel>>, LoadKernel,
(const MultiKernelLoaderSpec& spec), (override));
MOCK_METHOD(bool, UnloadModule, (ModuleHandle module_handle), (override));
MOCK_METHOD(absl::Status, LoadModule,
(const MultiModuleLoaderSpec& spec, ModuleHandle* module_handle),
(override));
MOCK_METHOD(absl::StatusOr<ModuleHandle>, LoadModule,
(const MultiModuleLoaderSpec& spec), (override));
MOCK_METHOD(absl::StatusOr<std::shared_ptr<DeviceMemoryBase>>,
CreateOrShareConstant,
(Stream * stream, absl::Span<const uint8_t> content), (override));
Expand Down
5 changes: 2 additions & 3 deletions xla/stream_executor/mock_stream_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ class MockStreamExecutor : public StreamExecutor {
(const MultiKernelLoaderSpec& spec), (override));
MOCK_METHOD(std::unique_ptr<ActivateContext>, Activate, (), (override));
MOCK_METHOD(bool, UnloadModule, (ModuleHandle module_handle), (override));
MOCK_METHOD(absl::Status, LoadModule,
(const MultiModuleLoaderSpec& spec, ModuleHandle* module_handle),
(override));
MOCK_METHOD(absl::StatusOr<ModuleHandle>, LoadModule,
(const MultiModuleLoaderSpec& spec), (override));
MOCK_METHOD(absl::StatusOr<std::shared_ptr<DeviceMemoryBase>>,
CreateOrShareConstant,
(Stream * stream, absl::Span<const uint8_t> content), (override));
Expand Down
Loading

0 comments on commit 918e7cf

Please sign in to comment.