Skip to content

Commit

Permalink
fix(kernel): 解决 attention 访存错误的问题
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Feb 1, 2024
1 parent d2d43a3 commit 26f4b19
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
28 changes: 15 additions & 13 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,27 @@ namespace refactor::kernel {
// seqLen: 此次处理的词数
// posId: 在 kv cache 中的位置
// attLen = pastSeqLen + seqLen
static __forceinline__ __device__ bool
causualMask(int tokenId, int seqLen,
int posId, int attLen) {
// tokenId ↓ |<---attLen---->|
// 0 | * * ... * |
// 1 | * * ... * * |
// 2 | * * ... * * * |
// seqLen: 3 |---------------|
return attLen + tokenId >= posId + seqLen;
}
struct AttentionCausualMask {
__forceinline__ __device__ bool
operator()(int tokenId, int seqLen,
int posId, int attLen) {
// tokenId ↓ |<---attLen---->|
// 0 | * * ... * |
// 1 | * * ... * * |
// 2 | * * ... * * * |
// seqLen: 3 |---------------|
return attLen + tokenId >= posId + seqLen;
}
};

// gridDim.x = batch * nHead
// gridDim.y = seqLen
// blockDim.x = min(1024, attLen)
// sizeof(shared) = attLen * sizeof(float)
template<class T>
template<class T, class Mask>
static __global__ void softmax(
T *__restrict__ att,
bool (*mask)(int, int, int, int),
Mask mask,
uint32_t attLen,
uint32_t bufLen) {
// 找到这个线程块对应的 attention 区域
Expand Down Expand Up @@ -161,7 +163,7 @@ namespace refactor::kernel {
std::min(1024u, attLen),
attLen * sizeof(float),
stream>>>(
att, causualMask, attLen, bufLen);
att, AttentionCausualMask(), attLen, bufLen);
{
half alpha = 1, beta = 0;
cublasLtMatmul(
Expand Down
2 changes: 2 additions & 0 deletions src/04kernel/test/kernels/attention/test_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "../../../src/kernels/attention/cuda_kernel.hh"
#include "hardware/device_manager.h"
#include "kernel/cuda/functions.cuh"
#include <gtest/gtest.h>
#include <numeric>

Expand Down Expand Up @@ -43,6 +44,7 @@ TEST(kernel, AttentionCudaNoKvCache) {
void *outputs[]{*oGpu};
routine(res, *workspace, inputs, outputs);
}
cuda::sync();
}

#endif

0 comments on commit 26f4b19

Please sign in to comment.