Skip to content

Commit

Permalink
Expand activation scaling to other submodels
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKoff88 committed Nov 27, 2024
1 parent 20900b2 commit 66d5d9c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
5 changes: 3 additions & 2 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,11 @@ def _set_runtime_options(
Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin", "DiffusionPipeline"], "OnnxConfig"],
],
task: str,
library_name: str,
):
for model_name in models_and_export_configs.keys():
_, sub_export_config = models_and_export_configs[model_name]
if "vae_" in model_name or "text-generation" in task:
if "diffusers" in library_name or "text-generation" in task:
sub_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}


Expand Down Expand Up @@ -754,7 +755,7 @@ def export_from_model(

model.save_config(output)

_set_runtime_options(models_and_export_configs, task)
_set_runtime_options(models_and_export_configs, task, library_name)

export_models(
models_and_export_configs=models_and_export_configs,
Expand Down
12 changes: 12 additions & 0 deletions tests/openvino/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,18 @@ def _openvino_export(
self.assertTrue(
ov_model.vae_decoder.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
)
if hasattr(ov_model, "text_encoder") and ov_model.text_encoder:
self.assertTrue(
ov_model.text_encoder.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
)
if hasattr(ov_model, "unet") and ov_model.unet:
self.assertTrue(
ov_model.unet.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
)
if hasattr(ov_model, "transformer") and ov_model.transformer:
self.assertTrue(
ov_model.transformer.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_export(self, model_type: str):
Expand Down

0 comments on commit 66d5d9c

Please sign in to comment.