diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 5c0cc7cc5ad..ab6ec98d543 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -1272,66 +1272,6 @@ def hash_block_features(self, input_ids, extra_keys: list = []): """ return hashlib.sha256(pickle.dumps((input_ids, extra_keys))).hexdigest() - def _revert_match_blocks( - self, - request, - matched_token_num: int, - block_size: int, - chunk_idx: int, - match_node_ids: list, - matche_nodes: list, - match_gpu_block_ids: list, - match_cpu_block_ids: list, - gpu_match_token_num: int, - cpu_match_token_num: int, - swap_node_ids: list, - ): - # position = request.multimodal_inputs["mm_positions"][chunk_idx] - # revert_tokens = matched_token_num - position.offset - # TODO(chengyanfu): fix when is_chunked_mm_input=True, revert all matched tokens - revert_tokens = matched_token_num - match_block_ids = [node.block_id for node in matche_nodes] - logger.warning( - f"match_block: req_id {request.request_id} revert tokens: {revert_tokens} from matched nodes: {match_block_ids}" - ) - while revert_tokens >= block_size: - if len(matche_nodes) == 0: - logger.error(f"req_id {request.request_id} revert nodes error, tokens: {revert_tokens}") - break - revert_tokens -= block_size - revert_block = matche_nodes.pop() - revert_block_id = revert_block.block_id - if revert_block_id in match_gpu_block_ids: - match_gpu_block_ids.remove(revert_block_id) - match_node_ids.remove(revert_block.node_id) - gpu_match_token_num -= block_size - elif revert_block_id in match_cpu_block_ids: - match_cpu_block_ids.remove(revert_block_id) - match_node_ids.remove(revert_block.node_id) - cpu_match_token_num -= block_size - else: - logger.error( - f"req_id {request.request_id} revert nodes error, nodes: {revert_block_id}, " - f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}" - ) - break - if revert_block_id in swap_node_ids: - swap_node_ids.remove(revert_block_id) - - if revert_tokens > 0: - last_block_id = matche_nodes[-1].block_id - if last_block_id in match_gpu_block_ids: - gpu_match_token_num -= revert_tokens - elif last_block_id in match_cpu_block_ids: - cpu_match_token_num -= revert_tokens - else: - logger.error( - f"req_id {request.request_id} revert nodes error, revert_tokens: {revert_tokens}, nodes: {last_block_id}, " - f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}" - ) - current_node = self.radix_tree_root if len(matche_nodes) == 0 else matche_nodes[-1] - return gpu_match_token_num, cpu_match_token_num, current_node - def mm_match_block(self, request, block_size): """ Match and retrieve cached blocks for multimodal requests using a radix tree structure. @@ -1420,28 +1360,6 @@ def mm_match_block(self, request, block_size): if has_modified_cpu_lru_leaf_heap: heapq.heapify(self.cpu_lru_leaf_heap) - if self.cache_config.disable_chunked_mm_input: - matched_token_num = gpu_match_token_num + cpu_match_token_num - is_chunked, chunk_idx = self.is_chunked_mm_input(request.multimodal_inputs, matched_token_num) - if is_chunked: - ( - gpu_match_token_num, - cpu_match_token_num, - current_match_node, - ) = self._revert_match_blocks( - request=request, - matched_token_num=matched_token_num, - block_size=block_size, - chunk_idx=chunk_idx, - match_node_ids=match_node_ids, - matche_nodes=matche_nodes, - match_gpu_block_ids=match_gpu_block_ids, - match_cpu_block_ids=match_cpu_block_ids, - gpu_match_token_num=gpu_match_token_num, - cpu_match_token_num=cpu_match_token_num, - swap_node_ids=swap_node_ids, - ) - logger.info(f"match_block: req_id {request.request_id} matched nodes: {match_node_ids}") return ( match_gpu_block_ids, diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 1106b56f9fe..9b1303682c9 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -353,6 +353,21 @@ def _is_mm_request(self, request): return False + def revert_chunked_mm_input(self, mm_inputs, matched_token_num): + """ + revert mm_inputs that is chunked + """ + if mm_inputs is None or "mm_positions" not in mm_inputs or len(mm_inputs["mm_positions"]) == 0: + return matched_token_num + + for idx in range(len(mm_inputs["mm_positions"])): + position = mm_inputs["mm_positions"][idx] + if position.offset < matched_token_num < position.offset + position.length: + return position.offset + elif matched_token_num < position.offset: + break + return matched_token_num + def _get_num_new_tokens(self, request, token_budget): # TODO: set condition to new _get_num_new_tokens num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens @@ -904,11 +919,20 @@ def get_prefix_cached_blocks(self, request: Request): main_process_metrics.prefix_gpu_cache_token_num.inc(request.gpu_cache_token_num) main_process_metrics.prefix_cpu_cache_token_num.inc(request.cpu_cache_token_num) - if matched_token_num == request.need_prefill_tokens: - request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size - request.skip_allocate = True + if self.config.cache_config.disable_chunked_mm_input: + if matched_token_num == request.need_prefill_tokens: + matched_token_num = matched_token_num - self.config.cache_config.block_size + request.skip_allocate = True + request.num_computed_tokens = self.revert_chunked_mm_input( + request.multimodal_inputs, matched_token_num + ) else: - request.num_computed_tokens = matched_token_num + if matched_token_num == request.need_prefill_tokens: + request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size + request.skip_allocate = True + else: + request.num_computed_tokens = matched_token_num + llm_logger.info(f"request {request.request_id} num_computed_tokens: {request.num_computed_tokens}") request.cache_prepare_time = time.time() - cache_prepare_time return True except Exception as e: diff --git a/tests/v1/cache_manager/test_revert_blocks.py b/tests/v1/cache_manager/test_revert_blocks.py deleted file mode 100644 index 0cc3def4ae7..00000000000 --- a/tests/v1/cache_manager/test_revert_blocks.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from dataclasses import asdict -from types import SimpleNamespace - -from fastdeploy.cache_manager.cache_data import BlockNode -from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager -from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig -from fastdeploy.engine.args_utils import EngineArgs -from fastdeploy.engine.request import ImagePosition, Request -from fastdeploy.scheduler import SchedulerConfig - - -def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_override=100, max_num_batched_tokens=3200): - engine_args = EngineArgs( - max_num_seqs=max_num_seqs, - num_gpu_blocks_override=num_gpu_blocks_override, - max_num_batched_tokens=max_num_batched_tokens, - ) - args = asdict(engine_args) - cache_cfg = CacheConfig(args) - model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=4196) - speculative_cfg = SimpleNamespace(method=None) - model_cfg.print = print - model_cfg.architectures = ["test_model"] - cache_cfg.bytes_per_layer_per_block = 1 - parallel_cfg = ParallelConfig(args) - scheduler_cfg = SchedulerConfig(args) - graph_opt_cfg = engine_args.create_graph_optimization_config() - fd_config = FDConfig( - model_config=model_cfg, - cache_config=cache_cfg, - parallel_config=parallel_cfg, - graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, - scheduler_config=scheduler_cfg, - ) - return PrefixCacheManager(config=fd_config, tensor_parallel_size=8, splitwise_role="mixed") - - -class TestIsChunkedMMInput(unittest.TestCase): - def setUp(self): - self.cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=True, num_gpu_blocks_override=100) - - def test_is_chunked_mm_input_none_input(self): - result, idx = self.cache_manager.is_chunked_mm_input(None, 10) - self.assertFalse(result) - self.assertEqual(idx, 0) - - def test_is_chunked_mm_input_no_mm_positions(self): - mm_inputs = {"other_field": "value"} - result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 10) - self.assertFalse(result) - self.assertEqual(idx, 0) - - def test_is_chunked_mm_input_empty_positions(self): - mm_inputs = {"mm_positions": []} - result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 10) - self.assertFalse(result) - self.assertEqual(idx, 0) - - def test_is_chunked_mm_input_matched_in_chunk(self): - mm_inputs = { - "mm_positions": [ - ImagePosition(offset=5, length=10), - ImagePosition(offset=20, length=10), - ] - } - result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 8) - self.assertTrue(result) - self.assertEqual(idx, 0) - - def test_is_chunked_mm_input_matched_in_second_chunk(self): - mm_inputs = { - "mm_positions": [ - ImagePosition(offset=5, length=10), - ImagePosition(offset=20, length=10), - ] - } - result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 25) - self.assertTrue(result) - self.assertEqual(idx, 1) - - def test_is_chunked_mm_input_before_first_chunk(self): - mm_inputs = { - "mm_positions": [ - ImagePosition(offset=5, length=10), - ImagePosition(offset=20, length=10), - ] - } - result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 3) - self.assertFalse(result) - self.assertEqual(idx, 0) - - def test_is_chunked_mm_input_after_last_chunk(self): - mm_inputs = { - "mm_positions": [ - ImagePosition(offset=5, length=10), - ImagePosition(offset=20, length=10), - ] - } - result, idx = self.cache_manager.is_chunked_mm_input(mm_inputs, 35) - self.assertFalse(result) - self.assertEqual(idx, 0) - - -@unittest.skip("Skip TestRevertMatchBlocks") -class TestRevertMatchBlocks(unittest.TestCase): - def setUp(self): - self.block_size = 64 - self.cache_manager = make_prefix_cache_manager(max_num_seqs=3, enable_mm=True, num_gpu_blocks_override=100) - - def make_match_blocks(self, gpu_block_num, cpu_block_num): - block_num = gpu_block_num + cpu_block_num - matched_token_num = block_num * self.block_size - match_node_ids = [] - matche_nodes = [] - match_gpu_block_ids = [] - match_cpu_block_ids = [] - for idx in range(block_num): - node_id = idx + 10 - block = BlockNode(node_id, [], 0, 0, idx, 0, None, None, None) - match_node_ids.append(node_id) - matche_nodes.append(block) - match_gpu_block_ids.append(idx) - - for _ in range(cpu_block_num): - match_cpu_block_ids.append(match_gpu_block_ids.pop()) - - gpu_match_token_num = len(match_gpu_block_ids) * self.block_size - cpu_match_token_num = len(match_cpu_block_ids) * self.block_size - return ( - matched_token_num, - match_node_ids, - matche_nodes, - match_gpu_block_ids, - match_cpu_block_ids, - gpu_match_token_num, - cpu_match_token_num, - ) - - def test_revert_full_blocks(self): - # Setup test data - multimodal_inputs = { - "mm_positions": [ImagePosition(offset=0, length=1200)], - "mm_hashes": ["image1"], - } - req_dict = { - "request_id": "req1", - "prompt_token_ids": [-1] * 1200 + [2] * 120, - "prompt_token_ids_len": 1320, - "multimodal_inputs": multimodal_inputs, - } - - ( - matched_token_num, - match_node_ids, - matche_nodes, - match_gpu_block_ids, - match_cpu_block_ids, - gpu_match_token_num, - cpu_match_token_num, - ) = self.make_match_blocks(gpu_block_num=2, cpu_block_num=0) - - # Call method - ( - gpu_match_token_num, - cpu_match_token_num, - current_match_node, - ) = self.cache_manager._revert_match_blocks( - request=Request.from_dict(req_dict), - matched_token_num=matched_token_num, - block_size=self.block_size, - chunk_idx=0, - match_node_ids=match_node_ids, - matche_nodes=matche_nodes, - match_gpu_block_ids=match_gpu_block_ids, - match_cpu_block_ids=match_cpu_block_ids, - gpu_match_token_num=gpu_match_token_num, - cpu_match_token_num=cpu_match_token_num, - swap_node_ids=[], - ) - - # Assertions - self.assertEqual(gpu_match_token_num, 0) - self.assertEqual(cpu_match_token_num, 0) - self.assertEqual(len(match_node_ids), 0) - self.assertEqual(len(match_gpu_block_ids), 0) - - def test_revert_partial_block(self): - # Setup test data - multimodal_inputs = { - "mm_positions": [ImagePosition(offset=120, length=1200)], - "mm_hashes": ["image1"], - } - req_dict = { - "request_id": "req1", - "prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120, - "prompt_token_ids_len": 1440, - "multimodal_inputs": multimodal_inputs, - } - - ( - matched_token_num, - match_node_ids, - matche_nodes, - match_gpu_block_ids, - match_cpu_block_ids, - gpu_match_token_num, - cpu_match_token_num, - ) = self.make_match_blocks(gpu_block_num=20, cpu_block_num=0) - - # Call method - ( - gpu_match_token_num, - cpu_match_token_num, - current_match_node, - ) = self.cache_manager._revert_match_blocks( - request=Request.from_dict(req_dict), - matched_token_num=matched_token_num, - block_size=self.block_size, - chunk_idx=0, - match_node_ids=match_node_ids, - matche_nodes=matche_nodes, - match_gpu_block_ids=match_gpu_block_ids, - match_cpu_block_ids=match_cpu_block_ids, - gpu_match_token_num=gpu_match_token_num, - cpu_match_token_num=cpu_match_token_num, - swap_node_ids=[], - ) - - # Assertions - self.assertEqual(gpu_match_token_num, 120) - self.assertEqual(cpu_match_token_num, 0) - self.assertEqual(len(match_node_ids), 2) - self.assertEqual(len(match_gpu_block_ids), 2) - - def test_revert_with_cpu_blocks(self): - # Setup test data - multimodal_inputs = { - "mm_positions": [ImagePosition(offset=120, length=1200), ImagePosition(offset=1440, length=420)], - "mm_hashes": ["image1", "image2"], - } - req_dict = { - "request_id": "req1", - "prompt_token_ids": [1] * 120 + [-1] * 1200 + [2] * 120 + [-1] * 420, - "prompt_token_ids_len": 1860, - "multimodal_inputs": multimodal_inputs, - } - - ( - matched_token_num, - match_node_ids, - matche_nodes, - match_gpu_block_ids, - match_cpu_block_ids, - gpu_match_token_num, - cpu_match_token_num, - ) = self.make_match_blocks(gpu_block_num=22, cpu_block_num=6) - - # Call method - ( - gpu_match_token_num, - cpu_match_token_num, - current_match_node, - ) = self.cache_manager._revert_match_blocks( - request=Request.from_dict(req_dict), - matched_token_num=matched_token_num, - block_size=self.block_size, - chunk_idx=1, - match_node_ids=match_node_ids, - matche_nodes=matche_nodes, - match_gpu_block_ids=match_gpu_block_ids, - match_cpu_block_ids=match_cpu_block_ids, - gpu_match_token_num=gpu_match_token_num, - cpu_match_token_num=cpu_match_token_num, - swap_node_ids=[], - ) - - # Assertions - self.assertEqual(gpu_match_token_num, 22 * self.block_size) - self.assertEqual(cpu_match_token_num, 32) - self.assertEqual(len(match_node_ids), 23) - self.assertEqual(len(match_gpu_block_ids), 22) - self.assertEqual(len(match_cpu_block_ids), 1) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/v1/test_resource_manager_v1.py b/tests/v1/test_resource_manager_v1.py index 3864f41eb88..6d00e6d3d9d 100644 --- a/tests/v1/test_resource_manager_v1.py +++ b/tests/v1/test_resource_manager_v1.py @@ -9,7 +9,7 @@ from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig, SchedulerConfig from fastdeploy.engine.args_utils import EngineArgs -from fastdeploy.engine.request import Request +from fastdeploy.engine.request import ImagePosition, Request from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 @@ -173,5 +173,100 @@ def test_download_features_retry(self): self.assertEqual(self.request.error_code, 530) +class TestRevertChunkedMMInput(unittest.TestCase): + def setUp(self): + max_num_seqs = 2 + engine_args = EngineArgs( + max_num_seqs=max_num_seqs, + num_gpu_blocks_override=102, + max_num_batched_tokens=3200, + ) + args = asdict(engine_args) + + cache_cfg = CacheConfig(args) + model_cfg = SimpleNamespace(enable_mm=True) # Enable multimodal for feature testing + speculative_cfg = SimpleNamespace(method=None) + model_cfg.print = print + model_cfg.max_model_len = 5120 + model_cfg.architectures = ["test_model"] + cache_cfg.bytes_per_layer_per_block = 1 + parallel_cfg = ParallelConfig(args) + scheduler_cfg = SchedulerConfig(args) + graph_opt_cfg = engine_args.create_graph_optimization_config() + + fd_config = FDConfig( + model_config=model_cfg, + cache_config=cache_cfg, + parallel_config=parallel_cfg, + graph_opt_config=graph_opt_cfg, + speculative_config=speculative_cfg, + scheduler_config=scheduler_cfg, + ) + self.manager = ResourceManagerV1( + max_num_seqs=max_num_seqs, config=fd_config, tensor_parallel_size=8, splitwise_role="mixed" + ) + req_dict = { + "request_id": "test_request", + "multimodal_inputs": {}, + } + self.request = Request.from_dict(req_dict) + self.request.async_process_futures = [] + self.request.multimodal_inputs = {} + + def test_revert_chunked_mm_input_none_input(self): + result = self.manager.revert_chunked_mm_input(None, 10) + self.assertEqual(result, 10) + + def test_revert_chunked_mm_input_no_mm_positions(self): + mm_inputs = {"other_field": "value"} + result = self.manager.revert_chunked_mm_input(mm_inputs, 10) + self.assertEqual(result, 10) + + def test_revert_chunked_mm_input_empty_positions(self): + mm_inputs = {"mm_positions": []} + result = self.manager.revert_chunked_mm_input(mm_inputs, 10) + self.assertEqual(result, 10) + + def test_revert_chunked_mm_input_matched_in_chunk(self): + mm_inputs = { + "mm_positions": [ + ImagePosition(offset=5, length=10), + ImagePosition(offset=20, length=10), + ] + } + result = self.manager.revert_chunked_mm_input(mm_inputs, 8) + self.assertEqual(result, 5) + + def test_revert_chunked_mm_input_matched_in_second_chunk(self): + mm_inputs = { + "mm_positions": [ + ImagePosition(offset=5, length=10), + ImagePosition(offset=20, length=10), + ] + } + result = self.manager.revert_chunked_mm_input(mm_inputs, 25) + self.assertEqual(result, 20) + + def test_revert_chunked_mm_input_before_first_chunk(self): + mm_inputs = { + "mm_positions": [ + ImagePosition(offset=5, length=10), + ImagePosition(offset=20, length=10), + ] + } + result = self.manager.revert_chunked_mm_input(mm_inputs, 3) + self.assertEqual(result, 3) + + def test_revert_chunked_mm_input_after_last_chunk(self): + mm_inputs = { + "mm_positions": [ + ImagePosition(offset=5, length=10), + ImagePosition(offset=20, length=10), + ] + } + result = self.manager.revert_chunked_mm_input(mm_inputs, 35) + self.assertEqual(result, 35) + + if __name__ == "__main__": unittest.main()