diff --git a/examples/llama7b_fp8_quantization.py b/examples/llama7b_fp8_quantization.py new file mode 100644 index 00000000000..e8979737cbf --- /dev/null +++ b/examples/llama7b_fp8_quantization.py @@ -0,0 +1,39 @@ +import torch +from datasets import load_dataset +from transformers import AutoTokenizer + +from sparseml.modifiers import GPTQModifier +from sparseml.transformers import SparseAutoModelForCausalLM, oneshot + + +model_stub = "meta-llama/Meta-Llama-3-8B-Instruct" +output_dir = "Meta-Llama-3-8B-Instruct-FP8-Compressed" +num_calibration_samples = 512 + +tokenizer = AutoTokenizer.from_pretrained(model_stub, use_fast=True) +tokenizer.pad_token = tokenizer.eos_token + + +def preprocess(batch): + text = tokenizer.apply_chat_template(batch["messages"], tokenize=False) + tokenized = tokenizer(text, padding=True, truncation=True, max_length=2048) + return tokenized + + +ds = load_dataset("mgoin/ultrachat_2k", split="train_sft") +examples = ds.map(preprocess, remove_columns=ds.column_names) + +recipe = GPTQModifier(targets=["Linear"], scheme="FP8", ignore=["lm_head"]) + +model = SparseAutoModelForCausalLM.from_pretrained( + model_stub, torch_dtype=torch.bfloat16, device_map="auto" +) + +oneshot( + model=model, + dataset=examples, + recipe=recipe, + output_dir=output_dir, + num_calibration_samples=num_calibration_samples, + save_compressed=True, +) diff --git a/examples/llama7b_w8a8_quantization.py b/examples/llama7b_w8a8_quantization.py index c894613ffbb..6ffd4b0f623 100644 --- a/examples/llama7b_w8a8_quantization.py +++ b/examples/llama7b_w8a8_quantization.py @@ -16,19 +16,18 @@ num_bits: 8 type: "int" symmetric: true - strategy: "channel" + strategy: "tensor" input_activations: num_bits: 8 type: "int" symmetric: true - dynamic: True - strategy: "token" + strategy: "tensor" targets: ["Linear"] """ # setting device_map to auto to spread the model evenly across all available GPUs # load the model in as bfloat16 to save on memory and compute -model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base" +model_stub = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" model = SparseAutoModelForCausalLM.from_pretrained( model_stub, torch_dtype=torch.bfloat16, device_map="auto" ) @@ -37,7 +36,7 @@ dataset = "ultrachat-200k" # save location of quantized model out -output_dir = "./output_llama7b_w8a8_channel_dynamic_compressed" +output_dir = "./TEST_MAIN_BRANCH_TENSOR" # set dataset config parameters splits = {"calibration": "train_gen[:5%]"} diff --git a/src/sparseml/modifiers/quantization/__init__.py b/src/sparseml/modifiers/quantization/__init__.py index ebdf28a6d5b..fe2676473d5 100644 --- a/src/sparseml/modifiers/quantization/__init__.py +++ b/src/sparseml/modifiers/quantization/__init__.py @@ -13,3 +13,5 @@ # limitations under the License. # flake8: noqa + +from .gptq import * diff --git a/src/sparseml/transformers/compression/quantization_format.py b/src/sparseml/transformers/compression/quantization_format.py index 48d47d68d22..8b099b7e36b 100644 --- a/src/sparseml/transformers/compression/quantization_format.py +++ b/src/sparseml/transformers/compression/quantization_format.py @@ -46,9 +46,11 @@ def infer_quantization_format( return quantization_format if save_compressed: - quant_depths = _get_quant_depths(model) - if quant_depths == [4]: # save packed if everything is int4 + quant_types = _get_quant_types(model) + if quant_types == ["int4"]: # save packed if everything is int4 return CompressionFormat.pack_quantized + elif quant_types == ["float8"]: + return CompressionFormat.float_quantized # otherwise just quantize to int8 return CompressionFormat.int_quantized @@ -57,17 +59,19 @@ def infer_quantization_format( return None -def _get_quant_depths(model): +def _get_quant_types(model): """ - Gets a list of all the quantized bit depths present in model + Gets a list of all the quantized types present in model """ - quant_depths = [] + quant_info = [] for _, submodule in iter_named_leaf_modules(model): if is_module_quantized(submodule): weight_scheme = submodule.quantization_scheme.weights if weight_scheme is not None: weight_bit_depth = weight_scheme.num_bits - if weight_bit_depth not in quant_depths: - quant_depths.append(weight_bit_depth) + weight_type = weight_scheme.type + weight_info = f"{weight_type}{weight_bit_depth}" + if weight_info not in quant_info: + quant_info.append(weight_info) - return quant_depths + return quant_info diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_fp8.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_fp8.yaml new file mode 100644 index 00000000000..54b24871663 --- /dev/null +++ b/tests/sparseml/transformers/compression/recipes/new_quant_fp8.yaml @@ -0,0 +1,19 @@ +quant_stage: + quant_modifiers: + GPTQModifier: + sequential_update: false + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 8 + type: "float" + symmetric: true + strategy: channel + input_activations: + num_bits: 8 + type: "float" + symmetric: true + dynamic: true + strategy: token + targets: ["Linear"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/test_fp8.py b/tests/sparseml/transformers/compression/test_fp8.py new file mode 100644 index 00000000000..aa98b2e39f5 --- /dev/null +++ b/tests/sparseml/transformers/compression/test_fp8.py @@ -0,0 +1,184 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest + +import torch +from torch.utils.data import DataLoader +from transformers import DefaultDataCollator + +from compressed_tensors.quantization.utils import is_module_quantized +from parameterized import parameterized_class +from sparseml.pytorch.utils import tensors_to_device +from sparseml.transformers import ( + SparseAutoModelForCausalLM, + SparseAutoTokenizer, + oneshot, +) +from sparseml.transformers.finetune.data import TextGenerationDataset +from sparseml.transformers.finetune.data.data_args import DataTrainingArguments +from tests.testing_utils import requires_gpu, requires_torch + + +@requires_torch +@requires_gpu +@parameterized_class( + ("recipe", "ppl_threshold"), + [("tests/sparseml/transformers/compression/recipes/new_quant_fp8.yaml", 5000)], +) +class TestQuantizationMatches(unittest.TestCase): + recipe = None + ppl_threshold = None + # TODO: use "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" for nightly + # or weekly runs, but this smaller model is better for commit testing + model_stub = "Xenova/llama2.c-stories15M" + dataset = "ultrachat-200k" + output = "tiny_llama_out" + max_seq_length = 512 + weight_dtype = torch.float16 + num_eval = 64 + + @classmethod + def setUpClass(cls): + cls.test_dir = tempfile.mkdtemp() + + cls.model = SparseAutoModelForCausalLM.from_pretrained( + cls.model_stub, torch_dtype=cls.weight_dtype, device_map="cuda:0" + ) + cls._run_oneshot( + cls.model, + cls.recipe, + cls.dataset, + os.path.join(cls.test_dir, cls.output), + ) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.test_dir) + del cls.model + torch.cuda.empty_cache() + + @staticmethod + def _run_oneshot(model, recipe, dataset, output_dir): + num_calibration_samples = 512 + max_seq_length = 512 + pad_to_max_length = False + + oneshot( + model=model, + dataset=dataset, + overwrite_output_dir=True, + output_dir=output_dir, + max_seq_length=max_seq_length, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + pad_to_max_length=pad_to_max_length, + clear_sparse_session=True, + splits={"calibration": "train_gen[:5%]"}, + ) + + def _get_quant_info(self, model): + quant_info_weights = {} + quant_info_inputs = {} + for name, module in model.named_modules(): + if is_module_quantized(module): + if module.quantization_scheme.weights is not None: + quant_info_weights[name] = ( + module.weight_scale, + module.weight_zero_point, + module.weight, + ) + + if module.quantization_scheme.input_activations is not None: + is_dynamic = module.quantization_scheme.input_activations.dynamic + if not is_dynamic: + quant_info_inputs[name] = ( + module.input_scale, + module.input_zero_point, + ) + + return quant_info_weights, quant_info_inputs + + def test_quantization_reload(self): + model_reloaded = SparseAutoModelForCausalLM.from_pretrained( + os.path.join(self.test_dir, self.output), + torch_dtype="auto", + device_map="cuda:0", + ) + + og_weights, og_inputs = self._get_quant_info(self.model) + reloaded_weights, reloaded_inputs = self._get_quant_info(model_reloaded) + + for name, (o_scale, o_zp, o_weight) in og_weights.items(): + n_scale, n_zp, n_weight = reloaded_weights[name] + assert o_scale.dtype == n_scale.dtype == self.weight_dtype + assert torch.equal(o_scale, n_scale) + assert o_zp.dtype == n_zp.dtype == torch.float8_e4m3fn + assert torch.equal(o_zp, n_zp) + + # we don't expect an exact match here because o_weight still has the + # original weight and n_weight has been fake_quantized + assert n_weight.dtype == o_weight.dtype == self.weight_dtype + + for name, (o_scale, o_zp) in og_inputs.items(): + n_scale, n_zp = reloaded_inputs[name] + assert o_scale.dtype == n_scale.dtype == self.weight_dtype + assert torch.equal(o_scale, n_scale) + assert o_zp.dtype == n_zp.dtype == torch.float8_e4m3fn + assert torch.equal(o_zp, n_zp) + + def _get_dataloader(self, data_args, tokenizer): + dataset_manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split="train_gen[:5%]", + tokenizer=tokenizer, + ) + calib_dataset = dataset_manager.tokenize_and_process( + dataset_manager.get_raw_dataset() + ) + data_loader = DataLoader( + calib_dataset, + batch_size=1, + collate_fn=DefaultDataCollator(), + sampler=torch.utils.data.RandomSampler(calib_dataset), + ) + + return data_loader + + @torch.no_grad() + def test_perplexity(self): + tokenizer = SparseAutoTokenizer.from_pretrained(self.model_stub) + data_args = DataTrainingArguments( + dataset="ultrachat-200k", + max_seq_length=self.max_seq_length, + ) + dataloader = self._get_dataloader(data_args, tokenizer) + + total_ppl = 0.0 + total_non_nan = 0 + for idx, sample in enumerate(dataloader): + if idx >= self.num_eval: + break + output = self.model(**tensors_to_device(sample, "cuda:0")) + if torch.isnan(output.loss): + continue + total_ppl += torch.exp(output.loss).item() + total_non_nan += 1 + + avg_ppl = total_ppl / total_non_nan + assert avg_ppl <= self.ppl_threshold