Skip to content

Commit

Permalink
Remove modeling patching (#782)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng authored Jul 1, 2024
1 parent 60532db commit 1d29904
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 26 deletions.
3 changes: 1 addition & 2 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_torch_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -454,7 +454,6 @@ def _from_transformers(
}

model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
model = patch_decoder_attention_mask(model)

traced_model = jit_trace(model, task, use_cache)
save_dir = TemporaryDirectory()
Expand Down
3 changes: 1 addition & 2 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _IPEX_MINIMUM_VERSION_FOR_PATCHING, _patch_model
from ..generation.modeling import prepare_jit_inputs
from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask, recursive_to_device
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -95,7 +95,6 @@ def ipex_jit_trace(model, task, use_cache):
# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
_enable_tpp()
else:
model = patch_decoder_attention_mask(model)
sample_inputs = prepare_jit_inputs(model, task, use_cache)

model.config.return_dict = False
Expand Down
22 changes: 0 additions & 22 deletions optimum/intel/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import torch
from huggingface_hub import HfApi, HfFolder
from transformers.modeling_utils import PreTrainedModel


MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"}
Expand Down Expand Up @@ -132,27 +131,6 @@ def _prepare_decoder_sliding_window_attention_mask(
return combined_attention_mask


def patch_decoder_attention_mask(model: "PreTrainedModel"):
"""
Apply patch on decoder with past model forward to resolve first inference based on model architecture
Args:
model (PretrainedModel): The model to patch.
Returns:
model with applied patch
"""
if model.config.model_type in {"bloom", "mpt"}:
model.transformer._prepare_attn_mask = _prepare_attn_mask
elif model.config.model_type == "llama":
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
elif model.config.model_type == "mistral":
model.model._prepare_decoder_attention_mask = _prepare_decoder_sliding_window_attention_mask
elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}:
model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
return model


def get_model_device(model: torch.nn.Module) -> torch.device:
"""
Determines the device on which a PyTorch model is currently residing.
Expand Down

0 comments on commit 1d29904

Please sign in to comment.