Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 56b7854

Browse files
Sara Adkinsdbogunowicz
andauthored
Update Examples to New UX (#2301)
* update examples to use new ux * add sparse model * update paths * update paths * up samples for sparse model * fix recipe * remove extra files * update sparse dtype --------- Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
1 parent 53e98b6 commit 56b7854

File tree

6 files changed

+88
-27
lines changed

6 files changed

+88
-27
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
from datasets import load_dataset
3+
4+
from sparseml.transformers import (
5+
SparseAutoModelForCausalLM,
6+
SparseAutoTokenizer,
7+
oneshot,
8+
)
9+
10+
11+
# define a sparseml recipe for GPTQ W4A16 quantization
12+
recipe = """
13+
quant_stage:
14+
quant_modifiers:
15+
GPTQModifier:
16+
sequential_update: false
17+
ignore: ["lm_head"]
18+
config_groups:
19+
group_0:
20+
weights:
21+
num_bits: 4
22+
type: "int"
23+
symmetric: true
24+
strategy: "channel"
25+
targets: ["Linear"]
26+
"""
27+
28+
# load in a 50% sparse model with 2:4 sparsity structure
29+
# setting device_map to auto to spread the model evenly across all available GPUs
30+
model_stub = "neuralmagic/SparseLlama-2-7b-cnn-daily-mail-pruned_50.2of4"
31+
model = SparseAutoModelForCausalLM.from_pretrained(
32+
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
33+
)
34+
tokenizer = SparseAutoTokenizer.from_pretrained(model_stub)
35+
36+
# for quantization calibration, we will use a subset of the dataset that was used to
37+
# sparsify and finetune the model
38+
dataset = load_dataset("abisee/cnn_dailymail", "1.0.0", split="train[:5%]")
39+
40+
# set dataset config parameters
41+
max_seq_length = 4096
42+
pad_to_max_length = False
43+
num_calibration_samples = 1024
44+
45+
46+
# preprocess the data into a single text entry, then tokenize the dataset
47+
def process_sample(sample):
48+
formatted = "Article:\n{}\n\n### Summarization:\n{}".format(
49+
sample["article"], sample["highlights"]
50+
)
51+
return tokenizer(
52+
formatted, padding=pad_to_max_length, max_length=max_seq_length, truncation=True
53+
)
54+
55+
56+
tokenized_dataset = dataset.map(
57+
process_sample, remove_columns=["article", "highlights", "id"]
58+
)
59+
60+
# save location of quantized model out
61+
output_dir = "./llama7b_sparse_24_w4a16_channel_compressed"
62+
63+
# apply quantization recipe to the model and save quantized output int4 packed format
64+
# the sparsity structure of the original model will be maintained
65+
oneshot(
66+
model=model,
67+
dataset=tokenized_dataset,
68+
recipe=recipe,
69+
output_dir=output_dir,
70+
max_seq_length=max_seq_length,
71+
pad_to_max_length=pad_to_max_length,
72+
num_calibration_samples=num_calibration_samples,
73+
save_compressed=True,
74+
)

examples/llama7b_sparse_quantized/2:4_w4a16_recipe.yaml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ finetuning_stage:
2222
quantization_stage:
2323
run_type: oneshot
2424
quantization_modifiers:
25-
vLLMQuantizationModifier:
25+
GPTQModifier:
26+
sequential_update: false
2627
ignore: ["lm_head"]
2728
config_groups:
2829
group_0:
@@ -32,7 +33,3 @@ quantization_stage:
3233
symmetric: true
3334
strategy: "channel"
3435
targets: ["Linear"]
35-
SparseGPTModifier:
36-
sparsity: 0.0
37-
quantize: True
38-
sequential_update: false

examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
num_calibration_samples = 512
2525

2626
# set training parameters for finetuning
27-
num_train_epochs = 1
27+
num_train_epochs = 0.5
2828
logging_steps = 500
2929
save_steps = 5000
3030
gradient_checkpointing = True # saves memory during training
3131
learning_rate = 0.0001
32-
bf16 = True # using bfloat16 for training
32+
bf16 = False # using full precision for training
3333
lr_scheduler_type = "cosine"
3434
warmup_ratio = 0.1
3535

examples/llama7b_w4a16_quantization.ipynb

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@
2525
"cell_type": "markdown",
2626
"metadata": {},
2727
"source": [
28-
"SparseML uses recipes to define configurations for different oneshot algorithms. Recipes can be defined as a string or a yaml file. Below we create a sample recipe for GPTQ quantization. The recipe is made up of two different algorithms, called modifiers.\n",
28+
"SparseML uses recipes to define configurations for different oneshot algorithms. Recipes can be defined as a string or a yaml file. A recipe consists of one or more sparsification or quantization algorithms, called modifiers in SparseML. Below we create a sample recipe for GPTQ quantization that only requires a single modifier.\n",
2929
"\n",
30-
"1. **vLLMQuantizationModifier**: calibrates the model for quantization by calculating scale and zero points from a small amount of calibration data\n",
31-
"2. **SparseGPTModifier**: applies the GPTQ algorithm, using the result of the vLLMQuantizationModifier to determine the best quantization bin to place each linear weight into"
30+
"This modifier specifies that we should quantize the weights of each linear layer to 4 bits, using a symmetric channelwise quantization pattern. The lm-head will not be quantized even though it is a Linear layer, because it is included in the ignore list."
3231
]
3332
},
3433
{
@@ -37,10 +36,11 @@
3736
"metadata": {},
3837
"outputs": [],
3938
"source": [
40-
"recipe=\"\"\"\n",
39+
"recipe = \"\"\"\n",
4140
"quant_stage:\n",
4241
" quant_modifiers:\n",
43-
" vLLMQuantizationModifier:\n",
42+
" GPTQModifier:\n",
43+
" sequential_update: false\n",
4444
" ignore: [\"lm_head\"]\n",
4545
" config_groups:\n",
4646
" group_0:\n",
@@ -50,10 +50,6 @@
5050
" symmetric: true\n",
5151
" strategy: \"channel\"\n",
5252
" targets: [\"Linear\"]\n",
53-
" SparseGPTModifier:\n",
54-
" sparsity: 0.0\n",
55-
" quantize: True\n",
56-
" sequential_update: false\n",
5753
"\"\"\""
5854
]
5955
},

examples/llama7b_w4a16_quantization.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
recipe = """
88
quant_stage:
99
quant_modifiers:
10-
vLLMQuantizationModifier:
10+
GPTQModifier:
11+
sequential_update: false
1112
ignore: ["lm_head"]
1213
config_groups:
1314
group_0:
@@ -17,10 +18,6 @@
1718
symmetric: true
1819
strategy: "channel"
1920
targets: ["Linear"]
20-
SparseGPTModifier:
21-
sparsity: 0.0
22-
quantize: true
23-
sequential_update: false
2421
"""
2522

2623
# setting device_map to auto to spread the model evenly across all available GPUs

examples/llama7b_w8a8_quantization.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
recipe = """
88
quant_stage:
99
quant_modifiers:
10-
vLLMQuantizationModifier:
10+
GPTQModifier:
11+
sequential_update: false
1112
ignore: ["lm_head"]
1213
config_groups:
1314
group_0:
@@ -23,10 +24,6 @@
2324
dynamic: True
2425
strategy: "token"
2526
targets: ["Linear"]
26-
SparseGPTModifier:
27-
sparsity: 0.0
28-
quantize: true
29-
sequential_update: false
3027
"""
3128

3229
# setting device_map to auto to spread the model evenly across all available GPUs
@@ -40,7 +37,7 @@
4037
dataset = "ultrachat-200k"
4138

4239
# save location of quantized model out
43-
output_dir = "./output_llama7b_w8a8_channel_compressed"
40+
output_dir = "./output_llama7b_w8a8_channel_dynamic_compressed"
4441

4542
# set dataset config parameters
4643
splits = {"calibration": "train_gen[:5%]"}

0 commit comments

Comments
 (0)