Skip to content

Commit be142b5

Browse files
jiqing-fengecharlaix
authored andcommitted
Add INC modeling position_ids generation (huggingface#456)
* add position_ids in forward * check if jit model need position_ids * use MODEL_TYPES_REQUIRING_POSITION_IDS * fix has_position_ids * fix position_ids length * rm useless params * check model inputs by input names * fix format * check input names in graph model * fix style * consider eager model in input_names * add input names * add text input names * fix styl;e * Update optimum/intel/generation/modeling.py * fix format * Update optimum/intel/generation/modeling.py --------- Co-authored-by: Ella Charlaix <ella@huggingface.co> Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent 26e850c commit be142b5

File tree

2 files changed

+34
-16
lines changed

2 files changed

+34
-16
lines changed

optimum/intel/generation/modeling.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,20 @@
2121

2222
import torch
2323
from huggingface_hub import hf_hub_download
24-
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
24+
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig, PretrainedConfig, PreTrainedModel
25+
from transformers.generation import GenerationMixin
2526
from transformers.modeling_outputs import CausalLMOutputWithPast
2627
from transformers.utils import WEIGHTS_NAME
2728

2829
from optimum.exporters import TasksManager
29-
from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
3030
from optimum.modeling_base import OptimizedModel
3131
from optimum.utils import NormalizedConfigManager
3232

3333
from ..utils.constant import _TASK_ALIASES
34-
from ..utils.import_utils import is_torch_version, is_transformers_version
34+
from ..utils.import_utils import is_torch_version
3535
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask
3636

3737

38-
if is_transformers_version("<", "4.25.0"):
39-
from transformers.generation_utils import GenerationMixin
40-
else:
41-
from transformers.generation import GenerationMixin
42-
43-
4438
logger = logging.getLogger(__name__)
4539

4640

@@ -112,12 +106,14 @@ def __init__(
112106
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
113107
self.model_dtype = kwargs.get("model_dtype", None)
114108

115-
if is_transformers_version("<=", "4.25.1"):
116-
self.generation_config = None
109+
if isinstance(model, torch.jit.ScriptModule):
110+
self.input_names = {
111+
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
112+
}
117113
else:
118-
from transformers import GenerationConfig
114+
self.input_names = set()
119115

120-
self.generation_config = GenerationConfig.from_model_config(config)
116+
self.generation_config = GenerationConfig.from_model_config(config)
121117

122118
# Avoid warnings when creating a transformers pipeline
123119
AutoConfig.register(self.base_model_prefix, AutoConfig)
@@ -267,6 +263,7 @@ def forward(
267263
position_ids: Optional[torch.FloatTensor] = None,
268264
**kwargs,
269265
) -> CausalLMOutputWithPast:
266+
# 1. Prepare model inputs
270267
if attention_mask is None:
271268
attention_mask = torch.ones_like(input_ids)
272269

@@ -275,6 +272,15 @@ def forward(
275272
"attention_mask": attention_mask,
276273
}
277274

275+
if "position_ids" in self.input_names and position_ids is None:
276+
position_ids = attention_mask.long().cumsum(-1) - 1
277+
position_ids.masked_fill_(attention_mask == 0, 1)
278+
if past_key_values:
279+
position_ids = position_ids[:, -1].unsqueeze(-1)
280+
281+
if "position_ids" in self.input_names or not self.input_names:
282+
inputs["position_ids"] = position_ids
283+
278284
model_type = self.config.model_type.replace("_", "-")
279285

280286
if self.use_cache:
@@ -308,17 +314,17 @@ def forward(
308314

309315
inputs["past_key_values"] = past_key_values
310316

311-
if position_ids is not None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
312-
inputs["position_ids"] = position_ids
313-
317+
# 2. Model forward
314318
outputs = self.model(**inputs)
315319

320+
# 3. Process model outputs
316321
if isinstance(outputs, (list, tuple)):
317322
logits = outputs[0]
318323
past_key_values = outputs[1] if self.use_cache else None
319324
else:
320325
logits = outputs["logits"]
321326
past_key_values = outputs["past_key_values"] if self.use_cache else None
327+
322328
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
323329

324330

tests/generation/test_modeling.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,15 @@ def test_compare_with_and_without_past_key_values(self):
160160
f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms,"
161161
f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}",
162162
)
163+
164+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
165+
def test_input_names(self, model_arch):
166+
model_id = MODEL_NAMES[model_arch]
167+
model = TSModelForCausalLM.from_pretrained(model_id, export=True)
168+
self.assertTrue(isinstance(model.input_names, set))
169+
self.assertTrue("input_ids" in model.input_names)
170+
self.assertTrue("attention_mask" in model.input_names)
171+
if model.use_cache:
172+
self.assertTrue("past_key_values" in model.input_names)
173+
else:
174+
self.assertTrue("past_key_values" not in model.input_names)

0 commit comments

Comments
 (0)