|
13 | 13 | )
|
14 | 14 | from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
|
15 | 15 | from lightllm.models.llama.infer_struct import LlamaInferStateInfo
|
| 16 | +from lightllm.models.llama.splitfuse_infer_struct import LlamaSplitFuseInferStateInfo |
16 | 17 | from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
|
17 | 18 | from lightllm.common.basemodel import TransformerLayerInferActivationWeightQuantTpl
|
18 | 19 | from lightllm.common.basemodel.cuda_kernel.ppl_awquant import (
|
@@ -220,6 +221,33 @@ def _token_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_wei
|
220 | 221 | input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
|
221 | 222 | return
|
222 | 223 |
|
| 224 | + def _splitfuse_attention(self, input_embding, infer_state: LlamaSplitFuseInferStateInfo, layer_weight): |
| 225 | + # 因为 LlamaSplitFuseInferStateInfo 对象并没有 is_prefill 成员,但是后续的矩阵乘法算子入口 |
| 226 | + # 函数输入中需要使用到, 所以在开始的地方默认添加一个 is_prefill 成员,并设置为True. |
| 227 | + infer_state.is_prefill = True |
| 228 | + |
| 229 | + input1, token_scale, skip_out = self._awquant_att_norm(input_embding, infer_state, layer_weight) |
| 230 | + cache_kv = self._pre_cache_kv(infer_state, layer_weight) |
| 231 | + q, cache_kv = self._get_qkv(input1, cache_kv, token_scale, infer_state, layer_weight) |
| 232 | + input1 = None |
| 233 | + self._post_cache_kv(cache_kv, infer_state, layer_weight) |
| 234 | + o = self._splitfuse_attention_kernel(q, infer_state, layer_weight) |
| 235 | + q = None |
| 236 | + o = self._get_o(o, infer_state, layer_weight) |
| 237 | + if self.world_size_ > 1: |
| 238 | + dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) |
| 239 | + input_embding.add_(o.view(-1, self.embed_dim_)) |
| 240 | + return |
| 241 | + |
| 242 | + def _splitfuse_ffn(self, input_embdings, infer_state: LlamaSplitFuseInferStateInfo, layer_weight): |
| 243 | + input1, token_scale, skip_out = self._awquant_ffn_norm(input_embdings, infer_state, layer_weight) |
| 244 | + ffn_out = self._ffn(input1, token_scale, infer_state, layer_weight) |
| 245 | + input1 = None |
| 246 | + if self.world_size_ > 1: |
| 247 | + dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) |
| 248 | + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) |
| 249 | + return |
| 250 | + |
223 | 251 | def _awquant_matmul_ppl_int8_quant_dequant(
|
224 | 252 | self, input, quant_weight_params, is_prefill, token_scale=None, out=None, bias=None, has_act=False
|
225 | 253 | ):
|
|
0 commit comments