42
42
from ..utils .modeling_utils import MULTI_QUERY_ATTN_MODELS
43
43
from .configuration import _DEFAULT_4BIT_CONFIGS , OVConfig , OVWeightQuantizationConfig , _check_default_4bit_configs
44
44
from .modeling import _TOKENIZER_FOR_DOC , INPUTS_DOCSTRING , MODEL_START_DOCSTRING , OVModel
45
- from .utils import ONNX_WEIGHTS_NAME , OV_XML_FILE_NAME , STR_TO_OV_TYPE
45
+ from .utils import ONNX_WEIGHTS_NAME , OV_TO_NP_TYPE , OV_XML_FILE_NAME , STR_TO_OV_TYPE
46
46
47
47
48
48
if TYPE_CHECKING :
@@ -409,7 +409,7 @@ def prepare_inputs(
409
409
elif self .use_cache :
410
410
for input_name in self .key_value_input_names :
411
411
model_inputs = self .model .input (input_name )
412
- # dtype = OV_TO_NP_TYPE[model_inputs.get_element_type().get_type_name()]
412
+ dtype = OV_TO_NP_TYPE [model_inputs .get_element_type ().get_type_name ()]
413
413
shape = model_inputs .get_partial_shape ()
414
414
if self .config .model_type == "chatglm" :
415
415
shape [0 ] = 0
@@ -420,7 +420,7 @@ def prepare_inputs(
420
420
shape [2 ] = 0
421
421
else :
422
422
shape [1 ] = 0
423
- inputs [input_name ] = Tensor ( model_inputs . get_element_type (), shape . get_shape () )
423
+ inputs [input_name ] = np . empty ([ dim . get_length () for dim in shape ], dtype = dtype )
424
424
else :
425
425
# past_key_values are not used explicitly, instead they are handled inside the model
426
426
if past_key_values is None :
@@ -587,11 +587,11 @@ def _deduplicate_inputs(self, model_inputs: Dict):
587
587
)
588
588
for input_name , input_tensor in model_inputs .items ():
589
589
if input_name not in ["input_ids" , "beam_idx" ]:
590
- if not isinstance ( input_tensor , Tensor ) :
590
+ if input_name not in self . key_value_input_names :
591
591
upd_model_inputs [input_name ] = input_tensor [indicies ]
592
592
else :
593
- shape = input_tensor .shape
594
- dtype = input_tensor .element_type
593
+ shape = input_tensor .shape if isinstance ( input_tensor , Tensor ) else list ( input_tensor . shape )
594
+ dtype = input_tensor .element_type if isinstance ( input_tensor , Tensor ) else Type ( input_tensor . dtype )
595
595
upd_batch_size = indicies .shape [0 ]
596
596
if self .config .model_type == "bloom" :
597
597
upd_batch_size *= self .config .num_attention_heads
0 commit comments