Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 0 additions & 82 deletions fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,66 +1398,6 @@ def get_block_hash_extra_keys(self, request, start_idx, end_idx, mm_idx):
hash_keys.append(mm_inputs["mm_hashes"][img_idx])
return len(mm_inputs["mm_positions"]) - 1, hash_keys

def _revert_match_blocks(
self,
request,
matched_token_num: int,
block_size: int,
chunk_idx: int,
match_node_ids: list,
matche_nodes: list,
match_gpu_block_ids: list,
match_cpu_block_ids: list,
gpu_match_token_num: int,
cpu_match_token_num: int,
swap_node_ids: list,
):
# position = request.multimodal_inputs["mm_positions"][chunk_idx]
# revert_tokens = matched_token_num - position.offset
# TODO(chengyanfu): fix when is_chunked_mm_input=True, revert all matched tokens
revert_tokens = matched_token_num
match_block_ids = [node.block_id for node in matche_nodes]
logger.warning(
f"match_block: req_id {request.request_id} revert tokens: {revert_tokens} from matched nodes: {match_block_ids}"
)
while revert_tokens >= block_size:
if len(matche_nodes) == 0:
logger.error(f"req_id {request.request_id} revert nodes error, tokens: {revert_tokens}")
break
revert_tokens -= block_size
revert_block = matche_nodes.pop()
revert_block_id = revert_block.block_id
if revert_block_id in match_gpu_block_ids:
match_gpu_block_ids.remove(revert_block_id)
match_node_ids.remove(revert_block.node_id)
gpu_match_token_num -= block_size
elif revert_block_id in match_cpu_block_ids:
match_cpu_block_ids.remove(revert_block_id)
match_node_ids.remove(revert_block.node_id)
cpu_match_token_num -= block_size
else:
logger.error(
f"req_id {request.request_id} revert nodes error, nodes: {revert_block_id}, "
f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}"
)
break
if revert_block_id in swap_node_ids:
swap_node_ids.remove(revert_block_id)

if revert_tokens > 0:
last_block_id = matche_nodes[-1].block_id
if last_block_id in match_gpu_block_ids:
gpu_match_token_num -= revert_tokens
elif last_block_id in match_cpu_block_ids:
cpu_match_token_num -= revert_tokens
else:
logger.error(
f"req_id {request.request_id} revert nodes error, revert_tokens: {revert_tokens}, nodes: {last_block_id}, "
f"match_gpu_block_ids: {match_gpu_block_ids}, match_cpu_block_ids: {match_cpu_block_ids}"
)
current_node = self.radix_tree_root if len(matche_nodes) == 0 else matche_nodes[-1]
return gpu_match_token_num, cpu_match_token_num, current_node

def mm_match_block(self, request, block_size):
"""
Match and retrieve cached blocks for multimodal requests using a radix tree structure.
Expand Down Expand Up @@ -1550,28 +1490,6 @@ def mm_match_block(self, request, block_size):
if has_modified_cpu_lru_leaf_heap:
heapq.heapify(self.cpu_lru_leaf_heap)

if self.cache_config.disable_chunked_mm_input:
matched_token_num = gpu_match_token_num + cpu_match_token_num
is_chunked, chunk_idx = self.is_chunked_mm_input(request.multimodal_inputs, matched_token_num)
if is_chunked:
(
gpu_match_token_num,
cpu_match_token_num,
current_match_node,
) = self._revert_match_blocks(
request=request,
matched_token_num=matched_token_num,
block_size=block_size,
chunk_idx=chunk_idx,
match_node_ids=match_node_ids,
matche_nodes=matche_nodes,
match_gpu_block_ids=match_gpu_block_ids,
match_cpu_block_ids=match_cpu_block_ids,
gpu_match_token_num=gpu_match_token_num,
cpu_match_token_num=cpu_match_token_num,
swap_node_ids=swap_node_ids,
)

logger.info(f"match_block: req_id {request.request_id} matched nodes: {match_node_ids}")
return (
match_gpu_block_ids,
Expand Down
32 changes: 28 additions & 4 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,21 @@ def _is_mm_request(self, request):

return False

def revert_chunked_mm_input(self, mm_inputs, matched_token_num):
"""
revert mm_inputs that is chunked
"""
if mm_inputs is None or "mm_positions" not in mm_inputs or len(mm_inputs["mm_positions"]) == 0:
return matched_token_num

for idx in range(len(mm_inputs["mm_positions"])):
position = mm_inputs["mm_positions"][idx]
if position.offset < matched_token_num < position.offset + position.length:
return position.offset
elif matched_token_num < position.offset:
break
return matched_token_num

def _get_num_new_tokens(self, request, token_budget):
# TODO: set condition to new _get_num_new_tokens
num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens
Expand Down Expand Up @@ -957,11 +972,20 @@ def get_prefix_cached_blocks(self, request: Request):
main_process_metrics.prefix_gpu_cache_token_num.inc(request.metrics.gpu_cache_token_num)
main_process_metrics.prefix_cpu_cache_token_num.inc(request.metrics.gpu_cache_token_num)

if matched_token_num == request.need_prefill_tokens:
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
request.skip_allocate = True
if self.config.cache_config.disable_chunked_mm_input:
if matched_token_num == request.need_prefill_tokens:
matched_token_num = matched_token_num - self.config.cache_config.block_size
request.skip_allocate = True
request.num_computed_tokens = self.revert_chunked_mm_input(
request.multimodal_inputs, matched_token_num
)
else:
request.num_computed_tokens = matched_token_num
if matched_token_num == request.need_prefill_tokens:
request.num_computed_tokens = matched_token_num - self.config.cache_config.block_size
request.skip_allocate = True
else:
request.num_computed_tokens = matched_token_num
llm_logger.info(f"request {request.request_id} num_computed_tokens: {request.num_computed_tokens}")
return True
except Exception as e:
llm_logger.error(f"prefix match blocks error: {e}, {str(traceback.format_exc())} waiting reschedule...")
Expand Down
66 changes: 0 additions & 66 deletions tests/cache_manager/test_prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,17 +821,6 @@ def test_update_cache_blocks_refreshes_mappings(self):
self.assertIn(req_id, manager.leaf_req_map[new_leaf])
self.assertEqual(task.num_cached_blocks, 2)

def test_is_chunked_mm_input_detects_overlap(self):
manager = _create_manager()
mm_inputs = {
"mm_positions": [SimpleNamespace(offset=2, length=3)],
"mm_hashes": ["img"],
}

chunked, idx = manager.is_chunked_mm_input(mm_inputs, matched_token_num=3)
self.assertTrue(chunked)
self.assertEqual(idx, 0)

def test_issue_and_sync_swap_tasks(self):
manager = _create_manager()
manager.cache_task_queue = _DummyEngineCacheQueue()
Expand Down Expand Up @@ -1101,33 +1090,6 @@ def test_free_block_ids_async_consumes_finished_future(self):
self.assertIsNone(manager.gpu_free_task_future)
self.assertTrue(finished.result_called)

def test_mm_match_block_reverts_chunked_inputs(self):
manager = _create_manager(num_gpu_blocks=4)
manager.cache_config.disable_chunked_mm_input = True
block_size = 2
input_ids = [1, 2, 3, 4]
hash_input = get_hash_str(input_ids)
hash_first = get_hash_str([1, 2])
hash_second = get_hash_str([3, 4], ["img"])
node1 = BlockNode(80, input_ids, hash_input, 1, 0, block_size, hash_first, 0, parent=manager.radix_tree_root)
node2 = BlockNode(81, input_ids, hash_input, 2, 1, block_size, hash_second, 0, parent=node1)
manager.radix_tree_root.children[hash_first] = node1
node1.children[hash_second] = node2

request = SimpleNamespace(
prompt_token_ids=input_ids,
output_token_ids=[],
request_id="chunk-req",
multimodal_inputs={
"mm_positions": [SimpleNamespace(offset=1, length=3)],
"mm_hashes": ["img"],
},
num_total_tokens=4,
)

match_gpu, *_ = manager.mm_match_block(request, block_size)
self.assertEqual(match_gpu, [])

def test_mm_build_path_creates_new_nodes(self):
manager = _create_manager(num_gpu_blocks=6)
request = SimpleNamespace(
Expand Down Expand Up @@ -1194,34 +1156,6 @@ def test_clear_prefix_cache_resets_on_signal(self):
with self.assertRaises(SystemExit):
manager.clear_prefix_cache()

@unittest.skip("Skip TestRevertMatchBlocks")
def test_revert_match_blocks_adjusts_lists(self):
manager = _create_manager()
request = SimpleNamespace(
request_id="revert",
multimodal_inputs={"mm_positions": [SimpleNamespace(offset=2, length=2)]},
)
node = BlockNode(120, [1, 2], 0, 1, 0, 2, get_hash_str([1, 2]), 0, parent=manager.radix_tree_root)
matche_nodes = [node]
match_gpu = [0]
match_node_ids = [node.node_id]
swap_nodes = [node.block_id]
gpu_tokens, cpu_tokens, current = manager._revert_match_blocks(
request=request,
matched_token_num=4,
block_size=2,
chunk_idx=0,
match_node_ids=match_node_ids,
matche_nodes=matche_nodes,
match_gpu_block_ids=match_gpu,
match_cpu_block_ids=[],
gpu_match_token_num=4,
cpu_match_token_num=0,
swap_node_ids=swap_nodes,
)
self.assertEqual(gpu_tokens, 2)
self.assertEqual(current, manager.radix_tree_root)


# Coverage-oriented tests. These are used to lightly exercise specific
# implementation details without constraining core behavior.
Expand Down
Loading
Loading