Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions csrc/quantization/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include "cuda_utils.h"

#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
Expand Down Expand Up @@ -119,3 +121,59 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
TORCH_CHECK(input.size(-1) % 2 == 0);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}

void silu_mul_fp8_quant_deep_gemm_cuda(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
int64_t group_size, double eps, double fp8_min, double fp8_max,
bool use_ue8m0) {
static constexpr int NUM_WARPS = 4;

using Idx_t = uint32_t;

Idx_t E = input.size(0);
Idx_t T = input.size(1);
Idx_t H = input.size(2) / 2;
Idx_t G = cuda_utils::ceil_div(H, Idx_t(group_size * NUM_WARPS));
Idx_t stride_i_e = input.stride(0);
Idx_t stride_i_t = input.stride(1);
Idx_t stride_i_h = input.stride(2);
Idx_t stride_yq_e = y_q.stride(0);
Idx_t stride_yq_t = y_q.stride(1);
Idx_t stride_yq_h = y_q.stride(2);
Idx_t stride_ys_e = y_s.stride(0);
Idx_t stride_ys_t = y_s.stride(1);
Idx_t stride_ys_g = y_s.stride(2);

int stride_counts_e = counts.stride(0);

static constexpr int NUM_PARALLEL_TOKENS = 16;
dim3 grid(E * G, NUM_PARALLEL_TOKENS);
dim3 block(NUM_WARPS * 32);

if (use_ue8m0) {
vllm::silu_mul_fp8_quant_deep_gemm_kernel<__nv_bfloat16, NUM_WARPS, Idx_t,
NUM_PARALLEL_TOKENS, true>
<<<grid, block>>>(
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()),
reinterpret_cast<__nv_fp8_e4m3*>(y_q.data_ptr<at::Float8_e4m3fn>()),
y_s.data_ptr<float>(),
reinterpret_cast<uint32_t*>(counts.data_ptr<int>()), H, G,
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t,
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, stride_counts_e,
static_cast<float>(fp8_min), static_cast<float>(fp8_max));
} else {
vllm::silu_mul_fp8_quant_deep_gemm_kernel<__nv_bfloat16, NUM_WARPS, Idx_t,
NUM_PARALLEL_TOKENS, false>
<<<grid, block>>>(
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()),
reinterpret_cast<__nv_fp8_e4m3*>(y_q.data_ptr<at::Float8_e4m3fn>()),
y_s.data_ptr<float>(),
reinterpret_cast<uint32_t*>(counts.data_ptr<int>()), H, G,
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t,
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, stride_counts_e,
static_cast<float>(fp8_min), static_cast<float>(fp8_max));
}
}
12 changes: 7 additions & 5 deletions tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

# (E, T, H, group_size, seed)
CASES = [
(1, 1, 128, 64, 0),
(1, 4, 128, 128, 0),
(2, 4, 256, 128, 0),
(32, 64, 256, 128, 0),
(17, 31, 768, 128, 0),
(8, 16, 7168, 128, 0),
(8, 32, 7168, 128, 0),
(8, 64, 7168, 128, 0),
(8, 128, 7168, 128, 0),
(8, 256, 7168, 128, 0),
(8, 512, 7168, 128, 0),
(8, 1024, 7168, 128, 0),
]


Expand Down
47 changes: 47 additions & 0 deletions vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,53 @@ def silu_mul_fp8_quant_deep_gemm(
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
group_size: int = 128,
eps: float = 1e-10,
) -> tuple[torch.Tensor, torch.Tensor]:
assert y.ndim == 3, "y must be (E, T, 2*H)"
E, T, H2 = y.shape
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
H = H2 // 2
G = H // group_size
assert H % group_size == 0, "H must be divisible by group_size"
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E

tokens_per_expert = tokens_per_expert.to(device=y.device,
dtype=torch.int32)

fp8_dtype = torch.float8_e4m3fn
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device).contiguous()

stride_ys_e = T * G
stride_ys_t = 1
stride_ys_g = T
y_s = torch.empty_strided((E, T, G),
(stride_ys_e, stride_ys_t, stride_ys_g),
dtype=torch.float32,
device=y.device).contiguous()

f_info = torch.finfo(fp8_dtype)
fp8_max = f_info.max
fp8_min = f_info.min
use_ue8m0 = is_deep_gemm_e8m0_used()
torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda(
y,
tokens_per_expert,
y_q,
y_s,
group_size,
eps,
fp8_min,
fp8_max,
use_ue8m0,
)

return y_q, y_s


def silu_mul_fp8_quant_deep_gemm_old(
y: torch.Tensor, # (E, T, 2*H)
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
group_size: int = 128,
eps: float = 1e-10,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales

Expand Down
Loading