-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MoE grouped gemm and fused topk_softmax (#8)
* Initial * group gemm * Fix install. Add topk_softmax kernels.
- Loading branch information
1 parent
c448678
commit fa58402
Showing
10 changed files
with
1,013 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <torch/extension.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
#define VLLM_LDG(arg) *(arg) | ||
|
||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) | ||
|
||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ | ||
AT_DISPATCH_SWITCH( \ | ||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) | ||
|
||
template<typename T> | ||
__device__ __forceinline__ T silu(const T& x) { | ||
// x * sigmoid(x) | ||
return (T) (((float) x) / (1.0f + expf((float) -x))); | ||
} | ||
|
||
template<typename scalar_t> | ||
__global__ void silu_and_mul_kernel( | ||
scalar_t* __restrict__ out, // [..., d] | ||
const scalar_t* __restrict__ input, // [..., 2, d] | ||
const int d) { | ||
const int64_t token_idx = blockIdx.x; | ||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { | ||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); | ||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); | ||
out[token_idx * d + idx] = silu(x) * y; | ||
} | ||
} | ||
|
||
|
||
void silu_and_mul( | ||
torch::Tensor& out, // [..., d] | ||
torch::Tensor& input) // [..., 2 * d] | ||
{ | ||
int64_t num_tokens = input.numel() / input.size(-1); | ||
int d = input.size(-1) / 2; | ||
|
||
dim3 grid(num_tokens); | ||
dim3 block(std::min(d, 1024)); | ||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
VLLM_DISPATCH_FLOATING_TYPES( | ||
input.scalar_type(), | ||
"silu_and_mul_kernel", | ||
[&] { | ||
silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||
out.data_ptr<scalar_t>(), | ||
input.data_ptr<scalar_t>(), | ||
d); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
void silu_and_mul( | ||
torch::Tensor& out, | ||
torch::Tensor& input); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#include <torch/extension.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
|
||
#include <ATen/ATen.h> | ||
#include <THC/THCAtomics.cuh> | ||
|
||
const static size_t NUM_MAX_EXPERTS = 64; | ||
|
||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) | ||
|
||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ | ||
AT_DISPATCH_SWITCH( \ | ||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) | ||
|
||
template <typename scalar_t> | ||
__global__ void moe_alig_block_size_kernel(scalar_t *__restrict__ topk_ids, | ||
int32_t *sorted_token_ids, | ||
int32_t *expert_ids, | ||
int32_t *total_tokens_post_pad, | ||
int32_t num_experts, | ||
int32_t block_size, | ||
size_t numel) { | ||
const size_t tokens_per_thread = ((numel + blockDim.x - 1) / blockDim.x); | ||
const size_t start_idx = threadIdx.x * tokens_per_thread; | ||
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS]; | ||
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1]; | ||
for(int i = 0;i < num_experts;i++){ | ||
tokens_cnts[threadIdx.x + 1][i] = 0; | ||
} | ||
|
||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||
++tokens_cnts[threadIdx.x + 1][topk_ids[i]]; | ||
} | ||
|
||
__syncthreads(); | ||
|
||
tokens_cnts[0][threadIdx.x] = 0; | ||
for(int i=1;i<=blockDim.x;++i){ | ||
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x]; | ||
} | ||
|
||
__syncthreads(); | ||
|
||
if(threadIdx.x ==0){ | ||
cumsum[0] = 0; | ||
for(int i=1;i<=num_experts;++i){ | ||
cumsum[i] = cumsum[i-1] + (tokens_cnts[blockDim.x][i - 1] + block_size - 1) / block_size * block_size; | ||
} | ||
*total_tokens_post_pad = cumsum[num_experts]; | ||
} | ||
|
||
__syncthreads(); | ||
|
||
for(int i= cumsum[threadIdx.x];i<cumsum[threadIdx.x + 1];i += block_size){ | ||
expert_ids[i / block_size] = threadIdx.x; | ||
} | ||
|
||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||
int32_t expert_id = topk_ids[i]; | ||
int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id]; | ||
sorted_token_ids[rank_post_pad] = i; | ||
++tokens_cnts[threadIdx.x][expert_id]; | ||
} | ||
} | ||
|
||
void moe_alig_block_size( | ||
torch::Tensor topk_ids, | ||
int num_experts, | ||
int block_size, | ||
torch::Tensor sorted_token_ids, | ||
torch::Tensor experts_ids, | ||
torch::Tensor num_tokens_post_pad) { | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
assert(num_experts <= NUM_MAX_EXPERTS); | ||
VLLM_DISPATCH_INTEGRAL_TYPES( | ||
topk_ids.scalar_type(), "moe_alig_block_size_kernel", [&] { | ||
moe_alig_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>( | ||
topk_ids.data_ptr<scalar_t>(), | ||
sorted_token_ids.data_ptr<int32_t>(), | ||
experts_ids.data_ptr<int32_t>(), | ||
num_tokens_post_pad.data_ptr<int32_t>(), | ||
num_experts, | ||
block_size, | ||
topk_ids.numel()); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
void moe_alig_block_size( | ||
torch::Tensor topk_ids, | ||
int num_experts, | ||
int block_size, | ||
torch::Tensor sorted_token_ids, | ||
torch::Tensor experts_ids, | ||
torch::Tensor num_tokens_post_pad | ||
); |
Oops, something went wrong.