Skip to content

Commit fc6a3a4

Browse files
committed
fix
1 parent 1d15e35 commit fc6a3a4

File tree

4 files changed

+36
-14
lines changed

4 files changed

+36
-14
lines changed

optimum/intel/ipex/inference.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,24 @@
3131
IPEXModelForMaskedLM,
3232
IPEXModelForSequenceClassification,
3333
IPEXModelForTokenClassification,
34+
IPEXBloomForCausalLM,
35+
IPEXMPTForCausalLM,
36+
IPEXOPTForCausalLM,
37+
IPEXGPTBigCodeForCausalLM,
3438
)
39+
40+
3541
from .utils import _HEAD_TO_AUTOMODELS
3642

3743

44+
_MODEL_TYPE_TO_AUTOMODELS = {
45+
"bloom": IPEXBloomForCausalLM,
46+
"mpt": IPEXMPTForCausalLM,
47+
"opt": IPEXOPTForCausalLM,
48+
"big_code": IPEXGPTBigCodeForCausalLM,
49+
}
50+
51+
3852
logger = logging.getLogger(__name__)
3953

4054
IPEX_NOT_AVAILABLE_ERROR_MSG = (
@@ -131,7 +145,14 @@ def __enter__(self):
131145
)
132146
if task in _HEAD_TO_AUTOMODELS:
133147
model = jit_trace(model, task, use_cache)
134-
model = eval(_HEAD_TO_AUTOMODELS[task])(model, self._original.config, use_cache=use_cache)
148+
model_type = getattr(self._original.config, "model_type", "").replace("_", "-")
149+
150+
if task == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS.keys():
151+
auto_model_class = _MODEL_TYPE_TO_AUTOMODELS[task]
152+
else:
153+
auto_model_class = eval(_HEAD_TO_AUTOMODELS[task])
154+
155+
model = auto_model_class(model, self._original.config, use_cache=use_cache)
135156

136157
# Enable automatic mixed precision (AMP) if we are going to target `bfloat16`
137158
with torch.cpu.amp.autocast(enabled=self._dtype == torch.bfloat16):

optimum/intel/ipex/modeling_base.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,17 +158,10 @@ def _from_pretrained(
158158

159159
model = torch.jit.load(model_cache_path)
160160
torch.jit.freeze(model.eval())
161-
161+
model_type = config.model_type.replace("_", "-")
162162
init_cls = cls
163-
if cls is IPEXModelForCausalLM:
164-
if config.model_type == "bloom":
165-
init_cls = IPEXBloomForCausalLM
166-
elif config.model_type == "mpt":
167-
init_cls = IPEXMPTForCausalLM
168-
elif config.model_type == "opt":
169-
init_cls = IPEXOPTForCausalLM
170-
elif config.model_type == "gpt_bigcode":
171-
init_cls = IPEXGPTBigCodeForCausalLM
163+
if cls.export_feature == "text-generation" and model_type in _MODEL_TYPE_TO_AUTOMODELS:
164+
init_cls = _MODEL_TYPE_TO_AUTOMODELS[model_type]
172165

173166
return init_cls(model, config=config, model_save_dir=model_save_dir, **kwargs)
174167

@@ -535,3 +528,11 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
535528
"position_ids": None,
536529
"attention_mask": attention_mask,
537530
}
531+
532+
533+
_MODEL_TYPE_TO_AUTOMODELS = {
534+
"bloom": IPEXBloomForCausalLM,
535+
"mpt": IPEXMPTForCausalLM,
536+
"opt": IPEXOPTForCausalLM,
537+
"big-code": IPEXGPTBigCodeForCausalLM,
538+
}

optimum/intel/ipex/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@
1717
"text-generation": "IPEXModelForCausalLM",
1818
"text-classification": "IPEXModelForSequenceClassification",
1919
"token-classification": "IPEXModelForTokenClassification",
20-
"question-answering": "IPEXModelForQuestionAnswering",
20+
# "question-answering": "IPEXModelForQuestionAnswering",
2121
}

tests/ipex/test_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929

3030
from optimum.intel import inference_mode as ipex_inference_mode
31-
from optimum.intel.generation.modeling import TSModelForCausalLM
31+
from optimum.intel.ipex.modeling_base import IPEXModel
3232

3333

3434
MODEL_NAMES = {
@@ -112,6 +112,6 @@ def test_text_generation_pipeline_inference(self, model_arch):
112112
text_generator, dtype=model.config.torch_dtype, verbose=False, jit=True
113113
) as ipex_text_generator:
114114
output_ipex = ipex_text_generator(inputs)
115-
self.assertTrue(isinstance(ipex_text_generator.model._optimized, TSModelForCausalLM))
115+
self.assertTrue(isinstance(ipex_text_generator.model._optimized, IPEXModel))
116116
self.assertTrue(isinstance(ipex_text_generator.model._optimized.model, torch.jit.RecursiveScriptModule))
117117
self.assertEqual(output[0]["generated_text"], output_ipex[0]["generated_text"])

0 commit comments

Comments
 (0)