diff --git a/neural_compressor/torch/algorithms/qat/__init__.py b/neural_compressor/torch/algorithms/qat/__init__.py new file mode 100644 index 00000000000..e4c8a62c491 --- /dev/null +++ b/neural_compressor/torch/algorithms/qat/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint:disable=import-error +"""QAT (Quantization Aware Tuning).""" diff --git a/neural_compressor/torch/algorithms/qat/quant_linear.py b/neural_compressor/torch/algorithms/qat/quant_linear.py new file mode 100644 index 00000000000..2858c0f9420 --- /dev/null +++ b/neural_compressor/torch/algorithms/qat/quant_linear.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Quantized Linear.""" + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .tensor_quantizer import TensorQuantizer + + +class QuantLinear(nn.Module): + """Quantized version of nn.Linear.""" + + def forward(self, input: torch.Tensor): + """Add weight/input/output of quantization for the original forward method.""" + qw = self.weight_quantizer(self.weight) + qi = self.input_quantizer(input) + out = F.linear(qi, qw, self.bias) + out = self.output_quantizer(out) + return out + + def _setup(self, quant_cfg): + """Init quantizer.""" + self.weight_quantizer = TensorQuantizer( + data_type=quant_cfg.data_type, + block_size=quant_cfg.group_size, + bits=quant_cfg.bits, + sym=quant_cfg.sym, + if_quant=True, + learn_exponent=False, + ) + self.input_quantizer = TensorQuantizer( + data_type=quant_cfg.act_data_type, + block_size=quant_cfg.act_group_size, + bits=quant_cfg.act_bits, + sym=quant_cfg.act_sym, + if_quant=True, + learn_exponent=False, + ) + self.output_quantizer = TensorQuantizer( + data_type=quant_cfg.act_data_type, + block_size=quant_cfg.act_group_size, + bits=quant_cfg.act_bits, + sym=quant_cfg.act_sym, + if_quant=False, + ) + # Currently don't quant output + self.output_quantizer.disable() + + # TODO: remove + self.original_weight_dtype = None if self.weight is None else self.weight.dtype + + def extra_repr(self) -> str: + """Generate extra_repr making sure import keys exist in self.__dict__.""" + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" + + def __repr__(self): + """Overriding the __repr__ method, makes the output more concise and meaningful.""" + return ( + f"QuantLinear(\n" + f" in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}\n" + f" (input_quantizer): {self.input_quantizer}\n" + f" (output_quantizer): {self.output_quantizer}\n" + f" (weight_quantizer): {self.weight_quantizer}\n" + f")" + ) diff --git a/neural_compressor/torch/algorithms/qat/quant_utils.py b/neural_compressor/torch/algorithms/qat/quant_utils.py new file mode 100644 index 00000000000..bf99a36d5b3 --- /dev/null +++ b/neural_compressor/torch/algorithms/qat/quant_utils.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for quantization.""" + +import types +from typing import Any + +import torch +import torch.nn as nn + +from .quant_linear import QuantLinear +from .tensor_quantizer import TensorQuantizer + + +def convert(module: nn.Module, quant_cfg=None, quant_module=None): + """Convert the model to a quantized one with quant config.""" + + # update class + original_cls = type(module) + module.__class__ = quant_module + module.forward = types.MethodType(quant_module.forward, module) + + # setup quantizers + module._setup(quant_cfg) + + return module + + +def replace_with_quant_linear(model, quant_cfg=None): + """Recursively replace the module with quantized module.""" + + # TODO: support more modules, like kv. + for name, child in model.named_children(): + if isinstance(child, nn.Linear): + if "lm_head" in name: + continue + # REPLACE on the parent (model), not on child + quantized = convert(child, quant_cfg, QuantLinear) + setattr(model, name, quantized) + + # now recurse into whichever module is now at `model.name` + replace_with_quant_linear(getattr(model, name), quant_cfg=quant_cfg) + + return model + + +def get_quant_config_with_scheme(scheme: str): + """Get quantization config.""" + + try: + # use scheme definitions from AutoRound since we utilize the quantization functions now + from auto_round.schemes import preset_name_to_scheme + + quant_cfg = preset_name_to_scheme(scheme) + return quant_cfg + except ImportError: + return None + + +def convert_model_with_mapping(model, mapping=None): + """Process mapping to quant config.""" + # key is torch module, TODO: support more key format, like layer name. + for key in mapping: + # TODO: support more torch modules + if key == nn.Linear: + quant_cfg = get_quant_config_with_scheme(mapping[key]) + if quant_cfg is None: + continue + replace_with_quant_linear(model, quant_cfg) + + replaced_modules = sum(isinstance(m, TensorQuantizer) for _, m in model.named_modules()) + print(f"Inserted {replaced_modules} quantizers") + + +def get_quant_config(scheme: str) -> dict[str, Any]: + """Generate quantization config for a torch model. + + Args: + model: The PyTorch model to analyze + + Returns: + Dictionary containing the quantization configuration + """ + + # TODO: support more quant config + try: + from auto_round.export.export_to_llmcompressor.config import initialize_quantization + + quantization_config = initialize_quantization(scheme=scheme) + quantization_config = quantization_config.to_dict() + quantization_config["provider"] = "auto-round" + quantization_config["config_groups"]["group_0"]["weights"]["is_mx"] = True + quantization_config["config_groups"]["group_0"]["input_activations"]["is_mx"] = True + + except ImportError: + quantization_config = None + + return quantization_config + + +def get_quantization_format(module) -> str | None: + """Gets the quantization string. + + Gets the quantization string by iterating through the module and its children. + The first non-None quantization string is returned. + """ + + def _get_quantization_from_layer(layer): + weight_quantizer = getattr(layer, "weight_quantizer", None) + input_quantizer = getattr(layer, "input_quantizer", None) + + if weight_quantizer is None or weight_quantizer._disabled: + return None + + # TODO: support more quant format + if weight_quantizer.num_bits == 8 and weight_quantizer.data_type == "mx_fp8": + return "MXFP8" + + # Raise error for unsupported num_bits + raise NotImplementedError(f"Unsupported quantizer with num_bits: {weight_quantizer.num_bits}") + + quantization = _get_quantization_from_layer(module) + if quantization is not None: + return quantization + + for _, layer in module.named_children(): + format = get_quantization_format(layer) + if format is not None: + return format + + return None + + +def is_quantlinear(module: nn.Module) -> bool: + """Returns whether the module is a quantized linear layer.""" + return "QuantLinear" in type(module).__name__ diff --git a/neural_compressor/torch/algorithms/qat/tensor_quantizer.py b/neural_compressor/torch/algorithms/qat/tensor_quantizer.py new file mode 100644 index 00000000000..e8c0badad28 --- /dev/null +++ b/neural_compressor/torch/algorithms/qat/tensor_quantizer.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TensorQuantizer Module.""" + +import torch +from torch import nn + +try: + from auto_round.data_type import get_quant_func +except ImportError: + get_quant_func = None + + +class TensorQuantizer(nn.Module): + """Tensor quantizer module.""" + + def __init__( + self, + data_type="mx_fp8", + bits=8, + block_size=32, + sym=True, + if_quant=True, + learn_exponent=False, + amax=None, + scale_shape=None, + device=None, + ): + """Initialize quantizer and set up required variables.""" + super().__init__() + self.amax = amax + self.data_type = data_type + self.num_bits = bits + self.block_size = block_size + self.sym = sym + self._if_quant = if_quant + self.learn_exponent = learn_exponent + self._dequantize = False + self._input_dtype = None + self._fake_quant = True + + # enable quantizer + self.enable() + + assert ( + get_quant_func is not None + ), "The quantization function is imported from AutoRound, please install it. 'pip install auto-round'" + + # self.data_type will be overided 'mx_fp' -> 'mx_fp8' + self.quant_func, self.data_type = get_quant_func(self.data_type, self.num_bits, self.sym) + + if scale_shape is not None: + # E8M0 scales (exponent) + self.register_buffer( + "scale", + torch.empty(scale_shape[0], scale_shape[1] // self.block_size, dtype=torch.uint8, device=device), + ) + self.save_scale = True + else: + self.save_scale = False + + def forward(self, inputs: torch.Tensor): + """Apply tensor_quant function to inputs. + + Args: + inputs: A Tensor of type float32/float16/bfloat16. + + Returns: + outputs: A Tensor of type output_dtype + """ + + if self._disabled or (not self._if_quant): + self._input_dtype = inputs.dtype + return inputs + + x = inputs + if not x.is_contiguous(): + x = x.contiguous() + + if self.fake_quant: + q = self._fake_quantize(x)[0] + else: + # TODO: add implementation + q = self._real_quantize(x) + + return q.to(inputs.dtype) + + def _fake_quantize(self, inputs: torch.Tensor): + """Fake quantization.""" + + # the shared_exp can be trainable + if self.learn_exponent: + q, shared_exp, _ = self.quant_func( + inputs, + bits=self.num_bits, + group_size=self.block_size, + data_type=self.data_type, + ) + else: + # wrapper no_grad, because the function includes extra trainable variables + with torch.no_grad(): + q, shared_exp, _ = self.quant_func( + inputs, + bits=self.num_bits, + group_size=self.block_size, + data_type=self.data_type, + ) + + # simple STE, since we add no_grad in the quant function + q = q.detach() + (inputs - inputs.detach()) + + if self.save_scale: + # TODO: PACK uint8 + self.scale.data.copy_(shared_exp.detach()) + + return q, shared_exp + + def _real_quantize(self, inputs: torch.Tensor): + raise NotImplementedError("This method hasn't be implemented.") + + @property + def fake_quant(self): + """Return True if fake quantization is used.""" + return self._fake_quant + + def disable(self): + """Bypass the module.""" + self._disabled = True + + def enable(self): + """Enable the module.""" + self._disabled = False + + def weight_pack(self, weight, scale): + """Pack weight and scale when saving.""" + original_shape = weight.shape + + # TODO: support more quantization format + if self.data_type == "mx_fp8": + qweight = (weight.reshape(-1, self.block_size) / torch.exp2(scale.float()).reshape(-1, 1)).to( + torch.float8_e4m3fn + ) + + e8m0_scale = (scale + 127).to(torch.uint8) + return qweight.reshape(original_shape), e8m0_scale.reshape(original_shape[0], -1) + + def __repr__(self): + if self._disabled or not self._if_quant: + return "TensorQuantizer(disabled)" + + qformat_str = f"({self.data_type}) format" + bits_str = f"({self.num_bits}) bit" + + if self.block_size: + bs_str = f"block_size={self.block_size}" + else: + bs_str = "block_size=None" + + # amax + amax_str = f"amax={self.amax}" if self.amax is not None else "amax=?" + # fake / real + mode_str = "fake" if self._fake_quant else "real" + # sym + sym_str = "sym" if self.sym else "asym" + # quant enable + qflag = "quant" if self._if_quant else "no-quant" + + return f"TensorQuantizer({qformat_str} {bits_str} {mode_str} {bs_str}, {amax_str} {qflag})" diff --git a/neural_compressor/torch/export/export_hf.py b/neural_compressor/torch/export/export_hf.py new file mode 100644 index 00000000000..e617ae122a9 --- /dev/null +++ b/neural_compressor/torch/export/export_hf.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Export quantized hf model to compatible formats.""" + +import tempfile +import warnings +from pathlib import Path +from typing import Any + +import torch +import torch.nn as nn + + +def _export_quantized_weight(sub_module: nn.Module, quantization_format: str = None, weight_name: str = "weight"): + """For the given weight attr of the sub_module, export the quantization info of it. + + The export includes converting weight tensor to correct quantized values and quantized dtype, + and registering scaling factors. + """ + if quantization_format is None: + return + + weight: nn.Parameter = getattr(sub_module, weight_name) + weight_quantizer = getattr(sub_module, "weight_quantizer") + + qdq_weight, scale = weight_quantizer._fake_quantize(weight) + + # TODO: support more scale dtype when there are other quantization format except mxfp8/mxfp4 + quantized_weight, e8m0_scale = weight_quantizer.weight_pack(qdq_weight, scale) + + sub_module.register_buffer("weight_scale", e8m0_scale) + + setattr(sub_module, weight_name, nn.Parameter(quantized_weight, requires_grad=False)) + + +def _export_hf_checkpoint(model: nn.Module, scheme: str | None = None) -> tuple[dict[str, Any], dict[str, Any]]: + """Exports the torch model to the packed checkpoint with original HF naming. + + The packed checkpoint will be consumed by the TensorRT-LLM unified converter. + + Args: + model: the torch model. + dtype: the weights data type to export the unquantized layers or the default model data type if None. + + Returns: + post_state_dict: Dict containing quantized weights + quant_config: config information to export hf_quant_cfg.json + """ + # Create a model layer pool + # If `model.model` exists use that, otherwise use `model` itself, e.g., Nemotron-H + root = getattr(model, "model", model) + # If that has a `.layers`, use it, otherwise fall back to the object itself + root = getattr(root, "layers", root) + layer_pool = {f"model.layers.{name}": sub_module for name, sub_module in root.named_modules()} + + from ..algorithms.qat.quant_utils import get_quant_config, get_quantization_format, is_quantlinear + + # compressored config + quant_config = get_quant_config(scheme=scheme) + + for name, sub_module in layer_pool.items(): + quantization_format = get_quantization_format(sub_module) + if quantization_format is not None: + if is_quantlinear(sub_module): + _export_quantized_weight(sub_module, quantization_format) + + quantized_state_dict = model.state_dict() + + return quantized_state_dict, quant_config + + +def export_hf2compressored_model(model: nn.Module, export_dir: Path | str = tempfile.gettempdir(), scheme: str = None): + """Exports the torch model to the packed checkpoint with original HF naming. + + The packed checkpoint will be consumed by the VLLM. + """ + export_dir = Path(export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + try: + _, quant_config = _export_hf_checkpoint(model, scheme) + model.save_pretrained(export_dir) + model.config.quantization_config = quant_config + model.config.save_pretrained(export_dir) + + except Exception as e: + warnings.warn("Cannot export model and config, the state can be saved with torch.save for further inspection.") + raise e diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 27e5a85551e..883edc60f60 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -18,10 +18,11 @@ """Intel Neural Compressor Pytorch quantization config API.""" +import copy import importlib import json from collections import OrderedDict -from typing import Callable, Dict, List, NamedTuple, Optional +from typing import Any, Callable, Dict, List, NamedTuple, Optional from typing import OrderedDict as OrderedDictType from typing import Tuple, Union @@ -2167,3 +2168,15 @@ def get_config_set_for_tuning(cls, dtype="int8"): return cls._model_mapping[STATIC_QUANT].get_config_set_for_tuning() else: raise ValueError(f"Unsupported dtype: {dtype}, allowed values are 'fp8' and 'int8'.") + + +# TODO: support more mappings configurations. +# Default map for swapping float module to qat modules +DEFAULT_QAT_MODULE_MAPPINGS: dict[Callable, Any] = { + torch.nn.Linear: "MXFP8", +} + + +def get_default_qat_module_mappings() -> dict[Callable, Any]: + """Get default module mapping for quantization aware training.""" + return copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS) diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index a313220c43e..84f770a4a71 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -185,6 +185,37 @@ def prepare( return prepared_model +@log_process(mode=Mode.PREPARE) +def prepare_qat( + model: torch.nn.Module, + mapping=None, + inplace: bool = True, +): + r"""Prepares a copy of the model for quantization calibration or + quantization-aware training and converts it to quantized version. + + Quantization configuration should be assigned preemptively + to individual submodules in `.qconfig` attribute. + + Args: + model: input model to be modified in-place + quant_config: quantization config that maps float modules to quantized modules to be + replaced. + inplace: carry out model transformations in-place, the original module + is mutated + """ + assert model.training, "prepare_qat only works on models in training mode" + + from .config import get_default_qat_module_mappings + + if mapping is None: + mapping = get_default_qat_module_mappings() + + from ..algorithms.qat.quant_utils import convert_model_with_mapping + + return convert_model_with_mapping(model, mapping) + + @log_process(mode=Mode.CONVERT) def convert( model: torch.nn.Module, diff --git a/test/3x/torch/algorithms/qat/test_qat.py b/test/3x/torch/algorithms/qat/test_qat.py new file mode 100644 index 00000000000..83fc1dd4348 --- /dev/null +++ b/test/3x/torch/algorithms/qat/test_qat.py @@ -0,0 +1,67 @@ +import math +import types +import torch +import torch.nn as nn +import pytest + +# Skip the whole module if auto_round (needed for get_quant_func inside TensorQuantizer) is not available +auto_round = pytest.importorskip("auto_round") + +from neural_compressor.torch.quantization.quantize import prepare_qat +from neural_compressor.torch.algorithms.qat.tensor_quantizer import TensorQuantizer + + +def setup_seed(seed): + import numpy as np + import random + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + +class TinyModel(nn.Module): + """Simple hierarchical model for recursive replacement tests.""" + + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(32, 64) + self.lm_head = nn.Linear(64, 2) + + def forward(self, x): + x = self.fc1(x) + return self.lm_head(x) + +def test_replace_quant_layer(): + """Check the inserted quant linear.""" + model = TinyModel() + + prepare_qat(model) + + replaced_modules = sum(isinstance(m, TensorQuantizer) for _, m in model.named_modules()) + + assert replaced_modules == 3 + + +def test_train(): + """QAT test.""" + setup_seed(20) + + model = TinyModel() + prepare_qat(model) + + inp = torch.randn([2, 32]) + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) + + with torch.autocast(device_type="cpu", dtype=torch.bfloat16): + output = model(inp) + loss = output.mean() + + optimizer.zero_grad() + loss.backward() + + # check the grad + for name, param in model.named_parameters(): + assert param.grad is not None + optimizer.step() diff --git a/test/3x/torch/algorithms/qat/test_quant_utils.py b/test/3x/torch/algorithms/qat/test_quant_utils.py new file mode 100644 index 00000000000..cca51126caf --- /dev/null +++ b/test/3x/torch/algorithms/qat/test_quant_utils.py @@ -0,0 +1,208 @@ +# -*- coding: utf-8 -*- + +import sys +import types +import importlib +from types import SimpleNamespace +from pathlib import Path + +import pytest +import torch +import torch.nn as nn + +from neural_compressor.torch.algorithms.qat import quant_utils + + +from neural_compressor.torch.algorithms.qat.tensor_quantizer import TensorQuantizer # type: ignore +from neural_compressor.torch.algorithms.qat.quant_linear import QuantLinear + + +class TinyModel(nn.Module): + """Simple hierarchical model for recursive replacement tests.""" + + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(16, 8) + self.block = nn.Sequential( + nn.Linear(8, 8), + nn.ReLU(), + nn.Linear(8, 4), + ) + self.lm_head = nn.Linear(4, 2) + + def forward(self, x): + x = self.fc1(x) + x = self.block(x) + return self.lm_head(x) + + +@pytest.fixture +def sample_input(): + return torch.randn(2, 16) + +def make_quant_cfg( + *, + data_type="mx_fp8", + bits=8, + group_size=32, + sym=True, + act_data_type="mx_fp8", + act_bits=8, + act_group_size=32, + act_sym=True, +): + """ + Build a lightweight namespace mimicking the attributes QuantLinear._setup expects. + """ + return types.SimpleNamespace( + data_type=data_type, + bits=bits, + group_size=group_size, + sym=sym, + act_data_type=act_data_type, + act_bits=act_bits, + act_group_size=act_group_size, + act_sym=act_sym, + ) + +@pytest.fixture +def quant_cfg(): + return make_quant_cfg() + + +def test_convert_replaces_class_and_calls_setup(monkeypatch, quant_cfg): + linear = nn.Linear(4, 3) + + original_forward_id = id(QuantLinear.forward) + + quant_utils.convert(linear, quant_cfg=quant_cfg, quant_module=QuantLinear) + + assert isinstance(linear, QuantLinear) + assert hasattr(linear.forward, "__self__") and linear.forward.__self__ is linear + assert linear.forward.__func__ is QuantLinear.forward or id(linear.forward.__func__) == original_forward_id + + +def test_replace_with_quant_linear_recursive(monkeypatch, quant_cfg): + model = TinyModel() + + + quant_utils.replace_with_quant_linear(model, quant_cfg=quant_cfg) + + assert isinstance(model.fc1, QuantLinear) + assert isinstance(model.block[0], QuantLinear) + assert isinstance(model.block[2], QuantLinear) + assert isinstance(model.lm_head, nn.Linear) + + +def test_is_quantlinear_positive_and_negative(): + q = QuantLinear() + plain = nn.Linear(4, 2) + assert quant_utils.is_quantlinear(q) is True + assert quant_utils.is_quantlinear(plain) is False + + +def test_get_quantization_format_positive(monkeypatch): + layer = QuantLinear() + + layer.weight_quantizer = TensorQuantizer(bits=8, data_type="mx_fp8") + layer.weight_quantizer._disabled = False + layer.input_quantizer = TensorQuantizer(bits=8, data_type="mx_fp8") + layer.input_quantizer._disabled = False + + layer.weight = None + fmt = quant_utils.get_quantization_format(layer) + assert fmt == "MXFP8" + + +def test_get_quantization_format_none(): + layer = nn.Linear(4, 2) + fmt = quant_utils.get_quantization_format(layer) + assert fmt is None + + +def test_get_quantization_format_unsupported_bits_raises(): + layer = QuantLinear() + layer.weight_quantizer = TensorQuantizer(bits=4, data_type="mx_fp8") + layer.weight_quantizer._disabled = False + layer.input_quantizer = TensorQuantizer(bits=4, data_type="mx_fp8") + layer.input_quantizer._disabled = False + + with pytest.raises(NotImplementedError): + quant_utils.get_quantization_format(layer) + + +def test_get_quant_config_success(monkeypatch): + # dynamic fake module: auto_round.export.export_to_llmcompressor.config + module_name = "auto_round.export.export_to_llmcompressor.config" + + class DummyQuantCfg: + def __init__(self): + self.data = { + "provider": "dummy", + "config_groups": { + "group_0": { + "weights": {}, + "input_activations": {}, + } + }, + } + + def to_dict(self): + return self.data + + def initialize_quantization(scheme: str): + return DummyQuantCfg() + + # auto_round + auto_round = types.ModuleType("auto_round") + export = types.ModuleType("auto_round.export") + export_to = types.ModuleType("auto_round.export.export_to_llmcompressor") + config_mod = types.ModuleType(module_name) + config_mod.initialize_quantization = initialize_quantization + + sys.modules["auto_round"] = auto_round + sys.modules["auto_round.export"] = export + sys.modules["auto_round.export.export_to_llmcompressor"] = export_to + sys.modules[module_name] = config_mod + + cfg = quant_utils.get_quant_config(scheme="mxfp8") + assert isinstance(cfg, dict) + assert cfg["provider"] == "auto-round" + assert cfg["config_groups"]["group_0"]["weights"]["is_mx"] is True + assert cfg["config_groups"]["group_0"]["input_activations"]["is_mx"] is True + + +def test_convert_forward_executes(monkeypatch): + linear = nn.Linear(5, 3) + + def fake_forward(self, x): + return torch.zeros(x.shape[0], 3) + + monkeypatch.setattr(QuantLinear, "forward", fake_forward, raising=True) + + quant_utils.convert(linear, quant_cfg=make_quant_cfg(), quant_module=QuantLinear) + out = linear(torch.randn(2, 5)) + assert out.shape == (2, 3) + assert torch.all(out == 0) + + +def test_replace_with_quant_linear_idempotent(quant_cfg): + model = TinyModel() + quant_utils.replace_with_quant_linear(model, quant_cfg=quant_cfg) + quant_utils.replace_with_quant_linear(model, quant_cfg=quant_cfg) + assert isinstance(model.fc1, QuantLinear) + + +@pytest.mark.parametrize("disabled", [True, False]) +def test_get_quantization_format_disabled_returns_none(disabled): + layer = QuantLinear() + layer.weight_quantizer = TensorQuantizer(bits=8, data_type="mx_fp8") + layer.weight_quantizer._disabled = disabled + layer.input_quantizer = TensorQuantizer(bits=8, data_type="mx_fp8") + layer.input_quantizer._disabled = disabled + + fmt = quant_utils.get_quantization_format(layer) + if disabled: + assert fmt is None + else: + assert fmt == "MXFP8" diff --git a/test/3x/torch/algorithms/qat/test_quantizer_and_linear.py b/test/3x/torch/algorithms/qat/test_quantizer_and_linear.py new file mode 100644 index 00000000000..8f5c6108ba8 --- /dev/null +++ b/test/3x/torch/algorithms/qat/test_quantizer_and_linear.py @@ -0,0 +1,165 @@ +import math +import types +import torch +import pytest +import torch.nn as nn + +# Skip the whole module if auto_round (needed for get_quant_func inside TensorQuantizer) is not available +auto_round = pytest.importorskip("auto_round") + +from neural_compressor.torch.algorithms.qat.quant_linear import QuantLinear +from neural_compressor.torch.algorithms.qat.tensor_quantizer import TensorQuantizer + +def make_quant_cfg( + *, + data_type="mx_fp8", + bits=8, + group_size=32, + sym=True, + act_data_type="mx_fp8", + act_bits=8, + act_group_size=32, + act_sym=True, +): + """ + Build a lightweight namespace mimicking the attributes QuantLinear._setup expects. + """ + return types.SimpleNamespace( + data_type=data_type, + bits=bits, + group_size=group_size, + sym=sym, + act_data_type=act_data_type, + act_bits=act_bits, + act_group_size=act_group_size, + act_sym=act_sym, + ) + + +def build_quant_linear(in_features=32, out_features=16, bias=True, quant_cfg=None, device="cpu", dtype=torch.float32): + """ + Manually construct a QuantLinear since the class does not define an __init__. + + Steps: + 1. Instantiate the module + 2. Register parameter tensors (weight, bias) + 3. Add metadata attributes used by extra_repr / repr + 4. Call internal _setup with provided quant config + """ + if quant_cfg is None: + quant_cfg = make_quant_cfg(group_size=32, act_group_size=32) + + ql = QuantLinear() + ql.in_features = in_features + ql.out_features = out_features + + weight = torch.randn(out_features, in_features, device=device, dtype=dtype) + ql.register_parameter("weight", nn.Parameter(weight)) + + if bias: + b = torch.randn(out_features, device=device, dtype=dtype) + ql.register_parameter("bias", nn.Parameter(b)) + else: + ql.bias = None # make sure attribute exists + + ql._setup(quant_cfg) + return ql + + +@pytest.mark.parametrize("bias", [True, False]) +def test_quant_linear_forward_and_backward(bias): + torch.manual_seed(42) + + in_features = 32 + out_features = 16 + batch_size = 3 + + ql = build_quant_linear(in_features=in_features, out_features=out_features, bias=bias) + + # Create a deliberately non-contiguous input (transpose trick) + base = torch.randn(in_features, batch_size) + x = base.t() # shape (batch_size, in_features) but non-contiguous + assert not x.is_contiguous() + + x.requires_grad_(True) + out = ql(x) + + # Shape & dtype checks + assert out.shape == (batch_size, out_features) + assert out.dtype == x.dtype + + # Backward pass + loss = out.sum() + loss.backward() + + assert ql.weight.grad is not None, "Weight should receive gradient through fake quant path" + if bias: + assert ql.bias.grad is not None, "Bias should receive gradient" + else: + assert ql.bias is None + + # Ensure original weight dtype tracked + assert ql.original_weight_dtype == ql.weight.dtype + + # Output quantizer is explicitly disabled in _setup + assert "TensorQuantizer(disabled)" in repr(ql.output_quantizer) + + # Input/weight quantizers should be enabled (not containing 'disabled') + assert "disabled" not in repr(ql.input_quantizer) + assert "disabled" not in repr(ql.weight_quantizer) + + +def test_quant_linear_repr_and_extra_repr(): + ql = build_quant_linear(in_features=8, out_features=4, bias=True) + r = repr(ql) + # Basic structural checks + assert "QuantLinear(" in r + assert "(input_quantizer):" in r + assert "(weight_quantizer):" in r + assert "(output_quantizer):" in r + # extra_repr path + er = ql.extra_repr() + assert "in_features=8" in er + assert "out_features=4" in er + assert "bias=True" in er + + +def test_tensor_quantizer_disable_and_no_quant_path(): + tq = TensorQuantizer(if_quant=False) # constructed with quantization turned off + x = torch.randn(5, 7) + out = tq(x) + # When disabled (not quant) it should return the identical object (same memory) + assert out.data_ptr() == x.data_ptr() + assert repr(tq) == "TensorQuantizer(disabled)" + + +def test_tensor_quantizer_enable_disable_cycle(): + tq = TensorQuantizer() + x = torch.randn(4, 32) # group size default 32, matches last dim + y1 = tq(x) + assert y1.shape == x.shape + # Disable and ensure passthrough (pointer equality) + tq.disable() + y2 = tq(x) + assert y2.data_ptr() == x.data_ptr() + assert "disabled" in repr(tq) + # Re-enable + tq.enable() + y3 = tq(x) + assert y3.shape == x.shape + assert "disabled" not in repr(tq) + + +def test_tensor_quantizer_scale_persistence(): + # Provide scale_shape so internal buffer is registered & updated + tq = TensorQuantizer(scale_shape=(4, 32), block_size=32) + x = torch.randn(4, 32) + # Use internal fake quant function to generate scale + q, shared_exp = tq._fake_quantize(x) + # scale buffer should have been updated (shape (4, 1)) + assert hasattr(tq, "scale") + assert tq.scale.shape == (4, 1) + # We cannot be certain of values, but at least ensure it is uint8 and not all zeros (likely) + assert tq.scale.dtype == torch.uint8 + # Heuristic: at least one non-zero (if all zero it may still be valid, but improbable) + assert (tq.scale != 0).any() or (shared_exp == 0).all()