diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 901e90a421..2f7267c984 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -15,6 +15,7 @@ import logging import os +from functools import wraps from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional, Tuple, Union @@ -45,7 +46,7 @@ from optimum.modeling_base import OptimizedModel from optimum.utils import NormalizedConfigManager -from ..generation.modeling import jit_trace +from ..generation.modeling import jit_trace, prepare_jit_inputs from ..utils.import_utils import is_torch_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask @@ -64,6 +65,7 @@ def __init__( model, config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + warmup: bool = True, **kwargs, ): OptimizedModel.__init__(self, model=model, config=config) @@ -81,6 +83,8 @@ def __init__( AutoConfig.register(self.base_model_prefix, AutoConfig) if hasattr(self.auto_model_class, "register"): self.auto_model_class.register(AutoConfig, self.__class__) + if warmup: + self._init_warmup() @classmethod def _from_transformers( @@ -220,6 +224,14 @@ def _call_model(self, *args, **kwargs): out = self.model(*args, **kwargs) return out + def _init_warmup(self): + # warmup, the first 2 forwards of an IPEX model include some preprocessing steps and + # the results of the compute are unpredictable + use_cache = "past_key_values" in self.input_names + dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache) + for _ in range(2): + self(**dummy_inputs) + class IPEXModelForSequenceClassification(IPEXModel): auto_model_class = AutoModelForSequenceClassification @@ -278,8 +290,21 @@ class IPEXModelForQuestionAnswering(IPEXModel): auto_model_class = AutoModelForQuestionAnswering export_feature = "question-answering" - def forward(self, *args, **kwargs): - outputs = self._call_model(*args, **kwargs) + def forward(self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + token_type_ids: torch.Tensor = None, + **kwargs, + ): + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + if "token_type_ids" in self.input_names: + inputs["token_type_ids"] = token_type_ids + + outputs = self._call_model(**inputs) start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0] end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1] return ModelOutput(start_logits=start_logits, end_logits=end_logits) @@ -295,9 +320,11 @@ def __init__( config: PretrainedConfig = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, use_cache: bool = True, + warmup: bool = True, **kwargs, ): - super().__init__(model, config, model_save_dir=model_save_dir) + # Perform the initial warmup at the end of __init__ + super().__init__(model, config, model_save_dir=model_save_dir, warmup=False) self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) self.model_dtype = kwargs.get("model_dtype", self.dtype) @@ -325,6 +352,8 @@ def __init__( self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache if hasattr(self.model_cls, "_convert_to_bloom_cache"): self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache + if warmup: + self._init_warmup() def _prepare_past_key_values(self, input_ids): model_type = self.config.model_type.replace("_", "-")