Skip to content

Commit e8aeed0

Browse files
pemeliyatensorflower-gardener
authored andcommitted
PR tensorflow#21886: [ROCM][NFC] BlasLt interface refactoring & simplifying: part I
Imported from GitHub PR openxla/xla#21886 After this PR tensorflow#73926 is merged, we can remove unnecessary low-level DoMatmul functions from GpuBlasLt interface (which otherwise looks scary and unnecessarily complicated). Furthermore, we can also remove **ValidateInputs** function from the interface and derived classes since a high-level **ExecuteOnStream** function already handles data-types correctly. This also greatly simplifies the code. Also, I have packed the input arguments of ExecuteOnStream calls to a struct **MemoryArgs** to simplify arguments passing in derived classes and improve code readability. Finally, in the original GpuBlasLt PR: openxla/xla#5911, I made a sort of mistake by adding a reference to **blas_lt** to the MatmulPlan class [here](https://github.com/openxla/xla/blob/main/xla/stream_executor/rocm/hip_blas_lt.h#L135), thereby making MatmulPlans bound to a **particular BlasLt instance**. This resulted in some further bugfixes and, most importantly, complicated GpuBlasLt cache design in gpublas_lt_matmul_thunk.cc/.h. In this PR, I remove this reference again from MatmulPlan class and in the next NFC PR the cache mechanics can also be simplified. Unfortunately, this change also requires a tandem PR for Tensorflow: tensorflow#85835 @xla-rotation Would you please have a look Copybara import of the project: -- e96bb2fbedab3f53b31ef0e1748582c76e9fb105 by Pavel Emeliyanenko <pavel.emeliyanenko@amd.com>: blaslt interface refactoring: removing blas_lt_ref added cuda adaptions cuda-side adaptions cuda side adaptions fix fixing pointers Merging this change closes tensorflow#21886 PiperOrigin-RevId: 727898957
1 parent 1b71f5e commit e8aeed0

File tree

12 files changed

+256
-554
lines changed

12 files changed

+256
-554
lines changed

tensorflow/core/kernels/matmul_util.cc

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
23
Licensed under the Apache License, Version 2.0 (the "License");
34
you may not use this file except in compliance with the License.
45
You may obtain a copy of the License at
6+
57
http://www.apache.org/licenses/LICENSE-2.0
8+
69
Unless required by applicable law or agreed to in writing, software
710
distributed under the License is distributed on an "AS IS" BASIS,
811
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -176,7 +179,7 @@ StatusOr<se::blas::ComputationType> GetBlasComputationType(
176179

177180
TF_ASSIGN_OR_RETURN(
178181
auto algorithms,
179-
plan->GetAlgorithms(*max_algorithm_count, max_scratch_size));
182+
plan->GetAlgorithms(stream, *max_algorithm_count, max_scratch_size));
180183

181184
ptr->second = {std::move(plan), std::move(algorithms)};
182185
}
@@ -201,9 +204,8 @@ Status PlanAndAlgorithms::ExecuteOnStream(
201204
se::DeviceMemoryBase{}, // c_scale_buffer
202205
se::DeviceMemoryBase{}, // d_scale_buffer
203206
se::DeviceMemoryBase{}, // d_amax_buffer
204-
algorithms[algorithm_idx],
205-
std::nullopt, // workspace
206-
&scratch_allocator, profile_result);
207+
algorithms[algorithm_idx], scratch_allocator,
208+
profile_result);
207209
}
208210

209211
} // namespace tensorflow

third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc

+8-5
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,7 @@ CublasLtCmd::CublasLtCmd(
12051205
workspace_buffer_(workspace_buffer) {}
12061206

12071207
absl::StatusOr<se::gpu::BlasLt::MatmulPlan*> CublasLtCmd::GetMatmulPlan(
1208-
const stream_executor::Stream* stream) {
1208+
const se::Stream* stream) {
12091209
auto it = matmul_plans_cache_.find(stream);
12101210
if (it != matmul_plans_cache_.end()) return it->second.get();
12111211
TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan(
@@ -1215,13 +1215,14 @@ absl::StatusOr<se::gpu::BlasLt::MatmulPlan*> CublasLtCmd::GetMatmulPlan(
12151215
}
12161216

12171217
absl::StatusOr<se::gpu::BlasLt::MatmulAlgorithm>
1218-
CublasLtCmd::GetMatmulAlgorithm(const se::gpu::BlasLt::MatmulPlan* plan,
1218+
CublasLtCmd::GetMatmulAlgorithm(const se::Stream* stream,
1219+
const se::gpu::BlasLt::MatmulPlan* plan,
12191220
int64_t max_workspace) {
12201221
auto it = matmul_algorithm_cache_.find(plan);
12211222
if (it != matmul_algorithm_cache_.end()) return it->second;
12221223
TF_ASSIGN_OR_RETURN(
12231224
auto algorithms,
1224-
plan->GetAlgorithms(/*max_algorithm_count*/ 128,
1225+
plan->GetAlgorithms(stream, /*max_algorithm_count*/ 128,
12251226
/*max_workspace_size*/ max_workspace));
12261227
TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size());
12271228
auto [it_insert, _] =
@@ -1237,7 +1238,8 @@ absl::Status CublasLtCmd::Initialize(const Thunk::InitializeParams& params,
12371238
// Populate plan and algorithm cache;
12381239
TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(params.stream));
12391240
TF_RETURN_IF_ERROR(
1240-
GetMatmulAlgorithm(plan, workspace_buffer_.size()).status());
1241+
GetMatmulAlgorithm(params.stream, plan, workspace_buffer_.size())
1242+
.status());
12411243
return absl::OkStatus();
12421244
}
12431245

@@ -1246,7 +1248,8 @@ absl::Status CublasLtCmd::Record(const Thunk::ExecuteParams& execute_params,
12461248
se::CommandBuffer* command_buffer) {
12471249
TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(execute_params.stream));
12481250
TF_ASSIGN_OR_RETURN(auto algorithm,
1249-
GetMatmulAlgorithm(plan, workspace_buffer_.size()));
1251+
GetMatmulAlgorithm(execute_params.stream, plan,
1252+
workspace_buffer_.size()));
12501253

12511254
const BufferAllocations& allocs = *execute_params.buffer_allocations;
12521255

third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -803,13 +803,13 @@ class CublasLtCmd : public TracedCommandBufferCmd {
803803

804804
private:
805805
absl::StatusOr<se::gpu::BlasLt::MatmulPlan*> GetMatmulPlan(
806-
const stream_executor::Stream* stream);
806+
const se::Stream* stream);
807807

808808
absl::StatusOr<se::gpu::BlasLt::MatmulAlgorithm> GetMatmulAlgorithm(
809-
const se::gpu::BlasLt::MatmulPlan* plan, int64_t max_workspace);
809+
const se::Stream* stream, const se::gpu::BlasLt::MatmulPlan* plan,
810+
int64_t max_workspace);
810811

811-
absl::flat_hash_map<const stream_executor::Stream*,
812-
se::gpu::BlasLt::MatmulPlanPtr>
812+
absl::flat_hash_map<const se::Stream*, se::gpu::BlasLt::MatmulPlanPtr>
813813
matmul_plans_cache_;
814814

815815
absl::flat_hash_map<const se::gpu::BlasLt::MatmulPlan*,

third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc

+11-9
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ CublasLtMatmulThunk::CublasLtMatmulThunk(
6565
absl::Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) {
6666
TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(params.stream));
6767

68-
TF_ASSIGN_OR_RETURN(
69-
auto algorithm,
70-
GetMatmulAlgorithm(plan, workspace_buffer_.has_value()
71-
? workspace_buffer_.value().size()
72-
: 0));
68+
TF_ASSIGN_OR_RETURN(auto algorithm,
69+
GetMatmulAlgorithm(params.stream, plan,
70+
workspace_buffer_.has_value()
71+
? workspace_buffer_.value().size()
72+
: 0));
7373

7474
VLOG(3) << "Running cublas_lt matmul thunk";
7575
const BufferAllocations& allocs = *params.buffer_allocations;
@@ -99,7 +99,7 @@ absl::Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) {
9999
aux = allocs.GetDeviceAddress(aux_buffer_);
100100
}
101101

102-
std::optional<se::DeviceMemoryBase> workspace;
102+
se::DeviceMemoryBase workspace;
103103
if (workspace_buffer_.has_value()) {
104104
workspace = allocs.GetDeviceAddress(workspace_buffer_.value());
105105
}
@@ -112,7 +112,7 @@ absl::Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) {
112112
}
113113

114114
absl::StatusOr<se::gpu::BlasLt::MatmulPlan*> CublasLtMatmulThunk::GetMatmulPlan(
115-
const stream_executor::Stream* stream) {
115+
const se::Stream* stream) {
116116
{
117117
absl::MutexLock lock(&matmul_plans_cache_mutex_);
118118
auto it = matmul_plans_cache_.find(stream);
@@ -127,7 +127,8 @@ absl::StatusOr<se::gpu::BlasLt::MatmulPlan*> CublasLtMatmulThunk::GetMatmulPlan(
127127
}
128128

129129
absl::StatusOr<se::gpu::BlasLt::MatmulAlgorithm>
130-
CublasLtMatmulThunk::GetMatmulAlgorithm(const se::gpu::BlasLt::MatmulPlan* plan,
130+
CublasLtMatmulThunk::GetMatmulAlgorithm(const se::Stream* stream,
131+
const se::gpu::BlasLt::MatmulPlan* plan,
131132
int64_t max_workspace) {
132133
{
133134
absl::MutexLock lock(&matmul_algorithm_cache_mutex_);
@@ -136,7 +137,8 @@ CublasLtMatmulThunk::GetMatmulAlgorithm(const se::gpu::BlasLt::MatmulPlan* plan,
136137
}
137138
TF_ASSIGN_OR_RETURN(
138139
auto algorithms,
139-
plan->GetAlgorithms(/*max_algorithm_count*/ 128,
140+
plan->GetAlgorithms(stream,
141+
/*max_algorithm_count*/ 128,
140142
/*max_workspace_size*/ max_workspace));
141143
TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size());
142144

third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ class CublasLtMatmulThunk : public Thunk {
7474
absl::StatusOr<se::gpu::BlasLt::MatmulPlan*> GetMatmulPlan(
7575
const stream_executor::Stream* stream);
7676
absl::StatusOr<se::gpu::BlasLt::MatmulAlgorithm> GetMatmulAlgorithm(
77-
const se::gpu::BlasLt::MatmulPlan* plan, int64_t max_workspace);
77+
const se::Stream* stream, const se::gpu::BlasLt::MatmulPlan* plan,
78+
int64_t max_workspace);
7879

7980
absl::Mutex matmul_plans_cache_mutex_;
8081
absl::flat_hash_map<const stream_executor::Stream*,

third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class GemmAutotuner {
186186

187187
TF_ASSIGN_OR_RETURN(
188188
auto algorithms,
189-
plan->GetAlgorithms(/*max_algorithm_count*/ 128,
189+
plan->GetAlgorithms(stream_, /*max_algorithm_count*/ 128,
190190
/*max_workspace_size*/ workspace_buffer.size()));
191191

192192
auto tuned_func = [&](const BlasLt::MatmulAlgorithm& algorithm)

0 commit comments

Comments
 (0)