From 1d2990464477109c29744e5da5cbd0f5cf381a63 Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Mon, 1 Jul 2024 16:36:27 +0800 Subject: [PATCH] Remove modeling patching (#782) --- optimum/intel/generation/modeling.py | 3 +-- optimum/intel/ipex/modeling_base.py | 3 +-- optimum/intel/utils/modeling_utils.py | 22 ---------------------- 3 files changed, 2 insertions(+), 26 deletions(-) diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index 054ef44bfe..7d7e854311 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -34,7 +34,7 @@ from ..utils.constant import _TASK_ALIASES from ..utils.import_utils import is_torch_version -from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask +from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS logger = logging.getLogger(__name__) @@ -454,7 +454,6 @@ def _from_transformers( } model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) - model = patch_decoder_attention_mask(model) traced_model = jit_trace(model, task, use_cache) save_dir = TemporaryDirectory() diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 000ccfa3fe..9f4c0d1056 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -54,7 +54,7 @@ from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _IPEX_MINIMUM_VERSION_FOR_PATCHING, _patch_model from ..generation.modeling import prepare_jit_inputs from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version -from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device +from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device logger = logging.getLogger(__name__) @@ -95,7 +95,6 @@ def ipex_jit_trace(model, task, use_cache): # Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755. _enable_tpp() else: - model = patch_decoder_attention_mask(model) sample_inputs = prepare_jit_inputs(model, task, use_cache) model.config.return_dict = False diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 49069778e5..ef04151119 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -18,7 +18,6 @@ import torch from huggingface_hub import HfApi, HfFolder -from transformers.modeling_utils import PreTrainedModel MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} @@ -132,27 +131,6 @@ def _prepare_decoder_sliding_window_attention_mask( return combined_attention_mask -def patch_decoder_attention_mask(model: "PreTrainedModel"): - """ - Apply patch on decoder with past model forward to resolve first inference based on model architecture - - Args: - model (PretrainedModel): The model to patch. - - Returns: - model with applied patch - """ - if model.config.model_type in {"bloom", "mpt"}: - model.transformer._prepare_attn_mask = _prepare_attn_mask - elif model.config.model_type == "llama": - model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - elif model.config.model_type == "mistral": - model.model._prepare_decoder_attention_mask = _prepare_decoder_sliding_window_attention_mask - elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: - model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - return model - - def get_model_device(model: torch.nn.Module) -> torch.device: """ Determines the device on which a PyTorch model is currently residing.