From 3cf395c51adc904843c51f199179329fbcf55696 Mon Sep 17 00:00:00 2001 From: Ayakouji Date: Fri, 21 Nov 2025 17:05:16 +0800 Subject: [PATCH 1/3] support v1 loader --- fastdeploy/config.py | 15 +- .../qwen_vl_processor/qwen_vl_processor.py | 2 +- fastdeploy/model_executor/models/__init__.py | 4 +- .../models/qwen3_vl/__init__.py | 0 .../models/qwen3_vl/dfnrope/__init__.py | 0 .../models/qwen3_vl/dfnrope/activation.py | 142 +++++ .../models/qwen3_vl/dfnrope/configuration.py | 64 +++ .../models/qwen3_vl/dfnrope/modeling.py | 508 ++++++++++++++++++ .../models/qwen3_vl/qwen3_vl.py | 484 +++++++++++++++++ fastdeploy/model_executor/utils.py | 4 +- fastdeploy/worker/gpu_model_runner.py | 1 + 11 files changed, 1219 insertions(+), 5 deletions(-) create mode 100644 fastdeploy/model_executor/models/qwen3_vl/__init__.py create mode 100644 fastdeploy/model_executor/models/qwen3_vl/dfnrope/__init__.py create mode 100644 fastdeploy/model_executor/models/qwen3_vl/dfnrope/activation.py create mode 100644 fastdeploy/model_executor/models/qwen3_vl/dfnrope/configuration.py create mode 100644 fastdeploy/model_executor/models/qwen3_vl/dfnrope/modeling.py create mode 100644 fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 5ec3df934ac..f1708841924 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -209,6 +209,13 @@ def __init__( pretrained_config, _ = PretrainedConfig.get_config_dict(self.model) self.pretrained_config = PretrainedConfig.from_dict(pretrained_config) + # Some exported configs (e.g. Qwen3-VL) embed the text model's configuration under a `text_config` key. + if "text_config" in pretrained_config and isinstance(pretrained_config["text_config"], dict): + text_fg = pretrained_config.pop("text_config") + for key, value in text_fg.items(): + if not hasattr(self, key): + setattr(self, key, value) + # set attribute from pretrained_config for key, value in pretrained_config.items(): setattr(self, key, value) @@ -325,7 +332,13 @@ def reset_config_value(key, value): def read_model_config(self): config_path = os.path.join(self.model, "config.json") if os.path.exists(config_path): - self.model_config = json.load(open(config_path, "r", encoding="utf-8")) + raw_cfg = json.load(open(config_path, "r", encoding="utf-8")) + if "text_config" in raw_cfg and isinstance(raw_cfg["text_config"], dict): + text_cfg = raw_cfg.pop("text_config") + for k, v in text_cfg.items(): + if k not in raw_cfg: + raw_cfg[k] = v + self.model_config = raw_cfg if "torch_dtype" in self.model_config and "dtype" in self.model_config: raise ValueError( "Only one of 'torch_dtype' or 'dtype' should be present in config.json. " diff --git a/fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py b/fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py index 06f43f335ae..fcaf15f6dce 100644 --- a/fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py +++ b/fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py @@ -67,7 +67,7 @@ def __init__( self.processor = DataProcessor( model_path=model_name_or_path, enable_processor_cache=enable_processor_cache, - tokens_per_second=config.vision_config.tokens_per_second, + # tokens_per_second=config.vision_config.tokens_per_second, tokenizer=self.tokenizer, **processor_kwargs, ) diff --git a/fastdeploy/model_executor/models/__init__.py b/fastdeploy/model_executor/models/__init__.py index 9ac761d2f68..51a46a176ea 100644 --- a/fastdeploy/model_executor/models/__init__.py +++ b/fastdeploy/model_executor/models/__init__.py @@ -59,8 +59,8 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode ): ModelRegistry.register_pretrained_model(attr) - except ImportError: - raise ImportError(f"{module_file=} import error") + except Exception as e: + raise ImportError(f"{module_file=} import error, error message: {e}") auto_models_registry(os.path.dirname(__file__)) diff --git a/fastdeploy/model_executor/models/qwen3_vl/__init__.py b/fastdeploy/model_executor/models/qwen3_vl/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/fastdeploy/model_executor/models/qwen3_vl/dfnrope/__init__.py b/fastdeploy/model_executor/models/qwen3_vl/dfnrope/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/fastdeploy/model_executor/models/qwen3_vl/dfnrope/activation.py b/fastdeploy/model_executor/models/qwen3_vl/dfnrope/activation.py new file mode 100644 index 00000000000..fc4168cc79e --- /dev/null +++ b/fastdeploy/model_executor/models/qwen3_vl/dfnrope/activation.py @@ -0,0 +1,142 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import math +from collections import OrderedDict + +import paddle +import paddle.nn.functional as F +from paddle import Tensor, nn + + +class NewGELUActivation(nn.Layer): + """Google BERT style GELU.""" + + def forward(self, input: Tensor) -> Tensor: + return ( + 0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0)))) + ) + + +class GELUActivation(nn.Layer): + """Original GELU implementation.""" + + def __init__(self, use_gelu_python: bool = False): + super().__init__() + self.act = self._gelu_python if use_gelu_python else nn.functional.gelu + + def _gelu_python(self, input: Tensor) -> Tensor: + return input * 0.5 * (1.0 + paddle.erf(input / math.sqrt(2.0))) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class FastGELUActivation(nn.Layer): + """Fast GELU approximation.""" + + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) + + +class QuickGELUActivation(nn.Layer): + """Quick GELU approximation.""" + + def forward(self, input: Tensor) -> Tensor: + return input * F.sigmoid(1.702 * input) + + +class ClippedGELUActivation(nn.Layer): + """Clipped GELU used by some quantized models.""" + + def __init__(self, min: float, max: float): + if min > max: + raise ValueError(f"min should be < max (got min: {min}, max: {max})") + super().__init__() + self.min = min + self.max = max + + def forward(self, x: Tensor) -> Tensor: + return paddle.clip(gelu(x), self.min, self.max) + + +class SiLUActivation(nn.Layer): + """SiLU / Swish activation.""" + + def forward(self, input: Tensor) -> Tensor: + return F.silu(input) + + +class MishActivation(nn.Layer): + """Mish activation.""" + + def forward(self, input: Tensor) -> Tensor: + return F.mish(input) + + +class LinearActivation(nn.Layer): + """Identity activation.""" + + def forward(self, input: Tensor) -> Tensor: + return input + + +class ClassInstantier(OrderedDict): + """Instantiates layers lazily when accessed.""" + + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "gelu": GELUActivation, + "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), + "gelu_fast": FastGELUActivation, + "gelu_new": NewGELUActivation, + "gelu_tanh": (nn.GELU, {"approximate": "tanh"}), + "gelu_python": (GELUActivation, {"use_gelu_python": True}), + "linear": LinearActivation, + "mish": MishActivation, + "quick_gelu": QuickGELUActivation, + "relu": nn.ReLU, + "relu6": nn.ReLU6, + "sigmoid": nn.Sigmoid, + "silu": SiLUActivation, + "swish": SiLUActivation, + "tanh": nn.Tanh, +} +ACT2FN = ClassInstantier(ACT2CLS) + + +def get_activation_fn(hidden_act: str): + if hidden_act == "gelu_pytorch_tanh": + return ACT2FN["gelu_tanh"] + if hidden_act in ACT2FN: + return ACT2FN[hidden_act] + raise KeyError(f"function {hidden_act} not found in ACT2FN mapping {list(ACT2FN.keys())}") + + +# For backwards compatibility with: from activations import gelu_python +gelu_python = get_activation_fn("gelu_python") +gelu_new = get_activation_fn("gelu_new") +gelu = get_activation_fn("gelu") +gelu_fast = get_activation_fn("gelu_fast") +quick_gelu = get_activation_fn("quick_gelu") +silu = get_activation_fn("silu") +mish = get_activation_fn("mish") +linear_act = get_activation_fn("linear") diff --git a/fastdeploy/model_executor/models/qwen3_vl/dfnrope/configuration.py b/fastdeploy/model_executor/models/qwen3_vl/dfnrope/configuration.py new file mode 100644 index 00000000000..c4928184cb0 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen3_vl/dfnrope/configuration.py @@ -0,0 +1,64 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +from paddleformers.transformers.configuration_utils import PretrainedConfig + +__all__ = [ + "Qwen3VisionTransformerConfig", +] + + +class Qwen3VisionTransformerConfig(PretrainedConfig): + r"""Configuration for the Qwen3 vision encoder used in Qwen3-VL.""" + + model_type = "qwen3_vision_transformer" + + def __init__( + self, + depth: int = 27, + hidden_size: int = 1152, + hidden_act: str = "gelu_tanh", + intermediate_size: int = 4304, + num_heads: int = 16, + in_channels: int = 3, + patch_size: int = 16, + spatial_merge_size: int = 2, + temporal_patch_size: int = 2, + out_hidden_size: int = 3584, + num_position_embeddings: int = 2304, + deepstack_visual_indexes: list[int] | None = None, + initializer_range: float = 0.02, + tokens_per_second: int = 2, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + self.deepstack_visual_indexes = list(deepstack_visual_indexes or []) + self.tokens_per_second = tokens_per_second diff --git a/fastdeploy/model_executor/models/qwen3_vl/dfnrope/modeling.py b/fastdeploy/model_executor/models/qwen3_vl/dfnrope/modeling.py new file mode 100644 index 00000000000..77ee3ff47a9 --- /dev/null +++ b/fastdeploy/model_executor/models/qwen3_vl/dfnrope/modeling.py @@ -0,0 +1,508 @@ +""" +Qwen3 vision encoder implementation for FastDeploy. +""" + +from __future__ import annotations + +from functools import partial + +import numpy as np +import paddle +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import ( + ColumnParallelLinear, + RowParallelLinear, +) +from paddleformers.transformers.model_utils import PretrainedModel +from paddleformers.utils.log import logger + +from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.models.qwen2_5_vl.dfnrope.modeling import ( + VisionFlashAttention2, + VisionRotaryEmbedding, +) +from fastdeploy.model_executor.utils import set_weight_attrs + +from .activation import get_activation_fn +from .configuration import Qwen3VisionTransformerConfig + + +class Qwen3VisionPatchEmbed(nn.Layer): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + model_format: str = "", + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.hidden_size = hidden_size + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3D( + in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias_attr=True, + ) + set_weight_attrs(self.proj.weight, {"weight_need_transpose": model_format == "torch"}) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + # L, C = hidden_states.shape + # hidden_states = hidden_states.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + # hidden_states = self.proj(hidden_states).view(L, self.hidden_size) + target_dtype = self.proj.weight.dtype + sequence_length = hidden_states.shape[0] + hidden_states = hidden_states.reshape( + [-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size] + ) + hidden_states = self.proj(paddle.cast(hidden_states, target_dtype)).reshape( + [sequence_length, self.hidden_size] + ) + return hidden_states + + +class Qwen3VisionMLP(nn.Layer): + def __init__( + self, + dim: int, + hidden_dim: int, + hidden_act: str = "silu", + tensor_parallel_degree: int = 1, + model_format: str = "", + ) -> None: + super().__init__() + self.tensor_parallel_degree = tensor_parallel_degree + + if tensor_parallel_degree > 1: + mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() + self.linear_fc1 = ColumnParallelLinear( + dim, + hidden_dim, + mp_group=mp_group, + gather_output=False, + has_bias=True, + ) + self.linear_fc2 = RowParallelLinear( + hidden_dim, + dim, + mp_group=mp_group, + input_is_parallel=True, + has_bias=True, + ) + set_weight_attrs(self.linear_fc1.weight, {"output_dim": True}) + set_weight_attrs(self.linear_fc2.weight, {"output_dim": False}) + else: + self.linear_fc1 = nn.Linear(dim, hidden_dim, bias_attr=True) + self.linear_fc2 = nn.Linear(hidden_dim, dim, bias_attr=True) + + set_weight_attrs(self.linear_fc1.weight, {"weight_need_transpose": model_format == "torch"}) + set_weight_attrs(self.linear_fc2.weight, {"weight_need_transpose": model_format == "torch"}) + self.act = get_activation_fn(hidden_act) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + hidden_states = self.linear_fc2(self.act(self.linear_fc1(hidden_states))) + return hidden_states + + +class Qwen3VisionPatchMerger(nn.Layer): + def __init__( + self, + d_model: int, + context_dim: int, + spatial_merge_size: int, + tensor_parallel_degree: int, + use_postshuffle_norm: bool = False, + norm_eps: float = 1e-6, + model_format: str = "", + ) -> None: + super().__init__() + self.tensor_parallel_degree = tensor_parallel_degree + self.spatial_merge_size = spatial_merge_size + self.hidden_size = context_dim * (spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + + norm_shape = context_dim if not use_postshuffle_norm else self.hidden_size + self.norm = nn.LayerNorm(norm_shape, epsilon=norm_eps) + + if tensor_parallel_degree > 1: + mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() + self.linear_fc1 = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + mp_group=mp_group, + gather_output=False, + has_bias=True, + ) + self.linear_fc2 = RowParallelLinear( + self.hidden_size, + d_model, + mp_group=mp_group, + input_is_parallel=True, + has_bias=True, + ) + set_weight_attrs(self.linear_fc1.weight, {"output_dim": True}) + set_weight_attrs(self.linear_fc2.weight, {"output_dim": False}) + else: + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size, bias_attr=True) + self.linear_fc2 = nn.Linear(self.hidden_size, d_model, bias_attr=True) + + set_weight_attrs(self.linear_fc1.weight, {"weight_need_transpose": model_format == "torch"}) + set_weight_attrs(self.linear_fc2.weight, {"weight_need_transpose": model_format == "torch"}) + + self.act = nn.GELU() + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + # if self.use_postshuffle_norm: + # hidden_states = self.norm(hidden_states.reshape([-1, self.hidden_size])) + # else: + # hidden_states = self.norm(hidden_states).reshape([-1, self.hidden_size]) + + # hidden_states = self.linear_fc1(hidden_states) + # hidden_states = + # hidden_states = + if self.use_postshuffle_norm: + hidden_states = self.norm(hidden_states.view(-1, self.hidden_size)) + else: + hidden_states = self.norm(hidden_states).view(-1, self.hidden_size) + + hidden_states = self.linear_fc2(self.act(self.linear_fc1(hidden_states))) + return hidden_states + + +class Qwen3VisionBlock(nn.Layer): + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + hidden_act: str, + tensor_parallel_degree: int, + norm_eps: float, + model_format: str = "", + ) -> None: + super().__init__() + self.norm1 = nn.LayerNorm(dim, epsilon=norm_eps) + self.norm2 = nn.LayerNorm(dim, epsilon=norm_eps) + self.attn = VisionFlashAttention2( + dim=dim, + num_heads=num_heads, + tensor_parallel_degree=tensor_parallel_degree, + model_format=model_format, + ) + self.mlp = Qwen3VisionMLP( + dim=dim, + hidden_dim=mlp_hidden_dim, + hidden_act=hidden_act, + tensor_parallel_degree=tensor_parallel_degree, + model_format=model_format, + ) + + def forward( + self, + hidden_states: paddle.Tensor, + cu_seqlens: paddle.Tensor, + max_seqlen: int, + rotary_pos_emb: paddle.Tensor, + ) -> paddle.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens, + max_seqlen, + rotary_pos_emb, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen3VisionTransformerPretrainedModel(PretrainedModel): + """Qwen3 vision encoder.""" + + config_class = Qwen3VisionTransformerConfig + + def __init__(self, config, prefix_name: str = "") -> None: + vision_config = config.vision_config + super().__init__(vision_config) + self.prefix_name = prefix_name + self.spatial_merge_size = vision_config.spatial_merge_size + self.temporal_patch_size = vision_config.temporal_patch_size + self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + self.num_position_embeddings = vision_config.num_position_embeddings + self.num_grid_per_side = int(max(self.num_position_embeddings, 1) ** 0.5) + model_format = getattr(config, "model_format", "") + self.patch_embed = Qwen3VisionPatchEmbed( + patch_size=vision_config.patch_size, + temporal_patch_size=vision_config.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=vision_config.hidden_size, + model_format=model_format, + ) + self.pos_embed = nn.Embedding(self.num_position_embeddings, vision_config.hidden_size) + + head_dim = vision_config.hidden_size // vision_config.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.merger = Qwen3VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=vision_config.hidden_size, + spatial_merge_size=self.spatial_merge_size, + tensor_parallel_degree=config.pretrained_config.tensor_parallel_degree, + use_postshuffle_norm=False, + norm_eps=1e-6, + model_format=model_format, + ) + + self.deepstack_merger_list = nn.LayerList( + [ + Qwen3VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=vision_config.hidden_size, + spatial_merge_size=self.spatial_merge_size, + tensor_parallel_degree=config.pretrained_config.tensor_parallel_degree, + use_postshuffle_norm=True, + norm_eps=1e-6, + model_format=model_format, + ) + for _ in self.deepstack_visual_indexes + ] + ) + + self.blocks = nn.LayerList( + [ + Qwen3VisionBlock( + dim=vision_config.hidden_size, + num_heads=vision_config.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + hidden_act=vision_config.hidden_act, + tensor_parallel_degree=config.pretrained_config.tensor_parallel_degree, + norm_eps=1e-6, + model_format=model_format, + ) + for _ in range(vision_config.depth) + ] + ) + + self.out_hidden_size = vision_config.out_hidden_size * (1 + len(self.deepstack_visual_indexes)) + self._set_model_format_attrs(model_format) + + def _set_model_format_attrs(self, model_format): + if model_format is None: + return + for name, param in self.named_parameters(): + if "weight" in name and len(param.shape) == 2: + logger.info(f"[Vision] {name} need to be transposed weight.") + set_weight_attrs(param, {"weight_need_transpose": model_format == "torch"}) + + @property + def dtype(self) -> paddle.dtype: + return self.patch_embed.proj.weight.dtype + + def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> paddle.Tensor: + num_grid_per_side = self.num_grid_per_side + merge_size = self.spatial_merge_size + hidden_dim = self.pos_embed.weight.shape[-1] + outputs = [] + + for t, h, w in grid_thw: + h_idxs = paddle.linspace(0, num_grid_per_side - 1, h, dtype="float32") + w_idxs = paddle.linspace(0, num_grid_per_side - 1, w, dtype="float32") + + h_floor = paddle.floor(h_idxs).astype("int64") + w_floor = paddle.floor(w_idxs).astype("int64") + h_ceil = paddle.clip(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = paddle.clip(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - paddle.cast(h_floor, "float32") + dw = w_idxs - paddle.cast(w_floor, "float32") + + dh_grid, dw_grid = paddle.meshgrid(dh, dw) + h_floor_grid, w_floor_grid = paddle.meshgrid(h_floor, w_floor) + h_ceil_grid, w_ceil_grid = paddle.meshgrid(h_ceil, w_ceil) + + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1.0 - dh_grid - w01 + + h_grid = paddle.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) + w_grid = paddle.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) + h_grid_idx = h_grid * num_grid_per_side + + indices = (h_grid_idx + w_grid).reshape([4, -1]) + weights = paddle.stack([w00, w01, w10, w11], axis=0).reshape([4, -1, 1]).astype(self.dtype) + + embeds = self.pos_embed(indices) + weighted = embeds * weights + combined = weighted.sum(axis=0) + + combined = combined.reshape([h // merge_size, merge_size, w // merge_size, merge_size, hidden_dim]) + combined = combined.transpose([0, 2, 1, 3, 4]).reshape([1, -1, hidden_dim]) + combined = combined.tile([t, 1, 1]).reshape([-1, hidden_dim]) + outputs.append(combined) + + return paddle.concat(outputs, axis=0) + + def rot_pos_emb(self, grid_thw: list[list[int]]) -> paddle.Tensor: + pos_ids = [] + max_grid_size = 0 + for t, h, w in grid_thw: + max_grid_size = max(max_grid_size, h, w) + hpos_ids = paddle.arange(h).unsqueeze(1).tile([1, w]) + hpos_ids = hpos_ids.reshape( + [ + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ] + ) + hpos_ids = hpos_ids.transpose([0, 2, 1, 3]).reshape([-1]) + + wpos_ids = paddle.arange(w).unsqueeze(0).tile([h, 1]) + wpos_ids = wpos_ids.reshape( + [ + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ] + ) + wpos_ids = wpos_ids.transpose([0, 2, 1, 3]).reshape([-1]) + pos_ids.append(paddle.stack([hpos_ids, wpos_ids], axis=-1).tile([t, 1])) + + pos_ids = paddle.concat(pos_ids, axis=0) + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].reshape([pos_ids.shape[0], -1]) + return rotary_pos_emb + + def _build_cu_seqlens(self, grid_thw: list[list[int]]) -> paddle.Tensor: + grid_tensor = paddle.to_tensor(grid_thw, dtype="int32") + per_item = grid_tensor[:, 1] * grid_tensor[:, 2] + repeats = grid_tensor[:, 0] + per_frame = paddle.repeat_interleave(per_item, repeats) + cu_seqlens = paddle.cumsum(per_frame, axis=0) + cu_seqlens = paddle.concat([paddle.zeros([1], dtype="int32"), cu_seqlens]) + return cu_seqlens + + def compute_attn_mask_seqlen(self, cu_seqlens: paddle.Tensor) -> int: + if cu_seqlens.shape[0] <= 1: + return 0 + diffs = cu_seqlens[1:] - cu_seqlens[:-1] + return diffs.max().item() + + def forward(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor | list, num_pad: int = 0) -> paddle.Tensor: + if isinstance(grid_thw, paddle.Tensor): + grid_list = grid_thw.astype("int32").numpy().tolist() + else: + grid_list = grid_thw + + hidden_states = self.patch_embed(hidden_states) + pos_embeds = self.fast_pos_embed_interpolate(grid_list) + hidden_states = hidden_states + paddle.cast(pos_embeds, hidden_states.dtype) + rotary_pos_emb = self.rot_pos_emb(grid_list) + + cu_seqlens = self._build_cu_seqlens(grid_list) + max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) + + deepstack_features = [] + for layer_id, block in enumerate(self.blocks): + hidden_states = block(hidden_states, cu_seqlens, max_seqlen, rotary_pos_emb) + if layer_id in self.deepstack_visual_indexes: + ds_idx = self.deepstack_visual_indexes.index(layer_id) + deepstack_features.append(self.deepstack_merger_list[ds_idx](hidden_states)) + + hidden_states = self.merger(hidden_states) + if deepstack_features: + hidden_states = paddle.concat([hidden_states] + deepstack_features, axis=1) + return hidden_states + + def extract_feature(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor) -> paddle.Tensor: + return self.forward(hidden_states, grid_thw) + + @classmethod + def _get_tensor_parallel_mappings(cls, config, is_split=True): + from paddleformers.transformers.conversion_utils import split_or_merge_func + + from fastdeploy.model_executor.models.tp_utils import build_expanded_keys + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + ) + + vision_config = config.vision_config + tp_degree = getattr(config, "tensor_parallel_degree", 1) + tp_rank = getattr(config, "tensor_parallel_rank", 0) + + def split_qkv_weight(weight): + hidden = vision_config.hidden_size + head_dim = hidden // vision_config.num_heads + weight = weight.reshape([hidden, 3, vision_config.num_heads, head_dim]) + weight = np.split(weight, tp_degree, axis=2)[tp_rank] + return weight.reshape([hidden, -1]) + + def split_qkv_bias(bias): + head_dim = vision_config.hidden_size // vision_config.num_heads + bias = bias.reshape([3, vision_config.num_heads, head_dim]) + bias = np.split(bias, tp_degree, axis=1)[tp_rank] + return bias.reshape([-1]) + + base_actions = { + "visual.blocks.0.attn.proj.weight": partial(fn, is_column=False), + "visual.blocks.0.mlp.linear_fc1.weight": partial(fn, is_column=True), + "visual.blocks.0.mlp.linear_fc1.bias": partial(fn, is_column=True), + "visual.blocks.0.mlp.linear_fc2.weight": partial(fn, is_column=False), + "visual.blocks.0.attn.qkv.weight": split_qkv_weight, + "visual.blocks.0.attn.qkv.bias": split_qkv_bias, + "visual.merger.linear_fc1.weight": partial(fn, is_column=True), + "visual.merger.linear_fc1.bias": partial(fn, is_column=True), + "visual.merger.linear_fc2.weight": partial(fn, is_column=False), + } + + for idx in range(len(vision_config.deepstack_visual_indexes)): + base_actions[f"visual.deepstack_merger_list.{idx}.linear_fc1.weight"] = partial(fn, is_column=True) + base_actions[f"visual.deepstack_merger_list.{idx}.linear_fc1.bias"] = partial(fn, is_column=True) + base_actions[f"visual.deepstack_merger_list.{idx}.linear_fc2.weight"] = partial(fn, is_column=False) + + final_actions = {} + final_actions.update( + build_expanded_keys( + {k: v for k, v in base_actions.items() if "visual.blocks.0." in k}, + vision_config.depth, + ) + ) + for k, v in base_actions.items(): + if "visual.blocks.0." not in k: + final_actions[k] = v + return final_actions + + def load_state_dict(self, state_dict): + params_dict = dict(self.named_parameters()) + buffers_dict = dict(self.named_buffers()) + + prefix = f"{self.prefix_name}." if self.prefix_name else "" + + for name, param in params_dict.items(): + key = prefix + name + if key not in state_dict: + raise ValueError(f"Missing parameter {key} in state_dict") + tensor = get_tensor(state_dict.pop(key)) + if tensor.shape != param.shape: + raise ValueError(f"Shape mismatch for {key}: expected {param.shape}, got {tensor.shape}") + param.copy_(tensor, False) + + for name, buffer in buffers_dict.items(): + key = prefix + name + if key not in state_dict: + continue + tensor = get_tensor(state_dict.pop(key)) + if tensor.shape != buffer.shape: + raise ValueError(f"Shape mismatch for buffer {key}: expected {buffer.shape}, got {tensor.shape}") + buffer.copy_(tensor, False) diff --git a/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py b/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py new file mode 100644 index 00000000000..f0774e89f4d --- /dev/null +++ b/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py @@ -0,0 +1,484 @@ +from __future__ import annotations + +import re +from functools import partial +from typing import Dict, List, Optional, Union + +import numpy as np +import paddle +from paddle import nn +from paddleformers.transformers import PretrainedModel +from paddleformers.transformers.configuration_utils import PretrainedConfig +from paddleformers.utils.log import logger + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.layers.embeddings import VocabParallelEmbedding +from fastdeploy.model_executor.layers.lm_head import ParallelLMHead +from fastdeploy.model_executor.layers.normalization import RMSNorm +from fastdeploy.model_executor.layers.utils import get_tensor +from fastdeploy.model_executor.models.model_base import ( + ModelCategory, + ModelForCasualLM, + ModelRegistry, +) +from fastdeploy.model_executor.models.qwen3 import Qwen3DecoderLayer +from fastdeploy.model_executor.models.qwen3_vl.dfnrope.modeling import ( + Qwen3VisionTransformerPretrainedModel, +) +from fastdeploy.model_executor.models.tp_utils import TensorSplitMode as tsm +from fastdeploy.model_executor.models.utils import LayerIdPlaceholder as layerid +from fastdeploy.model_executor.models.utils import WeightMeta + + +# @support_graph_optimization +class Qwen3_VLModel(nn.Layer): + """Language backbone for Qwen3-VL.""" + + def __init__(self, fd_config: FDConfig) -> None: + super().__init__() + + self.num_layers = fd_config.model_config.num_hidden_layers + self.image_token_id = fd_config.model_config.image_token_id + self.video_token_id = fd_config.model_config.video_token_id + self._dtype = fd_config.model_config.dtype + fd_config.model_config.pretrained_config.prefix_name = "model" + self.fd_config = fd_config + + self.embed_tokens = VocabParallelEmbedding( + fd_config=fd_config, + num_embeddings=fd_config.model_config.vocab_size, + embedding_dim=fd_config.model_config.hidden_size, + params_dtype=paddle.get_default_dtype, + prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.embed_tokens", + ) + + self.layers = nn.LayerList( + [ + Qwen3DecoderLayer( + fd_config=fd_config, + prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.layers.{i}", + ) + for i in range(self.num_layers) + ] + ) + + self.norm = RMSNorm( + fd_config, + hidden_size=fd_config.model_config.hidden_size, + eps=fd_config.model_config.rms_norm_eps, + prefix=f"{fd_config.model_config.pretrained_config.prefix_name}.norm", + ) + + # model_format = getattr(fd_config.model_config, "model_format", None) + # self._set_model_format_attrs(model_format) + + # def _set_model_format_attrs(self, model_format): + # if model_format is None: + # return + # for name, param in self.named_parameters(): + # if "weight" in name and len(param.shape) == 2: + # logger.info(f"[Model] {name} need to be transposed weight.") + # set_weight_attrs(param, {"weight_need_transpose": model_format == "torch"}) + + def load_state_dict(self, state_dict): + self.embed_tokens.load_state_dict(state_dict) + self.norm.load_state_dict(state_dict) + for i in range(self.num_layers): + logger.info(f"Start load layer {i}") + self.layers[i].load_state_dict(state_dict) + + def get_input_embeddings(self, ids_remove_padding: paddle.Tensor) -> paddle.Tensor: + return self.embed_tokens(ids_remove_padding=ids_remove_padding) + + def forward( + self, + input_embeddings: paddle.Tensor, + ids_remove_padding: paddle.Tensor, + image_features: Optional[paddle.Tensor], + forward_meta: ForwardMeta, + deepstack_inputs: Optional[List[paddle.Tensor]] = None, + ) -> paddle.Tensor: + hidden_states = input_embeddings + residual = None + for layer_id, layer in enumerate(self.layers): + hidden_states, residual = layer( + forward_meta, + hidden_states, + residual, + ) + if deepstack_inputs is not None and layer_id < len(deepstack_inputs): + hidden_states = hidden_states + deepstack_inputs[layer_id] + logger.info(f"after block layers {hidden_states}, residual {residual}") + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +@ModelRegistry.register_model_class( + architecture="Qwen3VLForConditionalGeneration", + module_name="qwen3_vl.qwen3_vl", + category=ModelCategory.MULTIMODAL, + primary_use=ModelCategory.MULTIMODAL, +) +class Qwen3VLForConditionalGeneration(ModelForCasualLM): + def __init__(self, fd_config: FDConfig) -> None: + super().__init__(fd_config) + self.visual = self._init_vision_model(fd_config.model_config) + self.model = Qwen3_VLModel(fd_config=fd_config) + # token ids (convenience aliases) + self.image_token_id = fd_config.model_config.image_token_id + self.video_token_id = fd_config.model_config.video_token_id + self.context_hidden_size = fd_config.model_config.hidden_size + + vision_config = fd_config.model_config.vision_config + self.visual_hidden_size = vision_config.out_hidden_size + self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + + self.use_deepstack = hasattr(vision_config, "deepstack_visual_indexes") + self.deepstack_num_level = len(vision_config.deepstack_visual_indexes) if self.use_deepstack else 0 + self._deepstack_cache_capacity = fd_config.model_config.max_model_len if self.use_deepstack else 0 + self._deepstack_cache_len = 0 + self.deepstack_input_embeds: Optional[List[paddle.Tensor]] = None + if self.use_deepstack: + dtype = fd_config.model_config.dtype + cache_tokens = fd_config.scheduler_config.max_num_batched_tokens + self.deepstack_input_embeds = [ + paddle.zeros([cache_tokens, self.context_hidden_size], dtype=dtype) + for _ in range(self.deepstack_num_level) + ] + + self.visual_dim = vision_config.out_hidden_size + self.multiscale_dim = self.visual_dim * self.deepstack_num_level + + # self._input_embeddings = paddle.zeros( + # [fd_config.model_config.max_model_len, fd_config.model_config.hidden_size], + # dtype=fd_config.model_config.dtype, + # ) + + self.ori_vocab_size = fd_config.model_config.ori_vocab_size + self.lm_head = ParallelLMHead( + fd_config=fd_config, + embedding_dim=fd_config.model_config.hidden_size, + num_embeddings=fd_config.model_config.vocab_size, + prefix="lm_head", + ) + self.tie_word_embeddings = fd_config.model_config.tie_word_embeddings + self.fd_config = fd_config + + def _init_vision_model(self, model_config) -> nn.Layer: + visual = Qwen3VisionTransformerPretrainedModel(model_config, prefix_name="visual") + visual = paddle.amp.decorate(models=visual, level="O2", dtype="bfloat16") + visual.eval() + return visual + + @classmethod + def name(cls) -> str: + return "Qwen3VLForConditionalGeneration" + + @paddle.no_grad() + def load_weights(self, weights_iterator) -> None: + """Load model parameters from a given weights iterator.""" + + from fastdeploy.model_executor.utils import ( + default_weight_loader, + process_weights_after_loading, + ) + + stacked_params_mapping = [ + # (param_name, weight_name, expert_id, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("up_gate_proj", "gate_proj", "gate"), + ("up_gate_proj", "up_proj", "up"), + ("embed_tokens.embeddings", "embed_tokens", None), + ("lm_head.linear", "lm_head", None), + ("visual", "model.visual", None), + ] + + params_dict = dict(self.named_parameters()) + # params_name model.embed_tokens.embeddings.weight + # weight_name model.language_model.embed_tokens.weight + process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()), self.fd_config) + logger.info(f"[Qwen3-VL] params_dict names: {list(params_dict.keys())} ") + for loaded_weight_name, loaded_weight in weights_iterator: + loaded_weight_name = loaded_weight_name.replace(".language_model", "") + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in loaded_weight_name: + continue + model_param_name = loaded_weight_name.replace(weight_name, param_name) + logger.info( + f"[Qwen3-VL] loaded_weight_name: {loaded_weight_name}, weight_name {weight_name}, param_name {param_name}, model_param_name {model_param_name} 1" + ) + if model_param_name not in params_dict: + logger.info(f"[Qwen3-VL] {model_param_name} not in params_dict1") + continue + param = params_dict[model_param_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + weight_loader(param, loaded_weight, shard_id) + break + else: + model_param_name = loaded_weight_name + if model_param_name not in params_dict: + logger.info(f"[Qwen3-VL] {model_param_name} not in params_dict2") + continue + param = params_dict[model_param_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + weight_loader(param, loaded_weight) + + model_sublayer_name = re.sub(r"\.(weight)$", "", model_param_name) + process_weights_after_loading_fn(model_sublayer_name, param) + + if self.tie_word_embeddings: + weight_tensor = get_tensor(self.model.embed_tokens.embeddings.weight) + self.lm_head.linear.weight.set_value(weight_tensor) + + @paddle.no_grad() + def set_state_dict(self, state_dict: Dict[str, Union[np.ndarray, paddle.Tensor]]) -> None: + self.model.load_state_dict(state_dict) + self.visual.load_state_dict(state_dict) + if self.tie_word_embeddings: + self.lm_head.load_state_dict({self.lm_head.weight_key: self.model.embed_tokens.embeddings.weight}) + else: + self.lm_head.load_state_dict(state_dict) + + def compute_logits(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + logits = self.lm_head(hidden_states) + logits = paddle.cast(logits, paddle.float32) + logits[:, self.ori_vocab_size :] = -float("inf") + return logits + + def _set_deepstack_input_embeds(self, deepstack_input_embeds: paddle.Tensor) -> None: + num_tokens = deepstack_input_embeds.shape[1] + if num_tokens > self.deepstack_input_embeds[0].shape[0]: + self.deepstack_input_embeds = [ + paddle.zeros( + num_tokens, + self.context_hidden_size, + dtype=self.deepstack_input_embeds[0].dtype, + # device=self.deepstack_input_embeds[0].place, + ) + for _ in range(self.deepstack_num_level) + ] + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_tokens].copy_(deepstack_input_embeds[idx], False) + + def _get_deepstack_input_embeds(self, num_tokens: int) -> Optional[List[paddle.Tensor]]: + return [tensor[:num_tokens] for tensor in self.deepstack_input_embeds] + + def _clear_deepstack_input_embeds(self, num_token: int) -> None: + if num_token > 0: + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_token].zero_() + + def _compute_deepstack_embeds_v0( + self, + input_embeddings: paddle.Tensor, + image_features: paddle.Tensor, + image_mask: paddle.Tensor, + ): + """For only image inputs case""" + ( + mm_embeddings_main, + mm_embeddings_multiscale, + ) = paddle.spilit( + image_features, + [self.visual_dim, self.multiscale_dim], + dim=-1, + ) + + deepstack_input_embeds = input_embeddings.new_zeros( + size=[input_embeddings.shape[0], self.deepstack_num_level * input_embeddings.shape[1]], + ) + + deepstack_input_embeds[image_mask] = mm_embeddings_multiscale + deepstack_input_embeds = deepstack_input_embeds.view( + input_embeddings.shape[0], self.deepstack_num_level, self.visual_dim + ) + deepstack_input_embeds = deepstack_input_embeds.transpose([1, 0, 2]) + + return deepstack_input_embeds, mm_embeddings_main + + def _compute_deepstack_embeds(self): + pass + + def get_input_embeddings( + self, + ids_remove_padding: paddle.Tensor, + image_features: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + input_embeddings = self.model.get_input_embeddings(ids_remove_padding=ids_remove_padding) + + if image_features is None: + return input_embeddings + + image_mask = ids_remove_padding == self.image_token_id + + deepstack_input_embeds = None + + if self.use_deepstack: + ( + deepstack_input_embeds, + mm_embeddings, + ) = self._compute_deepstack_embeds_v0( + input_embeddings, + image_features, + image_mask, + ) + + input_embeddings[image_mask] = mm_embeddings + + if deepstack_input_embeds is not None: + self._set_deepstack_input_embeds(deepstack_input_embeds) + + return input_embeddings + + def forward( + self, + ids_remove_padding: paddle.Tensor, + image_features: Optional[paddle.Tensor], + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + logger.info(f"ids_remove_padding: {ids_remove_padding}") + input_embeddings = self.get_input_embeddings(ids_remove_padding, image_features) + logger.info(f"input_embeddings: {input_embeddings}") + # self._input_embeddings.copy_(input_embeddings, False) + deepstack_inputs = None + if self.use_deepstack: + deepstack_inputs = self._get_deepstack_input_embeds(input_embeddings.shape[0]) + + hidden_states = self.model( + input_embeddings=input_embeddings, + ids_remove_padding=ids_remove_padding, + image_features=image_features, + forward_meta=forward_meta, + deepstack_inputs=deepstack_inputs, + ) + + logger.info(f"hidden_states: {hidden_states}") + if self.use_deepstack: + self._clear_deepstack_input_embeds(input_embeddings.shape[0]) + + return hidden_states + + +class Qwen3_VLPretrainedModel(PretrainedModel): + """Utilities for tensor-parallel weight splitting.""" + + config_class = FDConfig + + def _init_weight(self, layer): + return None + + @classmethod + def arch_name(cls) -> str: + return "Qwen3VLForConditionalGeneration" + + weight_infos = [ + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.q_proj.weight", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.q_proj.bias", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.k_proj.weight", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.k_proj.bias", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.v_proj.weight", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.v_proj.bias", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.self_attn.o_proj.weight", False), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.mlp.gate_proj.weight", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.mlp.up_proj.weight", True), + WeightMeta(f".layers.{{{layerid.LAYER_ID}}}.mlp.down_proj.weight", False), + WeightMeta(".embed_tokens.weight", False), + WeightMeta("lm_head.weight", True), + ] + + weight_vision = [ + WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.attn.proj.weight", False), + WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.linear_fc1.weight", True), + WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.linear_fc1.bias", True), + WeightMeta(f"visual.blocks.{{{layerid.LAYER_ID}}}.mlp.linear_fc2.weight", False), + WeightMeta( + f"visual.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.weight", + True, + tsm.GQA, + ), + WeightMeta( + f"visual.blocks.{{{layerid.LAYER_ID}}}.attn.qkv.bias", + True, + tsm.GQA, + ), + WeightMeta("visual.merger.linear_fc1.weight", True), + WeightMeta("visual.merger.linear_fc1.bias", True), + WeightMeta("visual.merger.linear_fc2.weight", False), + ] + + @classmethod + def _get_tensor_parallel_mappings(cls, config: PretrainedConfig, is_split: bool = True): + logger.info("qwen3_vl inference model _get_tensor_parallel_mappings") + from fastdeploy.model_executor.models.tp_utils import ( + build_expanded_keys, + has_prefix, + split_or_merge_func_v1, + ) + + fn = split_or_merge_func_v1( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + ) + + vision_num_heads = config.vision_config.get("num_heads") + vision_hidden = config.vision_config.get("hidden_size") + vision_head_dim = vision_hidden // vision_num_heads + vision_fn = split_or_merge_func_v1( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=vision_num_heads, + num_key_value_heads=vision_num_heads, + head_dim=vision_head_dim, + ) + + def get_tensor_parallel_split_mappings(num_layers: int, prefix_name: str): + base_actions = {} + for weight_name, is_column, extra in cls.weight_infos: + params = {"is_column": is_column, **({extra.value: True} if extra else {})} + + if "lm_head.weight" in weight_name or weight_name.startswith("."): + key = weight_name + elif not has_prefix(prefix_name, weight_name): + key = f"{prefix_name}{weight_name}" + else: + key = weight_name + base_actions[key] = partial(fn, **params) + + return build_expanded_keys(base_actions, num_layers) + + def get_vision_parallel_split_mappings(num_layers: int, deepstack_count: int): + base_actions = {} + for weight_name, is_column, extra in cls.weight_vision: + params = {"is_column": is_column, **({extra.value: True} if extra else {})} + base_actions[weight_name] = partial(vision_fn, **params) + + actions = build_expanded_keys( + {k: v for k, v in base_actions.items() if "visual.blocks." in k}, + num_layers, + ) + + for key, action in base_actions.items(): + if "visual.blocks." not in key: + actions[key] = action + + for idx in range(deepstack_count): + actions[f"visual.deepstack_merger_list.{idx}.linear_fc1.weight"] = partial(vision_fn, is_column=True) + actions[f"visual.deepstack_merger_list.{idx}.linear_fc1.bias"] = partial(vision_fn, is_column=True) + actions[f"visual.deepstack_merger_list.{idx}.linear_fc2.weight"] = partial(vision_fn, is_column=False) + return actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers, config.prefix_name) + vision_depth = config.vision_config.get("depth", 0) + deepstack_count = len(config.vision_config.get("deepstack_visual_indexes", [])) + vision_mappings = get_vision_parallel_split_mappings(vision_depth, deepstack_count) + + mappings.update(vision_mappings) + return mappings diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index e0434106165..523c685a7fc 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -264,7 +264,8 @@ def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): output_dim = getattr(param, "output_dim", None) weight_need_transpose = getattr(param, "weight_need_transpose", False) - if weight_need_transpose: + if weight_need_transpose and loaded_weight.ndim == 2: + logger.info(f"[Torch] {param.name}.weight need transpose, from {loaded_weight.shape} to {param.shape}") loaded_weight = loaded_weight.transpose([1, 0]) # Tensor parallelism splits the weight along the output_dim if output_dim is not None and fd_config is not None and fd_config.parallel_config.tensor_parallel_size > 1: @@ -282,6 +283,7 @@ def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): loaded_weight = fd_cast(loaded_weight, param) if param.shape != loaded_weight.shape: # for e_score_correction_bias + logger.info(f"[Torch] {param.name} weight reshaped") loaded_weight = loaded_weight.reshape(param.shape) assert param.shape == loaded_weight.shape, ( f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 5b30babffd4..a4b9ea30384 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1804,6 +1804,7 @@ def _dummy_run( self._dummy_pooler_run(hidden_states, model_output) break else: + logger.info(f"model_output shape: {model_output.shape}") hidden_states = rebuild_padding( model_output, self.share_inputs["cu_seqlens_q"], From 9083b4c4110e0b350b90d5b47edd852acbcd4cce Mon Sep 17 00:00:00 2001 From: Ayakouji Date: Fri, 5 Dec 2025 15:20:26 +0800 Subject: [PATCH 2/3] remove useless code --- fastdeploy/model_executor/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 523c685a7fc..bbbd24d62de 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -264,7 +264,7 @@ def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): output_dim = getattr(param, "output_dim", None) weight_need_transpose = getattr(param, "weight_need_transpose", False) - if weight_need_transpose and loaded_weight.ndim == 2: + if weight_need_transpose: logger.info(f"[Torch] {param.name}.weight need transpose, from {loaded_weight.shape} to {param.shape}") loaded_weight = loaded_weight.transpose([1, 0]) # Tensor parallelism splits the weight along the output_dim From 87af261b37e4e0537f4183b38b2f668cfcabcb2b Mon Sep 17 00:00:00 2001 From: Ayakouji Date: Fri, 5 Dec 2025 17:44:42 +0800 Subject: [PATCH 3/3] remove useless --- fastdeploy/model_executor/models/qwen3_vl/dfnrope/modeling.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/fastdeploy/model_executor/models/qwen3_vl/dfnrope/modeling.py b/fastdeploy/model_executor/models/qwen3_vl/dfnrope/modeling.py index 77ee3ff47a9..f1a31425ace 100644 --- a/fastdeploy/model_executor/models/qwen3_vl/dfnrope/modeling.py +++ b/fastdeploy/model_executor/models/qwen3_vl/dfnrope/modeling.py @@ -1,7 +1,3 @@ -""" -Qwen3 vision encoder implementation for FastDeploy. -""" - from __future__ import annotations from functools import partial