From 15a050abac1f0cd186dafada4c131a504c6a9d64 Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Wed, 10 Apr 2024 22:28:27 +0800 Subject: [PATCH] Add bf16 inference for llm model (#387) Co-authored-by: baishihao Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> --- lightllm/common/basemodel/basemodel.py | 28 ++++++++++++----- .../basemodel/layer_weights/hf_load_utils.py | 3 +- .../pre_and_post_layer_weight.py | 4 +-- .../bloom/layer_infer/post_layer_infer.py | 4 +-- .../bloom/layer_weights/hf_load_utils.py | 3 +- lightllm/models/bloom/model.py | 6 ++-- .../token_attention_nopad_reduceV.py | 2 +- lightllm/models/chatglm2/model.py | 4 +-- lightllm/models/gemma_2b/model.py | 6 ++-- .../gemma_2b/triton_kernel/gelu_and_mul.py | 2 +- .../llama/layer_infer/post_layer_infer.py | 6 ++-- lightllm/models/llama/model.py | 30 +++++++++---------- .../models/llama/triton_kernel/rmsnorm.py | 2 +- .../llama/triton_kernel/silu_and_mul.py | 2 +- .../splitfuse_context_flashattention_nopad.py | 4 +-- lightllm/models/mistral/model.py | 6 ++-- .../token_attention_nopad_reduceV.py | 2 +- lightllm/models/mixtral/model.py | 6 ++-- lightllm/models/qwen/model.py | 4 +-- lightllm/models/qwen2/model.py | 6 ++-- lightllm/models/starcoder/model.py | 2 +- lightllm/models/starcoder2/model.py | 6 ++-- lightllm/server/api_server.py | 5 ++++ lightllm/server/router/manager.py | 1 + .../server/router/model_infer/model_rpc.py | 2 +- lightllm/server/tokenizer.py | 1 - 26 files changed, 83 insertions(+), 64 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index eda614d8f..22bf2493b 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -45,7 +45,9 @@ def __init__(self, kvargs): self.max_seq_length = kvargs.get("max_seq_length", 1024 * 5) self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False) self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False) - + self.data_type = kvargs.get("data_type", "float16") + + self._init_datatype() self._init_config() self._verify_must() self._verify_params() @@ -80,16 +82,16 @@ def _verify_params(self): def _init_weights(self): self.pre_post_weight = self.pre_and_post_weight_class( - self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode + self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode ) self.trans_layers_weight = [ self.transformer_weight_class( - i, self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode + i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode ) for i in range(self.config["n_layer"]) ] load_hf_weights( - "fp16", + self.data_type, weight_dir=self.weight_dir_, pre_post_layer=self.pre_post_weight, transformer_layer_list=self.trans_layers_weight, @@ -103,7 +105,7 @@ def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.world_size_ == 0 self.mem_manager = MemoryManager( self.max_total_token_num, - dtype=torch.float16, + dtype=self.data_type, head_num=self.config["num_attention_heads"] // self.world_size_, head_dim=self.config["n_embed"] // self.config["num_attention_heads"], layer_num=self.config["n_layer"], @@ -137,6 +139,16 @@ def _init_some_value(self): self.vocab_size = self.config["vocab_size"] return + def _init_datatype(self): + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type =torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + def _init_custom(self): pass @@ -223,7 +235,7 @@ def _prefill( infer_state.mem_index = alloc_mem infer_state.kv_buffer = torch.empty( (input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), - dtype=torch.float16, + dtype=self.data_type, device="cuda", ) @@ -279,7 +291,7 @@ def _decode( infer_state.mem_index = alloc_mem infer_state.kv_buffer = torch.empty( (batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), - dtype=torch.float16, + dtype=self.data_type, device="cuda", ) copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index) @@ -341,7 +353,7 @@ def splitfuse_forward( infer_state.mem_index = alloc_mem infer_state.kv_buffer = torch.empty( (alloc_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_), - dtype=torch.float16, + dtype=self.data_type, device="cuda", ) diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index 8a46c9224..958529db4 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -27,7 +27,8 @@ def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_lay def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): - data_type = torch.float16 if data_type == 'fp16' else torch.float32 + if isinstance(data_type, str): + data_type = torch.float16 if data_type == 'fp16' else torch.float32 if pre_post_layer is not None: assert pre_post_layer.data_type_ == data_type, "type is not right" if transformer_layer_list is not None: diff --git a/lightllm/models/baichuan2_7b/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/baichuan2_7b/layer_weights/pre_and_post_layer_weight.py index 08778c5e8..641c993d9 100644 --- a/lightllm/models/baichuan2_7b/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/baichuan2_7b/layer_weights/pre_and_post_layer_weight.py @@ -21,8 +21,8 @@ def load_hf_weights(self, weights): (self.tp_rank_ + 1), :]) if 'lm_head.weight' in weights: # print(weights['lm_head.weight'].shape) - self.lm_head_weight_ = nn.functional.normalize(weights['lm_head.weight'].to( - torch.float16).cuda())[split_vob_size * self.tp_rank_:split_vob_size * (self.tp_rank_ + 1), :] + self.lm_head_weight_ = self._cuda( + nn.functional.normalize(weights['lm_head.weight'])[split_vob_size * self.tp_rank_:split_vob_size * (self.tp_rank_ + 1), :]) if 'model.norm.weight' in weights: self.final_norm_weight_ = self._cuda(weights['model.norm.weight']) diff --git a/lightllm/models/bloom/layer_infer/post_layer_infer.py b/lightllm/models/bloom/layer_infer/post_layer_infer.py index bbff0bff5..cdb637fa2 100644 --- a/lightllm/models/bloom/layer_infer/post_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/post_layer_infer.py @@ -29,7 +29,7 @@ def soft_max(self, data): def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight, return_logics=False): batch_size = infer_state.batch_size - last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=torch.float16) + last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype) if infer_state.is_prefill: last_index = torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 last_input[:, :] = input_embdings[last_index, :] @@ -44,7 +44,7 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh if self.world_size_ == 1: gather_data = logic_batch else: - gather_data = torch.empty((self.vocab_size_, batch_size), device=logic_batch.device, dtype=torch.float16) + gather_data = torch.empty((self.vocab_size_, batch_size), device=logic_batch.device, dtype=input_embdings.dtype) split_size = self.vocab_size_ // self.world_size_ dist.all_gather([gather_data[i * split_size: (i + 1) * split_size, :] for i in range(self.world_size_)], logic_batch, group=None, async_op=False) diff --git a/lightllm/models/bloom/layer_weights/hf_load_utils.py b/lightllm/models/bloom/layer_weights/hf_load_utils.py index b29f9d079..01c4c5862 100755 --- a/lightllm/models/bloom/layer_weights/hf_load_utils.py +++ b/lightllm/models/bloom/layer_weights/hf_load_utils.py @@ -5,7 +5,8 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None): - data_type = torch.float16 if data_type == 'fp16' else torch.float32 + if isinstance(data_type, str): + data_type = torch.float16 if data_type == 'fp16' else torch.float32 if pre_post_layer is not None: assert pre_post_layer.data_type_ == data_type, "type is not right" if transformer_layer_list is not None: diff --git a/lightllm/models/bloom/model.py b/lightllm/models/bloom/model.py index c1aa83fb5..2730d45ab 100644 --- a/lightllm/models/bloom/model.py +++ b/lightllm/models/bloom/model.py @@ -40,13 +40,13 @@ def _reset_num_key_value_heads(self): return def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode) + self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode) self.trans_layers_weight = [ - self.transformer_weight_class(i, self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode) + self.transformer_weight_class(i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode) for i in range(self.config["n_layer"]) ] load_hf_weights( - "fp16", + self.data_type, weight_dir=self.weight_dir_, pre_post_layer=self.pre_post_weight, transformer_layer_list=self.trans_layers_weight, diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py index 1f437c1e0..52621be76 100644 --- a/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py +++ b/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py @@ -37,7 +37,7 @@ def _fwd_kernel_token_att2( v_value = tl.load(V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) acc += tl.sum(p_value[:, None] * v_value, 0) - acc = acc.to(tl.float16) + acc = acc.to(Out.dtype.element_ty) off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc) diff --git a/lightllm/models/chatglm2/model.py b/lightllm/models/chatglm2/model.py index 6675c0842..de9d507ce 100644 --- a/lightllm/models/chatglm2/model.py +++ b/lightllm/models/chatglm2/model.py @@ -70,6 +70,6 @@ def _init_to_get_rotary(self, base=10000): t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return diff --git a/lightllm/models/gemma_2b/model.py b/lightllm/models/gemma_2b/model.py index 329a37098..4d1a31dd9 100644 --- a/lightllm/models/gemma_2b/model.py +++ b/lightllm/models/gemma_2b/model.py @@ -45,7 +45,7 @@ def _init_custom(self): def _init_mem_manager(self): self.mem_manager = MemoryManager(self.max_total_token_num, - dtype=torch.float16, + dtype=self.data_type, head_num=self.config["num_key_value_heads"], # [SYM] always == 1 head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], layer_num=self.config["num_hidden_layers"]) @@ -73,7 +73,7 @@ def _init_to_get_rotary(self, default_base=10000): t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return \ No newline at end of file diff --git a/lightllm/models/gemma_2b/triton_kernel/gelu_and_mul.py b/lightllm/models/gemma_2b/triton_kernel/gelu_and_mul.py index 05a48fc88..4b08165f2 100644 --- a/lightllm/models/gemma_2b/triton_kernel/gelu_and_mul.py +++ b/lightllm/models/gemma_2b/triton_kernel/gelu_and_mul.py @@ -56,7 +56,7 @@ def _gelu_and_mul_kernel( ).to(tl.float32) gate = gelu(gate) - gate = gate.to(tl.float16) + gate = gate.to(input_ptr.dtype.element_ty) tl.store( input_ptr + res_offsets, diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 031568739..f12ac8264 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -30,7 +30,7 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo if infer_state.is_splitfuse: # for SplitFuse batch_size = infer_state.batch_size - last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=torch.float16) + last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype) tmp_ = torch.cat( [ torch.ones(infer_state.decode_req_num, dtype=torch.int32, device="cuda"), @@ -44,7 +44,7 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo if not infer_state.is_splitfuse and infer_state.is_prefill and not infer_state.return_all_prompt_logprobs: batch_size = infer_state.batch_size - last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=torch.float16) + last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype) last_index = ( torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 ) @@ -81,7 +81,7 @@ def token_forward( if self.world_size_ == 1: gather_data = logic_batch else: - gather_data = torch.empty((self.vocab_size_, token_num), device=logic_batch.device, dtype=torch.float16) + gather_data = torch.empty((self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings.dtype) split_indexes = np.linspace(0, self.vocab_size_, self.world_size_ + 1, dtype=np.int64) dist.all_gather( [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.world_size_)], diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 8b5a4d549..7ec4ab3bb 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -55,7 +55,7 @@ def _verify_params(self): def _init_mem_manager(self): self.mem_manager = select_mem_manager_class(self.mode)(self.max_total_token_num, - dtype=torch.float16, + dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.world_size_, head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], layer_num=self.config["num_hidden_layers"]) @@ -74,21 +74,21 @@ def _init_custom(self): return def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode) + self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode) self.trans_layers_weight = [ - self.transformer_weight_class(i, self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode) + self.transformer_weight_class(i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode) for i in range(self.config["n_layer"]) ] if self.load_way == 'HF': load_hf_weights( - "fp16", + self.data_type, weight_dir=self.weight_dir_, pre_post_layer=self.pre_post_weight, transformer_layer_list=self.trans_layers_weight, weight_dict=self.weight_dict) else: load_ds_weights( - "fp16", + self.data_type, weight_dir=self.weight_dir_, pre_post_layer=self.pre_post_weight, transformer_layer_list=self.trans_layers_weight, @@ -132,8 +132,8 @@ def _init_to_get_rotary(self, default_base=10000): t = torch.arange(max_seq_len + 1024 * 128, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return def _init_to_get_dynamic_ntk_rotary(self): @@ -145,22 +145,22 @@ def _init_to_get_dynamic_ntk_rotary(self): else: scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) max_seq_len = max(self.max_seq_length, max_position_embeddings) - self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda") - self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda") + self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda") + self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda") inv_freq = 1.0 / (base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)) t = torch.arange(max_position_embeddings, device="cpu", dtype=torch.float32) freqs = torch.outer(t, inv_freq) - self._cos_cached[0:max_position_embeddings, :] = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached[0:max_position_embeddings, :] = torch.sin(freqs).to(torch.float16).cuda() + self._cos_cached[0:max_position_embeddings, :] = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached[0:max_position_embeddings, :] = torch.sin(freqs).to(self.data_type).cuda() for seq_loc_index in range(max_position_embeddings, max_seq_len, 1): new_base = base * ((scaling_factor * (seq_loc_index + 1) / max_position_embeddings) -(scaling_factor - 1)) ** (partial_head_dim / (partial_head_dim - 2)) inv_freq = 1.0 / (new_base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)) t = torch.tensor([seq_loc_index,], device="cpu", dtype=torch.float32) freqs = torch.outer(t, inv_freq) - self._cos_cached[seq_loc_index:seq_loc_index + 1, :] = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached[seq_loc_index:seq_loc_index + 1, :] = torch.sin(freqs).to(torch.float16).cuda() + self._cos_cached[seq_loc_index:seq_loc_index + 1, :] = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached[seq_loc_index:seq_loc_index + 1, :] = torch.sin(freqs).to(self.data_type).cuda() return def _init_to_get_yarn_rotary(self): @@ -194,8 +194,8 @@ def _init_to_get_yarn_rotary(self): freqs = torch.einsum("i,j->ij", t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self._cos_cached = emb.cos().to(torch.float16).cuda() * mscale - self._sin_cached = emb.sin().to(torch.float16).cuda() * mscale + self._cos_cached = emb.cos().to(self.data_type).cuda() * mscale + self._sin_cached = emb.sin().to(self.data_type).cuda() * mscale return diff --git a/lightllm/models/llama/triton_kernel/rmsnorm.py b/lightllm/models/llama/triton_kernel/rmsnorm.py index d42a329b3..80b88ae34 100644 --- a/lightllm/models/llama/triton_kernel/rmsnorm.py +++ b/lightllm/models/llama/triton_kernel/rmsnorm.py @@ -35,7 +35,7 @@ def _rms_norm_fwd_fused( x_hat = x * rstd y = x_hat * w # Write output - tl.store(Y + cols, y.to(tl.float16), mask=mask) + tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask) def rmsnorm_forward(x, weight, eps): diff --git a/lightllm/models/llama/triton_kernel/silu_and_mul.py b/lightllm/models/llama/triton_kernel/silu_and_mul.py index a7aaa0fc3..743532d33 100644 --- a/lightllm/models/llama/triton_kernel/silu_and_mul.py +++ b/lightllm/models/llama/triton_kernel/silu_and_mul.py @@ -44,7 +44,7 @@ def _silu_and_mul_kernel( ).to(tl.float32) gate = gate / (1 + tl.exp(-gate)) - gate = gate.to(tl.float16) + gate = gate.to(input_ptr.dtype.element_ty) tl.store( input_ptr + res_offsets, diff --git a/lightllm/models/llama/triton_kernel/splitfuse_context_flashattention_nopad.py b/lightllm/models/llama/triton_kernel/splitfuse_context_flashattention_nopad.py index fc5b228b6..77d9d6b72 100644 --- a/lightllm/models/llama/triton_kernel/splitfuse_context_flashattention_nopad.py +++ b/lightllm/models/llama/triton_kernel/splitfuse_context_flashattention_nopad.py @@ -301,8 +301,8 @@ def _fwd_kernel_int8( vs_ptrs + kv_loc[:, None] * stride_vsbs, mask=(start_n + offs_n)[:, None] < cur_batch_seq_len, other=0.0 ) - p = p.to(tl.float16) - acc += tl.dot(p, v.to(tl.float16) * v_scale) + p = p.to(V.dtype.element_ty) + acc += tl.dot(p, v.to(V.dtype.element_ty) * v_scale) # update m_i and l_i l_i = l_i_new diff --git a/lightllm/models/mistral/model.py b/lightllm/models/mistral/model.py index 404efedea..41c103ccd 100644 --- a/lightllm/models/mistral/model.py +++ b/lightllm/models/mistral/model.py @@ -51,7 +51,7 @@ def _init_custom(self): def _init_mem_manager(self): self.mem_manager = MemoryManager(self.max_total_token_num, # [SYM] should be sliding window? - dtype=torch.float16, + dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.world_size_, head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], layer_num=self.config["num_hidden_layers"], @@ -79,7 +79,7 @@ def _init_to_get_rotary(self, default_base=10000): t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return \ No newline at end of file diff --git a/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py index 6031085a4..8019820b2 100644 --- a/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py +++ b/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py @@ -41,7 +41,7 @@ def _fwd_kernel_token_att2( v_value = tl.load(V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None] + cur_batch_start_index) < cur_batch_seq_len, other=0.0) # [1, D] + [64, 1] = [64, D] acc += tl.sum(p_value[:, None] * v_value, 0) # [64, 1] * [64, D] = [64, D] -> [D] - acc = acc.to(tl.float16) + acc = acc.to(Out.dtype.element_ty) off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc) diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py index 03970bc73..d7b874ee0 100644 --- a/lightllm/models/mixtral/model.py +++ b/lightllm/models/mixtral/model.py @@ -53,7 +53,7 @@ def _init_custom(self): def _init_mem_manager(self): self.mem_manager = MemoryManager(self.max_total_token_num, - dtype=torch.float16, + dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.world_size_, head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], layer_num=self.config["num_hidden_layers"], @@ -82,7 +82,7 @@ def _init_to_get_rotary(self, default_base=10000): t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return \ No newline at end of file diff --git a/lightllm/models/qwen/model.py b/lightllm/models/qwen/model.py index 975357ed4..44c0a0e30 100644 --- a/lightllm/models/qwen/model.py +++ b/lightllm/models/qwen/model.py @@ -78,8 +78,8 @@ def _init_qwen_dynamic_ntk(self): t = torch.arange(total_seq_len_supported + 128 * 1024, device="cpu", dtype=torch.float32) freqs = torch.outer(t, inv_freq) - self._cos_cached.append(torch.cos(freqs).to(torch.float16).cuda()) - self._sin_cached.append(torch.sin(freqs).to(torch.float16).cuda()) + self._cos_cached.append(torch.cos(freqs).to(self.data_type).cuda()) + self._sin_cached.append(torch.sin(freqs).to(self.data_type).cuda()) self._cos_cached = torch.stack(self._cos_cached, dim=0).contiguous() self._sin_cached = torch.stack(self._sin_cached, dim=0).contiguous() diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index df983f875..5824a7e14 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -52,7 +52,7 @@ def _init_custom(self): def _init_mem_manager(self): self.mem_manager = MemoryManager( self.max_total_token_num, # [SYM] should be sliding window? - dtype=torch.float16, + dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.world_size_, head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], layer_num=self.config["num_hidden_layers"], @@ -82,6 +82,6 @@ def _init_to_get_rotary(self, default_base=10000): t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return diff --git a/lightllm/models/starcoder/model.py b/lightllm/models/starcoder/model.py index 12a98a4db..936bdf2a4 100644 --- a/lightllm/models/starcoder/model.py +++ b/lightllm/models/starcoder/model.py @@ -43,7 +43,7 @@ def _verify_params(self): def _init_mem_manager(self): self.mem_manager = select_mem_manager_class(self.mode)(self.max_total_token_num, - dtype=torch.float16, + dtype=self.data_type, head_num=self.config["num_key_value_heads"], head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], layer_num=self.config["num_hidden_layers"]) diff --git a/lightllm/models/starcoder2/model.py b/lightllm/models/starcoder2/model.py index 74e085e5e..79c17f33b 100644 --- a/lightllm/models/starcoder2/model.py +++ b/lightllm/models/starcoder2/model.py @@ -48,7 +48,7 @@ def _init_custom(self): def _init_mem_manager(self): self.mem_manager = select_mem_manager_class(self.mode)( self.max_total_token_num, - dtype=torch.float16, + dtype=self.data_type, head_num=self.config["num_key_value_heads"] // self.world_size_, head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], layer_num=self.config["num_hidden_layers"], @@ -81,6 +81,6 @@ def _init_to_get_rotary(self, default_base=10000): t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 38118bdb0..94ef42567 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -404,6 +404,11 @@ def main(): parser.add_argument( "--cache_reserved_ratio", type=float, default=0.5, help="cache server reserved capacity ratio after clear" ) + parser.add_argument( + "--data_type", type=str, + choices=["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"], + default="float16", help="the data type of the model weight" + ) parser.add_argument("--return_all_prompt_logprobs", action="store_true", help="return all prompt tokens logprobs") parser.add_argument( "--long_truncation_mode", diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 8ed402a89..125b83802 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -75,6 +75,7 @@ async def wait_to_model_ready(self): "splitfuse_block_size": self.splitfuse_block_size, "return_all_prompt_logprobs": self.args.return_all_prompt_logprobs, "use_dynamic_prompt_cache": self.args.use_dynamic_prompt_cache, + "data_type": self.args.data_type, "eos_id": self.eos_id, } init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 02cc7fd82..42622b7a1 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -99,8 +99,8 @@ def exposed_init_model(self, kvargs): "max_seq_length": kvargs.get("max_seq_length", 1024 * 5), "return_all_prompt_logprobs": self.return_all_prompt_logprobs, "use_dynamic_prompt_cache": self.use_dynamic_prompt_cache, + "data_type": kvargs.get("data_type", "float16"), } - is_weight_only_quant = any("w6a16" in mode_ or "w8a16" in mode_ or "w4a16" in mode_ for mode_ in self.mode) try: diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 274d568bc..709f60f01 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -20,7 +20,6 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.convert_slow_tokenizer import convert_slow_tokenizer -from transformers import LlamaTokenizer from transformers.configuration_utils import PretrainedConfig from lightllm.utils.log_utils import init_logger