Skip to content

Commit 8e34dad

Browse files
huakaigohuakaigo
and
huakaigo
authored
support llama-quik, a w4a4 quantization method (#386)
Support w4a4 quantization method for llama model. The quantization method is from [QUIK paper](https://arxiv.org/abs/2310.09259), and the implementation also refers to [QUIK github](https://github.com/IST-DASLab/QUIK). --------- Co-authored-by: huakaigo <liuahuakai@sensetime.com>
1 parent a231505 commit 8e34dad

File tree

10 files changed

+934
-0
lines changed

10 files changed

+934
-0
lines changed

lightllm/models/llama_quik/__init__.py

Whitespace-only changes.

lightllm/models/llama_quik/cuda_kernel/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import torch
2+
from typing import Tuple
3+
4+
def CONTIGUOUS_TENSOR(tensor: torch.Tensor):
5+
""" Helper function """
6+
if tensor.is_contiguous(): return tensor
7+
else: return tensor.contiguous()
8+
9+
def int4Matmul(
10+
input: torch.Tensor,
11+
weight: torch.Tensor,
12+
) -> torch.Tensor:
13+
from lightllm_quik_kernel.matmul import int4Matmul
14+
return int4Matmul(
15+
CONTIGUOUS_TENSOR(input), CONTIGUOUS_TENSOR(weight))
16+
17+
def int8Matmul(
18+
input: torch.Tensor,
19+
weight: torch.Tensor,
20+
) -> torch.Tensor:
21+
from lightllm_quik_kernel.matmul import int8Matmul
22+
return int8Matmul(
23+
CONTIGUOUS_TENSOR(input), CONTIGUOUS_TENSOR(weight))
24+
25+
def asym_quantize(
26+
src: torch.Tensor,
27+
int_indices: torch.Tensor,
28+
fp_indices: torch.Tensor,
29+
bits: int,
30+
)->Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
31+
"""
32+
asymmetric quantization for activations of QUIK
33+
34+
Returns:
35+
quantized_weights: int8 tensor of shape (rows, len(int_indices)) for bits=8 or (rows, len(int_indices)/2) for bits=4
36+
meta: float16 tensor of shape (2*rows,), layout is [scale, zero, scale, zero ...]
37+
float_weights: float16 tensor of shape (rows, len(fp_indices))
38+
"""
39+
from lightllm_quik_kernel.asymmetric import quantize
40+
41+
return quantize(
42+
CONTIGUOUS_TENSOR(src),
43+
CONTIGUOUS_TENSOR(int_indices),
44+
CONTIGUOUS_TENSOR(fp_indices),
45+
bits)
46+
47+
def asym_dequantize(
48+
int_result: torch.Tensor,
49+
act_meta: torch.Tensor,
50+
weight_scale: torch.Tensor,
51+
wReduced: torch.Tensor,
52+
fp_result: torch.Tensor,
53+
bits: int
54+
)->torch.Tensor:
55+
"""
56+
asymmetric dequantization for activations of QUIK
57+
58+
Args:
59+
int_result: the result of matmul(q_act, q_weight)
60+
act_meta: the packed tensor of activation scale and zero, layout is [scale, zero, ...]
61+
weight_scale: the tensor of weight scales
62+
wReduced: the constant term for dequantization
63+
fp_result: the result of matmul(fp_act, fp_weight)
64+
bits: 4 or 8
65+
Returns:
66+
the tensor of dequantization result
67+
"""
68+
from lightllm_quik_kernel.asymmetric import dequantize
69+
70+
return dequantize(
71+
CONTIGUOUS_TENSOR(int_result),
72+
CONTIGUOUS_TENSOR(act_meta),
73+
CONTIGUOUS_TENSOR(weight_scale),
74+
CONTIGUOUS_TENSOR(wReduced),
75+
CONTIGUOUS_TENSOR(fp_result),
76+
bits)
77+
78+
def sym_quantize(
79+
src: torch.Tensor,
80+
scale: torch.Tensor,
81+
bits: int
82+
)->torch.Tensor:
83+
"""
84+
symmetric quantization for activations of QUIK
85+
86+
Args:
87+
src: the tensor to be quantized
88+
scale: the scale tensor calculated externally using S = (fp_max - fp_min)/ (q_max - q_min)
89+
Example: (torch.max(torch.abs(x), dim=1)[0].unsqueeze(1) / (1 << (bits - 1) - 1)).to(torch.float16)
90+
bits: 4 or 8
91+
Returns:
92+
the quantized tensor
93+
"""
94+
from lightllm_quik_kernel.symmetric import quantize
95+
96+
return quantize(
97+
CONTIGUOUS_TENSOR(src),
98+
CONTIGUOUS_TENSOR(scale),
99+
bits)
100+
101+
def sym_dequantize(
102+
int_result: torch.Tensor,
103+
act_scale: torch.Tensor,
104+
weight_scale: torch.Tensor,
105+
fp_result: torch.Tensor,
106+
bits: int
107+
)->torch.Tensor:
108+
"""
109+
symmetric dequantization for activations of QUIK
110+
111+
Args:
112+
int_result: the result of matmul(q_act, q_weight)
113+
act_scale: the tensor of weight scales
114+
weight_scale: the tensor of weight scales
115+
fp_result: the result of matmul(fp_act, fp_weight)
116+
bits: 4 or 8
117+
118+
Return:
119+
the dequantized result
120+
"""
121+
from lightllm_quik_kernel.symmetric import dequantize
122+
123+
return dequantize(
124+
CONTIGUOUS_TENSOR(int_result),
125+
CONTIGUOUS_TENSOR(act_scale),
126+
CONTIGUOUS_TENSOR(weight_scale),
127+
CONTIGUOUS_TENSOR(fp_result),
128+
bits)

lightllm/models/llama_quik/layer_infer/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from typing import Tuple
2+
3+
import numpy as np
4+
import torch
5+
import torch.distributed as dist
6+
import torch.functional as F
7+
import triton
8+
from functools import partial
9+
10+
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
11+
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
12+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
13+
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
14+
from lightllm.models.llama_quik.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeightQuik
15+
from lightllm.utils.infer_utils import mark_cost_time
16+
17+
18+
class LlamaTransformerLayerInferQuik(LlamaTransformerLayerInfer):
19+
""" """
20+
21+
def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
22+
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
23+
self.eps_ = network_config["rms_norm_eps"]
24+
self.tp_q_head_num_ = network_config["num_attention_heads"] // self.world_size_
25+
self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.world_size_
26+
self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.world_size_
27+
self.tp_o_head_num_ = self.tp_q_head_num_
28+
self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"]
29+
self.embed_dim_ = network_config["hidden_size"]
30+
31+
self.inter_dim_ = network_config["intermediate_size"]
32+
self._bind_func()
33+
return
34+
35+
def _get_qkv(
36+
self,
37+
input: torch.Tensor,
38+
cache_kv: torch.Tensor,
39+
infer_state: LlamaInferStateInfo,
40+
layer_weight: LlamaTransformerLayerWeightQuik,
41+
) -> torch.Tensor:
42+
q = layer_weight.q_proj(input.view(-1, self.embed_dim_))
43+
if layer_weight.cat_kv_:
44+
cache_kv = layer_weight.kv_proj(input.view(-1, self.embed_dim_)).view(-1, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_)
45+
else:
46+
cache_k = layer_weight.k_proj(input.view(-1, self.embed_dim_)).view(-1, self.tp_k_head_num_, self.head_dim_)
47+
cache_v = layer_weight.v_proj(input.view(-1, self.embed_dim_)).view(-1, self.tp_v_head_num_, self.head_dim_)
48+
cache_kv = torch.cat([cache_k, cache_v], dim=1)
49+
rotary_emb_fwd(
50+
q.view(-1, self.tp_q_head_num_, self.head_dim_),
51+
cache_kv[:, 0 : self.tp_k_head_num_, :],
52+
infer_state.position_cos,
53+
infer_state.position_sin,
54+
)
55+
return q, cache_kv
56+
57+
def _get_o(
58+
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeightQuik
59+
) -> torch.Tensor:
60+
return layer_weight.o_proj(input.view(-1, self.embed_dim_))
61+
62+
def _ffn(
63+
self,
64+
input,
65+
infer_state: LlamaInferStateInfo,
66+
layer_weight: LlamaTransformerLayerWeightQuik,
67+
) -> torch.Tensor:
68+
if not layer_weight.cat_gate_up_:
69+
gate_out = layer_weight.gate_proj(input.view(-1, self.embed_dim_))
70+
up_out = layer_weight.up_proj(input.view(-1, self.embed_dim_))
71+
torch.nn.functional.silu(gate_out, inplace=True)
72+
gate_out.mul_(up_out)
73+
input = None
74+
ffn2_out = layer_weight.down_proj(gate_out)
75+
gate_out, up_out = None, None
76+
else:
77+
gate_up_out = layer_weight.gate_up_proj(input.view(-1, self.embed_dim_)).view(-1, self.inter_dim_ * 2 // self.world_size_)
78+
# gate_out, up_out = torch.split(gate_up_out, split_size_or_sections=1, dim=1)
79+
ffn1_out = silu_and_mul_fwd(gate_up_out)
80+
input = None
81+
gate_up_out = None
82+
ffn2_out = layer_weight.down_proj(ffn1_out)
83+
ffn1_out = None
84+
85+
return ffn2_out
86+
87+
@mark_cost_time(
88+
"trans context flash forward time cost"
89+
) # dont to remove this, will make performence down, did not know why
90+
def _context_attention(self, input_embding, infer_state: LlamaInferStateInfo, layer_weight):
91+
input1 = self._att_norm(input_embding, infer_state, layer_weight)
92+
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
93+
q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight)
94+
input1 = None
95+
self._post_cache_kv(cache_kv, infer_state, layer_weight)
96+
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
97+
q = None
98+
o = self._get_o(o, infer_state, layer_weight)
99+
if self.world_size_ > 1:
100+
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
101+
input_embding.add_(o.view(-1, self.embed_dim_))
102+
return
103+
104+
@mark_cost_time(
105+
"trans context ffn forward time cost"
106+
) # dont to remove this, will make performence down, did not know why
107+
def _context_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight):
108+
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
109+
ffn_out = self._ffn(input1, infer_state, layer_weight)
110+
input1 = None
111+
if self.world_size_ > 1:
112+
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
113+
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
114+
return
115+
116+
# this impl dont to use @mark_cost_time
117+
def _token_attention(self, input_embding, infer_state: LlamaInferStateInfo, layer_weight):
118+
input1 = self._att_norm(input_embding, infer_state, layer_weight)
119+
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
120+
q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight)
121+
input1 = None
122+
self._post_cache_kv(cache_kv, infer_state, layer_weight)
123+
o = self._token_attention_kernel(q, infer_state, layer_weight)
124+
q = None
125+
o = self._get_o(o, infer_state, layer_weight)
126+
if self.world_size_ > 1:
127+
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
128+
input_embding.add_(o.view(-1, self.embed_dim_))
129+
return
130+
131+
# this impl dont to use @mark_cost_time
132+
def _token_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight):
133+
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
134+
ffn_out = self._ffn(input1, infer_state, layer_weight)
135+
input1 = None
136+
if self.world_size_ > 1:
137+
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
138+
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
139+
return

lightllm/models/llama_quik/layer_weights/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)