diff --git a/src/sparseml/modifiers/quantization/gptq/pytorch.py b/src/sparseml/modifiers/quantization/gptq/pytorch.py index 2eb14e8d2d7..e9e3f715625 100644 --- a/src/sparseml/modifiers/quantization/gptq/pytorch.py +++ b/src/sparseml/modifiers/quantization/gptq/pytorch.py @@ -176,7 +176,6 @@ def _pruning_arguments(self): """ Gather the parameters needed for root module compression in a dict - :param sparsity: target sparsity :return: dict of params for pruning """ return { diff --git a/tests/sparseml/transformers/compression/configs/channelwise_1.1b.yaml b/tests/sparseml/transformers/compression/configs/channelwise_1.1b.yaml new file mode 100644 index 00000000000..05d7ffea467 --- /dev/null +++ b/tests/sparseml/transformers/compression/configs/channelwise_1.1b.yaml @@ -0,0 +1,5 @@ +cadence: "nightly" +test_type: "regression" +model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +old_recipe: "tests/sparseml/transformers/compression/recipes/old_quant_channel.yaml" +new_recipe: "tests/sparseml/transformers/compression/recipes/new_quant_channel.yaml" \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/configs/channelwise_15m.yaml b/tests/sparseml/transformers/compression/configs/channelwise_15m.yaml new file mode 100644 index 00000000000..77548c5a155 --- /dev/null +++ b/tests/sparseml/transformers/compression/configs/channelwise_15m.yaml @@ -0,0 +1,5 @@ +cadence: "commit" +test_type: "regression" +model_stub: "Xenova/llama2.c-stories15M" +old_recipe: "tests/sparseml/transformers/compression/recipes/old_quant_channel.yaml" +new_recipe: "tests/sparseml/transformers/compression/recipes/new_quant_channel.yaml" \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/configs/inputs_1.1b.yaml b/tests/sparseml/transformers/compression/configs/inputs_1.1b.yaml new file mode 100644 index 00000000000..305be80af89 --- /dev/null +++ b/tests/sparseml/transformers/compression/configs/inputs_1.1b.yaml @@ -0,0 +1,5 @@ +cadence: "nightly" +test_type: "regression" +model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +old_recipe: "tests/sparseml/transformers/compression/recipes/old_quant_full.yaml" +new_recipe: "tests/sparseml/transformers/compression/recipes/new_quant_full.yaml" \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/configs/inputs_15m.yaml b/tests/sparseml/transformers/compression/configs/inputs_15m.yaml new file mode 100644 index 00000000000..237750b40bd --- /dev/null +++ b/tests/sparseml/transformers/compression/configs/inputs_15m.yaml @@ -0,0 +1,5 @@ +cadence: "commit" +test_type: "regression" +model_stub: "Xenova/llama2.c-stories15M" +old_recipe: "tests/sparseml/transformers/compression/recipes/old_quant_full.yaml" +new_recipe: "tests/sparseml/transformers/compression/recipes/new_quant_full.yaml" \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/configs/weights_only_1.1b.yaml b/tests/sparseml/transformers/compression/configs/weights_only_1.1b.yaml new file mode 100644 index 00000000000..0dad96251dd --- /dev/null +++ b/tests/sparseml/transformers/compression/configs/weights_only_1.1b.yaml @@ -0,0 +1,5 @@ +cadence: "nightly" +test_type: "regression" +model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" +old_recipe: "tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml" +new_recipe: "tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml" \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/configs/weights_only_15m.yaml b/tests/sparseml/transformers/compression/configs/weights_only_15m.yaml new file mode 100644 index 00000000000..b5d2b1cd87f --- /dev/null +++ b/tests/sparseml/transformers/compression/configs/weights_only_15m.yaml @@ -0,0 +1,5 @@ +cadence: "commit" +test_type: "regression" +model_stub: "Xenova/llama2.c-stories15M" +old_recipe: "tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml" +new_recipe: "tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml" \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_channel.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_channel.yaml new file mode 100644 index 00000000000..48df197537c --- /dev/null +++ b/tests/sparseml/transformers/compression/recipes/new_quant_channel.yaml @@ -0,0 +1,18 @@ +test_stage: + quant_modifiers: + vLLMQuantizationModifier: + ignore: ["lm_head", "model.layers.0.mlp.down_proj"] + config_groups: + group_0: + weights: + num_bits: 4 + type: "int" + symmetric: False + strategy: "channel" + input_activations: null + output_activations: null + targets: ["Linear"] + GPTQModifier: + block_size: 128 + sequential_update: False + targets: ["model.layers.0", "model.layers.1", "model.layers.2"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml index 409a168ecfd..924dcd6e3f6 100644 --- a/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml +++ b/tests/sparseml/transformers/compression/recipes/new_quant_full.yaml @@ -28,5 +28,4 @@ test_stage: GPTQModifier: block_size: 128 sequential_update: False - percdamp: 0.01 targets: ["re:model.layers.\\d+$"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml index 68bf42e1bc5..19b9d196e6a 100644 --- a/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml +++ b/tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml @@ -15,5 +15,4 @@ test_stage: GPTQModifier: block_size: 128 sequential_update: False - percdamp: 0.01 targets: ["re:model.layers.\\d+$"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/recipes/old_quant_channel.yaml b/tests/sparseml/transformers/compression/recipes/old_quant_channel.yaml new file mode 100644 index 00000000000..350d07ce1c2 --- /dev/null +++ b/tests/sparseml/transformers/compression/recipes/old_quant_channel.yaml @@ -0,0 +1,28 @@ +test_stage: + quant_modifiers: + QuantizationModifier: + ignore: + - model.layers.0.mlp.down_proj + - lm_head + - LlamaRotaryEmbedding + - LlamaRMSNorm + - SiLU + - MatMulLeftInput_QK + - MatMulRightInput_QK + - MatMulOutput_QK + - MatMulLeftInput_PV + - MatMulRightInput_PV + - MatMulOutput_PV + - Embedding + scheme_overrides: + Linear: + weights: + num_bits: 4 + symmetric: false + strategy: "channel" + input_activations: null + output_activations: null + GPTQModifier: + block_size: 128 + sequential_update: False + targets: ["model.layers.0", "model.layers.1", "model.layers.2"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml b/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml index 95edd24628e..9d67e334fef 100644 --- a/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml +++ b/tests/sparseml/transformers/compression/recipes/old_quant_full.yaml @@ -34,5 +34,4 @@ test_stage: GPTQModifier: block_size: 128 sequential_update: False - percdamp: 0.01 targets: ["re:model.layers.\\d+$"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml b/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml index 375dcfceb6c..78e49595fe2 100644 --- a/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml +++ b/tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml @@ -31,5 +31,4 @@ test_stage: GPTQModifier: block_size: 128 sequential_update: False - percdamp: 0.01 targets: ["re:model.layers.\\d+$"] \ No newline at end of file diff --git a/tests/sparseml/transformers/compression/test_quantization.py b/tests/sparseml/transformers/compression/test_quantization.py index 03d86954609..4b250018d6f 100644 --- a/tests/sparseml/transformers/compression/test_quantization.py +++ b/tests/sparseml/transformers/compression/test_quantization.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math import os import shutil import tempfile import unittest +import pytest import torch from torch.utils.data import DataLoader from transformers import DefaultDataCollator +from compressed_tensors.quantization import fake_quantize from compressed_tensors.quantization.utils import is_module_quantized from parameterized import parameterized_class from sparseml.pytorch.utils import tensors_to_device @@ -32,30 +33,31 @@ ) from sparseml.transformers.finetune.data import TextGenerationDataset from sparseml.transformers.finetune.data.data_args import DataTrainingArguments -from tests.testing_utils import requires_gpu, requires_torch +from tests.testing_utils import parse_params, requires_gpu, requires_torch + + +CONFIGS_DIRECTORY = "tests/sparseml/transformers/compression/configs" @requires_torch @requires_gpu -@parameterized_class( - ("old_recipe", "new_recipe"), - [ - ( - "tests/sparseml/transformers/compression/recipes/old_quant_full.yaml", - "tests/sparseml/transformers/compression/recipes/new_quant_full.yaml", - ), - ( - "tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml", - "tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml", - ), - ], -) +@pytest.mark.integration +@parameterized_class(parse_params(CONFIGS_DIRECTORY)) class TestQuantizationMatches(unittest.TestCase): + """ + Tests new compressed-tensors quantization format matches performance with the old + sparseml format. For setup, this class runs a full oneshot run with both an old and + new quantization recipe that should be equivalent. Then tests the following: + - quantization structure matches after oneshot + - quantized weights match + - decompressing the new model has the expected weights on reload + - no perplexity regression from the old quantization framework, asserts we are + no more than 2% on perplexity + """ + old_recipe = None new_recipe = None - # TODO: use "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" for nightly - # or weekly runs, but this smaller model is better for commit testing - model_stub = "Xenova/llama2.c-stories15M" + model_stub = None dataset = "open_platypus" old_output = "tiny_llama_old" new_output = "tiny_llama_new" @@ -86,16 +88,9 @@ def setUpClass(cls): os.path.join(cls.test_dir, cls.new_output), ) - @classmethod - def tearDownClass(cls): - shutil.rmtree(cls.test_dir) - del cls.model_new - del cls.model_old - torch.cuda.empty_cache() - @staticmethod def _run_oneshot(model, recipe, dataset, output_dir): - num_calibration_samples = 512 + num_calibration_samples = 256 max_seq_length = 512 pad_to_max_length = False @@ -116,12 +111,13 @@ def _get_quant_info_old(self, model): quant_info_inputs = {} for name, module in model.named_modules(): if hasattr(module, "weight_fake_quant"): - scale = module.weight_fake_quant.scale.item() - zp = module.weight_fake_quant.zero_point.item() - quant_info_weights[name] = (scale, zp) + scale = module.weight_fake_quant.scale + zp = module.weight_fake_quant.zero_point + weight = module.weight_fake_quant(module.weight) + quant_info_weights[name] = (scale, zp, weight) elif hasattr(module, "quant"): - scale = module.quant.activation_post_process.scale.item() - zp = module.quant.activation_post_process.zero_point.item() + scale = module.quant.activation_post_process.scale + zp = module.quant.activation_post_process.zero_point quant_info_inputs[name] = (scale, zp) return quant_info_weights, quant_info_inputs @@ -133,17 +129,42 @@ def _get_quant_info_new(self, model): if is_module_quantized(module): if module.quantization_scheme.weights is not None: quant_info_weights[name] = ( - module.weight_scale.item(), - module.weight_zero_point.item(), + module.weight_scale, + module.weight_zero_point, + fake_quantize( + module.weight, + module.weight_scale, + module.weight_zero_point, + module.quantization_scheme.weights, + ), ) if module.quantization_scheme.input_activations is not None: quant_info_inputs[name] = ( - module.input_scale.item(), - module.input_zero_point.item(), + module.input_scale, + module.input_zero_point, ) return quant_info_weights, quant_info_inputs + def _get_dataloader(self, data_args, tokenizer): + dataset_manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split="train", + tokenizer=tokenizer, + ) + calib_dataset = dataset_manager.tokenize_and_process( + dataset_manager.get_raw_dataset() + ) + data_loader = DataLoader( + calib_dataset, + batch_size=1, + collate_fn=DefaultDataCollator(), + sampler=torch.utils.data.RandomSampler(calib_dataset), + ) + + return data_loader + def test_quantization_counts(self): old_quant_weights, old_quant_inputs = self._get_quant_info_old(self.model_old) new_quant_weights, new_quant_inputs = self._get_quant_info_new(self.model_new) @@ -151,16 +172,24 @@ def test_quantization_counts(self): assert len(old_quant_weights) == len(new_quant_weights) assert len(old_quant_inputs) == len(new_quant_inputs) - def test_quantization_scale_and_zp(self): - old_quant_weights, old_quant_inputs = self._get_quant_info_old(self.model_old) - new_quant_weights, new_quant_inputs = self._get_quant_info_new(self.model_new) + def test_quantization_matches(self): + old_quant_weights, _ = self._get_quant_info_old(self.model_old) + new_quant_weights, _ = self._get_quant_info_new(self.model_new) - for name, (o_scale, o_zp) in old_quant_weights.items(): + for name, (o_scale, o_zp, _) in old_quant_weights.items(): if name.endswith(".module"): name = name[:-7] - n_scale, n_zp = new_quant_weights[name] - assert math.isclose(o_scale, n_scale, abs_tol=1e-3, rel_tol=1e-3) - assert o_zp == n_zp + n_scale, n_zp, _ = new_quant_weights[name] + if n_scale.ndim == 2: # channelwise + n_scale = n_scale[:, 0] + n_zp = n_zp[:, 0] + elif n_scale.ndim == 0: # tensor + n_scale = torch.unsqueeze(n_scale, 0) + n_zp = torch.unsqueeze(n_zp, 0) + + assert torch.all( + torch.isclose(o_scale.cpu(), n_scale.cpu(), atol=1e-3, rtol=1e-3) + ) def test_quantization_reload(self): model_reloaded = SparseAutoModelForCausalLM.from_pretrained( @@ -170,34 +199,15 @@ def test_quantization_reload(self): og_weights, og_inputs = self._get_quant_info_new(self.model_new) reloaded_weights, reloaded_inputs = self._get_quant_info_new(model_reloaded) - for name, (o_scale, o_zp) in og_weights.items(): - n_scale, n_zp = reloaded_weights[name] - assert o_scale == n_scale - assert o_zp == n_zp + for name, (o_scale, o_zp, _) in og_weights.items(): + n_scale, n_zp, _ = reloaded_weights[name] + assert torch.equal(o_scale.cpu(), n_scale.cpu()) + assert torch.equal(o_zp.cpu(), n_zp.cpu()) for name, (o_scale, o_zp) in og_inputs.items(): n_scale, n_zp = reloaded_inputs[name] - assert o_scale == n_scale - assert o_zp == n_zp - - def _get_dataloader(self, data_args, tokenizer): - dataset_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split="train", - tokenizer=tokenizer, - ) - calib_dataset = dataset_manager.tokenize_and_process( - dataset_manager.get_raw_dataset() - ) - data_loader = DataLoader( - calib_dataset, - batch_size=1, - collate_fn=DefaultDataCollator(), - sampler=torch.utils.data.RandomSampler(calib_dataset), - ) - - return data_loader + assert torch.equal(o_scale.cpu(), n_scale.cpu()) + assert torch.equal(o_zp.cpu(), n_zp.cpu()) @torch.no_grad() def test_perplexity(self): @@ -228,3 +238,10 @@ def test_perplexity(self): total_ppl_old / total_non_nan ) assert avg_ppl_ratio <= 1.02 + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.test_dir) + del cls.model_new + del cls.model_old + torch.cuda.empty_cache()