diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index a66172fc1b5..d66113f5705 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -29,6 +29,7 @@ from fastdeploy.config import FDConfig from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.ops.triton_ops import _TRITON_AVAILABLE, qk_rmsnorm_fused from .utils import get_tensor @@ -256,6 +257,91 @@ def forward( return out, residual_out +class QKRMSNorm(nn.Layer): + """ + QK Normalization layer. + """ + + def __init__( + self, + fd_config: FDConfig, + head_dim: int, + q_size: int, + kv_size: int, + eps: float = 1e-5, + prefix: str = "", + begin_norm_axis: int = 1, + dtype: str = None, + ) -> None: + super().__init__() + self.fd_config = fd_config + self.prefix: str = prefix + self.head_dim: int = head_dim + self.q_weight_key: Optional[str] = f"{prefix}.q_norm.weight" + self.k_weight_key: Optional[str] = f"{prefix}.k_norm.weight" + self.eps: float = eps + self._norm_weight_dtype = dtype + if self._norm_weight_dtype is None: + self._norm_weight_dtype = self._helper.get_default_dtype() + else: + assert dtype in [ + "float32", + "bfloat16", + "float16", + ], f"Unsupported dtype: {dtype}. Must be one of: float32, bfloat16, float16" + + self.q_size = q_size + self.kv_size = kv_size + + self.q_norm = RMSNorm( + fd_config, + hidden_size=self.head_dim, + eps=fd_config.model_config.rms_norm_eps, + prefix=f"{prefix}.q_norm", + begin_norm_axis=begin_norm_axis, + ) + self.k_norm = RMSNorm( + fd_config, + hidden_size=self.head_dim, + eps=fd_config.model_config.rms_norm_eps, + prefix=f"{prefix}.k_norm", + begin_norm_axis=begin_norm_axis, + ) + self.qk_norm_fused = _TRITON_AVAILABLE + + def load_state_dict(self, state_dict): + self.q_norm.load_state_dict(state_dict) + self.k_norm.load_state_dict(state_dict) + + def forward( + self, + qkv_out, + ) -> paddle.Tensor: + if self.qk_norm_fused: + qkv_out = qk_rmsnorm_fused( + qkv_out, + self.q_norm.weight, + self.k_norm.weight, + self.eps, + self.q_size, + self.kv_size, + self.head_dim, + ) + else: + q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1) + + q_by_head = q.reshape([*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim]) + q_by_head = self.q_norm(q_by_head)[0] + q = q_by_head.reshape(q.shape) + + k_by_head = k.reshape([*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim]) + k_by_head = self.k_norm(k_by_head)[0] + k = k_by_head.reshape(k.shape) + + qkv_out = paddle.concat([q, k, v], axis=-1) + return qkv_out + + class LayerNorm(nn.Layer): """ Initializes the LayerNormalization layer diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index 32775a3e6cf..6bf11bccb0d 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -41,7 +41,7 @@ ) from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.moe.moe import FusedMoE -from fastdeploy.model_executor.layers.normalization import RMSNorm +from fastdeploy.model_executor.layers.normalization import QKRMSNorm, RMSNorm from fastdeploy.model_executor.models.model_base import ( ModelCategory, ModelForCasualLM, @@ -205,18 +205,13 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None rms_norm_eps=fd_config.model_config.rms_norm_eps, ) if self.use_qk_norm: - self.q_norm = RMSNorm( + self.qk_norm = QKRMSNorm( fd_config, - hidden_size=self.head_dim, + head_dim=self.head_dim, + q_size=self.q_size, + kv_size=self.kv_size, eps=fd_config.model_config.rms_norm_eps, - prefix=f"{prefix}.q_norm", - begin_norm_axis=2, - ) - self.k_norm = RMSNorm( - fd_config, - hidden_size=self.head_dim, - eps=fd_config.model_config.rms_norm_eps, - prefix=f"{prefix}.k_norm", + prefix=prefix, begin_norm_axis=2, ) @@ -227,13 +222,8 @@ def forward( ): """ """ qkv_out = self.qkv_proj(hidden_states) - if self.use_qk_norm: - q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1) - q = self.q_norm(q.reshape([-1, self.num_heads, self.head_dim]))[0].reshape(q.shape) - k = self.k_norm(k.reshape([-1, self.num_kv_heads, self.head_dim]))[0].reshape(k.shape) - qkv_out = paddle.concat([q, k, v], axis=-1) - + qkv_out = self.qk_norm(qkv_out) atten_out = self.attn( qkv=qkv_out, forward_meta=forward_meta, @@ -435,6 +425,11 @@ def load_weights(self, weights_iterator) -> None: ("lm_head.linear", "lm_head", None), ("experts.gate_correction_bias", "gate.e_score_correction_bias", None), ] + + if self.fd_config.model_config.use_qk_norm: + stacked_params_mapping.append(("qk_norm.q_norm", "q_norm", None)) + stacked_params_mapping.append(("qk_norm.k_norm", "k_norm", None)) + # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( num_experts=self.fd_config.model_config.n_routed_experts, diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index f148235081f..a33161158c7 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -33,7 +33,7 @@ from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding from fastdeploy.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from fastdeploy.model_executor.layers.lm_head import ParallelLMHead -from fastdeploy.model_executor.layers.normalization import RMSNorm +from fastdeploy.model_executor.layers.normalization import QKRMSNorm, RMSNorm from fastdeploy.model_executor.models.model_base import ( ModelCategory, ModelForCasualLM, @@ -57,6 +57,10 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None self.fd_config = fd_config self.head_dim = fd_config.model_config.head_dim + tp_size = fd_config.parallel_config.tensor_parallel_size + num_kv_heads_replicas = max(1, tp_size // fd_config.model_config.num_key_value_heads) + self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // tp_size + self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // tp_size self.qkv_proj = QKVParallelLinear(fd_config, prefix=f"{prefix}.qkv_proj", with_bias=False) @@ -75,32 +79,21 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None use_neox_rotary_style=True, ) - self.q_norm = RMSNorm( + self.qk_norm = QKRMSNorm( fd_config, - hidden_size=self.head_dim, + head_dim=self.head_dim, + q_size=self.q_size, + kv_size=self.kv_size, eps=fd_config.model_config.rms_norm_eps, - prefix=f"{prefix}.q_norm", - begin_norm_axis=2, - ) - self.k_norm = RMSNorm( - fd_config, - hidden_size=self.head_dim, - eps=fd_config.model_config.rms_norm_eps, - prefix=f"{prefix}.k_norm", + prefix=prefix, begin_norm_axis=2, ) - tp_size = fd_config.parallel_config.tensor_parallel_size - num_kv_heads_replicas = max(1, tp_size // fd_config.model_config.num_key_value_heads) - self.q_size = fd_config.model_config.num_attention_heads * self.head_dim // tp_size - self.kv_size = fd_config.model_config.num_key_value_heads * self.head_dim * num_kv_heads_replicas // tp_size - def load_state_dict(self, state_dict): """ """ self.qkv_proj.load_state_dict(state_dict) self.o_proj.load_state_dict(state_dict) - self.q_norm.load_state_dict(state_dict) - self.k_norm.load_state_dict(state_dict) + self.qk_norm.load_state_dict(state_dict) self.attn.load_state_dict(state_dict) def forward( @@ -110,19 +103,7 @@ def forward( ): """ """ qkv_out = self.qkv_proj(hidden_states) - # origin_qkv_out = qkv_out - q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1) - - q_by_head = q.reshape([*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim]) - q_by_head = self.q_norm(q_by_head)[0] - q = q_by_head.reshape(q.shape) - - k_by_head = k.reshape([*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim]) - k_by_head = self.k_norm(k_by_head)[0] - k = k_by_head.reshape(k.shape) - - qkv_out = paddle.concat([q, k, v], axis=-1) - + qkv_out = self.qk_norm(qkv_out) atten_out = self.attn( qkv=qkv_out, forward_meta=forward_meta, @@ -280,6 +261,8 @@ def load_weights(self, weights_iterator) -> None: ("up_gate_proj", "up_proj", "up"), ("embed_tokens.embeddings", "embed_tokens", None), ("lm_head.linear", "lm_head", None), + ("qk_norm.q_norm", "q_norm", None), + ("qk_norm.k_norm", "k_norm", None), ] params_dict = dict(self.named_parameters()) diff --git a/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py b/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py index 41ba2aa0c36..bd746c336fe 100644 --- a/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py +++ b/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py @@ -207,6 +207,8 @@ def load_weights(self, weights_iterator) -> None: ("embed_tokens.embeddings", "embed_tokens", None), ("lm_head.linear", "lm_head", None), ("visual", "model.visual", None), + ("qk_norm.q_norm", "q_norm", None), + ("qk_norm.k_norm", "k_norm", None), ] params_dict = dict(self.named_parameters()) diff --git a/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl_moe.py b/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl_moe.py index 8f6e27f7632..36c7544518d 100644 --- a/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl_moe.py +++ b/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl_moe.py @@ -227,6 +227,8 @@ def load_weights(self, weights_iterator) -> None: ("embed_tokens.embeddings", "embed_tokens", None), ("lm_head.linear", "lm_head", None), ("visual", "model.visual", None), + ("qk_norm.q_norm", "q_norm", None), + ("qk_norm.k_norm", "k_norm", None), ] expert_params_mapping = self.get_expert_mapping() # Not actually used diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 9301233bb69..f19bdba1b84 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -358,6 +358,8 @@ def load_weights(self, weights_iterator) -> None: ("up_gate_proj", "up_proj", "up"), ("embed_tokens.embeddings", "embed_tokens", None), ("lm_head.linear", "lm_head", None), + ("qk_norm.q_norm", "q_norm", None), + ("qk_norm.k_norm", "k_norm", None), ] expert_params_mapping = self.get_expert_mapping() params_dict = dict(self.named_parameters()) diff --git a/fastdeploy/model_executor/ops/triton_ops/__init__.py b/fastdeploy/model_executor/ops/triton_ops/__init__.py index 3481c30caa6..47b069a8b2c 100644 --- a/fastdeploy/model_executor/ops/triton_ops/__init__.py +++ b/fastdeploy/model_executor/ops/triton_ops/__init__.py @@ -15,10 +15,12 @@ """ try: + from .qk_rmsnorm_fused_kernel import qk_rmsnorm_fused from .repetition_early_stop_kernel import repetition_early_stopper_kernel - from .wint2_fused_moe import fused_moe_wint2_triton from .wint2_fused_moe_kernel import moe_wint2_ffn_kernel - __all__ = ["fused_moe_wint2_triton", "moe_wint2_ffn_kernel", "repetition_early_stopper_kernel"] + _TRITON_AVAILABLE = True + + __all__ = ["moe_wint2_ffn_kernel", "repetition_early_stopper_kernel", "qk_rmsnorm_fused"] except: - pass + _TRITON_AVAILABLE = False diff --git a/fastdeploy/model_executor/ops/triton_ops/qk_rmsnorm_fused_kernel.py b/fastdeploy/model_executor/ops/triton_ops/qk_rmsnorm_fused_kernel.py new file mode 100644 index 00000000000..255e87b9c8f --- /dev/null +++ b/fastdeploy/model_executor/ops/triton_ops/qk_rmsnorm_fused_kernel.py @@ -0,0 +1,130 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import triton +import triton.language as tl + +from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( + enable_compat_on_triton_kernel, +) +from fastdeploy.utils import ceil_div + + +@enable_compat_on_triton_kernel +@triton.jit +def qk_rmsnorm_fused_kernel( + x_ptr, + q_weight_ptr, + k_weight_ptr, + M, + q_size, + kv_size, + eps, + num_q_heads: tl.constexpr, + num_kv_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_HEADS: tl.constexpr, +): + pid = tl.program_id(0) + + heads_per_token = tl.cdiv(num_q_heads, BLOCK_HEADS) + token_id = pid // heads_per_token + head_block = pid % heads_per_token + + if token_id >= M: + return + + offs_h = tl.arange(0, BLOCK_HEADS) + offs_d = tl.arange(0, head_dim) + + head_ids = head_block * BLOCK_HEADS + offs_h + + q_mask = head_ids < num_q_heads + kv_mask = head_ids < num_kv_heads + + row_base = token_id * (q_size + 2 * kv_size) + + # ------------------- + # Q RMSNorm + # ------------------- + q_ptrs = x_ptr + row_base + head_ids[:, None] * head_dim + offs_d[None, :] + + q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0).to(tl.float32) + q_var = tl.sum(q * q, axis=1) / head_dim + q_hat = q * tl.rsqrt(q_var[:, None] + eps) + + q_w = tl.load(q_weight_ptr + offs_d).to(tl.float32) + q_out = q_hat * q_w[None, :] + + tl.store( + q_ptrs, + q_out, + mask=q_mask[:, None], + ) + + # ------------------- + # K RMSNorm + # ------------------- + k_ptrs = x_ptr + row_base + q_size + head_ids[:, None] * head_dim + offs_d[None, :] + + k = tl.load(k_ptrs, mask=kv_mask[:, None], other=0.0).to(tl.float32) + k_var = tl.sum(k * k, axis=1) / head_dim + k_hat = k * tl.rsqrt(k_var[:, None] + eps) + + k_w = tl.load(k_weight_ptr + offs_d).to(tl.float32) + k_out = k_hat * k_w[None, :] + + tl.store( + k_ptrs, + k_out, + mask=kv_mask[:, None], + ) + + +def qk_rmsnorm_fused( + qkv_out, + q_norm_weight, + k_norm_weight, + eps, + q_size, + kv_size, + head_dim, +): + assert qkv_out.ndim == 2 + M, _ = qkv_out.shape + + num_q_heads = q_size // head_dim + num_kv_heads = kv_size // head_dim + + BLOCK_HEADS = 4 if num_q_heads <= 32 else 8 + + grid = (M * ceil_div(num_q_heads, BLOCK_HEADS),) + + qk_rmsnorm_fused_kernel[grid]( + x_ptr=qkv_out, + q_weight_ptr=q_norm_weight, + k_weight_ptr=k_norm_weight, + M=M, + q_size=q_size, + kv_size=kv_size, + eps=eps, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + BLOCK_HEADS=BLOCK_HEADS, + num_warps=2, + ) + return qkv_out diff --git a/fastdeploy/model_executor/ops/triton_ops/triton_utils.py b/fastdeploy/model_executor/ops/triton_ops/triton_utils.py index a61268044bd..f92e8d9c94f 100644 --- a/fastdeploy/model_executor/ops/triton_ops/triton_utils.py +++ b/fastdeploy/model_executor/ops/triton_ops/triton_utils.py @@ -30,6 +30,17 @@ python_path = sys.executable +def enable_compat_on_triton_kernel(triton_kernel): + class WrappedTritonKernel: + def __init__(self, kernel): + self.kernel = kernel + + def __getitem__(self, index): + return paddle.use_compat_guard(enable=True, silent=True)(self.kernel[index]) + + return WrappedTritonKernel(triton_kernel) + + def SubstituteTemplate(template, values): """ Substitute all variables in the given template string using the provided values dictionary. diff --git a/tests/e2e/Qwen3VLMOE_RL/baseline.txt b/tests/e2e/Qwen3VLMOE_RL/baseline.txt index 697490b7e40..8f299cc1e45 100644 --- a/tests/e2e/Qwen3VLMOE_RL/baseline.txt +++ b/tests/e2e/Qwen3VLMOE_RL/baseline.txt @@ -14,13 +14,13 @@ model.layers.0.mlp.gate.weight model.layers.0.mlp.gate.weight:model.layers.0.mlp.gate.weight model.layers.0.post_attention_layernorm.weight model.layers.0.post_attention_layernorm.weight:model.layers.0.post_attention_layernorm.weight -model.layers.0.self_attn.k_norm.weight -model.layers.0.self_attn.k_norm.weight:model.layers.0.self_attn.k_norm.weight model.layers.0.self_attn.o_proj.weight model.layers.0.self_attn.o_proj.weight:model.layers.0.self_attn.o_proj.weight model.layers.0.self_attn.o_proj.weight_scale -model.layers.0.self_attn.q_norm.weight -model.layers.0.self_attn.q_norm.weight:model.layers.0.self_attn.q_norm.weight +model.layers.0.self_attn.qk_norm.k_norm.weight +model.layers.0.self_attn.qk_norm.k_norm.weight:model.layers.0.self_attn.qk_norm.k_norm.weight +model.layers.0.self_attn.qk_norm.q_norm.weight +model.layers.0.self_attn.qk_norm.q_norm.weight:model.layers.0.self_attn.qk_norm.q_norm.weight model.layers.0.self_attn.qkv_proj.weight model.layers.0.self_attn.qkv_proj.weight:model.layers.0.self_attn.qkv_proj.weight model.layers.0.self_attn.qkv_proj.weight_scale @@ -36,13 +36,13 @@ model.layers.1.mlp.gate.weight model.layers.1.mlp.gate.weight:model.layers.1.mlp.gate.weight model.layers.1.post_attention_layernorm.weight model.layers.1.post_attention_layernorm.weight:model.layers.1.post_attention_layernorm.weight -model.layers.1.self_attn.k_norm.weight -model.layers.1.self_attn.k_norm.weight:model.layers.1.self_attn.k_norm.weight model.layers.1.self_attn.o_proj.weight model.layers.1.self_attn.o_proj.weight:model.layers.1.self_attn.o_proj.weight model.layers.1.self_attn.o_proj.weight_scale -model.layers.1.self_attn.q_norm.weight -model.layers.1.self_attn.q_norm.weight:model.layers.1.self_attn.q_norm.weight +model.layers.1.self_attn.qk_norm.k_norm.weight +model.layers.1.self_attn.qk_norm.k_norm.weight:model.layers.1.self_attn.qk_norm.k_norm.weight +model.layers.1.self_attn.qk_norm.q_norm.weight +model.layers.1.self_attn.qk_norm.q_norm.weight:model.layers.1.self_attn.qk_norm.q_norm.weight model.layers.1.self_attn.qkv_proj.weight model.layers.1.self_attn.qkv_proj.weight:model.layers.1.self_attn.qkv_proj.weight model.layers.1.self_attn.qkv_proj.weight_scale @@ -58,13 +58,13 @@ model.layers.10.mlp.gate.weight model.layers.10.mlp.gate.weight:model.layers.10.mlp.gate.weight model.layers.10.post_attention_layernorm.weight model.layers.10.post_attention_layernorm.weight:model.layers.10.post_attention_layernorm.weight -model.layers.10.self_attn.k_norm.weight -model.layers.10.self_attn.k_norm.weight:model.layers.10.self_attn.k_norm.weight model.layers.10.self_attn.o_proj.weight model.layers.10.self_attn.o_proj.weight:model.layers.10.self_attn.o_proj.weight model.layers.10.self_attn.o_proj.weight_scale -model.layers.10.self_attn.q_norm.weight -model.layers.10.self_attn.q_norm.weight:model.layers.10.self_attn.q_norm.weight +model.layers.10.self_attn.qk_norm.k_norm.weight +model.layers.10.self_attn.qk_norm.k_norm.weight:model.layers.10.self_attn.qk_norm.k_norm.weight +model.layers.10.self_attn.qk_norm.q_norm.weight +model.layers.10.self_attn.qk_norm.q_norm.weight:model.layers.10.self_attn.qk_norm.q_norm.weight model.layers.10.self_attn.qkv_proj.weight model.layers.10.self_attn.qkv_proj.weight:model.layers.10.self_attn.qkv_proj.weight model.layers.10.self_attn.qkv_proj.weight_scale @@ -80,13 +80,13 @@ model.layers.11.mlp.gate.weight model.layers.11.mlp.gate.weight:model.layers.11.mlp.gate.weight model.layers.11.post_attention_layernorm.weight model.layers.11.post_attention_layernorm.weight:model.layers.11.post_attention_layernorm.weight -model.layers.11.self_attn.k_norm.weight -model.layers.11.self_attn.k_norm.weight:model.layers.11.self_attn.k_norm.weight model.layers.11.self_attn.o_proj.weight model.layers.11.self_attn.o_proj.weight:model.layers.11.self_attn.o_proj.weight model.layers.11.self_attn.o_proj.weight_scale -model.layers.11.self_attn.q_norm.weight -model.layers.11.self_attn.q_norm.weight:model.layers.11.self_attn.q_norm.weight +model.layers.11.self_attn.qk_norm.k_norm.weight +model.layers.11.self_attn.qk_norm.k_norm.weight:model.layers.11.self_attn.qk_norm.k_norm.weight +model.layers.11.self_attn.qk_norm.q_norm.weight +model.layers.11.self_attn.qk_norm.q_norm.weight:model.layers.11.self_attn.qk_norm.q_norm.weight model.layers.11.self_attn.qkv_proj.weight model.layers.11.self_attn.qkv_proj.weight:model.layers.11.self_attn.qkv_proj.weight model.layers.11.self_attn.qkv_proj.weight_scale @@ -102,13 +102,13 @@ model.layers.12.mlp.gate.weight model.layers.12.mlp.gate.weight:model.layers.12.mlp.gate.weight model.layers.12.post_attention_layernorm.weight model.layers.12.post_attention_layernorm.weight:model.layers.12.post_attention_layernorm.weight -model.layers.12.self_attn.k_norm.weight -model.layers.12.self_attn.k_norm.weight:model.layers.12.self_attn.k_norm.weight model.layers.12.self_attn.o_proj.weight model.layers.12.self_attn.o_proj.weight:model.layers.12.self_attn.o_proj.weight model.layers.12.self_attn.o_proj.weight_scale -model.layers.12.self_attn.q_norm.weight -model.layers.12.self_attn.q_norm.weight:model.layers.12.self_attn.q_norm.weight +model.layers.12.self_attn.qk_norm.k_norm.weight +model.layers.12.self_attn.qk_norm.k_norm.weight:model.layers.12.self_attn.qk_norm.k_norm.weight +model.layers.12.self_attn.qk_norm.q_norm.weight +model.layers.12.self_attn.qk_norm.q_norm.weight:model.layers.12.self_attn.qk_norm.q_norm.weight model.layers.12.self_attn.qkv_proj.weight model.layers.12.self_attn.qkv_proj.weight:model.layers.12.self_attn.qkv_proj.weight model.layers.12.self_attn.qkv_proj.weight_scale @@ -124,13 +124,13 @@ model.layers.13.mlp.gate.weight model.layers.13.mlp.gate.weight:model.layers.13.mlp.gate.weight model.layers.13.post_attention_layernorm.weight model.layers.13.post_attention_layernorm.weight:model.layers.13.post_attention_layernorm.weight -model.layers.13.self_attn.k_norm.weight -model.layers.13.self_attn.k_norm.weight:model.layers.13.self_attn.k_norm.weight model.layers.13.self_attn.o_proj.weight model.layers.13.self_attn.o_proj.weight:model.layers.13.self_attn.o_proj.weight model.layers.13.self_attn.o_proj.weight_scale -model.layers.13.self_attn.q_norm.weight -model.layers.13.self_attn.q_norm.weight:model.layers.13.self_attn.q_norm.weight +model.layers.13.self_attn.qk_norm.k_norm.weight +model.layers.13.self_attn.qk_norm.k_norm.weight:model.layers.13.self_attn.qk_norm.k_norm.weight +model.layers.13.self_attn.qk_norm.q_norm.weight +model.layers.13.self_attn.qk_norm.q_norm.weight:model.layers.13.self_attn.qk_norm.q_norm.weight model.layers.13.self_attn.qkv_proj.weight model.layers.13.self_attn.qkv_proj.weight:model.layers.13.self_attn.qkv_proj.weight model.layers.13.self_attn.qkv_proj.weight_scale @@ -146,13 +146,13 @@ model.layers.14.mlp.gate.weight model.layers.14.mlp.gate.weight:model.layers.14.mlp.gate.weight model.layers.14.post_attention_layernorm.weight model.layers.14.post_attention_layernorm.weight:model.layers.14.post_attention_layernorm.weight -model.layers.14.self_attn.k_norm.weight -model.layers.14.self_attn.k_norm.weight:model.layers.14.self_attn.k_norm.weight model.layers.14.self_attn.o_proj.weight model.layers.14.self_attn.o_proj.weight:model.layers.14.self_attn.o_proj.weight model.layers.14.self_attn.o_proj.weight_scale -model.layers.14.self_attn.q_norm.weight -model.layers.14.self_attn.q_norm.weight:model.layers.14.self_attn.q_norm.weight +model.layers.14.self_attn.qk_norm.k_norm.weight +model.layers.14.self_attn.qk_norm.k_norm.weight:model.layers.14.self_attn.qk_norm.k_norm.weight +model.layers.14.self_attn.qk_norm.q_norm.weight +model.layers.14.self_attn.qk_norm.q_norm.weight:model.layers.14.self_attn.qk_norm.q_norm.weight model.layers.14.self_attn.qkv_proj.weight model.layers.14.self_attn.qkv_proj.weight:model.layers.14.self_attn.qkv_proj.weight model.layers.14.self_attn.qkv_proj.weight_scale @@ -168,13 +168,13 @@ model.layers.15.mlp.gate.weight model.layers.15.mlp.gate.weight:model.layers.15.mlp.gate.weight model.layers.15.post_attention_layernorm.weight model.layers.15.post_attention_layernorm.weight:model.layers.15.post_attention_layernorm.weight -model.layers.15.self_attn.k_norm.weight -model.layers.15.self_attn.k_norm.weight:model.layers.15.self_attn.k_norm.weight model.layers.15.self_attn.o_proj.weight model.layers.15.self_attn.o_proj.weight:model.layers.15.self_attn.o_proj.weight model.layers.15.self_attn.o_proj.weight_scale -model.layers.15.self_attn.q_norm.weight -model.layers.15.self_attn.q_norm.weight:model.layers.15.self_attn.q_norm.weight +model.layers.15.self_attn.qk_norm.k_norm.weight +model.layers.15.self_attn.qk_norm.k_norm.weight:model.layers.15.self_attn.qk_norm.k_norm.weight +model.layers.15.self_attn.qk_norm.q_norm.weight +model.layers.15.self_attn.qk_norm.q_norm.weight:model.layers.15.self_attn.qk_norm.q_norm.weight model.layers.15.self_attn.qkv_proj.weight model.layers.15.self_attn.qkv_proj.weight:model.layers.15.self_attn.qkv_proj.weight model.layers.15.self_attn.qkv_proj.weight_scale @@ -190,13 +190,13 @@ model.layers.16.mlp.gate.weight model.layers.16.mlp.gate.weight:model.layers.16.mlp.gate.weight model.layers.16.post_attention_layernorm.weight model.layers.16.post_attention_layernorm.weight:model.layers.16.post_attention_layernorm.weight -model.layers.16.self_attn.k_norm.weight -model.layers.16.self_attn.k_norm.weight:model.layers.16.self_attn.k_norm.weight model.layers.16.self_attn.o_proj.weight model.layers.16.self_attn.o_proj.weight:model.layers.16.self_attn.o_proj.weight model.layers.16.self_attn.o_proj.weight_scale -model.layers.16.self_attn.q_norm.weight -model.layers.16.self_attn.q_norm.weight:model.layers.16.self_attn.q_norm.weight +model.layers.16.self_attn.qk_norm.k_norm.weight +model.layers.16.self_attn.qk_norm.k_norm.weight:model.layers.16.self_attn.qk_norm.k_norm.weight +model.layers.16.self_attn.qk_norm.q_norm.weight +model.layers.16.self_attn.qk_norm.q_norm.weight:model.layers.16.self_attn.qk_norm.q_norm.weight model.layers.16.self_attn.qkv_proj.weight model.layers.16.self_attn.qkv_proj.weight:model.layers.16.self_attn.qkv_proj.weight model.layers.16.self_attn.qkv_proj.weight_scale @@ -212,13 +212,13 @@ model.layers.17.mlp.gate.weight model.layers.17.mlp.gate.weight:model.layers.17.mlp.gate.weight model.layers.17.post_attention_layernorm.weight model.layers.17.post_attention_layernorm.weight:model.layers.17.post_attention_layernorm.weight -model.layers.17.self_attn.k_norm.weight -model.layers.17.self_attn.k_norm.weight:model.layers.17.self_attn.k_norm.weight model.layers.17.self_attn.o_proj.weight model.layers.17.self_attn.o_proj.weight:model.layers.17.self_attn.o_proj.weight model.layers.17.self_attn.o_proj.weight_scale -model.layers.17.self_attn.q_norm.weight -model.layers.17.self_attn.q_norm.weight:model.layers.17.self_attn.q_norm.weight +model.layers.17.self_attn.qk_norm.k_norm.weight +model.layers.17.self_attn.qk_norm.k_norm.weight:model.layers.17.self_attn.qk_norm.k_norm.weight +model.layers.17.self_attn.qk_norm.q_norm.weight +model.layers.17.self_attn.qk_norm.q_norm.weight:model.layers.17.self_attn.qk_norm.q_norm.weight model.layers.17.self_attn.qkv_proj.weight model.layers.17.self_attn.qkv_proj.weight:model.layers.17.self_attn.qkv_proj.weight model.layers.17.self_attn.qkv_proj.weight_scale @@ -234,13 +234,13 @@ model.layers.18.mlp.gate.weight model.layers.18.mlp.gate.weight:model.layers.18.mlp.gate.weight model.layers.18.post_attention_layernorm.weight model.layers.18.post_attention_layernorm.weight:model.layers.18.post_attention_layernorm.weight -model.layers.18.self_attn.k_norm.weight -model.layers.18.self_attn.k_norm.weight:model.layers.18.self_attn.k_norm.weight model.layers.18.self_attn.o_proj.weight model.layers.18.self_attn.o_proj.weight:model.layers.18.self_attn.o_proj.weight model.layers.18.self_attn.o_proj.weight_scale -model.layers.18.self_attn.q_norm.weight -model.layers.18.self_attn.q_norm.weight:model.layers.18.self_attn.q_norm.weight +model.layers.18.self_attn.qk_norm.k_norm.weight +model.layers.18.self_attn.qk_norm.k_norm.weight:model.layers.18.self_attn.qk_norm.k_norm.weight +model.layers.18.self_attn.qk_norm.q_norm.weight +model.layers.18.self_attn.qk_norm.q_norm.weight:model.layers.18.self_attn.qk_norm.q_norm.weight model.layers.18.self_attn.qkv_proj.weight model.layers.18.self_attn.qkv_proj.weight:model.layers.18.self_attn.qkv_proj.weight model.layers.18.self_attn.qkv_proj.weight_scale @@ -256,13 +256,13 @@ model.layers.19.mlp.gate.weight model.layers.19.mlp.gate.weight:model.layers.19.mlp.gate.weight model.layers.19.post_attention_layernorm.weight model.layers.19.post_attention_layernorm.weight:model.layers.19.post_attention_layernorm.weight -model.layers.19.self_attn.k_norm.weight -model.layers.19.self_attn.k_norm.weight:model.layers.19.self_attn.k_norm.weight model.layers.19.self_attn.o_proj.weight model.layers.19.self_attn.o_proj.weight:model.layers.19.self_attn.o_proj.weight model.layers.19.self_attn.o_proj.weight_scale -model.layers.19.self_attn.q_norm.weight -model.layers.19.self_attn.q_norm.weight:model.layers.19.self_attn.q_norm.weight +model.layers.19.self_attn.qk_norm.k_norm.weight +model.layers.19.self_attn.qk_norm.k_norm.weight:model.layers.19.self_attn.qk_norm.k_norm.weight +model.layers.19.self_attn.qk_norm.q_norm.weight +model.layers.19.self_attn.qk_norm.q_norm.weight:model.layers.19.self_attn.qk_norm.q_norm.weight model.layers.19.self_attn.qkv_proj.weight model.layers.19.self_attn.qkv_proj.weight:model.layers.19.self_attn.qkv_proj.weight model.layers.19.self_attn.qkv_proj.weight_scale @@ -278,13 +278,13 @@ model.layers.2.mlp.gate.weight model.layers.2.mlp.gate.weight:model.layers.2.mlp.gate.weight model.layers.2.post_attention_layernorm.weight model.layers.2.post_attention_layernorm.weight:model.layers.2.post_attention_layernorm.weight -model.layers.2.self_attn.k_norm.weight -model.layers.2.self_attn.k_norm.weight:model.layers.2.self_attn.k_norm.weight model.layers.2.self_attn.o_proj.weight model.layers.2.self_attn.o_proj.weight:model.layers.2.self_attn.o_proj.weight model.layers.2.self_attn.o_proj.weight_scale -model.layers.2.self_attn.q_norm.weight -model.layers.2.self_attn.q_norm.weight:model.layers.2.self_attn.q_norm.weight +model.layers.2.self_attn.qk_norm.k_norm.weight +model.layers.2.self_attn.qk_norm.k_norm.weight:model.layers.2.self_attn.qk_norm.k_norm.weight +model.layers.2.self_attn.qk_norm.q_norm.weight +model.layers.2.self_attn.qk_norm.q_norm.weight:model.layers.2.self_attn.qk_norm.q_norm.weight model.layers.2.self_attn.qkv_proj.weight model.layers.2.self_attn.qkv_proj.weight:model.layers.2.self_attn.qkv_proj.weight model.layers.2.self_attn.qkv_proj.weight_scale @@ -300,13 +300,13 @@ model.layers.20.mlp.gate.weight model.layers.20.mlp.gate.weight:model.layers.20.mlp.gate.weight model.layers.20.post_attention_layernorm.weight model.layers.20.post_attention_layernorm.weight:model.layers.20.post_attention_layernorm.weight -model.layers.20.self_attn.k_norm.weight -model.layers.20.self_attn.k_norm.weight:model.layers.20.self_attn.k_norm.weight model.layers.20.self_attn.o_proj.weight model.layers.20.self_attn.o_proj.weight:model.layers.20.self_attn.o_proj.weight model.layers.20.self_attn.o_proj.weight_scale -model.layers.20.self_attn.q_norm.weight -model.layers.20.self_attn.q_norm.weight:model.layers.20.self_attn.q_norm.weight +model.layers.20.self_attn.qk_norm.k_norm.weight +model.layers.20.self_attn.qk_norm.k_norm.weight:model.layers.20.self_attn.qk_norm.k_norm.weight +model.layers.20.self_attn.qk_norm.q_norm.weight +model.layers.20.self_attn.qk_norm.q_norm.weight:model.layers.20.self_attn.qk_norm.q_norm.weight model.layers.20.self_attn.qkv_proj.weight model.layers.20.self_attn.qkv_proj.weight:model.layers.20.self_attn.qkv_proj.weight model.layers.20.self_attn.qkv_proj.weight_scale @@ -322,13 +322,13 @@ model.layers.21.mlp.gate.weight model.layers.21.mlp.gate.weight:model.layers.21.mlp.gate.weight model.layers.21.post_attention_layernorm.weight model.layers.21.post_attention_layernorm.weight:model.layers.21.post_attention_layernorm.weight -model.layers.21.self_attn.k_norm.weight -model.layers.21.self_attn.k_norm.weight:model.layers.21.self_attn.k_norm.weight model.layers.21.self_attn.o_proj.weight model.layers.21.self_attn.o_proj.weight:model.layers.21.self_attn.o_proj.weight model.layers.21.self_attn.o_proj.weight_scale -model.layers.21.self_attn.q_norm.weight -model.layers.21.self_attn.q_norm.weight:model.layers.21.self_attn.q_norm.weight +model.layers.21.self_attn.qk_norm.k_norm.weight +model.layers.21.self_attn.qk_norm.k_norm.weight:model.layers.21.self_attn.qk_norm.k_norm.weight +model.layers.21.self_attn.qk_norm.q_norm.weight +model.layers.21.self_attn.qk_norm.q_norm.weight:model.layers.21.self_attn.qk_norm.q_norm.weight model.layers.21.self_attn.qkv_proj.weight model.layers.21.self_attn.qkv_proj.weight:model.layers.21.self_attn.qkv_proj.weight model.layers.21.self_attn.qkv_proj.weight_scale @@ -344,13 +344,13 @@ model.layers.22.mlp.gate.weight model.layers.22.mlp.gate.weight:model.layers.22.mlp.gate.weight model.layers.22.post_attention_layernorm.weight model.layers.22.post_attention_layernorm.weight:model.layers.22.post_attention_layernorm.weight -model.layers.22.self_attn.k_norm.weight -model.layers.22.self_attn.k_norm.weight:model.layers.22.self_attn.k_norm.weight model.layers.22.self_attn.o_proj.weight model.layers.22.self_attn.o_proj.weight:model.layers.22.self_attn.o_proj.weight model.layers.22.self_attn.o_proj.weight_scale -model.layers.22.self_attn.q_norm.weight -model.layers.22.self_attn.q_norm.weight:model.layers.22.self_attn.q_norm.weight +model.layers.22.self_attn.qk_norm.k_norm.weight +model.layers.22.self_attn.qk_norm.k_norm.weight:model.layers.22.self_attn.qk_norm.k_norm.weight +model.layers.22.self_attn.qk_norm.q_norm.weight +model.layers.22.self_attn.qk_norm.q_norm.weight:model.layers.22.self_attn.qk_norm.q_norm.weight model.layers.22.self_attn.qkv_proj.weight model.layers.22.self_attn.qkv_proj.weight:model.layers.22.self_attn.qkv_proj.weight model.layers.22.self_attn.qkv_proj.weight_scale @@ -366,13 +366,13 @@ model.layers.23.mlp.gate.weight model.layers.23.mlp.gate.weight:model.layers.23.mlp.gate.weight model.layers.23.post_attention_layernorm.weight model.layers.23.post_attention_layernorm.weight:model.layers.23.post_attention_layernorm.weight -model.layers.23.self_attn.k_norm.weight -model.layers.23.self_attn.k_norm.weight:model.layers.23.self_attn.k_norm.weight model.layers.23.self_attn.o_proj.weight model.layers.23.self_attn.o_proj.weight:model.layers.23.self_attn.o_proj.weight model.layers.23.self_attn.o_proj.weight_scale -model.layers.23.self_attn.q_norm.weight -model.layers.23.self_attn.q_norm.weight:model.layers.23.self_attn.q_norm.weight +model.layers.23.self_attn.qk_norm.k_norm.weight +model.layers.23.self_attn.qk_norm.k_norm.weight:model.layers.23.self_attn.qk_norm.k_norm.weight +model.layers.23.self_attn.qk_norm.q_norm.weight +model.layers.23.self_attn.qk_norm.q_norm.weight:model.layers.23.self_attn.qk_norm.q_norm.weight model.layers.23.self_attn.qkv_proj.weight model.layers.23.self_attn.qkv_proj.weight:model.layers.23.self_attn.qkv_proj.weight model.layers.23.self_attn.qkv_proj.weight_scale @@ -388,13 +388,13 @@ model.layers.24.mlp.gate.weight model.layers.24.mlp.gate.weight:model.layers.24.mlp.gate.weight model.layers.24.post_attention_layernorm.weight model.layers.24.post_attention_layernorm.weight:model.layers.24.post_attention_layernorm.weight -model.layers.24.self_attn.k_norm.weight -model.layers.24.self_attn.k_norm.weight:model.layers.24.self_attn.k_norm.weight model.layers.24.self_attn.o_proj.weight model.layers.24.self_attn.o_proj.weight:model.layers.24.self_attn.o_proj.weight model.layers.24.self_attn.o_proj.weight_scale -model.layers.24.self_attn.q_norm.weight -model.layers.24.self_attn.q_norm.weight:model.layers.24.self_attn.q_norm.weight +model.layers.24.self_attn.qk_norm.k_norm.weight +model.layers.24.self_attn.qk_norm.k_norm.weight:model.layers.24.self_attn.qk_norm.k_norm.weight +model.layers.24.self_attn.qk_norm.q_norm.weight +model.layers.24.self_attn.qk_norm.q_norm.weight:model.layers.24.self_attn.qk_norm.q_norm.weight model.layers.24.self_attn.qkv_proj.weight model.layers.24.self_attn.qkv_proj.weight:model.layers.24.self_attn.qkv_proj.weight model.layers.24.self_attn.qkv_proj.weight_scale @@ -410,13 +410,13 @@ model.layers.25.mlp.gate.weight model.layers.25.mlp.gate.weight:model.layers.25.mlp.gate.weight model.layers.25.post_attention_layernorm.weight model.layers.25.post_attention_layernorm.weight:model.layers.25.post_attention_layernorm.weight -model.layers.25.self_attn.k_norm.weight -model.layers.25.self_attn.k_norm.weight:model.layers.25.self_attn.k_norm.weight model.layers.25.self_attn.o_proj.weight model.layers.25.self_attn.o_proj.weight:model.layers.25.self_attn.o_proj.weight model.layers.25.self_attn.o_proj.weight_scale -model.layers.25.self_attn.q_norm.weight -model.layers.25.self_attn.q_norm.weight:model.layers.25.self_attn.q_norm.weight +model.layers.25.self_attn.qk_norm.k_norm.weight +model.layers.25.self_attn.qk_norm.k_norm.weight:model.layers.25.self_attn.qk_norm.k_norm.weight +model.layers.25.self_attn.qk_norm.q_norm.weight +model.layers.25.self_attn.qk_norm.q_norm.weight:model.layers.25.self_attn.qk_norm.q_norm.weight model.layers.25.self_attn.qkv_proj.weight model.layers.25.self_attn.qkv_proj.weight:model.layers.25.self_attn.qkv_proj.weight model.layers.25.self_attn.qkv_proj.weight_scale @@ -432,13 +432,13 @@ model.layers.26.mlp.gate.weight model.layers.26.mlp.gate.weight:model.layers.26.mlp.gate.weight model.layers.26.post_attention_layernorm.weight model.layers.26.post_attention_layernorm.weight:model.layers.26.post_attention_layernorm.weight -model.layers.26.self_attn.k_norm.weight -model.layers.26.self_attn.k_norm.weight:model.layers.26.self_attn.k_norm.weight model.layers.26.self_attn.o_proj.weight model.layers.26.self_attn.o_proj.weight:model.layers.26.self_attn.o_proj.weight model.layers.26.self_attn.o_proj.weight_scale -model.layers.26.self_attn.q_norm.weight -model.layers.26.self_attn.q_norm.weight:model.layers.26.self_attn.q_norm.weight +model.layers.26.self_attn.qk_norm.k_norm.weight +model.layers.26.self_attn.qk_norm.k_norm.weight:model.layers.26.self_attn.qk_norm.k_norm.weight +model.layers.26.self_attn.qk_norm.q_norm.weight +model.layers.26.self_attn.qk_norm.q_norm.weight:model.layers.26.self_attn.qk_norm.q_norm.weight model.layers.26.self_attn.qkv_proj.weight model.layers.26.self_attn.qkv_proj.weight:model.layers.26.self_attn.qkv_proj.weight model.layers.26.self_attn.qkv_proj.weight_scale @@ -454,13 +454,13 @@ model.layers.27.mlp.gate.weight model.layers.27.mlp.gate.weight:model.layers.27.mlp.gate.weight model.layers.27.post_attention_layernorm.weight model.layers.27.post_attention_layernorm.weight:model.layers.27.post_attention_layernorm.weight -model.layers.27.self_attn.k_norm.weight -model.layers.27.self_attn.k_norm.weight:model.layers.27.self_attn.k_norm.weight model.layers.27.self_attn.o_proj.weight model.layers.27.self_attn.o_proj.weight:model.layers.27.self_attn.o_proj.weight model.layers.27.self_attn.o_proj.weight_scale -model.layers.27.self_attn.q_norm.weight -model.layers.27.self_attn.q_norm.weight:model.layers.27.self_attn.q_norm.weight +model.layers.27.self_attn.qk_norm.k_norm.weight +model.layers.27.self_attn.qk_norm.k_norm.weight:model.layers.27.self_attn.qk_norm.k_norm.weight +model.layers.27.self_attn.qk_norm.q_norm.weight +model.layers.27.self_attn.qk_norm.q_norm.weight:model.layers.27.self_attn.qk_norm.q_norm.weight model.layers.27.self_attn.qkv_proj.weight model.layers.27.self_attn.qkv_proj.weight:model.layers.27.self_attn.qkv_proj.weight model.layers.27.self_attn.qkv_proj.weight_scale @@ -476,13 +476,13 @@ model.layers.28.mlp.gate.weight model.layers.28.mlp.gate.weight:model.layers.28.mlp.gate.weight model.layers.28.post_attention_layernorm.weight model.layers.28.post_attention_layernorm.weight:model.layers.28.post_attention_layernorm.weight -model.layers.28.self_attn.k_norm.weight -model.layers.28.self_attn.k_norm.weight:model.layers.28.self_attn.k_norm.weight model.layers.28.self_attn.o_proj.weight model.layers.28.self_attn.o_proj.weight:model.layers.28.self_attn.o_proj.weight model.layers.28.self_attn.o_proj.weight_scale -model.layers.28.self_attn.q_norm.weight -model.layers.28.self_attn.q_norm.weight:model.layers.28.self_attn.q_norm.weight +model.layers.28.self_attn.qk_norm.k_norm.weight +model.layers.28.self_attn.qk_norm.k_norm.weight:model.layers.28.self_attn.qk_norm.k_norm.weight +model.layers.28.self_attn.qk_norm.q_norm.weight +model.layers.28.self_attn.qk_norm.q_norm.weight:model.layers.28.self_attn.qk_norm.q_norm.weight model.layers.28.self_attn.qkv_proj.weight model.layers.28.self_attn.qkv_proj.weight:model.layers.28.self_attn.qkv_proj.weight model.layers.28.self_attn.qkv_proj.weight_scale @@ -498,13 +498,13 @@ model.layers.29.mlp.gate.weight model.layers.29.mlp.gate.weight:model.layers.29.mlp.gate.weight model.layers.29.post_attention_layernorm.weight model.layers.29.post_attention_layernorm.weight:model.layers.29.post_attention_layernorm.weight -model.layers.29.self_attn.k_norm.weight -model.layers.29.self_attn.k_norm.weight:model.layers.29.self_attn.k_norm.weight model.layers.29.self_attn.o_proj.weight model.layers.29.self_attn.o_proj.weight:model.layers.29.self_attn.o_proj.weight model.layers.29.self_attn.o_proj.weight_scale -model.layers.29.self_attn.q_norm.weight -model.layers.29.self_attn.q_norm.weight:model.layers.29.self_attn.q_norm.weight +model.layers.29.self_attn.qk_norm.k_norm.weight +model.layers.29.self_attn.qk_norm.k_norm.weight:model.layers.29.self_attn.qk_norm.k_norm.weight +model.layers.29.self_attn.qk_norm.q_norm.weight +model.layers.29.self_attn.qk_norm.q_norm.weight:model.layers.29.self_attn.qk_norm.q_norm.weight model.layers.29.self_attn.qkv_proj.weight model.layers.29.self_attn.qkv_proj.weight:model.layers.29.self_attn.qkv_proj.weight model.layers.29.self_attn.qkv_proj.weight_scale @@ -520,13 +520,13 @@ model.layers.3.mlp.gate.weight model.layers.3.mlp.gate.weight:model.layers.3.mlp.gate.weight model.layers.3.post_attention_layernorm.weight model.layers.3.post_attention_layernorm.weight:model.layers.3.post_attention_layernorm.weight -model.layers.3.self_attn.k_norm.weight -model.layers.3.self_attn.k_norm.weight:model.layers.3.self_attn.k_norm.weight model.layers.3.self_attn.o_proj.weight model.layers.3.self_attn.o_proj.weight:model.layers.3.self_attn.o_proj.weight model.layers.3.self_attn.o_proj.weight_scale -model.layers.3.self_attn.q_norm.weight -model.layers.3.self_attn.q_norm.weight:model.layers.3.self_attn.q_norm.weight +model.layers.3.self_attn.qk_norm.k_norm.weight +model.layers.3.self_attn.qk_norm.k_norm.weight:model.layers.3.self_attn.qk_norm.k_norm.weight +model.layers.3.self_attn.qk_norm.q_norm.weight +model.layers.3.self_attn.qk_norm.q_norm.weight:model.layers.3.self_attn.qk_norm.q_norm.weight model.layers.3.self_attn.qkv_proj.weight model.layers.3.self_attn.qkv_proj.weight:model.layers.3.self_attn.qkv_proj.weight model.layers.3.self_attn.qkv_proj.weight_scale @@ -542,13 +542,13 @@ model.layers.30.mlp.gate.weight model.layers.30.mlp.gate.weight:model.layers.30.mlp.gate.weight model.layers.30.post_attention_layernorm.weight model.layers.30.post_attention_layernorm.weight:model.layers.30.post_attention_layernorm.weight -model.layers.30.self_attn.k_norm.weight -model.layers.30.self_attn.k_norm.weight:model.layers.30.self_attn.k_norm.weight model.layers.30.self_attn.o_proj.weight model.layers.30.self_attn.o_proj.weight:model.layers.30.self_attn.o_proj.weight model.layers.30.self_attn.o_proj.weight_scale -model.layers.30.self_attn.q_norm.weight -model.layers.30.self_attn.q_norm.weight:model.layers.30.self_attn.q_norm.weight +model.layers.30.self_attn.qk_norm.k_norm.weight +model.layers.30.self_attn.qk_norm.k_norm.weight:model.layers.30.self_attn.qk_norm.k_norm.weight +model.layers.30.self_attn.qk_norm.q_norm.weight +model.layers.30.self_attn.qk_norm.q_norm.weight:model.layers.30.self_attn.qk_norm.q_norm.weight model.layers.30.self_attn.qkv_proj.weight model.layers.30.self_attn.qkv_proj.weight:model.layers.30.self_attn.qkv_proj.weight model.layers.30.self_attn.qkv_proj.weight_scale @@ -564,13 +564,13 @@ model.layers.31.mlp.gate.weight model.layers.31.mlp.gate.weight:model.layers.31.mlp.gate.weight model.layers.31.post_attention_layernorm.weight model.layers.31.post_attention_layernorm.weight:model.layers.31.post_attention_layernorm.weight -model.layers.31.self_attn.k_norm.weight -model.layers.31.self_attn.k_norm.weight:model.layers.31.self_attn.k_norm.weight model.layers.31.self_attn.o_proj.weight model.layers.31.self_attn.o_proj.weight:model.layers.31.self_attn.o_proj.weight model.layers.31.self_attn.o_proj.weight_scale -model.layers.31.self_attn.q_norm.weight -model.layers.31.self_attn.q_norm.weight:model.layers.31.self_attn.q_norm.weight +model.layers.31.self_attn.qk_norm.k_norm.weight +model.layers.31.self_attn.qk_norm.k_norm.weight:model.layers.31.self_attn.qk_norm.k_norm.weight +model.layers.31.self_attn.qk_norm.q_norm.weight +model.layers.31.self_attn.qk_norm.q_norm.weight:model.layers.31.self_attn.qk_norm.q_norm.weight model.layers.31.self_attn.qkv_proj.weight model.layers.31.self_attn.qkv_proj.weight:model.layers.31.self_attn.qkv_proj.weight model.layers.31.self_attn.qkv_proj.weight_scale @@ -586,13 +586,13 @@ model.layers.32.mlp.gate.weight model.layers.32.mlp.gate.weight:model.layers.32.mlp.gate.weight model.layers.32.post_attention_layernorm.weight model.layers.32.post_attention_layernorm.weight:model.layers.32.post_attention_layernorm.weight -model.layers.32.self_attn.k_norm.weight -model.layers.32.self_attn.k_norm.weight:model.layers.32.self_attn.k_norm.weight model.layers.32.self_attn.o_proj.weight model.layers.32.self_attn.o_proj.weight:model.layers.32.self_attn.o_proj.weight model.layers.32.self_attn.o_proj.weight_scale -model.layers.32.self_attn.q_norm.weight -model.layers.32.self_attn.q_norm.weight:model.layers.32.self_attn.q_norm.weight +model.layers.32.self_attn.qk_norm.k_norm.weight +model.layers.32.self_attn.qk_norm.k_norm.weight:model.layers.32.self_attn.qk_norm.k_norm.weight +model.layers.32.self_attn.qk_norm.q_norm.weight +model.layers.32.self_attn.qk_norm.q_norm.weight:model.layers.32.self_attn.qk_norm.q_norm.weight model.layers.32.self_attn.qkv_proj.weight model.layers.32.self_attn.qkv_proj.weight:model.layers.32.self_attn.qkv_proj.weight model.layers.32.self_attn.qkv_proj.weight_scale @@ -608,13 +608,13 @@ model.layers.33.mlp.gate.weight model.layers.33.mlp.gate.weight:model.layers.33.mlp.gate.weight model.layers.33.post_attention_layernorm.weight model.layers.33.post_attention_layernorm.weight:model.layers.33.post_attention_layernorm.weight -model.layers.33.self_attn.k_norm.weight -model.layers.33.self_attn.k_norm.weight:model.layers.33.self_attn.k_norm.weight model.layers.33.self_attn.o_proj.weight model.layers.33.self_attn.o_proj.weight:model.layers.33.self_attn.o_proj.weight model.layers.33.self_attn.o_proj.weight_scale -model.layers.33.self_attn.q_norm.weight -model.layers.33.self_attn.q_norm.weight:model.layers.33.self_attn.q_norm.weight +model.layers.33.self_attn.qk_norm.k_norm.weight +model.layers.33.self_attn.qk_norm.k_norm.weight:model.layers.33.self_attn.qk_norm.k_norm.weight +model.layers.33.self_attn.qk_norm.q_norm.weight +model.layers.33.self_attn.qk_norm.q_norm.weight:model.layers.33.self_attn.qk_norm.q_norm.weight model.layers.33.self_attn.qkv_proj.weight model.layers.33.self_attn.qkv_proj.weight:model.layers.33.self_attn.qkv_proj.weight model.layers.33.self_attn.qkv_proj.weight_scale @@ -630,13 +630,13 @@ model.layers.34.mlp.gate.weight model.layers.34.mlp.gate.weight:model.layers.34.mlp.gate.weight model.layers.34.post_attention_layernorm.weight model.layers.34.post_attention_layernorm.weight:model.layers.34.post_attention_layernorm.weight -model.layers.34.self_attn.k_norm.weight -model.layers.34.self_attn.k_norm.weight:model.layers.34.self_attn.k_norm.weight model.layers.34.self_attn.o_proj.weight model.layers.34.self_attn.o_proj.weight:model.layers.34.self_attn.o_proj.weight model.layers.34.self_attn.o_proj.weight_scale -model.layers.34.self_attn.q_norm.weight -model.layers.34.self_attn.q_norm.weight:model.layers.34.self_attn.q_norm.weight +model.layers.34.self_attn.qk_norm.k_norm.weight +model.layers.34.self_attn.qk_norm.k_norm.weight:model.layers.34.self_attn.qk_norm.k_norm.weight +model.layers.34.self_attn.qk_norm.q_norm.weight +model.layers.34.self_attn.qk_norm.q_norm.weight:model.layers.34.self_attn.qk_norm.q_norm.weight model.layers.34.self_attn.qkv_proj.weight model.layers.34.self_attn.qkv_proj.weight:model.layers.34.self_attn.qkv_proj.weight model.layers.34.self_attn.qkv_proj.weight_scale @@ -652,13 +652,13 @@ model.layers.35.mlp.gate.weight model.layers.35.mlp.gate.weight:model.layers.35.mlp.gate.weight model.layers.35.post_attention_layernorm.weight model.layers.35.post_attention_layernorm.weight:model.layers.35.post_attention_layernorm.weight -model.layers.35.self_attn.k_norm.weight -model.layers.35.self_attn.k_norm.weight:model.layers.35.self_attn.k_norm.weight model.layers.35.self_attn.o_proj.weight model.layers.35.self_attn.o_proj.weight:model.layers.35.self_attn.o_proj.weight model.layers.35.self_attn.o_proj.weight_scale -model.layers.35.self_attn.q_norm.weight -model.layers.35.self_attn.q_norm.weight:model.layers.35.self_attn.q_norm.weight +model.layers.35.self_attn.qk_norm.k_norm.weight +model.layers.35.self_attn.qk_norm.k_norm.weight:model.layers.35.self_attn.qk_norm.k_norm.weight +model.layers.35.self_attn.qk_norm.q_norm.weight +model.layers.35.self_attn.qk_norm.q_norm.weight:model.layers.35.self_attn.qk_norm.q_norm.weight model.layers.35.self_attn.qkv_proj.weight model.layers.35.self_attn.qkv_proj.weight:model.layers.35.self_attn.qkv_proj.weight model.layers.35.self_attn.qkv_proj.weight_scale @@ -674,13 +674,13 @@ model.layers.36.mlp.gate.weight model.layers.36.mlp.gate.weight:model.layers.36.mlp.gate.weight model.layers.36.post_attention_layernorm.weight model.layers.36.post_attention_layernorm.weight:model.layers.36.post_attention_layernorm.weight -model.layers.36.self_attn.k_norm.weight -model.layers.36.self_attn.k_norm.weight:model.layers.36.self_attn.k_norm.weight model.layers.36.self_attn.o_proj.weight model.layers.36.self_attn.o_proj.weight:model.layers.36.self_attn.o_proj.weight model.layers.36.self_attn.o_proj.weight_scale -model.layers.36.self_attn.q_norm.weight -model.layers.36.self_attn.q_norm.weight:model.layers.36.self_attn.q_norm.weight +model.layers.36.self_attn.qk_norm.k_norm.weight +model.layers.36.self_attn.qk_norm.k_norm.weight:model.layers.36.self_attn.qk_norm.k_norm.weight +model.layers.36.self_attn.qk_norm.q_norm.weight +model.layers.36.self_attn.qk_norm.q_norm.weight:model.layers.36.self_attn.qk_norm.q_norm.weight model.layers.36.self_attn.qkv_proj.weight model.layers.36.self_attn.qkv_proj.weight:model.layers.36.self_attn.qkv_proj.weight model.layers.36.self_attn.qkv_proj.weight_scale @@ -696,13 +696,13 @@ model.layers.37.mlp.gate.weight model.layers.37.mlp.gate.weight:model.layers.37.mlp.gate.weight model.layers.37.post_attention_layernorm.weight model.layers.37.post_attention_layernorm.weight:model.layers.37.post_attention_layernorm.weight -model.layers.37.self_attn.k_norm.weight -model.layers.37.self_attn.k_norm.weight:model.layers.37.self_attn.k_norm.weight model.layers.37.self_attn.o_proj.weight model.layers.37.self_attn.o_proj.weight:model.layers.37.self_attn.o_proj.weight model.layers.37.self_attn.o_proj.weight_scale -model.layers.37.self_attn.q_norm.weight -model.layers.37.self_attn.q_norm.weight:model.layers.37.self_attn.q_norm.weight +model.layers.37.self_attn.qk_norm.k_norm.weight +model.layers.37.self_attn.qk_norm.k_norm.weight:model.layers.37.self_attn.qk_norm.k_norm.weight +model.layers.37.self_attn.qk_norm.q_norm.weight +model.layers.37.self_attn.qk_norm.q_norm.weight:model.layers.37.self_attn.qk_norm.q_norm.weight model.layers.37.self_attn.qkv_proj.weight model.layers.37.self_attn.qkv_proj.weight:model.layers.37.self_attn.qkv_proj.weight model.layers.37.self_attn.qkv_proj.weight_scale @@ -718,13 +718,13 @@ model.layers.38.mlp.gate.weight model.layers.38.mlp.gate.weight:model.layers.38.mlp.gate.weight model.layers.38.post_attention_layernorm.weight model.layers.38.post_attention_layernorm.weight:model.layers.38.post_attention_layernorm.weight -model.layers.38.self_attn.k_norm.weight -model.layers.38.self_attn.k_norm.weight:model.layers.38.self_attn.k_norm.weight model.layers.38.self_attn.o_proj.weight model.layers.38.self_attn.o_proj.weight:model.layers.38.self_attn.o_proj.weight model.layers.38.self_attn.o_proj.weight_scale -model.layers.38.self_attn.q_norm.weight -model.layers.38.self_attn.q_norm.weight:model.layers.38.self_attn.q_norm.weight +model.layers.38.self_attn.qk_norm.k_norm.weight +model.layers.38.self_attn.qk_norm.k_norm.weight:model.layers.38.self_attn.qk_norm.k_norm.weight +model.layers.38.self_attn.qk_norm.q_norm.weight +model.layers.38.self_attn.qk_norm.q_norm.weight:model.layers.38.self_attn.qk_norm.q_norm.weight model.layers.38.self_attn.qkv_proj.weight model.layers.38.self_attn.qkv_proj.weight:model.layers.38.self_attn.qkv_proj.weight model.layers.38.self_attn.qkv_proj.weight_scale @@ -740,13 +740,13 @@ model.layers.39.mlp.gate.weight model.layers.39.mlp.gate.weight:model.layers.39.mlp.gate.weight model.layers.39.post_attention_layernorm.weight model.layers.39.post_attention_layernorm.weight:model.layers.39.post_attention_layernorm.weight -model.layers.39.self_attn.k_norm.weight -model.layers.39.self_attn.k_norm.weight:model.layers.39.self_attn.k_norm.weight model.layers.39.self_attn.o_proj.weight model.layers.39.self_attn.o_proj.weight:model.layers.39.self_attn.o_proj.weight model.layers.39.self_attn.o_proj.weight_scale -model.layers.39.self_attn.q_norm.weight -model.layers.39.self_attn.q_norm.weight:model.layers.39.self_attn.q_norm.weight +model.layers.39.self_attn.qk_norm.k_norm.weight +model.layers.39.self_attn.qk_norm.k_norm.weight:model.layers.39.self_attn.qk_norm.k_norm.weight +model.layers.39.self_attn.qk_norm.q_norm.weight +model.layers.39.self_attn.qk_norm.q_norm.weight:model.layers.39.self_attn.qk_norm.q_norm.weight model.layers.39.self_attn.qkv_proj.weight model.layers.39.self_attn.qkv_proj.weight:model.layers.39.self_attn.qkv_proj.weight model.layers.39.self_attn.qkv_proj.weight_scale @@ -762,13 +762,13 @@ model.layers.4.mlp.gate.weight model.layers.4.mlp.gate.weight:model.layers.4.mlp.gate.weight model.layers.4.post_attention_layernorm.weight model.layers.4.post_attention_layernorm.weight:model.layers.4.post_attention_layernorm.weight -model.layers.4.self_attn.k_norm.weight -model.layers.4.self_attn.k_norm.weight:model.layers.4.self_attn.k_norm.weight model.layers.4.self_attn.o_proj.weight model.layers.4.self_attn.o_proj.weight:model.layers.4.self_attn.o_proj.weight model.layers.4.self_attn.o_proj.weight_scale -model.layers.4.self_attn.q_norm.weight -model.layers.4.self_attn.q_norm.weight:model.layers.4.self_attn.q_norm.weight +model.layers.4.self_attn.qk_norm.k_norm.weight +model.layers.4.self_attn.qk_norm.k_norm.weight:model.layers.4.self_attn.qk_norm.k_norm.weight +model.layers.4.self_attn.qk_norm.q_norm.weight +model.layers.4.self_attn.qk_norm.q_norm.weight:model.layers.4.self_attn.qk_norm.q_norm.weight model.layers.4.self_attn.qkv_proj.weight model.layers.4.self_attn.qkv_proj.weight:model.layers.4.self_attn.qkv_proj.weight model.layers.4.self_attn.qkv_proj.weight_scale @@ -784,13 +784,13 @@ model.layers.40.mlp.gate.weight model.layers.40.mlp.gate.weight:model.layers.40.mlp.gate.weight model.layers.40.post_attention_layernorm.weight model.layers.40.post_attention_layernorm.weight:model.layers.40.post_attention_layernorm.weight -model.layers.40.self_attn.k_norm.weight -model.layers.40.self_attn.k_norm.weight:model.layers.40.self_attn.k_norm.weight model.layers.40.self_attn.o_proj.weight model.layers.40.self_attn.o_proj.weight:model.layers.40.self_attn.o_proj.weight model.layers.40.self_attn.o_proj.weight_scale -model.layers.40.self_attn.q_norm.weight -model.layers.40.self_attn.q_norm.weight:model.layers.40.self_attn.q_norm.weight +model.layers.40.self_attn.qk_norm.k_norm.weight +model.layers.40.self_attn.qk_norm.k_norm.weight:model.layers.40.self_attn.qk_norm.k_norm.weight +model.layers.40.self_attn.qk_norm.q_norm.weight +model.layers.40.self_attn.qk_norm.q_norm.weight:model.layers.40.self_attn.qk_norm.q_norm.weight model.layers.40.self_attn.qkv_proj.weight model.layers.40.self_attn.qkv_proj.weight:model.layers.40.self_attn.qkv_proj.weight model.layers.40.self_attn.qkv_proj.weight_scale @@ -806,13 +806,13 @@ model.layers.41.mlp.gate.weight model.layers.41.mlp.gate.weight:model.layers.41.mlp.gate.weight model.layers.41.post_attention_layernorm.weight model.layers.41.post_attention_layernorm.weight:model.layers.41.post_attention_layernorm.weight -model.layers.41.self_attn.k_norm.weight -model.layers.41.self_attn.k_norm.weight:model.layers.41.self_attn.k_norm.weight model.layers.41.self_attn.o_proj.weight model.layers.41.self_attn.o_proj.weight:model.layers.41.self_attn.o_proj.weight model.layers.41.self_attn.o_proj.weight_scale -model.layers.41.self_attn.q_norm.weight -model.layers.41.self_attn.q_norm.weight:model.layers.41.self_attn.q_norm.weight +model.layers.41.self_attn.qk_norm.k_norm.weight +model.layers.41.self_attn.qk_norm.k_norm.weight:model.layers.41.self_attn.qk_norm.k_norm.weight +model.layers.41.self_attn.qk_norm.q_norm.weight +model.layers.41.self_attn.qk_norm.q_norm.weight:model.layers.41.self_attn.qk_norm.q_norm.weight model.layers.41.self_attn.qkv_proj.weight model.layers.41.self_attn.qkv_proj.weight:model.layers.41.self_attn.qkv_proj.weight model.layers.41.self_attn.qkv_proj.weight_scale @@ -828,13 +828,13 @@ model.layers.42.mlp.gate.weight model.layers.42.mlp.gate.weight:model.layers.42.mlp.gate.weight model.layers.42.post_attention_layernorm.weight model.layers.42.post_attention_layernorm.weight:model.layers.42.post_attention_layernorm.weight -model.layers.42.self_attn.k_norm.weight -model.layers.42.self_attn.k_norm.weight:model.layers.42.self_attn.k_norm.weight model.layers.42.self_attn.o_proj.weight model.layers.42.self_attn.o_proj.weight:model.layers.42.self_attn.o_proj.weight model.layers.42.self_attn.o_proj.weight_scale -model.layers.42.self_attn.q_norm.weight -model.layers.42.self_attn.q_norm.weight:model.layers.42.self_attn.q_norm.weight +model.layers.42.self_attn.qk_norm.k_norm.weight +model.layers.42.self_attn.qk_norm.k_norm.weight:model.layers.42.self_attn.qk_norm.k_norm.weight +model.layers.42.self_attn.qk_norm.q_norm.weight +model.layers.42.self_attn.qk_norm.q_norm.weight:model.layers.42.self_attn.qk_norm.q_norm.weight model.layers.42.self_attn.qkv_proj.weight model.layers.42.self_attn.qkv_proj.weight:model.layers.42.self_attn.qkv_proj.weight model.layers.42.self_attn.qkv_proj.weight_scale @@ -850,13 +850,13 @@ model.layers.43.mlp.gate.weight model.layers.43.mlp.gate.weight:model.layers.43.mlp.gate.weight model.layers.43.post_attention_layernorm.weight model.layers.43.post_attention_layernorm.weight:model.layers.43.post_attention_layernorm.weight -model.layers.43.self_attn.k_norm.weight -model.layers.43.self_attn.k_norm.weight:model.layers.43.self_attn.k_norm.weight model.layers.43.self_attn.o_proj.weight model.layers.43.self_attn.o_proj.weight:model.layers.43.self_attn.o_proj.weight model.layers.43.self_attn.o_proj.weight_scale -model.layers.43.self_attn.q_norm.weight -model.layers.43.self_attn.q_norm.weight:model.layers.43.self_attn.q_norm.weight +model.layers.43.self_attn.qk_norm.k_norm.weight +model.layers.43.self_attn.qk_norm.k_norm.weight:model.layers.43.self_attn.qk_norm.k_norm.weight +model.layers.43.self_attn.qk_norm.q_norm.weight +model.layers.43.self_attn.qk_norm.q_norm.weight:model.layers.43.self_attn.qk_norm.q_norm.weight model.layers.43.self_attn.qkv_proj.weight model.layers.43.self_attn.qkv_proj.weight:model.layers.43.self_attn.qkv_proj.weight model.layers.43.self_attn.qkv_proj.weight_scale @@ -872,13 +872,13 @@ model.layers.44.mlp.gate.weight model.layers.44.mlp.gate.weight:model.layers.44.mlp.gate.weight model.layers.44.post_attention_layernorm.weight model.layers.44.post_attention_layernorm.weight:model.layers.44.post_attention_layernorm.weight -model.layers.44.self_attn.k_norm.weight -model.layers.44.self_attn.k_norm.weight:model.layers.44.self_attn.k_norm.weight model.layers.44.self_attn.o_proj.weight model.layers.44.self_attn.o_proj.weight:model.layers.44.self_attn.o_proj.weight model.layers.44.self_attn.o_proj.weight_scale -model.layers.44.self_attn.q_norm.weight -model.layers.44.self_attn.q_norm.weight:model.layers.44.self_attn.q_norm.weight +model.layers.44.self_attn.qk_norm.k_norm.weight +model.layers.44.self_attn.qk_norm.k_norm.weight:model.layers.44.self_attn.qk_norm.k_norm.weight +model.layers.44.self_attn.qk_norm.q_norm.weight +model.layers.44.self_attn.qk_norm.q_norm.weight:model.layers.44.self_attn.qk_norm.q_norm.weight model.layers.44.self_attn.qkv_proj.weight model.layers.44.self_attn.qkv_proj.weight:model.layers.44.self_attn.qkv_proj.weight model.layers.44.self_attn.qkv_proj.weight_scale @@ -894,13 +894,13 @@ model.layers.45.mlp.gate.weight model.layers.45.mlp.gate.weight:model.layers.45.mlp.gate.weight model.layers.45.post_attention_layernorm.weight model.layers.45.post_attention_layernorm.weight:model.layers.45.post_attention_layernorm.weight -model.layers.45.self_attn.k_norm.weight -model.layers.45.self_attn.k_norm.weight:model.layers.45.self_attn.k_norm.weight model.layers.45.self_attn.o_proj.weight model.layers.45.self_attn.o_proj.weight:model.layers.45.self_attn.o_proj.weight model.layers.45.self_attn.o_proj.weight_scale -model.layers.45.self_attn.q_norm.weight -model.layers.45.self_attn.q_norm.weight:model.layers.45.self_attn.q_norm.weight +model.layers.45.self_attn.qk_norm.k_norm.weight +model.layers.45.self_attn.qk_norm.k_norm.weight:model.layers.45.self_attn.qk_norm.k_norm.weight +model.layers.45.self_attn.qk_norm.q_norm.weight +model.layers.45.self_attn.qk_norm.q_norm.weight:model.layers.45.self_attn.qk_norm.q_norm.weight model.layers.45.self_attn.qkv_proj.weight model.layers.45.self_attn.qkv_proj.weight:model.layers.45.self_attn.qkv_proj.weight model.layers.45.self_attn.qkv_proj.weight_scale @@ -916,13 +916,13 @@ model.layers.46.mlp.gate.weight model.layers.46.mlp.gate.weight:model.layers.46.mlp.gate.weight model.layers.46.post_attention_layernorm.weight model.layers.46.post_attention_layernorm.weight:model.layers.46.post_attention_layernorm.weight -model.layers.46.self_attn.k_norm.weight -model.layers.46.self_attn.k_norm.weight:model.layers.46.self_attn.k_norm.weight model.layers.46.self_attn.o_proj.weight model.layers.46.self_attn.o_proj.weight:model.layers.46.self_attn.o_proj.weight model.layers.46.self_attn.o_proj.weight_scale -model.layers.46.self_attn.q_norm.weight -model.layers.46.self_attn.q_norm.weight:model.layers.46.self_attn.q_norm.weight +model.layers.46.self_attn.qk_norm.k_norm.weight +model.layers.46.self_attn.qk_norm.k_norm.weight:model.layers.46.self_attn.qk_norm.k_norm.weight +model.layers.46.self_attn.qk_norm.q_norm.weight +model.layers.46.self_attn.qk_norm.q_norm.weight:model.layers.46.self_attn.qk_norm.q_norm.weight model.layers.46.self_attn.qkv_proj.weight model.layers.46.self_attn.qkv_proj.weight:model.layers.46.self_attn.qkv_proj.weight model.layers.46.self_attn.qkv_proj.weight_scale @@ -938,13 +938,13 @@ model.layers.47.mlp.gate.weight model.layers.47.mlp.gate.weight:model.layers.47.mlp.gate.weight model.layers.47.post_attention_layernorm.weight model.layers.47.post_attention_layernorm.weight:model.layers.47.post_attention_layernorm.weight -model.layers.47.self_attn.k_norm.weight -model.layers.47.self_attn.k_norm.weight:model.layers.47.self_attn.k_norm.weight model.layers.47.self_attn.o_proj.weight model.layers.47.self_attn.o_proj.weight:model.layers.47.self_attn.o_proj.weight model.layers.47.self_attn.o_proj.weight_scale -model.layers.47.self_attn.q_norm.weight -model.layers.47.self_attn.q_norm.weight:model.layers.47.self_attn.q_norm.weight +model.layers.47.self_attn.qk_norm.k_norm.weight +model.layers.47.self_attn.qk_norm.k_norm.weight:model.layers.47.self_attn.qk_norm.k_norm.weight +model.layers.47.self_attn.qk_norm.q_norm.weight +model.layers.47.self_attn.qk_norm.q_norm.weight:model.layers.47.self_attn.qk_norm.q_norm.weight model.layers.47.self_attn.qkv_proj.weight model.layers.47.self_attn.qkv_proj.weight:model.layers.47.self_attn.qkv_proj.weight model.layers.47.self_attn.qkv_proj.weight_scale @@ -960,13 +960,13 @@ model.layers.5.mlp.gate.weight model.layers.5.mlp.gate.weight:model.layers.5.mlp.gate.weight model.layers.5.post_attention_layernorm.weight model.layers.5.post_attention_layernorm.weight:model.layers.5.post_attention_layernorm.weight -model.layers.5.self_attn.k_norm.weight -model.layers.5.self_attn.k_norm.weight:model.layers.5.self_attn.k_norm.weight model.layers.5.self_attn.o_proj.weight model.layers.5.self_attn.o_proj.weight:model.layers.5.self_attn.o_proj.weight model.layers.5.self_attn.o_proj.weight_scale -model.layers.5.self_attn.q_norm.weight -model.layers.5.self_attn.q_norm.weight:model.layers.5.self_attn.q_norm.weight +model.layers.5.self_attn.qk_norm.k_norm.weight +model.layers.5.self_attn.qk_norm.k_norm.weight:model.layers.5.self_attn.qk_norm.k_norm.weight +model.layers.5.self_attn.qk_norm.q_norm.weight +model.layers.5.self_attn.qk_norm.q_norm.weight:model.layers.5.self_attn.qk_norm.q_norm.weight model.layers.5.self_attn.qkv_proj.weight model.layers.5.self_attn.qkv_proj.weight:model.layers.5.self_attn.qkv_proj.weight model.layers.5.self_attn.qkv_proj.weight_scale @@ -982,13 +982,13 @@ model.layers.6.mlp.gate.weight model.layers.6.mlp.gate.weight:model.layers.6.mlp.gate.weight model.layers.6.post_attention_layernorm.weight model.layers.6.post_attention_layernorm.weight:model.layers.6.post_attention_layernorm.weight -model.layers.6.self_attn.k_norm.weight -model.layers.6.self_attn.k_norm.weight:model.layers.6.self_attn.k_norm.weight model.layers.6.self_attn.o_proj.weight model.layers.6.self_attn.o_proj.weight:model.layers.6.self_attn.o_proj.weight model.layers.6.self_attn.o_proj.weight_scale -model.layers.6.self_attn.q_norm.weight -model.layers.6.self_attn.q_norm.weight:model.layers.6.self_attn.q_norm.weight +model.layers.6.self_attn.qk_norm.k_norm.weight +model.layers.6.self_attn.qk_norm.k_norm.weight:model.layers.6.self_attn.qk_norm.k_norm.weight +model.layers.6.self_attn.qk_norm.q_norm.weight +model.layers.6.self_attn.qk_norm.q_norm.weight:model.layers.6.self_attn.qk_norm.q_norm.weight model.layers.6.self_attn.qkv_proj.weight model.layers.6.self_attn.qkv_proj.weight:model.layers.6.self_attn.qkv_proj.weight model.layers.6.self_attn.qkv_proj.weight_scale @@ -1004,13 +1004,13 @@ model.layers.7.mlp.gate.weight model.layers.7.mlp.gate.weight:model.layers.7.mlp.gate.weight model.layers.7.post_attention_layernorm.weight model.layers.7.post_attention_layernorm.weight:model.layers.7.post_attention_layernorm.weight -model.layers.7.self_attn.k_norm.weight -model.layers.7.self_attn.k_norm.weight:model.layers.7.self_attn.k_norm.weight model.layers.7.self_attn.o_proj.weight model.layers.7.self_attn.o_proj.weight:model.layers.7.self_attn.o_proj.weight model.layers.7.self_attn.o_proj.weight_scale -model.layers.7.self_attn.q_norm.weight -model.layers.7.self_attn.q_norm.weight:model.layers.7.self_attn.q_norm.weight +model.layers.7.self_attn.qk_norm.k_norm.weight +model.layers.7.self_attn.qk_norm.k_norm.weight:model.layers.7.self_attn.qk_norm.k_norm.weight +model.layers.7.self_attn.qk_norm.q_norm.weight +model.layers.7.self_attn.qk_norm.q_norm.weight:model.layers.7.self_attn.qk_norm.q_norm.weight model.layers.7.self_attn.qkv_proj.weight model.layers.7.self_attn.qkv_proj.weight:model.layers.7.self_attn.qkv_proj.weight model.layers.7.self_attn.qkv_proj.weight_scale @@ -1026,13 +1026,13 @@ model.layers.8.mlp.gate.weight model.layers.8.mlp.gate.weight:model.layers.8.mlp.gate.weight model.layers.8.post_attention_layernorm.weight model.layers.8.post_attention_layernorm.weight:model.layers.8.post_attention_layernorm.weight -model.layers.8.self_attn.k_norm.weight -model.layers.8.self_attn.k_norm.weight:model.layers.8.self_attn.k_norm.weight model.layers.8.self_attn.o_proj.weight model.layers.8.self_attn.o_proj.weight:model.layers.8.self_attn.o_proj.weight model.layers.8.self_attn.o_proj.weight_scale -model.layers.8.self_attn.q_norm.weight -model.layers.8.self_attn.q_norm.weight:model.layers.8.self_attn.q_norm.weight +model.layers.8.self_attn.qk_norm.k_norm.weight +model.layers.8.self_attn.qk_norm.k_norm.weight:model.layers.8.self_attn.qk_norm.k_norm.weight +model.layers.8.self_attn.qk_norm.q_norm.weight +model.layers.8.self_attn.qk_norm.q_norm.weight:model.layers.8.self_attn.qk_norm.q_norm.weight model.layers.8.self_attn.qkv_proj.weight model.layers.8.self_attn.qkv_proj.weight:model.layers.8.self_attn.qkv_proj.weight model.layers.8.self_attn.qkv_proj.weight_scale @@ -1048,13 +1048,13 @@ model.layers.9.mlp.gate.weight model.layers.9.mlp.gate.weight:model.layers.9.mlp.gate.weight model.layers.9.post_attention_layernorm.weight model.layers.9.post_attention_layernorm.weight:model.layers.9.post_attention_layernorm.weight -model.layers.9.self_attn.k_norm.weight -model.layers.9.self_attn.k_norm.weight:model.layers.9.self_attn.k_norm.weight model.layers.9.self_attn.o_proj.weight model.layers.9.self_attn.o_proj.weight:model.layers.9.self_attn.o_proj.weight model.layers.9.self_attn.o_proj.weight_scale -model.layers.9.self_attn.q_norm.weight -model.layers.9.self_attn.q_norm.weight:model.layers.9.self_attn.q_norm.weight +model.layers.9.self_attn.qk_norm.k_norm.weight +model.layers.9.self_attn.qk_norm.k_norm.weight:model.layers.9.self_attn.qk_norm.k_norm.weight +model.layers.9.self_attn.qk_norm.q_norm.weight +model.layers.9.self_attn.qk_norm.q_norm.weight:model.layers.9.self_attn.qk_norm.q_norm.weight model.layers.9.self_attn.qkv_proj.weight model.layers.9.self_attn.qkv_proj.weight:model.layers.9.self_attn.qkv_proj.weight model.layers.9.self_attn.qkv_proj.weight_scale diff --git a/tests/e2e/Qwen3VL_RL/baseline.txt b/tests/e2e/Qwen3VL_RL/baseline.txt index 6f3d36f57eb..def3b469867 100644 --- a/tests/e2e/Qwen3VL_RL/baseline.txt +++ b/tests/e2e/Qwen3VL_RL/baseline.txt @@ -12,13 +12,13 @@ model.layers.0.mlp.up_gate_proj.weight:model.layers.0.mlp.gate_up_fused_proj.wei model.layers.0.mlp.up_gate_proj.weight_scale model.layers.0.post_attention_layernorm.weight model.layers.0.post_attention_layernorm.weight:model.layers.0.post_attention_layernorm.weight -model.layers.0.self_attn.k_norm.weight -model.layers.0.self_attn.k_norm.weight:model.layers.0.self_attn.k_norm.weight model.layers.0.self_attn.o_proj.weight model.layers.0.self_attn.o_proj.weight:model.layers.0.self_attn.o_proj.weight model.layers.0.self_attn.o_proj.weight_scale -model.layers.0.self_attn.q_norm.weight -model.layers.0.self_attn.q_norm.weight:model.layers.0.self_attn.q_norm.weight +model.layers.0.self_attn.qk_norm.k_norm.weight +model.layers.0.self_attn.qk_norm.k_norm.weight:model.layers.0.self_attn.qk_norm.k_norm.weight +model.layers.0.self_attn.qk_norm.q_norm.weight +model.layers.0.self_attn.qk_norm.q_norm.weight:model.layers.0.self_attn.qk_norm.q_norm.weight model.layers.0.self_attn.qkv_proj.weight model.layers.0.self_attn.qkv_proj.weight:model.layers.0.self_attn.qkv_proj.weight model.layers.0.self_attn.qkv_proj.weight_scale @@ -32,13 +32,13 @@ model.layers.1.mlp.up_gate_proj.weight:model.layers.1.mlp.gate_up_fused_proj.wei model.layers.1.mlp.up_gate_proj.weight_scale model.layers.1.post_attention_layernorm.weight model.layers.1.post_attention_layernorm.weight:model.layers.1.post_attention_layernorm.weight -model.layers.1.self_attn.k_norm.weight -model.layers.1.self_attn.k_norm.weight:model.layers.1.self_attn.k_norm.weight model.layers.1.self_attn.o_proj.weight model.layers.1.self_attn.o_proj.weight:model.layers.1.self_attn.o_proj.weight model.layers.1.self_attn.o_proj.weight_scale -model.layers.1.self_attn.q_norm.weight -model.layers.1.self_attn.q_norm.weight:model.layers.1.self_attn.q_norm.weight +model.layers.1.self_attn.qk_norm.k_norm.weight +model.layers.1.self_attn.qk_norm.k_norm.weight:model.layers.1.self_attn.qk_norm.k_norm.weight +model.layers.1.self_attn.qk_norm.q_norm.weight +model.layers.1.self_attn.qk_norm.q_norm.weight:model.layers.1.self_attn.qk_norm.q_norm.weight model.layers.1.self_attn.qkv_proj.weight model.layers.1.self_attn.qkv_proj.weight:model.layers.1.self_attn.qkv_proj.weight model.layers.1.self_attn.qkv_proj.weight_scale @@ -52,13 +52,13 @@ model.layers.10.mlp.up_gate_proj.weight:model.layers.10.mlp.gate_up_fused_proj.w model.layers.10.mlp.up_gate_proj.weight_scale model.layers.10.post_attention_layernorm.weight model.layers.10.post_attention_layernorm.weight:model.layers.10.post_attention_layernorm.weight -model.layers.10.self_attn.k_norm.weight -model.layers.10.self_attn.k_norm.weight:model.layers.10.self_attn.k_norm.weight model.layers.10.self_attn.o_proj.weight model.layers.10.self_attn.o_proj.weight:model.layers.10.self_attn.o_proj.weight model.layers.10.self_attn.o_proj.weight_scale -model.layers.10.self_attn.q_norm.weight -model.layers.10.self_attn.q_norm.weight:model.layers.10.self_attn.q_norm.weight +model.layers.10.self_attn.qk_norm.k_norm.weight +model.layers.10.self_attn.qk_norm.k_norm.weight:model.layers.10.self_attn.qk_norm.k_norm.weight +model.layers.10.self_attn.qk_norm.q_norm.weight +model.layers.10.self_attn.qk_norm.q_norm.weight:model.layers.10.self_attn.qk_norm.q_norm.weight model.layers.10.self_attn.qkv_proj.weight model.layers.10.self_attn.qkv_proj.weight:model.layers.10.self_attn.qkv_proj.weight model.layers.10.self_attn.qkv_proj.weight_scale @@ -72,13 +72,13 @@ model.layers.11.mlp.up_gate_proj.weight:model.layers.11.mlp.gate_up_fused_proj.w model.layers.11.mlp.up_gate_proj.weight_scale model.layers.11.post_attention_layernorm.weight model.layers.11.post_attention_layernorm.weight:model.layers.11.post_attention_layernorm.weight -model.layers.11.self_attn.k_norm.weight -model.layers.11.self_attn.k_norm.weight:model.layers.11.self_attn.k_norm.weight model.layers.11.self_attn.o_proj.weight model.layers.11.self_attn.o_proj.weight:model.layers.11.self_attn.o_proj.weight model.layers.11.self_attn.o_proj.weight_scale -model.layers.11.self_attn.q_norm.weight -model.layers.11.self_attn.q_norm.weight:model.layers.11.self_attn.q_norm.weight +model.layers.11.self_attn.qk_norm.k_norm.weight +model.layers.11.self_attn.qk_norm.k_norm.weight:model.layers.11.self_attn.qk_norm.k_norm.weight +model.layers.11.self_attn.qk_norm.q_norm.weight +model.layers.11.self_attn.qk_norm.q_norm.weight:model.layers.11.self_attn.qk_norm.q_norm.weight model.layers.11.self_attn.qkv_proj.weight model.layers.11.self_attn.qkv_proj.weight:model.layers.11.self_attn.qkv_proj.weight model.layers.11.self_attn.qkv_proj.weight_scale @@ -92,13 +92,13 @@ model.layers.12.mlp.up_gate_proj.weight:model.layers.12.mlp.gate_up_fused_proj.w model.layers.12.mlp.up_gate_proj.weight_scale model.layers.12.post_attention_layernorm.weight model.layers.12.post_attention_layernorm.weight:model.layers.12.post_attention_layernorm.weight -model.layers.12.self_attn.k_norm.weight -model.layers.12.self_attn.k_norm.weight:model.layers.12.self_attn.k_norm.weight model.layers.12.self_attn.o_proj.weight model.layers.12.self_attn.o_proj.weight:model.layers.12.self_attn.o_proj.weight model.layers.12.self_attn.o_proj.weight_scale -model.layers.12.self_attn.q_norm.weight -model.layers.12.self_attn.q_norm.weight:model.layers.12.self_attn.q_norm.weight +model.layers.12.self_attn.qk_norm.k_norm.weight +model.layers.12.self_attn.qk_norm.k_norm.weight:model.layers.12.self_attn.qk_norm.k_norm.weight +model.layers.12.self_attn.qk_norm.q_norm.weight +model.layers.12.self_attn.qk_norm.q_norm.weight:model.layers.12.self_attn.qk_norm.q_norm.weight model.layers.12.self_attn.qkv_proj.weight model.layers.12.self_attn.qkv_proj.weight:model.layers.12.self_attn.qkv_proj.weight model.layers.12.self_attn.qkv_proj.weight_scale @@ -112,13 +112,13 @@ model.layers.13.mlp.up_gate_proj.weight:model.layers.13.mlp.gate_up_fused_proj.w model.layers.13.mlp.up_gate_proj.weight_scale model.layers.13.post_attention_layernorm.weight model.layers.13.post_attention_layernorm.weight:model.layers.13.post_attention_layernorm.weight -model.layers.13.self_attn.k_norm.weight -model.layers.13.self_attn.k_norm.weight:model.layers.13.self_attn.k_norm.weight model.layers.13.self_attn.o_proj.weight model.layers.13.self_attn.o_proj.weight:model.layers.13.self_attn.o_proj.weight model.layers.13.self_attn.o_proj.weight_scale -model.layers.13.self_attn.q_norm.weight -model.layers.13.self_attn.q_norm.weight:model.layers.13.self_attn.q_norm.weight +model.layers.13.self_attn.qk_norm.k_norm.weight +model.layers.13.self_attn.qk_norm.k_norm.weight:model.layers.13.self_attn.qk_norm.k_norm.weight +model.layers.13.self_attn.qk_norm.q_norm.weight +model.layers.13.self_attn.qk_norm.q_norm.weight:model.layers.13.self_attn.qk_norm.q_norm.weight model.layers.13.self_attn.qkv_proj.weight model.layers.13.self_attn.qkv_proj.weight:model.layers.13.self_attn.qkv_proj.weight model.layers.13.self_attn.qkv_proj.weight_scale @@ -132,13 +132,13 @@ model.layers.14.mlp.up_gate_proj.weight:model.layers.14.mlp.gate_up_fused_proj.w model.layers.14.mlp.up_gate_proj.weight_scale model.layers.14.post_attention_layernorm.weight model.layers.14.post_attention_layernorm.weight:model.layers.14.post_attention_layernorm.weight -model.layers.14.self_attn.k_norm.weight -model.layers.14.self_attn.k_norm.weight:model.layers.14.self_attn.k_norm.weight model.layers.14.self_attn.o_proj.weight model.layers.14.self_attn.o_proj.weight:model.layers.14.self_attn.o_proj.weight model.layers.14.self_attn.o_proj.weight_scale -model.layers.14.self_attn.q_norm.weight -model.layers.14.self_attn.q_norm.weight:model.layers.14.self_attn.q_norm.weight +model.layers.14.self_attn.qk_norm.k_norm.weight +model.layers.14.self_attn.qk_norm.k_norm.weight:model.layers.14.self_attn.qk_norm.k_norm.weight +model.layers.14.self_attn.qk_norm.q_norm.weight +model.layers.14.self_attn.qk_norm.q_norm.weight:model.layers.14.self_attn.qk_norm.q_norm.weight model.layers.14.self_attn.qkv_proj.weight model.layers.14.self_attn.qkv_proj.weight:model.layers.14.self_attn.qkv_proj.weight model.layers.14.self_attn.qkv_proj.weight_scale @@ -152,13 +152,13 @@ model.layers.15.mlp.up_gate_proj.weight:model.layers.15.mlp.gate_up_fused_proj.w model.layers.15.mlp.up_gate_proj.weight_scale model.layers.15.post_attention_layernorm.weight model.layers.15.post_attention_layernorm.weight:model.layers.15.post_attention_layernorm.weight -model.layers.15.self_attn.k_norm.weight -model.layers.15.self_attn.k_norm.weight:model.layers.15.self_attn.k_norm.weight model.layers.15.self_attn.o_proj.weight model.layers.15.self_attn.o_proj.weight:model.layers.15.self_attn.o_proj.weight model.layers.15.self_attn.o_proj.weight_scale -model.layers.15.self_attn.q_norm.weight -model.layers.15.self_attn.q_norm.weight:model.layers.15.self_attn.q_norm.weight +model.layers.15.self_attn.qk_norm.k_norm.weight +model.layers.15.self_attn.qk_norm.k_norm.weight:model.layers.15.self_attn.qk_norm.k_norm.weight +model.layers.15.self_attn.qk_norm.q_norm.weight +model.layers.15.self_attn.qk_norm.q_norm.weight:model.layers.15.self_attn.qk_norm.q_norm.weight model.layers.15.self_attn.qkv_proj.weight model.layers.15.self_attn.qkv_proj.weight:model.layers.15.self_attn.qkv_proj.weight model.layers.15.self_attn.qkv_proj.weight_scale @@ -172,13 +172,13 @@ model.layers.16.mlp.up_gate_proj.weight:model.layers.16.mlp.gate_up_fused_proj.w model.layers.16.mlp.up_gate_proj.weight_scale model.layers.16.post_attention_layernorm.weight model.layers.16.post_attention_layernorm.weight:model.layers.16.post_attention_layernorm.weight -model.layers.16.self_attn.k_norm.weight -model.layers.16.self_attn.k_norm.weight:model.layers.16.self_attn.k_norm.weight model.layers.16.self_attn.o_proj.weight model.layers.16.self_attn.o_proj.weight:model.layers.16.self_attn.o_proj.weight model.layers.16.self_attn.o_proj.weight_scale -model.layers.16.self_attn.q_norm.weight -model.layers.16.self_attn.q_norm.weight:model.layers.16.self_attn.q_norm.weight +model.layers.16.self_attn.qk_norm.k_norm.weight +model.layers.16.self_attn.qk_norm.k_norm.weight:model.layers.16.self_attn.qk_norm.k_norm.weight +model.layers.16.self_attn.qk_norm.q_norm.weight +model.layers.16.self_attn.qk_norm.q_norm.weight:model.layers.16.self_attn.qk_norm.q_norm.weight model.layers.16.self_attn.qkv_proj.weight model.layers.16.self_attn.qkv_proj.weight:model.layers.16.self_attn.qkv_proj.weight model.layers.16.self_attn.qkv_proj.weight_scale @@ -192,13 +192,13 @@ model.layers.17.mlp.up_gate_proj.weight:model.layers.17.mlp.gate_up_fused_proj.w model.layers.17.mlp.up_gate_proj.weight_scale model.layers.17.post_attention_layernorm.weight model.layers.17.post_attention_layernorm.weight:model.layers.17.post_attention_layernorm.weight -model.layers.17.self_attn.k_norm.weight -model.layers.17.self_attn.k_norm.weight:model.layers.17.self_attn.k_norm.weight model.layers.17.self_attn.o_proj.weight model.layers.17.self_attn.o_proj.weight:model.layers.17.self_attn.o_proj.weight model.layers.17.self_attn.o_proj.weight_scale -model.layers.17.self_attn.q_norm.weight -model.layers.17.self_attn.q_norm.weight:model.layers.17.self_attn.q_norm.weight +model.layers.17.self_attn.qk_norm.k_norm.weight +model.layers.17.self_attn.qk_norm.k_norm.weight:model.layers.17.self_attn.qk_norm.k_norm.weight +model.layers.17.self_attn.qk_norm.q_norm.weight +model.layers.17.self_attn.qk_norm.q_norm.weight:model.layers.17.self_attn.qk_norm.q_norm.weight model.layers.17.self_attn.qkv_proj.weight model.layers.17.self_attn.qkv_proj.weight:model.layers.17.self_attn.qkv_proj.weight model.layers.17.self_attn.qkv_proj.weight_scale @@ -212,13 +212,13 @@ model.layers.18.mlp.up_gate_proj.weight:model.layers.18.mlp.gate_up_fused_proj.w model.layers.18.mlp.up_gate_proj.weight_scale model.layers.18.post_attention_layernorm.weight model.layers.18.post_attention_layernorm.weight:model.layers.18.post_attention_layernorm.weight -model.layers.18.self_attn.k_norm.weight -model.layers.18.self_attn.k_norm.weight:model.layers.18.self_attn.k_norm.weight model.layers.18.self_attn.o_proj.weight model.layers.18.self_attn.o_proj.weight:model.layers.18.self_attn.o_proj.weight model.layers.18.self_attn.o_proj.weight_scale -model.layers.18.self_attn.q_norm.weight -model.layers.18.self_attn.q_norm.weight:model.layers.18.self_attn.q_norm.weight +model.layers.18.self_attn.qk_norm.k_norm.weight +model.layers.18.self_attn.qk_norm.k_norm.weight:model.layers.18.self_attn.qk_norm.k_norm.weight +model.layers.18.self_attn.qk_norm.q_norm.weight +model.layers.18.self_attn.qk_norm.q_norm.weight:model.layers.18.self_attn.qk_norm.q_norm.weight model.layers.18.self_attn.qkv_proj.weight model.layers.18.self_attn.qkv_proj.weight:model.layers.18.self_attn.qkv_proj.weight model.layers.18.self_attn.qkv_proj.weight_scale @@ -232,13 +232,13 @@ model.layers.19.mlp.up_gate_proj.weight:model.layers.19.mlp.gate_up_fused_proj.w model.layers.19.mlp.up_gate_proj.weight_scale model.layers.19.post_attention_layernorm.weight model.layers.19.post_attention_layernorm.weight:model.layers.19.post_attention_layernorm.weight -model.layers.19.self_attn.k_norm.weight -model.layers.19.self_attn.k_norm.weight:model.layers.19.self_attn.k_norm.weight model.layers.19.self_attn.o_proj.weight model.layers.19.self_attn.o_proj.weight:model.layers.19.self_attn.o_proj.weight model.layers.19.self_attn.o_proj.weight_scale -model.layers.19.self_attn.q_norm.weight -model.layers.19.self_attn.q_norm.weight:model.layers.19.self_attn.q_norm.weight +model.layers.19.self_attn.qk_norm.k_norm.weight +model.layers.19.self_attn.qk_norm.k_norm.weight:model.layers.19.self_attn.qk_norm.k_norm.weight +model.layers.19.self_attn.qk_norm.q_norm.weight +model.layers.19.self_attn.qk_norm.q_norm.weight:model.layers.19.self_attn.qk_norm.q_norm.weight model.layers.19.self_attn.qkv_proj.weight model.layers.19.self_attn.qkv_proj.weight:model.layers.19.self_attn.qkv_proj.weight model.layers.19.self_attn.qkv_proj.weight_scale @@ -252,13 +252,13 @@ model.layers.2.mlp.up_gate_proj.weight:model.layers.2.mlp.gate_up_fused_proj.wei model.layers.2.mlp.up_gate_proj.weight_scale model.layers.2.post_attention_layernorm.weight model.layers.2.post_attention_layernorm.weight:model.layers.2.post_attention_layernorm.weight -model.layers.2.self_attn.k_norm.weight -model.layers.2.self_attn.k_norm.weight:model.layers.2.self_attn.k_norm.weight model.layers.2.self_attn.o_proj.weight model.layers.2.self_attn.o_proj.weight:model.layers.2.self_attn.o_proj.weight model.layers.2.self_attn.o_proj.weight_scale -model.layers.2.self_attn.q_norm.weight -model.layers.2.self_attn.q_norm.weight:model.layers.2.self_attn.q_norm.weight +model.layers.2.self_attn.qk_norm.k_norm.weight +model.layers.2.self_attn.qk_norm.k_norm.weight:model.layers.2.self_attn.qk_norm.k_norm.weight +model.layers.2.self_attn.qk_norm.q_norm.weight +model.layers.2.self_attn.qk_norm.q_norm.weight:model.layers.2.self_attn.qk_norm.q_norm.weight model.layers.2.self_attn.qkv_proj.weight model.layers.2.self_attn.qkv_proj.weight:model.layers.2.self_attn.qkv_proj.weight model.layers.2.self_attn.qkv_proj.weight_scale @@ -272,13 +272,13 @@ model.layers.20.mlp.up_gate_proj.weight:model.layers.20.mlp.gate_up_fused_proj.w model.layers.20.mlp.up_gate_proj.weight_scale model.layers.20.post_attention_layernorm.weight model.layers.20.post_attention_layernorm.weight:model.layers.20.post_attention_layernorm.weight -model.layers.20.self_attn.k_norm.weight -model.layers.20.self_attn.k_norm.weight:model.layers.20.self_attn.k_norm.weight model.layers.20.self_attn.o_proj.weight model.layers.20.self_attn.o_proj.weight:model.layers.20.self_attn.o_proj.weight model.layers.20.self_attn.o_proj.weight_scale -model.layers.20.self_attn.q_norm.weight -model.layers.20.self_attn.q_norm.weight:model.layers.20.self_attn.q_norm.weight +model.layers.20.self_attn.qk_norm.k_norm.weight +model.layers.20.self_attn.qk_norm.k_norm.weight:model.layers.20.self_attn.qk_norm.k_norm.weight +model.layers.20.self_attn.qk_norm.q_norm.weight +model.layers.20.self_attn.qk_norm.q_norm.weight:model.layers.20.self_attn.qk_norm.q_norm.weight model.layers.20.self_attn.qkv_proj.weight model.layers.20.self_attn.qkv_proj.weight:model.layers.20.self_attn.qkv_proj.weight model.layers.20.self_attn.qkv_proj.weight_scale @@ -292,13 +292,13 @@ model.layers.21.mlp.up_gate_proj.weight:model.layers.21.mlp.gate_up_fused_proj.w model.layers.21.mlp.up_gate_proj.weight_scale model.layers.21.post_attention_layernorm.weight model.layers.21.post_attention_layernorm.weight:model.layers.21.post_attention_layernorm.weight -model.layers.21.self_attn.k_norm.weight -model.layers.21.self_attn.k_norm.weight:model.layers.21.self_attn.k_norm.weight model.layers.21.self_attn.o_proj.weight model.layers.21.self_attn.o_proj.weight:model.layers.21.self_attn.o_proj.weight model.layers.21.self_attn.o_proj.weight_scale -model.layers.21.self_attn.q_norm.weight -model.layers.21.self_attn.q_norm.weight:model.layers.21.self_attn.q_norm.weight +model.layers.21.self_attn.qk_norm.k_norm.weight +model.layers.21.self_attn.qk_norm.k_norm.weight:model.layers.21.self_attn.qk_norm.k_norm.weight +model.layers.21.self_attn.qk_norm.q_norm.weight +model.layers.21.self_attn.qk_norm.q_norm.weight:model.layers.21.self_attn.qk_norm.q_norm.weight model.layers.21.self_attn.qkv_proj.weight model.layers.21.self_attn.qkv_proj.weight:model.layers.21.self_attn.qkv_proj.weight model.layers.21.self_attn.qkv_proj.weight_scale @@ -312,13 +312,13 @@ model.layers.22.mlp.up_gate_proj.weight:model.layers.22.mlp.gate_up_fused_proj.w model.layers.22.mlp.up_gate_proj.weight_scale model.layers.22.post_attention_layernorm.weight model.layers.22.post_attention_layernorm.weight:model.layers.22.post_attention_layernorm.weight -model.layers.22.self_attn.k_norm.weight -model.layers.22.self_attn.k_norm.weight:model.layers.22.self_attn.k_norm.weight model.layers.22.self_attn.o_proj.weight model.layers.22.self_attn.o_proj.weight:model.layers.22.self_attn.o_proj.weight model.layers.22.self_attn.o_proj.weight_scale -model.layers.22.self_attn.q_norm.weight -model.layers.22.self_attn.q_norm.weight:model.layers.22.self_attn.q_norm.weight +model.layers.22.self_attn.qk_norm.k_norm.weight +model.layers.22.self_attn.qk_norm.k_norm.weight:model.layers.22.self_attn.qk_norm.k_norm.weight +model.layers.22.self_attn.qk_norm.q_norm.weight +model.layers.22.self_attn.qk_norm.q_norm.weight:model.layers.22.self_attn.qk_norm.q_norm.weight model.layers.22.self_attn.qkv_proj.weight model.layers.22.self_attn.qkv_proj.weight:model.layers.22.self_attn.qkv_proj.weight model.layers.22.self_attn.qkv_proj.weight_scale @@ -332,13 +332,13 @@ model.layers.23.mlp.up_gate_proj.weight:model.layers.23.mlp.gate_up_fused_proj.w model.layers.23.mlp.up_gate_proj.weight_scale model.layers.23.post_attention_layernorm.weight model.layers.23.post_attention_layernorm.weight:model.layers.23.post_attention_layernorm.weight -model.layers.23.self_attn.k_norm.weight -model.layers.23.self_attn.k_norm.weight:model.layers.23.self_attn.k_norm.weight model.layers.23.self_attn.o_proj.weight model.layers.23.self_attn.o_proj.weight:model.layers.23.self_attn.o_proj.weight model.layers.23.self_attn.o_proj.weight_scale -model.layers.23.self_attn.q_norm.weight -model.layers.23.self_attn.q_norm.weight:model.layers.23.self_attn.q_norm.weight +model.layers.23.self_attn.qk_norm.k_norm.weight +model.layers.23.self_attn.qk_norm.k_norm.weight:model.layers.23.self_attn.qk_norm.k_norm.weight +model.layers.23.self_attn.qk_norm.q_norm.weight +model.layers.23.self_attn.qk_norm.q_norm.weight:model.layers.23.self_attn.qk_norm.q_norm.weight model.layers.23.self_attn.qkv_proj.weight model.layers.23.self_attn.qkv_proj.weight:model.layers.23.self_attn.qkv_proj.weight model.layers.23.self_attn.qkv_proj.weight_scale @@ -352,13 +352,13 @@ model.layers.24.mlp.up_gate_proj.weight:model.layers.24.mlp.gate_up_fused_proj.w model.layers.24.mlp.up_gate_proj.weight_scale model.layers.24.post_attention_layernorm.weight model.layers.24.post_attention_layernorm.weight:model.layers.24.post_attention_layernorm.weight -model.layers.24.self_attn.k_norm.weight -model.layers.24.self_attn.k_norm.weight:model.layers.24.self_attn.k_norm.weight model.layers.24.self_attn.o_proj.weight model.layers.24.self_attn.o_proj.weight:model.layers.24.self_attn.o_proj.weight model.layers.24.self_attn.o_proj.weight_scale -model.layers.24.self_attn.q_norm.weight -model.layers.24.self_attn.q_norm.weight:model.layers.24.self_attn.q_norm.weight +model.layers.24.self_attn.qk_norm.k_norm.weight +model.layers.24.self_attn.qk_norm.k_norm.weight:model.layers.24.self_attn.qk_norm.k_norm.weight +model.layers.24.self_attn.qk_norm.q_norm.weight +model.layers.24.self_attn.qk_norm.q_norm.weight:model.layers.24.self_attn.qk_norm.q_norm.weight model.layers.24.self_attn.qkv_proj.weight model.layers.24.self_attn.qkv_proj.weight:model.layers.24.self_attn.qkv_proj.weight model.layers.24.self_attn.qkv_proj.weight_scale @@ -372,13 +372,13 @@ model.layers.25.mlp.up_gate_proj.weight:model.layers.25.mlp.gate_up_fused_proj.w model.layers.25.mlp.up_gate_proj.weight_scale model.layers.25.post_attention_layernorm.weight model.layers.25.post_attention_layernorm.weight:model.layers.25.post_attention_layernorm.weight -model.layers.25.self_attn.k_norm.weight -model.layers.25.self_attn.k_norm.weight:model.layers.25.self_attn.k_norm.weight model.layers.25.self_attn.o_proj.weight model.layers.25.self_attn.o_proj.weight:model.layers.25.self_attn.o_proj.weight model.layers.25.self_attn.o_proj.weight_scale -model.layers.25.self_attn.q_norm.weight -model.layers.25.self_attn.q_norm.weight:model.layers.25.self_attn.q_norm.weight +model.layers.25.self_attn.qk_norm.k_norm.weight +model.layers.25.self_attn.qk_norm.k_norm.weight:model.layers.25.self_attn.qk_norm.k_norm.weight +model.layers.25.self_attn.qk_norm.q_norm.weight +model.layers.25.self_attn.qk_norm.q_norm.weight:model.layers.25.self_attn.qk_norm.q_norm.weight model.layers.25.self_attn.qkv_proj.weight model.layers.25.self_attn.qkv_proj.weight:model.layers.25.self_attn.qkv_proj.weight model.layers.25.self_attn.qkv_proj.weight_scale @@ -392,13 +392,13 @@ model.layers.26.mlp.up_gate_proj.weight:model.layers.26.mlp.gate_up_fused_proj.w model.layers.26.mlp.up_gate_proj.weight_scale model.layers.26.post_attention_layernorm.weight model.layers.26.post_attention_layernorm.weight:model.layers.26.post_attention_layernorm.weight -model.layers.26.self_attn.k_norm.weight -model.layers.26.self_attn.k_norm.weight:model.layers.26.self_attn.k_norm.weight model.layers.26.self_attn.o_proj.weight model.layers.26.self_attn.o_proj.weight:model.layers.26.self_attn.o_proj.weight model.layers.26.self_attn.o_proj.weight_scale -model.layers.26.self_attn.q_norm.weight -model.layers.26.self_attn.q_norm.weight:model.layers.26.self_attn.q_norm.weight +model.layers.26.self_attn.qk_norm.k_norm.weight +model.layers.26.self_attn.qk_norm.k_norm.weight:model.layers.26.self_attn.qk_norm.k_norm.weight +model.layers.26.self_attn.qk_norm.q_norm.weight +model.layers.26.self_attn.qk_norm.q_norm.weight:model.layers.26.self_attn.qk_norm.q_norm.weight model.layers.26.self_attn.qkv_proj.weight model.layers.26.self_attn.qkv_proj.weight:model.layers.26.self_attn.qkv_proj.weight model.layers.26.self_attn.qkv_proj.weight_scale @@ -412,13 +412,13 @@ model.layers.27.mlp.up_gate_proj.weight:model.layers.27.mlp.gate_up_fused_proj.w model.layers.27.mlp.up_gate_proj.weight_scale model.layers.27.post_attention_layernorm.weight model.layers.27.post_attention_layernorm.weight:model.layers.27.post_attention_layernorm.weight -model.layers.27.self_attn.k_norm.weight -model.layers.27.self_attn.k_norm.weight:model.layers.27.self_attn.k_norm.weight model.layers.27.self_attn.o_proj.weight model.layers.27.self_attn.o_proj.weight:model.layers.27.self_attn.o_proj.weight model.layers.27.self_attn.o_proj.weight_scale -model.layers.27.self_attn.q_norm.weight -model.layers.27.self_attn.q_norm.weight:model.layers.27.self_attn.q_norm.weight +model.layers.27.self_attn.qk_norm.k_norm.weight +model.layers.27.self_attn.qk_norm.k_norm.weight:model.layers.27.self_attn.qk_norm.k_norm.weight +model.layers.27.self_attn.qk_norm.q_norm.weight +model.layers.27.self_attn.qk_norm.q_norm.weight:model.layers.27.self_attn.qk_norm.q_norm.weight model.layers.27.self_attn.qkv_proj.weight model.layers.27.self_attn.qkv_proj.weight:model.layers.27.self_attn.qkv_proj.weight model.layers.27.self_attn.qkv_proj.weight_scale @@ -432,13 +432,13 @@ model.layers.28.mlp.up_gate_proj.weight:model.layers.28.mlp.gate_up_fused_proj.w model.layers.28.mlp.up_gate_proj.weight_scale model.layers.28.post_attention_layernorm.weight model.layers.28.post_attention_layernorm.weight:model.layers.28.post_attention_layernorm.weight -model.layers.28.self_attn.k_norm.weight -model.layers.28.self_attn.k_norm.weight:model.layers.28.self_attn.k_norm.weight model.layers.28.self_attn.o_proj.weight model.layers.28.self_attn.o_proj.weight:model.layers.28.self_attn.o_proj.weight model.layers.28.self_attn.o_proj.weight_scale -model.layers.28.self_attn.q_norm.weight -model.layers.28.self_attn.q_norm.weight:model.layers.28.self_attn.q_norm.weight +model.layers.28.self_attn.qk_norm.k_norm.weight +model.layers.28.self_attn.qk_norm.k_norm.weight:model.layers.28.self_attn.qk_norm.k_norm.weight +model.layers.28.self_attn.qk_norm.q_norm.weight +model.layers.28.self_attn.qk_norm.q_norm.weight:model.layers.28.self_attn.qk_norm.q_norm.weight model.layers.28.self_attn.qkv_proj.weight model.layers.28.self_attn.qkv_proj.weight:model.layers.28.self_attn.qkv_proj.weight model.layers.28.self_attn.qkv_proj.weight_scale @@ -452,13 +452,13 @@ model.layers.29.mlp.up_gate_proj.weight:model.layers.29.mlp.gate_up_fused_proj.w model.layers.29.mlp.up_gate_proj.weight_scale model.layers.29.post_attention_layernorm.weight model.layers.29.post_attention_layernorm.weight:model.layers.29.post_attention_layernorm.weight -model.layers.29.self_attn.k_norm.weight -model.layers.29.self_attn.k_norm.weight:model.layers.29.self_attn.k_norm.weight model.layers.29.self_attn.o_proj.weight model.layers.29.self_attn.o_proj.weight:model.layers.29.self_attn.o_proj.weight model.layers.29.self_attn.o_proj.weight_scale -model.layers.29.self_attn.q_norm.weight -model.layers.29.self_attn.q_norm.weight:model.layers.29.self_attn.q_norm.weight +model.layers.29.self_attn.qk_norm.k_norm.weight +model.layers.29.self_attn.qk_norm.k_norm.weight:model.layers.29.self_attn.qk_norm.k_norm.weight +model.layers.29.self_attn.qk_norm.q_norm.weight +model.layers.29.self_attn.qk_norm.q_norm.weight:model.layers.29.self_attn.qk_norm.q_norm.weight model.layers.29.self_attn.qkv_proj.weight model.layers.29.self_attn.qkv_proj.weight:model.layers.29.self_attn.qkv_proj.weight model.layers.29.self_attn.qkv_proj.weight_scale @@ -472,13 +472,13 @@ model.layers.3.mlp.up_gate_proj.weight:model.layers.3.mlp.gate_up_fused_proj.wei model.layers.3.mlp.up_gate_proj.weight_scale model.layers.3.post_attention_layernorm.weight model.layers.3.post_attention_layernorm.weight:model.layers.3.post_attention_layernorm.weight -model.layers.3.self_attn.k_norm.weight -model.layers.3.self_attn.k_norm.weight:model.layers.3.self_attn.k_norm.weight model.layers.3.self_attn.o_proj.weight model.layers.3.self_attn.o_proj.weight:model.layers.3.self_attn.o_proj.weight model.layers.3.self_attn.o_proj.weight_scale -model.layers.3.self_attn.q_norm.weight -model.layers.3.self_attn.q_norm.weight:model.layers.3.self_attn.q_norm.weight +model.layers.3.self_attn.qk_norm.k_norm.weight +model.layers.3.self_attn.qk_norm.k_norm.weight:model.layers.3.self_attn.qk_norm.k_norm.weight +model.layers.3.self_attn.qk_norm.q_norm.weight +model.layers.3.self_attn.qk_norm.q_norm.weight:model.layers.3.self_attn.qk_norm.q_norm.weight model.layers.3.self_attn.qkv_proj.weight model.layers.3.self_attn.qkv_proj.weight:model.layers.3.self_attn.qkv_proj.weight model.layers.3.self_attn.qkv_proj.weight_scale @@ -492,13 +492,13 @@ model.layers.30.mlp.up_gate_proj.weight:model.layers.30.mlp.gate_up_fused_proj.w model.layers.30.mlp.up_gate_proj.weight_scale model.layers.30.post_attention_layernorm.weight model.layers.30.post_attention_layernorm.weight:model.layers.30.post_attention_layernorm.weight -model.layers.30.self_attn.k_norm.weight -model.layers.30.self_attn.k_norm.weight:model.layers.30.self_attn.k_norm.weight model.layers.30.self_attn.o_proj.weight model.layers.30.self_attn.o_proj.weight:model.layers.30.self_attn.o_proj.weight model.layers.30.self_attn.o_proj.weight_scale -model.layers.30.self_attn.q_norm.weight -model.layers.30.self_attn.q_norm.weight:model.layers.30.self_attn.q_norm.weight +model.layers.30.self_attn.qk_norm.k_norm.weight +model.layers.30.self_attn.qk_norm.k_norm.weight:model.layers.30.self_attn.qk_norm.k_norm.weight +model.layers.30.self_attn.qk_norm.q_norm.weight +model.layers.30.self_attn.qk_norm.q_norm.weight:model.layers.30.self_attn.qk_norm.q_norm.weight model.layers.30.self_attn.qkv_proj.weight model.layers.30.self_attn.qkv_proj.weight:model.layers.30.self_attn.qkv_proj.weight model.layers.30.self_attn.qkv_proj.weight_scale @@ -512,13 +512,13 @@ model.layers.31.mlp.up_gate_proj.weight:model.layers.31.mlp.gate_up_fused_proj.w model.layers.31.mlp.up_gate_proj.weight_scale model.layers.31.post_attention_layernorm.weight model.layers.31.post_attention_layernorm.weight:model.layers.31.post_attention_layernorm.weight -model.layers.31.self_attn.k_norm.weight -model.layers.31.self_attn.k_norm.weight:model.layers.31.self_attn.k_norm.weight model.layers.31.self_attn.o_proj.weight model.layers.31.self_attn.o_proj.weight:model.layers.31.self_attn.o_proj.weight model.layers.31.self_attn.o_proj.weight_scale -model.layers.31.self_attn.q_norm.weight -model.layers.31.self_attn.q_norm.weight:model.layers.31.self_attn.q_norm.weight +model.layers.31.self_attn.qk_norm.k_norm.weight +model.layers.31.self_attn.qk_norm.k_norm.weight:model.layers.31.self_attn.qk_norm.k_norm.weight +model.layers.31.self_attn.qk_norm.q_norm.weight +model.layers.31.self_attn.qk_norm.q_norm.weight:model.layers.31.self_attn.qk_norm.q_norm.weight model.layers.31.self_attn.qkv_proj.weight model.layers.31.self_attn.qkv_proj.weight:model.layers.31.self_attn.qkv_proj.weight model.layers.31.self_attn.qkv_proj.weight_scale @@ -532,13 +532,13 @@ model.layers.32.mlp.up_gate_proj.weight:model.layers.32.mlp.gate_up_fused_proj.w model.layers.32.mlp.up_gate_proj.weight_scale model.layers.32.post_attention_layernorm.weight model.layers.32.post_attention_layernorm.weight:model.layers.32.post_attention_layernorm.weight -model.layers.32.self_attn.k_norm.weight -model.layers.32.self_attn.k_norm.weight:model.layers.32.self_attn.k_norm.weight model.layers.32.self_attn.o_proj.weight model.layers.32.self_attn.o_proj.weight:model.layers.32.self_attn.o_proj.weight model.layers.32.self_attn.o_proj.weight_scale -model.layers.32.self_attn.q_norm.weight -model.layers.32.self_attn.q_norm.weight:model.layers.32.self_attn.q_norm.weight +model.layers.32.self_attn.qk_norm.k_norm.weight +model.layers.32.self_attn.qk_norm.k_norm.weight:model.layers.32.self_attn.qk_norm.k_norm.weight +model.layers.32.self_attn.qk_norm.q_norm.weight +model.layers.32.self_attn.qk_norm.q_norm.weight:model.layers.32.self_attn.qk_norm.q_norm.weight model.layers.32.self_attn.qkv_proj.weight model.layers.32.self_attn.qkv_proj.weight:model.layers.32.self_attn.qkv_proj.weight model.layers.32.self_attn.qkv_proj.weight_scale @@ -552,13 +552,13 @@ model.layers.33.mlp.up_gate_proj.weight:model.layers.33.mlp.gate_up_fused_proj.w model.layers.33.mlp.up_gate_proj.weight_scale model.layers.33.post_attention_layernorm.weight model.layers.33.post_attention_layernorm.weight:model.layers.33.post_attention_layernorm.weight -model.layers.33.self_attn.k_norm.weight -model.layers.33.self_attn.k_norm.weight:model.layers.33.self_attn.k_norm.weight model.layers.33.self_attn.o_proj.weight model.layers.33.self_attn.o_proj.weight:model.layers.33.self_attn.o_proj.weight model.layers.33.self_attn.o_proj.weight_scale -model.layers.33.self_attn.q_norm.weight -model.layers.33.self_attn.q_norm.weight:model.layers.33.self_attn.q_norm.weight +model.layers.33.self_attn.qk_norm.k_norm.weight +model.layers.33.self_attn.qk_norm.k_norm.weight:model.layers.33.self_attn.qk_norm.k_norm.weight +model.layers.33.self_attn.qk_norm.q_norm.weight +model.layers.33.self_attn.qk_norm.q_norm.weight:model.layers.33.self_attn.qk_norm.q_norm.weight model.layers.33.self_attn.qkv_proj.weight model.layers.33.self_attn.qkv_proj.weight:model.layers.33.self_attn.qkv_proj.weight model.layers.33.self_attn.qkv_proj.weight_scale @@ -572,13 +572,13 @@ model.layers.34.mlp.up_gate_proj.weight:model.layers.34.mlp.gate_up_fused_proj.w model.layers.34.mlp.up_gate_proj.weight_scale model.layers.34.post_attention_layernorm.weight model.layers.34.post_attention_layernorm.weight:model.layers.34.post_attention_layernorm.weight -model.layers.34.self_attn.k_norm.weight -model.layers.34.self_attn.k_norm.weight:model.layers.34.self_attn.k_norm.weight model.layers.34.self_attn.o_proj.weight model.layers.34.self_attn.o_proj.weight:model.layers.34.self_attn.o_proj.weight model.layers.34.self_attn.o_proj.weight_scale -model.layers.34.self_attn.q_norm.weight -model.layers.34.self_attn.q_norm.weight:model.layers.34.self_attn.q_norm.weight +model.layers.34.self_attn.qk_norm.k_norm.weight +model.layers.34.self_attn.qk_norm.k_norm.weight:model.layers.34.self_attn.qk_norm.k_norm.weight +model.layers.34.self_attn.qk_norm.q_norm.weight +model.layers.34.self_attn.qk_norm.q_norm.weight:model.layers.34.self_attn.qk_norm.q_norm.weight model.layers.34.self_attn.qkv_proj.weight model.layers.34.self_attn.qkv_proj.weight:model.layers.34.self_attn.qkv_proj.weight model.layers.34.self_attn.qkv_proj.weight_scale @@ -592,13 +592,13 @@ model.layers.35.mlp.up_gate_proj.weight:model.layers.35.mlp.gate_up_fused_proj.w model.layers.35.mlp.up_gate_proj.weight_scale model.layers.35.post_attention_layernorm.weight model.layers.35.post_attention_layernorm.weight:model.layers.35.post_attention_layernorm.weight -model.layers.35.self_attn.k_norm.weight -model.layers.35.self_attn.k_norm.weight:model.layers.35.self_attn.k_norm.weight model.layers.35.self_attn.o_proj.weight model.layers.35.self_attn.o_proj.weight:model.layers.35.self_attn.o_proj.weight model.layers.35.self_attn.o_proj.weight_scale -model.layers.35.self_attn.q_norm.weight -model.layers.35.self_attn.q_norm.weight:model.layers.35.self_attn.q_norm.weight +model.layers.35.self_attn.qk_norm.k_norm.weight +model.layers.35.self_attn.qk_norm.k_norm.weight:model.layers.35.self_attn.qk_norm.k_norm.weight +model.layers.35.self_attn.qk_norm.q_norm.weight +model.layers.35.self_attn.qk_norm.q_norm.weight:model.layers.35.self_attn.qk_norm.q_norm.weight model.layers.35.self_attn.qkv_proj.weight model.layers.35.self_attn.qkv_proj.weight:model.layers.35.self_attn.qkv_proj.weight model.layers.35.self_attn.qkv_proj.weight_scale @@ -612,13 +612,13 @@ model.layers.4.mlp.up_gate_proj.weight:model.layers.4.mlp.gate_up_fused_proj.wei model.layers.4.mlp.up_gate_proj.weight_scale model.layers.4.post_attention_layernorm.weight model.layers.4.post_attention_layernorm.weight:model.layers.4.post_attention_layernorm.weight -model.layers.4.self_attn.k_norm.weight -model.layers.4.self_attn.k_norm.weight:model.layers.4.self_attn.k_norm.weight model.layers.4.self_attn.o_proj.weight model.layers.4.self_attn.o_proj.weight:model.layers.4.self_attn.o_proj.weight model.layers.4.self_attn.o_proj.weight_scale -model.layers.4.self_attn.q_norm.weight -model.layers.4.self_attn.q_norm.weight:model.layers.4.self_attn.q_norm.weight +model.layers.4.self_attn.qk_norm.k_norm.weight +model.layers.4.self_attn.qk_norm.k_norm.weight:model.layers.4.self_attn.qk_norm.k_norm.weight +model.layers.4.self_attn.qk_norm.q_norm.weight +model.layers.4.self_attn.qk_norm.q_norm.weight:model.layers.4.self_attn.qk_norm.q_norm.weight model.layers.4.self_attn.qkv_proj.weight model.layers.4.self_attn.qkv_proj.weight:model.layers.4.self_attn.qkv_proj.weight model.layers.4.self_attn.qkv_proj.weight_scale @@ -632,13 +632,13 @@ model.layers.5.mlp.up_gate_proj.weight:model.layers.5.mlp.gate_up_fused_proj.wei model.layers.5.mlp.up_gate_proj.weight_scale model.layers.5.post_attention_layernorm.weight model.layers.5.post_attention_layernorm.weight:model.layers.5.post_attention_layernorm.weight -model.layers.5.self_attn.k_norm.weight -model.layers.5.self_attn.k_norm.weight:model.layers.5.self_attn.k_norm.weight model.layers.5.self_attn.o_proj.weight model.layers.5.self_attn.o_proj.weight:model.layers.5.self_attn.o_proj.weight model.layers.5.self_attn.o_proj.weight_scale -model.layers.5.self_attn.q_norm.weight -model.layers.5.self_attn.q_norm.weight:model.layers.5.self_attn.q_norm.weight +model.layers.5.self_attn.qk_norm.k_norm.weight +model.layers.5.self_attn.qk_norm.k_norm.weight:model.layers.5.self_attn.qk_norm.k_norm.weight +model.layers.5.self_attn.qk_norm.q_norm.weight +model.layers.5.self_attn.qk_norm.q_norm.weight:model.layers.5.self_attn.qk_norm.q_norm.weight model.layers.5.self_attn.qkv_proj.weight model.layers.5.self_attn.qkv_proj.weight:model.layers.5.self_attn.qkv_proj.weight model.layers.5.self_attn.qkv_proj.weight_scale @@ -652,13 +652,13 @@ model.layers.6.mlp.up_gate_proj.weight:model.layers.6.mlp.gate_up_fused_proj.wei model.layers.6.mlp.up_gate_proj.weight_scale model.layers.6.post_attention_layernorm.weight model.layers.6.post_attention_layernorm.weight:model.layers.6.post_attention_layernorm.weight -model.layers.6.self_attn.k_norm.weight -model.layers.6.self_attn.k_norm.weight:model.layers.6.self_attn.k_norm.weight model.layers.6.self_attn.o_proj.weight model.layers.6.self_attn.o_proj.weight:model.layers.6.self_attn.o_proj.weight model.layers.6.self_attn.o_proj.weight_scale -model.layers.6.self_attn.q_norm.weight -model.layers.6.self_attn.q_norm.weight:model.layers.6.self_attn.q_norm.weight +model.layers.6.self_attn.qk_norm.k_norm.weight +model.layers.6.self_attn.qk_norm.k_norm.weight:model.layers.6.self_attn.qk_norm.k_norm.weight +model.layers.6.self_attn.qk_norm.q_norm.weight +model.layers.6.self_attn.qk_norm.q_norm.weight:model.layers.6.self_attn.qk_norm.q_norm.weight model.layers.6.self_attn.qkv_proj.weight model.layers.6.self_attn.qkv_proj.weight:model.layers.6.self_attn.qkv_proj.weight model.layers.6.self_attn.qkv_proj.weight_scale @@ -672,13 +672,13 @@ model.layers.7.mlp.up_gate_proj.weight:model.layers.7.mlp.gate_up_fused_proj.wei model.layers.7.mlp.up_gate_proj.weight_scale model.layers.7.post_attention_layernorm.weight model.layers.7.post_attention_layernorm.weight:model.layers.7.post_attention_layernorm.weight -model.layers.7.self_attn.k_norm.weight -model.layers.7.self_attn.k_norm.weight:model.layers.7.self_attn.k_norm.weight model.layers.7.self_attn.o_proj.weight model.layers.7.self_attn.o_proj.weight:model.layers.7.self_attn.o_proj.weight model.layers.7.self_attn.o_proj.weight_scale -model.layers.7.self_attn.q_norm.weight -model.layers.7.self_attn.q_norm.weight:model.layers.7.self_attn.q_norm.weight +model.layers.7.self_attn.qk_norm.k_norm.weight +model.layers.7.self_attn.qk_norm.k_norm.weight:model.layers.7.self_attn.qk_norm.k_norm.weight +model.layers.7.self_attn.qk_norm.q_norm.weight +model.layers.7.self_attn.qk_norm.q_norm.weight:model.layers.7.self_attn.qk_norm.q_norm.weight model.layers.7.self_attn.qkv_proj.weight model.layers.7.self_attn.qkv_proj.weight:model.layers.7.self_attn.qkv_proj.weight model.layers.7.self_attn.qkv_proj.weight_scale @@ -692,13 +692,13 @@ model.layers.8.mlp.up_gate_proj.weight:model.layers.8.mlp.gate_up_fused_proj.wei model.layers.8.mlp.up_gate_proj.weight_scale model.layers.8.post_attention_layernorm.weight model.layers.8.post_attention_layernorm.weight:model.layers.8.post_attention_layernorm.weight -model.layers.8.self_attn.k_norm.weight -model.layers.8.self_attn.k_norm.weight:model.layers.8.self_attn.k_norm.weight model.layers.8.self_attn.o_proj.weight model.layers.8.self_attn.o_proj.weight:model.layers.8.self_attn.o_proj.weight model.layers.8.self_attn.o_proj.weight_scale -model.layers.8.self_attn.q_norm.weight -model.layers.8.self_attn.q_norm.weight:model.layers.8.self_attn.q_norm.weight +model.layers.8.self_attn.qk_norm.k_norm.weight +model.layers.8.self_attn.qk_norm.k_norm.weight:model.layers.8.self_attn.qk_norm.k_norm.weight +model.layers.8.self_attn.qk_norm.q_norm.weight +model.layers.8.self_attn.qk_norm.q_norm.weight:model.layers.8.self_attn.qk_norm.q_norm.weight model.layers.8.self_attn.qkv_proj.weight model.layers.8.self_attn.qkv_proj.weight:model.layers.8.self_attn.qkv_proj.weight model.layers.8.self_attn.qkv_proj.weight_scale @@ -712,13 +712,13 @@ model.layers.9.mlp.up_gate_proj.weight:model.layers.9.mlp.gate_up_fused_proj.wei model.layers.9.mlp.up_gate_proj.weight_scale model.layers.9.post_attention_layernorm.weight model.layers.9.post_attention_layernorm.weight:model.layers.9.post_attention_layernorm.weight -model.layers.9.self_attn.k_norm.weight -model.layers.9.self_attn.k_norm.weight:model.layers.9.self_attn.k_norm.weight model.layers.9.self_attn.o_proj.weight model.layers.9.self_attn.o_proj.weight:model.layers.9.self_attn.o_proj.weight model.layers.9.self_attn.o_proj.weight_scale -model.layers.9.self_attn.q_norm.weight -model.layers.9.self_attn.q_norm.weight:model.layers.9.self_attn.q_norm.weight +model.layers.9.self_attn.qk_norm.k_norm.weight +model.layers.9.self_attn.qk_norm.k_norm.weight:model.layers.9.self_attn.qk_norm.k_norm.weight +model.layers.9.self_attn.qk_norm.q_norm.weight +model.layers.9.self_attn.qk_norm.q_norm.weight:model.layers.9.self_attn.qk_norm.q_norm.weight model.layers.9.self_attn.qkv_proj.weight model.layers.9.self_attn.qkv_proj.weight:model.layers.9.self_attn.qkv_proj.weight model.layers.9.self_attn.qkv_proj.weight_scale diff --git a/tests/operators/test_qk_rmsnorm_fused.py b/tests/operators/test_qk_rmsnorm_fused.py new file mode 100644 index 00000000000..6318778955e --- /dev/null +++ b/tests/operators/test_qk_rmsnorm_fused.py @@ -0,0 +1,133 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.triton_ops import qk_rmsnorm_fused +from tests.utils import OpPerformanceTester + +paddle.set_default_dtype("bfloat16") +paddle.seed(99) + + +class TestQKNorm(unittest.TestCase): + def setUp(self) -> None: + # Qwen3-30B-A3B TP1 + self.hidden_size = 2048 + self.num_attention_heads = 32 + self.num_key_value_heads = 4 + self.num_hidden_layers = 48 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.tp_size = 1 + + # # Qwen3-235B-A22B TP4 + # self.hidden_size = 4096 + # self.num_attention_heads = 64 + # self.num_key_value_heads = 4 + # self.num_hidden_layers = 94 + # self.head_dim = 128 + # self.rms_norm_eps = 1e-6 + # self.tp_size = 4 + + # # GLM_4.6 TP4 + # self.hidden_size = 5120 + # self.num_attention_heads = 96 + # self.num_key_value_heads = 8 + # self.num_hidden_layers = 92 + # self.head_dim = 128 + # self.rms_norm_eps = 1e-5 + # self.tp_size = 4 + + self.num_kv_heads_replicas = max(1, self.tp_size // self.num_key_value_heads) + self.q_size = self.num_attention_heads * self.head_dim // self.tp_size + self.kv_size = self.num_key_value_heads * self.head_dim * self.num_kv_heads_replicas // self.tp_size + self.q_norm_weight = paddle.randn([self.head_dim], paddle.bfloat16) + self.k_norm_weight = paddle.randn([self.head_dim], paddle.bfloat16) + + def qk_norm_paddle(self, qkv_out): + q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1) + q_by_head = q.reshape([*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim]) + q_by_head = paddle.incubate.nn.functional.fused_rms_norm( + q_by_head, self.q_norm_weight, None, self.rms_norm_eps, begin_norm_axis=2 + )[0] + q = q_by_head.reshape(q.shape) + + k_by_head = k.reshape([*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim]) + k_by_head = paddle.incubate.nn.functional.fused_rms_norm( + k_by_head, self.k_norm_weight, None, self.rms_norm_eps, begin_norm_axis=2 + )[0] + k = k_by_head.reshape(k.shape) + + qkv_out = paddle.concat([q, k, v], axis=-1) + return qkv_out + + def qk_norm_triton_fused(self, qkv_out): + qkv_out = qk_rmsnorm_fused( + qkv_out, + self.q_norm_weight, + self.k_norm_weight, + self.rms_norm_eps, + self.q_size, + self.kv_size, + self.head_dim, + ) + return qkv_out + + def test_qk_norm_paddle_performance(self): + tester_paddle = OpPerformanceTester( + op_name="qk_norm_paddle", + op_fn=self.qk_norm_paddle, + num_layers=self.num_hidden_layers, + ) + + tester_paddle.benchmark( + input_size=self.head_dim + * (self.num_attention_heads // self.tp_size + 2 * self.num_key_value_heads // self.tp_size), + batch_sizes=[1, 8, 64, 128, 1024, 2048, 4096, 8192], + ) + + def test_qk_norm_fused_performance(self): + tester = OpPerformanceTester( + op_name="qk_norm_triton_fused", + op_fn=self.qk_norm_triton_fused, + num_layers=self.num_hidden_layers, + ) + tester.benchmark( + input_size=self.head_dim + * (self.num_attention_heads // self.tp_size + 2 * self.num_key_value_heads // self.tp_size), + batch_sizes=[1, 8, 64, 128, 1024, 2048, 4096, 8192], + ) + + def test_qk_norm_result(self): + x = paddle.randn( + [ + 128, + self.head_dim + * (self.num_attention_heads // self.tp_size + 2 * self.num_key_value_heads // self.tp_size), + ], + paddle.bfloat16, + ) + out_paddle = self.qk_norm_paddle(x) + out_triton_fused = self.qk_norm_triton_fused(x) + np.testing.assert_allclose(out_triton_fused.numpy(), out_paddle.numpy(), rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + unittest.main()