Skip to content

Commit

Permalink
feat(kernel): 开始实现 softmax
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Jan 31, 2024
1 parent 61c7ddb commit 614c6a5
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,39 @@ namespace refactor::kernel {
using K = AttentionCuda;
using namespace cublas;

static __forceinline__ __device__ bool mask(int tokid, int posid) {
return true;
// 因果系统的注意力遮罩。
// tokenId: 第几个词
// seqLen: 此次处理的词数
// posId: 在 kv cache 中的位置
// attLen = pastSeqLen + seqLen
static __forceinline__ __device__ bool
causualMask(int tokenId, int seqLen,
int posId, int attLen) {
// tokenId ↓ |<---attLen---->|
// 0 | * * ... * |
// 1 | * * ... * * |
// 2 | * * ... * * * |
// seqLen: 3 |---------------|
return attLen + tokenId >= posId + seqLen;
}

// gridDim.x = batch * nHead
// gridDim.y = seqLen
template<class T, class Mask>
// blockDim.x = min(1024, attLen)
template<class T>
static __global__ void softmax(
T *__restrict__ attention,
Mask mask,
uint32_t seqLen,
T *__restrict__ att,
bool (*mask)(int, int, int, int),
uint32_t attLen,
uint32_t bufLen) {
// int offset = (blockIdx.x * len_q + blockIdx.y) * len_buf;
// SharedMemory<float> shared;
// float *smem = shared.getPointer();
// 找到这个线程块对应的 attention 区域
att += (blockIdx.x * gridDim.x + gridDim.y) * bufLen;
// 将输入装入共享内存并 cast + mask
extern __shared__ float shared[];// size = attLen = pastSeqLen + seqLen
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
shared[i] = mask(blockIdx.y, gridDim.y, i, attLen) ? float(att[i]) : -__FLT_MAX__;
}

// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
// T pb = (position_bias == nullptr) ? T(0.) : position_bias[offset + i];
// smem[i] = mask[blockIdx.y * len_buf + i] > 0 ? x[offset + i] * scale + pb : -Inf<T>();
// }
// float local_max = -1e20;
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
// local_max = fmaxf(local_max, smem[i]);
Expand Down Expand Up @@ -125,7 +138,7 @@ namespace refactor::kernel {
auto k = inputs[1];
auto v = inputs[2];
auto o = outputs[0];
auto att = workspace;
auto att = reinterpret_cast<half *>(workspace);
auto workspaceQK = reinterpret_cast<uint8_t *>(workspace) + hardware::alignBytes(d->attSize, 256);
auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256);

Expand All @@ -143,7 +156,7 @@ namespace refactor::kernel {
cudaStreamLegacy);

softmax<<<dim3(info.batch * info.nHead, info.seqLen), info.seqLen>>>(
att, mask, info.seqLen, info.seqLen);
att, causualMask, info.seqLen, info.seqLen);

cublasLtMatmul(
handle, d->mul.get(),
Expand Down

0 comments on commit 614c6a5

Please sign in to comment.