Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 24 additions & 151 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,13 @@ def parse_args(parser: "ArgumentParser"):
return parse_args_openvino(parser)

def run(self):
from optimum.utils.save_utils import maybe_load_preprocessors

from ...exporters.openvino.__main__ import infer_task, main_export, maybe_convert_tokenizers
from ...exporters.openvino.utils import save_preprocessors
from ...intel.openvino.configuration import _DEFAULT_4BIT_WQ_CONFIG, OVConfig, get_default_quantization_config
from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available, is_nncf_available
from ...exporters.openvino.__main__ import main_export
from ...intel.openvino.configuration import (
_DEFAULT_4BIT_WQ_CONFIG,
OVConfig,
get_default_quantization_config,
)
from ...intel.utils.import_utils import is_nncf_available
from ...intel.utils.modeling_utils import _infer_library_from_model_name_or_path

if self.args.library is None:
Expand Down Expand Up @@ -430,151 +431,23 @@ def run(self):
quantization_config = prepare_q_config(self.args)
ov_config = OVConfig(quantization_config=quantization_config)

quantization_config = ov_config.quantization_config if ov_config else None
quantize_with_dataset = quantization_config and getattr(quantization_config, "dataset", None) is not None
task = infer_task(self.args.task, self.args.model, library_name=library_name)
# in some cases automatic task detection for multimodal models gives incorrect results
if self.args.task == "auto" and library_name == "transformers":
from transformers import AutoConfig

from ...exporters.openvino.utils import MULTI_MODAL_TEXT_GENERATION_MODELS

config = AutoConfig.from_pretrained(
self.args.model,
cache_dir=self.args.cache_dir,
trust_remote_code=self.args.trust_remote_code,
)
if getattr(config, "model_type", "") in MULTI_MODAL_TEXT_GENERATION_MODELS:
task = "image-text-to-text"

if library_name == "diffusers" and quantize_with_dataset:
if not is_diffusers_available():
raise ValueError(DIFFUSERS_IMPORT_ERROR.format("Export of diffusers models"))

from diffusers import DiffusionPipeline

diffusers_config = DiffusionPipeline.load_config(self.args.model)
class_name = diffusers_config.get("_class_name", None)

if class_name == "LatentConsistencyModelPipeline":
from optimum.intel import OVLatentConsistencyModelPipeline

model_cls = OVLatentConsistencyModelPipeline

elif class_name == "StableDiffusionXLPipeline":
from optimum.intel import OVStableDiffusionXLPipeline

model_cls = OVStableDiffusionXLPipeline
elif class_name == "StableDiffusionPipeline":
from optimum.intel import OVStableDiffusionPipeline

model_cls = OVStableDiffusionPipeline
elif class_name == "StableDiffusion3Pipeline":
from optimum.intel import OVStableDiffusion3Pipeline

model_cls = OVStableDiffusion3Pipeline
elif class_name == "FluxPipeline":
from optimum.intel import OVFluxPipeline

model_cls = OVFluxPipeline
elif class_name == "SanaPipeline":
from optimum.intel import OVSanaPipeline

model_cls = OVSanaPipeline
elif class_name == "SaneSprintPipeline":
from optimum.intel import OVSanaSprintPipeline

model_cls = OVSanaSprintPipeline

else:
raise NotImplementedError(f"Quantization isn't supported for class {class_name}.")

model = model_cls.from_pretrained(self.args.model, export=True, quantization_config=quantization_config)
model.save_pretrained(self.args.output)
if not self.args.disable_convert_tokenizer:
maybe_convert_tokenizers(library_name, self.args.output, model, task=task)
elif (
quantize_with_dataset
and (
task in ["fill-mask", "zero-shot-image-classification"]
or task.startswith("text-generation")
or task.startswith("text2text-generation")
or task.startswith("automatic-speech-recognition")
or task.startswith("feature-extraction")
)
or (task == "image-text-to-text" and quantization_config is not None)
):
if task.startswith("text-generation"):
from optimum.intel import OVModelForCausalLM

model_cls = OVModelForCausalLM
elif task.startswith("text2text-generation"):
from optimum.intel import OVModelForSeq2SeqLM

model_cls = OVModelForSeq2SeqLM
elif task == "image-text-to-text":
from optimum.intel import OVModelForVisualCausalLM

model_cls = OVModelForVisualCausalLM
elif "automatic-speech-recognition" in task:
from optimum.intel import OVModelForSpeechSeq2Seq

model_cls = OVModelForSpeechSeq2Seq
elif task.startswith("feature-extraction") and library_name == "transformers":
from ...intel import OVModelForFeatureExtraction

model_cls = OVModelForFeatureExtraction
elif task.startswith("feature-extraction") and library_name == "sentence_transformers":
from ...intel import OVSentenceTransformer

model_cls = OVSentenceTransformer
elif task == "fill-mask":
from ...intel import OVModelForMaskedLM

model_cls = OVModelForMaskedLM
elif task == "zero-shot-image-classification":
from ...intel import OVModelForZeroShotImageClassification

model_cls = OVModelForZeroShotImageClassification
else:
raise NotImplementedError(
f"Unable to find a matching model class for the task={task} and library_name={library_name}."
)

# In this case, to apply quantization an instance of a model class is required
model = model_cls.from_pretrained(
self.args.model,
export=True,
quantization_config=quantization_config,
stateful=not self.args.disable_stateful,
trust_remote_code=self.args.trust_remote_code,
variant=self.args.variant,
cache_dir=self.args.cache_dir,
)
model.save_pretrained(self.args.output)

preprocessors = maybe_load_preprocessors(self.args.model, trust_remote_code=self.args.trust_remote_code)
save_preprocessors(preprocessors, model.config, self.args.output, self.args.trust_remote_code)
if not self.args.disable_convert_tokenizer:
maybe_convert_tokenizers(library_name, self.args.output, preprocessors=preprocessors, task=task)
else:
# TODO : add input shapes
main_export(
model_name_or_path=self.args.model,
output=self.args.output,
task=self.args.task,
framework=self.args.framework,
cache_dir=self.args.cache_dir,
trust_remote_code=self.args.trust_remote_code,
pad_token_id=self.args.pad_token_id,
ov_config=ov_config,
stateful=not self.args.disable_stateful,
convert_tokenizer=not self.args.disable_convert_tokenizer,
library_name=library_name,
variant=self.args.variant,
model_kwargs=self.args.model_kwargs,
# **input_shapes,
)
# TODO : add input shapes
main_export(
model_name_or_path=self.args.model,
output=self.args.output,
task=self.args.task,
framework=self.args.framework,
cache_dir=self.args.cache_dir,
trust_remote_code=self.args.trust_remote_code,
pad_token_id=self.args.pad_token_id,
ov_config=ov_config,
stateful=not self.args.disable_stateful,
convert_tokenizer=not self.args.disable_convert_tokenizer,
library_name=library_name,
variant=self.args.variant,
model_kwargs=self.args.model_kwargs,
# **input_shapes,
)


def prepare_wc_config(args, default_configs):
Expand Down
Loading
Loading