From d03bdfe25d3587e985b4e017d846b535749a3a66 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 17 May 2024 17:45:39 +0000 Subject: [PATCH 1/6] infer fp8 format --- .../compression/quantization_format.py | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/src/sparseml/transformers/compression/quantization_format.py b/src/sparseml/transformers/compression/quantization_format.py index 5f8f8722753..8b099b7e36b 100644 --- a/src/sparseml/transformers/compression/quantization_format.py +++ b/src/sparseml/transformers/compression/quantization_format.py @@ -16,7 +16,11 @@ from typing import Optional from compressed_tensors import CompressionFormat -from compressed_tensors.quantization.utils import is_model_quantized +from compressed_tensors.quantization.utils import ( + is_model_quantized, + is_module_quantized, + iter_named_leaf_modules, +) __all__ = ["infer_quantization_format"] @@ -42,7 +46,32 @@ def infer_quantization_format( return quantization_format if save_compressed: + quant_types = _get_quant_types(model) + if quant_types == ["int4"]: # save packed if everything is int4 + return CompressionFormat.pack_quantized + elif quant_types == ["float8"]: + return CompressionFormat.float_quantized + + # otherwise just quantize to int8 return CompressionFormat.int_quantized else: # format will be inferred from config return None + + +def _get_quant_types(model): + """ + Gets a list of all the quantized types present in model + """ + quant_info = [] + for _, submodule in iter_named_leaf_modules(model): + if is_module_quantized(submodule): + weight_scheme = submodule.quantization_scheme.weights + if weight_scheme is not None: + weight_bit_depth = weight_scheme.num_bits + weight_type = weight_scheme.type + weight_info = f"{weight_type}{weight_bit_depth}" + if weight_info not in quant_info: + quant_info.append(weight_info) + + return quant_info From 214efc08ae9b1d095e39ef5f1d28067fceb00c9c Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 28 May 2024 17:07:20 +0000 Subject: [PATCH 2/6] integration tests --- .../compression/recipes/new_quant_fp8.yaml | 19 ++ .../transformers/compression/test_fp8.py | 176 ++++++++++++++++++ 2 files changed, 195 insertions(+) create mode 100644 tests/sparseml/transformers/compression/recipes/new_quant_fp8.yaml create mode 100644 tests/sparseml/transformers/compression/test_fp8.py diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_fp8.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_fp8.yaml new file mode 100644 index 00000000000..54b24871663 --- /dev/null +++ b/tests/sparseml/transformers/compression/recipes/new_quant_fp8.yaml @@ -0,0 +1,19 @@ +quant_stage: + quant_modifiers: + GPTQModifier: + sequential_update: false + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 8 + type: "float" + symmetric: true + strategy: channel + input_activations: + num_bits: 8 + type: "float" + symmetric: true + dynamic: true + strategy: token + targets: ["Linear"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/test_fp8.py b/tests/sparseml/transformers/compression/test_fp8.py new file mode 100644 index 00000000000..74403bc40bf --- /dev/null +++ b/tests/sparseml/transformers/compression/test_fp8.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest + +import torch +from torch.utils.data import DataLoader +from transformers import DefaultDataCollator + +from compressed_tensors.quantization.utils import is_module_quantized +from parameterized import parameterized_class +from sparseml.pytorch.utils import tensors_to_device +from sparseml.transformers import ( + SparseAutoModelForCausalLM, + SparseAutoTokenizer, + oneshot, +) +from sparseml.transformers.finetune.data import TextGenerationDataset +from sparseml.transformers.finetune.data.data_args import DataTrainingArguments +from tests.testing_utils import requires_gpu, requires_torch + + +@requires_torch +@requires_gpu +@parameterized_class( + ("recipe", "ppl_threshold"), + [("tests/sparseml/transformers/compression/recipes/new_quant_fp8.yaml", 5000)], +) +class TestQuantizationMatches(unittest.TestCase): + recipe = None + ppl_threshold = None + # TODO: use "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" for nightly + # or weekly runs, but this smaller model is better for commit testing + model_stub = "Xenova/llama2.c-stories15M" + dataset = "ultrachat-200k" + output = "tiny_llama_out" + max_seq_length = 512 + num_eval = 64 + + @classmethod + def setUpClass(cls): + cls.test_dir = tempfile.mkdtemp() + + cls.model = SparseAutoModelForCausalLM.from_pretrained( + cls.model_stub, device_map="cuda:0" + ) + cls._run_oneshot( + cls.model, + cls.recipe, + cls.dataset, + os.path.join(cls.test_dir, cls.output), + ) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.test_dir) + del cls.model + torch.cuda.empty_cache() + + @staticmethod + def _run_oneshot(model, recipe, dataset, output_dir): + num_calibration_samples = 512 + max_seq_length = 512 + pad_to_max_length = False + + oneshot( + model=model, + dataset=dataset, + overwrite_output_dir=True, + output_dir=output_dir, + max_seq_length=max_seq_length, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + pad_to_max_length=pad_to_max_length, + clear_sparse_session=True, + splits={"calibration": "train_gen[:5%]"}, + ) + + def _get_quant_info(self, model): + quant_info_weights = {} + quant_info_inputs = {} + for name, module in model.named_modules(): + if is_module_quantized(module): + if module.quantization_scheme.weights is not None: + quant_info_weights[name] = ( + module.weight_scale, + module.weight_zero_point, + ) + + if module.quantization_scheme.input_activations is not None: + is_dynamic = module.quantization_scheme.input_activations.dynamic + if not is_dynamic: + quant_info_inputs[name] = ( + module.input_scale, + module.input_zero_point, + ) + + return quant_info_weights, quant_info_inputs + + def test_quantization_reload(self): + model_reloaded = SparseAutoModelForCausalLM.from_pretrained( + os.path.join(self.test_dir, self.output), device_map="cuda:0" + ) + + og_weights, og_inputs = self._get_quant_info(self.model) + reloaded_weights, reloaded_inputs = self._get_quant_info(model_reloaded) + + for name, (o_scale, o_zp) in og_weights.items(): + n_scale, n_zp = reloaded_weights[name] + assert o_scale.dtype == n_scale.dtype + assert torch.equal(o_scale, n_scale) + assert o_zp.dtype == n_zp.dtype + assert torch.equal(o_zp, n_zp) + + for name, (o_scale, o_zp) in og_inputs.items(): + n_scale, n_zp = reloaded_inputs[name] + assert o_scale.dtype == n_scale.dtype + assert torch.equal(o_scale, n_scale) + assert o_zp.dtype == n_zp.dtype + assert torch.equal(o_zp, n_zp) + + def _get_dataloader(self, data_args, tokenizer): + dataset_manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split="train_gen[:5%]", + tokenizer=tokenizer, + ) + calib_dataset = dataset_manager.tokenize_and_process( + dataset_manager.get_raw_dataset() + ) + data_loader = DataLoader( + calib_dataset, + batch_size=1, + collate_fn=DefaultDataCollator(), + sampler=torch.utils.data.RandomSampler(calib_dataset), + ) + + return data_loader + + @torch.no_grad() + def test_perplexity(self): + tokenizer = SparseAutoTokenizer.from_pretrained(self.model_stub) + data_args = DataTrainingArguments( + dataset="ultrachat-200k", + max_seq_length=self.max_seq_length, + ) + dataloader = self._get_dataloader(data_args, tokenizer) + + total_ppl = 0.0 + total_non_nan = 0 + for idx, sample in enumerate(dataloader): + if idx >= self.num_eval: + break + output = self.model(**tensors_to_device(sample, "cuda:0")) + if torch.isnan(output.loss): + continue + total_ppl += torch.exp(output.loss).item() + total_non_nan += 1 + + avg_ppl = total_ppl / total_non_nan + assert avg_ppl <= self.ppl_threshold From 3aa99d1c9ec51e70b9aee077d8a5e20f46d71be1 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 28 May 2024 19:35:20 +0000 Subject: [PATCH 3/6] udpate tests --- .../transformers/compression/test_fp8.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/sparseml/transformers/compression/test_fp8.py b/tests/sparseml/transformers/compression/test_fp8.py index 74403bc40bf..8125afbabfe 100644 --- a/tests/sparseml/transformers/compression/test_fp8.py +++ b/tests/sparseml/transformers/compression/test_fp8.py @@ -49,6 +49,7 @@ class TestQuantizationMatches(unittest.TestCase): dataset = "ultrachat-200k" output = "tiny_llama_out" max_seq_length = 512 + weight_dtype = torch.bfloat16 num_eval = 64 @classmethod @@ -56,7 +57,7 @@ def setUpClass(cls): cls.test_dir = tempfile.mkdtemp() cls.model = SparseAutoModelForCausalLM.from_pretrained( - cls.model_stub, device_map="cuda:0" + cls.model_stub, torch_dtype=cls.weight_dtype, device_map="cuda:0" ) cls._run_oneshot( cls.model, @@ -99,6 +100,7 @@ def _get_quant_info(self, model): quant_info_weights[name] = ( module.weight_scale, module.weight_zero_point, + module.weight, ) if module.quantization_scheme.input_activations is not None: @@ -113,24 +115,30 @@ def _get_quant_info(self, model): def test_quantization_reload(self): model_reloaded = SparseAutoModelForCausalLM.from_pretrained( - os.path.join(self.test_dir, self.output), device_map="cuda:0" + os.path.join(self.test_dir, self.output), + torch_dtype="auto", + device_map="cuda:0", ) og_weights, og_inputs = self._get_quant_info(self.model) reloaded_weights, reloaded_inputs = self._get_quant_info(model_reloaded) - for name, (o_scale, o_zp) in og_weights.items(): - n_scale, n_zp = reloaded_weights[name] - assert o_scale.dtype == n_scale.dtype + for name, (o_scale, o_zp, o_weight) in og_weights.items(): + n_scale, n_zp, n_weight = reloaded_weights[name] + assert o_scale.dtype == n_scale.dtype == self.weight_dtype assert torch.equal(o_scale, n_scale) - assert o_zp.dtype == n_zp.dtype + assert o_zp.dtype == n_zp.dtype == self.weight_dtype assert torch.equal(o_zp, n_zp) + # we don't expect an exact match here because o_weight still has the + # original weight and n_weight has been fake_quantized + assert n_weight.dtype == o_weight.dtype == self.weight_dtype + for name, (o_scale, o_zp) in og_inputs.items(): n_scale, n_zp = reloaded_inputs[name] - assert o_scale.dtype == n_scale.dtype + assert o_scale.dtype == n_scale.dtype == self.weight_dtype assert torch.equal(o_scale, n_scale) - assert o_zp.dtype == n_zp.dtype + assert o_zp.dtype == n_zp.dtype == self.weight_dtype assert torch.equal(o_zp, n_zp) def _get_dataloader(self, data_args, tokenizer): From bb625d4a490bd93117e190688ca1920c96492195 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 29 May 2024 21:22:03 +0000 Subject: [PATCH 4/6] update tests --- tests/sparseml/transformers/compression/test_fp8.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/sparseml/transformers/compression/test_fp8.py b/tests/sparseml/transformers/compression/test_fp8.py index 8125afbabfe..aa98b2e39f5 100644 --- a/tests/sparseml/transformers/compression/test_fp8.py +++ b/tests/sparseml/transformers/compression/test_fp8.py @@ -49,7 +49,7 @@ class TestQuantizationMatches(unittest.TestCase): dataset = "ultrachat-200k" output = "tiny_llama_out" max_seq_length = 512 - weight_dtype = torch.bfloat16 + weight_dtype = torch.float16 num_eval = 64 @classmethod @@ -127,7 +127,7 @@ def test_quantization_reload(self): n_scale, n_zp, n_weight = reloaded_weights[name] assert o_scale.dtype == n_scale.dtype == self.weight_dtype assert torch.equal(o_scale, n_scale) - assert o_zp.dtype == n_zp.dtype == self.weight_dtype + assert o_zp.dtype == n_zp.dtype == torch.float8_e4m3fn assert torch.equal(o_zp, n_zp) # we don't expect an exact match here because o_weight still has the @@ -138,7 +138,7 @@ def test_quantization_reload(self): n_scale, n_zp = reloaded_inputs[name] assert o_scale.dtype == n_scale.dtype == self.weight_dtype assert torch.equal(o_scale, n_scale) - assert o_zp.dtype == n_zp.dtype == self.weight_dtype + assert o_zp.dtype == n_zp.dtype == torch.float8_e4m3fn assert torch.equal(o_zp, n_zp) def _get_dataloader(self, data_args, tokenizer): From 793d50db82bb79ff354b22bacf8f2324715b5a54 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 7 Jun 2024 17:51:44 +0000 Subject: [PATCH 5/6] fp8 example --- examples/llama7b_fp8_quantization.py | 58 ++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 examples/llama7b_fp8_quantization.py diff --git a/examples/llama7b_fp8_quantization.py b/examples/llama7b_fp8_quantization.py new file mode 100644 index 00000000000..abc87aefe1b --- /dev/null +++ b/examples/llama7b_fp8_quantization.py @@ -0,0 +1,58 @@ +import torch + +from sparseml.transformers import SparseAutoModelForCausalLM, oneshot + + +# define a sparseml recipe for GPTQ FP8 quantization +recipe = """ +quant_stage: + quant_modifiers: + GPTQModifier: + sequential_update: false + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 8 + type: "float" + symmetric: true + strategy: "tensor" + input_activations: + num_bits: 8 + type: "float" + symmetric: true + strategy: "tensor" + targets: ["Linear"] +""" + +# setting device_map to auto to spread the model evenly across all available GPUs +# load the model in as bfloat16 to save on memory and compute +model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base" +model = SparseAutoModelForCausalLM.from_pretrained( + model_stub, torch_dtype=torch.bfloat16, device_map="auto" +) + +# uses SparseML's built-in preprocessing for ultra chat +dataset = "ultrachat-200k" + +# save location of quantized model out +output_dir = "./output_llama7b_fp8_compressed" + +# set dataset config parameters +splits = {"calibration": "train_gen[:5%]"} +max_seq_length = 512 +pad_to_max_length = False +num_calibration_samples = 512 + +# apply recipe to the model and save quantized output in fp8 format +oneshot( + model=model, + dataset=dataset, + recipe=recipe, + output_dir=output_dir, + splits=splits, + max_seq_length=max_seq_length, + pad_to_max_length=pad_to_max_length, + num_calibration_samples=num_calibration_samples, + save_compressed=True, +) From 46dc4181a0a515b2ccc5e045d0edaeb12fa11a66 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 14 Jun 2024 21:36:46 +0000 Subject: [PATCH 6/6] update examples --- examples/llama7b_fp8_quantization.py | 65 +++++++------------ examples/llama7b_w8a8_quantization.py | 9 ++- .../modifiers/quantization/__init__.py | 2 + 3 files changed, 29 insertions(+), 47 deletions(-) diff --git a/examples/llama7b_fp8_quantization.py b/examples/llama7b_fp8_quantization.py index abc87aefe1b..e8979737cbf 100644 --- a/examples/llama7b_fp8_quantization.py +++ b/examples/llama7b_fp8_quantization.py @@ -1,58 +1,39 @@ import torch +from datasets import load_dataset +from transformers import AutoTokenizer +from sparseml.modifiers import GPTQModifier from sparseml.transformers import SparseAutoModelForCausalLM, oneshot -# define a sparseml recipe for GPTQ FP8 quantization -recipe = """ -quant_stage: - quant_modifiers: - GPTQModifier: - sequential_update: false - ignore: ["lm_head"] - config_groups: - group_0: - weights: - num_bits: 8 - type: "float" - symmetric: true - strategy: "tensor" - input_activations: - num_bits: 8 - type: "float" - symmetric: true - strategy: "tensor" - targets: ["Linear"] -""" - -# setting device_map to auto to spread the model evenly across all available GPUs -# load the model in as bfloat16 to save on memory and compute -model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base" -model = SparseAutoModelForCausalLM.from_pretrained( - model_stub, torch_dtype=torch.bfloat16, device_map="auto" -) +model_stub = "meta-llama/Meta-Llama-3-8B-Instruct" +output_dir = "Meta-Llama-3-8B-Instruct-FP8-Compressed" +num_calibration_samples = 512 -# uses SparseML's built-in preprocessing for ultra chat -dataset = "ultrachat-200k" +tokenizer = AutoTokenizer.from_pretrained(model_stub, use_fast=True) +tokenizer.pad_token = tokenizer.eos_token -# save location of quantized model out -output_dir = "./output_llama7b_fp8_compressed" -# set dataset config parameters -splits = {"calibration": "train_gen[:5%]"} -max_seq_length = 512 -pad_to_max_length = False -num_calibration_samples = 512 +def preprocess(batch): + text = tokenizer.apply_chat_template(batch["messages"], tokenize=False) + tokenized = tokenizer(text, padding=True, truncation=True, max_length=2048) + return tokenized + + +ds = load_dataset("mgoin/ultrachat_2k", split="train_sft") +examples = ds.map(preprocess, remove_columns=ds.column_names) + +recipe = GPTQModifier(targets=["Linear"], scheme="FP8", ignore=["lm_head"]) + +model = SparseAutoModelForCausalLM.from_pretrained( + model_stub, torch_dtype=torch.bfloat16, device_map="auto" +) -# apply recipe to the model and save quantized output in fp8 format oneshot( model=model, - dataset=dataset, + dataset=examples, recipe=recipe, output_dir=output_dir, - splits=splits, - max_seq_length=max_seq_length, - pad_to_max_length=pad_to_max_length, num_calibration_samples=num_calibration_samples, save_compressed=True, ) diff --git a/examples/llama7b_w8a8_quantization.py b/examples/llama7b_w8a8_quantization.py index c894613ffbb..6ffd4b0f623 100644 --- a/examples/llama7b_w8a8_quantization.py +++ b/examples/llama7b_w8a8_quantization.py @@ -16,19 +16,18 @@ num_bits: 8 type: "int" symmetric: true - strategy: "channel" + strategy: "tensor" input_activations: num_bits: 8 type: "int" symmetric: true - dynamic: True - strategy: "token" + strategy: "tensor" targets: ["Linear"] """ # setting device_map to auto to spread the model evenly across all available GPUs # load the model in as bfloat16 to save on memory and compute -model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base" +model_stub = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" model = SparseAutoModelForCausalLM.from_pretrained( model_stub, torch_dtype=torch.bfloat16, device_map="auto" ) @@ -37,7 +36,7 @@ dataset = "ultrachat-200k" # save location of quantized model out -output_dir = "./output_llama7b_w8a8_channel_dynamic_compressed" +output_dir = "./TEST_MAIN_BRANCH_TENSOR" # set dataset config parameters splits = {"calibration": "train_gen[:5%]"} diff --git a/src/sparseml/modifiers/quantization/__init__.py b/src/sparseml/modifiers/quantization/__init__.py index ebdf28a6d5b..fe2676473d5 100644 --- a/src/sparseml/modifiers/quantization/__init__.py +++ b/src/sparseml/modifiers/quantization/__init__.py @@ -13,3 +13,5 @@ # limitations under the License. # flake8: noqa + +from .gptq import *