Skip to content

Commit

Permalink
fix forward without pkv
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
  • Loading branch information
jiqing-feng committed Dec 4, 2024
1 parent b84274c commit 8113800
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def _llama_model_forward(
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
setattr(past_key_values, "input_lens", input_lens)

for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
Expand All @@ -222,6 +221,7 @@ def _llama_model_forward(
output_attentions=output_attentions,
use_cache=use_cache,
position_embeddings=position_embeddings,
input_lens=input_lens,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -322,7 +322,6 @@ def _falcon_model_forward(
else:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
setattr(past_key_values, "input_lens", input_lens)

next_decoder_cache = None
all_self_attentions = () if output_attentions else None
Expand All @@ -343,6 +342,7 @@ def _falcon_model_forward(
alibi=None,
cache_position=cache_position,
position_embeddings=position_embeddings,
input_lens=input_lens,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -445,8 +445,6 @@ def _gpt2_model_forward(
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32)
if past_key_values is not None:
setattr(past_key_values, "input_lens", input_lens)

presents = None
all_self_attentions = () if output_attentions else None
Expand All @@ -465,6 +463,7 @@ def _gpt2_model_forward(
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
input_lens=input_lens,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -535,7 +534,7 @@ def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if past_key_value is None and kwargs.get("layer_past", None) is not None:
past_key_value = kwargs.pop("layer_past", None)
input_lens = getattr(past_key_value, "input_lens", None)
input_lens = kwargs.pop("input_lens", None)
past_len = 0
if past_key_value is not None:
past_len = past_key_value.get_seq_length()
Expand Down

0 comments on commit 8113800

Please sign in to comment.