Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def tune(args):
use_auto_mapping = True

woq_config = AutoRoundConfig(
is_vlm=True,
bits=args.bits,
sym=not args.asym,
group_size=args.group_size,
Expand Down
195 changes: 90 additions & 105 deletions neural_compressor/torch/algorithms/weight_only/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def _is_auto_round_available():
_is_auto_round_available()

from auto_round import AutoRound, AutoRoundMLLM # pylint: disable=E0401
from auto_round.compressors.mllm.eval import lmms_eval, mllm_eval
from auto_round.compressors.mllm.template import Template, get_template
from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401
from auto_round.mllm import lmms_eval, mllm_eval
from auto_round.mllm.template import Template, get_template
from auto_round.schemes import QuantizationScheme

from neural_compressor.torch.algorithms import Quantizer
Expand All @@ -50,7 +50,19 @@ class AutoRoundQuantizer(Quantizer):

def __init__(
self,
quant_config: dict = {},
bits: int = None,
group_size: int = None,
sym: bool = None,
data_type: str = None,
act_bits: int = None,
act_group_size: int = None,
act_sym: bool = None,
act_data_type: str = None,
act_dynamic: bool = None,
super_bits: int = None,
super_group_size: int = None,
quant_config: dict = {}, # for INC
layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None,
enable_full_range: bool = False, ##for symmetric, TODO support later
batch_size: int = 8,
amp: bool = True,
Expand All @@ -71,21 +83,14 @@ def __init__(
gradient_accumulate_steps: int = 1,
not_use_best_mse: bool = False,
dynamic_max_gap: int = -1,
data_type: str = "int",
scale_dtype: str = "fp16",
to_quant_block_names: list = None,
act_bits: int = 32,
act_group_size: int = None,
act_sym: bool = None,
act_dynamic: bool = True,
act_data_type: Optional[str] = None,
low_cpu_mem_usage: bool = False,
export_format: str = "itrex",
# v0.4
enable_norm_bias_tuning: bool = False,
enable_torch_compile: bool = None,
# mllm
is_mllm: bool = False,
quant_nontext_module: bool = False,
extra_data_dir: str = None,
image_processor=None,
Expand Down Expand Up @@ -119,6 +124,7 @@ def __init__(
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether to use symmetric quantization. (default is None).
layer_config (dict, optional): Layer-wise quantization config. Defaults to None.
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether symmetric quantization is to be used (default is False).
Expand Down Expand Up @@ -155,7 +161,6 @@ def __init__(
enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning.
enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True.
quant_nontext_module (bool): Whether to quantize nontext module.
is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM).
extra_data_dir (str): The path for extra data such as images, audio or videos.
processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or
decode the data that groups several modalities (among text, vision and audio).
Expand All @@ -170,12 +175,26 @@ def __init__(
The quantized model.
"""
super().__init__(quant_config)
self.tokenizer = "Placeholder" # for AutoRound initialization
self.layer_config = layer_config
self.output_dir = kwargs.pop("output_dir", "temp_auto_round")
self.tokenizer = kwargs.pop("tokenizer", "Placeholder") # for AutoRound initialization
self.enable_full_range = enable_full_range
self.bits = bits
self.group_size = group_size
self.sym = sym
self.data_type = data_type
self.act_bits = act_bits
self.act_group_size = act_group_size
self.act_sym = act_sym
self.act_data_type = act_data_type
self.act_dynamic = act_dynamic
self.super_bits = super_bits
self.super_group_size = super_group_size
self.batch_size = batch_size
self.amp = amp
self.device = get_accelerator(kwargs.pop("device", "auto")).name()
self.lr_scheduler = lr_scheduler
self.dataset = dataset
self.enable_quanted_input = enable_quanted_input
self.enable_minmax_tuning = enable_minmax_tuning
self.lr = lr
Expand All @@ -190,19 +209,12 @@ def __init__(
self.gradient_accumulate_steps = gradient_accumulate_steps
self.not_use_best_mse = not_use_best_mse
self.dynamic_max_gap = dynamic_max_gap
self.data_type = data_type
self.scale_dtype = scale_dtype
self.to_quant_block_names = to_quant_block_names
self.act_bits = act_bits
self.act_group_size = act_group_size
self.act_sym = act_sym
self.act_dynamic = act_dynamic
self.act_data_type = act_data_type
self.low_cpu_mem_usage = low_cpu_mem_usage
self.export_format = export_format
self.enable_norm_bias_tuning = enable_norm_bias_tuning
self.enable_torch_compile = enable_torch_compile
self.is_mllm = is_mllm
self.quant_nontext_module = quant_nontext_module
self.extra_data_dir = extra_data_dir
self.processor = processor
Expand All @@ -213,7 +225,7 @@ def __init__(
self.device_map = device_map
self.enable_w4afp8 = self._is_w4afp8()

def _is_w4afp8(self):
def _is_w4afp8(self) -> bool:
return any([v.get("data_type", None) == "fp8_to_int_sym" for v in self.quant_config.values()])

def prepare(self, model: torch.nn.Module, *args, **kwargs):
Expand All @@ -237,96 +249,69 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
Returns:
The quantized model.
"""
dataloader = CapturedDataloader(model.args_list, model.kwargs_list)
model = model.orig_model
if self.is_mllm:
rounder = AutoRoundMLLM(
model,
tokenizer=self.tokenizer,
scheme=self.scheme,
processor=self.processor,
image_processor=self.image_processor,
layer_config=self.quant_config,
batch_size=self.batch_size,
amp=self.amp,
device_map=self.device_map,
lr_scheduler=self.lr_scheduler,
dataset=dataloader,
extra_data_dir=self.extra_data_dir,
template=self.template,
quant_nontext_module=self.quant_nontext_module,
enable_quanted_input=self.enable_quanted_input,
enable_minmax_tuning=self.enable_minmax_tuning,
lr=self.lr,
minmax_lr=self.minmax_lr,
low_gpu_mem_usage=self.low_gpu_mem_usage,
low_cpu_mem_usage=self.low_gpu_mem_usage,
iters=self.iters,
seqlen=self.seqlen,
nsamples=self.nsamples,
sampler=self.sampler,
seed=self.seed,
nblocks=self.nblocks,
gradient_accumulate_steps=self.gradient_accumulate_steps,
not_use_best_mse=self.not_use_best_mse,
dynamic_max_gap=self.dynamic_max_gap,
data_type=self.data_type,
scale_dtype=self.scale_dtype,
act_bits=self.act_bits,
act_group_size=self.act_group_size,
act_sym=self.act_sym,
act_dynamic=self.act_dynamic,
to_quant_block_names=self.to_quant_block_names,
enable_norm_bias_tuning=self.enable_norm_bias_tuning,
truncation=self.truncation,
enable_torch_compile=self.enable_torch_compile,
)
tokenizer = getattr(model.orig_model, "tokenizer", None)
if tokenizer is not None:
delattr(model.orig_model, "tokenizer")
else:
rounder = AutoRound(
model=model,
tokenizer=self.tokenizer,
scheme=self.scheme,
dataset=dataloader,
layer_config=self.quant_config or {},
enable_full_range=self.enable_full_range,
batch_size=self.batch_size,
amp=self.amp,
device_map=self.device_map,
lr_scheduler=self.lr_scheduler,
enable_quanted_input=self.enable_quanted_input,
enable_minmax_tuning=self.enable_minmax_tuning,
lr=self.lr,
minmax_lr=self.minmax_lr,
low_gpu_mem_usage=self.low_gpu_mem_usage,
iters=self.iters,
seqlen=self.seqlen,
nsamples=self.nsamples,
sampler=self.sampler,
seed=self.seed,
nblocks=self.nblocks,
gradient_accumulate_steps=self.gradient_accumulate_steps,
not_use_best_mse=self.not_use_best_mse,
dynamic_max_gap=self.dynamic_max_gap,
data_type=self.data_type,
scale_dtype=self.scale_dtype,
to_quant_block_names=self.to_quant_block_names,
act_bits=self.act_bits,
act_group_size=self.act_group_size,
act_sym=self.act_sym,
act_dynamic=self.act_dynamic,
low_cpu_mem_usage=self.low_cpu_mem_usage,
enable_norm_bias_tuning=self.enable_norm_bias_tuning,
enable_torch_compile=self.enable_torch_compile,
)
tokenizer = "Placeholder"
self.dataset = CapturedDataloader(model.args_list, model.kwargs_list)
model = model.orig_model
rounder = AutoRound(
model,
layer_config=self.layer_config,
bits=self.bits,
data_type=self.data_type,
group_size=self.group_size,
sym=self.sym,
act_bits=self.act_bits,
act_group_size=self.act_group_size,
act_sym=self.act_sym,
act_data_type=self.act_data_type,
act_dynamic=self.act_dynamic,
super_bits=self.super_bits,
super_group_size=self.super_group_size,
tokenizer=tokenizer,
scheme=self.scheme,
processor=self.processor,
image_processor=self.image_processor,
enable_full_range=self.enable_full_range,
batch_size=self.batch_size,
amp=self.amp,
device_map=self.device_map,
lr_scheduler=self.lr_scheduler,
dataset=self.dataset,
extra_data_dir=self.extra_data_dir,
template=self.template,
quant_nontext_module=self.quant_nontext_module,
enable_quanted_input=self.enable_quanted_input,
enable_minmax_tuning=self.enable_minmax_tuning,
lr=self.lr,
minmax_lr=self.minmax_lr,
low_gpu_mem_usage=self.low_gpu_mem_usage,
low_cpu_mem_usage=self.low_gpu_mem_usage,
iters=self.iters,
seqlen=self.seqlen,
nsamples=self.nsamples,
sampler=self.sampler,
seed=self.seed,
nblocks=self.nblocks,
gradient_accumulate_steps=self.gradient_accumulate_steps,
not_use_best_mse=self.not_use_best_mse,
dynamic_max_gap=self.dynamic_max_gap,
scale_dtype=self.scale_dtype,
to_quant_block_names=self.to_quant_block_names,
enable_norm_bias_tuning=self.enable_norm_bias_tuning,
truncation=self.truncation,
enable_torch_compile=self.enable_torch_compile,
)
model, weight_config = rounder.quantize()
model.autoround_config = weight_config
if self.enable_w4afp8:
return rounder.save_quantized(output_dir="temp_auto_round", inplace=True)
return rounder.save_quantized(output_dir=self.output_dir, inplace=True)
elif "itrex" in self.export_format:
model = pack_model(model, weight_config, device=self.device, inplace=True)
else: # pragma: no cover
model = rounder.save_quantized(output_dir="temp_auto_round", format=self.export_format, inplace=True)

model = rounder.save_quantized(output_dir=self.output_dir, format=self.export_format, inplace=True)
return model


Expand Down Expand Up @@ -389,8 +374,8 @@ def get_mllm_dataloader(
DataLoader: The DataLoader for the calibrated datasets.
"""
from auto_round.calib_dataset import CALIB_DATASETS
from auto_round.mllm.autoround_mllm import _only_text_test
from auto_round.mllm.mllm_dataset import get_mllm_dataloader # pylint: disable=E0401
from auto_round.compressors.mllm.compressor import _only_text_test
from auto_round.compressors.mllm.dataset import get_mllm_dataloader # pylint: disable=E0401

template = template if template is not None else model.config.model_type
template = get_template(
Expand Down
52 changes: 37 additions & 15 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,25 +579,36 @@ def autoround_quantize_entry(
if quant_config.name != AUTOROUND or quant_config.dtype == "fp32":
continue
else:
dtype = quant_config.dtype
bits = quant_config.bits
if dtype != "int" and "int" in dtype:
if dtype == "fp8_to_int_sym":
group_size = quant_config.group_size
sym = quant_config.use_sym
data_type = quant_config.dtype
act_bits = quant_config.act_bits
act_group_size = quant_config.act_group_size
act_sym = quant_config.act_sym
act_data_type = quant_config.act_dtype
act_dynamic = quant_config.act_dynamic
super_bits = quant_config.super_bits
super_group_size = quant_config.super_group_size
if data_type is not None and data_type != "int" and "int" in data_type:
if data_type == "fp8_to_int_sym":
bits = 4
else:
bits = int(dtype.lstrip("int"))
dtype = "int"
bits = int(data_type.lstrip("int"))
data_type = "int"
weight_config[op_name] = {
"data_type": dtype,
"data_type": data_type,
"bits": bits,
"sym": quant_config.use_sym,
"group_size": quant_config.group_size,
"act_bits": quant_config.act_bits,
"act_group_size": quant_config.act_group_size,
"act_sym": quant_config.act_sym,
"act_dynamic": quant_config.act_dynamic,
"act_data_type": quant_config.act_dtype,
"sym": sym,
"group_size": group_size,
"act_bits": act_bits,
"act_group_size": act_group_size,
"act_sym": act_sym,
"act_dynamic": act_dynamic,
"act_data_type": act_data_type,
}
layer_config = quant_config.to_dict().get("layer_config", None)
output_dir = quant_config.to_dict().get("output_dir", "temp_auto_round")
enable_full_range = quant_config.enable_full_range
batch_size = quant_config.batch_size
amp = quant_config.amp
Expand All @@ -622,7 +633,6 @@ def autoround_quantize_entry(
export_format = quant_config.export_format
enable_norm_bias_tuning = quant_config.enable_norm_bias_tuning
enable_torch_compile = quant_config.enable_torch_compile
is_mllm = quant_config.is_mllm
quant_nontext_module = quant_config.quant_nontext_module
extra_data_dir = quant_config.extra_data_dir
processor = quant_config.processor
Expand All @@ -637,6 +647,19 @@ def autoround_quantize_entry(
model,
quantizer_cls=AutoRoundQuantizer,
quant_config=weight_config,
bits=bits,
data_type=data_type,
group_size=group_size,
sym=sym,
act_bits=act_bits,
act_group_size=act_group_size,
act_sym=act_sym,
act_data_type=act_data_type,
act_dynamic=act_dynamic,
super_bits=super_bits,
super_group_size=super_group_size,
layer_config=layer_config,
output_dir=output_dir,
enable_full_range=enable_full_range,
batch_size=batch_size,
amp=amp,
Expand All @@ -661,7 +684,6 @@ def autoround_quantize_entry(
export_format=export_format,
enable_norm_bias_tuning=enable_norm_bias_tuning,
enable_torch_compile=enable_torch_compile,
is_mllm=is_mllm,
quant_nontext_module=quant_nontext_module,
extra_data_dir=extra_data_dir,
processor=processor,
Expand Down
Loading
Loading