-
Notifications
You must be signed in to change notification settings - Fork 0
/
0001-Feat-Add-support-for-kleidiai-quantization-schemes.patch
407 lines (387 loc) · 18.1 KB
/
0001-Feat-Add-support-for-kleidiai-quantization-schemes.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
From 2e7db1250fecad8fbc239ae9e3dcbac8dba2a62d Mon Sep 17 00:00:00 2001
From: Nikhil Gupta <nikhil.gupta2@arm.com>
Date: Wed, 14 Aug 2024 15:45:44 +0000
Subject: [PATCH 1/1] [Feat]: Add support for kleidiai quantization schemes
Description:
1. Allow Int4WeightOnlyQuantizer to work with channelwise and groupwise
symmetric quantization schemes
2. KleidiAI supports channelwise and 32 groupwise quantized matmul
kernels
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
---
torchao/quantization/GPTQ.py | 144 +++++++++++++++++++++++++---------
torchao/quantization/utils.py | 112 ++++++++++++++++++++++++++
2 files changed, 217 insertions(+), 39 deletions(-)
diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py
index e45bb26..8a23663 100644
--- a/torchao/quantization/GPTQ.py
+++ b/torchao/quantization/GPTQ.py
@@ -36,6 +36,7 @@ from .utils import (
groupwise_affine_dequantize_tensor_from_qparams,
pack_tinygemm_scales_and_zeros,
groupwise_affine_quantize_tensor,
+ prepare_int4_weight_and_scales_and_zeros,
)
aten = torch.ops.aten
@@ -546,6 +547,20 @@ def linear_forward_int4(
c = c.reshape(new_shape)
return c
+def linear_forward_int4_symmetric_groupwise(x, weight_int4pack, out_features, in_features):
+ origin_x_size = x.size()
+ c = torch.ops.aten._kai_input_quant_mm_int4(x, weight_int4pack, x.shape[-2],out_features, in_features, 32)
+ new_shape = origin_x_size[:-1] + (out_features,)
+ c = c.reshape(new_shape)
+ return c
+
+def linear_forward_int4_symmetric_channelwise(x, weight_int4pack, out_features, in_features):
+ origin_x_size = x.size()
+ c = torch.ops.aten._kai_input_quant_mm_int4(x, weight_int4pack, x.shape[-2],out_features, in_features, 0)
+ new_shape = origin_x_size[:-1] + (out_features,)
+ c = c.reshape(new_shape)
+ return c
+
class WeightOnlyInt4Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
@@ -557,9 +572,10 @@ class WeightOnlyInt4Linear(torch.nn.Module):
# TODO: remove dtype field, not used
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8,
precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16,
+ scheme=None,
) -> None:
super().__init__()
- self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
+ self.padding = scheme is None and not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
if self.padding:
from .utils import find_multiple
self.origin_in_features = in_features
@@ -573,35 +589,58 @@ class WeightOnlyInt4Linear(torch.nn.Module):
self.inner_k_tiles = inner_k_tiles
self.precision = precision
self.scales_precision = scales_precision
+ self.scheme = scheme
if dtype is not None:
raise ValueError("Please specify 'precision' instead of 'dtype'")
-
- assert out_features % 8 == 0, "require out_features % 8 == 0"
- assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
- self.register_buffer(
+ if scheme is None:
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
+ assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
+ self.register_buffer(
+ "weight",
+ torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32, device=device)
+ )
+ self.dtype = dtype
+ self.register_buffer(
+ "scales_and_zeros",
+ torch.empty((in_features // groupsize, out_features, 2), dtype=self.scales_precision, device=device)
+ )
+ elif scheme == "symmetric_groupwise":
+ self.register_buffer(
"weight",
- torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32, device=device)
- )
- self.dtype = dtype
- self.register_buffer(
- "scales_and_zeros",
- torch.empty((in_features // groupsize, out_features, 2), dtype=self.scales_precision, device=device)
- )
+ torch.empty((torch.ops.aten.get_kai_weight_pack_int4_size(out_features,in_features,groupsize)), dtype=torch.uint8)
+ )
+ self.register_buffer(
+ "scales_and_zeros",
+ torch.empty((0), dtype=self.scales_precision)
+ )
+ elif scheme == "symmetric_channelwise":
+ self.register_buffer(
+ "weight",
+ torch.empty((torch.ops.aten.get_kai_weight_pack_int4_size(out_features,in_features,groupsize)), dtype=torch.uint8)
+ )
+ self.register_buffer(
+ "scales_and_zeros",
+ torch.empty((out_features), dtype=self.scales_precision)
+ )
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.padding:
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
- return linear_forward_int4(
- input,
- self.weight,
- self.scales_and_zeros,
- self.out_features,
- self.groupsize,
- self.precision,
- self.scales_precision,
- )
-
+ if self.scheme is None:
+ return linear_forward_int4(
+ input,
+ self.weight,
+ self.scales_and_zeros,
+ self.out_features,
+ self.groupsize,
+ self.precision,
+ self.scales_precision,
+ )
+ elif self.scheme == "symmetric_groupwise":
+ return linear_forward_int4_symmetric_groupwise(input, self.weight, self.out_features, self.in_features)
+ elif self.scheme == "symmetric_channelwise":
+ return linear_forward_int4_symmetric_channelwise(input, self.weight, self.out_features, self.in_features)
def _replace_linear_int4(
module: torch.nn.Module,
@@ -613,10 +652,19 @@ def _replace_linear_int4(
scales_precision: torch.dtype = torch.bfloat16,
linear_class: Type[torch.nn.Module] = WeightOnlyInt4Linear,
copy_weights: bool = False,
+ scheme=None,
):
for name, child in module.named_children():
if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)):
- if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
+ if scheme is not None or _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
+ if scheme == "symmetric_channelwise":
+ groupsize = child.in_features
+ scales_precision = torch.float32
+ elif scheme == "symmetric_groupwise":
+ # Scales are actually f16 but they are packed along with weights
+ # To maintain api compatibility we populate scaled with no element in this scheme
+ scales_precision = torch.float32
+
new_linear = linear_class(
child.in_features,
child.out_features,
@@ -626,6 +674,7 @@ def _replace_linear_int4(
inner_k_tiles=inner_k_tiles,
precision=precision,
scales_precision=scales_precision,
+ scheme=scheme,
)
# TODO: merge with 8da4w?
# In distributed training, the model may be instantiated
@@ -645,6 +694,7 @@ def _replace_linear_int4(
scales_precision,
linear_class,
copy_weights,
+ scheme=scheme,
)
@@ -667,10 +717,11 @@ class Int4WeightOnlyQuantizer(Quantizer):
inner_k_tiles: Optional[int] = 8,
device: torch.device = torch.device("cuda"),
precision: torch.dtype = torch.bfloat16,
+ scheme = None
) -> None:
super().__init__()
assert inner_k_tiles in [2, 4, 8]
- assert groupsize in [32, 64, 128, 256]
+ assert groupsize in [0, 32, 64, 128, 256] # 0 group size is allowed for channelwise scheme where groupsize = row size
self.inner_k_tiles = inner_k_tiles
self.groupsize: int = groupsize
@@ -678,6 +729,7 @@ class Int4WeightOnlyQuantizer(Quantizer):
self.device: torch.device = device
# precision and dtype are being used interchangeably here
self.precision: torch.dtype = precision
+ self.scheme = scheme
@torch.no_grad()
def _create_quantized_state_dict(
@@ -692,12 +744,13 @@ class Int4WeightOnlyQuantizer(Quantizer):
# assert out_features % 8 == 0, "require out_features % 8 == 0"
logging.info(f"linear: {fqn}, in={in_features}, out={out_features}")
- assert (
- in_features % self.groupsize == 0
- ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0"
+ if self.scheme is None:
+ assert (
+ in_features % self.groupsize == 0
+ ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0"
weight = mod.weight.data
- if not _check_linear_int4_k(
+ if self.scheme is None and not _check_linear_int4_k(
in_features, self.groupsize, self.inner_k_tiles
):
if self.padding_allowed:
@@ -710,17 +763,29 @@ class Int4WeightOnlyQuantizer(Quantizer):
logging.warn(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
"and that groupsize and inner_k_tiles*16 evenly divide into it")
continue
- (
- w_int4x8,
- scales_and_zeros
- ) = groupwise_affine_quantize_tensor(
- weight,
- 4, # n_bit
- self.groupsize,
- self.precision, # dtype for scales_and_zeros
- )
- # TODO: just get the device from mod.weight.device?
- weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles)
+ if self.scheme == "symmetric_channelwise" or self.scheme == "symmetric_groupwise" :
+ (
+ w_int4x8,
+ scales_and_zeros
+ ) = prepare_int4_weight_and_scales_and_zeros(
+ weight.to(self.precision), self.groupsize, self.inner_k_tiles, self.scheme, precision=self.precision
+ )
+ if self.scheme == "symmetric_channelwise":
+ weight_int4pack = torch.ops.aten._kai_weight_pack_int4(w_int4x8.to(self.device),scales_and_zeros,mod.out_features,mod.in_features,0)
+ elif self.scheme == "symmetric_groupwise":
+ weight_int4pack = torch.ops.aten._kai_weight_pack_int4(w_int4x8.to(self.device),scales_and_zeros,mod.out_features,mod.in_features,self.groupsize)
+ else:
+ (
+ w_int4x8,
+ scales_and_zeros
+ ) = groupwise_affine_quantize_tensor(
+ weight,
+ 4, # n_bit
+ self.groupsize,
+ self.precision, # dtype for scales_and_zeros
+ )
+ # TODO: just get the device from mod.weight.device?
+ weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device)
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(self.device)
return cur_state_dict
@@ -734,6 +799,7 @@ class Int4WeightOnlyQuantizer(Quantizer):
skip_layer_func=None,
precision=self.precision,
scales_precision=self.precision,
+ scheme=self.scheme,
)
return model
diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py
index d5f9ed1..f2d9d26 100644
--- a/torchao/quantization/utils.py
+++ b/torchao/quantization/utils.py
@@ -367,6 +367,7 @@ def groupwise_affine_quantize_tensor_from_qparams(
int_data = int_data.to(device='mps')
return int_data
+
def groupwise_affine_dequantize_tensor_from_qparams(
w_int4x8,
scales,
@@ -396,6 +397,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
quant_max = 2**n_bit - 1
return dequantize_affine(w_int32, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)
+
def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize, dtype)
w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(
@@ -405,6 +407,116 @@ def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bflo
return w_int4x8, scales_and_zeros
+def get_group_qparams(w, n_bit=4, groupsize=128, scheme="symmetric_channelwise",precision=torch.bfloat16):
+ if groupsize > w.shape[-1] or scheme == "symmetric_channelwise":
+ groupsize = w.shape[-1]
+ assert groupsize > 1
+ assert w.shape[-1] % groupsize == 0
+ assert w.dim() == 2
+
+ to_quant = w.reshape(-1, groupsize)
+ assert torch.isnan(to_quant).sum() == 0
+
+ # improved symmetric 4 bit quantization that uses bin correspondingto -8 from [-8,7] ( -2^(b-1) , 2^(b-1)-1 ) range
+ if scheme == "symmetric_groupwise":
+ to_quant_abs = to_quant.abs()
+ max_abs_indices = to_quant_abs.argmax(dim=1, keepdim=True)
+ max_val = torch.gather(to_quant, 1, max_abs_indices)
+ scales = max_val / -8
+ zeros = torch.zeros_like(scales)
+
+ elif scheme == "symmetric_channelwise":
+ to_quant_abs = to_quant.abs()
+ max_abs_indices = to_quant_abs.argmax(dim=1, keepdim=True)
+ max_val = torch.gather(to_quant, 1, max_abs_indices)
+ scales = max_val / -8
+ zeros = torch.zeros_like(scales)
+ return scales.to(precision).reshape(w.shape[0], -1), zeros.to(
+ precision
+ ).reshape(w.shape[0], -1)
+
+
+def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128, scheme="symmetric_channelwise", precision=torch.bfloat16):
+ assert groupsize > 1
+ if groupsize > w.shape[-1] and scales.shape[-1] == 1:
+ groupsize = w.shape[-1]
+
+ assert w.shape[-1] % groupsize == 0
+ assert w.dim() == 2
+
+ to_quant = w.reshape(-1, groupsize)
+ assert torch.isnan(to_quant).sum() == 0
+
+ scales = scales.reshape(-1, 1)
+ zeros = zeros.reshape(-1, 1)
+ if scheme == "symmetric_groupwise":
+ half_block_size = groupsize//2
+ max_int = 2**n_bit - 1
+ w_int8 = (
+ to_quant.div(scales)
+ .add(8.5)
+ .to(torch.int8)
+ .clamp(max=max_int)
+ )
+ # For KleidiAI q8si32 x q4si32 kernel we do packing as below. The kernel performs [ RHS * LHS^T = Output^T ]
+ # NOTE : Packing ith ( lower 4 bits ) and groupsize//2 + ith ( upper 4 bits ) 4 bit values together in uint8
+ w_uint8 = torch.zeros((w_int8.shape[-2], half_block_size), dtype=torch.uint8)
+ # Process each pair of values
+ for i in range(half_block_size):
+ # Get values from the first and second half of the block
+ first_half_values = w_int8[:, i]
+ second_half_values = w_int8[:, i + half_block_size]
+
+ # Combine values: second half values in upper 4 bits, first half values in lower 4 bits
+ combined = (second_half_values.to(torch.int32) << 4) | (first_half_values.to(torch.int32) & 0x0F)
+
+ # Store the combined values in w_uint8
+ w_uint8[:, i] = combined.to(torch.uint8)
+
+ # Pack one 16 bit scale as two 8 bits and store in front of each group
+ scales = scales.to(dtype=torch.float16)
+ scales = scales.view(torch.uint8)
+ w_uint8 = torch.cat((scales,w_uint8),dim=1)
+ elif scheme == "symmetric_channelwise":
+ max_int = 2**n_bit - 1
+ w_int8 = (
+ to_quant.div(scales)
+ .add(8.5)
+ .to(torch.int8)
+ .clamp(max=max_int)
+ )
+ # We pack every odd index value in upper 4 bits and every even index value in lower 4 bits
+ w_uint8 = (w_int8[::,1::2] << 4 | w_int8[::,::2] ).to(torch.uint8)
+ return w_uint8
+
+
+def pack_scales_and_zeros(scales, zeros, scheme="symmetric_channelwise"):
+ assert scales.shape == zeros.shape
+ if scheme == "symmetric_groupwise":
+ # We pack scales inside weight tensor for this scheme
+ scales_zeros = torch.empty(0)
+ elif scheme == "symmetric_channelwise":
+ scales_zeros = scales.squeeze().contiguous()
+
+ return scales_zeros
+
+
+def group_quantize_tensor(w, n_bit=4, groupsize=128, scheme="symmetric_channelwise", precision=torch.bfloat16):
+ scales, zeros = get_group_qparams(w, n_bit, groupsize, scheme=scheme, precision=precision)
+ w_uint8 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize, scheme=scheme)
+ scales_and_zeros = pack_scales_and_zeros(scales, zeros, scheme=scheme)
+ return w_uint8, scales_and_zeros
+
+
+def prepare_int4_weight_and_scales_and_zeros(weights, groupsize, inner_k_tiles, scheme="symmetric_channelwise", precision=torch.bfloat16):
+ if groupsize > weights.shape[-1] or scheme == "symmetric_channelwise":
+ groupsize = weights.shape[-1]
+ weight_int4pack, scales_and_zeros = group_quantize_tensor(
+ weights, n_bit=4, groupsize=groupsize, scheme=scheme, precision=precision
+ )
+ return weight_int4pack, scales_and_zeros
+
+
def groupwise_affine_dequantize_tensor(
w_int4x8,
scales_and_zeros,
--
2.34.1