Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,13 @@ def __init__(
pretrained_config, _ = PretrainedConfig.get_config_dict(self.model)
self.pretrained_config = PretrainedConfig.from_dict(pretrained_config)

# Some exported configs (e.g. Qwen3-VL) embed the text model's configuration under a `text_config` key.
if "text_config" in pretrained_config and isinstance(pretrained_config["text_config"], dict):
text_fg = pretrained_config.pop("text_config")
for key, value in text_fg.items():
if not hasattr(self, key):
setattr(self, key, value)

# set attribute from pretrained_config
for key, value in pretrained_config.items():
setattr(self, key, value)
Expand Down Expand Up @@ -325,7 +332,13 @@ def reset_config_value(key, value):
def read_model_config(self):
config_path = os.path.join(self.model, "config.json")
if os.path.exists(config_path):
self.model_config = json.load(open(config_path, "r", encoding="utf-8"))
raw_cfg = json.load(open(config_path, "r", encoding="utf-8"))
if "text_config" in raw_cfg and isinstance(raw_cfg["text_config"], dict):
text_cfg = raw_cfg.pop("text_config")
for k, v in text_cfg.items():
if k not in raw_cfg:
raw_cfg[k] = v
self.model_config = raw_cfg
if "torch_dtype" in self.model_config and "dtype" in self.model_config:
raise ValueError(
"Only one of 'torch_dtype' or 'dtype' should be present in config.json. "
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/input/qwen_vl_processor/qwen_vl_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
self.processor = DataProcessor(
model_path=model_name_or_path,
enable_processor_cache=enable_processor_cache,
tokens_per_second=config.vision_config.tokens_per_second,
# tokens_per_second=config.vision_config.tokens_per_second,
tokenizer=self.tokenizer,
**processor_kwargs,
)
Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def auto_models_registry(dir_path, register_path="fastdeploy.model_executor.mode
):
ModelRegistry.register_pretrained_model(attr)

except ImportError:
raise ImportError(f"{module_file=} import error")
except Exception as e:
raise ImportError(f"{module_file=} import error, error message: {e}")


auto_models_registry(os.path.dirname(__file__))
Expand Down
Empty file.
Empty file.
142 changes: 142 additions & 0 deletions fastdeploy/model_executor/models/qwen3_vl/dfnrope/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import math
from collections import OrderedDict

import paddle
import paddle.nn.functional as F
from paddle import Tensor, nn


class NewGELUActivation(nn.Layer):
"""Google BERT style GELU."""

def forward(self, input: Tensor) -> Tensor:
return (
0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0))))
)


class GELUActivation(nn.Layer):
"""Original GELU implementation."""

def __init__(self, use_gelu_python: bool = False):
super().__init__()
self.act = self._gelu_python if use_gelu_python else nn.functional.gelu

def _gelu_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + paddle.erf(input / math.sqrt(2.0)))

def forward(self, input: Tensor) -> Tensor:
return self.act(input)


class FastGELUActivation(nn.Layer):
"""Fast GELU approximation."""

def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))


class QuickGELUActivation(nn.Layer):
"""Quick GELU approximation."""

def forward(self, input: Tensor) -> Tensor:
return input * F.sigmoid(1.702 * input)


class ClippedGELUActivation(nn.Layer):
"""Clipped GELU used by some quantized models."""

def __init__(self, min: float, max: float):
if min > max:
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
super().__init__()
self.min = min
self.max = max

def forward(self, x: Tensor) -> Tensor:
return paddle.clip(gelu(x), self.min, self.max)


class SiLUActivation(nn.Layer):
"""SiLU / Swish activation."""

def forward(self, input: Tensor) -> Tensor:
return F.silu(input)


class MishActivation(nn.Layer):
"""Mish activation."""

def forward(self, input: Tensor) -> Tensor:
return F.mish(input)


class LinearActivation(nn.Layer):
"""Identity activation."""

def forward(self, input: Tensor) -> Tensor:
return input


class ClassInstantier(OrderedDict):
"""Instantiates layers lazily when accessed."""

def __getitem__(self, key):
content = super().__getitem__(key)
cls, kwargs = content if isinstance(content, tuple) else (content, {})
return cls(**kwargs)


ACT2CLS = {
"gelu": GELUActivation,
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
"gelu_fast": FastGELUActivation,
"gelu_new": NewGELUActivation,
"gelu_tanh": (nn.GELU, {"approximate": "tanh"}),
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
"linear": LinearActivation,
"mish": MishActivation,
"quick_gelu": QuickGELUActivation,
"relu": nn.ReLU,
"relu6": nn.ReLU6,
"sigmoid": nn.Sigmoid,
"silu": SiLUActivation,
"swish": SiLUActivation,
"tanh": nn.Tanh,
}
ACT2FN = ClassInstantier(ACT2CLS)


def get_activation_fn(hidden_act: str):
if hidden_act == "gelu_pytorch_tanh":
return ACT2FN["gelu_tanh"]
if hidden_act in ACT2FN:
return ACT2FN[hidden_act]
raise KeyError(f"function {hidden_act} not found in ACT2FN mapping {list(ACT2FN.keys())}")


# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation_fn("gelu_python")
gelu_new = get_activation_fn("gelu_new")
gelu = get_activation_fn("gelu")
gelu_fast = get_activation_fn("gelu_fast")
quick_gelu = get_activation_fn("quick_gelu")
silu = get_activation_fn("silu")
mish = get_activation_fn("mish")
linear_act = get_activation_fn("linear")
64 changes: 64 additions & 0 deletions fastdeploy/model_executor/models/qwen3_vl/dfnrope/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

from __future__ import annotations

from paddleformers.transformers.configuration_utils import PretrainedConfig

__all__ = [
"Qwen3VisionTransformerConfig",
]


class Qwen3VisionTransformerConfig(PretrainedConfig):
r"""Configuration for the Qwen3 vision encoder used in Qwen3-VL."""

model_type = "qwen3_vision_transformer"

def __init__(
self,
depth: int = 27,
hidden_size: int = 1152,
hidden_act: str = "gelu_tanh",
intermediate_size: int = 4304,
num_heads: int = 16,
in_channels: int = 3,
patch_size: int = 16,
spatial_merge_size: int = 2,
temporal_patch_size: int = 2,
out_hidden_size: int = 3584,
num_position_embeddings: int = 2304,
deepstack_visual_indexes: list[int] | None = None,
initializer_range: float = 0.02,
tokens_per_second: int = 2,
**kwargs,
) -> None:
super().__init__(**kwargs)

self.depth = depth
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.out_hidden_size = out_hidden_size
self.num_position_embeddings = num_position_embeddings
self.initializer_range = initializer_range
self.deepstack_visual_indexes = list(deepstack_visual_indexes or [])
self.tokens_per_second = tokens_per_second
Loading
Loading