-
Notifications
You must be signed in to change notification settings - Fork 207
[OpenVINO] Add support for Mistral3 #1627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f3e3433
4808de0
3100050
976843b
52c4b61
422709a
a6bb58e
56eb853
2f288b0
a46ff83
7dcc1e4
1d4c1e9
316ffa4
aa397b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -175,6 +175,8 @@ | |||||
| MiniCPMModelPatcher, | ||||||
| MiniCPMVImageEmbeddingsModelPatcher, | ||||||
| MiniCPMVResamplerModelPatcher, | ||||||
| Mistral3ImageEmbeddingModelPatcher, | ||||||
| Mistral3MultiModalProjectorPatcher, | ||||||
| MistralModelPatcher, | ||||||
| MixtralModelPatcher, | ||||||
| MPTModelPatcher, | ||||||
|
|
@@ -2054,6 +2056,130 @@ def patch_model_for_export(self, model: PreTrainedModel, model_kwargs: Optional[ | |||||
| return LlavaNextVideoImageEmbeddingModelPatcher(self, model, model_kwargs) | ||||||
|
|
||||||
|
|
||||||
| class Mistral3ConfigBehavior(str, enum.Enum): | ||||||
| LANGUAGE = "language" | ||||||
| # VISION_EMBEDDINGS extracts visual features and applies projector.norm(). | ||||||
| # Combined with the cycle block | ||||||
| # (https://github.com/huggingface/transformers/blob/v5.2.0/src/transformers/models/mistral3/modeling_mistral3.py#L76-L94) | ||||||
| # and MULTI_MODAL_PROJECTOR, this is equivalent to get_image_features | ||||||
| # (https://github.com/huggingface/transformers/blob/v5.2.0/src/transformers/models/mistral3/modeling_mistral3.py#L223-L248). | ||||||
| VISION_EMBEDDINGS = "vision_embeddings" | ||||||
| TEXT_EMBEDDINGS = "text_embeddings" | ||||||
| MULTI_MODAL_PROJECTOR = "multi_modal_projector" | ||||||
|
|
||||||
|
|
||||||
| class DummyMistral3MultiModalProjectorInputGenerator(DummyLLavaMultiModalProjectorInputGenerator): | ||||||
| def __init__( | ||||||
| self, | ||||||
| task: str, | ||||||
| normalized_config: NormalizedVisionConfig, | ||||||
| batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], | ||||||
| random_batch_size_range: Optional[Tuple[int, int]] = None, | ||||||
| **kwargs, | ||||||
| ): | ||||||
| super().__init__(task, normalized_config, batch_size, random_batch_size_range, **kwargs) | ||||||
| self.spatial_merge_size = getattr( | ||||||
| normalized_config.config, "spatial_merge_size", getattr(normalized_config, "spatial_merge_size", 2) | ||||||
| ) | ||||||
| self.num_merged_patches = self.num_patches // (self.spatial_merge_size**2) | ||||||
|
|
||||||
| def generate( | ||||||
| self, | ||||||
| input_name: str, | ||||||
| framework: str = "pt", | ||||||
| int_dtype: str = "int64", | ||||||
| float_dtype: str = "fp32", | ||||||
| ): | ||||||
| input_dim = self.hidden_size * self.spatial_merge_size**2 | ||||||
| shape = [self.num_merged_patches, input_dim] | ||||||
| return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) | ||||||
|
|
||||||
|
|
||||||
| class Mistral3MultiModalProjectorOpenVINOConfig(OnnxConfig): | ||||||
| DUMMY_INPUT_GENERATOR_CLASSES = (DummyMistral3MultiModalProjectorInputGenerator,) | ||||||
| NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig | ||||||
| _MODEL_PATCHER = Mistral3MultiModalProjectorPatcher | ||||||
|
|
||||||
| @property | ||||||
| def inputs(self) -> Dict[str, Dict[int, str]]: | ||||||
| return {"image_features": {0: "num_patches"}} | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not :
Suggested change
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is related to the |
||||||
|
|
||||||
| @property | ||||||
| def outputs(self) -> Dict[str, Dict[int, str]]: | ||||||
| return {"hidden_states": {0: "num_patches"}} | ||||||
|
|
||||||
|
|
||||||
| @register_in_tasks_manager("mistral3", *["image-text-to-text"], library_name="transformers") | ||||||
| class Mistral3OpenVINOConfig(BaseVLMOpenVINOConfig): | ||||||
| MIN_TRANSFORMERS_VERSION = "4.50.0" | ||||||
| SUPPORTED_BEHAVIORS = [model_type.value for model_type in Mistral3ConfigBehavior] | ||||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| config: "PretrainedConfig", | ||||||
| task: str = "feature-extraction", | ||||||
| int_dtype: str = "int64", | ||||||
| float_dtype: str = "fp32", | ||||||
| behavior: VLMConfigBehavior = VLMConfigBehavior.VISION_EMBEDDINGS, | ||||||
| preprocessors: Optional[List[Any]] = None, | ||||||
| **kwargs, | ||||||
| ): | ||||||
| super().__init__( | ||||||
| config=config, | ||||||
| task=task, | ||||||
| int_dtype=int_dtype, | ||||||
| float_dtype=float_dtype, | ||||||
| preprocessors=preprocessors, | ||||||
| ) | ||||||
| self._orig_config = config | ||||||
| if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"): | ||||||
| self._config = config.vision_config | ||||||
| self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) | ||||||
|
|
||||||
| def with_behavior( | ||||||
| self, | ||||||
| behavior: Union[str, Mistral3ConfigBehavior], | ||||||
| ): | ||||||
| if isinstance(behavior, str) and not isinstance(behavior, Mistral3ConfigBehavior): | ||||||
| behavior = Mistral3ConfigBehavior(behavior) | ||||||
|
|
||||||
| if behavior == Mistral3ConfigBehavior.MULTI_MODAL_PROJECTOR: | ||||||
| return Mistral3MultiModalProjectorOpenVINOConfig( | ||||||
| self._orig_config.vision_config, | ||||||
| task="feature-extraction", | ||||||
| int_dtype=self.int_dtype, | ||||||
| float_dtype=self.float_dtype, | ||||||
| ) | ||||||
|
|
||||||
| return super().with_behavior(behavior) | ||||||
|
|
||||||
| def get_model_for_behavior(self, model, behavior: Union[str, Mistral3ConfigBehavior]): | ||||||
| if isinstance(behavior, str) and not isinstance(behavior, Mistral3ConfigBehavior): | ||||||
| behavior = Mistral3ConfigBehavior(behavior) | ||||||
|
|
||||||
| if behavior == Mistral3ConfigBehavior.MULTI_MODAL_PROJECTOR: | ||||||
| return ( | ||||||
| model.multi_modal_projector | ||||||
| if hasattr(model, "multi_modal_projector") | ||||||
| else model.model.multi_modal_projector | ||||||
| ) | ||||||
|
|
||||||
| return super().get_model_for_behavior(model, behavior) | ||||||
|
|
||||||
| def patch_model_for_export(self, model: PreTrainedModel, model_kwargs: Optional[Dict[str, Any]] = None): | ||||||
| model_kwargs = model_kwargs or {} | ||||||
|
|
||||||
| if self._behavior != VLMConfigBehavior.VISION_EMBEDDINGS: | ||||||
| return super().patch_model_for_export(model, model_kwargs) | ||||||
|
|
||||||
| return Mistral3ImageEmbeddingModelPatcher(self, model, model_kwargs) | ||||||
|
|
||||||
| def generate_dummy_inputs(self, framework: str = "pt", **kwargs) -> Dict: | ||||||
| if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS and self._config.model_type == "pixtral": | ||||||
| kwargs["batch_size"] = 1 | ||||||
| return super().generate_dummy_inputs(framework, **kwargs) | ||||||
|
|
||||||
|
|
||||||
| @register_in_tasks_manager( | ||||||
| "maira2", *["image-text-to-text", "text-generation", "text-generation-with-past"], library_name="transformers" | ||||||
| ) | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.