@@ -62,6 +62,14 @@ def __getattr__(self, attr_name):
6262
6363 attr = getattr (config , leaf_attr_name , None )
6464
65+ if attr is None :
66+ # Fallback to checking if the leaf attribute name exists as mapped
67+ try :
68+ mapped_attr_name = super ().__getattribute__ (leaf_attr_name .upper ())
69+ attr = getattr (config , mapped_attr_name , None )
70+ except AttributeError :
71+ pass
72+
6573 # If the attribute was not specified manually, try to fallback on the attribute_map.
6674 if attr is None :
6775 attribute_map = getattr (self .config , "attribute_map" , {})
@@ -134,9 +142,40 @@ def __getattr__(self, attr_name):
134142 return super ().__getattr__ (attr_name )
135143
136144
137- Pix2StructNormalizedTextConfig = NormalizedTextAndVisionConfig .with_args (
138- text_config = "text_config" , vision_config = "vision_config"
139- )
145+ def create_normalized_text_and_vision_config (
146+ text_config_cls : type [NormalizedTextConfig ] = NormalizedTextConfig ,
147+ vision_config_cls : type [NormalizedVisionConfig ] = NormalizedVisionConfig ,
148+ ) -> type [NormalizedConfig ]:
149+ """
150+ Create a normalized config for a model with both a text and vision config.
151+
152+ This allows for custom renaming of parameters within the nested configs.
153+
154+ Example usage:
155+ >>> MyNormalizedTextAndVisionConfigWithGQA = create_normalized_text_and_vision_config(
156+ text_config_cls=NormalizedTextConfigWithGQA
157+ ).with_args(text_config="text_config", vision_config="vision_config")
158+
159+ Attributes:
160+ text_config_cls ([`type[NormalizedTextConfig]`]):
161+ Normalized configuration class to use for the text config.
162+
163+ vision_config_cls ([`type[NormalizedVisionConfig]`]):
164+ Normalized configuration class to use for the vision config.
165+ """
166+
167+ class CustomNormalizedTextAndVisionConfig (text_config_cls , vision_config_cls ):
168+ TEXT_CONFIG = None
169+ VISION_CONFIG = None
170+
171+ def __getattr__ (self , attr_name ):
172+ if self .TEXT_CONFIG is not None and attr_name .upper () in dir (text_config_cls ):
173+ attr_name = f"{ self .TEXT_CONFIG } .{ attr_name } "
174+ elif self .VISION_CONFIG is not None and attr_name .upper () in dir (vision_config_cls ):
175+ attr_name = f"{ self .VISION_CONFIG } .{ attr_name } "
176+ return super ().__getattr__ (attr_name )
177+
178+ return CustomNormalizedTextAndVisionConfig
140179
141180
142181class NormalizedEncoderDecoderConfig (NormalizedConfig ):
@@ -161,7 +200,6 @@ def __getattr__(self, attr_name):
161200 num_attention_heads = "encoder_attention_heads" ,
162201 hidden_size = "d_model" ,
163202)
164-
165203GPT2LikeNormalizedTextConfig = NormalizedTextConfig .with_args (num_attention_heads = "n_head" , hidden_size = "n_embd" )
166204T5LikeNormalizedTextConfig = NormalizedTextConfig .with_args (
167205 num_attention_heads = "num_heads" ,
@@ -173,23 +211,32 @@ def __getattr__(self, attr_name):
173211GPTBigCodeNormalizedTextConfig = NormalizedTextConfig .with_args (
174212 num_attention_heads = "n_head" , hidden_size = "n_embd" , num_layers = "n_layer"
175213)
176-
177214WhisperLikeNormalizedTextConfig = NormalizedTextConfig .with_args (
178215 hidden_size = "d_model" ,
179216)
180-
181217TrOCRLikeNormalizedTextConfig = NormalizedTextConfig .with_args (
182218 num_layers = "decoder_layers" ,
183219 num_attention_heads = "decoder_attention_heads" ,
184220 hidden_size = "hidden_size" ,
185221)
186-
187222SpeechToTextLikeNormalizedTextConfig = NormalizedSeq2SeqConfig .with_args (
188223 decoder_num_layers = "decoder_layers" ,
189224 num_layers = "decoder_layers" ,
190225 input_features_per_channel = "input_feat_per_channel" ,
191226 allow_new = True ,
192227)
228+ Pix2StructNormalizedTextConfig = NormalizedTextAndVisionConfig .with_args (
229+ text_config = "text_config" , vision_config = "vision_config"
230+ )
231+
232+
233+ class Gemma3NormalizedTextConfigWithGQA (NormalizedTextConfigWithGQA ):
234+ HEAD_DIM = "text_config.head_dim"
235+
236+
237+ Gemma3NormalizedTextAndVisionConfig = create_normalized_text_and_vision_config (
238+ text_config_cls = Gemma3NormalizedTextConfigWithGQA
239+ ).with_args (text_config = "text_config" , vision_config = "vision_config" )
193240
194241
195242class NormalizedConfigManager :
@@ -253,11 +300,13 @@ class NormalizedConfigManager:
253300 "electra" : NormalizedTextConfig ,
254301 "encoder-decoder" : NormalizedEncoderDecoderConfig ,
255302 "gemma" : NormalizedTextConfigWithGQA ,
303+ "gemma3" : Gemma3NormalizedTextAndVisionConfig ,
256304 "gpt2" : GPT2LikeNormalizedTextConfig ,
257305 "gpt_bigcode" : GPTBigCodeNormalizedTextConfig ,
258306 "gpt_neo" : NormalizedTextConfig .with_args (num_attention_heads = "num_heads" ),
259307 "gpt_neox" : NormalizedTextConfig ,
260308 "gptj" : GPT2LikeNormalizedTextConfig ,
309+ "granite" : NormalizedTextConfigWithGQA ,
261310 "imagegpt" : GPT2LikeNormalizedTextConfig ,
262311 "internlm2" : NormalizedTextConfigWithGQA ,
263312 "llama" : NormalizedTextConfigWithGQA ,
@@ -298,7 +347,6 @@ class NormalizedConfigManager:
298347 "qwen3" : NormalizedTextConfig ,
299348 "qwen3_moe" : NormalizedTextConfig ,
300349 "smollm3" : NormalizedTextConfig ,
301- "granite" : NormalizedTextConfigWithGQA ,
302350 }
303351
304352 @classmethod
0 commit comments