Skip to content

Commit

Permalink
review comments and kv cache compression disable in fp
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jan 9, 2025
1 parent 1066e01 commit 75c653d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
14 changes: 7 additions & 7 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from optimum.exporters import TasksManager
from optimum.exporters.utils import (
DECODER_NAME,
DECODER_WITH_PAST_NAME,
ENCODER_NAME,
_get_submodels_for_export_encoder_decoder,
get_diffusion_models_for_export,
Expand All @@ -48,7 +47,6 @@
_transformers_version,
compare_versions,
is_diffusers_version,
is_openvino_version,
is_openvino_tokenizers_version,
is_tokenizers_version,
is_transformers_version,
Expand Down Expand Up @@ -110,10 +108,13 @@ def _set_runtime_options(
"diffusers" in library_name
or "text-generation" in task
or ("image-text-to-text" in task and model_name == "language_model")
or getattr(sub_export_config, "stateful", False)
):
sub_export_config.runtime_options["ACTIVATIONS_SCALE_FACTOR"] = "8.0"
if not quantized_model and (
"text-generation" in task or ("image-text-to-text" in task and model_name == "language_model")
"text-generation" in task
or ("image-text-to-text" in task and model_name == "language_model")
or getattr(sub_export_config, "stateful", False)
):
sub_export_config.runtime_options["KV_CACHE_PRECISION"] = "f16"

Expand Down Expand Up @@ -643,7 +644,7 @@ def export_from_model(
is_encoder_decoder = getattr(getattr(model, "config", {}), "is_encoder_decoder", False)
model_type = getattr(getattr(model, "config", {}), "model_type", "")
stateful = stateful and (
ensure_export_task_support_stateful(task, is_encoder_decoder) or ensure_model_type_support_stateful(model_type)
ensure_export_task_support_stateful(task) or ensure_model_type_support_stateful(model_type)
)

if stateful and is_encoder_decoder and not getattr(model, "_supports_cache_class", False):
Expand Down Expand Up @@ -1251,17 +1252,16 @@ def _get_encoder_decoder_stateful_models_for_export(
all_variants = "\n".join([f" - {name}: {description}" for name, description in export_config.VARIANTS.items()])
logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}")

models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=True)
models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=False)

encoder_export_config = export_config.with_behavior("encoder")
models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config)

decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)

decoder_export_config_with_past.stateful = True
decoder_with_past_model = models_for_export.pop(DECODER_WITH_PAST_NAME)
models_for_export[DECODER_NAME] = (
decoder_with_past_model,
models_for_export[DECODER_NAME],
decoder_export_config_with_past,
)
return None, models_for_export
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3700,4 +3700,4 @@ def patched_forward(*args, **kwargs):

model.forward = patched_forward

super().__init__(config, model, model_kwargs)
super().__init__(config, model, model_kwargs)
6 changes: 3 additions & 3 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,9 @@ def forward(
# Decode
if past_key_values is None or self.decoder_with_past is None:
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:]
if past_key_values is not None and self.use_cache
else decoder_input_ids,
input_ids=(
decoder_input_ids[:, -1:] if past_key_values is not None and self.use_cache else decoder_input_ids
),
past_key_values=past_key_values,
encoder_hidden_states=encoder_outputs.last_hidden_state,
encoder_attention_mask=attention_mask,
Expand Down

0 comments on commit 75c653d

Please sign in to comment.