diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index da7a42715b..4a7622265c 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -208,7 +208,6 @@ def _llama_model_forward( hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) - setattr(past_key_values, "input_lens", input_lens) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: @@ -222,6 +221,7 @@ def _llama_model_forward( output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, + input_lens=input_lens, ) hidden_states = layer_outputs[0] @@ -322,7 +322,6 @@ def _falcon_model_forward( else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) - setattr(past_key_values, "input_lens", input_lens) next_decoder_cache = None all_self_attentions = () if output_attentions else None @@ -343,6 +342,7 @@ def _falcon_model_forward( alibi=None, cache_position=cache_position, position_embeddings=position_embeddings, + input_lens=input_lens, ) hidden_states = outputs[0] @@ -445,8 +445,6 @@ def _gpt2_model_forward( hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) - if past_key_values is not None: - setattr(past_key_values, "input_lens", input_lens) presents = None all_self_attentions = () if output_attentions else None @@ -465,6 +463,7 @@ def _gpt2_model_forward( encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, + input_lens=input_lens, ) hidden_states = outputs[0] @@ -535,7 +534,7 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if past_key_value is None and kwargs.get("layer_past", None) is not None: past_key_value = kwargs.pop("layer_past", None) - input_lens = getattr(past_key_value, "input_lens", None) + input_lens = kwargs.pop("input_lens", None) past_len = 0 if past_key_value is not None: past_len = past_key_value.get_seq_length()