Skip to content

Commit

Permalink
add kv trans v2 kernel for dp mode pd (#763)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiworldwzj authored Mar 11, 2025
1 parent f12ba29 commit 80d101f
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 0 deletions.
94 changes: 94 additions & 0 deletions lightllm/common/kv_trans_kernel/kv_trans_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch

import triton
import triton.language as tl


@triton.jit
def _kv_trans_kernel(
input_mems_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
input_token_idx_ptr,
input_dp_idx_ptr,
output_ptr,
output_stride_0,
output_stride_1,
output_stride_2,
output_token_idx_ptr,
token_num: int,
head_num: int,
head_dim: int,
grid_count: int,
BLOCK_SIZE: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
output_stride_1 = tl.cast(output_stride_1, dtype=tl.int64)

head_num_dim = head_num * head_dim
tid = tl.program_id(0)

offs = tl.arange(0, BLOCK_SIZE)
while tid < token_num:
dp_index = tl.load(input_dp_idx_ptr + tid)
input_token_idx = tl.load(input_token_idx_ptr + tid)
output_token_idx = tl.load(output_token_idx_ptr + tid)
input_ptr = tl.load(input_mems_ptr + dp_index).to(tl.pointer_type(output_ptr.dtype.element_ty))
for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES):
cur_offs = block_idx * BLOCK_SIZE + offs
in_datas = tl.load(input_ptr + input_stride_0 * input_token_idx + cur_offs, mask=cur_offs < head_num_dim)
tl.store(output_ptr + output_stride_0 * output_token_idx + cur_offs, in_datas, mask=cur_offs < head_num_dim)

tid += grid_count

return


def kv_trans_v2(
input_mems: torch.Tensor,
input_idx: torch.Tensor,
input_dp_idx: torch.Tensor,
output: torch.Tensor,
output_idx: torch.Tensor,
dp_size_in_node: int,
):
"""
input_memes 是一个 torch.uint64 的tensor, 其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。
"""
assert input_mems.is_contiguous()
assert output.is_contiguous()
assert len(input_mems.shape) == 1
assert len(input_mems) == dp_size_in_node
assert len(output.shape) == 3
assert len(input_idx) == len(output_idx)
assert len(input_idx) == len(input_dp_idx)

_, head_num, head_dim = output.shape
token_num = len(input_idx)
# 用较少的资源来做数据传输,防止占用过多的 sm 计算单元
grid_count = 20
BLOCK_SIZE = 256
NUM_STAGES = 3
grid = (grid_count,)

_kv_trans_kernel[grid](
input_mems,
*output.stride(),
input_idx,
input_dp_idx,
output,
*output.stride(),
output_idx,
token_num=token_num,
head_num=head_num,
head_dim=head_dim,
grid_count=grid_count,
BLOCK_SIZE=BLOCK_SIZE,
NUM_STAGES=NUM_STAGES,
num_warps=1,
)
return
41 changes: 41 additions & 0 deletions unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
import torch
import random
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2


@pytest.mark.parametrize(
"token_num",
[token_num for token_num in range(5, 10)],
)
def test_kv_trans_v2(token_num):
dp_size_in_node = 8
head_num = 2
head_dim = 512
kv_buffer_token_num = 512
mems = []
for _ in range(dp_size_in_node):
mems.append(torch.randn((kv_buffer_token_num, head_num, head_dim), dtype=torch.float16, device="cuda"))
input_mems = torch.tensor([e.data_ptr() for e in mems], dtype=torch.uint64, device="cuda")
input_idx = [random.randint(0, kv_buffer_token_num - 1) for _ in range(token_num)]
input_idx = torch.tensor(input_idx, dtype=torch.int32, device="cuda")
input_dp_idx = [random.randint(0, dp_size_in_node - 1) for _ in range(token_num)]
input_dp_idx = torch.tensor(input_dp_idx, dtype=torch.int32, device="cuda")

true_output = torch.zeros((token_num, head_num, head_dim), dtype=torch.float16, device="cuda")
test_output = torch.zeros((token_num, head_num, head_dim), dtype=torch.float16, device="cuda")
output_idx = torch.arange(0, token_num, 1, dtype=torch.int32, device="cuda")

kv_trans_v2(input_mems, input_idx, input_dp_idx, test_output, output_idx, dp_size_in_node)

for dest_token_index, token_index, dp_index in zip(
list(range(token_num)), input_idx.cpu().numpy(), input_dp_idx.cpu().numpy()
):
true_output[dest_token_index, :, :] = mems[dp_index][token_index]

assert torch.equal(true_output, test_output)
return


if __name__ == "__main__":
pytest.main()

0 comments on commit 80d101f

Please sign in to comment.