diff --git a/src/04kernel/cuda/include/kernel/cuda/functions.cuh b/src/04kernel/cuda/include/kernel/cuda/functions.cuh index 23fcf8f7..9fb17edc 100644 --- a/src/04kernel/cuda/include/kernel/cuda/functions.cuh +++ b/src/04kernel/cuda/include/kernel/cuda/functions.cuh @@ -6,11 +6,15 @@ namespace refactor::kernel::cuda { int currentDevice(); void sync(); - + void setCudaDevice(int); void copyOut(void *dst, const void *src, size_t size); + template + inline T round_up(T m, Tb d) { + return ((m + T(d) - 1) / T(d)) * T(d); + } }// namespace refactor::kernel::cuda #endif// KERNEL_CUDA_FUNCTIONS_CUH diff --git a/src/04kernel/cuda/include/kernel/cuda/rope.cuh b/src/04kernel/cuda/include/kernel/cuda/rope.cuh new file mode 100644 index 00000000..704eca51 --- /dev/null +++ b/src/04kernel/cuda/include/kernel/cuda/rope.cuh @@ -0,0 +1,17 @@ +#ifndef KERNEL_CUDA_ROPE_CUH +#define KERNEL_CUDA_ROPE_CUH + +namespace refactor::kernel::cuda { + void launchRoPE( + void const *input, + int64_t const *posIDs, + void *output, + unsigned int batchSize, + unsigned int seqLen, + unsigned int nHeads, + unsigned int headDim, + float theta, + bool useHalf); +} + +#endif diff --git a/src/04kernel/cuda/src/rope.cu b/src/04kernel/cuda/src/rope.cu new file mode 100644 index 00000000..900e720e --- /dev/null +++ b/src/04kernel/cuda/src/rope.cu @@ -0,0 +1,81 @@ +#include "kernel/cuda/functions.cuh" +#include "kernel/cuda/rope.cuh" +#include +#include + +namespace refactor::kernel::cuda { + + template + __global__ static void RoPEKernel( + unsigned int dim_model,// dim_model = num_head * dim_head + unsigned int dim_head, + unsigned int hidden_stride, // hidden_stride = seq_len * dim_model + unsigned int pos_stride, // pos_stride = seq_len + const int64_t *__restrict__ pos,// (batch, seq_len) + const T *__restrict__ in, // (batch, seq_len, num_head, dim_head) + T *__restrict__ out // (batch, seq_len, num_head, dim_head) + ) { + unsigned int batch_id = blockIdx.x; + int64_t target_pos = pos[batch_id * pos_stride + blockIdx.y]; + size_t ith = blockIdx.z * blockDim.x + threadIdx.x; + unsigned int col = ith % dim_head; + size_t offset = batch_id * hidden_stride + blockIdx.y * dim_model; + + if (ith >= dim_model) + return; + + unsigned int half_dim = dim_head / 2; + if (col < half_dim) { + float freq = target_pos * powf(10000, -float(col * 2) / dim_head); + float cos_freq = cos(freq); + float sin_freq = sin(freq); + out[offset + ith] = + in[offset + ith] * T(cos_freq) - in[offset + ith + half_dim] * T(sin_freq); + } else { + float freq = target_pos * powf(10000, -float((col - half_dim) * 2) / dim_head); + float cos_freq = cos(freq); + float sin_freq = sin(freq); + out[offset + ith] = + in[offset + ith] * T(cos_freq) + in[offset + ith - half_dim] * T(sin_freq); + } + } + + + void launchRoPE( + void const *input, + int64_t const *posIDs, + void *output, + unsigned int batchSize, + unsigned int seqLen, + unsigned int nHeads, + unsigned int headDim, + float theta, + bool useHalf) { + + unsigned int dimModel = nHeads * headDim; + unsigned int hiddenStride = seqLen * dimModel; + unsigned int threads = min(1024, round_up(dimModel, 32)); + dim3 gridDim(batchSize, seqLen, round_up(dimModel, threads) / threads); + dim3 blockDim(threads, 1, 1); + if (useHalf) { + RoPEKernel<<>>( + dimModel, + headDim, + hiddenStride, + seqLen, + reinterpret_cast(posIDs), + reinterpret_cast(input), + reinterpret_cast(output)); + + } else { + RoPEKernel<<>>( + dimModel, + headDim, + hiddenStride, + seqLen, + reinterpret_cast(posIDs), + reinterpret_cast(input), + reinterpret_cast(output)); + } + } +}// namespace refactor::kernel::cuda diff --git a/src/04kernel/include/kernel/attributes/rope_info.h b/src/04kernel/include/kernel/attributes/rope_info.h new file mode 100644 index 00000000..b88f5fa1 --- /dev/null +++ b/src/04kernel/include/kernel/attributes/rope_info.h @@ -0,0 +1,17 @@ +#ifndef KERNEL_ROPE_INFO_H +#define KERNEL_ROPE_INFO_H + +#include "../tensor.h" + +namespace refactor::kernel { + struct RoPEInfo { + dim_t batchsize = 1; + dim_t seq_len, n_heads, head_dim; + float theta; + + RoPEInfo(Tensor const &input, float _theta); + }; + +}// namespace refactor::kernel + +#endif diff --git a/src/04kernel/include/kernel/collectors/rope.h b/src/04kernel/include/kernel/collectors/rope.h new file mode 100644 index 00000000..7a475acd --- /dev/null +++ b/src/04kernel/include/kernel/collectors/rope.h @@ -0,0 +1,18 @@ +#ifndef KERNEL_ROPE_COLLECTOR_H +#define KERNEL_ROPE_COLLECTOR_H + +#include "../collector.h" + +namespace refactor::kernel { + struct RoPECollector final : public InfoCollector { + float theta; + constexpr RoPECollector(decltype(_target) target, float _theta) noexcept + : InfoCollector(target), theta(_theta){} + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; + +}// namespace refactor::kernel + +#endif diff --git a/src/04kernel/src/attributes/rope_info.cc b/src/04kernel/src/attributes/rope_info.cc new file mode 100644 index 00000000..73b6afec --- /dev/null +++ b/src/04kernel/src/attributes/rope_info.cc @@ -0,0 +1,17 @@ +#include "kernel/attributes/rope_info.h" + +namespace refactor::kernel { + RoPEInfo::RoPEInfo(Tensor const &input, float _theta) : theta(_theta) { + if (input.rank() == 4) { + batchsize = input.shape[0]; + seq_len = input.shape[1]; + n_heads = input.shape[2]; + head_dim = input.shape[3]; + } else { + batchsize = 1; + seq_len = input.shape[0]; + n_heads = input.shape[1]; + head_dim = input.shape[2]; + } + } +}// namespace refactor::kernel diff --git a/src/04kernel/src/collectors/rope.cc b/src/04kernel/src/collectors/rope.cc new file mode 100644 index 00000000..93a4075e --- /dev/null +++ b/src/04kernel/src/collectors/rope.cc @@ -0,0 +1,27 @@ +#include "kernel/collectors/rope.h" +#include "kernel/attributes/rope_info.h" +#include "../kernels/rope/cpu_kernel.hh" +#include "../kernels/rope/cuda_kernel.hh" +namespace refactor::kernel { + std::vector + RoPECollector::filter(TensorRefs inputs, TensorRefs outputs) const { + RoPEInfo info(inputs[0], theta); + + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + if (auto ptr = RoPECpu::build(info, inputs[0]); ptr != nullptr) { + ans.emplace_back(std::move(ptr)); + } + break; + case decltype(_target)::Nvidia: + if (auto ptr = RoPECuda::build(info, inputs[0]); ptr != nullptr) { + ans.emplace_back(std::move(ptr)); + } + break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/rope/cpu_kernel.cc b/src/04kernel/src/kernels/rope/cpu_kernel.cc new file mode 100644 index 00000000..c26345e7 --- /dev/null +++ b/src/04kernel/src/kernels/rope/cpu_kernel.cc @@ -0,0 +1,67 @@ +#include "cpu_kernel.hh" +#include + +namespace refactor::kernel { + using K = RoPECpu; + + K::RoPECpu(decltype(info) info_) noexcept + : Kernel(), info(info_) {} + + auto K::build(decltype(info) info, Tensor const &x) noexcept -> KernelBox { + if (x.dataType != DataType::F32) { + return nullptr; + } + return std::make_unique(info); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing rotary position embedding on cpu"; + } + + + auto K::lower(Resources &) const -> RoutineWorkspace { + return [batchsize = this->info.batchsize, + seq_len = this->info.seq_len, + n_heads = this->info.n_heads, + head_dim = this->info.head_dim, + theta = this->info.theta]// + (Resources &, void *, void const *const *inputs, void *const *outputs) { + auto input = reinterpret_cast(inputs[0]); + auto pos_ids = reinterpret_cast(inputs[1]); + auto output = reinterpret_cast(outputs[0]); + auto half_dim = head_dim / 2; + + for (unsigned int batch_id = 0; batch_id < batchsize; batch_id++) { + for (unsigned int pos = 0; pos < seq_len; pos++) { + auto pos_id = pos_ids[batch_id * seq_len + pos]; + for (unsigned int head = 0; head < n_heads; head++) { + auto offset = batch_id * seq_len * n_heads * head_dim + pos * n_heads * head_dim + head * head_dim; + for (unsigned int i = 0; i < head_dim; i++) { + if (i < half_dim) { + float freq = pos_id * powf(theta, -float(i * 2) / head_dim); + float cos_freq = cos(freq); + float sin_freq = sin(freq); + output[offset + i] = + input[offset + i] * float(cos_freq) - input[offset + i + half_dim] * float(sin_freq); + } else { + float freq = pos_id * powf(theta, -float((i - half_dim) * 2) / head_dim); + float cos_freq = cos(freq); + float sin_freq = sin(freq); + output[offset + i] = + input[offset + i] * float(cos_freq) + input[offset + i - half_dim] * float(sin_freq); + } + } + } + } + } + }; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/rope/cpu_kernel.hh b/src/04kernel/src/kernels/rope/cpu_kernel.hh new file mode 100644 index 00000000..48cd779f --- /dev/null +++ b/src/04kernel/src/kernels/rope/cpu_kernel.hh @@ -0,0 +1,24 @@ +#ifndef KERNEL_ROPE_CPU_KERNEL_HH +#define KERNEL_ROPE_CPU_KERNEL_HH + + +#include "kernel/kernel.h" +#include "kernel/tensor.h" +#include "kernel/attributes/rope_info.h" + +namespace refactor::kernel { + struct RoPECpu final : public Kernel { + RoPEInfo info; + RoPECpu(decltype(info)) noexcept; + + static KernelBox build(decltype(info), Tensor const &x) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; + RoutineWorkspace lower(Resources &) const final; + + }; +} + +#endif diff --git a/src/04kernel/src/kernels/rope/cuda_kernel.cc b/src/04kernel/src/kernels/rope/cuda_kernel.cc new file mode 100644 index 00000000..b1c82bf5 --- /dev/null +++ b/src/04kernel/src/kernels/rope/cuda_kernel.cc @@ -0,0 +1,49 @@ +#include "cuda_kernel.hh" + +#ifdef USE_CUDA +#include "kernel/cuda/rope.cuh" +#endif +namespace refactor::kernel { + using K = RoPECuda; + + K::RoPECuda(decltype(info) info_, DataType _dtype) noexcept + : Kernel(), info(info_), dtype(_dtype) {} + + auto K::build(decltype(info) info, Tensor const &x) noexcept -> KernelBox { +#ifndef USE_CUDA + return nullptr; +#endif + return std::make_unique(info, x.dataType); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing rotary position embedding on Nvidia gpu"; + } +#ifdef USE_CUDA + auto K::lower(Resources &) const -> RoutineWorkspace { + + return [info = this->info, useHalf = this->dtype == DataType::FP16]// + (Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + cuda::launchRoPE( + inputs[0], + reinterpret_cast(inputs[1]), + outputs[0], + info.batchsize, + info.seq_len, + info.n_heads, + info.head_dim, + info.theta, + useHalf); + }; + } +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/rope/cuda_kernel.hh b/src/04kernel/src/kernels/rope/cuda_kernel.hh new file mode 100644 index 00000000..c189ace1 --- /dev/null +++ b/src/04kernel/src/kernels/rope/cuda_kernel.hh @@ -0,0 +1,26 @@ +#ifndef KERNEL_ROPE_CUDA_KERNEL_HH +#define KERNEL_ROPE_CUDA_KERNEL_HH + + +#include "kernel/attributes/rope_info.h" +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + struct RoPECuda final : public Kernel { + RoPEInfo info; + DataType dtype; + RoPECuda(decltype(info), DataType) noexcept; + + static KernelBox build(decltype(info), Tensor const &x) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const final; +#endif + }; +}// namespace refactor::kernel + +#endif diff --git a/src/04kernel/test/kernels/rope/test_cuda.cpp b/src/04kernel/test/kernels/rope/test_cuda.cpp new file mode 100644 index 00000000..9abf00de --- /dev/null +++ b/src/04kernel/test/kernels/rope/test_cuda.cpp @@ -0,0 +1,51 @@ +#ifdef USE_CUDA + +#include "../../../src/kernels/rope/cpu_kernel.hh" +#include "../../../src/kernels/rope/cuda_kernel.hh" +#include "hardware/device_manager.h" +#include +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, RoPECuda) { + auto input = Tensor::share(DataType::F32, Shape{2, 3, 1, 2}); + auto posIDs = Tensor::share(DataType::I64, Shape{2, 3}); + auto output = Tensor::share(DataType::F32, Shape{2, 3, 1, 2}); + auto res = runtime::Resources(); + RoPEInfo info(*input, 10000.0f); + auto cudaKernel = RoPECuda::build(info, *input)->lower(res).routine; + auto cpuKernel = RoPECpu::build(info, *input)->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto inputGpu = dev.malloc(input->bytesSize()), + posIDsGpu = dev.malloc(posIDs->bytesSize()), + outputGpu = dev.malloc(output->bytesSize()); + // put input data + std::vector output_(output->elementsSize()); + std::vector posIDs_({0, 1, 2, 0, 1, 2}); + std::vector input_(input->elementsSize(), 0.2); + + inputGpu->copyFromHost(input_.data(), input->bytesSize()); + posIDsGpu->copyFromHost(posIDs_.data(), posIDs->bytesSize()); + { + void const *inputs[]{*inputGpu, *posIDsGpu}; + void *outputs[]{*outputGpu}; + cudaKernel(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{input_.data(), posIDs_.data()}; + void *outputs[]{output_.data()}; + cpuKernel(res, nullptr, inputs, outputs); + } + // check + std::vector result(output->elementsSize()); + outputGpu->copyToHost(result.data(), output->bytesSize()); + for (auto i : range0_(output_.size())) { + EXPECT_FLOAT_EQ(result[i], output_[i]); + } +} + +#endif diff --git a/src/05computation/include/computation/operators/rope.h b/src/05computation/include/computation/operators/rope.h new file mode 100644 index 00000000..155e55a2 --- /dev/null +++ b/src/05computation/include/computation/operators/rope.h @@ -0,0 +1,23 @@ +#ifndef COMPUTATION_ROPE_H +#define COMPUTATION_ROPE_H + +#include "../operator.h" + +namespace refactor::computation { + + struct RotaryPositionEmbedding final : public Operator { + float theta; + + constexpr RotaryPositionEmbedding(float _theta) noexcept + : Operator(), theta(_theta) {} + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const final; + std::string serialize() const noexcept final; + }; + +}// namespace refactor::computation + +#endif diff --git a/src/05computation/src/operators/rope.cc b/src/05computation/src/operators/rope.cc new file mode 100644 index 00000000..4fe84977 --- /dev/null +++ b/src/05computation/src/operators/rope.cc @@ -0,0 +1,26 @@ +#include "computation/operators/rope.h" +#include "kernel/collectors/rope.h" + +namespace refactor::computation { + using Op = RotaryPositionEmbedding; + + auto Op::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto Op::opTypeId() const noexcept -> size_t { return typeId(); } + auto Op::name() const noexcept -> std::string_view { return "RotaryPositionEmbedding"; } + auto Op::candidateKernels(Target target) const -> kernel::CollectorBox { + using Collector_ = kernel::RoPECollector; + return std::make_unique(target, theta); + } + auto Op::serialize() const noexcept -> std::string { + union code { + float f; + int32_t i; + }; + return fmt::format(("{}({:e}={:#010x})"), + name(), theta, + code{theta}.i); + } +}// namespace refactor::computation diff --git a/src/08-01llm/src/operators.cpp b/src/08-01llm/src/operators.cpp index a99adb08..f97db344 100644 --- a/src/08-01llm/src/operators.cpp +++ b/src/08-01llm/src/operators.cpp @@ -2,6 +2,7 @@ #include "operators/attention.hh" #include "operators/mat_mul.hh" #include "operators/rms_normalization.hh" +#include "operators/rope.hh" namespace refactor::llm { using namespace frontend; @@ -9,9 +10,10 @@ namespace refactor::llm { void register_() { #define REGISTER(NAME, CLASS) Operator::register_("llm::" #NAME) // clang-format off - REGISTER(Attention , Attention ); - REGISTER(RmsNormalization, RmsNormalization); - REGISTER(MatMul , MatMul ); + REGISTER(Attention , Attention ); + REGISTER(RmsNormalization , RmsNormalization ); + REGISTER(MatMul , MatMul ); + REGISTER(RotaryPositionEmbedding, RotaryPositionEmbedding); // clang-format on #undef REGISTER } diff --git a/src/08-01llm/src/operators/rope.cc b/src/08-01llm/src/operators/rope.cc new file mode 100644 index 00000000..54d96ba9 --- /dev/null +++ b/src/08-01llm/src/operators/rope.cc @@ -0,0 +1,66 @@ +#include "rope.hh" +#include "common.h" +#include "computation/operators/rope.h" +namespace refactor::llm { + using Op = RotaryPositionEmbedding; + + Op::RotaryPositionEmbedding(float _theta) : Operator(), theta(_theta) {} + + auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { + auto _theta = attributes.getOrInsert("theta", {10000}).float_(); + return OpBox(std::make_unique(_theta)); + } + + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "llm::RotaryPositionEmbedding"; } + + auto Op::infer(TensorRefs inputs, InferOptions const &) const -> InferResult { + EXPECT_SIZE(2) + + auto input = inputs[0]; // (batchsize?, seq_len, n_heads, head_dim) + auto pos_ids = inputs[1];// (batchsize?, seq_len) + + // Check shapes + if (input.rank() == 4) { + if (pos_ids.rank() != 2) { + return Err(InferError(ERROR_MSG("Invalid pos_ids shape"))); + } + if (input.shape[0] != pos_ids.shape[0]) { + return Err(InferError(ERROR_MSG("Batchsizes not matched"))); + } + if (input.shape[1] != pos_ids.shape[1]) { + return Err(InferError(ERROR_MSG("Seq_len not matched"))); + } + } else if (input.rank() == 3) { + if (pos_ids.rank() != 1) { + return Err(InferError(ERROR_MSG("Invalid pos_ids shape"))); + } + if (input.shape[0] != pos_ids.shape[0]) { + return Err(InferError(ERROR_MSG("Seq_len not matched"))); + } + } else { + return Err(InferError(ERROR_MSG("Invalid input shape"))); + } + + // Check types + if (input.dataType != DataType::F32 && input.dataType != DataType::FP16) { + return Err(InferError(ERROR_MSG("Invalid input dtype"))); + } + if (pos_ids.dataType != DataType::I64) { + return Err(InferError(ERROR_MSG("Invalid pos_ids dtype"))); + } + + return Ok(Tensors{Tensor::share(input.dataType, input.shape, extractDependency(inputs))}); + } + + auto Op::lower(TensorRefs) const -> computation::OpBox { + using Op_ = computation::RotaryPositionEmbedding; + return std::make_unique(theta); + } + +}// namespace refactor::llm diff --git a/src/08-01llm/src/operators/rope.hh b/src/08-01llm/src/operators/rope.hh new file mode 100644 index 00000000..58b5b395 --- /dev/null +++ b/src/08-01llm/src/operators/rope.hh @@ -0,0 +1,28 @@ +#ifndef LLM_ROPE_HH +#define LLM_ROPE_HH + +#include "frontend/operator.h" + +namespace refactor::llm { + using namespace frontend; + + struct RotaryPositionEmbedding final : public Operator { + float theta; + + explicit RotaryPositionEmbedding(float _theta); + + static OpBox build(ModelContext const &, std::string_view, Attributes); + + static size_t typeId(); + size_t opTypeId() const final; + std::string_view opTypeName() const final; + + InferResult infer(TensorRefs, InferOptions const &) const final; + + computation::OpBox lower(TensorRefs) const final; + + }; +} + + +#endif