diff --git a/examples/3.x_api/pytorch/multimodal-modeling/quantization/auto_round/mllm.py b/examples/3.x_api/pytorch/multimodal-modeling/quantization/auto_round/mllm.py index 84f4976bb31..bad8eebdfd0 100644 --- a/examples/3.x_api/pytorch/multimodal-modeling/quantization/auto_round/mllm.py +++ b/examples/3.x_api/pytorch/multimodal-modeling/quantization/auto_round/mllm.py @@ -223,7 +223,6 @@ def tune(args): use_auto_mapping = True woq_config = AutoRoundConfig( - is_vlm=True, bits=args.bits, sym=not args.asym, group_size=args.group_size, diff --git a/neural_compressor/torch/algorithms/weight_only/autoround.py b/neural_compressor/torch/algorithms/weight_only/autoround.py index b3c6b292831..2286b69fdc1 100644 --- a/neural_compressor/torch/algorithms/weight_only/autoround.py +++ b/neural_compressor/torch/algorithms/weight_only/autoround.py @@ -34,9 +34,9 @@ def _is_auto_round_available(): _is_auto_round_available() from auto_round import AutoRound, AutoRoundMLLM # pylint: disable=E0401 +from auto_round.compressors.mllm.eval import lmms_eval, mllm_eval +from auto_round.compressors.mllm.template import Template, get_template from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401 -from auto_round.mllm import lmms_eval, mllm_eval -from auto_round.mllm.template import Template, get_template from auto_round.schemes import QuantizationScheme from neural_compressor.torch.algorithms import Quantizer @@ -50,7 +50,19 @@ class AutoRoundQuantizer(Quantizer): def __init__( self, - quant_config: dict = {}, + bits: int = None, + group_size: int = None, + sym: bool = None, + data_type: str = None, + act_bits: int = None, + act_group_size: int = None, + act_sym: bool = None, + act_data_type: str = None, + act_dynamic: bool = None, + super_bits: int = None, + super_group_size: int = None, + quant_config: dict = {}, # for INC + layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None, enable_full_range: bool = False, ##for symmetric, TODO support later batch_size: int = 8, amp: bool = True, @@ -71,21 +83,14 @@ def __init__( gradient_accumulate_steps: int = 1, not_use_best_mse: bool = False, dynamic_max_gap: int = -1, - data_type: str = "int", scale_dtype: str = "fp16", to_quant_block_names: list = None, - act_bits: int = 32, - act_group_size: int = None, - act_sym: bool = None, - act_dynamic: bool = True, - act_data_type: Optional[str] = None, low_cpu_mem_usage: bool = False, export_format: str = "itrex", # v0.4 enable_norm_bias_tuning: bool = False, enable_torch_compile: bool = None, # mllm - is_mllm: bool = False, quant_nontext_module: bool = False, extra_data_dir: str = None, image_processor=None, @@ -119,6 +124,7 @@ def __init__( bits (int): Number of bits for quantization (default is 4). group_size (int): Size of the quantization group (default is 128). sym (bool): Whether to use symmetric quantization. (default is None). + layer_config (dict, optional): Layer-wise quantization config. Defaults to None. bits (int): Number of bits for quantization (default is 4). group_size (int): Size of the quantization group (default is 128). sym (bool): Whether symmetric quantization is to be used (default is False). @@ -155,7 +161,6 @@ def __init__( enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning. enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True. quant_nontext_module (bool): Whether to quantize nontext module. - is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM). extra_data_dir (str): The path for extra data such as images, audio or videos. processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or decode the data that groups several modalities (among text, vision and audio). @@ -170,12 +175,26 @@ def __init__( The quantized model. """ super().__init__(quant_config) - self.tokenizer = "Placeholder" # for AutoRound initialization + self.layer_config = layer_config + self.output_dir = kwargs.pop("output_dir", "temp_auto_round") + self.tokenizer = kwargs.pop("tokenizer", "Placeholder") # for AutoRound initialization self.enable_full_range = enable_full_range + self.bits = bits + self.group_size = group_size + self.sym = sym + self.data_type = data_type + self.act_bits = act_bits + self.act_group_size = act_group_size + self.act_sym = act_sym + self.act_data_type = act_data_type + self.act_dynamic = act_dynamic + self.super_bits = super_bits + self.super_group_size = super_group_size self.batch_size = batch_size self.amp = amp self.device = get_accelerator(kwargs.pop("device", "auto")).name() self.lr_scheduler = lr_scheduler + self.dataset = dataset self.enable_quanted_input = enable_quanted_input self.enable_minmax_tuning = enable_minmax_tuning self.lr = lr @@ -190,19 +209,12 @@ def __init__( self.gradient_accumulate_steps = gradient_accumulate_steps self.not_use_best_mse = not_use_best_mse self.dynamic_max_gap = dynamic_max_gap - self.data_type = data_type self.scale_dtype = scale_dtype self.to_quant_block_names = to_quant_block_names - self.act_bits = act_bits - self.act_group_size = act_group_size - self.act_sym = act_sym - self.act_dynamic = act_dynamic - self.act_data_type = act_data_type self.low_cpu_mem_usage = low_cpu_mem_usage self.export_format = export_format self.enable_norm_bias_tuning = enable_norm_bias_tuning self.enable_torch_compile = enable_torch_compile - self.is_mllm = is_mllm self.quant_nontext_module = quant_nontext_module self.extra_data_dir = extra_data_dir self.processor = processor @@ -213,7 +225,7 @@ def __init__( self.device_map = device_map self.enable_w4afp8 = self._is_w4afp8() - def _is_w4afp8(self): + def _is_w4afp8(self) -> bool: return any([v.get("data_type", None) == "fp8_to_int_sym" for v in self.quant_config.values()]) def prepare(self, model: torch.nn.Module, *args, **kwargs): @@ -237,96 +249,69 @@ def convert(self, model: torch.nn.Module, *args, **kwargs): Returns: The quantized model. """ - dataloader = CapturedDataloader(model.args_list, model.kwargs_list) - model = model.orig_model - if self.is_mllm: - rounder = AutoRoundMLLM( - model, - tokenizer=self.tokenizer, - scheme=self.scheme, - processor=self.processor, - image_processor=self.image_processor, - layer_config=self.quant_config, - batch_size=self.batch_size, - amp=self.amp, - device_map=self.device_map, - lr_scheduler=self.lr_scheduler, - dataset=dataloader, - extra_data_dir=self.extra_data_dir, - template=self.template, - quant_nontext_module=self.quant_nontext_module, - enable_quanted_input=self.enable_quanted_input, - enable_minmax_tuning=self.enable_minmax_tuning, - lr=self.lr, - minmax_lr=self.minmax_lr, - low_gpu_mem_usage=self.low_gpu_mem_usage, - low_cpu_mem_usage=self.low_gpu_mem_usage, - iters=self.iters, - seqlen=self.seqlen, - nsamples=self.nsamples, - sampler=self.sampler, - seed=self.seed, - nblocks=self.nblocks, - gradient_accumulate_steps=self.gradient_accumulate_steps, - not_use_best_mse=self.not_use_best_mse, - dynamic_max_gap=self.dynamic_max_gap, - data_type=self.data_type, - scale_dtype=self.scale_dtype, - act_bits=self.act_bits, - act_group_size=self.act_group_size, - act_sym=self.act_sym, - act_dynamic=self.act_dynamic, - to_quant_block_names=self.to_quant_block_names, - enable_norm_bias_tuning=self.enable_norm_bias_tuning, - truncation=self.truncation, - enable_torch_compile=self.enable_torch_compile, - ) + tokenizer = getattr(model.orig_model, "tokenizer", None) + if tokenizer is not None: + delattr(model.orig_model, "tokenizer") else: - rounder = AutoRound( - model=model, - tokenizer=self.tokenizer, - scheme=self.scheme, - dataset=dataloader, - layer_config=self.quant_config or {}, - enable_full_range=self.enable_full_range, - batch_size=self.batch_size, - amp=self.amp, - device_map=self.device_map, - lr_scheduler=self.lr_scheduler, - enable_quanted_input=self.enable_quanted_input, - enable_minmax_tuning=self.enable_minmax_tuning, - lr=self.lr, - minmax_lr=self.minmax_lr, - low_gpu_mem_usage=self.low_gpu_mem_usage, - iters=self.iters, - seqlen=self.seqlen, - nsamples=self.nsamples, - sampler=self.sampler, - seed=self.seed, - nblocks=self.nblocks, - gradient_accumulate_steps=self.gradient_accumulate_steps, - not_use_best_mse=self.not_use_best_mse, - dynamic_max_gap=self.dynamic_max_gap, - data_type=self.data_type, - scale_dtype=self.scale_dtype, - to_quant_block_names=self.to_quant_block_names, - act_bits=self.act_bits, - act_group_size=self.act_group_size, - act_sym=self.act_sym, - act_dynamic=self.act_dynamic, - low_cpu_mem_usage=self.low_cpu_mem_usage, - enable_norm_bias_tuning=self.enable_norm_bias_tuning, - enable_torch_compile=self.enable_torch_compile, - ) + tokenizer = "Placeholder" + self.dataset = CapturedDataloader(model.args_list, model.kwargs_list) + model = model.orig_model + rounder = AutoRound( + model, + layer_config=self.layer_config, + bits=self.bits, + data_type=self.data_type, + group_size=self.group_size, + sym=self.sym, + act_bits=self.act_bits, + act_group_size=self.act_group_size, + act_sym=self.act_sym, + act_data_type=self.act_data_type, + act_dynamic=self.act_dynamic, + super_bits=self.super_bits, + super_group_size=self.super_group_size, + tokenizer=tokenizer, + scheme=self.scheme, + processor=self.processor, + image_processor=self.image_processor, + enable_full_range=self.enable_full_range, + batch_size=self.batch_size, + amp=self.amp, + device_map=self.device_map, + lr_scheduler=self.lr_scheduler, + dataset=self.dataset, + extra_data_dir=self.extra_data_dir, + template=self.template, + quant_nontext_module=self.quant_nontext_module, + enable_quanted_input=self.enable_quanted_input, + enable_minmax_tuning=self.enable_minmax_tuning, + lr=self.lr, + minmax_lr=self.minmax_lr, + low_gpu_mem_usage=self.low_gpu_mem_usage, + low_cpu_mem_usage=self.low_gpu_mem_usage, + iters=self.iters, + seqlen=self.seqlen, + nsamples=self.nsamples, + sampler=self.sampler, + seed=self.seed, + nblocks=self.nblocks, + gradient_accumulate_steps=self.gradient_accumulate_steps, + not_use_best_mse=self.not_use_best_mse, + dynamic_max_gap=self.dynamic_max_gap, + scale_dtype=self.scale_dtype, + to_quant_block_names=self.to_quant_block_names, + enable_norm_bias_tuning=self.enable_norm_bias_tuning, + truncation=self.truncation, + enable_torch_compile=self.enable_torch_compile, + ) model, weight_config = rounder.quantize() model.autoround_config = weight_config if self.enable_w4afp8: - return rounder.save_quantized(output_dir="temp_auto_round", inplace=True) + return rounder.save_quantized(output_dir=self.output_dir, inplace=True) elif "itrex" in self.export_format: model = pack_model(model, weight_config, device=self.device, inplace=True) else: # pragma: no cover - model = rounder.save_quantized(output_dir="temp_auto_round", format=self.export_format, inplace=True) - + model = rounder.save_quantized(output_dir=self.output_dir, format=self.export_format, inplace=True) return model @@ -389,8 +374,8 @@ def get_mllm_dataloader( DataLoader: The DataLoader for the calibrated datasets. """ from auto_round.calib_dataset import CALIB_DATASETS - from auto_round.mllm.autoround_mllm import _only_text_test - from auto_round.mllm.mllm_dataset import get_mllm_dataloader # pylint: disable=E0401 + from auto_round.compressors.mllm.compressor import _only_text_test + from auto_round.compressors.mllm.dataset import get_mllm_dataloader # pylint: disable=E0401 template = template if template is not None else model.config.model_type template = get_template( diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 5bb11fb69bc..e7ff4d02d85 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -579,25 +579,36 @@ def autoround_quantize_entry( if quant_config.name != AUTOROUND or quant_config.dtype == "fp32": continue else: - dtype = quant_config.dtype bits = quant_config.bits - if dtype != "int" and "int" in dtype: - if dtype == "fp8_to_int_sym": + group_size = quant_config.group_size + sym = quant_config.use_sym + data_type = quant_config.dtype + act_bits = quant_config.act_bits + act_group_size = quant_config.act_group_size + act_sym = quant_config.act_sym + act_data_type = quant_config.act_dtype + act_dynamic = quant_config.act_dynamic + super_bits = quant_config.super_bits + super_group_size = quant_config.super_group_size + if data_type is not None and data_type != "int" and "int" in data_type: + if data_type == "fp8_to_int_sym": bits = 4 else: - bits = int(dtype.lstrip("int")) - dtype = "int" + bits = int(data_type.lstrip("int")) + data_type = "int" weight_config[op_name] = { - "data_type": dtype, + "data_type": data_type, "bits": bits, - "sym": quant_config.use_sym, - "group_size": quant_config.group_size, - "act_bits": quant_config.act_bits, - "act_group_size": quant_config.act_group_size, - "act_sym": quant_config.act_sym, - "act_dynamic": quant_config.act_dynamic, - "act_data_type": quant_config.act_dtype, + "sym": sym, + "group_size": group_size, + "act_bits": act_bits, + "act_group_size": act_group_size, + "act_sym": act_sym, + "act_dynamic": act_dynamic, + "act_data_type": act_data_type, } + layer_config = quant_config.to_dict().get("layer_config", None) + output_dir = quant_config.to_dict().get("output_dir", "temp_auto_round") enable_full_range = quant_config.enable_full_range batch_size = quant_config.batch_size amp = quant_config.amp @@ -622,7 +633,6 @@ def autoround_quantize_entry( export_format = quant_config.export_format enable_norm_bias_tuning = quant_config.enable_norm_bias_tuning enable_torch_compile = quant_config.enable_torch_compile - is_mllm = quant_config.is_mllm quant_nontext_module = quant_config.quant_nontext_module extra_data_dir = quant_config.extra_data_dir processor = quant_config.processor @@ -637,6 +647,19 @@ def autoround_quantize_entry( model, quantizer_cls=AutoRoundQuantizer, quant_config=weight_config, + bits=bits, + data_type=data_type, + group_size=group_size, + sym=sym, + act_bits=act_bits, + act_group_size=act_group_size, + act_sym=act_sym, + act_data_type=act_data_type, + act_dynamic=act_dynamic, + super_bits=super_bits, + super_group_size=super_group_size, + layer_config=layer_config, + output_dir=output_dir, enable_full_range=enable_full_range, batch_size=batch_size, amp=amp, @@ -661,7 +684,6 @@ def autoround_quantize_entry( export_format=export_format, enable_norm_bias_tuning=enable_norm_bias_tuning, enable_torch_compile=enable_torch_compile, - is_mllm=is_mllm, quant_nontext_module=quant_nontext_module, extra_data_dir=extra_data_dir, processor=processor, diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 27e5a85551e..b4f8f3d4c57 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -936,16 +936,18 @@ class AutoRoundConfig(TorchBaseConfig): def __init__( self, - dtype: str = "int", - bits: int = 4, - use_sym: bool = False, - group_size: int = 128, - # AUTOROUND - act_bits: int = 32, + bits: int = None, + group_size: int = None, + use_sym: bool = None, + dtype: str = None, + act_bits: int = None, act_group_size: int = None, act_sym: bool = None, - act_dynamic: bool = True, - act_dtype: Optional[str] = "int", + act_dtype: str = None, + act_dynamic: bool = None, + super_bits: int = None, + super_group_size: int = None, + # AUTOROUND enable_full_range: bool = False, batch_size: int = 8, amp: bool = True, @@ -970,18 +972,20 @@ def __init__( export_format: str = "itrex", # v0.4 enable_norm_bias_tuning: bool = False, - enable_torch_compile: bool = None, + enable_torch_compile: bool = False, # v0.7 scheme: str | dict = "W4A16", - device_map: str = None, + device_map: [str, int, torch.device, dict] = 0, # mllm - is_mllm: bool = False, quant_nontext_module: bool = False, extra_data_dir: str = None, processor=None, image_processor=None, template=None, truncation: bool = False, + quant_lm_head: bool = False, + # v0.8 + enable_adam: bool = False, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, **kwargs, ): @@ -1024,7 +1028,6 @@ def __init__( enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True. quant_nontext_module (bool): Whether to quantize nontext module. extra_data_dir (str): The path for extra data such as images, audio or videos. - is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM). processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or decode the data that groups several modalities (among text, vision and audio). This is handled by objects called processors, which group together two or more processing objects such @@ -1038,17 +1041,20 @@ def __init__( Default is DEFAULT_WHITE_LIST. """ super().__init__(white_list=white_list) - self.dtype = dtype + + self.enable_full_range = enable_full_range + self.batch_size = batch_size self.bits = bits - self.use_sym = use_sym self.group_size = group_size + self.use_sym = use_sym + self.dtype = dtype self.act_bits = act_bits self.act_group_size = act_group_size self.act_sym = act_sym - self.act_dynamic = act_dynamic self.act_dtype = act_dtype - self.enable_full_range = enable_full_range - self.batch_size = batch_size + self.act_dynamic = act_dynamic + self.super_bits = super_bits + self.super_group_size = super_group_size self.amp = amp self.lr_scheduler = lr_scheduler self.enable_quanted_input = enable_quanted_input @@ -1071,7 +1077,6 @@ def __init__( self.export_format = export_format self.enable_norm_bias_tuning = enable_norm_bias_tuning self.enable_torch_compile = enable_torch_compile - self.is_mllm = is_mllm self.quant_nontext_module = quant_nontext_module self.extra_data_dir = extra_data_dir self.processor = processor @@ -1080,6 +1085,10 @@ def __init__( self.truncation = truncation self.scheme = scheme self.device_map = device_map + self.quant_lm_head = quant_lm_head + # add kwargs + for k, v in kwargs.items(): + setattr(self, k, v) self._post_init() @classmethod @@ -1096,6 +1105,18 @@ def register_supported_configs(cls) -> List[OperatorConfig]: supported_configs.append(OperatorConfig(config=linear_AUTOROUND_config, operators=operators)) cls.supported_configs = supported_configs + def get_params_dict(self): + """Get a dictionary containing the parameters and their values for the current instance. + + Returns: + A dictionary containing the parameters and their values. + """ + result = dict() + for param, value in self.__dict__.items(): + if param not in ["_global_config", "_local_config", "_white_list", "tokenizer"]: + result[param] = value + return result + @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: """Get information about the model. @@ -1114,6 +1135,24 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: logger.debug(f"Get model info: {filter_result}") return filter_result + def to_config_mapping( + self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: + """Convert the configuration to a mapping. + + Args: + config_list (List[BaseConfig]): List of base configurations. Default is None. + model_info (List[Tuple[str, str]]): List of tuples containing the name and type of each module in the model. + Default is None. + + Returns: + OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: The configuration mapping. + """ + if not self.quant_lm_head: + self.set_local(LM_HEAD_NAMES, AutoRoundConfig(dtype="fp32")) + config_mapping = super().to_config_mapping(config_list, model_info) + return config_mapping + @classmethod def get_config_set_for_tuning(cls) -> Union[None, "AutoRoundConfig", List["AutoRoundConfig"]]: """Get the default configuration set for tuning. diff --git a/neural_compressor/torch/quantization/quantize.py b/neural_compressor/torch/quantization/quantize.py index a313220c43e..1e0defbe698 100644 --- a/neural_compressor/torch/quantization/quantize.py +++ b/neural_compressor/torch/quantization/quantize.py @@ -21,6 +21,7 @@ from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry from neural_compressor.common.utils import Mode, call_counter, log_process from neural_compressor.torch.quantization.config import ( + AutoRoundConfig, FP8Config, HybridGPTQConfig, INT8StaticQuantConfig, @@ -88,6 +89,12 @@ def preprocess_quant_config(model, quant_config, mode="prepare", example_inputs= scale_sharing=quant_config.scale_sharing, ) model_info = quant_config.get_model_info(model, example_inputs) + elif isinstance(quant_config, AutoRoundConfig): + _tokenizer_backup = getattr(quant_config, "tokenizer", None) + if _tokenizer_backup is not None: + setattr(model, "tokenizer", _tokenizer_backup) + delattr(quant_config, "tokenizer") + model_info = quant_config.get_model_info(model=model) else: model_info = quant_config.get_model_info(model=model) diff --git a/neural_compressor/transformers/quantization/utils.py b/neural_compressor/transformers/quantization/utils.py index ced7ba174cd..f825970c485 100644 --- a/neural_compressor/transformers/quantization/utils.py +++ b/neural_compressor/transformers/quantization/utils.py @@ -34,7 +34,7 @@ convert, prepare, ) -from neural_compressor.torch.utils import is_ipex_available, is_package_available +from neural_compressor.torch.utils import get_accelerator, is_ipex_available, is_package_available if is_ipex_available(): import intel_extension_for_pytorch as ipex @@ -480,7 +480,9 @@ def convert_to_quantized_model(model, config, device="cpu", for_inference=True): run_fn(model, *run_args) model = convert(model) elif config.quant_method.value == "autoround": - if config.is_vlm is True: + from auto_round.utils import is_mllm_model + + if is_mllm_model(model): from transformers import AutoProcessor, AutoTokenizer from neural_compressor.torch.algorithms.weight_only.autoround import ( @@ -528,11 +530,14 @@ def convert_to_quantized_model(model, config, device="cpu", for_inference=True): bs=config.batch_size, nsamples=config.n_samples, ) + + device_map = get_accelerator().current_device_name() quant_config = AutoRoundConfig( dtype=dtype, bits=config.bits, use_sym=config.sym, group_size=config.group_size, + device_map=device_map, enable_quanted_input=not config.disable_quanted_input, lr=config.lr, minmax_lr=config.minmax_lr, @@ -543,7 +548,6 @@ def convert_to_quantized_model(model, config, device="cpu", for_inference=True): scale_dtype=config.scale_dtype, use_layer_wise=config.use_layer_wise, # vlm arguments - is_mllm=config.is_vlm, quant_nontext_module=config.quant_nontext_module, truncation=config.truncation, gradient_accumulate_steps=config.gradient_accumulate_steps, @@ -551,7 +555,7 @@ def convert_to_quantized_model(model, config, device="cpu", for_inference=True): ) # vlm set non-text module config - if config.is_vlm is True: + if is_mllm_model(model): from neural_compressor.torch.utils.utility import ( find_matching_blocks, get_layer_names_in_block, diff --git a/neural_compressor/transformers/utils/quantization_config.py b/neural_compressor/transformers/utils/quantization_config.py index 94c86a89de2..98bb8b20f0e 100644 --- a/neural_compressor/transformers/utils/quantization_config.py +++ b/neural_compressor/transformers/utils/quantization_config.py @@ -546,7 +546,6 @@ def __init__( use_layer_wise: bool = None, quant_lm_head: bool = False, # vlm arguments - is_vlm: bool = False, quant_nontext_module: bool = False, truncation: bool = False, gradient_accumulate_steps: int = 1, @@ -603,7 +602,6 @@ def __init__( self.model_path = kwargs.get("model_path", "") # vlm arguments - self.is_vlm = is_vlm self.quant_nontext_module = quant_nontext_module self.truncation = truncation self.gradient_accumulate_steps = gradient_accumulate_steps diff --git a/test/3x/torch/quantization/weight_only/test_autoround.py b/test/3x/torch/quantization/weight_only/test_autoround.py index b919e56cf4d..d8c67e4e0e0 100644 --- a/test/3x/torch/quantization/weight_only/test_autoround.py +++ b/test/3x/torch/quantization/weight_only/test_autoround.py @@ -86,11 +86,11 @@ def setup_class(self): torchscript=True, ) self.inp = torch.ones([1, 10], dtype=torch.long) - tokenizer = transformers.AutoTokenizer.from_pretrained( + self.tokenizer = transformers.AutoTokenizer.from_pretrained( "hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True ) from neural_compressor.torch.algorithms.weight_only.autoround import get_dataloader - self.dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=10) + self.dataloader = get_dataloader(self.tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=10) self.label = self.gptj(self.inp)[0] @classmethod @@ -103,9 +103,7 @@ def setup_method(self, method): @pytest.mark.parametrize("quant_lm_head", [True, False]) def test_autoround(self, quant_lm_head): fp32_model = copy.deepcopy(self.gptj) - quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32") - if quant_lm_head is False: - quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) + quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32", quant_lm_head=quant_lm_head) logger.info(f"Test AutoRound with config {quant_config}") # prepare + convert API @@ -142,7 +140,8 @@ def test_int4_dtype(self): def test_autoround_with_quantize_API(self): gpt_j_model = copy.deepcopy(self.gptj) - quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32") + quant_config = AutoRoundConfig(dtype="int", bits=4, act_dtype="int", act_bits=32,nsamples=32, seqlen=10, + iters=10, use_sym=False, group_size=128, amp=False ,scale_dtype="fp32") quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) logger.info(f"Test AutoRound with config {quant_config}") @@ -162,8 +161,9 @@ def test_save_and_load(self): fp32_model = copy.deepcopy(self.gptj) # known issue: scale_dtype="fp32" will cause accuracy gap between quantized model # (using auto-round WeightOnlyLinear) and reloaded model (using INCWeightOnlyLinear) - quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp16") - # quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) + quant_config = AutoRoundConfig(dtype="int", bits=4, act_dtype="int", act_bits=32,nsamples=32, seqlen=10, + iters=10, use_sym=False, group_size=128, amp=False ,scale_dtype="fp16") + quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) logger.info(f"Test AutoRound with config {quant_config}") # quantizer execute @@ -260,7 +260,6 @@ def test_mllm(self): quant_config = AutoRoundConfig( bits=4, group_size=128, - is_mllm=True, nsamples=1, batch_size=batch_size, iters=1, @@ -295,27 +294,77 @@ def test_mllm(self): @pytest.mark.skipif(not ct_installed, reason="The compressed-tensors module is not installed.") @pytest.mark.parametrize("scheme", ["MXFP4", "NVFP4"]) def test_scheme(self, scheme): - fp32_model = copy.deepcopy(self.gptj) + # INC API + from transformers import AutoModelForCausalLM, AutoTokenizer + fp32_model = AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + torchscript=True, + device_map="auto", + ) + inp = torch.ones([1, 10], dtype=torch.long) + tokenizer = AutoTokenizer.from_pretrained( + "facebook/opt-125m", trust_remote_code=True) + + output_dir = "./saved_inc" quant_config = AutoRoundConfig( + tokenizer=tokenizer, nsamples=32, seqlen=10, iters=10, amp=False, scale_dtype="fp16", scheme=scheme, - export_format="llm_compressor", + export_format="auto_round", + output_dir=output_dir, # default is "temp_auto_round" ) - logger.info(f"Test AutoRound with config {quant_config}") # quantizer execute model = prepare(model=fp32_model, quant_config=quant_config) - run_fn(model, self.dataloader) - q_model = convert(model) - out = q_model(self.inp)[0] - assert q_model is not None, "Quantization failed!" - assert q_model.transformer.h[0].attn.k_proj.bits is 4 - assert torch.allclose(out, self.label, atol=1e-1) + inc_model = convert(model) + inc_model = AutoModelForCausalLM.from_pretrained( + output_dir, + torch_dtype="auto", + device_map="auto", + ) + out = inc_model(inp)[0] + + # AutoRound API + from transformers import AutoModelForCausalLM, AutoTokenizer + fp32_model = transformers.AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + torchscript=True, + device_map="auto", + ) + inp = torch.ones([1, 10], dtype=torch.long) + tokenizer = transformers.AutoTokenizer.from_pretrained( + "facebook/opt-125m", trust_remote_code=True) + from auto_round import AutoRound + ar = AutoRound( + model=fp32_model, + tokenizer=tokenizer, + nsamples=32, + seqlen=10, + iters=10, + amp=False, + scale_dtype="fp16", + scheme=scheme, + ) + quantized_model_path = "./saved_ar" + ar.quantize() + model = ar.save_quantized(output_dir=quantized_model_path, inplace=True, format="auto_round") + model = AutoModelForCausalLM.from_pretrained( + quantized_model_path, + torch_dtype="auto", + device_map="auto", + ) + tokenizer = AutoTokenizer.from_pretrained(quantized_model_path) + out_ar = model(inp)[0] + assert torch.all(out_ar.eq(out)) + shutil.rmtree(output_dir, ignore_errors=True) + shutil.rmtree(quantized_model_path, ignore_errors=True) + assert torch.all(out.eq(out_ar)) + @pytest.mark.skipif(not is_habana_framework_installed(), reason="Habana framework is not installed") @pytest.mark.skipif(os.getenv("PT_HPU_LAZY_MODE", "0") == "1", reason="Lazy mode is enabled") @pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed") @@ -398,9 +447,7 @@ def test_autoround_w4a8(self): @pytest.mark.parametrize("quant_lm_head", [True, False]) def test_autoround(self, quant_lm_head): fp32_model = copy.deepcopy(self.tiny_llama_model) - quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, act_dtype="fp32", amp=False ,scale_dtype="fp32") - if quant_lm_head is False: - quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32")) + quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, act_dtype="fp32", amp=False ,scale_dtype="fp32", quant_lm_head=quant_lm_head) logger.info(f"Test AutoRound with config {quant_config}") # prepare + convert API diff --git a/test/3x/torch/quantization/weight_only/test_transformers.py b/test/3x/torch/quantization/weight_only/test_transformers.py index bb894b069cc..f7d07502f26 100644 --- a/test/3x/torch/quantization/weight_only/test_transformers.py +++ b/test/3x/torch/quantization/weight_only/test_transformers.py @@ -234,7 +234,6 @@ def test_vlm(self): woq_config = AutoRoundConfig( bits=4, group_size=128, - is_vlm=True, dataset="NeelNanda/pile-10k", iters=1, n_samples=1, @@ -267,7 +266,6 @@ def test_vlm(self): # woq_config = AutoRoundConfig( # bits=4, # group_size=128, - # is_vlm=True, # dataset="liuhaotian/llava_conv_58k", # iters=2, # n_samples=5, diff --git a/test/3x/torch/requirements.txt b/test/3x/torch/requirements.txt index 592bddb963a..cd58c39fb38 100644 --- a/test/3x/torch/requirements.txt +++ b/test/3x/torch/requirements.txt @@ -1,4 +1,4 @@ -auto_round +auto_round @ git+https://github.com/intel/auto-round.git@v0.8.0rc compressed-tensors datasets deepspeed @ git+https://github.com/HabanaAI/DeepSpeed.git@1.22.0