Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce kv transfer process to num of tp for pd. #758

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions lightllm/common/deepseek2_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import List, Union
from lightllm.utils.log_utils import init_logger
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
from lightllm.distributed.pynccl import PyNcclCommunicator

logger = init_logger(__name__)

Expand Down Expand Up @@ -40,7 +41,11 @@ def alloc_kv_move_buffer(self, max_req_total_len):
return

def send_to_decode_node(
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["Deepseek2MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1

Expand All @@ -54,7 +59,7 @@ def send_to_decode_node(
cur_mem = mem_managers[cur_device_index]
for layer_index in range(cur_mem.layer_num):
move_buffer = cur_mem._get_kv_move_data(move_token_indexes, layer_index)
dist.send(move_buffer, dst=1)
nccl_comm.send(move_buffer, dst=1)
return

def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
Expand All @@ -66,7 +71,11 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
return move_buffer

def receive_from_prefill_node(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1

Expand All @@ -81,7 +90,7 @@ def receive_from_prefill_node(
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, self.head_num, self.head_dim)
for layer_index in range(self.layer_num):
dist.recv(recive_buffer, src=0)
nccl_comm.recv(recive_buffer, src=0)
for i, mem in enumerate(mem_managers):
if i == cur_device_index:
mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index)
Expand All @@ -98,7 +107,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
return

def send_to_decode_node_p2p(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
"""
使用 p2p triton kernel 进行数据复制和传输的实现方式。
Expand All @@ -113,7 +126,7 @@ def send_to_decode_node_p2p(
move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
for layer_index in range(self.layer_num):
move_buffer = self._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer)
dist.send(move_buffer, dst=1)
nccl_comm.send(move_buffer, dst=1)
return

def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor):
Expand All @@ -126,7 +139,11 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
return move_buffer

def receive_from_prefill_node_p2p(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1

Expand All @@ -141,7 +158,7 @@ def receive_from_prefill_node_p2p(
move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim)
for layer_index in range(self.layer_num):
dist.recv(recive_buffer, src=0)
nccl_comm.recv(recive_buffer, src=0)
for i, mem in enumerate(mem_managers):
mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index)
return
Expand Down
35 changes: 26 additions & 9 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
from lightllm.utils.dist_utils import get_current_rank_in_node
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
from lightllm.distributed.pynccl import PyNcclCommunicator


logger = init_logger(__name__)
Expand Down Expand Up @@ -86,7 +87,11 @@ def alloc_kv_move_buffer(self, max_req_total_len):
return

def send_to_decode_node(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1

Expand All @@ -103,14 +108,14 @@ def send_to_decode_node(
for layer_index in range(mem.layer_num):
move_buffer = mem._get_kv_move_data(move_token_indexes, layer_index)
if i == cur_device_index:
dist.send(move_buffer, dst=1)
nccl_comm.send(move_buffer, dst=1)
else:
move_size = move_buffer.numel()
new_move_buffer = cur_mem.kv_move_buffer.view(-1)[0:move_size].view(move_buffer.shape)
from torch.cuda import comm

comm.broadcast(move_buffer, out=[new_move_buffer])
dist.send(new_move_buffer, dst=1)
nccl_comm.send(new_move_buffer, dst=1)
return

def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
Expand All @@ -122,7 +127,11 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
return move_buffer

def receive_from_prefill_node(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1

Expand All @@ -139,7 +148,7 @@ def receive_from_prefill_node(
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, 2 * self.head_num, self.head_dim)
for i, mem in enumerate(mem_managers):
for layer_index in range(mem.layer_num):
dist.recv(recive_buffer, src=0)
nccl_comm.recv(recive_buffer, src=0)
if i == cur_device_index:
mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index)
else:
Expand All @@ -155,7 +164,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
return

def send_to_decode_node_p2p(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
"""
使用 p2p triton kernel 进行数据复制和传输的实现方式。
Expand All @@ -173,7 +186,7 @@ def send_to_decode_node_p2p(
for i, mem in enumerate(mem_managers):
for layer_index in range(mem.layer_num):
move_buffer = mem._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer)
dist.send(move_buffer, dst=1)
nccl_comm.send(move_buffer, dst=1)
return

def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor):
Expand All @@ -186,7 +199,11 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
return move_buffer

def receive_from_prefill_node_p2p(
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
self,
move_tasks: List[KVMoveTask],
mem_managers: List["MemoryManager"],
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1

Expand All @@ -204,7 +221,7 @@ def receive_from_prefill_node_p2p(
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, 2 * self.head_num, self.head_dim)
for i, mem in enumerate(mem_managers):
for layer_index in range(mem.layer_num):
dist.recv(recive_buffer, src=0)
nccl_comm.recv(recive_buffer, src=0)
mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index)
return

Expand Down
Loading