Skip to content

Commit 9b7d640

Browse files
author
wangzaijun
committed
ppl_w8a8 support splitfuse mode
1 parent c8160a4 commit 9b7d640

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py

+28
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
1515
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
16+
from lightllm.models.llama.splitfuse_infer_struct import LlamaSplitFuseInferStateInfo
1617
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
1718
from lightllm.common.basemodel import TransformerLayerInferActivationWeightQuantTpl
1819
from lightllm.common.basemodel.cuda_kernel.ppl_awquant import (
@@ -220,6 +221,33 @@ def _token_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_wei
220221
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
221222
return
222223

224+
def _splitfuse_attention(self, input_embding, infer_state: LlamaSplitFuseInferStateInfo, layer_weight):
225+
# 因为 LlamaSplitFuseInferStateInfo 对象并没有 is_prefill 成员,但是后续的矩阵乘法算子入口
226+
# 函数输入中需要使用到, 所以在开始的地方默认添加一个 is_prefill 成员,并设置为True.
227+
infer_state.is_prefill = True
228+
229+
input1, token_scale, skip_out = self._awquant_att_norm(input_embding, infer_state, layer_weight)
230+
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
231+
q, cache_kv = self._get_qkv(input1, cache_kv, token_scale, infer_state, layer_weight)
232+
input1 = None
233+
self._post_cache_kv(cache_kv, infer_state, layer_weight)
234+
o = self._splitfuse_attention_kernel(q, infer_state, layer_weight)
235+
q = None
236+
o = self._get_o(o, infer_state, layer_weight)
237+
if self.world_size_ > 1:
238+
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
239+
input_embding.add_(o.view(-1, self.embed_dim_))
240+
return
241+
242+
def _splitfuse_ffn(self, input_embdings, infer_state: LlamaSplitFuseInferStateInfo, layer_weight):
243+
input1, token_scale, skip_out = self._awquant_ffn_norm(input_embdings, infer_state, layer_weight)
244+
ffn_out = self._ffn(input1, token_scale, infer_state, layer_weight)
245+
input1 = None
246+
if self.world_size_ > 1:
247+
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
248+
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
249+
return
250+
223251
def _awquant_matmul_ppl_int8_quant_dequant(
224252
self, input, quant_weight_params, is_prefill, token_scale=None, out=None, bias=None, has_act=False
225253
):

lightllm/server/httpserver/manager.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ async def generate(
183183
logger.debug(
184184
f"req_id:{group_request_id},start:{start_time}s,first_token_cost:{first_token_cost_ms}ms\n"
185185
f"total_cost_time:{total_cost_time_ms}ms,out_token_counter:{out_token_counter}\n"
186-
f"mean_per_token_cost_time: {total_cost_time_ms/out_token_counter}ms"
186+
f"mean_per_token_cost_time: {total_cost_time_ms/out_token_counter}ms\n"
187+
f"prompt_token_num:{prompt_tokens}"
187188
)
188189
monitor.histogram_observe("lightllm_request_inference_duration", total_cost_time_ms)
189190
monitor.histogram_observe(

0 commit comments

Comments
 (0)