Skip to content

Commit

Permalink
deepseek2 support pd multi dp kv trans.
Browse files Browse the repository at this point in the history
  • Loading branch information
hiworldwzj committed Mar 11, 2025
1 parent 5f26114 commit 63b4bdd
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 16 deletions.
70 changes: 60 additions & 10 deletions lightllm/common/deepseek2_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import List, Union
from lightllm.utils.log_utils import init_logger
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node

logger = init_logger(__name__)

Expand Down Expand Up @@ -103,50 +104,99 @@ def send_to_decode_node_p2p(
"""
使用 p2p triton kernel 进行数据复制和传输的实现方式。
"""
assert dp_size_in_node == 1
if not hasattr(self, "mem_ptrs_dict"):
self.mem_ptrs_dict = {}
for layer_index in range(self.layer_num):
mems_ptr = []
for i in range(0, len(mem_managers), len(mem_managers) // dp_size_in_node):
mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr())
mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda")
self.mem_ptrs_dict[layer_index] = mems_ptr

move_token_indexes = []
token_dp_indexes = []
for task in move_tasks:
if task.move_kv_len != 0:
move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :])
token_dp_indexes.extend([task.prefill_dp_index for _ in range(task.move_kv_len)])

move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda")
for layer_index in range(self.layer_num):
move_buffer = self._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer)
move_buffer = self._get_kv_move_data_p2p(
move_token_indexes, token_dp_indexes, layer_index, self.kv_move_buffer, dp_size_in_node
)
dist.send(move_buffer, dst=1)
return

def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor):
def _get_kv_move_data_p2p(
self,
token_indexes: torch.Tensor,
token_dp_indexes: torch.Tensor,
layer_index: int,
kv_move_buffer: torch.Tensor,
dp_size_in_node: int,
):
move_token_num = len(token_indexes)
move_size = self.kv_buffer.numel() // self.layer_num // self.size * move_token_num
move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, self.head_num, self.head_dim)
kv_trans(
self.kv_buffer[layer_index, :, :, :], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num]
kv_trans_v2_for_p_node(
input_mems=self.mem_ptrs_dict[layer_index],
input_idx=token_indexes,
input_dp_idx=token_dp_indexes,
output=move_buffer,
output_idx=self.kv_move_buf_indexes[0:move_token_num],
dp_size_in_node=dp_size_in_node,
)
return move_buffer

def receive_from_prefill_node_p2p(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
):
assert dp_size_in_node == 1
if not hasattr(self, "mem_ptrs_dict"):
self.mem_ptrs_dict = {}
for layer_index in range(self.layer_num):
mems_ptr = []
for i in range(0, len(mem_managers)):
mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr())
mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda")
self.mem_ptrs_dict[layer_index] = mems_ptr

move_token_indexes = []
token_dp_indexes = []
for task in move_tasks:
if task.move_kv_len != 0:
move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :])
token_dp_indexes.extend([task.decode_dp_index for _ in range(task.move_kv_len)])

move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda")

token_num = len(move_token_indexes)
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim)
for layer_index in range(self.layer_num):
dist.recv(recive_buffer, src=0)
for i, mem in enumerate(mem_managers):
mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index)
self._write_kv_move_data_p2p(
move_token_indexes, token_dp_indexes, recive_buffer, layer_index, dp_size_in_node
)
return

def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index):
def _write_kv_move_data_p2p(
self,
token_indexes: torch.Tensor,
token_dp_indexes: torch.Tensor,
buffer_tensor: torch.Tensor,
layer_index,
dp_size_in_node: int,
):
move_token_num = len(token_indexes)
kv_trans(buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.kv_buffer[layer_index], token_indexes)
kv_trans_v2_for_d_node(
output_mems=self.mem_ptrs_dict[layer_index],
output_idx=token_indexes,
output_dp_idx=token_dp_indexes,
input=buffer_tensor,
input_idx=self.kv_move_buf_indexes[0:move_token_num],
dp_size_in_node=dp_size_in_node,
)
return
105 changes: 102 additions & 3 deletions lightllm/common/kv_trans_kernel/kv_trans_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@triton.jit
def _kv_trans_kernel(
def _kv_trans_prefill_node_kernel(
input_mems_ptr,
input_stride_0,
input_stride_1,
Expand Down Expand Up @@ -48,7 +48,7 @@ def _kv_trans_kernel(
return


def kv_trans_v2(
def kv_trans_v2_for_p_node(
input_mems: torch.Tensor,
input_idx: torch.Tensor,
input_dp_idx: torch.Tensor,
Expand All @@ -75,7 +75,7 @@ def kv_trans_v2(
NUM_STAGES = 3
grid = (grid_count,)

_kv_trans_kernel[grid](
_kv_trans_prefill_node_kernel[grid](
input_mems,
*output.stride(),
input_idx,
Expand All @@ -92,3 +92,102 @@ def kv_trans_v2(
num_warps=1,
)
return


@triton.jit
def _kv_trans_decode_node_kernel(
output_mems_ptr,
output_stride_0,
output_stride_1,
output_stride_2,
output_token_idx_ptr,
output_token_dp_index_ptr,
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
input_token_idx_ptr,
token_num: int,
head_num: int,
head_dim: int,
grid_count: int,
BLOCK_SIZE: tl.constexpr,
NUM_STAGES: tl.constexpr,
CARD_NUM_PER_D: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
output_stride_1 = tl.cast(output_stride_1, dtype=tl.int64)

head_num_dim = head_num * head_dim
tid = tl.program_id(0)

offs = tl.arange(0, BLOCK_SIZE)
while tid < token_num:
dp_index = tl.load(output_token_dp_index_ptr + tid)
input_token_idx = tl.load(input_token_idx_ptr + tid)
output_token_idx = tl.load(output_token_idx_ptr + tid)
for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES):
cur_offs = block_idx * BLOCK_SIZE + offs
in_datas = tl.load(input_ptr + input_stride_0 * input_token_idx + cur_offs, mask=cur_offs < head_num_dim)
for mem_index in tl.range(
dp_index * CARD_NUM_PER_D, (dp_index + 1) * CARD_NUM_PER_D, num_stages=NUM_STAGES
):
output_ptr = tl.load(output_mems_ptr + mem_index).to(tl.pointer_type(input_ptr.dtype.element_ty))
tl.store(
output_ptr + output_stride_0 * output_token_idx + cur_offs, in_datas, mask=cur_offs < head_num_dim
)

tid += grid_count

return


def kv_trans_v2_for_d_node(
output_mems: torch.Tensor,
output_idx: torch.Tensor,
output_dp_idx: torch.Tensor,
input: torch.Tensor,
input_idx: torch.Tensor,
dp_size_in_node: int,
):
"""
output_mems 是一个 torch.uint64 的tensor, 其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。
"""
assert output_mems.is_contiguous()
assert input.is_contiguous()
assert len(output_mems.shape) == 1
assert len(input.shape) == 3
assert len(input_idx) == len(output_idx)
assert len(input_idx) == len(output_dp_idx)
assert len(output_mems) % dp_size_in_node == 0

card_num_per_d = len(output_mems) // dp_size_in_node

_, head_num, head_dim = input.shape
token_num = len(input_idx)
# 用较少的资源来做数据传输,防止占用过多的 sm 计算单元
grid_count = 20
BLOCK_SIZE = 256
NUM_STAGES = 3
grid = (grid_count,)

_kv_trans_decode_node_kernel[grid](
output_mems,
*input.stride(),
output_idx,
output_dp_idx,
input,
*input.stride(),
input_idx,
token_num=token_num,
head_num=head_num,
head_dim=head_dim,
grid_count=grid_count,
BLOCK_SIZE=BLOCK_SIZE,
NUM_STAGES=NUM_STAGES,
CARD_NUM_PER_D=card_num_per_d,
num_warps=1,
)
return
42 changes: 39 additions & 3 deletions unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import pytest
import torch
import random
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_p_node, kv_trans_v2_for_d_node


@pytest.mark.parametrize(
"token_num",
[token_num for token_num in range(5, 10)],
)
def test_kv_trans_v2(token_num):
def test_kv_trans_v2_for_p_node(token_num):
dp_size_in_node = 8
head_num = 2
head_dim = 512
Expand All @@ -26,7 +26,7 @@ def test_kv_trans_v2(token_num):
test_output = torch.zeros((token_num, head_num, head_dim), dtype=torch.float16, device="cuda")
output_idx = torch.arange(0, token_num, 1, dtype=torch.int32, device="cuda")

kv_trans_v2(input_mems, input_idx, input_dp_idx, test_output, output_idx, dp_size_in_node)
kv_trans_v2_for_p_node(input_mems, input_idx, input_dp_idx, test_output, output_idx, dp_size_in_node)

for dest_token_index, token_index, dp_index in zip(
list(range(token_num)), input_idx.cpu().numpy(), input_dp_idx.cpu().numpy()
Expand All @@ -37,5 +37,41 @@ def test_kv_trans_v2(token_num):
return


@pytest.mark.parametrize(
"token_num",
[token_num for token_num in range(5, 10)],
)
def test_kv_trans_v2_for_d_node(token_num):
card_num = 8
dp_size_in_node = 4
head_num = 2
head_dim = 512
kv_buffer_token_num = 512
mems = []
for _ in range(card_num):
mems.append(torch.randn((kv_buffer_token_num, head_num, head_dim), dtype=torch.float16, device="cuda"))
output_mems = torch.tensor([e.data_ptr() for e in mems], dtype=torch.uint64, device="cuda")
output_idx = [random.randint(0, kv_buffer_token_num - 1) for _ in range(token_num)]
output_idx = torch.tensor(output_idx, dtype=torch.int32, device="cuda")
output_dp_idx = [random.randint(0, dp_size_in_node - 1) for _ in range(token_num)]
output_dp_idx = torch.tensor(output_dp_idx, dtype=torch.int32, device="cuda")

test_input = torch.randn((token_num, head_num, head_dim), dtype=torch.float16, device="cuda")
input_idx = torch.arange(0, token_num, 1, dtype=torch.int32, device="cuda")

kv_trans_v2_for_d_node(output_mems, output_idx, output_dp_idx, test_input, input_idx, dp_size_in_node)

for dest_token_index, token_index, dest_token_index, dp_index in zip(
list(range(token_num)),
input_idx.cpu().numpy(),
output_idx.cpu().numpy(),
output_dp_idx.cpu().numpy(),
):
for mem_index in range(dp_index * card_num // dp_size_in_node, (dp_index + 1) * card_num // dp_size_in_node):
torch.equal(mems[mem_index][dest_token_index, :, :], test_input[token_index, :, :])

return


if __name__ == "__main__":
pytest.main()

0 comments on commit 63b4bdd

Please sign in to comment.