Skip to content

Commit

Permalink
refactor(kernel): 实现一种不依赖模板参数的 BlockReduce 并用于 softmax
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Jan 31, 2024
1 parent f5197bb commit fa6e5b4
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 14 deletions.
24 changes: 23 additions & 1 deletion src/04kernel/cuda/include/kernel/cuda/reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,28 @@
#include <cub/warp/warp_reduce.cuh>

namespace refactor::kernel::cuda {
}

template<class T, class ReductionOp>
__inline__ __device__ T blockReduce(T x, T init, ReductionOp op) {
using WarpReduce = cub::WarpReduce<T>;
__shared__ typename WarpReduce::TempStorage tempStorage;
__shared__ T shared[32], ans;

auto reduce = WarpReduce(tempStorage);
int lane = threadIdx.x % 32;
int wid = threadIdx.x / 32;
x = reduce.Reduce(x, op);
if (lane == 0) { shared[wid] = x; }
__syncthreads();
if (wid == 0) {
x = (threadIdx.x < blockDim.x / 32) ? shared[lane] : init;
shared[lane] = reduce.Reduce(x, op);
if (lane == 0) { ans = shared[0]; }
}
__syncthreads();
return ans;// avoid RAW hazard
}

}// namespace refactor::kernel::cuda

#endif// KERNEL_CUDA_REDUCE_CUH
12 changes: 5 additions & 7 deletions src/04kernel/src/kernels/softmax/cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "cuda_kernel.hh"
#include <cub/cub.cuh>
#include "kernel/cuda/reduce.cuh"

namespace refactor::kernel {
using namespace runtime;
Expand All @@ -18,8 +18,8 @@ namespace refactor::kernel {
template<> __device__ __forceinline__ nv_bfloat16 reciprocal<nv_bfloat16>(nv_bfloat16 x) { return hrcp(x); }

// blockDim.x === BLOCK_DIM
template<int BLOCK_DIM, class T>
__launch_bounds__(BLOCK_DIM) __global__ void blockSoftmaxKernel(
template<class T>
__global__ void blockSoftmaxKernel(
T const *__restrict x,
T *__restrict y,
int mid,
Expand All @@ -40,10 +40,8 @@ namespace refactor::kernel {
for (int i = threadIdx.x + blockDim.x; i < mid; i += blockDim.x) {
maxSumThread = MaxSum::reduce(maxSumThread, {x[id + i * stride], 1});// reduce the data to one block
}
using BlockReduce = cub::BlockReduce<MaxSum, BLOCK_DIM>;
__shared__ typename BlockReduce::TempStorage tempStorage;
__shared__ MaxSum maxSumTotal;
auto maxSumBlock = BlockReduce(tempStorage).Reduce(maxSumThread, MaxSum::reduce);
auto maxSumBlock = cuda::blockReduce(maxSumThread, {-__FLT_MAX__, 0}, MaxSum::reduce);
if (threadIdx.x == 0) {
maxSumTotal = maxSumBlock;// must set threadIdx.x = 0 write the output to memory
}
Expand Down Expand Up @@ -113,7 +111,7 @@ namespace refactor::kernel {
auto y = reinterpret_cast<T *>(outputs[0]);
int numBlocks = info.pre * info.post;
if (info.mid > 1024) {
blockSoftmaxKernel<1024><<<numBlocks, 1024>>>(x, y, info.mid, info.post);
blockSoftmaxKernel<<<numBlocks, 1024>>>(x, y, info.mid, info.post);
} else {
int blockDimX, mid = static_cast<int>(info.mid);
for (blockDimX = 32; blockDimX > 4 && mid < blockDimX; blockDimX /= 2) {}
Expand Down
19 changes: 13 additions & 6 deletions src/04kernel/test/kernels/softmax/test_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
#include "../../../src/kernels/softmax/cuda_kernel.hh"
#include "hardware/device_manager.h"
#include <gtest/gtest.h>
#include <numeric>

using namespace refactor;
using namespace kernel;
using namespace hardware;

TEST(kernel, SoftmaxCuda) {
static void test(Shape shape, int axis) {
// build routine
auto xTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4});
auto outTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4});
dim_t axis = 1;
auto kCpu = SoftmaxCpu::build(SoftmaxInfo(*xTensor, axis));
auto kCuda = SoftmaxCuda::build(SoftmaxInfo(*xTensor, axis));
auto xTensor = Tensor::share(DataType::F32, shape);
auto outTensor = Tensor::share(DataType::F32, shape);
SoftmaxInfo info(*xTensor, axis);
auto kCpu = SoftmaxCpu::build(info);
auto kCuda = SoftmaxCuda::build(info);
ASSERT_TRUE(kCpu && kCuda);
auto res = runtime::Resources();
auto rCpu = kCpu->lower(res).routine;
Expand All @@ -28,6 +29,7 @@ TEST(kernel, SoftmaxCuda) {
std::vector<float>
data(xTensor->elementsSize(), 0),
cpuOut(outTensor->elementsSize());
std::iota(data.begin(), data.end(), 0);
gpuIn->copyFromHost(data.data(), xTensor->bytesSize());
// inference
{
Expand All @@ -49,4 +51,9 @@ TEST(kernel, SoftmaxCuda) {
}
}

TEST(kernel, SoftmaxCuda) {
test({2, 3, 2, 5, 4}, 1);
test({2, 2048, 2, 5, 4}, 1);
}

#endif

0 comments on commit fa6e5b4

Please sign in to comment.