21
21
22
22
import torch
23
23
from huggingface_hub import hf_hub_download
24
- from transformers import AutoConfig , AutoModelForCausalLM , PretrainedConfig , PreTrainedModel
24
+ from transformers import AutoConfig , AutoModelForCausalLM , GenerationConfig , PretrainedConfig , PreTrainedModel
25
+ from transformers .generation import GenerationMixin
25
26
from transformers .modeling_outputs import CausalLMOutputWithPast
26
27
from transformers .utils import WEIGHTS_NAME
27
28
28
29
from optimum .exporters import TasksManager
29
- from optimum .exporters .onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
30
30
from optimum .modeling_base import OptimizedModel
31
31
from optimum .utils import NormalizedConfigManager
32
32
33
33
from ..utils .constant import _TASK_ALIASES
34
- from ..utils .import_utils import is_torch_version , is_transformers_version
34
+ from ..utils .import_utils import is_torch_version
35
35
from ..utils .modeling_utils import MULTI_QUERY_ATTN_MODELS , patch_decoder_attention_mask
36
36
37
37
38
- if is_transformers_version ("<" , "4.25.0" ):
39
- from transformers .generation_utils import GenerationMixin
40
- else :
41
- from transformers .generation import GenerationMixin
42
-
43
-
44
38
logger = logging .getLogger (__name__ )
45
39
46
40
@@ -112,12 +106,14 @@ def __init__(
112
106
self .normalized_config = NormalizedConfigManager .get_normalized_config_class (config .model_type )(config )
113
107
self .model_dtype = kwargs .get ("model_dtype" , None )
114
108
115
- if is_transformers_version ("<=" , "4.25.1" ):
116
- self .generation_config = None
109
+ if isinstance (model , torch .jit .ScriptModule ):
110
+ self .input_names = {
111
+ inputs .debugName ().split ("." )[0 ] for inputs in model .graph .inputs () if inputs .debugName () != "self"
112
+ }
117
113
else :
118
- from transformers import GenerationConfig
114
+ self . input_names = set ()
119
115
120
- self .generation_config = GenerationConfig .from_model_config (config )
116
+ self .generation_config = GenerationConfig .from_model_config (config )
121
117
122
118
# Avoid warnings when creating a transformers pipeline
123
119
AutoConfig .register (self .base_model_prefix , AutoConfig )
@@ -267,6 +263,7 @@ def forward(
267
263
position_ids : Optional [torch .FloatTensor ] = None ,
268
264
** kwargs ,
269
265
) -> CausalLMOutputWithPast :
266
+ # 1. Prepare model inputs
270
267
if attention_mask is None :
271
268
attention_mask = torch .ones_like (input_ids )
272
269
@@ -275,6 +272,15 @@ def forward(
275
272
"attention_mask" : attention_mask ,
276
273
}
277
274
275
+ if "position_ids" in self .input_names and position_ids is None :
276
+ position_ids = attention_mask .long ().cumsum (- 1 ) - 1
277
+ position_ids .masked_fill_ (attention_mask == 0 , 1 )
278
+ if past_key_values :
279
+ position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
280
+
281
+ if "position_ids" in self .input_names or not self .input_names :
282
+ inputs ["position_ids" ] = position_ids
283
+
278
284
model_type = self .config .model_type .replace ("_" , "-" )
279
285
280
286
if self .use_cache :
@@ -308,17 +314,17 @@ def forward(
308
314
309
315
inputs ["past_key_values" ] = past_key_values
310
316
311
- if position_ids is not None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS :
312
- inputs ["position_ids" ] = position_ids
313
-
317
+ # 2. Model forward
314
318
outputs = self .model (** inputs )
315
319
320
+ # 3. Process model outputs
316
321
if isinstance (outputs , (list , tuple )):
317
322
logits = outputs [0 ]
318
323
past_key_values = outputs [1 ] if self .use_cache else None
319
324
else :
320
325
logits = outputs ["logits" ]
321
326
past_key_values = outputs ["past_key_values" ] if self .use_cache else None
327
+
322
328
return CausalLMOutputWithPast (logits = logits , past_key_values = past_key_values )
323
329
324
330
0 commit comments