diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 36074fcc00..779776754a 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -31,14 +31,7 @@ from optimum.exporters.onnx.model_patcher import DecoderModelPatcher from optimum.utils import is_diffusers_available -from ...intel.utils.import_utils import ( - _torch_version, - _transformers_version, - is_nncf_available, - is_optimum_version, - is_torch_version, - is_transformers_version, -) +from ...intel.utils.import_utils import is_nncf_available, is_optimum_version from .model_patcher import patch_model_with_bettertransformer from .stateful import ensure_stateful_is_available, patch_stateful from .utils import ( @@ -331,18 +324,6 @@ def export_pytorch( output = Path(output) if stateful: - if is_transformers_version("<", "4.36") or is_torch_version("<", "2.1.1"): - COLOR_RED = "\033[1;31m" - COLOR_RESET = "\033[0m" - logger.warning( - COLOR_RED - + "[WARNING] For good performance with stateful models, transformers>=4.36.2 and PyTorch>=2.1.1 are required. " - f"This Python environment has Transformers {_transformers_version} and PyTorch {_torch_version}. " - "Consider upgrading PyTorch and Transformers, for example by running " - "`pip install --upgrade --upgrade-strategy eager optimum[openvino,nncf]`, and export the model again" - + COLOR_RESET - ) - # Trigger bettertransformer together with stateful model because OpenVINO HW-dependent transformations expect # both of them are applied to demonstrate the best performance. # TODO: Consider applying bettertransformer regardless of stateful flag -- requires additional validation. diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 37106eacf8..f953771a7a 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -14,16 +14,31 @@ import logging as log -from optimum.intel.utils.import_utils import is_torch_version +from optimum.intel.utils.import_utils import ( + _torch_version, + _transformers_version, + is_torch_version, + is_transformers_version, +) def patch_model_with_bettertransformer(model): - if is_torch_version("<", "2.0"): + # check that the model has not yet been pathced + if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: + return model + + if is_transformers_version("<", "4.36") or is_torch_version("<", "2.1.1"): + COLOR_RED = "\033[1;31m" + COLOR_RESET = "\033[0m" log.warn( - "integration Scaled Dot Product Attention optimization supported only with torch > 2.0." - "Usage model with stateful=True may be non-effective if model does not contain torch.functional.scaled_dot_product_attention" - "It is recommended to upgrade PyTorch version for using stateful model or use stateful=False" + COLOR_RED + + "[WARNING] For good performance with stateful models, transformers>=4.36.2 and PyTorch>=2.1.1 are required. " + f"This Python environment has Transformers {_transformers_version} and PyTorch {_torch_version}. " + "Consider upgrading PyTorch and Transformers, for example by running " + "`pip install --upgrade --upgrade-strategy eager optimum[openvino,nncf]`, and export the model again" + + COLOR_RESET ) + # model already has required SDPA implementation if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa": return model diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 3a2e55978c..cb6e577291 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -24,22 +24,26 @@ import transformers from accelerate.data_loader import DataLoaderStateMixin from datasets import Dataset, load_dataset -from nncf import NNCFConfig, compress_weights +from nncf import NNCFConfig from nncf.torch import create_compressed_model, register_default_init_args, register_module from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk from nncf.torch.initialization import PTInitializingDataLoader from openvino._offline_transformations import compress_quantize_weights_transformation from openvino.runtime import Core, Tensor +from torch.utils._pytree import tree_map from torch.utils.data import DataLoader, RandomSampler from transformers import DataCollator, PreTrainedModel, default_data_collator from transformers.pytorch_utils import Conv1D +from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed from optimum.exporters.tasks import TasksManager from optimum.quantization_base import OptimumQuantizer from ...exporters.openvino import export, export_pytorch_via_onnx -from ...exporters.openvino.stateful import ensure_export_task_support_stateful +from ...exporters.openvino.model_patcher import patch_model_with_bettertransformer +from ...exporters.openvino.stateful import ensure_export_task_support_stateful, ensure_stateful_is_available from ..utils.constant import _TASK_ALIASES +from ..utils.modeling_utils import get_model_device from .configuration import OVConfig from .modeling_base import OVBaseModel from .modeling_decoder import OVBaseDecoderModel @@ -361,9 +365,7 @@ def _quantize_ovcausallm( self.model.model, quantization_dataset, model_type=nncf.ModelType.TRANSFORMER if not kwargs.get("model_type") else kwargs.get("model_type"), - fast_bias_correction=True - if not kwargs.get("fast_bias_correction") - else kwargs.get("fast_bias_correction"), + fast_bias_correction=kwargs.get("fast_bias_correction", True), **kwargs, ) self.model.model = quantized_model @@ -405,13 +407,42 @@ def _quantize_torchmodel( if file_name is None and ov_config.save_onnx_model else Path(ov_file_name).with_suffix(".onnx") ) + + task = self.task + model = self.model + self.model.config.save_pretrained(save_directory) + if task.startswith("text-generation"): + onnx_config = onnx_config_class( + model.config, use_past=model.config.use_cache, use_past_in_inputs=model.config.use_cache + ) + if model.config.use_cache: + task = "text-generation-with-past" + else: + onnx_config = onnx_config_class(model.config) + + stateful = ensure_stateful_is_available() and ensure_export_task_support_stateful(task) + if weights_only: - if getattr(self.model.config, "tie_word_embeddings", True): - # to fix problem with shared embedding weights in nncf compress_weights() - self.model.tie_weights() - compressed_model = compress_weights(self.model) - self.model = compressed_model + if stateful: + # patch model before weight compression + model = patch_model_with_bettertransformer(model) + + dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") + device = get_model_device(model) + dummy_inputs = tree_map( + lambda value: value.to(device) if isinstance(value, torch.Tensor) else value, dummy_inputs + ) + check_dummy_inputs_are_allowed(model, dummy_inputs) + + nncf.compress_weights(model, dataset=nncf.Dataset([dummy_inputs])) else: + if stateful: + logger.warn( + "Quantization algorithm does not support optimized stateful models. " + "The original model without optimization will be quantized and export." + ) + stateful = False + calibration_dataloader = self._get_calibration_dataloader( calibration_dataset=calibration_dataset, batch_size=batch_size, @@ -423,22 +454,10 @@ def _quantize_torchmodel( ov_config.add_input_info(model_inputs) nncf_config = NNCFConfig.from_dict(ov_config.__dict__) nncf_config = register_default_init_args(nncf_config, calibration_dataloader) - controller, compressed_model = create_compressed_model( - self.model, nncf_config, wrap_inputs_fn=wrap_nncf_model_inputs_with_objwalk - ) - compressed_model = controller.strip(do_copy=False) - - task = self.task - model = self.model - self.model.config.save_pretrained(save_directory) - if task.startswith("text-generation"): - onnx_config = onnx_config_class( - model.config, use_past=model.config.use_cache, use_past_in_inputs=model.config.use_cache + controller, model = create_compressed_model( + model, nncf_config, wrap_inputs_fn=wrap_nncf_model_inputs_with_objwalk ) - if model.config.use_cache: - task = "text-generation-with-past" - else: - onnx_config = onnx_config_class(model.config) + model = controller.strip(do_copy=False) model_path = save_directory / (onnx_file_name if ov_config.save_onnx_model else ov_file_name) onnx_path = save_directory / onnx_file_name @@ -447,7 +466,8 @@ def _quantize_torchmodel( opset = max(opset, MIN_ONNX_QDQ_OPSET) kwargs = {} if not ov_config.save_onnx_model: - kwargs = {"stateful": ensure_export_task_support_stateful(task)} + kwargs = {"stateful": stateful} + _, _, is_onnx = export_fn(model=model, config=onnx_config, output=model_path, opset=opset, **kwargs) if is_onnx: # Load and save the compressed model diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 1a3b6fbede..99ad42aafa 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -148,3 +148,24 @@ def patch_decoder_attention_mask(model: "PreTrainedModel"): elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask return model + + +def get_model_device(model: torch.nn.Module) -> torch.device: + """ + Determines the device on which a PyTorch model is currently residing. + + Args: + model: The PyTorch model to query. + + Returns: + torch.device: The device where the model's parameters are located. + + Raises: + StopIteration: If the model has no parameters. + """ + try: + device = next(model.parameters()).device + except StopIteration: + # The model had no parameters at all, doesn't matter which device to choose + device = torch.device("cpu") + return device