diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index ca12d455be..d0a6dd92f8 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -163,6 +163,7 @@ InternVLChatImageEmbeddingModelPatcher, JaisModelPatcher, Lfm2ModelPatcher, + Lfm2MoeModelPatcher, Llama4ImageEmbeddingsModelPatcher, Llama4TextModelPatcher, LlavaImageEmbeddingModelPatcher, @@ -4581,6 +4582,19 @@ def inputs(self) -> Dict[str, Dict[int, str]]: return common_inputs +@register_in_tasks_manager( + "lfm2_moe", + *[ + "text-generation", + "text-generation-with-past", + ], + library_name="transformers", +) +class LFM2MoeOpenVINOConfig(LFM2OpenVINOConfig): + MAX_TRANSFORMERS_VERSION = "4.57.99" + _MODEL_PATCHER = Lfm2MoeModelPatcher + + @register_in_tasks_manager( "granitemoehybrid", *["text-generation", "text-generation-with-past"], library_name="transformers" ) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 56e550858c..138b970ada 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -7162,17 +7162,113 @@ def __enter__(self): conv_layer.slow_forward = types.MethodType(lfm2_short_conv_forward_patched, conv_layer) def __exit__(self, exc_type, exc_value, traceback): - from transformers.models.lfm2.modeling_lfm2 import Lfm2ShortConv + super().__exit__(exc_type, exc_value, traceback) + setattr(self._model, self.orig_forward_name, self.model_orig_forward) + + +def lfm2_moe_sparse_block_forward_patched(self, hidden_states: torch.Tensor): + batch_size, sequence_length, hidden_dim = hidden_states.shape + num_tokens = batch_size * sequence_length + num_experts = self.num_experts + + hidden_states = hidden_states.view(num_tokens, hidden_dim) + + router_logits = self.gate(hidden_states) + + if self.router_temperature != 1.0: + router_logits = router_logits / self.router_temperature + + if self.score_function == "sigmoid": + routing_scores = router_logits.sigmoid() + if self.expert_bias is not None: + scores_for_routing = routing_scores + self.expert_bias + _, selected_experts = torch.topk(scores_for_routing, k=self.top_k, dim=-1) + routing_weights = torch.gather(routing_scores, dim=1, index=selected_experts) + else: + routing_weights, selected_experts = torch.topk(routing_scores, k=self.top_k, dim=-1) + + if self.norm_topk_prob: + routing_weights = routing_weights / (routing_weights.sum(dim=-1, keepdim=True) + 1e-6) + + elif self.score_function == "softmax": + scores_for_routing, selected_experts = torch.topk(router_logits, k=self.top_k, dim=-1) + routing_weights = torch.softmax(scores_for_routing, dim=-1, dtype=torch.float32) + + else: + raise ValueError(f"Unsupported router score function: {self.score_function}") + + if self.routed_scaling_factor: + routing_weights = routing_weights * self.routed_scaling_factor + + routing_weights = routing_weights.to(hidden_states.dtype) + + dense_routing_weights = torch.zeros( + num_tokens, + num_experts, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + dense_routing_weights.scatter_(dim=1, index=selected_experts, src=routing_weights) + hidden_states_expanded = hidden_states.repeat(num_experts, 1) # (num_experts * num_tokens, hidden_dim) + hidden_states_expanded = hidden_states_expanded.view( + num_experts, -1, hidden_dim + ) # (num_experts, num_tokens, hidden_dim) + + # self.w2(F.silu(self.w1(x)) * self.w3(x)) + silu_out = F.silu(torch.bmm(hidden_states_expanded, self.w1_stacked.transpose(1, 2))) + x_w3 = torch.bmm(hidden_states_expanded, self.w3_stacked.transpose(1, 2)) + silu_x_w3 = silu_out * x_w3 + next_states = torch.bmm(silu_x_w3, self.w2_stacked.transpose(1, 2)) + + next_states = next_states.view(num_experts, batch_size, -1, hidden_dim) + next_states = next_states * dense_routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + + return next_states, router_logits + + +class Lfm2MoeModelPatcher(Lfm2ModelPatcher): + def __enter__(self): + super().__enter__() + setattr(self._model, self.orig_forward_name, self.patched_forward) + + for idx, layer in enumerate(self._model.model.layers): + if hasattr(layer, "conv"): + conv_layer = layer.conv + conv_layer._orig_forward = conv_layer.slow_forward + conv_layer.slow_forward = types.MethodType(lfm2_short_conv_forward_patched, conv_layer) + if hasattr(layer, "feed_forward") and hasattr(layer.feed_forward, "num_experts"): + sparse_moe_block = layer.feed_forward + num_experts = sparse_moe_block.num_experts + sparse_moe_block.w1_stacked = torch.concat( + tuple(sparse_moe_block.experts[i].w1.weight.unsqueeze(0) for i in range(num_experts)), + dim=0, + ) + sparse_moe_block.w2_stacked = torch.concat( + tuple(sparse_moe_block.experts[i].w2.weight.unsqueeze(0) for i in range(num_experts)), + dim=0, + ) + sparse_moe_block.w3_stacked = torch.concat( + tuple(sparse_moe_block.experts[i].w3.weight.unsqueeze(0) for i in range(num_experts)), dim=0 + ) + sparse_moe_block._orig_forward = sparse_moe_block.forward + sparse_moe_block.forward = types.MethodType(lfm2_moe_sparse_block_forward_patched, sparse_moe_block) + + def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) setattr(self._model, self.orig_forward_name, self.model_orig_forward) for layer in self._model.model.layers: - if hasattr(layer, "conv") and isinstance(layer.conv, Lfm2ShortConv): + if hasattr(layer, "conv"): conv_layer = layer.conv - else: - continue - conv_layer.slow_forward = conv_layer._orig_forward + conv_layer.slow_forward = conv_layer._orig_forward + elif hasattr(layer, "feed_forward") and hasattr(layer.feed_forward, "num_experts"): + sparse_moe_block = layer.feed_forward + sparse_moe_block.forward = sparse_moe_block._orig_forward + delattr(sparse_moe_block, "w1_stacked") + delattr(sparse_moe_block, "w2_stacked") + delattr(sparse_moe_block, "w3_stacked") class GptOssModelPatcher(OVDecoderModelPatcher): diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index de92645017..45f9409d73 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -304,7 +304,7 @@ def get_submodels(model): "minicpmo", ] -SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "lfm2", "granitemoehybrid"] +SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "lfm2", "lfm2_moe", "granitemoehybrid"] # All transformers, diffusers, timm and sentence transformers models that are supported via optimum-onnx OnnxConfigs but that have currently no test # TODO: add tests for all models that are compatible and remove support for all others diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 4b9d4ce027..ed2b068684 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -1440,7 +1440,7 @@ def prepare_inputs_for_generation( # decoding stage so it takes the last token input_ids = input_ids[:, -1].unsqueeze(-1) - if self.config.model_type not in ["lfm2", "granitemoehybrid"]: + if self.config.model_type not in ["lfm2", "lfm2_moe", "granitemoehybrid"]: # LFM2 and GraniteMoeHybrid (Granite-4.0) require the attention mask # to be the length of the full context, so default mask from OVModelForCausalLM needs to be used. # Other models like Mamba typically do not require an attention_mask