Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add position_ids in forward #456

Merged
merged 21 commits into from
Jan 8, 2024
Merged

Conversation

jiqing-feng
Copy link
Collaborator

@jiqing-feng jiqing-feng commented Oct 17, 2023

Hi @echarlaix

Do you think that we should add position_ids in the forward of the generation model? The optimum has supported to generate position_ids in this PR.

cc @changwangss

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the addition @jiqing-feng, you're right we need to add the support of position_ids now that it has been integrated in optimum

@@ -88,7 +89,7 @@ def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
traced_model(**model_inputs)
traced_model(**model_inputs)

return traced_model
return traced_model, has_position_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep jit_trace as it is

Suggested change
return traced_model, has_position_ids
return traced_model

@@ -116,6 +118,7 @@ def __init__(
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", None)
self.has_position_ids = has_position_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to have an attribute, we can use MODEL_TYPES_REQUIRING_POSITION_IDS directly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will use it.

Comment on lines 320 to 334
position_ids = kwargs.get("position_ids", None)
if self.has_position_ids and position_ids is not None:
inputs.update({"position_ids": position_ids})
elif self.has_position_ids and position_ids is None:
seq_length = input_ids.shape[-1]
if not self.use_cache:
past_key_values_length = 0
else:
past_key_values_length = past_key_values[0][1].shape[-2]
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=self._device
).unsqueeze(0)
inputs.update({"position_ids": position_ids})
elif not self.has_position_ids and position_ids is not None:
logger.warning("You miss the position_ids in the inputs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should generate the position_ids here as you already added it in prepare_inputs_for_generation, I would just give it when needed by checking the graph as done in https://github.com/huggingface/optimum/blob/e7bd60dd2c1e295263ba57a4e468a62ab5b179e8/optimum/onnxruntime/modeling_decoder.py#L229-L232

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is more reasonable. However, for generation tasks, different decoding way will cause different inputs. For example, llama in greedy_search contains position_ids in inputs but assisted_decoding only have input_ids. Besides, we already generate attention_mask in the forward. WDYT?

Copy link
Collaborator

@echarlaix echarlaix Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point, I'm ok with the modification but think we need to add a test for every architecture to verify we create it correctly. For example is past_key_values_length = past_key_values[0][1].shape[-2] for every architecture ? (looks like it from the empty pkv generation above but would like to verify, also to make sure this is compatible in case we add support for new architectures)

@jiqing-feng
Copy link
Collaborator Author

Hi @echarlaix . I have fixed all tests, would you please help me review these changes? Thx!

@jiqing-feng
Copy link
Collaborator Author

Hi @echarlaix , could you have a look at this change? Thx 😄 !

@jiqing-feng
Copy link
Collaborator Author

Hi @echarlaix . Would you please help me review these changes? This change could avoid forward failure because of the position_ids. For now, position_ids is just like past_key_values in some models, it should be contained in the model inputs since jit trace dummy inputs have position_ids.

Comment on lines 298 to 315
if has_position_ids and position_ids is not None:
inputs.update({"position_ids": position_ids})
elif has_position_ids and position_ids is None:
seq_length = input_ids.shape[-1]
if not self.use_cache:
past_key_values_length = 0
else:
past_key_values_length = (
past_key_values[0].shape[-2]
if model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS
else past_key_values[0][1].shape[-2]
)
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=self._device
).unsqueeze(0)
inputs.update({"position_ids": position_ids})
elif not has_position_ids and position_ids is not None:
logger.warning("You miss the position_ids in the inputs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to check directly in the model graph to check whether position_ids is one of the model's expected input if that can be done. If not this new addition will create issues for all the previously exported models (the ones that were exported without any position_ids)for all architectures from MODEL_TYPES_REQUIRING_POSITION_IDS`

Suggested change
if has_position_ids and position_ids is not None:
inputs.update({"position_ids": position_ids})
elif has_position_ids and position_ids is None:
seq_length = input_ids.shape[-1]
if not self.use_cache:
past_key_values_length = 0
else:
past_key_values_length = (
past_key_values[0].shape[-2]
if model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS
else past_key_values[0][1].shape[-2]
)
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=self._device
).unsqueeze(0)
inputs.update({"position_ids": position_ids})
elif not has_position_ids and position_ids is not None:
logger.warning("You miss the position_ids in the inputs")
if "position_ids" in self.input_names:
if position_ids is None:
position_ids = ...
inputs["position_ids"] = position_ids

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also concerning the position_ids (and the attention_mask) computation I think we should do as follow :

if "attention_mask" in self.input_names or "position_ids" in self.input_names:
if attention_mask is not None:
attention_mask = np.array(attention_mask)
else:
attention_mask = np.ones(
(input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype
)
if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask
if "position_ids" in self.input_names:
if position_ids is not None:
position_ids = np.array(position_ids)
else:
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
if past_key_values:
position_ids = np.expand_dims(position_ids[:, -1], axis=-1)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am afraid there is no input_names attr in this class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes could this be added you think ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can try it

@@ -264,6 +262,7 @@ def forward(
}

model_type = self.config.model_type.replace("_", "-")
has_position_ids = True if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS else False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
has_position_ids = True if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS else False
has_position_ids = model_type in MODEL_TYPES_REQUIRING_POSITION_IDS

@jiqing-feng
Copy link
Collaborator Author

jiqing-feng commented Dec 14, 2023

Hi @echarlaix . I am afraid input_names is not a great way. For example, the model_input_names in GPT2 is ["input_ids", "attention_mask"], see here. However, GPT2 will prepare position_ids in generation task which is not in the model_input_names, see here.

I think it would be better to keep this way and don't use input_names since our goal is to enable jit trace model support generation task, WDYT?

@echarlaix
Copy link
Collaborator

Hi @echarlaix . I am afraid input_names is not a great way. For example, the model_input_names in GPT2 is ["input_ids", "attention_mask"], see here. However, GPT2 will prepare position_ids in generation task which is not in the model_input_names, see here.

I think it would be better to keep this way and don't use input_names since our goal is to enable jit trace model support generation task, WDYT?

I was suggesting to check the model graph directly as done in https://github.com/huggingface/optimum-intel/blob/v1.12.1/optimum/intel/openvino/modeling_base.py#L82 (to check whether position_ids is one of the model's expected input) If that can't be done, this PR might results in issues for all the previously exported models (the ones that were exported without any position_ids) for all architectures from MODEL_TYPES_REQUIRING_POSITION_IDS

@jiqing-feng
Copy link
Collaborator Author

jiqing-feng commented Dec 20, 2023

Hi @echarlaix . I think I got what you mean. The forward inputs were checked by the graph model inputs, could you please help me to review these changes? Thx!

@jiqing-feng
Copy link
Collaborator Author

jiqing-feng commented Dec 22, 2023

Hi @echarlaix . Sorry for the misunderstanding. I just found that there is no way to get the input names from a Torch Script model, so I can only get the input names when tracing the model. Would like to hear your opinion. Thx!

@echarlaix
Copy link
Collaborator

Hi @echarlaix . Sorry for the misunderstanding. I just found that there is no way to get the input names from a Torch Script model, so I can only get the input names when tracing the model. Would like to hear your opinion. Thx!

I was able to have something with :

 input_names = [inputs.debugName() for inputs in model.graph.inputs()]

can you check it out ?

@jiqing-feng
Copy link
Collaborator Author

jiqing-feng commented Jan 5, 2024

Hi @echarlaix . Thanks for your advice, it perfectly fixed my problem. Would you please review these changes? Thx!

And the failed CIs are not related to my changes

@echarlaix
Copy link
Collaborator

Hi @echarlaix . Thanks for your advice, it perfectly fixed my problem. Would you please review these changes? Thx!

And the failed CIs are not related to my changes

Added updates in jiqing-feng#2, can you take a look ?

@echarlaix
Copy link
Collaborator

Also could you add a test before we can merge ?

@jiqing-feng
Copy link
Collaborator Author

jiqing-feng commented Jan 8, 2024

Hi @echarlaix . I have merged your changes and also added the tests. Would you please help to review the test function? Thx!

BTW, failed CIs seem not related to our changes.

@echarlaix echarlaix merged commit c64025d into huggingface:main Jan 8, 2024
9 of 10 checks passed
@jiqing-feng jiqing-feng deleted the position_ids branch January 8, 2024 14:15
PenghuiCheng pushed a commit to PenghuiCheng/optimum-intel that referenced this pull request Jan 16, 2024
* add position_ids in forward

* check if jit model need position_ids

* use MODEL_TYPES_REQUIRING_POSITION_IDS

* fix has_position_ids

* fix position_ids length

* rm useless params

* check model inputs by input names

* fix format

* check input names in graph model

* fix style

* consider eager model in input_names

* add input names

* add text input names

* fix styl;e

* Update optimum/intel/generation/modeling.py

* fix format

* Update optimum/intel/generation/modeling.py

---------

Co-authored-by: Ella Charlaix <ella@huggingface.co>
Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants