From fddac13c314a6266dff7c33550863b5da4d5788c Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Thu, 1 Feb 2024 12:07:08 +0800 Subject: [PATCH] =?UTF-8?q?feat(kernel):=20=E5=B0=81=E8=A3=85=20attLen=20?= =?UTF-8?q?=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../include/kernel/attributes/attention_info.h | 3 +++ src/04kernel/src/attributes/attention_info.cc | 13 +++++++++++++ .../src/kernels/attention/cuda_kernel.cu | 17 +++++++++-------- 3 files changed, 25 insertions(+), 8 deletions(-) create mode 100644 src/04kernel/src/attributes/attention_info.cc diff --git a/src/04kernel/include/kernel/attributes/attention_info.h b/src/04kernel/include/kernel/attributes/attention_info.h index 16d5fb0e..9cd64a56 100644 --- a/src/04kernel/include/kernel/attributes/attention_info.h +++ b/src/04kernel/include/kernel/attributes/attention_info.h @@ -9,6 +9,9 @@ namespace refactor::kernel { DataType dataType; dim_t batch, nHead, nKVHead, seqLen, headDim, cacheLen; bool concatCache, resetCache; + + dim_t attLen(dim_t pastSeqLen) const noexcept; + size_t attSize(dim_t pastSeqLen) const noexcept; }; }// namespace refactor::kernel diff --git a/src/04kernel/src/attributes/attention_info.cc b/src/04kernel/src/attributes/attention_info.cc new file mode 100644 index 00000000..c16c59fa --- /dev/null +++ b/src/04kernel/src/attributes/attention_info.cc @@ -0,0 +1,13 @@ +#include "kernel/attributes/attention_info.h" + +namespace refactor::kernel { + + dim_t AttentionInfo::attLen(dim_t pastSeqLen) const noexcept { + return pastSeqLen + seqLen; + } + + size_t AttentionInfo::attSize(dim_t pastSeqLen) const noexcept { + return batch * nHead * seqLen * attLen(pastSeqLen) * dataType.size(); + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index 6f5b7d13..bcf31dc4 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -71,7 +71,7 @@ namespace refactor::kernel { MatMulDescriptor mul; MatrixDescriptor q, k, v, att; cublasLtMatmulAlgo_t algoQK, algoAV; - size_t attSize, workspaceSizeQK, workspaceSizeAV; + size_t workspaceSizeQK, workspaceSizeAV; Descriptors(CublasLtContext const &context, AttentionInfo info) @@ -112,8 +112,7 @@ namespace refactor::kernel { .order = ROW_MAJOR, .batchCount = static_cast(info.batch * info.nHead), .batchStride = static_cast(info.seqLen * info.seqLen), - }), - attSize(info.batch * info.nHead * info.seqLen * info.seqLen * info.dataType.size()) { + }) { auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att); auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q); algoQK = algoQK_; @@ -125,7 +124,7 @@ namespace refactor::kernel { auto const &context = *res.fetchOrStore(); auto d = std::make_shared(context, info); - auto workspaceSize = d->attSize; + auto workspaceSize = info.attSize(0); workspaceSize = hardware::alignBytes(workspaceSize, 256); workspaceSize += d->workspaceSizeQK; workspaceSize += d->workspaceSizeAV; @@ -139,7 +138,7 @@ namespace refactor::kernel { auto v = inputs[2]; auto o = outputs[0]; auto att = reinterpret_cast(workspace); - auto workspaceQK = reinterpret_cast(workspace) + hardware::alignBytes(d->attSize, 256); + auto workspaceQK = reinterpret_cast(workspace) + hardware::alignBytes(info.attSize(0), 256); auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256); { half alpha = rsqrtf(info.headDim), beta = 0; @@ -155,10 +154,12 @@ namespace refactor::kernel { workspaceQK, d->workspaceSizeQK, cudaStreamLegacy); } + auto attLen = info.attLen(0); + auto bufLen = attLen; softmax<<>>( - att, causualMask, info.seqLen, info.seqLen); + std::min(1024u, attLen), + attLen * sizeof(float)>>>( + att, causualMask, attLen, bufLen); { half alpha = 1, beta = 0; cublasLtMatmul(