Skip to content

Commit

Permalink
feat(kernel): 封装 attLen 计算
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Feb 1, 2024
1 parent 29790ed commit fddac13
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
3 changes: 3 additions & 0 deletions src/04kernel/include/kernel/attributes/attention_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ namespace refactor::kernel {
DataType dataType;
dim_t batch, nHead, nKVHead, seqLen, headDim, cacheLen;
bool concatCache, resetCache;

dim_t attLen(dim_t pastSeqLen) const noexcept;
size_t attSize(dim_t pastSeqLen) const noexcept;
};

}// namespace refactor::kernel
Expand Down
13 changes: 13 additions & 0 deletions src/04kernel/src/attributes/attention_info.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include "kernel/attributes/attention_info.h"

namespace refactor::kernel {

dim_t AttentionInfo::attLen(dim_t pastSeqLen) const noexcept {
return pastSeqLen + seqLen;
}

size_t AttentionInfo::attSize(dim_t pastSeqLen) const noexcept {
return batch * nHead * seqLen * attLen(pastSeqLen) * dataType.size();
}

}// namespace refactor::kernel
17 changes: 9 additions & 8 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace refactor::kernel {
MatMulDescriptor mul;
MatrixDescriptor q, k, v, att;
cublasLtMatmulAlgo_t algoQK, algoAV;
size_t attSize, workspaceSizeQK, workspaceSizeAV;
size_t workspaceSizeQK, workspaceSizeAV;

Descriptors(CublasLtContext const &context,
AttentionInfo info)
Expand Down Expand Up @@ -112,8 +112,7 @@ namespace refactor::kernel {
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.seqLen),
}),
attSize(info.batch * info.nHead * info.seqLen * info.seqLen * info.dataType.size()) {
}) {
auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att);
auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q);
algoQK = algoQK_;
Expand All @@ -125,7 +124,7 @@ namespace refactor::kernel {

auto const &context = *res.fetchOrStore<CublasLtContext>();
auto d = std::make_shared<Descriptors>(context, info);
auto workspaceSize = d->attSize;
auto workspaceSize = info.attSize(0);
workspaceSize = hardware::alignBytes(workspaceSize, 256);
workspaceSize += d->workspaceSizeQK;
workspaceSize += d->workspaceSizeAV;
Expand All @@ -139,7 +138,7 @@ namespace refactor::kernel {
auto v = inputs[2];
auto o = outputs[0];
auto att = reinterpret_cast<half *>(workspace);
auto workspaceQK = reinterpret_cast<uint8_t *>(workspace) + hardware::alignBytes(d->attSize, 256);
auto workspaceQK = reinterpret_cast<uint8_t *>(workspace) + hardware::alignBytes(info.attSize(0), 256);
auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256);
{
half alpha = rsqrtf(info.headDim), beta = 0;
Expand All @@ -155,10 +154,12 @@ namespace refactor::kernel {
workspaceQK, d->workspaceSizeQK,
cudaStreamLegacy);
}
auto attLen = info.attLen(0);
auto bufLen = attLen;
softmax<<<dim3(info.batch * info.nHead, info.seqLen),
info.seqLen,
info.seqLen * sizeof(float)>>>(
att, causualMask, info.seqLen, info.seqLen);
std::min(1024u, attLen),
attLen * sizeof(float)>>>(
att, causualMask, attLen, bufLen);
{
half alpha = 1, beta = 0;
cublasLtMatmul(
Expand Down

0 comments on commit fddac13

Please sign in to comment.