File tree Expand file tree Collapse file tree 2 files changed +5
-2
lines changed Expand file tree Collapse file tree 2 files changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -127,8 +127,10 @@ def __init__(
127
127
self .output_dtypes = output_dtypes
128
128
self .model = model
129
129
self .request = None if not self ._compile_only else self .model
130
+
131
+ generation_config = kwargs .get ("generation_config" , None )
130
132
if self .can_generate ():
131
- self .generation_config = kwargs . get ( " generation_config" , GenerationConfig .from_model_config (config ) )
133
+ self .generation_config = generation_config or GenerationConfig .from_model_config (config )
132
134
133
135
if is_transformers_version (">=" , "4.44.99" ):
134
136
misplaced_generation_parameters = self .config ._get_non_default_generation_parameters ()
Original file line number Diff line number Diff line change @@ -80,7 +80,8 @@ def __init__(
80
80
self .decoder_model = decoder
81
81
self .decoder_with_past_model = decoder_with_past
82
82
83
- self .generation_config = kwargs .get ("generation_config" , GenerationConfig .from_model_config (config ))
83
+ generation_config = kwargs .get ("generation_config" , None )
84
+ self .generation_config = generation_config or GenerationConfig .from_model_config (config )
84
85
85
86
if is_transformers_version (">=" , "4.44.99" ):
86
87
misplaced_generation_parameters = self .config ._get_non_default_generation_parameters ()
You can’t perform that action at this time.
0 commit comments