1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch .distributed as dist
6
+ import torch .functional as F
7
+ import triton
8
+ from functools import partial
9
+
10
+ from lightllm .models .llama .triton_kernel .rotary_emb import rotary_emb_fwd
11
+ from lightllm .models .llama .triton_kernel .silu_and_mul import silu_and_mul_fwd
12
+ from lightllm .models .llama .infer_struct import LlamaInferStateInfo
13
+ from lightllm .models .llama .layer_infer .transformer_layer_infer import LlamaTransformerLayerInfer
14
+ from lightllm .models .llama_quik .layer_weights .transformer_layer_weight import LlamaTransformerLayerWeightQuik
15
+ from lightllm .utils .infer_utils import mark_cost_time
16
+
17
+
18
+ class LlamaTransformerLayerInferQuik (LlamaTransformerLayerInfer ):
19
+ """ """
20
+
21
+ def __init__ (self , layer_num , tp_rank , world_size , network_config , mode = []):
22
+ super ().__init__ (layer_num , tp_rank , world_size , network_config , mode )
23
+ self .eps_ = network_config ["rms_norm_eps" ]
24
+ self .tp_q_head_num_ = network_config ["num_attention_heads" ] // self .world_size_
25
+ self .tp_k_head_num_ = network_config ["num_key_value_heads" ] // self .world_size_
26
+ self .tp_v_head_num_ = network_config ["num_key_value_heads" ] // self .world_size_
27
+ self .tp_o_head_num_ = self .tp_q_head_num_
28
+ self .head_dim_ = network_config ["hidden_size" ] // network_config ["num_attention_heads" ]
29
+ self .embed_dim_ = network_config ["hidden_size" ]
30
+
31
+ self .inter_dim_ = network_config ["intermediate_size" ]
32
+ self ._bind_func ()
33
+ return
34
+
35
+ def _get_qkv (
36
+ self ,
37
+ input : torch .Tensor ,
38
+ cache_kv : torch .Tensor ,
39
+ infer_state : LlamaInferStateInfo ,
40
+ layer_weight : LlamaTransformerLayerWeightQuik ,
41
+ ) -> torch .Tensor :
42
+ q = layer_weight .q_proj (input .view (- 1 , self .embed_dim_ ))
43
+ if layer_weight .cat_kv_ :
44
+ cache_kv = layer_weight .kv_proj (input .view (- 1 , self .embed_dim_ )).view (- 1 , self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ )
45
+ else :
46
+ cache_k = layer_weight .k_proj (input .view (- 1 , self .embed_dim_ )).view (- 1 , self .tp_k_head_num_ , self .head_dim_ )
47
+ cache_v = layer_weight .v_proj (input .view (- 1 , self .embed_dim_ )).view (- 1 , self .tp_v_head_num_ , self .head_dim_ )
48
+ cache_kv = torch .cat ([cache_k , cache_v ], dim = 1 )
49
+ rotary_emb_fwd (
50
+ q .view (- 1 , self .tp_q_head_num_ , self .head_dim_ ),
51
+ cache_kv [:, 0 : self .tp_k_head_num_ , :],
52
+ infer_state .position_cos ,
53
+ infer_state .position_sin ,
54
+ )
55
+ return q , cache_kv
56
+
57
+ def _get_o (
58
+ self , input , infer_state : LlamaInferStateInfo , layer_weight : LlamaTransformerLayerWeightQuik
59
+ ) -> torch .Tensor :
60
+ return layer_weight .o_proj (input .view (- 1 , self .embed_dim_ ))
61
+
62
+ def _ffn (
63
+ self ,
64
+ input ,
65
+ infer_state : LlamaInferStateInfo ,
66
+ layer_weight : LlamaTransformerLayerWeightQuik ,
67
+ ) -> torch .Tensor :
68
+ if not layer_weight .cat_gate_up_ :
69
+ gate_out = layer_weight .gate_proj (input .view (- 1 , self .embed_dim_ ))
70
+ up_out = layer_weight .up_proj (input .view (- 1 , self .embed_dim_ ))
71
+ torch .nn .functional .silu (gate_out , inplace = True )
72
+ gate_out .mul_ (up_out )
73
+ input = None
74
+ ffn2_out = layer_weight .down_proj (gate_out )
75
+ gate_out , up_out = None , None
76
+ else :
77
+ gate_up_out = layer_weight .gate_up_proj (input .view (- 1 , self .embed_dim_ )).view (- 1 , self .inter_dim_ * 2 // self .world_size_ )
78
+ # gate_out, up_out = torch.split(gate_up_out, split_size_or_sections=1, dim=1)
79
+ ffn1_out = silu_and_mul_fwd (gate_up_out )
80
+ input = None
81
+ gate_up_out = None
82
+ ffn2_out = layer_weight .down_proj (ffn1_out )
83
+ ffn1_out = None
84
+
85
+ return ffn2_out
86
+
87
+ @mark_cost_time (
88
+ "trans context flash forward time cost"
89
+ ) # dont to remove this, will make performence down, did not know why
90
+ def _context_attention (self , input_embding , infer_state : LlamaInferStateInfo , layer_weight ):
91
+ input1 = self ._att_norm (input_embding , infer_state , layer_weight )
92
+ cache_kv = self ._pre_cache_kv (infer_state , layer_weight )
93
+ q , cache_kv = self ._get_qkv (input1 , cache_kv , infer_state , layer_weight )
94
+ input1 = None
95
+ self ._post_cache_kv (cache_kv , infer_state , layer_weight )
96
+ o = self ._context_attention_kernel (q , cache_kv , infer_state , layer_weight )
97
+ q = None
98
+ o = self ._get_o (o , infer_state , layer_weight )
99
+ if self .world_size_ > 1 :
100
+ dist .all_reduce (o , op = dist .ReduceOp .SUM , async_op = False )
101
+ input_embding .add_ (o .view (- 1 , self .embed_dim_ ))
102
+ return
103
+
104
+ @mark_cost_time (
105
+ "trans context ffn forward time cost"
106
+ ) # dont to remove this, will make performence down, did not know why
107
+ def _context_ffn (self , input_embdings , infer_state : LlamaInferStateInfo , layer_weight ):
108
+ input1 = self ._ffn_norm (input_embdings , infer_state , layer_weight )
109
+ ffn_out = self ._ffn (input1 , infer_state , layer_weight )
110
+ input1 = None
111
+ if self .world_size_ > 1 :
112
+ dist .all_reduce (ffn_out , op = dist .ReduceOp .SUM , async_op = False )
113
+ input_embdings .add_ (ffn_out .view (- 1 , self .embed_dim_ ))
114
+ return
115
+
116
+ # this impl dont to use @mark_cost_time
117
+ def _token_attention (self , input_embding , infer_state : LlamaInferStateInfo , layer_weight ):
118
+ input1 = self ._att_norm (input_embding , infer_state , layer_weight )
119
+ cache_kv = self ._pre_cache_kv (infer_state , layer_weight )
120
+ q , cache_kv = self ._get_qkv (input1 , cache_kv , infer_state , layer_weight )
121
+ input1 = None
122
+ self ._post_cache_kv (cache_kv , infer_state , layer_weight )
123
+ o = self ._token_attention_kernel (q , infer_state , layer_weight )
124
+ q = None
125
+ o = self ._get_o (o , infer_state , layer_weight )
126
+ if self .world_size_ > 1 :
127
+ dist .all_reduce (o , op = dist .ReduceOp .SUM , async_op = False )
128
+ input_embding .add_ (o .view (- 1 , self .embed_dim_ ))
129
+ return
130
+
131
+ # this impl dont to use @mark_cost_time
132
+ def _token_ffn (self , input_embdings , infer_state : LlamaInferStateInfo , layer_weight ):
133
+ input1 = self ._ffn_norm (input_embdings , infer_state , layer_weight )
134
+ ffn_out = self ._ffn (input1 , infer_state , layer_weight )
135
+ input1 = None
136
+ if self .world_size_ > 1 :
137
+ dist .all_reduce (ffn_out , op = dist .ReduceOp .SUM , async_op = False )
138
+ input_embdings .add_ (ffn_out .view (- 1 , self .embed_dim_ ))
139
+ return
0 commit comments