Skip to content

Commit 1e22953

Browse files
reedwmGoogle-ML-Automation
authored andcommitted
Remove Thunk::Cleanup method.
The method was effectively unused, since it wasn't overridden by SequentialThunk, and so SequentialThunk wouldn't call Cleanup on its subthunks. NcclRaggedAllToAllStartThunk overrode Cleanup to free some device buffers, but these were never freed since Cleanup was not called. The memory is now stored in DeviceMemoryHandles, which automatically free the buffers in the destructor. PiperOrigin-RevId: 725409574
1 parent 4825d5d commit 1e22953

File tree

5 files changed

+12
-58
lines changed

5 files changed

+12
-58
lines changed

xla/backends/gpu/runtime/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,7 @@ cc_library(
787787
"//xla/service:collective_ops_utils",
788788
"//xla/service/gpu/transforms/collectives:collective_ops_utils",
789789
"//xla/stream_executor:device_memory",
790+
"//xla/stream_executor:device_memory_handle",
790791
"//xla/stream_executor:memory_allocation",
791792
"//xla/stream_executor:stream",
792793
"//xla/tsl/platform:errors",

xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.cc

+7-18
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ limitations under the License.
4141
#include "xla/shape.h"
4242
#include "xla/shape_util.h"
4343
#include "xla/stream_executor/device_memory.h"
44+
#include "xla/stream_executor/device_memory_handle.h"
4445
#include "xla/stream_executor/memory_allocation.h"
4546
#include "xla/stream_executor/stream.h"
4647
#include "xla/tsl/platform/errors.h"
@@ -172,15 +173,16 @@ absl::Status NcclRaggedAllToAllStartThunk::Initialize(
172173
}
173174

174175
if (!device_buffer_allocs_.contains(params.executor)) {
175-
se::DeviceMemoryBase output_offsets_device_buffer =
176-
params.executor->Allocate(config_.num_ragged_rows * sizeof(int64_t));
176+
se::DeviceMemoryHandle output_offsets_device_buffer{
177+
params.executor,
178+
params.executor->Allocate(config_.num_ragged_rows * sizeof(int64_t))};
177179

178-
if (output_offsets_device_buffer.is_null()) {
180+
if (output_offsets_device_buffer.memory().is_null()) {
179181
return absl::InternalError("Failed to allocate output offsets buffer.");
180182
}
181183

182184
device_buffer_allocs_.emplace(params.executor,
183-
output_offsets_device_buffer);
185+
std::move(output_offsets_device_buffer));
184186
}
185187

186188
if (should_use_memcpy()) {
@@ -214,19 +216,6 @@ absl::Status NcclRaggedAllToAllStartThunk::Initialize(
214216
return absl::OkStatus();
215217
}
216218

217-
absl::Status NcclRaggedAllToAllStartThunk::Cleanup(
218-
const CleanupParams& params) {
219-
absl::MutexLock lock(&mutex_);
220-
221-
if (device_buffer_allocs_.contains(params.executor)) {
222-
se::DeviceMemoryBase alloc =
223-
device_buffer_allocs_.extract(params.executor).mapped();
224-
params.executor->Deallocate(&alloc);
225-
}
226-
227-
return absl::OkStatus();
228-
}
229-
230219
bool NcclRaggedAllToAllStartThunk::is_local() const {
231220
CHECK_NE(device_count_, -1);
232221
for (const auto& replica_group : config_.config.replica_groups) {
@@ -267,7 +256,7 @@ absl::Status NcclRaggedAllToAllStartThunk::RunNcclCollective(
267256

268257
auto jt = device_buffer_allocs_.find(stream.parent());
269258
CHECK(jt != device_buffer_allocs_.end());
270-
output_offsets_device_buffer = jt->second;
259+
output_offsets_device_buffer = jt->second.memory();
271260
}
272261

273262
if (should_use_memcpy()) {

xla/backends/gpu/runtime/nccl_ragged_all_to_all_thunk.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ limitations under the License.
3232
#include "xla/hlo/ir/hlo_instructions.h"
3333
#include "xla/service/collective_ops_utils.h"
3434
#include "xla/stream_executor/device_memory.h"
35+
#include "xla/stream_executor/device_memory_handle.h"
3536
#include "xla/stream_executor/memory_allocation.h"
3637
#include "xla/stream_executor/stream.h"
3738

@@ -61,8 +62,6 @@ class NcclRaggedAllToAllStartThunk : public NcclCollectiveThunk {
6162

6263
absl::Status Initialize(const InitializeParams& params) override;
6364

64-
absl::Status Cleanup(const CleanupParams& params) override;
65-
6665
static const char* GetHloOpName() { return "ragged-all-to-all-start"; }
6766

6867
static CollectiveOpGroupMode GetGroupMode(
@@ -92,7 +91,7 @@ class NcclRaggedAllToAllStartThunk : public NcclCollectiveThunk {
9291
std::vector<std::unique_ptr<se::MemoryAllocation>>>
9392
host_buffer_allocs_ ABSL_GUARDED_BY(mutex_);
9493

95-
absl::flat_hash_map<se::StreamExecutor*, se::DeviceMemoryBase>
94+
absl::flat_hash_map<se::StreamExecutor*, se::DeviceMemoryHandle>
9695
device_buffer_allocs_ ABSL_GUARDED_BY(mutex_);
9796

9897
absl::Mutex pointers_mutex_;

xla/backends/gpu/runtime/thunk.h

-25
Original file line numberDiff line numberDiff line change
@@ -423,23 +423,6 @@ class Thunk {
423423
bool requires_exclusive_lock_on_gpu = false);
424424
};
425425

426-
//===--------------------------------------------------------------------===//
427-
// CleanupParams
428-
//===--------------------------------------------------------------------===//
429-
430-
// Parameters passed to Cleanup. Before returning from executable execution,
431-
// thunks may need to clean up any resource allocated or registered through
432-
// runtime APIs.
433-
struct CleanupParams {
434-
se::StreamExecutor* executor = nullptr;
435-
436-
// Parameters for executing collective operations.
437-
CollectiveExecuteParams* collective_params = nullptr;
438-
439-
// Collective cliques acquired based on resource requests.
440-
CollectiveCliques* collective_cliques = nullptr;
441-
};
442-
443426
//===--------------------------------------------------------------------===//
444427

445428
Thunk(Kind kind, ThunkInfo thunk_info)
@@ -481,14 +464,6 @@ class Thunk {
481464
// Precondition: Initialize(initialize_params) has been called.
482465
virtual absl::Status ExecuteOnStream(const ExecuteParams& params) = 0;
483466

484-
// Cleans up any resources after thunk execution.
485-
//
486-
// This may be called multiple times. Its main purpose is to free up
487-
// any resources occupied after initialization and execution.
488-
virtual absl::Status Cleanup(const CleanupParams& params) {
489-
return absl::OkStatus();
490-
}
491-
492467
static absl::string_view KindToString(Thunk::Kind kind);
493468

494469
ExecutionStreamId execution_stream_id() const { return execution_stream_id_; }

xla/service/gpu/gpu_executable.cc

+2-12
Original file line numberDiff line numberDiff line change
@@ -347,18 +347,8 @@ absl::Status ExecuteThunks(
347347

348348
TF_RETURN_IF_ERROR(thunk_sequence.ExecuteOnStream(execute_params));
349349

350-
auto status =
351-
MaybeSyncAndProfile(run_options, execution_timer.get(),
352-
block_host_until_done ? main_stream : nullptr);
353-
354-
Thunk::CleanupParams cleanup_params{
355-
executor,
356-
&collective_params,
357-
&collective_cliques,
358-
};
359-
TF_RETURN_IF_ERROR(thunk_sequence.Cleanup(cleanup_params));
360-
361-
return status;
350+
return MaybeSyncAndProfile(run_options, execution_timer.get(),
351+
block_host_until_done ? main_stream : nullptr);
362352
}
363353

364354
namespace {

0 commit comments

Comments
 (0)