Skip to content

Commit adba526

Browse files
authored
[gemma_2b] support (#346)
to be confirmed: in triton attention operator for context and decode, because the head_dim == 256, so "assert Lk in {16, 32, 64, 128}" -> "assert Lk in {16, 32, 64, 128, 256}" if it will affect the performance?
1 parent 414aa08 commit adba526

13 files changed

+385
-3
lines changed

lightllm/models/gemma_2b/__init__.py

Whitespace-only changes.

lightllm/models/gemma_2b/layer_infer/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
import torch.distributed as dist
3+
import numpy as np
4+
5+
from lightllm.common.basemodel.splitfuse_infer_struct import SplitFuseInferStateInfo
6+
from lightllm.models.gemma_2b.layer_weights.pre_and_post_layer_weight import Gemma_2bPreAndPostLayerWeight
7+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
8+
from lightllm.common.basemodel import PreLayerInferTpl
9+
from lightllm.utils.infer_utils import mark_cost_time
10+
11+
12+
class Gemma_2bPreLayerInfer(PreLayerInferTpl):
13+
""" """
14+
15+
def __init__(self, tp_rank, world_size, network_config, mode):
16+
super().__init__(tp_rank, world_size, network_config, mode)
17+
tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.world_size_ + 1, dtype=np.int64)
18+
self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1])
19+
self.normfactor = network_config["hidden_size"]**0.5
20+
return
21+
22+
def _norm(self, input, infer_state, layer_weight : Gemma_2bPreAndPostLayerWeight) -> torch.Tensor:
23+
return input * self.normfactor
24+
25+
@mark_cost_time("pre context forward")
26+
def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight):
27+
input_mask = torch.logical_or(self.vob_start_id_ > input_ids, input_ids >= self.vob_end_id_)
28+
tmp_input_ids = input_ids - self.vob_start_id_
29+
tmp_input_ids[input_mask] = 0
30+
input_embdings = torch.embedding(layer_weight.wte_weight_, tmp_input_ids, padding_idx=-1)
31+
input_embdings[input_mask] = 0.0
32+
if self.world_size_ > 1:
33+
dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False)
34+
input_embdings = self._norm(input_embdings, infer_state, layer_weight)
35+
return input_embdings
36+
37+
def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight):
38+
input_mask = torch.logical_or(self.vob_start_id_ > input_ids, input_ids >= self.vob_end_id_)
39+
tmp_input_ids = input_ids - self.vob_start_id_
40+
tmp_input_ids[input_mask] = 0
41+
input_embdings = torch.embedding(layer_weight.wte_weight_, tmp_input_ids, padding_idx=-1)
42+
input_embdings[input_mask] = 0.0
43+
if self.world_size_ > 1:
44+
dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False)
45+
input_embdings = self._norm(input_embdings, infer_state, layer_weight)
46+
return input_embdings
47+
48+
# @mark_cost_time("splitfuse forward")
49+
def splitfuse_forward(
50+
self, input_ids, infer_state: SplitFuseInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight
51+
):
52+
return self.token_forward(input_ids, infer_state, layer_weight)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
import torch.functional as F
3+
import torch.distributed as dist
4+
import numpy as np
5+
from typing import Tuple
6+
from functools import partial
7+
import triton
8+
9+
from lightllm.models.gemma_2b.layer_weights.transformer_layer_weight import Gemma_2bTransformerLayerWeight
10+
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
11+
from lightllm.models.gemma_2b.triton_kernel.gelu_and_mul import gelu_and_mul_fwd
12+
13+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
14+
15+
16+
class Gemma_2bTransformerLayerInfer(LlamaTransformerLayerInfer):
17+
""" """
18+
19+
def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
20+
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
21+
self.tp_k_head_num_ = network_config["num_key_value_heads"] # [SYM] always == 1
22+
self.tp_v_head_num_ = network_config["num_key_value_heads"]
23+
return
24+
25+
def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bTransformerLayerWeight) -> torch.Tensor:
26+
up_gate_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.gate_up_proj)
27+
ffn1_out = gelu_and_mul_fwd(up_gate_out)
28+
input = None
29+
up_gate_out = None
30+
ffn2_out = torch.mm(ffn1_out, layer_weight.down_proj)
31+
ffn1_out = None
32+
return ffn2_out

lightllm/models/gemma_2b/layer_weights/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
import numpy as np
3+
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
4+
5+
6+
class Gemma_2bPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
7+
def __init__(self, tp_rank, world_size, data_type, network_config, mode):
8+
super().__init__(tp_rank, world_size, data_type, network_config, mode)
9+
return
10+
11+
def load_hf_weights(self, weights):
12+
vob_size = self.network_config_["vocab_size"]
13+
split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64)
14+
split_start = split_indexes[self.tp_rank_]
15+
split_end = split_indexes[self.tp_rank_ + 1]
16+
if "model.embed_tokens.weight" in weights:
17+
# print(weights['model.embed_tokens.weight'].shape)
18+
self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :])
19+
if "lm_head.weight" in weights:
20+
# print(weights['lm_head.weight'].shape)
21+
self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :])
22+
else:
23+
self.lm_head_weight_ = self.wte_weight_
24+
if "model.norm.weight" in weights:
25+
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"])
26+
self.final_norm_weight_ = self.final_norm_weight_ + 1
27+
28+
return
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
import math
3+
import numpy as np
4+
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
5+
6+
7+
class Gemma_2bTransformerLayerWeight(LlamaTransformerLayerWeight):
8+
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]):
9+
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode)
10+
return
11+
12+
def _load_qkvo_weights(self, weights):
13+
# input layernorm params
14+
if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights:
15+
self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"])
16+
self.att_norm_weight_ += 1
17+
18+
n_embed = self.network_config_["hidden_size"]
19+
q_split_n_embed = n_embed // self.world_size_
20+
21+
# q k v weights for llama
22+
if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights:
23+
self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"]
24+
self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :]
25+
self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1))
26+
27+
if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights:
28+
k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"]
29+
self.k_weight_ = k_weight_.transpose(0, 1)
30+
31+
if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights:
32+
v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"]
33+
self.v_weight_ = v_weight_.transpose(0, 1)
34+
35+
# attention output dense params
36+
if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights:
37+
self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"]
38+
self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)]
39+
self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1))
40+
41+
self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1)
42+
43+
return
44+
45+
def _load_ffn_weights(self, weights):
46+
if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights:
47+
self.ffn_norm_weight_ = self._cuda(
48+
weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"]
49+
)
50+
self.ffn_norm_weight_ += 1
51+
52+
inter_size = self.network_config_["intermediate_size"]
53+
split_inter_size = inter_size // self.world_size_
54+
55+
if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights:
56+
up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][
57+
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
58+
]
59+
self.up_proj = up_proj.transpose(0, 1)
60+
61+
if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights:
62+
gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][
63+
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
64+
]
65+
self.gate_proj = gate_proj.transpose(0, 1)
66+
67+
self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1)
68+
69+
if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights:
70+
self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][
71+
:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)
72+
]
73+
self.down_proj = self._cuda(self.down_proj.transpose(0, 1))
74+
return

lightllm/models/gemma_2b/model.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import os
2+
import json
3+
import torch
4+
5+
from lightllm.common.basemodel import TpPartBaseModel
6+
from lightllm.models.gemma_2b.layer_weights.transformer_layer_weight import Gemma_2bTransformerLayerWeight
7+
from lightllm.models.gemma_2b.layer_weights.pre_and_post_layer_weight import Gemma_2bPreAndPostLayerWeight
8+
from lightllm.models.gemma_2b.layer_infer.pre_layer_infer import Gemma_2bPreLayerInfer
9+
from lightllm.models.gemma_2b.layer_infer.transformer_layer_infer import Gemma_2bTransformerLayerInfer
10+
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
11+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
12+
13+
from lightllm.common.mem_utils import MemoryManager
14+
15+
class Gemma_2bTpPartModel(TpPartBaseModel):
16+
# weight class
17+
pre_and_post_weight_class = Gemma_2bPreAndPostLayerWeight
18+
transformer_weight_class = Gemma_2bTransformerLayerWeight
19+
20+
# infer class
21+
pre_layer_infer_class = Gemma_2bPreLayerInfer
22+
post_layer_infer_class = LlamaPostLayerInfer
23+
transformer_layer_infer_class = Gemma_2bTransformerLayerInfer
24+
25+
# infer state class
26+
infer_state_class = LlamaInferStateInfo
27+
28+
def __init__(self, kvargs):
29+
super().__init__(kvargs)
30+
return
31+
32+
def _init_config(self):
33+
super()._init_config()
34+
return
35+
36+
def _verify_params(self):
37+
assert self.load_way in ["HF"], "gemma only supports HF format to load Now!"
38+
# assert self.config["num_key_value_heads"] % self.world_size_ == 0
39+
assert self.config["num_attention_heads"] % self.world_size_ == 0
40+
return
41+
42+
def _init_custom(self):
43+
self._init_to_get_rotary()
44+
return
45+
46+
def _init_mem_manager(self):
47+
self.mem_manager = MemoryManager(self.max_total_token_num,
48+
dtype=torch.float16,
49+
head_num=self.config["num_key_value_heads"], # [SYM] always == 1
50+
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
51+
layer_num=self.config["num_hidden_layers"])
52+
return
53+
54+
55+
def _init_to_get_rotary(self, default_base=10000):
56+
if self.config.get("rope_scaling", {}) is None:
57+
rope_scaling_factor = 1.0
58+
else:
59+
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
60+
61+
base = self.config.get("rope_theta", float(default_base))
62+
63+
if "max_sequence_length" in self.config:
64+
max_seq_len = self.config["max_sequence_length"]
65+
else:
66+
max_position_embeddings = self.config.get(
67+
"max_position_embeddings",
68+
2048 if base <= 10000.0 + 1e-5 else 16384
69+
)
70+
max_seq_len = max_position_embeddings * rope_scaling_factor
71+
72+
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim_, 2, device="cpu", dtype=torch.float32) / self.head_dim_))
73+
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
74+
freqs = torch.outer(t, inv_freq)
75+
76+
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
77+
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
78+
return
79+

lightllm/models/gemma_2b/triton_kernel/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import torch
2+
import math
3+
import triton
4+
import triton.language as tl
5+
6+
# copy from xformers impl.
7+
_kAlpha = math.sqrt(2.0 / math.pi)
8+
9+
@triton.jit
10+
def tanh(x):
11+
# Tanh is just a scaled sigmoid
12+
return 2 * tl.sigmoid(2 * x) - 1
13+
14+
@triton.jit
15+
def gelu(x):
16+
"""
17+
GeLU_ activation - Gaussian error linear unit
18+
19+
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
20+
"""
21+
return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x)))
22+
23+
@triton.jit
24+
def _gelu_and_mul_kernel(
25+
input_ptr,
26+
stride_input_m,
27+
stride_input_n,
28+
stride_output_m,
29+
stride_output_n,
30+
size_m,
31+
size_n,
32+
BLOCK_M: tl.constexpr,
33+
BLOCK_N: tl.constexpr,
34+
):
35+
tid = tl.program_id(0)
36+
input_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M)
37+
output_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M)
38+
39+
pid = tl.program_id(1)
40+
input_n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)
41+
output_n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)
42+
43+
up_offsets = input_m_offsets[:, None] * stride_input_m + (input_n_offsets[None, :] + size_n) * stride_input_n
44+
gate_offsets = input_m_offsets[:, None] * stride_input_m + input_n_offsets[None, :] * stride_input_n
45+
res_offsets = output_m_offsets[:, None] * stride_output_m + output_n_offsets[None, :] * stride_output_n
46+
47+
up = tl.load(
48+
input_ptr + up_offsets,
49+
mask=(input_n_offsets < size_n)[None, :] * (input_m_offsets < size_m)[:, None],
50+
other=0.0,
51+
)
52+
gate = tl.load(
53+
input_ptr + gate_offsets,
54+
mask=(input_n_offsets < size_n)[None, :] * (input_m_offsets < size_m)[:, None],
55+
other=0.0,
56+
).to(tl.float32)
57+
58+
gate = gelu(gate)
59+
gate = gate.to(tl.float16)
60+
61+
tl.store(
62+
input_ptr + res_offsets,
63+
up * gate,
64+
mask=(output_n_offsets < size_n)[None, :] * (output_m_offsets < size_m)[:, None],
65+
)
66+
67+
68+
def gelu_and_mul_fwd(input):
69+
stride_input_m = input.stride(0)
70+
stride_input_n = input.stride(1)
71+
stride_output_m = input.stride(0)
72+
stride_output_n = input.stride(1)
73+
size_m = input.shape[0]
74+
size_n = input.shape[-1] // 2
75+
BLOCK_M = 128
76+
BLOCK_N = 128
77+
grid = (
78+
triton.cdiv(size_m, BLOCK_M),
79+
triton.cdiv(size_n, BLOCK_N),
80+
)
81+
_gelu_and_mul_kernel[grid](
82+
input,
83+
stride_input_m,
84+
stride_input_n,
85+
stride_output_m,
86+
stride_output_n,
87+
size_m,
88+
size_n,
89+
BLOCK_M,
90+
BLOCK_N,
91+
)
92+
return input[:, 0 : (input.shape[-1] // 2)]
93+
94+
95+
def torch_gelu_and_mul(input: torch.Tensor):
96+
return torch.nn.functional.gelu(input[:, 0 : (input.shape[-1] // 2)]) * input[:, (input.shape[-1] // 2) :]
97+
98+
99+
def test_gelu_and_mul(M, N, dtype, device="cuda"):
100+
# create data
101+
X = torch.randn((M, N), dtype=dtype, device=device)
102+
103+
# run
104+
y_tri = gelu_and_mul_fwd(X)
105+
y_ref = torch_gelu_and_mul(X)
106+
107+
# compare
108+
print("type:", y_tri.dtype, y_ref.dtype)
109+
print("max delta:", torch.max(torch.abs(y_tri - y_ref)))
110+
assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
111+
return
112+
113+
114+
# test_gelu_and_mul(16, 4096, torch.float16, device='cuda')

0 commit comments

Comments
 (0)