Skip to content

[OpenVINO] Add support for Mistral3#1627

Open
kyoui-dev wants to merge 14 commits intohuggingface:mainfrom
kyoui-dev:mistral3
Open

[OpenVINO] Add support for Mistral3#1627
kyoui-dev wants to merge 14 commits intohuggingface:mainfrom
kyoui-dev:mistral3

Conversation

@kyoui-dev
Copy link

@kyoui-dev kyoui-dev commented Mar 2, 2026

What does this PR do?

Conversion cmd-line for mistralai/Mistral-Small-3.1-24B-Instruct-2503:

optimum-cli export openvino -m mistralai/Mistral-Small-3.1-24B-Instruct-2503 ./Mistral-Small-3.1-24B --task image-text-to-text

Inference of mistralai/Mistral-Small-3.1-24B-Instruct-2503 using OpenVINO backend:

from transformers import AutoTokenizer, AutoProcessor
from transformers.image_utils import load_image
from huggingface_hub import hf_hub_download
from optimum.intel.openvino import OVModelForVisualCausalLM


model_dir = "./Mistral-Small-3.1-24B"

tokenizer = AutoTokenizer.from_pretrained(model_dir)
processor = AutoProcessor.from_pretrained(model_dir)
model = OVModelForVisualCausalLM.from_pretrained(model_dir)

# Prepare image input
image_path = hf_hub_download(
                repo_id="raushan-testing-hf/images_test",
                filename="australia.jpg",
                repo_type="dataset",
        )
image_input = load_image(image_path)
question = "Describe this image."
inputs = model.preprocess_inputs(processor=processor, text=question, image=image_input)

# Run inference
output_ids = model.generate(**inputs, max_new_tokens=10)
output_text = tokenizer.decode(output_ids[0])

print(output_text)

Fixes #1338

Before submitting

  • N/A This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@kyoui-dev
Copy link
Author

kyoui-dev commented Mar 2, 2026

Hi @popovaan,

Could you please take a look at my PR?

Thank you!

@popovaan
Copy link
Collaborator

popovaan commented Mar 2, 2026

Please add tests to the PR and use a local path for now, until we have a published tiny model.

Copy link
Collaborator

@rkazants rkazants left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add tests and create tiny model for it

@kyoui-dev
Copy link
Author

kyoui-dev commented Mar 3, 2026

Hi @popovaan and @rkazants,

I've added the tests and created a tiny model locally for now.

Here is the script I used to create the tiny model:

import os
from transformers import (
    AutoConfig,
    AutoModelForImageTextToText,
    AutoProcessor,
    AutoTokenizer,
)

model_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
config = AutoConfig.from_pretrained(model_id)

config.text_config.num_hidden_layers = 2
config.text_config.hidden_size = 8
config.text_config.intermediate_size = 64
config.text_config.num_attention_heads = 8
config.text_config.num_key_value_heads = 4
config.text_config.head_dim = 32

config.vision_config.num_hidden_layers = 2
config.vision_config.hidden_size = 128
config.vision_config.intermediate_size = 64
config.vision_config.num_attention_heads = 4
config.vision_config.head_dim = 32

model = AutoModelForImageTextToText.from_config(config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

output_dir = "./tiny-random-mistral3"
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

I’d appreciate it if you could review the updates.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds OpenVINO export + inference support for the new mistral3 (Mistral-Small-3.1) visual language model family, integrating it into the OpenVINO VLM export pipeline and test matrices.

Changes:

  • Introduces a Mistral3-specific OVModelForVisualCausalLM implementation and export-time patchers to handle non-traceable vision components.
  • Registers new OpenVINO export configs/behaviors for Mistral3, including a dedicated multi_modal_projector submodel.
  • Extends OpenVINO test coverage + documentation to include the mistral3 architecture (gated by transformers>=4.50.0).

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
optimum/intel/openvino/modeling_visual_language.py Adds _OVMistral3ForCausalLM runtime logic (vision embeddings + vision/text merge + preprocessing) and registers it in the architecture mapping.
optimum/exporters/openvino/model_patcher.py Adds Mistral3-specific forward patchers to make vision embedding + projector exportable to OV IR.
optimum/exporters/openvino/model_configs.py Registers Mistral3 in TasksManager custom loading, adds Mistral3 OpenVINO config + multi_modal_projector export config and dummy input generator.
optimum/exporters/openvino/utils.py Marks mistral3 as a multi-submodel VLM for OpenVINO export.
tests/openvino/utils_tests.py Adds mistral3 model fixture + expected INT8 node counts for its exported submodels.
tests/openvino/test_seq2seq.py Adds mistral3 to supported visual-causal-lm integration tests for transformers>=4.50.0.
tests/openvino/test_export.py Adds mistral3 to supported export architectures for transformers>=4.50.0.
tests/openvino/test_exporters_cli.py Adds CLI exporter coverage for (image-text-to-text, mistral3) for transformers>=4.50.0.
tests/openvino/test_quantization.py Adds mistral3 to auto-compression architecture list for transformers>=4.50.0.
tests/openvino/test_genai.py Allows mistral3 to be routed through AutoModelForImageTextToText in the GenAI pipeline helper.
docs/source/openvino/models.mdx Documents Mistral3 as a supported architecture.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@popovaan
Copy link
Collaborator

popovaan commented Mar 4, 2026

Hi @popovaan and @rkazants,

I've added the tests and created a tiny model locally for now.

Here is the script I used to create the tiny model:

import os
from transformers import (
    AutoConfig,
    AutoModelForImageTextToText,
    AutoProcessor,
    AutoTokenizer,
)

model_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
config = AutoConfig.from_pretrained(model_id)

config.text_config.num_hidden_layers = 2
config.text_config.hidden_size = 8
config.text_config.intermediate_size = 64
config.text_config.num_attention_heads = 8
config.text_config.num_key_value_heads = 4
config.text_config.head_dim = 32

config.vision_config.num_hidden_layers = 2
config.vision_config.hidden_size = 128
config.vision_config.intermediate_size = 64
config.vision_config.num_attention_heads = 4
config.vision_config.head_dim = 32

model = AutoModelForImageTextToText.from_config(config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

output_dir = "./tiny-random-mistral3"
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

I’d appreciate it if you could review the updates.

Thanks for sharing this script. I’ve published this model on Hugging Face, please use it in the tests in your PR.

https://huggingface.co/optimum-intel-internal-testing/tiny-random-mistral3

@kyoui-dev
Copy link
Author

Thanks! I've updated it. Please let me know if anything else is needed.

Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the addition @kyoui-dev !

Comment on lines +2069 to +2102
class DummyMistral3MultiModalProjectorInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ["image_features"]

def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
self.task = task
self.batch_size = batch_size
self.hidden_size = normalized_config.hidden_size
self.spatial_merge_size = getattr(
normalized_config.config, "spatial_merge_size",
getattr(normalized_config, "spatial_merge_size", 2)
)
image_size = normalized_config.image_size
patch_size = normalized_config.patch_size
patches_per_side = image_size // patch_size
merged_per_side = patches_per_side // self.spatial_merge_size
self.num_merged_patches = merged_per_side * merged_per_side

def generate(
self,
input_name: str,
framework: str = "pt",
int_dtype: str = "int64",
float_dtype: str = "fp32",
):
input_dim = self.hidden_size * self.spatial_merge_size ** 2
shape = [self.num_merged_patches, input_dim]
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not inherit fro DummyLLavaMultiModalProjectorInputGenerator and just add the spatial_merge_size attribute ? num_merged_patches can be inferred from self.num_patches no ? Also here the resulting input shape will not vary depending on batch_size could you extend on why ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points! I'll refactor to inherit from DummyLLavaMultiModalProjectorInputGenerator and derive num_merged_patches from self.num_patches.
Regarding batch_size: in the original transformers implementation, Mistral3PatchMerger.forward() concatenates patches from all images via torch.cat(permuted_tensor, dim=0) into a flat 2D tensor [total_merged_patches, dim], because each image can have different spatial dimensions and thus produce a different number of merged patches — so they can't form a uniform batch axis. Since the cycle block runs in PyTorch at runtime, the projector submodel receives this already-flattened 2D input, which is why batch_size doesn't appear in the shape.


@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"image_features": {0: "num_patches"}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not :

Suggested change
return {"image_features": {0: "num_patches"}}
return {"image_features": {0: "batch_size", 1: "sequence_length"}}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to the batch_size discussion above — since the cycle block flattens all images' patches into a single 2D tensor [total_merged_patches, dim], the projector input has only one dynamic dimension. So {0: "num_patches"} reflects the actual runtime shape.

@rkazants
Copy link
Collaborator

rkazants commented Mar 5, 2026

@kyoui-dev, please double-check that all newly added tests are passing locally on your machine:
{CEA5D74D-7014-441C-B491-4685ABF7D0E6}

@kyoui-dev
Copy link
Author

Hi @echarlaix, thanks for the review!

I've addressed all the comments you left. Could you take another look when you get a chance?

@kyoui-dev
Copy link
Author

Hi @rkazants,

All newly added tests are passing on my local machine. Could you take a look?

(optimum-intel) kyoui-dev@kyoui-MacBookPro optimum-intel % pytest \
  tests/openvino/test_export.py \
  tests/openvino/test_exporters_cli.py \
  tests/openvino/test_quantization.py \
  tests/openvino/test_seq2seq.py \
  -k "mistral3 or test_exporters_cli_25_image_text_to_text or test_ovmodel_load_with_compressed_weights_17 or test_ovmodel_load_with_uncompressed_weights_17" \
  -v
============================================================= test session starts ==============================================================
platform darwin -- Python 3.13.7, pytest-7.4.4, pluggy-1.6.0 -- /Users/kyoui-dev/Desktop/GitHub/optimum-intel/.venv/bin/python3
cachedir: .pytest_cache
rootdir: /Users/kyoui-dev/Desktop/GitHub/optimum-intel
configfile: pyproject.toml
plugins: anyio-4.12.1
collected 633 items / 626 deselected / 7 selected                                                                                              

tests/openvino/test_export.py::ExportModelTest::test_export_27_mistral3 PASSED                                                           [ 14%]
tests/openvino/test_exporters_cli.py::OVCLIExportTestCase::test_exporters_cli_25_image_text_to_text PASSED                               [ 28%]
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_compressed_weights_17 PASSED                        [ 42%]
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_uncompressed_weights_17 PASSED                      [ 57%]
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_compare_to_transformers_03_mistral3 PASSED                 [ 71%]
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_generate_utils_03_mistral3 PASSED                          [ 85%]
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_model_can_be_loaded_after_saving_03_mistral3 PASSED        [100%]

=============================================================== warnings summary ===============================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

.venv/lib/python3.13/site-packages/torch/jit/_script.py:1480
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/.venv/lib/python3.13/site-packages/torch/jit/_script.py:1480: DeprecationWarning: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

tests/openvino/test_export.py::ExportModelTest::test_export_27_mistral3
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_compressed_weights_17
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_uncompressed_weights_17
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_compare_to_transformers_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_generate_utils_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_model_can_be_loaded_after_saving_03_mistral3
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/.venv/lib/python3.13/site-packages/optimum/exporters/base.py:151: FutureWarning: functools.partial will be a method descriptor in future Python versions; wrap it in staticmethod() if you want to preserve the old behavior
    self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)

tests/openvino/test_export.py::ExportModelTest::test_export_27_mistral3
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_compressed_weights_17
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_uncompressed_weights_17
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_compare_to_transformers_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_generate_utils_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_model_can_be_loaded_after_saving_03_mistral3
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/optimum/exporters/openvino/model_configs.py:1772: FutureWarning: functools.partial will be a method descriptor in future Python versions; wrap it in staticmethod() if you want to preserve the old behavior
    InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS

tests/openvino/test_export.py: 4 warnings
tests/openvino/test_quantization.py: 8 warnings
tests/openvino/test_seq2seq.py: 12 warnings
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/.venv/lib/python3.13/site-packages/torch/jit/_trace.py:1000: DeprecationWarning: `torch.jit.trace` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

tests/openvino/test_export.py: 8 warnings
tests/openvino/test_quantization.py: 16 warnings
tests/openvino/test_seq2seq.py: 24 warnings
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/.venv/lib/python3.13/site-packages/torch/jit/_trace.py:1139: DeprecationWarning: `torch.jit.trace_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

tests/openvino/test_export.py::ExportModelTest::test_export_27_mistral3
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_compressed_weights_17
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_uncompressed_weights_17
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_compare_to_transformers_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_generate_utils_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_model_can_be_loaded_after_saving_03_mistral3
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/.venv/lib/python3.13/site-packages/transformers/cache_utils.py:132: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
    if not self.is_initialized or self.keys.numel() == 0:

tests/openvino/test_export.py::ExportModelTest::test_export_27_mistral3
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_compressed_weights_17
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_uncompressed_weights_17
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_compare_to_transformers_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_generate_utils_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_model_can_be_loaded_after_saving_03_mistral3
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/.venv/lib/python3.13/site-packages/transformers/masking_utils.py:207: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
    if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:

tests/openvino/test_export.py::ExportModelTest::test_export_27_mistral3
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_compressed_weights_17
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_uncompressed_weights_17
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_compare_to_transformers_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_generate_utils_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_model_can_be_loaded_after_saving_03_mistral3
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/optimum/exporters/openvino/model_patcher.py:233: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
    torch.tensor(0.0, device=mask.device, dtype=dtype),

tests/openvino/test_export.py::ExportModelTest::test_export_27_mistral3
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_compressed_weights_17
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_uncompressed_weights_17
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_compare_to_transformers_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_generate_utils_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_model_can_be_loaded_after_saving_03_mistral3
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/optimum/exporters/openvino/model_patcher.py:234: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
    torch.tensor(torch.finfo(torch.float16).min, device=mask.device, dtype=dtype),

tests/openvino/test_export.py::ExportModelTest::test_export_27_mistral3
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_compressed_weights_17
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_uncompressed_weights_17
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_compare_to_transformers_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_generate_utils_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_model_can_be_loaded_after_saving_03_mistral3
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/.venv/lib/python3.13/site-packages/transformers/integrations/sdpa_attention.py:81: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
    is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)

tests/openvino/test_export.py::ExportModelTest::test_export_27_mistral3
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_compressed_weights_17
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_uncompressed_weights_17
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_compare_to_transformers_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_generate_utils_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_model_can_be_loaded_after_saving_03_mistral3
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/.venv/lib/python3.13/site-packages/transformers/models/pixtral/modeling_pixtral.py:482: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
    for embed, size in zip(patch_embeds, image_sizes)

tests/openvino/test_export.py::ExportModelTest::test_export_27_mistral3
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_compressed_weights_17
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_uncompressed_weights_17
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_compare_to_transformers_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_generate_utils_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_model_can_be_loaded_after_saving_03_mistral3
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/.venv/lib/python3.13/site-packages/transformers/models/pixtral/modeling_pixtral.py:429: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
    block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1)

tests/openvino/test_export.py::ExportModelTest::test_export_27_mistral3
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_compressed_weights_17
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_uncompressed_weights_17
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_compare_to_transformers_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_generate_utils_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_model_can_be_loaded_after_saving_03_mistral3
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/.venv/lib/python3.13/site-packages/transformers/models/pixtral/modeling_pixtral.py:430: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
    block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1)

tests/openvino/test_export.py::ExportModelTest::test_export_27_mistral3
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_compressed_weights_17
tests/openvino/test_quantization.py::OVWeightCompressionTest::test_ovmodel_load_with_uncompressed_weights_17
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_compare_to_transformers_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_generate_utils_03_mistral3
tests/openvino/test_seq2seq.py::OVModelForVisualCausalLMIntegrationTest::test_model_can_be_loaded_after_saving_03_mistral3
  /Users/kyoui-dev/Desktop/GitHub/optimum-intel/.venv/lib/python3.13/site-packages/transformers/models/pixtral/modeling_pixtral.py:431: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
    for start, end in zip(block_start_idx, block_end_idx):

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================= 7 passed, 626 deselected, 141 warnings in 154.94s (0:02:34) ==========================================

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@popovaan
Copy link
Collaborator

This error seem to be related with this PR, please take a look:
https://github.com/huggingface/optimum-intel/actions/runs/22792915683/job/66592942836?pr=1627

Also please fix the code style with following set of commands:

ruff check --config pyproject.toml --fix .
ruff format --config pyproject.toml .
ruff check --config pyproject.toml .
ruff format --check --config pyproject.toml .
 
black .
black --check .

@kyoui-dev
Copy link
Author

Hi @popovaan,

Thank you for letting me know. I've fixed the test and the code style with ruff and black. Could you check?

@rkazants
Copy link
Collaborator

@kyoui-dev, please check tests locally before running our CI.

@kyoui-dev
Copy link
Author

@rkazants, I just ran the related tests locally and they passed on my end. Please let me know if anything else is needed.

@kyoui-dev kyoui-dev requested review from echarlaix and rkazants March 17, 2026 03:39
Comment on lines +343 to +350
(
{"encoder": 30, "decoder": 52, "decoder_with_past": 61}
if is_transformers_version("<=", "4.45")
else {
"encoder": 30,
"decoder": 52,
}
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some unrelated changes in the PR, probably due to code style applying? Please remove the changes from unrelated files.

@popovaan
Copy link
Collaborator

popovaan commented Mar 17, 2026

Could you please locally run OpenVINO GenAI WhoWhatBenchmark tool to check the accuracy of the full model (not the tiny one) and share the results?
https://github.com/openvinotoolkit/openvino.genai/tree/master/tools/who_what_benchmark

Here is the instruction: https://github.com/openvinotoolkit/openvino.genai/blob/master/tools/who_what_benchmark/README.md#compare-visual-language-models-with-image-inputs-vlms

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Please add support for mistral3 models for openvino export

6 participants