Skip to content

Commit 2633e0f

Browse files
committed
[GPU] Abstraction of cl_buffer_manager
Adding abstraction in cl_buffer_manager using const for data size Signed-off-by: Debadri Samaddar <s.debadri@samsung.com>
1 parent dab2d91 commit 2633e0f

File tree

3 files changed

+62
-27
lines changed

3 files changed

+62
-27
lines changed

nntrainer/cl_buffer_manager.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,20 @@ ClBufferManager &ClBufferManager::getInstance() {
2222
// to-do: Implementation to be updated with array of Buffer objects if required
2323
// fp16 Buffer objects to be added in future
2424
void ClBufferManager::initBuffers() {
25-
readBufferA = new opencl::Buffer(context_inst_, buffer_size_bytes, true);
26-
readBufferB = new opencl::Buffer(context_inst_, buffer_size_bytes, true);
27-
readBufferC = new opencl::Buffer(context_inst_, buffer_size_bytes, true);
28-
writeBufferA = new opencl::Buffer(context_inst_, buffer_size_bytes, false);
29-
writeBufferB = new opencl::Buffer(context_inst_, buffer_size_bytes, false);
25+
inBufferA = new opencl::Buffer(context_inst_, buffer_size_bytes, true);
26+
inBufferB = new opencl::Buffer(context_inst_, buffer_size_bytes, true);
27+
inBufferC = new opencl::Buffer(context_inst_, buffer_size_bytes, true);
28+
outBufferA = new opencl::Buffer(context_inst_, buffer_size_bytes, false);
29+
outBufferB = new opencl::Buffer(context_inst_, buffer_size_bytes, false);
3030
ml_logi("ClBufferManager: Buffers initialized");
3131
}
3232

3333
ClBufferManager::~ClBufferManager() {
34-
delete readBufferA;
35-
delete readBufferB;
36-
delete readBufferC;
37-
delete writeBufferA;
38-
delete writeBufferB;
34+
delete inBufferA;
35+
delete inBufferB;
36+
delete inBufferC;
37+
delete outBufferA;
38+
delete outBufferB;
3939
ml_logi("ClBufferManager: Buffers destroyed");
4040
}
4141

nntrainer/cl_buffer_manager.h

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ class ClBufferManager {
3434
* @brief Private constructor to prevent object creation
3535
*
3636
*/
37-
ClBufferManager(){};
37+
ClBufferManager() :
38+
inBufferA(nullptr),
39+
inBufferB(nullptr),
40+
inBufferC(nullptr),
41+
outBufferA(nullptr),
42+
outBufferB(nullptr){};
3843

3944
/**
4045
* @brief OpenCl context global instance
@@ -45,7 +50,13 @@ class ClBufferManager {
4550
/**
4651
* @brief Buffer size in bytes preset (256 mebibytes)
4752
*/
48-
size_t buffer_size_bytes = 8192 * 8192 * sizeof(float);
53+
const size_t buffer_size_bytes = 8192 * 8192 * sizeof(float);
54+
55+
opencl::Buffer *inBufferA;
56+
opencl::Buffer *inBufferB;
57+
opencl::Buffer *inBufferC;
58+
opencl::Buffer *outBufferA;
59+
opencl::Buffer *outBufferB;
4960

5061
public:
5162
/**
@@ -55,17 +66,41 @@ class ClBufferManager {
5566
*/
5667
static ClBufferManager &getInstance();
5768

58-
opencl::Buffer *readBufferA;
59-
opencl::Buffer *readBufferB;
60-
opencl::Buffer *readBufferC;
61-
opencl::Buffer *writeBufferA;
62-
opencl::Buffer *writeBufferB;
63-
6469
/**
6570
* @brief Initialize Buffer objects.
6671
*/
6772
void initBuffers();
6873

74+
/**
75+
* @brief Get read only inBufferA.
76+
* @return opencl::Buffer* or nullptr if initBuffers() is not called
77+
*/
78+
opencl::Buffer *getInBufferA() { return inBufferA; }
79+
80+
/**
81+
* @brief Get read only inBufferB.
82+
* @return opencl::Buffer* or nullptr if initBuffers() is not called
83+
*/
84+
opencl::Buffer *getInBufferB() { return inBufferB; }
85+
86+
/**
87+
* @brief Get read only inBufferC.
88+
* @return opencl::Buffer* or nullptr if initBuffers() is not called
89+
*/
90+
opencl::Buffer *getInBufferC() { return inBufferC; }
91+
92+
/**
93+
* @brief Get read-write outBufferA.
94+
* @return opencl::Buffer* or nullptr if initBuffers() is not called
95+
*/
96+
opencl::Buffer *getOutBufferA() { return outBufferA; }
97+
98+
/**
99+
* @brief Get read-write outBufferB.
100+
* @return opencl::Buffer* or nullptr if initBuffers() is not called
101+
*/
102+
opencl::Buffer *getOutBufferB() { return outBufferB; }
103+
69104
/**
70105
* @brief Destroy Buffer pointers.
71106
*

nntrainer/tensor/cl_operations/blas_kernels.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,39 +40,39 @@ void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata,
4040
size_t dim1_size = sizeof(float) * dim1;
4141
size_t dim2_size = sizeof(float) * dim2;
4242

43-
result = clbuffInstance.readBufferA->WriteDataRegion(
43+
result = clbuffInstance.getInBufferA()->WriteDataRegion(
4444
cl_context_ref.command_queue_inst_, dim1 * dim2 * sizeof(float),
4545
matAdata);
4646
if (!result) {
4747
break;
4848
}
4949

50-
result = clbuffInstance.readBufferB->WriteDataRegion(
50+
result = clbuffInstance.getInBufferB()->WriteDataRegion(
5151
cl_context_ref.command_queue_inst_, dim2_size, vecXdata);
5252
if (!result) {
5353
break;
5454
}
5555

56-
result = clbuffInstance.writeBufferA->WriteDataRegion(
56+
result = clbuffInstance.getOutBufferA()->WriteDataRegion(
5757
cl_context_ref.command_queue_inst_, dim1_size, vecYdata);
5858
if (!result) {
5959
break;
6060
}
6161

62-
result = kernel_sgemv_ptr->SetKernelArguments(0, clbuffInstance.readBufferA,
63-
sizeof(cl_mem));
62+
result = kernel_sgemv_ptr->SetKernelArguments(
63+
0, clbuffInstance.getInBufferA(), sizeof(cl_mem));
6464
if (!result) {
6565
break;
6666
}
6767

68-
result = kernel_sgemv_ptr->SetKernelArguments(1, clbuffInstance.readBufferB,
69-
sizeof(cl_mem));
68+
result = kernel_sgemv_ptr->SetKernelArguments(
69+
1, clbuffInstance.getInBufferB(), sizeof(cl_mem));
7070
if (!result) {
7171
break;
7272
}
7373

7474
result = kernel_sgemv_ptr->SetKernelArguments(
75-
2, clbuffInstance.writeBufferA, sizeof(cl_mem));
75+
2, clbuffInstance.getOutBufferA(), sizeof(cl_mem));
7676
if (!result) {
7777
break;
7878
}
@@ -96,7 +96,7 @@ void sgemv_cl(const float *matAdata, const float *vecXdata, float *vecYdata,
9696
break;
9797
}
9898

99-
result = clbuffInstance.writeBufferA->ReadDataRegion(
99+
result = clbuffInstance.getOutBufferA()->ReadDataRegion(
100100
cl_context_ref.command_queue_inst_, dim1_size, vecYdata);
101101
if (!result) {
102102
break;

0 commit comments

Comments
 (0)