Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add position_ids in forward #456

Merged
merged 21 commits into from
Jan 8, 2024
Merged
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
past_key_values = self._convert_to_bloom_cache(past_key_values)

position_ids = kwargs.get("position_ids", None)

attention_mask = kwargs.get("attention_mask", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
Expand Down Expand Up @@ -264,6 +262,7 @@ def forward(
}

model_type = self.config.model_type.replace("_", "-")
has_position_ids = True if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS else False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
has_position_ids = True if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS else False
has_position_ids = model_type in MODEL_TYPES_REQUIRING_POSITION_IDS


if self.use_cache:
if past_key_values is None:
Expand Down Expand Up @@ -296,8 +295,24 @@ def forward(

inputs["past_key_values"] = past_key_values

if position_ids is not None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
inputs["position_ids"] = position_ids
if has_position_ids and position_ids is not None:
inputs.update({"position_ids": position_ids})
elif has_position_ids and position_ids is None:
seq_length = input_ids.shape[-1]
if not self.use_cache:
past_key_values_length = 0
else:
past_key_values_length = (
past_key_values[0].shape[-2]
if model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS
else past_key_values[0][1].shape[-2]
)
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=self._device
).unsqueeze(0)
inputs.update({"position_ids": position_ids})
elif not has_position_ids and position_ids is not None:
logger.warning("You miss the position_ids in the inputs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to check directly in the model graph to check whether position_ids is one of the model's expected input if that can be done. If not this new addition will create issues for all the previously exported models (the ones that were exported without any position_ids)for all architectures from MODEL_TYPES_REQUIRING_POSITION_IDS`

Suggested change
if has_position_ids and position_ids is not None:
inputs.update({"position_ids": position_ids})
elif has_position_ids and position_ids is None:
seq_length = input_ids.shape[-1]
if not self.use_cache:
past_key_values_length = 0
else:
past_key_values_length = (
past_key_values[0].shape[-2]
if model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS
else past_key_values[0][1].shape[-2]
)
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=self._device
).unsqueeze(0)
inputs.update({"position_ids": position_ids})
elif not has_position_ids and position_ids is not None:
logger.warning("You miss the position_ids in the inputs")
if "position_ids" in self.input_names:
if position_ids is None:
position_ids = ...
inputs["position_ids"] = position_ids

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also concerning the position_ids (and the attention_mask) computation I think we should do as follow :

if "attention_mask" in self.input_names or "position_ids" in self.input_names:
if attention_mask is not None:
attention_mask = np.array(attention_mask)
else:
attention_mask = np.ones(
(input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype
)
if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask
if "position_ids" in self.input_names:
if position_ids is not None:
position_ids = np.array(position_ids)
else:
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
if past_key_values:
position_ids = np.expand_dims(position_ids[:, -1], axis=-1)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am afraid there is no input_names attr in this class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes could this be added you think ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can try it


outputs = self.model(**inputs)

Expand All @@ -307,6 +322,7 @@ def forward(
else:
logits = outputs["logits"]
past_key_values = outputs["past_key_values"] if self.use_cache else None

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)


Expand Down
Loading