diff --git a/awq/modules/fused/moe.py b/awq/modules/fused/moe.py index bab3bf55..70431609 100644 --- a/awq/modules/fused/moe.py +++ b/awq/modules/fused/moe.py @@ -1,7 +1,5 @@ import torch -import triton from typing import Dict -import triton.language as tl try: import awq_ext # with CUDA kernels @@ -52,17 +50,6 @@ def apply_moe_weights( topk: int, renormalize: bool, ) -> torch.Tensor: - # NOTE: DISABLED FOR NOW - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 1e9 #1024 - if FP16_MATMUL_HEURISTIC_CONDITION: - dequant_w1 = awq_ext.dequantize_weights_cuda( - w1.qweight, w1.scales, w1.qzeros, 0, 0, 0, False - ).permute(0, 2, 1) - dequant_w2 = awq_ext.dequantize_weights_cuda( - w2.qweight, w2.scales, w2.qzeros, 0, 0, 0, False - ).permute(0, 2, 1) - return fused_moe(x, dequant_w1, dequant_w2, gating_output, topk, renormalize) - topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize) (sorted_token_ids, expert_ids, num_tokens_post_padded) = moe_align_block_size( topk_ids, 16, w1.qweight.shape[0] @@ -104,125 +91,6 @@ def apply_moe_weights( return torch.sum(out, dim=1) -@triton.jit -def fused_moe_kernel( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - N, - K, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` - # by to get the element one row down (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - MUL_ROUTED_WEIGHT: tl.constexpr, - top_k: tl.constexpr, - compute_type: tl.constexpr, -): - """ - Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. - - Key Parameters: - - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. - - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. - - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, - and N is the output feature dimension. - - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. - - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. - This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` - by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) - if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: - return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) - token_mask = offs_token < num_valid_tokens - - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak - ) - - off_experts = tl.load(expert_ids_ptr + pid_m) - b_ptrs = ( - b_ptr - + off_experts * stride_be - + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - ) - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - # Load the next block of A and B, generate a mask by checking the K dimension. - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - # We accumulate along the K dimension. - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) - accumulator = accumulator * moe_weight[:, None] - - accumulator = accumulator.to(compute_type) - # ----------------------------------------------------------- - # Write back the block of the output - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] - c_mask = token_mask[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - def moe_align_block_size(topk_ids: torch.Tensor, block_size: int, num_experts: int): """ @@ -267,53 +135,6 @@ def moe_align_block_size(topk_ids: torch.Tensor, block_size: int, num_experts: i return sorted_ids, expert_ids, num_tokens_post_pad -def invoke_fused_moe_kernel( - A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, - top_k: int, - config: dict, -): - assert topk_weights.stride(1) == 1 - assert sorted_token_ids.stride(0) == 1 - - grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), - ) - - fused_moe_kernel[grid]( - A, - B, - C, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, - **config, - ) - - def fused_topk( gating_output: torch.Tensor, topk: int, @@ -349,114 +170,3 @@ def fused_topk( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids - - -def fused_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - inplace: bool = True, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. Defaults to False. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - # assert w1.is_contiguous(), "Expert weights1 must be contiguous" - # assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] - M, _ = hidden_states.shape - E, N, _ = w1.shape - - topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize) - - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - - if topk_ids.numel() <= w1.shape[0]: - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - } - - intermediate_cache1 = torch.empty( - (M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache2 = torch.empty( - (M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache3 = torch.empty( - (M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E - ) - - invoke_fused_moe_kernel( - hidden_states, - w1, - intermediate_cache1, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - ) - - awq_ext.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - ) - - if inplace: - return torch.sum( - intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=hidden_states, - ) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)