Skip to content

Commit

Permalink
temp: 尝试仍然使用 cub reduce
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Feb 1, 2024
1 parent 26f4b19 commit fe1ffc6
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 47 deletions.
31 changes: 0 additions & 31 deletions src/04kernel/cuda/include/kernel/cuda/reduce.cuh

This file was deleted.

24 changes: 17 additions & 7 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#include "../../utilities/cuda/cublaslt_utils.cuh"
#include "cuda_kernel.hh"
#include "hardware/functions.h"
#include "kernel/cuda/reduce.cuh"
#include "kernel/cuda/functions.cuh"
#include <cub/block/block_reduce.cuh>

namespace refactor::kernel {
using K = AttentionCuda;
Expand All @@ -27,7 +28,7 @@ namespace refactor::kernel {

// gridDim.x = batch * nHead
// gridDim.y = seqLen
// blockDim.x = min(1024, attLen)
// blockDim.x = 1024
// sizeof(shared) = attLen * sizeof(float)
template<class T, class Mask>
static __global__ void softmax(
Expand All @@ -36,25 +37,34 @@ namespace refactor::kernel {
uint32_t attLen,
uint32_t bufLen) {
// 找到这个线程块对应的 attention 区域
att += (blockIdx.x * gridDim.x + gridDim.y) * bufLen;
att += (blockIdx.x * gridDim.x + blockIdx.y) * bufLen;
// 将输入装入共享内存并 cast + mask
extern __shared__ float shared[];// size = attLen = pastSeqLen + seqLen
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
shared[i] = mask(blockIdx.y, gridDim.y, i, attLen) ? float(att[i]) : -__FLT_MAX__;
}

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage tempStorage;
__shared__ float sharedMax, sharedSum;

float localMax = -1e20;
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
localMax = cub::Max()(localMax, shared[i]);
}
localMax = cuda::blockReduce(localMax, -1e20f, cub::Max());
localMax = BlockReduce(tempStorage).Reduce(localMax, cub::Max(), attLen);
if (threadIdx.x == 0) { sharedMax = localMax; }
__syncthreads();

float localSum = 1e-20;
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
localSum += shared[i] = expf(shared[i] - localMax);
localSum += shared[i] = expf(shared[i] - sharedMax);
}
localSum = cuda::blockReduce(localSum, 1e-20f, cub::Sum());
auto reciprocal = fdividef(1, localSum);
localSum = BlockReduce(tempStorage).Reduce(localSum, cub::Sum(), attLen);
if (threadIdx.x == 0) { sharedSum = localSum; }
__syncthreads();

auto reciprocal = fdividef(1, sharedSum);
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
att[i] = shared[i] * reciprocal;
}
Expand Down
12 changes: 7 additions & 5 deletions src/04kernel/src/kernels/softmax/cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "cuda_kernel.hh"
#include "kernel/cuda/reduce.cuh"
#include <cub/cub.cuh>

namespace refactor::kernel {
using namespace runtime;
Expand All @@ -18,8 +18,8 @@ namespace refactor::kernel {
template<> __device__ __forceinline__ nv_bfloat16 reciprocal<nv_bfloat16>(nv_bfloat16 x) { return hrcp(x); }

// blockDim.x === BLOCK_DIM
template<class T>
__global__ void blockSoftmaxKernel(
template<int BLOCK_DIM, class T>
__launch_bounds__(BLOCK_DIM) __global__ void blockSoftmaxKernel(
T const *__restrict x,
T *__restrict y,
int mid,
Expand All @@ -40,8 +40,10 @@ namespace refactor::kernel {
for (int i = threadIdx.x + blockDim.x; i < mid; i += blockDim.x) {
maxSumThread = MaxSum::reduce(maxSumThread, {x[id + i * stride], 1});// reduce the data to one block
}
using BlockReduce = cub::BlockReduce<MaxSum, BLOCK_DIM>;
__shared__ typename BlockReduce::TempStorage tempStorage;
__shared__ MaxSum maxSumTotal;
auto maxSumBlock = cuda::blockReduce(maxSumThread, {-__FLT_MAX__, 0}, MaxSum::reduce);
auto maxSumBlock = BlockReduce(tempStorage).Reduce(maxSumThread, MaxSum::reduce);
if (threadIdx.x == 0) {
maxSumTotal = maxSumBlock;// must set threadIdx.x = 0 write the output to memory
}
Expand Down Expand Up @@ -111,7 +113,7 @@ namespace refactor::kernel {
auto y = reinterpret_cast<T *>(outputs[0]);
int numBlocks = info.pre * info.post;
if (info.mid > 1024) {
blockSoftmaxKernel<<<numBlocks, 1024>>>(x, y, info.mid, info.post);
blockSoftmaxKernel<1024><<<numBlocks, 1024>>>(x, y, info.mid, info.post);
} else {
int blockDimX, mid = static_cast<int>(info.mid);
for (blockDimX = 32; blockDimX > 4 && mid < blockDimX; blockDimX /= 2) {}
Expand Down
17 changes: 13 additions & 4 deletions src/04kernel/test/kernels/attention/test_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using namespace hardware;
TEST(kernel, AttentionCudaNoKvCache) {
// build routine
AttentionInfo info{
.dataType = DataType::FP16,
.dataType = DataType::F32,
.batch = 1,
.nHead = 4,
.nKVHead = 4,
Expand All @@ -23,9 +23,9 @@ TEST(kernel, AttentionCudaNoKvCache) {
.concatCache = false,
.resetCache = false,
};
auto q = Tensor::share(DataType::FP16, Shape{info.batch, info.nHead, info.seqLen, info.headDim}),
k = Tensor::share(DataType::FP16, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}),
v = Tensor::share(DataType::FP16, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}),
auto q = Tensor::share(DataType::F32, Shape{info.batch, info.nHead, info.seqLen, info.headDim}),
k = Tensor::share(DataType::F32, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}),
v = Tensor::share(DataType::F32, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}),
o = q;
auto kernel = AttentionCuda::build(info);
ASSERT_TRUE(kernel);
Expand All @@ -38,6 +38,15 @@ TEST(kernel, AttentionCudaNoKvCache) {
vGpu = dev.malloc(v->bytesSize()),
oGpu = dev.malloc(o->bytesSize()),
workspace = dev.malloc(workspaceSize);
// put input data
std::vector<float>
q_(q->elementsSize(), 1),
k_(k->elementsSize(), 1),
v_(v->elementsSize(), 1),
o_(o->elementsSize());
qGpu->copyFromHost(q_.data());
kGpu->copyFromHost(k_.data());
vGpu->copyFromHost(v_.data());
// inference
{
void const *inputs[]{*qGpu, *kGpu, *vGpu};
Expand Down

0 comments on commit fe1ffc6

Please sign in to comment.