Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 Quantization Support #2306

Closed
wants to merge 14 commits into from
39 changes: 39 additions & 0 deletions examples/llama7b_fp8_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from datasets import load_dataset
from transformers import AutoTokenizer

from sparseml.modifiers import GPTQModifier
from sparseml.transformers import SparseAutoModelForCausalLM, oneshot


model_stub = "meta-llama/Meta-Llama-3-8B-Instruct"
output_dir = "Meta-Llama-3-8B-Instruct-FP8-Compressed"
num_calibration_samples = 512

tokenizer = AutoTokenizer.from_pretrained(model_stub, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token


def preprocess(batch):
text = tokenizer.apply_chat_template(batch["messages"], tokenize=False)
tokenized = tokenizer(text, padding=True, truncation=True, max_length=2048)
return tokenized


ds = load_dataset("mgoin/ultrachat_2k", split="train_sft")
examples = ds.map(preprocess, remove_columns=ds.column_names)

recipe = GPTQModifier(targets=["Linear"], scheme="FP8", ignore=["lm_head"])

model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)

oneshot(
model=model,
dataset=examples,
recipe=recipe,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
save_compressed=True,
)
9 changes: 4 additions & 5 deletions examples/llama7b_w8a8_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@
num_bits: 8
type: "int"
symmetric: true
strategy: "channel"
strategy: "tensor"
input_activations:
num_bits: 8
type: "int"
symmetric: true
dynamic: True
strategy: "token"
strategy: "tensor"
targets: ["Linear"]
"""

# setting device_map to auto to spread the model evenly across all available GPUs
# load the model in as bfloat16 to save on memory and compute
model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base"
model_stub = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)
Expand All @@ -37,7 +36,7 @@
dataset = "ultrachat-200k"

# save location of quantized model out
output_dir = "./output_llama7b_w8a8_channel_dynamic_compressed"
output_dir = "./TEST_MAIN_BRANCH_TENSOR"

# set dataset config parameters
splits = {"calibration": "train_gen[:5%]"}
Expand Down
2 changes: 2 additions & 0 deletions src/sparseml/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
# limitations under the License.

# flake8: noqa

from .gptq import *
20 changes: 12 additions & 8 deletions src/sparseml/transformers/compression/quantization_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ def infer_quantization_format(
return quantization_format

if save_compressed:
quant_depths = _get_quant_depths(model)
if quant_depths == [4]: # save packed if everything is int4
quant_types = _get_quant_types(model)
if quant_types == ["int4"]: # save packed if everything is int4
return CompressionFormat.pack_quantized
elif quant_types == ["float8"]:
return CompressionFormat.float_quantized

# otherwise just quantize to int8
return CompressionFormat.int_quantized
Expand All @@ -57,17 +59,19 @@ def infer_quantization_format(
return None


def _get_quant_depths(model):
def _get_quant_types(model):
"""
Gets a list of all the quantized bit depths present in model
Gets a list of all the quantized types present in model
"""
quant_depths = []
quant_info = []
for _, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
weight_scheme = submodule.quantization_scheme.weights
if weight_scheme is not None:
weight_bit_depth = weight_scheme.num_bits
if weight_bit_depth not in quant_depths:
quant_depths.append(weight_bit_depth)
weight_type = weight_scheme.type
weight_info = f"{weight_type}{weight_bit_depth}"
if weight_info not in quant_info:
quant_info.append(weight_info)

return quant_depths
return quant_info
19 changes: 19 additions & 0 deletions tests/sparseml/transformers/compression/recipes/new_quant_fp8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
quant_stage:
quant_modifiers:
GPTQModifier:
sequential_update: false
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: "float"
symmetric: true
strategy: channel
input_activations:
num_bits: 8
type: "float"
symmetric: true
dynamic: true
strategy: token
targets: ["Linear"]
184 changes: 184 additions & 0 deletions tests/sparseml/transformers/compression/test_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import unittest

import torch
from torch.utils.data import DataLoader
from transformers import DefaultDataCollator

from compressed_tensors.quantization.utils import is_module_quantized
from parameterized import parameterized_class
from sparseml.pytorch.utils import tensors_to_device
from sparseml.transformers import (
SparseAutoModelForCausalLM,
SparseAutoTokenizer,
oneshot,
)
from sparseml.transformers.finetune.data import TextGenerationDataset
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from tests.testing_utils import requires_gpu, requires_torch


@requires_torch
@requires_gpu
@parameterized_class(
("recipe", "ppl_threshold"),
[("tests/sparseml/transformers/compression/recipes/new_quant_fp8.yaml", 5000)],
)
class TestQuantizationMatches(unittest.TestCase):
recipe = None
ppl_threshold = None
# TODO: use "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" for nightly
# or weekly runs, but this smaller model is better for commit testing
model_stub = "Xenova/llama2.c-stories15M"
dataset = "ultrachat-200k"
output = "tiny_llama_out"
max_seq_length = 512
weight_dtype = torch.float16
num_eval = 64

@classmethod
def setUpClass(cls):
cls.test_dir = tempfile.mkdtemp()

cls.model = SparseAutoModelForCausalLM.from_pretrained(
cls.model_stub, torch_dtype=cls.weight_dtype, device_map="cuda:0"
)
cls._run_oneshot(
cls.model,
cls.recipe,
cls.dataset,
os.path.join(cls.test_dir, cls.output),
)

@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.test_dir)
del cls.model
torch.cuda.empty_cache()

@staticmethod
def _run_oneshot(model, recipe, dataset, output_dir):
num_calibration_samples = 512
max_seq_length = 512
pad_to_max_length = False

oneshot(
model=model,
dataset=dataset,
overwrite_output_dir=True,
output_dir=output_dir,
max_seq_length=max_seq_length,
num_calibration_samples=num_calibration_samples,
recipe=recipe,
pad_to_max_length=pad_to_max_length,
clear_sparse_session=True,
splits={"calibration": "train_gen[:5%]"},
)

def _get_quant_info(self, model):
quant_info_weights = {}
quant_info_inputs = {}
for name, module in model.named_modules():
if is_module_quantized(module):
if module.quantization_scheme.weights is not None:
quant_info_weights[name] = (
module.weight_scale,
module.weight_zero_point,
module.weight,
)

if module.quantization_scheme.input_activations is not None:
is_dynamic = module.quantization_scheme.input_activations.dynamic
if not is_dynamic:
quant_info_inputs[name] = (
module.input_scale,
module.input_zero_point,
)

return quant_info_weights, quant_info_inputs

def test_quantization_reload(self):
model_reloaded = SparseAutoModelForCausalLM.from_pretrained(
os.path.join(self.test_dir, self.output),
torch_dtype="auto",
device_map="cuda:0",
)

og_weights, og_inputs = self._get_quant_info(self.model)
reloaded_weights, reloaded_inputs = self._get_quant_info(model_reloaded)

for name, (o_scale, o_zp, o_weight) in og_weights.items():
n_scale, n_zp, n_weight = reloaded_weights[name]
assert o_scale.dtype == n_scale.dtype == self.weight_dtype
assert torch.equal(o_scale, n_scale)
assert o_zp.dtype == n_zp.dtype == torch.float8_e4m3fn
assert torch.equal(o_zp, n_zp)

# we don't expect an exact match here because o_weight still has the
# original weight and n_weight has been fake_quantized
assert n_weight.dtype == o_weight.dtype == self.weight_dtype

for name, (o_scale, o_zp) in og_inputs.items():
n_scale, n_zp = reloaded_inputs[name]
assert o_scale.dtype == n_scale.dtype == self.weight_dtype
assert torch.equal(o_scale, n_scale)
assert o_zp.dtype == n_zp.dtype == torch.float8_e4m3fn
assert torch.equal(o_zp, n_zp)

def _get_dataloader(self, data_args, tokenizer):
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split="train_gen[:5%]",
tokenizer=tokenizer,
)
calib_dataset = dataset_manager.tokenize_and_process(
dataset_manager.get_raw_dataset()
)
data_loader = DataLoader(
calib_dataset,
batch_size=1,
collate_fn=DefaultDataCollator(),
sampler=torch.utils.data.RandomSampler(calib_dataset),
)

return data_loader

@torch.no_grad()
def test_perplexity(self):
tokenizer = SparseAutoTokenizer.from_pretrained(self.model_stub)
data_args = DataTrainingArguments(
dataset="ultrachat-200k",
max_seq_length=self.max_seq_length,
)
dataloader = self._get_dataloader(data_args, tokenizer)

total_ppl = 0.0
total_non_nan = 0
for idx, sample in enumerate(dataloader):
if idx >= self.num_eval:
break
output = self.model(**tensors_to_device(sample, "cuda:0"))
if torch.isnan(output.loss):
continue
total_ppl += torch.exp(output.loss).item()
total_non_nan += 1

avg_ppl = total_ppl / total_non_nan
assert avg_ppl <= self.ppl_threshold
Loading