-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #93 from InfiniTensor/kernel-rope
feat (kernel): 增加RoPE算子
- Loading branch information
Showing
17 changed files
with
547 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef KERNEL_CUDA_ROPE_CUH | ||
#define KERNEL_CUDA_ROPE_CUH | ||
|
||
namespace refactor::kernel::cuda { | ||
void launchRoPE( | ||
void const *input, | ||
int64_t const *posIDs, | ||
void *output, | ||
unsigned int batchSize, | ||
unsigned int seqLen, | ||
unsigned int nHeads, | ||
unsigned int headDim, | ||
float theta, | ||
bool useHalf); | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
#include "kernel/cuda/functions.cuh" | ||
#include "kernel/cuda/rope.cuh" | ||
#include <cstdio> | ||
#include <cuda_fp16.h> | ||
|
||
namespace refactor::kernel::cuda { | ||
|
||
template<class T> | ||
__global__ static void RoPEKernel( | ||
unsigned int dim_model,// dim_model = num_head * dim_head | ||
unsigned int dim_head, | ||
unsigned int hidden_stride, // hidden_stride = seq_len * dim_model | ||
unsigned int pos_stride, // pos_stride = seq_len | ||
const int64_t *__restrict__ pos,// (batch, seq_len) | ||
const T *__restrict__ in, // (batch, seq_len, num_head, dim_head) | ||
T *__restrict__ out // (batch, seq_len, num_head, dim_head) | ||
) { | ||
unsigned int batch_id = blockIdx.x; | ||
int64_t target_pos = pos[batch_id * pos_stride + blockIdx.y]; | ||
size_t ith = blockIdx.z * blockDim.x + threadIdx.x; | ||
unsigned int col = ith % dim_head; | ||
size_t offset = batch_id * hidden_stride + blockIdx.y * dim_model; | ||
|
||
if (ith >= dim_model) | ||
return; | ||
|
||
unsigned int half_dim = dim_head / 2; | ||
if (col < half_dim) { | ||
float freq = target_pos * powf(10000, -float(col * 2) / dim_head); | ||
float cos_freq = cos(freq); | ||
float sin_freq = sin(freq); | ||
out[offset + ith] = | ||
in[offset + ith] * T(cos_freq) - in[offset + ith + half_dim] * T(sin_freq); | ||
} else { | ||
float freq = target_pos * powf(10000, -float((col - half_dim) * 2) / dim_head); | ||
float cos_freq = cos(freq); | ||
float sin_freq = sin(freq); | ||
out[offset + ith] = | ||
in[offset + ith] * T(cos_freq) + in[offset + ith - half_dim] * T(sin_freq); | ||
} | ||
} | ||
|
||
|
||
void launchRoPE( | ||
void const *input, | ||
int64_t const *posIDs, | ||
void *output, | ||
unsigned int batchSize, | ||
unsigned int seqLen, | ||
unsigned int nHeads, | ||
unsigned int headDim, | ||
float theta, | ||
bool useHalf) { | ||
|
||
unsigned int dimModel = nHeads * headDim; | ||
unsigned int hiddenStride = seqLen * dimModel; | ||
unsigned int threads = min(1024, round_up(dimModel, 32)); | ||
dim3 gridDim(batchSize, seqLen, round_up(dimModel, threads) / threads); | ||
dim3 blockDim(threads, 1, 1); | ||
if (useHalf) { | ||
RoPEKernel<<<gridDim, blockDim, 0, 0>>>( | ||
dimModel, | ||
headDim, | ||
hiddenStride, | ||
seqLen, | ||
reinterpret_cast<const int64_t *>(posIDs), | ||
reinterpret_cast<const half *>(input), | ||
reinterpret_cast<half *>(output)); | ||
|
||
} else { | ||
RoPEKernel<<<gridDim, blockDim, 0, 0>>>( | ||
dimModel, | ||
headDim, | ||
hiddenStride, | ||
seqLen, | ||
reinterpret_cast<const int64_t *>(posIDs), | ||
reinterpret_cast<const float *>(input), | ||
reinterpret_cast<float *>(output)); | ||
} | ||
} | ||
}// namespace refactor::kernel::cuda |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#ifndef KERNEL_ROPE_INFO_H | ||
#define KERNEL_ROPE_INFO_H | ||
|
||
#include "../tensor.h" | ||
|
||
namespace refactor::kernel { | ||
struct RoPEInfo { | ||
dim_t batchsize = 1; | ||
dim_t seq_len, n_heads, head_dim; | ||
float theta; | ||
|
||
RoPEInfo(Tensor const &input, float _theta); | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef KERNEL_ROPE_COLLECTOR_H | ||
#define KERNEL_ROPE_COLLECTOR_H | ||
|
||
#include "../collector.h" | ||
|
||
namespace refactor::kernel { | ||
struct RoPECollector final : public InfoCollector { | ||
float theta; | ||
constexpr RoPECollector(decltype(_target) target, float _theta) noexcept | ||
: InfoCollector(target), theta(_theta){} | ||
|
||
std::vector<KernelBox> | ||
filter(TensorRefs inputs, TensorRefs outputs) const final; | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#include "kernel/attributes/rope_info.h" | ||
|
||
namespace refactor::kernel { | ||
RoPEInfo::RoPEInfo(Tensor const &input, float _theta) : theta(_theta) { | ||
if (input.rank() == 4) { | ||
batchsize = input.shape[0]; | ||
seq_len = input.shape[1]; | ||
n_heads = input.shape[2]; | ||
head_dim = input.shape[3]; | ||
} else { | ||
batchsize = 1; | ||
seq_len = input.shape[0]; | ||
n_heads = input.shape[1]; | ||
head_dim = input.shape[2]; | ||
} | ||
} | ||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#include "kernel/collectors/rope.h" | ||
#include "kernel/attributes/rope_info.h" | ||
#include "../kernels/rope/cpu_kernel.hh" | ||
#include "../kernels/rope/cuda_kernel.hh" | ||
namespace refactor::kernel { | ||
std::vector<KernelBox> | ||
RoPECollector::filter(TensorRefs inputs, TensorRefs outputs) const { | ||
RoPEInfo info(inputs[0], theta); | ||
|
||
std::vector<KernelBox> ans; | ||
switch (_target) { | ||
case decltype(_target)::Cpu: | ||
if (auto ptr = RoPECpu::build(info, inputs[0]); ptr != nullptr) { | ||
ans.emplace_back(std::move(ptr)); | ||
} | ||
break; | ||
case decltype(_target)::Nvidia: | ||
if (auto ptr = RoPECuda::build(info, inputs[0]); ptr != nullptr) { | ||
ans.emplace_back(std::move(ptr)); | ||
} | ||
break; | ||
default: | ||
UNREACHABLEX(void, "Unknown target"); | ||
} | ||
return ans; | ||
} | ||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#include "cpu_kernel.hh" | ||
#include <cmath> | ||
|
||
namespace refactor::kernel { | ||
using K = RoPECpu; | ||
|
||
K::RoPECpu(decltype(info) info_) noexcept | ||
: Kernel(), info(info_) {} | ||
|
||
auto K::build(decltype(info) info, Tensor const &x) noexcept -> KernelBox { | ||
if (x.dataType != DataType::F32) { | ||
return nullptr; | ||
} | ||
return std::make_unique<K>(info); | ||
} | ||
auto K::typeId() noexcept -> size_t { | ||
static uint8_t ID = 1; | ||
return reinterpret_cast<size_t>(&ID); | ||
} | ||
|
||
auto K::kernelTypeId() const noexcept -> size_t { | ||
return typeId(); | ||
} | ||
auto K::description() const noexcept -> std::string_view { | ||
return "Performing rotary position embedding on cpu"; | ||
} | ||
|
||
|
||
auto K::lower(Resources &) const -> RoutineWorkspace { | ||
return [batchsize = this->info.batchsize, | ||
seq_len = this->info.seq_len, | ||
n_heads = this->info.n_heads, | ||
head_dim = this->info.head_dim, | ||
theta = this->info.theta]// | ||
(Resources &, void *, void const *const *inputs, void *const *outputs) { | ||
auto input = reinterpret_cast<float const *>(inputs[0]); | ||
auto pos_ids = reinterpret_cast<int64_t const *>(inputs[1]); | ||
auto output = reinterpret_cast<float *>(outputs[0]); | ||
auto half_dim = head_dim / 2; | ||
|
||
for (unsigned int batch_id = 0; batch_id < batchsize; batch_id++) { | ||
for (unsigned int pos = 0; pos < seq_len; pos++) { | ||
auto pos_id = pos_ids[batch_id * seq_len + pos]; | ||
for (unsigned int head = 0; head < n_heads; head++) { | ||
auto offset = batch_id * seq_len * n_heads * head_dim + pos * n_heads * head_dim + head * head_dim; | ||
for (unsigned int i = 0; i < head_dim; i++) { | ||
if (i < half_dim) { | ||
float freq = pos_id * powf(theta, -float(i * 2) / head_dim); | ||
float cos_freq = cos(freq); | ||
float sin_freq = sin(freq); | ||
output[offset + i] = | ||
input[offset + i] * float(cos_freq) - input[offset + i + half_dim] * float(sin_freq); | ||
} else { | ||
float freq = pos_id * powf(theta, -float((i - half_dim) * 2) / head_dim); | ||
float cos_freq = cos(freq); | ||
float sin_freq = sin(freq); | ||
output[offset + i] = | ||
input[offset + i] * float(cos_freq) + input[offset + i - half_dim] * float(sin_freq); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
} | ||
|
||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#ifndef KERNEL_ROPE_CPU_KERNEL_HH | ||
#define KERNEL_ROPE_CPU_KERNEL_HH | ||
|
||
|
||
#include "kernel/kernel.h" | ||
#include "kernel/tensor.h" | ||
#include "kernel/attributes/rope_info.h" | ||
|
||
namespace refactor::kernel { | ||
struct RoPECpu final : public Kernel { | ||
RoPEInfo info; | ||
RoPECpu(decltype(info)) noexcept; | ||
|
||
static KernelBox build(decltype(info), Tensor const &x) noexcept; | ||
static size_t typeId() noexcept; | ||
|
||
size_t kernelTypeId() const noexcept final; | ||
std::string_view description() const noexcept final; | ||
RoutineWorkspace lower(Resources &) const final; | ||
|
||
}; | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#include "cuda_kernel.hh" | ||
|
||
#ifdef USE_CUDA | ||
#include "kernel/cuda/rope.cuh" | ||
#endif | ||
namespace refactor::kernel { | ||
using K = RoPECuda; | ||
|
||
K::RoPECuda(decltype(info) info_, DataType _dtype) noexcept | ||
: Kernel(), info(info_), dtype(_dtype) {} | ||
|
||
auto K::build(decltype(info) info, Tensor const &x) noexcept -> KernelBox { | ||
#ifndef USE_CUDA | ||
return nullptr; | ||
#endif | ||
return std::make_unique<K>(info, x.dataType); | ||
} | ||
|
||
auto K::typeId() noexcept -> size_t { | ||
static uint8_t ID = 1; | ||
return reinterpret_cast<size_t>(&ID); | ||
} | ||
|
||
auto K::kernelTypeId() const noexcept -> size_t { | ||
return typeId(); | ||
} | ||
auto K::description() const noexcept -> std::string_view { | ||
return "Performing rotary position embedding on Nvidia gpu"; | ||
} | ||
#ifdef USE_CUDA | ||
auto K::lower(Resources &) const -> RoutineWorkspace { | ||
|
||
return [info = this->info, useHalf = this->dtype == DataType::FP16]// | ||
(Resources &, void *workspace, void const *const *inputs, void *const *outputs) { | ||
cuda::launchRoPE( | ||
inputs[0], | ||
reinterpret_cast<const int64_t *>(inputs[1]), | ||
outputs[0], | ||
info.batchsize, | ||
info.seq_len, | ||
info.n_heads, | ||
info.head_dim, | ||
info.theta, | ||
useHalf); | ||
}; | ||
} | ||
#endif | ||
|
||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#ifndef KERNEL_ROPE_CUDA_KERNEL_HH | ||
#define KERNEL_ROPE_CUDA_KERNEL_HH | ||
|
||
|
||
#include "kernel/attributes/rope_info.h" | ||
#include "kernel/kernel.h" | ||
#include "kernel/tensor.h" | ||
|
||
namespace refactor::kernel { | ||
struct RoPECuda final : public Kernel { | ||
RoPEInfo info; | ||
DataType dtype; | ||
RoPECuda(decltype(info), DataType) noexcept; | ||
|
||
static KernelBox build(decltype(info), Tensor const &x) noexcept; | ||
static size_t typeId() noexcept; | ||
|
||
size_t kernelTypeId() const noexcept final; | ||
std::string_view description() const noexcept final; | ||
#ifdef USE_CUDA | ||
RoutineWorkspace lower(Resources &) const final; | ||
#endif | ||
}; | ||
}// namespace refactor::kernel | ||
|
||
#endif |
Oops, something went wrong.