From fe1ffc64c4d3c250a592d321e90e11f8bf0605aa Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Thu, 1 Feb 2024 14:50:57 +0800 Subject: [PATCH] =?UTF-8?q?temp:=20=E5=B0=9D=E8=AF=95=E4=BB=8D=E7=84=B6?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=20cub=20reduce?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../cuda/include/kernel/cuda/reduce.cuh | 31 ------------------- .../src/kernels/attention/cuda_kernel.cu | 24 +++++++++----- .../src/kernels/softmax/cuda_kernel.cu | 12 ++++--- .../test/kernels/attention/test_cuda.cpp | 17 +++++++--- 4 files changed, 37 insertions(+), 47 deletions(-) delete mode 100644 src/04kernel/cuda/include/kernel/cuda/reduce.cuh diff --git a/src/04kernel/cuda/include/kernel/cuda/reduce.cuh b/src/04kernel/cuda/include/kernel/cuda/reduce.cuh deleted file mode 100644 index 6a5be4a3..00000000 --- a/src/04kernel/cuda/include/kernel/cuda/reduce.cuh +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef KERNEL_CUDA_REDUCE_CUH -#define KERNEL_CUDA_REDUCE_CUH - -#include - -namespace refactor::kernel::cuda { - - template - __inline__ __device__ T blockReduce(T x, T init, ReductionOp op) { - using WarpReduce = cub::WarpReduce; - __shared__ typename WarpReduce::TempStorage tempStorage; - __shared__ T shared[32], ans; - - auto reduce = WarpReduce(tempStorage); - int lane = threadIdx.x % 32; - int wid = threadIdx.x / 32; - x = reduce.Reduce(x, op); - if (lane == 0) { shared[wid] = x; } - __syncthreads(); - if (wid == 0) { - x = (threadIdx.x < blockDim.x / 32) ? shared[lane] : init; - shared[lane] = reduce.Reduce(x, op); - if (lane == 0) { ans = shared[0]; } - } - __syncthreads(); - return ans;// avoid RAW hazard - } - -}// namespace refactor::kernel::cuda - -#endif// KERNEL_CUDA_REDUCE_CUH diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index c66c2c43..1196e753 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -1,7 +1,8 @@ #include "../../utilities/cuda/cublaslt_utils.cuh" #include "cuda_kernel.hh" #include "hardware/functions.h" -#include "kernel/cuda/reduce.cuh" +#include "kernel/cuda/functions.cuh" +#include namespace refactor::kernel { using K = AttentionCuda; @@ -27,7 +28,7 @@ namespace refactor::kernel { // gridDim.x = batch * nHead // gridDim.y = seqLen - // blockDim.x = min(1024, attLen) + // blockDim.x = 1024 // sizeof(shared) = attLen * sizeof(float) template static __global__ void softmax( @@ -36,25 +37,34 @@ namespace refactor::kernel { uint32_t attLen, uint32_t bufLen) { // 找到这个线程块对应的 attention 区域 - att += (blockIdx.x * gridDim.x + gridDim.y) * bufLen; + att += (blockIdx.x * gridDim.x + blockIdx.y) * bufLen; // 将输入装入共享内存并 cast + mask extern __shared__ float shared[];// size = attLen = pastSeqLen + seqLen for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { shared[i] = mask(blockIdx.y, gridDim.y, i, attLen) ? float(att[i]) : -__FLT_MAX__; } + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tempStorage; + __shared__ float sharedMax, sharedSum; + float localMax = -1e20; for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { localMax = cub::Max()(localMax, shared[i]); } - localMax = cuda::blockReduce(localMax, -1e20f, cub::Max()); + localMax = BlockReduce(tempStorage).Reduce(localMax, cub::Max(), attLen); + if (threadIdx.x == 0) { sharedMax = localMax; } + __syncthreads(); float localSum = 1e-20; for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { - localSum += shared[i] = expf(shared[i] - localMax); + localSum += shared[i] = expf(shared[i] - sharedMax); } - localSum = cuda::blockReduce(localSum, 1e-20f, cub::Sum()); - auto reciprocal = fdividef(1, localSum); + localSum = BlockReduce(tempStorage).Reduce(localSum, cub::Sum(), attLen); + if (threadIdx.x == 0) { sharedSum = localSum; } + __syncthreads(); + + auto reciprocal = fdividef(1, sharedSum); for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { att[i] = shared[i] * reciprocal; } diff --git a/src/04kernel/src/kernels/softmax/cuda_kernel.cu b/src/04kernel/src/kernels/softmax/cuda_kernel.cu index c5ecb821..114cb453 100644 --- a/src/04kernel/src/kernels/softmax/cuda_kernel.cu +++ b/src/04kernel/src/kernels/softmax/cuda_kernel.cu @@ -1,5 +1,5 @@ #include "cuda_kernel.hh" -#include "kernel/cuda/reduce.cuh" +#include namespace refactor::kernel { using namespace runtime; @@ -18,8 +18,8 @@ namespace refactor::kernel { template<> __device__ __forceinline__ nv_bfloat16 reciprocal(nv_bfloat16 x) { return hrcp(x); } // blockDim.x === BLOCK_DIM - template - __global__ void blockSoftmaxKernel( + template + __launch_bounds__(BLOCK_DIM) __global__ void blockSoftmaxKernel( T const *__restrict x, T *__restrict y, int mid, @@ -40,8 +40,10 @@ namespace refactor::kernel { for (int i = threadIdx.x + blockDim.x; i < mid; i += blockDim.x) { maxSumThread = MaxSum::reduce(maxSumThread, {x[id + i * stride], 1});// reduce the data to one block } + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tempStorage; __shared__ MaxSum maxSumTotal; - auto maxSumBlock = cuda::blockReduce(maxSumThread, {-__FLT_MAX__, 0}, MaxSum::reduce); + auto maxSumBlock = BlockReduce(tempStorage).Reduce(maxSumThread, MaxSum::reduce); if (threadIdx.x == 0) { maxSumTotal = maxSumBlock;// must set threadIdx.x = 0 write the output to memory } @@ -111,7 +113,7 @@ namespace refactor::kernel { auto y = reinterpret_cast(outputs[0]); int numBlocks = info.pre * info.post; if (info.mid > 1024) { - blockSoftmaxKernel<<>>(x, y, info.mid, info.post); + blockSoftmaxKernel<1024><<>>(x, y, info.mid, info.post); } else { int blockDimX, mid = static_cast(info.mid); for (blockDimX = 32; blockDimX > 4 && mid < blockDimX; blockDimX /= 2) {} diff --git a/src/04kernel/test/kernels/attention/test_cuda.cpp b/src/04kernel/test/kernels/attention/test_cuda.cpp index 64555f16..794ae174 100644 --- a/src/04kernel/test/kernels/attention/test_cuda.cpp +++ b/src/04kernel/test/kernels/attention/test_cuda.cpp @@ -13,7 +13,7 @@ using namespace hardware; TEST(kernel, AttentionCudaNoKvCache) { // build routine AttentionInfo info{ - .dataType = DataType::FP16, + .dataType = DataType::F32, .batch = 1, .nHead = 4, .nKVHead = 4, @@ -23,9 +23,9 @@ TEST(kernel, AttentionCudaNoKvCache) { .concatCache = false, .resetCache = false, }; - auto q = Tensor::share(DataType::FP16, Shape{info.batch, info.nHead, info.seqLen, info.headDim}), - k = Tensor::share(DataType::FP16, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}), - v = Tensor::share(DataType::FP16, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}), + auto q = Tensor::share(DataType::F32, Shape{info.batch, info.nHead, info.seqLen, info.headDim}), + k = Tensor::share(DataType::F32, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}), + v = Tensor::share(DataType::F32, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}), o = q; auto kernel = AttentionCuda::build(info); ASSERT_TRUE(kernel); @@ -38,6 +38,15 @@ TEST(kernel, AttentionCudaNoKvCache) { vGpu = dev.malloc(v->bytesSize()), oGpu = dev.malloc(o->bytesSize()), workspace = dev.malloc(workspaceSize); + // put input data + std::vector + q_(q->elementsSize(), 1), + k_(k->elementsSize(), 1), + v_(v->elementsSize(), 1), + o_(o->elementsSize()); + qGpu->copyFromHost(q_.data()); + kGpu->copyFromHost(k_.data()); + vGpu->copyFromHost(v_.data()); // inference { void const *inputs[]{*qGpu, *kGpu, *vGpu};