From 6b748dc0bb5806ddadd22da025cc78055f6bd3c3 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Tue, 23 Sep 2025 13:43:49 +0800 Subject: [PATCH 01/12] remove mllm flag Signed-off-by: Kaihui-intel --- .../quantization/auto_round/mllm.py | 1 - .../torch/algorithms/weight_only/autoround.py | 132 ++++++------------ .../torch/quantization/algorithm_entry.py | 2 - .../torch/quantization/config.py | 3 - .../transformers/quantization/utils.py | 6 +- .../transformers/utils/quantization_config.py | 2 - .../weight_only/test_autoround.py | 1 - .../weight_only/test_transformers.py | 2 - 8 files changed, 49 insertions(+), 100 deletions(-) diff --git a/examples/3.x_api/pytorch/multimodal-modeling/quantization/auto_round/mllm.py b/examples/3.x_api/pytorch/multimodal-modeling/quantization/auto_round/mllm.py index 84f4976bb31..bad8eebdfd0 100644 --- a/examples/3.x_api/pytorch/multimodal-modeling/quantization/auto_round/mllm.py +++ b/examples/3.x_api/pytorch/multimodal-modeling/quantization/auto_round/mllm.py @@ -223,7 +223,6 @@ def tune(args): use_auto_mapping = True woq_config = AutoRoundConfig( - is_vlm=True, bits=args.bits, sym=not args.asym, group_size=args.group_size, diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index b3c6b292831..e5ef3cf591a 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -35,9 +35,9 @@ def _is_auto_round_available(): from auto_round import AutoRound, AutoRoundMLLM # pylint: disable=E0401 from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401 -from auto_round.mllm import lmms_eval, mllm_eval -from auto_round.mllm.template import Template, get_template from auto_round.schemes import QuantizationScheme +from auto_round.compressors.mllm.eval import mllm_eval, lmms_eval +from auto_round.compressors.mllm.template import Template, get_template from neural_compressor.torch.algorithms import Quantizer from neural_compressor.torch.utils import get_accelerator, logger @@ -85,7 +85,6 @@ def __init__( enable_norm_bias_tuning: bool = False, enable_torch_compile: bool = None, # mllm - is_mllm: bool = False, quant_nontext_module: bool = False, extra_data_dir: str = None, image_processor=None, @@ -155,7 +154,6 @@ def __init__( enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning. enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True. quant_nontext_module (bool): Whether to quantize nontext module. - is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM). extra_data_dir (str): The path for extra data such as images, audio or videos. processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or decode the data that groups several modalities (among text, vision and audio). @@ -202,7 +200,6 @@ def __init__( self.export_format = export_format self.enable_norm_bias_tuning = enable_norm_bias_tuning self.enable_torch_compile = enable_torch_compile - self.is_mllm = is_mllm self.quant_nontext_module = quant_nontext_module self.extra_data_dir = extra_data_dir self.processor = processor @@ -239,85 +236,48 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): """ dataloader = CapturedDataloader(model.args_list, model.kwargs_list) model = model.orig_model - if self.is_mllm: - rounder = AutoRoundMLLM( - model, - tokenizer=self.tokenizer, - scheme=self.scheme, - processor=self.processor, - image_processor=self.image_processor, - layer_config=self.quant_config, - batch_size=self.batch_size, - amp=self.amp, - device_map=self.device_map, - lr_scheduler=self.lr_scheduler, - dataset=dataloader, - extra_data_dir=self.extra_data_dir, - template=self.template, - quant_nontext_module=self.quant_nontext_module, - enable_quanted_input=self.enable_quanted_input, - enable_minmax_tuning=self.enable_minmax_tuning, - lr=self.lr, - minmax_lr=self.minmax_lr, - low_gpu_mem_usage=self.low_gpu_mem_usage, - low_cpu_mem_usage=self.low_gpu_mem_usage, - iters=self.iters, - seqlen=self.seqlen, - nsamples=self.nsamples, - sampler=self.sampler, - seed=self.seed, - nblocks=self.nblocks, - gradient_accumulate_steps=self.gradient_accumulate_steps, - not_use_best_mse=self.not_use_best_mse, - dynamic_max_gap=self.dynamic_max_gap, - data_type=self.data_type, - scale_dtype=self.scale_dtype, - act_bits=self.act_bits, - act_group_size=self.act_group_size, - act_sym=self.act_sym, - act_dynamic=self.act_dynamic, - to_quant_block_names=self.to_quant_block_names, - enable_norm_bias_tuning=self.enable_norm_bias_tuning, - truncation=self.truncation, - enable_torch_compile=self.enable_torch_compile, - ) - else: - rounder = AutoRound( - model=model, - tokenizer=self.tokenizer, - scheme=self.scheme, - dataset=dataloader, - layer_config=self.quant_config or {}, - enable_full_range=self.enable_full_range, - batch_size=self.batch_size, - amp=self.amp, - device_map=self.device_map, - lr_scheduler=self.lr_scheduler, - enable_quanted_input=self.enable_quanted_input, - enable_minmax_tuning=self.enable_minmax_tuning, - lr=self.lr, - minmax_lr=self.minmax_lr, - low_gpu_mem_usage=self.low_gpu_mem_usage, - iters=self.iters, - seqlen=self.seqlen, - nsamples=self.nsamples, - sampler=self.sampler, - seed=self.seed, - nblocks=self.nblocks, - gradient_accumulate_steps=self.gradient_accumulate_steps, - not_use_best_mse=self.not_use_best_mse, - dynamic_max_gap=self.dynamic_max_gap, - data_type=self.data_type, - scale_dtype=self.scale_dtype, - to_quant_block_names=self.to_quant_block_names, - act_bits=self.act_bits, - act_group_size=self.act_group_size, - act_sym=self.act_sym, - act_dynamic=self.act_dynamic, - low_cpu_mem_usage=self.low_cpu_mem_usage, - enable_norm_bias_tuning=self.enable_norm_bias_tuning, - enable_torch_compile=self.enable_torch_compile, - ) + rounder = AutoRound( + model, + tokenizer=self.tokenizer, + scheme=self.scheme, + processor=self.processor, + image_processor=self.image_processor, + layer_config=self.quant_config, + enable_full_range=self.enable_full_range, + batch_size=self.batch_size, + amp=self.amp, + device_map=self.device_map, + lr_scheduler=self.lr_scheduler, + dataset=dataloader, + extra_data_dir=self.extra_data_dir, + template=self.template, + quant_nontext_module=self.quant_nontext_module, + enable_quanted_input=self.enable_quanted_input, + enable_minmax_tuning=self.enable_minmax_tuning, + lr=self.lr, + minmax_lr=self.minmax_lr, + low_gpu_mem_usage=self.low_gpu_mem_usage, + low_cpu_mem_usage=self.low_gpu_mem_usage, + iters=self.iters, + seqlen=self.seqlen, + nsamples=self.nsamples, + sampler=self.sampler, + seed=self.seed, + nblocks=self.nblocks, + gradient_accumulate_steps=self.gradient_accumulate_steps, + not_use_best_mse=self.not_use_best_mse, + dynamic_max_gap=self.dynamic_max_gap, + data_type=self.data_type, + scale_dtype=self.scale_dtype, + act_bits=self.act_bits, + act_group_size=self.act_group_size, + act_sym=self.act_sym, + act_dynamic=self.act_dynamic, + to_quant_block_names=self.to_quant_block_names, + enable_norm_bias_tuning=self.enable_norm_bias_tuning, + truncation=self.truncation, + enable_torch_compile=self.enable_torch_compile, + ) model, weight_config = rounder.quantize() model.autoround_config = weight_config if self.enable_w4afp8: @@ -389,8 +349,8 @@ def get_mllm_dataloader( DataLoader: The DataLoader for the calibrated datasets. """ from auto_round.calib_dataset import CALIB_DATASETS - from auto_round.mllm.autoround_mllm import _only_text_test - from auto_round.mllm.mllm_dataset import get_mllm_dataloader # pylint: disable=E0401 + from auto_round.compressors.mllm.compressor import _only_text_test + from auto_round.compressors.mllm.dataset import get_mllm_dataloader # pylint: disable=E0401 template = template if template is not None else model.config.model_type template = get_template( diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 5bb11fb69bc..fc148db54f2 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -622,7 +622,6 @@ def autoround_quantize_entry( export_format = quant_config.export_format enable_norm_bias_tuning = quant_config.enable_norm_bias_tuning enable_torch_compile = quant_config.enable_torch_compile - is_mllm = quant_config.is_mllm quant_nontext_module = quant_config.quant_nontext_module extra_data_dir = quant_config.extra_data_dir processor = quant_config.processor @@ -661,7 +660,6 @@ def autoround_quantize_entry( export_format=export_format, enable_norm_bias_tuning=enable_norm_bias_tuning, enable_torch_compile=enable_torch_compile, - is_mllm=is_mllm, quant_nontext_module=quant_nontext_module, extra_data_dir=extra_data_dir, processor=processor, diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 27e5a85551e..559fcf17ec9 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -975,7 +975,6 @@ def __init__( scheme: str | dict = "W4A16", device_map: str = None, # mllm - is_mllm: bool = False, quant_nontext_module: bool = False, extra_data_dir: str = None, processor=None, @@ -1024,7 +1023,6 @@ def __init__( enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True. quant_nontext_module (bool): Whether to quantize nontext module. extra_data_dir (str): The path for extra data such as images, audio or videos. - is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM). processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or decode the data that groups several modalities (among text, vision and audio). This is handled by objects called processors, which group together two or more processing objects such @@ -1071,7 +1069,6 @@ def __init__( self.export_format = export_format self.enable_norm_bias_tuning = enable_norm_bias_tuning self.enable_torch_compile = enable_torch_compile - self.is_mllm = is_mllm self.quant_nontext_module = quant_nontext_module self.extra_data_dir = extra_data_dir self.processor = processor diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py index ced7ba174cd..cd33a01acf1 100644 --- a/neural_compressor/transformers/quantization/utils.py +++ b/neural_compressor/transformers/quantization/utils.py @@ -480,7 +480,8 @@ def convert_to_quantized_model(model, config, device="cpu", for_inference=True): run_fn(model, *run_args) model = convert(model) elif config.quant_method.value == "autoround": - if config.is_vlm is True: + from auto_round.utils import is_mllm_model + if is_mllm_model(model): from transformers import AutoProcessor, AutoTokenizer from neural_compressor.torch.algorithms.weight_only.autoround import ( @@ -543,7 +544,6 @@ def convert_to_quantized_model(model, config, device="cpu", for_inference=True): scale_dtype=config.scale_dtype, use_layer_wise=config.use_layer_wise, # vlm arguments - is_mllm=config.is_vlm, quant_nontext_module=config.quant_nontext_module, truncation=config.truncation, gradient_accumulate_steps=config.gradient_accumulate_steps, @@ -551,7 +551,7 @@ def convert_to_quantized_model(model, config, device="cpu", for_inference=True): ) # vlm set non-text module config - if config.is_vlm is True: + if is_mllm_model(model): from neural_compressor.torch.utils.utility import ( find_matching_blocks, get_layer_names_in_block, diff --git a/neural_compressor/transformers/utils/quantization_config.py b/neural_compressor/transformers/utils/quantization_config.py index 94c86a89de2..98bb8b20f0e 100644 --- a/neural_compressor/transformers/utils/quantization_config.py +++ b/neural_compressor/transformers/utils/quantization_config.py @@ -546,7 +546,6 @@ def __init__( use_layer_wise: bool = None, quant_lm_head: bool = False, # vlm arguments - is_vlm: bool = False, quant_nontext_module: bool = False, truncation: bool = False, gradient_accumulate_steps: int = 1, @@ -603,7 +602,6 @@ def __init__( self.model_path = kwargs.get("model_path", "") # vlm arguments - self.is_vlm = is_vlm self.quant_nontext_module = quant_nontext_module self.truncation = truncation self.gradient_accumulate_steps = gradient_accumulate_steps diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index b919e56cf4d..74d4511e4cf 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -260,7 +260,6 @@ def test_mllm(self): quant_config = AutoRoundConfig( bits=4, group_size=128, - is_mllm=True, nsamples=1, batch_size=batch_size, iters=1, diff --git a/test/3x/torch/quantization/weight_only/test_transformers.py b/test/3x/torch/quantization/weight_only/test_transformers.py index bb894b069cc..f7d07502f26 100644 --- a/test/3x/torch/quantization/weight_only/test_transformers.py +++ b/test/3x/torch/quantization/weight_only/test_transformers.py @@ -234,7 +234,6 @@ def test_vlm(self): woq_config = AutoRoundConfig( bits=4, group_size=128, - is_vlm=True, dataset="NeelNanda/pile-10k", iters=1, n_samples=1, @@ -267,7 +266,6 @@ def test_vlm(self): # woq_config = AutoRoundConfig( # bits=4, # group_size=128, - # is_vlm=True, # dataset="liuhaotian/llava_conv_58k", # iters=2, # n_samples=5, From 7c029aebfc9c4acaecabb0d4551882b1cb6a98b1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Sep 2025 06:48:53 +0000 Subject: [PATCH 02/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/torch/algorithms/weight_only/autoround.py | 4 ++-- neural_compressor/transformers/quantization/utils.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index e5ef3cf591a..14e976f28a2 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -34,10 +34,10 @@ def _is_auto_round_available(): _is_auto_round_available() from auto_round import AutoRound, AutoRoundMLLM # pylint: disable=E0401 +from auto_round.compressors.mllm.eval import lmms_eval, mllm_eval +from auto_round.compressors.mllm.template import Template, get_template from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401 from auto_round.schemes import QuantizationScheme -from auto_round.compressors.mllm.eval import mllm_eval, lmms_eval -from auto_round.compressors.mllm.template import Template, get_template from neural_compressor.torch.algorithms import Quantizer from neural_compressor.torch.utils import get_accelerator, logger diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py index cd33a01acf1..d9500efbea1 100644 --- a/neural_compressor/transformers/quantization/utils.py +++ b/neural_compressor/transformers/quantization/utils.py @@ -481,6 +481,7 @@ def convert_to_quantized_model(model, config, device="cpu", for_inference=True): model = convert(model) elif config.quant_method.value == "autoround": from auto_round.utils import is_mllm_model + if is_mllm_model(model): from transformers import AutoProcessor, AutoTokenizer From 8abc33dcd6803df2d367a7b7dafd2b6a77a071b7 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Thu, 25 Sep 2025 09:27:36 +0800 Subject: [PATCH 03/12] fix device_map in transformers Signed-off-by: Kaihui-intel --- neural_compressor/transformers/quantization/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py index d9500efbea1..644f4ad1a78 100644 --- a/neural_compressor/transformers/quantization/utils.py +++ b/neural_compressor/transformers/quantization/utils.py @@ -34,7 +34,8 @@ convert, prepare, ) -from neural_compressor.torch.utils import is_ipex_available, is_package_available +from neural_compressor.torch.utils import is_ipex_available, is_package_available, get_accelerator + if is_ipex_available(): import intel_extension_for_pytorch as ipex @@ -530,11 +531,14 @@ def convert_to_quantized_model(model, config, device="cpu", for_inference=True): bs=config.batch_size, nsamples=config.n_samples, ) + + device_map = get_accelerator().current_device_name() quant_config = AutoRoundConfig( dtype=dtype, bits=config.bits, use_sym=config.sym, group_size=config.group_size, + device_map=device_map, enable_quanted_input=not config.disable_quanted_input, lr=config.lr, minmax_lr=config.minmax_lr, From e87109919ccd67a83147f7ed12e411d391d66df5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Sep 2025 02:29:57 +0000 Subject: [PATCH 04/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/transformers/quantization/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py index 644f4ad1a78..f825970c485 100644 --- a/neural_compressor/transformers/quantization/utils.py +++ b/neural_compressor/transformers/quantization/utils.py @@ -34,8 +34,7 @@ convert, prepare, ) -from neural_compressor.torch.utils import is_ipex_available, is_package_available, get_accelerator - +from neural_compressor.torch.utils import get_accelerator, is_ipex_available, is_package_available if is_ipex_available(): import intel_extension_for_pytorch as ipex From 324826496c02c44058b934362c43e9c1e4d5b686 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Thu, 25 Sep 2025 13:59:51 +0800 Subject: [PATCH 05/12] support lm_head Signed-off-by: Kaihui-intel --- .../torch/quantization/config.py | 20 +++++++++++++++++++ .../weight_only/test_autoround.py | 8 ++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 559fcf17ec9..b035035edea 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -981,6 +981,7 @@ def __init__( image_processor=None, template=None, truncation: bool = False, + quant_lm_head: bool = False, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, **kwargs, ): @@ -1077,6 +1078,7 @@ def __init__( self.truncation = truncation self.scheme = scheme self.device_map = device_map + self.quant_lm_head = quant_lm_head self._post_init() @classmethod @@ -1110,6 +1112,24 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: filter_result.append(pair) logger.debug(f"Get model info: {filter_result}") return filter_result + + def to_config_mapping( + self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: + """Convert the configuration to a mapping. + + Args: + config_list (List[BaseConfig]): List of base configurations. Default is None. + model_info (List[Tuple[str, str]]): List of tuples containing the name and type of each module in the model. + Default is None. + + Returns: + OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: The configuration mapping. + """ + if not self.quant_lm_head: + self.set_local(LM_HEAD_NAMES, AutoRoundConfig(dtype="fp32")) + config_mapping = super().to_config_mapping(config_list, model_info) + return config_mapping @classmethod def get_config_set_for_tuning(cls) -> Union[None, "AutoRoundConfig", List["AutoRoundConfig"]]: diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index 74d4511e4cf..e0afdfbbf26 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -103,9 +103,7 @@ def setup_method(self, method): @pytest.mark.parametrize("quant_lm_head", [True, False]) def test_autoround(self, quant_lm_head): fp32_model = copy.deepcopy(self.gptj) - quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32") - if quant_lm_head is False: - quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) + quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32", quant_lm_head=quant_lm_head) logger.info(f"Test AutoRound with config {quant_config}") # prepare + convert API @@ -397,9 +395,7 @@ def test_autoround_w4a8(self): @pytest.mark.parametrize("quant_lm_head", [True, False]) def test_autoround(self, quant_lm_head): fp32_model = copy.deepcopy(self.tiny_llama_model) - quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, act_dtype="fp32", amp=False ,scale_dtype="fp32") - if quant_lm_head is False: - quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) + quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, act_dtype="fp32", amp=False ,scale_dtype="fp32", quant_lm_head=quant_lm_head) logger.info(f"Test AutoRound with config {quant_config}") # prepare + convert API From 469bba46c2414d5ba04a3c08f41f11c7454af08e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Sep 2025 07:02:28 +0000 Subject: [PATCH 06/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/torch/quantization/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index b035035edea..801f3661c9c 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -1112,7 +1112,7 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: filter_result.append(pair) logger.debug(f"Get model info: {filter_result}") return filter_result - + def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: From c29976b1ce1670170a37e97551cc883e6eda3f43 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Thu, 25 Sep 2025 22:00:00 +0800 Subject: [PATCH 07/12] update scheme Signed-off-by: Kaihui-intel --- .../torch/algorithms/weight_only/autoround.py | 74 ++++++++++++------- .../torch/quantization/algorithm_entry.py | 46 ++++++++---- .../torch/quantization/config.py | 57 +++++++++----- .../torch/quantization/quantize.py | 7 ++ .../weight_only/test_autoround.py | 36 ++++++--- 5 files changed, 153 insertions(+), 67 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index 14e976f28a2..03589a2a849 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -43,13 +43,21 @@ def _is_auto_round_available(): from neural_compressor.torch.utils import get_accelerator, logger from .utility import CapturedDataloader, InputCaptureModule - - class AutoRoundQuantizer(Quantizer): """AutoRound Quantizer.""" - def __init__( self, + bits: int = None, + group_size: int = None, + sym: bool = None, + data_type: str = None, + act_bits: int = None, + act_group_size: int = None, + act_sym: bool = None, + act_data_type: str = None, + act_dynamic: bool = None, + super_bits: int = None, + super_group_size: int = None, quant_config: dict = {}, enable_full_range: bool = False, ##for symmetric, TODO support later batch_size: int = 8, @@ -71,14 +79,8 @@ def __init__( gradient_accumulate_steps: int = 1, not_use_best_mse: bool = False, dynamic_max_gap: int = -1, - data_type: str = "int", scale_dtype: str = "fp16", to_quant_block_names: list = None, - act_bits: int = 32, - act_group_size: int = None, - act_sym: bool = None, - act_dynamic: bool = True, - act_data_type: Optional[str] = None, low_cpu_mem_usage: bool = False, export_format: str = "itrex", # v0.4 @@ -168,12 +170,24 @@ def __init__( The quantized model. """ super().__init__(quant_config) - self.tokenizer = "Placeholder" # for AutoRound initialization + self.tokenizer = kwargs.pop("tokenizer", "Placeholder") # for AutoRound initialization self.enable_full_range = enable_full_range + self.bits = bits + self.group_size = group_size + self.sym = sym + self.data_type = data_type + self.act_bits= act_bits + self.act_group_size= act_group_size + self.act_sym = act_sym + self.act_data_type = act_data_type + self.act_dynamic = act_dynamic + self.super_bits = super_bits + self.super_group_size = super_group_size self.batch_size = batch_size self.amp = amp self.device = get_accelerator(kwargs.pop("device", "auto")).name() self.lr_scheduler = lr_scheduler + self.dataset = dataset self.enable_quanted_input = enable_quanted_input self.enable_minmax_tuning = enable_minmax_tuning self.lr = lr @@ -188,14 +202,8 @@ def __init__( self.gradient_accumulate_steps = gradient_accumulate_steps self.not_use_best_mse = not_use_best_mse self.dynamic_max_gap = dynamic_max_gap - self.data_type = data_type self.scale_dtype = scale_dtype self.to_quant_block_names = to_quant_block_names - self.act_bits = act_bits - self.act_group_size = act_group_size - self.act_sym = act_sym - self.act_dynamic = act_dynamic - self.act_data_type = act_data_type self.low_cpu_mem_usage = low_cpu_mem_usage self.export_format = export_format self.enable_norm_bias_tuning = enable_norm_bias_tuning @@ -210,7 +218,7 @@ def __init__( self.device_map = device_map self.enable_w4afp8 = self._is_w4afp8() - def _is_w4afp8(self): + def _is_w4afp8(self) -> bool: return any([v.get("data_type", None) == "fp8_to_int_sym" for v in self.quant_config.values()]) def prepare(self, model: torch.nn.Module, *args, **kwargs): @@ -234,21 +242,37 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): Returns: The quantized model. """ - dataloader = CapturedDataloader(model.args_list, model.kwargs_list) + + tokenizer = getattr(model.orig_model, "tokenizer", None) + if tokenizer is not None: + delattr(model.orig_model, "tokenizer") + else: + tokenizer = "Placeholder" + self.dataset = CapturedDataloader(model.args_list, model.kwargs_list) model = model.orig_model rounder = AutoRound( model, - tokenizer=self.tokenizer, + bits=self.bits, + data_type=self.data_type, + group_size=self.group_size, + sym=self.sym, + act_bits=self.act_bits, + act_group_size=self.act_group_size, + act_sym=self.act_sym, + act_data_type=self.act_data_type, + act_dynamic=self.act_dynamic, + super_bits=self.super_bits, + super_group_size=self.super_group_size, + tokenizer=tokenizer, scheme=self.scheme, processor=self.processor, image_processor=self.image_processor, - layer_config=self.quant_config, enable_full_range=self.enable_full_range, batch_size=self.batch_size, amp=self.amp, device_map=self.device_map, lr_scheduler=self.lr_scheduler, - dataset=dataloader, + dataset=self.dataset, extra_data_dir=self.extra_data_dir, template=self.template, quant_nontext_module=self.quant_nontext_module, @@ -267,12 +291,7 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): gradient_accumulate_steps=self.gradient_accumulate_steps, not_use_best_mse=self.not_use_best_mse, dynamic_max_gap=self.dynamic_max_gap, - data_type=self.data_type, scale_dtype=self.scale_dtype, - act_bits=self.act_bits, - act_group_size=self.act_group_size, - act_sym=self.act_sym, - act_dynamic=self.act_dynamic, to_quant_block_names=self.to_quant_block_names, enable_norm_bias_tuning=self.enable_norm_bias_tuning, truncation=self.truncation, @@ -285,7 +304,8 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): elif "itrex" in self.export_format: model = pack_model(model, weight_config, device=self.device, inplace=True) else: # pragma: no cover - model = rounder.save_quantized(output_dir="temp_auto_round", format=self.export_format, inplace=True) + pass + # model = rounder.save_quantized(output_dir="temp_auto_round", format=self.export_format, inplace=False) return model diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index fc148db54f2..c1d861b28b5 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -579,24 +579,33 @@ def autoround_quantize_entry( if quant_config.name != AUTOROUND or quant_config.dtype == "fp32": continue else: - dtype = quant_config.dtype bits = quant_config.bits - if dtype != "int" and "int" in dtype: - if dtype == "fp8_to_int_sym": + group_size = quant_config.group_size + sym = quant_config.use_sym + data_type = quant_config.dtype + act_bits= quant_config.act_bits + act_group_size= quant_config.act_group_size + act_sym = quant_config.act_sym + act_data_type = quant_config.act_dtype + act_dynamic = quant_config.act_dynamic + super_bits = quant_config.super_bits + super_group_size = quant_config.super_group_size + if data_type is not None and data_type != "int" and "int" in data_type: + if data_type == "fp8_to_int_sym": bits = 4 else: - bits = int(dtype.lstrip("int")) - dtype = "int" + bits = int(data_type.lstrip("int")) + data_type = "int" weight_config[op_name] = { - "data_type": dtype, + "data_type": data_type, "bits": bits, - "sym": quant_config.use_sym, - "group_size": quant_config.group_size, - "act_bits": quant_config.act_bits, - "act_group_size": quant_config.act_group_size, - "act_sym": quant_config.act_sym, - "act_dynamic": quant_config.act_dynamic, - "act_data_type": quant_config.act_dtype, + "sym": sym, + "group_size": group_size, + "act_bits": act_bits, + "act_group_size": act_group_size, + "act_sym": act_sym, + "act_dynamic": act_dynamic, + "act_data_type": act_data_type, } enable_full_range = quant_config.enable_full_range batch_size = quant_config.batch_size @@ -636,6 +645,17 @@ def autoround_quantize_entry( model, quantizer_cls=AutoRoundQuantizer, quant_config=weight_config, + bits=bits, + data_type=data_type, + group_size=group_size, + sym=sym, + act_bits=act_bits, + act_group_size=act_group_size, + act_sym=act_sym, + act_data_type=act_data_type, + act_dynamic=act_dynamic, + super_bits=super_bits, + super_group_size=super_group_size, enable_full_range=enable_full_range, batch_size=batch_size, amp=amp, diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index b035035edea..ae6e3e9ecbf 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -936,16 +936,18 @@ class AutoRoundConfig(TorchBaseConfig): def __init__( self, - dtype: str = "int", - bits: int = 4, - use_sym: bool = False, - group_size: int = 128, - # AUTOROUND - act_bits: int = 32, + bits: int = None, + group_size: int = None, + use_sym: bool = None, + dtype: str = None, + act_bits: int = None, act_group_size: int = None, act_sym: bool = None, - act_dynamic: bool = True, - act_dtype: Optional[str] = "int", + act_dtype: str = None, + act_dynamic: bool = None, + super_bits: int = None, + super_group_size: int = None, + # AUTOROUND enable_full_range: bool = False, batch_size: int = 8, amp: bool = True, @@ -970,10 +972,10 @@ def __init__( export_format: str = "itrex", # v0.4 enable_norm_bias_tuning: bool = False, - enable_torch_compile: bool = None, + enable_torch_compile: bool = False, # v0.7 scheme: str | dict = "W4A16", - device_map: str = None, + device_map: [str, int, torch.device, dict] = 0, # mllm quant_nontext_module: bool = False, extra_data_dir: str = None, @@ -982,6 +984,8 @@ def __init__( template=None, truncation: bool = False, quant_lm_head: bool = False, + #v0.8 + enable_adam: bool = False, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, **kwargs, ): @@ -1037,17 +1041,20 @@ def __init__( Default is DEFAULT_WHITE_LIST. """ super().__init__(white_list=white_list) - self.dtype = dtype + + self.enable_full_range = enable_full_range + self.batch_size = batch_size self.bits = bits - self.use_sym = use_sym self.group_size = group_size - self.act_bits = act_bits - self.act_group_size = act_group_size + self.use_sym = use_sym + self.dtype = dtype + self.act_bits= act_bits + self.act_group_size= act_group_size self.act_sym = act_sym - self.act_dynamic = act_dynamic self.act_dtype = act_dtype - self.enable_full_range = enable_full_range - self.batch_size = batch_size + self.act_dynamic = act_dynamic + self.super_bits = super_bits + self.super_group_size = super_group_size self.amp = amp self.lr_scheduler = lr_scheduler self.enable_quanted_input = enable_quanted_input @@ -1079,6 +1086,9 @@ def __init__( self.scheme = scheme self.device_map = device_map self.quant_lm_head = quant_lm_head + # add kwargs + for k, v in kwargs.items(): + setattr(self, k, v) self._post_init() @classmethod @@ -1095,6 +1105,19 @@ def register_supported_configs(cls) -> List[OperatorConfig]: supported_configs.append(OperatorConfig(config=linear_AUTOROUND_config, operators=operators)) cls.supported_configs = supported_configs + def get_params_dict(self): + """Get a dictionary containing the parameters and their values for the current instance. + + Returns: + A dictionary containing the parameters and their values. + """ + result = dict() + for param, value in self.__dict__.items(): + if param not in ["_global_config", "_local_config", "_white_list", "tokenizer"]: + result[param] = value + return result + + @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: """Get information about the model. diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index a313220c43e..2cbc190ff5b 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -25,6 +25,7 @@ HybridGPTQConfig, INT8StaticQuantConfig, SmoothQuantConfig, + AutoRoundConfig, ) from neural_compressor.torch.utils import is_ipex_available, logger from neural_compressor.torch.utils.utility import WHITE_MODULE_LIST, algos_mapping, get_model_info @@ -88,6 +89,12 @@ def preprocess_quant_config(model, quant_config, mode="prepare", example_inputs= scale_sharing=quant_config.scale_sharing, ) model_info = quant_config.get_model_info(model, example_inputs) + elif isinstance(quant_config, AutoRoundConfig): + _tokenizer_backup = getattr(quant_config, "tokenizer", None) + if _tokenizer_backup is not None: + setattr(model, "tokenizer", _tokenizer_backup) + delattr(quant_config, "tokenizer") + model_info = quant_config.get_model_info(model=model) else: model_info = quant_config.get_model_info(model=model) diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index e0afdfbbf26..e55ef2a077a 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -86,11 +86,11 @@ def setup_class(self): torchscript=True, ) self.inp = torch.ones([1, 10], dtype=torch.long) - tokenizer = transformers.AutoTokenizer.from_pretrained( + self.tokenizer = transformers.AutoTokenizer.from_pretrained( "hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True ) from neural_compressor.torch.algorithms.weight_only.autoround import get_dataloader - self.dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=10) + self.dataloader = get_dataloader(self.tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=10) self.label = self.gptj(self.inp)[0] @classmethod @@ -140,7 +140,8 @@ def test_int4_dtype(self): def test_autoround_with_quantize_API(self): gpt_j_model = copy.deepcopy(self.gptj) - quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32") + quant_config = AutoRoundConfig(dtype="int", bits=4, act_dtype="int", act_bits=32,nsamples=32, seqlen=10, + iters=10, use_sym=False, group_size=128, amp=False ,scale_dtype="fp32") quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) logger.info(f"Test AutoRound with config {quant_config}") @@ -160,8 +161,9 @@ def test_save_and_load(self): fp32_model = copy.deepcopy(self.gptj) # known issue: scale_dtype="fp32" will cause accuracy gap between quantized model # (using auto-round WeightOnlyLinear) and reloaded model (using INCWeightOnlyLinear) - quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp16") - # quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) + quant_config = AutoRoundConfig(dtype="int", bits=4, act_dtype="int", act_bits=32,nsamples=32, seqlen=10, + iters=10, use_sym=False, group_size=128, amp=False ,scale_dtype="fp16") + quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) logger.info(f"Test AutoRound with config {quant_config}") # quantizer execute @@ -294,6 +296,7 @@ def test_mllm(self): def test_scheme(self, scheme): fp32_model = copy.deepcopy(self.gptj) quant_config = AutoRoundConfig( + tokenizer=self.tokenizer, nsamples=32, seqlen=10, iters=10, @@ -306,13 +309,26 @@ def test_scheme(self, scheme): # quantizer execute model = prepare(model=fp32_model, quant_config=quant_config) - run_fn(model, self.dataloader) q_model = convert(model) out = q_model(self.inp)[0] - assert q_model is not None, "Quantization failed!" - assert q_model.transformer.h[0].attn.k_proj.bits is 4 - assert torch.allclose(out, self.label, atol=1e-1) - + + from auto_round import AutoRound + fp32_model = copy.deepcopy(self.gptj) + ar = AutoRound( + model=fp32_model, + tokenizer=self.tokenizer, + nsamples=32, + seqlen=10, + iters=10, + amp=False, + scale_dtype="fp16", + scheme=scheme, + export_format="llm_compressor", + ) + ar.quantize() + out_ar = ar.model(self.inp)[0] + assert torch.all(out.eq(out_ar)) + @pytest.mark.skipif(not is_habana_framework_installed(), reason="Habana framework is not installed") @pytest.mark.skipif(os.getenv("PT_HPU_LAZY_MODE", "0") == "1", reason="Lazy mode is enabled") @pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed") From fa053542434aa349e92e346b513360f7ce468405 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Sep 2025 15:02:46 +0000 Subject: [PATCH 08/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../torch/algorithms/weight_only/autoround.py | 7 +++++-- neural_compressor/torch/quantization/algorithm_entry.py | 4 ++-- neural_compressor/torch/quantization/config.py | 9 ++++----- neural_compressor/torch/quantization/quantize.py | 2 +- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index 03589a2a849..86aa03db86f 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -43,8 +43,11 @@ def _is_auto_round_available(): from neural_compressor.torch.utils import get_accelerator, logger from .utility import CapturedDataloader, InputCaptureModule + + class AutoRoundQuantizer(Quantizer): """AutoRound Quantizer.""" + def __init__( self, bits: int = None, @@ -176,8 +179,8 @@ def __init__( self.group_size = group_size self.sym = sym self.data_type = data_type - self.act_bits= act_bits - self.act_group_size= act_group_size + self.act_bits = act_bits + self.act_group_size = act_group_size self.act_sym = act_sym self.act_data_type = act_data_type self.act_dynamic = act_dynamic diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index c1d861b28b5..ef0edd7e3fd 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -583,8 +583,8 @@ def autoround_quantize_entry( group_size = quant_config.group_size sym = quant_config.use_sym data_type = quant_config.dtype - act_bits= quant_config.act_bits - act_group_size= quant_config.act_group_size + act_bits = quant_config.act_bits + act_group_size = quant_config.act_group_size act_sym = quant_config.act_sym act_data_type = quant_config.act_dtype act_dynamic = quant_config.act_dynamic diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 349e39d34d6..b4f8f3d4c57 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -984,7 +984,7 @@ def __init__( template=None, truncation: bool = False, quant_lm_head: bool = False, - #v0.8 + # v0.8 enable_adam: bool = False, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, **kwargs, @@ -1041,15 +1041,15 @@ def __init__( Default is DEFAULT_WHITE_LIST. """ super().__init__(white_list=white_list) - + self.enable_full_range = enable_full_range self.batch_size = batch_size self.bits = bits self.group_size = group_size self.use_sym = use_sym self.dtype = dtype - self.act_bits= act_bits - self.act_group_size= act_group_size + self.act_bits = act_bits + self.act_group_size = act_group_size self.act_sym = act_sym self.act_dtype = act_dtype self.act_dynamic = act_dynamic @@ -1117,7 +1117,6 @@ def get_params_dict(self): result[param] = value return result - @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: """Get information about the model. diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index 2cbc190ff5b..1e0defbe698 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -21,11 +21,11 @@ from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry from neural_compressor.common.utils import Mode, call_counter, log_process from neural_compressor.torch.quantization.config import ( + AutoRoundConfig, FP8Config, HybridGPTQConfig, INT8StaticQuantConfig, SmoothQuantConfig, - AutoRoundConfig, ) from neural_compressor.torch.utils import is_ipex_available, logger from neural_compressor.torch.utils.utility import WHITE_MODULE_LIST, algos_mapping, get_model_info From aafbea53b022232a7aa1ce403482a3e16e0ea5f4 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Fri, 26 Sep 2025 16:31:35 +0800 Subject: [PATCH 09/12] export&save&add layer_config, output_dir Signed-off-by: Kaihui-intel --- .../torch/algorithms/weight_only/autoround.py | 13 +++-- .../torch/quantization/algorithm_entry.py | 4 ++ .../weight_only/test_autoround.py | 56 +++++++++++++++---- 3 files changed, 58 insertions(+), 15 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index 03589a2a849..0d0b681556b 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -58,7 +58,8 @@ def __init__( act_dynamic: bool = None, super_bits: int = None, super_group_size: int = None, - quant_config: dict = {}, + quant_config: dict = {}, # for INC + layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None, enable_full_range: bool = False, ##for symmetric, TODO support later batch_size: int = 8, amp: bool = True, @@ -120,6 +121,7 @@ def __init__( bits (int): Number of bits for quantization (default is 4). group_size (int): Size of the quantization group (default is 128). sym (bool): Whether to use symmetric quantization. (default is None). + layer_config (dict, optional): Layer-wise quantization config. Defaults to None. bits (int): Number of bits for quantization (default is 4). group_size (int): Size of the quantization group (default is 128). sym (bool): Whether symmetric quantization is to be used (default is False). @@ -170,6 +172,8 @@ def __init__( The quantized model. """ super().__init__(quant_config) + self.layer_config = layer_config + self.output_dir = kwargs.pop("output_dir", "temp_auto_round") self.tokenizer = kwargs.pop("tokenizer", "Placeholder") # for AutoRound initialization self.enable_full_range = enable_full_range self.bits = bits @@ -252,6 +256,7 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): model = model.orig_model rounder = AutoRound( model, + layer_config=self.layer_config, bits=self.bits, data_type=self.data_type, group_size=self.group_size, @@ -300,13 +305,11 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): model, weight_config = rounder.quantize() model.autoround_config = weight_config if self.enable_w4afp8: - return rounder.save_quantized(output_dir="temp_auto_round", inplace=True) + return rounder.save_quantized(output_dir=self.output_dir, inplace=True) elif "itrex" in self.export_format: model = pack_model(model, weight_config, device=self.device, inplace=True) else: # pragma: no cover - pass - # model = rounder.save_quantized(output_dir="temp_auto_round", format=self.export_format, inplace=False) - + model = rounder.save_quantized(output_dir=self.output_dir, format=self.export_format, inplace=True) return model diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index c1d861b28b5..c905e84ec3f 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -607,6 +607,8 @@ def autoround_quantize_entry( "act_dynamic": act_dynamic, "act_data_type": act_data_type, } + layer_config = quant_config.to_dict().get("layer_config", None) + output_dir = quant_config.to_dict().get("output_dir", "temp_auto_round") enable_full_range = quant_config.enable_full_range batch_size = quant_config.batch_size amp = quant_config.amp @@ -656,6 +658,8 @@ def autoround_quantize_entry( act_dynamic=act_dynamic, super_bits=super_bits, super_group_size=super_group_size, + layer_config=layer_config, + output_dir=output_dir, enable_full_range=enable_full_range, batch_size=batch_size, amp=amp, diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index e55ef2a077a..d8c67e4e0e0 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -294,39 +294,75 @@ def test_mllm(self): @pytest.mark.skipif(not ct_installed, reason="The compressed-tensors module is not installed.") @pytest.mark.parametrize("scheme", ["MXFP4", "NVFP4"]) def test_scheme(self, scheme): - fp32_model = copy.deepcopy(self.gptj) + # INC API + from transformers import AutoModelForCausalLM, AutoTokenizer + fp32_model = AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + torchscript=True, + device_map="auto", + ) + inp = torch.ones([1, 10], dtype=torch.long) + tokenizer = AutoTokenizer.from_pretrained( + "facebook/opt-125m", trust_remote_code=True) + + output_dir = "./saved_inc" quant_config = AutoRoundConfig( - tokenizer=self.tokenizer, + tokenizer=tokenizer, nsamples=32, seqlen=10, iters=10, amp=False, scale_dtype="fp16", scheme=scheme, - export_format="llm_compressor", + export_format="auto_round", + output_dir=output_dir, # default is "temp_auto_round" ) - logger.info(f"Test AutoRound with config {quant_config}") # quantizer execute model = prepare(model=fp32_model, quant_config=quant_config) - q_model = convert(model) - out = q_model(self.inp)[0] + inc_model = convert(model) + inc_model = AutoModelForCausalLM.from_pretrained( + output_dir, + torch_dtype="auto", + device_map="auto", + ) + out = inc_model(inp)[0] + # AutoRound API + from transformers import AutoModelForCausalLM, AutoTokenizer + fp32_model = transformers.AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + torchscript=True, + device_map="auto", + ) + inp = torch.ones([1, 10], dtype=torch.long) + tokenizer = transformers.AutoTokenizer.from_pretrained( + "facebook/opt-125m", trust_remote_code=True) from auto_round import AutoRound - fp32_model = copy.deepcopy(self.gptj) ar = AutoRound( model=fp32_model, - tokenizer=self.tokenizer, + tokenizer=tokenizer, nsamples=32, seqlen=10, iters=10, amp=False, scale_dtype="fp16", scheme=scheme, - export_format="llm_compressor", ) + quantized_model_path = "./saved_ar" ar.quantize() - out_ar = ar.model(self.inp)[0] + model = ar.save_quantized(output_dir=quantized_model_path, inplace=True, format="auto_round") + model = AutoModelForCausalLM.from_pretrained( + quantized_model_path, + torch_dtype="auto", + device_map="auto", + ) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + out_ar = model(inp)[0] + assert torch.all(out_ar.eq(out)) + shutil.rmtree(output_dir, ignore_errors=True) + shutil.rmtree(quantized_model_path, ignore_errors=True) + assert torch.all(out.eq(out_ar)) @pytest.mark.skipif(not is_habana_framework_installed(), reason="Habana framework is not installed") From 95f20c5e65d57ba08294dbdb3320bc0780ca9c47 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 Sep 2025 09:34:39 +0000 Subject: [PATCH 10/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/torch/algorithms/weight_only/autoround.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index 9da3ebaebbb..1c0fe04c65a 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -61,7 +61,7 @@ def __init__( act_dynamic: bool = None, super_bits: int = None, super_group_size: int = None, - quant_config: dict = {}, # for INC + quant_config: dict = {}, # for INC layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None, enable_full_range: bool = False, ##for symmetric, TODO support later batch_size: int = 8, @@ -175,7 +175,7 @@ def __init__( The quantized model. """ super().__init__(quant_config) - self.layer_config = layer_config + self.layer_config = layer_config self.output_dir = kwargs.pop("output_dir", "temp_auto_round") self.tokenizer = kwargs.pop("tokenizer", "Placeholder") # for AutoRound initialization self.enable_full_range = enable_full_range From e1cd0da232cdf4c1c2a7245df3b4180d9a43a539 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Fri, 26 Sep 2025 22:53:04 +0800 Subject: [PATCH 11/12] fix docstring Signed-off-by: Kaihui-intel --- neural_compressor/torch/algorithms/weight_only/autoround.py | 1 - 1 file changed, 1 deletion(-) diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index 1c0fe04c65a..2286b69fdc1 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -249,7 +249,6 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): Returns: The quantized model. """ - tokenizer = getattr(model.orig_model, "tokenizer", None) if tokenizer is not None: delattr(model.orig_model, "tokenizer") From 1f7152024b76924a470c2d528d3cf3a903d298a7 Mon Sep 17 00:00:00 2001 From: chensuyue Date: Mon, 29 Sep 2025 15:04:43 +0800 Subject: [PATCH 12/12] run ci with auto-round v0.8.0rc Signed-off-by: chensuyue --- test/3x/torch/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/3x/torch/requirements.txt b/test/3x/torch/requirements.txt index 592bddb963a..cd58c39fb38 100644 --- a/test/3x/torch/requirements.txt +++ b/test/3x/torch/requirements.txt @@ -1,4 +1,4 @@ -auto_round +auto_round @ git+https://github.com/intel/auto-round.git@v0.8.0rc compressed-tensors datasets deepspeed @ git+https://github.com/HabanaAI/DeepSpeed.git@1.22.0