From f07c584bfa4d6a6b3723f4d9dad93c2d369d139e Mon Sep 17 00:00:00 2001 From: ming1753 Date: Tue, 28 Oct 2025 18:51:02 +0800 Subject: [PATCH 1/6] fix paddleocr prefix cache bug --- .../paddleocr_vl_processor.py | 38 +- .../input/paddleocr_vl_processor/process.py | 358 ++++++++++++------ .../paddleocr_vl_processor/process_video.py | 82 ++++ 3 files changed, 348 insertions(+), 130 deletions(-) create mode 100644 fastdeploy/input/paddleocr_vl_processor/process_video.py diff --git a/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py b/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py index d9da0ef3828..2e9e680c0b5 100644 --- a/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py +++ b/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py @@ -47,6 +47,7 @@ def __init__( mm_processor_kwargs=None, reasoning_parser_obj=None, tool_parser_obj=None, + enable_processor_cache=False, ): """ Initialize PaddleOCRVLProcessor instance. @@ -65,6 +66,7 @@ def __init__( processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs) self.processor = DataProcessor( model_path=model_name_or_path, + enable_processor_cache=enable_processor_cache, tokens_per_second=config.vision_config.tokens_per_second, tokenizer=self.tokenizer, **processor_kwargs, @@ -252,7 +254,7 @@ def process_request_dict(self, request, max_model_len=None): return request - def append_generated_tokens(self, outputs, generated_token_ids): + def append_generated_tokens(self, multimodal_inputs, generated_token_ids): """ Append generated tokens to existing outputs. @@ -260,19 +262,13 @@ def append_generated_tokens(self, outputs, generated_token_ids): outputs: Current model outputs generated_token_ids: Generated tokens to append """ - out = {"input_ids": [], "token_type_ids": [], "position_ids": [], "cur_position": outputs["cur_position"]} - self.processor._add_text(generated_token_ids, out) + num_tokens = len(generated_token_ids) + multimodal_inputs["input_ids"].extend(generated_token_ids) + multimodal_inputs["token_type_ids"].extend([0] * num_tokens) - outputs["input_ids"] = np.concatenate( - [outputs["input_ids"], np.array(out["input_ids"], dtype=np.int64)], axis=0 - ) - outputs["token_type_ids"] = np.concatenate( - [outputs["token_type_ids"], np.array(out["token_type_ids"], dtype=np.int64)], axis=0 - ) - outputs["position_ids"] = np.concatenate( - [outputs["position_ids"], out["position_ids"][0]], axis=1, dtype=np.int64 - ) - outputs["cur_position"] = out["cur_position"] + pos_ids = self.processor._compute_text_positions(multimodal_inputs["cur_position"], num_tokens) + multimodal_inputs["position_ids"].append(pos_ids) + multimodal_inputs["cur_position"] += num_tokens def pack_outputs(self, outputs): """ @@ -284,6 +280,22 @@ def pack_outputs(self, outputs): Returns: dict: Packed output dictionary with all required fields """ + if not outputs["images"]: + outputs["images"] = None # No images case + outputs["grid_thw"] = None # No spatial dimensions + outputs["image_type_ids"] = None # No type IDs + else: + outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically + outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions + outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array + + # Convert all outputs to numpy arrays with appropriate types + outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64 + outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64 + outputs["position_ids"] = np.concatenate( + outputs["position_ids"], axis=1, dtype=np.int64 + ) # Concatenate position ID + outputs["image_patch_id"] = self.processor.image_token_id outputs["video_patch_id"] = self.processor.video_token_id outputs["position_ids"] = outputs["position_ids"].transpose(1, 0) diff --git a/fastdeploy/input/paddleocr_vl_processor/process.py b/fastdeploy/input/paddleocr_vl_processor/process.py index 436721b52c6..67acbad4488 100644 --- a/fastdeploy/input/paddleocr_vl_processor/process.py +++ b/fastdeploy/input/paddleocr_vl_processor/process.py @@ -15,16 +15,23 @@ # limitations under the License. """ -from typing import Any, Dict, List, Union +import pickle +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +import zmq from paddleformers.transformers import AutoTokenizer +from PIL import Image +from fastdeploy.engine.request import ImagePosition from fastdeploy.entrypoints.chat_utils import parse_chat_messages +from fastdeploy.input.ernie4_5_vl_processor import read_video_decord from fastdeploy.input.utils import IDS_TYPE_FLAG +from fastdeploy.multimodal.hasher import MultimodalHasher from fastdeploy.utils import data_processor_logger from .image_processor import ImageProcessor +from .process_video import sample_frames class DataProcessor: @@ -48,8 +55,11 @@ class DataProcessor: def __init__( self, model_path: str, + enable_processor_cache: bool = False, video_min_frames: int = 4, video_max_frames: int = 768, + video_target_frames: int = -1, + video_fps: int = -1, tokens_per_second: int = 2, tokenizer=None, **kwargs, @@ -66,6 +76,8 @@ def __init__( """ self.min_frames = video_min_frames self.max_frames = video_max_frames + self.target_frames = video_target_frames + self.fps = video_fps # Initialize tokenizer with left padding and fast tokenizer if tokenizer is None: @@ -74,13 +86,13 @@ def __init__( else: self.tokenizer = tokenizer self.image_processor = ImageProcessor.from_pretrained(model_path) # Initialize image processor + self.enable_processor_cache = enable_processor_cache # Convolution sizes for patch aggregation self.spatial_conv_size = self.image_processor.merge_size self.temporal_conv_size = self.image_processor.temporal_patch_size # Special tokens and IDs - self.image_token = "<|IMAGE_PLACEHOLDER|>" self.video_token = "<|video_pad|>" @@ -100,41 +112,7 @@ def __init__( "assistant": "Assistant: ", } - def _pack_outputs(self, outputs): - """ - Pack and convert all output data into numpy arrays with appropriate types. - - Args: - outputs (dict): Dictionary containing model outputs with keys: - - images: List of visual features - - grid_thw: List of spatial dimensions - - image_type_ids: List of content type indicators - - input_ids: List of token IDs - - token_type_ids: List of type identifiers - - position_ids: List of position embeddings - - Returns: - dict: Processed outputs with all values converted to numpy arrays - """ - # Process visual outputs - stack if exists or set to None if empty - if not outputs["images"]: - outputs["images"] = None # No images case - outputs["grid_thw"] = None # No spatial dimensions - outputs["image_type_ids"] = None # No type IDs - else: - outputs["images"] = np.vstack(outputs["images"]) # Stack image features vertically - outputs["grid_thw"] = np.vstack(outputs["grid_thw"]) # Stack spatial dimensions - outputs["image_type_ids"] = np.array(outputs["image_type_ids"]) # Convert to numpy array - - # Convert all outputs to numpy arrays with appropriate types - outputs["input_ids"] = np.array(outputs["input_ids"], dtype=np.int64) # Token IDs as int64 - outputs["token_type_ids"] = np.array(outputs["token_type_ids"], dtype=np.int64) # Type IDs as int64 - outputs["position_ids"] = np.concatenate( - outputs["position_ids"], axis=1, dtype=np.int64 - ) # Concatenate position IDs - return outputs - - def text2ids(self, text, images=None, videos=None): + def text2ids(self, text, images=None, videos=None, image_uuid=None, video_uuid=None): """ Convert text with image/video placeholders into model inputs. @@ -142,6 +120,8 @@ def text2ids(self, text, images=None, videos=None): text: Input text with <|image@placeholder|> and <|video@placeholder|> markers images: List of PIL Images corresponding to image placeholders videos: List of video data corresponding to video placeholders + image_uuid: List of unique identifiers for each image, used for caching or hashing. + video_uuid: List of unique identifiers for each video, used for caching or hashing. Returns: Dict containing: @@ -162,11 +142,14 @@ def text2ids(self, text, images=None, videos=None): "image_type_ids": [], "labels": [], "cur_position": 0, - "pic_cnt": 0, "video_cnt": 0, + "fps": [], + "mm_positions": [], + "mm_hashes": [], "vit_seqlen": [], "vit_position_ids": [], } + # Define placeholders and their lengths IMAGE_PLACEHOLDER = self.image_token VIDEO_PLACEHOLDER = self.video_token @@ -188,23 +171,30 @@ def text2ids(self, text, images=None, videos=None): break if ed == image_pos: - outputs["pic_cnt"] += 1 - self._add_image(images[image_idx], outputs) + image = images[image_idx] + uuid = image_uuid[image_idx] if image_uuid else None + if not isinstance(image, tuple): + self._add_image(image, outputs, uuid) + else: + self._add_processed_image(image, outputs, uuid) image_idx += 1 st = ed + IMAGE_PLACEHOLDER_LEN else: item = videos[video_idx] - if isinstance(item, dict): - frames, meta = self._load_and_process_video(item["video"], item) + uuid = video_uuid[video_idx] if video_uuid else None + if not isinstance(item, tuple): + if isinstance(item, dict): + frames, meta = self._load_and_process_video(item["video"], item) + else: + frames, meta = self._load_and_process_video(item, {}) + self._add_video(frames, meta, outputs, uuid) else: - frames, meta = self._load_and_process_video(item, {}) - - outputs["video_cnt"] += 1 - self._add_video(frames, meta, outputs) + # cached frames are already processed + self._add_processed_video(item, outputs, uuid) video_idx += 1 st = ed + VIDEO_PLACEHOLDER_LEN - return self._pack_outputs(outputs) + return outputs def request2ids( self, request: Dict[str, Any], tgts: List[str] = None @@ -222,76 +212,84 @@ def request2ids( Dict with same structure as text2ids() output """ - outputs = { - "input_ids": [], - "token_type_ids": [], - "position_ids": [], - "images": [], - "grid_thw": [], - "image_type_ids": [], - "labels": [], - "cur_position": 0, - "pic_cnt": 0, - "video_cnt": 0, - "vit_seqlen": [], - "vit_position_ids": [], - } - # Parse and validate chat messages messages = parse_chat_messages(request.get("messages")) - image_message_list = [] # Store visual content messages - + mm_items = [] for msg in messages: role = msg.get("role") assert role in self.role_prefixes, f"Unsupported role: {role}" # Normalize content to list format - content_items = msg.get("content") - if not isinstance(content_items, list): - content_items = [content_items] - + content = msg.get("content") + if not isinstance(content, list): + content = [content] # Collect all visual content items - for item in content_items: - if isinstance(item, dict) and item.get("type") in ["image", "video"]: - image_message_list.append(item) - - raw_messages = request["messages"] - request["messages"] = messages - - prompt_token_ids = self.apply_chat_template(request) - if len(prompt_token_ids) == 0: - raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs") - request["messages"] = raw_messages - - vision_start_index = 0 - vision_message_index = 0 - for i in range(len(prompt_token_ids)): - if prompt_token_ids[i] == self.vision_start_id: - self._add_text(prompt_token_ids[vision_start_index : i + 1], outputs) - - vision_start_index = i + 1 - image_message = image_message_list[vision_message_index] - - if image_message["type"] == "image": - img = image_message.get("image") - if img is None: - continue - outputs["pic_cnt"] += 1 - self._add_image(img, outputs) - - elif image_message["type"] == "video": - video_bytes = image_message.get("video") - if video_bytes is None: - continue - frames, meta = self._load_and_process_video(video_bytes, image_message) + for item in content: + if item.get("type") in ["image", "video"]: + mm_items.append(item) + + missing_hashes, missing_idx = [], [] + for idx, item in enumerate(mm_items): + if not item.get("data"): + # raw data not provided, should be retrieved from processor cache + missing_hashes.append(item.get("uuid")) + missing_idx.append(idx) + + if len(missing_hashes) > 0 and not self.enable_processor_cache: + raise ValueError("Missing items cannot be retrieved without processor cache.") + + if self.enable_processor_cache: + context = zmq.Context() + dealer = context.socket(zmq.DEALER) + dealer.connect("ipc:///dev/shm/processor_cache.ipc") + + missing_items = self.get_processor_cache(dealer, missing_hashes) + for idx in range(len(missing_items)): + if not missing_items[idx]: + raise ValueError(f"Missing item {idx} not found in processor cache") + mm_items[missing_idx[idx]]["data"] = missing_items[idx] + + images, videos = [], [] + image_uuid, video_uuid = [], [] + for item in mm_items: + if item.get("type") == "image": + images.append(item["data"]) + image_uuid.append(item["uuid"]) + elif item.get("type") == "video": + videos.append(item["data"]) + video_uuid.append(item["uuid"]) + else: + raise ValueError(f"Unsupported multimodal type: {item.get('type')}") - outputs["video_cnt"] += 1 - self._add_video(frames, meta, outputs) + if self.tokenizer.chat_template is None: + raise ValueError("This model does not support chat template.") - vision_message_index += 1 + chat_template_kwargs = request.get("chat_template_kwargs", {}) + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=request.get("add_generation_prompt", True), + **chat_template_kwargs, + ) + request["prompt_tokens"] = prompt + + outputs = self.text2ids(prompt, images, videos, image_uuid, video_uuid) + + if self.enable_processor_cache: + missing_idx = set(missing_idx) + hashes_to_cache, items_to_cache = [], [] + for idx in range(len(mm_items)): + if idx in missing_idx: + continue + meta = {} + t, h, w = outputs["grid_thw"][idx] + meta["thw"] = (t, h, w) + meta["fps"] = outputs["fps"][idx] + hashes_to_cache.append(outputs["mm_hashes"][idx]) + items_to_cache.append((outputs["images"][idx], meta)) + self.update_processor_cache(dealer, hashes_to_cache, items_to_cache) - self._add_text(prompt_token_ids[vision_start_index:], outputs) - return self._pack_outputs(outputs) + return outputs def _add_text(self, tokens, outputs: Dict) -> None: """ @@ -316,9 +314,9 @@ def _add_text(self, tokens, outputs: Dict) -> None: outputs["input_ids"].extend(tokens) outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * num_tokens) - position_ids = self._compute_text_positions(outputs["cur_position"], num_tokens) - outputs["position_ids"].append(position_ids) - outputs["cur_position"] = position_ids.max() + 1 + pos_ids = self._compute_text_positions(outputs["cur_position"], num_tokens) + outputs["position_ids"].append(pos_ids) + outputs["cur_position"] = pos_ids.max() + 1 def _compute_text_positions(self, start_pos: int, num_tokens: int) -> np.ndarray: """ @@ -336,7 +334,7 @@ def _compute_text_positions(self, start_pos: int, num_tokens: int) -> np.ndarray position = text_index + start_pos return position - def _add_image(self, img, outputs: Dict) -> None: + def _add_image(self, img, outputs: Dict, uuid: Optional[str]) -> None: """ Add image data to model inputs dictionary. @@ -353,23 +351,49 @@ def _add_image(self, img, outputs: Dict) -> None: num_tokens = ret["grid_thw"].prod() // self.image_processor.merge_size**2 grid_thw = ret["grid_thw"].tolist() + outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) outputs["input_ids"].extend([self.image_token_id] * num_tokens) outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens) outputs["images"].append(ret["pixel_values"]) + if not uuid: + outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"])) + else: + outputs["mm_hashes"].append(uuid) outputs["grid_thw"].append(grid_thw) outputs["image_type_ids"].append(0) # position_ids t, h, w = grid_thw - position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0) - outputs["position_ids"].append(position_ids) - outputs["cur_position"] = position_ids.max() + 1 + pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, 0) + outputs["position_ids"].append(pos_ids) + outputs["cur_position"] = pos_ids.max() + 1 + outputs["fps"].append(0) numel = h * w outputs["vit_seqlen"].append(numel) outputs["vit_position_ids"].append(np.arange(numel) % numel) - def _add_video(self, frames, meta: Dict, outputs: Dict) -> None: + def _add_processed_image(self, img_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None: + img, meta = img_cache + num_tokens = img.shape[0] // self.image_processor.merge_size**2 + + outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) + outputs["input_ids"].extend([self.image_patch_id] * num_tokens) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens) + + _, h, w = meta["thw"] + pos_ids = self._compute_vision_positions(outputs["cur_position"], 1, h, w, 0) + outputs["position_ids"].append(pos_ids) + outputs["cur_position"] = pos_ids.max() + 1 + + outputs["images"].append(img) + outputs["mm_hashes"].append(uuid) + outputs["grid_thw"].append(np.array([[1, h, w]])) + outputs["image_type_ids"].append(0) + + outputs["fps"].append(0) + + def _add_video(self, frames, meta: Dict, outputs: Dict, uuid: Optional[str]) -> None: """ Add video data to model inputs dictionary. @@ -387,24 +411,52 @@ def _add_video(self, frames, meta: Dict, outputs: Dict) -> None: num_tokens = ret["image_grid_thw"].prod() // self.image_processor.merge_size**2 grid_thw = ret["image_grid_thw"].tolist() + outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) outputs["input_ids"].extend([self.video_token_id] * num_tokens) outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens) outputs["images"].append(ret["pixel_values"]) + if not uuid: + outputs["mm_hashes"].append(MultimodalHasher.hash_features(ret["pixel_values"])) + else: + outputs["mm_hashes"].append(uuid) outputs["grid_thw"].append(grid_thw) outputs["image_type_ids"].extend([1] * grid_thw[0]) fps = meta["fps"] second_per_grid_t = self.temporal_conv_size / fps t, h, w = grid_thw - position_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t) + pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t) - outputs["position_ids"].append(position_ids) - outputs["cur_position"] = position_ids.max() + 1 + outputs["position_ids"].append(pos_ids) + outputs["cur_position"] = pos_ids.max() + 1 + outputs["fps"].append(fps) numel = h * w outputs["vit_seqlen"].append(numel) outputs["vit_position_ids"].append(np.arange(numel) % numel) + def _add_processed_video(self, frames_cache: Tuple[np.ndarray, dict], outputs: Dict, uuid: str) -> None: + frames, meta = frames_cache + num_tokens = frames.shape[0] // self.image_processor.merge_size**2 + + t, h, w = meta["thw"] + outputs["images"].append(frames) + outputs["mm_hashes"].append(uuid) + outputs["grid_thw"].append(np.array([[t, h, w]])) + + outputs["mm_positions"].append(ImagePosition(len(outputs["input_ids"]), num_tokens)) + outputs["input_ids"].extend([self.image_patch_id] * num_tokens) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens) + outputs["image_type_ids"].extend([1] * t) + + fps = meta["fps"] + second_per_grid_t = self.temporal_conv_size / fps + pos_ids = self._compute_vision_positions(outputs["cur_position"], t, h, w, second_per_grid_t) + outputs["position_ids"].append(pos_ids) + outputs["cur_position"] = pos_ids.max() + 1 + + outputs["fps"].append(fps) + def _compute_vision_positions( self, start_pos: int, t: int, h: int, w: int, second_per_grid_t: float ) -> np.ndarray: @@ -438,6 +490,78 @@ def _compute_vision_positions( position = np.stack([t_index, h_index, w_index]) + start_pos return position + def _load_and_process_video(self, url: str, item: Dict) -> Tuple[np.ndarray, Dict]: + """ + Load and preprocess video into frames. + + Args: + url: Video file path or bytes + item: Dictionary containing processing parameters + + Returns: + tuple: (frames, metadata) where: + - frames: Processed video frames as numpy array + - metadata: Updated video metadata dictionary + """ + reader, meta, _ = read_video_decord(url, save_to_disk=False) + + # Apply frame sampling if fps or target_frames specified + fps = item.get("fps", self.fps) + num_frames = item.get("target_frames", self.target_frames) + + frame_indices = list(range(meta["num_of_frame"])) + if fps > 0 or num_frames > 0: + # Get frame sampling constraints + min_frames = item.get("min_frames", self.min_frames) + max_frames = item.get("max_frames", self.max_frames) + + # Sample frames according to specifications + frame_indices = sample_frames( + frame_factor=self.temporal_conv_size, # Ensure divisible by temporal patch size + min_frames=min_frames, + max_frames=max_frames, + metadata=meta, + fps=fps, + num_frames=num_frames, + ) + + # Update metadata with new frame count and fps + meta["num_of_frame"] = len(frame_indices) + if fps is not None: + meta["fps"] = fps # Use specified fps + meta["duration"] = len(frame_indices) / fps + else: + meta["fps"] = len(frame_indices) / meta["duration"] # Calculate fps from sampled frames + + frames = [] + for idx in frame_indices: + frame = reader[idx].asnumpy() + image = Image.fromarray(frame, "RGB") + frames.append(image) + frames = np.stack([np.array(f.convert("RGB")) for f in frames], axis=0) + + return frames, meta + + def get_processor_cache(self, socket, mm_hashes: list[str]) -> list: + """ + get cache correspond to given hash values + """ + req = pickle.dumps(mm_hashes) + socket.send_multipart([b"", req]) + _, resp = socket.recv_multipart() + mm_items = pickle.loads(resp) + data_processor_logger.info(f"Get cache of mm_hashes: {mm_hashes}") + + return mm_items + + def update_processor_cache(self, socket, mm_hashes: list[str], mm_items): + """ + update cache data + """ + req = pickle.dumps((mm_hashes, mm_items)) + socket.send_multipart([b"", req]) + data_processor_logger.info(f"Update cache of mm_hashes: {mm_hashes}") + def apply_chat_template(self, request): """ Apply chat template to convert messages into token sequence. diff --git a/fastdeploy/input/paddleocr_vl_processor/process_video.py b/fastdeploy/input/paddleocr_vl_processor/process_video.py new file mode 100644 index 00000000000..c7089d26dc2 --- /dev/null +++ b/fastdeploy/input/paddleocr_vl_processor/process_video.py @@ -0,0 +1,82 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import math +from typing import Optional, Union + +import numpy as np + + +def sample_frames( + frame_factor: int, + min_frames: int, + max_frames: int, + metadata: Optional[dict] = None, + fps: Optional[Union[int, float]] = None, + num_frames: Optional[int] = None, +): + """ + Sample frames from video according to specified criteria. + + Args: + frame_factor: Ensure sampled frames are multiples of this factor + min_frames: Minimum number of frames to sample + max_frames: Maximum number of frames to sample + metadata: Video metadata containing fps information + fps: Target frames per second for sampling + num_frames: Exact number of frames to sample + + Returns: + np.ndarray: Sampled video frames + + Raises: + ValueError: If both fps and num_frames are specified, + or if required metadata is missing, + or if requested frames exceed available frames + """ + if fps > 0 and num_frames > 0: + raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!") + + total_num_frames = metadata["num_of_frame"] + + # If num_frames is not given but fps is, calculate num_frames from fps + if num_frames > 0: + num_frames = round(num_frames / frame_factor) * frame_factor + elif fps > 0: + if metadata is None: + raise ValueError( + "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. " + "Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video" + ) + max_frames = math.floor(min(max_frames, total_num_frames) / frame_factor) * frame_factor + num_frames = total_num_frames / metadata["fps"] * fps + num_frames = min(min(max(num_frames, min_frames), max_frames), total_num_frames) + num_frames = math.floor(num_frames / frame_factor) * frame_factor + if num_frames > total_num_frames: + raise ValueError( + f"Video can't be sampled. The inferred `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. " + "Decrease `num_frames` or `fps` for sampling." + ) + + # Calculate frame indices based on sampling strategy + if num_frames > 0: + # Evenly spaced sampling for target frame count + indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(np.int32) + else: + # Keep all frames if no sampling requested + indices = np.arange(0, total_num_frames).astype(np.int32) + + return indices From e0f9cef8d63272ef7e0947a07b4682fd5da42bff Mon Sep 17 00:00:00 2001 From: Limerances <466107905@qq.com> Date: Tue, 28 Oct 2025 18:56:25 +0800 Subject: [PATCH 2/6] add test for paddleocr_vl --- tests/e2e/test_paddleocr_vl_serving.py | 281 +++++++++++++++++++++++++ 1 file changed, 281 insertions(+) create mode 100644 tests/e2e/test_paddleocr_vl_serving.py diff --git a/tests/e2e/test_paddleocr_vl_serving.py b/tests/e2e/test_paddleocr_vl_serving.py new file mode 100644 index 00000000000..5b650ee79ff --- /dev/null +++ b/tests/e2e/test_paddleocr_vl_serving.py @@ -0,0 +1,281 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import signal +import socket +import subprocess +import sys +import time + +import pytest +import requests + +# Read ports from environment variables; use default values if not set +FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) +FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133)) +FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233)) +FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) + +# List of ports to clean before and after tests +PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT] + +os.environ["FD_USE_MACHETE"] = "0" + + +def is_port_open(host: str, port: int, timeout=1.0): + """ + Check if a TCP port is open on the given host. + Returns True if connection succeeds, False otherwise. + """ + try: + with socket.create_connection((host, port), timeout): + return True + except Exception: + return False + + +def kill_process_on_port(port: int): + """ + Kill processes that are listening on the given port. + Uses `lsof` to find process ids and sends SIGKILL. + """ + try: + output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() + current_pid = os.getpid() + parent_pid = os.getppid() + for pid in output.splitlines(): + pid = int(pid) + if pid in (current_pid, parent_pid): + print(f"Skip killing current process (pid={pid}) on port {port}") + continue + os.kill(pid, signal.SIGKILL) + print(f"Killed process on port {port}, pid={pid}") + except subprocess.CalledProcessError: + pass + + +def clean_ports(): + """ + Kill all processes occupying the ports listed in PORTS_TO_CLEAN. + """ + for port in PORTS_TO_CLEAN: + kill_process_on_port(port) + time.sleep(2) + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts the API server as a subprocess + - Waits for server port to open (up to 30 seconds) + - Tears down server after all tests finish + """ + print("Pre-test port cleanup...") + clean_ports() + print("log dir clean ") + if os.path.exists("log") and os.path.isdir("log"): + shutil.rmtree("log") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "PaddleOCR-VL-0.9B") + else: + # model_path = "./PaddleOCR-VL-0.9B" + model_path = "/workspace/ocr/PaddleOCR-VL-0.9B" + + log_path = "server.log" + limit_mm_str = json.dumps({"image": 100, "video": 100}) + + cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--tensor-parallel-size", + "2", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--enable-mm", + "--max-model-len", + "32768", + "--max-num-batched-tokens", + "384", + "--max-num-seqs", + "128", + "--limit-mm-per-prompt", + limit_mm_str, + "--enable-chunked-prefill", + "--kv-cache-ratio", + "0.71", + "--quantization", + "wint4", + "--graph-optimization-config", + '{"graph_opt_level":0, "use_cudagraph":true}', + ] + + # Start subprocess in new process group + with open(log_path, "w") as logfile: + process = subprocess.Popen( + cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + ) + + # Wait up to 10 minutes for API server to be ready + for _ in range(10 * 60): + if is_port_open("127.0.0.1", FD_API_PORT): + print(f"API server is up on port {FD_API_PORT}") + break + time.sleep(1) + else: + print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process.pid, signal.SIGTERM) + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError(f"API server did not start on port {FD_API_PORT}") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process.pid, signal.SIGTERM) + print(f"API server (pid={process.pid}) terminated") + clean_ports() + except Exception as e: + print(f"Failed to terminate API server: {e}") + + +@pytest.fixture(scope="session") +def api_url(request): + """ + Returns the API endpoint URL for chat completions. + """ + return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions" + + +@pytest.fixture(scope="session") +def metrics_url(request): + """ + Returns the metrics endpoint URL. + """ + return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + + +@pytest.fixture +def headers(): + """ + Returns common HTTP request headers. + """ + return {"Content-Type": "application/json"} + + +@pytest.fixture +def consistent_payload(): + """ + Returns a fixed payload for consistency testing, + including a fixed random seed and temperature. + """ + return { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://paddle-model-ecology.bj.bcebos.com/PPOCRVL/dataset/ocr_v5_eval/handwrite_ch_rec_val/中文手写古籍_000054_crop_32.jpg", + }, + }, + {"type": "text", "text": "OCR:"}, + ], + } + ], + "temperature": 0.8, + "top_p": 0, # fix top_p to reduce randomness + "seed": 13, # fixed random seed + } + + +# ========================== +# Helper function to calculate difference rate between two texts +# ========================== +def calculate_diff_rate(text1, text2): + """ + Calculate the difference rate between two strings + based on the normalized Levenshtein edit distance. + Returns a float in [0,1], where 0 means identical. + """ + if text1 == text2: + return 0.0 + + len1, len2 = len(text1), len(text2) + dp = [[0] * (len2 + 1) for _ in range(len1 + 1)] + + for i in range(len1 + 1): + for j in range(len2 + 1): + if i == 0 or j == 0: + dp[i][j] = i + j + elif text1[i - 1] == text2[j - 1]: + dp[i][j] = dp[i - 1][j - 1] + else: + dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + + edit_distance = dp[len1][len2] + max_len = max(len1, len2) + return edit_distance / max_len if max_len > 0 else 0.0 + + +# ========================== +# Consistency test for repeated runs with fixed payload +# ========================== +def test_consistency_between_runs(api_url, headers, consistent_payload): + """ + Test that two runs with the same fixed input produce similar outputs. + """ + # First request + resp1 = requests.post(api_url, headers=headers, json=consistent_payload) + assert resp1.status_code == 200 + result1 = resp1.json() + content1 = result1["choices"][0]["message"]["content"] + print(content1) + + # Second request + resp2 = requests.post(api_url, headers=headers, json=consistent_payload) + assert resp2.status_code == 200 + result2 = resp2.json() + + content2 = result2["choices"][0]["message"]["content"] + print(content2) + + # Calculate difference rate + diff_rate = calculate_diff_rate(content1, content2) + + # Verify that the difference rate is below the threshold + assert diff_rate < 0.05, f"Output difference too large ({diff_rate:.4%})" + + assert content1 == "生甘草" From 6c6f26fa04446eb6a4b78903e90fb99c920a573a Mon Sep 17 00:00:00 2001 From: ming1753 Date: Tue, 28 Oct 2025 19:26:47 +0800 Subject: [PATCH 3/6] disable prefix-caching in ocr --- fastdeploy/engine/args_utils.py | 2 ++ fastdeploy/worker/worker_process.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 6fcbeac1e4b..c7fd2009c36 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -452,6 +452,8 @@ def __post_init__(self): if "PaddleOCR" in get_model_architecture(self.model, self.model_config_name): envs.FD_ENABLE_MAX_PREFILL = 1 + self.enable_prefix_caching = False + self.max_encoder_cache = 0 @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 9ea08542bc3..30bf8b103b5 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -814,6 +814,8 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: architecture = fd_config.model_config.architectures[0] if "PaddleOCR" in architecture: envs.FD_ENABLE_MAX_PREFILL = 1 + fd_config.cache_config.enable_prefix_caching = False + fd_config.cache_config.max_encoder_cache = 0 return fd_config From 6d1dd20be411a0197a40327a870df9abb155eb0d Mon Sep 17 00:00:00 2001 From: Limerances <466107905@qq.com> Date: Tue, 28 Oct 2025 19:45:34 +0800 Subject: [PATCH 4/6] add test for paddleocr_vl --- tests/e2e/test_paddleocr_vl_serving.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/e2e/test_paddleocr_vl_serving.py b/tests/e2e/test_paddleocr_vl_serving.py index 5b650ee79ff..997bbddf87a 100644 --- a/tests/e2e/test_paddleocr_vl_serving.py +++ b/tests/e2e/test_paddleocr_vl_serving.py @@ -96,8 +96,7 @@ def setup_and_run_server(): if base_path: model_path = os.path.join(base_path, "PaddleOCR-VL-0.9B") else: - # model_path = "./PaddleOCR-VL-0.9B" - model_path = "/workspace/ocr/PaddleOCR-VL-0.9B" + model_path = "./PaddleOCR-VL-0.9B" log_path = "server.log" limit_mm_str = json.dumps({"image": 100, "video": 100}) From 73527d5b87251fc481357783750fc126cc1cd104 Mon Sep 17 00:00:00 2001 From: Limerances <466107905@qq.com> Date: Fri, 31 Oct 2025 16:13:19 +0800 Subject: [PATCH 5/6] Fix top_p for rejection sampling --- .../input/ernie4_5_vl_processor/ernie4_5_vl_processor.py | 5 +++++ .../input/paddleocr_vl_processor/paddleocr_vl_processor.py | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py index 1211eccf532..04e8b435ae6 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py +++ b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py @@ -26,6 +26,8 @@ from .process import DataProcessor +_SAMPLING_EPS = 1e-5 + class Ernie4_5_VLProcessor(Ernie4_5Processor): """The processor class for ERNIE MoE VL models.""" @@ -260,6 +262,9 @@ def process_request_dict(self, request, max_model_len=None): request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1) data_processor_logger.info(f"Processed request {request}") + if request.get("top_p") < _SAMPLING_EPS: + request["top_p"] = _SAMPLING_EPS + return request def append_completion_tokens(self, multimodal_inputs, completion_token_ids): diff --git a/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py b/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py index 2e9e680c0b5..544484661ff 100644 --- a/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py +++ b/fastdeploy/input/paddleocr_vl_processor/paddleocr_vl_processor.py @@ -22,6 +22,8 @@ from .process import DataProcessor +_SAMPLING_EPS = 1e-5 + class PaddleOCRVLProcessor(TextProcessor): """ @@ -61,7 +63,6 @@ def __init__( tool_parser_obj: Tool parser instance """ super().__init__(model_name_or_path, reasoning_parser_obj, tool_parser_obj) - data_processor_logger.info(f"model_name_or_path: {model_name_or_path}") processor_kwargs = self._parse_processor_kwargs(mm_processor_kwargs) self.processor = DataProcessor( @@ -252,6 +253,9 @@ def process_request_dict(self, request, max_model_len=None): if request.get("max_tokens") is None: request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"])) # Ensure at least 1 token + if request.get("top_p") < _SAMPLING_EPS: + request["top_p"] = _SAMPLING_EPS + return request def append_generated_tokens(self, multimodal_inputs, generated_token_ids): From 7c6d5ef7809587d1781478b5b39a6cecbd573f17 Mon Sep 17 00:00:00 2001 From: Limerances <466107905@qq.com> Date: Fri, 5 Dec 2025 17:10:46 +0800 Subject: [PATCH 6/6] support mxfp4 in gpt-oss --- fastdeploy/config.py | 3 + fastdeploy/envs.py | 2 + fastdeploy/model_executor/layers/linear.py | 7 +- fastdeploy/model_executor/layers/moe/moe.py | 83 +++- .../model_executor/layers/normalization.py | 20 +- .../layers/quantization/__init__.py | 5 + .../layers/quantization/mxfp4.py | 405 ++++++++++++++++++ .../layers/quantization/weight_only.py | 8 +- fastdeploy/model_executor/layers/utils.py | 20 + fastdeploy/model_executor/models/gpt_oss.py | 51 ++- fastdeploy/utils.py | 2 +- 11 files changed, 573 insertions(+), 33 deletions(-) create mode 100644 fastdeploy/model_executor/layers/quantization/mxfp4.py diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 14e5579764a..6f0f3612fd9 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -336,6 +336,9 @@ def read_model_config(self): elif "dtype" in self.model_config: self.model_format = "paddle" logger.info("The model format is Paddle") + elif "model_type" in self.model_config and self.model_config["model_type"] == "gpt_oss": + self.model_format = "torch" + logger.info("The model format is Hugging Face") else: raise ValueError( "Unknown model format. Please ensure your config.json contains " diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 1eb7af39490..b222b8799ac 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -54,6 +54,8 @@ "FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"), # Set moe backend."cutlass","marlin" and "triton" can be set currently. "FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"), + # Whether to use FLASHINFER as MXFP4 backend for MoE. + "FD_USE_FLASHINFER_MOE_MXFP4_BF16": lambda: os.getenv("FD_USE_FLASHINFER_MOE_MXFP4_BF16", "0"), # Whether to use Machete for wint4 dense gemm. "FD_USE_MACHETE": lambda: os.getenv("FD_USE_MACHETE", "1"), # Set whether to disable recompute the request when the KV cache is full. diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index b329844daa9..1db2a9ffc11 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -30,7 +30,7 @@ ) from fastdeploy.platforms import current_platform -from .utils import _set_var_distributed, divide, get_tensor +from .utils import _set_var_distributed, divide, get_tensor, modules_to_convert class UnquantizedLinearMethod(QuantMethodBase): @@ -156,7 +156,7 @@ def __init__( self.output_size, ] - if fd_config.quant_config and not skip_quant: + if fd_config.quant_config and not skip_quant and modules_to_convert(prefix, self.fd_config): self.quant_method = fd_config.quant_config.get_quant_method(self) else: self.quant_method: Optional[QuantMethodBase] = UnquantizedLinearMethod() @@ -589,7 +589,7 @@ class QKVParallelLinear(ColumnParallelLinear): QKVParallelLinear Layer. """ - def __init__(self, fd_config, prefix, with_bias=False, add_bias=True): + def __init__(self, fd_config, prefix, with_bias=False, add_bias=True, skip_quant=False): """ Initialize the QKV Linear layer with given parameters. @@ -623,6 +623,7 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True): output_size=output_size, with_bias=with_bias, add_bias=add_bias, + skip_quant=skip_quant, ) def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int): diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 09330e549a7..edaf9ec5aa4 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -209,10 +209,12 @@ def __init__( tp_size={self.tp_size}." ) - def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str] = None): + def weight_loader( + self, param, loaded_weight, expert_id, shard_id: Optional[str] = None, loaded_weight_name: Optional[str] = None + ): if expert_id is None and shard_id is None: # MoE experts has been fused in disk - self._load_fused_experts_weight(param, loaded_weight) + self._load_fused_experts_weight(param, loaded_weight, loaded_weight_name) return if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"): SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM @@ -331,7 +333,7 @@ def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim loaded_weight = loaded_weight.cast(expert_param.dtype) expert_param.copy_(loaded_weight, False) - def _load_fused_experts_weight(self, param, loaded_weight): + def _load_fused_experts_weight(self, param, loaded_weight, loaded_weight_name: Optional[str] = None): if self.tp_size > 1: dim = -1 if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)): @@ -342,11 +344,76 @@ def _load_fused_experts_weight(self, param, loaded_weight): shard_offset = self.tp_rank * block_size shard_size = (self.tp_rank + 1) * block_size loaded_weight = slice_fn(loaded_weight, dim, shard_offset, shard_size) - assert param.shape == loaded_weight.shape, ( - f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" - ) - loaded_weight = get_tensor(loaded_weight) - param.copy_(loaded_weight, False) + + if self.moe_quant_config.name() == "mxfp4": + assert loaded_weight_name is not None + weight = get_tensor(loaded_weight) + if "block" in loaded_weight_name: + if "up" in loaded_weight_name: + weight = weight.reshape([self.num_experts, 2 * self.moe_intermediate_size, -1]) + elif "down" in loaded_weight_name: + weight = weight.reshape([self.num_experts, self.hidden_size, -1]) + weight = paddle.nn.functional.pad( + weight.cast("int32"), + pad=[0, param.shape[-1] - weight.shape[-1], 0, param.shape[-2] - weight.shape[-2]], + mode="constant", + value=0, + ).cast("uint8") + + if "up" in loaded_weight_name: + gate_w, up_w = weight[:, ::2, :], weight[:, 1::2, :] + param.copy_(paddle.concat([up_w, gate_w], axis=1), False) + else: + param.copy_(weight, False) + + elif "scale" in loaded_weight_name: + if "up" in loaded_weight_name: + weight = weight.reshape([self.num_experts, 2 * self.moe_intermediate_size, -1]) + elif "down" in loaded_weight_name: + weight = weight.reshape([self.num_experts, self.hidden_size, -1]) + weight = paddle.nn.functional.pad( + weight.cast("int32"), + pad=[0, param.shape[-1] - weight.shape[-1], 0, param.shape[-2] - weight.shape[-2]], + mode="constant", + value=0, + ).cast("uint8") + + def _interleave_mxfp4_cutlass_sm90(w): + w_shape = w.shape + w_interleaved = w.reshape([w_shape[0], w_shape[1], (w_shape[2] // 4), 4]) + w_interleaved = w_interleaved.permute([0, 2, 1, 3]) + w_interleaved = w_interleaved.reshape([w_shape[0], w_shape[2] // 4, w_shape[1] * 4]) + return w_interleaved + + if "up" in loaded_weight_name: + gate_s, up_s = weight[:, ::2, :], weight[:, 1::2, :] + up_gate_proj_scale = paddle.concat([up_s, gate_s], axis=1) + up_gate_proj_scale_interleaved = _interleave_mxfp4_cutlass_sm90(up_gate_proj_scale) + param.copy_(up_gate_proj_scale_interleaved, False) + else: + down_proj_scale = weight + down_proj_scale_interleaved = _interleave_mxfp4_cutlass_sm90(down_proj_scale) + param.copy_(down_proj_scale_interleaved, False) + + elif "bias" in loaded_weight_name: + + weight = paddle.nn.functional.pad( + weight, pad=[0, param.shape[-1] - weight.shape[-1]], mode="constant", value=0 + ) + + if "up" in loaded_weight_name: + gate_b, up_b = weight[:, ::2].cast("bfloat16"), weight[:, 1::2].cast("bfloat16") + param.copy_(paddle.concat([up_b, gate_b], axis=-1), False) + else: + param.copy_(weight.cast("bfloat16"), False) + + else: + assert param.shape == loaded_weight.shape, ( + f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" + ) + + loaded_weight = get_tensor(loaded_weight) + param.copy_(loaded_weight, False) if hasattr(param, "tensor_track"): for i in range(self.num_local_experts): diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index 6bcb05ba727..c7120ba8ac8 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -29,7 +29,7 @@ from fastdeploy.config import FDConfig -from .utils import get_tensor +from .utils import get_tensor, modules_to_convert class RMSNorm(nn.Layer): @@ -92,9 +92,21 @@ def __init__( "float16", ], f"Unsupported dtype: {dtype}. Must be one of: float32, bfloat16, float16" - self.quant_round_type: int = self.fd_config.quant_config.quant_round_type if fd_config.quant_config else 0 - self.quant_max_bound: int = self.fd_config.quant_config.quant_max_bound if fd_config.quant_config else 0 - self.quant_min_bound: int = self.fd_config.quant_config.quant_min_bound if fd_config.quant_config else 0 + self.quant_round_type: int = ( + self.fd_config.quant_config.quant_round_type + if fd_config.quant_config and modules_to_convert(prefix, self.fd_config) + else 0 + ) + self.quant_max_bound: int = ( + self.fd_config.quant_config.quant_max_bound + if fd_config.quant_config and modules_to_convert(prefix, self.fd_config) + else 0 + ) + self.quant_min_bound: int = ( + self.fd_config.quant_config.quant_min_bound + if fd_config.quant_config and modules_to_convert(prefix, self.fd_config) + else 0 + ) self.begin_norm_axis: int = begin_norm_axis self.init_weight() diff --git a/fastdeploy/model_executor/layers/quantization/__init__.py b/fastdeploy/model_executor/layers/quantization/__init__.py index f8716369852..bec0dc977b4 100644 --- a/fastdeploy/model_executor/layers/quantization/__init__.py +++ b/fastdeploy/model_executor/layers/quantization/__init__.py @@ -33,6 +33,7 @@ "mix_quant", "tensor_wise_fp8", "kvcache", + "mxfp4", ] @@ -99,6 +100,8 @@ def _get_offline_quant_config_name(quantization_config, is_torch_weight, is_v1_l has_block_size = "weight_block_size" in quantization_config if quant_method == "fp8" and has_block_size: quant_config_name = "block_wise_fp8" + elif quant_method == "mxfp4": + quant_config_name = "mxfp4" else: raise ValueError("Torch weight offline quantization only supports block-wise FP8.") else: @@ -116,6 +119,7 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]: from .block_wise_fp8 import BlockWiseFP8Config from .kv_cache import KvCacheQuantConfig from .mix_quant import MixQuantConfig + from .mxfp4 import MXFP4Config from .tensor_wise_fp8 import TensorWiseFP8Config from .w4a8 import W4A8Config from .w4afp8 import W4AFP8Config @@ -137,6 +141,7 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]: "tensor_wise_fp8": TensorWiseFP8Config, "kvcache": KvCacheQuantConfig, "mix_quant": MixQuantConfig, + "mxfp4": MXFP4Config, } return method_to_config[quantization] diff --git a/fastdeploy/model_executor/layers/quantization/mxfp4.py b/fastdeploy/model_executor/layers/quantization/mxfp4.py new file mode 100644 index 00000000000..4986222a0d9 --- /dev/null +++ b/fastdeploy/model_executor/layers/quantization/mxfp4.py @@ -0,0 +1,405 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import importlib +import importlib.util +from enum import Enum +from typing import Optional + +import paddle +from paddle import nn + +from fastdeploy import envs +from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch +from fastdeploy.model_executor.utils import set_weight_attrs +from fastdeploy.platforms import current_platform +from fastdeploy.utils import get_logger + +from ..moe import FusedMoE +from .quant_base import QuantConfigBase, QuantMethodBase + +paddle.compat.enable_torch_proxy() +import torch +from torch.nn import functional as F + +logger = get_logger("config", "config.log") + + +class Mxfp4Backend(Enum): + NONE = 0 + + # FlashInfer Backend + SM90_FI_MXFP4_BF16 = 1 + + # Triton Backend + TRITON = 2 + + +def check_device_capability(num): + if paddle.is_compiled_with_cuda(): + device = paddle.device.get_device() + major, minor = paddle.device.cuda.get_device_capability(device) + return major * 10 + minor >= num + else: + return False + + +def has_flashinfer(): + return importlib.util.find_spec("flashinfer") is not None + + +def round_up(a, b): + return ((a + b - 1) // b) * b + + +def get_mxfp4_backend(): + if current_platform.is_cuda(): + if check_device_capability(90) and has_flashinfer() and envs.FD_USE_FLASHINFER_MOE_MXFP4_BF16: + logger.info("FastDeploy Using FlashInfer MXFP4 BF16 backend for SM90 in MoE") + return Mxfp4Backend.SM90_FI_MXFP4_BF16 + else: + logger.info("FastDeploy Using Triton backend in MoE") + return Mxfp4Backend.TRITON + else: + raise NotImplementedError + + +class MXFP4Config(QuantConfigBase): + """Base class for quantization configs.""" + + def __init__(self, is_checkpoint_bf16: bool = False): + super().__init__() + self.is_checkpoint_bf16 = is_checkpoint_bf16 + + def name(self) -> str: + return "mxfp4" + + @classmethod + def from_config(cls, config: dict) -> "MXFP4Config": + is_checkpoint_bf16 = not config.get("is_quantized", False) + return cls(is_checkpoint_bf16) + + def get_quant_method(self, layer) -> Optional[QuantMethodBase]: + if isinstance(layer, FusedMoE): + return MXFP4MoeMethod(self) + else: + raise NotImplementedError + + +class MXFP4MoeMethod(QuantMethodBase): + def __init__( + self, + quant_config: MXFP4Config, + ) -> None: + super().__init__() + self.quant_config = quant_config + self.mxfp4_backend = get_mxfp4_backend() + + def create_weights(self, layer, **extra_weight_attrs): + + block_size = 32 + + intermediate_size_pad = layer.moe_intermediate_size + hidden_size_pad = layer.hidden_size + + if self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + intermediate_size_pad = round_up(intermediate_size_pad, 128) + hidden_size_pad = round_up(hidden_size_pad, 128) + else: + intermediate_size_pad = round_up(intermediate_size_pad, 64) + + self.intermediate_size_pad = intermediate_size_pad + self.hidden_size_pad = hidden_size_pad + self.num_experts = layer.num_local_experts + + self.up_gate_proj_weight_shape = [ + self.num_experts, + intermediate_size_pad * 2, + hidden_size_pad // 2, # uint8 + ] + + self.down_proj_weight_shape = [ + self.num_experts, + hidden_size_pad, + intermediate_size_pad // 2, # uint8 + ] + + self.up_gate_proj_scale_shape = [ + self.num_experts, + intermediate_size_pad * 2, + hidden_size_pad // block_size, + ] + + self.down_proj_scale_shape = [ + self.num_experts, + hidden_size_pad, + intermediate_size_pad // block_size, + ] + + self.weight_dtype = "uint8" + + setattr( + layer, + "up_gate_proj_weight", + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + "down_proj_weight", + layer.create_parameter( + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + setattr( + layer, + "up_gate_proj_scale", + layer.create_parameter( + shape=self.up_gate_proj_scale_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + setattr( + layer, + "down_proj_scale", + layer.create_parameter( + shape=self.down_proj_scale_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + + extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch" + + set_weight_attrs(layer.up_gate_proj_weight, extra_weight_attrs) + set_weight_attrs(layer.down_proj_weight, extra_weight_attrs) + + set_weight_attrs(layer.up_gate_proj_scale, extra_weight_attrs) + set_weight_attrs(layer.down_proj_scale, extra_weight_attrs) + + if layer.with_bias: + layer.up_gate_proj_bias = layer.create_parameter( + shape=[self.num_experts, intermediate_size_pad * 2], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + layer.down_proj_bias = layer.create_parameter( + shape=[self.num_experts, hidden_size_pad], + dtype=layer.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ) + + set_weight_attrs( + layer.up_gate_proj_bias, + extra_weight_attrs, + ) + set_weight_attrs( + layer.down_proj_bias, + extra_weight_attrs, + ) + + if layer.activation == "swigluoai": + gemm1_alpha = layer.create_parameter( + shape=[self.num_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(1.702), + ) + gemm1_alpha.initialize() + setattr(layer, "gemm1_alpha", gemm1_alpha) + + gemm1_beta = layer.create_parameter( + shape=[self.num_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(1.0), + ) + gemm1_beta.initialize() + setattr(layer, "gemm1_beta", gemm1_beta) + + gemm1_clamp_limit = layer.create_parameter( + shape=[self.num_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(7.0), + ) + gemm1_clamp_limit.initialize() + setattr(layer, "gemm1_clamp_limit", gemm1_clamp_limit) + + def process_weights_after_loading(self, layer) -> None: + return + block_size = 32 + if self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + assert ( + layer.up_gate_proj_weight.dim() == 3 + and layer.up_gate_proj_weight.shape[0] == self.num_experts + and layer.up_gate_proj_weight.shape[1] == self.intermediate_size_pad * 2 + and layer.up_gate_proj_weight.shape[2] == self.hidden_size_pad // 2 + ) + assert ( + layer.up_gate_proj_scale.dim() == 3 + and layer.up_gate_proj_scale.shape[0] == self.num_experts + and layer.up_gate_proj_scale.shape[1] == self.intermediate_size_pad * 2 + and layer.up_gate_proj_scale.shape[2] == self.hidden_size_pad // block_size + ) + assert ( + layer.down_proj_weight.dim() == 3 + and layer.down_proj_weight.shape[0] == self.num_experts + and layer.down_proj_weight.shape[1] == self.hidden_size_pad + and layer.down_proj_weight.shape[2] == self.intermediate_size_pad // 2 + ) + assert ( + layer.down_proj_scale.dim() == 3 + and layer.down_proj_scale.shape[0] == self.num_experts + and layer.down_proj_scale.shape[1] == self.hidden_size_pad + and layer.down_proj_scale.shape[2] == self.intermediate_size_pad // block_size + ) + if layer.with_bias: + assert ( + layer.up_gate_proj_bias.dim() == 2 + and layer.up_gate_proj_bias.shape[0] == self.num_experts + and layer.up_gate_proj_bias.shape[1] == self.intermediate_size_pad * 2 + ) + assert ( + layer.down_proj_bias.dim() == 2 + and layer.down_proj_bias.shape[0] == self.num_experts + and layer.down_proj_bias.shape[1] == self.hidden_size_pad + ) + + gate_w, up_w = layer.up_gate_proj_weight[:, ::2, :], layer.up_gate_proj_weight[:, 1::2, :] + gate_b, up_b = layer.up_gate_proj_bias[:, ::2].cast("bfloat16"), layer.up_gate_proj_bias[:, 1::2].cast( + "bfloat16" + ) + gate_s, up_s = layer.up_gate_proj_scale[:, ::2, :], layer.up_gate_proj_scale[:, 1::2, :] + + layer.up_gate_proj_weight.copy_(paddle.concat([up_w, gate_w], axis=1), False) + layer.up_gate_proj_bias.copy_(paddle.concat([up_b, gate_b], axis=-1), False) + + def _interleave_mxfp4_cutlass_sm90(w): + w_shape = w.shape + w_interleaved = w.reshape([w_shape[0], w_shape[1], (w_shape[2] // 4), 4]) + w_interleaved = w_interleaved.permute([0, 2, 1, 3]) + w_interleaved = w_interleaved.reshape([w_shape[0], w_shape[2] // 4, w_shape[1] * 4]) + return w_interleaved + + up_gate_proj_scale = paddle.concat([up_s, gate_s], axis=1) + up_gate_proj_scale_interleaved = _interleave_mxfp4_cutlass_sm90(up_gate_proj_scale) + + down_proj_scale = layer.down_proj_scale + down_proj_scale_interleaved = _interleave_mxfp4_cutlass_sm90(down_proj_scale) + + layer.up_gate_proj_scale.copy_(up_gate_proj_scale_interleaved, False) + layer.down_proj_scale.copy_(down_proj_scale_interleaved, False) + else: + raise ValueError(f"Unsupported backend: {self.mxfp4_backend}") + + def compute_routing(self, router_logits: paddle.Tensor, top_k: int): + """ + Compute routing weights and selected experts from router logits. + + Args: + router_logits (torch.Tensor): Router logits of shape [batch_size, num_experts] + top_k (int): Number of experts to route to per token + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - routing_weights: Expert weights of shape [batch_size, top_k] + - selected_experts: Expert indices of shape [batch_size, top_k] + """ + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.float() + return routing_weights, selected_experts + + def apply(self, layer: nn.Layer, x: paddle.Tensor, router: nn.Layer) -> paddle.Tensor: + router_out = router(x.cast("float32")) + + if self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: + + ( + _, + _, + _, + topk_weights, + topk_idx, + _, + ) = moe_expert_dispatch( + x, + router_out, + layer.gate_correction_bias, + ( + layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None + ), # if set, permute_input will be int8_t + layer.top_k, + False, + self.quant_config.name(), + topk_only_mode=False, + ) + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + quant_scales = [ + layer.up_gate_proj_scale, + layer.down_proj_scale, + ] + extra_kwargs = dict( + use_w4_group_scaling=True, + fc1_expert_weights=layer.up_gate_proj_weight, + fc2_expert_weights=layer.down_proj_weight, + ) + + from flashinfer.fused_moe import ( + cutlass_fused_moe as flashinfer_cutlass_fused_moe, + ) + + x = paddle.nn.functional.pad(x, pad=[0, self.hidden_size_pad - x.shape[-1]], mode="constant", value=0) + + output = paddle.zeros_like(x, dtype="bfloat16") + + _ = flashinfer_cutlass_fused_moe( + input=x, + token_selected_experts=topk_idx, + token_final_scales=topk_weights, + output_dtype=torch.bfloat16, + output=output, + quant_scales=quant_scales, + fc1_expert_biases=layer.up_gate_proj_bias, + fc2_expert_biases=layer.down_proj_bias, + swiglu_alpha=layer.gemm1_alpha, + swiglu_beta=layer.gemm1_beta, + swiglu_limit=layer.gemm1_clamp_limit, + # tp_size=self.moe.tp_size, + # tp_rank=self.moe.tp_rank, + # ep_size=self.moe.ep_size, + # ep_rank=self.moe.ep_rank, + tune_max_num_tokens=8192, + **extra_kwargs, + ) + + return output[..., : layer.hidden_size] + + def process_loaded_weights(self, layer, weights): + """Process the weight after loading. + + This can be used for example, to transpose weights for computation. + """ + return diff --git a/fastdeploy/model_executor/layers/quantization/weight_only.py b/fastdeploy/model_executor/layers/quantization/weight_only.py index 5cd7ec79e74..75a75e11374 100644 --- a/fastdeploy/model_executor/layers/quantization/weight_only.py +++ b/fastdeploy/model_executor/layers/quantization/weight_only.py @@ -253,7 +253,7 @@ def create_weights(self, layer, **extra_weight_attrs): else: if isinstance(self, MacheteWeightOnlyLinearMethod): # Using group scale for machete, group size is 128 - weight_scale_shape = [(layer.weight_shape[0] + 127) // 128, layer.weight_shape[1]] + weight_scale_shape = [(layer.weight_shape[0] + 63) // 64, layer.weight_shape[1]] if self.quant_config.name() == "wint4": layer.weight_shape[0] //= 8 else: @@ -313,7 +313,7 @@ def process_weights_after_loading(self, layer) -> None: w=layer.weight, atype=layer._dtype, quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128", - group_size=128, + group_size=64, ) else: quanted_weight_tensor, weight_scale_tensor = weight_quantize( @@ -417,7 +417,7 @@ def process_loaded_weights(self, layer, weight) -> None: w=weight, atype=layer._dtype, quant_type="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128", - group_size=128, + group_size=64, ) layer.weight.set_value(quanted_weight_tensor) layer.weight_scale.set_value(weight_scale_tensor.astype(paddle.get_default_dtype())) @@ -430,7 +430,7 @@ def apply(self, layer, x): w_prepack=layer.weight, w_g_s=layer.weight_scale, weight_dtype="uint4b8" if self.quant_config.name() == "wint4" else "uint8b128", - group_size=128, + group_size=64, ) if layer.with_bias: linear_out = paddle.add(linear_out, layer.bias) diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index c0644896e8e..3b1c37c0e88 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -23,6 +23,7 @@ from paddle.framework import in_dynamic_mode from scipy.linalg import block_diag +from fastdeploy.config import FDConfig from fastdeploy.platforms import current_platform if current_platform.is_cuda() and current_platform.available(): @@ -402,3 +403,22 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, ran def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int, offset: int = 0): per_partition_vocab_size = divide(global_vocab_size, world_size) return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, offset=offset) + + +def modules_to_convert(prefix: str, fd_config: FDConfig): + import fnmatch + + if ( + hasattr(fd_config.model_config, "quantization_config") + and fd_config.model_config.quantization_config is not None + ): + if "modules_to_not_convert" in fd_config.model_config.quantization_config: + patterns = fd_config.model_config.quantization_config["modules_to_not_convert"] + for p in patterns: + if fnmatch.fnmatch(prefix, p) or fnmatch.fnmatch(prefix, p + ".*"): + return False + return True + else: + return True + else: + return True diff --git a/fastdeploy/model_executor/models/gpt_oss.py b/fastdeploy/model_executor/models/gpt_oss.py index c263d357772..03244859dda 100644 --- a/fastdeploy/model_executor/models/gpt_oss.py +++ b/fastdeploy/model_executor/models/gpt_oss.py @@ -227,6 +227,13 @@ def forward(self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta): ) class GptOssForCausalLM(ModelForCasualLM): def __init__(self, fd_config: FDConfig): + if ( + hasattr(fd_config, "quant_config") + and fd_config.model_config.quantization_config is not None + and "modules_to_not_convert" in fd_config.model_config.quantization_config + ): + fd_config.model_config.quantization_config["modules_to_not_convert"].append("*norm") + super(GptOssForCausalLM, self).__init__(fd_config) self.fd_config = fd_config self.model = GptOssModel(fd_config=fd_config) @@ -266,14 +273,20 @@ def load_weights(self, weights_iterator) -> None: ] expert_params_mapping = [ # (param_name, weight_name, expert_id, shard_id) - ("up_gate_proj_weight", "gate_up_proj", None, None), ("up_gate_proj_bias", "gate_up_proj_bias", None, None), - ("down_proj_weight", "down_proj", None, None), ("down_proj_bias", "down_proj_bias", None, None), + ("up_gate_proj_weight", "gate_up_proj", None, None), + ("down_proj_weight", "down_proj", None, None), + ("up_gate_proj_weight", "gate_up_proj_blocks", None, None), + ("up_gate_proj_scale", "gate_up_proj_scales", None, None), + ("down_proj_weight", "down_proj_blocks", None, None), + ("down_proj_scale", "down_proj_scales", None, None), ] params_dict = dict(self.named_parameters()) process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers())) + for loaded_weight_name, loaded_weight in weights_iterator: + matched = False for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in loaded_weight_name: continue @@ -285,26 +298,38 @@ def load_weights(self, weights_iterator) -> None: param = params_dict[model_param_name] weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) weight_loader(param, loaded_weight, shard_id) + matched = True break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping + if not matched: + for param_name, weight_name, expert_id, shard_id in expert_params_mapping: if weight_name not in loaded_weight_name: continue + model_param_name = loaded_weight_name.replace(weight_name, param_name) if model_param_name not in params_dict: continue + param = params_dict[model_param_name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id=shard_id, expert_id=expert_id) + weight_loader( + param, + loaded_weight, + loaded_weight_name=loaded_weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + + matched = True break - else: - model_param_name = loaded_weight_name - if model_param_name not in params_dict: - continue - param = params_dict[model_param_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) - weight_loader(param, loaded_weight) + if not matched: + + model_param_name = loaded_weight_name + if model_param_name not in params_dict: + continue + + param = params_dict[model_param_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config)) + weight_loader(param, loaded_weight) model_sublayer_name = re.sub(r"\.(up_gate_proj_weight|down_proj_weight|weight)$", "", model_param_name) process_weights_after_loading_fn(model_sublayer_name, param) diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index a5a45a8a287..46c256156d6 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -568,7 +568,7 @@ def check_unified_ckpt(model_dir): try: # check all the file exists - safetensors_num = int(model_files[0].strip(".safetensors").split("-")[-1]) + safetensors_num = int(model_files[0].strip(".safetensors").split("-")[-1]) + 1 flags = [0] * safetensors_num for x in model_files: current_index = int(x.strip(".safetensors").split("-")[1])