Skip to content

Commit

Permalink
enable compile for non-generation tasks
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
  • Loading branch information
jiqing-feng committed Nov 27, 2024
1 parent 587837e commit 2902247
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 19 deletions.
10 changes: 4 additions & 6 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def patch_op(m, target_m, new_op_name, new_op):
def _patch_llama_model(model):
"""
Patch llama model:
1. Use IPEX Rope and Paged cache
1. Use IPEX rope and paged cache
2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add)
"""
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
Expand All @@ -87,9 +87,8 @@ def _patch_llama_model(model):
def _patch_falcon_model(model):
"""
Patch falcon model:
1. Disable SDPA so the attention mask will be compatible to ipex attention.
2. Use IPEX Rope and paged cache
3. Linear fusion with (Linear + Gelu) and (Linear + Add + Add)
1. Use IPEX rope and paged cache
2. Linear fusion with (Linear + Gelu) and (Linear + Add + Add)
"""
num_key_value_heads = (
model.config.num_kv_heads if (model.config.new_decoder_architecture or not model.config.multi_query) else 1
Expand All @@ -104,8 +103,7 @@ def _patch_falcon_model(model):
def _patch_gpt2_model(model):
"""
Patch gpt2 model:
1. Disable SDPA so the attention mask will be compatible to ipex attention.
2. Use IAKV cache
1. Use IPEX paged attention
"""
num_key_value_heads = model.config.num_attention_heads
setattr(model.config, "num_key_value_heads", num_key_value_heads)
Expand Down
21 changes: 8 additions & 13 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto.auto_factory import _get_model_class as get_model_class

from optimum.exporters import TasksManager
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

Expand Down Expand Up @@ -103,6 +102,11 @@ def __init__(
if hasattr(self.auto_model_class, "register"):
self.auto_model_class.register(AutoConfig, self.__class__)

# Non-generation tasks can use torch.compile to get acceleration.
if self.export_feature not in _IPEX_EXPORTED_GENERATION_TASKS:
logger.info("Enable torch.compile optimization, please warm up by your real case inputs")
self.model.forward = torch.compile(self.model.forward)

@classmethod
def _from_transformers(cls, *args, **kwargs):
return cls._from_pretrained(*args, **kwargs)
Expand Down Expand Up @@ -165,27 +169,18 @@ def _from_pretrained(
)
token = use_auth_token

commit_hash = kwargs.pop("_commit_hash", None)

model_kwargs = {
"revision": revision,
"token": token,
"cache_dir": cache_dir,
"subfolder": subfolder,
"local_files_only": local_files_only,
"force_download": force_download,
"torch_dtype": torch_dtype,
"trust_remote_code": trust_remote_code,
}

task = cls.export_feature
model = TasksManager.get_model_from_task(
task,
model_id,
library_name="transformers",
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
_commit_hash=commit_hash,
**model_kwargs,
)
model = cls.auto_model_class.from_pretrained(model_id, **model_kwargs)
config = model.config
return cls(model, config=config, export=True, **kwargs)

Expand Down

0 comments on commit 2902247

Please sign in to comment.