Skip to content

Commit 660632a

Browse files
committed
fix
1 parent d1b54c1 commit 660632a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

optimum/intel/openvino/modeling_decoder.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
4343
from .configuration import _DEFAULT_4BIT_CONFIGS, OVConfig, OVWeightQuantizationConfig, _check_default_4bit_configs
4444
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
45-
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE
45+
from .utils import ONNX_WEIGHTS_NAME, OV_TO_NP_TYPE, OV_XML_FILE_NAME, STR_TO_OV_TYPE
4646

4747

4848
if TYPE_CHECKING:
@@ -409,7 +409,7 @@ def prepare_inputs(
409409
elif self.use_cache:
410410
for input_name in self.key_value_input_names:
411411
model_inputs = self.model.input(input_name)
412-
# dtype = OV_TO_NP_TYPE[model_inputs.get_element_type().get_type_name()]
412+
dtype = OV_TO_NP_TYPE[model_inputs.get_element_type().get_type_name()]
413413
shape = model_inputs.get_partial_shape()
414414
if self.config.model_type == "chatglm":
415415
shape[0] = 0
@@ -420,7 +420,7 @@ def prepare_inputs(
420420
shape[2] = 0
421421
else:
422422
shape[1] = 0
423-
inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape())
423+
inputs[input_name] = np.empty([dim.get_length() for dim in shape], dtype=dtype)
424424
else:
425425
# past_key_values are not used explicitly, instead they are handled inside the model
426426
if past_key_values is None:
@@ -587,11 +587,11 @@ def _deduplicate_inputs(self, model_inputs: Dict):
587587
)
588588
for input_name, input_tensor in model_inputs.items():
589589
if input_name not in ["input_ids", "beam_idx"]:
590-
if not isinstance(input_tensor, Tensor):
590+
if input_name not in self.key_value_input_names:
591591
upd_model_inputs[input_name] = input_tensor[indicies]
592592
else:
593-
shape = input_tensor.shape
594-
dtype = input_tensor.element_type
593+
shape = input_tensor.shape if isinstance(input_tensor, Tensor) else list(input_tensor.shape)
594+
dtype = input_tensor.element_type if isinstance(input_tensor, Tensor) else Type(input_tensor.dtype)
595595
upd_batch_size = indicies.shape[0]
596596
if self.config.model_type == "bloom":
597597
upd_batch_size *= self.config.num_attention_heads

0 commit comments

Comments
 (0)