File tree Expand file tree Collapse file tree 3 files changed +47
-2
lines changed Expand file tree Collapse file tree 3 files changed +47
-2
lines changed Original file line number Diff line number Diff line change @@ -89,7 +89,10 @@ def __init__(
89
89
90
90
self .model = model
91
91
self .request = None
92
- self .generation_config = GenerationConfig .from_model_config (config ) if self .can_generate () else None
92
+ if self .can_generate ():
93
+ self .generation_config = kwargs .get ("generation_config" , GenerationConfig .from_model_config (config ))
94
+ else :
95
+ self .generation_config = None
93
96
94
97
self ._openvino_config = None
95
98
if quantization_config :
@@ -240,6 +243,20 @@ def _from_pretrained(
240
243
quantization_config = cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
241
244
242
245
model = cls .load_model (model_cache_path , quantization_config = quantization_config )
246
+
247
+ try :
248
+ generation_config = GenerationConfig .from_pretrained (
249
+ model_id ,
250
+ token = token ,
251
+ revision = revision ,
252
+ subfolder = subfolder ,
253
+ force_download = force_download ,
254
+ cache_dir = cache_dir
255
+ )
256
+ kwargs ["generation_config" ] = generation_config
257
+ except Exception :
258
+ pass
259
+
243
260
return cls (
244
261
model ,
245
262
config = config ,
Original file line number Diff line number Diff line change @@ -78,7 +78,10 @@ def __init__(
78
78
self .encoder_model = encoder
79
79
self .decoder_model = decoder
80
80
self .decoder_with_past_model = decoder_with_past
81
- self .generation_config = GenerationConfig .from_model_config (config ) if self .can_generate () else None
81
+ if self .can_generate ():
82
+ self .generation_config = kwargs .get ("generation_config" , GenerationConfig .from_model_config (config ))
83
+ else :
84
+ self .generation_config = None
82
85
self ._openvino_config = None
83
86
if quantization_config :
84
87
self ._openvino_config = OVConfig (quantization_config = quantization_config )
@@ -218,6 +221,19 @@ def _from_pretrained(
218
221
if use_cache :
219
222
decoder_with_past = cls .load_model (file_names ["decoder_with_past" ], quantization_config )
220
223
224
+ try :
225
+ generation_config = GenerationConfig .from_pretrained (
226
+ model_id ,
227
+ token = token ,
228
+ revision = revision ,
229
+ cache_dir = cache_dir ,
230
+ force_download = force_download ,
231
+ local_files_only = local_files_only
232
+ )
233
+ kwargs ["generation_config" ] = generation_config
234
+ except Exception :
235
+ pass
236
+
221
237
return cls (
222
238
encoder = encoder ,
223
239
decoder = decoder ,
Original file line number Diff line number Diff line change @@ -763,6 +763,18 @@ def _from_pretrained(
763
763
init_cls = cls
764
764
765
765
enable_compilation = kwargs .pop ("compile" , True ) and not load_in_4bit
766
+ try :
767
+ generation_config = GenerationConfig .from_pretrained (
768
+ model_id ,
769
+ token = token ,
770
+ revision = revision ,
771
+ cache_dir = cache_dir ,
772
+ force_download = force_download ,
773
+ local_files_only = local_files_only
774
+ )
775
+ kwargs ["generation_config" ] = generation_config
776
+ except Exception :
777
+ pass
766
778
causal_model = init_cls (
767
779
model = model ,
768
780
config = config ,
You can’t perform that action at this time.
0 commit comments