Skip to content

Commit

Permalink
feat(kernel): 确认供 attention 调用的 softmax 接口
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Jan 31, 2024
1 parent fa6e5b4 commit 61c7ddb
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 5 deletions.
48 changes: 43 additions & 5 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,44 @@ namespace refactor::kernel {
using K = AttentionCuda;
using namespace cublas;

static __forceinline__ __device__ bool mask(int tokid, int posid) {
return true;
}

// gridDim.x = batch * nHead
// gridDim.y = seqLen
template<class T, class Mask>
static __global__ void softmax(
T *__restrict__ attention,
Mask mask,
uint32_t seqLen,
uint32_t bufLen) {
// int offset = (blockIdx.x * len_q + blockIdx.y) * len_buf;
// SharedMemory<float> shared;
// float *smem = shared.getPointer();

// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
// T pb = (position_bias == nullptr) ? T(0.) : position_bias[offset + i];
// smem[i] = mask[blockIdx.y * len_buf + i] > 0 ? x[offset + i] * scale + pb : -Inf<T>();
// }
// float local_max = -1e20;
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
// local_max = fmaxf(local_max, smem[i]);
// }
// local_max = functions::blockReduceMax<float>(local_max);

// float local_sum = 1e-20;
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
// float v = expf(float(smem[i]) - local_max);
// smem[i] = v;
// local_sum += v;
// }
// local_sum = functions::blockReduceSum<float>(local_sum);
// for (int i = threadIdx.x; i < len_buf; i += blockDim.x) {
// x[offset + i] = float(smem[i]) / local_sum;
// }
}

RoutineWorkspace K::lower(Resources &res) const {
auto handle = res.fetchOrStore<CublasLtContext>()->handle;

Expand All @@ -23,9 +61,9 @@ namespace refactor::kernel {
size_t attSize, workspaceSizeQK, workspaceSizeAV;

Descriptors(CublasLtContext const &context,
cublasComputeType_t compute,
AttentionInfo info)
: mul(compute, CUDA_R_32F),
: mul(computeTypeConvert(info.dataType),
dataTypeConvert(info.dataType)),
q(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.seqLen),
Expand Down Expand Up @@ -73,11 +111,10 @@ namespace refactor::kernel {
};

auto const &context = *res.fetchOrStore<CublasLtContext>();
auto d = std::make_shared<Descriptors>(context, CUBLAS_COMPUTE_32F, info);
auto d = std::make_shared<Descriptors>(context, info);
auto workspaceSize = d->attSize;
workspaceSize = hardware::alignBytes(workspaceSize, 256);
workspaceSize += d->workspaceSizeQK;
workspaceSize = hardware::alignBytes(workspaceSize, 256);
workspaceSize += d->workspaceSizeAV;
workspaceSize = hardware::alignBytes(workspaceSize, 256);

Expand Down Expand Up @@ -105,7 +142,8 @@ namespace refactor::kernel {
workspaceQK, d->workspaceSizeQK,
cudaStreamLegacy);

// TODO inline mask && softmax
softmax<<<dim3(info.batch * info.nHead, info.seqLen), info.seqLen>>>(
att, mask, info.seqLen, info.seqLen);

cublasLtMatmul(
handle, d->mul.get(),
Expand Down
15 changes: 15 additions & 0 deletions src/04kernel/src/utilities/cuda/cublaslt_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ namespace refactor::kernel::cublas {
switch (dt) {
case DataType::F32:
return CUDA_R_32F;
case DataType::FP16:
return CUDA_R_16F;
case DataType::BF16:
return CUDA_R_16BF;
default:
TODO("");
}
}
cublasComputeType_t computeTypeConvert(DataType dt) {
switch (dt) {
case DataType::F32:
case DataType::BF16:
return CUBLAS_COMPUTE_32F;
case DataType::FP16:
return CUBLAS_COMPUTE_16F;
default:
TODO("");
}
Expand Down
1 change: 1 addition & 0 deletions src/04kernel/src/utilities/cuda/cublaslt_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace refactor::kernel::cublas {
};

cudaDataType dataTypeConvert(DataType);
cublasComputeType_t computeTypeConvert(DataType);

class MatMulDescriptor {
cublasLtMatmulDesc_t _internal;
Expand Down

0 comments on commit 61c7ddb

Please sign in to comment.