Skip to content

Commit 859a48b

Browse files
authored
add mode ppl_int4kv_flashdecoding (#367)
1 parent 986f93d commit 859a48b

File tree

6 files changed

+278
-19
lines changed

6 files changed

+278
-19
lines changed

lightllm/common/mem_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from lightllm.common.mem_manager import MemoryManager
22
from lightllm.common.int8kv_mem_manager import INT8KVMemoryManager
33
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
4+
from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
45
from lightllm.utils.log_utils import init_logger
56

67
logger = init_logger(__name__)
@@ -11,6 +12,9 @@ def select_mem_manager_class(mode):
1112
if "ppl_int8kv" in mode or "ppl_int8kv_flashdecoding" in mode:
1213
memory_manager_class = PPLINT8KVMemoryManager
1314
logger.info(f"Model kv cache using mode {mode}")
15+
elif "ppl_int4kv_flashdecoding" in mode:
16+
memory_manager_class = PPLINT4KVMemoryManager
17+
logger.info(f"Model kv cache using mode {mode}")
1418
elif "triton_int8kv" in mode:
1519
memory_manager_class = INT8KVMemoryManager
1620
logger.info("Model kv cache using mode triton int8kv")
+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
3+
from .mem_manager import MemoryManager
4+
5+
6+
class PPLINT4KVMemoryManager(MemoryManager):
7+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True):
8+
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=True)
9+
10+
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
11+
group_quant_size = 8
12+
self.kv_buffer = [
13+
torch.empty((size, 2 * head_num, head_dim // 2), dtype=torch.int8, device="cuda") for _ in range(layer_num)
14+
]
15+
self.scale_buffer = [
16+
torch.empty((size, 2 * head_num, head_dim // group_quant_size), dtype=dtype, device="cuda")
17+
for _ in range(layer_num)
18+
]
19+
20+
def _free_buffers(self):
21+
self.kv_buffer = None
22+
self.scale_buffer = None

lightllm/models/llama/layer_infer/transformer_layer_infer.py

+30
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ def _bind_attention(self):
6363
LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding, self
6464
)
6565
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self)
66+
elif "ppl_int4kv_flashdecoding" in self.mode:
67+
self._token_attention_kernel = partial(
68+
LlamaTransformerLayerInfer._token_decode_attention_ppl_int4kv_flashdecoding, self
69+
)
70+
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int4kv, self)
6671
elif "ppl_fp16" in self.mode:
6772
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16, self)
6873
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
@@ -298,6 +303,14 @@ def _copy_kv_to_mem_cache_ppl_int8kv(self, buffer, mem_index, mem_manager):
298303
)
299304
return
300305

306+
def _copy_kv_to_mem_cache_ppl_int4kv(self, buffer, mem_index, mem_manager):
307+
from lightllm.models.llama.triton_kernel.ppl_int4kv_copy_kv import destindex_copy_int4kv
308+
309+
destindex_copy_int4kv(
310+
buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_]
311+
)
312+
return
313+
301314
def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
302315
total_token_num = infer_state.total_token_num
303316
batch_size = infer_state.batch_size
@@ -524,3 +537,20 @@ def _token_decode_attention_ppl_int8kv_flashdecoding(
524537
return token_decode_attention_flash_decoding(
525538
q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_k_scale, cache_v, cache_v_scale, out=out
526539
)
540+
541+
def _token_decode_attention_ppl_int4kv_flashdecoding(
542+
self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None
543+
):
544+
from lightllm.models.llama.triton_kernel.ppl_int4kv_flash_decoding import token_decode_attention_flash_decoding
545+
546+
cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
547+
cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
548+
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
549+
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
550+
]
551+
cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][
552+
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
553+
]
554+
return token_decode_attention_flash_decoding(
555+
q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_k_scale, cache_v, cache_v_scale, out=out
556+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def _fwd_kernel_destindex_copy_quantize_int4_kv(
9+
K,
10+
Dest_loc,
11+
Out,
12+
Out_scale,
13+
stride_k_bs,
14+
stride_k_h,
15+
stride_k_g,
16+
stride_k_d,
17+
stride_o_bs,
18+
stride_o_h,
19+
stride_o_g,
20+
stride_o_d,
21+
stride_os_bs,
22+
stride_os_h,
23+
stride_os_g,
24+
group_size,
25+
BLOCK_GROUP_NUM: tl.constexpr,
26+
BLOCK_GROUP_DIM: tl.constexpr,
27+
):
28+
cur_index = tl.program_id(0)
29+
cur_head = tl.program_id(1)
30+
31+
offs_g = tl.arange(0, BLOCK_GROUP_NUM)
32+
offs_d = tl.arange(0, BLOCK_GROUP_DIM // 2)
33+
34+
dest_index = tl.load(Dest_loc + cur_index)
35+
36+
src_data_0 = tl.load(
37+
K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2,
38+
mask=offs_g[:, None] < group_size,
39+
other=0.0,
40+
)
41+
src_data_1 = tl.load(
42+
K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2 + 1,
43+
mask=offs_g[:, None] < group_size,
44+
other=0.0,
45+
)
46+
47+
abs_data_0 = tl.abs(src_data_0)
48+
abs_data_1 = tl.abs(src_data_1)
49+
50+
data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to(tl.float16)
51+
q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8)
52+
q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0)
53+
q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0)
54+
55+
q_src_data_1 = (src_data_1 / data_scale[:, None]).to(tl.int8)
56+
q_src_data_1 = tl.where(q_src_data_1 > 7, 7, q_src_data_1)
57+
q_src_data_1 = tl.where(q_src_data_1 < -7, -7, q_src_data_1)
58+
59+
low_4 = ((q_src_data_0 & 0x80) >> 4) | (q_src_data_0 & 0xF)
60+
high_4 = (((q_src_data_1 & 0x80) >> 4) | (q_src_data_1 & 0xF)) << 4
61+
62+
# tl.device_print(low_4)
63+
# tl.device_print(high_4)
64+
65+
out_data = low_4 | high_4
66+
67+
o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :]
68+
os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g
69+
tl.store(o_ptrs, out_data, mask=offs_g[:, None] < group_size)
70+
tl.store(os_ptrs, data_scale, mask=offs_g < group_size)
71+
return
72+
73+
74+
@torch.no_grad()
75+
def destindex_copy_int4kv(K, DestLoc, Out, Out_scale):
76+
seq_len = DestLoc.shape[0]
77+
head_num = K.shape[1]
78+
head_dim = K.shape[2]
79+
quant_group_dim = 8
80+
81+
assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv"
82+
grid = (seq_len, head_num)
83+
num_warps = 1
84+
85+
group_size = head_dim // quant_group_dim
86+
group_dim = quant_group_dim
87+
88+
K = K.view((K.shape[0], K.shape[1], group_size, group_dim))
89+
Out = Out.view(
90+
Out.shape[0], Out.shape[1], group_size, group_dim // 2
91+
) # OUt 是 int8 类型, 两个int4组一个int8,所以 group_dim // 2
92+
93+
_fwd_kernel_destindex_copy_quantize_int4_kv[grid](
94+
K,
95+
DestLoc,
96+
Out,
97+
Out_scale,
98+
K.stride(0),
99+
K.stride(1),
100+
K.stride(2),
101+
K.stride(3),
102+
Out.stride(0),
103+
Out.stride(1),
104+
Out.stride(2),
105+
Out.stride(3),
106+
Out_scale.stride(0),
107+
Out_scale.stride(1),
108+
Out_scale.stride(2),
109+
group_size,
110+
BLOCK_GROUP_NUM=triton.next_power_of_2(group_size),
111+
BLOCK_GROUP_DIM=group_dim,
112+
num_warps=num_warps,
113+
num_stages=1,
114+
)
115+
return
116+
117+
118+
def test2():
119+
import time
120+
121+
src = torch.randn((1, 1, 16), dtype=torch.float16).cuda()
122+
src[0, 0, :] = torch.tensor([-2, 1, 2, 0, 4, 5, 6, 7, -2, 1, 2, 0, 4, 5, 6, 7]).cuda()
123+
dest_loc = torch.arange(0, 1, dtype=torch.int32).cuda()
124+
value_dest = torch.randn((1, 1, 8), dtype=torch.float16).cuda().to(torch.int8)
125+
scale_dest = torch.randn((1, 1, 2), dtype=torch.float16).cuda()
126+
127+
destindex_copy_int4kv(src, dest_loc, value_dest, scale_dest)
128+
129+
print(value_dest)
130+
print(scale_dest)
131+
132+
133+
if __name__ == "__main__":
134+
test2()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
3+
4+
def token_decode_attention_flash_decoding(
5+
q, infer_state, q_head_num, head_dim, cache_k, cache_k_scale, cache_v, cache_v_scale, out=None
6+
):
7+
BLOCK_SEQ = 256
8+
batch_size = infer_state.batch_size
9+
max_len_in_batch = infer_state.max_len_in_batch
10+
calcu_shape1 = (batch_size, q_head_num, head_dim)
11+
12+
from lightllm_ppl_int4kv_flashdecoding_kernel import group8_int4kv_flashdecoding_stage1
13+
from .flash_decoding_stage2 import flash_decode_stage2
14+
15+
o_tensor = torch.empty_like(q) if out is None else out
16+
17+
if getattr(infer_state, "mid_o", None) is None:
18+
infer_state.mid_o = torch.empty(
19+
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda"
20+
)
21+
infer_state.mid_o_logexpsum = torch.empty(
22+
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda"
23+
)
24+
25+
mid_o = infer_state.mid_o
26+
mid_o_logexpsum = infer_state.mid_o_logexpsum
27+
group8_int4kv_flashdecoding_stage1(
28+
BLOCK_SEQ,
29+
mid_o,
30+
mid_o_logexpsum,
31+
1.0 / (head_dim ** 0.5),
32+
q.view(calcu_shape1),
33+
cache_k,
34+
cache_k_scale,
35+
cache_v,
36+
cache_v_scale,
37+
infer_state.req_manager.req_to_token_indexs,
38+
infer_state.b_req_idx,
39+
infer_state.b_seq_len,
40+
infer_state.max_len_in_batch,
41+
)
42+
43+
flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)
44+
return o_tensor

lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py

+44-19
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,46 @@
66

77
@triton.jit
88
def _fwd_kernel_destindex_copy_quantize_kv(
9-
K, Dest_loc, Out, Out_scale,
10-
stride_k_bs, stride_k_h, stride_k_g, stride_k_d,
11-
stride_o_bs, stride_o_h, stride_o_g, stride_o_d,
12-
stride_os_bs, stride_os_h, stride_os_g,
9+
K,
10+
Dest_loc,
11+
Out,
12+
Out_scale,
13+
stride_k_bs,
14+
stride_k_h,
15+
stride_k_g,
16+
stride_k_d,
17+
stride_o_bs,
18+
stride_o_h,
19+
stride_o_g,
20+
stride_o_d,
21+
stride_os_bs,
22+
stride_os_h,
23+
stride_os_g,
1324
group_size,
1425
BLOCK_GROUP_NUM: tl.constexpr,
15-
BLOCK_GROUP_DIM: tl.constexpr
26+
BLOCK_GROUP_DIM: tl.constexpr,
1627
):
1728
cur_index = tl.program_id(0)
1829
cur_head = tl.program_id(1)
19-
30+
2031
offs_g = tl.arange(0, BLOCK_GROUP_NUM)
2132
offs_d = tl.arange(0, BLOCK_GROUP_DIM)
2233

2334
dest_index = tl.load(Dest_loc + cur_index)
2435

25-
src_data = tl.load(K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :],
26-
mask=offs_g[:, None] < group_size, other=0.0)
36+
src_data = tl.load(
37+
K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :],
38+
mask=offs_g[:, None] < group_size,
39+
other=0.0,
40+
)
2741
abs_data = tl.abs(src_data)
28-
data_scale = (tl.max(abs_data, axis=1) / 127.).to(tl.float16)
42+
data_scale = (tl.max(abs_data, axis=1) / 127.0).to(tl.float16)
2943
q_src_data = (src_data / data_scale[:, None]).to(tl.int8)
30-
31-
o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :]
44+
45+
o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :]
3246
os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g
33-
tl.store(o_ptrs, q_src_data, mask=offs_g[:, None]<group_size)
34-
tl.store(os_ptrs, data_scale)
47+
tl.store(o_ptrs, q_src_data, mask=offs_g[:, None] < group_size)
48+
tl.store(os_ptrs, data_scale, mask=offs_g < group_size)
3549
return
3650

3751

@@ -53,13 +67,24 @@ def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale):
5367
Out = Out.view(Out.shape[0], Out.shape[1], group_size, group_dim)
5468

5569
_fwd_kernel_destindex_copy_quantize_kv[grid](
56-
K, DestLoc, Out, Out_scale,
57-
K.stride(0), K.stride(1), K.stride(2), K.stride(3),
58-
Out.stride(0), Out.stride(1), Out.stride(2), Out.stride(3),
59-
Out_scale.stride(0), Out_scale.stride(1), Out_scale.stride(2),
70+
K,
71+
DestLoc,
72+
Out,
73+
Out_scale,
74+
K.stride(0),
75+
K.stride(1),
76+
K.stride(2),
77+
K.stride(3),
78+
Out.stride(0),
79+
Out.stride(1),
80+
Out.stride(2),
81+
Out.stride(3),
82+
Out_scale.stride(0),
83+
Out_scale.stride(1),
84+
Out_scale.stride(2),
6085
group_size,
6186
BLOCK_GROUP_NUM=triton.next_power_of_2(group_size),
62-
BLOCK_GROUP_DIM=group_dim,
87+
BLOCK_GROUP_DIM=group_dim,
6388
num_warps=num_warps,
6489
num_stages=1,
6590
)
@@ -93,5 +118,5 @@ def test2():
93118
print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32)))
94119

95120

96-
if __name__ == '__main__':
121+
if __name__ == "__main__":
97122
test2()

0 commit comments

Comments
 (0)