Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 5c1de1c

Browse files
bfineranSara Adkinsdbogunowicz
authored
udpate llama7b_sparse_quantized example (#2322)
* udpate llama7b_sparse_quantized example * one shot llama example * Update examples/llama7b_sparse_quantized/README.md Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> * Fix GPTQ Aliases (#2327) * fix alias application with unit tests * style --------- Co-authored-by: Sara Adkins <sara@neuralmagic.com> Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
1 parent 4e2ad0a commit 5c1de1c

File tree

5 files changed

+189
-65
lines changed

5 files changed

+189
-65
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Creating a Quantized Llama Model in One Shot
2+
3+
Quantizing a model to a lower precision can save on both memory and speed at inference time.
4+
This example demonstrates how to use the SparseML API to quantize a Llama model from 16 bits
5+
to 4 bits and save it to a compressed-tensors format for inference with vLLM.
6+
7+
## Step 1: Select a model and dataset
8+
For this example, we will use a TinyLlama model and the open platypus dataset, however
9+
these can be swapped out for any huggingface compatible models and datasets
10+
11+
```python
12+
model = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
13+
dataset = "open_platypus"
14+
```
15+
16+
## Step 2: Configure a `GPTQModifier`
17+
Modifiers in sparseml are used to apply optimizations to models. In this example we use a
18+
`GPTQModifier` to apply the GPTQ algorithm to our model. We target all `Linear` layers
19+
for 4-bit weight quantization. These options may be swapped out for any valid `QuantizationScheme`.
20+
21+
```python
22+
from sparseml.modifiers.quantization.gptq import GPTQModifier
23+
24+
gptq = GPTQModifier(
25+
targets="Linear",
26+
scheme="W4A16"
27+
)
28+
```
29+
30+
31+
### Step3: One-Shot Compression
32+
33+
The `oneshot` api applies the created modifier to the target model and dataset.
34+
Setting `save_compressed` to True runs the model through `compressed_tensors` compression
35+
after the quantization is completed.
36+
37+
```python
38+
from sparseml.transformers import oneshot
39+
40+
oneshot(
41+
model=model,
42+
dataset=dataset,
43+
recipe=gptq,
44+
save_compressed=True,
45+
output_dir="llama-compressed-example",
46+
overwrite_output_dir=True,
47+
max_seq_length=256,
48+
num_calibration_samples=256,
49+
)
50+
```
Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,84 @@
11
# Creating a Sparse Quantized Llama7b Model
22

3-
The example in this folder runs in multiple stages to create a Llama 7b model with
4-
a 2:4 sparsity pattern and W4A16 post training quantization (PTW). The model is
5-
calibrated and trained with the ultachat200k dataset. At least 75GB of GPU memory is
6-
required to run this example.
3+
This example uses SparseML and Compressed-Tensors to create a 2:4 sparse and quantized Llama2-7b model.
4+
The model is calibrated and trained with the ultachat200k dataset.
5+
At least 75GB of GPU memory is required to run this example.
76

8-
## Recipe Summary
7+
Follow the steps below, or to run the example as `python examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py`
98

10-
The recipe used for this flow is located in [2:4_w4a16_recipe.yaml](./2:4_w4a16_recipe.yaml). It contains 3 stages that are outlined below.
9+
## Step 1: Select a model, dataset, and recipe
10+
In this step, we select which model to use as a baseline for sparsification, a dataset to
11+
use for calibration and finetuning, and a recipe.
1112

13+
Models can reference a local directory, model in the huggingface hub, or in the sparsezoo.
1214

13-
### Stage 1: Sparsification
15+
Datasets can be from a local compatible directory or the huggingface hub.
1416

15-
Runs the SparseGPT one-shot algorithm to prune the model to 50% sparsity with a 2:4
16-
sparsity pattern. This means that 2 weights out of every group of 4 weights are masked to 0.
17+
Recipes are YAML files that describe how a model should be optimized during or after training.
18+
The recipe used for this flow is located in [2:4_w4a16_recipe.yaml](./2:4_w4a16_recipe.yaml).
19+
It contains instructions to prune the model to 2:4 sparsity, run one epoch of recovery finetuning,
20+
and quantize to 4 bits in one show using GPTQ.
1721

18-
### Stage 2: Finetuning Recovery
19-
20-
This stage runs a single epoch of training on the ultrachat200k dataset while maintaining
21-
the sparsity mask from stage 1. The purpose of this stage is to recover any accuracy lost
22-
during the sparsification process.
22+
```python
23+
import torch
24+
from sparseml.transformers import SparseAutoModelForCausalLM
2325

24-
### Stage 3: Quantization
26+
model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base"
27+
model = SparseAutoModelForCausalLM.from_pretrained(
28+
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
29+
)
2530

26-
Finally, we run the GPTQ one-shot algorithm to quantize all linear weights to 4 bit
27-
channelwise.
31+
dataset = "ultrachat-200k"
32+
splits = {"calibration": "train_gen[:5%]", "train": "train_gen"}
2833

29-
## How to Run
34+
recipe = "2:4_w4a16_recipe.yaml"
35+
```
3036

31-
We can run the entire staged recipe with one call to SparseML's `apply` pathway. This
32-
will save a checkpoint of the model after each stage.
37+
## Step 2: Run sparsification using `apply`
38+
The `apply` function applies the given recipe to our model and dataset.
39+
The hardcoded kwargs may be altered based on each model's needs.
40+
After running, the sparsified model will be saved to `output_llama7b_2:4_w4a16_channel`.
41+
42+
```python
43+
from sparseml.transformers import apply
44+
45+
output_dir = "output_llama7b_2:4_w4a16_channel"
46+
47+
apply(
48+
model=model,
49+
dataset=dataset,
50+
recipe=recipe,
51+
bf16=False, # use full precision for training
52+
output_dir=output_dir,
53+
splits=splits,
54+
max_seq_length=512,
55+
num_calibration_samples=512,
56+
num_train_epochs=0.5,
57+
logging_steps=500,
58+
save_steps=5000,
59+
gradient_checkpointing=True,
60+
learning_rate=0.0001,
61+
lr_scheduler_type="cosine",
62+
warmup_ratio=0.1,
63+
)
64+
```
3365

34-
```python examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py```
3566

36-
### Compression
67+
### Step 3: Compression
3768

3869
The resulting model will be uncompressed. To save a final compressed copy of the model
3970
run the following:
4071

41-
```
72+
```python
4273
import torch
4374
from sparseml.transformers import SparseAutoModelForCausalLM
4475

76+
compressed_output_dir = "output_llama7b_2:4_w4a16_channel_compressed"
4577
model = SparseAutoModelForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16)
4678
model.save_pretrained(compressed_output_dir, save_compressed=True)
4779
```
4880

4981
### Custom Quantization
5082
The current repo supports multiple quantization techniques configured using a recipe. Supported strategies are `tensor`, `group` and `channel`.
5183
The above recipe (`2:4_w4a16_recipe.yaml`) uses channel-wise quantization specified by `strategy: "channel"` in its config group.
52-
To use quantize per tensor, change strategy from `channel` to `tensor`. To use group size quantization, change from `channel` to `group` and specify its value, say 128, by including `group_size: 128`. Group size quantization example is shown in `2:4_w4a16_group-128_recipe.yaml`
84+
To use quantize per tensor, change strategy from `channel` to `tensor`. To use group size quantization, change from `channel` to `group` and specify its value, say 128, by including `group_size: 128`. A group size quantization example is shown in `2:4_w4a16_group-128_recipe.yaml`.

src/sparseml/modifiers/quantization/gptq/base.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from pydantic import Field
1919

2020
from compressed_tensors.quantization import (
21-
QuantizationConfig,
2221
QuantizationScheme,
2322
is_preset_scheme,
23+
preset_name_to_scheme,
2424
)
2525
from sparseml.core import Modifier
2626
from sparseml.core.factory import ModifierFactory
@@ -77,6 +77,7 @@ class GPTQModifier(Modifier):
7777
QuantizationScheme except targets, which will be set to the targets parameter
7878
set at the modifier level. Can also be set to a dictionary of the format
7979
`preset_scheme_name: targets` for example: `W8A8: ['Linear']` for weight 8 bit
80+
or a string of a preset scheme if targets is provided
8081
and activation 8 bit quantization on the Linear layers.
8182
"""
8283

@@ -89,7 +90,7 @@ class GPTQModifier(Modifier):
8990
ignore: List[str] = Field(default_factory=list)
9091
disable_quantization_observer_epoch: Optional[float] = None
9192
num_calibration_steps: Optional[int] = None
92-
scheme: Optional[Dict[str, Any]] = None
93+
scheme: Optional[Union[str, Dict[str, Any]]] = None
9394
compressible_layers_: Optional[List] = None
9495
quantization_modifier_: Any = None
9596

@@ -167,32 +168,33 @@ def _build_quant_modifier(self, framework):
167168
if getattr(self, key, False)
168169
}
169170

171+
if isinstance(self.targets, str):
172+
self.targets = [self.targets]
173+
170174
if self.scheme is not None:
171175
# takes precedence over config_groups
172176

173-
if any(is_preset_scheme(key) for key in self.scheme.keys()):
174-
config_groups = QuantizationConfig(
175-
config_groups=self.scheme
176-
).config_groups
177-
quant_args["config_groups"] = config_groups
178-
else:
179-
targets = self.targets or ["Linear"]
180-
config_group = QuantizationScheme.model_validate(
181-
{"targets": targets, **self.scheme}
182-
)
183-
quant_args["config_groups"] = {"config_group_0": config_group}
177+
if isinstance(self.scheme, str) and is_preset_scheme(self.scheme):
178+
# attach targets to scheme
179+
self.scheme = {self.scheme: self.targets}
184180

185-
targets = self.targets or ["Linear"]
186-
config_group = QuantizationScheme.model_validate(
187-
{"targets": targets, **self.scheme}
188-
)
189-
quant_args["config_groups"] = {"config_group_0": config_group}
181+
quant_args["config_groups"] = {}
182+
for idx, key in enumerate(self.scheme.keys()):
183+
if is_preset_scheme(key):
184+
scheme = preset_name_to_scheme(key, self.scheme[key])
185+
else:
186+
scheme = QuantizationScheme.model_validate(
187+
{"targets": self.scheme[key], **self.scheme}
188+
)
189+
190+
group_name = f"group_{idx}"
191+
quant_args["config_groups"][group_name] = scheme
190192

191-
if "config_groups" not in quant_args:
193+
if "config_groups" not in quant_args or len("config_groups") == 0:
192194
default_quant_scheme = QuantizationScheme.default_scheme(
193195
targets=self.targets
194196
)
195-
quant_args["config_groups"] = {"config_group_0": default_quant_scheme}
197+
quant_args["config_groups"] = {"group_0": default_quant_scheme}
196198
_LOGGER.info(f"Building quantization modifier with args: {quant_args}")
197199
vllm_quant_config = {"QuantizationModifier": quant_args}
198200
self._build_quant_modifier_from_dict(vllm_quant_config, framework)

tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_create_default_quant_modifier(self):
9595
modifier.on_initialize_structure(testing_harness.get_state())
9696
assert modifier.quantize
9797
assert isinstance(modifier.quantization_modifier_, QuantizationModifier)
98-
default_config_group_name = "config_group_0"
98+
default_config_group_name = "group_0"
9999
should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[
100100
default_config_group_name
101101
]

tests/sparseml/transformers/gptq/test_oneshot.py

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,57 @@
1616
import shutil
1717
import unittest
1818

19+
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme
20+
from parameterized import parameterized_class
21+
from sparseml.modifiers.quantization.gptq import GPTQModifier
1922
from sparseml.transformers.sparsification.sparse_model import SparseAutoModelForCausalLM
2023
from tests.testing_utils import requires_torch
2124

2225

26+
recipe_str = """
27+
quant_stage:
28+
quant_modifiers:
29+
GPTQModifier:
30+
sequential_update: false
31+
ignore: ["lm_head"]
32+
config_groups:
33+
group_0:
34+
weights:
35+
num_bits: 4
36+
type: "int"
37+
symmetric: true
38+
strategy: "channel"
39+
targets: ["Linear"]
40+
"""
41+
42+
recipe_modifier_full = GPTQModifier(
43+
ignore=["lm_head"],
44+
sequential_update=False,
45+
config_groups={
46+
"group_0": QuantizationScheme(
47+
targets=["Linear"], weights=QuantizationArgs(num_bits=4, strategy="channel")
48+
)
49+
},
50+
)
51+
52+
recipe_modifier_shorthand_a = GPTQModifier(
53+
ignore=["lm_head"], sequential_update=False, targets="Linear", scheme="W4A16"
54+
)
55+
56+
recipe_modifier_shorthand_b = GPTQModifier(
57+
ignore=["lm_head"], sequential_update=False, scheme={"W4A16": ["Linear"]}
58+
)
59+
60+
2361
@requires_torch
62+
@parameterized_class(
63+
[
64+
{"recipe": recipe_str},
65+
{"recipe": recipe_modifier_full},
66+
{"recipe": recipe_modifier_shorthand_a},
67+
{"recipe": recipe_modifier_shorthand_b},
68+
]
69+
)
2470
class TestGPTQOneShotWithFullScheme(unittest.TestCase):
2571
def setUp(self):
2672
import torch
@@ -30,26 +76,6 @@ def setUp(self):
3076
self.dataset = "open_platypus"
3177
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
3278

33-
self.recipe = """
34-
first_stage:
35-
quant_modifiers:
36-
GPTQModifier:
37-
ignore: ["lm_head"]
38-
sequential_update: True
39-
dampening_frac: 0.001
40-
block_size: 128
41-
targets: ["Linear"]
42-
scheme:
43-
input_activations: null
44-
output_activations: null
45-
weights:
46-
num_bits: 8
47-
type: "int"
48-
symmetric: true
49-
strategy: "tensor"
50-
group_size: 128
51-
"""
52-
5379
def test_oneshot_application(self):
5480
from sparseml.transformers import oneshot
5581

@@ -68,9 +94,23 @@ def test_oneshot_application(self):
6894
# Check that the model is quantized
6995
assert model_loaded.quantization_config is not None
7096

97+
# check config is set properly
98+
assert model_loaded.quantization_config.ignore == ["lm_head"]
99+
assert len(model_loaded.quantization_config.config_groups) == 1
100+
quant_scheme = model_loaded.quantization_config.config_groups["group_0"]
101+
assert isinstance(quant_scheme, QuantizationScheme)
102+
assert quant_scheme.targets == ["Linear"]
103+
weight_args = model_loaded.quantization_config.config_groups["group_0"].weights
104+
assert isinstance(weight_args, QuantizationArgs)
105+
assert weight_args.num_bits == 4
106+
71107
# Check a specific layer is quantized
72108
targetted_linear_layer = model_loaded.transformer.h[0].attn.attention.k_proj
73109
assert hasattr(targetted_linear_layer, "quantization_scheme")
74110

111+
# Check lm-head is not quantized
112+
not_targetted = model_loaded.lm_head
113+
assert not hasattr(not_targetted, "quantization_scheme")
114+
75115
def tearDown(self):
76116
shutil.rmtree(self.output)

0 commit comments

Comments
 (0)