Skip to content

Commit

Permalink
rm warmup because no jit mode anymore
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
  • Loading branch information
jiqing-feng committed Nov 26, 2024
1 parent 209760d commit 3d32f18
Showing 1 changed file with 3 additions and 65 deletions.
68 changes: 3 additions & 65 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,16 @@
from transformers.models.auto.auto_factory import _get_model_class as get_model_class

from optimum.exporters import TasksManager
from optimum.exporters.tasks import make_backend_config_constructor_for_task
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ...exporters.ipex.cache_utils import IPEXPagedCache
from ...exporters.ipex.model_config import ipex_onnx_config
from ...exporters.ipex.model_patcher import (
_IPEX_EXPORTED_GENERATION_TASKS,
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_patch_model,
)
from ..generation.modeling import get_float_type
from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_ipex_version, is_transformers_version
from ..utils.modeling_utils import recursive_to_device


logger = logging.getLogger(__name__)
Expand All @@ -76,38 +71,6 @@ def _is_patched_with_ipex(model, task, use_cache: bool = True):
return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES


def _prepare_inputs_for_ipex_model(model, task, use_cache):
task = _TASK_ALIASES.get(task, task)
signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__)
if _is_patched_with_ipex(model, task, use_cache) and model.config.model_type in ipex_onnx_config:
onnx_config_class = make_backend_config_constructor_for_task(
ipex_onnx_config[model.config.model_type], task=task
)
else:
onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task)
float_dtype = get_float_type(model.dtype)
if "text-generation" in task:
onnx_config = onnx_config_class(
model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype
)
else:
onnx_config = onnx_config_class(model.config)

dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt")

# Check attention_mask shape
if _is_patched_with_ipex(model, task, use_cache) and model.config.model_type in ipex_onnx_config:
past_len = dummy_inputs["past_key_values"][0][0].shape[-2]
input_len = dummy_inputs["input_ids"].shape[-1]
attention_len = dummy_inputs["attention_mask"].shape[-1]
if attention_len != input_len + past_len:
dummy_inputs["attention_mask"] = torch.ones([dummy_inputs["input_ids"].shape[0], input_len + past_len]).to(
dummy_inputs["input_ids"].dtype
)

return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None}


class IPEXModel(OptimizedModel):
auto_model_class = AutoModel
export_feature = "feature-extraction"
Expand All @@ -121,7 +84,6 @@ def __init__(
config: PretrainedConfig = None,
export: bool = False,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
warmup: bool = True,
**kwargs,
):
config = config or model.config
Expand All @@ -141,8 +103,6 @@ def __init__(
AutoConfig.register(self.base_model_prefix, AutoConfig)
if hasattr(self.auto_model_class, "register"):
self.auto_model_class.register(AutoConfig, self.__class__)
if warmup:
self._init_warmup()

@classmethod
def _from_transformers(cls, *args, **kwargs):
Expand Down Expand Up @@ -260,28 +220,12 @@ def add_patch(self) -> bool:
return self._add_patch

def to(self, device: Union[torch.device, str]):
self.model.to(self.device)
self.model.to(device)
return self

def can_generate(self):
return isinstance(self, GenerationMixin)

def _call_model(self, *args, **kwargs):
out = self.model(*args, **kwargs)
return out

def _init_warmup(self):
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
# the results of the compute are unpredictable
# TODO : add warmup for IPEX exported model
if not self._add_patch:
# use_cache = "past_key_values" in self.input_names
dummy_inputs = _prepare_inputs_for_ipex_model(self, self.export_feature, self.use_cache)
if self.device.type != "cpu":
dummy_inputs = recursive_to_device(value=dummy_inputs, device=self.device)
for _ in range(2):
self(**dummy_inputs)


class IPEXModelForSequenceClassification(IPEXModel):
auto_model_class = AutoModelForSequenceClassification
Expand Down Expand Up @@ -327,13 +271,9 @@ def __init__(
export: bool = False,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
warmup: bool = True,
**kwargs,
):
# Perform the initial warmup at the end of __init__
super().__init__(
model, config, export=export, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache
)
super().__init__(model, config, export=export, model_save_dir=model_save_dir, use_cache=use_cache)

self._supports_cache_class = getattr(model, "_supports_cache_class", None)
self._supports_sdpa = getattr(model, "_supports_sdpa", None)
Expand Down Expand Up @@ -363,8 +303,6 @@ def __init__(
self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache
if hasattr(self.model_cls, "_convert_to_bloom_cache"):
self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache
if warmup:
self._init_warmup()

def forward(
self,
Expand Down Expand Up @@ -396,7 +334,7 @@ def forward(
inputs["past_key_values"] = past_key_values

# 2. Model forward
outputs = self._call_model(**inputs)
outputs = self.model(**inputs)

# 3. Process model outputs
if isinstance(outputs, (list, tuple)):
Expand Down

0 comments on commit 3d32f18

Please sign in to comment.