From 9906f75707ef095c8495b4267c8de79d71380525 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Fri, 7 Mar 2025 19:55:21 +0800 Subject: [PATCH 1/4] single kv transfer process for pd. --- lightllm/common/deepseek2_mem_manager.py | 21 +- lightllm/common/mem_manager.py | 23 +- lightllm/distributed/pynccl.py | 332 ++++++++++++++ lightllm/distributed/pynccl_wrapper.py | 405 ++++++++++++++++++ lightllm/server/pd_io_struct.py | 13 + .../decode_kv_move_manager.py | 111 ++--- .../decode_node_impl/decode_trans_process.py | 113 ++--- .../prefill_kv_move_manager.py | 82 ++-- .../prefill_trans_process.py | 131 +++--- 9 files changed, 1011 insertions(+), 220 deletions(-) create mode 100644 lightllm/distributed/pynccl.py create mode 100644 lightllm/distributed/pynccl_wrapper.py diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 7fd8dee26..c7f437e96 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -6,6 +6,7 @@ from typing import List, Union from lightllm.utils.log_utils import init_logger from lightllm.common.kv_trans_kernel.kv_trans import kv_trans +from lightllm.distributed.pynccl import PyNcclCommunicator logger = init_logger(__name__) @@ -40,7 +41,8 @@ def alloc_kv_move_buffer(self, max_req_total_len): return def send_to_decode_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator ): assert dp_size_in_node == 1 @@ -54,7 +56,7 @@ def send_to_decode_node( cur_mem = mem_managers[cur_device_index] for layer_index in range(cur_mem.layer_num): move_buffer = cur_mem._get_kv_move_data(move_token_indexes, layer_index) - dist.send(move_buffer, dst=1) + nccl_comm.send(move_buffer, dst=1) return def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): @@ -66,7 +68,8 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): return move_buffer def receive_from_prefill_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator ): assert dp_size_in_node == 1 @@ -81,7 +84,7 @@ def receive_from_prefill_node( move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, self.head_num, self.head_dim) for layer_index in range(self.layer_num): - dist.recv(recive_buffer, src=0) + nccl_comm.recv(recive_buffer, src=0) for i, mem in enumerate(mem_managers): if i == cur_device_index: mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index) @@ -98,7 +101,8 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch. return def send_to_decode_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator ): """ 使用 p2p triton kernel 进行数据复制和传输的实现方式。 @@ -113,7 +117,7 @@ def send_to_decode_node_p2p( move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") for layer_index in range(self.layer_num): move_buffer = self._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer) - dist.send(move_buffer, dst=1) + nccl_comm.send(move_buffer, dst=1) return def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor): @@ -126,7 +130,8 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k return move_buffer def receive_from_prefill_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator ): assert dp_size_in_node == 1 @@ -141,7 +146,7 @@ def receive_from_prefill_node_p2p( move_size = self.kv_buffer.numel() // self.layer_num // self.size * token_num recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, self.head_num, self.head_dim) for layer_index in range(self.layer_num): - dist.recv(recive_buffer, src=0) + nccl_comm.recv(recive_buffer, src=0) for i, mem in enumerate(mem_managers): mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index) return diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 9f396d855..79edd3cdc 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -10,6 +10,7 @@ from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.utils.dist_utils import get_current_rank_in_node from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.distributed.pynccl import PyNcclCommunicator logger = init_logger(__name__) @@ -86,7 +87,8 @@ def alloc_kv_move_buffer(self, max_req_total_len): return def send_to_decode_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator ): assert dp_size_in_node == 1 @@ -103,14 +105,14 @@ def send_to_decode_node( for layer_index in range(mem.layer_num): move_buffer = mem._get_kv_move_data(move_token_indexes, layer_index) if i == cur_device_index: - dist.send(move_buffer, dst=1) + nccl_comm.send(move_buffer, dst=1) else: move_size = move_buffer.numel() new_move_buffer = cur_mem.kv_move_buffer.view(-1)[0:move_size].view(move_buffer.shape) from torch.cuda import comm comm.broadcast(move_buffer, out=[new_move_buffer]) - dist.send(new_move_buffer, dst=1) + nccl_comm.send(new_move_buffer, dst=1) return def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): @@ -122,7 +124,8 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): return move_buffer def receive_from_prefill_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -139,7 +142,7 @@ def receive_from_prefill_node( recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(1, token_num, 2 * self.head_num, self.head_dim) for i, mem in enumerate(mem_managers): for layer_index in range(mem.layer_num): - dist.recv(recive_buffer, src=0) + nccl_comm.recv(recive_buffer, src=0) if i == cur_device_index: mem._write_kv_move_data(move_token_indexes, recive_buffer, layer_index) else: @@ -155,7 +158,8 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch. return def send_to_decode_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator ): """ 使用 p2p triton kernel 进行数据复制和传输的实现方式。 @@ -173,7 +177,7 @@ def send_to_decode_node_p2p( for i, mem in enumerate(mem_managers): for layer_index in range(mem.layer_num): move_buffer = mem._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer) - dist.send(move_buffer, dst=1) + nccl_comm.send(move_buffer, dst=1) return def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor): @@ -186,7 +190,8 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k return move_buffer def receive_from_prefill_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int + self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -204,7 +209,7 @@ def receive_from_prefill_node_p2p( recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, 2 * self.head_num, self.head_dim) for i, mem in enumerate(mem_managers): for layer_index in range(mem.layer_num): - dist.recv(recive_buffer, src=0) + nccl_comm.recv(recive_buffer, src=0) mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index) return diff --git a/lightllm/distributed/pynccl.py b/lightllm/distributed/pynccl.py new file mode 100644 index 000000000..9a01dd116 --- /dev/null +++ b/lightllm/distributed/pynccl.py @@ -0,0 +1,332 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/pynccl.py +# of the vllm-project/vllm GitHub repository. +# +# Copyright 2023 ModelTC Team +# Copyright 2023 vLLM Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from datetime import timedelta +import pickle +import time +from typing import Optional, Union, Dict, Deque, Tuple, Any +from collections import deque +import logging + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp, TCPStore + +from lightllm.distributed.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, + ncclRedOpTypeEnum, ncclUniqueId) + +logger = logging.getLogger(__name__) + +_current_stream = None + +def current_stream() -> torch.cuda.Stream: + global _current_stream + if _current_stream is None: + _current_stream = torch.cuda.current_stream() + return _current_stream + +@dataclasses.dataclass +class StatelessP2PProcessGroup: + """A dataclass to hold a metadata store, and the rank, world_size of the + group. Only use it to communicate metadata between processes. + For data-plane communication, create NCCL-related objects. + """ + + dest_id: int + src_id: int + is_server: bool + + rank: int = 0 + world_size: int = 2 + store: TCPStore = None + data_expiration_seconds: int = 3600 # 1 hour + # dst rank -> counter + send_dst_counter: int = 0 + # src rank -> counter + recv_src_counter: int = 0 + entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque) + + def __post_init__(self): + self.rank = 0 if self.is_server else 1 + self.world_size = 2 + self.send_dst_counter = 0 + self.recv_src_counter = 0 + + def send_obj(self, obj: Any): + """Send an object to a destination rank.""" + self.expire_data() + key = f"send_to/{self.dest_id}/{self.send_dst_counter}" + self.store.set(key, pickle.dumps(obj)) + self.send_dst_counter += 1 + self.entries.append((key, time.time())) + + def expire_data(self): + """Expire data that is older than `data_expiration_seconds` seconds.""" + while self.entries: + # check the oldest entry + key, timestamp = self.entries[0] + if time.time() - timestamp > self.data_expiration_seconds: + self.store.delete_key(key) + self.entries.popleft() + else: + break + + def recv_obj(self) -> Any: + """Receive an object from a source rank.""" + obj = pickle.loads( + self.store.get( + f"send_to/{self.dest_id}/{self.recv_src_counter}")) + self.recv_src_counter += 1 + return obj + + @staticmethod + def create( + src_id: int, + dest_id: int, + is_server: bool, + store: torch._C._distributed_c10d.Store + ) -> "StatelessP2PProcessGroup": + """A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. + + If we have process A and process B called `torch.distributed.init_process_group` + to form a group, and then we want to form another group with process A, B, C, + D, it is not possible in PyTorch, because process A and process B have already + formed a group, and process C and process D cannot join that group. This + function is a workaround for this issue. + + `torch.distributed.init_process_group` is a global call, while this function + is a stateless call. It will return a `StatelessProcessGroup` object that can be + used for exchanging metadata. With this function, process A and process B + can call `StatelessProcessGroup.create` to form a group, and then process A, B, + C, and D can call `StatelessProcessGroup.create` to form another group. + """ # noqa + return StatelessP2PProcessGroup(src_id=src_id, dest_id=dest_id, is_server=is_server, store=store) + + +class PyNcclCommunicator: + + def __init__( + self, + group: Union[ProcessGroup, StatelessP2PProcessGroup], + device: Union[int, str, torch.device], + library_path: Optional[str] = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyNcclCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + if not isinstance(group, StatelessP2PProcessGroup): + assert dist.is_initialized() + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "PyNcclCommunicator should be attached to a non-NCCL group.") + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + return + + self.available = True + self.disabled = False + + logger.info("LightLLM is using nccl==%s", self.nccl.ncclGetVersion()) + + if self.rank == 0: + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = ncclUniqueId() + + if not isinstance(group, StatelessP2PProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + if group.rank == 0: + group.send_obj(self.unique_id) + else: + self.unique_id = group.recv_obj() + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank) + + stream = current_stream() + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + stream.synchronize() + del data + + def destroy(self): + self.nccl.ncclCommDestroy(self.comm) + + def all_reduce(self, + in_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None) -> torch.Tensor: + if self.disabled: + return None + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert in_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {in_tensor.device}") + + out_tensor = torch.empty_like(in_tensor) + + if stream is None: + stream = current_stream() + self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + return out_tensor + + def all_gather(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, + cudaStream_t(stream.cuda_stream)) + + def reduce_scatter(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + self.comm, cudaStream_t(stream.cuda_stream)) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + diff --git a/lightllm/distributed/pynccl_wrapper.py b/lightllm/distributed/pynccl_wrapper.py new file mode 100644 index 000000000..d35ec8e3a --- /dev/null +++ b/lightllm/distributed/pynccl_wrapper.py @@ -0,0 +1,405 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/pynccl_wrapper.py +# of the vllm-project/vllm GitHub repository. +# +# Copyright 2023 ModelTC Team +# Copyright 2023 vLLM Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 + +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +import logging + +logger = logging.getLogger(__name__) + + +def find_nccl_library() -> str: + """ + We either use the library file specified by the `VLLM_NCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be + found by `ctypes` automatically. + """ + so_file = None + + # manually load the nccl library + if so_file: + logger.info( + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", + so_file) + else: + if torch.version.cuda is not None: + so_file = "libnccl.so.2" + elif torch.version.hip is not None: + so_file = "librccl.so.1" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.info("Found nccl from library %s", so_file) + return so_file + + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllGather", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclReduceScatter", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function("ncclSend", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function("ncclRecv", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function("ncclBroadcast", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ctypes.c_int, ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s. " + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s. " + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: Dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, + count, datatype, op, + comm, stream)) + + def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) + + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, + comm, stream)) + + def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, root: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, + datatype, root, comm, + stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] + + + +def test_ncclGetUniqueId(): + lib = NCCLLibrary() + unique_id = lib.ncclGetUniqueId() + print(unique_id.internal) + # `list(unique_id.internal)` is something like this: + # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + # as long as the function doesn't raise an exception, we're good + assert unique_id is not None + +if __name__ == '__main__': + import torch; + torch.cuda.set_device(0) + test_ncclGetUniqueId() diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index d5d22c8ea..222cd5887 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -74,6 +74,19 @@ class DecodeNodeInfo: rpyc_port: str max_new_tokens: int +@dataclass +class PDTransJoinInfo: + decode_id: int + decode_device_id: int + prefill_id: int + prefill_device_id: int + prefill_ip: str + prefill_port: int + +@dataclass +class PDTransLeaveInfo: + decode_id: int + prefill_id: int @dataclass class KVMoveTask: diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py index 30096e3e5..0c46b6dd5 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py @@ -16,7 +16,7 @@ from .decode_infer_rpyc import PDDecodeInferRpcServer from ..task_queue import TaskQueue import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask, UpKVStatus +from lightllm.server.pd_io_struct import KVMoveTask, UpKVStatus, PDTransJoinInfo, PDTransLeaveInfo from lightllm.utils.retry_utils import retry import numpy as np from rpyc import AsyncResult @@ -33,12 +33,11 @@ @dataclass class TransProcessObj: - prefill_node_id: str = None - process: mp.Process = None + prefill_node_id: int = None task_in_queue: mp.Queue = None task_out_queue: mp.Queue = None - nccl_ip: str = None - nccl_port: str = None + prefill_ip: str = None + prefill_port: int = None device_index: int = None manager: "DecodeKVMoveManager" = None has_error: bool = False @@ -48,26 +47,31 @@ class TransProcessObj: put_to_radix_thread: threading.Thread = None latest_check_time: float = None - def create(self, prefill_node_id: str, nccl_ip: str, nccl_port: int, manager: "DecodeKVMoveManager"): - from .decode_trans_process import start_decode_trans_process + def create( + self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manager: "DecodeKVMoveManager" + ): - task_in_queue = mp.Queue() - task_out_queue = mp.Queue() device_index = manager.get_next_device_index() - proc = start_decode_trans_process( - manager.args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, manager.mem_queues - ) - assert task_out_queue.get(timeout=30) == "proc_start" - manager._put_mem_manager_to_mem_queue() - assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" + decode_node_id = manager.args.pd_node_id + task_in_queue = manager.kv_trans_task_in_queue + task_out_queue = manager.kv_trans_task_out_queue + + task_in_queue.put(PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=-1, + prefill_ip=prefill_ip, + prefill_port=prefill_port, + decode_id=decode_node_id, + decode_device_id=device_index, + )) assert task_out_queue.get(timeout=60) == "nccl_ok" self.prefill_node_id = prefill_node_id - self.process = proc + self.decode_node_id = decode_node_id self.task_in_queue = task_in_queue self.task_out_queue = task_out_queue - self.nccl_ip = nccl_ip - self.nccl_port = nccl_port + self.prefill_ip = prefill_ip + self.prefill_port = prefill_port self.device_index = device_index self.manager = manager @@ -86,20 +90,6 @@ def create(self, prefill_node_id: str, nccl_ip: str, nccl_port: int, manager: "D self.put_to_radix_thread.start() return - def check_trans_process(self, raise_exception=True): - process = psutil.Process(self.process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - self.set_has_error() - if raise_exception: - raise Exception(f"trans process: {self.process.pid} is dead") - return - - def timer_to_check_status(self, raise_exception=True): - if time.time() - self.latest_check_time >= 2.0: - self.latest_check_time = time.time() - self.check_trans_process(raise_exception=raise_exception) - return - def _transfer_kv(self, move_tasks: List[KVMoveTask]): with self.manager.device_locks[self.device_index]: self.task_in_queue.put(move_tasks.copy(), timeout=10) @@ -130,8 +120,6 @@ def kv_move_loop(self): logger.info(f"{func_name} get task {task.to_decode_log_info()}") try: - self.timer_to_check_status(raise_exception=True) - if not kv_trans_use_p2p(): with self.manager.kv_trans_lock: self._transfer_kv(move_tasks) @@ -148,6 +136,10 @@ def kv_move_loop(self): self.manager.put_to_fail_release_task_queue(move_tasks) logger.error(f"{func_name} prefill id {self.prefill_node_id} device_index {self.device_index} thread quit") + self.task_in_queue.put(PDTransLeaveInfo( + decode_id=self.decode_node_id, + prefill_id=self.prefill_node_id + )) return def put_to_radix_loop(self): @@ -163,8 +155,6 @@ def put_to_radix_loop(self): try: # random to check stats - self.timer_to_check_status(raise_exception=True) - self.manager._put_kv_received_to_radix_cache(move_tasks.copy()) for task in move_tasks.copy(): logger.info( @@ -239,12 +229,6 @@ def __del__(self): logger.error(f"trans obj deled, prefill node id {self.prefill_node_id} device_index {self.device_index}") - # 强制关闭连接和杀掉传输进程 - if self.process is not None: - logger.warning(f"trans kv process {self.process.pid} is killed") - os.kill(self.process.pid, signal.SIGKILL) - pass - class DecodeKVMoveManager(rpyc.Service): def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): @@ -284,6 +268,18 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.kv_trans_lock = threading.Lock() # 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。 self.device_locks = [threading.Lock() for _ in range(self.node_world_size)] + + # start a single kv trans process + self.kv_trans_task_in_queue = mp.Queue() + self.kv_trans_task_out_queue = mp.Queue() + from .decode_trans_process import start_decode_trans_process + self.kv_trans_process = start_decode_trans_process( + self.args, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues) + + assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" + self._put_mem_manager_to_mem_queue() + assert self.kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" + return def put_to_fail_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): @@ -392,17 +388,17 @@ def exposed_check_alive(self): # 用于 prefill node check 通信连接的状态。 return - def exposed_build_trans_process(self, prefill_node_id, nccl_ip, nccl_port, prefill_node_max_kv_trans_num): - prefill_node_id, nccl_ip, nccl_port, prefill_node_max_kv_trans_num = list( - map(obtain, [prefill_node_id, nccl_ip, nccl_port, prefill_node_max_kv_trans_num]) + def exposed_build_trans_process(self, prefill_node_id, prefill_ip, prefill_port, prefill_node_max_kv_trans_num): + prefill_node_id, prefill_ip, prefill_port, prefill_node_max_kv_trans_num = list( + map(obtain, [prefill_node_id, prefill_ip, prefill_port, prefill_node_max_kv_trans_num]) ) thread_local_data.prefill_node_id = prefill_node_id - logger.info(f"build trans infos {prefill_node_id} {nccl_ip} {nccl_port}") + logger.info(f"build trans infos {prefill_node_id} {prefill_ip} {prefill_port}") # 如果有历史残留,一并移除 self.remove_trans_obj(prefill_node_id) tran_obj = TransProcessObj() - tran_obj.create(prefill_node_id, nccl_ip, nccl_port, self) + tran_obj.create(prefill_node_id, prefill_ip, prefill_port, self) self.node_id_to_trans_obj[prefill_node_id] = tran_obj return min(prefill_node_max_kv_trans_num, self.args.max_total_token_num) @@ -499,10 +495,25 @@ def remove_trans_obj(self, prefill_node_id): trans_obj.set_has_error() return + def check_trans_process(self, raise_exception=True): + process = psutil.Process(self.kv_trans_process.pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + if raise_exception: + raise Exception(f"trans process: {self.kv_trans_process.pid} is dead") + return + def timer_loop(self): - while True: - self._unfrozen_time_out_reqs_tokens() - time.sleep(3.5) + try: + last_check_time = time.time() + while True: + self._unfrozen_time_out_reqs_tokens() + time.sleep(3.5) + if last_check_time - time.time() > 10.0: + self.check_trans_process() + last_check_time = time.time() + except (BaseException, RuntimeError) as e: + logger.exception(str(e)) + raise e def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index b70bf8efe..010074b10 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -3,91 +3,100 @@ import sys import inspect import torch.multiprocessing as mp -from typing import List, Dict +from torch.distributed import TCPStore +from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger from lightllm.common.mem_manager import MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry +from lightllm.distributed.pynccl import PyNcclCommunicator, StatelessP2PProcessGroup logger = init_logger(__name__) +def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, + mem_managers: List[MemoryManager], prefill_to_comm: Dict[int, PyNcclCommunicator], + dp_size_in_node: int): + total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) + try: + prefill_id = move_tasks[0].prefill_node_id + device_index = prefill_to_comm[prefill_id].device.index + start = time.time() + if total_move_kv_len != 0: + cur_mem = mem_managers[device_index] + logger.info(f"trans start: {move_tasks[0].to_decode_log_info()}") + if kv_trans_use_p2p(): + cur_mem.receive_from_prefill_node_p2p(move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id]) + else: + cur_mem.receive_from_prefill_node(move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id]) + logger.info(f"trans finished: {move_tasks[0].to_decode_log_info()} move len: {total_move_kv_len}") + torch.cuda.synchronize() + logger.info(f"trans cost time: {(time.time() - start)}, {move_tasks[0].to_decode_log_info()}") + task_out_queue.put("ok") + except BaseException as e: + logger.exception(str(e)) + task_out_queue.put("fail") + raise e + +def _handle_prefill_join(node_info: PDTransJoinInfo, task_out_queue: mp.Queue, prefill_to_comm: Dict[int, PyNcclCommunicator]): + try: + store_client = TCPStore(host_name=node_info.prefill_ip, port=node_info.prefill_port, is_master=False, use_libuv=False) + group = StatelessP2PProcessGroup.create( + src_id=node_info.prefill_id, + dest_id=node_info.decode_id, + is_server=False, + store=store_client) + comm = PyNcclCommunicator(group, node_info.decode_device_id) + prefill_to_comm[node_info.prefill_id] = comm + logger.info(f"{node_info} kv trans connected") + task_out_queue.put('nccl_ok') + except Exception as e: + logger.warning(f"error while connect to prefill node: {e}") + def _init_env( args, - device_index: int, - nccl_ip, - nccl_port, task_in_queue: mp.Queue, task_out_queue: mp.Queue, - mem_queues: List[mp.Queue], -): - import os - - # os.environ["NCCL_DEBUG"] = "INFO" - os.environ["NCCL_MAX_NCHANNELS"] = "2" - os.environ["NCCL_NSOCKS_PER_CHANNEL"] = "1" - os.environ["NCCL_SOCKET_NTHREADS"] = "1" - torch.backends.cudnn.enabled = False + mem_queues: List[mp.Queue]): dp_size_in_node = max(1, args.dp // args.nnodes) node_world_size = args.tp // args.nnodes try: - # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) - torch.cuda.set_device(device_index) - task_out_queue.put("proc_start") mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] assert len(mem_managers) == node_world_size task_out_queue.put("get_mem_managers_ok") - import torch.distributed as dist - from datetime import timedelta - - dist.init_process_group( - "nccl", init_method=f"tcp://{nccl_ip}:{nccl_port}", rank=1, world_size=2, timeout=timedelta(seconds=60) - ) - task_out_queue.put("nccl_ok") + prefill_to_comm: Dict[int, PyNcclCommunicator] = {} while True: - move_tasks: List[KVMoveTask] = task_in_queue.get() - total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) - try: - start = time.time() - if total_move_kv_len != 0: - cur_mem = mem_managers[device_index] - logger.info(f"trans start: {move_tasks[0].to_decode_log_info()}") - if kv_trans_use_p2p(): - cur_mem.receive_from_prefill_node_p2p(move_tasks, mem_managers, dp_size_in_node) - else: - cur_mem.receive_from_prefill_node(move_tasks, mem_managers, dp_size_in_node) - logger.info(f"trans finished: {move_tasks[0].to_decode_log_info()} move len: {total_move_kv_len}") - torch.cuda.synchronize() - logger.info(f"trans cost time: {(time.time() - start)}, {move_tasks[0].to_decode_log_info()}") - task_out_queue.put("ok") - except BaseException as e: - logger.exception(str(e)) - task_out_queue.put("fail") - raise e - except BaseException as e: - logger.exception(str(e)) - sys.exit(-1) - return + task: Union[List, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() + if isinstance(task, List): + _handle_kvmove_task(task, task_out_queue, mem_managers, prefill_to_comm, dp_size_in_node) + elif isinstance(task, PDTransJoinInfo): + _handle_prefill_join(task, task_out_queue, prefill_to_comm) + elif isinstance(task, PDTransLeaveInfo): + prefill_to_comm[task.prefill_id].destroy() + logger.info(f"destory {task.prefill_id} nccl communicator.") + else: + logger.warning(f'unexpected task type: {task}') + + except Exception as e: + logger.error(f"Fatal error happened in kv trans process: {e}") + raise def start_decode_trans_process( args, - device_index: int, - nccl_ip, - nccl_port, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): proc = mp.Process( - target=_init_env, args=(args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, mem_queues) + target=_init_env, args=(args, task_in_queue, task_out_queue, mem_queues) ) proc.start() assert proc.is_alive() - logger.info(f"decode trans kv process start, nccl_ip: {nccl_ip}, nccl_port: {nccl_port}") + logger.info(f"decode trans kv process start!") return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index 27b0fbb19..fbff30d20 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -17,7 +17,7 @@ from .prefill_infer_rpyc import PDPrefillInferRpcServer from lightllm.common.mem_manager import MemoryManager import torch.multiprocessing as mp -from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo from lightllm.utils.net_utils import find_available_port from lightllm.utils.retry_utils import retry from rpyc.utils.classic import obtain @@ -35,13 +35,10 @@ @dataclass class TransProcessObj: - decode_node_id: str = None + decode_node_id: int = None rpyc_conn: object = None # rpyc_con 的连接对象 - process: mp.Process = None task_in_queue: mp.Queue = None task_out_queue: mp.Queue = None - nccl_ip: str = None - nccl_port: str = None device_index: str = None # 使用的gpu序号 manager: "PrefillKVMoveManager" = None has_error: bool = False @@ -52,42 +49,38 @@ class TransProcessObj: latest_check_time: float = None def create( - self, decode_node_id: str, decode_node_ip: str, decode_node_rpyc_port: int, manager: "PrefillKVMoveManager" + self, decode_node_id: int, decode_node_ip: str, decode_node_rpyc_port: int, manager: "PrefillKVMoveManager" ): con = rpyc.connect( host=decode_node_ip, port=decode_node_rpyc_port, config={"allow_pickle": True}, keepalive=True ) - nccl_ip = manager.host_ip - nccl_port = find_available_port(manager.args.pd_p_allowed_port_min, manager.args.pd_p_allowed_port_max) - if nccl_port is None: - raise Exception("no pd nccl port can be used") - - from .prefill_trans_process import start_prefill_trans_process device_index = manager.get_next_device_index() # 分配 trans 进程使用的显卡 - task_in_queue = mp.Queue() - task_out_queue = mp.Queue() - proc = start_prefill_trans_process( - manager.args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, manager.mem_queues - ) - assert task_out_queue.get(timeout=30) == "proc_start" - manager._put_mem_manager_to_mem_queue() - assert task_out_queue.get(timeout=60) == "get_mem_managers_ok" prefill_node_id = manager.args.pd_node_id + task_in_queue = manager.kv_trans_task_in_queue + task_out_queue = manager.kv_trans_task_out_queue + + task_in_queue.put(PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=device_index, + prefill_ip=manager.host_ip, + prefill_port=manager.kv_trans_port, + decode_id=decode_node_id, + decode_device_id=-1 + )) + # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 max_kv_trans_token_num = obtain( - con.root.build_trans_process(prefill_node_id, nccl_ip, nccl_port, manager.args.max_total_token_num) + con.root.build_trans_process(prefill_node_id, manager.host_ip, manager.kv_trans_port, manager.args.max_total_token_num) ) self.max_kv_trans_token_num = max_kv_trans_token_num assert task_out_queue.get(timeout=60) == "nccl_ok" self.decode_node_id = decode_node_id + self.prefill_node_id = prefill_node_id self.rpyc_conn = con - self.process = proc self.task_in_queue = task_in_queue self.task_out_queue = task_out_queue - self.nccl_port = nccl_port - self.nccl_ip = nccl_ip self.device_index = device_index self.manager = manager self.latest_check_time = time.time() @@ -114,13 +107,6 @@ def _get_request_tasks(self, datas: List[KVMoveTask]): break return ans_list - def check_trans_process(self, raise_exception=True): - process = psutil.Process(self.process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): - self.set_has_error() - if raise_exception: - raise Exception(f"trans process: {self.process.pid} is dead") - return def check_connect(self, raise_exception=True): try: @@ -134,7 +120,6 @@ def check_connect(self, raise_exception=True): def timer_check_status(self, raise_exception=True): if time.time() - self.latest_check_time >= 2.0: self.latest_check_time = time.time() - self.check_trans_process(raise_exception=raise_exception) self.check_connect(raise_exception=raise_exception) if self.has_error: self.manager.remove_trans_obj(self.decode_node_id) @@ -249,6 +234,8 @@ def kv_trans_handle_loop(self): self.manager.put_to_release_task_queue(move_tasks) logger.error(f"trans kv thread, decode id {self.decode_node_id} device_index {self.device_index} thread quit") + self.task_in_queue.put(PDTransLeaveInfo( + decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) return def wait_thread_quit(self): @@ -302,12 +289,6 @@ def __del__(self): logger.error(f"trans obj deled, decode node id {self.decode_node_id} device_index {self.device_index}") - # 强制关闭连接和杀掉传输进程 - if self.process is not None: - logger.warning(f"prefill trans process {self.process.pid} is killed") - os.kill(self.process.pid, signal.SIGKILL) - pass - class PrefillKVMoveManager: def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): @@ -344,6 +325,19 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.release_task_queue = TaskQueue(lambda datas: datas[0:KV_MOVE_MAX_NUM], fail_func=None) self.release_tasks_thread = threading.Thread(target=self.handle_release_task_loop, daemon=True) self.release_tasks_thread.start() + + # start a single kv trans process + self.kv_trans_task_in_queue = mp.Queue() + self.kv_trans_task_out_queue = mp.Queue() + from .prefill_trans_process import start_decode_trans_process + self.kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) + self.kv_trans_process = start_decode_trans_process( + self.args, self.host_ip, self.kv_trans_port, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues) + + assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" + self._put_mem_manager_to_mem_queue() + assert self.kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" + return def put_to_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): @@ -364,6 +358,13 @@ def handle_release_task_loop(self): self._remove_req_refs_from_prompt_cache(handle_list) return + def check_trans_process(self, raise_exception=True): + process = psutil.Process(self.kv_trans_process.pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + if raise_exception: + raise Exception(f"trans process: {self.kv_trans_process.pid} is dead") + return + def get_next_device_index(self): counts = [0 for _ in range(self.node_world_size)] for obj in self.node_id_to_trans_obj.values(): @@ -403,6 +404,7 @@ def remove_dead_trans_obj(self): def task_dispatcher_loop(self): try: + last_check_time = time.time() # 获取任务,并分发给相关卡的处理队列 while True: move_task: KVMoveTask = self.info_queue.get() @@ -415,6 +417,10 @@ def task_dispatcher_loop(self): finally: trans_obj = None + if time.time() - last_check_time > 10.0: + self.check_trans_process() + last_check_time = time.time() + except (BaseException, RuntimeError) as e: logger.exception(str(e)) raise e diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index 1973aabac..b9b9b7242 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -3,97 +3,102 @@ import sys import inspect import torch.multiprocessing as mp -from typing import List, Dict +from torch.distributed import TCPStore +from typing import List, Dict, Union from lightllm.utils.log_utils import init_logger from lightllm.common.mem_manager import MemoryManager -from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry +from lightllm.distributed.pynccl import StatelessP2PProcessGroup, PyNcclCommunicator logger = init_logger(__name__) -# device_index 是用来指示,当前传输进程使用的用于数据传输的显卡id -# 当模型是多卡推理的时候,需要传输的 kv 需要先移动到 device_index -# 指定的显卡上,然后再进行传输,因为torch nccl 限制了只能操作一张显卡上的数据 +def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, + mem_managers: List[MemoryManager], decode_to_comm: Dict[int, PyNcclCommunicator], + dp_size_in_node: int): + total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) + try: + decode_id = move_tasks[0].decode_node.node_id + device_index = decode_to_comm[decode_id].device.index + torch.cuda.set_device(device_index) + start = time.time() + if total_move_kv_len != 0: + logger.info(f"trans start: {move_tasks[0].to_prefill_log_info()}") + cur_mem = mem_managers[device_index] + if kv_trans_use_p2p(): + cur_mem.send_to_decode_node_p2p(move_tasks, mem_managers, dp_size_in_node, decode_to_comm[decode_id]) + else: + cur_mem.send_to_decode_node(move_tasks, mem_managers, dp_size_in_node, decode_to_comm[decode_id]) + logger.info(f"trans finished: {move_tasks[0].to_prefill_log_info()} move len: {total_move_kv_len}") + torch.cuda.synchronize() + logger.info( + f"trans cost time: {(time.time() - start)}," + f"move_total_kv_len: {total_move_kv_len}, {move_tasks[0].to_prefill_log_info()}" + ) + task_out_queue.put("ok") + except BaseException as e: + logger.exception(str(e)) + task_out_queue.put("fail") + +def _handle_decode_join(node_info: PDTransJoinInfo, task_out_queue: mp.Queue, decode_to_comm: Dict[str, PyNcclCommunicator], store: TCPStore): + try: + group = StatelessP2PProcessGroup.create(node_info.prefill_id, node_info.decode_id, True, store) + comm = PyNcclCommunicator(group, node_info.prefill_device_id) + decode_to_comm[node_info.decode_id] = comm + logger.info(f"{node_info} kv trans connected!") + task_out_queue.put("nccl_ok") + except Exception as e: + logger.warning(f"error while connect to decode node: {e}") + def _init_env( args, - device_index: int, - nccl_ip, - nccl_port, + store_ip, + store_port, task_in_queue: mp.Queue, task_out_queue: mp.Queue, - mem_queues: List[mp.Queue], -): - import os - - # os.environ["NCCL_DEBUG"] = "INFO" - os.environ["NCCL_MAX_NCHANNELS"] = "2" - os.environ["NCCL_NSOCKS_PER_CHANNEL"] = "1" - os.environ["NCCL_SOCKET_NTHREADS"] = "1" - torch.backends.cudnn.enabled = False - - dp_size_in_node = max(1, args.dp // args.nnodes) - node_world_size = args.tp // args.nnodes - + mem_queues: List[mp.Queue],): try: - # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) - torch.cuda.set_device(device_index) - + master_store = TCPStore(host_name=store_ip, port=store_port, is_master=True, use_libuv=True) + dp_size_in_node = max(1, args.dp // args.nnodes) + node_world_size = args.tp // args.nnodes task_out_queue.put("proc_start") mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] assert len(mem_managers) == node_world_size task_out_queue.put("get_mem_managers_ok") - import torch.distributed as dist - from datetime import timedelta + decode_to_comm: Dict[int, PyNcclCommunicator] = {} - dist.init_process_group( - "nccl", init_method=f"tcp://{nccl_ip}:{nccl_port}", rank=0, world_size=2, timeout=timedelta(seconds=60) - ) - task_out_queue.put("nccl_ok") while True: - move_tasks: List[KVMoveTask] = task_in_queue.get() - total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) - try: - start = time.time() - if total_move_kv_len != 0: - logger.info(f"trans start: {move_tasks[0].to_prefill_log_info()}") - cur_mem = mem_managers[device_index] - if kv_trans_use_p2p(): - cur_mem.send_to_decode_node_p2p(move_tasks, mem_managers, dp_size_in_node) - else: - cur_mem.send_to_decode_node(move_tasks, mem_managers, dp_size_in_node) - logger.info(f"trans finished: {move_tasks[0].to_prefill_log_info()} move len: {total_move_kv_len}") - torch.cuda.synchronize() - logger.info( - f"trans cost time: {(time.time() - start)}," - f"move_total_kv_len: {total_move_kv_len}, {move_tasks[0].to_prefill_log_info()}" - ) - task_out_queue.put("ok") - except BaseException as e: - logger.exception(str(e)) - task_out_queue.put("fail") - raise e - except BaseException as e: - logger.exception(str(e)) - sys.exit(-1) - return + task: Union[List, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get() + if isinstance(task, List): + _handle_kvmove_task(task, task_out_queue, mem_managers, decode_to_comm, dp_size_in_node) + elif isinstance(task, PDTransJoinInfo): + _handle_decode_join(task, task_out_queue, decode_to_comm, master_store) + elif isinstance(task, PDTransLeaveInfo): + decode_to_comm[task.decode_id].destroy() + logger.info(f"destory {task.decode_id} nccl communicator.") + else: + logger.warning(f'unexpected task type: {task}') + + except Exception as e: + logger.error(f"Fatal error happened in kv trans process: {e}") + pass -def start_prefill_trans_process( +def start_decode_trans_process( args, - device_index: int, - nccl_ip, - nccl_port, + store_ip, + store_port, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): proc = mp.Process( - target=_init_env, args=(args, device_index, nccl_ip, nccl_port, task_in_queue, task_out_queue, mem_queues) + target=_init_env, args=(args, store_ip, store_port, task_in_queue, task_out_queue, mem_queues) ) proc.start() assert proc.is_alive() - logger.info(f"trans kv process start, nccl_ip: {nccl_ip}, nccl_port: {nccl_port}") - return proc + logger.info(f"trans kv process started!") + return proc \ No newline at end of file From 1063c2075acab77f14c07425ea7891b44bad8f0f Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Mon, 10 Mar 2025 13:02:01 +0800 Subject: [PATCH 2/4] fix name. --- .../pd_mode/prefill_node_impl/prefill_kv_move_manager.py | 4 ++-- .../pd_mode/prefill_node_impl/prefill_trans_process.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index fbff30d20..e0c342654 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -329,9 +329,9 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): # start a single kv trans process self.kv_trans_task_in_queue = mp.Queue() self.kv_trans_task_out_queue = mp.Queue() - from .prefill_trans_process import start_decode_trans_process + from .prefill_trans_process import start_prefill_trans_process self.kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) - self.kv_trans_process = start_decode_trans_process( + self.kv_trans_process = start_prefill_trans_process( self.args, self.host_ip, self.kv_trans_port, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues) assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index b9b9b7242..b6fa0f032 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -87,7 +87,7 @@ def _init_env( pass -def start_decode_trans_process( +def start_prefill_trans_process( args, store_ip, store_port, From 2ec8673194a9d41c11d2cdc4074c5feff927ea49 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Mon, 10 Mar 2025 13:59:13 +0800 Subject: [PATCH 3/4] fix style. --- lightllm/common/deepseek2_mem_manager.py | 28 ++- lightllm/common/mem_manager.py | 24 ++- lightllm/distributed/pynccl.py | 138 +++++++----- lightllm/distributed/pynccl_wrapper.py | 201 ++++++++++-------- lightllm/server/pd_io_struct.py | 3 + .../decode_kv_move_manager.py | 33 ++- .../decode_node_impl/decode_trans_process.py | 50 +++-- .../prefill_kv_move_manager.py | 37 ++-- .../prefill_trans_process.py | 30 ++- 9 files changed, 318 insertions(+), 226 deletions(-) diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index c7f437e96..0afb16048 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -41,8 +41,11 @@ def alloc_kv_move_buffer(self, max_req_total_len): return def send_to_decode_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int, - nccl_comm: PyNcclCommunicator + self, + move_tasks: List[KVMoveTask], + mem_managers: List["Deepseek2MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -68,8 +71,11 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): return move_buffer def receive_from_prefill_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, - nccl_comm: PyNcclCommunicator + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -101,8 +107,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch. return def send_to_decode_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, - nccl_comm: PyNcclCommunicator + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): """ 使用 p2p triton kernel 进行数据复制和传输的实现方式。 @@ -130,8 +139,11 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k return move_buffer def receive_from_prefill_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, - nccl_comm: PyNcclCommunicator + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 79edd3cdc..7c4dd35ef 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -87,8 +87,11 @@ def alloc_kv_move_buffer(self, max_req_total_len): return def send_to_decode_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, - nccl_comm: PyNcclCommunicator + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -124,7 +127,10 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int): return move_buffer def receive_from_prefill_node( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 @@ -158,8 +164,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch. return def send_to_decode_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, - nccl_comm: PyNcclCommunicator + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm: PyNcclCommunicator, ): """ 使用 p2p triton kernel 进行数据复制和传输的实现方式。 @@ -190,7 +199,10 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k return move_buffer def receive_from_prefill_node_p2p( - self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int, + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, nccl_comm: PyNcclCommunicator, ): assert dp_size_in_node == 1 diff --git a/lightllm/distributed/pynccl.py b/lightllm/distributed/pynccl.py index 9a01dd116..3637b04dd 100644 --- a/lightllm/distributed/pynccl.py +++ b/lightllm/distributed/pynccl.py @@ -33,19 +33,27 @@ from torch.distributed import ProcessGroup, ReduceOp, TCPStore from lightllm.distributed.pynccl_wrapper import ( - NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, - ncclRedOpTypeEnum, ncclUniqueId) + NCCLLibrary, + buffer_type, + cudaStream_t, + ncclComm_t, + ncclDataTypeEnum, + ncclRedOpTypeEnum, + ncclUniqueId, +) logger = logging.getLogger(__name__) _current_stream = None + def current_stream() -> torch.cuda.Stream: global _current_stream if _current_stream is None: _current_stream = torch.cuda.current_stream() return _current_stream + @dataclasses.dataclass class StatelessP2PProcessGroup: """A dataclass to hold a metadata store, and the rank, world_size of the @@ -94,18 +102,13 @@ def expire_data(self): def recv_obj(self) -> Any: """Receive an object from a source rank.""" - obj = pickle.loads( - self.store.get( - f"send_to/{self.dest_id}/{self.recv_src_counter}")) + obj = pickle.loads(self.store.get(f"send_to/{self.dest_id}/{self.recv_src_counter}")) self.recv_src_counter += 1 return obj @staticmethod def create( - src_id: int, - dest_id: int, - is_server: bool, - store: torch._C._distributed_c10d.Store + src_id: int, dest_id: int, is_server: bool, store: torch._C._distributed_c10d.Store ) -> "StatelessP2PProcessGroup": """A replacement for `torch.distributed.init_process_group` that does not pollute the global state. @@ -121,12 +124,11 @@ def create( used for exchanging metadata. With this function, process A and process B can call `StatelessProcessGroup.create` to form a group, and then process A, B, C, and D can call `StatelessProcessGroup.create` to form another group. - """ # noqa + """ # noqa return StatelessP2PProcessGroup(src_id=src_id, dest_id=dest_id, is_server=is_server, store=store) class PyNcclCommunicator: - def __init__( self, group: Union[ProcessGroup, StatelessP2PProcessGroup], @@ -146,8 +148,9 @@ def __init__( """ if not isinstance(group, StatelessP2PProcessGroup): assert dist.is_initialized() - assert dist.get_backend(group) != dist.Backend.NCCL, ( - "PyNcclCommunicator should be attached to a non-NCCL group.") + assert ( + dist.get_backend(group) != dist.Backend.NCCL + ), "PyNcclCommunicator should be attached to a non-NCCL group." # note: this rank is the rank in the group self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) @@ -207,8 +210,7 @@ def __init__( # `torch.cuda.device` is a context manager that changes the # current cuda device to the specified one with torch.cuda.device(device): - self.comm: ncclComm_t = self.nccl.ncclCommInitRank( - self.world_size, self.unique_id, self.rank) + self.comm: ncclComm_t = self.nccl.ncclCommInitRank(self.world_size, self.unique_id, self.rank) stream = current_stream() # A small all_reduce for warmup. @@ -220,10 +222,7 @@ def __init__( def destroy(self): self.nccl.ncclCommDestroy(self.comm) - def all_reduce(self, - in_tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - stream=None) -> torch.Tensor: + def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor: if self.disabled: return None # nccl communicator created on a specific device @@ -231,24 +230,25 @@ def all_reduce(self, # otherwise it will cause "illegal memory access" assert in_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {in_tensor.device}") + f"but the input tensor is on {in_tensor.device}" + ) out_tensor = torch.empty_like(in_tensor) if stream is None: stream = current_stream() - self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), - buffer_type(out_tensor.data_ptr()), - in_tensor.numel(), - ncclDataTypeEnum.from_torch(in_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - cudaStream_t(stream.cuda_stream)) + self.nccl.ncclAllReduce( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) return out_tensor - def all_gather(self, - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - stream=None): + def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None): if self.disabled: return # nccl communicator created on a specific device @@ -256,20 +256,22 @@ def all_gather(self, # otherwise it will cause "illegal memory access" assert input_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") + f"but the input tensor is on {input_tensor.device}" + ) if stream is None: stream = current_stream() self.nccl.ncclAllGather( buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), input_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, - cudaStream_t(stream.cuda_stream)) - - def reduce_scatter(self, - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - stream=None): + buffer_type(output_tensor.data_ptr()), + input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def reduce_scatter( + self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None + ): if self.disabled: return # nccl communicator created on a specific device @@ -277,46 +279,63 @@ def reduce_scatter(self, # otherwise it will cause "illegal memory access" assert input_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") + f"but the input tensor is on {input_tensor.device}" + ) if stream is None: stream = current_stream() self.nccl.ncclReduceScatter( buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + buffer_type(output_tensor.data_ptr()), + output_tensor.numel(), ncclDataTypeEnum.from_torch(input_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - cudaStream_t(stream.cuda_stream)) + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() - self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), dst, - self.comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + dst, + self.comm, + cudaStream_t(stream.cuda_stream), + ) def recv(self, tensor: torch.Tensor, src: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() - self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), src, - self.comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) def broadcast(self, tensor: torch.Tensor, src: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() if src == self.rank: @@ -326,7 +345,12 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): else: sendbuff = buffer_type() recvbuff = buffer_type(tensor.data_ptr()) - self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), src, - self.comm, cudaStream_t(stream.cuda_stream)) - + self.nccl.ncclBroadcast( + sendbuff, + recvbuff, + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) diff --git a/lightllm/distributed/pynccl_wrapper.py b/lightllm/distributed/pynccl_wrapper.py index d35ec8e3a..344689d96 100644 --- a/lightllm/distributed/pynccl_wrapper.py +++ b/lightllm/distributed/pynccl_wrapper.py @@ -64,9 +64,7 @@ def find_nccl_library() -> str: # manually load the nccl library if so_file: - logger.info( - "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", - so_file) + logger.info("Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file) else: if torch.version.cuda is not None: so_file = "libnccl.so.2" @@ -173,77 +171,74 @@ class NCCLLibrary: # const char* ncclGetErrorString(ncclResult_t result) Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), # ncclResult_t ncclGetVersion(int *version); - Function("ncclGetVersion", ncclResult_t, - [ctypes.POINTER(ctypes.c_int)]), + Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); - Function("ncclGetUniqueId", ncclResult_t, - [ctypes.POINTER(ncclUniqueId)]), + Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), # ncclResult_t ncclCommInitRank( # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); # note that ncclComm_t is a pointer type, so the first argument # is a pointer to a pointer - Function("ncclCommInitRank", ncclResult_t, [ - ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, - ctypes.c_int - ]), + Function( + "ncclCommInitRank", ncclResult_t, [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int] + ), # ncclResult_t ncclAllReduce( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclAllReduce", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclAllReduce", + ncclResult_t, + [buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t], + ), # ncclResult_t ncclAllGather( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclAllGather", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclComm_t, cudaStream_t - ]), - + Function( + "ncclAllGather", + ncclResult_t, + [buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclComm_t, cudaStream_t], + ), # ncclResult_t ncclReduceScatter( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclReduceScatter", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclReduceScatter", + ncclResult_t, + [buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t], + ), # ncclResult_t ncclSend( # const void* sendbuff, size_t count, ncclDataType_t datatype, # int dest, ncclComm_t comm, cudaStream_t stream); - Function("ncclSend", ncclResult_t, [ - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, - ncclComm_t, cudaStream_t - ]), - + Function( + "ncclSend", + ncclResult_t, + [buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t], + ), # ncclResult_t ncclRecv( # void* recvbuff, size_t count, ncclDataType_t datatype, # int src, ncclComm_t comm, cudaStream_t stream); - Function("ncclRecv", ncclResult_t, [ - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, - ncclComm_t, cudaStream_t - ]), - + Function( + "ncclRecv", + ncclResult_t, + [buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t], + ), # ncclResult_t ncclBroadcast( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, int root, ncclComm_t comm, # cudaStream_t stream); - Function("ncclBroadcast", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ctypes.c_int, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclBroadcast", + ncclResult_t, + [buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t], + ), # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -277,8 +272,10 @@ def __init__(self, so_file: Optional[str] = None): "or it does not support the current platform %s. " "If you already have the library, please set the " "environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path.", so_file, - platform.platform()) + " to point to the correct nccl library path.", + so_file, + platform.platform(), + ) raise e if so_file not in NCCLLibrary.path_to_dict_mapping: @@ -311,80 +308,100 @@ def ncclGetVersion(self) -> str: def ncclGetUniqueId(self) -> ncclUniqueId: unique_id = ncclUniqueId() - self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( - ctypes.byref(unique_id))) + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id))) return unique_id - def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, - rank: int) -> ncclComm_t: + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, rank: int) -> ncclComm_t: comm = ncclComm_t() - self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), - world_size, unique_id, - rank)) + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), world_size, unique_id, rank)) return comm - def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + def ncclAllReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, - datatype, op, comm, - stream)) - - def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream)) + + def ncclReduceScatter( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, - count, datatype, op, - comm, stream)) - - def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, count, datatype, op, comm, stream)) + + def ncclAllGather( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # which is an aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, - datatype, comm, stream)) - - def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, - dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, - dest, comm, stream)) - - def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, - src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, - comm, stream)) - - def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, root: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, - datatype, root, comm, - stream)) + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, datatype, comm, stream)) + + def ncclSend( + self, sendbuff: buffer_type, count: int, datatype: int, dest: int, comm: ncclComm_t, stream: cudaStream_t + ) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream)) + + def ncclRecv( + self, recvbuff: buffer_type, count: int, datatype: int, src: int, comm: ncclComm_t, stream: cudaStream_t + ) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)) + + def ncclBroadcast( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, datatype, root, comm, stream)) def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) __all__ = [ - "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", - "ncclComm_t", "cudaStream_t", "buffer_type" + "NCCLLibrary", + "ncclDataTypeEnum", + "ncclRedOpTypeEnum", + "ncclUniqueId", + "ncclComm_t", + "cudaStream_t", + "buffer_type", ] - def test_ncclGetUniqueId(): lib = NCCLLibrary() unique_id = lib.ncclGetUniqueId() @@ -399,7 +416,9 @@ def test_ncclGetUniqueId(): # as long as the function doesn't raise an exception, we're good assert unique_id is not None -if __name__ == '__main__': - import torch; + +if __name__ == "__main__": + import torch + torch.cuda.set_device(0) test_ncclGetUniqueId() diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 222cd5887..d63a38e99 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -74,6 +74,7 @@ class DecodeNodeInfo: rpyc_port: str max_new_tokens: int + @dataclass class PDTransJoinInfo: decode_id: int @@ -83,11 +84,13 @@ class PDTransJoinInfo: prefill_ip: str prefill_port: int + @dataclass class PDTransLeaveInfo: decode_id: int prefill_id: int + @dataclass class KVMoveTask: group_request_id: int diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py index 0c46b6dd5..8af8952fa 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py @@ -47,23 +47,23 @@ class TransProcessObj: put_to_radix_thread: threading.Thread = None latest_check_time: float = None - def create( - self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manager: "DecodeKVMoveManager" - ): + def create(self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manager: "DecodeKVMoveManager"): device_index = manager.get_next_device_index() decode_node_id = manager.args.pd_node_id task_in_queue = manager.kv_trans_task_in_queue task_out_queue = manager.kv_trans_task_out_queue - task_in_queue.put(PDTransJoinInfo( - prefill_id=prefill_node_id, - prefill_device_id=-1, - prefill_ip=prefill_ip, - prefill_port=prefill_port, - decode_id=decode_node_id, - decode_device_id=device_index, - )) + task_in_queue.put( + PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=-1, + prefill_ip=prefill_ip, + prefill_port=prefill_port, + decode_id=decode_node_id, + decode_device_id=device_index, + ) + ) assert task_out_queue.get(timeout=60) == "nccl_ok" self.prefill_node_id = prefill_node_id @@ -136,10 +136,7 @@ def kv_move_loop(self): self.manager.put_to_fail_release_task_queue(move_tasks) logger.error(f"{func_name} prefill id {self.prefill_node_id} device_index {self.device_index} thread quit") - self.task_in_queue.put(PDTransLeaveInfo( - decode_id=self.decode_node_id, - prefill_id=self.prefill_node_id - )) + self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) return def put_to_radix_loop(self): @@ -269,12 +266,14 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): # 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。 self.device_locks = [threading.Lock() for _ in range(self.node_world_size)] - # start a single kv trans process + # start a single kv trans process self.kv_trans_task_in_queue = mp.Queue() self.kv_trans_task_out_queue = mp.Queue() from .decode_trans_process import start_decode_trans_process + self.kv_trans_process = start_decode_trans_process( - self.args, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues) + self.args, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues + ) assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" self._put_mem_manager_to_mem_queue() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 010074b10..7f3f9c676 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -15,9 +15,13 @@ logger = init_logger(__name__) -def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, - mem_managers: List[MemoryManager], prefill_to_comm: Dict[int, PyNcclCommunicator], - dp_size_in_node: int): +def _handle_kvmove_task( + move_tasks: List[KVMoveTask], + task_out_queue: mp.Queue, + mem_managers: List[MemoryManager], + prefill_to_comm: Dict[int, PyNcclCommunicator], + dp_size_in_node: int, +): total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) try: prefill_id = move_tasks[0].prefill_node_id @@ -27,9 +31,13 @@ def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, cur_mem = mem_managers[device_index] logger.info(f"trans start: {move_tasks[0].to_decode_log_info()}") if kv_trans_use_p2p(): - cur_mem.receive_from_prefill_node_p2p(move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id]) + cur_mem.receive_from_prefill_node_p2p( + move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id] + ) else: - cur_mem.receive_from_prefill_node(move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id]) + cur_mem.receive_from_prefill_node( + move_tasks, mem_managers, dp_size_in_node, prefill_to_comm[prefill_id] + ) logger.info(f"trans finished: {move_tasks[0].to_decode_log_info()} move len: {total_move_kv_len}") torch.cuda.synchronize() logger.info(f"trans cost time: {(time.time() - start)}, {move_tasks[0].to_decode_log_info()}") @@ -39,26 +47,26 @@ def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, task_out_queue.put("fail") raise e -def _handle_prefill_join(node_info: PDTransJoinInfo, task_out_queue: mp.Queue, prefill_to_comm: Dict[int, PyNcclCommunicator]): + +def _handle_prefill_join( + node_info: PDTransJoinInfo, task_out_queue: mp.Queue, prefill_to_comm: Dict[int, PyNcclCommunicator] +): try: - store_client = TCPStore(host_name=node_info.prefill_ip, port=node_info.prefill_port, is_master=False, use_libuv=False) + store_client = TCPStore( + host_name=node_info.prefill_ip, port=node_info.prefill_port, is_master=False, use_libuv=False + ) group = StatelessP2PProcessGroup.create( - src_id=node_info.prefill_id, - dest_id=node_info.decode_id, - is_server=False, - store=store_client) + src_id=node_info.prefill_id, dest_id=node_info.decode_id, is_server=False, store=store_client + ) comm = PyNcclCommunicator(group, node_info.decode_device_id) prefill_to_comm[node_info.prefill_id] = comm logger.info(f"{node_info} kv trans connected") - task_out_queue.put('nccl_ok') + task_out_queue.put("nccl_ok") except Exception as e: logger.warning(f"error while connect to prefill node: {e}") -def _init_env( - args, - task_in_queue: mp.Queue, - task_out_queue: mp.Queue, - mem_queues: List[mp.Queue]): + +def _init_env(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue]): dp_size_in_node = max(1, args.dp // args.nnodes) node_world_size = args.tp // args.nnodes @@ -80,7 +88,7 @@ def _init_env( prefill_to_comm[task.prefill_id].destroy() logger.info(f"destory {task.prefill_id} nccl communicator.") else: - logger.warning(f'unexpected task type: {task}') + logger.warning(f"unexpected task type: {task}") except Exception as e: logger.error(f"Fatal error happened in kv trans process: {e}") @@ -93,10 +101,8 @@ def start_decode_trans_process( task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): - proc = mp.Process( - target=_init_env, args=(args, task_in_queue, task_out_queue, mem_queues) - ) + proc = mp.Process(target=_init_env, args=(args, task_in_queue, task_out_queue, mem_queues)) proc.start() assert proc.is_alive() - logger.info(f"decode trans kv process start!") + logger.info("decode trans kv process start!") return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index e0c342654..5c4e946cf 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -60,18 +60,22 @@ def create( task_in_queue = manager.kv_trans_task_in_queue task_out_queue = manager.kv_trans_task_out_queue - task_in_queue.put(PDTransJoinInfo( - prefill_id=prefill_node_id, - prefill_device_id=device_index, - prefill_ip=manager.host_ip, - prefill_port=manager.kv_trans_port, - decode_id=decode_node_id, - decode_device_id=-1 - )) + task_in_queue.put( + PDTransJoinInfo( + prefill_id=prefill_node_id, + prefill_device_id=device_index, + prefill_ip=manager.host_ip, + prefill_port=manager.kv_trans_port, + decode_id=decode_node_id, + decode_device_id=-1, + ) + ) # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 max_kv_trans_token_num = obtain( - con.root.build_trans_process(prefill_node_id, manager.host_ip, manager.kv_trans_port, manager.args.max_total_token_num) + con.root.build_trans_process( + prefill_node_id, manager.host_ip, manager.kv_trans_port, manager.args.max_total_token_num + ) ) self.max_kv_trans_token_num = max_kv_trans_token_num assert task_out_queue.get(timeout=60) == "nccl_ok" @@ -107,7 +111,6 @@ def _get_request_tasks(self, datas: List[KVMoveTask]): break return ans_list - def check_connect(self, raise_exception=True): try: self.rpyc_conn.root.check_alive() @@ -234,8 +237,7 @@ def kv_trans_handle_loop(self): self.manager.put_to_release_task_queue(move_tasks) logger.error(f"trans kv thread, decode id {self.decode_node_id} device_index {self.device_index} thread quit") - self.task_in_queue.put(PDTransLeaveInfo( - decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) + self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) return def wait_thread_quit(self): @@ -326,13 +328,20 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.release_tasks_thread = threading.Thread(target=self.handle_release_task_loop, daemon=True) self.release_tasks_thread.start() - # start a single kv trans process + # start a single kv trans process self.kv_trans_task_in_queue = mp.Queue() self.kv_trans_task_out_queue = mp.Queue() from .prefill_trans_process import start_prefill_trans_process + self.kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) self.kv_trans_process = start_prefill_trans_process( - self.args, self.host_ip, self.kv_trans_port, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues) + self.args, + self.host_ip, + self.kv_trans_port, + self.kv_trans_task_in_queue, + self.kv_trans_task_out_queue, + self.mem_queues, + ) assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" self._put_mem_manager_to_mem_queue() diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index b6fa0f032..c6def3b3c 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -15,9 +15,14 @@ logger = init_logger(__name__) -def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, - mem_managers: List[MemoryManager], decode_to_comm: Dict[int, PyNcclCommunicator], - dp_size_in_node: int): + +def _handle_kvmove_task( + move_tasks: List[KVMoveTask], + task_out_queue: mp.Queue, + mem_managers: List[MemoryManager], + decode_to_comm: Dict[int, PyNcclCommunicator], + dp_size_in_node: int, +): total_move_kv_len = sum([task.move_kv_len for task in move_tasks]) try: decode_id = move_tasks[0].decode_node.node_id @@ -42,7 +47,10 @@ def _handle_kvmove_task(move_tasks: List[KVMoveTask], task_out_queue: mp.Queue, logger.exception(str(e)) task_out_queue.put("fail") -def _handle_decode_join(node_info: PDTransJoinInfo, task_out_queue: mp.Queue, decode_to_comm: Dict[str, PyNcclCommunicator], store: TCPStore): + +def _handle_decode_join( + node_info: PDTransJoinInfo, task_out_queue: mp.Queue, decode_to_comm: Dict[str, PyNcclCommunicator], store: TCPStore +): try: group = StatelessP2PProcessGroup.create(node_info.prefill_id, node_info.decode_id, True, store) comm = PyNcclCommunicator(group, node_info.prefill_device_id) @@ -52,13 +60,15 @@ def _handle_decode_join(node_info: PDTransJoinInfo, task_out_queue: mp.Queue, de except Exception as e: logger.warning(f"error while connect to decode node: {e}") + def _init_env( args, store_ip, store_port, task_in_queue: mp.Queue, task_out_queue: mp.Queue, - mem_queues: List[mp.Queue],): + mem_queues: List[mp.Queue], +): try: graceful_registry(inspect.currentframe().f_code.co_name) master_store = TCPStore(host_name=store_ip, port=store_port, is_master=True, use_libuv=True) @@ -80,7 +90,7 @@ def _init_env( decode_to_comm[task.decode_id].destroy() logger.info(f"destory {task.decode_id} nccl communicator.") else: - logger.warning(f'unexpected task type: {task}') + logger.warning(f"unexpected task type: {task}") except Exception as e: logger.error(f"Fatal error happened in kv trans process: {e}") @@ -95,10 +105,8 @@ def start_prefill_trans_process( task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): - proc = mp.Process( - target=_init_env, args=(args, store_ip, store_port, task_in_queue, task_out_queue, mem_queues) - ) + proc = mp.Process(target=_init_env, args=(args, store_ip, store_port, task_in_queue, task_out_queue, mem_queues)) proc.start() assert proc.is_alive() - logger.info(f"trans kv process started!") - return proc \ No newline at end of file + logger.info("prefill trans kv process started!") + return proc From 1d248b0d7f01a4ab2a35266c09993060b77f9f6d Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Tue, 11 Mar 2025 17:57:08 +0800 Subject: [PATCH 4/4] one kv trans process per tp. --- lightllm/distributed/pynccl.py | 71 ----------------- .../decode_kv_move_manager.py | 59 ++++++++++---- .../decode_node_impl/decode_trans_process.py | 12 +-- .../prefill_kv_move_manager.py | 79 +++++++++++++------ .../prefill_trans_process.py | 12 +-- 5 files changed, 110 insertions(+), 123 deletions(-) diff --git a/lightllm/distributed/pynccl.py b/lightllm/distributed/pynccl.py index 3637b04dd..b96e0d1ba 100644 --- a/lightllm/distributed/pynccl.py +++ b/lightllm/distributed/pynccl.py @@ -248,51 +248,6 @@ def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, strea ) return out_tensor - def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None): - if self.disabled: - return - # nccl communicator created on a specific device - # will only work on tensors on the same device - # otherwise it will cause "illegal memory access" - assert input_tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}" - ) - if stream is None: - stream = current_stream() - self.nccl.ncclAllGather( - buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), - input_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), - self.comm, - cudaStream_t(stream.cuda_stream), - ) - - def reduce_scatter( - self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None - ): - if self.disabled: - return - # nccl communicator created on a specific device - # will only work on tensors on the same device - # otherwise it will cause "illegal memory access" - assert input_tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}" - ) - if stream is None: - stream = current_stream() - self.nccl.ncclReduceScatter( - buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), - output_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), - self.comm, - cudaStream_t(stream.cuda_stream), - ) - def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return @@ -328,29 +283,3 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): self.comm, cudaStream_t(stream.cuda_stream), ) - - def broadcast(self, tensor: torch.Tensor, src: int, stream=None): - if self.disabled: - return - assert tensor.device == self.device, ( - f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}" - ) - if stream is None: - stream = current_stream() - if src == self.rank: - sendbuff = buffer_type(tensor.data_ptr()) - # NCCL requires the sender also to have a receive buffer - recvbuff = buffer_type(tensor.data_ptr()) - else: - sendbuff = buffer_type() - recvbuff = buffer_type(tensor.data_ptr()) - self.nccl.ncclBroadcast( - sendbuff, - recvbuff, - tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), - src, - self.comm, - cudaStream_t(stream.cuda_stream), - ) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py index 8af8952fa..7a6f120cb 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py @@ -51,8 +51,8 @@ def create(self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manag device_index = manager.get_next_device_index() decode_node_id = manager.args.pd_node_id - task_in_queue = manager.kv_trans_task_in_queue - task_out_queue = manager.kv_trans_task_out_queue + task_in_queue = manager.kv_trans_task_in_queues[device_index] + task_out_queue = manager.kv_trans_task_out_queues[device_index] task_in_queue.put( PDTransJoinInfo( @@ -136,7 +136,6 @@ def kv_move_loop(self): self.manager.put_to_fail_release_task_queue(move_tasks) logger.error(f"{func_name} prefill id {self.prefill_node_id} device_index {self.device_index} thread quit") - self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) return def put_to_radix_loop(self): @@ -217,6 +216,7 @@ def __del__(self): try: self.set_has_error() self.wait_thread_quit() + self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) if self.ready_to_move_queue is not None: self.ready_to_move_queue.clear_tasks() if self.move_finished_queue is not None: @@ -266,18 +266,31 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): # 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。 self.device_locks = [threading.Lock() for _ in range(self.node_world_size)] - # start a single kv trans process - self.kv_trans_task_in_queue = mp.Queue() - self.kv_trans_task_out_queue = mp.Queue() from .decode_trans_process import start_decode_trans_process - self.kv_trans_process = start_decode_trans_process( - self.args, self.kv_trans_task_in_queue, self.kv_trans_task_out_queue, self.mem_queues - ) + self.kv_trans_processes = [] + self.kv_trans_task_in_queues = [] + self.kv_trans_task_out_queues = [] + self.kv_trans_process_alive = [] + + for device_index in range(self.node_world_size): + kv_trans_task_in_queue = mp.Queue() + kv_trans_task_out_queue = mp.Queue() + kv_trans_process = start_decode_trans_process( + self.args, + device_index, + kv_trans_task_in_queue, + kv_trans_task_out_queue, + self.mem_queues, + ) + assert kv_trans_task_out_queue.get(timeout=30) == "proc_start" + self._put_mem_manager_to_mem_queue() + assert kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" - assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" - self._put_mem_manager_to_mem_queue() - assert self.kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" + self.kv_trans_processes.append(kv_trans_process) + self.kv_trans_task_in_queues.append(kv_trans_task_in_queue) + self.kv_trans_task_out_queues.append(kv_trans_task_out_queue) + self.kv_trans_process_alive.append(True) return @@ -462,7 +475,9 @@ def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optiona return ans_list def get_next_device_index(self): - counts = [0 for _ in range(self.node_world_size)] + counts = [ + 0 if self.kv_trans_process_alive[device_id] else (1 << 20) for device_id in range(self.node_world_size) + ] for obj in self.node_id_to_trans_obj.values(): counts[obj.device_index] += 1 device_index = int(np.argmin(counts)) @@ -495,10 +510,22 @@ def remove_trans_obj(self, prefill_node_id): return def check_trans_process(self, raise_exception=True): - process = psutil.Process(self.kv_trans_process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + at_least_one_alive = False + for device_id in range(self.node_world_size): + if not self.kv_trans_process_alive[device_id]: + continue + + process = psutil.Process(self.kv_trans_processes[device_id].pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + self.kv_trans_process_alive[device_id] = False + logger.error(f"kv trans process for device: {device_id} dead!!!") + else: + at_least_one_alive = True + + if not at_least_one_alive: if raise_exception: - raise Exception(f"trans process: {self.kv_trans_process.pid} is dead") + raise Exception("All trans process are dead!!!") + return def timer_loop(self): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index 7f3f9c676..100b05eaf 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -66,16 +66,17 @@ def _handle_prefill_join( logger.warning(f"error while connect to prefill node: {e}") -def _init_env(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue]): +def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue]): dp_size_in_node = max(1, args.dp // args.nnodes) - node_world_size = args.tp // args.nnodes try: + torch.cuda.set_device(device_id) graceful_registry(inspect.currentframe().f_code.co_name) task_out_queue.put("proc_start") + mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] - assert len(mem_managers) == node_world_size + task_out_queue.put("get_mem_managers_ok") prefill_to_comm: Dict[int, PyNcclCommunicator] = {} while True: @@ -97,12 +98,13 @@ def _init_env(args, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queue def start_decode_trans_process( args, + device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): - proc = mp.Process(target=_init_env, args=(args, task_in_queue, task_out_queue, mem_queues)) + proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue, mem_queues)) proc.start() assert proc.is_alive() - logger.info("decode trans kv process start!") + logger.info(f"decode trans kv process for device: {device_id} start!") return proc diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py index 5c4e946cf..5ebce1021 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py @@ -39,7 +39,7 @@ class TransProcessObj: rpyc_conn: object = None # rpyc_con 的连接对象 task_in_queue: mp.Queue = None task_out_queue: mp.Queue = None - device_index: str = None # 使用的gpu序号 + device_index: int = None # 使用的gpu序号 manager: "PrefillKVMoveManager" = None has_error: bool = False request_kv_trans_task_queue: TaskQueue = None @@ -57,15 +57,15 @@ def create( device_index = manager.get_next_device_index() # 分配 trans 进程使用的显卡 prefill_node_id = manager.args.pd_node_id - task_in_queue = manager.kv_trans_task_in_queue - task_out_queue = manager.kv_trans_task_out_queue + task_in_queue = manager.kv_trans_task_in_queues[device_index] + task_out_queue = manager.kv_trans_task_out_queues[device_index] task_in_queue.put( PDTransJoinInfo( prefill_id=prefill_node_id, prefill_device_id=device_index, prefill_ip=manager.host_ip, - prefill_port=manager.kv_trans_port, + prefill_port=manager.kv_trans_ports[device_index], decode_id=decode_node_id, decode_device_id=-1, ) @@ -74,7 +74,7 @@ def create( # 异步调用, 让decode节点建立与prefill节点进行nccl通信的进程 max_kv_trans_token_num = obtain( con.root.build_trans_process( - prefill_node_id, manager.host_ip, manager.kv_trans_port, manager.args.max_total_token_num + prefill_node_id, manager.host_ip, manager.kv_trans_ports[device_index], manager.args.max_total_token_num ) ) self.max_kv_trans_token_num = max_kv_trans_token_num @@ -237,7 +237,6 @@ def kv_trans_handle_loop(self): self.manager.put_to_release_task_queue(move_tasks) logger.error(f"trans kv thread, decode id {self.decode_node_id} device_index {self.device_index} thread quit") - self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) return def wait_thread_quit(self): @@ -282,6 +281,7 @@ def __del__(self): try: self.set_has_error() self.wait_thread_quit() + self.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id)) if self.request_kv_trans_task_queue is not None: self.request_kv_trans_task_queue.clear_tasks() if self.ready_kv_trans_task_queue is not None: @@ -329,24 +329,37 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]): self.release_tasks_thread.start() # start a single kv trans process - self.kv_trans_task_in_queue = mp.Queue() - self.kv_trans_task_out_queue = mp.Queue() - from .prefill_trans_process import start_prefill_trans_process - - self.kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) - self.kv_trans_process = start_prefill_trans_process( - self.args, - self.host_ip, - self.kv_trans_port, - self.kv_trans_task_in_queue, - self.kv_trans_task_out_queue, - self.mem_queues, - ) - assert self.kv_trans_task_out_queue.get(timeout=30) == "proc_start" - self._put_mem_manager_to_mem_queue() - assert self.kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" + from .prefill_trans_process import start_prefill_trans_process + self.kv_trans_ports = [] + self.kv_trans_processes = [] + self.kv_trans_task_in_queues = [] + self.kv_trans_task_out_queues = [] + self.kv_trans_process_alive = [] + + for device_id in range(self.node_world_size): + kv_trans_task_in_queue = mp.Queue() + kv_trans_task_out_queue = mp.Queue() + kv_trans_port = find_available_port(self.args.pd_p_allowed_port_min, self.args.pd_p_allowed_port_max) + kv_trans_process = start_prefill_trans_process( + self.args, + self.host_ip, + kv_trans_port, + device_id, + kv_trans_task_in_queue, + kv_trans_task_out_queue, + self.mem_queues, + ) + assert kv_trans_task_out_queue.get(timeout=30) == "proc_start" + self._put_mem_manager_to_mem_queue() + assert kv_trans_task_out_queue.get(timeout=60) == "get_mem_managers_ok" + + self.kv_trans_ports.append(kv_trans_port) + self.kv_trans_processes.append(kv_trans_process) + self.kv_trans_task_in_queues.append(kv_trans_task_in_queue) + self.kv_trans_task_out_queues.append(kv_trans_task_out_queue) + self.kv_trans_process_alive.append(True) return def put_to_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]): @@ -368,14 +381,28 @@ def handle_release_task_loop(self): return def check_trans_process(self, raise_exception=True): - process = psutil.Process(self.kv_trans_process.pid) - if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + at_least_one_alive = False + for device_id in range(self.node_world_size): + if not self.kv_trans_process_alive[device_id]: + continue + + process = psutil.Process(self.kv_trans_processes[device_id].pid) + if not (process.is_running() and process.status() != psutil.STATUS_ZOMBIE): + self.kv_trans_process_alive[device_id] = False + logger.error(f"kv trans process for device: {device_id} dead!!!") + else: + at_least_one_alive = True + + if not at_least_one_alive: if raise_exception: - raise Exception(f"trans process: {self.kv_trans_process.pid} is dead") + raise Exception("All trans process are dead!!!") + return def get_next_device_index(self): - counts = [0 for _ in range(self.node_world_size)] + counts = [ + 0 if self.kv_trans_process_alive[device_id] else (1 << 20) for device_id in range(self.node_world_size) + ] for obj in self.node_id_to_trans_obj.values(): counts[obj.device_index] += 1 device_index = int(np.argmin(counts)) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index c6def3b3c..62327a11c 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -27,7 +27,6 @@ def _handle_kvmove_task( try: decode_id = move_tasks[0].decode_node.node_id device_index = decode_to_comm[decode_id].device.index - torch.cuda.set_device(device_index) start = time.time() if total_move_kv_len != 0: logger.info(f"trans start: {move_tasks[0].to_prefill_log_info()}") @@ -65,18 +64,18 @@ def _init_env( args, store_ip, store_port, + device_id, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): try: + torch.cuda.set_device(device_id) graceful_registry(inspect.currentframe().f_code.co_name) master_store = TCPStore(host_name=store_ip, port=store_port, is_master=True, use_libuv=True) dp_size_in_node = max(1, args.dp // args.nnodes) - node_world_size = args.tp // args.nnodes task_out_queue.put("proc_start") mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues] - assert len(mem_managers) == node_world_size task_out_queue.put("get_mem_managers_ok") decode_to_comm: Dict[int, PyNcclCommunicator] = {} @@ -101,12 +100,15 @@ def start_prefill_trans_process( args, store_ip, store_port, + device_id, task_in_queue: mp.Queue, task_out_queue: mp.Queue, mem_queues: List[mp.Queue], ): - proc = mp.Process(target=_init_env, args=(args, store_ip, store_port, task_in_queue, task_out_queue, mem_queues)) + proc = mp.Process( + target=_init_env, args=(args, store_ip, store_port, device_id, task_in_queue, task_out_queue, mem_queues) + ) proc.start() assert proc.is_alive() - logger.info("prefill trans kv process started!") + logger.info(f"prefill trans kv process for device: {device_id} started!") return proc