diff --git a/src/02hardware/include/hardware/devices/nvidia.h b/src/02hardware/include/hardware/devices/nvidia.h index d19dd315..18a4269d 100644 --- a/src/02hardware/include/hardware/devices/nvidia.h +++ b/src/02hardware/include/hardware/devices/nvidia.h @@ -3,6 +3,12 @@ #include "../device.h" +#define CUDA_ASSERT(STATUS) \ + if (auto status = (STATUS); status != cudaSuccess) { \ + RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ + cudaGetErrorString(status), (int) status)); \ + } + namespace refactor::hardware { class Nvidia final : public Device { diff --git a/src/02hardware/src/devices/nvidia/device.cc b/src/02hardware/src/devices/nvidia/device.cc index fd10cb70..20f63c0f 100644 --- a/src/02hardware/src/devices/nvidia/device.cc +++ b/src/02hardware/src/devices/nvidia/device.cc @@ -4,12 +4,6 @@ #ifdef USE_CUDA #include "memory.hh" #include - -#define CUDA_ASSERT(STATUS) \ - if (auto status = (STATUS); status != cudaSuccess) { \ - RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ - cudaGetErrorString(status), (int) status)); \ - } #endif namespace refactor::hardware { diff --git a/src/02hardware/src/devices/nvidia/memory.cc b/src/02hardware/src/devices/nvidia/memory.cc index 42310196..1c3be21e 100644 --- a/src/02hardware/src/devices/nvidia/memory.cc +++ b/src/02hardware/src/devices/nvidia/memory.cc @@ -1,15 +1,9 @@ #ifdef USE_CUDA #include "memory.hh" -#include "common.h" +#include "hardware/devices/nvidia.h" #include -#define CUDA_ASSERT(STATUS) \ - if (auto status = (STATUS); status != cudaSuccess) { \ - RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ - cudaGetErrorString(status), (int) status)); \ - } - namespace refactor::hardware { using M = NvidiaMemory; diff --git a/src/04kernel/include/kernel/attributes/attention_info.h b/src/04kernel/include/kernel/attributes/attention_info.h new file mode 100644 index 00000000..16d5fb0e --- /dev/null +++ b/src/04kernel/include/kernel/attributes/attention_info.h @@ -0,0 +1,16 @@ +#ifndef KERNEL_ATTENTION_INFO_H +#define KERNEL_ATTENTION_INFO_H + +#include "../tensor.h" + +namespace refactor::kernel { + + struct AttentionInfo { + DataType dataType; + dim_t batch, nHead, nKVHead, seqLen, headDim, cacheLen; + bool concatCache, resetCache; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_ATTENTION_INFO_H diff --git a/src/04kernel/include/kernel/collectors/attention.h b/src/04kernel/include/kernel/collectors/attention.h index 527bc63f..abf33957 100644 --- a/src/04kernel/include/kernel/collectors/attention.h +++ b/src/04kernel/include/kernel/collectors/attention.h @@ -6,9 +6,8 @@ namespace refactor::kernel { struct AttentionCollector final : public InfoCollector { - dim_t maxSeqLen; - AttentionCollector(decltype(_target), decltype(maxSeqLen)) noexcept; + AttentionCollector(decltype(_target)) noexcept; std::vector filter(TensorRefs inputs, TensorRefs outputs) const final; diff --git a/src/04kernel/src/collectors/attention.cc b/src/04kernel/src/collectors/attention.cc index 3933097f..a778c128 100644 --- a/src/04kernel/src/collectors/attention.cc +++ b/src/04kernel/src/collectors/attention.cc @@ -1,38 +1,57 @@ #include "kernel/collectors/attention.h" +#include "kernel/attributes/attention_info.h" // #include "../kernels/attention/cpu_kernel.hh" #include "../kernels/attention/cuda_kernel.hh" namespace refactor::kernel { AttentionCollector::AttentionCollector( - decltype(_target) target, - decltype(maxSeqLen) maxSeqLen_) noexcept - : InfoCollector(target), - maxSeqLen(maxSeqLen_) {} + decltype(_target) target) noexcept + : InfoCollector(target) {} std::vector AttentionCollector::filter(TensorRefs inputs, TensorRefs outputs) const { auto const &query = inputs[0].get(); auto const &key = inputs[1].get(); - auto pastSeqLen = inputs.size() == 3 ? 0 : *inputs[2].get().data->get(); - auto cacheLen = outputs.size() == 1 ? 0 : outputs[1].get().shape[2]; - std::vector ans; + AttentionInfo info{ + .dataType = query.dataType, + .batch = query.shape[0], + .nHead = query.shape[1], + .nKVHead = key.shape[1], + .seqLen = query.shape[2], + .headDim = query.shape[3], + .cacheLen = 0, + .concatCache = false, + .resetCache = false, + }; + switch (outputs.size()) { + case 1: + // no kv cache + ASSERT(inputs.size() == 3, ""); + break; + case 3: + switch (inputs.size()) { + case 6: + info.resetCache = true; + case 4: + info.concatCache = true; + case 3: + info.cacheLen = outputs[1].get().shape[2]; + break; + default: + UNREACHABLE(); + } + break; + default: + UNREACHABLE(); + } + + std ::vector ans; switch (_target) { case decltype(_target)::Cpu: break; case decltype(_target)::Nvidia: { - decltype(AttentionCuda::info) info{ - .dataType = query.dataType, - .batch = query.shape[0], - .nHead = query.shape[1], - .nKVHead = key.shape[1], - .pastSeqLen = static_cast(pastSeqLen), - .seqLen = query.shape[2], - .cacheLen = cacheLen, - .headDim = query.shape[3], - .resetCache = false, - }; if (auto ptr = AttentionCuda::build(info); ptr) { ans.emplace_back(std::move(ptr)); } diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu new file mode 100644 index 00000000..a0f3f56a --- /dev/null +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -0,0 +1,127 @@ +#include "../../utilities/cuda/cublaslt_utils.cuh" +#include "cuda_kernel.hh" +#include "hardware/functions.h" + +namespace refactor::kernel { + using K = AttentionCuda; + using namespace cublas; + + RoutineWorkspace K::lower(Resources &res) const { + auto handle = res.fetchOrStore()->handle; + + constexpr auto ROW_MAJOR = CUBLASLT_ORDER_ROW; + constexpr auto COL_MAJOR = CUBLASLT_ORDER_COL; + + if (!info.cacheLen) { + if (info.nHead == info.nKVHead) { + // RAII for closure + struct Descriptors { + MatMulDescriptor mul; + MatrixDescriptor q, k, v, att; + cublasLtMatmulAlgo_t algoQK, algoAV; + size_t attSize, workspaceSizeQK, workspaceSizeAV; + + Descriptors(CublasLtContext const &context, + cublasComputeType_t compute, + AttentionInfo info) + : mul(compute, CUDA_R_32F), + q(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(info.headDim), + .majorStride = static_cast(info.headDim), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.seqLen * info.headDim), + }), + k(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.headDim), + .cols = static_cast(info.seqLen), + .majorStride = static_cast(info.headDim), + .order = COL_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.seqLen * info.headDim), + }), + v(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(info.headDim), + .majorStride = static_cast(info.headDim), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.seqLen * info.headDim), + }), + att(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(info.seqLen), + .majorStride = static_cast(info.seqLen), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.seqLen * info.seqLen), + }), + attSize(info.batch * info.nHead * info.seqLen * info.seqLen * info.dataType.size()) { + auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att); + auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q); + algoQK = algoQK_; + algoAV = algoAV_; + workspaceSizeQK = workspaceSizeQK_; + workspaceSizeAV = workspaceSizeAV_; + } + }; + + auto const &context = *res.fetchOrStore(); + auto d = std::make_shared(context, CUBLAS_COMPUTE_32F, info); + auto workspaceSize = d->attSize; + workspaceSize = hardware::alignBytes(workspaceSize, 256); + workspaceSize += d->workspaceSizeQK; + workspaceSize = hardware::alignBytes(workspaceSize, 256); + workspaceSize += d->workspaceSizeAV; + workspaceSize = hardware::alignBytes(workspaceSize, 256); + + auto routine = [d = std::move(d), info = this->info]// + (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { + auto handle = res.fetchOrStore()->handle; + auto q = inputs[0]; + auto k = inputs[1]; + auto v = inputs[2]; + auto o = outputs[0]; + auto att = workspace; + auto workspaceQK = reinterpret_cast(workspace) + hardware::alignBytes(d->attSize, 256); + auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256); + + float alpha = 1, beta = 0; + cublasLtMatmul( + handle, d->mul.get(), + &alpha, + q, d->q.get(), + k, d->k.get(), + &beta, + att, d->att.get(), + att, d->att.get(), + &d->algoQK, + workspaceQK, d->workspaceSizeQK, + cudaStreamLegacy); + + // TODO inline mask && softmax + + cublasLtMatmul( + handle, d->mul.get(), + &alpha, + att, d->att.get(), + v, d->v.get(), + &beta, + o, d->q.get(), + o, d->q.get(), + &d->algoAV, + workspaceAV, d->workspaceSizeAV, + cudaStreamLegacy); + }; + return {std::move(routine), workspaceSize}; + } + } + TODO(""); + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.hh b/src/04kernel/src/kernels/attention/cuda_kernel.hh index 5ea19ae8..20cf9712 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.hh +++ b/src/04kernel/src/kernels/attention/cuda_kernel.hh @@ -1,17 +1,13 @@ #ifndef KERNEL_ATTENTION_CUDA_KERNEL_HH #define KERNEL_ATTENTION_CUDA_KERNEL_HH +#include "kernel/attributes/attention_info.h" #include "kernel/kernel.h" -#include "kernel/tensor.h" namespace refactor::kernel { struct AttentionCuda final : public Kernel { - struct { - DataType dataType; - dim_t batch, nHead, nKVHead, pastSeqLen, seqLen, cacheLen, headDim; - bool resetCache; - } info; + AttentionInfo info; AttentionCuda(decltype(info)) noexcept; diff --git a/src/04kernel/src/utilities/cuda/cublaslt_context.cu b/src/04kernel/src/utilities/cuda/cublaslt_context.cu deleted file mode 100644 index 2fc8fb18..00000000 --- a/src/04kernel/src/utilities/cuda/cublaslt_context.cu +++ /dev/null @@ -1,33 +0,0 @@ -#include "common.h" -#include "cublaslt_context.hh" - -namespace refactor::kernel::cublas { - - CublasLtContext::CublasLtContext() : runtime::Resource() { - if (cublasLtCreate(&handle) != CUBLAS_STATUS_SUCCESS) { - RUNTIME_ERROR("Failed to create cublasLt handle"); - } - } - CublasLtContext::~CublasLtContext() { - if (cublasLtDestroy(handle) != CUBLAS_STATUS_SUCCESS) { - fmt::println("Failed to destroy cublasLt handle"); - abort(); - } - } - - auto CublasLtContext::typeId() noexcept -> size_t { - static uint8_t ID = 1; - return reinterpret_cast(&ID); - } - auto CublasLtContext::build() noexcept -> runtime::ResourceBox { - return std::make_unique(); - } - - auto CublasLtContext::resourceTypeId() const noexcept -> size_t { - return typeId(); - } - auto CublasLtContext::description() const noexcept -> std::string_view { - return "CublasLtContext"; - } - -}// namespace refactor::kernel::cublas diff --git a/src/04kernel/src/utilities/cuda/cublaslt_context.hh b/src/04kernel/src/utilities/cuda/cublaslt_context.hh deleted file mode 100644 index 84e1d2d9..00000000 --- a/src/04kernel/src/utilities/cuda/cublaslt_context.hh +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef KERNEL_CUBLASLT_CONTEXT_HH -#define KERNEL_CUBLASLT_CONTEXT_HH - -#include "runtime/resource.h" -#include - -#define CUBLAS_ASSERT(STATUS) \ - if (auto status = (STATUS); status != CUBLAS_STATUS_SUCCESS) { \ - fmt::println("cublas failed on \"" #STATUS "\" with {}", \ - (int) status); \ - abort(); \ - } - -namespace refactor::kernel::cublas { - - struct CublasLtContext final : public runtime::Resource { - cublasLtHandle_t handle; - - CublasLtContext(); - ~CublasLtContext(); - CublasLtContext(CublasLtContext const &) noexcept = delete; - CublasLtContext(CublasLtContext &&) noexcept = delete; - - static size_t typeId() noexcept; - static runtime::ResourceBox build() noexcept; - - size_t resourceTypeId() const noexcept final; - std::string_view description() const noexcept final; - }; - -}// namespace refactor::kernel::cublas - -#endif// KERNEL_CUBLASLT_CONTEXT_HH diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu new file mode 100644 index 00000000..d07af6ab --- /dev/null +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu @@ -0,0 +1,145 @@ +#include "cublaslt_utils.cuh" +#include "hardware/devices/nvidia.h" + +namespace refactor::kernel::cublas { + + CublasLtContext::CublasLtContext() : runtime::Resource() { + if (cublasLtCreate(&handle) != CUBLAS_STATUS_SUCCESS) { + RUNTIME_ERROR("Failed to create cublasLt handle"); + } + } + CublasLtContext::~CublasLtContext() { + if (cublasLtDestroy(handle) != CUBLAS_STATUS_SUCCESS) { + fmt::println("Failed to destroy cublasLt handle"); + abort(); + } + } + + auto CublasLtContext::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto CublasLtContext::build() noexcept -> runtime::ResourceBox { + return std::make_unique(); + } + + auto CublasLtContext::resourceTypeId() const noexcept -> size_t { + return typeId(); + } + auto CublasLtContext::description() const noexcept -> std::string_view { + return "CublasLtContext"; + } + + cudaDataType dataTypeConvert(DataType dt) { + switch (dt) { + case DataType::F32: + return CUDA_R_32F; + default: + TODO(""); + } + } + + MatMulDescriptor::MatMulDescriptor(cublasComputeType_t compute, cudaDataType data) + : _internal(nullptr) { + CUBLASLT_ASSERT(cublasLtMatmulDescCreate(&_internal, compute, data)); + } + MatMulDescriptor::~MatMulDescriptor() { + CUBLASLT_ASSERT(cublasLtMatmulDescDestroy(_internal)); + } + cublasLtMatmulDesc_t MatMulDescriptor::get() const noexcept { + return _internal; + } + + MatrixDescriptor::MatrixDescriptor(MatrixLayout layout) + : _internal(nullptr) { + CUBLASLT_ASSERT(cublasLtMatrixLayoutCreate( + &_internal, + layout.dataType, + layout.rows, + layout.cols, + layout.majorStride)); + CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute( + _internal, + CUBLASLT_MATRIX_LAYOUT_ORDER, + &layout.order, + sizeof(layout.order))); + CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute( + _internal, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &layout.batchCount, + sizeof(layout.batchCount))); + CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute( + _internal, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &layout.batchStride, + sizeof(layout.batchStride))); + } + MatrixDescriptor::~MatrixDescriptor() { + CUBLASLT_ASSERT(cublasLtMatrixLayoutDestroy(_internal)); + } + cublasLtMatrixLayout_t MatrixDescriptor::get() const noexcept { + return _internal; + } + + std::pair + tune(cublasLtHandle_t handle, + MatMulDescriptor const &matmul, + MatrixDescriptor const &a, + MatrixDescriptor const &b, + MatrixDescriptor const &c) { + + int device; + CUDA_ASSERT(cudaGetDevice(&device)); + cudaDeviceProp prop; + CUDA_ASSERT(cudaGetDeviceProperties(&prop, device)); + + auto workspace = std::numeric_limits::max(); + auto alignment = prop.textureAlignment; + + cublasLtMatmulPreference_t preference; + CUBLASLT_ASSERT(cublasLtMatmulPreferenceCreate(&preference)); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace, + sizeof(workspace))); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, + &alignment, + sizeof(alignment))); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, + &alignment, + sizeof(alignment))); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, + &alignment, + sizeof(alignment))); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, + &alignment, + sizeof(alignment))); + + cublasLtMatmulHeuristicResult_t result; + int ansN; + CUBLASLT_ASSERT(cublasLtMatmulAlgoGetHeuristic( + handle, + matmul.get(), + a.get(), + b.get(), + c.get(), + c.get(), + preference, + 1, + &result, + &ansN)); + ASSERT(ansN == 1, ""); + + return {result.algo, result.workspaceSize}; + } + +}// namespace refactor::kernel::cublas diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh new file mode 100644 index 00000000..5dd23607 --- /dev/null +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh @@ -0,0 +1,74 @@ +#ifndef KERNEL_CUBLASLT_UTILS_CUH +#define KERNEL_CUBLASLT_UTILS_CUH + +#include "common.h" +#include "runtime/resource.h" +#include + +#define CUBLASLT_ASSERT(STATUS) \ + if (auto status = (STATUS); status != CUBLAS_STATUS_SUCCESS) { \ + fmt::println("cublasLt failed on \"" #STATUS "\" with {}", \ + (int) status); \ + abort(); \ + } + +namespace refactor::kernel::cublas { + + struct CublasLtContext final : public runtime::Resource { + cublasLtHandle_t handle; + + CublasLtContext(); + ~CublasLtContext(); + CublasLtContext(CublasLtContext const &) noexcept = delete; + CublasLtContext(CublasLtContext &&) noexcept = delete; + + static size_t typeId() noexcept; + static runtime::ResourceBox build() noexcept; + + size_t resourceTypeId() const noexcept final; + std::string_view description() const noexcept final; + }; + + cudaDataType dataTypeConvert(DataType); + + class MatMulDescriptor { + cublasLtMatmulDesc_t _internal; + + public: + MatMulDescriptor(cublasComputeType_t, cudaDataType); + ~MatMulDescriptor(); + MatMulDescriptor(MatMulDescriptor const &) noexcept = delete; + MatMulDescriptor(MatMulDescriptor &&) noexcept = delete; + cublasLtMatmulDesc_t get() const noexcept; + }; + + struct MatrixLayout { + cudaDataType dataType; + uint64_t rows, cols; + int64_t majorStride; + cublasLtOrder_t order; + int32_t batchCount; + int64_t batchStride; + }; + + class MatrixDescriptor { + cublasLtMatrixLayout_t _internal; + + public: + MatrixDescriptor(MatrixLayout layout); + ~MatrixDescriptor(); + MatrixDescriptor(MatrixDescriptor const &) noexcept = delete; + MatrixDescriptor(MatrixDescriptor &&) noexcept = delete; + cublasLtMatrixLayout_t get() const noexcept; + }; + + std::pair + tune(cublasLtHandle_t, + MatMulDescriptor const &, + MatrixDescriptor const &, + MatrixDescriptor const &, + MatrixDescriptor const &); + +}// namespace refactor::kernel::cublas + +#endif// KERNEL_CUBLASLT_UTILS_CUH