Skip to content
1 change: 1 addition & 0 deletions docs/source/openvino/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Here is the list of the supported architectures :
- MiniCPM-o
- MiniCPMV
- Mistral
- Mistral 3
- Mixtral
- MobileBert
- MobileNet v1
Expand Down
126 changes: 126 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@
MiniCPMModelPatcher,
MiniCPMVImageEmbeddingsModelPatcher,
MiniCPMVResamplerModelPatcher,
Mistral3ImageEmbeddingModelPatcher,
Mistral3MultiModalProjectorPatcher,
MistralModelPatcher,
MixtralModelPatcher,
MPTModelPatcher,
Expand Down Expand Up @@ -2054,6 +2056,130 @@ def patch_model_for_export(self, model: PreTrainedModel, model_kwargs: Optional[
return LlavaNextVideoImageEmbeddingModelPatcher(self, model, model_kwargs)


class Mistral3ConfigBehavior(str, enum.Enum):
LANGUAGE = "language"
# VISION_EMBEDDINGS extracts visual features and applies projector.norm().
# Combined with the cycle block
# (https://github.com/huggingface/transformers/blob/v5.2.0/src/transformers/models/mistral3/modeling_mistral3.py#L76-L94)
# and MULTI_MODAL_PROJECTOR, this is equivalent to get_image_features
# (https://github.com/huggingface/transformers/blob/v5.2.0/src/transformers/models/mistral3/modeling_mistral3.py#L223-L248).
VISION_EMBEDDINGS = "vision_embeddings"
TEXT_EMBEDDINGS = "text_embeddings"
MULTI_MODAL_PROJECTOR = "multi_modal_projector"


class DummyMistral3MultiModalProjectorInputGenerator(DummyLLavaMultiModalProjectorInputGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(task, normalized_config, batch_size, random_batch_size_range, **kwargs)
self.spatial_merge_size = getattr(
normalized_config.config, "spatial_merge_size", getattr(normalized_config, "spatial_merge_size", 2)
)
self.num_merged_patches = self.num_patches // (self.spatial_merge_size**2)

def generate(
self,
input_name: str,
framework: str = "pt",
int_dtype: str = "int64",
float_dtype: str = "fp32",
):
input_dim = self.hidden_size * self.spatial_merge_size**2
shape = [self.num_merged_patches, input_dim]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)


class Mistral3MultiModalProjectorOpenVINOConfig(OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyMistral3MultiModalProjectorInputGenerator,)
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
_MODEL_PATCHER = Mistral3MultiModalProjectorPatcher

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"image_features": {0: "num_patches"}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not :

Suggested change
return {"image_features": {0: "num_patches"}}
return {"image_features": {0: "batch_size", 1: "sequence_length"}}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to the batch_size discussion above — since the cycle block flattens all images' patches into a single 2D tensor [total_merged_patches, dim], the projector input has only one dynamic dimension. So {0: "num_patches"} reflects the actual runtime shape.


@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {"hidden_states": {0: "num_patches"}}


@register_in_tasks_manager("mistral3", *["image-text-to-text"], library_name="transformers")
class Mistral3OpenVINOConfig(BaseVLMOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.50.0"
SUPPORTED_BEHAVIORS = [model_type.value for model_type in Mistral3ConfigBehavior]

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
behavior: VLMConfigBehavior = VLMConfigBehavior.VISION_EMBEDDINGS,
preprocessors: Optional[List[Any]] = None,
**kwargs,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
)
self._orig_config = config
if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
self._config = config.vision_config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)

def with_behavior(
self,
behavior: Union[str, Mistral3ConfigBehavior],
):
if isinstance(behavior, str) and not isinstance(behavior, Mistral3ConfigBehavior):
behavior = Mistral3ConfigBehavior(behavior)

if behavior == Mistral3ConfigBehavior.MULTI_MODAL_PROJECTOR:
return Mistral3MultiModalProjectorOpenVINOConfig(
self._orig_config.vision_config,
task="feature-extraction",
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
)

return super().with_behavior(behavior)

def get_model_for_behavior(self, model, behavior: Union[str, Mistral3ConfigBehavior]):
if isinstance(behavior, str) and not isinstance(behavior, Mistral3ConfigBehavior):
behavior = Mistral3ConfigBehavior(behavior)

if behavior == Mistral3ConfigBehavior.MULTI_MODAL_PROJECTOR:
return (
model.multi_modal_projector
if hasattr(model, "multi_modal_projector")
else model.model.multi_modal_projector
)

return super().get_model_for_behavior(model, behavior)

def patch_model_for_export(self, model: PreTrainedModel, model_kwargs: Optional[Dict[str, Any]] = None):
model_kwargs = model_kwargs or {}

if self._behavior != VLMConfigBehavior.VISION_EMBEDDINGS:
return super().patch_model_for_export(model, model_kwargs)

return Mistral3ImageEmbeddingModelPatcher(self, model, model_kwargs)

def generate_dummy_inputs(self, framework: str = "pt", **kwargs) -> Dict:
if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS and self._config.model_type == "pixtral":
kwargs["batch_size"] = 1
return super().generate_dummy_inputs(framework, **kwargs)


@register_in_tasks_manager(
"maira2", *["image-text-to-text", "text-generation", "text-generation-with-past"], library_name="transformers"
)
Expand Down
67 changes: 64 additions & 3 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2248,7 +2248,7 @@ def _persimmon_self_attn_sdpa_forward(
fused_qkv = self.query_key_value(hidden_states)

# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_states, key_states, value_states) = self._split_heads(fused_qkv)
query_states, key_states, value_states = self._split_heads(fused_qkv)

if self.qk_layernorm:
query_states = self.q_layernorm(query_states)
Expand Down Expand Up @@ -3309,6 +3309,67 @@ def __exit__(self, exc_type, exc_value, traceback):
self._model.forward = self._model.__orig_forward


# Adopted from https://github.com/huggingface/transformers/blob/v5.2.0/src/transformers/models/mistral3/modeling_mistral3.py#L223-L248
# Mistral3Model.get_image_features() with only projector.norm() applied instead of full projector forward,
# as the patch_merger cycle block (unfold loop) cannot be traced to OpenVINO IR.
def mistral3_vision_embed_forward(self, pixel_values):
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
vision_feature_layer = self.config.vision_feature_layer
if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer]
else:
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
image_features = self.multi_modal_projector.norm(selected_image_feature.squeeze(0))
return image_features


# Adopted from https://github.com/huggingface/transformers/blob/v5.2.0/src/transformers/models/mistral3/modeling_mistral3.py#L76-L94
# and https://github.com/huggingface/transformers/blob/v5.2.0/src/transformers/models/mistral3/modeling_mistral3.py#L118-L124
# Mistral3MultiModalProjector.forward() and Mistral3PatchMerger.forward() with norm and cycle block excluded.
# norm is moved to vision_embed_forward, cycle block runs in PyTorch at runtime.
def mistral3_multi_modal_projector_forward(self, image_features):
hidden_states = self.patch_merger.merging_layer(image_features)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states


class Mistral3ImageEmbeddingModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: "PreTrainedModel",
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = types.MethodType(mistral3_vision_embed_forward, model)

super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


class Mistral3MultiModalProjectorPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: "PreTrainedModel",
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = types.MethodType(mistral3_multi_modal_projector_forward, model)

super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


def _embednb_forward(self, ids: torch.Tensor) -> torch.Tensor:
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
Expand Down Expand Up @@ -6278,8 +6339,8 @@ def __enter__(self):
if is_transformers_version(">=", "4.56"):
# openvino is not able to trace through the new chunked_overlay with left_padding
self.original_chunked_overlay = transformers.masking_utils.chunked_overlay
transformers.masking_utils.chunked_overlay = (
lambda chunk_size, left_padding: transformers.masking_utils._legacy_chunked_overlay(chunk_size)
transformers.masking_utils.chunked_overlay = lambda chunk_size, left_padding: (
transformers.masking_utils._legacy_chunked_overlay(chunk_size)
)

def __exit__(self, exc_type, exc_value, traceback):
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def get_submodels(model):
"llava_next",
"llava_next_video",
"llava-qwen2",
"mistral3",
"internvl_chat",
"maira2",
"minicpmv",
Expand Down
86 changes: 86 additions & 0 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,91 @@ def get_video_features(self, pixel_values, input_ids=None, **kwargs):
return video_features


class _OVMistral3ForCausalLM(OVModelForVisualCausalLM):
additional_parts = ["multi_modal_projector"]

def get_vision_embeddings(self, pixel_values, input_ids=None, image_sizes=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None

image_features = self.vision_embeddings(pixel_values).last_hidden_state
image_features = torch.from_numpy(image_features) if isinstance(image_features, np.ndarray) else image_features

# Adopted from https://github.com/huggingface/transformers/blob/v5.2.0/src/transformers/models/mistral3/modeling_mistral3.py#L75-L96
patch_size = self.config.vision_config.patch_size
spatial_merge_size = self.config.spatial_merge_size
d = image_features.shape[-1]

image_sizes_scaled = [(size[0] // patch_size, size[1] // patch_size) for size in image_sizes]
tokens_per_image = [h * w for h, w in image_sizes_scaled]

permuted_tensor = []
for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
h, w = image_sizes_scaled[image_index]
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
grid = torch.nn.functional.unfold(
image_grid,
kernel_size=spatial_merge_size,
stride=spatial_merge_size,
)
grid = grid.view(d * spatial_merge_size**2, -1).t()
permuted_tensor.append(grid)

image_features = torch.cat(permuted_tensor, dim=0)
image_features = self.multi_modal_projector(image_features)

return image_features

# Adopted from https://github.com/huggingface/transformers/blob/v5.2.0/src/transformers/models/mistral3/modeling_mistral3.py#L258-L280
# and https://github.com/huggingface/transformers/blob/v5.2.0/src/transformers/models/mistral3/modeling_mistral3.py#L313-L324
def merge_vision_text_embeddings(
self,
vision_embeds,
inputs_embeds,
input_ids=None,
attention_mask=None,
position_ids=None,
**kwargs,
):
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds

special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

return inputs_embeds, attention_mask, position_ids

@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if video is not None or audio is not None:
raise ValueError("Video/Audio input is not supported for Mistral3")

conversation = [
{
"role": "user",
"content": [{"type": "text", "text": text}],
}
]
if image is not None:
conversation[0]["content"].insert(0, {"type": "image"})

prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=prompt, return_tensors="pt")
return inputs


class _OVInternVLForCausalLM(OVModelForVisualCausalLM):
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
Expand Down Expand Up @@ -4806,6 +4891,7 @@ def preprocess_inputs(
"llava": _OVLlavaForCausalLM,
"llava_next": _OVLlavaNextForCausalLM,
"llava_next_video": _OVLlavaNextVideoForCausalLM,
"mistral3": _OVMistral3ForCausalLM,
"minicpmv": _OVMiniCPMVForCausalLM,
"llava-qwen2": _OVNanoLlavaForCausalLM,
"maira2": _OVMaira2ForCausalLM,
Expand Down
3 changes: 3 additions & 0 deletions tests/openvino/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class ExportModelTest(unittest.TestCase):
if is_transformers_version(">=", "4.49"):
SUPPORTED_ARCHITECTURES.update({"zamba2": OVModelForCausalLM})

if is_transformers_version(">=", "4.50.0"):
SUPPORTED_ARCHITECTURES.update({"mistral3": OVModelForVisualCausalLM})

if is_transformers_version(">=", "4.53.0"):
SUPPORTED_ARCHITECTURES.update({"granitemoehybrid": OVModelForCausalLM})

Expand Down
Loading