diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index 79aa6f2b..7d018607 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -7,6 +7,44 @@ namespace refactor::kernel { using K = AttentionCuda; using namespace cublas; + static __forceinline__ __device__ bool mask(int tokid, int posid) { + return true; + } + + // gridDim.x = batch * nHead + // gridDim.y = seqLen + template + static __global__ void softmax( + T *__restrict__ attention, + Mask mask, + uint32_t seqLen, + uint32_t bufLen) { + // int offset = (blockIdx.x * len_q + blockIdx.y) * len_buf; + // SharedMemory shared; + // float *smem = shared.getPointer(); + + // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { + // T pb = (position_bias == nullptr) ? T(0.) : position_bias[offset + i]; + // smem[i] = mask[blockIdx.y * len_buf + i] > 0 ? x[offset + i] * scale + pb : -Inf(); + // } + // float local_max = -1e20; + // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { + // local_max = fmaxf(local_max, smem[i]); + // } + // local_max = functions::blockReduceMax(local_max); + + // float local_sum = 1e-20; + // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { + // float v = expf(float(smem[i]) - local_max); + // smem[i] = v; + // local_sum += v; + // } + // local_sum = functions::blockReduceSum(local_sum); + // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { + // x[offset + i] = float(smem[i]) / local_sum; + // } + } + RoutineWorkspace K::lower(Resources &res) const { auto handle = res.fetchOrStore()->handle; @@ -23,9 +61,9 @@ namespace refactor::kernel { size_t attSize, workspaceSizeQK, workspaceSizeAV; Descriptors(CublasLtContext const &context, - cublasComputeType_t compute, AttentionInfo info) - : mul(compute, CUDA_R_32F), + : mul(computeTypeConvert(info.dataType), + dataTypeConvert(info.dataType)), q(MatrixLayout{ .dataType = dataTypeConvert(info.dataType), .rows = static_cast(info.seqLen), @@ -73,11 +111,10 @@ namespace refactor::kernel { }; auto const &context = *res.fetchOrStore(); - auto d = std::make_shared(context, CUBLAS_COMPUTE_32F, info); + auto d = std::make_shared(context, info); auto workspaceSize = d->attSize; workspaceSize = hardware::alignBytes(workspaceSize, 256); workspaceSize += d->workspaceSizeQK; - workspaceSize = hardware::alignBytes(workspaceSize, 256); workspaceSize += d->workspaceSizeAV; workspaceSize = hardware::alignBytes(workspaceSize, 256); @@ -105,7 +142,8 @@ namespace refactor::kernel { workspaceQK, d->workspaceSizeQK, cudaStreamLegacy); - // TODO inline mask && softmax + softmax<<>>( + att, mask, info.seqLen, info.seqLen); cublasLtMatmul( handle, d->mul.get(), diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu index d07af6ab..7c34ad4c 100644 --- a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu @@ -34,6 +34,21 @@ namespace refactor::kernel::cublas { switch (dt) { case DataType::F32: return CUDA_R_32F; + case DataType::FP16: + return CUDA_R_16F; + case DataType::BF16: + return CUDA_R_16BF; + default: + TODO(""); + } + } + cublasComputeType_t computeTypeConvert(DataType dt) { + switch (dt) { + case DataType::F32: + case DataType::BF16: + return CUBLAS_COMPUTE_32F; + case DataType::FP16: + return CUBLAS_COMPUTE_16F; default: TODO(""); } diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh index 5dd23607..ccaad7ec 100644 --- a/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh @@ -30,6 +30,7 @@ namespace refactor::kernel::cublas { }; cudaDataType dataTypeConvert(DataType); + cublasComputeType_t computeTypeConvert(DataType); class MatMulDescriptor { cublasLtMatmulDesc_t _internal;