Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add position_ids in forward #456

Merged
merged 21 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals

def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
model_inputs = prepare_jit_inputs(model, task, use_cache)
has_position_ids = True if "position_ids" in model_inputs else False
# check if the model_inputs is correct.
model(**model_inputs)
torch._C._jit_set_texpr_fuser_enabled(False)
Expand All @@ -88,7 +89,7 @@ def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):
traced_model(**model_inputs)
traced_model(**model_inputs)

return traced_model
return traced_model, has_position_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep jit_trace as it is

Suggested change
return traced_model, has_position_ids
return traced_model



class PreTrainedModel(OptimizedModel):
Expand All @@ -107,6 +108,7 @@ def __init__(
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
use_cache: bool = True,
has_position_ids: bool = False,
**kwargs,
):
super(BaseModelForCausalLM, self).__init__(model=model, config=config)
Expand All @@ -116,6 +118,7 @@ def __init__(
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", None)
self.has_position_ids = has_position_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to have an attribute, we can use MODEL_TYPES_REQUIRING_POSITION_IDS directly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will use it.


if is_transformers_version("<=", "4.25.1"):
self.generation_config = None
Expand Down Expand Up @@ -145,7 +148,7 @@ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional
torch.jit.save(self.model, os.path.join(save_directory, WEIGHTS_NAME))

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
past_key_values = past_key_values or kwargs.get("past", None)

if self.use_cache and past_key_values is not None:
Expand All @@ -156,11 +159,19 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": self.use_cache,
"position_ids": None,
"position_ids": position_ids,
"attention_mask": kwargs.get("attention_mask", None),
"token_type_ids": None,
}
Expand Down Expand Up @@ -292,8 +303,8 @@ def forward(
pkv = tuple(empty_tensor for _ in range(nb_pkv))
else:
pkv = ()
for nb_pkv in range(nb_pkv):
if nb_pkv % 2 == 0:
for i in range(nb_pkv):
if i % 2 == 0:
new_shape = [input_ids.shape[0] * num_key_value_heads, d_k, 0]
else:
new_shape = [input_ids.shape[0] * num_key_value_heads, 0, d_k]
Expand All @@ -305,6 +316,23 @@ def forward(
past_key_values = tuple(tuple(pkv) for _ in range(num_layers))

inputs["past_key_values"] = past_key_values

position_ids = kwargs.get("position_ids", None)
if self.has_position_ids and position_ids is not None:
inputs.update({"position_ids": position_ids})
elif self.has_position_ids and position_ids is None:
seq_length = input_ids.shape[-1]
if not self.use_cache:
past_key_values_length = 0
else:
past_key_values_length = past_key_values[0][1].shape[-2]
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=self._device
).unsqueeze(0)
inputs.update({"position_ids": position_ids})
elif not self.has_position_ids and position_ids is not None:
logger.warning("You miss the position_ids in the inputs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should generate the position_ids here as you already added it in prepare_inputs_for_generation, I would just give it when needed by checking the graph as done in https://github.com/huggingface/optimum/blob/e7bd60dd2c1e295263ba57a4e468a62ab5b179e8/optimum/onnxruntime/modeling_decoder.py#L229-L232

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is more reasonable. However, for generation tasks, different decoding way will cause different inputs. For example, llama in greedy_search contains position_ids in inputs but assisted_decoding only have input_ids. Besides, we already generate attention_mask in the forward. WDYT?

Copy link
Collaborator

@echarlaix echarlaix Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point, I'm ok with the modification but think we need to add a test for every architecture to verify we create it correctly. For example is past_key_values_length = past_key_values[0][1].shape[-2] for every architecture ? (looks like it from the empty pkv generation above but would like to verify, also to make sure this is compatible in case we add support for new architectures)


outputs = self.model(**inputs)

if isinstance(outputs, (list, tuple)):
Expand All @@ -313,6 +341,7 @@ def forward(
else:
logits = outputs["logits"]
past_key_values = outputs["past_key_values"] if self.use_cache else None

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)


Expand Down Expand Up @@ -412,7 +441,7 @@ def _from_transformers(
if model.config.model_type == "llama":
model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask

traced_model = jit_trace(model, task, use_cache)
traced_model, has_position_ids = jit_trace(model, task, use_cache)
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME)
Expand All @@ -427,5 +456,6 @@ def _from_transformers(
force_download=force_download,
cache_dir=cache_dir,
local_files_only=local_files_only,
has_position_ids=has_position_ids,
**kwargs,
)
3 changes: 2 additions & 1 deletion optimum/intel/ipex/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __enter__(self):
use_cache = False
if hasattr(self._original.config, "use_cache") and self._original.config.use_cache:
use_cache = True
model = jit_trace(
model, has_position_ids = jit_trace(
model=model,
task=self._model.task,
use_cache=use_cache,
Expand All @@ -126,6 +126,7 @@ def __enter__(self):
config=self._original.config,
use_cache=use_cache,
model_dtype=self._original.dtype,
has_position_ids=has_position_ids,
)
except Exception as e:
logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
Expand Down
3 changes: 2 additions & 1 deletion optimum/intel/neural_compressor/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _from_transformers(
if task == "text-generation":
model = patch_decoder_attention_mask(model)

traced_model = jit_trace(model, task, use_cache)
traced_model, has_position_ids = jit_trace(model, task, use_cache)
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME)
Expand All @@ -250,6 +250,7 @@ def _from_transformers(
force_download=force_download,
cache_dir=cache_dir,
local_files_only=local_files_only,
has_position_ids=has_position_ids,
**kwargs,
)

Expand Down