Skip to content

Commit 6db6200

Browse files
author
yunqian
committed
feat: add c4ai
1 parent eecc9d0 commit 6db6200

File tree

14 files changed

+624
-0
lines changed

14 files changed

+624
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
from functools import partial
2+
from typing import Tuple
3+
4+
import torch
5+
import torch.distributed as dist
6+
7+
from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl
8+
from lightllm.utils.infer_utils import mark_cost_time
9+
10+
from ...infer_struct import InferStateInfo
11+
from ...splitfuse_infer_struct import SplitFuseInferStateInfo
12+
from ..transformer_layer_infer import TransformerLayerInfer
13+
14+
15+
class TransformerLayerCohereInferTpl(TransformerLayerInferTpl):
16+
""" """
17+
18+
def __init__(self, layer_num, tp_rank, world_size, network_config, mode):
19+
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
20+
21+
self.use_qk_norm_ = self.network_config_.get("use_qk_norm", False)
22+
return
23+
24+
def _att_norm(
25+
self, input, infer_state: InferStateInfo, layer_weight
26+
) -> torch.Tensor:
27+
raise Exception("need to impl")
28+
29+
def _q_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
30+
raise Exception("need to impl")
31+
32+
def _k_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
33+
raise Exception("need to impl")
34+
35+
def _bind_norm(
36+
self, input, infer_state: InferStateInfo, layer_weight
37+
) -> torch.Tensor:
38+
self._att_norm = partial(TransformerLayerCohereInferTpl._q_norm, self)
39+
self._q_norm = partial(TransformerLayerCohereInferTpl._k_norm, self)
40+
self._k_norm = partial(TransformerLayerCohereInferTpl._att_norm, self)
41+
42+
def _rotary_emb_fwd(self, q, kv, position_cos, position_sin):
43+
raise Exception("need to impl")
44+
45+
def _bind_rotary_emb_fwd(self):
46+
raise Exception("need to impl")
47+
48+
def _get_qkv(
49+
self, input, cache_kv, infer_state: InferStateInfo, layer_weight
50+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51+
q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_)
52+
torch.mm(
53+
input.view(-1, self.embed_dim_),
54+
layer_weight.kv_weight_,
55+
out=cache_kv.view(
56+
-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_
57+
),
58+
)
59+
if self.use_qk_norm_:
60+
q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
61+
k = cache_kv[:, 0 : self.tp_k_head_num_, :]
62+
q = self._q_norm(q, infer_state, layer_weight)
63+
cache_kv[:, 0 : self.tp_k_head_num_, :] = self._k_norm(
64+
k, infer_state, layer_weight
65+
)
66+
self._rotary_emb_fwd(
67+
q, cache_kv, infer_state.position_cos, infer_state.position_sin
68+
)
69+
return q, cache_kv
70+
71+
def _context_attention_kernel(
72+
self, q, kv, infer_state: InferStateInfo, layer_weight, out=None
73+
) -> torch.Tensor:
74+
raise Exception("need to impl")
75+
76+
def _token_attention_kernel(
77+
self, q, infer_state: InferStateInfo, layer_weight, out=None
78+
) -> torch.Tensor:
79+
raise Exception("need to impl")
80+
81+
def _splitfuse_attention_kernel(
82+
self, q, infer_state: SplitFuseInferStateInfo, layer_weight, out=None
83+
) -> torch.Tensor:
84+
raise Exception("need to impl")
85+
86+
def _get_o(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
87+
raise Exception("need to impl")
88+
89+
def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
90+
raise Exception("need to impl")
91+
92+
@mark_cost_time(
93+
"trans context flash forward time cost"
94+
) # dont to remove this, will make performence down, did not know why
95+
def _context_attention(
96+
self, input_embding, infer_state: InferStateInfo, layer_weight
97+
):
98+
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
99+
q, cache_kv = self._get_qkv(input_embding, cache_kv, infer_state, layer_weight)
100+
self._post_cache_kv(cache_kv, infer_state, layer_weight)
101+
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
102+
q = None
103+
o = self._get_o(o, infer_state, layer_weight)
104+
if self.world_size_ > 1:
105+
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
106+
infer_state._attn_out = o
107+
return
108+
109+
@mark_cost_time(
110+
"trans context ffn forward time cost"
111+
) # dont to remove this, will make performence down, did not know why
112+
def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight):
113+
ffn_out = self._ffn(input_embdings, infer_state, layer_weight)
114+
if self.world_size_ > 1:
115+
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
116+
infer_state._ffn_out = ffn_out
117+
return
118+
119+
# this impl dont to use @mark_cost_time
120+
def _token_attention(
121+
self, input_embding, infer_state: InferStateInfo, layer_weight
122+
):
123+
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
124+
q, cache_kv = self._get_qkv(input_embding, cache_kv, infer_state, layer_weight)
125+
self._post_cache_kv(cache_kv, infer_state, layer_weight)
126+
o = self._token_attention_kernel(q, infer_state, layer_weight)
127+
q = None
128+
o = self._get_o(o, infer_state, layer_weight)
129+
if self.world_size_ > 1:
130+
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
131+
infer_state._attn_out = o
132+
return
133+
134+
# this impl dont to use @mark_cost_time
135+
def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight):
136+
ffn_out = self._ffn(input_embdings, infer_state, layer_weight)
137+
if self.world_size_ > 1:
138+
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
139+
infer_state._ffn_out = ffn_out
140+
return
141+
142+
# @mark_cost_time("trans context flash forward time cost") # dont to remove this, will make performence down, did not know why
143+
def _splitfuse_attention(
144+
self, input_embding, infer_state: SplitFuseInferStateInfo, layer_weight
145+
):
146+
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
147+
q, cache_kv = self._get_qkv(input_embding, cache_kv, infer_state, layer_weight)
148+
self._post_cache_kv(cache_kv, infer_state, layer_weight)
149+
o = self._splitfuse_attention_kernel(q, infer_state, layer_weight)
150+
q = None
151+
o = self._get_o(o, infer_state, layer_weight)
152+
if self.world_size_ > 1:
153+
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
154+
infer_state._attn_out = o
155+
return
156+
157+
# @mark_cost_time("trans context ffn forward time cost") # dont to remove this, will make performence down, did not know why
158+
def _splitfuse_ffn(
159+
self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight
160+
):
161+
ffn_out = self._ffn(input_embdings, infer_state, layer_weight)
162+
if self.world_size_ > 1:
163+
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
164+
infer_state._ffn_out = ffn_out
165+
return
166+
167+
def _cohere_residual(self, input_embdings, infer_state: InferStateInfo):
168+
# emb_addr = input_embdings.data_ptr()
169+
# attn_out_addr = infer_state._attn_out.data_ptr()
170+
# ffn_addr = infer_state._ffn_out.data_ptr()
171+
# assert emb_addr != attn_out_addr
172+
# assert emb_addr != ffn_addr
173+
# assert attn_out_addr != ffn_addr
174+
input_embdings.add_(
175+
infer_state._attn_out.view(-1, self.embed_dim_)
176+
+ infer_state._ffn_out.view(-1, self.embed_dim_)
177+
)
178+
179+
def context_forward(
180+
self, input_embdings, infer_state: InferStateInfo, layer_weight
181+
):
182+
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
183+
self._context_attention(input1, infer_state, layer_weight=layer_weight)
184+
self._context_ffn(input1, infer_state, layer_weight)
185+
self._cohere_residual(input_embdings, infer_state)
186+
return input_embdings
187+
188+
def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
189+
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
190+
self._token_attention(input1, infer_state, layer_weight=layer_weight)
191+
self._token_ffn(input1, infer_state, layer_weight)
192+
self._cohere_residual(input_embdings, infer_state)
193+
return input_embdings
194+
195+
def splitfuse_forward(
196+
self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight
197+
):
198+
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
199+
self._splitfuse_attention(input1, infer_state, layer_weight=layer_weight)
200+
self._splitfuse_ffn(input1, infer_state, layer_weight)
201+
self._cohere_residual(input_embdings, infer_state)
202+
return input_embdings

lightllm/models/cohere/__init__.py

Whitespace-only changes.
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
2+
3+
4+
class CohereInferStateInfo(LlamaInferStateInfo):
5+
def __init__(self):
6+
super().__init__()
7+
self._attn_out = None
8+
self._ffn_out = None

lightllm/models/cohere/layer_infer/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import torch
2+
import torch.distributed as dist
3+
import numpy as np
4+
5+
from lightllm.models.cohere.infer_struct import CohereInferStateInfo
6+
from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight
7+
from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, multi_head_layernorm_forward
8+
from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight
9+
from lightllm.common.basemodel.splitfuse_infer_struct import SplitFuseInferStateInfo
10+
11+
from einops import rearrange
12+
from lightllm.common.basemodel import PostLayerInferTpl
13+
14+
15+
class CoherePostLayerInfer(PostLayerInferTpl):
16+
def __init__(self, tp_rank, world_size, network_config, mode):
17+
super().__init__(tp_rank, world_size, network_config, mode)
18+
self.eps_ = network_config["layer_norm_eps"]
19+
self.vocab_size_ = network_config["vocab_size"]
20+
self.embed_dim_ = network_config["n_embed"]
21+
self.logits_scale = network_config["logit_scale"]
22+
return
23+
24+
def _norm(self, input, infer_state, layer_weight: CoherePreAndPostLayerWeight) -> torch.Tensor:
25+
return layernorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_)
26+
27+
def _slice_get_last_input(self, input_embdings, infer_state: CohereInferStateInfo):
28+
if infer_state.is_splitfuse:
29+
# for SplitFuse
30+
batch_size = infer_state.batch_size
31+
last_input = torch.empty(
32+
(batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype
33+
)
34+
tmp_ = torch.cat(
35+
[
36+
torch.ones(infer_state.decode_req_num, dtype=torch.int32, device="cuda"),
37+
infer_state.prefill_b_seq_len - infer_state.prefill_b_split_ready_cache_len,
38+
],
39+
dim=0,
40+
)
41+
last_index = torch.cumsum(tmp_, dim=0, dtype=torch.long) - 1
42+
last_input[:, :] = input_embdings[last_index, :]
43+
return last_input, batch_size
44+
45+
if infer_state.is_prefill and infer_state.is_token_healing:
46+
batch_size = infer_state.batch_size
47+
b_seq_len_numpy = (infer_state.b_seq_len - infer_state.b_ready_cache_len).detach().cpu().numpy()
48+
select_index = []
49+
start_index = 0
50+
select_token_num = 0
51+
for cur_len in b_seq_len_numpy:
52+
if cur_len == 1:
53+
select_index.append(start_index + cur_len - 1)
54+
start_index += cur_len
55+
select_token_num += 1
56+
else:
57+
select_index.append(start_index + cur_len - 2)
58+
select_index.append(start_index + cur_len - 1)
59+
start_index += cur_len
60+
select_token_num += 2
61+
62+
last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device)
63+
last_input = torch.empty(
64+
(select_token_num, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype
65+
)
66+
67+
last_input[:, :] = input_embdings[last_index, :]
68+
return last_input, select_token_num
69+
70+
if not infer_state.is_splitfuse and infer_state.is_prefill and not infer_state.return_all_prompt_logics:
71+
batch_size = infer_state.batch_size
72+
last_input = torch.empty(
73+
(batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype
74+
)
75+
last_index = (
76+
torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1
77+
)
78+
last_input[:, :] = input_embdings[last_index, :]
79+
return last_input, batch_size
80+
81+
if not infer_state.is_splitfuse and infer_state.is_prefill and infer_state.return_all_prompt_logics:
82+
total_tokens = infer_state.total_token_num
83+
return input_embdings, total_tokens
84+
85+
if not infer_state.is_splitfuse and not infer_state.is_prefill:
86+
batch_size = infer_state.batch_size
87+
return input_embdings[-batch_size:, :], batch_size
88+
89+
assert False, "Error State"
90+
91+
def soft_max(self, data):
92+
return torch.softmax(data.permute(1, 0).float(), dim=-1)
93+
94+
def token_forward(
95+
self,
96+
input_embdings,
97+
infer_state: CohereInferStateInfo,
98+
layer_weight: CoherePreAndPostLayerWeight,
99+
return_logics=False,
100+
):
101+
last_input, token_num = self._slice_get_last_input(input_embdings, infer_state)
102+
input_embdings_dtype = input_embdings.dtype
103+
input_embdings = None
104+
last_input = self._norm(last_input, infer_state, layer_weight)
105+
last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, token_num)
106+
logic_batch = torch.mm(layer_weight.lm_head_weight_, last_input)
107+
108+
last_input = None
109+
if self.world_size_ == 1:
110+
gather_data = logic_batch
111+
else:
112+
gather_data = torch.empty(
113+
(self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings_dtype
114+
)
115+
split_indexes = np.linspace(0, self.vocab_size_, self.world_size_ + 1, dtype=np.int64)
116+
dist.all_gather(
117+
[gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.world_size_)],
118+
logic_batch,
119+
group=None,
120+
async_op=False,
121+
)
122+
gather_data = gather_data * self.logits_scale
123+
logic_batch = None
124+
125+
if not return_logics:
126+
prob_out = self.soft_max(gather_data)
127+
gather_data = None
128+
return prob_out
129+
else:
130+
ans_logics = gather_data.permute(1, 0).float()
131+
gather_data = None
132+
return ans_logics
133+
134+
# @mark_cost_time("splitfuse post forward")
135+
def splitfuse_forward(
136+
self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight: BaseLayerWeight, return_logics=False
137+
):
138+
return self.token_forward(input_embdings, infer_state, layer_weight, return_logics=return_logics)

0 commit comments

Comments
 (0)