Skip to content

Commit ce8d1bf

Browse files
committed
loading generation config if it is part of model
1 parent dad18d1 commit ce8d1bf

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

optimum/intel/openvino/modeling_base.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ def __init__(
8989

9090
self.model = model
9191
self.request = None
92-
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
92+
if self.can_generate():
93+
self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config))
94+
else:
95+
self.generation_config = None
9396

9497
self._openvino_config = None
9598
if quantization_config:
@@ -240,6 +243,20 @@ def _from_pretrained(
240243
quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
241244

242245
model = cls.load_model(model_cache_path, quantization_config=quantization_config)
246+
247+
try:
248+
generation_config = GenerationConfig.from_pretrained(
249+
model_id,
250+
token=token,
251+
revision=revision,
252+
subfolder=subfolder,
253+
force_download=force_download,
254+
cache_dir=cache_dir
255+
)
256+
kwargs["generation_config"] = generation_config
257+
except Exception:
258+
pass
259+
243260
return cls(
244261
model,
245262
config=config,

optimum/intel/openvino/modeling_base_seq2seq.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ def __init__(
7878
self.encoder_model = encoder
7979
self.decoder_model = decoder
8080
self.decoder_with_past_model = decoder_with_past
81-
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
81+
if self.can_generate():
82+
self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config))
83+
else:
84+
self.generation_config = None
8285
self._openvino_config = None
8386
if quantization_config:
8487
self._openvino_config = OVConfig(quantization_config=quantization_config)
@@ -218,6 +221,19 @@ def _from_pretrained(
218221
if use_cache:
219222
decoder_with_past = cls.load_model(file_names["decoder_with_past"], quantization_config)
220223

224+
try:
225+
generation_config = GenerationConfig.from_pretrained(
226+
model_id,
227+
token=token,
228+
revision=revision,
229+
cache_dir=cache_dir,
230+
force_download=force_download,
231+
local_files_only=local_files_only
232+
)
233+
kwargs["generation_config"] = generation_config
234+
except Exception:
235+
pass
236+
221237
return cls(
222238
encoder=encoder,
223239
decoder=decoder,

optimum/intel/openvino/modeling_decoder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,18 @@ def _from_pretrained(
763763
init_cls = cls
764764

765765
enable_compilation = kwargs.pop("compile", True) and not load_in_4bit
766+
try:
767+
generation_config = GenerationConfig.from_pretrained(
768+
model_id,
769+
token=token,
770+
revision=revision,
771+
cache_dir=cache_dir,
772+
force_download=force_download,
773+
local_files_only=local_files_only
774+
)
775+
kwargs["generation_config"] = generation_config
776+
except Exception:
777+
pass
766778
causal_model = init_cls(
767779
model=model,
768780
config=config,

0 commit comments

Comments
 (0)