diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index 7d018607..b976ac2c 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -7,26 +7,39 @@ namespace refactor::kernel { using K = AttentionCuda; using namespace cublas; - static __forceinline__ __device__ bool mask(int tokid, int posid) { - return true; + // 因果系统的注意力遮罩。 + // tokenId: 第几个词 + // seqLen: 此次处理的词数 + // posId: 在 kv cache 中的位置 + // attLen = pastSeqLen + seqLen + static __forceinline__ __device__ bool + causualMask(int tokenId, int seqLen, + int posId, int attLen) { + // tokenId ↓ |<---attLen---->| + // 0 | * * ... * | + // 1 | * * ... * * | + // 2 | * * ... * * * | + // seqLen: 3 |---------------| + return attLen + tokenId >= posId + seqLen; } // gridDim.x = batch * nHead // gridDim.y = seqLen - template + // blockDim.x = min(1024, attLen) + template static __global__ void softmax( - T *__restrict__ attention, - Mask mask, - uint32_t seqLen, + T *__restrict__ att, + bool (*mask)(int, int, int, int), + uint32_t attLen, uint32_t bufLen) { - // int offset = (blockIdx.x * len_q + blockIdx.y) * len_buf; - // SharedMemory shared; - // float *smem = shared.getPointer(); + // 找到这个线程块对应的 attention 区域 + att += (blockIdx.x * gridDim.x + gridDim.y) * bufLen; + // 将输入装入共享内存并 cast + mask + extern __shared__ float shared[];// size = attLen = pastSeqLen + seqLen + for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { + shared[i] = mask(blockIdx.y, gridDim.y, i, attLen) ? float(att[i]) : -__FLT_MAX__; + } - // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { - // T pb = (position_bias == nullptr) ? T(0.) : position_bias[offset + i]; - // smem[i] = mask[blockIdx.y * len_buf + i] > 0 ? x[offset + i] * scale + pb : -Inf(); - // } // float local_max = -1e20; // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { // local_max = fmaxf(local_max, smem[i]); @@ -125,7 +138,7 @@ namespace refactor::kernel { auto k = inputs[1]; auto v = inputs[2]; auto o = outputs[0]; - auto att = workspace; + auto att = reinterpret_cast(workspace); auto workspaceQK = reinterpret_cast(workspace) + hardware::alignBytes(d->attSize, 256); auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256); @@ -143,7 +156,7 @@ namespace refactor::kernel { cudaStreamLegacy); softmax<<>>( - att, mask, info.seqLen, info.seqLen); + att, causualMask, info.seqLen, info.seqLen); cublasLtMatmul( handle, d->mul.get(),