Skip to content

Commit 7ae0119

Browse files
change layernorm kernel
1 parent 60bb56b commit 7ae0119

File tree

4 files changed

+26
-17
lines changed

4 files changed

+26
-17
lines changed

lightllm/models/cohere/layer_infer/post_layer_infer.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@ def __init__(self, tp_rank, world_size, network_config, mode):
2222
return
2323

2424
def _norm(self, input, infer_state, layer_weight: CoherePreAndPostLayerWeight) -> torch.Tensor:
25-
input = layernorm_forward(
25+
return layernorm_forward(
2626
input.unsqueeze(1), layer_weight.final_norm_weight_.unsqueeze(0), eps=self.eps_
27-
).squeeze_(1)
28-
return input
27+
).squeeze(1)
2928

3029
def _slice_get_last_input(self, input_embdings, infer_state: CohereInferStateInfo):
3130
if infer_state.is_splitfuse:

lightllm/models/cohere/layer_infer/transformer_layer_infer.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
)
77
from lightllm.models.cohere.infer_struct import CohereInferStateInfo
88
from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight
9-
from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward
9+
from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, torch_layernorm
1010
from lightllm.models.cohere.triton_kernels.rotary_emb import rotary_emb_fwd
1111
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
1212
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
@@ -42,16 +42,13 @@ def _bind_rotary_emb_fwd(self):
4242
self._rotary_emb_fwd = partial(CohereTransformerLayerInfer._rotary_emb_fwd, self)
4343

4444
def _att_norm(self, input, infer_state, layer_weight):
45-
input = layernorm_forward(input.unsqueeze(1), layer_weight.att_norm_weight_.unsqueeze(0), self.eps_).squeeze_(1)
46-
return input
45+
return layernorm_forward(input.unsqueeze(1), layer_weight.att_norm_weight_.unsqueeze(0), self.eps_).squeeze(1)
4746

4847
def _q_norm(self, input, infer_state, layer_weight):
49-
input = layernorm_forward(input, layer_weight.q_norm_weight_, self.eps_)
50-
return input
48+
return layernorm_forward(input, layer_weight.q_norm_weight_, self.eps_)
5149

5250
def _k_norm(self, input, infer_state, layer_weight):
53-
input = layernorm_forward(input, layer_weight.k_norm_weight_, self.eps_)
54-
return input
51+
return layernorm_forward(input, layer_weight.k_norm_weight_, self.eps_)
5552

5653
def _bind_norm(self):
5754
self._att_norm = partial(CohereTransformerLayerInfer._att_norm, self)

lightllm/models/cohere/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_cohere_template import (
55
TransformerLayerCohereInferTpl,
66
)
7+
from lightllm.common.mem_manager import MemoryManager
78
from lightllm.models.cohere.infer_struct import CohereInferStateInfo
89
from lightllm.models.cohere.layer_infer.post_layer_infer import CoherePostLayerInfer
910
from lightllm.models.cohere.layer_infer.transformer_layer_infer import CohereTransformerLayerInfer

lightllm/models/cohere/triton_kernels/layernorm.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ def _layer_norm_fwd_kernel(
1313
stride_x_N,
1414
stride_x_hn,
1515
stride_x_hd,
16+
stride_y_N,
17+
stride_y_hn,
18+
stride_y_hd,
1619
stride_w_hn,
1720
stride_w_hd,
1821
N, # number of columns in X
@@ -23,7 +26,7 @@ def _layer_norm_fwd_kernel(
2326
H = tl.program_id(1)
2427

2528
X += Seq * stride_x_N + H * stride_x_hn
26-
Y += Seq * stride_x_N + H * stride_x_hn
29+
Y += Seq * stride_y_N + H * stride_y_hn
2730
W += H * stride_w_hn
2831

2932
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
@@ -63,32 +66,41 @@ def layernorm_forward(
6366
assert X.shape[-1] == W.shape[-1]
6467
assert X.shape[-2] == W.shape[-2]
6568

69+
y = torch.empty_like(X)
70+
6671
stride_x_N = X.stride(0)
6772
stride_x_hn = X.stride(1)
6873
stride_x_hd = X.stride(2)
74+
75+
stride_y_N = y.stride(0)
76+
stride_y_hn = y.stride(1)
77+
stride_y_hd = y.stride(2)
78+
6979
stride_w_hn = W.stride(0)
7080
stride_w_hd = W.stride(1)
81+
7182
N = X.shape[-1]
7283
BLOCK_SIZE = 128
7384

74-
Y = torch.empty_like(X)
75-
7685
grid = (X.shape[0], X.shape[1])
7786
_layer_norm_fwd_kernel[grid](
7887
X,
7988
W,
80-
Y,
89+
y,
8190
stride_x_N,
8291
stride_x_hn,
8392
stride_x_hd,
93+
stride_y_N,
94+
stride_y_hn,
95+
stride_y_hd,
8496
stride_w_hn,
8597
stride_w_hd,
8698
N,
8799
eps,
88100
BLOCK_SIZE,
89101
)
90102

91-
return Y
103+
return y
92104

93105

94106
def torch_layernorm(x, weight, eps):
@@ -104,12 +116,12 @@ def torch_layernorm(x, weight, eps):
104116
def test_layernorm(eps=1e-5):
105117
# create data
106118
dtype = torch.float16
107-
x_shape = (1000, 1, 128)
119+
x_shape = (5, 1, 128)
108120
w_shape = (x_shape[-2], x_shape[-1])
109121
weight = torch.rand(w_shape, dtype=dtype, device="cuda")
110122
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
111123
# forward pass
112-
y_ref = torch_layernorm(x.to(torch.float32), weight.to(torch.float32), eps).to(dtype)
124+
y_ref = torch_layernorm(x, weight, eps).to(dtype)
113125
y_out = layernorm_forward(x, weight, eps)
114126

115127
# compare

0 commit comments

Comments
 (0)