Skip to content

Commit 7afb702

Browse files
Add normalized config for gemma3
1 parent e9f5bdd commit 7afb702

File tree

1 file changed

+56
-8
lines changed

1 file changed

+56
-8
lines changed

optimum/utils/normalized_config.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ def __getattr__(self, attr_name):
6262

6363
attr = getattr(config, leaf_attr_name, None)
6464

65+
if attr is None:
66+
# Fallback to checking if the leaf attribute name exists as mapped
67+
try:
68+
mapped_attr_name = super().__getattribute__(leaf_attr_name.upper())
69+
attr = getattr(config, mapped_attr_name, None)
70+
except AttributeError:
71+
pass
72+
6573
# If the attribute was not specified manually, try to fallback on the attribute_map.
6674
if attr is None:
6775
attribute_map = getattr(self.config, "attribute_map", {})
@@ -134,9 +142,40 @@ def __getattr__(self, attr_name):
134142
return super().__getattr__(attr_name)
135143

136144

137-
Pix2StructNormalizedTextConfig = NormalizedTextAndVisionConfig.with_args(
138-
text_config="text_config", vision_config="vision_config"
139-
)
145+
def create_normalized_text_and_vision_config(
146+
text_config_cls: type[NormalizedTextConfig] = NormalizedTextConfig,
147+
vision_config_cls: type[NormalizedVisionConfig] = NormalizedVisionConfig,
148+
) -> type[NormalizedConfig]:
149+
"""
150+
Create a normalized config for a model with both a text and vision config.
151+
152+
This allows for custom renaming of parameters within the nested configs.
153+
154+
Example usage:
155+
>>> MyNormalizedTextAndVisionConfigWithGQA = create_normalized_text_and_vision_config(
156+
text_config_cls=NormalizedTextConfigWithGQA
157+
).with_args(text_config="text_config", vision_config="vision_config")
158+
159+
Attributes:
160+
text_config_cls ([`type[NormalizedTextConfig]`]):
161+
Normalized configuration class to use for the text config.
162+
163+
vision_config_cls ([`type[NormalizedVisionConfig]`]):
164+
Normalized configuration class to use for the vision config.
165+
"""
166+
167+
class CustomNormalizedTextAndVisionConfig(text_config_cls, vision_config_cls):
168+
TEXT_CONFIG = None
169+
VISION_CONFIG = None
170+
171+
def __getattr__(self, attr_name):
172+
if self.TEXT_CONFIG is not None and attr_name.upper() in dir(text_config_cls):
173+
attr_name = f"{self.TEXT_CONFIG}.{attr_name}"
174+
elif self.VISION_CONFIG is not None and attr_name.upper() in dir(vision_config_cls):
175+
attr_name = f"{self.VISION_CONFIG}.{attr_name}"
176+
return super().__getattr__(attr_name)
177+
178+
return CustomNormalizedTextAndVisionConfig
140179

141180

142181
class NormalizedEncoderDecoderConfig(NormalizedConfig):
@@ -161,7 +200,6 @@ def __getattr__(self, attr_name):
161200
num_attention_heads="encoder_attention_heads",
162201
hidden_size="d_model",
163202
)
164-
165203
GPT2LikeNormalizedTextConfig = NormalizedTextConfig.with_args(num_attention_heads="n_head", hidden_size="n_embd")
166204
T5LikeNormalizedTextConfig = NormalizedTextConfig.with_args(
167205
num_attention_heads="num_heads",
@@ -173,23 +211,32 @@ def __getattr__(self, attr_name):
173211
GPTBigCodeNormalizedTextConfig = NormalizedTextConfig.with_args(
174212
num_attention_heads="n_head", hidden_size="n_embd", num_layers="n_layer"
175213
)
176-
177214
WhisperLikeNormalizedTextConfig = NormalizedTextConfig.with_args(
178215
hidden_size="d_model",
179216
)
180-
181217
TrOCRLikeNormalizedTextConfig = NormalizedTextConfig.with_args(
182218
num_layers="decoder_layers",
183219
num_attention_heads="decoder_attention_heads",
184220
hidden_size="hidden_size",
185221
)
186-
187222
SpeechToTextLikeNormalizedTextConfig = NormalizedSeq2SeqConfig.with_args(
188223
decoder_num_layers="decoder_layers",
189224
num_layers="decoder_layers",
190225
input_features_per_channel="input_feat_per_channel",
191226
allow_new=True,
192227
)
228+
Pix2StructNormalizedTextConfig = NormalizedTextAndVisionConfig.with_args(
229+
text_config="text_config", vision_config="vision_config"
230+
)
231+
232+
233+
class Gemma3NormalizedTextConfigWithGQA(NormalizedTextConfigWithGQA):
234+
HEAD_DIM = "text_config.head_dim"
235+
236+
237+
Gemma3NormalizedTextAndVisionConfig = create_normalized_text_and_vision_config(
238+
text_config_cls=Gemma3NormalizedTextConfigWithGQA
239+
).with_args(text_config="text_config", vision_config="vision_config")
193240

194241

195242
class NormalizedConfigManager:
@@ -253,11 +300,13 @@ class NormalizedConfigManager:
253300
"electra": NormalizedTextConfig,
254301
"encoder-decoder": NormalizedEncoderDecoderConfig,
255302
"gemma": NormalizedTextConfigWithGQA,
303+
"gemma3": Gemma3NormalizedTextAndVisionConfig,
256304
"gpt2": GPT2LikeNormalizedTextConfig,
257305
"gpt_bigcode": GPTBigCodeNormalizedTextConfig,
258306
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
259307
"gpt_neox": NormalizedTextConfig,
260308
"gptj": GPT2LikeNormalizedTextConfig,
309+
"granite": NormalizedTextConfigWithGQA,
261310
"imagegpt": GPT2LikeNormalizedTextConfig,
262311
"internlm2": NormalizedTextConfigWithGQA,
263312
"llama": NormalizedTextConfigWithGQA,
@@ -298,7 +347,6 @@ class NormalizedConfigManager:
298347
"qwen3": NormalizedTextConfig,
299348
"qwen3_moe": NormalizedTextConfig,
300349
"smollm3": NormalizedTextConfig,
301-
"granite": NormalizedTextConfigWithGQA,
302350
}
303351

304352
@classmethod

0 commit comments

Comments
 (0)