Skip to content

Commit 3ec2fbe

Browse files
committed
fix generation config
1 parent 96ef48d commit 3ec2fbe

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

optimum/intel/openvino/modeling_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,10 @@ def __init__(
127127
self.output_dtypes = output_dtypes
128128
self.model = model
129129
self.request = None if not self._compile_only else self.model
130+
131+
generation_config = kwargs.get("generation_config", None)
130132
if self.can_generate():
131-
self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config))
133+
self.generation_config = generation_config or GenerationConfig.from_model_config(config)
132134

133135
if is_transformers_version(">=", "4.44.99"):
134136
misplaced_generation_parameters = self.config._get_non_default_generation_parameters()

optimum/intel/openvino/modeling_base_seq2seq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def __init__(
8080
self.decoder_model = decoder
8181
self.decoder_with_past_model = decoder_with_past
8282

83-
self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config))
83+
generation_config = kwargs.get("generation_config", None)
84+
self.generation_config = generation_config or GenerationConfig.from_model_config(config)
8485

8586
if is_transformers_version(">=", "4.44.99"):
8687
misplaced_generation_parameters = self.config._get_non_default_generation_parameters()

0 commit comments

Comments
 (0)