Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 7bb3db3

Browse files
authored
[GPTQ UX] Add scheme arg with QuantizationScheme support (#2286)
* Update GHA file to install compressed-tensors from source * Missed commit (#2300) * Remove src from import * Style * Full Scheme support * Add a small test for accepting full scheme
1 parent c672b9a commit 7bb3db3

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

src/sparseml/modifiers/quantization/gptq/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ class GPTQModifier(Modifier):
6868
not be updated. Leave None to not disable observers during QAT. Default is None
6969
:param num_calibration_steps: Number of steps to run post training calibration for.
7070
When None, the entire calibration_dataloader is used
71+
:param scheme: [Used, if a quantization modifier is not specified], the quantization
72+
scheme to apply to the model, this is a dictionary that supports all keys from
73+
QuantizationScheme except targets, which will be set to the targets parameter
74+
set at the modifier level.
7175
"""
7276

7377
sequential_update: Optional[bool] = False
@@ -79,6 +83,7 @@ class GPTQModifier(Modifier):
7983
ignore: List[str] = Field(default_factory=list)
8084
disable_quantization_observer_epoch: Optional[float] = None
8185
num_calibration_steps: Optional[int] = None
86+
scheme: Optional[Dict[str, Any]] = None
8287
compressible_layers_: Optional[List] = None
8388
quantization_modifier_: Any = None
8489

@@ -156,6 +161,14 @@ def _build_quant_modifier(self, framework):
156161
if getattr(self, key, False)
157162
}
158163

164+
if self.scheme is not None:
165+
# takes precedence over config_groups
166+
targets = self.targets or ["Linear"]
167+
config_group = QuantizationScheme.model_validate(
168+
{"targets": targets, **self.scheme}
169+
)
170+
quant_args["config_groups"] = {"config_group_0": config_group}
171+
159172
if "config_groups" not in quant_args:
160173
default_quant_scheme = QuantizationScheme.default_scheme(
161174
targets=self.targets
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import shutil
17+
import unittest
18+
19+
from sparseml.transformers.sparsification.sparse_model import SparseAutoModelForCausalLM
20+
from tests.testing_utils import requires_torch
21+
22+
23+
@requires_torch
24+
class TestGPTQOneShotWithFullScheme(unittest.TestCase):
25+
def setUp(self):
26+
import torch
27+
28+
self.output = "./oneshot_output"
29+
self.model = "roneneldan/TinyStories-1M"
30+
self.dataset = "open_platypus"
31+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
32+
33+
self.recipe = """
34+
first_stage:
35+
quant_modifiers:
36+
GPTQModifier:
37+
ignore: ["lm_head"]
38+
sequential_update: True
39+
dampening_frac: 0.001
40+
block_size: 128
41+
targets: ["Linear"]
42+
scheme:
43+
input_activations: null
44+
output_activations: null
45+
weights:
46+
num_bits: 8
47+
type: "int"
48+
symmetric: true
49+
strategy: "tensor"
50+
group_size: 128
51+
"""
52+
53+
def test_oneshot_application(self):
54+
from sparseml.transformers import oneshot
55+
56+
oneshot(
57+
model=self.model,
58+
dataset=self.dataset,
59+
output_dir=self.output,
60+
overwrite_output_dir=True,
61+
recipe=self.recipe,
62+
oneshot_device=self.device,
63+
num_calibration_samples=9,
64+
)
65+
66+
model_loaded = SparseAutoModelForCausalLM.from_pretrained(self.output)
67+
68+
# Check that the model is quantized
69+
assert model_loaded.quantization_config is not None
70+
71+
# Check a specific layer is quantized
72+
targetted_linear_layer = model_loaded.transformer.h[0].attn.attention.k_proj
73+
assert hasattr(targetted_linear_layer, "quantization_scheme")
74+
75+
def tearDown(self):
76+
shutil.rmtree(self.output)

0 commit comments

Comments
 (0)