Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add position_ids in forward #456

Merged
merged 21 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions optimum/intel/generation/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,20 @@

import torch
from huggingface_hub import hf_hub_download
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import WEIGHTS_NAME

from optimum.exporters import TasksManager
from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager

from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_torch_version, is_transformers_version
from ..utils.import_utils import is_torch_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask


if is_transformers_version("<", "4.25.0"):
from transformers.generation_utils import GenerationMixin
else:
from transformers.generation import GenerationMixin


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -112,12 +106,14 @@ def __init__(
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", None)

if is_transformers_version("<=", "4.25.1"):
self.generation_config = None
if isinstance(model, torch.jit.ScriptModule):
self.input_names = {
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
}
else:
from transformers import GenerationConfig
self.input_names = set()

self.generation_config = GenerationConfig.from_model_config(config)
self.generation_config = GenerationConfig.from_model_config(config)

# Avoid warnings when creating a transformers pipeline
AutoConfig.register(self.base_model_prefix, AutoConfig)
Expand Down Expand Up @@ -267,6 +263,7 @@ def forward(
position_ids: Optional[torch.FloatTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
# 1. Prepare model inputs
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)

Expand All @@ -275,6 +272,15 @@ def forward(
"attention_mask": attention_mask,
}

if "position_ids" in self.input_names and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

echarlaix marked this conversation as resolved.
Show resolved Hide resolved
if "position_ids" in self.input_names or not self.input_names:
inputs["position_ids"] = position_ids

model_type = self.config.model_type.replace("_", "-")

if self.use_cache:
Expand Down Expand Up @@ -308,17 +314,17 @@ def forward(

inputs["past_key_values"] = past_key_values

if position_ids is not None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS:
inputs["position_ids"] = position_ids

# 2. Model forward
outputs = self.model(**inputs)

# 3. Process model outputs
if isinstance(outputs, (list, tuple)):
logits = outputs[0]
past_key_values = outputs[1] if self.use_cache else None
else:
logits = outputs["logits"]
past_key_values = outputs["past_key_values"] if self.use_cache else None

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)


Expand Down
12 changes: 12 additions & 0 deletions tests/generation/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,15 @@ def test_compare_with_and_without_past_key_values(self):
f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms,"
f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}",
)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_input_names(self, model_arch):
model_id = MODEL_NAMES[model_arch]
model = TSModelForCausalLM.from_pretrained(model_id, export=True)
self.assertTrue(isinstance(model.input_names, set))
self.assertTrue("input_ids" in model.input_names)
self.assertTrue("attention_mask" in model.input_names)
if model.use_cache:
self.assertTrue("past_key_values" in model.input_names)
else:
self.assertTrue("past_key_values" not in model.input_names)
Loading