Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ def read_model_config(self):
elif "dtype" in self.model_config:
self.model_format = "paddle"
logger.info("The model format is Paddle")
elif "model_type" in self.model_config and self.model_config["model_type"] == "gpt_oss":
self.model_format = "torch"
logger.info("The model format is Hugging Face")
else:
raise ValueError(
"Unknown model format. Please ensure your config.json contains "
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
"FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
# Set moe backend."cutlass","marlin" and "triton" can be set currently.
"FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"),
# Whether to use FLASHINFER as MXFP4 backend for MoE.
"FD_USE_FLASHINFER_MOE_MXFP4_BF16": lambda: os.getenv("FD_USE_FLASHINFER_MOE_MXFP4_BF16", "0"),
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with other boolean environment variables in this file (e.g., FD_USE_DEEP_GEMM, FD_USE_HF_TOKENIZER), the value should be converted to a boolean using bool(int(os.getenv(...))) pattern instead of returning a string "0". This ensures consistent handling of boolean environment variables throughout the codebase.

Suggested change
"FD_USE_FLASHINFER_MOE_MXFP4_BF16": lambda: os.getenv("FD_USE_FLASHINFER_MOE_MXFP4_BF16", "0"),
"FD_USE_FLASHINFER_MOE_MXFP4_BF16": lambda: bool(int(os.getenv("FD_USE_FLASHINFER_MOE_MXFP4_BF16", "0"))),

Copilot uses AI. Check for mistakes.
# Whether to use Machete for wint4 dense gemm.
"FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "1"),
# Set whether to disable recompute the request when the KV cache is full.
Expand Down
9 changes: 7 additions & 2 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from fastdeploy.platforms import current_platform

from .utils import _set_var_distributed, divide, get_tensor
from .utils import _set_var_distributed, divide, get_tensor, modules_to_convert


class UnquantizedLinearMethod(QuantMethodBase):
Expand Down Expand Up @@ -168,7 +168,12 @@ def __init__(
self.output_size,
]

if fd_config.quant_config and not skip_quant and fd_config.quant_config.get_quant_method(self):
if (
fd_config.quant_config
and not skip_quant
and modules_to_convert(prefix, self.fd_config)
and fd_config.quant_config.get_quant_method(self)
):
self.quant_method = fd_config.quant_config.get_quant_method(self)
else:
self.quant_method: Optional[QuantMethodBase] = UnquantizedLinearMethod()
Expand Down
85 changes: 78 additions & 7 deletions fastdeploy/model_executor/layers/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,20 @@ def __init__(
)

def weight_loader(
self, param, loaded_weight, expert_id, shard_id: Optional[str] = None, source: Optional[str] = None
self,
param,
loaded_weight,
expert_id,
shard_id: Optional[str] = None,
source: Optional[str] = None,
loaded_weight_name: Optional[str] = None,
):
"""
source:Avoid redundant transpose of fused weights when weight_loader is called iteratively
"""
if expert_id is None and shard_id is None:
# MoE experts has been fused in disk
self._load_fused_experts_weight(param, loaded_weight)
self._load_fused_experts_weight(param, loaded_weight, loaded_weight_name)
return
if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"):
SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM
Expand Down Expand Up @@ -368,7 +374,7 @@ def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim
loaded_weight = loaded_weight.cast(expert_param.dtype)
h2d_copy(dst=expert_param, src=loaded_weight)

def _load_fused_experts_weight(self, param, loaded_weight):
def _load_fused_experts_weight(self, param, loaded_weight, loaded_weight_name: Optional[str] = None):
if self.tp_size > 1:
dim = -1
if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)):
Expand All @@ -379,10 +385,75 @@ def _load_fused_experts_weight(self, param, loaded_weight):
shard_offset = self.tp_rank * block_size
shard_size = (self.tp_rank + 1) * block_size
loaded_weight = slice_fn(loaded_weight, dim, shard_offset, shard_size)
assert param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
h2d_copy(dst=param, src=loaded_weight)

if self.moe_quant_config.name() == "mxfp4":
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code accesses self.moe_quant_config.name() without checking if moe_quant_config exists or is not None. This could raise an AttributeError if moe_quant_config is None. Consider adding a null check: if self.moe_quant_config and self.moe_quant_config.name() == "mxfp4":

Suggested change
if self.moe_quant_config.name() == "mxfp4":
if self.moe_quant_config and self.moe_quant_config.name() == "mxfp4":

Copilot uses AI. Check for mistakes.
assert loaded_weight_name is not None
weight = get_tensor(loaded_weight)
if "block" in loaded_weight_name:
if "up" in loaded_weight_name:
weight = weight.reshape([self.num_experts, 2 * self.moe_intermediate_size, -1])
elif "down" in loaded_weight_name:
weight = weight.reshape([self.num_experts, self.hidden_size, -1])
weight = paddle.nn.functional.pad(
weight.cast("int32"),
pad=[0, param.shape[-1] - weight.shape[-1], 0, param.shape[-2] - weight.shape[-2]],
mode="constant",
value=0,
).cast("uint8")

if "up" in loaded_weight_name:
gate_w, up_w = weight[:, ::2, :], weight[:, 1::2, :]
param.copy_(paddle.concat([up_w, gate_w], axis=1), False)
else:
param.copy_(weight, False)

elif "scale" in loaded_weight_name:
if "up" in loaded_weight_name:
weight = weight.reshape([self.num_experts, 2 * self.moe_intermediate_size, -1])
elif "down" in loaded_weight_name:
weight = weight.reshape([self.num_experts, self.hidden_size, -1])
weight = paddle.nn.functional.pad(
weight.cast("int32"),
pad=[0, param.shape[-1] - weight.shape[-1], 0, param.shape[-2] - weight.shape[-2]],
mode="constant",
value=0,
).cast("uint8")

def _interleave_mxfp4_cutlass_sm90(w):
w_shape = w.shape
w_interleaved = w.reshape([w_shape[0], w_shape[1], (w_shape[2] // 4), 4])
w_interleaved = w_interleaved.permute([0, 2, 1, 3])
w_interleaved = w_interleaved.reshape([w_shape[0], w_shape[2] // 4, w_shape[1] * 4])
return w_interleaved
Comment on lines +422 to +427
Copy link

Copilot AI Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function _interleave_mxfp4_cutlass_sm90 is duplicated - it appears both here and in fastdeploy/model_executor/layers/quantization/mxfp4.py (lines 297-302 and 422-427). Consider extracting this as a shared utility function to avoid code duplication and improve maintainability.

Copilot uses AI. Check for mistakes.

if "up" in loaded_weight_name:
gate_s, up_s = weight[:, ::2, :], weight[:, 1::2, :]
up_gate_proj_scale = paddle.concat([up_s, gate_s], axis=1)
up_gate_proj_scale_interleaved = _interleave_mxfp4_cutlass_sm90(up_gate_proj_scale)
param.copy_(up_gate_proj_scale_interleaved, False)
else:
down_proj_scale = weight
down_proj_scale_interleaved = _interleave_mxfp4_cutlass_sm90(down_proj_scale)
param.copy_(down_proj_scale_interleaved, False)

elif "bias" in loaded_weight_name:

weight = paddle.nn.functional.pad(
weight, pad=[0, param.shape[-1] - weight.shape[-1]], mode="constant", value=0
)

if "up" in loaded_weight_name:
gate_b, up_b = weight[:, ::2].cast("bfloat16"), weight[:, 1::2].cast("bfloat16")
param.copy_(paddle.concat([up_b, gate_b], axis=-1), False)
else:
param.copy_(weight.cast("bfloat16"), False)

else:
assert param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)

h2d_copy(dst=param, src=loaded_weight)

if hasattr(param, "tensor_track"):
for i in range(self.num_local_experts):
Expand Down
20 changes: 16 additions & 4 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta

from .utils import get_tensor
from .utils import get_tensor, modules_to_convert


class RMSNorm(nn.Layer):
Expand Down Expand Up @@ -94,9 +94,21 @@ def __init__(
"float16",
], f"Unsupported dtype: {dtype}. Must be one of: float32, bfloat16, float16"

self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0
self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0
self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0
self.quant_round_type: int = (
self.fd_config.quant_config.quant_round_type
if fd_config.quant_config and modules_to_convert(prefix, self.fd_config)
else 0
)
self.quant_max_bound: int = (
self.fd_config.quant_config.quant_max_bound
if fd_config.quant_config and modules_to_convert(prefix, self.fd_config)
else 0
)
self.quant_min_bound: int = (
self.fd_config.quant_config.quant_min_bound
if fd_config.quant_config and modules_to_convert(prefix, self.fd_config)
else 0
)
self.begin_norm_axis: int = begin_norm_axis

self.layer_id = layer_id
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"mix_quant",
"tensor_wise_fp8",
"kvcache",
"mxfp4",
]


Expand Down Expand Up @@ -112,6 +113,8 @@ def _get_offline_quant_config_name(quantization_config, is_torch_weight, is_v1_l
has_block_size = "weight_block_size" in quantization_config
if quant_method == "fp8" and has_block_size:
quant_config_name = "block_wise_fp8"
elif quant_method == "mxfp4":
quant_config_name = "mxfp4"
else:
raise ValueError("Torch weight offline quantization only supports block-wise FP8.")
else:
Expand All @@ -129,6 +132,7 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
from .block_wise_fp8 import BlockWiseFP8Config
from .kv_cache import KvCacheQuantConfig
from .mix_quant import MixQuantConfig
from .mxfp4 import MXFP4Config
from .tensor_wise_fp8 import TensorWiseFP8Config
from .w4a8 import W4A8Config
from .w4afp8 import W4AFP8Config
Expand All @@ -150,6 +154,7 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
"tensor_wise_fp8": TensorWiseFP8Config,
"kvcache": KvCacheQuantConfig,
"mix_quant": MixQuantConfig,
"mxfp4": MXFP4Config,
}

return method_to_config[quantization]
Loading
Loading