diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 85f88cf123..de99799350 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -249,6 +249,7 @@ paddle::Tensor MoeExpertFFNFunc( const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, + const paddle::optional& up_proj_in_scale, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, diff --git a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.h b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.h similarity index 100% rename from custom_ops/gpu_ops/moe/fast_hardamard_kernel.h rename to custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.h diff --git a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp similarity index 96% rename from custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu rename to custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp index 1323cb4839..02906fa9de 100644 --- a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp @@ -974,79 +974,3 @@ void MoeFastHardamardWrapper(const T *x_data, } } } - -template void MoeFastHardamardWrapper( - const phi::dtype::float16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::float16 *shift, - const phi::dtype::float16 *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - phi::dtype::float16 *out, - cudaStream_t &stream -); - -template void MoeFastHardamardWrapper( - const phi::dtype::float16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::float16 *shift, - const phi::dtype::float16 *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - int8_t *out, - cudaStream_t &stream -); - -template void MoeFastHardamardWrapper( - const phi::dtype::bfloat16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::bfloat16 *shift, - const phi::dtype::bfloat16 *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - phi::dtype::bfloat16 *out, - cudaStream_t &stream -); - -template void MoeFastHardamardWrapper( - const phi::dtype::bfloat16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::bfloat16 *shift, - const phi::dtype::bfloat16 *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - int8_t *out, - cudaStream_t &stream -); diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu new file mode 100644 index 0000000000..21800dc636 --- /dev/null +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fast_hardamard_kernel.hpp" + +template void MoeFastHardamardWrapper( + const phi::dtype::bfloat16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::bfloat16 *shift, + const phi::dtype::bfloat16 *smooth, + const float* quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + phi::dtype::bfloat16 *out, + cudaStream_t &stream +); diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_fp8.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_fp8.cu new file mode 100644 index 0000000000..bbc47d975b --- /dev/null +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_fp8.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fast_hardamard_kernel.hpp" + +template void MoeFastHardamardWrapper( + const phi::dtype::bfloat16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::bfloat16 *shift, + const phi::dtype::bfloat16 *smooth, + const float* quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + phi::dtype::float8_e4m3fn *out, + cudaStream_t &stream +); diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_int8.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_int8.cu new file mode 100644 index 0000000000..e3f9663fe7 --- /dev/null +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_int8.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fast_hardamard_kernel.hpp" + +template void MoeFastHardamardWrapper( + const phi::dtype::bfloat16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::bfloat16 *shift, + const phi::dtype::bfloat16 *smooth, + const float* quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + int8_t *out, + cudaStream_t &stream +); diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu new file mode 100644 index 0000000000..e61bf44378 --- /dev/null +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fast_hardamard_kernel.hpp" + +template void MoeFastHardamardWrapper( + const phi::dtype::float16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::float16 *shift, + const phi::dtype::float16 *smooth, + const float* quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + phi::dtype::float16 *out, + cudaStream_t &stream +); diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu new file mode 100644 index 0000000000..e4edb32b5f --- /dev/null +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fast_hardamard_kernel.hpp" + +template void MoeFastHardamardWrapper( + const phi::dtype::float16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::float16 *shift, + const phi::dtype::float16 *smooth, + const float* quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + int8_t *out, + cudaStream_t &stream +); diff --git a/custom_ops/gpu_ops/moe/fused_moe_helper.h b/custom_ops/gpu_ops/moe/fused_moe_helper.h index 703a7c11f0..f24f12ea70 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_helper.h +++ b/custom_ops/gpu_ops/moe/fused_moe_helper.h @@ -250,7 +250,7 @@ template class MoeHelper { initialize_moe_routing_kernelLauncher( input_activations, permuted_data_, permuted_rows_, nullptr, nullptr, - expanded_source_row_to_expanded_dest_row, num_rows, num_rows, + expanded_source_row_to_expanded_dest_row, nullptr, num_rows, num_rows, hidden_size, k, stream); const int64_t expanded_active_expert_rows = k * num_rows; diff --git a/custom_ops/gpu_ops/moe/fused_moe_op.h b/custom_ops/gpu_ops/moe/fused_moe_op.h index eeaecb716f..b3687ea9a7 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_op.h +++ b/custom_ops/gpu_ops/moe/fused_moe_op.h @@ -128,6 +128,17 @@ struct SumOp { __device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; } }; +template +struct MaxOp { +__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +__device__ inline float operator()(float const &x, float const &y) { return fmax(x, y); } +}; + + template __forceinline__ __device__ OutType QuantHelperFunc(const InType input, const float scale, @@ -139,101 +150,114 @@ __forceinline__ __device__ OutType QuantHelperFunc(const InType input, template __global__ void masked_quantize_moe_input_kernel(const T* permuted_inputs, -const int64_t* expert_idx_per_token, -const float* quant_scales, -const float quant_max_bound, -const float quant_min_bound, -const int64_t token_num, -const int64_t dim, -float* permuted_input_row_sum, -const int64_t* recv_expert_count, -const int num_max_tokens_per_expert, -OutT* out) { -using LoadT = AlignedVector; -using LoadOutT = AlignedVector; -LoadT input_vec; -LoadOutT output_vec; -float scale_factor = -7.0f / 512.0f; -using vec_t = typename BytesToType::Type; -for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) { - const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert; - const auto expert_id = token_idx / num_max_tokens_per_expert; - if (token_idx_in_expert >= recv_expert_count[expert_id]) { - auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert; - auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x; - token_idx += num_iters_to_next_expert * gridDim.x; - continue; - } - int64_t expert_idx = expert_idx_per_token[token_idx]; - float quant_scale = quant_scales[expert_idx]; - float thread_row_sum = 0.0f; - for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { - int64_t offset = token_idx * dim + idx * VecSize; - Load(&permuted_inputs[offset], &input_vec); - #pragma unroll - for (int i = 0; i < VecSize; i++) { - output_vec[i] = QuantHelperFunc(input_vec[i], quant_scale, quant_max_bound, quant_min_bound); - thread_row_sum += static_cast(output_vec[i]); + const int64_t* expert_idx_per_token, + const int64_t token_num, + const int64_t dim, + float* input_dequant_scale, + const int64_t* recv_expert_count, + const int num_max_tokens_per_expert, + OutT* out) { + using LoadT = AlignedVector; + using LoadOutT = AlignedVector; + LoadT input_vec; + LoadOutT output_vec; + using vec_t = typename BytesToType::Type; + extern __shared__ char smem_[]; + for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) { + const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert; + const auto expert_id = token_idx / num_max_tokens_per_expert; + if (token_idx_in_expert >= recv_expert_count[expert_id]) { + auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert; + auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x; + token_idx += num_iters_to_next_expert * gridDim.x; + continue; + } + int64_t expert_idx = expert_idx_per_token[token_idx]; + float abs_max = 0.0f; + for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { + int64_t offset = token_idx * dim + idx * VecSize; + #pragma unroll + for (int i = 0; i < VecSize; i++) { + float res = static_cast(input_vec[i]); + abs_max = fmax(abs_max, fabs(res)); + } + Store(input_vec, reinterpret_cast(smem_) + idx * VecSize); + } + abs_max = BlockAllReduce(abs_max); + input_dequant_scale[token_idx] = abs_max; + float quant_scale = 440.0f / abs_max; + for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { + int64_t offset = token_idx * dim + idx * VecSize; + Load(reinterpret_cast(smem_) + idx * VecSize, &input_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + float res = static_cast(input_vec[i]); + output_vec[i] = static_cast(res * quant_scale); + } + *(reinterpret_cast(&out[offset])) = *(reinterpret_cast(&output_vec)); + } } - *(reinterpret_cast(&out[offset])) = *(reinterpret_cast(&output_vec)); - } - float block_row_sum = BlockAllReduce(thread_row_sum); - permuted_input_row_sum[token_idx] = block_row_sum * scale_factor; - } } template __global__ void quantize_moe_input_kernel(const T* permuted_inputs, -const int64_t* expert_idx_per_token, -const float* quant_scales, -const float quant_max_bound, -const float quant_min_bound, -const int64_t token_num, -const int64_t dim, -float* permuted_input_row_sum, -const int64_t* recv_expert_count, -const int num_max_tokens_per_expert, -OutT* out) { -using LoadT = AlignedVector; -using LoadOutT = AlignedVector; -LoadT input_vec; -LoadOutT output_vec; -using vec_t = typename BytesToType::Type; -float scale_factor = -7.0f / 512.0f; -for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) { - int64_t expert_idx = expert_idx_per_token[token_idx]; - float quant_scale = quant_scales[expert_idx]; - float thread_row_sum = 0.0f; - for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { - int64_t offset = token_idx * dim + idx * VecSize; - Load(&permuted_inputs[offset], &input_vec); - #pragma unroll - for (int i = 0; i < VecSize; i++) { - output_vec[i] = QuantHelperFunc(input_vec[i], quant_scale, quant_max_bound, quant_min_bound); - thread_row_sum += static_cast(output_vec[i]); + const int64_t* expert_idx_per_token, + const int64_t token_num, + const int64_t dim, + float* input_dequant_scale, + const int64_t* recv_expert_count, + const int num_max_tokens_per_expert, + OutT* out) { + using LoadT = AlignedVector; + using LoadOutT = AlignedVector; + LoadT input_vec; + LoadOutT output_vec; + using vec_t = typename BytesToType::Type; + + extern __shared__ char smem_[]; + + for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) { + int64_t expert_idx = expert_idx_per_token[token_idx]; + float abs_max = 0.0f; + for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { + int64_t offset = token_idx * dim + idx * VecSize; + Load(&permuted_inputs[offset], &input_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + float res = static_cast(input_vec[i]); + abs_max = fmax(abs_max, fabs(res)); + } + Store(input_vec, reinterpret_cast(smem_) + idx * VecSize); + } + abs_max = BlockAllReduce(abs_max); + input_dequant_scale[token_idx] = abs_max; + float quant_scale = 440.0f / abs_max; + + for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { + int64_t offset = token_idx * dim + idx * VecSize; + Load(reinterpret_cast(smem_) + idx * VecSize, &input_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + float res = static_cast(input_vec[i]); + output_vec[i] = static_cast(res * quant_scale); + } + *(reinterpret_cast(&out[offset])) = *(reinterpret_cast(&output_vec)); + } } - *(reinterpret_cast(&out[offset])) = *(reinterpret_cast(&output_vec)); - } - float block_row_sum = BlockAllReduce(thread_row_sum); - permuted_input_row_sum[token_idx] = block_row_sum * scale_factor; - } } template void quantize_moe_input( - const T* permuted_inputs, - const int64_t* expert_idx_per_token, - const float* quant_scales, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - float* permuted_input_row_sum, - const int64_t* recv_expert_count, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - OutT* out, - cudaStream_t stream) { + const T* permuted_inputs, + const int64_t* expert_idx_per_token, + const int64_t token_num, + const int64_t dim, + float* input_quant_scale, + const int64_t* recv_expert_count, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + OutT* out, + cudaStream_t stream) { constexpr int VecSize = 16 / sizeof(T); constexpr int threads_per_block = 128; const int dev_id = 0; @@ -247,112 +271,21 @@ void quantize_moe_input( const int num_blocks_per_wave = sm_count * act_blocks_per_sm; dim3 grid; grid.x = min(static_cast(num_blocks_per_wave), token_num); - kernel<<>>( + const int smem_size = dim * sizeof(T); + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>( permuted_inputs, expert_idx_per_token, - quant_scales, - quant_max_bound, - quant_min_bound, token_num, dim, - permuted_input_row_sum, + input_quant_scale, recv_expert_count, num_max_tokens_per_expert, out); } -template -__global__ void masked_compute_row_sum_kernel( -const T* permuted_inputs, -const int64_t token_num, -const int64_t dim, -float* permuted_input_row_sum, -const int64_t* recv_expert_count, -const int num_max_tokens_per_expert) { -using LoadT = AlignedVector; -LoadT input_vec; -float scale_factor = -7.0f / 512.0f; -for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) { - const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert; - const auto expert_id = token_idx / num_max_tokens_per_expert; - if (token_idx_in_expert >= recv_expert_count[expert_id]) { - auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert; - auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x; - token_idx += num_iters_to_next_expert * gridDim.x; - continue; - } - float thread_row_sum = 0.0f; - for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { - int64_t offset = token_idx * dim + idx * VecSize; - Load(&permuted_inputs[offset], &input_vec); - #pragma unroll - for (int i = 0; i < VecSize; i++) { - thread_row_sum += static_cast(input_vec[i]); - } - } - float block_row_sum = BlockAllReduce(thread_row_sum); - permuted_input_row_sum[token_idx] = block_row_sum * scale_factor; - } -} - -template -__global__ void compute_row_sum_kernel( -const T* permuted_inputs, -const int64_t token_num, -const int64_t dim, -float* permuted_input_row_sum, -const int64_t* recv_expert_count, -const int num_max_tokens_per_expert) { -using LoadT = AlignedVector; -LoadT input_vec; -float scale_factor = -7.0f / 512.0f; -for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) { - float thread_row_sum = 0.0f; - for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { - int64_t offset = token_idx * dim + idx * VecSize; - Load(&permuted_inputs[offset], &input_vec); - #pragma unroll - for (int i = 0; i < VecSize; i++) { - thread_row_sum += static_cast(input_vec[i]); - } - } - float block_row_sum = BlockAllReduce(thread_row_sum); - permuted_input_row_sum[token_idx] = block_row_sum * scale_factor; - } -} - -template -void compute_row_sum( - const T* permuted_inputs, - const int64_t token_num, - const int64_t dim, - float* permuted_input_row_sum, - const int64_t* recv_expert_count, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - cudaStream_t stream) { - constexpr int VecSize = 16 / sizeof(T); - constexpr int threads_per_block = 128; - const int dev_id = 0; - int sm_count; - int act_blocks_per_sm; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - assert(dim % VecSize == 0); - auto kernel = used_in_ep_low_latency ? masked_compute_row_sum_kernel : compute_row_sum_kernel; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, kernel, threads_per_block, 0); - const int num_blocks_per_wave = sm_count * act_blocks_per_sm; - dim3 grid; - grid.x = min(static_cast(num_blocks_per_wave), token_num); - kernel<<>>( - permuted_inputs, - token_num, - dim, - permuted_input_row_sum, - recv_expert_count, - num_max_tokens_per_expert); - } - // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing // the output in the softmax kernel when we extend this module to support @@ -1237,7 +1170,7 @@ void topk_gating_softmax_kernelLauncher(const T* input, // to row 0 in the original matrix. Thus, to know where to read in the source // matrix, we simply take the modulus of the expanded index. -template +template __global__ void initialize_moe_routing_kernel( const T* unpermuted_input, OutT* permuted_output, @@ -1245,6 +1178,7 @@ __global__ void initialize_moe_routing_kernel( const int *expert_idx_per_token, const float *w4a8_in_scale, int* expanded_source_row_to_expanded_dest_row, + float *dequant_scale, const int64_t num_rows, const int64_t active_rows, const int64_t cols, @@ -1266,15 +1200,49 @@ __global__ void initialize_moe_routing_kernel( expanded_dest_row; } + extern __shared__ char smem_[]; + + T * data_smem = reinterpret_cast(smem_); + if (expanded_dest_row < active_rows) { const int expert_idx = expert_idx_per_token[expanded_dest_row]; - const float scale = w4a8_in_scale ? w4a8_in_scale[expert_idx] : -1; + float scale; const int source_row = expanded_source_row % num_rows; const T* source_row_ptr = unpermuted_input + source_row * cols; OutT *dest_row_ptr = permuted_output + expanded_dest_row * cols; + if constexpr(std::is_same::value) { + if (dequant_scale != nullptr) { + float abs_max = 0.f; + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + Load(&source_row_ptr[tid], &src_vec); + Store(src_vec, &data_smem[tid]); + for (int j = 0; j < VecSize; j++) { + abs_max = fmaxf(abs_max, fabsf(static_cast(src_vec[j]))); + } + } + abs_max = BlockAllReduce(abs_max); + scale = 440.0f / abs_max; + dequant_scale[expanded_dest_row] = abs_max; + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + Load(&data_smem[tid], &src_vec); + using StoreT = AlignedVector; + StoreT dest_vec; + for (int j = 0; j < VecSize; j++) { + float quant_value = scale * static_cast(src_vec[j]); + dest_vec[j] = static_cast(quant_value); + } + Store(dest_vec, &dest_row_ptr[tid]); + } + return; + } else { + scale = w4a8_in_scale ? w4a8_in_scale[expert_idx] : -1; + } + } for (int tid = threadIdx.x * VecSize; tid < cols; tid += blockDim.x * VecSize) { // dest_row_ptr[tid] = source_row_ptr[tid]; @@ -1320,41 +1288,35 @@ void initialize_moe_routing_kernelLauncher( const int *expert_idx_per_token, const float *w4a8_in_scale, int* expanded_source_row_to_expanded_dest_row, + float * dequant_scale, const int64_t num_rows, const int64_t active_rows, const int64_t cols, const int64_t k, cudaStream_t stream) { - const int threads = std::min(cols, int64_t(1024)); + constexpr int threads = 256; constexpr int max_pack_size = 16 / sizeof(T); const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k); - if (cols % max_pack_size == 0) { - initialize_moe_routing_kernel - <<>>( - unpermuted_input, - permuted_output, - expanded_dest_row_to_expanded_source_row, - expert_idx_per_token, - w4a8_in_scale, - expanded_source_row_to_expanded_dest_row, - num_rows, - k * active_rows, - cols, - num_rows * k); - } else { - initialize_moe_routing_kernel - <<>>( + const int smem_size = cols * sizeof(float); + auto kernel = &initialize_moe_routing_kernel; + if (cols % max_pack_size != 0) { + kernel = &initialize_moe_routing_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>( unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, expert_idx_per_token, w4a8_in_scale, expanded_source_row_to_expanded_dest_row, + dequant_scale, num_rows, k * active_rows, cols, num_rows * k); - } } // ============================== Infer GEMM sizes diff --git a/custom_ops/gpu_ops/moe/moe_dispatch.cu b/custom_ops/gpu_ops/moe/moe_dispatch.cu index bc18ece456..495b522250 100644 --- a/custom_ops/gpu_ops/moe/moe_dispatch.cu +++ b/custom_ops/gpu_ops/moe/moe_dispatch.cu @@ -33,7 +33,7 @@ void MoeDispatchKernel( const int hidden_size, const int expert_num, paddle::Tensor *permute_input, paddle::Tensor *tokens_expert_prefix_sum, paddle::Tensor *permute_indices_per_token, paddle::Tensor *topk_weight, - paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) { + paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token, paddle::Tensor *dequant_scale) { using namespace phi; if (num_rows == 0){ @@ -120,22 +120,34 @@ void MoeDispatchKernel( initialize_moe_routing_kernelLauncher( input.data(), permute_input->data(), permuted_rows_, expert_idx_per_token->data(), w4a8_in_scale->data(), - permute_indices_per_token->data(), num_rows, num_rows, + permute_indices_per_token->data(), nullptr, + num_rows, num_rows, hidden_size, moe_topk, stream); } else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) { initialize_moe_routing_kernelLauncher( input.data(), permute_input->data(), permuted_rows_, expert_idx_per_token->data(), w4a8_in_scale->data(), - permute_indices_per_token->data(), num_rows, num_rows, + permute_indices_per_token->data(), nullptr, + num_rows, num_rows, hidden_size, moe_topk, stream); } } else { - initialize_moe_routing_kernelLauncher( - input.data(), permute_input->data(), permuted_rows_, - expert_idx_per_token->data(), nullptr, - permute_indices_per_token->data(), num_rows, num_rows, + if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) { + initialize_moe_routing_kernelLauncher( + input.data(), permute_input->data(), + permuted_rows_, expert_idx_per_token->data(), + nullptr, + permute_indices_per_token->data(), dequant_scale->data(), + num_rows, num_rows, hidden_size, moe_topk, stream); + } else { + initialize_moe_routing_kernelLauncher( + input.data(), permute_input->data(), permuted_rows_, + expert_idx_per_token->data(), nullptr, + permute_indices_per_token->data(), nullptr, num_rows, num_rows, + hidden_size, moe_topk, stream); + } } compute_total_rows_before_expert( @@ -170,10 +182,20 @@ std::vector MoeExpertDispatch( } else if (moe_quant_type == "w4afp8") { permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN; } + } else { + if (moe_quant_type == "w4afp8") { + permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN; + } } auto permute_input = GetEmptyTensor({moe_topk * num_rows, hidden_size}, permute_input_dtype, place); + int dequant_scale_size = 1; + if (moe_quant_type == "w4afp8" && !w4a8_in_scale) { + dequant_scale_size = moe_topk * num_rows; + } + + auto dequant_scale = GetEmptyTensor({dequant_scale_size}, paddle::DataType::FLOAT32, place); // correspond to the weighted coefficients of the results from each expert. auto topk_weight = GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); @@ -194,7 +216,8 @@ std::vector MoeExpertDispatch( permute_indices_per_token, topk_weight, topk_idx, - expert_idx_per_token}; + expert_idx_per_token, + dequant_scale}; } switch (input_type) { @@ -203,14 +226,14 @@ std::vector MoeExpertDispatch( input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk, group_moe, topk_only_mode, num_rows, hidden_size, expert_num, &permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token, - &topk_weight, &topk_idx, &expert_idx_per_token); + &topk_weight, &topk_idx, &expert_idx_per_token, &dequant_scale); break; case paddle::DataType::FLOAT16: MoeDispatchKernel( input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk, group_moe, topk_only_mode, num_rows, hidden_size, expert_num, &permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token, - &topk_weight, &topk_idx, &expert_idx_per_token); + &topk_weight, &topk_idx, &expert_idx_per_token, &dequant_scale); break; default: PD_THROW("Unsupported data type for MoeDispatchKernel"); @@ -220,7 +243,8 @@ std::vector MoeExpertDispatch( permute_indices_per_token, topk_weight, topk_idx, - expert_idx_per_token}; + expert_idx_per_token, + dequant_scale}; } std::vector> MoeExpertDispatchInferShape( @@ -311,7 +335,7 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch) paddle::Optional("w4a8_in_scale")}) .Outputs({"permute_input", "tokens_expert_prefix_sum", "permute_indices_per_token", "topk_weight", "topk_idx", - "expert_idx_per_token"}) + "expert_idx_per_token", "dequant_scale"}) .Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"}) .SetKernelFn(PD_KERNEL(MoeExpertDispatch)) .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape)) diff --git a/custom_ops/gpu_ops/moe/moe_expert_ffn_wint2.cu b/custom_ops/gpu_ops/moe/moe_expert_ffn_wint2.cu index f3e51bfcfa..47bcba3087 100644 --- a/custom_ops/gpu_ops/moe/moe_expert_ffn_wint2.cu +++ b/custom_ops/gpu_ops/moe/moe_expert_ffn_wint2.cu @@ -17,7 +17,7 @@ #include "cutlass/numeric_conversion.h" #include "group_swiglu_with_masked.h" #include "helper.h" -#include "moe/fast_hardamard_kernel.h" +#include "moe/fast_hardmard/fast_hardamard_kernel.h" #include "moe/fused_moe_helper.h" template diff --git a/custom_ops/gpu_ops/moe/moe_ffn.cu b/custom_ops/gpu_ops/moe/moe_ffn.cu index c135903778..fe48861bf7 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn.cu @@ -18,7 +18,7 @@ #include "cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h" #include "group_swiglu_with_masked.h" #include "helper.h" -#include "moe/fast_hardamard_kernel.h" +#include "moe/fast_hardmard/fast_hardamard_kernel.h" #include "moe/fused_moe_helper.h" #include "w4afp8_gemm/w4afp8_gemm.h" @@ -27,6 +27,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, + const paddle::optional& up_proj_in_scale, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, @@ -178,37 +179,22 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, typedef PDTraits traits_fp8; typedef typename traits_fp8::DataType DataType_fp8; typedef typename traits_fp8::data_t data_t_fp8; - - Allocator::AllocationPtr ffn1_input_row_sum; - ffn1_input_row_sum = allocator->Allocate( - sizeof(float) * expanded_active_expert_rows); - - compute_row_sum( - permute_input.data(), - expanded_active_expert_rows, - hidden_size, - reinterpret_cast(ffn1_input_row_sum->ptr()), - const_cast(tokens_expert_prefix_sum.data()), - num_max_tokens_per_expert, - used_in_ep_low_latency, - stream); - - - float* row_scale = nullptr; + paddle::Tensor weight_scale_tensor = *const_cast(up_gate_proj_scale.get_ptr()); + const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2 ? hidden_size : weight_scale_tensor.dims()[3]; + const float* input_dequant_scale = up_proj_in_scale ? up_proj_in_scale.get().data() : nullptr; DisPatchW4AFp8GemmWrapper( reinterpret_cast(permute_input.data()), reinterpret_cast(up_gate_proj_weight.data()), const_cast(tokens_expert_prefix_sum.data()), - reinterpret_cast(ffn1_input_row_sum->ptr()), - row_scale, - const_cast(up_gate_proj_scale.get_ptr()) - ->data(), + input_dequant_scale, + weight_scale_tensor.data(), reinterpret_cast(fc1_out), used_in_ep_low_latency ? num_max_tokens_per_expert : 0, used_in_ep_low_latency ? num_max_tokens_per_expert : permute_input.dims()[0], num_experts, inter_size, hidden_size, + weight_scale_group_size, stream); } else { typename cutlass::WintQuantTraits::Arguments quant_args; @@ -319,63 +305,83 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, } else if (quant_method == "w4afp8") { data_t *ffn2_shift = nullptr; data_t *ffn2_smooth = nullptr; - float* row_scale = nullptr; + float* input_dequant_scale = nullptr; Allocator::AllocationPtr fp8_act_out; fp8_act_out = allocator->Allocate( SizeOf(paddle::DataType::INT8) * act_out_tensor.numel()); - Allocator::AllocationPtr ffn2_input_row_sum; - ffn2_input_row_sum = allocator->Allocate( - sizeof(float) * expanded_active_expert_rows); - // note(yuanxiaolan): optimize this - MoeFastHardamardWrapper( - act_out_tensor.data(), - expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, - const_cast(tokens_expert_prefix_sum.data()), - ffn2_shift, // ffn2_shift->data(), - ffn2_smooth, // ffn2_smooth->data(), - nullptr, - 1, - 448.0f, - -448.0f, - expanded_active_expert_rows, - inter_size / 2, - num_max_tokens_per_expert, - used_in_ep_low_latency, - hadamard_block_size, - act_out_tensor.data(), - stream - ); + if (down_proj_in_scale) { + MoeFastHardamardWrapper( + act_out_tensor.data(), + expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, + const_cast(tokens_expert_prefix_sum.data()), + ffn2_shift, + ffn2_smooth, + down_proj_in_scale ? const_cast(down_proj_in_scale.get_ptr())->data() : nullptr, + 1, + 448.0f, + -448.0f, + expanded_active_expert_rows, + inter_size / 2, + num_max_tokens_per_expert, + used_in_ep_low_latency, + hadamard_block_size, + reinterpret_cast(fp8_act_out->ptr()), + stream + ); + } else { + Allocator::AllocationPtr ffn2_input_dequant_scale; + ffn2_input_dequant_scale = allocator->Allocate( + sizeof(float) * expanded_active_expert_rows); + input_dequant_scale = reinterpret_cast(ffn2_input_dequant_scale->ptr()); + MoeFastHardamardWrapper( + act_out_tensor.data(), + expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, + const_cast(tokens_expert_prefix_sum.data()), + ffn2_shift, // ffn2_shift->data(), + ffn2_smooth, // ffn2_smooth->data(), + nullptr, + 1, + 448.0f, + -448.0f, + expanded_active_expert_rows, + inter_size / 2, + num_max_tokens_per_expert, + used_in_ep_low_latency, + hadamard_block_size, + act_out_tensor.data(), + stream + ); - quantize_moe_input(act_out_tensor.data(), - expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, - down_proj_in_scale ? const_cast(down_proj_in_scale.get_ptr())->data() : nullptr, - 448.0f, - -448.0f, - expanded_active_expert_rows, - inter_size / 2, - reinterpret_cast(ffn2_input_row_sum->ptr()), - const_cast(tokens_expert_prefix_sum.data()), - num_max_tokens_per_expert, - used_in_ep_low_latency, - reinterpret_cast(fp8_act_out->ptr()), - stream + quantize_moe_input(act_out_tensor.data(), + expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, + expanded_active_expert_rows, + inter_size / 2, + input_dequant_scale, + const_cast(tokens_expert_prefix_sum.data()), + num_max_tokens_per_expert, + used_in_ep_low_latency, + reinterpret_cast(fp8_act_out->ptr()), + stream ); + } + + paddle::Tensor weight_scale_tensor = *const_cast(down_proj_scale.get_ptr()); + const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2 ? inter_size / 2 : weight_scale_tensor.dims()[3]; DisPatchW4AFp8GemmWrapper( reinterpret_cast(fp8_act_out->ptr()), reinterpret_cast(down_proj_weight.data()), const_cast(tokens_expert_prefix_sum.data()), - reinterpret_cast(ffn2_input_row_sum->ptr()), - row_scale, - const_cast(down_proj_scale.get_ptr()) - ->data(), + input_dequant_scale, + weight_scale_tensor.data(), reinterpret_cast(ffn_out_data), used_in_ep_low_latency ? num_max_tokens_per_expert : 0, used_in_ep_low_latency ? num_max_tokens_per_expert : act_out_tensor.dims()[0], num_experts, hidden_size, inter_size / 2, + weight_scale_group_size, stream); } else { typename cutlass::WintQuantTraits::Arguments quant_args; @@ -400,6 +406,7 @@ paddle::Tensor MoeExpertFFNFunc( const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, + const paddle::optional& up_proj_in_scale, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, @@ -421,6 +428,7 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() tokens_expert_prefix_sum, up_gate_proj_weight, down_proj_weight, + up_proj_in_scale, up_gate_proj_bias, up_gate_proj_scale, down_proj_scale, @@ -437,6 +445,7 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() tokens_expert_prefix_sum, up_gate_proj_weight, down_proj_weight, + up_proj_in_scale, up_gate_proj_bias, up_gate_proj_scale, down_proj_scale, @@ -459,6 +468,7 @@ std::vector MoeExpertFFN( const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, + const paddle::optional& up_proj_in_scale, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, @@ -471,6 +481,7 @@ std::vector MoeExpertFFN( tokens_expert_prefix_sum, up_gate_proj_weight, down_proj_weight, + up_proj_in_scale, up_gate_proj_bias, up_gate_proj_scale, down_proj_scale, @@ -487,6 +498,7 @@ std::vector> MoeExpertFFNInferShape( const std::vector& tokens_expert_prefix_sum_shape, const std::vector& up_gate_proj_weight_shape, const std::vector& down_proj_weight_shape, + const paddle::optional>& up_proj_in_scale_shape, const paddle::optional>& up_gate_proj_bias_shape, const paddle::optional>& up_gate_proj_scale_shape, const paddle::optional>& down_proj_scale_shape, @@ -504,6 +516,7 @@ std::vector MoeExpertFFNInferDtype( const paddle::DataType &tokens_expert_prefix_sum_dtype, const paddle::DataType &up_gate_proj_weight_dtype, const paddle::DataType &down_proj_weight_dtype, + const paddle::optional &up_proj_in_scale_dtype, const paddle::optional &up_gate_proj_bias_dtype, const paddle::optional &up_gate_proj_scale_dtype, const paddle::optional &down_proj_scale_dtype, @@ -577,6 +590,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn) "tokens_expert_prefix_sum", "up_gate_proj_weight", "down_proj_weight", + paddle::Optional("up_proj_in_scale"), paddle::Optional("up_gate_proj_bias"), paddle::Optional("up_gate_proj_scale"), paddle::Optional("down_proj_scale"), diff --git a/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h b/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h index 71e37a8ba3..48025e8962 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h @@ -24,12 +24,13 @@ using namespace cute; template + class SmemLayoutB, class SmemLayoutC, class SmemLayoutScale> struct SharedStorage { union { struct { cute::array_aligned> smem_a; cute::array_aligned> smem_b; + cute::array_aligned> smem_scale; }; cute::array_aligned> smem_c; }; @@ -41,16 +42,16 @@ struct SharedStorage { template struct Kernel_traits { using Element = elem_type; - using ElementAccum = float; using ElementOutput = OutputType; + using ElementAccum = typename std::conditional_t; static_assert(cutlass::sizeof_bits_v == 8); static constexpr int kNWarps = kNWarps_; @@ -66,10 +67,10 @@ struct Kernel_traits { static constexpr int kTiles = kTiles_; static constexpr int TokenPackSize = TokenPackSize_; static constexpr int M = M_; - static constexpr int TAIL_N = TAIL_N_; + static constexpr int K = K_; + static constexpr int WeightScaleGroup = WeightScaleGroup_; using TileShape_MNK = Shape, Int, Int>; - using TileShape_MNK_TAIL = Shape, Int, Int>; static constexpr int kClusterM = kClusterM_; using ClusterShape_MNK = Shape, _1, _1>; @@ -83,10 +84,6 @@ struct Kernel_traits { cute::GMMA::rs_op_selector(), AtomLayoutMNK{})); - using TiledMma_TAIL = decltype(cute::make_tiled_mma( - cute::GMMA::rs_op_selector(), - AtomLayoutMNK{})); - using SmemLayoutAtomA = decltype( cutlass::gemm::collective::detail::rs_smem_selector< GMMA::Major::K, Element, Int, Int>()); @@ -103,20 +100,6 @@ struct Kernel_traits { using SmemLayoutB = decltype( tile_to_shape(SmemLayoutAtomB{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - - using SmemLayoutAtomB_TAIL = decltype( - cutlass::gemm::collective::detail::rs_smem_selector< - GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})), - decltype(cute::get<2>(TileShape_MNK_TAIL{}))>()); - - using SmemLayoutB_TAIL = decltype( - tile_to_shape(SmemLayoutAtomB_TAIL{}, - make_shape( - shape<1>(TileShape_MNK_TAIL{}), - shape<2>(TileShape_MNK_TAIL{}), - Int{}) - )); - using SmemLayoutAtomC = decltype( cutlass::gemm::collective::detail::rs_smem_selector< GMMA::Major::K, ElementOutput, @@ -128,8 +111,10 @@ struct Kernel_traits { using SmemCopyAtomAB = Copy_Atom; using SmemCopyAtomC = Copy_Atom; + using SmemLayoutScale = Layout, Int>>; + using SharedStorage = SharedStorage< - kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC>; + kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC, SmemLayoutScale>; using MainloopPipeline = typename cutlass::PipelineTmaAsync; using PipelineState = typename cutlass::PipelineState; diff --git a/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h b/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h index cb46397d51..b650371415 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h @@ -35,19 +35,19 @@ struct CollectiveMainloopFwd { using Element = typename Ktraits::Element; using ElementOutput = typename Ktraits::ElementOutput; using TileShape_MNK = typename Ktraits::TileShape_MNK; - using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL; using ClusterShape = typename Ktraits::ClusterShape_MNK; using ElementAccum = typename Ktraits::ElementAccum; static constexpr int kStages = Ktraits::kStages; static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int kBlockN = Ktraits::kBlockN; - static constexpr int TAIL_N = Ktraits::TAIL_N; static constexpr int kBlockK = Ktraits::kBlockK; static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; static constexpr int kTiles = Ktraits::kTiles; static constexpr int M = Ktraits::M; + static constexpr int K = Ktraits::K; static constexpr int TokenPackSize = Ktraits::TokenPackSize; + static constexpr int WeightScaleGroup = Ktraits::WeightScaleGroup; using GmemTiledCopy = cute::SM90_TMA_LOAD; @@ -55,12 +55,16 @@ struct CollectiveMainloopFwd { using SmemLayoutA = typename Ktraits::SmemLayoutA; using SmemLayoutB = typename Ktraits::SmemLayoutB; using SmemLayoutC = typename Ktraits::SmemLayoutC; - using SmemLayoutB_TAIL = typename Ktraits::SmemLayoutB_TAIL; + using SmemLayoutScale = typename Ktraits::SmemLayoutScale; using ShapeT = cute::Shape; using StrideT = cute::Shape; using LayoutT = cute::Layout; + using ShapeTScale = cute::Shape; + using StrideTScale = cute::Shape<_1, int64_t, int64_t>; + using LayoutTScale = cute::Layout; + using TMA_A = decltype(make_tma_copy( GmemTiledCopy{}, make_tensor( @@ -83,6 +87,17 @@ struct CollectiveMainloopFwd { select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{}))); + using TMA_Scale = decltype(make_tma_copy( + GmemTiledCopy{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + ShapeTScale{}, + StrideTScale{} + ), + SmemLayoutScale{}(_, _0{}), + select<0>(Shape>{}), + size<0>(ClusterShape{}))); + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{}); using MainloopPipeline = typename Ktraits::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; @@ -93,6 +108,7 @@ struct CollectiveMainloopFwd { static constexpr uint32_t TmaTransactionBytesA = static_cast(size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesB = static_cast(size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesScale = static_cast(size(SmemLayoutScale{}(_, _0{})) * cutlass::sizeof_bits_v / 8); struct Arguments { Element const* ptr_A; @@ -102,18 +118,21 @@ struct CollectiveMainloopFwd { ElementOutput * ptr_C; LayoutT layout_C; const float *weight_scale; - const float *input_row_sum; + LayoutTScale layout_Scale; + const float *input_scale; const int64_t * tokens; }; struct Params { LayoutT layout_A; LayoutT layout_B; + LayoutTScale layout_Scale; TMA_A tma_load_A; TMA_B tma_load_B; + TMA_Scale tma_load_Scale; ElementOutput * ptr_C; const float *weight_scale; - const float *input_row_sum; + const float *input_scale; const int64_t * tokens; }; @@ -134,25 +153,37 @@ struct CollectiveMainloopFwd { SmemLayoutB{}(_, _, _0{}), select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{})); + Tensor mScale = make_tensor(make_gmem_ptr(args.weight_scale), args.layout_Scale); + TMA_Scale tma_load_Scale = make_tma_copy( + GmemTiledCopy{}, + mScale, + SmemLayoutScale{}(_, _0{}), + select<0>(Shape>{}), + size<0>(ClusterShape{})); - return {args.layout_A, args.layout_B, tma_load_A, tma_load_B, - args.ptr_C, args.weight_scale, args.input_row_sum, args.tokens}; + return { + args.layout_A, args.layout_B, args.layout_Scale, + tma_load_A, tma_load_B, tma_load_Scale, + args.ptr_C, args.weight_scale, args.input_scale, args.tokens}; } CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& mainloop_params) { cute::prefetch_tma_descriptor(mainloop_params.tma_load_A.get_tma_descriptor()); cute::prefetch_tma_descriptor(mainloop_params.tma_load_B.get_tma_descriptor()); + if constexpr (WeightScaleGroup < K) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_Scale.get_tma_descriptor()); + } } - template + template CUTLASS_DEVICE void store(Params const& mainloop_params, FrgTensorO & tOrO, SharedStorage& shared_storage, TiledMma tiled_mma, - const float *input_row_sum, const float *weight_scale, + const float *input_scale, const int64_t tokens, const int64_t pre_fix_tokens, const int bidm, @@ -163,24 +194,55 @@ struct CollectiveMainloopFwd { using packHalf = typename PackedHalf::Type; Tensor tOrO_out = make_tensor(tOrO.layout()); - #pragma unroll - for (int i = 0; i < size(tOrO); i+=4) { - const int sum_idx = i * 2; - tOrO[i] = (tOrO[i] + input_row_sum[sum_idx]) * weight_scale[0]; - tOrO[i + 1] = (tOrO[i + 1] + input_row_sum[sum_idx + 1]) * weight_scale[0]; - tOrO[i + 2] = (tOrO[i + 2] + input_row_sum[sum_idx]) * weight_scale[1]; - tOrO[i + 3] = (tOrO[i + 3] + input_row_sum[sum_idx + 1]) * weight_scale[1]; - *reinterpret_cast(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]); - *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]); + if (input_scale != nullptr) { + const int lane_id = tidx % 4 * 2; + if constexpr (WeightScaleGroup == K) { + #pragma unroll + for (int i = 0; i < size(tOrO); i+=4) { + const int scale_idx = i * 2 + lane_id; + tOrO[i] = tOrO[i] * weight_scale[0] * input_scale[scale_idx]; + tOrO[i + 1] = tOrO[i + 1] * weight_scale[0] * input_scale[scale_idx + 1]; + tOrO[i + 2] = tOrO[i + 2] * weight_scale[1] * input_scale[scale_idx]; + tOrO[i + 3] = tOrO[i + 3] * weight_scale[1] * input_scale[scale_idx + 1]; + *reinterpret_cast(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]); + *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]); + } + } else { + #pragma unroll + for (int i = 0; i < size(tOrO); i+=4) { + const int scale_idx = i * 2 + lane_id; + *reinterpret_cast(&tOrO_out[i]) = packHalf(float(tOrO[i]) * input_scale[scale_idx], float(tOrO[i + 2]) * input_scale[scale_idx]); + *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(float(tOrO[i + 1]) * input_scale[scale_idx + 1], float(tOrO[i + 3]) * input_scale[scale_idx + 1]); + } + } + } else { + if constexpr (WeightScaleGroup == K) { + #pragma unroll + for (int i = 0; i < size(tOrO); i+=4) { + tOrO[i] = (tOrO[i]) * weight_scale[0]; + tOrO[i + 1] = tOrO[i + 1] * weight_scale[0]; + tOrO[i + 2] = tOrO[i + 2] * weight_scale[1]; + tOrO[i + 3] = tOrO[i + 3] * weight_scale[1]; + *reinterpret_cast(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]); + *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]); + } + } else { + #pragma unroll + for (int i = 0; i < size(tOrO); i+=4) { + *reinterpret_cast(&tOrO_out[i]) = packHalf(float(tOrO[i]), float(tOrO[i + 2])); + *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(float(tOrO[i + 1]), float(tOrO[i + 3])); + } + } } + uint16_t *smem_c = reinterpret_cast(shared_storage.smem_c.data()); uint32_t * reg_data = reinterpret_cast(tOrO_out.data()); cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0); - constexpr int k_copy_times = CUR_N / 16; + constexpr int k_copy_times = kBlockN / 16; #pragma unroll for (int i = 0; i < k_copy_times; i++) { @@ -193,8 +255,8 @@ struct CollectiveMainloopFwd { } cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0); - const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize; - ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN) + bidm * kBlockM; + const int expert_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize; + ElementOutput * store_c = mainloop_params.ptr_C + expert_idx + bidn * (M * kBlockN) + bidm * kBlockM; const int reamin_tokens = tokens - bidn * kBlockN; @@ -202,7 +264,7 @@ struct CollectiveMainloopFwd { constexpr int kPackSize = 16 / sizeof(ElementOutput); constexpr int kNumVecElem = kBlockM / kPackSize; - constexpr int copy_len = CUR_N * kNumVecElem; + constexpr int copy_len = kBlockN * kNumVecElem; #pragma unroll for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) { const int idx_div2 = idx / 2; @@ -246,16 +308,17 @@ struct CollectiveMainloopFwd { Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{}); Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{}); + Tensor sScale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutScale{}); Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(mainloop_params.layout_A.shape()); Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(mainloop_params.layout_B.shape()); + Tensor mScale = mainloop_params.tma_load_Scale.get_tma_tensor(mainloop_params.layout_Scale.shape()); Tensor gA = local_tile(mA(_, _, bidb), select<0, 1>(Shape, Int>{}), make_coord(bidm, _)); + Tensor gScale = local_tile(mScale(_, bidm, bidb), select<0>(Shape>{}), make_coord(_)); auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A, _0{}, Layout{}, group_modes<0, 2>(sA), group_modes<0, 2>(gA)); - const int kIters = kTiles / kStages; - if constexpr (TokenPackSize == 0) { Tensor gB = get_local_no_packed_tensor( mB, @@ -267,72 +330,54 @@ struct CollectiveMainloopFwd { if (tidx == 0) { #pragma unroll - for (int kiter = 0; kiter < kIters; ++kiter) { - #pragma unroll - for (int s = 0; s < kStages; s++) { - const int i = kiter * kStages + s; - pipeline.producer_acquire(smem_pipe_write); - copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tAgA(_, i), tAsA(_, smem_pipe_write.index())); - - copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tBgB(_, i), tBsB(_, smem_pipe_write.index())); - ++smem_pipe_write; - } - } - - #pragma unroll - for (int i = kIters * kStages; i < kTiles; ++i) { + for (int kiter = 0; kiter < kTiles; ++kiter) { pipeline.producer_acquire(smem_pipe_write); copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tAgA(_, i), tAsA(_, smem_pipe_write.index())); + tAgA(_, kiter), tAsA(_, smem_pipe_write.index())); copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tBgB(_, i), tBsB(_, smem_pipe_write.index())); - ++smem_pipe_write; + tBgB(_, kiter), tBsB(_, smem_pipe_write.index())); + + if constexpr (WeightScaleGroup < K) { + copy(mainloop_params.tma_load_Scale.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + gScale(_, kiter), sScale(_, smem_pipe_write.index())); + } + + ++smem_pipe_write; } } } else { - auto mB_this_batch = make_tensor( + auto mB_this_expert = make_tensor( mB(_, _, bidb).data(), make_layout( cute::make_shape(tokens, size<1>(mB)), mB.stride() )); - Tensor gB = local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _)); + Tensor gB = local_tile(mB_this_expert, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _)); auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB)); if (tidx == 0) { #pragma unroll - for (int kiter = 0; kiter < kIters; ++kiter) { - #pragma unroll - for (int s = 0; s < kStages; s++) { - const int i = kiter * kStages + s; - pipeline.producer_acquire(smem_pipe_write); - copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tAgA(_, i), tAsA(_, smem_pipe_write.index())); - - copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tBgB(_, i), tBsB(_, smem_pipe_write.index())); - ++smem_pipe_write; - } - } - - #pragma unroll - for (int i = kIters * kStages; i < kTiles; ++i) { + for (int kiter = 0; kiter < kTiles; ++kiter) { pipeline.producer_acquire(smem_pipe_write); copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tAgA(_, i), tAsA(_, smem_pipe_write.index())); + tAgA(_, kiter), tAsA(_, smem_pipe_write.index())); copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tBgB(_, i), tBsB(_, smem_pipe_write.index())); - ++smem_pipe_write; + tBgB(_, kiter), tBsB(_, smem_pipe_write.index())); + + if constexpr (WeightScaleGroup < K) { + copy(mainloop_params.tma_load_Scale.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + gScale(_, kiter), sScale(_, smem_pipe_write.index())); + } + + ++smem_pipe_write; } } } } - template + template CUTLASS_DEVICE void mma(Params const& mainloop_params, TiledMma tiled_mma, @@ -341,17 +386,53 @@ struct CollectiveMainloopFwd { SharedStorage& shared_storage, FrgTensorO &tSrS, const int tidx) { + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{}); + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{}); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + + auto threadMma = tiled_mma.get_thread_slice(tidx); + + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx); + + Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0)); + Tensor tSrB = threadMma.partition_fragment_B(sB); + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + #pragma unroll + for (int kiter = 0; kiter < kTiles; ++kiter) { + Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index())); + consumer_wait(pipeline, smem_pipe_read); + gemm(tiled_mma, tSrA, tSsA, tSrB(_, _, _, smem_pipe_read.index()), tSrS, smem_tiled_copy_A, smem_thr_copy_A); + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } - using sMemBLayout = std::conditional_t< - CUR_N == kBlockN, - SmemLayoutB, - SmemLayoutB_TAIL - >; + } + + template + CUTLASS_DEVICE void + mma_pipeline(Params const& mainloop_params, + TiledMma tiled_mma, + MainloopPipeline pipeline, + PipelineState& smem_pipe_read, + SharedStorage& shared_storage, + FrgTensorO &tSrS, + const int tidx) { Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{}); - Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{}); + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{}); + float2 *weight_scale = reinterpret_cast(shared_storage.smem_scale.data()) + tidx / 4; - tiled_mma.accumulate_ = GMMA::ScaleOut::One; + Tensor tSrS1 = make_fragment_like(tSrS); + Tensor tSrS2 = make_fragment_like(tSrS); + + __half2 * tSrS_data = reinterpret_cast<__half2*>(raw_pointer_cast(tSrS.data())); + __half2 * tSrS1_data = reinterpret_cast<__half2*>(raw_pointer_cast(tSrS1.data())); + __half2 * tSrS2_data = reinterpret_cast<__half2*>(raw_pointer_cast(tSrS2.data())); auto threadMma = tiled_mma.get_thread_slice(tidx); @@ -366,29 +447,53 @@ struct CollectiveMainloopFwd { pipeline.consumer_wait(smem_pipe_read, barrier_token); }; - const int kIters = kTiles / kStages; + __half2 scale1, scale2, scale3, scale4; + float2 scale_cur_k; + #pragma unroll + for (int kiter = 0; kiter < kTiles;) { + Tensor tSsA1 = smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index())); + consumer_wait(pipeline, smem_pipe_read); + scale_cur_k = *(weight_scale + smem_pipe_read.index() * (kBlockM / 2)); + scale1 = __half2(scale_cur_k.x, scale_cur_k.x); + scale2 = __half2(scale_cur_k.y, scale_cur_k.y); - constexpr int B_STEPS = CUR_N == 0 ? 1 : (kBlockN / CUR_N); + gemm(tiled_mma, tSrA, tSsA1, tSrB(_, _, _, smem_pipe_read.index()), tSrS1, smem_tiled_copy_A, smem_thr_copy_A); + pipeline.consumer_release(smem_pipe_read); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - #pragma unroll - for (int kiter = 0; kiter < kIters; ++kiter) { - #pragma unroll - for (int s = 0; s < kStages; s++) { - Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, s)); + if (kiter > 0) { + for (int i = 0; i < size(tSrS) / 2; i+=2) { + tSrS_data[i] = __hfma2(tSrS2_data[i], scale3, tSrS_data[i]); + tSrS_data[i+1] = __hfma2(tSrS2_data[i + 1], scale4, tSrS_data[i+1]); + } + } + + ++smem_pipe_read; + ++kiter; + + if (kiter < kTiles) { + Tensor tSsA2 = smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index())); consumer_wait(pipeline, smem_pipe_read); - gemm(tiled_mma, tSrA, tSsA, tSrB(_, _, _, s * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A); + scale_cur_k = *(weight_scale + smem_pipe_read.index() * (kBlockM / 2)); + scale3 = __half2(scale_cur_k.x, scale_cur_k.x); + scale4 = __half2(scale_cur_k.y, scale_cur_k.y); + + gemm(tiled_mma, tSrA, tSsA2, tSrB(_, _, _, smem_pipe_read.index()), tSrS2, smem_tiled_copy_A, smem_thr_copy_A); pipeline.consumer_release(smem_pipe_read); ++smem_pipe_read; + ++kiter; + } + + for (int i = 0; i < size(tSrS) / 2; i+=2) { + tSrS_data[i] = __hfma2(tSrS1_data[i], scale1, tSrS_data[i]); + tSrS_data[i+1] = __hfma2(tSrS1_data[i + 1], scale2, tSrS_data[i+1]); } + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } - #pragma unroll - for (int i = 0; i < kTiles % kStages; ++i) { - Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, i)); - consumer_wait(pipeline, smem_pipe_read); - gemm(tiled_mma, tSrA, tSsA, tSrB(_, _, _, i * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A); - pipeline.consumer_release(smem_pipe_read); - ++smem_pipe_read; + for (int i = 0; i < size(tSrS) / 2; i+=2) { + tSrS_data[i] = __hfma2(tSrS2_data[i], scale3, tSrS_data[i]); + tSrS_data[i+1] = __hfma2(tSrS2_data[i + 1], scale4, tSrS_data[i+1]); } } }; diff --git a/custom_ops/gpu_ops/w4afp8_gemm/utils.hpp b/custom_ops/gpu_ops/w4afp8_gemm/utils.hpp index 2c0f685fe7..128ea564aa 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/utils.hpp +++ b/custom_ops/gpu_ops/w4afp8_gemm/utils.hpp @@ -62,8 +62,12 @@ template __forceinline__ __device__ void convert_c4_2_fp8(const int32_t * src, int32_t * dst1, int32_t * dst2) { #pragma unroll for (int i = 0; i < numel; ++i) { - dst1[i] = (src[i] >> 4) & 0x0f0f0f0f; - dst2[i] = src[i] & 0x0f0f0f0f; + uint32_t head1 = src[i] & 0x80808080; + dst1[i] = (src[i] >> 4) & 0x07070707; + dst1[i] = dst1[i] | head1; + uint32_t head2 = (src[i] & 0x08080808) << 4; + dst2[i] = src[i] & 0x07070707; + dst2[i] = dst2[i] | head2; } } @@ -88,7 +92,6 @@ __forceinline__ __device__ void gemm( warpgroup_arrive(); } constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4; - Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); @@ -103,7 +106,9 @@ __forceinline__ __device__ void gemm( convert_c4_2_fp8(tCrA_data, tCrA1_data, tCrA2_data); cute::gemm(tiled_mma, tCrA1(_,_,k_block), tCrB(_,_,2 * k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; cute::gemm(tiled_mma, tCrA2(_,_,k_block), tCrB(_,_, 2 * k_block + 1), tCrC); + } if constexpr (commit) { warpgroup_commit_batch(); diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu index 53685c5c97..0b72b3454f 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu @@ -20,25 +20,11 @@ #include "paddle/extension.h" #include "w4afp8_gemm_template.h" #include "w4afp8_gemm.h" +#include "weight_kernel.hpp" +#include "weight_scale_kernel.hpp" + -void weight_convert(const uint8_t *weight, uint8_t *weight_new, int batch, int M, int K) { - assert(K % 64 == 0); - for (int b = 0; b < batch; ++b) { - for (int m = 0; m < M; ++m) { - for (int k = 0; k < K; k+=64) { - for (int k_inner = 0; k_inner < 32; ++k_inner) { - uint8_t temp = 0; - uint8_t left = weight[b * M * K + m * K + k + k_inner]; - uint8_t right = weight[b * M * K + m * K + k + k_inner + 32]; - temp |= left << 4; - temp |= right; - weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] = *reinterpret_cast(&temp); - } - } - } - } -} template class NVTraits; @@ -65,26 +51,26 @@ void DisPatchW4AFp8Gemm( const cutlass::float_e4m3_t* input, const cutlass::float_e4m3_t* weight, const int64_t * tokens, - const float * input_row_sum, const float * weight_scale, + const float * input_dequant_scale, OutputType * out, const int64_t token_padding_size, const int64_t max_tokens, - const int batch_size, + const int Experts, const int64_t M, const int64_t K, + const int WeightScaleGroup, cudaStream_t stream) { int kBlockN = 256; - int TailN = 0; if constexpr (std::is_same_v) { GEMM_SWITCH_BF16( - M, K, batch_size, token_padding_size, kBlockN, TailN, + M, K, Experts, token_padding_size, kBlockN, WeightScaleGroup, weight, input, out, weight_scale, - input_row_sum, + input_dequant_scale, tokens, max_tokens, stream) @@ -97,16 +83,17 @@ std::vector W4AFp8Gemm( const paddle::Tensor& input, const paddle::Tensor& weight, const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group - const paddle::Tensor& input_row_sum, const paddle::Tensor& weight_scale, + const paddle::Tensor& input_dequant_scale, const int64_t token_padding_size, const int64_t max_tokens, const bool is_bfloat16) { - const int batch_size = weight.dims()[0]; + const int Experts = weight.dims()[0]; const int M = weight.dims()[1]; const int K = weight.dims()[2] * 2; + const int WeightScaleGroup = weight_scale.dims().size() == 2 ? K : weight_scale.dims()[3]; if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) { PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN']."); @@ -121,14 +108,15 @@ std::vector W4AFp8Gemm( reinterpret_cast(input.data()), reinterpret_cast(weight.data()), tokens.data(), - input_row_sum.data(), weight_scale.data(), + input_dequant_scale.data(), reinterpret_cast(out_data), token_padding_size, max_tokens, - batch_size, + Experts, M, K, + WeightScaleGroup, input.stream()); return {out}; } else { @@ -136,20 +124,21 @@ std::vector W4AFp8Gemm( } } else { if (is_bfloat16) { - paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place()); + paddle::Tensor out = paddle::empty({Experts, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place()); phi::dtype::bfloat16 * out_data = out.data(); DisPatchW4AFp8Gemm( reinterpret_cast(input.data()), reinterpret_cast(weight.data()), tokens.data(), - input_row_sum.data(), weight_scale.data(), + input_dequant_scale.data(), reinterpret_cast(out_data), token_padding_size, max_tokens, - batch_size, + Experts, M, K, + WeightScaleGroup, input.stream()); return {out}; } else { @@ -163,8 +152,7 @@ void DisPatchW4AFp8GemmWrapper( const InputType* input, const InputType* weight, const int64_t* total_rows_before_expert, - const float* input_row_sum, - const float* row_scale, + const float* input_dequant_scale, const float* weight_scale, OutputType * out, const int64_t token_padding_size, @@ -172,6 +160,7 @@ void DisPatchW4AFp8GemmWrapper( const int num_experts, const int64_t M, const int64_t K, + const int WeightScaleGroup, cudaStream_t stream) { using InType = typename NVTraits::data_t; using OutType = typename NVTraits::data_t; @@ -179,78 +168,18 @@ void DisPatchW4AFp8GemmWrapper( reinterpret_cast(input), reinterpret_cast(weight), total_rows_before_expert, - input_row_sum, weight_scale, + input_dequant_scale, reinterpret_cast(out), token_padding_size, max_tokens, num_experts, M, K, + WeightScaleGroup, stream); } - -std::vector W4AFp8GemmWeightConvert(const paddle::Tensor& weight) { - const int batch_size = weight.dims()[0]; - const int M = weight.dims()[1]; - const int K = weight.dims()[2]; - paddle::Tensor weight_new = paddle::empty({batch_size, M, K / 2}, paddle::DataType::UINT8, weight.place()); - weight_convert(weight.data(), weight_new.data(), batch_size, M, K); - return {weight_new}; -} - -template -__global__ void permute_scale_kernel( - T* input_data, - const int numel) { - using LoadT = AlignedVector; - LoadT input_vec; - LoadT dst_vec; - const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize; - if (load_idx >= numel) { - return; - } - Load(&input_data[load_idx], &input_vec); - - for (int i = 0; i < kPackSize; i+=2) { - dst_vec[i] = input_vec[i / 2]; - dst_vec[i + 1] = input_vec[i / 2 + 8]; - } - - Store(dst_vec, &input_data[load_idx]); -} - -void W4AFp8GemmScalePermute(const paddle::Tensor& scale) { - const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1; - const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0]; - if (col % 16 != 0) { - PD_THROW("Only supported when col is divisible by 16."); - } - const int numel = row * col; - const int threads = 128; - const int kPackSize = 16; - const int grid_size = (numel / kPackSize + threads - 1) / threads; - - if (scale.dtype() == paddle::DataType::BFLOAT16) { - permute_scale_kernel<<>>( - const_cast(scale.data()), - numel - ); - } else if (scale.dtype() == paddle::DataType::FLOAT16) { - permute_scale_kernel<<>>( - const_cast(scale.data()), - numel - ); - } else if (scale.dtype() == paddle::DataType::FLOAT32) { - permute_scale_kernel<<>>( - const_cast(scale.data()), - numel - ); - } - -} - PD_BUILD_STATIC_OP(w4afp8_gemm_scale_permute) .Inputs({"weight_scale"}) .Outputs({"permute_scale"}) @@ -261,7 +190,6 @@ PD_BUILD_STATIC_OP(w4afp8_gemm) .Inputs({"input", "weight", "tokens", - "input_row_sum", "weight_scale"}) .Outputs({"out"}) .Attrs({"token_padding_size: int64_t", @@ -278,8 +206,7 @@ template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>( const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* weight, const int64_t * tokens, - const float * input_row_sum, - const float * row_scale, + const float * input_dequant_scale, const float * weight_scale, __nv_bfloat16 * out, const int64_t token_padding_size, @@ -287,6 +214,7 @@ template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>( const int num_experts, const int64_t M, const int64_t K, + const int WeightScaleGroup, cudaStream_t stream ); @@ -294,8 +222,7 @@ template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>( const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* weight, const int64_t * tokens, - const float * input_row_sum, - const float * row_scale, + const float * input_dequant_scale, const float * weight_scale, half * out, const int64_t token_padding_size, @@ -303,5 +230,6 @@ template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>( const int num_experts, const int64_t M, const int64_t K, + const int WeightScaleGroup, cudaStream_t stream ); diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h index c2474d419b..b8c393ae1d 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h @@ -24,8 +24,8 @@ std::vector W4AFp8Gemm( const paddle::Tensor& input, const paddle::Tensor& weight, const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group - const paddle::Tensor& input_row_sum, const paddle::Tensor& weight_scale, + const paddle::Tensor& input_dequant_scale, const int64_t token_padding_size, const int64_t max_tokens, const bool is_bfloat16); @@ -35,8 +35,7 @@ void DisPatchW4AFp8GemmWrapper( const InputType* input, const InputType* weight, const int64_t * tokens, - const float * input_row_sum, - const float * row_scale, + const float * input_dequant_scale, const float * weight_scale, OutputType * out, const int64_t token_padding_size, @@ -44,4 +43,5 @@ void DisPatchW4AFp8GemmWrapper( const int num_experts, const int64_t M, const int64_t K, + const int WeightScaleGroup, cudaStream_t stream); diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp index 01a8dd114c..425d648e61 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp @@ -34,7 +34,6 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp static_assert(cutlass::sizeof_bits_v == 8); using TileShape_MNK = typename Ktraits::TileShape_MNK; - using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL; using ClusterShape = typename Ktraits::ClusterShape_MNK; static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{}); @@ -42,8 +41,9 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp static constexpr int kBlockN = Ktraits::kBlockN; static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int M = Ktraits::M; + static constexpr int K = Ktraits::K; static constexpr int TokenPackSize = Ktraits::TokenPackSize; - static constexpr int TAIL_N = Ktraits::TAIL_N; + static constexpr int WeightScaleGroup = Ktraits::WeightScaleGroup; using CollectiveMainloop = CollectiveMainloopFwd; @@ -68,7 +68,11 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; PipelineParams pipeline_params; - pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB; + if constexpr (WeightScaleGroup == K) { + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB; + } else { + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB + CollectiveMainloop::TmaTransactionBytesScale; + } int warp_group_idx = cutlass::canonical_warp_group_idx(); pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer @@ -96,8 +100,9 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp return; } - float* input_row_sum = reinterpret_cast( - shared_memory + sizeof(typename Ktraits::SharedStorage)); + const bool is_need_input_scale = mainloop_params.input_scale != nullptr; + + float* input_scale = is_need_input_scale ? reinterpret_cast(shared_memory + sizeof(typename Ktraits::SharedStorage)) : nullptr; if (warp_group_idx == 0) { cutlass::arch::warpgroup_reg_dealloc(); @@ -119,53 +124,40 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp typename Ktraits::TiledMma tiled_mma; - typename Ktraits::TiledMma_TAIL tiled_mma_tail; - const int mma_tidx = tidx - NumCopyThreads; - const int lane_id = mma_tidx % 4 * 2; - - const float2 weight_scale = reinterpret_cast(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4]; - if constexpr (TokenPackSize == 0) { - const int input_sum_idx = pre_fix_tokens + bidn * kBlockN; - if (mma_tidx < kBlockN) { - reinterpret_cast(input_row_sum)[mma_tidx] = reinterpret_cast(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx]; - } - } else { - const int input_sum_idx = bidb * TokenPackSize + bidn * kBlockN; - if (mma_tidx < kBlockN / 4) { - reinterpret_cast(input_row_sum)[mma_tidx] = reinterpret_cast(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx]; + if (is_need_input_scale) { + if constexpr (TokenPackSize == 0) { + const int input_scale_idx = pre_fix_tokens + bidn * kBlockN; + if (mma_tidx < tokens) { + reinterpret_cast(input_scale)[mma_tidx] = reinterpret_cast(mainloop_params.input_scale + input_scale_idx)[mma_tidx]; + } + } else { + const int input_scale_idx = bidb * TokenPackSize + bidn * kBlockN; + if (mma_tidx < kBlockN / 4) { + reinterpret_cast(input_scale)[mma_tidx] = reinterpret_cast(mainloop_params.input_scale + input_scale_idx)[mma_tidx]; + } } } - const int reamin_tokens = tokens - bidn * kBlockN; + float2 weight_scale; + + if constexpr (WeightScaleGroup == K) { + weight_scale = reinterpret_cast(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4]; + } + Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})); - if (TAIL_N > 0 && reamin_tokens < kBlockN) { - Tensor tSrS_tail = partition_fragment_C(tiled_mma_tail, select<0, 1>(TileShape_MNK_TAIL{})); - collective_mainloop.mma( + if constexpr (WeightScaleGroup == K) { + collective_mainloop.mma( mainloop_params, - tiled_mma_tail, + tiled_mma, pipeline, smem_pipe_read, shared_storage, - tSrS_tail, - mma_tidx); - collective_mainloop.store( - mainloop_params, - tSrS_tail, - shared_storage, - tiled_mma_tail, - input_row_sum + lane_id, - reinterpret_cast(&weight_scale), - tokens, - pre_fix_tokens, - bidm, - bidn, - bidb, + tSrS, mma_tidx); } else { - Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})); - collective_mainloop.mma( + collective_mainloop.mma_pipeline( mainloop_params, tiled_mma, pipeline, @@ -173,41 +165,55 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp shared_storage, tSrS, mma_tidx); - collective_mainloop.store( - mainloop_params, - tSrS, - shared_storage, - tiled_mma, - input_row_sum + lane_id, - reinterpret_cast(&weight_scale), - tokens, - pre_fix_tokens, - bidm, - bidn, - bidb, - mma_tidx); } + + + collective_mainloop.store( + mainloop_params, + tSrS, + shared_storage, + tiled_mma, + reinterpret_cast(&weight_scale), + input_scale, + tokens, + pre_fix_tokens, + bidm, + bidn, + bidb, + mma_tidx); } } -template +template auto get_gmem_layout(const int Rows, const int Cols) { return make_layout( make_shape( static_cast(Rows), static_cast(Cols), - static_cast(Batch)), + static_cast(Experts)), make_stride( static_cast(Cols), cute::_1{}, static_cast(Rows * Cols))); } +template +auto get_scale_layout(const int Rows, const int Cols) { + return make_layout( + make_shape( + static_cast(Cols), + static_cast(Rows), + static_cast(Experts)), + make_stride( + cute::_1{}, + static_cast(Cols), + static_cast(Rows * Cols))); +} + -template -void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale, - const float *input_row_sum, const int64_t * tokens, const int64_t max_tokens, cudaStream_t stream) { +template +void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale, const float * input_dequant_scale, const int64_t * tokens, const int max_tokens, cudaStream_t stream) { using ElementOutput = typename Kernel_traits::ElementOutput; using Element = typename Kernel_traits::Element; @@ -216,24 +222,28 @@ void run_gemm(const InputType * A, const InputType * B, OutputType * C, const fl constexpr int M_nums = (M + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; const int N_nums = (max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; + constexpr int K_scale_nums = K / Kernel_traits::kBlockM; + static_assert(K % WeightScaleGroup == 0); + static_assert(WeightScaleGroup == 128 || WeightScaleGroup == K); typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({ static_cast(A), - get_gmem_layout(M, K / 2), + get_gmem_layout(M, K / 2), static_cast(B), - get_gmem_layout(TokenPackSize == 0 ? max_tokens: TokenPackSize, K), + get_gmem_layout(TokenPackSize == 0 ? max_tokens: TokenPackSize, K), static_cast(C), - get_gmem_layout(M, TokenPackSize == 0 ? max_tokens : TokenPackSize), + get_gmem_layout(M, TokenPackSize == 0 ? max_tokens : TokenPackSize), weight_scale, - input_row_sum, + get_scale_layout(M_nums, K_scale_nums * Kernel_traits::kBlockM), + input_dequant_scale, tokens }); void *kernel; kernel = (void *)w4afp8_gemm_kernel; - int smem_size = sizeof(typename Kernel_traits::SharedStorage) + sizeof(float) * Kernel_traits::kBlockN; + int smem_size = sizeof(typename Kernel_traits::SharedStorage) + Kernel_traits::kBlockN * sizeof(float); if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); @@ -242,7 +252,7 @@ void run_gemm(const InputType * A, const InputType * B, OutputType * C, const fl dim3 grid_dims; grid_dims.x = M_nums; grid_dims.y = N_nums; - grid_dims.z = Batch; + grid_dims.z = Experts; static constexpr int ctaSize = Kernel_traits::kNWarps * 32; dim3 block_dims(ctaSize); dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); diff --git a/custom_ops/gpu_ops/w4afp8_gemm/weight_kernel.hpp b/custom_ops/gpu_ops/w4afp8_gemm/weight_kernel.hpp new file mode 100644 index 0000000000..7501bdaebc --- /dev/null +++ b/custom_ops/gpu_ops/w4afp8_gemm/weight_kernel.hpp @@ -0,0 +1,123 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "helper.h" +#include "paddle/extension.h" + +void weight_convert(const uint8_t *weight, uint8_t *weight_new, int experts, int M, int K) { + assert(K % 64 == 0); + for (int b = 0; b < experts; ++b) { + for (int m = 0; m < M; ++m) { + for (int k = 0; k < K; k+=64) { + for (int k_inner = 0; k_inner < 32; ++k_inner) { + uint8_t temp = 0; + uint8_t left = weight[b * M * K + m * K + k + k_inner]; + uint8_t right = weight[b * M * K + m * K + k + k_inner + 32]; + temp |= left << 4; + temp |= right; + weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] = *reinterpret_cast(&temp); + } + } + } + } +} + +__global__ void weight_permute_interleave_kernelw4afp8( + const int8_t* input_data, + int8_t* output_data, + const int original_k, + const int original_n) { + + const int numel = original_k * original_n / 4; + const int pack_group_size = 64; + const int thread_group_size = pack_group_size / 4; // 16 + const int thread_k_stride = original_k / 4; + + const int linear_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (linear_idx >= numel) return; + + const int n_id = linear_idx / thread_k_stride; + const int k_id = linear_idx % thread_k_stride; + const int k_group_idx = k_id / thread_group_size; + const int k_idx_in_group = k_id % thread_group_size; + + const int8_t* src = input_data + + k_group_idx * pack_group_size / 2 * original_n + + k_idx_in_group * original_n + n_id; + + int8_t tmp0 = src[0]; + int8_t tmp1 = src[pack_group_size / 4 * original_n]; + + int8_t tmp00 = (tmp0 & 0xF0) + 112; + int8_t tmp01 = ((tmp0 << 4) & 0xF0) + 112; + int8_t tmp10 = (tmp1 & 0xF0) + 112; + int8_t tmp11 = ((tmp1 << 4) & 0xF0) + 112; + + uint8_t utmp00 = *(reinterpret_cast(&tmp00)); + uint8_t utmp01 = *(reinterpret_cast(&tmp01)); + uint8_t utmp10 = *(reinterpret_cast(&tmp10)); + uint8_t utmp11 = *(reinterpret_cast(&tmp11)); + + utmp00 = (utmp00 & 0xF0) >> 4; + utmp01 = (utmp01 & 0xF0) >> 4; + utmp10 = (utmp10 & 0xF0) >> 4; + utmp11 = (utmp11 & 0xF0) >> 4; + + tmp00 = *(reinterpret_cast(&utmp00)) - 7; + tmp01 = *(reinterpret_cast(&utmp01)) - 7; + tmp10 = *(reinterpret_cast(&utmp10)) - 7; + tmp11 = *(reinterpret_cast(&utmp11)) - 7; + + if (tmp00 <= 0) { + tmp00 = 8 - tmp00; + } + if (tmp01 <= 0) { + tmp01 = 8 - tmp01; + } + if (tmp10 <= 0) { + tmp10 = 8 - tmp10; + } + if (tmp11 <= 0) { + tmp11 = 8 - tmp11; + } + + int8_t dst0 = (tmp01 << 4) | tmp11; + int8_t dst1 = (tmp00 << 4) | tmp10; + + int8_t* dst = output_data + n_id * original_k / 2 + (k_group_idx * pack_group_size / 2) + k_idx_in_group * 2; + dst[0] = dst0; + dst[1] = dst1; +} + +std::vector W4AFp8GemmWeightConvert(const paddle::Tensor& weight) { + if (weight.place() == paddle::CPUPlace()) { + const int experts = weight.dims()[0]; + const int M = weight.dims()[1]; + const int K = weight.dims()[2]; + paddle::Tensor weight_new = paddle::empty({experts, M, K / 2}, paddle::DataType::UINT8, weight.place()); + weight_convert(weight.data(), weight_new.data(), experts, M, K); + return {weight_new}; + } else { + const int original_k = weight.dims()[0] * 2; + const int original_n = weight.dims()[1]; + paddle::Tensor weight_new = paddle::empty(weight.shape(), paddle::DataType::INT8, weight.place()); + const int block_dim = 256; + const int original_numel = original_k * original_n; + const int grid_size = (original_numel + block_dim - 1) / block_dim; + + weight_permute_interleave_kernelw4afp8<<>>( + weight.data(), weight_new.data(), original_k, original_n); + return {weight_new}; + } +} diff --git a/custom_ops/gpu_ops/w4afp8_gemm/weight_scale_kernel.hpp b/custom_ops/gpu_ops/w4afp8_gemm/weight_scale_kernel.hpp new file mode 100644 index 0000000000..9a37bf4a2f --- /dev/null +++ b/custom_ops/gpu_ops/w4afp8_gemm/weight_scale_kernel.hpp @@ -0,0 +1,66 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "helper.h" +#include "paddle/extension.h" + +template +__global__ void permute_scale_kernel( + T* input_data, + const int numel) { + using LoadT = AlignedVector; + LoadT input_vec; + LoadT dst_vec; + const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize; + if (load_idx >= numel) { + return; + } + Load(&input_data[load_idx], &input_vec); + + for (int i = 0; i < kPackSize; i+=2) { + dst_vec[i] = input_vec[i / 2]; + dst_vec[i + 1] = input_vec[i / 2 + 8]; + } + + Store(dst_vec, &input_data[load_idx]); +} + +void W4AFp8GemmScalePermute(const paddle::Tensor& scale) { + const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1; + const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0]; + if (col % 16 != 0) { + PD_THROW("Only supported when col is divisible by 16."); + } + const int numel = row * col; + const int threads = 128; + const int kPackSize = 16; + const int grid_size = (numel / kPackSize + threads - 1) / threads; + + if (scale.dtype() == paddle::DataType::BFLOAT16) { + permute_scale_kernel<<>>( + const_cast(scale.data()), + numel + ); + } else if (scale.dtype() == paddle::DataType::FLOAT16) { + permute_scale_kernel<<>>( + const_cast(scale.data()), + numel + ); + } else if (scale.dtype() == paddle::DataType::FLOAT32) { + permute_scale_kernel<<>>( + const_cast(scale.data()), + numel + ); + } + +} diff --git a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py index 1acf3c80ae..802942eecd 100644 --- a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py +++ b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import re file_dir = "./gpu_ops/w4afp8_gemm/" @@ -30,12 +32,12 @@ #include """ gemm_template_case = """ -void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}( +void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}( const cutlass::float_e4m3_t * weight, const cutlass::float_e4m3_t * input, {cutlass_type} * out, const float *weight_scale, - const float *input_row_sum, + const float * input_dequant_scale, const int64_t *tokens, const int64_t max_tokens, cudaStream_t stream); @@ -48,22 +50,22 @@ """ gemm_template_cu_template = """ -void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}( +void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}( const cutlass::float_e4m3_t * weight, const cutlass::float_e4m3_t * input, {cutlass_type} * out, const float *weight_scale, - const float *input_row_sum, + const float * input_dequant_scale, const int64_t *tokens, const int64_t max_tokens, cudaStream_t stream) {{ constexpr static int M = {M}; constexpr static int K = {K}; - constexpr static int Batch = {BATCH}; + constexpr static int EXPERTS = {EXPERTS}; constexpr static int TokenPackSize = {PADDING}; constexpr static int kBlockN = {N}; - constexpr static int kBlockN_TAIL = {TAILN}; + constexpr static int kGroupSize = {GROUPSIZE}; constexpr static int kBlockM = 128; constexpr static int kBlockK = 128; constexpr static int kNWarps = 4 + kBlockM / 16; @@ -74,22 +76,24 @@ using Kernel_traits = Kernel_traits< kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles, - M, TokenPackSize, kBlockN_TAIL, kCluster, cutlass::float_e4m3_t, + M, K, TokenPackSize, kGroupSize, kCluster, cutlass::float_e4m3_t, {cutlass_type}>; run_gemm - (weight, input, out, weight_scale, - input_row_sum, tokens, max_tokens, stream); + Kernel_traits, M, K, EXPERTS, TokenPackSize, kGroupSize> + (weight, input, out, weight_scale, input_dequant_scale, tokens, max_tokens, stream); }} """ +# [M, K, Number of experts, token Padding Size, weight K group size] gemm_case = [ - [8192, 3584, 8, 0], # eb45T ffn1 - [8192, 3584, 8, 2048], # eb45T ffn1 - [7168, 8192, 8, 0], # eb45T ffn2 - [7168, 8192, 8, 2048], # eb45T ffn2 - [1792, 8192, 64, 0], # eb45t ffn1 - [8192, 896, 64, 0], # eb45t ffn2 + [8192, 3584, 8, 0, 3584], # eb45T ffn1 + [8192, 3584, 8, 2048, 3584], # eb45T ffn1 + [7168, 8192, 8, 0, 8192], # eb45T ffn2 + [7168, 8192, 8, 2048, 8192], # eb45T ffn2 + [1792, 8192, 64, 0, 8192], # eb45t ffn1 + [8192, 896, 64, 0, 896], # eb45t ffn2 + [1792, 8192, 64, 0, 128], # eb45t ffn1 + [8192, 896, 64, 0, 128], # eb45t ffn2 ] dtype = ["BF16"] @@ -97,6 +101,19 @@ use_fast_compile = True n_range = [256] if use_fast_compile else [i for i in range(16, 257, 16)] +all_cu_files = [] +for type in dtype: + for case in gemm_case: + for n in n_range: + all_cu_files.append(f"w4afp8_gemm_M{case[0]}_N{n}_G{case[4]}_K{case[1]}_E{case[2]}_P{case[3]}_{type}.cu") + +for file_path, empty_list, file_name_list in os.walk(file_dir): + for file_name in file_name_list: + if re.match(r"^w4afp8_gemm_M\d+_N\d+_.*\.cu$", file_name): + if file_name not in all_cu_files: + print("delete w4afp8 kernel file", file_path + file_name) + os.remove(file_path + file_name) + def get_cutlass_type(type): if type == "BF16": @@ -116,28 +133,16 @@ def get_cutlass_type(type): M=case[0], K=case[1], N=n, - BATCH=case[2], + EXPERTS=case[2], TYPE=type, PADDING=case[3], - TAILN=0, - cutlass_type=get_cutlass_type(type), - ) - ) - template_head_file.write( - gemm_template_case.format( - M=case[0], - K=case[1], - N=256, - BATCH=case[2], - TYPE=type, - PADDING=case[3], - TAILN=n - 16, + GROUPSIZE=case[4], cutlass_type=get_cutlass_type(type), ) ) template_cu_file = open( - f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_TAILN{0}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w" + f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_G{case[4]}_K{case[1]}_E{case[2]}_P{case[3]}_{type}.cu", "w" ) template_cu_file.write(gemm_template_cu_head) template_cu_file.write( @@ -145,29 +150,10 @@ def get_cutlass_type(type): M=case[0], K=case[1], N=n, - BATCH=case[2], + EXPERTS=case[2], TYPE=type, PADDING=case[3], - TAILN=0, - cutlass_type=get_cutlass_type(type), - ) - ) - - template_cu_file.close() - - template_cu_file = open( - f"{file_dir}w4afp8_gemm_M{case[0]}_N{256}_TAILN{n-16}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w" - ) - template_cu_file.write(gemm_template_cu_head) - template_cu_file.write( - gemm_template_cu_template.format( - M=case[0], - K=case[1], - N=256, - BATCH=case[2], - TYPE=type, - PADDING=case[3], - TAILN=n - 16, + GROUPSIZE=case[4], cutlass_type=get_cutlass_type(type), ) ) @@ -177,8 +163,8 @@ def get_cutlass_type(type): for type in dtype: template_head_file.write("\n") template_head_file.write( - """#define GEMM_SWITCH_{TYPE}(_M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN, ...) {{ \\ - if (_M == 0 && _K == 0 && _BATCH == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _TailN == 0) {{ \\""".format( + """#define GEMM_SWITCH_{TYPE}(_M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE, ...) {{ \\ + if (_M == 0 && _K == 0 && _EXPERTS == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _GROUPSIZE == 0) {{ \\""".format( TYPE=type ) ) @@ -188,23 +174,16 @@ def get_cutlass_type(type): for case in gemm_case: for n in n_range: template_head_file.write( - """ }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\ - w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format( - M=case[0], K=case[1], N=n, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=0 - ) - ) - template_head_file.write("\n") - template_head_file.write( - """ }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\ - w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format( - M=case[0], K=case[1], N=256, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=n - 16 + """ }} else if (_M == {M} && _K == {K} && _EXPERTS == {EXPERTS} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _GROUPSIZE == {GROUPSIZE}) {{ \\ + w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format( + M=case[0], K=case[1], N=n, EXPERTS=case[2], TYPE=type, PADDING=case[3], GROUPSIZE=case[4] ) ) template_head_file.write("\n") template_head_file.write( """ } else { \\ - PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d batch=%d token_padding_size=%d kBlockN=%d tailN=%d\\n", _M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN)); \\ + PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d experts=%d token_padding_size=%d kBlockN=%d groupsize=%d\\n", _M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE)); \\ } \\ }""" ) diff --git a/docs/features/plas_attention.md b/docs/features/plas_attention.md index 6551b4e8f7..b096fceeb6 100644 --- a/docs/features/plas_attention.md +++ b/docs/features/plas_attention.md @@ -15,7 +15,7 @@ In terms of training efficiency, the training cost is very low because only the Following the approaches of NSA and MoBA, we partition the KV into multiple blocks. During both the prefill and decode stages, instead of performing attention computation over all KV, we dynamically select the top-K blocks with the highest attention scores for each query token, thereby enabling efficient sparse attention computation.
-Attention Gate Module +Attention Gate Module
* **Attention Gate Module**: As illustrated in the figure above, to estimate the importance of each block with low computational overhead, we design a lightweight attention gate module. This module first compresses each K block via a MLP layer to generate a representative low-dimensional representation: $K_c^T=W_{kp}K^T$, where $W_{kp}$ denotes the MLP layer weights. Compared to directly applying mean pooling, the learnable MLP can more effectively capture semantic relationships and importance distributions among different tokens, thereby providing a refined representation of each block. After obtaining the compressed representation $K_c$, the importance of each query token with respect to each block is estimated via: $Softmax(Q\cdot K_c^T)$. To enhance the discriminative ability of the MLP layer, we use the full attention result after 1D max pooling $1DMaxPooling(Softmax(Q \cdot K^T))$ as the ground truth. By minimizing the distribution divergence between the two, the MLP layer is guided to learn feature representations that better align with the true attention distribution. @@ -27,7 +27,7 @@ Following the approaches of NSA and MoBA, we partition the KV into multiple bloc During sparse attention computation, each query token may dynamically select different KV blocks, leading to highly irregular memory access patterns in HBM. It is feasible to simply process each query token separately, but it will lead to excessively fine-grained computing, which cannot make full use of the tensor core, thus significantly reducing the GPU computing efficiency.
-Token/Head Union +Token/Head Union
To optimize performance in both the prefill and decode stages, we design a special joint strategy to adapt to their respective characteristics: diff --git a/docs/zh/features/plas_attention.md b/docs/zh/features/plas_attention.md index a49cb25fde..0d8fcb2b97 100644 --- a/docs/zh/features/plas_attention.md +++ b/docs/zh/features/plas_attention.md @@ -15,7 +15,7 @@ 借鉴 NSA 和 MoBA 的方法,我们将键值对 (KV) 划分为多个块。在预填充和解码阶段,我们不再对所有键值进行注意力计算,而是动态地为每个查询 token 选择注意力得分最高的前 K 个块,从而实现高效的稀疏注意力计算。
-Attention Gate Module +Attention Gate Module
* **Attention Gate Module**: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个MLP层压缩每个K个块,生成一个具有代表性的低维表示: $K_c^T=W_{kp}K^T$ ,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异,引导 MLP 层学习更符合真实注意力分布的特征表示。 @@ -29,7 +29,7 @@ 在稀疏注意力计算过程中,每个查询 token 可能会动态选择不同的 KV 块,导致 HBM 的内存访问模式非常不规则。简单地对每个查询 token 进行单独处理是可行的,但这会导致计算粒度过细,无法充分利用张量核,从而显著降低 GPU 的计算效率。
-Token/Head Union +Token/Head Union
为了优化预填充和解码阶段的性能,我们设计了一种特殊的联合策略来适应各自的特点: diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index eaa46448c4..c36536f013 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -30,7 +30,10 @@ from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce try: - from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute + from fastdeploy.model_executor.ops.gpu import ( + w4afp8_gemm_scale_permute, + w4afp8_gemm_weight_convert, + ) except: logger.warning("import w4afp8_gemm_scale_permute Failed!") elif current_platform.is_iluvatar(): @@ -67,6 +70,7 @@ def compute_ffn( expert_idx_per_token: paddle.Tensor, used_in_ep_low_latency: bool = False, estimate_total_token_nums: int = -1, + dequant_scale: paddle.Tensor = None, ): """ Paddle Cutlass compute Fused MoE. @@ -90,6 +94,7 @@ def compute_ffn( token_nums_per_expert, getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_weight_attrs[1]), + dequant_scale, None, (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), @@ -244,6 +249,7 @@ def apply_tp( topk_weights, topk_idx, expert_idx_per_token, + dequant_scale, ) = moe_expert_dispatch( x, gate_out, @@ -264,19 +270,21 @@ def apply_tp( topk_weights, topk_idx, expert_idx_per_token, + dequant_scale, ) = moe_expert_dispatch( x, gate_out, layer.gate_correction_bias, - ( - layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None - ), # if set, permute_input will be int8_t + (layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None), layer.top_k, False, self.moe_quant_type, topk_only_mode=False, ) + if hasattr(layer, "up_gate_proj_in_scale"): + dequant_scale = None + if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8": # only w4a8 need expert_idx_per_token # Other need not this tensor, so we make it None. @@ -284,7 +292,9 @@ def apply_tp( else: expert_idx_per_token = expert_idx_per_token.cast("int64") - ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, expert_idx_per_token) + ffn_out = self.compute_ffn( + layer, permute_input, token_nums_per_expert, expert_idx_per_token, False, -1, dequant_scale + ) # reduce 中会做 topk 个 weight 的 norm 和 routed_scaling_factor fused_moe_out = moe_expert_reduce( @@ -785,7 +795,7 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict): weight_name = self.added_weight_attrs[idx] weight_list = [] for i in range(layer.num_local_experts): - quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80) + quant_weight = w4afp8_gemm_weight_convert(weight_tensor[i]) weight_list.append(quant_weight) quanted_weight = paddle.stack(weight_list, axis=0) getattr(layer, weight_name).set_value(quanted_weight) @@ -815,16 +825,16 @@ def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict): ) # in_scales - for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]: - setattr( - layer, - in_scale_name, - layer.create_parameter( - shape=[layer.num_local_experts], - dtype="float32", - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) + # for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]: + # setattr( + # layer, + # in_scale_name, + # layer.create_parameter( + # shape=[layer.num_local_experts], + # dtype="float32", + # default_initializer=paddle.nn.initializer.Constant(0), + # ), + # ) # weight_scales setattr( @@ -882,10 +892,52 @@ def _permute_weight_scale(weight_scale: paddle.Tensor): return weight_scale def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], processed_in_scale: paddle.Tensor): - processed_weight_scale = ( - paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None] - ) - processed_weight_scale = _permute_weight_scale(processed_weight_scale) + if processed_in_scale is not None: + processed_weight_scale = ( + paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None] + ) + else: + processed_weight_scale = paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) + + if len(processed_weight_scale.shape) == 3: + if name == "up_gate_proj_weight_scale" and processed_weight_scale.shape[-1] * 128 != layer.hidden_size: + assert ( + layer.hidden_size // 128 % processed_weight_scale.shape[-1] == 0 + ), "weight_scale_group_size must be a multiple of 128" + # If it is a multiple of 128, repeat to 128 + processed_weight_scale = processed_weight_scale.repeat_interleave( + layer.hidden_size // 128 // processed_weight_scale.shape[-1], axis=-1 + ) + elif name == "down_proj_weight_scale": + assert ( + layer.moe_intermediate_size // 128 % processed_weight_scale.shape[-1] == 0 + ), "weight_scale_group_size must be a multiple of 128" + # If it is a multiple of 128, repeat to 128 + processed_weight_scale = processed_weight_scale.repeat_interleave( + layer.moe_intermediate_size // 128 // processed_weight_scale.shape[-1], axis=-1 + ) + else: + raise ValueError(f"Invalid weight scale name: {name}") + + origin_shape = processed_weight_scale.shape + processed_weight_scale = processed_weight_scale.transpose([0, 2, 1]) + processed_weight_scale = processed_weight_scale.reshape([-1, processed_weight_scale.shape[-1]]) + processed_weight_scale = _permute_weight_scale(processed_weight_scale) + processed_weight_scale = processed_weight_scale.reshape( + [origin_shape[0], origin_shape[2], origin_shape[1] // 128, 128] + ) + processed_weight_scale = processed_weight_scale.transpose([0, 2, 1, 3]) + setattr( + layer, + name, + layer.create_parameter( + shape=processed_weight_scale.shape, + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + else: + processed_weight_scale = _permute_weight_scale(processed_weight_scale) getattr(layer, name).set_value(processed_weight_scale) # 1. Init scale containers and maps @@ -932,16 +984,15 @@ def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], process scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx) scale_weight_map[name].append(scale_tensor) - # 3. Process scale tensor and set to layer - in_scales = [] - for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]: - in_scales.append(_process_in_scale(in_scale_name, scale_weight_map[in_scale_name])) - for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]): + in_scale_name = weight_scale_name.replace("_weight_scale", "_in_scale") + in_scale = None + if hasattr(layer, in_scale_name) and in_scale_name in scale_weight_map.keys(): + in_scale = _process_in_scale(in_scale_name, scale_weight_map[in_scale_name]) _process_weight_scale( weight_scale_name, scale_weight_map[weight_scale_name], - in_scales[i], + in_scale, ) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py index f9f717d313..f390749e47 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py @@ -275,6 +275,7 @@ def apply( topk_weights, topk_idx, expert_idx_per_token, + dequant_scale, ) = moe_expert_dispatch( x, gate_out, diff --git a/tests/operators/test_w4afp8_gemm.py b/tests/operators/test_w4afp8_gemm.py index 29459ddf39..9c97840575 100644 --- a/tests/operators/test_w4afp8_gemm.py +++ b/tests/operators/test_w4afp8_gemm.py @@ -23,10 +23,10 @@ class TestW4AFP8GEMM(unittest.TestCase): def setUp(self): paddle.seed(0) - self.tokens_per_group = 256 - self.N = 256 - self.K = 256 - self.BATCH = 1 + self.tokens_per_group = 1 + self.N = 1792 + self.K = 8192 + self.BATCH = 64 self.TokenPadding = 0 tokens = [self.tokens_per_group] * self.BATCH @@ -38,14 +38,15 @@ def setUp(self): self.input_fp8 = paddle.randn([self.all_tokens, self.K], dtype="bfloat16").astype(paddle.float8_e4m3fn) self.input_bf16 = self.input_fp8.astype("bfloat16") - self.weight = paddle.randn([self.BATCH, self.N, self.K], dtype="bfloat16") / 10 + self.weight = paddle.randn([self.BATCH, self.N, self.K], dtype="bfloat16") self.weight_scale = 7 / self.weight.abs().max(axis=-1).reshape([self.BATCH, self.N, 1]) - self.weight_quant = (self.weight * self.weight_scale).astype("int") + 7 - self.weight_quant = paddle.clip(self.weight_quant, 0, 14) + self.weight_quant = (self.weight * self.weight_scale).astype("int") + self.weight_quant = paddle.clip(self.weight_quant, -7, 7) + self.weight_quant_naive = self.weight_quant.astype("float32") self.weight_quant = self.weight_quant.astype("bfloat16") + self.weight_quant = paddle.where(self.weight_quant > 0, self.weight_quant, 8 - self.weight_quant) self.weight_dequant_scale = 1 / self.weight_scale.astype("float32") - self.input_row_sum = self.input_bf16.sum(axis=1) * -7 / 512 self.max_tokens = int(self.tokens.max()) def w4afp8_gemm_naive(self, input_bf16, weight_quant, tokens, weight_dequant_scale): @@ -54,7 +55,7 @@ def w4afp8_gemm_naive(self, input_bf16, weight_quant, tokens, weight_dequant_sca pre_fix_token = 0 for i in range(self.BATCH): input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :] - weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i] + weight = weight_quant[i] * weight_dequant_scale[i] out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True) out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i pre_fix_token += tokens[i] @@ -72,7 +73,9 @@ def permute_scale(self, weight_scale): return weight_scale def test_w4afp8_gemm(self): - out_naive = self.w4afp8_gemm_naive(self.input_bf16, self.weight_quant, self.tokens, self.weight_dequant_scale) + out_naive = self.w4afp8_gemm_naive( + self.input_bf16, self.weight_quant_naive, self.tokens, self.weight_dequant_scale + ) weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512) weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu()) @@ -82,10 +85,9 @@ def test_w4afp8_gemm(self): self.input_fp8, weight_int4.cuda(), self.tokens_prefix_sum, - self.input_row_sum.astype("float32"), weight_dequant_scale.astype("float32"), int(self.TokenPadding), - self.max_tokens, + self.all_tokens, True, ) else: @@ -93,7 +95,6 @@ def test_w4afp8_gemm(self): self.input_fp8, weight_int4.cuda(), self.tokens, - self.input_row_sum.astype("float32"), weight_dequant_scale.astype("float32"), int(self.TokenPadding), self.max_tokens, @@ -101,7 +102,7 @@ def test_w4afp8_gemm(self): ) gap = (out_cuda - out_naive).abs() - self.assertLess(float(gap.mean()), 0.07) + self.assertLess(float(gap.mean()), 0.11) if __name__ == "__main__":