Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
FluxTransfromerModelPatcher,
Gemma2ModelPatcher,
Gemma3LMModelPatcher,
Gemma3nLMModelPatcher,
GptJModelPatcher,
GptNeoModelPatcher,
GptNeoxModelPatcher,
Expand Down Expand Up @@ -261,6 +262,10 @@ def init_model_configs():
"transformers",
"Gemma3ForConditionalGeneration",
)
TasksManager._CUSTOM_CLASSES[("pt", "gemma3n", "image-text-to-text")] = (
"transformers",
"Gemma3nForConditionalGeneration",
)
TasksManager._CUSTOM_CLASSES[("pt", "idefics3", "image-text-to-text")] = (
"transformers",
"AutoModelForImageTextToText",
Expand Down Expand Up @@ -1465,6 +1470,99 @@ class Gemma3TextOpenVINOConfig(Gemma2OpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.50.0"


class Gemma3nDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
)
self.num_key_value_heads = normalized_config.num_key_value_heads
self.head_dim = normalized_config.head_dim
self.layer_types = normalized_config.config.layer_types
self.num_kv_shared_layers = normalized_config.config.num_kv_shared_layers
self.sliding_window = normalized_config.config.sliding_window

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
# some layers do not produce their own KV-cache, they use the shared KV-cache
layer_types = self.layer_types[: -self.num_kv_shared_layers]
past_kv_values = []
for layer_type in layer_types:
if layer_type == "sliding_attention":
shape = (
self.batch_size,
self.num_key_value_heads,
self.sliding_window,
self.head_dim,
)
else:
shape = (
self.batch_size,
self.num_key_value_heads,
self.sequence_length,
self.head_dim,
)
past_kv_value = (
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
)
past_kv_values.append(past_kv_value)

return past_kv_values


@register_in_tasks_manager(
"gemma3n_text",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
],
library_name="transformers",
)
class Gemma3nTextOpenVINOConfig(Gemma3TextOpenVINOConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Gemma3nDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = Gemma3nDummyPastKeyValuesGenerator
MIN_TRANSFORMERS_VERSION = "4.50.0"

def add_past_key_values(self, inputs_or_outputs: dict[str, dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + sequence_length"
name = "present"

num_kv_shared_layers = self._normalized_config.config.num_kv_shared_layers
layer_types = self._normalized_config.config.layer_types[:-num_kv_shared_layers]

for i, layer_type in enumerate(layer_types):
if layer_type == "sliding_attention":
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name}
else:
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name}


class DeciDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
Expand Down Expand Up @@ -4155,6 +4253,32 @@ def with_behavior(
return super().with_behavior(behavior)


@register_in_tasks_manager("gemma3n", *["image-text-to-text"], library_name="transformers")
class Gemma3nOpenVINOConfig(Gemma3OpenVINOConfig):
def with_behavior(self, behavior: Union[str, VLMConfigBehavior]):
"""
Creates a config for different behaviour specific to Gemma3n.

For LANGUAGE behavior, this explicitly uses the Gemma3n text model_type
instead of relying on the underlying text_config.model_type value.
"""
if isinstance(behavior, str) and not isinstance(behavior, VLMConfigBehavior):
behavior = VLMConfigBehavior(behavior)

if behavior == VLMConfigBehavior.LANGUAGE:
# Force the Gemma3n-specific text model type to ensure proper behavior
model_type = "gemma3n_text"
return get_vlm_text_generation_config(
model_type,
self._orig_config.text_config,
self.int_dtype,
self.float_dtype,
model_patcher=Gemma3nLMModelPatcher,
inputs_update={"token_type_ids": {0: "batch_size", 1: "sequence_length"}},
)
return super().with_behavior(behavior)


class DummyVisionPositionIdsInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = ("patch_attention_mask", "patch_position_ids")

Expand Down
84 changes: 84 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4745,6 +4745,90 @@ def __exit__(self, exc_type, exc_value, traceback):
del self._model.model._orig_update_causual_mask


def _project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: Optional[torch.Tensor] = None,
) -> torch.Tensor:
per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
per_layer_projection = (
self.per_layer_projection_scale.to(dtype=inputs_embeds.dtype, device=per_layer_projection.device)
* per_layer_projection
)

per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)

if per_layer_inputs is None:
return per_layer_projection

if per_layer_projection.shape != per_layer_inputs.shape:
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]

return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
dtype=inputs_embeds.dtype, device=per_layer_projection.device
)


def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
target_sparsity_tensor = torch.tensor(self.activation_sparsity, dtype=torch.float32, device=inputs.device)

def normal_icdf_approx(p):
p = torch.clamp(p, 1e-7, 1 - 1e-7)
a1 = -3.969683028665376e01
a2 = 2.209460984245205e02
a3 = -2.759285104469687e02
a4 = 1.383577518672690e02
a5 = -3.066479806614716e01
a6 = 2.506628277459239e00
b1 = -5.447609879822406e01
b2 = 1.615858368580409e02
b3 = -1.556989798598866e02
b4 = 6.680131188771972e01
b5 = -1.328068155288572e01
q = p - 0.5
r = q * q
num = (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q
den = ((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1.0
return num / den

std_multiplier = normal_icdf_approx(target_sparsity_tensor)
std_multiplier = std_multiplier.type(inputs.dtype)
inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
cutoff_x = inputs_mean + inputs_std * std_multiplier
return nn.functional.relu(inputs - cutoff_x)


class Gemma3nLMModelPatcher(Gemma3LMModelPatcher):
def __enter__(self):
super().__enter__()
self._model.model.language_model._orig_project_per_layer_inputs = (
self._model.model.language_model.project_per_layer_inputs
)
self._model.model.language_model.project_per_layer_inputs = types.MethodType(
_project_per_layer_inputs, self._model.model.language_model
)

for decoder_layer in self._model.model.language_model.layers:
decoder_layer.mlp._orig_gaussian_topk = decoder_layer.mlp._gaussian_topk
decoder_layer.mlp._gaussian_topk = types.MethodType(_gaussian_topk, decoder_layer.mlp)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.model.language_model.project_per_layer_inputs = (
self._model.model.language_model._orig_project_per_layer_inputs
)

for decoder_layer in self._model.model.language_model.layers:
decoder_layer.mlp._gaussian_topk = decoder_layer.mlp._orig_gaussian_topk


class Idefics3ImageEmbeddingsModelPatcher(ModelPatcher):
def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def get_submodels(model):
"qwen3_vl",
"got_ocr2",
"gemma3",
"gemma3n",
"idefics3",
"smolvlm",
"phi4mm",
Expand Down
6 changes: 5 additions & 1 deletion optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -3868,7 +3868,10 @@ def merge_vision_text_embeddings(
self.get_text_embeddings(torch.tensor([[self.config.image_token_index]], dtype=torch.long))[0]
)
else:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
if self.config.model_type == "gemma3n":
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
else:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds)

image_features = image_features.to(inputs_embeds.dtype)
Expand Down Expand Up @@ -4817,6 +4820,7 @@ def preprocess_inputs(
"qwen2_5_vl_text": _OVQwen2_5_VLForCausalLM,
"got_ocr2": _OVGotOCR2ForCausalLM,
"gemma3": _OVGemma3ForCausalLM,
"gemma3n": _OVGemma3ForCausalLM,
"idefics3": _OVIdefics3ForCausalLM,
"smolvlm": _OVSmolVLForCasualLM,
"phi4mm": _OVPhi4MMForCausalLM,
Expand Down
Loading