From 63b4bdd07a127f4b1632e7ff8df4a75ec91ee0db Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 11 Mar 2025 17:33:14 +0800 Subject: [PATCH] deepseek2 support pd multi dp kv trans. --- lightllm/common/deepseek2_mem_manager.py | 70 ++++++++++-- .../common/kv_trans_kernel/kv_trans_v2.py | 105 +++++++++++++++++- .../kv_trans_kernel/test_kv_trans_v2.py | 42 ++++++- 3 files changed, 201 insertions(+), 16 deletions(-) diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 7fd8dee26..404413f84 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -6,6 +6,7 @@ from typing import List, Union from lightllm.utils.log_utils import init_logger from lightllm.common.kv_trans_kernel.kv_trans import kv_trans +from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node logger = init_logger(__name__) @@ -103,50 +104,99 @@ def send_to_decode_node_p2p( """ 使用 p2p triton kernel 进行数据复制和传输的实现方式。 """ - assert dp_size_in_node == 1 + if not hasattr(self, "mem_ptrs_dict"): + self.mem_ptrs_dict = {} + for layer_index in range(self.layer_num): + mems_ptr = [] + for i in range(0, len(mem_managers), len(mem_managers) // dp_size_in_node): + mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) + mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda") + self.mem_ptrs_dict[layer_index] = mems_ptr move_token_indexes = [] + token_dp_indexes = [] for task in move_tasks: if task.move_kv_len != 0: move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) + token_dp_indexes.extend([task.prefill_dp_index for _ in range(task.move_kv_len)]) move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") + token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") for layer_index in range(self.layer_num): - move_buffer = self._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer) + move_buffer = self._get_kv_move_data_p2p( + move_token_indexes, token_dp_indexes, layer_index, self.kv_move_buffer, dp_size_in_node + ) dist.send(move_buffer, dst=1) return - def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor): + def _get_kv_move_data_p2p( + self, + token_indexes: torch.Tensor, + token_dp_indexes: torch.Tensor, + layer_index: int, + kv_move_buffer: torch.Tensor, + dp_size_in_node: int, + ): move_token_num = len(token_indexes) move_size = self.kv_buffer.numel() // self.layer_num // self.size * move_token_num move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, self.head_num, self.head_dim) - kv_trans( - self.kv_buffer[layer_index, :, :, :], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num] + kv_trans_v2_for_p_node( + input_mems=self.mem_ptrs_dict[layer_index], + input_idx=token_indexes, + input_dp_idx=token_dp_indexes, + output=move_buffer, + output_idx=self.kv_move_buf_indexes[0:move_token_num], + dp_size_in_node=dp_size_in_node, ) return move_buffer def receive_from_prefill_node_p2p( self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int ): - assert dp_size_in_node == 1 + if not hasattr(self, "mem_ptrs_dict"): + self.mem_ptrs_dict = {} + for layer_index in range(self.layer_num): + mems_ptr = [] + for i in range(0, len(mem_managers)): + mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) + mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda") + self.mem_ptrs_dict[layer_index] = mems_ptr move_token_indexes = [] + token_dp_indexes = [] for task in move_tasks: if task.move_kv_len != 0: move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) + token_dp_indexes.extend([task.decode_dp_index for _ in range(task.move_kv_len)]) move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") + token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") token_num = len(move_token_indexes) move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim) for layer_index in range(self.layer_num): dist.recv(recive_buffer, src=0) - for i, mem in enumerate(mem_managers): - mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index) + self._write_kv_move_data_p2p( + move_token_indexes, token_dp_indexes, recive_buffer, layer_index, dp_size_in_node + ) return - def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index): + def _write_kv_move_data_p2p( + self, + token_indexes: torch.Tensor, + token_dp_indexes: torch.Tensor, + buffer_tensor: torch.Tensor, + layer_index, + dp_size_in_node: int, + ): move_token_num = len(token_indexes) - kv_trans(buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.kv_buffer[layer_index], token_indexes) + kv_trans_v2_for_d_node( + output_mems=self.mem_ptrs_dict[layer_index], + output_idx=token_indexes, + output_dp_idx=token_dp_indexes, + input=buffer_tensor, + input_idx=self.kv_move_buf_indexes[0:move_token_num], + dp_size_in_node=dp_size_in_node, + ) return diff --git a/lightllm/common/kv_trans_kernel/kv_trans_v2.py b/lightllm/common/kv_trans_kernel/kv_trans_v2.py index 4465ad65b..772de5f6c 100644 --- a/lightllm/common/kv_trans_kernel/kv_trans_v2.py +++ b/lightllm/common/kv_trans_kernel/kv_trans_v2.py @@ -5,7 +5,7 @@ @triton.jit -def _kv_trans_kernel( +def _kv_trans_prefill_node_kernel( input_mems_ptr, input_stride_0, input_stride_1, @@ -48,7 +48,7 @@ def _kv_trans_kernel( return -def kv_trans_v2( +def kv_trans_v2_for_p_node( input_mems: torch.Tensor, input_idx: torch.Tensor, input_dp_idx: torch.Tensor, @@ -75,7 +75,7 @@ def kv_trans_v2( NUM_STAGES = 3 grid = (grid_count,) - _kv_trans_kernel[grid]( + _kv_trans_prefill_node_kernel[grid]( input_mems, *output.stride(), input_idx, @@ -92,3 +92,102 @@ def kv_trans_v2( num_warps=1, ) return + + +@triton.jit +def _kv_trans_decode_node_kernel( + output_mems_ptr, + output_stride_0, + output_stride_1, + output_stride_2, + output_token_idx_ptr, + output_token_dp_index_ptr, + input_ptr, + input_stride_0, + input_stride_1, + input_stride_2, + input_token_idx_ptr, + token_num: int, + head_num: int, + head_dim: int, + grid_count: int, + BLOCK_SIZE: tl.constexpr, + NUM_STAGES: tl.constexpr, + CARD_NUM_PER_D: tl.constexpr, +): + input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) + input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) + output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64) + output_stride_1 = tl.cast(output_stride_1, dtype=tl.int64) + + head_num_dim = head_num * head_dim + tid = tl.program_id(0) + + offs = tl.arange(0, BLOCK_SIZE) + while tid < token_num: + dp_index = tl.load(output_token_dp_index_ptr + tid) + input_token_idx = tl.load(input_token_idx_ptr + tid) + output_token_idx = tl.load(output_token_idx_ptr + tid) + for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES): + cur_offs = block_idx * BLOCK_SIZE + offs + in_datas = tl.load(input_ptr + input_stride_0 * input_token_idx + cur_offs, mask=cur_offs < head_num_dim) + for mem_index in tl.range( + dp_index * CARD_NUM_PER_D, (dp_index + 1) * CARD_NUM_PER_D, num_stages=NUM_STAGES + ): + output_ptr = tl.load(output_mems_ptr + mem_index).to(tl.pointer_type(input_ptr.dtype.element_ty)) + tl.store( + output_ptr + output_stride_0 * output_token_idx + cur_offs, in_datas, mask=cur_offs < head_num_dim + ) + + tid += grid_count + + return + + +def kv_trans_v2_for_d_node( + output_mems: torch.Tensor, + output_idx: torch.Tensor, + output_dp_idx: torch.Tensor, + input: torch.Tensor, + input_idx: torch.Tensor, + dp_size_in_node: int, +): + """ + output_mems 是一个 torch.uint64 的tensor, 其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。 + """ + assert output_mems.is_contiguous() + assert input.is_contiguous() + assert len(output_mems.shape) == 1 + assert len(input.shape) == 3 + assert len(input_idx) == len(output_idx) + assert len(input_idx) == len(output_dp_idx) + assert len(output_mems) % dp_size_in_node == 0 + + card_num_per_d = len(output_mems) // dp_size_in_node + + _, head_num, head_dim = input.shape + token_num = len(input_idx) + # 用较少的资源来做数据传输,防止占用过多的 sm 计算单元 + grid_count = 20 + BLOCK_SIZE = 256 + NUM_STAGES = 3 + grid = (grid_count,) + + _kv_trans_decode_node_kernel[grid]( + output_mems, + *input.stride(), + output_idx, + output_dp_idx, + input, + *input.stride(), + input_idx, + token_num=token_num, + head_num=head_num, + head_dim=head_dim, + grid_count=grid_count, + BLOCK_SIZE=BLOCK_SIZE, + NUM_STAGES=NUM_STAGES, + CARD_NUM_PER_D=card_num_per_d, + num_warps=1, + ) + return diff --git a/unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py b/unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py index 9d0afc949..509415da0 100644 --- a/unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py +++ b/unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py @@ -1,14 +1,14 @@ import pytest import torch import random -from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2 +from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_p_node, kv_trans_v2_for_d_node @pytest.mark.parametrize( "token_num", [token_num for token_num in range(5, 10)], ) -def test_kv_trans_v2(token_num): +def test_kv_trans_v2_for_p_node(token_num): dp_size_in_node = 8 head_num = 2 head_dim = 512 @@ -26,7 +26,7 @@ def test_kv_trans_v2(token_num): test_output = torch.zeros((token_num, head_num, head_dim), dtype=torch.float16, device="cuda") output_idx = torch.arange(0, token_num, 1, dtype=torch.int32, device="cuda") - kv_trans_v2(input_mems, input_idx, input_dp_idx, test_output, output_idx, dp_size_in_node) + kv_trans_v2_for_p_node(input_mems, input_idx, input_dp_idx, test_output, output_idx, dp_size_in_node) for dest_token_index, token_index, dp_index in zip( list(range(token_num)), input_idx.cpu().numpy(), input_dp_idx.cpu().numpy() @@ -37,5 +37,41 @@ def test_kv_trans_v2(token_num): return +@pytest.mark.parametrize( + "token_num", + [token_num for token_num in range(5, 10)], +) +def test_kv_trans_v2_for_d_node(token_num): + card_num = 8 + dp_size_in_node = 4 + head_num = 2 + head_dim = 512 + kv_buffer_token_num = 512 + mems = [] + for _ in range(card_num): + mems.append(torch.randn((kv_buffer_token_num, head_num, head_dim), dtype=torch.float16, device="cuda")) + output_mems = torch.tensor([e.data_ptr() for e in mems], dtype=torch.uint64, device="cuda") + output_idx = [random.randint(0, kv_buffer_token_num - 1) for _ in range(token_num)] + output_idx = torch.tensor(output_idx, dtype=torch.int32, device="cuda") + output_dp_idx = [random.randint(0, dp_size_in_node - 1) for _ in range(token_num)] + output_dp_idx = torch.tensor(output_dp_idx, dtype=torch.int32, device="cuda") + + test_input = torch.randn((token_num, head_num, head_dim), dtype=torch.float16, device="cuda") + input_idx = torch.arange(0, token_num, 1, dtype=torch.int32, device="cuda") + + kv_trans_v2_for_d_node(output_mems, output_idx, output_dp_idx, test_input, input_idx, dp_size_in_node) + + for dest_token_index, token_index, dest_token_index, dp_index in zip( + list(range(token_num)), + input_idx.cpu().numpy(), + output_idx.cpu().numpy(), + output_dp_idx.cpu().numpy(), + ): + for mem_index in range(dp_index * card_num // dp_size_in_node, (dp_index + 1) * card_num // dp_size_in_node): + torch.equal(mems[mem_index][dest_token_index, :, :], test_input[token_index, :, :]) + + return + + if __name__ == "__main__": pytest.main()