Skip to content

Commit

Permalink
Loading generation config if it is part of model (#750)
Browse files Browse the repository at this point in the history
* loading generation config if it is part of model

* update test

* add saving generation config with save_pretrained call

* Update tests/openvino/test_export.py

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>

---------

Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
  • Loading branch information
eaidova and echarlaix authored Jun 6, 2024
1 parent 66191a3 commit d5dbb3d
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 3 deletions.
27 changes: 26 additions & 1 deletion optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def __init__(

self.model = model
self.request = None
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
if self.can_generate():
self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config))
else:
self.generation_config = None

self._openvino_config = None
if quantization_config:
Expand Down Expand Up @@ -155,6 +158,14 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
"""
dst_path = os.path.join(save_directory, OV_XML_FILE_NAME)
openvino.save_model(self.model, dst_path, compress_to_fp16=False)
generation_config = getattr(self, "generation_config", None)
if generation_config is not None:
try:
generation_config.save_pretrained(save_directory)
except Exception as exception:
logger.warning(
f"The generation config will not be saved, saving failed with following error:\n{exception}"
)

self._save_openvino_config(save_directory)

Expand Down Expand Up @@ -240,6 +251,20 @@ def _from_pretrained(
quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)

model = cls.load_model(model_cache_path, quantization_config=quantization_config)

try:
generation_config = GenerationConfig.from_pretrained(
model_id,
token=token,
revision=revision,
subfolder=subfolder,
force_download=force_download,
cache_dir=cache_dir,
)
kwargs["generation_config"] = generation_config
except Exception:
pass

return cls(
model,
config=config,
Expand Down
25 changes: 24 additions & 1 deletion optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def __init__(
self.encoder_model = encoder
self.decoder_model = decoder
self.decoder_with_past_model = decoder_with_past
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
if self.can_generate():
self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config))
else:
self.generation_config = None
self._openvino_config = None
if quantization_config:
self._openvino_config = OVConfig(quantization_config=quantization_config)
Expand All @@ -104,6 +107,13 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
openvino.save_model(src_file, dst_path, compress_to_fp16=False)

self._save_openvino_config(save_directory)
if self.generation_config is not None:
try:
self.generation_config.save_pretrained(save_directory)
except Exception as exception:
logger.warning(
f"The generation config will not be saved, saving failed with following error:\n{exception}"
)

@classmethod
def _from_pretrained(
Expand Down Expand Up @@ -218,6 +228,19 @@ def _from_pretrained(
if use_cache:
decoder_with_past = cls.load_model(file_names["decoder_with_past"], quantization_config)

try:
generation_config = GenerationConfig.from_pretrained(
model_id,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
kwargs["generation_config"] = generation_config
except Exception:
pass

return cls(
encoder=encoder,
decoder=decoder,
Expand Down
20 changes: 20 additions & 0 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
dst_path = os.path.join(save_directory, OV_XML_FILE_NAME)
openvino.save_model(model_to_save, dst_path, compress_to_fp16=False)

if self.generation_config is not None:
try:
self.generation_config.save_pretrained(save_directory)
except Exception as exception:
logger.warning(
f"The generation config will not be saved, saving failed with following error:\n{exception}"
)

self._save_openvino_config(save_directory)

@classmethod
Expand Down Expand Up @@ -765,6 +773,18 @@ def _from_pretrained(
init_cls = cls

enable_compilation = kwargs.pop("compile", True) and not load_in_4bit
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
kwargs["generation_config"] = generation_config
except Exception:
pass
causal_model = init_cls(
model=model,
config=config,
Expand Down
49 changes: 48 additions & 1 deletion tests/openvino/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
from parameterized import parameterized
from sentence_transformers import SentenceTransformer, models
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from utils_tests import MODEL_NAMES

from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
Expand Down Expand Up @@ -71,6 +71,8 @@ class ExportModelTest(unittest.TestCase):
"latent-consistency": OVLatentConsistencyModelPipeline,
}

GENERATIVE_MODELS = ("pix2struct", "t5", "bart", "gpt2", "whisper")

def _openvino_export(
self,
model_type: str,
Expand Down Expand Up @@ -124,6 +126,51 @@ def _openvino_export(
def test_export(self, model_type: str):
self._openvino_export(model_type)

@parameterized.expand(GENERATIVE_MODELS)
def test_export_with_custom_gen_config(self, model_type):
auto_model = self.SUPPORTED_ARCHITECTURES[model_type]
task = auto_model.export_feature
model_name = MODEL_NAMES[model_type]
loading_kwargs = {"attn_implementation": "eager"} if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED else {}

model = auto_model.auto_model_class.from_pretrained(model_name, **loading_kwargs)

model.generation_config.top_k = 42
model.generation_config.do_sample = True

if getattr(model.config, "model_type", None) == "pix2struct":
preprocessors = maybe_load_preprocessors(model_name)
else:
preprocessors = None

supported_tasks = (task, task + "-with-past") if "text-generation" in task else (task,)
for supported_task in supported_tasks:
with TemporaryDirectory() as tmpdirname:
export_from_model(
model=model,
output=Path(tmpdirname),
task=supported_task,
preprocessors=preprocessors,
)

use_cache = supported_task.endswith("-with-past")
ov_model = auto_model.from_pretrained(tmpdirname, use_cache=use_cache)
self.assertIsInstance(ov_model, OVBaseModel)
self.assertTrue(ov_model.can_generate())
self.assertTrue(ov_model.generation_config is not None)
self.assertIsInstance(ov_model.generation_config, GenerationConfig)
self.assertTrue(ov_model.generation_config.top_k == 42)

# check that generate config remains after repeated saving
with TemporaryDirectory() as tmpdirname2:
ov_model.save_pretrained(tmpdirname2)
ov_model = auto_model.from_pretrained(tmpdirname2, use_cache=use_cache)
self.assertIsInstance(ov_model, OVBaseModel)
self.assertTrue(ov_model.can_generate())
self.assertTrue(ov_model.generation_config is not None)
self.assertIsInstance(ov_model.generation_config, GenerationConfig)
self.assertTrue(ov_model.generation_config.top_k == 42)


class CustomExportModelTest(unittest.TestCase):
def test_custom_export_config_model(self):
Expand Down

0 comments on commit d5dbb3d

Please sign in to comment.