Skip to content

Commit

Permalink
feat: 开始实现 attention
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
  • Loading branch information
YdrMaster committed Jan 31, 2024
1 parent 27a8ad6 commit 90c5180
Show file tree
Hide file tree
Showing 12 changed files with 409 additions and 105 deletions.
6 changes: 6 additions & 0 deletions src/02hardware/include/hardware/devices/nvidia.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

#include "../device.h"

#define CUDA_ASSERT(STATUS) \
if (auto status = (STATUS); status != cudaSuccess) { \
RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
cudaGetErrorString(status), (int) status)); \
}

namespace refactor::hardware {

class Nvidia final : public Device {
Expand Down
6 changes: 0 additions & 6 deletions src/02hardware/src/devices/nvidia/device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@
#ifdef USE_CUDA
#include "memory.hh"
#include <cuda_runtime.h>

#define CUDA_ASSERT(STATUS) \
if (auto status = (STATUS); status != cudaSuccess) { \
RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
cudaGetErrorString(status), (int) status)); \
}
#endif

namespace refactor::hardware {
Expand Down
8 changes: 1 addition & 7 deletions src/02hardware/src/devices/nvidia/memory.cc
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
#ifdef USE_CUDA

#include "memory.hh"
#include "common.h"
#include "hardware/devices/nvidia.h"
#include <cuda_runtime.h>

#define CUDA_ASSERT(STATUS) \
if (auto status = (STATUS); status != cudaSuccess) { \
RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
cudaGetErrorString(status), (int) status)); \
}

namespace refactor::hardware {
using M = NvidiaMemory;

Expand Down
16 changes: 16 additions & 0 deletions src/04kernel/include/kernel/attributes/attention_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef KERNEL_ATTENTION_INFO_H
#define KERNEL_ATTENTION_INFO_H

#include "../tensor.h"

namespace refactor::kernel {

struct AttentionInfo {
DataType dataType;
dim_t batch, nHead, nKVHead, seqLen, headDim, cacheLen;
bool concatCache, resetCache;
};

}// namespace refactor::kernel

#endif// KERNEL_ATTENTION_INFO_H
3 changes: 1 addition & 2 deletions src/04kernel/include/kernel/collectors/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
namespace refactor::kernel {

struct AttentionCollector final : public InfoCollector {
dim_t maxSeqLen;

AttentionCollector(decltype(_target), decltype(maxSeqLen)) noexcept;
AttentionCollector(decltype(_target)) noexcept;

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
Expand Down
55 changes: 37 additions & 18 deletions src/04kernel/src/collectors/attention.cc
Original file line number Diff line number Diff line change
@@ -1,38 +1,57 @@
#include "kernel/collectors/attention.h"
#include "kernel/attributes/attention_info.h"
// #include "../kernels/attention/cpu_kernel.hh"
#include "../kernels/attention/cuda_kernel.hh"

namespace refactor::kernel {

AttentionCollector::AttentionCollector(
decltype(_target) target,
decltype(maxSeqLen) maxSeqLen_) noexcept
: InfoCollector(target),
maxSeqLen(maxSeqLen_) {}
decltype(_target) target) noexcept
: InfoCollector(target) {}

std::vector<KernelBox>
AttentionCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
auto const &query = inputs[0].get();
auto const &key = inputs[1].get();
auto pastSeqLen = inputs.size() == 3 ? 0 : *inputs[2].get().data->get<int64_t>();
auto cacheLen = outputs.size() == 1 ? 0 : outputs[1].get().shape[2];

std::vector<KernelBox> ans;
AttentionInfo info{
.dataType = query.dataType,
.batch = query.shape[0],
.nHead = query.shape[1],
.nKVHead = key.shape[1],
.seqLen = query.shape[2],
.headDim = query.shape[3],
.cacheLen = 0,
.concatCache = false,
.resetCache = false,
};
switch (outputs.size()) {
case 1:
// no kv cache
ASSERT(inputs.size() == 3, "");
break;
case 3:
switch (inputs.size()) {
case 6:
info.resetCache = true;
case 4:
info.concatCache = true;
case 3:
info.cacheLen = outputs[1].get().shape[2];
break;
default:
UNREACHABLE();
}
break;
default:
UNREACHABLE();
}

std ::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
break;
case decltype(_target)::Nvidia: {
decltype(AttentionCuda::info) info{
.dataType = query.dataType,
.batch = query.shape[0],
.nHead = query.shape[1],
.nKVHead = key.shape[1],
.pastSeqLen = static_cast<dim_t>(pastSeqLen),
.seqLen = query.shape[2],
.cacheLen = cacheLen,
.headDim = query.shape[3],
.resetCache = false,
};
if (auto ptr = AttentionCuda::build(info); ptr) {
ans.emplace_back(std::move(ptr));
}
Expand Down
127 changes: 127 additions & 0 deletions src/04kernel/src/kernels/attention/cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#include "../../utilities/cuda/cublaslt_utils.cuh"
#include "cuda_kernel.hh"
#include "hardware/functions.h"

namespace refactor::kernel {
using K = AttentionCuda;
using namespace cublas;

RoutineWorkspace K::lower(Resources &res) const {
auto handle = res.fetchOrStore<CublasLtContext>()->handle;

constexpr auto ROW_MAJOR = CUBLASLT_ORDER_ROW;
constexpr auto COL_MAJOR = CUBLASLT_ORDER_COL;

if (!info.cacheLen) {
if (info.nHead == info.nKVHead) {
// RAII for closure
struct Descriptors {
MatMulDescriptor mul;
MatrixDescriptor q, k, v, att;
cublasLtMatmulAlgo_t algoQK, algoAV;
size_t attSize, workspaceSizeQK, workspaceSizeAV;

Descriptors(CublasLtContext const &context,
cublasComputeType_t compute,
AttentionInfo info)
: mul(compute, CUDA_R_32F),
q(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(info.headDim),
.majorStride = static_cast<int64_t>(info.headDim),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
}),
k(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.headDim),
.cols = static_cast<uint64_t>(info.seqLen),
.majorStride = static_cast<int64_t>(info.headDim),
.order = COL_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
}),
v(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(info.headDim),
.majorStride = static_cast<int64_t>(info.headDim),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
}),
att(MatrixLayout{
.dataType = dataTypeConvert(info.dataType),
.rows = static_cast<uint64_t>(info.seqLen),
.cols = static_cast<uint64_t>(info.seqLen),
.majorStride = static_cast<int64_t>(info.seqLen),
.order = ROW_MAJOR,
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
.batchStride = static_cast<int64_t>(info.seqLen * info.seqLen),
}),
attSize(info.batch * info.nHead * info.seqLen * info.seqLen * info.dataType.size()) {
auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att);
auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q);
algoQK = algoQK_;
algoAV = algoAV_;
workspaceSizeQK = workspaceSizeQK_;
workspaceSizeAV = workspaceSizeAV_;
}
};

auto const &context = *res.fetchOrStore<CublasLtContext>();
auto d = std::make_shared<Descriptors>(context, CUBLAS_COMPUTE_32F, info);
auto workspaceSize = d->attSize;
workspaceSize = hardware::alignBytes(workspaceSize, 256);
workspaceSize += d->workspaceSizeQK;
workspaceSize = hardware::alignBytes(workspaceSize, 256);
workspaceSize += d->workspaceSizeAV;
workspaceSize = hardware::alignBytes(workspaceSize, 256);

auto routine = [d = std::move(d), info = this->info]//
(Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
auto handle = res.fetchOrStore<CublasLtContext>()->handle;
auto q = inputs[0];
auto k = inputs[1];
auto v = inputs[2];
auto o = outputs[0];
auto att = workspace;
auto workspaceQK = reinterpret_cast<uint8_t *>(workspace) + hardware::alignBytes(d->attSize, 256);
auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256);

float alpha = 1, beta = 0;
cublasLtMatmul(
handle, d->mul.get(),
&alpha,
q, d->q.get(),
k, d->k.get(),
&beta,
att, d->att.get(),
att, d->att.get(),
&d->algoQK,
workspaceQK, d->workspaceSizeQK,
cudaStreamLegacy);

// TODO inline mask && softmax

cublasLtMatmul(
handle, d->mul.get(),
&alpha,
att, d->att.get(),
v, d->v.get(),
&beta,
o, d->q.get(),
o, d->q.get(),
&d->algoAV,
workspaceAV, d->workspaceSizeAV,
cudaStreamLegacy);
};
return {std::move(routine), workspaceSize};
}
}
TODO("");
}

}// namespace refactor::kernel
8 changes: 2 additions & 6 deletions src/04kernel/src/kernels/attention/cuda_kernel.hh
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
#ifndef KERNEL_ATTENTION_CUDA_KERNEL_HH
#define KERNEL_ATTENTION_CUDA_KERNEL_HH

#include "kernel/attributes/attention_info.h"
#include "kernel/kernel.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct AttentionCuda final : public Kernel {
struct {
DataType dataType;
dim_t batch, nHead, nKVHead, pastSeqLen, seqLen, cacheLen, headDim;
bool resetCache;
} info;
AttentionInfo info;

AttentionCuda(decltype(info)) noexcept;

Expand Down
33 changes: 0 additions & 33 deletions src/04kernel/src/utilities/cuda/cublaslt_context.cu

This file was deleted.

33 changes: 0 additions & 33 deletions src/04kernel/src/utilities/cuda/cublaslt_context.hh

This file was deleted.

Loading

0 comments on commit 90c5180

Please sign in to comment.