From f1f66fe331b5ee6e94aa8538eabd938eba6db9dd Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Mon, 18 Mar 2024 15:35:45 +0800 Subject: [PATCH] [Feature] Dynamic prompt cache (#356) Something like RadixAttention . link https://lmsys.org/blog/2024-01-17-sglang/ --------- Co-authored-by: wangzaijun Co-authored-by: shihaobai <42648726+shihaobai@users.noreply.github.com> Co-authored-by: baishihao Co-authored-by: Wu SiYu Co-authored-by: Siyu Wu --- .pre-commit-config.yaml | 2 +- lightllm/common/basemodel/basemodel.py | 270 +++++++---- lightllm/common/basemodel/infer_struct.py | 14 +- .../basemodel/splitfuse_infer_struct.py | 23 +- .../splitfuse_copy_kv_index_to_req.py | 38 +- lightllm/common/infer_utils.py | 14 +- lightllm/common/mem_manager.py | 47 +- .../bloom/layer_infer/post_layer_infer.py | 2 +- .../layer_infer/transformer_layer_infer.py | 41 +- .../context_flashattention_nopad.py | 208 ++++++-- lightllm/models/llama/infer_struct.py | 14 +- .../llama/layer_infer/post_layer_infer.py | 6 +- .../layer_infer/transformer_layer_infer.py | 42 +- .../models/llama/splitfuse_infer_struct.py | 18 +- .../context_flashattention_nopad.py | 273 ++++++++--- .../splitfuse_context_flashattention_nopad.py | 273 +++++++---- lightllm/models/llava/model.py | 11 +- lightllm/models/starcoder/infer_struct.py | 14 +- .../layer_infer/transformer_layer_infer.py | 2 +- lightllm/server/api_server.py | 268 ++++++----- lightllm/server/httpserver/manager.py | 79 +-- lightllm/server/io_struct.py | 171 ++++--- .../server/router/dynamic_prompt/__init__.py | 0 .../router/dynamic_prompt/radix_cache.py | 452 ++++++++++++++++++ .../router/dynamic_prompt/shared_arr.py | 154 ++++++ lightllm/server/router/manager.py | 170 +++---- .../server/router/model_infer/infer_batch.py | 152 +++--- .../server/router/model_infer/model_rpc.py | 158 +++--- .../server/router/model_infer/pre_process.py | 205 ++++---- lightllm/server/router/req_queue.py | 80 ++-- lightllm/server/tokenizer.py | 22 +- lightllm/server/visualserver/manager.py | 32 +- test/model/model_infer.py | 83 +++- test/model/test_bloom.py | 25 +- test/model/test_chatglm2.py | 27 +- test/model/test_intern.py | 27 +- test/model/test_llama.py | 27 +- test/model/test_llama2.py | 27 +- test/model/test_starcoder.py | 27 +- test/model/test_starcoder_quantized.py | 23 +- 40 files changed, 2374 insertions(+), 1147 deletions(-) create mode 100644 lightllm/server/router/dynamic_prompt/__init__.py create mode 100644 lightllm/server/router/dynamic_prompt/radix_cache.py create mode 100644 lightllm/server/router/dynamic_prompt/shared_arr.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 965457a97..678ac8b8f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,4 +11,4 @@ repos: hooks: - id: flake8 additional_dependencies: [flake8-typing-imports==1.9.0] - args: ['--config=.flake8', '--max-line-length=120', '--ignore=C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606'] \ No newline at end of file + args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606'] \ No newline at end of file diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 02831d74b..906bc67ea 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -1,4 +1,5 @@ import os + # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" import json import torch @@ -54,32 +55,36 @@ def __init__(self, kvargs): self._init_some_value() self._init_custom() return - + def _init_config(self): - with open(os.path.join(self.weight_dir_, "config.json"), 'r') as json_file: + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: self.config = json.load(json_file) # rename keys repair_config(self.config, same_names=["num_attention_heads", "n_head"]) repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) if self.finetune_config: - self.config['vocab_size'] = self.finetune_config.vocab_size + self.config["vocab_size"] = self.finetune_config.vocab_size return - + @final def _verify_must(self): assert self.config["num_attention_heads"] % self.world_size_ == 0 return - + def _verify_params(self): assert self.load_way == "HF", "only support HF format weights" assert self.config["num_key_value_heads"] % self.world_size_ == 0 return def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode) + self.pre_post_weight = self.pre_and_post_weight_class( + self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode + ) self.trans_layers_weight = [ - self.transformer_weight_class(i, self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode) + self.transformer_weight_class( + i, self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode + ) for i in range(self.config["n_layer"]) ] load_hf_weights( @@ -87,39 +92,42 @@ def _init_weights(self): weight_dir=self.weight_dir_, pre_post_layer=self.pre_post_weight, transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict) + weight_dict=self.weight_dict, + ) self.pre_post_weight.verify_load() [weight.verify_load() for weight in self.trans_layers_weight] - return - + return + def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.world_size_ == 0 - self.mem_manager = MemoryManager(self.max_total_token_num, - dtype=torch.float16, - head_num=self.config["num_attention_heads"] // self.world_size_, - head_dim=self.config["n_embed"] // self.config["num_attention_heads"], - layer_num=self.config["n_layer"]) + self.mem_manager = MemoryManager( + self.max_total_token_num, + dtype=torch.float16, + head_num=self.config["num_attention_heads"] // self.world_size_, + head_dim=self.config["n_embed"] // self.config["num_attention_heads"], + layer_num=self.config["n_layer"], + ) return - + def _init_req_manager(self): - self.req_manager = ReqManager(self.max_req_num, - self.max_seq_length, - self.mem_manager) - return - + self.req_manager = ReqManager(self.max_req_num, self.max_seq_length, self.mem_manager) + return + def _init_infer_layer(self): - self.pre_infer = self.pre_layer_infer_class(tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode) - self.post_infer = self.post_layer_infer_class(tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode) + self.pre_infer = self.pre_layer_infer_class( + tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode + ) + self.post_infer = self.post_layer_infer_class( + tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode + ) self.layers_infer = [ self.transformer_layer_infer_class( - i, - tp_rank=self.tp_rank_, - world_size=self.world_size_, - network_config=self.config, - mode=self.mode) for i in range( - self.config["n_layer"])] + i, tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode + ) + for i in range(self.config["n_layer"]) + ] return - + def _init_some_value(self): self.head_dim_ = self.config["n_embed"] // self.config["num_attention_heads"] self.tp_k_head_num_ = self.config["num_key_value_heads"] // self.world_size_ @@ -127,47 +135,77 @@ def _init_some_value(self): self.layers_num = self.config["n_layer"] self.vocab_size = self.config["vocab_size"] return - + def _init_custom(self): pass - @torch.no_grad() def forward( - self, - batch_size, - total_token_num, - max_len_in_batch, - input_ids : torch.Tensor, - b_req_idx : torch.Tensor, - b_start_loc : torch.Tensor, - b_seq_len : torch.Tensor, - multimodal_params=None, - is_prefill=True): + self, + batch_size, + total_token_num, + max_len_in_batch, + input_ids: torch.Tensor, + b_req_idx: torch.Tensor, + b_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, + b_ready_cache_len: torch.Tensor = None, + multimodal_params=None, + is_prefill=True, + ): if is_prefill: - return self._prefill(batch_size, total_token_num, max_len_in_batch, input_ids, b_req_idx, b_start_loc, b_seq_len, multimodal_params) + return self._prefill( + batch_size, + total_token_num, + max_len_in_batch, + input_ids, + b_req_idx, + b_start_loc, + b_seq_len, + b_ready_cache_len, + multimodal_params, + ) else: - return self._decode(batch_size, total_token_num, max_len_in_batch, input_ids, b_req_idx, b_start_loc, b_seq_len, multimodal_params) + return self._decode( + batch_size, + total_token_num, + max_len_in_batch, + input_ids, + b_req_idx, + b_start_loc, + b_seq_len, + multimodal_params, + ) - - def _prefill(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_req_idx, b_start_loc, b_seq_len, multimodal_params): + def _prefill( + self, + batch_size, + total_token_num, + max_len_in_batch, + input_ids, + b_req_idx, + b_start_loc, + b_seq_len, + b_ready_cache_len, + multimodal_params, + ): infer_state = self.infer_state_class() infer_state.is_prefill = True infer_state.return_all_prompt_logprobs = self.return_all_prompt_logprobs infer_state.batch_size = batch_size infer_state.total_token_num = total_token_num infer_state.max_len_in_batch = max_len_in_batch - assert (input_ids.shape[0] == total_token_num) - assert (b_req_idx.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0]) + assert b_req_idx.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0] infer_state.b_req_idx = b_req_idx infer_state.b_start_loc = b_start_loc infer_state.b_seq_len = b_seq_len + infer_state.b_ready_cache_len = b_ready_cache_len infer_state.multimodal_params = multimodal_params infer_state.mem_manager = self.mem_manager infer_state.req_manager = self.req_manager - alloc_mem = self.mem_manager.alloc_contiguous(infer_state.total_token_num) + alloc_mem = self.mem_manager.alloc_contiguous(input_ids.shape[0]) if alloc_mem is not None: infer_state.mem_is_contiguous = True infer_state.mem_index = alloc_mem[0] @@ -176,29 +214,49 @@ def _prefill(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_r else: infer_state.mem_is_contiguous = False - alloc_mem = self.mem_manager.alloc(infer_state.total_token_num) + alloc_mem = self.mem_manager.alloc(input_ids.shape[0]) infer_state.mem_index = alloc_mem - infer_state.kv_buffer = torch.empty((infer_state.total_token_num, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - - init_req_to_token_indexes(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, - max_len_in_batch, infer_state.mem_index) + infer_state.kv_buffer = torch.empty( + (input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), + dtype=torch.float16, + device="cuda", + ) + + init_req_to_token_indexes( + self.req_manager.req_to_token_indexs, + b_req_idx, + b_seq_len, + b_ready_cache_len, + max_len_in_batch, + infer_state.mem_index, + ) infer_state.init_some_extra_state(self, input_ids) predict_logics = self._context_forward(input_ids, infer_state) return predict_logics - - def _decode(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_req_idx, b_start_loc, b_seq_len, multimodal_params): + + def _decode( + self, + batch_size, + total_token_num, + max_len_in_batch, + input_ids, + b_req_idx, + b_start_loc, + b_seq_len, + multimodal_params, + ): infer_state = self.infer_state_class() infer_state.is_prefill = False infer_state.batch_size = batch_size infer_state.total_token_num = total_token_num infer_state.max_len_in_batch = max_len_in_batch - assert (b_req_idx.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0]) + assert b_req_idx.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0] infer_state.b_req_idx = b_req_idx infer_state.b_start_loc = b_start_loc infer_state.b_seq_len = b_seq_len infer_state.multimodal_params = multimodal_params - + infer_state.mem_manager = self.mem_manager infer_state.req_manager = self.req_manager @@ -213,31 +271,35 @@ def _decode(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_re infer_state.mem_is_contiguous = False alloc_mem = self.mem_manager.alloc(batch_size) infer_state.mem_index = alloc_mem - infer_state.kv_buffer = torch.empty((batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.kv_buffer = torch.empty( + (batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), + dtype=torch.float16, + device="cuda", + ) copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index) infer_state.init_some_extra_state(self, input_ids) predict_logics = self._token_forward(input_ids, infer_state) return predict_logics - + @torch.no_grad() def splitfuse_forward( - self, - input_ids, - decode_req_num, - decode_total_token_num, - decode_b_req_idx : torch.Tensor, - decode_b_start_loc : torch.Tensor, - decode_b_seq_len : torch.Tensor, - decode_max_len_in_batch, - - prefill_req_num, - prefill_b_req_idx : torch.Tensor, - prefill_b_split_start_loc : torch.Tensor, - prefill_b_split_seq_len : torch.Tensor, - prefill_max_split_seq_len_in_batch, - prefill_b_seq_len : torch.Tensor): - + self, + input_ids, + decode_req_num, + decode_total_token_num, + decode_b_req_idx: torch.Tensor, + decode_b_start_loc: torch.Tensor, + decode_b_seq_len: torch.Tensor, + decode_max_len_in_batch, + prefill_req_num, + prefill_b_req_idx: torch.Tensor, + prefill_b_split_start_loc: torch.Tensor, + prefill_b_split_ready_cache_len: torch.Tensor, + prefill_max_split_seq_len_in_batch, + prefill_b_seq_len: torch.Tensor, + ): + infer_state = self.splitfuse_infer_state_class() infer_state.batch_size = decode_req_num + prefill_req_num @@ -251,14 +313,14 @@ def splitfuse_forward( infer_state.prefill_req_num = prefill_req_num infer_state.prefill_b_req_idx = prefill_b_req_idx infer_state.prefill_b_split_start_loc = prefill_b_split_start_loc - infer_state.prefill_b_split_seq_len = prefill_b_split_seq_len + infer_state.prefill_b_split_ready_cache_len = prefill_b_split_ready_cache_len infer_state.prefill_max_split_seq_len_in_batch = prefill_max_split_seq_len_in_batch infer_state.prefill_b_seq_len = prefill_b_seq_len # infer_state.event = [torch.cuda.Event() for _ in range(self.layers_num)] infer_state.mem_manager = self.mem_manager infer_state.req_manager = self.req_manager - + alloc_size = len(input_ids) alloc_mem = self.mem_manager.alloc_contiguous(alloc_size) if alloc_mem is not None: @@ -270,35 +332,45 @@ def splitfuse_forward( infer_state.mem_is_contiguous = False alloc_mem = self.mem_manager.alloc(alloc_size) infer_state.mem_index = alloc_mem - infer_state.kv_buffer = torch.empty((alloc_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - + infer_state.kv_buffer = torch.empty( + (alloc_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), + dtype=torch.float16, + device="cuda", + ) + # decode 部分 if decode_req_num != 0: - copy_kv_index_to_req(self.req_manager.req_to_token_indexs, - decode_b_req_idx, - decode_b_seq_len, - infer_state.mem_index[0:decode_req_num]) - + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + decode_b_req_idx, + decode_b_seq_len, + infer_state.mem_index[0:decode_req_num], + ) + # split prefill 部分 if prefill_req_num != 0: - splitfuse_copy_kv_index_to_req(self.req_manager.req_to_token_indexs, - prefill_b_req_idx, - prefill_b_split_seq_len, - prefill_b_seq_len, - infer_state.mem_index[decode_req_num:]) - + splitfuse_copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + prefill_b_req_idx, + prefill_b_split_ready_cache_len, + prefill_b_seq_len, + infer_state.mem_index[decode_req_num:], + ) + infer_state.init_some_extra_state(self, input_ids) infer_state.create_inner_decode_infer_status() predict_logics = self._splitfuse_forward(input_ids, infer_state) return predict_logics - + @final def _context_forward(self, input_ids, infer_state: InferStateInfo): cuda_input_ids = input_ids input_embs = self.pre_infer.context_forward(cuda_input_ids, infer_state, self.pre_post_weight) for i in range(self.layers_num): input_embs = self.layers_infer[i].context_forward(input_embs, infer_state, self.trans_layers_weight[i]) - predict_logics = self.post_infer.token_forward(input_embs, infer_state, self.pre_post_weight, return_logics=True) + predict_logics = self.post_infer.token_forward( + input_embs, infer_state, self.pre_post_weight, return_logics=True + ) return predict_logics @final @@ -307,16 +379,18 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo): input_embs = self.pre_infer.token_forward(cuda_input_ids, infer_state, self.pre_post_weight) for i in range(self.layers_num): input_embs = self.layers_infer[i].token_forward(input_embs, infer_state, self.trans_layers_weight[i]) - predict_logics = self.post_infer.token_forward(input_embs, infer_state, self.pre_post_weight, return_logics=True) + predict_logics = self.post_infer.token_forward( + input_embs, infer_state, self.pre_post_weight, return_logics=True + ) return predict_logics - + @final def _splitfuse_forward(self, input_ids, infer_state: SplitFuseInferStateInfo): cuda_input_ids = input_ids input_embs = self.pre_infer.splitfuse_forward(cuda_input_ids, infer_state, self.pre_post_weight) for i in range(self.layers_num): input_embs = self.layers_infer[i].splitfuse_forward(input_embs, infer_state, self.trans_layers_weight[i]) - predict_logics = self.post_infer.splitfuse_forward(input_embs, infer_state, self.pre_post_weight, return_logics=True) + predict_logics = self.post_infer.splitfuse_forward( + input_embs, infer_state, self.pre_post_weight, return_logics=True + ) return predict_logics - - diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 34c068279..672b56b35 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -2,6 +2,7 @@ from lightllm.common.mem_manager import MemoryManager from lightllm.common.req_manager import ReqManager + class InferStateInfo: """ 推理时用的信息结构体 @@ -12,23 +13,26 @@ def __init__(self): self.total_token_num = None self.b_req_idx = None self.b_start_loc = None + self.b_ready_cache_len = None # only for prefill prompt cache used. self.b_seq_len = None + # max_len_in_batch prefill 和 decode 阶段含义不同 + # prefill 阶段指每个req 输入token的长度(不包括已经cache的部分)最大值 + # decode 阶段指的是每个req的总长 最大值 self.max_len_in_batch = None self.is_prefill = None - + self.mem_manager: MemoryManager = None self.req_manager: ReqManager = None - + self.mem_is_contiguous = None self.mem_index = None - self.mem_start = None + self.mem_start = None self.mem_end = None self.kv_buffer = None self.is_splitfuse = False self.return_all_prompt_logprobs = False self.multimodal_params = None - - def init_some_extra_state(self, model, input_ids : torch.Tensor): + def init_some_extra_state(self, model, input_ids: torch.Tensor): pass diff --git a/lightllm/common/basemodel/splitfuse_infer_struct.py b/lightllm/common/basemodel/splitfuse_infer_struct.py index 694361394..dddfba31b 100755 --- a/lightllm/common/basemodel/splitfuse_infer_struct.py +++ b/lightllm/common/basemodel/splitfuse_infer_struct.py @@ -3,6 +3,7 @@ from lightllm.common.mem_manager import MemoryManager from lightllm.common.req_manager import ReqManager + class SplitFuseInferStateInfo: """ 推理时用的信息结构体 @@ -12,26 +13,26 @@ class SplitFuseInferStateInfo: def __init__(self): self.batch_size = None - + self.decode_req_num = None self.decode_total_token_num = None - self.decode_b_req_idx : torch.Tensor = None - self.decode_b_start_loc : torch.Tensor = None - self.decode_b_seq_len : torch.Tensor = None + self.decode_b_req_idx: torch.Tensor = None + self.decode_b_start_loc: torch.Tensor = None + self.decode_b_seq_len: torch.Tensor = None self.decode_max_len_in_batch = None self.prefill_req_num = None - self.prefill_b_req_idx : torch.Tensor = None - self.prefill_b_split_start_loc : torch.Tensor = None - self.prefill_b_split_seq_len : torch.Tensor = None + self.prefill_b_req_idx: torch.Tensor = None + self.prefill_b_split_start_loc: torch.Tensor = None + self.prefill_b_split_ready_cache_len: torch.Tensor = None self.prefill_max_split_seq_len_in_batch = None - self.prefill_b_seq_len : torch.Tensor = None + self.prefill_b_seq_len: torch.Tensor = None self.mem_manager: MemoryManager = None self.req_manager: ReqManager = None self.mem_is_contiguous = None - self.mem_start = None + self.mem_start = None self.mem_end = None self.mem_index = None self.kv_buffer = None @@ -59,6 +60,6 @@ def create_inner_decode_infer_status(self): self.inner_decode_infer_status = infer_status return infer_status - - def init_some_extra_state(self, model, input_ids : torch.Tensor): + + def init_some_extra_state(self, model, input_ids: torch.Tensor): pass diff --git a/lightllm/common/basemodel/triton_kernel/splitfuse_copy_kv_index_to_req.py b/lightllm/common/basemodel/triton_kernel/splitfuse_copy_kv_index_to_req.py index bc3f7db3d..da36a9f0c 100644 --- a/lightllm/common/basemodel/triton_kernel/splitfuse_copy_kv_index_to_req.py +++ b/lightllm/common/basemodel/triton_kernel/splitfuse_copy_kv_index_to_req.py @@ -6,9 +6,15 @@ @triton.jit def _fwd_kernel_copy_kv_index_to_req( - req_to_token_indexs, b_req_idx, b_split_seq_len, cumsum_split_seq_len, b_seq_len, memindex, - stride_req_to_token_b, stride_req_to_token_s, - BLOCK_M: tl.constexpr + req_to_token_indexs, + b_req_idx, + b_split_seq_len, + cumsum_split_seq_len, + b_seq_len, + memindex, + stride_req_to_token_b, + stride_req_to_token_s, + BLOCK_M: tl.constexpr, ): cur_index = tl.program_id(0) cur_req_idx = tl.load(b_req_idx + cur_index) @@ -18,24 +24,36 @@ def _fwd_kernel_copy_kv_index_to_req( store_end = tl.load(b_seq_len + cur_index) store_start = store_end - q_split_len - + off_m = tl.arange(0, BLOCK_M) for block_start in range(0, q_split_len, BLOCK_M): - read_index = tl.load(memindex + q_mem_start + block_start + off_m, mask = q_mem_start + block_start + off_m < q_mem_end, other=0) - tl.store(req_to_token_indexs + cur_req_idx * stride_req_to_token_b + (block_start + store_start + off_m), read_index, - mask = block_start + store_start + off_m < store_end) + read_index = tl.load( + memindex + q_mem_start + block_start + off_m, mask=q_mem_start + block_start + off_m < q_mem_end, other=0 + ) + tl.store( + req_to_token_indexs + cur_req_idx * stride_req_to_token_b + (block_start + store_start + off_m), + read_index, + mask=block_start + store_start + off_m < store_end, + ) return @torch.no_grad() -def splitfuse_copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_split_seq_len, b_seq_len, memindex): +def splitfuse_copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_ready_cache_len, b_seq_len, memindex): batch_size = b_seq_len.shape[0] grid = (batch_size,) num_warps = 1 + b_split_seq_len = b_seq_len - b_ready_cache_len cumsum_split_seq_len = torch.cumsum(b_split_seq_len, dim=0) _fwd_kernel_copy_kv_index_to_req[grid]( - req_to_token_indexs, b_req_idx, b_split_seq_len, cumsum_split_seq_len, b_seq_len, memindex, - req_to_token_indexs.stride(0), req_to_token_indexs.stride(1), + req_to_token_indexs, + b_req_idx, + b_split_seq_len, + cumsum_split_seq_len, + b_seq_len, + memindex, + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), BLOCK_M=32, num_warps=num_warps, num_stages=1, diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index f38b51584..da2f35e08 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -1,9 +1,15 @@ -def init_req_to_token_indexes(req_to_token_indexs, b_req_idx, b_seq_len, max_len_in_batch, alloc_mem_index): +def init_req_to_token_indexes( + req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, max_len_in_batch, alloc_mem_index +): start_index = 0 b_seq_len_numpy = b_seq_len.cpu().numpy() + b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy() b_req_idx_numpy = b_req_idx.cpu().numpy() for i in range(len(b_seq_len)): cur_seq_len = b_seq_len_numpy[i] - req_to_token_indexs[b_req_idx_numpy[i], 0:cur_seq_len] = alloc_mem_index[start_index:start_index + cur_seq_len] - start_index += cur_seq_len - return \ No newline at end of file + cur_ready_cache_len = b_ready_cache_len_numpy[i] + req_to_token_indexs[b_req_idx_numpy[i], cur_ready_cache_len:cur_seq_len] = alloc_mem_index[ + start_index : start_index + cur_seq_len - cur_ready_cache_len + ] + start_index += cur_seq_len - cur_ready_cache_len + return diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index f967323c7..e281b46f7 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -2,59 +2,63 @@ from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) - + class MemoryManager: def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False): - self.size = size + self.size = size self.dtype = dtype self.head_num = head_num self.head_dim = head_dim self.layer_num = layer_num self.always_copy = always_copy - + # mem_state 修改为使用计数方式,方便后期实现token共享机制,实现beam search 等 self.mem_state = torch.zeros((size,), dtype=torch.int32, device="cuda") self.indexes = torch.arange(0, size, dtype=torch.long, device="cuda") self.can_use_mem_size = size self._init_buffers(size, dtype, head_num, head_dim, layer_num) - + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = [torch.empty((size, 2 * head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)] - + self.kv_buffer = [ + torch.empty((size, 2 * head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num) + ] + def _free_buffers(self): self.kv_buffer = None - + @torch.no_grad() def alloc(self, need_size): if need_size > self.can_use_mem_size: - logger.warn(f'warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}') + logger.warn(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") return None can_use_index = torch.nonzero(self.mem_state == 0).view(-1) - select_index = can_use_index[0 : need_size] + select_index = can_use_index[0:need_size] self.add_refs(select_index) return select_index - + @torch.no_grad() def alloc_contiguous(self, need_size): if self.always_copy: return None if need_size > self.can_use_mem_size: - logger.warn(f'warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}') + logger.warn(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") return None - + can_use_index = torch.nonzero(self.mem_state == 0).view(-1) can_use_index_size = len(can_use_index) - can_use_index = can_use_index[0 : can_use_index_size - need_size + 1][(can_use_index[need_size - 1: ] - can_use_index[0 : can_use_index_size - need_size + 1]) == need_size - 1] + can_use_index = can_use_index[0 : can_use_index_size - need_size + 1][ + (can_use_index[need_size - 1 :] - can_use_index[0 : can_use_index_size - need_size + 1]) == need_size - 1 + ] if can_use_index.shape[0] == 0: # logger.warn(f'warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}') return None start = can_use_index[0].item() end = start + need_size - select_index = self.indexes[start : end] + select_index = self.indexes[start:end] self.add_refs(select_index) return select_index, start, end - + @torch.no_grad() def free(self, free_index): """_summary_ @@ -67,7 +71,7 @@ def free(self, free_index): if self.can_use_mem_size == len(self.mem_state): logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") return - + @torch.no_grad() def add_refs(self, token_index: torch.Tensor): state = self.mem_state[token_index] @@ -76,33 +80,32 @@ def add_refs(self, token_index: torch.Tensor): self.can_use_mem_size -= all_tokens - has_used_tokens self.mem_state[token_index] += 1 return - + @torch.no_grad() def decrease_refs(self, token_index: torch.Tensor): - self.mem_state[token_index] -= 1 + token_index, counts = token_index.unique(return_counts=True) + self.mem_state[token_index] -= counts state = self.mem_state[token_index] used_tokens = torch.count_nonzero(state).item() all_tokens = len(state) self.can_use_mem_size += all_tokens - used_tokens return - @torch.no_grad() def free_all(self): self.can_use_mem_size = len(self.mem_state) self.mem_state[:] = 0 - + @torch.no_grad() def resize_mem(self, new_size): """ just for test code """ - size = new_size + size = new_size dtype = self.dtype head_num = self.head_num head_dim = self.head_dim layer_num = self.layer_num - always_copy = self.always_copy self.mem_state = torch.zeros((size,), dtype=torch.int32, device="cuda") self.indexes = torch.arange(0, size, dtype=torch.long, device="cuda") diff --git a/lightllm/models/bloom/layer_infer/post_layer_infer.py b/lightllm/models/bloom/layer_infer/post_layer_infer.py index d27be9382..bbff0bff5 100644 --- a/lightllm/models/bloom/layer_infer/post_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/post_layer_infer.py @@ -31,7 +31,7 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh batch_size = infer_state.batch_size last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=torch.float16) if infer_state.is_prefill: - last_index = torch.cumsum(infer_state.b_seq_len, dim=0, dtype=torch.long) - 1 + last_index = torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 last_input[:, :] = input_embdings[last_index, :] else: last_input[:, :] = input_embdings[-batch_size:, :] diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index 5dea5f975..4c7a76ff9 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -63,16 +63,37 @@ def _context_attention_kernel( self, q, kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None ) -> torch.Tensor: o_tensor = torch.empty_like(q) if out is None else out - context_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - layer_weight.tp_alibi, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) + import triton + if triton.__version__ >= "2.1.0": + context_attention_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + infer_state.b_req_idx, + layer_weight.tp_alibi, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + infer_state.req_manager.req_to_token_indexs, + ) + elif triton.__version__ == "2.0.0": + context_attention_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + layer_weight.tp_alibi, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + else: + assert False + return o_tensor def _token_attention_kernel( diff --git a/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py b/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py index 6a511a5e8..acf749b8f 100644 --- a/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py +++ b/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py @@ -6,23 +6,46 @@ import torch.nn.functional as F if triton.__version__ >= "2.1.0": + @triton.jit def _fwd_kernel( - Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, + Q, + K, + V, + sm_scale, + Alibi, + B_Start_Loc, + B_Seqlen, Out, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + Req_to_tokens, + B_req_idx, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + b_ready_cache_len, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + ready_cache_len = tl.load(b_ready_cache_len + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - ready_cache_len block_start_loc = BLOCK_M * start_m @@ -30,15 +53,14 @@ def _fwd_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd - off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - k_ptrs = K + off_k - v_ptrs = V + off_v - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) @@ -46,21 +68,29 @@ def _fwd_kernel( alibi_m = tl.load(Alibi + cur_head) block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + ready_cache_len, cur_batch_seq_len + ready_cache_len) - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + for start_n in range(0, block_mask * block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + kv_loc = tl.load( + Req_to_tokens + + stride_req_to_tokens_b * cur_batch_req_idx + + stride_req_to_tokens_s * (start_n + offs_n), + mask=(start_n + offs_n) < block_end_loc, + other=0, + ) + off_k = kv_loc[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale - alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + alibi_loc = ready_cache_len + offs_m[:, None] - (start_n + offs_n[None, :]) qk -= alibi_loc * alibi_m - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk = tl.where((offs_m[:, None] + ready_cache_len) >= (start_n + offs_n[None, :]), qk, -10000000.0) m_ij = tl.max(qk, 1) p = tl.exp(qk - m_ij[:, None]) @@ -78,8 +108,8 @@ def _fwd_kernel( acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + off_v = kv_loc[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < block_end_loc, other=0.0) p = p.to(v.dtype) acc += tl.dot(p, v) @@ -87,20 +117,26 @@ def _fwd_kernel( l_i = l_i_new m_i = m_i_new # initialize pointers to output - off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return @torch.no_grad() - def context_attention_fwd(q, k, v, o, alibi, b_start_loc, b_seq_len, max_input_len): + def context_attention_fwd( + q, k, v, o, b_req_idx, alibi, b_start_loc, b_seq_len, b_ready_cache_len, max_input_len, req_to_token_indexs + ): BLOCK = 128 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lq**0.5) + sm_scale = 1.0 / (Lq ** 0.5) batch, head = b_seq_len.shape[0], q.shape[1] grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) @@ -108,12 +144,31 @@ def context_attention_fwd(q, k, v, o, alibi, b_start_loc, b_seq_len, max_input_l num_warps = 4 if Lk <= 64 else 8 _fwd_kernel[grid]( - q, k, v, sm_scale, alibi, b_start_loc, b_seq_len, + q, + k, + v, + sm_scale, + alibi, + b_start_loc, + b_seq_len, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), + req_to_token_indexs, + b_req_idx, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), + b_ready_cache_len, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, @@ -121,18 +176,37 @@ def context_attention_fwd(q, k, v, o, alibi, b_start_loc, b_seq_len, max_input_l num_stages=1, ) return + elif triton.__version__ == "2.0.0": + @triton.jit def _fwd_kernel( - Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, + Q, + K, + V, + sm_scale, + Alibi, + B_Start_Loc, + B_Seqlen, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug Out, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - stride_tmp_b, stride_tmp_h, stride_tmp_s, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) @@ -148,7 +222,11 @@ def _fwd_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd @@ -169,8 +247,11 @@ def _fwd_kernel( for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -199,8 +280,11 @@ def _fwd_kernel( acc_scale = tl.load(t_ptrs) acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) p = p.to(v.dtype) acc += tl.dot(p, v) @@ -208,7 +292,11 @@ def _fwd_kernel( l_i = l_i_new m_i = m_i_new # initialize pointers to output - off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return @@ -221,7 +309,7 @@ def context_attention_fwd(q, k, v, o, alibi, b_start_loc, b_seq_len, max_input_l assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lq**0.5) + sm_scale = 1.0 / (Lq ** 0.5) batch, head = b_seq_len.shape[0], q.shape[1] grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) @@ -231,14 +319,30 @@ def context_attention_fwd(q, k, v, o, alibi, b_start_loc, b_seq_len, max_input_l tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) _fwd_kernel[grid]( - q, k, v, sm_scale, alibi, b_start_loc, b_seq_len, + q, + k, + v, + sm_scale, + alibi, + b_start_loc, + b_seq_len, tmp, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - tmp.stride(0), tmp.stride(1), tmp.stride(2), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, @@ -246,6 +350,7 @@ def context_attention_fwd(q, k, v, o, alibi, b_start_loc, b_seq_len, max_input_l num_stages=1, ) return + else: raise Exception("error triton version!") @@ -255,7 +360,7 @@ def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): xk = xk.view(bs, seqlen, num_head, head_dim) xv = xv.view(bs, seqlen, num_head, head_dim) mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() - mask[mask == 0.] = -100000000.0 + mask[mask == 0.0] = -100000000.0 mask = mask.repeat(bs, num_head, 1, 1) keys = xk values = xv @@ -264,8 +369,9 @@ def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): values = values.transpose(1, 2) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) - output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, - num_head, head_dim) # (bs, n_local_heads, slen, head_dim) + output = ( + torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + ) # (bs, n_local_heads, slen, head_dim) return output diff --git a/lightllm/models/llama/infer_struct.py b/lightllm/models/llama/infer_struct.py index a87e2be0e..f1097738b 100644 --- a/lightllm/models/llama/infer_struct.py +++ b/lightllm/models/llama/infer_struct.py @@ -3,18 +3,24 @@ from lightllm.common.basemodel import InferStateInfo from lightllm.common.req_manager import ReqManager + class LlamaInferStateInfo(InferStateInfo): def __init__(self): super().__init__() self.position_cos = None self.position_sin = None self.other_kv_index = None - - def init_some_extra_state(self, model, input_ids : torch.Tensor): + + def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.is_prefill: b_seq_len_numpy = self.b_seq_len.cpu().numpy() - position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) - for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + b_ready_cache_len_numpy = self.b_ready_cache_len.cpu().numpy() + position_ids = torch.from_numpy( + np.concatenate( + [np.arange(b_ready_cache_len_numpy[i], b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], + axis=0, + ) + ).cuda() self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) position_ids = None diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 840e8669a..031568739 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -34,7 +34,7 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo tmp_ = torch.cat( [ torch.ones(infer_state.decode_req_num, dtype=torch.int32, device="cuda"), - infer_state.prefill_b_split_seq_len, + infer_state.prefill_b_seq_len - infer_state.prefill_b_split_ready_cache_len, ], dim=0, ) @@ -45,7 +45,9 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo if not infer_state.is_splitfuse and infer_state.is_prefill and not infer_state.return_all_prompt_logprobs: batch_size = infer_state.batch_size last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=torch.float16) - last_index = torch.cumsum(infer_state.b_seq_len, dim=0, dtype=torch.long) - 1 + last_index = ( + torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 + ) last_input[:, :] = input_embdings[last_index, :] return last_input, batch_size diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 3875433ba..27faa53f2 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -123,15 +123,35 @@ def _context_attention_kernel( self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None ) -> torch.Tensor: o_tensor = torch.empty_like(q) if out is None else out - context_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) + import triton + if triton.__version__ >= "2.1.0": + context_attention_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], + infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + infer_state.req_manager.req_to_token_indexs, + ) + elif triton.__version__ == "2.0.0": + context_attention_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + else: + assert False + return o_tensor def _splitfuse_attention_kernel( @@ -163,7 +183,7 @@ def _splitfuse_attention_kernel( infer_state.req_manager.req_to_token_indexs, infer_state.prefill_b_req_idx, infer_state.prefill_b_split_start_loc, - infer_state.prefill_b_split_seq_len, + infer_state.prefill_b_split_ready_cache_len, infer_state.prefill_b_seq_len, infer_state.prefill_max_split_seq_len_in_batch, ) @@ -205,7 +225,7 @@ def _splitfuse_attention_kernel_int8kv( infer_state.req_manager.req_to_token_indexs, infer_state.prefill_b_req_idx, infer_state.prefill_b_split_start_loc, - infer_state.prefill_b_split_seq_len, + infer_state.prefill_b_split_ready_cache_len, infer_state.prefill_b_seq_len, infer_state.prefill_max_split_seq_len_in_batch, ) diff --git a/lightllm/models/llama/splitfuse_infer_struct.py b/lightllm/models/llama/splitfuse_infer_struct.py index ea1556e57..5c754aaf1 100644 --- a/lightllm/models/llama/splitfuse_infer_struct.py +++ b/lightllm/models/llama/splitfuse_infer_struct.py @@ -4,8 +4,9 @@ from lightllm.common.req_manager import ReqManager from .infer_struct import LlamaInferStateInfo + class LlamaSplitFuseInferStateInfo(SplitFuseInferStateInfo): - + inner_decode_infer_state_class = LlamaInferStateInfo def __init__(self): @@ -13,17 +14,18 @@ def __init__(self): self.position_cos = None self.position_sin = None self.other_kv_index = None - - def init_some_extra_state(self, model, input_ids : torch.Tensor): + + def init_some_extra_state(self, model, input_ids: torch.Tensor): position_ids = [] if self.decode_req_num != 0: position_ids.append((self.decode_b_seq_len - 1).cpu().numpy()) if self.prefill_req_num != 0: b_seq_len_numpy = self.prefill_b_seq_len.cpu().numpy() - b_split_len_numpy = self.prefill_b_split_seq_len.cpu().numpy() - b_start_numpy = b_seq_len_numpy - b_split_len_numpy - position_ids.extend([np.arange(b_start_numpy[i], b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))]) - + b_ready_cache_len_numpy = self.prefill_b_split_ready_cache_len.cpu().numpy() + position_ids.extend( + [np.arange(b_ready_cache_len_numpy[i], b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))] + ) + position_ids = torch.from_numpy(np.concatenate(position_ids, axis=0)).cuda().view(-1) self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) @@ -33,7 +35,7 @@ def init_some_extra_state(self, model, input_ids : torch.Tensor): elif self.prefill_req_num != 0: self.other_kv_index = self.req_manager.req_to_token_indexs[self.prefill_b_req_idx[0], 0].item() return - + def create_inner_decode_infer_status(self): infer_status = super().create_inner_decode_infer_status() infer_status.other_kv_index = self.other_kv_index diff --git a/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py index d63999d6f..e7076c360 100644 --- a/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py +++ b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py @@ -5,30 +5,52 @@ import math import torch.nn.functional as F -TESLA = 'Tesla' in torch.cuda.get_device_name(0) +TESLA = "Tesla" in torch.cuda.get_device_name(0) if triton.__version__ >= "2.1.0": + @triton.jit def _fwd_kernel( - Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 Out, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, + Req_to_tokens, + B_req_idx, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, kv_group_num, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + b_prompt_cache_len, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) - + cur_kv_head = cur_head // kv_group_num - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) block_start_loc = BLOCK_M * start_m @@ -36,33 +58,39 @@ def _fwd_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - k_ptrs = K + off_k - v_ptrs = V + off_v - # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + for start_n in range(0, block_mask * block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) + kv_loc = tl.load( + Req_to_tokens + + stride_req_to_tokens_b * cur_batch_req_idx + + stride_req_to_tokens_s * (start_n + offs_n), + mask=(start_n + offs_n) < block_end_loc, + other=0, + ) + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk = tl.where(offs_m[:, None] + prompt_cache_len >= start_n + offs_n[None, :], qk, float("-100000000.0")) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) @@ -79,45 +107,69 @@ def _fwd_kernel( p = p * p_scale[:, None] # scale acc acc_scale = l_i / l_i_new * alpha + acc_scale = tl.where(offs_m + prompt_cache_len >= start_n, acc_scale, 1.0) acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) - + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < block_end_loc, other=0.0) p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i l_i = l_i_new m_i = m_i_new # initialize pointers to output - off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return @torch.no_grad() - def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): + def context_attention_fwd( + q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs + ): BLOCK = 128 if not TESLA else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128, 256} - sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 + sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] - + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, num_warps = 4 if Lk <= 64 else 8 _fwd_kernel[grid]( - q, k, v, sm_scale, b_start_loc, b_seq_len, + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), + req_to_token_indexs, + b_req_idx, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), kv_group_num=kv_group_num, + b_prompt_cache_len=b_prompt_cache_len, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, @@ -127,24 +179,41 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): return elif triton.__version__ == "2.0.0": + @triton.jit def _fwd_kernel( - Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug Out, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - stride_tmp_b, stride_tmp_h, stride_tmp_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, kv_group_num, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) - + cur_kv_head = cur_head // kv_group_num cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) @@ -156,7 +225,11 @@ def _fwd_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) @@ -175,8 +248,11 @@ def _fwd_kernel( for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -201,8 +277,11 @@ def _fwd_kernel( acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) p = p.to(v.dtype) acc += tl.dot(p, v) @@ -210,7 +289,11 @@ def _fwd_kernel( l_i = l_i_new m_i = m_i_new # initialize pointers to output - off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) @@ -224,24 +307,39 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128, 256} - sm_scale = 1.0 / (Lq**0.5) + sm_scale = 1.0 / (Lq ** 0.5) batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] - + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 # num_warps = 4 _fwd_kernel[grid]( - q, k, v, sm_scale, b_start_loc, b_seq_len, + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, tmp, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - tmp.stride(0), tmp.stride(1), tmp.stride(2), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), kv_group_num=kv_group_num, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, @@ -251,18 +349,18 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): ) return - else: raise Exception("error triton version!") - -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): +def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim, prompt_cache_len): xq = xq.view(bs, seqlen, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - xv = xv.view(bs, seqlen, num_head, head_dim) + xk = xk.view(bs, seqlen + prompt_cache_len, num_head, head_dim) + xv = xv.view(bs, seqlen + prompt_cache_len, num_head, head_dim) + mask_cache = torch.ones((seqlen, prompt_cache_len)).cuda().unsqueeze(0).unsqueeze(0).cuda() mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() - mask[mask == 0.] = -100000000.0 + mask[mask == 0.0] = -100000000.0 + mask = torch.cat([mask_cache, mask], dim=-1) mask = mask.repeat(bs, num_head, 1, 1) keys = xk values = xv @@ -277,38 +375,61 @@ def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): def test(): import torch + import numpy as np - Z, H, N_CTX, D_HEAD = 4, 6, 1024, 128 + Z, H, N_CTX, D_HEAD = 1, 6, 500, 128 dtype = torch.float16 - Z = 3 + Z = 1 q = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - v = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) + k = torch.empty((Z * N_CTX + 7000, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + v = torch.empty((Z * N_CTX + 7000, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) o = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - + req_to_token_indexs = torch.zeros((10, Z * N_CTX + 7000), dtype=torch.int32, device="cuda") max_input_len = N_CTX - Z = 4 + Z = 1 b_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") + b_req_idx = torch.ones((Z,), dtype=torch.int32, device="cuda") + b_prompt_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") + b_prompt_cache_len[0] = 10 + prompt_cache_len = 10 - b_seq_len[0] = 512 - b_seq_len[1] = 1024 - b_seq_len[2] = 512 - b_seq_len[3] = 1024 - - for i in range(1, Z): - b_start_loc[i] = b_start_loc[i - 1] + b_seq_len[i - 1] + b_seq_len[0] = 500 + b_req_idx[0] = 0 + req_to_token_indexs[0][: prompt_cache_len + N_CTX] = torch.tensor( + np.arange(prompt_cache_len + N_CTX), dtype=torch.int32 + ).cuda() torch_out = [] start = 0 for i in range(Z): end = start + b_seq_len[i] - torch_o = torch_att(q[start:end], k[start:end], v[start:end], 1, b_seq_len[i], H, D_HEAD) + torch_o = torch_att( + q[start:end], + k[start : end + prompt_cache_len], + v[start : end + prompt_cache_len], + 1, + b_seq_len[i], + H, + D_HEAD, + prompt_cache_len, + ) start = end torch_out.append(torch_o) + torch_out = torch.cat(torch_out, dim=0) - context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len) - print(o.shape, torch_out.shape) + import time + + torch.cuda.synchronize() + a = time.time() + for i in range(10000): + context_attention_fwd( + q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs + ) + torch.cuda.synchronize() + b = time.time() + # print(o.shape, torch_out.shape) + print((b - a) / 10000) print("max ", torch.max(torch.abs(torch_out - o))) print("mean ", torch.mean(torch.abs(torch_out - o))) diff --git a/lightllm/models/llama/triton_kernel/splitfuse_context_flashattention_nopad.py b/lightllm/models/llama/triton_kernel/splitfuse_context_flashattention_nopad.py index 2aff79201..fc5b228b6 100644 --- a/lightllm/models/llama/triton_kernel/splitfuse_context_flashattention_nopad.py +++ b/lightllm/models/llama/triton_kernel/splitfuse_context_flashattention_nopad.py @@ -8,33 +8,47 @@ @triton.jit def _fwd_kernel( - Q, K, V, sm_scale, Req_to_tokens, B_req_idx, - B_split_start_loc, - B_split_seq_len, - B_seqlen, + Q, + K, + V, + sm_scale, + Req_to_tokens, + B_req_idx, + B_split_start_loc, + B_split_ready_cache_len, + B_seqlen, Out, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - stride_req_to_tokens_b, stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, kv_group_num, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) - + cur_kv_head = cur_head // kv_group_num - - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_q_split_start_loc = tl.load(B_split_start_loc + cur_batch) - cur_batch_q_split_seq_len = tl.load(B_split_seq_len + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_q_split_start_loc = tl.load(B_split_start_loc + cur_batch) + cur_batch_seq_start = tl.load(B_split_ready_cache_len + cur_batch) cur_batch_seq_len = tl.load(B_seqlen + cur_batch) - cur_batch_seq_start = cur_batch_seq_len - cur_batch_q_split_seq_len - + cur_batch_q_split_seq_len = cur_batch_seq_len - cur_batch_seq_start + # initialize offsets offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) @@ -58,14 +72,19 @@ def _fwd_kernel( for start_n in range(0, block_mask * (cur_batch_seq_start + (start_m + 1) * BLOCK_M), BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - kv_loc = tl.load(Req_to_tokens + cur_batch_req_idx * stride_req_to_tokens_b + start_n + offs_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0) - k = tl.load(k_ptrs + kv_loc[None, :] * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + kv_loc = tl.load( + Req_to_tokens + cur_batch_req_idx * stride_req_to_tokens_b + start_n + offs_n, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=0, + ) + k = tl.load( + k_ptrs + kv_loc[None, :] * stride_kbs, mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0 + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale - qk = tl.where(cur_batch_seq_start + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk = tl.where(cur_batch_seq_start + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-100000000.0")) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) @@ -84,8 +103,9 @@ def _fwd_kernel( acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs + kv_loc[:, None] * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + v = tl.load( + v_ptrs + kv_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 + ) p = p.to(v.dtype) acc += tl.dot(p, v) @@ -98,15 +118,21 @@ def _fwd_kernel( tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_q_split_seq_len) return + @torch.no_grad() -def splitfuse_context_attention_fwd(q, k, v, o, - prefill_req_num, - req_to_tokens, - prefill_b_req_idx, - prefill_b_split_start_loc, - prefill_b_split_seq_len, - prefill_b_seq_len, - prefill_max_split_seq_len_in_batch): +def splitfuse_context_attention_fwd( + q, + k, + v, + o, + prefill_req_num, + req_to_tokens, + prefill_b_req_idx, + prefill_b_split_start_loc, + prefill_b_split_ready_cache_len, + prefill_b_seq_len, + prefill_max_split_seq_len_in_batch, +): if triton.__version__ == "2.0.0": raise Exception("triton version is not right") @@ -116,51 +142,85 @@ def splitfuse_context_attention_fwd(q, k, v, o, assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 - batch, head = prefill_b_seq_len.shape[0], q.shape[1] + sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 + _, head = prefill_b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] - + grid = (prefill_req_num, head, triton.cdiv(prefill_max_split_seq_len_in_batch, BLOCK)) num_warps = 4 if Lk <= 64 else 8 _fwd_kernel[grid]( - q, k, v, sm_scale, - req_to_tokens, + q, + k, + v, + sm_scale, + req_to_tokens, prefill_b_req_idx, prefill_b_split_start_loc, - prefill_b_split_seq_len, + prefill_b_split_ready_cache_len, prefill_b_seq_len, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - req_to_tokens.stride(0), req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_tokens.stride(0), + req_to_tokens.stride(1), kv_group_num=kv_group_num, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=num_warps, - num_stages=1 + num_stages=1, ) return + @triton.jit def _fwd_kernel_int8( - Q, K, K_scale, V, V_scale, sm_scale, Req_to_tokens, B_req_idx, - B_split_start_loc, - B_split_seq_len, - B_seqlen, + Q, + K, + K_scale, + V, + V_scale, + sm_scale, + Req_to_tokens, + B_req_idx, + B_split_start_loc, + B_split_ready_cache_len, + B_seqlen, Out, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_ksbs, stride_ksh, stride_ksd, - stride_vbs, stride_vh, stride_vd, - stride_vsbs, stride_vsh, stride_vsd, - stride_obs, stride_oh, stride_od, - stride_req_to_tokens_b, stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_ksbs, + stride_ksh, + stride_ksd, + stride_vbs, + stride_vh, + stride_vd, + stride_vsbs, + stride_vsh, + stride_vsd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, kv_group_num, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) @@ -170,11 +230,10 @@ def _fwd_kernel_int8( cur_kv_head = cur_head // kv_group_num cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_q_split_start_loc = tl.load(B_split_start_loc + cur_batch) - cur_batch_q_split_seq_len = tl.load(B_split_seq_len + cur_batch) - + cur_batch_q_split_start_loc = tl.load(B_split_start_loc + cur_batch) cur_batch_seq_len = tl.load(B_seqlen + cur_batch) - cur_batch_seq_start = cur_batch_seq_len - cur_batch_q_split_seq_len + cur_batch_seq_start = tl.load(B_split_ready_cache_len + cur_batch) + cur_batch_q_split_seq_len = cur_batch_seq_len - cur_batch_seq_start # initialize offsets offs_n = tl.arange(0, BLOCK_N) @@ -201,15 +260,22 @@ def _fwd_kernel_int8( for start_n in range(0, block_mask * (cur_batch_seq_start + (start_m + 1) * BLOCK_M), BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - kv_loc = tl.load(Req_to_tokens + cur_batch_req_idx * stride_req_to_tokens_b + start_n + offs_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0) - k = tl.load(k_ptrs + kv_loc[None, :] * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) - k_scale = tl.load(ks_ptrs + kv_loc[None, :] * stride_ksbs, mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + kv_loc = tl.load( + Req_to_tokens + cur_batch_req_idx * stride_req_to_tokens_b + start_n + offs_n, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=0, + ) + k = tl.load( + k_ptrs + kv_loc[None, :] * stride_kbs, mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0 + ) + k_scale = tl.load( + ks_ptrs + kv_loc[None, :] * stride_ksbs, mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0 + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, (k_scale * k)) qk *= sm_scale - qk = tl.where(cur_batch_seq_start + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk = tl.where(cur_batch_seq_start + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-100000000.0")) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) @@ -228,9 +294,12 @@ def _fwd_kernel_int8( acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs + kv_loc[:, None] * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) - v_scale = tl.load(vs_ptrs + kv_loc[:, None] * stride_vsbs, mask=(start_n + offs_n)[:, None] < cur_batch_seq_len, other=0.0) + v = tl.load( + v_ptrs + kv_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 + ) + v_scale = tl.load( + vs_ptrs + kv_loc[:, None] * stride_vsbs, mask=(start_n + offs_n)[:, None] < cur_batch_seq_len, other=0.0 + ) p = p.to(tl.float16) acc += tl.dot(p, v.to(tl.float16) * v_scale) @@ -244,15 +313,23 @@ def _fwd_kernel_int8( tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_q_split_seq_len) return + @torch.no_grad() -def splitfuse_context_attention_fwd_int8kv(q, k, k_scale, v, v_scale, o, - prefill_req_num, - req_to_tokens, - prefill_b_req_idx, - prefill_b_split_start_loc, - prefill_b_split_seq_len, - prefill_b_seq_len, - prefill_max_split_seq_len_in_batch): +def splitfuse_context_attention_fwd_int8kv( + q, + k, + k_scale, + v, + v_scale, + o, + prefill_req_num, + req_to_tokens, + prefill_b_req_idx, + prefill_b_split_start_loc, + prefill_b_split_ready_cache_len, + prefill_b_seq_len, + prefill_max_split_seq_len_in_batch, +): if triton.__version__ == "2.0.0": raise Exception("triton version is not right") @@ -262,33 +339,51 @@ def splitfuse_context_attention_fwd_int8kv(q, k, k_scale, v, v_scale, o, assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 - batch, head = prefill_b_seq_len.shape[0], q.shape[1] + sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 + _, head = prefill_b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] - + grid = (prefill_req_num, head, triton.cdiv(prefill_max_split_seq_len_in_batch, BLOCK)) num_warps = 4 if Lk <= 64 else 8 _fwd_kernel_int8[grid]( - q, k, k_scale, v, v_scale, sm_scale, - req_to_tokens, + q, + k, + k_scale, + v, + v_scale, + sm_scale, + req_to_tokens, prefill_b_req_idx, prefill_b_split_start_loc, - prefill_b_split_seq_len, + prefill_b_split_ready_cache_len, prefill_b_seq_len, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - k_scale.stride(0), k_scale.stride(1), k_scale.stride(2), - v.stride(0), v.stride(1), v.stride(2), - v_scale.stride(0), v_scale.stride(1), v_scale.stride(2), - o.stride(0), o.stride(1), o.stride(2), - req_to_tokens.stride(0), req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + k_scale.stride(0), + k_scale.stride(1), + k_scale.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + v_scale.stride(0), + v_scale.stride(1), + v_scale.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_tokens.stride(0), + req_to_tokens.stride(1), kv_group_num=kv_group_num, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=num_warps, - num_stages=1 + num_stages=1, ) return diff --git a/lightllm/models/llava/model.py b/lightllm/models/llava/model.py index 19e626cf9..f5d2f1a22 100644 --- a/lightllm/models/llava/model.py +++ b/lightllm/models/llava/model.py @@ -8,20 +8,19 @@ # Warp of the origal tokenizer class LlavaTokenizer: - def __init__(self, tokenizer, model_cfg): self.tokenizer = tokenizer self.image_token = model_cfg.get("image_token", "") - mm_vision_tower = model_cfg.get('mm_vision_tower', 'openai/clip-vit-large-patch14-336') + mm_vision_tower = model_cfg.get("mm_vision_tower", "openai/clip-vit-large-patch14-336") if isinstance(mm_vision_tower, list): mm_vision_tower = mm_vision_tower[0] - mm_vision_tower = mm_vision_tower.split('/')[-1] - vision_tower_match = re.match(r'^clip-vit-large-patch(\d+)-(\d+)$', mm_vision_tower) + mm_vision_tower = mm_vision_tower.split("/")[-1] + vision_tower_match = re.match(r"^clip-vit-large-patch(\d+)-(\d+)$", mm_vision_tower) patch_size = int(vision_tower_match.group(1)) default_img_size = int(vision_tower_match.group(2)) image_size = model_cfg.get("img_size", default_img_size) image_size = model_cfg.get("mm_image_size", image_size) - # (image_size // patch_size) ** 2: (336 // 14) ** 2 = 576 + # (image_size // patch_size) ** 2: (336 // 14) ** 2 = 576 self.image_length = (image_size // patch_size) ** 2 self.skip_start = model_cfg.get("skip_start", True) @@ -51,7 +50,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None): return input_ids def __getattr__(self, name): - if name != 'encode': + if name != "encode": return getattr(self.tokenizer, name) return self.encode diff --git a/lightllm/models/starcoder/infer_struct.py b/lightllm/models/starcoder/infer_struct.py index fa7013202..613dd3b16 100644 --- a/lightllm/models/starcoder/infer_struct.py +++ b/lightllm/models/starcoder/infer_struct.py @@ -2,16 +2,22 @@ import numpy as np from lightllm.common.basemodel import InferStateInfo + class StarcoderInferStateInfo(InferStateInfo): def __init__(self): super().__init__() self.position_ids = None - - def init_some_extra_state(self, model, input_ids : torch.Tensor): + + def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.is_prefill: b_seq_len_numpy = self.b_seq_len.cpu().numpy() - self.position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) - for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + b_ready_cache_len_numpy = self.b_ready_cache_len.cpu().numpy() + self.position_ids = torch.from_numpy( + np.concatenate( + [np.arange(b_ready_cache_len_numpy[i], b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], + axis=0, + ) + ).cuda() else: self.position_ids = self.b_seq_len - 1 self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item() diff --git a/lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py index 745875f79..0e8989a09 100755 --- a/lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py @@ -44,7 +44,7 @@ def _bind_func(self): self._ffn_norm = partial(BloomTransformerLayerInfer._ffn_norm, self) LlamaTransformerLayerInferWquant._bind_matmul(self) - LlamaTransformerLayerInferWquant._bind_attention(self) + LlamaTransformerLayerInfer._bind_attention(self) return def _get_qkv( diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 15be9a587..f0557cf34 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -59,6 +59,7 @@ ) from lightllm.utils.log_utils import init_logger + logger = init_logger(__name__) TIMEOUT_KEEP_ALIVE = 5 # seconds. @@ -78,6 +79,7 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse def healthcheck(): return "OK" + @app.post("/generate") async def generate(request: Request) -> Response: global isFirst @@ -110,7 +112,7 @@ async def generate(request: Request) -> Response: # Abort the request if the client disconnects. await httpserver_manager.abort(request_id) return Response(status_code=499) - + # when set "--return_all_prompt_logprobs", the first token metadata will contains # prompt_logprobs and prompt_token_ids if is_first_metadata: @@ -132,7 +134,7 @@ async def generate(request: Request) -> Response: ret = { "generated_text": ["".join(final_output)], "count_output_tokens": count_output_tokens, - "finish_reason": finish_status.get_finish_reason() + "finish_reason": finish_status.get_finish_reason(), } if return_details: ret["tokens"] = tokens @@ -154,7 +156,7 @@ async def generate_stream(request: Request) -> Response: request_dict = await request.json() prompt = request_dict.pop("inputs") sample_params_dict = request_dict["parameters"] - return_details = sample_params_dict.pop("return_details", False) + _ = sample_params_dict.pop("return_details", False) sampling_params = SamplingParams(**sample_params_dict) sampling_params.verify() multimodal_params_dict = request_dict.get("multimodal_params", {}) @@ -177,12 +179,10 @@ async def stream_results() -> AsyncGenerator[bytes, None]: "generated_text": None, "finished": finish_status.is_finished(), "finish_reason": finish_status.get_finish_reason(), - "details": None + "details": None, } - yield ("data:" + json.dumps(ret, ensure_ascii=False) + f"\n\n").encode( - "utf-8" - ) + yield ("data:" + json.dumps(ret, ensure_ascii=False) + "\n\n").encode("utf-8") async def abort_request() -> None: await httpserver_manager.abort(request_id) @@ -191,15 +191,11 @@ async def abort_request() -> None: # Abort the request if the client disconnects. background_tasks.add_task(abort_request) - return StreamingResponse( - stream_results(), media_type="text/event-stream", background=background_tasks - ) + return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) -async def chat_completions( - request: ChatCompletionRequest, raw_request: Request -) -> Response: +async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response: global isFirst if isFirst: loop = asyncio.get_event_loop() @@ -213,14 +209,10 @@ async def chat_completions( ) if request.n > 1: - return create_error_response( - HTTPStatus.BAD_REQUEST, "The n parameter currently only supports 1" - ) + return create_error_response(HTTPStatus.BAD_REQUEST, "The n parameter currently only supports 1") if request.function_call != "none": - return create_error_response( - HTTPStatus.BAD_REQUEST, "The function call feature is not supported" - ) + return create_error_response(HTTPStatus.BAD_REQUEST, "The function call feature is not supported") created_time = int(time.time()) prompt = await build_prompt(request) @@ -233,7 +225,7 @@ async def chat_completions( top_k=request.top_k, ignore_eos=request.ignore_eos, max_new_tokens=request.max_tokens, - stop_sequences=request.stop + stop_sequences=request.stop, ) sampling_params.verify() multimodal_params = MultimodalParams(images=[]) @@ -259,16 +251,12 @@ async def chat_completions( usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens + total_tokens=prompt_tokens + completion_tokens, ) chat_message = ChatMessage(role="assistant", content="".join(final_output)) choice = ChatCompletionResponseChoice(index=0, message=chat_message) resp = ChatCompletionResponse( - id=request_id, - created=created_time, - model=request.model, - choices=[choice], - usage=usage + id=request_id, created=created_time, model=request.model, choices=[choice], usage=usage ) return resp @@ -277,9 +265,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output, metadata, _ in results_generator: delta_message = DeltaMessage(role="assistant", content=request_output) - stream_choice = ChatCompletionStreamResponseChoice( - index=0, delta=delta_message - ) + stream_choice = ChatCompletionStreamResponseChoice(index=0, delta=delta_message) stream_resp = ChatCompletionStreamResponse( id=request_id, @@ -287,7 +273,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: model=request.model, choices=[stream_choice], ) - yield ("data: " + stream_resp.json(ensure_ascii=False) + f"\n\n").encode("utf-8") + yield ("data: " + stream_resp.json(ensure_ascii=False) + "\n\n").encode("utf-8") async def abort_request() -> None: await httpserver_manager.abort(request_id) @@ -296,9 +282,7 @@ async def abort_request() -> None: # Abort the request if the client disconnects. background_tasks.add_task(abort_request) - return StreamingResponse( - stream_results(), media_type="text/event-stream", background=background_tasks - ) + return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) def main(): @@ -306,86 +290,120 @@ def main(): parser.add_argument("--host", type=str, default="127.0.0.1") parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--model_dir", type=str, default=None, - help="the model weight dir path, the app will load config, weights and tokenizer from this dir") - parser.add_argument("--tokenizer_mode", type=str, default="slow", - help="""tokenizer load mode, can be slow or auto, slow mode load fast but run slow, slow mode is good for debug and test, - when you want to get best performance, try auto mode""") - parser.add_argument("--load_way", type=str, default="HF", - help="the way of loading model weights, the default is HF(Huggingface format), llama also supports DS(Deepspeed)") - parser.add_argument("--max_total_token_num", type=int, default=6000, - help="the total token nums the gpu and model can support, equals = max_batch * (input_len + output_len)") - parser.add_argument("--batch_max_tokens", type=int, default=None, - help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM") - parser.add_argument("--eos_id", type=int, default=2, - help="eos stop token id") - parser.add_argument("--running_max_req_size", type=int, default=1000, - help="the max size for forward requests in the same time") - parser.add_argument("--tp", type=int, default=1, - help="model tp parral size, the default is 1") - parser.add_argument("--max_req_input_len", type=int, default=2048, - help="the max value for req input tokens num") - parser.add_argument("--max_req_total_len", type=int, default=2048 + 1024, - help="the max value for req_input_len + req_output_len") - parser.add_argument("--nccl_port", type=int, default=28765, - help="the nccl_port to build a distributed environment for PyTorch") - parser.add_argument("--mode", type=str, default=[], nargs='+', - help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding - | triton_gqa_attention | triton_gqa_flashdecoding] - [triton_w4a16 | triton_w8a16 | lmdeploy_w4a16 | ppl_w4a16 | ppl_w8a8], + parser.add_argument( + "--model_dir", + type=str, + default=None, + help="the model weight dir path, the app will load config, weights and tokenizer from this dir", + ) + parser.add_argument( + "--tokenizer_mode", + type=str, + default="slow", + help="""tokenizer load mode, can be slow or auto, slow mode load fast but run slow, slow mode is + good for debug and test, when you want to get best performance, try auto mode""", + ) + parser.add_argument( + "--load_way", + type=str, + default="HF", + help="""the way of loading model weights, the default is HF(Huggingface format), llama also supports + DS(Deepspeed)""", + ) + parser.add_argument( + "--max_total_token_num", + type=int, + default=6000, + help="the total token nums the gpu and model can support, equals = max_batch * (input_len + output_len)", + ) + parser.add_argument( + "--batch_max_tokens", + type=int, + default=None, + help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM", + ) + parser.add_argument("--eos_id", type=int, default=2, help="eos stop token id") + parser.add_argument( + "--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time" + ) + parser.add_argument("--tp", type=int, default=1, help="model tp parral size, the default is 1") + parser.add_argument("--max_req_input_len", type=int, default=2048, help="the max value for req input tokens num") + parser.add_argument( + "--max_req_total_len", type=int, default=2048 + 1024, help="the max value for req_input_len + req_output_len" + ) + parser.add_argument( + "--nccl_port", type=int, default=28765, help="the nccl_port to build a distributed environment for PyTorch" + ) + parser.add_argument( + "--mode", + type=str, + default=[], + nargs="+", + help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding + | triton_gqa_attention | triton_gqa_flashdecoding] + [triton_w4a16 | triton_w8a16 | lmdeploy_w4a16 | ppl_w4a16 | ppl_w8a8], triton_flashdecoding mode is for long context, current support llama llama2 qwen; triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; ppl_fp16 mode use ppl fast fp16 decode attention kernel; - triton_int8weight and triton_int4weight and lmdeploy_int4weight or ppl_int4weight mode use int8 and int4 to store weights; - you need to read source code to make sure the supported detail mode for all models""") - parser.add_argument("--trust_remote_code", action='store_true', - help="Whether or not to allow for custom models defined on the Hub in their own modeling files.") - parser.add_argument("--disable_log_stats", action='store_true', - help="disable logging throughput stats.") - parser.add_argument("--log_stats_interval", type=int, default=10, - help="log stats interval in second.") - - parser.add_argument("--router_token_ratio", type=float, default=0.0, - help="token ratio to control router dispatch") - parser.add_argument("--router_max_new_token_len", type=int, default=1024, - help="the request max new token len for router") - - parser.add_argument("--no_skipping_special_tokens", action="store_true", - help="whether to skip special tokens when decoding") - parser.add_argument("--no_spaces_between_special_tokens", action="store_true", - help="whether to add spaces between special tokens when decoding") + triton_int8weight and triton_int4weight and lmdeploy_int4weight or ppl_int4weight mode + use int8 and int4 to store weights; + you need to read source code to make sure the supported detail mode for all models""", + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", + ) + parser.add_argument("--disable_log_stats", action="store_true", help="disable logging throughput stats.") + parser.add_argument("--log_stats_interval", type=int, default=10, help="log stats interval in second.") + + parser.add_argument("--router_token_ratio", type=float, default=0.0, help="token ratio to control router dispatch") + parser.add_argument( + "--router_max_new_token_len", type=int, default=1024, help="the request max new token len for router" + ) + + parser.add_argument( + "--no_skipping_special_tokens", action="store_true", help="whether to skip special tokens when decoding" + ) + parser.add_argument( + "--no_spaces_between_special_tokens", + action="store_true", + help="whether to add spaces between special tokens when decoding", + ) + + parser.add_argument("--use_dynamic_prompt_cache", action="store_true", help="use_dynamic_prompt_cache test") + + parser.add_argument("--splitfuse_mode", action="store_true", help="use splitfuse mode") - parser.add_argument("--splitfuse_mode", action='store_true', - help="use splitfuse mode") - parser.add_argument("--splitfuse_block_size", type=int, default=256, - help="splitfuse block size") - parser.add_argument("--prompt_cache_strs", type=str, default=[], nargs='+', - help="""prompt cache strs""") - parser.add_argument("--enable_multimodal", action='store_true', - help="Whether or not to allow to load additional multimodal models.") - parser.add_argument("--cache_capacity", type=int, default=200, - help="cache server capacity for multimodal resources") - parser.add_argument("--cache_reserved_ratio", type=float, default=0.5, - help="cache server reserved capacity ratio after clear") - parser.add_argument("--return_all_prompt_logprobs", action="store_true", - help="return all prompt tokens logprobs") - parser.add_argument("--long_truncation_mode", type=str, choices=[None, 'head', 'center'], default=None, - help="""use to select the handle way when input token len > max_req_input_len. - None : raise Exception + parser.add_argument("--splitfuse_block_size", type=int, default=256, help="splitfuse block size") + parser.add_argument( + "--enable_multimodal", action="store_true", help="Whether or not to allow to load additional multimodal models." + ) + parser.add_argument( + "--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources" + ) + parser.add_argument( + "--cache_reserved_ratio", type=float, default=0.5, help="cache server reserved capacity ratio after clear" + ) + parser.add_argument("--return_all_prompt_logprobs", action="store_true", help="return all prompt tokens logprobs") + parser.add_argument( + "--long_truncation_mode", + type=str, + choices=[None, "head", "center"], + default=None, + help="""use to select the handle way when input token len > max_req_input_len. + None : raise Exception head : remove some head tokens to make input token len <= max_req_input_len - center : remove some tokens in center loc to make input token len <= max_req_input_len""") - - args = parser.parse_args() + center : remove some tokens in center loc to make input token len <= max_req_input_len""", + ) - # 非splitfuse 模式,不支持 prompt cache 特性 - if not args.splitfuse_mode: - assert len(args.prompt_cache_strs) == 0 + args = parser.parse_args() assert args.max_req_input_len < args.max_req_total_len assert args.max_req_total_len <= args.max_total_token_num - + if not args.splitfuse_mode: # 普通模式下 if args.batch_max_tokens is None: @@ -393,9 +411,7 @@ def main(): batch_max_tokens = max(batch_max_tokens, args.max_req_total_len) args.batch_max_tokens = batch_max_tokens else: - assert ( - args.batch_max_tokens >= args.max_req_total_len - ), "batch_max_tokens must >= max_req_total_len" + assert args.batch_max_tokens >= args.max_req_total_len, "batch_max_tokens must >= max_req_total_len" else: # splitfuse 模式下 # assert args.batch_max_tokens is not None, "need to set by yourself" @@ -404,22 +420,24 @@ def main(): batch_max_tokens = max(batch_max_tokens, args.splitfuse_block_size) args.batch_max_tokens = batch_max_tokens - can_use_ports = alloc_can_use_network_port( - num=5 + args.tp, used_nccl_port=args.nccl_port - ) + can_use_ports = alloc_can_use_network_port(num=5 + args.tp, used_nccl_port=args.nccl_port) router_port, detokenization_port, httpserver_port, visual_port, cache_port = can_use_ports[0:5] model_rpc_ports = can_use_ports[5:] if args.enable_multimodal: - start_submodule_processes(start_funcs=[start_cache_manager,], - start_args=[(cache_port, args)]) + start_submodule_processes( + start_funcs=[ + start_cache_manager, + ], + start_args=[(cache_port, args)], + ) # help to manage data stored on Ceph - if 's3://' in args.model_dir: + if "s3://" in args.model_dir: from lightllm.utils.petrel_helper import s3_model_prepare + s3_model_prepare(args.model_dir) - from .httpserver.manager import HttpServerManager global httpserver_manager httpserver_manager = HttpServerManager( args, @@ -429,17 +447,27 @@ def main(): httpserver_port=httpserver_port, enable_multimodal=args.enable_multimodal, ) - - from .detokenization.manager import start_detokenization_process - start_submodule_processes(start_funcs=[start_router_process, start_detokenization_process], - start_args=[(args, router_port, detokenization_port, model_rpc_ports), - (args, detokenization_port, httpserver_port)]) + + start_submodule_processes( + start_funcs=[start_router_process, start_detokenization_process], + start_args=[ + (args, router_port, detokenization_port, model_rpc_ports), + (args, detokenization_port, httpserver_port), + ], + ) if args.enable_multimodal: - start_submodule_processes(start_funcs=[start_visual_process,], - start_args=[(args, router_port, visual_port, cache_port),]) + start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args, router_port, visual_port, cache_port), + ], + ) if "s3://" in args.model_dir: from lightllm.utils.petrel_helper import s3_model_clear + s3_model_clear(args.model_dir) uvicorn.run( @@ -453,5 +481,5 @@ def main(): if __name__ == "__main__": - torch.multiprocessing.set_start_method('spawn'), # this code will not be ok for settings to fork to subprocess + torch.multiprocessing.set_start_method("spawn"), # this code will not be ok for settings to fork to subprocess main() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 4f2fb0f26..25ade7d45 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -11,6 +11,7 @@ from ..io_struct import BatchStrOut, AbortReq, FinishStatus from ..embed_cache.utils import get_shm_name_data, create_shm + class HttpServerManager: def __init__( self, @@ -34,10 +35,8 @@ def __init__( self.recv_from_detokenization = context.socket(zmq.PULL) self.recv_from_detokenization.bind(f"tcp://127.0.0.1:{httpserver_port}") - - self.tokenizer = get_tokenizer( - args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code - ) + + self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) self.req_id_to_out_inf = {} # value type (out_str, metadata, finished, event) @@ -45,34 +44,7 @@ def __init__( self.max_req_input_len = args.max_req_input_len self.max_req_total_len = args.max_req_total_len - self._init_prompt_cache() - return - - def _init_prompt_cache(self): - """ - 初始化 prompt cache 特性, 这个地方的id 分配要于 router 中 的id 分配对齐 - """ - self.prompt_cache_reqs = [] - # 初始化 prompt cahce, 然后初始化请求队列 - if self.args.splitfuse_mode: - id = -1 # id 从 -1, -2, .... 避免和正常的 id 占用 - for prompt_cache_str in self.args.prompt_cache_strs: - prompt_ids = self.tokenizer.encode(prompt_cache_str) - self.prompt_cache_reqs.append((id, prompt_ids)) - id -= 1 return - - def _find_prompt_cache_req(self, token_ids): - prompt_cache_len = 0 - prompt_cache_req_id = None - for (req_id, prompt_ids) in self.prompt_cache_reqs: - prompt_len = len(prompt_ids) - if len(token_ids) > prompt_len: - if token_ids[0 : prompt_len] == prompt_ids: - prompt_cache_len = prompt_len - prompt_cache_req_id = req_id - break - return prompt_cache_len, prompt_cache_req_id # connect cache server, calculate md5, alloc resource, return uuid async def _alloc_resource(self, data, num): @@ -95,9 +67,9 @@ async def _alloc_resource(self, data, num): async def _alloc_multimodal_resources(self, multimodal_params): for img in multimodal_params.images: record = await self._alloc_resource(img.read(), self.tokenizer.image_length) - img.uuid = record['id'] - img.token_id = record['token_id'] - img.token_num = record['token_num'] + img.uuid = record["id"] + img.token_id = record["token_id"] + img.token_num = record["token_num"] async def _release_multimodal_resources(self, multimodal_params): if multimodal_params is not None: @@ -107,9 +79,7 @@ async def _release_multimodal_resources(self, multimodal_params): async def generate(self, prompt, sampling_params, request_id, multimodal_params): if self.enable_multimodal: - assert ( - len(multimodal_params.images) <= self.args.cache_capacity - ), "too many images!" + assert len(multimodal_params.images) <= self.args.cache_capacity, "too many images!" await self._alloc_multimodal_resources(multimodal_params) prompt_ids = self.tokenizer.encode(prompt, multimodal_params) else: @@ -119,14 +89,15 @@ async def generate(self, prompt, sampling_params, request_id, multimodal_params) if prompt_tokens > self.max_req_input_len: # use long_truncation_mode to truncate long input len req. if self.args.long_truncation_mode is None: - raise ValueError( - f"the input prompt token len {prompt_tokens} is too long > {self.max_req_input_len}" - ) + raise ValueError(f"the input prompt token len {prompt_tokens} is too long > {self.max_req_input_len}") elif self.args.long_truncation_mode == "head": - prompt_ids = prompt_ids[-self.max_req_input_len:] + prompt_ids = prompt_ids[-self.max_req_input_len :] prompt_tokens = len(prompt_ids) elif self.args.long_truncation_mode == "center": - prompt_ids = prompt_ids[0:self.max_req_input_len // 2] + prompt_ids[-(self.max_req_input_len - self.max_req_input_len // 2):] + prompt_ids = ( + prompt_ids[0 : self.max_req_input_len // 2] + + prompt_ids[-(self.max_req_input_len - self.max_req_input_len // 2) :] + ) prompt_tokens = len(prompt_ids) assert prompt_tokens == self.max_req_input_len else: @@ -135,26 +106,21 @@ async def generate(self, prompt, sampling_params, request_id, multimodal_params) req_total_len = prompt_tokens + sampling_params.max_new_tokens if req_total_len > self.max_req_total_len: raise ValueError( - f"the req token total len (input len + output len) is too long > max_req_total_len:{self.max_req_total_len}" + f"the req total len (input len + output len) is too long > max_req_total_len:{self.max_req_total_len}" ) if req_total_len + 1 > self.total_token_num: - raise ValueError( - f"the req token total len + 1 (input len + output len + 1) is too long > max_total_token_num:{self.total_token_num}" - ) - + raise ValueError(f"the req token total len + 1 is too long > max_total_token_num:{self.total_token_num}") + sampling_params.stop_sentences_to_token_ids(self.tokenizer) req_status = ReqStatus(request_id, multimodal_params) event = req_status.event self.req_id_to_out_inf[request_id] = req_status - # 寻找是否有可用的prompt cache 可用 - prompt_cache_len, prompt_cache_req_id = self._find_prompt_cache_req(prompt_ids) - if self.enable_multimodal: - self.send_to_visual.send_pyobj((prompt_ids, sampling_params, multimodal_params, request_id, prompt_cache_len, prompt_cache_req_id)) + self.send_to_visual.send_pyobj((prompt_ids, sampling_params, multimodal_params, request_id)) else: - self.send_to_router.send_pyobj((prompt_ids, sampling_params, multimodal_params, request_id, prompt_cache_len, prompt_cache_req_id)) + self.send_to_router.send_pyobj((prompt_ids, sampling_params, multimodal_params, request_id)) while True: try: @@ -197,15 +163,13 @@ async def abort(self, request_id): async def handle_loop(self): while True: recv_ans: BatchStrOut = await self.recv_from_detokenization.recv_pyobj() - assert isinstance( - recv_ans, BatchStrOut - ), f"error recv type {type(recv_ans)}" + assert isinstance(recv_ans, BatchStrOut), f"error recv type {type(recv_ans)}" for req_id, text, metadata, finish_status in recv_ans.reqs_infs: finish_status = FinishStatus(finish_status) try: if not finish_status.is_aborted(): - req_status : ReqStatus = self.req_id_to_out_inf[req_id] - async with req_status.lock: + req_status: ReqStatus = self.req_id_to_out_inf[req_id] + async with req_status.lock: req_status.out_token_info_list.append((text, metadata, finish_status)) req_status.event.set() else: @@ -214,6 +178,7 @@ async def handle_loop(self): pass return + class ReqStatus: def __init__(self, req_id, multimodal_params) -> None: self.req_id = req_id diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index 7f3f88b5a..deed1b806 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -4,23 +4,23 @@ import asyncio import enum + class ReqRunStatus(enum.Enum): - WAIT_IN_QUEUE = 0 # 在队列中等待 - RUNNING = 1 # 运行 - PAUSED_AND_KVKEEP = 2 # 暂停保留KV - PAUSED_AND_OFFLOAD = 3 # 暂停卸载KV - RERUNNING_FROM_KVKEEP = 4 # 从暂停中恢复 - RERUNNING_FROM_OFFLOAD = 5 # 从卸载KV中恢复 + WAIT_IN_QUEUE = 0 # 在队列中等待 + RUNNING = 1 # 运行 + PAUSED_AND_OFFLOAD = 2 # 暂停卸载KV + RERUNNING_FROM_OFFLOAD = 3 # 从卸载KV中恢复 + class FinishStatus(enum.Enum): - NO_FINISH = 0 # 没有结束 - FINISHED_STOP = 1 # 因为遇到了STOP token 而结束 - FINISHED_LENGTH = 2 # 因为长度达到了最大长度而结束 - FINISHED_ABORT = 3 # 因为请求被中止而结束 + NO_FINISH = 0 # 没有结束 + FINISHED_STOP = 1 # 因为遇到了STOP token 而结束 + FINISHED_LENGTH = 2 # 因为长度达到了最大长度而结束 + FINISHED_ABORT = 3 # 因为请求被中止而结束 def is_finished(self): return 1 <= self.value <= 3 - + def is_aborted(self): return self == FinishStatus.FINISHED_ABORT @@ -35,8 +35,9 @@ def get_finish_reason(self): finish_reason = None return finish_reason + class Req: - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, multimodal_params: MultimodalParams, prompt_cache_len=0, prompt_cache_req_id=None): + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, multimodal_params: MultimodalParams): self.request_id = request_id self.prompt_ids = prompt_ids self.input_len = len(prompt_ids) @@ -48,28 +49,27 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, multim self.req_status = ReqRunStatus.WAIT_IN_QUEUE self.finish_status = FinishStatus.NO_FINISH - self.cur_kv_len = 0 # 当前已经占用掉 token 的 kv len 长度 - self.prompt_cache_len = prompt_cache_len # 可以复用的一些公共 prompt 头对应的 kv cache 长度, 只有 splitfuse 模式当前才实际使用 - self.prompt_cache_req_id = prompt_cache_req_id # 对应的可复用的请求的 id,方便初始化的时候,将其 kv cache 复制到当前请求中, 默认值 为 None - assert self.input_len > self.prompt_cache_len + self.cur_kv_len = 0 # 当前已经占用掉 token 的 kv len 长度 return - + def to_rpc_obj(self): - return {"request_id": self.request_id, - "input_id": self.prompt_ids, - "output_len": self.max_output_len, - "sampling_param": self.sample_params.to_dict(), - "multimodal_params": self.multimodal_params.to_dict(), - "prompt_cache_len": self.prompt_cache_len, - "prompt_cache_req_id": self.prompt_cache_req_id, - "req_status": self.req_status} - + return { + "request_id": self.request_id, + "input_id": self.prompt_ids, + "output_len": self.max_output_len, + "sampling_param": self.sample_params.to_dict(), + "multimodal_params": self.multimodal_params.to_dict(), + "req_status": self.req_status, + } + def to_req_detokenization_state(self): - out = ReqDetokenizationState(self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos) + out = ReqDetokenizationState( + self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos + ) # if self.output_metadata_list: # looks like no use # out.gen_metadata.update(self.output_metadata_list[-1]) return out - + def stop_sequences_matched(self): for stop_token_ids in self.sample_params.stop_sequences: stop_len = len(stop_token_ids) @@ -80,26 +80,26 @@ def stop_sequences_matched(self): return False def __repr__(self): - return (f"request_id(n={self.request_id}, " - f"prompt_ids={self.prompt_ids}, ") - + return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, " + def get_used_tokens(self): - return max(0, self.cur_kv_len - self.prompt_cache_len) + return max(0, self.cur_kv_len) def get_tuple_tokens(self, is_busy, router_max_new_token_len): raise Exception("need to impl") - + def get_decode_need_tokens(self): raise Exception("need to impl") - + def get_first_router_need_tokens(self): raise Exception("need to impl") + class NormalReq(Req): - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, multimodal_params: MultimodalParams, prompt_cache_len=0, prompt_cache_req_id=None): - super().__init__(request_id, prompt_ids, sample_params, multimodal_params, prompt_cache_len, prompt_cache_req_id) + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, multimodal_params: MultimodalParams): + super().__init__(request_id, prompt_ids, sample_params, multimodal_params) return - + def get_tuple_tokens(self, is_busy, router_max_new_token_len): """ 普通continues batch调度模式, 先prefill 后 decode 的估计方式 的实现 @@ -115,39 +115,43 @@ def get_tuple_tokens(self, is_busy, router_max_new_token_len): cur_max_new_token_len = min(self.max_output_len, max(int(1.1 * has_out_len), router_max_new_token_len)) if self.req_status == ReqRunStatus.RUNNING: - return (self.input_len + has_out_len - self.prompt_cache_len, max(0, cur_max_new_token_len - has_out_len - 1)) + return (self.input_len + has_out_len, max(0, cur_max_new_token_len - has_out_len - 1)) elif self.req_status == ReqRunStatus.WAIT_IN_QUEUE: - return (self.input_len + 1 - self.prompt_cache_len, max(0, cur_max_new_token_len - 1 - 1)) + return (self.input_len + 1, max(0, cur_max_new_token_len - 1 - 1)) elif self.req_status == ReqRunStatus.PAUSED_AND_OFFLOAD: - return (self.input_len + has_out_len + 1 - self.prompt_cache_len, max(0, cur_max_new_token_len - has_out_len - 1 - 1)) - elif self.req_status == ReqRunStatus.PAUSED_AND_KVKEEP: - return (self.input_len + has_out_len - self.prompt_cache_len, max(0, cur_max_new_token_len - has_out_len - 1)) + return (self.input_len + has_out_len + 1, max(0, cur_max_new_token_len - has_out_len - 1 - 1)) else: assert False, "error state" return - + def get_decode_need_tokens(self): if self.req_status == ReqRunStatus.RUNNING: return 1 else: assert False, "error state" - + def get_first_router_need_tokens(self): if self.req_status == ReqRunStatus.WAIT_IN_QUEUE: return self.input_len elif self.req_status == ReqRunStatus.PAUSED_AND_OFFLOAD: return self.input_len + len(self.output_ids) - elif self.req_status == ReqRunStatus.PAUSED_AND_KVKEEP: - return 0 else: assert False, "error state" + class SplitFuseReq(Req): - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, multimodal_params: MultimodalParams, prompt_cache_len=0, prompt_cache_req_id=None, splitfuse_block_size=None): - super().__init__(request_id, prompt_ids, sample_params, multimodal_params, prompt_cache_len, prompt_cache_req_id) + def __init__( + self, + request_id, + prompt_ids, + sample_params: SamplingParams, + multimodal_params: MultimodalParams, + splitfuse_block_size=None, + ): + super().__init__(request_id, prompt_ids, sample_params, multimodal_params) self.splitfuse_block_size = splitfuse_block_size return - + def get_tuple_tokens(self, is_busy, router_max_new_token_len): """ splitfuse 调度模式的实现 @@ -160,19 +164,43 @@ def get_tuple_tokens(self, is_busy, router_max_new_token_len): else: cur_max_new_token_len = min(self.max_output_len, max(int(1.1 * has_out_len), router_max_new_token_len)) - if self.req_status == ReqRunStatus.RUNNING or self.req_status == ReqRunStatus.PAUSED_AND_KVKEEP: - return (self.input_len + has_out_len - self.prompt_cache_len, - max(0, (self.input_len + has_out_len - self.prompt_cache_len - self.cur_kv_len + self.splitfuse_block_size - 1) // self.splitfuse_block_size + cur_max_new_token_len - has_out_len - 1)) + if self.req_status == ReqRunStatus.RUNNING: + return ( + self.input_len + has_out_len, + max( + 0, + (self.input_len + has_out_len - self.cur_kv_len + self.splitfuse_block_size - 1) + // self.splitfuse_block_size + + cur_max_new_token_len + - has_out_len + - 1, + ), + ) elif self.req_status == ReqRunStatus.WAIT_IN_QUEUE: - return (self.input_len - self.prompt_cache_len, - max(0, (self.input_len - self.prompt_cache_len + self.splitfuse_block_size - 1) // self.splitfuse_block_size + cur_max_new_token_len - 1)) + return ( + self.input_len, + max( + 0, + (self.input_len + self.splitfuse_block_size - 1) // self.splitfuse_block_size + + cur_max_new_token_len + - 1, + ), + ) elif self.req_status == ReqRunStatus.PAUSED_AND_OFFLOAD: - return (self.input_len + has_out_len - self.prompt_cache_len, - max(0, (self.input_len + has_out_len - self.prompt_cache_len + self.splitfuse_block_size - 1) // self.splitfuse_block_size + cur_max_new_token_len - has_out_len - 1)) + return ( + self.input_len + has_out_len, + max( + 0, + (self.input_len + has_out_len + self.splitfuse_block_size - 1) // self.splitfuse_block_size + + cur_max_new_token_len + - has_out_len + - 1, + ), + ) else: assert False, "error state" return - + def get_decode_need_tokens(self): """ splitfuse 调度模式的实现 @@ -181,17 +209,16 @@ def get_decode_need_tokens(self): return min(self.input_len + len(self.output_ids) - self.cur_kv_len, self.splitfuse_block_size) else: assert False, "error state" - + def get_first_router_need_tokens(self): if self.req_status == ReqRunStatus.WAIT_IN_QUEUE: - return min(self.input_len - self.prompt_cache_len, self.splitfuse_block_size) + return min(self.input_len, self.splitfuse_block_size) elif self.req_status == ReqRunStatus.PAUSED_AND_OFFLOAD: - return min(self.input_len + len(self.output_ids) - self.prompt_cache_len, self.splitfuse_block_size) - elif self.req_status == ReqRunStatus.PAUSED_AND_KVKEEP: - return min(self.input_len + len(self.output_ids) - self.cur_kv_len, self.splitfuse_block_size) + return min(self.input_len + len(self.output_ids), self.splitfuse_block_size) else: assert False, "error state" + class ReqDetokenizationState: def __init__( self, @@ -211,6 +238,7 @@ def __init__( self.ignore_eos = ignore_eos self.gen_metadata = {} + class Batch: def __init__(self, batch_id, reqs: List[Req]): self.batch_id = batch_id @@ -249,16 +277,16 @@ def mark_and_get_finished_req_and_preupdate_status(self, eos_id): self.batch_decode_need_tokens -= req.get_decode_need_tokens() else: unfinished_req_ids.append(req.request_id) - + return unfinished_req_ids, finished_req_ids - + def filter_out_finished_req(self, unfinished_req_ids, finished_req_ids): # update batch if len(finished_req_ids) != 0: self.reqs = [self.id_to_reqs[req_id] for req_id in unfinished_req_ids] self.id_to_reqs = {req.request_id: req for req in self.reqs} return - + def pop_req(self, req_id): self.reqs = [req for req in self.reqs if req.request_id != req_id] req = self.id_to_reqs[req_id] @@ -279,18 +307,19 @@ def merge(self, mini_batch): return def __repr__(self): - return (f"batch_id={self.batch_id}, " - f"reqs={self.reqs}, ") - + return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, " + + class BatchTokenIdOut: def __init__(self): self.reqs_infs: List[Tuple[str, int, Dict, int]] = [] # [req_id, new_token_id, gen_metadata, finish_status] + class BatchStrOut: def __init__(self): - self.reqs_infs: List[Tuple[str, str, Dict, int]] = [] # [req_id, token_str, gen_metadata, finish_status] - + self.reqs_infs: List[Tuple[str, str, Dict, int]] = [] # [req_id, token_str, gen_metadata, finish_status] + + class AbortReq: def __init__(self, req_id): self.req_id = req_id - diff --git a/lightllm/server/router/dynamic_prompt/__init__.py b/lightllm/server/router/dynamic_prompt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py new file mode 100644 index 000000000..2a978109f --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -0,0 +1,452 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/router/radix_cache.py +import torch +import heapq +import time +import numpy as np +from collections import defaultdict +from dataclasses import dataclass +from typing import Tuple +from sortedcontainers import SortedSet +from .shared_arr import SharedArray, SharedTreeInfoNode, SharedLinkedListManager + + +class UniqueTimeIdGenerator: + def __init__(self): + self.counter = 0 + + def generate_time_id(self): + self.counter += 1 + return self.counter + + +time_gen = UniqueTimeIdGenerator() + + +class TreeNode: + def __init__(self, shared_idx_manager): + self.shared_idx_manager: SharedLinkedListManager = shared_idx_manager + self.children = {} # 这里的键 为 token_id_key 的第一个元素 + self.parent: TreeNode = None + self.token_id_key: torch.Tensor = None + self.token_mem_index_value: torch.Tensor = None # 用于记录存储的 token_index 为每个元素在 token mem 中的index位置 + self.ref_counter = 0 + self.shared_idx_node: SharedTreeInfoNode = self.shared_idx_manager.alloc() + self.time_id = time_gen.generate_time_id() # 用于标识时间周期 + + def get_compare_key(self): + return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) + + def split_node(self, prefix_len): + split_parent_node = TreeNode(self.shared_idx_manager) + split_parent_node.parent = self.parent + split_parent_node.parent.children[self.token_id_key[0].item()] = split_parent_node + split_parent_node.token_id_key = self.token_id_key[0:prefix_len] + split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len] + split_parent_node.children = {} + split_parent_node.children[self.token_id_key[prefix_len].item()] = self + split_parent_node.ref_counter = self.ref_counter + + split_parent_node.shared_idx_node.set_parent_idx(self.shared_idx_node.get_parent_idx()) + new_len = len(split_parent_node.token_mem_index_value) + split_parent_node.shared_idx_node.set_node_value_len(new_len) + split_parent_node.shared_idx_node.set_node_prefix_total_len( + split_parent_node.get_parent_prefix_total_len() + new_len + ) + + self.token_id_key = self.token_id_key[prefix_len:] + self.token_mem_index_value = self.token_mem_index_value[prefix_len:] + self.parent = split_parent_node + self.shared_idx_node.set_parent_idx(split_parent_node.shared_idx_node.get_idx()) + new_len = len(self.token_mem_index_value) + self.shared_idx_node.set_node_value_len(new_len) + self.shared_idx_node.set_node_prefix_total_len(self.get_parent_prefix_total_len() + new_len) + + return split_parent_node + + def add_and_return_new_child(self, token_id_key, token_mem_index_value): + child = TreeNode(self.shared_idx_manager) + child.token_id_key = token_id_key + child.token_mem_index_value = token_mem_index_value + first_token_key = child.token_id_key[0].item() + assert first_token_key not in self.children.keys() + self.children[first_token_key] = child + child.parent = self + + # 更新shared 信息 + child.shared_idx_node.set_parent_idx(self.shared_idx_node.get_idx()) + new_len = len(child.token_mem_index_value) + child.shared_idx_node.set_node_value_len(new_len) + child.shared_idx_node.set_node_prefix_total_len(child.get_parent_prefix_total_len() + new_len) + return child + + def remove_child(self, child_node): + del self.children[child_node.token_id_key[0].item()] + child_node.parent = None + return + + def update_time(self): + self.time_id = time_gen.generate_time_id() + + def is_leaf(self): + return len(self.children) == 0 + + def get_parent_prefix_total_len(self): + return self.parent.shared_idx_node.get_node_prefix_total_len() + + +def match(key, seq): + i = 0 + for k, w in zip(key, seq): + if k != w: + break + i += 1 + return i + + +class RadixCache: + """ + unique_name 主要用于解决单机,多实列部署时的shm冲突 + """ + + def __init__(self, unique_name, total_token_num, tp_id): + self._key_dtype = torch.int64 + self._value_dtype = torch.int64 + + self.shared_idx_manager = SharedLinkedListManager(unique_name, total_token_num, tp_id) + + self.root_node = TreeNode(self.shared_idx_manager) + self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) + self.root_node.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + self.root_node.ref_counter = 1 # 初始化为 1 保证永远不会被 evict 掉 + + self.evict_tree_set = SortedSet(key=lambda x: x.get_compare_key()) # 自定义比较器 + self.evict_tree_set.add(self.root_node) + + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{tp_id}", (1,), dtype=np.int64) + self.refed_tokens_num.arr[0] = 0 + self.tree_total_tokens_num = SharedArray(f"{unique_name}_tree_total_tokens_num_{tp_id}", (1,), dtype=np.int64) + self.tree_total_tokens_num.arr[0] = 0 + + def insert(self, key, value=None): + if value is None: + value = key + + assert len(key) == len(value) and len(key) >= 1 + return self._insert_helper(self.root_node, key, value) + + def _insert_helper(self, node: TreeNode, key, value): + if node.is_leaf(): + self.evict_tree_set.discard(node) + + try: + first_key_id = key[0].item() + if first_key_id in node.children.keys(): + child: TreeNode = node.children[first_key_id] + prefix_len = match(key, child.token_id_key) + if prefix_len == len(key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + child.update_time() + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + + elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + key = key[prefix_len:] + value = value[prefix_len:] + split_parent_node = child.split_node(prefix_len) + new_node = split_parent_node.add_and_return_new_child(key, value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + elif prefix_len < len(key) and prefix_len == len(child.token_id_key): + return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:]) + else: + assert False, "can not run to here" + + else: + new_node = node.add_and_return_new_child(key, value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0 + finally: + node.update_time() + if node.is_leaf(): + self.evict_tree_set.add(node) + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + if tree_node != self.root_node: + if len(ans_value_list) != 0: + value = torch.concat(ans_value_list) + else: + value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + return tree_node, len(value), value + else: + self.dec_node_ref_counter(self.root_node) + return None, 0, None + + def _match_prefix_helper(self, node: TreeNode, key, ans_value_list: list, update_refs=False) -> TreeNode: + if node.is_leaf(): + self.evict_tree_set.discard(node) + + if update_refs: + node.ref_counter += 1 + # from 0 to 1 need update refs token num + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + + try: + if len(key) == 0: + return node + + first_key_id = key[0].item() + if first_key_id not in node.children.keys(): + return node + else: + child = node.children[first_key_id] + prefix_len = match(key, child.token_id_key) + if prefix_len == len(child.token_id_key): + ans_value_list.append(child.token_mem_index_value) + return self._match_prefix_helper(child, key[prefix_len:], ans_value_list, update_refs=update_refs) + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + ans_value_list.append(split_parent_node.token_mem_index_value) + + if update_refs: + split_parent_node.ref_counter += 1 + # from 0 to 1 need update refs token num + if split_parent_node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value) + + if child.is_leaf(): + self.evict_tree_set.add(child) + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + return split_parent_node + else: + assert False, "error state" + finally: + node.update_time() + if node.is_leaf(): + self.evict_tree_set.add(node) + + def evict(self, need_remove_tokens, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: TreeNode = self.evict_tree_set.pop(0) + assert ( + node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node + ), "error evict tree node state" + num_evicted += len(node.token_mem_index_value) + evict_callback(node.token_mem_index_value) + # update total token num + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + # 回收 shared 链表资源 + self.shared_idx_manager.free(node.shared_idx_node.get_idx()) + return + + def clear_tree_nodes(self): + """ + 该函数只在测试时调用 + """ + while True: + node: TreeNode = self.evict_tree_set.pop(0) + if node != self.root_node: + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + self.shared_idx_manager.free(node.shared_idx_node.get_idx()) + else: + break + + self.tree_total_tokens_num.arr[0] = 0 + self.refed_tokens_num.arr[0] = 0 + return + + def dec_node_ref_counter(self, node): + while node is not None: + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) + node.ref_counter -= 1 + node = node.parent + return + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def print_self(self, indent=0): + self._print_helper(self.root_node, indent) + + def _print_helper(self, node: TreeNode, indent): + print( + " " * indent, + f"shared_idx: {node.shared_idx_node.get_idx()} p_idx: {node.shared_idx_node.get_parent_idx()} \ + k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \ + time_id: {node.time_id} prefix_total_len: {node.shared_idx_node.get_node_prefix_total_len()} \ + node_value_len: {node.shared_idx_node.get_node_value_len()}", + ) + for _, child in node.children.items(): + self._print_helper(child, indent=indent + 2) + return + + +class RadixCacheReadOnlyClient: + """ + router 端只读用的客户端,用于从共享内存中读取树结构中的信息,用于进行prompt cache 的调度估计。 + """ + + def __init__(self, unique_name, total_token_num, tp_id): + self.shared_idx_manager = SharedLinkedListManager(unique_name, total_token_num, tp_id) + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{tp_id}", (1,), dtype=np.int64) + self.tree_total_tokens_num = SharedArray(f"{unique_name}_tree_total_tokens_num_{tp_id}", (1,), dtype=np.int64) + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def get_shared_node(self, idx): + return self.shared_idx_manager.get_shared_node(idx) + + def get_all_parent_shared_nodes(self, idx): + node = self.shared_idx_manager.get_shared_node(idx) + ans_list = [node] + while node.get_parent_idx() != -1: + node = self.shared_idx_manager.get_shared_node(node.get_parent_idx()) + ans_list.append(node) + return ans_list + + +# /////////////////////////////////////////////////////////////////////////////// + +if __name__ == "__main__": + # test 1 + def test1(): + tree = RadixCache("unique_name", 100, 0) + ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu")) + assert ans == 0 + tree.print_self() + ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) + assert ans == 5 + tree.print_self() + ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) + assert ans == 8 + tree.print_self() + + assert tree.get_refed_tokens_num() == 0 + assert tree.get_tree_total_tokens_num() == 13 + + # print("evict") + tree.evict(9, lambda x: x) + tree.print_self() + assert tree.get_refed_tokens_num() == 0 and tree.get_tree_total_tokens_num() == 0 + + test1() + + # test 2 + def test2(): + tree = RadixCache("unique_name", 100, 1) + ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu")) + ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) + tree.print_self() + + tree_node, size, values = tree.match_prefix( + torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64, device="cpu"), update_refs=False + ) + assert tree_node.shared_idx_node.get_node_prefix_total_len() == 5 and size == 5 and len(values) == 5 + tree_node, size, values = tree.match_prefix( + torch.tensor([0, 1, 2, 3, 4, 9], dtype=torch.int64, device="cpu"), update_refs=False + ) + assert tree_node.shared_idx_node.get_node_prefix_total_len() == 5 and size == 5 and len(values) == 5 + tree_node, size, values = tree.match_prefix( + torch.tensor([0, 1, 2, 3, 4, 7, 8], dtype=torch.int64, device="cpu"), update_refs=False + ) + assert tree_node.shared_idx_node.get_node_prefix_total_len() == 7 and size == 7 and len(values) == 7 + tree_node, size, values = tree.match_prefix( + torch.tensor([0, 1, 2, 3, 4, 7, 9], dtype=torch.int64, device="cpu"), update_refs=False + ) + assert tree_node.shared_idx_node.get_node_prefix_total_len() == 6 and size == 6 and len(values) == 6 + print(ans) + return + + # test2() + + # test 3 + def test3(): + tree = RadixCache("unique_name", 100, 2) + ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu")) + ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) + tree.print_self() + + tree_node, size, values = tree.match_prefix( + torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node.shared_idx_node.get_node_prefix_total_len() == 5 and size == 5 and len(values) == 5 + assert tree.get_refed_tokens_num() == 5 and tree.get_tree_total_tokens_num() == 13 + + tree_node, size, values = tree.match_prefix( + torch.tensor([0, 1, 2, 3, 4, 7, 9], dtype=torch.int64, device="cpu"), update_refs=True + ) + assert tree_node.shared_idx_node.get_node_prefix_total_len() == 6 and size == 6 and len(values) == 6 + assert tree.get_refed_tokens_num() == 6 and tree.get_tree_total_tokens_num() == 13 + + tree.print_self() + tree.evict(2, lambda x: x) + assert tree.get_refed_tokens_num() == 6 and tree.get_tree_total_tokens_num() == 8 + tree.print_self() + + tree.dec_node_ref_counter(tree_node) + tree.print_self() + print(ans) + return + + test3() + + def test4(): + + tree = RadixCache("unique_name", 100, 2) + ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu")) + ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) + tree.print_self() + + tree.clear_tree_nodes() + assert tree.shared_idx_manager.can_alloc_num() == 100 + print(ans) + return + + test4() diff --git a/lightllm/server/router/dynamic_prompt/shared_arr.py b/lightllm/server/router/dynamic_prompt/shared_arr.py new file mode 100644 index 000000000..37710464b --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/shared_arr.py @@ -0,0 +1,154 @@ +# import faulthandler +# faulthandler.enable() + +import numpy as np +import multiprocessing as mp +from multiprocessing import shared_memory + + +class SharedArray: + def __init__(self, name, shape, dtype): + dtype_byte_num = np.array([1], dtype=dtype).dtype.itemsize + try: + shm = shared_memory.SharedMemory(name=name, create=True, size=np.prod(shape) * dtype_byte_num) + print(f"create shm {name}") + except: + shm = shared_memory.SharedMemory(name=name, create=False, size=np.prod(shape) * dtype_byte_num) + print(f"link shm {name}") + self.shm = shm # SharedMemory 对象一定要被持有,否则会被释放 + self.arr = np.ndarray(shape, dtype=dtype, buffer=self.shm.buf) + + +class SharedTreeInfoNode: + def __init__(self, manager, idx) -> None: + self.manager = manager + self.idx = idx + + def get_idx(self): + return self.idx + + def get_parent_idx(self): + return self.manager._values[self.idx, 3] + + def set_parent_idx(self, p_idx): + self.manager._values[self.idx, 3] = p_idx + + def get_parent_idx_shared_node(self): + return SharedTreeInfoNode(self.manager, self.manager._values[self.idx, 3]) + + def get_node_value_len(self): + return self.manager._values[self.idx, 4] + + def set_node_value_len(self, value_len): + self.manager._values[self.idx, 4] = value_len + + def get_node_prefix_total_len(self): + return self.manager._values[self.idx, 5] + + def set_node_prefix_total_len(self, prefix_total_len): + self.manager._values[self.idx, 5] = prefix_total_len + + +class SharedLinkedListManager: + VALUE_INDEX = 0 + PRE_INDEX = 1 + NEXT_INDEX = 2 + + def __init__(self, unique_name, total_token_num, tp_id) -> None: + self.size = total_token_num + 2 # 因为 0 号节点不分配,所以为了满足充分可用性需要 + 2. + # 第二维对应信息 0 idx 1 pre index 2 next index 用于链表管理 3 tree_node parent node idx + # 4 tree_node value len 5 tree node prefix total len + self._shm_array = SharedArray(f"{unique_name} SharedLinkedList_{tp_id}", shape=(self.size, 6), dtype=np.int64) + self._values = self._shm_array.arr + # idx + self._values[:, self.VALUE_INDEX] = np.arange(0, self.size, 1) + # pre + self._values[0, self.PRE_INDEX] = -1 + self._values[1:, self.PRE_INDEX] = np.arange(0, self.size - 1, 1) + # next + self._values[0 : self.size - 1, self.NEXT_INDEX] = np.arange(1, self.size, 1) + self._values[self.size - 1, self.NEXT_INDEX] = -1 + + # tree node value + self._values[:, 3] = -1 + self._values[:, 4] = 0 + self._values[:, 5] = 0 + + def alloc(self): + if self._values[0, self.NEXT_INDEX] != -1: + alloc_idx = self._values[0, self.NEXT_INDEX] + if self._values[alloc_idx, self.NEXT_INDEX] == -1: + self._values[0, self.NEXT_INDEX] = -1 + ans = SharedTreeInfoNode(self, alloc_idx) + + nn_idx = self._values[alloc_idx, self.NEXT_INDEX] + self._values[0, self.NEXT_INDEX] = nn_idx + self._values[nn_idx, self.PRE_INDEX] = 0 + ans = SharedTreeInfoNode(self, alloc_idx) + # 初始化值 + ans.set_parent_idx(-1) + ans.set_node_value_len(0) + ans.set_node_prefix_total_len(0) + return ans + + assert False, "error cannot alloc" + + def free(self, idx): + nn_idx = self._values[0, self.NEXT_INDEX] + self._values[0, self.NEXT_INDEX] = idx + self._values[idx, self.PRE_INDEX] = 0 + self._values[idx, self.NEXT_INDEX] = nn_idx + if nn_idx != -1: + self._values[nn_idx, self.PRE_INDEX] = idx + return + + def can_alloc_num(self): + num = 0 + cur_loc = 0 + while self._values[cur_loc, self.NEXT_INDEX] != -1: + num += 1 + cur_loc = self._values[cur_loc, self.NEXT_INDEX] + return num + + def get_shared_node(self, idx): + return SharedTreeInfoNode(self, idx) + + +if __name__ == "__main__": + # test SharedArray + a = SharedArray("sb_abc", (1,), dtype=np.int32) + a.arr[0] = 10 + assert a.arr[0] == 10 + a.arr[0] += 10 + assert a.arr[0] == 20 + + # test SharedTreeIdxManager + mananger = SharedLinkedListManager("unique_name", 100, 0) + node1 = mananger.alloc() + node1.set_parent_idx(10) + assert node1.get_parent_idx() == 10 + node1.set_node_value_len(10) + assert node1.get_node_value_len() == 10 + node1.set_node_prefix_total_len(100) + assert node1.get_node_prefix_total_len() == 100 + mananger.free(node1.get_idx()) + alloc_nodes = [] + for _ in range(101): + node1 = alloc_nodes.append(mananger.alloc()) + + try: + node_tmp = mananger.alloc() + except: + assert True + + for e in alloc_nodes: + mananger.free(e.get_idx()) + + alloc_nodes = [] + for _ in range(101): + node1 = alloc_nodes.append(mananger.alloc()) + + try: + node_tmp = mananger.alloc() + except: + assert True diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index a8b89633e..9f8af1af2 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -2,6 +2,7 @@ import uuid import uvloop import asyncio + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) import zmq import zmq.asyncio @@ -23,7 +24,6 @@ class RouterManager: - def __init__(self, args, router_port, detokenization_port, model_rpc_ports): self.args = args self.model_weightdir = args.model_dir @@ -31,17 +31,17 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports): self.load_way = args.load_way self.mode = args.mode self.max_total_token_num = args.max_total_token_num - + self.pause_strategy = Fcfs() self.running_batch: Batch = None self.eos_id = args.eos_id self.has_wait_tokens = 0 self.max_wait_tokens = 10 - + context = zmq.asyncio.Context(2) self.recv_from_httpserver = context.socket(zmq.PULL) self.recv_from_httpserver.bind(f"tcp://127.0.0.1:{router_port}") - + self.send_to_detokenization = context.socket(zmq.PUSH) self.send_to_detokenization.connect(f"tcp://127.0.0.1:{detokenization_port}") self.model_rpc_ports = model_rpc_ports @@ -49,9 +49,6 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports): self.is_splitfuse_mode = args.splitfuse_mode self.splitfuse_block_size = args.splitfuse_block_size - if self.is_splitfuse_mode and len(args.prompt_cache_strs) != 0: - self.tokenizer = get_tokenizer(self.model_weightdir, args.tokenizer_mode, args.trust_remote_code) - self.stats_tool = Stats(not args.disable_log_stats, args.log_stats_interval) return @@ -65,49 +62,25 @@ async def wait_to_model_ready(self): init_model_ret = [] for rank_id in range(self.world_size): # async init model process kvargs = { - "rank_id" : rank_id, - "world_size" : self.world_size, - "weight_dir" : self.model_weightdir, - "load_way" : self.load_way, - "max_total_token_num" : self.max_total_token_num, - "mode" : self.mode, - "max_req_num" : self.args.running_max_req_size + 8, - "max_seq_length" : self.args.max_req_total_len + 8, # 留一点余量 - "nccl_port" : self.args.nccl_port, - "is_splitfuse_mode" : self.is_splitfuse_mode, - "splitfuse_block_size" : self.splitfuse_block_size, - "return_all_prompt_logprobs" : self.args.return_all_prompt_logprobs + "rank_id": rank_id, + "world_size": self.world_size, + "weight_dir": self.model_weightdir, + "load_way": self.load_way, + "max_total_token_num": self.max_total_token_num, + "mode": self.mode, + "max_req_num": self.args.running_max_req_size + 8, + "max_seq_length": self.args.max_req_total_len + 8, # 留一点余量 + "nccl_port": self.args.nccl_port, + "is_splitfuse_mode": self.is_splitfuse_mode, + "splitfuse_block_size": self.splitfuse_block_size, + "return_all_prompt_logprobs": self.args.return_all_prompt_logprobs, + "use_dynamic_prompt_cache": self.args.use_dynamic_prompt_cache, } init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) - await self._init_prompt_cache() - - self.req_queue = ReqQueue(self.args, - self.prompt_cache_used_tokens, - self.prompt_cache_req_num) - return - - async def _init_prompt_cache(self): - """ - 初始化 prompt cache 特性, 这个地方的id 分配要于 httpserver 中的id 分配对齐 - """ - # 初始化 prompt cahce, 然后初始化请求队列 - self.prompt_cache_used_tokens = 0 - self.prompt_cache_req_num = len(self.args.prompt_cache_strs) - if self.is_splitfuse_mode: - reqs = [] - id = -1 # id 从 -1, -2, .... 避免和正常的 id 占用 - for prompt_cache_str in self.args.prompt_cache_strs: - prompt_ids = self.tokenizer.encode(prompt_cache_str) - req = NormalReq(id, prompt_ids, SamplingParams(stop_sequences=[])) - self.prompt_cache_used_tokens += len(prompt_ids) - reqs.append(req) - id -= 1 - if len(reqs) != 0: - self.prompt_cache_batch = Batch(uuid.uuid4().hex, reqs) - await self._prefill_to_init_prompt_cache(self.prompt_cache_batch) + self.req_queue = ReqQueue(self.args) return def add_req( @@ -116,15 +89,11 @@ def add_req( sampling_params: SamplingParams, multimodal_params: MultimodalParams, request_id: str, - prompt_cache_len, - prompt_cache_req_id - ): + ): if self.is_splitfuse_mode: - req = SplitFuseReq(request_id, prompt_ids, sampling_params, multimodal_params, - prompt_cache_len, prompt_cache_req_id, self.splitfuse_block_size) + req = SplitFuseReq(request_id, prompt_ids, sampling_params, multimodal_params, self.splitfuse_block_size) else: - req = NormalReq(request_id, prompt_ids, sampling_params, multimodal_params, - prompt_cache_len, prompt_cache_req_id) + req = NormalReq(request_id, prompt_ids, sampling_params, multimodal_params) self.req_queue.append(req) self.send_to_detokenization.send_pyobj(req.to_req_detokenization_state()) return @@ -139,23 +108,25 @@ async def abort(self, request_id): req.finish_status = FinishStatus.FINISHED_ABORT return - async def loop_for_fwd(self,): + async def loop_for_fwd( + self, + ): counter_count = 0 while True: await self._step() counter_count += 1 if self.running_batch is not None: if counter_count % 50 == 0: - total_used_tokens = self.prompt_cache_used_tokens + self.running_batch.batch_used_tokens + self.req_queue.pause_req_used_tokens + total_used_tokens = self.running_batch.batch_used_tokens + self.req_queue.pause_req_used_tokens token_ratio = total_used_tokens / self.max_total_token_num logger.debug( - f"current batch size: {len(self.running_batch.reqs)} " + f"current batch size: {len(self.running_batch.reqs)} " f"paused req num: {len(self.req_queue.pause_req_dict)} " f"token used ratio: {token_ratio} " ) pass self.stats_tool.print_stats() - + if self.running_batch is None: await asyncio.sleep(0.01) # 10ms @@ -197,7 +168,9 @@ async def _step(self): return else: # pause strategy - paused_reqs = select_paused_reqs(self.running_batch, self.pause_strategy, self.req_queue, self.max_total_token_num) + paused_reqs = select_paused_reqs( + self.running_batch, self.pause_strategy, self.req_queue, self.max_total_token_num + ) await self._pause_reqs(self.running_batch, paused_reqs) logger.debug(f"pasued req num: {len(self.req_queue.pause_req_dict)}") self.has_wait_tokens = 0 @@ -212,11 +185,11 @@ async def _init_batch(self, batch: Batch): req_to_req_status = obtain(ans[0]) else: req_to_req_status = ans[0] - + self._update_init_status_to_batch(batch, req_to_req_status) return - async def _prefill_batch(self, batch:Batch): + async def _prefill_batch(self, batch: Batch): await self._init_batch(batch) if not self.is_splitfuse_mode: # 在 非 splitfuse 模式下,才需要真的执行 prefill 的操作。 @@ -233,24 +206,8 @@ async def _prefill_batch(self, batch:Batch): batch.filter_out_finished_req(unfinished_req_ids, finished_req_ids) await self._handle_finish_req(batch, unfinished_req_ids, finished_req_ids) return - - async def _prefill_to_init_prompt_cache(self, batch:Batch): - """ - 专用于初始化prompt cahce 请求的接口, 只在 splitfuse + prompt cache 模式下调用 - """ - await self._init_batch(batch) - # 在 splitfuse 模式下,才需要真的执行 prefill 的操作。 - rets = [self.model_rpcs[tp_rank].prefill_batch(batch.batch_id) for tp_rank in range(self.world_size)] - ans = await asyncio.gather(*rets) - if self.world_size != 1: - req_to_out_status = obtain(ans[0]) - else: - req_to_out_status = ans[0] - - self._update_out_status_to_batch(batch, req_to_out_status) - return - async def _decode_batch(self, batch:Batch): + async def _decode_batch(self, batch: Batch): rets = [self.model_rpcs[tp_rank].decode_batch(batch.batch_id) for tp_rank in range(self.world_size)] ans = await asyncio.gather(*rets) if self.world_size != 1: @@ -266,12 +223,17 @@ async def _decode_batch(self, batch:Batch): return async def _filter_batch(self, batch: Batch, unfinished_req_ids, finished_req_ids: List): - rets = [self.model_rpcs[tp_rank].filter_batch(batch.batch_id, unfinished_req_ids, finished_req_ids) for tp_rank in range(self.world_size)] + rets = [ + self.model_rpcs[tp_rank].filter_batch(batch.batch_id, unfinished_req_ids, finished_req_ids) + for tp_rank in range(self.world_size) + ] await asyncio.gather(*rets) return async def _merge_batch(self, batch1, batch2): - rets = [self.model_rpcs[tp_rank].merge_batch(batch1.batch_id, batch2.batch_id) for tp_rank in range(self.world_size)] + rets = [ + self.model_rpcs[tp_rank].merge_batch(batch1.batch_id, batch2.batch_id) for tp_rank in range(self.world_size) + ] await asyncio.gather(*rets) return @@ -279,10 +241,12 @@ async def _remove_batch(self, batch): rets = [self.model_rpcs[tp_rank].remove_batch(batch.batch_id) for tp_rank in range(self.world_size)] await asyncio.gather(*rets) return - + async def _pause_reqs(self, batch: Batch, pasue_reqs): pasue_reqs_info = [(r.request_id, r.req_status) for r in pasue_reqs] - rets = [self.model_rpcs[tp_rank].pause_reqs(batch.batch_id, pasue_reqs_info) for tp_rank in range(self.world_size)] + rets = [ + self.model_rpcs[tp_rank].pause_reqs(batch.batch_id, pasue_reqs_info) for tp_rank in range(self.world_size) + ] await asyncio.gather(*rets) return @@ -298,27 +262,27 @@ def _filter_runing_batch(self): if self.running_batch is not None and self.running_batch.is_clear(): self.running_batch = None return - + def _update_init_status_to_batch(self, batch: Batch, req_to_req_status): # 更新请求状态 new_batch_used_tokens = 0 - new_batch_decode_need_tokens = 0 # 只有在 splitfuse 模式下有意义 + new_batch_decode_need_tokens = 0 # 只有在 splitfuse 模式下有意义 for req_id, (req_status, cur_kv_len) in req_to_req_status.items(): r_obj = batch.id_to_reqs[req_id] r_obj.req_status = req_status r_obj.cur_kv_len = cur_kv_len new_batch_used_tokens += r_obj.get_used_tokens() new_batch_decode_need_tokens += r_obj.get_decode_need_tokens() - + batch.batch_used_tokens = new_batch_used_tokens batch.batch_decode_need_tokens = new_batch_decode_need_tokens return - + def _update_out_status_to_batch(self, batch: Batch, req_to_out_status): new_batch_used_tokens = 0 - new_batch_decode_need_tokens = 0 # 只有在 splitfuse 模式下有意义 + new_batch_decode_need_tokens = 0 # 只有在 splitfuse 模式下有意义 for req_id, (req_status, cur_kv_len, new_token_id, new_gen_metadata) in req_to_out_status.items(): - req : Req = batch.id_to_reqs[req_id] + req: Req = batch.id_to_reqs[req_id] req.req_status = req_status req.cur_kv_len = cur_kv_len if new_token_id is not None: @@ -326,16 +290,16 @@ def _update_out_status_to_batch(self, batch: Batch, req_to_out_status): req.output_metadata_list.append(new_gen_metadata) new_batch_used_tokens += req.get_used_tokens() new_batch_decode_need_tokens += req.get_decode_need_tokens() - + batch.batch_used_tokens = new_batch_used_tokens batch.batch_decode_need_tokens = new_batch_decode_need_tokens return - + def _can_decode(self, batch: Batch): - total_used_tokens = self.prompt_cache_used_tokens + batch.batch_used_tokens + self.req_queue.pause_req_used_tokens + total_used_tokens = batch.batch_used_tokens + self.req_queue.pause_req_used_tokens remaining_tokens = self.max_total_token_num - total_used_tokens return batch.batch_decode_need_tokens <= remaining_tokens - + def _send_to_detokenization_proc(self, batch: Batch, req_ans): batch_out = BatchTokenIdOut() for req_id, (_, _, new_token_id, new_gen_metadata) in req_ans.items(): @@ -343,16 +307,16 @@ def _send_to_detokenization_proc(self, batch: Batch, req_ans): if new_token_id is not None: # req.finish_status 传输 value值 不传送对象,可以减少序列化对象的大小。 batch_out.reqs_infs.append((req_id, new_token_id, new_gen_metadata, req.finish_status.value)) - + self.send_to_detokenization.send_pyobj(batch_out) return async def loop_for_netio_req(self): while True: recv_req = await self.recv_from_httpserver.recv_pyobj() - if isinstance(recv_req, tuple) and len(recv_req) == 6: - prompt_ids, sampling_params, multimodal_params, request_id, prompt_cache_len, prompt_cache_req_id = recv_req - self.add_req(prompt_ids, sampling_params, multimodal_params, request_id, prompt_cache_len, prompt_cache_req_id) + if isinstance(recv_req, tuple) and len(recv_req) == 4: + prompt_ids, sampling_params, multimodal_params, request_id = recv_req + self.add_req(prompt_ids, sampling_params, multimodal_params, request_id) elif isinstance(recv_req, AbortReq): abort_req = recv_req request_id = abort_req.req_id @@ -368,26 +332,26 @@ def clean_up(self): model_rpc.rpc_server_process.join() return + def start_router_process(args, router_port, detokenization_port, model_rpc_ports, pipe_writer): try: router = RouterManager( - args, - router_port=router_port, - detokenization_port=detokenization_port, - model_rpc_ports=model_rpc_ports) - + args, router_port=router_port, detokenization_port=detokenization_port, model_rpc_ports=model_rpc_ports + ) + asyncio.run(router.wait_to_model_ready()) - except Exception as e: + except: import traceback import sys + etype, evalue, tb = sys.exc_info() - err_str = '\n'.join(traceback.format_exception(etype, evalue, tb)) + err_str = "\n".join(traceback.format_exception(etype, evalue, tb)) pipe_writer.send(err_str) router.clean_up() raise - pipe_writer.send('init ok') - + pipe_writer.send("init ok") + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.create_task(router.loop_for_fwd()) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 8daa524bf..cf9316f9e 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -9,12 +9,16 @@ from lightllm.common.mem_manager import MemoryManager from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.io_struct import ReqRunStatus +from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) requests_mapping = {} -class InferSamplingParams: +class InferSamplingParams: def __init__( self, do_sample: bool = False, @@ -39,7 +43,6 @@ def __init__( class InferReq: - def __init__( self, r_id, @@ -49,8 +52,6 @@ def __init__( req_idx=-1, prompt_len=0, req_status=None, - prompt_cache_len=None, - prompt_cache_req_id=None, multimodal_params=None, ) -> None: self.r_id = r_id @@ -61,9 +62,10 @@ def __init__( self.prompt_len = prompt_len self.input_token_ids = input_token_ids self.req_status = req_status - self.cur_kv_len = 0 # 当前已经占用掉 token 现存的 kv len 长度 - self.prompt_cache_len = prompt_cache_len # 可以复用的一些公共 prompt 头对应的 kv cache 长度, prompt cache 目前只会在 splitfuse 模式下使用 - self.prompt_cache_req_id = prompt_cache_req_id # 对应的可复用的请求的 id,方便初始化的时候,将其 kv cache 复制到当前请求中 + self.cur_kv_len = 0 # 当前已经占用掉 token 现存的 kv len 长度 + + self.shared_kv_node = None + self.ready_cache_len = 0 return @@ -72,63 +74,75 @@ class InferBatch: batch_id: int request_ids: List req_manager: ReqManager - + radix_cache: RadixCache + @classmethod @torch.no_grad() - def init_batch(cls, batch_id, requests, dtype: torch.dtype, device: torch.device, req_manager:ReqManager, vocab_size: int): + def init_batch( + cls, + batch_id, + requests, + dtype: torch.dtype, + device: torch.device, + req_manager: ReqManager, + vocab_size: int, + radix_cache: RadixCache = None, + ): request_ids = [] - need_alloc_size = len([r for r in requests if r['request_id'] not in requests_mapping]) + need_alloc_size = len([r for r in requests if r["request_id"] not in requests_mapping]) nopad_b_req_idx = req_manager.alloc(need_alloc_size) nopad_b_req_idx = nopad_b_req_idx.cpu().numpy() - + index = 0 for r in requests: # request id -> idx in list mapping - r_id = r['request_id'] + r_id = r["request_id"] if r_id not in requests_mapping.keys(): - tokenized_input = r['input_id'] + tokenized_input = r["input_id"] input_length = len(tokenized_input) # postprocessor sampling_param = r["sampling_param"] multimodal_params = r["multimodal_params"] sampling_param["vocab_size"] = vocab_size - assert r['req_status'] == ReqRunStatus.WAIT_IN_QUEUE - r_obj = InferReq(r_id, - input_token_ids=tokenized_input, - out_token_id_count=collections.defaultdict(int), - sampling_param=InferSamplingParams(**sampling_param), - multimodal_params=multimodal_params, - req_idx=nopad_b_req_idx[index], - prompt_len=input_length, - req_status=r['req_status'], - prompt_cache_len=r['prompt_cache_len'], - prompt_cache_req_id=r['prompt_cache_req_id']) + assert r["req_status"] == ReqRunStatus.WAIT_IN_QUEUE + r_obj = InferReq( + r_id, + input_token_ids=tokenized_input, + out_token_id_count=collections.defaultdict(int), + sampling_param=InferSamplingParams(**sampling_param), + multimodal_params=multimodal_params, + req_idx=nopad_b_req_idx[index], + prompt_len=input_length, + req_status=r["req_status"], + ) requests_mapping[r_id] = r_obj index += 1 else: if requests_mapping[r_id].req_status == ReqRunStatus.PAUSED_AND_OFFLOAD: - r_obj : InferReq = requests_mapping[r_id] + r_obj: InferReq = requests_mapping[r_id] r_obj.req_status = ReqRunStatus.RERUNNING_FROM_OFFLOAD - elif requests_mapping[r_id].req_status == ReqRunStatus.PAUSED_AND_KVKEEP: - r_obj : InferReq = requests_mapping[r_id] - r_obj.req_status = ReqRunStatus.RERUNNING_FROM_KVKEEP else: assert False, f"should not exist {requests_mapping[r_id].req_status}" - + request_ids.append(r_id) # 如果是具有 prompt_cache 的使用特性则需要进行提前的填充和恢复操作。 if r_obj.req_status in [ReqRunStatus.RERUNNING_FROM_OFFLOAD, ReqRunStatus.WAIT_IN_QUEUE]: - if r_obj.prompt_cache_len != 0: # 有利用prompt_cache_len - prompt_cache_req_obj : InferReq = requests_mapping[r_obj.prompt_cache_req_id] - prompt_kv_tokens = req_manager.req_to_token_indexs[prompt_cache_req_obj.req_idx, 0:r_obj.prompt_cache_len] - mem_manager : MemoryManager = req_manager.mem_manager - mem_manager.add_refs(prompt_kv_tokens.long()) # 加 refs - req_manager.req_to_token_indexs[r_obj.req_idx, 0:r_obj.prompt_cache_len] = prompt_kv_tokens - r_obj.cur_kv_len = r_obj.prompt_cache_len - + if radix_cache is not None: + key = torch.tensor(r_obj.input_token_ids, dtype=torch.int64, device="cpu") + key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 + share_node, kv_len, value_tensor = radix_cache.match_prefix(key, update_refs=True) + if share_node is not None: + r_obj.shared_kv_node = share_node + r_obj.ready_cache_len = share_node.shared_idx_node.get_node_prefix_total_len() + mem_manager: MemoryManager = req_manager.mem_manager + value_tensor = value_tensor.long().cuda() + mem_manager.add_refs(value_tensor) # 加 refs + req_manager.req_to_token_indexs[r_obj.req_idx, 0 : r_obj.ready_cache_len] = value_tensor + r_obj.cur_kv_len = r_obj.ready_cache_len + # 初始化之后 所有请求状态置换为 RUNNING 状态 r_obj.req_status = ReqRunStatus.RUNNING @@ -136,23 +150,48 @@ def init_batch(cls, batch_id, requests, dtype: torch.dtype, device: torch.device batch_id=batch_id, request_ids=request_ids, req_manager=req_manager, + radix_cache=radix_cache, ) - + + def _free_a_req_mem(self, free_token_index: List, req: InferReq): + if self.radix_cache is None: + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len]) + else: + key = torch.tensor(req.input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() + prefix_len = self.radix_cache.insert(key, value) + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][:prefix_len]) + if req.shared_kv_node is not None: + assert req.shared_kv_node.shared_idx_node.get_node_prefix_total_len() <= prefix_len + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + req.shared_kv_node = None + req.ready_cache_len = 0 + @torch.no_grad() def free_self(self): free_req_index = [] free_token_index = [] for request_id in self.request_ids: - req : InferReq = requests_mapping.pop(request_id) + req: InferReq = requests_mapping.pop(request_id) free_req_index.append(req.req_idx) - free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][:req.cur_kv_len]) - + self._free_a_req_mem(free_token_index, req) + req.cur_kv_len = 0 + free_token_index = torch.cat(free_token_index, dim=-1) self.req_manager.free(free_req_index, free_token_index) if len(requests_mapping) == 0: requests_mapping.clear() + + if self.radix_cache is not None: + logger.info( + f"""free a batch state: + radix refed token num {self.radix_cache.get_refed_tokens_num()} + radix hold token num {self.radix_cache.get_tree_total_tokens_num()} + mem manager can alloc token num {self.req_manager.mem_manager.can_use_mem_size} + mem manager total size {self.req_manager.mem_manager.size}""" + ) return - + @torch.no_grad() def filter(self, request_ids: List[str], finished_request_ids: List[str]): if len(requests_mapping) == 0: @@ -161,49 +200,48 @@ def filter(self, request_ids: List[str], finished_request_ids: List[str]): return self if len(request_ids) == 0: self.free_self() - return InferBatch( - batch_id=self.batch_id, - request_ids=[], - req_manager=self.req_manager - ) + return InferBatch(batch_id=self.batch_id, request_ids=[], req_manager=self.req_manager, radix_cache=self.radix_cache) free_req_index = [] free_token_index = [] for request_id in finished_request_ids: - req : InferReq = requests_mapping.pop(request_id) + req: InferReq = requests_mapping.pop(request_id) free_req_index.append(req.req_idx) - free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][:req.cur_kv_len]) + self._free_a_req_mem(free_token_index, req) + req.cur_kv_len = 0 + free_token_index = torch.cat(free_token_index, dim=-1) self.req_manager.free(free_req_index, free_token_index) - + return InferBatch( - batch_id=self.batch_id, - request_ids=request_ids, - req_manager=self.req_manager, + batch_id=self.batch_id, request_ids=request_ids, req_manager=self.req_manager, radix_cache=self.radix_cache ) @torch.no_grad() def pause_reqs(self, pause_reqs: List[str]): for request_id, pause_way in pause_reqs: - req : InferReq = requests_mapping[request_id] + req: InferReq = requests_mapping[request_id] req.req_status = pause_way self.request_ids.remove(request_id) if pause_way == ReqRunStatus.PAUSED_AND_OFFLOAD: # 现在只支持全卸载一个请求的所有 kv 了 - self.req_manager.free_token(self.req_manager.req_to_token_indexs[req.req_idx][:req.cur_kv_len]) + free_token_index = [] + self._free_a_req_mem(free_token_index, req) + self.req_manager.free_token(free_token_index[0]) req.cur_kv_len = 0 + return self @classmethod @torch.no_grad() def merge(cls, batch1, batch2): request_ids = batch1.request_ids + batch2.request_ids - + return InferBatch( batch_id=batch1.batch_id, request_ids=request_ids, req_manager=batch1.req_manager, + radix_cache=batch1.radix_cache, ) def __len__(self): return len(self.request_ids) - diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 2ecd2dd6d..c49fd12ed 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -2,7 +2,6 @@ import numpy as np import rpyc import torch -import traceback from datetime import timedelta from typing import Dict, List, Tuple from transformers.configuration_utils import PretrainedConfig @@ -45,6 +44,7 @@ from .infer_batch import InferReq from lightllm.server.io_struct import ReqRunStatus from lightllm.utils.log_utils import init_logger +from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache class ModelRpcServer(rpyc.Service): @@ -65,6 +65,7 @@ def exposed_init_model(self, kvargs): self.is_splitfuse_mode = kvargs.get("is_splitfuse_mode", False) self.splitfuse_block_size = kvargs.get("splitfuse_block_size", None) self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False) + self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False) self.cache = {} self.logger = init_logger(__name__) @@ -77,6 +78,12 @@ def exposed_init_model(self, kvargs): ) torch.cuda.set_device(self.tp_rank) + self.radix_cache = ( + RadixCache(str(kvargs["nccl_port"]), max_total_token_num, self.tp_rank) + if self.use_dynamic_prompt_cache + else None + ) + model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) model_kvargs = { @@ -166,10 +173,13 @@ def exposed_init_model(self, kvargs): raise Exception(f"can not support {self.model_type} now") except Exception as e: self.logger.error(f"load model error: {str(e)} {e} {type(e)}") + import traceback + traceback.print_exc() raise e set_random_seed(2147483647) + return # @calculate_time(show=True, min_cost_ms=0.1) @@ -183,7 +193,13 @@ def exposed_add_batch(self, batch_id, reqs, dtype): else: assert False, "error dtype" batch_data = InferBatch.init_batch( - batch_id, reqs, dtype, torch.cuda.current_device(), self.model.req_manager, self.model.vocab_size + batch_id, + reqs, + dtype, + torch.cuda.current_device(), + self.model.req_manager, + self.model.vocab_size, + self.radix_cache, ) self.cache[batch_id] = batch_data @@ -256,38 +272,31 @@ def forward(self, batch_id, is_prefill): output_dict = {} batch: InferBatch = self.cache.pop(batch_id) if is_prefill: - kwargs, run_reqs, not_run_reqs = prepare_prefill_inputs(batch, self.is_multimodal) + kwargs, run_reqs = prepare_prefill_inputs( + batch, self.radix_cache, self.model.mem_manager, self.is_multimodal + ) else: - kwargs, run_reqs, not_run_reqs = prepare_decode_inputs(batch) + kwargs, run_reqs = prepare_decode_inputs(batch, self.radix_cache, self.model.mem_manager) - if len(run_reqs) >= 1: - logits = self.model.forward(**kwargs) - next_token_ids, next_token_probs = sample(logits, run_reqs) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - - for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): - # prefill and decode is same - req_obj.cur_kv_len = len(req_obj.input_token_ids) - req_obj.input_token_ids.append(next_token_id) - req_obj.out_token_id_count[next_token_id] += 1 - metadata = { - "id": int(next_token_id), - "logprob": float(next_token_logprob), - } - output_dict[req_obj.r_id] = ( - req_obj.req_status, - req_obj.cur_kv_len, - int(next_token_id), - metadata, - ) # 状态, cur_kv_len, token_id, metadata + logits = self.model.forward(**kwargs) + next_token_ids, next_token_probs = sample(logits, run_reqs) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - for req_obj in not_run_reqs: + for req_obj, next_token_id, next_token_logprob in zip(run_reqs, next_token_ids, next_token_logprobs): + # prefill and decode is same + req_obj.cur_kv_len = len(req_obj.input_token_ids) + req_obj.input_token_ids.append(next_token_id) + req_obj.out_token_id_count[next_token_id] += 1 + metadata = { + "id": int(next_token_id), + "logprob": float(next_token_logprob), + } output_dict[req_obj.r_id] = ( req_obj.req_status, req_obj.cur_kv_len, - None, - None, + int(next_token_id), + metadata, ) # 状态, cur_kv_len, token_id, metadata self.cache[batch.batch_id] = batch @@ -295,62 +304,55 @@ def forward(self, batch_id, is_prefill): @torch.no_grad() def _prefill_to_return_all_prompt_logprobs(self, batch_id): + # 在 return all_prompt_logprobs 的模式下,不能启用 dynamic prompt cache + assert self.radix_cache is None output_dict = {} batch: InferBatch = self.cache.pop(batch_id) - kwargs, run_reqs, not_run_reqs = prepare_prefill_inputs(batch) - - if len(run_reqs) >= 1: - prompt_all_logits = self.model.forward(**kwargs) - input_ids = kwargs["input_ids"] - b_start_loc = kwargs["b_start_loc"] - b_seq_len = kwargs["b_seq_len"] - last_index = torch.cumsum(b_seq_len, dim=0, dtype=torch.long) - 1 - logits = prompt_all_logits[last_index, :] - - next_token_ids, next_token_probs = sample(logits, run_reqs) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - - b_start_loc = b_start_loc.cpu().numpy() - b_seq_len = b_seq_len.cpu().numpy() - for req_obj, next_token_id, next_token_logprob, start_loc, seq_len in zip( - run_reqs, next_token_ids, next_token_logprobs, b_start_loc, b_seq_len - ): - # prefill and decode is same - req_obj.cur_kv_len = len(req_obj.input_token_ids) - req_obj.input_token_ids.append(next_token_id) - req_obj.out_token_id_count[next_token_id] += 1 - metadata = { - "id": int(next_token_id), - "logprob": float(next_token_logprob), - } - - cur_ids: torch.Tensor = input_ids[start_loc : start_loc + seq_len] - cur_logits = prompt_all_logits[start_loc : start_loc + seq_len] - cur_logprobs = torch.log_softmax(cur_logits, dim=-1, dtype=torch.float)[0:-1, :] - cur_logprobs = torch.gather(cur_logprobs, dim=1, index=cur_ids[1:].view(-1, 1)).detach().cpu().numpy() + kwargs, run_reqs = prepare_prefill_inputs(batch, self.radix_cache, self.model.mem_manager) - cur_ids = cur_ids.cpu().numpy() - all_prompts = [] - for index in range(len(cur_ids) - 1): - tmp_dict = {int(cur_ids[index + 1]): float(cur_logprobs[index, 0])} - all_prompts.append([int(cur_ids[index]), tmp_dict]) + prompt_all_logits = self.model.forward(**kwargs) + input_ids = kwargs["input_ids"] + b_start_loc = kwargs["b_start_loc"] + b_seq_len = kwargs["b_seq_len"] + last_index = torch.cumsum(b_seq_len, dim=0, dtype=torch.long) - 1 + logits = prompt_all_logits[last_index, :] - metadata["prompt_logprobs"] = all_prompts - metadata["prompt_token_ids"] = [int(e) for e in cur_ids] - output_dict[req_obj.r_id] = ( - req_obj.req_status, - req_obj.cur_kv_len, - int(next_token_id), - metadata, - ) # 状态, cur_kv_len, token_id, metadata + next_token_ids, next_token_probs = sample(logits, run_reqs) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - for req_obj in not_run_reqs: + b_start_loc = b_start_loc.cpu().numpy() + b_seq_len = b_seq_len.cpu().numpy() + for req_obj, next_token_id, next_token_logprob, start_loc, seq_len in zip( + run_reqs, next_token_ids, next_token_logprobs, b_start_loc, b_seq_len + ): + # prefill and decode is same + req_obj.cur_kv_len = len(req_obj.input_token_ids) + req_obj.input_token_ids.append(next_token_id) + req_obj.out_token_id_count[next_token_id] += 1 + metadata = { + "id": int(next_token_id), + "logprob": float(next_token_logprob), + } + + cur_ids: torch.Tensor = input_ids[start_loc : start_loc + seq_len] + cur_logits = prompt_all_logits[start_loc : start_loc + seq_len] + cur_logprobs = torch.log_softmax(cur_logits, dim=-1, dtype=torch.float)[0:-1, :] + cur_logprobs = torch.gather(cur_logprobs, dim=1, index=cur_ids[1:].view(-1, 1)).detach().cpu().numpy() + + cur_ids = cur_ids.cpu().numpy() + all_prompts = [] + for index in range(len(cur_ids) - 1): + tmp_dict = {int(cur_ids[index + 1]): float(cur_logprobs[index, 0])} + all_prompts.append([int(cur_ids[index]), tmp_dict]) + + metadata["prompt_logprobs"] = all_prompts + metadata["prompt_token_ids"] = [int(e) for e in cur_ids] output_dict[req_obj.r_id] = ( req_obj.req_status, req_obj.cur_kv_len, - None, - None, + int(next_token_id), + metadata, ) # 状态, cur_kv_len, token_id, metadata self.cache[batch.batch_id] = batch @@ -360,7 +362,9 @@ def _prefill_to_return_all_prompt_logprobs(self, batch_id): def splitfuse_forward(self, batch_id): output_dict = {} batch: InferBatch = self.cache.pop(batch_id) - kwargs, decode_reqs, prefill_reqs = splitfuse_prepare_decode_inputs(batch, self.splitfuse_block_size) + kwargs, decode_reqs, prefill_reqs = splitfuse_prepare_decode_inputs( + batch, self.splitfuse_block_size, self.radix_cache, self.model.mem_manager + ) decode_req_num = len(decode_reqs) all_reqs = decode_reqs all_reqs.extend(prefill_reqs) diff --git a/lightllm/server/router/model_infer/pre_process.py b/lightllm/server/router/model_infer/pre_process.py index be5f34f2b..0f9dfdfa7 100644 --- a/lightllm/server/router/model_infer/pre_process.py +++ b/lightllm/server/router/model_infer/pre_process.py @@ -3,10 +3,12 @@ from .infer_batch import requests_mapping, InferReq, InferBatch from lightllm.server.io_struct import ReqRunStatus from lightllm.utils.infer_utils import calculate_time +from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache +from lightllm.common.mem_manager import MemoryManager -#@calculate_time(show=True, min_cost_ms=1) -def prepare_prefill_inputs(batch:InferBatch, is_multimodal=False): - run_reqs, not_run_reqs = [], [] +# @calculate_time(show=True, min_cost_ms=1) +def prepare_prefill_inputs(batch: InferBatch, radix_cache: RadixCache, mem_manager: MemoryManager, is_multimodal=False): + run_reqs = [] nopad_total_token_num = 0 nopad_max_len_in_batch = 0 start_loc = 0 @@ -15,56 +17,59 @@ def prepare_prefill_inputs(batch:InferBatch, is_multimodal=False): nopad_b_start_loc = [] nopad_b_seq_len = [] batch_multimodal_params = [] + b_ready_cache_len = [] for request_id in batch.request_ids: - req : InferReq = requests_mapping[request_id] + req: InferReq = requests_mapping[request_id] assert req.req_status == ReqRunStatus.RUNNING - # 当请求已经存在 cur_kv_len 不为 0 的时候,就不需要做全 prefill 操作了, - # 说明是从 RERUNNING_FROM_KVKEEP 中 恢复的请求 - if req.cur_kv_len != 0: - not_run_reqs.append(req) - continue - + run_reqs.append(req) batch_multimodal_params.append(req.multimodal_params) nopad_b_req_idx.append(req.req_idx) nopad_b_start_loc.append(start_loc) - + seq_len = len(req.input_token_ids) - input_id = req.input_token_ids - + input_token_len = seq_len - req.ready_cache_len + + input_id = req.input_token_ids[req.ready_cache_len :] + nopad_b_seq_len.append(seq_len) input_ids.append(input_id) nopad_total_token_num += seq_len - nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len) - start_loc += seq_len - - if len(run_reqs) >= 1: - - input_ids = np.concatenate(input_ids, dtype=np.int64) - - input_ids = torch.tensor(input_ids, dtype=torch.int64, device='cuda') - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device='cuda') - nopad_b_start_loc = torch.tensor(nopad_b_start_loc, dtype=torch.int32, device='cuda') - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device='cuda') - kwargs = { - "batch_size": len(batch), - "total_token_num": nopad_total_token_num, - "max_len_in_batch": nopad_max_len_in_batch, - "input_ids": input_ids, - "b_req_idx": nopad_b_req_idx, - "b_start_loc": nopad_b_start_loc, - "b_seq_len": nopad_b_seq_len, - "is_prefill": True, - } - if is_multimodal: - kwargs["multimodal_params"] = batch_multimodal_params - return kwargs, run_reqs, not_run_reqs - else: - return {}, run_reqs, not_run_reqs - -#@calculate_time(show=True, min_cost_ms=1) -def prepare_decode_inputs(batch:InferBatch): - run_reqs, not_run_reqs = [], [] + nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_token_len) + b_ready_cache_len.append(req.ready_cache_len) + start_loc += input_token_len + + input_ids = np.concatenate(input_ids, dtype=np.int64) + + input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") + nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") + nopad_b_start_loc = torch.tensor(nopad_b_start_loc, dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda") + kwargs = { + "batch_size": len(batch), + "total_token_num": nopad_total_token_num, + "max_len_in_batch": nopad_max_len_in_batch, + "input_ids": input_ids, + "b_req_idx": nopad_b_req_idx, + "b_start_loc": nopad_b_start_loc, + "b_seq_len": nopad_b_seq_len, + "b_ready_cache_len": b_ready_cache_len, + "is_prefill": True, + } + if is_multimodal: + kwargs["multimodal_params"] = batch_multimodal_params + + # dynamic prompt cache 准备 token + if radix_cache is not None: + _free_radix_cache_to_get_enough_token(radix_cache, mem_manager, input_ids.shape[0]) + + return kwargs, run_reqs + + +# @calculate_time(show=True, min_cost_ms=1) +def prepare_decode_inputs(batch: InferBatch, radix_cache: RadixCache, mem_manager: MemoryManager): + run_reqs = [] nopad_total_token_num = 0 nopad_max_len_in_batch = 0 start_loc = 0 @@ -73,7 +78,7 @@ def prepare_decode_inputs(batch:InferBatch): nopad_b_start_loc = [] nopad_b_seq_len = [] for request_id in batch.request_ids: - req : InferReq = requests_mapping[request_id] + req: InferReq = requests_mapping[request_id] assert req.req_status == ReqRunStatus.RUNNING run_reqs.append(req) nopad_b_req_idx.append(req.req_idx) @@ -86,32 +91,35 @@ def prepare_decode_inputs(batch:InferBatch): nopad_total_token_num += seq_len nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len) start_loc += seq_len - - if len(run_reqs) >= 1: - - input_ids = torch.tensor(input_ids, dtype=torch.int64, device='cuda') - nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device='cuda') - nopad_b_start_loc = torch.tensor(nopad_b_start_loc, dtype=torch.int32, device='cuda') - nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device='cuda') - kwargs = { - "batch_size": len(batch), - "total_token_num": nopad_total_token_num, - "max_len_in_batch": nopad_max_len_in_batch, - "input_ids": input_ids, - "b_req_idx": nopad_b_req_idx, - "b_start_loc": nopad_b_start_loc, - "b_seq_len": nopad_b_seq_len, - "is_prefill": False - } - return kwargs, run_reqs, not_run_reqs - else: - return {}, run_reqs, not_run_reqs - -#@calculate_time(show=True, min_cost_ms=1) -def splitfuse_prepare_decode_inputs(batch:InferBatch, splitfuse_block_size): + + input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") + nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda") + nopad_b_start_loc = torch.tensor(nopad_b_start_loc, dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") + kwargs = { + "batch_size": len(batch), + "total_token_num": nopad_total_token_num, + "max_len_in_batch": nopad_max_len_in_batch, + "input_ids": input_ids, + "b_req_idx": nopad_b_req_idx, + "b_start_loc": nopad_b_start_loc, + "b_seq_len": nopad_b_seq_len, + "is_prefill": False, + } + # dynamic prompt cache 准备 token + if radix_cache is not None: + _free_radix_cache_to_get_enough_token(radix_cache, mem_manager, input_ids.shape[0]) + + return kwargs, run_reqs + + +# @calculate_time(show=True, min_cost_ms=1) +def splitfuse_prepare_decode_inputs( + batch: InferBatch, splitfuse_block_size, radix_cache: RadixCache, mem_manager: MemoryManager +): decode_reqs, prefill_reqs = [], [] for request_id in batch.request_ids: - req : InferReq = requests_mapping[request_id] + req: InferReq = requests_mapping[request_id] if req.cur_kv_len == len(req.input_token_ids) - 1: decode_reqs.append(req) elif req.cur_kv_len < len(req.input_token_ids) - 1: @@ -137,12 +145,12 @@ def splitfuse_prepare_decode_inputs(batch:InferBatch, splitfuse_block_size): decode_b_seq_len.append(seq_len) decode_max_len_in_batch = max(decode_max_len_in_batch, seq_len) input_ids.append(req.input_token_ids[-1]) - + prefill_req_num = len(prefill_reqs) prefill_b_req_idx = [] prefill_b_split_start_loc = [] split_start_loc = 0 - prefill_b_split_seq_len = [] + prefill_b_split_ready_cache_len = [] prefill_max_split_seq_len_in_batch = 0 prefill_b_seq_len = [] @@ -151,27 +159,48 @@ def splitfuse_prepare_decode_inputs(batch:InferBatch, splitfuse_block_size): split_len = min(len(req.input_token_ids) - req.cur_kv_len, splitfuse_block_size) prefill_b_split_start_loc.append(split_start_loc) split_start_loc += split_len - prefill_b_split_seq_len.append(split_len) + prefill_b_split_ready_cache_len.append(req.cur_kv_len) prefill_max_split_seq_len_in_batch = max(prefill_max_split_seq_len_in_batch, split_len) seq_len = req.cur_kv_len + split_len prefill_b_seq_len.append(seq_len) input_ids.extend(req.input_token_ids[seq_len - split_len : seq_len]) - + + input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda") kwargs = { - "input_ids": torch.tensor(input_ids, dtype=torch.int64, device='cuda'), - "decode_req_num": decode_req_num, - "decode_total_token_num": decode_total_token_num, - "decode_b_req_idx": torch.tensor(decode_b_req_idx, dtype=torch.int32, device='cuda'), - "decode_b_start_loc": torch.tensor(decode_b_start_loc, dtype=torch.int32, device="cuda"), - "decode_b_seq_len": torch.tensor(decode_b_seq_len, dtype=torch.int32, device="cuda"), - "decode_max_len_in_batch": decode_max_len_in_batch, - - "prefill_req_num": prefill_req_num, - "prefill_b_req_idx": torch.tensor(prefill_b_req_idx, dtype=torch.int32, device="cuda"), - "prefill_b_split_start_loc" : torch.tensor(prefill_b_split_start_loc, dtype=torch.int32, device="cuda"), - "prefill_b_split_seq_len" : torch.tensor(prefill_b_split_seq_len, dtype=torch.int32, device="cuda"), - "prefill_max_split_seq_len_in_batch" : prefill_max_split_seq_len_in_batch, - "prefill_b_seq_len" : torch.tensor(prefill_b_seq_len, dtype=torch.int32, device="cuda") - } + "input_ids": input_ids, + "decode_req_num": decode_req_num, + "decode_total_token_num": decode_total_token_num, + "decode_b_req_idx": torch.tensor(decode_b_req_idx, dtype=torch.int32, device="cuda"), + "decode_b_start_loc": torch.tensor(decode_b_start_loc, dtype=torch.int32, device="cuda"), + "decode_b_seq_len": torch.tensor(decode_b_seq_len, dtype=torch.int32, device="cuda"), + "decode_max_len_in_batch": decode_max_len_in_batch, + "prefill_req_num": prefill_req_num, + "prefill_b_req_idx": torch.tensor(prefill_b_req_idx, dtype=torch.int32, device="cuda"), + "prefill_b_split_start_loc": torch.tensor(prefill_b_split_start_loc, dtype=torch.int32, device="cuda"), + "prefill_b_split_ready_cache_len": torch.tensor( + prefill_b_split_ready_cache_len, dtype=torch.int32, device="cuda" + ), + "prefill_max_split_seq_len_in_batch": prefill_max_split_seq_len_in_batch, + "prefill_b_seq_len": torch.tensor(prefill_b_seq_len, dtype=torch.int32, device="cuda"), + } + + # dynamic prompt cache 准备 token + if radix_cache is not None: + _free_radix_cache_to_get_enough_token(radix_cache, mem_manager, input_ids.shape[0]) + return kwargs, decode_reqs, prefill_reqs - + + +def _free_radix_cache_to_get_enough_token(radix_cache: RadixCache, mem_manager: MemoryManager, need_token_num): + if need_token_num > mem_manager.can_use_mem_size: + need_evict_token_num = need_token_num - mem_manager.can_use_mem_size + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + radix_cache.evict(need_evict_token_num, release_mem) + mem_index = torch.concat(release_mems) + mem_manager.free(mem_index.cuda()) + return diff --git a/lightllm/server/router/req_queue.py b/lightllm/server/router/req_queue.py index bacde73e9..5a1bd5d85 100644 --- a/lightllm/server/router/req_queue.py +++ b/lightllm/server/router/req_queue.py @@ -2,14 +2,13 @@ import asyncio import numpy as np from typing import List -from ..io_struct import Batch, Req from lightllm.utils.infer_utils import calculate_time -from lightllm.server.io_struct import Req +from lightllm.server.io_struct import Batch, Req from lightllm.server.io_struct import ReqRunStatus, FinishStatus -class ReqQueue: - def __init__(self, args, prompt_cache_used_tokens, prompt_cache_req_num) -> None: +class ReqQueue: + def __init__(self, args) -> None: self.max_total_tokens = args.max_total_token_num assert args.batch_max_tokens is not None self.batch_max_tokens = args.batch_max_tokens @@ -17,75 +16,75 @@ def __init__(self, args, prompt_cache_used_tokens, prompt_cache_req_num) -> None self.waiting_req_list: List[Req] = [] self.router_token_ratio = args.router_token_ratio self.router_max_new_token_len = args.router_max_new_token_len - self.pause_req_dict = {} # 用于保存队列中被暂停的请求,暂停原因为 ReqRunStatus.PAUSED_AND_KVKEEP ReqRunStatus.PAUSED_AND_OFFLOAD + self.pause_req_dict = {} # 用于保存队列中被暂停的请求,暂停原因为 ReqRunStatus.PAUSED_AND_OFFLOAD self.pause_req_used_tokens = 0 self.is_splitfuse_mode = args.splitfuse_mode self.splitfuse_block_size = args.splitfuse_block_size - # 当使用 prompt cache 特性时的维护变量 - self.prompt_cache_used_tokens = prompt_cache_used_tokens - self.prompt_cache_req_num = prompt_cache_req_num - def append(self, req): self.waiting_req_list.append(req) return - - def back_to_wait_list(self, req_list:List[Req]): + + def back_to_wait_list(self, req_list: List[Req]): for req in req_list: - if req.req_status in [ReqRunStatus.PAUSED_AND_KVKEEP, ReqRunStatus.PAUSED_AND_OFFLOAD]: + if req.req_status in [ + ReqRunStatus.PAUSED_AND_OFFLOAD, + ]: self.pause_req_dict[req.request_id] = req self.waiting_req_list = req_list + self.waiting_req_list self.recalcu_pause_req_used_tokens() - return + return - def _init_cache_list(self, current_batch:Batch, is_busy): + def _init_cache_list(self, current_batch: Batch, is_busy): self.cache_pause_reqs_used_tokens = self.pause_req_used_tokens - self.cache_pause_reqs_num = len(self.pause_req_dict) + self.cache_pause_reqs_num = len(self.pause_req_dict) if current_batch is not None: - self.cache_len_list = [req.get_tuple_tokens(is_busy, self.router_max_new_token_len) for req in current_batch.reqs] + self.cache_len_list = [ + req.get_tuple_tokens(is_busy, self.router_max_new_token_len) for req in current_batch.reqs + ] else: self.cache_len_list = [] # @calculate_time(show=True, min_cost_ms=0.1) - def _can_add_new_req(self, req:Req, is_busy): - self.cache_len_list.append(req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) # hard to analysis + def _can_add_new_req(self, req: Req, is_busy): + self.cache_len_list.append(req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) # hard to analysis self.cache_len_list.sort(key=lambda x: -x[1]) - + left_out_len_array = np.array([e[1] for e in self.cache_len_list]) # assert left_out_len_array.min() >= 0 has_run_len_array = np.array([e[0] for e in self.cache_len_list]) cum_run_len_array = np.cumsum(has_run_len_array) size_array = np.arange(1, len(self.cache_len_list) + 1, 1) - + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() - if req.req_status in [ReqRunStatus.PAUSED_AND_KVKEEP, ReqRunStatus.PAUSED_AND_OFFLOAD]: + if req.req_status in [ReqRunStatus.PAUSED_AND_OFFLOAD]: self.cache_pause_reqs_used_tokens -= req.get_used_tokens() self.cache_pause_reqs_num -= 1 - ok_token_num = need_max_token_num < self.max_total_tokens - self.cache_pause_reqs_used_tokens - self.prompt_cache_used_tokens - ok_req_num = len(self.cache_len_list) + self.cache_pause_reqs_num + self.prompt_cache_req_num <= self.running_max_req_size + ok_token_num = need_max_token_num < self.max_total_tokens - self.cache_pause_reqs_used_tokens + ok_req_num = len(self.cache_len_list) + self.cache_pause_reqs_num <= self.running_max_req_size if ok_token_num and ok_req_num: return True else: return False - - #@calculate_time(show=True, min_cost_ms=10) - def generate_new_batch(self, current_batch:Batch): + + # @calculate_time(show=True, min_cost_ms=10) + def generate_new_batch(self, current_batch: Batch): # 如果当前已经被调度的请求数量超过了上限,直接不调度新的请求了。 - exist_req_num = self.prompt_cache_req_num + exist_req_num = 0 exist_req_num += 0 if current_batch is None else len(current_batch.reqs) exist_req_num += len(self.pause_req_dict) req_is_full = exist_req_num >= self.running_max_req_size if req_is_full: return None - + # 计算当前所有的token使用量,包括当前使用和暂停的 cur_all_used_tokens = 0 if current_batch is None else current_batch.batch_used_tokens - cur_all_used_tokens += self.recalcu_pause_req_used_tokens() + self.prompt_cache_used_tokens - + cur_all_used_tokens += self.recalcu_pause_req_used_tokens() + # 判断当前服务是否处于token使用率过高的状态,过高的情况下,调度要偏向保守 cur_token_ratio = cur_all_used_tokens / self.max_total_tokens is_busy = cur_token_ratio >= self.router_token_ratio @@ -97,39 +96,42 @@ def generate_new_batch(self, current_batch:Batch): cur_batch_decode_need_tokens = 0 else: cur_batch_decode_need_tokens = 0 if current_batch is None else current_batch.batch_decode_need_tokens - + self._init_cache_list(current_batch, is_busy) can_run_list = [] - new_batch_first_router_need_tokens = 0 # 主要是对 prefill 或者 splitfuse 大块计算时候的限制 + new_batch_first_router_need_tokens = 0 # 主要是对 prefill 或者 splitfuse 大块计算时候的限制 aborted_count = 0 for req in self.waiting_req_list: - if req.finish_status.is_aborted() and req.req_status == ReqRunStatus.WAIT_IN_QUEUE: - # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. + if req.finish_status.is_aborted() and req.req_status == ReqRunStatus.WAIT_IN_QUEUE: + # 由于管理的复杂性,只有没有被调度运行过的请求可以因为abort直接在队列中忽略掉. # 暂停的请求需要恢复后,由 router manager 部分来过滤。暂时保持这种处理方法, 否则会导致管理token的泄漏 aborted_count += 1 continue req_first_router_need_tokens = req.get_first_router_need_tokens() - if self._can_add_new_req(req, is_busy) and cur_batch_decode_need_tokens + new_batch_first_router_need_tokens + req_first_router_need_tokens <= self.batch_max_tokens: + if ( + self._can_add_new_req(req, is_busy) + and cur_batch_decode_need_tokens + new_batch_first_router_need_tokens + req_first_router_need_tokens + <= self.batch_max_tokens + ): can_run_list.append(req) new_batch_first_router_need_tokens += req_first_router_need_tokens - if req.req_status in [ReqRunStatus.PAUSED_AND_KVKEEP, ReqRunStatus.PAUSED_AND_OFFLOAD]: + if req.req_status in [ReqRunStatus.PAUSED_AND_OFFLOAD]: self.pause_req_dict.pop(req.request_id) else: break if len(can_run_list) != 0: new_batch = Batch(uuid.uuid4().hex, can_run_list) - self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count:] + self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] # 生成新 batch 以后,更新一下状态 self.recalcu_pause_req_used_tokens() return new_batch else: return None - + def recalcu_pause_req_used_tokens(self): used_tokens = 0 for req_id, req_obj in self.pause_req_dict.items(): used_tokens += req_obj.get_used_tokens() self.pause_req_used_tokens = used_tokens return self.pause_req_used_tokens - diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 0275872d3..274d568bc 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -18,8 +18,7 @@ from typing import List, Tuple, Union -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.convert_slow_tokenizer import convert_slow_tokenizer from transformers import LlamaTokenizer from transformers.configuration_utils import PretrainedConfig @@ -44,8 +43,7 @@ def get_tokenizer( """Gets a tokenizer for the given model name via Huggingface.""" if tokenizer_mode == "slow": if kwargs.get("use_fast", False): - raise ValueError( - "Cannot use the fast tokenizer in slow tokenizer mode.") + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True): @@ -53,20 +51,20 @@ def get_tokenizer( "For some LLaMA-based models, initializing the fast tokenizer may " "take a long time. To eliminate the initialization time, consider " f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer.") + "tokenizer." + ) # tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name) # tokenizer = convert_slow_tokenizer(tokenizer) # return tokenizer try: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=trust_remote_code, *args, - **kwargs) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs) except TypeError as e: # The LLaMA tokenizer causes a protobuf error in some environments, using slow mode. - # you can try pip install protobuf==3.20.0 to try repair + # you can try pip install protobuf==3.20.0 to try repair + logger.warning(f"load fast tokenizer fail: {str(e)}") kwargs["use_fast"] = False - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=trust_remote_code, *args, - **kwargs) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs) model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name) model_type = model_cfg.get("model_type", "") @@ -78,6 +76,6 @@ def get_tokenizer( if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.info( "Using a slow tokenizer. This might cause a significant " - "slowdown. Consider using a fast tokenizer instead.") + "slowdown. Consider using a fast tokenizer instead." + ) return tokenizer - diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index caa2f2ead..91e309ffa 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -3,10 +3,12 @@ import asyncio import uvloop import rpyc +from typing import List from transformers import AutoConfig from ..io_struct import AbortReq from ..embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed from rpyc.utils.classic import obtain + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from .model_infer.model_rpc import start_model_process, VisualModelRpcClient from io import BytesIO @@ -31,14 +33,14 @@ def __init__( self.recv_from_httpserver = context.socket(zmq.PULL) self.recv_from_httpserver.bind(f"tcp://127.0.0.1:{visual_port}") self.cache_client = rpyc.connect("localhost", client_port) - self.waiting_reqs = [] + self.waiting_reqs = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp self.world_size = 1 self.infer_batch_size = infer_batch_size - self.trust_remote_code=args.trust_remote_code + self.trust_remote_code = args.trust_remote_code self.args = args - + async def wait_to_model_ready(self): self.model_rpcs: List[VisualModelRpcClient] = [] @@ -49,8 +51,8 @@ async def wait_to_model_ready(self): init_model_ret = [] for rank_id in range(self.world_size): # async init model process kvargs = { - "weight_dir" : self.model_weightdir, - "trust_remote_code" : self.trust_remote_code, + "weight_dir": self.model_weightdir, + "trust_remote_code": self.trust_remote_code, "rank_id": rank_id, } init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) @@ -78,7 +80,7 @@ async def infer_imgs(self, uuids): else: img_embed = ans[0] torch.cuda.synchronize() - b = time.time() + # b = time.time() for i in range(len(uuids)): # print(" + set_item_embed:", uuids[i], img_embed[i].shape) if not self.cache_client.root.get_item_embed(uuids[i]): @@ -97,7 +99,7 @@ async def loop_for_fwd(self): uuids_need_infer = [] while cur_batch_size < self.infer_batch_size and len(self.waiting_reqs) > 0: req = self.waiting_reqs.pop(0) - _, _, multimodal_params, _, _, _ = req + _, _, multimodal_params, _ = req for img in multimodal_params.images: if not self.cache_client.root.get_item_embed(img.uuid): cur_batch_size += 1 @@ -116,7 +118,7 @@ async def loop_for_fwd(self): async def loop_for_netio_req(self): while True: recv_req = await self.recv_from_httpserver.recv_pyobj() - if isinstance(recv_req, tuple) and len(recv_req) == 6: + if isinstance(recv_req, tuple) and len(recv_req) == 4: self.waiting_reqs.append(recv_req) elif isinstance(recv_req, AbortReq): abort_req = recv_req @@ -132,21 +134,19 @@ def clean_up(self): model_rpc.rpc_server_process.join() return + def start_visual_process(args, router_port, visual_port, client_port, pipe_writer): - try: - visualserver = VisualManager( - args, - router_port, - visual_port, - client_port) + try: + visualserver = VisualManager(args, router_port, visual_port, client_port) asyncio.run(visualserver.wait_to_model_ready()) except Exception as e: import traceback - err_str = '\n'.join(traceback.format_exception(e)) + + err_str = "\n".join(traceback.format_exception(e)) pipe_writer.send(err_str) visualserver.clean_up() raise - pipe_writer.send('init ok') + pipe_writer.send("init ok") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.create_task(visualserver.loop_for_fwd()) diff --git a/test/model/model_infer.py b/test/model/model_infer.py index dc627571e..290161d7b 100644 --- a/test/model/model_infer.py +++ b/test/model/model_infer.py @@ -2,6 +2,7 @@ from multiprocessing import Queue import multiprocessing + def test_model_inference(world_size, model_dir, model_class, batch_size, input_len, output_len, mode): ans_queue = Queue() workers = [] @@ -10,14 +11,16 @@ def test_model_inference(world_size, model_dir, model_class, batch_size, input_l "tp_rank": rank_id, "world_size": world_size, "weight_dir": model_dir, - "max_total_token_num":batch_size * (input_len + output_len), + "max_total_token_num": batch_size * (input_len + output_len), "load_way": "HF", "mode": mode, "max_req_num": batch_size, - "max_seq_length": (input_len + output_len) + "max_seq_length": (input_len + output_len), } - - proc = multiprocessing.Process(target=tppart_model_infer, args=(model_class, model_kvargs, batch_size, input_len, output_len, ans_queue)) + + proc = multiprocessing.Process( + target=tppart_model_infer, args=(model_class, model_kvargs, batch_size, input_len, output_len, ans_queue) + ) proc.start() workers.append(proc) @@ -27,19 +30,21 @@ def test_model_inference(world_size, model_dir, model_class, batch_size, input_l assert not ans_queue.empty() while not ans_queue.empty(): assert ans_queue.get() - return + return def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_len, ans_queue): import torch import torch.distributed as dist + rank_id = model_kvargs["tp_rank"] world_size = model_kvargs["world_size"] - dist.init_process_group('nccl', init_method='tcp://127.0.0.1:28765', rank=rank_id, world_size=world_size) + dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size) torch.cuda.set_device(rank_id) import torch.distributed as dist + dist.barrier() torch.cuda.empty_cache() @@ -52,19 +57,23 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_ b_req_idx = model_part.req_manager.alloc(batch_size).int() b_start_loc = torch.zeros(batch_size, dtype=torch.int32, device="cuda") b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") for i in range(batch_size): b_start_loc[i] = i * input_len b_seq_len[i] = input_len total_token_num = input_len * batch_size - logics = model_part.forward(batch_size, - total_token_num, - input_len, - test_data, - b_req_idx, - b_start_loc, - b_seq_len, - is_prefill=True) + logics = model_part.forward( + batch_size, + total_token_num, + input_len, + test_data, + b_req_idx, + b_start_loc, + b_seq_len, + b_ready_cache_len=b_ready_cache_len, + is_prefill=True, + ) prob_out = torch.softmax(logics, dim=-1) predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) predict_ids = predict_ids.detach().cpu().numpy() @@ -73,25 +82,34 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_ b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") total_token_num += batch_size b_seq_len += 1 - logics = model_part.forward(batch_size, total_token_num, input_len + i + 1, torch.from_numpy( - predict_ids).cuda().reshape(-1), b_req_idx, b_start_loc, b_seq_len, is_prefill=False) + logics = model_part.forward( + batch_size, + total_token_num, + input_len + i + 1, + torch.from_numpy(predict_ids).cuda().reshape(-1), + b_req_idx, + b_start_loc, + b_seq_len, + is_prefill=False, + ) prob_out = torch.softmax(logics, dim=-1) predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) predict_ids = predict_ids.detach().cpu().numpy() model_part.mem_manager.free_all() model_part.req_manager.free_all() - + if rank_id == 0: print("can use mem size:", model_part.mem_manager.can_use_mem_size) print("can use req size:", model_part.req_manager.can_use_req_size) - + b_req_idx = None b_start_loc = None b_seq_len = None - + dist.barrier() import time + torch.cuda.synchronize() start_time = time.time() @@ -105,8 +123,17 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_ b_seq_len[i] = input_len total_token_num = batch_size * input_len - logics = model_part.forward(batch_size, total_token_num, input_len, test_data, - b_req_idx, b_start_loc, b_seq_len, is_prefill=True) + logics = model_part.forward( + batch_size, + total_token_num, + input_len, + test_data, + b_req_idx, + b_start_loc, + b_seq_len, + b_ready_cache_len=b_ready_cache_len, + is_prefill=True, + ) prob_out = torch.softmax(logics, dim=-1) predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) predict_ids = predict_ids.detach().cpu().numpy() @@ -122,8 +149,16 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_ total_token_num += batch_size b_seq_len += 1 - logics = model_part.forward(batch_size, total_token_num, input_len + i + 1, torch.from_numpy( - predict_ids).cuda().reshape(-1), b_req_idx, b_start_loc, b_seq_len, is_prefill=False) + logics = model_part.forward( + batch_size, + total_token_num, + input_len + i + 1, + torch.from_numpy(predict_ids).cuda().reshape(-1), + b_req_idx, + b_start_loc, + b_seq_len, + is_prefill=False, + ) prob_out = torch.softmax(logics, dim=-1) predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) predict_ids = predict_ids.detach().cpu().numpy() @@ -140,5 +175,3 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_ ans_queue.put(True) return - - diff --git a/test/model/test_bloom.py b/test/model/test_bloom.py index 95c68770c..89020fc66 100644 --- a/test/model/test_bloom.py +++ b/test/model/test_bloom.py @@ -2,21 +2,28 @@ import sys import unittest from model_infer import test_model_inference + sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) -class TestBloomInfer(unittest.TestCase): +class TestBloomInfer(unittest.TestCase): def test_bloom_infer(self): from lightllm.models.bloom.model import BloomTpPartModel - test_model_inference(world_size=1, - model_dir="/path/bloom-7b", - model_class=BloomTpPartModel, - batch_size=20, - input_len=1024, - output_len=1024, - mode=[]) + + test_model_inference( + world_size=1, + model_dir="/path/bloom-7b", + model_class=BloomTpPartModel, + batch_size=20, + input_len=1024, + output_len=1024, + mode=[], + ) return -if __name__ == '__main__': +if __name__ == "__main__": + import torch + + torch.multiprocessing.set_start_method("spawn"), # this code will not be ok for settings to fork to subprocess unittest.main() diff --git a/test/model/test_chatglm2.py b/test/model/test_chatglm2.py index 2594d033d..67c34e96f 100644 --- a/test/model/test_chatglm2.py +++ b/test/model/test_chatglm2.py @@ -2,21 +2,28 @@ import sys import unittest from model_infer import test_model_inference + sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) -class TestChatglm2Infer(unittest.TestCase): +class TestChatglm2Infer(unittest.TestCase): def test_chatglm2_infer(self): from lightllm.models.chatglm2.model import ChatGlm2TpPartModel - test_model_inference(world_size=1, - model_dir="/nvme/baishihao/chatglm2-6b/", - model_class=ChatGlm2TpPartModel, - batch_size=20, - input_len=1024, - output_len=1024, - mode=[]) + + test_model_inference( + world_size=1, + model_dir="/nvme/xxx/chatglm2-6b/", + model_class=ChatGlm2TpPartModel, + batch_size=20, + input_len=1024, + output_len=1024, + mode=[], + ) return -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + import torch + + torch.multiprocessing.set_start_method("spawn") + unittest.main() diff --git a/test/model/test_intern.py b/test/model/test_intern.py index 7a3e4fac1..d36b71604 100644 --- a/test/model/test_intern.py +++ b/test/model/test_intern.py @@ -2,21 +2,28 @@ import sys import unittest from model_infer import test_model_inference + sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) -class TestInternlmInfer(unittest.TestCase): +class TestInternlmInfer(unittest.TestCase): def test_internlm_infer(self): from lightllm.models.internlm.model import InternlmTpPartModel - test_model_inference(world_size=8, - model_dir="/path/internlm-chat-7b/", - model_class=InternlmTpPartModel, - batch_size=20, - input_len=1024, - output_len=1024, - mode=[]) + + test_model_inference( + world_size=8, + model_dir="/path/internlm-chat-7b/", + model_class=InternlmTpPartModel, + batch_size=20, + input_len=1024, + output_len=1024, + mode=[], + ) return -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + import torch + + torch.multiprocessing.set_start_method("spawn") + unittest.main() diff --git a/test/model/test_llama.py b/test/model/test_llama.py index 881c1497c..939ad1db2 100644 --- a/test/model/test_llama.py +++ b/test/model/test_llama.py @@ -2,21 +2,28 @@ import sys import unittest from model_infer import test_model_inference + sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) -class TestLlamaInfer(unittest.TestCase): +class TestLlamaInfer(unittest.TestCase): def test_llama_infer(self): from lightllm.models.llama.model import LlamaTpPartModel - test_model_inference(world_size=8, - model_dir="/path/to/llama-7b", - model_class=LlamaTpPartModel, - batch_size=20, - input_len=1024, - output_len=1024, - mode=[]) + + test_model_inference( + world_size=8, + model_dir="/path/to/llama-7b", + model_class=LlamaTpPartModel, + batch_size=20, + input_len=1024, + output_len=1024, + mode=[], + ) return -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + import torch + + torch.multiprocessing.set_start_method("spawn") + unittest.main() diff --git a/test/model/test_llama2.py b/test/model/test_llama2.py index 9af386677..7eed9d331 100644 --- a/test/model/test_llama2.py +++ b/test/model/test_llama2.py @@ -2,21 +2,28 @@ import sys import unittest from model_infer import test_model_inference + sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) -class TestLlama2Infer(unittest.TestCase): +class TestLlama2Infer(unittest.TestCase): def test_llama2_infer(self): from lightllm.models.llama.model import LlamaTpPartModel - test_model_inference(world_size=8, - model_dir="/path/llama2-7b-chat", - model_class=LlamaTpPartModel, - batch_size=20, - input_len=1024, - output_len=1024, - mode=[]) + + test_model_inference( + world_size=8, + model_dir="/path/llama2-7b-chat", + model_class=LlamaTpPartModel, + batch_size=20, + input_len=1024, + output_len=1024, + mode=[], + ) return -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + import torch + + torch.multiprocessing.set_start_method("spawn") + unittest.main() diff --git a/test/model/test_starcoder.py b/test/model/test_starcoder.py index af89c4eff..a1074977d 100644 --- a/test/model/test_starcoder.py +++ b/test/model/test_starcoder.py @@ -2,21 +2,28 @@ import sys import unittest from model_infer import test_model_inference + sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) -class TestStarcoderInfer(unittest.TestCase): +class TestStarcoderInfer(unittest.TestCase): def test_starcoder_infer(self): from lightllm.models.starcoder.model import StarcoderTpPartModel - test_model_inference(world_size=1, - model_dir="/path/xxxx", - model_class=StarcoderTpPartModel, - batch_size=20, - input_len=1024, - output_len=1024, - mode=[]) + + test_model_inference( + world_size=1, + model_dir="/path/xxxx", + model_class=StarcoderTpPartModel, + batch_size=20, + input_len=1024, + output_len=1024, + mode=[], + ) return -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + import torch + + torch.multiprocessing.set_start_method("spawn") + unittest.main() diff --git a/test/model/test_starcoder_quantized.py b/test/model/test_starcoder_quantized.py index 88a6f49af..d42b12174 100644 --- a/test/model/test_starcoder_quantized.py +++ b/test/model/test_starcoder_quantized.py @@ -8,18 +8,23 @@ class TestStarcoderInfer(unittest.TestCase): - def test_starcoder_infer(self): from lightllm.models.starcoder_wquant.model import StarcoderTpPartModelWQuant - test_model_inference(world_size=1, - model_dir="/path/xxxx", - model_class=StarcoderTpPartModelWQuant, - batch_size=2, - input_len=10, - output_len=10, - mode=["triton_int8weight"]) + + test_model_inference( + world_size=1, + model_dir="/path/xxxx", + model_class=StarcoderTpPartModelWQuant, + batch_size=2, + input_len=10, + output_len=10, + mode=["triton_int8weight"], + ) return -if __name__ == '__main__': +if __name__ == "__main__": + import torch + + torch.multiprocessing.set_start_method("spawn") unittest.main()