Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loading generation config if it is part of model #750

Merged
merged 4 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def __init__(

self.model = model
self.request = None
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
if self.can_generate():
self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config))
else:
self.generation_config = None

self._openvino_config = None
if quantization_config:
Expand Down Expand Up @@ -155,6 +158,14 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
"""
dst_path = os.path.join(save_directory, OV_XML_FILE_NAME)
openvino.save_model(self.model, dst_path, compress_to_fp16=False)
generation_config = getattr(self, "generation_config", None)
if generation_config is not None:
try:
generation_config.save_pretrained(save_directory)
except Exception as exception:
logger.warning(
f"The generation config will not be saved, saving failed with following error:\n{exception}"
)

self._save_openvino_config(save_directory)

Expand Down Expand Up @@ -240,6 +251,20 @@ def _from_pretrained(
quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)

model = cls.load_model(model_cache_path, quantization_config=quantization_config)

try:
generation_config = GenerationConfig.from_pretrained(
model_id,
token=token,
revision=revision,
subfolder=subfolder,
force_download=force_download,
cache_dir=cache_dir,
)
kwargs["generation_config"] = generation_config
except Exception:
pass

return cls(
model,
config=config,
Expand Down
25 changes: 24 additions & 1 deletion optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def __init__(
self.encoder_model = encoder
self.decoder_model = decoder
self.decoder_with_past_model = decoder_with_past
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
if self.can_generate():
self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config))
else:
self.generation_config = None
self._openvino_config = None
if quantization_config:
self._openvino_config = OVConfig(quantization_config=quantization_config)
Expand All @@ -104,6 +107,13 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
openvino.save_model(src_file, dst_path, compress_to_fp16=False)

self._save_openvino_config(save_directory)
if self.generation_config is not None:
try:
self.generation_config.save_pretrained(save_directory)
except Exception as exception:
logger.warning(
f"The generation config will not be saved, saving failed with following error:\n{exception}"
)

@classmethod
def _from_pretrained(
Expand Down Expand Up @@ -218,6 +228,19 @@ def _from_pretrained(
if use_cache:
decoder_with_past = cls.load_model(file_names["decoder_with_past"], quantization_config)

try:
generation_config = GenerationConfig.from_pretrained(
model_id,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
kwargs["generation_config"] = generation_config
except Exception:
pass

return cls(
encoder=encoder,
decoder=decoder,
Expand Down
20 changes: 20 additions & 0 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
dst_path = os.path.join(save_directory, OV_XML_FILE_NAME)
openvino.save_model(model_to_save, dst_path, compress_to_fp16=False)

if self.generation_config is not None:
try:
self.generation_config.save_pretrained(save_directory)
except Exception as exception:
logger.warning(
f"The generation config will not be saved, saving failed with following error:\n{exception}"
)

self._save_openvino_config(save_directory)

@classmethod
Expand Down Expand Up @@ -763,6 +771,18 @@ def _from_pretrained(
init_cls = cls

enable_compilation = kwargs.pop("compile", True) and not load_in_4bit
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
kwargs["generation_config"] = generation_config
except Exception:
pass
causal_model = init_cls(
model=model,
config=config,
Expand Down
49 changes: 48 additions & 1 deletion tests/openvino/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
from parameterized import parameterized
from sentence_transformers import SentenceTransformer, models
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from utils_tests import MODEL_NAMES

from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
Expand Down Expand Up @@ -71,6 +71,8 @@ class ExportModelTest(unittest.TestCase):
"latent-consistency": OVLatentConsistencyModelPipeline,
}

GENERATIVE_MODELS = ("pix2struct", "t5", "bart", "gpt2", "whisper")

def _openvino_export(
self,
model_type: str,
Expand Down Expand Up @@ -124,6 +126,51 @@ def _openvino_export(
def test_export(self, model_type: str):
self._openvino_export(model_type)

@parameterized.expand(GENERATIVE_MODELS)
def test_export_with_custom_gen_config(self, model_type):
auto_model = self.SUPPORTED_ARCHITECTURES[model_type]
task = auto_model.export_feature
model_name = MODEL_NAMES[model_type]
loading_kwargs = {"attn_implementation": "eager"} if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED else {}

model = auto_model.auto_model_class.from_pretrained(model_name, **loading_kwargs)

model.generation_config.top_k = 42
model.generation_config.do_sample = True

if getattr(model.config, "model_type", None) == "pix2struct":
preprocessors = maybe_load_preprocessors(model_name)
else:
preprocessors = None

supported_tasks = (task, task + "-with-past") if "text-generation" in task else (task,)
for supported_task in supported_tasks:
with TemporaryDirectory() as tmpdirname:
export_from_model(
model=model,
output=Path(tmpdirname),
task=supported_task,
preprocessors=preprocessors,
)

use_cache = supported_task.endswith("-with-past")
ov_model = auto_model.from_pretrained(tmpdirname, use_cache=use_cache)
self.assertIsInstance(ov_model, OVBaseModel)
self.assertTrue(ov_model.can_generate())
self.assertTrue(ov_model.generation_config is not None)
self.assertIsInstance(ov_model.generation_config, GenerationConfig)
self.assertTrue(ov_model.generation_config.top_k == 42)

# check that generate config remains after repeated saving
with TemporaryDirectory() as tmpdirname2:
ov_model.save_pretrained(tmpdirname2)
ov_model = auto_model.from_pretrained(tmpdirname2, use_cache=use_cache)
self.assertIsInstance(ov_model, OVBaseModel)
self.assertTrue(ov_model.can_generate())
self.assertTrue(ov_model.generation_config is not None)
self.assertIsInstance(ov_model.generation_config, GenerationConfig)
self.assertTrue(ov_model.generation_config.top_k == 42)


class CustomExportModelTest(unittest.TestCase):
def test_custom_export_config_model(self):
Expand Down
Loading