Skip to content

Commit c8160a4

Browse files
add high performance layernorm triton kernels. (#432)
1 parent 62c006c commit c8160a4

File tree

7 files changed

+380
-12
lines changed

7 files changed

+380
-12
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
4646
- [Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat)
4747
- [Llava-7b](https://huggingface.co/liuhaotian/llava-v1.5-7b)
4848
- [Llava-13b](https://huggingface.co/liuhaotian/llava-v1.5-13b)
49-
- [Mixtral]()
49+
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
5050
- [Stablelm](https://huggingface.co/stabilityai/stablelm-2-1_6b)
5151
- [MiniCPM](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)
52+
- [CohereForAI](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
5253

5354
> When you start Qwen-7b, you need to set the parameter '--eos_id 151643 --trust_remote_code'.
5455

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@ def _get_qkv(
5555
k = cache_kv[:, 0 : self.tp_k_head_num_, :]
5656
q = self._q_norm(q, infer_state, layer_weight)
5757
cache_kv[:, 0 : self.tp_k_head_num_, :] = self._k_norm(k, infer_state, layer_weight)
58-
self._rotary_emb_fwd(q, cache_kv, infer_state.position_cos, infer_state.position_sin)
58+
self._rotary_emb_fwd(
59+
q.view(-1, self.tp_q_head_num_, self.head_dim_),
60+
cache_kv[:, 0 : self.tp_k_head_num_, :],
61+
infer_state.position_cos,
62+
infer_state.position_sin,
63+
)
5964
return q, cache_kv
6065

6166
def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor:

lightllm/models/cohere/layer_infer/post_layer_infer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from lightllm.models.cohere.infer_struct import CohereInferStateInfo
66
from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight
7-
from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, multi_head_layernorm_forward
7+
from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward
88
from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight
99
from lightllm.common.basemodel.splitfuse_infer_struct import SplitFuseInferStateInfo
1010

@@ -22,7 +22,9 @@ def __init__(self, tp_rank, world_size, network_config, mode):
2222
return
2323

2424
def _norm(self, input, infer_state, layer_weight: CoherePreAndPostLayerWeight) -> torch.Tensor:
25-
return layernorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_)
25+
return layernorm_forward(
26+
input.unsqueeze(1), layer_weight.final_norm_weight_.unsqueeze(0), eps=self.eps_
27+
).squeeze(1)
2628

2729
def _slice_get_last_input(self, input_embdings, infer_state: CohereInferStateInfo):
2830
if infer_state.is_splitfuse:

lightllm/models/cohere/layer_infer/transformer_layer_infer.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
)
77
from lightllm.models.cohere.infer_struct import CohereInferStateInfo
88
from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight
9-
from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, multi_head_layernorm_forward
9+
from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, torch_layernorm
10+
from lightllm.models.cohere.triton_kernels.rotary_emb import rotary_emb_fwd
1011
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
11-
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
1212
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
1313

1414

@@ -42,13 +42,13 @@ def _bind_rotary_emb_fwd(self):
4242
self._rotary_emb_fwd = partial(CohereTransformerLayerInfer._rotary_emb_fwd, self)
4343

4444
def _att_norm(self, input, infer_state, layer_weight):
45-
return layernorm_forward(input, layer_weight.att_norm_weight_, self.eps_)
45+
return layernorm_forward(input.unsqueeze(1), layer_weight.att_norm_weight_.unsqueeze(0), self.eps_).squeeze(1)
4646

4747
def _q_norm(self, input, infer_state, layer_weight):
48-
return multi_head_layernorm_forward(input, layer_weight.q_norm_weight_, self.eps_)
48+
return layernorm_forward(input, layer_weight.q_norm_weight_, self.eps_)
4949

5050
def _k_norm(self, input, infer_state, layer_weight):
51-
return multi_head_layernorm_forward(input, layer_weight.k_norm_weight_, self.eps_)
51+
return layernorm_forward(input, layer_weight.k_norm_weight_, self.eps_)
5252

5353
def _bind_norm(self):
5454
self._att_norm = partial(CohereTransformerLayerInfer._att_norm, self)

lightllm/models/cohere/model.py

+45
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import os
2+
import torch
13
from lightllm.common.basemodel.basemodel import TpPartBaseModel
24
from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_cohere_template import (
35
TransformerLayerCohereInferTpl,
46
)
7+
from lightllm.common.mem_manager import MemoryManager
58
from lightllm.models.cohere.infer_struct import CohereInferStateInfo
69
from lightllm.models.cohere.layer_infer.post_layer_infer import CoherePostLayerInfer
710
from lightllm.models.cohere.layer_infer.transformer_layer_infer import CohereTransformerLayerInfer
@@ -10,6 +13,9 @@
1013
from lightllm.models.cohere.splitfuse_infer_struct import CohereSplitFuseInferStateInfo
1114
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
1215
from lightllm.models.llama.model import LlamaTpPartModel
16+
from lightllm.utils.log_utils import init_logger
17+
18+
logger = init_logger(__name__)
1319

1420

1521
class CohereTpPartModel(LlamaTpPartModel):
@@ -22,3 +28,42 @@ class CohereTpPartModel(LlamaTpPartModel):
2228

2329
infer_state_class = CohereInferStateInfo
2430
splitfuse_infer_state_class = CohereSplitFuseInferStateInfo
31+
32+
def _init_to_get_rotary(self, default_base=10000):
33+
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
34+
if self.config.get("rope_scaling", {}) is None:
35+
rope_scaling_factor = 1.0
36+
else:
37+
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
38+
39+
base = self.config.get("rope_theta", float(default_base))
40+
41+
if "max_sequence_length" in self.config:
42+
max_seq_len = self.config["max_sequence_length"]
43+
else:
44+
max_position_embeddings = self.config.get(
45+
"max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384
46+
)
47+
max_seq_len = max_position_embeddings * rope_scaling_factor
48+
49+
# NTK
50+
try:
51+
ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
52+
assert ntk_alpha >= 1
53+
if ntk_alpha > 1:
54+
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
55+
max_seq_len *= ntk_alpha
56+
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
57+
except:
58+
pass
59+
60+
inv_freq = 1.0 / (
61+
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
62+
)
63+
t = torch.arange(max_seq_len + 1024 * 128, device="cpu", dtype=torch.float32) / rope_scaling_factor
64+
freqs = torch.outer(t, inv_freq)
65+
freqs = torch.repeat_interleave(freqs, 2, dim=-1)
66+
67+
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
68+
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
69+
return
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,131 @@
11
import torch
2+
import triton
3+
import triton.language as tl
24

5+
# LayerNorm adapted from triton tutorial, used for Cohere q, k norm
6+
# X [N, head_num, head_dim]
7+
# W [head_num, head_dim]
8+
@triton.jit
9+
def _layer_norm_fwd_kernel(
10+
X, # pointer to the input
11+
W, # pointer to the weights
12+
Y,
13+
stride_x_N,
14+
stride_x_hn,
15+
stride_x_hd,
16+
stride_y_N,
17+
stride_y_hn,
18+
stride_y_hd,
19+
stride_w_hn,
20+
stride_w_hd,
21+
N, # number of columns in X
22+
eps, # epsilon to avoid division by zero
23+
BLOCK_SIZE: tl.constexpr,
24+
):
25+
Seq = tl.program_id(0)
26+
H = tl.program_id(1)
327

4-
def layernorm_forward(x, weight, eps):
5-
return torch.layer_norm(x, (x.shape[-1],), weight, bias=None, eps=eps)
28+
X += Seq * stride_x_N + H * stride_x_hn
29+
Y += Seq * stride_y_N + H * stride_y_hn
30+
W += H * stride_w_hn
631

32+
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
33+
for off in range(0, N, BLOCK_SIZE):
34+
cols = off + tl.arange(0, BLOCK_SIZE)
35+
a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
36+
_mean += a
37+
mean = tl.sum(_mean, axis=0) / N
738

8-
def multi_head_layernorm_forward(x, weight, eps):
39+
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
40+
for off in range(0, N, BLOCK_SIZE):
41+
cols = off + tl.arange(0, BLOCK_SIZE)
42+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
43+
x = tl.where(cols < N, x - mean, 0.0)
44+
_var += x * x
45+
var = tl.sum(_var, axis=0) / N
46+
rstd = 1 / tl.sqrt(var + eps)
47+
48+
for off in range(0, N, BLOCK_SIZE):
49+
cols = off + tl.arange(0, BLOCK_SIZE)
50+
mask = cols < N
51+
w = tl.load(W + cols, mask=mask).to(tl.float32)
52+
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
53+
x_hat = (x - mean) * rstd
54+
y = x_hat * w
55+
56+
tl.store(Y + cols, y.to(X.dtype.element_ty), mask=mask)
57+
58+
59+
def layernorm_forward(
60+
X, # pointer to the input
61+
W, # pointer to the weights
62+
eps, # epsilon to avoid division by zero
63+
):
64+
assert len(X.shape) == 3
65+
assert len(W.shape) == 2
66+
assert X.shape[-1] == W.shape[-1]
67+
assert X.shape[-2] == W.shape[-2]
68+
69+
y = torch.empty_like(X)
70+
71+
stride_x_N = X.stride(0)
72+
stride_x_hn = X.stride(1)
73+
stride_x_hd = X.stride(2)
74+
75+
stride_y_N = y.stride(0)
76+
stride_y_hn = y.stride(1)
77+
stride_y_hd = y.stride(2)
78+
79+
stride_w_hn = W.stride(0)
80+
stride_w_hd = W.stride(1)
81+
82+
N = X.shape[-1]
83+
BLOCK_SIZE = 128
84+
85+
grid = (X.shape[0], X.shape[1])
86+
_layer_norm_fwd_kernel[grid](
87+
X,
88+
W,
89+
y,
90+
stride_x_N,
91+
stride_x_hn,
92+
stride_x_hd,
93+
stride_y_N,
94+
stride_y_hn,
95+
stride_y_hd,
96+
stride_w_hn,
97+
stride_w_hd,
98+
N,
99+
eps,
100+
BLOCK_SIZE,
101+
)
102+
103+
return y
104+
105+
106+
def torch_layernorm(x, weight, eps):
9107
inp_dtype = x.dtype
10108
x = x.to(torch.float32)
11109
mean = x.mean(-1, keepdim=True)
12110
variance = (x - mean).pow(2).mean(-1, keepdim=True)
13111
x = (x - mean) * torch.rsqrt(variance + eps)
14112
x = weight.to(torch.float32) * x
15113
return x.to(inp_dtype)
114+
115+
116+
def test_layernorm(eps=1e-5):
117+
# create data
118+
dtype = torch.float16
119+
x_shape = (5, 1, 128)
120+
w_shape = (x_shape[-2], x_shape[-1])
121+
weight = torch.rand(w_shape, dtype=dtype, device="cuda")
122+
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
123+
# forward pass
124+
y_ref = torch_layernorm(x, weight, eps).to(dtype)
125+
y_out = layernorm_forward(x, weight, eps)
126+
127+
# compare
128+
print("type:", y_out.dtype, y_ref.dtype)
129+
print("max delta:", torch.max(torch.abs(y_out - y_ref)))
130+
assert torch.allclose(y_out, y_ref, atol=1e-2, rtol=0)
131+
return

0 commit comments

Comments
 (0)