From 80d101fe6a426a2fcfc9cc0244d30e4c00da48c0 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Tue, 11 Mar 2025 10:38:29 +0800 Subject: [PATCH] add kv trans v2 kernel for dp mode pd (#763) --- .../common/kv_trans_kernel/kv_trans_v2.py | 94 +++++++++++++++++++ .../kv_trans_kernel/test_kv_trans_v2.py | 41 ++++++++ 2 files changed, 135 insertions(+) create mode 100644 lightllm/common/kv_trans_kernel/kv_trans_v2.py create mode 100644 unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py diff --git a/lightllm/common/kv_trans_kernel/kv_trans_v2.py b/lightllm/common/kv_trans_kernel/kv_trans_v2.py new file mode 100644 index 000000000..4465ad65b --- /dev/null +++ b/lightllm/common/kv_trans_kernel/kv_trans_v2.py @@ -0,0 +1,94 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _kv_trans_kernel( + input_mems_ptr, + input_stride_0, + input_stride_1, + input_stride_2, + input_token_idx_ptr, + input_dp_idx_ptr, + output_ptr, + output_stride_0, + output_stride_1, + output_stride_2, + output_token_idx_ptr, + token_num: int, + head_num: int, + head_dim: int, + grid_count: int, + BLOCK_SIZE: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) + input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) + output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64) + output_stride_1 = tl.cast(output_stride_1, dtype=tl.int64) + + head_num_dim = head_num * head_dim + tid = tl.program_id(0) + + offs = tl.arange(0, BLOCK_SIZE) + while tid < token_num: + dp_index = tl.load(input_dp_idx_ptr + tid) + input_token_idx = tl.load(input_token_idx_ptr + tid) + output_token_idx = tl.load(output_token_idx_ptr + tid) + input_ptr = tl.load(input_mems_ptr + dp_index).to(tl.pointer_type(output_ptr.dtype.element_ty)) + for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES): + cur_offs = block_idx * BLOCK_SIZE + offs + in_datas = tl.load(input_ptr + input_stride_0 * input_token_idx + cur_offs, mask=cur_offs < head_num_dim) + tl.store(output_ptr + output_stride_0 * output_token_idx + cur_offs, in_datas, mask=cur_offs < head_num_dim) + + tid += grid_count + + return + + +def kv_trans_v2( + input_mems: torch.Tensor, + input_idx: torch.Tensor, + input_dp_idx: torch.Tensor, + output: torch.Tensor, + output_idx: torch.Tensor, + dp_size_in_node: int, +): + """ + input_memes 是一个 torch.uint64 的tensor, 其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。 + """ + assert input_mems.is_contiguous() + assert output.is_contiguous() + assert len(input_mems.shape) == 1 + assert len(input_mems) == dp_size_in_node + assert len(output.shape) == 3 + assert len(input_idx) == len(output_idx) + assert len(input_idx) == len(input_dp_idx) + + _, head_num, head_dim = output.shape + token_num = len(input_idx) + # 用较少的资源来做数据传输,防止占用过多的 sm 计算单元 + grid_count = 20 + BLOCK_SIZE = 256 + NUM_STAGES = 3 + grid = (grid_count,) + + _kv_trans_kernel[grid]( + input_mems, + *output.stride(), + input_idx, + input_dp_idx, + output, + *output.stride(), + output_idx, + token_num=token_num, + head_num=head_num, + head_dim=head_dim, + grid_count=grid_count, + BLOCK_SIZE=BLOCK_SIZE, + NUM_STAGES=NUM_STAGES, + num_warps=1, + ) + return diff --git a/unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py b/unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py new file mode 100644 index 000000000..9d0afc949 --- /dev/null +++ b/unit_tests/common/kv_trans_kernel/test_kv_trans_v2.py @@ -0,0 +1,41 @@ +import pytest +import torch +import random +from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2 + + +@pytest.mark.parametrize( + "token_num", + [token_num for token_num in range(5, 10)], +) +def test_kv_trans_v2(token_num): + dp_size_in_node = 8 + head_num = 2 + head_dim = 512 + kv_buffer_token_num = 512 + mems = [] + for _ in range(dp_size_in_node): + mems.append(torch.randn((kv_buffer_token_num, head_num, head_dim), dtype=torch.float16, device="cuda")) + input_mems = torch.tensor([e.data_ptr() for e in mems], dtype=torch.uint64, device="cuda") + input_idx = [random.randint(0, kv_buffer_token_num - 1) for _ in range(token_num)] + input_idx = torch.tensor(input_idx, dtype=torch.int32, device="cuda") + input_dp_idx = [random.randint(0, dp_size_in_node - 1) for _ in range(token_num)] + input_dp_idx = torch.tensor(input_dp_idx, dtype=torch.int32, device="cuda") + + true_output = torch.zeros((token_num, head_num, head_dim), dtype=torch.float16, device="cuda") + test_output = torch.zeros((token_num, head_num, head_dim), dtype=torch.float16, device="cuda") + output_idx = torch.arange(0, token_num, 1, dtype=torch.int32, device="cuda") + + kv_trans_v2(input_mems, input_idx, input_dp_idx, test_output, output_idx, dp_size_in_node) + + for dest_token_index, token_index, dp_index in zip( + list(range(token_num)), input_idx.cpu().numpy(), input_dp_idx.cpu().numpy() + ): + true_output[dest_token_index, :, :] = mems[dp_index][token_index] + + assert torch.equal(true_output, test_output) + return + + +if __name__ == "__main__": + pytest.main()