Skip to content
Merged
86 changes: 86 additions & 0 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.ops.triton_ops import _TRITON_AVAILABLE, qk_rmsnorm_fused

from .utils import get_tensor

Expand Down Expand Up @@ -256,6 +257,91 @@ def forward(
return out, residual_out


class QKRMSNorm(nn.Layer):
"""
QK Normalization layer.
"""

def __init__(
self,
fd_config: FDConfig,
head_dim: int,
q_size: int,
kv_size: int,
eps: float = 1e-5,
prefix: str = "",
begin_norm_axis: int = 1,
dtype: str = None,
) -> None:
super().__init__()
self.fd_config = fd_config
self.prefix: str = prefix
self.head_dim: int = head_dim
self.q_weight_key: Optional[str] = f"{prefix}.q_norm.weight"
self.k_weight_key: Optional[str] = f"{prefix}.k_norm.weight"
self.eps: float = eps
self._norm_weight_dtype = dtype
if self._norm_weight_dtype is None:
self._norm_weight_dtype = self._helper.get_default_dtype()
else:
assert dtype in [
"float32",
"bfloat16",
"float16",
], f"Unsupported dtype: {dtype}. Must be one of: float32, bfloat16, float16"

self.q_size = q_size
self.kv_size = kv_size

self.q_norm = RMSNorm(
fd_config,
hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.q_norm",
begin_norm_axis=begin_norm_axis,
)
self.k_norm = RMSNorm(
fd_config,
hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.k_norm",
begin_norm_axis=begin_norm_axis,
)
self.qk_norm_fused = _TRITON_AVAILABLE

def load_state_dict(self, state_dict):
self.q_norm.load_state_dict(state_dict)
self.k_norm.load_state_dict(state_dict)

def forward(
self,
qkv_out,
) -> paddle.Tensor:
if self.qk_norm_fused:
qkv_out = qk_rmsnorm_fused(
qkv_out,
self.q_norm.weight,
self.k_norm.weight,
self.eps,
self.q_size,
self.kv_size,
self.head_dim,
)
else:
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)

q_by_head = q.reshape([*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim])
q_by_head = self.q_norm(q_by_head)[0]
q = q_by_head.reshape(q.shape)

k_by_head = k.reshape([*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim])
k_by_head = self.k_norm(k_by_head)[0]
k = k_by_head.reshape(k.shape)

qkv_out = paddle.concat([q, k, v], axis=-1)
return qkv_out


class LayerNorm(nn.Layer):
"""
Initializes the LayerNormalization layer
Expand Down
29 changes: 12 additions & 17 deletions fastdeploy/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.moe.moe import FusedMoE
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.layers.normalization import QKRMSNorm, RMSNorm
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
Expand Down Expand Up @@ -205,18 +205,13 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
rms_norm_eps=fd_config.model_config.rms_norm_eps,
)
if self.use_qk_norm:
self.q_norm = RMSNorm(
self.qk_norm = QKRMSNorm(
fd_config,
hidden_size=self.head_dim,
head_dim=self.head_dim,
q_size=self.q_size,
kv_size=self.kv_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.q_norm",
begin_norm_axis=2,
)
self.k_norm = RMSNorm(
fd_config,
hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.k_norm",
prefix=prefix,
begin_norm_axis=2,
)

Expand All @@ -227,13 +222,8 @@ def forward(
):
""" """
qkv_out = self.qkv_proj(hidden_states)

if self.use_qk_norm:
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)
q = self.q_norm(q.reshape([-1, self.num_heads, self.head_dim]))[0].reshape(q.shape)
k = self.k_norm(k.reshape([-1, self.num_kv_heads, self.head_dim]))[0].reshape(k.shape)
qkv_out = paddle.concat([q, k, v], axis=-1)

qkv_out = self.qk_norm(qkv_out)
atten_out = self.attn(
qkv=qkv_out,
forward_meta=forward_meta,
Expand Down Expand Up @@ -435,6 +425,11 @@ def load_weights(self, weights_iterator) -> None:
("lm_head.linear", "lm_head", None),
("experts.gate_correction_bias", "gate.e_score_correction_bias", None),
]

if self.fd_config.model_config.use_qk_norm:
stacked_params_mapping.append(("qk_norm.q_norm", "q_norm", None))
stacked_params_mapping.append(("qk_norm.k_norm", "k_norm", None))

# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
num_experts=self.fd_config.model_config.n_routed_experts,
Expand Down
45 changes: 14 additions & 31 deletions fastdeploy/model_executor/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding
from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from fastdeploy.model_executor.layers.lm_head import ParallelLMHead
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.layers.normalization import QKRMSNorm, RMSNorm
from fastdeploy.model_executor.models.model_base import (
ModelCategory,
ModelForCasualLM,
Expand All @@ -57,6 +57,10 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None

self.fd_config = fd_config
self.head_dim = fd_config.model_config.head_dim
tp_size = fd_config.parallel_config.tensor_parallel_size
num_kv_heads_replicas = max(1, tp_size // fd_config.model_config.num_key_value_heads)
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // tp_size
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // tp_size

self.qkv_proj = QKVParallelLinear(fd_config, prefix=f"{prefix}.qkv_proj", with_bias=False)

Expand All @@ -75,32 +79,21 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
use_neox_rotary_style=True,
)

self.q_norm = RMSNorm(
self.qk_norm = QKRMSNorm(
fd_config,
hidden_size=self.head_dim,
head_dim=self.head_dim,
q_size=self.q_size,
kv_size=self.kv_size,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.q_norm",
begin_norm_axis=2,
)
self.k_norm = RMSNorm(
fd_config,
hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.k_norm",
prefix=prefix,
begin_norm_axis=2,
)

tp_size = fd_config.parallel_config.tensor_parallel_size
num_kv_heads_replicas = max(1, tp_size // fd_config.model_config.num_key_value_heads)
self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // tp_size
self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // tp_size

def load_state_dict(self, state_dict):
""" """
self.qkv_proj.load_state_dict(state_dict)
self.o_proj.load_state_dict(state_dict)
self.q_norm.load_state_dict(state_dict)
self.k_norm.load_state_dict(state_dict)
self.qk_norm.load_state_dict(state_dict)
self.attn.load_state_dict(state_dict)

def forward(
Expand All @@ -110,19 +103,7 @@ def forward(
):
""" """
qkv_out = self.qkv_proj(hidden_states)
# origin_qkv_out = qkv_out
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)

q_by_head = q.reshape([*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim])
q_by_head = self.q_norm(q_by_head)[0]
q = q_by_head.reshape(q.shape)

k_by_head = k.reshape([*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim])
k_by_head = self.k_norm(k_by_head)[0]
k = k_by_head.reshape(k.shape)

qkv_out = paddle.concat([q, k, v], axis=-1)

qkv_out = self.qk_norm(qkv_out)
atten_out = self.attn(
qkv=qkv_out,
forward_meta=forward_meta,
Expand Down Expand Up @@ -280,6 +261,8 @@ def load_weights(self, weights_iterator) -> None:
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("qk_norm.q_norm", "q_norm", None),
("qk_norm.k_norm", "k_norm", None),
]

params_dict = dict(self.named_parameters())
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def load_weights(self, weights_iterator) -> None:
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("visual", "model.visual", None),
("qk_norm.q_norm", "q_norm", None),
("qk_norm.k_norm", "k_norm", None),
]

params_dict = dict(self.named_parameters())
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/models/qwen3_vl/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def load_weights(self, weights_iterator) -> None:
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("visual", "model.visual", None),
("qk_norm.q_norm", "q_norm", None),
("qk_norm.k_norm", "k_norm", None),
]

expert_params_mapping = self.get_expert_mapping() # Not actually used
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/models/qwen3moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ def load_weights(self, weights_iterator) -> None:
("up_gate_proj", "up_proj", "up"),
("embed_tokens.embeddings", "embed_tokens", None),
("lm_head.linear", "lm_head", None),
("qk_norm.q_norm", "q_norm", None),
("qk_norm.k_norm", "k_norm", None),
]
expert_params_mapping = self.get_expert_mapping()
params_dict = dict(self.named_parameters())
Expand Down
8 changes: 5 additions & 3 deletions fastdeploy/model_executor/ops/triton_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
"""

try:
from .qk_rmsnorm_fused_kernel import qk_rmsnorm_fused
from .repetition_early_stop_kernel import repetition_early_stopper_kernel
from .wint2_fused_moe import fused_moe_wint2_triton
from .wint2_fused_moe_kernel import moe_wint2_ffn_kernel

__all__ = ["fused_moe_wint2_triton", "moe_wint2_ffn_kernel", "repetition_early_stopper_kernel"]
_TRITON_AVAILABLE = True

__all__ = ["moe_wint2_ffn_kernel", "repetition_early_stopper_kernel", "qk_rmsnorm_fused"]
except:
pass
_TRITON_AVAILABLE = False
Loading
Loading