From dceda273fb110e43556a97a8a058eaaea76b5674 Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Wed, 16 Oct 2024 12:26:03 -0700 Subject: [PATCH] Cache the actual concrete Context pointer in each GpuExecutor class, and stop calling gpu_context() where possible. PiperOrigin-RevId: 686600677 --- xla/stream_executor/cuda/cuda_executor.cc | 67 ++++++++++---------- xla/stream_executor/cuda/cuda_executor.h | 4 ++ xla/stream_executor/gpu/gpu_executor.h | 5 -- xla/stream_executor/rocm/rocm_executor.cc | 77 +++++++++++------------ xla/stream_executor/rocm/rocm_executor.h | 1 + 5 files changed, 77 insertions(+), 77 deletions(-) diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index c3c89ce58efcd..fc11568919db4 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -555,17 +555,17 @@ absl::StatusOr CudaExecutor::GetMemoryRange( } std::unique_ptr CudaExecutor::Activate() { - return std::make_unique(gpu_context()); + return std::make_unique(cuda_context_); } CudaExecutor::~CudaExecutor() { - CHECK(kernel_to_gpu_binary_.empty()) << "GpuExecutor has live kernels."; - CHECK(gpu_binary_to_module_.empty()) << "GpuExecutor has loaded modules."; + CHECK(kernel_to_gpu_binary_.empty()) << "CudaExecutor has live kernels."; + CHECK(gpu_binary_to_module_.empty()) << "CudaExecutor has loaded modules."; set_context(nullptr); } void CudaExecutor::UnifiedMemoryDeallocate(void* location) { - ScopedActivateContext activation(gpu_context()); + std::unique_ptr activation = Activate(); CUdeviceptr pointer = absl::bit_cast(location); auto status = cuda::ToStatus(cuMemFree(pointer)); if (!status.ok()) { @@ -573,12 +573,12 @@ void CudaExecutor::UnifiedMemoryDeallocate(void* location) { << "; result: " << status; } else { VLOG(2) << "deallocated unified memory at " << location << " for context " - << gpu_context(); + << cuda_context_; } } void* CudaExecutor::UnifiedMemoryAllocate(uint64_t size) { - ScopedActivateContext activation(gpu_context()); + std::unique_ptr activation = Activate(); CUdeviceptr result = 0; // "Portable" memory is visible to all CUDA contexts. Safe for our use model. auto status = @@ -589,16 +589,17 @@ void* CudaExecutor::UnifiedMemoryAllocate(uint64_t size) { return nullptr; } void* ptr = reinterpret_cast(result); - VLOG(2) << "allocated " << ptr << " for context " << gpu_context() << " of " + VLOG(2) << "allocated " << ptr << " for context " << cuda_context_ << " of " << size << " bytes in unified memory"; return ptr; } absl::Status CudaExecutor::Init() { TF_ASSIGN_OR_RETURN(device_, GetDevice(device_ordinal())); - TF_ASSIGN_OR_RETURN(Context * context, + TF_ASSIGN_OR_RETURN(CudaContext * context, CudaContext::Create(device_ordinal(), device_)); set_context(context); + cuda_context_ = context; TF_RETURN_IF_ERROR(GetComputeCapability(&cc_major_, &cc_minor_, device_)); TF_ASSIGN_OR_RETURN(delay_kernels_supported_, DelayKernelIsSupported()); return absl::OkStatus(); @@ -622,7 +623,7 @@ absl::StatusOr CudaExecutor::LoadModuleFromCuBin( std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle]; if (module == nullptr) { - TF_ASSIGN_OR_RETURN(module, LoadCubin(gpu_context(), cubin)); + TF_ASSIGN_OR_RETURN(module, LoadCubin(cuda_context_, cubin)); module_refcount = 1; VLOG(3) << "Loaded CUBIN " << static_cast(cubin) << " as module " << module; @@ -642,7 +643,7 @@ absl::StatusOr CudaExecutor::LoadModuleFromPtx(const char* ptx) { std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle]; if (module == nullptr) { - TF_ASSIGN_OR_RETURN(module, LoadPtx(gpu_context(), ptx)); + TF_ASSIGN_OR_RETURN(module, LoadPtx(cuda_context_, ptx)); VLOG(3) << "Loaded PTX " << static_cast(ptx) << " as module " << module; module_refcount = 1; @@ -672,7 +673,7 @@ absl::StatusOr> CudaExecutor::LoadKernel( VLOG(2) << "getting function " << *kernel_name << " from module " << module; TF_ASSIGN_OR_RETURN( CUfunction function, - GetModuleFunction(gpu_context(), module, kernel_name->c_str())); + GetModuleFunction(cuda_context_, module, kernel_name->c_str())); cuda_kernel->set_gpu_function(function); } else if (spec.has_cuda_ptx_in_memory()) { @@ -698,7 +699,7 @@ absl::StatusOr> CudaExecutor::LoadKernel( VLOG(2) << "getting function " << *kernel_name << " from module " << module; TF_ASSIGN_OR_RETURN( CUfunction function, - GetModuleFunction(gpu_context(), module, kernel_name->c_str())); + GetModuleFunction(cuda_context_, module, kernel_name->c_str())); cuda_kernel->set_gpu_function(function); } else if (spec.has_in_process_symbol()) { @@ -756,7 +757,7 @@ bool CudaExecutor::UnloadGpuBinary(ModuleHandle gpu_binary) { VLOG(3) << "Found CUDA module " << module << " with refcount " << refcount; if (--refcount == 0) { VLOG(3) << "Unloading CUDA module " << module; - UnloadCudaModule(gpu_context(), module); + UnloadCudaModule(cuda_context_, module); gpu_binary_to_module_.erase(module_it); } return true; @@ -782,7 +783,7 @@ void CudaExecutor::UnloadKernel(const Kernel* kernel) { absl::StatusOr CudaExecutor::LoadModule( const MultiModuleLoaderSpec& spec) { - // In GpuExecutor we store the pointer to the GPU binary (PTX or CUBIN) as + // We store the pointer to the GPU binary (PTX or CUBIN) as // ModuleHandle::id(). if (spec.has_cuda_cubin_in_memory()) { absl::MutexLock lock{&in_memory_modules_mu_}; @@ -905,15 +906,15 @@ DeviceMemoryBase CudaExecutor::Allocate(uint64_t size, int64_t memory_space) { return DeviceMemoryBase(nullptr, 0); } else if (memory_space == static_cast(stream_executor::MemoryType::kHost)) { - return DeviceMemoryBase(HostAllocate(gpu_context(), size), size); + return DeviceMemoryBase(HostAllocate(cuda_context_, size), size); } CHECK_EQ(memory_space, 0); - return DeviceMemoryBase(DeviceAllocate(gpu_context(), size), size); + return DeviceMemoryBase(DeviceAllocate(cuda_context_, size), size); } absl::StatusOr> CudaExecutor::HostMemoryAllocate(uint64_t size) { - auto* buffer = HostAllocate(gpu_context(), size); + auto* buffer = HostAllocate(cuda_context_, size); if (buffer == nullptr && size > 0) { return absl::InternalError( absl::StrFormat("Failed to allocate HostMemory of size %d", size)); @@ -929,25 +930,25 @@ void CudaExecutor::Deallocate(DeviceMemoryBase* mem) { } auto memory_space = status_or_memory_space.value(); if (memory_space == MemoryType::kHost) { - HostDeallocate(gpu_context(), mem->opaque()); + HostDeallocate(cuda_context_, mem->opaque()); } else { - DeviceDeallocate(gpu_context(), mem->opaque()); + DeviceDeallocate(cuda_context_, mem->opaque()); } } void CudaExecutor::HostMemoryDeallocate(void* location) { - return HostDeallocate(gpu_context(), location); + return HostDeallocate(cuda_context_, location); } bool CudaExecutor::SynchronizeAllActivity() { - return gpu_context()->Synchronize().ok(); + return cuda_context_->Synchronize().ok(); } bool CudaExecutor::HostMemoryRegister(void* location, uint64_t size) { VLOG(1) << "Called StreamExecutor::HostMemoryRegister(data=" << location << ")"; - ScopedActivateContext activation(gpu_context()); + std::unique_ptr activation = Activate(); // "Portable" memory is visible to all CUDA contexts. Safe for our use model. auto status = cuda::ToStatus( cuMemHostRegister(location, size, CU_MEMHOSTREGISTER_PORTABLE)); @@ -962,7 +963,7 @@ bool CudaExecutor::HostMemoryRegister(void* location, uint64_t size) { bool CudaExecutor::HostMemoryUnregister(void* location) { VLOG(1) << "Called StreamExecutor::HostUnregister(data=" << location << ")"; - ScopedActivateContext activation(gpu_context()); + std::unique_ptr activation = Activate(); auto status = cuda::ToStatus(cuMemHostUnregister(location)); if (!status.ok()) { LOG(ERROR) << "error unregistering host memory at " << location << ": " @@ -974,7 +975,7 @@ bool CudaExecutor::HostMemoryUnregister(void* location) { absl::Status CudaExecutor::SynchronousMemZero(DeviceMemoryBase* location, uint64_t size) { - ScopedActivateContext activation(gpu_context()); + std::unique_ptr activation = Activate(); CUdeviceptr cuda_location = AsCudaDevicePtr(location); if (reinterpret_cast(location->opaque()) % sizeof(uint32_t) == 0 && size % sizeof(uint32_t) == 0) { @@ -989,7 +990,7 @@ absl::Status CudaExecutor::SynchronousMemZero(DeviceMemoryBase* location, absl::Status CudaExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, const void* host_src, uint64_t size) { - ScopedActivateContext activation(gpu_context()); + std::unique_ptr activation = Activate(); TF_RETURN_IF_ERROR(cuda::ToStatus( cuMemcpyHtoD(AsCudaDevicePtr(gpu_dst), host_src, size), absl::StrFormat( @@ -1003,7 +1004,7 @@ absl::Status CudaExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, absl::Status CudaExecutor::SynchronousMemcpy(void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { - ScopedActivateContext activation(gpu_context()); + std::unique_ptr activation = Activate(); TF_RETURN_IF_ERROR(cuda::ToStatus( cuMemcpyDtoH(host_dst, AsCudaDevicePtr(gpu_src), size), absl::StrFormat("failed to synchronous memcpy from device to host " @@ -1026,7 +1027,7 @@ void CudaExecutor::DeallocateStream(Stream* stream) { } absl::Status CudaExecutor::BlockHostUntilDone(Stream* stream) { - return GpuDriver::SynchronizeStream(gpu_context(), AsGpuStreamValue(stream)); + return GpuDriver::SynchronizeStream(cuda_context_, AsGpuStreamValue(stream)); } blas::BlasSupport* CudaExecutor::AsBlas() { @@ -1091,18 +1092,18 @@ fft::FftSupport* CudaExecutor::AsFft() { } bool CudaExecutor::CanEnablePeerAccessTo(StreamExecutor* other) { - GpuExecutor* cuda_other = static_cast(other); - return CanEnablePeerAccess(gpu_context(), cuda_other->gpu_context()); + CudaExecutor* cuda_other = static_cast(other); + return CanEnablePeerAccess(cuda_context_, cuda_other->cuda_context_); } absl::Status CudaExecutor::EnablePeerAccessTo(StreamExecutor* other) { - GpuExecutor* cuda_other = static_cast(other); - return EnablePeerAccess(gpu_context(), cuda_other->gpu_context()); + CudaExecutor* cuda_other = static_cast(other); + return EnablePeerAccess(cuda_context_, cuda_other->cuda_context_); } bool CudaExecutor::DeviceMemoryUsage(int64_t* free_out, int64_t* total_out) const { - ScopedActivateContext activation(gpu_context()); + ScopedActivateContext activation(cuda_context_); size_t free = 0; size_t total = 0; auto status = cuda::ToStatus(cuMemGetInfo(&free, &total)); @@ -1130,7 +1131,7 @@ absl::StatusOr CudaExecutor::GetSymbol( CUmodule gpu_module_handle = it->second.first; CHECK(gpu_module_handle != nullptr); TF_RETURN_IF_ERROR( - GetModuleSymbol(gpu_context(), gpu_module_handle, symbol_name.c_str(), + GetModuleSymbol(cuda_context_, gpu_module_handle, symbol_name.c_str(), reinterpret_cast(&mem), &bytes)); return DeviceMemoryBase(mem, bytes); } diff --git a/xla/stream_executor/cuda/cuda_executor.h b/xla/stream_executor/cuda/cuda_executor.h index a112da7a31deb..7adf051f5c572 100644 --- a/xla/stream_executor/cuda/cuda_executor.h +++ b/xla/stream_executor/cuda/cuda_executor.h @@ -1,4 +1,5 @@ #include "xla/stream_executor/activate_context.h" +#include "xla/stream_executor/cuda/cuda_context.h" /* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -207,6 +208,9 @@ class CudaExecutor : public GpuExecutor { // Lookup map for alive streams, from raw stream pointers. absl::flat_hash_map alive_gpu_streams_ ABSL_GUARDED_BY(alive_gpu_streams_mu_); + + // CudaContext for this device. + CudaContext* cuda_context_; }; } // namespace stream_executor::gpu diff --git a/xla/stream_executor/gpu/gpu_executor.h b/xla/stream_executor/gpu/gpu_executor.h index 177f9234cd807..7b9a59416a941 100644 --- a/xla/stream_executor/gpu/gpu_executor.h +++ b/xla/stream_executor/gpu/gpu_executor.h @@ -17,7 +17,6 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_GPU_GPU_EXECUTOR_H_ #include -#include #include #include #include @@ -26,12 +25,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/context.h" #include "xla/stream_executor/host_memory_allocation.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_common.h" diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index c287085cd949a..608db59b50f2f 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -153,7 +153,7 @@ absl::StatusOr LoadHsaco(Context* context, hipModule_t module; GetDriverExecutor()->Schedule( [context, hsaco_contents, &module, &returned_status, ¬ification]() { - ScopedActivateContext activation{context}; + ScopedActivateContext activation(context); hipError_t res = wrap::hipModuleLoadData(&module, hsaco_contents); if (res != hipSuccess) { @@ -177,7 +177,7 @@ absl::StatusOr LoadHsaco(Context* context, absl::StatusOr GetModuleFunction(Context* context, hipModule_t module, const char* kernel_name) { - ScopedActivateContext activated{context}; + ScopedActivateContext activated(context); CHECK(module != nullptr && kernel_name != nullptr); hipFunction_t function; TF_RETURN_IF_ERROR( @@ -193,7 +193,7 @@ absl::StatusOr GetModuleFunction(Context* context, absl::Status GetModuleSymbol(Context* context, hipModule_t module, const char* symbol_name, hipDeviceptr_t* dptr, size_t* bytes) { - ScopedActivateContext activated{context}; + ScopedActivateContext activated(context); CHECK(module != nullptr && symbol_name != nullptr && (dptr != nullptr || bytes != nullptr)); return ToStatus(wrap::hipModuleGetGlobal(dptr, bytes, module, symbol_name), @@ -202,7 +202,7 @@ absl::Status GetModuleSymbol(Context* context, hipModule_t module, // Unloads module from the current context via cuModuleUnload. void UnloadRocmModule(Context* context, hipModule_t module) { - ScopedActivateContext activated{context}; + ScopedActivateContext activated(context); hipError_t res = wrap::hipModuleUnload(module); if (res != hipSuccess) { LOG(ERROR) << "failed to unload module " << module @@ -333,7 +333,7 @@ absl::StatusOr GetDevice(int device_ordinal) { // Returns the device associated with the given context. absl::StatusOr DeviceFromContext(Context* context) { - ScopedActivateContext activated{context}; + ScopedActivateContext activated(context); hipDevice_t device = -1; hipError_t result = wrap::hipCtxGetDevice(&device); if (result == hipSuccess) return device; @@ -378,7 +378,7 @@ absl::Status EnablePeerAccess(Context* from, Context* to) { return absl::OkStatus(); // A device can always access its own memory. } - ScopedActivateContext activated{from}; + ScopedActivateContext activated(from); hipError_t result = wrap::hipCtxEnablePeerAccess( tensorflow::down_cast(to)->context(), 0 /* = flags */); if (result != hipSuccess && result != hipErrorPeerAccessAlreadyEnabled) { @@ -424,7 +424,7 @@ void* DeviceAllocate(Context* context, uint64_t bytes) { return nullptr; } - ScopedActivateContext activated{context}; + ScopedActivateContext activated(context); hipDeviceptr_t result = 0; hipError_t res = wrap::hipMalloc(&result, bytes); if (res != hipSuccess) { @@ -444,7 +444,7 @@ void* DeviceAllocate(Context* context, uint64_t bytes) { // Deallocates memory on the GPU device that was previously allocated via // DeviceAllocate. void DeviceDeallocate(Context* context, void* location) { - ScopedActivateContext activation{context}; + ScopedActivateContext activation(context); hipDeviceptr_t pointer = absl::bit_cast(location); hipError_t res = wrap::hipFree(pointer); if (res != hipSuccess) { @@ -458,7 +458,7 @@ void DeviceDeallocate(Context* context, void* location) { // Allocates memory on the host. void* HostAllocate(Context* context, uint64_t bytes) { - ScopedActivateContext activation{context}; + ScopedActivateContext activation(context); void* host_mem = nullptr; // "Portable" memory is visible to all ROCM contexts. Safe for our use model. hipError_t res = wrap::hipHostMalloc(&host_mem, bytes, hipHostMallocPortable); @@ -473,15 +473,15 @@ void* HostAllocate(Context* context, uint64_t bytes) { RocmExecutor::~RocmExecutor() { for (auto& it : in_memory_modules_) { - UnloadRocmModule(gpu_context(), it.second); + UnloadRocmModule(rocm_context_, it.second); } set_context(nullptr); - CHECK(kernel_to_gpu_binary_.empty()) << "GpuExecutor has live kernels."; - CHECK(gpu_binary_to_module_.empty()) << "GpuExecutor has loaded modules."; + CHECK(kernel_to_gpu_binary_.empty()) << "RocmExecutor has live kernels."; + CHECK(gpu_binary_to_module_.empty()) << "RocmExecutor has loaded modules."; } std::unique_ptr RocmExecutor::Activate() { - return std::make_unique(gpu_context()); + return std::make_unique(rocm_context_); } bool RocmExecutor::UnloadModule(ModuleHandle module_handle) { @@ -582,7 +582,7 @@ bool RocmExecutor::UnloadGpuBinary(ModuleHandle module_handle) { VLOG(3) << "Found HSACO module " << module << " with refcount " << refcount; if (--refcount == 0) { VLOG(3) << "Unloading HSACO module " << module; - UnloadRocmModule(gpu_context(), module); + UnloadRocmModule(rocm_context_, module); gpu_binary_to_module_.erase(module_it); ModuleHandle mem_it{}; for (auto x : in_memory_modules_) { @@ -634,14 +634,14 @@ absl::StatusOr> RocmExecutor::LoadKernel( hipModule_t& module = in_memory_modules_[module_handle]; if (module == nullptr) { - TF_ASSIGN_OR_RETURN(module, LoadHsaco(gpu_context(), hsaco)); + TF_ASSIGN_OR_RETURN(module, LoadHsaco(rocm_context_, hsaco)); } kernel_to_gpu_binary_[rocm_kernel.get()] = module_handle; VLOG(2) << "getting function " << *kernel_name << " from module " << module; TF_ASSIGN_OR_RETURN( hipFunction_t function, - GetModuleFunction(gpu_context(), module, kernel_name->c_str())); + GetModuleFunction(rocm_context_, module, kernel_name->c_str())); rocm_kernel->set_gpu_function(function); } else if (spec.has_in_process_symbol()) { kernel_name = &spec.in_process_symbol().kernel_name(); @@ -694,8 +694,7 @@ absl::Status RocmExecutor::GetKernelMetadata(GpuKernel* rocm_kernel, absl::StatusOr RocmExecutor::LoadModule( const MultiModuleLoaderSpec& spec) { - // In GpuExecutor we store the pointer to the HSACO binary as - // ModuleHandle::id(). + // We store the pointer to the HSACO binary as ModuleHandle::id(). // TODO(ROCm): Need generic term instead of cubin/cuda/ptx if (spec.has_cuda_cubin_in_memory()) { @@ -715,7 +714,7 @@ absl::StatusOr RocmExecutor::LoadModuleFromHsaco( std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle]; if (module == nullptr) { - TF_ASSIGN_OR_RETURN(module, LoadHsaco(gpu_context(), hsaco)); + TF_ASSIGN_OR_RETURN(module, LoadHsaco(rocm_context_, hsaco)); module_refcount = 1; in_memory_modules_[module_handle] = module; VLOG(3) << "Loaded HSACO " << static_cast(hsaco) @@ -732,14 +731,14 @@ absl::StatusOr RocmExecutor::LoadModuleFromHsaco( DeviceMemoryBase RocmExecutor::Allocate(uint64_t size, int64_t memory_space) { if (memory_space == static_cast(stream_executor::MemoryType::kHost)) { - return DeviceMemoryBase(HostAllocate(gpu_context(), size), size); + return DeviceMemoryBase(HostAllocate(rocm_context_, size), size); } CHECK_EQ(memory_space, 0); - return DeviceMemoryBase(DeviceAllocate(gpu_context(), size), size); + return DeviceMemoryBase(DeviceAllocate(rocm_context_, size), size); } absl::StatusOr> RocmExecutor::HostMemoryAllocate(uint64_t size) { - auto* buffer = HostAllocate(gpu_context(), size); + auto* buffer = HostAllocate(rocm_context_, size); if (buffer == nullptr && size > 0) { return absl::InternalError( absl::StrFormat("Failed to allocate HostMemory of size %d", size)); @@ -748,7 +747,7 @@ RocmExecutor::HostMemoryAllocate(uint64_t size) { } void RocmExecutor::HostMemoryDeallocate(void* location) { - ScopedActivateContext activation{gpu_context()}; + std::unique_ptr activation = Activate(); hipError_t res = wrap::hipHostFree(location); if (res != hipSuccess) { LOG(ERROR) << "error deallocating host memory at " << location << ": " @@ -757,11 +756,11 @@ void RocmExecutor::HostMemoryDeallocate(void* location) { } void RocmExecutor::Deallocate(DeviceMemoryBase* mem) { - DeviceDeallocate(gpu_context(), mem->opaque()); + DeviceDeallocate(rocm_context_, mem->opaque()); } void* RocmExecutor::UnifiedMemoryAllocate(uint64_t size) { - ScopedActivateContext activated{gpu_context()}; + std::unique_ptr activation = Activate(); hipDeviceptr_t result = 0; // "managed" memory is visible to both CPU and GPU. hipError_t res = wrap::hipMallocManaged(&result, size, hipMemAttachGlobal); @@ -771,13 +770,13 @@ void* RocmExecutor::UnifiedMemoryAllocate(uint64_t size) { return nullptr; } void* ptr = reinterpret_cast(result); - VLOG(2) << "allocated " << ptr << " for context " << gpu_context() << " of " + VLOG(2) << "allocated " << ptr << " for context " << rocm_context_ << " of " << size << " bytes in unified memory"; return ptr; } void RocmExecutor::UnifiedMemoryDeallocate(void* location) { - ScopedActivateContext activation(gpu_context()); + std::unique_ptr activation = Activate(); hipDeviceptr_t pointer = absl::bit_cast(location); hipError_t res = wrap::hipFree(pointer); if (res != hipSuccess) { @@ -785,17 +784,17 @@ void RocmExecutor::UnifiedMemoryDeallocate(void* location) { << "; result: " << ToString(res); } else { VLOG(2) << "deallocated unified memory at " << location << " for context " - << gpu_context(); + << rocm_context_; } } bool RocmExecutor::SynchronizeAllActivity() { - return gpu_context()->Synchronize().ok(); + return rocm_context_->Synchronize().ok(); } absl::Status RocmExecutor::SynchronousMemZero(DeviceMemoryBase* location, uint64_t size) { - ScopedActivateContext activation{gpu_context()}; + std::unique_ptr activation = Activate(); hipDeviceptr_t rocm_location = AsROCmDevicePtr(location); if (reinterpret_cast(location->opaque()) % sizeof(uint32_t) == 0 && size % sizeof(uint32_t) == 0) { @@ -810,7 +809,7 @@ absl::Status RocmExecutor::SynchronousMemZero(DeviceMemoryBase* location, absl::Status RocmExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, const void* host_src, uint64_t size) { - ScopedActivateContext activation(gpu_context()); + std::unique_ptr activation = Activate(); TF_RETURN_IF_ERROR(ToStatus( wrap::hipMemcpyHtoD(AsROCmDevicePtr(gpu_dst), const_cast(host_src), size), @@ -825,7 +824,7 @@ absl::Status RocmExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, absl::Status RocmExecutor::SynchronousMemcpy(void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { - ScopedActivateContext activation{gpu_context()}; + std::unique_ptr activation = Activate(); TF_RETURN_IF_ERROR(ToStatus( wrap::hipMemcpyDtoH(host_dst, AsROCmDevicePtr(gpu_src), size), absl::StrFormat("failed to synchronous memcpy from device to host: " @@ -849,7 +848,7 @@ void RocmExecutor::DeallocateStream(Stream* stream) { } absl::Status RocmExecutor::BlockHostUntilDone(Stream* stream) { - return GpuDriver::SynchronizeStream(gpu_context(), AsGpuStreamValue(stream)); + return GpuDriver::SynchronizeStream(rocm_context_, AsGpuStreamValue(stream)); } blas::BlasSupport* RocmExecutor::AsBlas() { @@ -914,13 +913,13 @@ fft::FftSupport* RocmExecutor::AsFft() { } bool RocmExecutor::CanEnablePeerAccessTo(StreamExecutor* other) { - GpuExecutor* rocm_other = static_cast(other); - return CanEnablePeerAccess(gpu_context(), rocm_other->gpu_context()); + RocmExecutor* rocm_other = static_cast(other); + return CanEnablePeerAccess(rocm_context_, rocm_other->rocm_context_); } absl::Status RocmExecutor::EnablePeerAccessTo(StreamExecutor* other) { - GpuExecutor* rocm_other = static_cast(other); - return EnablePeerAccess(gpu_context(), rocm_other->gpu_context()); + RocmExecutor* rocm_other = static_cast(other); + return EnablePeerAccess(rocm_context_, rocm_other->rocm_context_); } bool RocmExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { @@ -937,14 +936,14 @@ absl::StatusOr RocmExecutor::GetSymbol( auto it = gpu_binary_to_module_.find(module_handle); CHECK(it != gpu_binary_to_module_.end()); TF_RETURN_IF_ERROR( - GetModuleSymbol(gpu_context(), it->second.first, symbol_name.c_str(), + GetModuleSymbol(rocm_context_, it->second.first, symbol_name.c_str(), reinterpret_cast(&mem), &bytes)); return DeviceMemoryBase(mem, bytes); } for (auto& it : gpu_binary_to_module_) { TF_RETURN_IF_ERROR( - GetModuleSymbol(gpu_context(), it.second.first, symbol_name.c_str(), + GetModuleSymbol(rocm_context_, it.second.first, symbol_name.c_str(), reinterpret_cast(&mem), &bytes)); return DeviceMemoryBase(mem, bytes); } diff --git a/xla/stream_executor/rocm/rocm_executor.h b/xla/stream_executor/rocm/rocm_executor.h index 9fe4af1b54a1d..8adb755fa1e14 100644 --- a/xla/stream_executor/rocm/rocm_executor.h +++ b/xla/stream_executor/rocm/rocm_executor.h @@ -32,6 +32,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_description.h"