diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 58e7c4f3144..445de4c2d1c 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -350,6 +350,9 @@ def read_model_config(self): elif "dtype" in self.model_config: self.model_format = "paddle" logger.info("The model format is Paddle") + elif "model_type" in self.model_config and self.model_config["model_type"] == "gpt_oss": + self.model_format = "torch" + logger.info("The model format is Hugging Face") else: raise ValueError( "Unknown model format. Please ensure your config.json contains " diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 93f135d09da..e5978f50faa 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -54,6 +54,8 @@ "FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"), # Set moe backend."cutlass","marlin" and "triton" can be set currently. "FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"), + # Whether to use FLASHINFER as MXFP4 backend for MoE. + "FD_USE_FLASHINFER_MOE_MXFP4_BF16": lambda: os.getenv("FD_USE_FLASHINFER_MOE_MXFP4_BF16", "0"), # Whether to use Machete for wint4 dense gemm. "FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "1"), # Set whether to disable recompute the request when the KV cache is full. diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index e126aed2ba1..953a8e806fd 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -32,7 +32,7 @@ ) from fastdeploy.platforms import current_platform -from .utils import _set_var_distributed, divide, get_tensor +from .utils import _set_var_distributed, divide, get_tensor, modules_to_convert class UnquantizedLinearMethod(QuantMethodBase): @@ -168,7 +168,12 @@ def __init__( self.output_size, ] - if fd_config.quant_config and not skip_quant and fd_config.quant_config.get_quant_method(self): + if ( + fd_config.quant_config + and not skip_quant + and modules_to_convert(prefix, self.fd_config) + and fd_config.quant_config.get_quant_method(self) + ): self.quant_method = fd_config.quant_config.get_quant_method(self) else: self.quant_method: Optional[QuantMethodBase] = UnquantizedLinearMethod() diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 743e05031f6..c36358afcc4 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -249,14 +249,20 @@ def __init__( ) def weight_loader( - self, param, loaded_weight, expert_id, shard_id: Optional[str] = None, source: Optional[str] = None + self, + param, + loaded_weight, + expert_id, + shard_id: Optional[str] = None, + source: Optional[str] = None, + loaded_weight_name: Optional[str] = None, ): """ source:Avoid redundant transpose of fused weights when weight_loader is called iteratively """ if expert_id is None and shard_id is None: # MoE experts has been fused in disk - self._load_fused_experts_weight(param, loaded_weight) + self._load_fused_experts_weight(param, loaded_weight, loaded_weight_name) return if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"): SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM @@ -368,7 +374,7 @@ def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim loaded_weight = loaded_weight.cast(expert_param.dtype) h2d_copy(dst=expert_param, src=loaded_weight) - def _load_fused_experts_weight(self, param, loaded_weight): + def _load_fused_experts_weight(self, param, loaded_weight, loaded_weight_name: Optional[str] = None): if self.tp_size > 1: dim = -1 if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)): @@ -379,10 +385,75 @@ def _load_fused_experts_weight(self, param, loaded_weight): shard_offset = self.tp_rank * block_size shard_size = (self.tp_rank + 1) * block_size loaded_weight = slice_fn(loaded_weight, dim, shard_offset, shard_size) - assert param.shape == loaded_weight.shape, ( - f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" - ) - h2d_copy(dst=param, src=loaded_weight) + + if self.moe_quant_config.name() == "mxfp4": + assert loaded_weight_name is not None + weight = get_tensor(loaded_weight) + if "block" in loaded_weight_name: + if "up" in loaded_weight_name: + weight = weight.reshape([self.num_experts, 2 * self.moe_intermediate_size, -1]) + elif "down" in loaded_weight_name: + weight = weight.reshape([self.num_experts, self.hidden_size, -1]) + weight = paddle.nn.functional.pad( + weight.cast("int32"), + pad=[0, param.shape[-1] - weight.shape[-1], 0, param.shape[-2] - weight.shape[-2]], + mode="constant", + value=0, + ).cast("uint8") + + if "up" in loaded_weight_name: + gate_w, up_w = weight[:, ::2, :], weight[:, 1::2, :] + param.copy_(paddle.concat([up_w, gate_w], axis=1), False) + else: + param.copy_(weight, False) + + elif "scale" in loaded_weight_name: + if "up" in loaded_weight_name: + weight = weight.reshape([self.num_experts, 2 * self.moe_intermediate_size, -1]) + elif "down" in loaded_weight_name: + weight = weight.reshape([self.num_experts, self.hidden_size, -1]) + weight = paddle.nn.functional.pad( + weight.cast("int32"), + pad=[0, param.shape[-1] - weight.shape[-1], 0, param.shape[-2] - weight.shape[-2]], + mode="constant", + value=0, + ).cast("uint8") + + def _interleave_mxfp4_cutlass_sm90(w): + w_shape = w.shape + w_interleaved = w.reshape([w_shape[0], w_shape[1], (w_shape[2] // 4), 4]) + w_interleaved = w_interleaved.permute([0, 2, 1, 3]) + w_interleaved = w_interleaved.reshape([w_shape[0], w_shape[2] // 4, w_shape[1] * 4]) + return w_interleaved + + if "up" in loaded_weight_name: + gate_s, up_s = weight[:, ::2, :], weight[:, 1::2, :] + up_gate_proj_scale = paddle.concat([up_s, gate_s], axis=1) + up_gate_proj_scale_interleaved = _interleave_mxfp4_cutlass_sm90(up_gate_proj_scale) + param.copy_(up_gate_proj_scale_interleaved, False) + else: + down_proj_scale = weight + down_proj_scale_interleaved = _interleave_mxfp4_cutlass_sm90(down_proj_scale) + param.copy_(down_proj_scale_interleaved, False) + + elif "bias" in loaded_weight_name: + + weight = paddle.nn.functional.pad( + weight, pad=[0, param.shape[-1] - weight.shape[-1]], mode="constant", value=0 + ) + + if "up" in loaded_weight_name: + gate_b, up_b = weight[:, ::2].cast("bfloat16"), weight[:, 1::2].cast("bfloat16") + param.copy_(paddle.concat([up_b, gate_b], axis=-1), False) + else: + param.copy_(weight.cast("bfloat16"), False) + + else: + assert param.shape == loaded_weight.shape, ( + f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" + ) + + h2d_copy(dst=param, src=loaded_weight) if hasattr(param, "tensor_track"): for i in range(self.num_local_experts): diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index ec1f0e65891..8cc4d6d24a9 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -30,7 +30,7 @@ from fastdeploy.config import FDConfig from fastdeploy.model_executor.forward_meta import ForwardMeta -from .utils import get_tensor +from .utils import get_tensor, modules_to_convert class RMSNorm(nn.Layer): @@ -94,9 +94,21 @@ def __init__( "float16", ], f"Unsupported dtype: {dtype}. Must be one of: float32, bfloat16, float16" - self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0 - self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0 - self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0 + self.quant_round_type: int = ( + self.fd_config.quant_config.quant_round_type + if fd_config.quant_config and modules_to_convert(prefix, self.fd_config) + else 0 + ) + self.quant_max_bound: int = ( + self.fd_config.quant_config.quant_max_bound + if fd_config.quant_config and modules_to_convert(prefix, self.fd_config) + else 0 + ) + self.quant_min_bound: int = ( + self.fd_config.quant_config.quant_min_bound + if fd_config.quant_config and modules_to_convert(prefix, self.fd_config) + else 0 + ) self.begin_norm_axis: int = begin_norm_axis self.layer_id = layer_id diff --git a/fastdeploy/model_executor/layers/quantization/__init__.py b/fastdeploy/model_executor/layers/quantization/__init__.py index 644021b1a47..7f72dd99c11 100644 --- a/fastdeploy/model_executor/layers/quantization/__init__.py +++ b/fastdeploy/model_executor/layers/quantization/__init__.py @@ -33,6 +33,7 @@ "mix_quant", "tensor_wise_fp8", "kvcache", + "mxfp4", ] @@ -112,6 +113,8 @@ def _get_offline_quant_config_name(quantization_config, is_torch_weight, is_v1_l has_block_size = "weight_block_size" in quantization_config if quant_method == "fp8" and has_block_size: quant_config_name = "block_wise_fp8" + elif quant_method == "mxfp4": + quant_config_name = "mxfp4" else: raise ValueError("Torch weight offline quantization only supports block-wise FP8.") else: @@ -129,6 +132,7 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]: from .block_wise_fp8 import BlockWiseFP8Config from .kv_cache import KvCacheQuantConfig from .mix_quant import MixQuantConfig + from .mxfp4 import MXFP4Config from .tensor_wise_fp8 import TensorWiseFP8Config from .w4a8 import W4A8Config from .w4afp8 import W4AFP8Config @@ -150,6 +154,7 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]: "tensor_wise_fp8": TensorWiseFP8Config, "kvcache": KvCacheQuantConfig, "mix_quant": MixQuantConfig, + "mxfp4": MXFP4Config, } return method_to_config[quantization] diff --git a/fastdeploy/model_executor/layers/quantization/mxfp4.py b/fastdeploy/model_executor/layers/quantization/mxfp4.py new file mode 100644 index 00000000000..a20265d87ad --- /dev/null +++ b/fastdeploy/model_executor/layers/quantization/mxfp4.py @@ -0,0 +1,405 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import importlib +import importlib.util +from enum import Enum +from typing import Optional + +import paddle +from paddle import nn + +from fastdeploy import envs +from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch +from fastdeploy.model_executor.utils import set_weight_attrs +from fastdeploy.platforms import current_platform +from fastdeploy.utils import get_logger + +from ..moe import FusedMoE +from .quant_base import QuantConfigBase, QuantMethodBase + +paddle.compat.enable_torch_proxy() +import torch +from torch.nn import functional as F + +logger = get_logger("config", "config.log") + + +class Mxfp4Backend(Enum): + NONE = 0 + + # FlashInfer Backend + SM90_FI_MXFP4_BF16 = 1 + + # Triton Backend + TRITON = 2 + + +def check_device_capability(num): + if paddle.is_compiled_with_cuda(): + device = paddle.device.get_device() + major, minor = paddle.device.cuda.get_device_capability(device) + return major * 10 + minor >= num + else: + return False + + +def has_flashinfer(): + return importlib.util.find_spec("flashinfer") is not None + + +def round_up(a, b): + return ((a + b - 1) // b) * b + + +def get_mxfp4_backend(): + if current_platform.is_cuda(): + if check_device_capability(90) and has_flashinfer() and envs.FD_USE_FLASHINFER_MOE_MXFP4_BF16: + logger.info("FastDeploy Using FlashInfer MXFP4 BF16 backend for SM90 in MoE") + return Mxfp4Backend.SM90_FI_MXFP4_BF16 + else: + logger.info("FastDeploy Using Triton backend in MoE") + return Mxfp4Backend.TRITON + else: + raise NotImplementedError + + +class MXFP4Config(QuantConfigBase): + """Base class for quantization configs.""" + + def __init__(self, is_checkpoint_bf16: bool = False): + super().__init__() + self.is_checkpoint_bf16 = is_checkpoint_bf16 + + def name(self) -> str: + return "mxfp4" + + @classmethod + def from_config(cls, config: dict) -> "MXFP4Config": + is_checkpoint_bf16 = not config.get("is_quantized", False) + return cls(is_checkpoint_bf16) + + def get_quant_method(self, layer) -> Optional[QuantMethodBase]: + if isinstance(layer, FusedMoE): + return MXFP4MoeMethod(self) + else: + raise NotImplementedError + + +class MXFP4MoeMethod(QuantMethodBase): + def __init__( + self, + quant_config: MXFP4Config, + ) -> None: + super().__init__() + self.quant_config = quant_config + self.mxfp4_backend = get_mxfp4_backend() + + def create_weights(self, layer, **extra_weight_attrs): + + block_size = 32 + + intermediate_size_pad = layer.moe_intermediate_size + hidden_size_pad = layer.hidden_size + + if self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + intermediate_size_pad = round_up(intermediate_size_pad, 128) + hidden_size_pad = round_up(hidden_size_pad, 128) + else: + intermediate_size_pad = round_up(intermediate_size_pad, 64) + + self.intermediate_size_pad = intermediate_size_pad + self.hidden_size_pad = hidden_size_pad + self.num_experts = layer.num_local_experts + + self.up_gate_proj_weight_shape = [ + self.num_experts, + intermediate_size_pad * 2, + hidden_size_pad // 2, # uint8 + ] + + self.down_proj_weight_shape = [ + self.num_experts, + hidden_size_pad, + intermediate_size_pad // 2, # uint8 + ] + + self.up_gate_proj_scale_shape = [ + self.num_experts, + intermediate_size_pad * 2, + hidden_size_pad // block_size, + ] + + self.down_proj_scale_shape = [ + self.num_experts, + hidden_size_pad, + intermediate_size_pad // block_size, + ] + + self.weight_dtype = "uint8" + + setattr( + layer, + "up_gate_proj_weight", + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + "down_proj_weight", + layer.create_parameter( + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + setattr( + layer, + "up_gate_proj_scale", + layer.create_parameter( + shape=self.up_gate_proj_scale_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + setattr( + layer, + "down_proj_scale", + layer.create_parameter( + shape=self.down_proj_scale_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch" + + set_weight_attrs(layer.up_gate_proj_weight, extra_weight_attrs) + set_weight_attrs(layer.down_proj_weight, extra_weight_attrs) + + set_weight_attrs(layer.up_gate_proj_scale, extra_weight_attrs) + set_weight_attrs(layer.down_proj_scale, extra_weight_attrs) + + if layer.with_bias: + layer.up_gate_proj_bias = layer.create_parameter( + shape=[self.num_experts, intermediate_size_pad * 2], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + layer.down_proj_bias = layer.create_parameter( + shape=[self.num_experts, hidden_size_pad], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + set_weight_attrs( + layer.up_gate_proj_bias, + extra_weight_attrs, + ) + set_weight_attrs( + layer.down_proj_bias, + extra_weight_attrs, + ) + + if layer.activation == "swigluoai": + gemm1_alpha = layer.create_parameter( + shape=[self.num_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(1.702), + ) + gemm1_alpha.initialize() + setattr(layer, "gemm1_alpha", gemm1_alpha) + + gemm1_beta = layer.create_parameter( + shape=[self.num_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(1.0), + ) + gemm1_beta.initialize() + setattr(layer, "gemm1_beta", gemm1_beta) + + gemm1_clamp_limit = layer.create_parameter( + shape=[self.num_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(7.0), + ) + gemm1_clamp_limit.initialize() + setattr(layer, "gemm1_clamp_limit", gemm1_clamp_limit) + + def process_weights_after_loading(self, layer) -> None: + return + block_size = 32 + if self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + assert ( + layer.up_gate_proj_weight.dim() == 3 + and layer.up_gate_proj_weight.shape[0] == self.num_experts + and layer.up_gate_proj_weight.shape[1] == self.intermediate_size_pad * 2 + and layer.up_gate_proj_weight.shape[2] == self.hidden_size_pad // 2 + ) + assert ( + layer.up_gate_proj_scale.dim() == 3 + and layer.up_gate_proj_scale.shape[0] == self.num_experts + and layer.up_gate_proj_scale.shape[1] == self.intermediate_size_pad * 2 + and layer.up_gate_proj_scale.shape[2] == self.hidden_size_pad // block_size + ) + assert ( + layer.down_proj_weight.dim() == 3 + and layer.down_proj_weight.shape[0] == self.num_experts + and layer.down_proj_weight.shape[1] == self.hidden_size_pad + and layer.down_proj_weight.shape[2] == self.intermediate_size_pad // 2 + ) + assert ( + layer.down_proj_scale.dim() == 3 + and layer.down_proj_scale.shape[0] == self.num_experts + and layer.down_proj_scale.shape[1] == self.hidden_size_pad + and layer.down_proj_scale.shape[2] == self.intermediate_size_pad // block_size + ) + if layer.with_bias: + assert ( + layer.up_gate_proj_bias.dim() == 2 + and layer.up_gate_proj_bias.shape[0] == self.num_experts + and layer.up_gate_proj_bias.shape[1] == self.intermediate_size_pad * 2 + ) + assert ( + layer.down_proj_bias.dim() == 2 + and layer.down_proj_bias.shape[0] == self.num_experts + and layer.down_proj_bias.shape[1] == self.hidden_size_pad + ) + + gate_w, up_w = layer.up_gate_proj_weight[:, ::2, :], layer.up_gate_proj_weight[:, 1::2, :] + gate_b, up_b = layer.up_gate_proj_bias[:, ::2].cast("bfloat16"), layer.up_gate_proj_bias[:, 1::2].cast( + "bfloat16" + ) + gate_s, up_s = layer.up_gate_proj_scale[:, ::2, :], layer.up_gate_proj_scale[:, 1::2, :] + + layer.up_gate_proj_weight.copy_(paddle.concat([up_w, gate_w], axis=1), False) + layer.up_gate_proj_bias.copy_(paddle.concat([up_b, gate_b], axis=-1), False) + + def _interleave_mxfp4_cutlass_sm90(w): + w_shape = w.shape + w_interleaved = w.reshape([w_shape[0], w_shape[1], (w_shape[2] // 4), 4]) + w_interleaved = w_interleaved.permute([0, 2, 1, 3]) + w_interleaved = w_interleaved.reshape([w_shape[0], w_shape[2] // 4, w_shape[1] * 4]) + return w_interleaved + + up_gate_proj_scale = paddle.concat([up_s, gate_s], axis=1) + up_gate_proj_scale_interleaved = _interleave_mxfp4_cutlass_sm90(up_gate_proj_scale) + + down_proj_scale = layer.down_proj_scale + down_proj_scale_interleaved = _interleave_mxfp4_cutlass_sm90(down_proj_scale) + + layer.up_gate_proj_scale.copy_(up_gate_proj_scale_interleaved, False) + layer.down_proj_scale.copy_(down_proj_scale_interleaved, False) + else: + raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") + + def compute_routing(self, router_logits: paddle.Tensor, top_k: int): + """ + Compute routing weights and selected experts from router logits. + + Args: + router_logits (torch.Tensor): Router logits of shape [batch_size, num_experts] + top_k (int): Number of experts to route to per token + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - routing_weights: Expert weights of shape [batch_size, top_k] + - selected_experts: Expert indices of shape [batch_size, top_k] + """ + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.float() + return routing_weights, selected_experts + + def apply(self, layer: nn.Layer, x: paddle.Tensor, router: nn.Layer) -> paddle.Tensor: + router_out = router(x.cast("float32")) + + if self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + + ( + _, + _, + _, + topk_weights, + topk_idx, + *_, + ) = moe_expert_dispatch( + x, + router_out, + layer.gate_correction_bias, + ( + layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None + ), # if set, permute_input will be int8_t + layer.top_k, + False, + self.quant_config.name(), + topk_only_mode=False, + ) + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + quant_scales = [ + layer.up_gate_proj_scale, + layer.down_proj_scale, + ] + extra_kwargs = dict( + use_w4_group_scaling=True, + fc1_expert_weights=layer.up_gate_proj_weight, + fc2_expert_weights=layer.down_proj_weight, + ) + + from flashinfer.fused_moe import ( + cutlass_fused_moe as flashinfer_cutlass_fused_moe, + ) + + x = paddle.nn.functional.pad(x, pad=[0, self.hidden_size_pad - x.shape[-1]], mode="constant", value=0) + + output = paddle.zeros_like(x, dtype="bfloat16") + + _ = flashinfer_cutlass_fused_moe( + input=x, + token_selected_experts=topk_idx, + token_final_scales=topk_weights, + output_dtype=torch.bfloat16, + output=output, + quant_scales=quant_scales, + fc1_expert_biases=layer.up_gate_proj_bias, + fc2_expert_biases=layer.down_proj_bias, + swiglu_alpha=layer.gemm1_alpha, + swiglu_beta=layer.gemm1_beta, + swiglu_limit=layer.gemm1_clamp_limit, + # tp_size=self.moe.tp_size, + # tp_rank=self.moe.tp_rank, + # ep_size=self.moe.ep_size, + # ep_rank=self.moe.ep_rank, + tune_max_num_tokens=8192, + **extra_kwargs, + ) + + return output[..., : layer.hidden_size] + + def process_loaded_weights(self, layer, weights): + """Process the weight after loading. + + This can be used for example, to transpose weights for computation. + """ + return diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index c18f062457e..4254a97c7bf 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -23,6 +23,7 @@ from paddle.framework import in_dynamic_mode from scipy.linalg import block_diag +from fastdeploy.config import FDConfig from fastdeploy.platforms import current_platform if current_platform.is_cuda() and current_platform.available(): @@ -572,3 +573,22 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, ran def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int, offset: int = 0): per_partition_vocab_size = divide(global_vocab_size, world_size) return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, offset=offset) + + +def modules_to_convert(prefix: str, fd_config: FDConfig): + import fnmatch + + if ( + hasattr(fd_config.model_config, "quantization_config") + and fd_config.model_config.quantization_config is not None + ): + if "modules_to_not_convert" in fd_config.model_config.quantization_config: + patterns = fd_config.model_config.quantization_config["modules_to_not_convert"] + for p in patterns: + if fnmatch.fnmatch(prefix, p) or fnmatch.fnmatch(prefix, p + ".*"): + return False + return True + else: + return True + else: + return True diff --git a/fastdeploy/model_executor/models/gpt_oss.py b/fastdeploy/model_executor/models/gpt_oss.py index 682c9f5f1ec..8bf52ab091a 100644 --- a/fastdeploy/model_executor/models/gpt_oss.py +++ b/fastdeploy/model_executor/models/gpt_oss.py @@ -226,6 +226,13 @@ def forward(self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta): ) class GptOssForCausalLM(ModelForCasualLM): def __init__(self, fd_config: FDConfig): + if ( + hasattr(fd_config, "quant_config") + and fd_config.model_config.quantization_config is not None + and "modules_to_not_convert" in fd_config.model_config.quantization_config + ): + fd_config.model_config.quantization_config["modules_to_not_convert"].append("*norm") + super(GptOssForCausalLM, self).__init__(fd_config) self.fd_config = fd_config self.model = GptOssModel(fd_config=fd_config) @@ -265,14 +272,19 @@ def load_weights(self, weights_iterator) -> None: ] expert_params_mapping = [ # (param_name, weight_name, expert_id, shard_id) - ("up_gate_proj_weight", "gate_up_proj", None, None), ("up_gate_proj_bias", "gate_up_proj_bias", None, None), - ("down_proj_weight", "down_proj", None, None), ("down_proj_bias", "down_proj_bias", None, None), + ("up_gate_proj_weight", "gate_up_proj", None, None), + ("down_proj_weight", "down_proj", None, None), + ("up_gate_proj_weight", "gate_up_proj_blocks", None, None), + ("up_gate_proj_scale", "gate_up_proj_scales", None, None), + ("down_proj_weight", "down_proj_blocks", None, None), + ("down_proj_scale", "down_proj_scales", None, None), ] params_dict = dict(self.named_parameters()) process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config) for loaded_weight_name, loaded_weight in weights_iterator: + matched = False for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in loaded_weight_name: continue @@ -284,26 +296,38 @@ def load_weights(self, weights_iterator) -> None: param = params_dict[model_param_name] weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight, shard_id) + matched = True break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping + if not matched: + for param_name, weight_name, expert_id, shard_id in expert_params_mapping: if weight_name not in loaded_weight_name: continue + model_param_name = loaded_weight_name.replace(weight_name, param_name) if model_param_name not in params_dict: continue + param = params_dict[model_param_name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id) + weight_loader( + param, + loaded_weight, + loaded_weight_name=loaded_weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + + matched = True break - else: - model_param_name = loaded_weight_name - if model_param_name not in params_dict: - continue - param = params_dict[model_param_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) - weight_loader(param, loaded_weight) + if not matched: + + model_param_name = loaded_weight_name + if model_param_name not in params_dict: + continue + + param = params_dict[model_param_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + weight_loader(param, loaded_weight) model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) process_weights_after_loading_fn(model_sublayer_name, param) diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index a0878fa7c73..29e8e3c7cb3 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -571,7 +571,7 @@ def check_unified_ckpt(model_dir): try: # check all the file exists - safetensors_num = int(model_files[0].strip(".safetensors").split("-")[-1]) + safetensors_num = int(model_files[0].strip(".safetensors").split("-")[-1]) + 1 flags = [0] * safetensors_num for x in model_files: current_index = int(x.strip(".safetensors").split("-")[1])