Skip to content

Commit ae6b8c0

Browse files
committed
[GPU] Optimized operations in the blas kernels with the latest buffer changes.
Updated the pipeline for both fp32 and fp16. SGEMM, SGEMV, DotCL, SSACL, Transpose ops updated. Signed-off-by: Niket Agarwal <niket.a@samsung.com>
1 parent 8184b61 commit ae6b8c0

File tree

2 files changed

+99
-123
lines changed

2 files changed

+99
-123
lines changed

nntrainer/tensor/cl_operations/blas_kernels.cpp

Lines changed: 42 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -120,31 +120,26 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1) {
120120

121121
size_t dim1_size = sizeof(float) * dim1;
122122

123-
opencl::Buffer inputA(cl_context_ref.context_inst_, dim1_size, true,
124-
nullptr);
125-
126-
opencl::Buffer inputX(cl_context_ref.context_inst_, dim1_size, true,
127-
nullptr);
128-
129-
opencl::Buffer dotResult(cl_context_ref.context_inst_, sizeof(float), true,
130-
&cl_ret);
131-
132-
result = inputA.WriteData(cl_context_ref.command_queue_inst_, vecAdata);
123+
result = clbuffInstance.getInBufferA()->WriteDataRegion(
124+
cl_context_ref.command_queue_inst_, dim1_size, vecAdata);
133125
if (!result) {
134126
break;
135127
}
136128

137-
result = inputX.WriteData(cl_context_ref.command_queue_inst_, vecXdata);
129+
result = clbuffInstance.getInBufferB()->WriteDataRegion(
130+
cl_context_ref.command_queue_inst_, dim1_size, vecXdata);
138131
if (!result) {
139132
break;
140133
}
141134

142-
result = kernel_dot_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem));
135+
result = kernel_dot_ptr->SetKernelArguments(
136+
0, clbuffInstance.getInBufferA(), sizeof(cl_mem));
143137
if (!result) {
144138
break;
145139
}
146140

147-
result = kernel_dot_ptr->SetKernelArguments(1, &inputX, sizeof(cl_mem));
141+
result = kernel_dot_ptr->SetKernelArguments(
142+
1, clbuffInstance.getInBufferB(), sizeof(cl_mem));
148143
if (!result) {
149144
break;
150145
}
@@ -154,7 +149,8 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1) {
154149
break;
155150
}
156151

157-
result = kernel_dot_ptr->SetKernelArguments(3, &dotResult, sizeof(cl_mem));
152+
result = kernel_dot_ptr->SetKernelArguments(
153+
3, clbuffInstance.getOutBufferA(), sizeof(cl_mem));
158154
if (!result) {
159155
break;
160156
}
@@ -168,7 +164,8 @@ float dot_cl(const float *vecAdata, const float *vecXdata, unsigned int dim1) {
168164
break;
169165
}
170166

171-
result = dotResult.ReadData(cl_context_ref.command_queue_inst_, &cl_ret);
167+
result = clbuffInstance.getOutBufferA()->ReadDataRegion(
168+
cl_context_ref.command_queue_inst_, sizeof(float), &cl_ret);
172169
if (!result) {
173170
break;
174171
}
@@ -213,41 +210,38 @@ void sgemm_cl(bool TransA, bool TransB, const float *A, const float *B,
213210
size_t k_n_size = K * N * sizeof(float);
214211
size_t m_n_size = M * N * sizeof(float);
215212

216-
opencl::Buffer inputA(cl_context_ref.context_inst_, m_k_size, true,
217-
nullptr);
218-
219-
opencl::Buffer inputB(cl_context_ref.context_inst_, k_n_size, true,
220-
nullptr);
221-
222-
opencl::Buffer inOutC(cl_context_ref.context_inst_, m_n_size, true,
223-
nullptr);
224-
225-
result = inputA.WriteData(cl_context_ref.command_queue_inst_, A);
213+
result = clbuffInstance.getInBufferA()->WriteDataRegion(
214+
cl_context_ref.command_queue_inst_, m_k_size, A);
226215
if (!result) {
227216
break;
228217
}
229218

230-
result = inputB.WriteData(cl_context_ref.command_queue_inst_, B);
219+
result = clbuffInstance.getInBufferB()->WriteDataRegion(
220+
cl_context_ref.command_queue_inst_, k_n_size, B);
231221
if (!result) {
232222
break;
233223
}
234224

235-
result = inOutC.WriteData(cl_context_ref.command_queue_inst_, C);
225+
result = clbuffInstance.getOutBufferA()->WriteDataRegion(
226+
cl_context_ref.command_queue_inst_, m_n_size, C);
236227
if (!result) {
237228
break;
238229
}
239230

240-
result = kernel_sgemm_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem));
231+
result = kernel_sgemm_ptr->SetKernelArguments(
232+
0, clbuffInstance.getInBufferA(), sizeof(cl_mem));
241233
if (!result) {
242234
break;
243235
}
244236

245-
result = kernel_sgemm_ptr->SetKernelArguments(1, &inputB, sizeof(cl_mem));
237+
result = kernel_sgemm_ptr->SetKernelArguments(
238+
1, clbuffInstance.getInBufferB(), sizeof(cl_mem));
246239
if (!result) {
247240
break;
248241
}
249242

250-
result = kernel_sgemm_ptr->SetKernelArguments(2, &inOutC, sizeof(cl_mem));
243+
result = kernel_sgemm_ptr->SetKernelArguments(
244+
2, clbuffInstance.getOutBufferA(), sizeof(cl_mem));
251245
if (!result) {
252246
break;
253247
}
@@ -281,7 +275,8 @@ void sgemm_cl(bool TransA, bool TransB, const float *A, const float *B,
281275
break;
282276
}
283277

284-
result = inOutC.ReadData(cl_context_ref.command_queue_inst_, C);
278+
result = clbuffInstance.getOutBufferA()->ReadDataRegion(
279+
cl_context_ref.command_queue_inst_, m_n_size, C);
285280
if (!result) {
286281
break;
287282
}
@@ -372,14 +367,14 @@ void sscal_cl(float *X, const unsigned int N, const float alpha) {
372367

373368
size_t x_size = N * sizeof(float);
374369

375-
opencl::Buffer inputX(cl_context_ref.context_inst_, x_size, false, nullptr);
376-
377-
result = inputX.WriteData(cl_context_ref.command_queue_inst_, X);
370+
result = clbuffInstance.getOutBufferA()->WriteDataRegion(
371+
cl_context_ref.command_queue_inst_, x_size, X);
378372
if (!result) {
379373
break;
380374
}
381375

382-
result = kernel_ptr->SetKernelArguments(0, &inputX, sizeof(cl_mem));
376+
result = kernel_ptr->SetKernelArguments(0, clbuffInstance.getOutBufferA(),
377+
sizeof(cl_mem));
383378
if (!result) {
384379
break;
385380
}
@@ -398,7 +393,8 @@ void sscal_cl(float *X, const unsigned int N, const float alpha) {
398393
break;
399394
}
400395

401-
result = inputX.ReadData(cl_context_ref.command_queue_inst_, X);
396+
result = clbuffInstance.getOutBufferA()->ReadDataRegion(
397+
cl_context_ref.command_queue_inst_, x_size, X);
402398
if (!result) {
403399
break;
404400
}
@@ -439,30 +435,26 @@ void transpose_cl_axis(const float *in, float *res,
439435
size_t dim_size = sizeof(float) * input_batch_size * input_height *
440436
input_width * input_channels;
441437

442-
opencl::Buffer inputA(cl_context_ref.context_inst_, dim_size, true,
443-
nullptr);
444-
445-
opencl::Buffer inOutRes(cl_context_ref.context_inst_, dim_size, true,
446-
nullptr);
447-
448-
result = inputA.WriteData(cl_context_ref.command_queue_inst_, in);
438+
result = clbuffInstance.getInBufferA()->WriteDataRegion(
439+
cl_context_ref.command_queue_inst_, dim_size, in);
449440
if (!result) {
450441
break;
451442
}
452443

453-
result = inOutRes.WriteData(cl_context_ref.command_queue_inst_, res);
444+
result = clbuffInstance.getOutBufferA()->WriteDataRegion(
445+
cl_context_ref.command_queue_inst_, dim_size, res);
454446
if (!result) {
455447
break;
456448
}
457449

458-
result =
459-
kernel_transpose_ptr->SetKernelArguments(0, &inputA, sizeof(cl_mem));
450+
result = kernel_transpose_ptr->SetKernelArguments(
451+
0, clbuffInstance.getInBufferA(), sizeof(cl_mem));
460452
if (!result) {
461453
break;
462454
}
463455

464-
result =
465-
kernel_transpose_ptr->SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
456+
result = kernel_transpose_ptr->SetKernelArguments(
457+
1, clbuffInstance.getOutBufferA(), sizeof(cl_mem));
466458
if (!result) {
467459
break;
468460
}
@@ -503,7 +495,8 @@ void transpose_cl_axis(const float *in, float *res,
503495
break;
504496
}
505497

506-
result = inOutRes.ReadData(cl_context_ref.command_queue_inst_, res);
498+
result = clbuffInstance.getOutBufferA()->ReadDataRegion(
499+
cl_context_ref.command_queue_inst_, dim_size, res);
507500
if (!result) {
508501
break;
509502
}

0 commit comments

Comments
 (0)