Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
InternVLChatImageEmbeddingModelPatcher,
JaisModelPatcher,
Lfm2ModelPatcher,
Lfm2MoeModelPatcher,
Llama4ImageEmbeddingsModelPatcher,
Llama4TextModelPatcher,
LlavaImageEmbeddingModelPatcher,
Expand Down Expand Up @@ -4581,6 +4582,19 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
return common_inputs


@register_in_tasks_manager(
"lfm2_moe",
*[
"text-generation",
"text-generation-with-past",
],
library_name="transformers",
)
class LFM2MoeOpenVINOConfig(LFM2OpenVINOConfig):
MAX_TRANSFORMERS_VERSION = "4.57.99"
_MODEL_PATCHER = Lfm2MoeModelPatcher


@register_in_tasks_manager(
"granitemoehybrid", *["text-generation", "text-generation-with-past"], library_name="transformers"
)
Expand Down
106 changes: 101 additions & 5 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7162,17 +7162,113 @@ def __enter__(self):
conv_layer.slow_forward = types.MethodType(lfm2_short_conv_forward_patched, conv_layer)

def __exit__(self, exc_type, exc_value, traceback):
from transformers.models.lfm2.modeling_lfm2 import Lfm2ShortConv
super().__exit__(exc_type, exc_value, traceback)
setattr(self._model, self.orig_forward_name, self.model_orig_forward)


def lfm2_moe_sparse_block_forward_patched(self, hidden_states: torch.Tensor):
batch_size, sequence_length, hidden_dim = hidden_states.shape
num_tokens = batch_size * sequence_length
num_experts = self.num_experts

hidden_states = hidden_states.view(num_tokens, hidden_dim)

router_logits = self.gate(hidden_states)

if self.router_temperature != 1.0:
router_logits = router_logits / self.router_temperature

if self.score_function == "sigmoid":
routing_scores = router_logits.sigmoid()

if self.expert_bias is not None:
scores_for_routing = routing_scores + self.expert_bias
_, selected_experts = torch.topk(scores_for_routing, k=self.top_k, dim=-1)
routing_weights = torch.gather(routing_scores, dim=1, index=selected_experts)
else:
routing_weights, selected_experts = torch.topk(routing_scores, k=self.top_k, dim=-1)

if self.norm_topk_prob:
routing_weights = routing_weights / (routing_weights.sum(dim=-1, keepdim=True) + 1e-6)

elif self.score_function == "softmax":
scores_for_routing, selected_experts = torch.topk(router_logits, k=self.top_k, dim=-1)
routing_weights = torch.softmax(scores_for_routing, dim=-1, dtype=torch.float32)

else:
raise ValueError(f"Unsupported router score function: {self.score_function}")

if self.routed_scaling_factor:
routing_weights = routing_weights * self.routed_scaling_factor

routing_weights = routing_weights.to(hidden_states.dtype)

dense_routing_weights = torch.zeros(
num_tokens,
num_experts,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
dense_routing_weights.scatter_(dim=1, index=selected_experts, src=routing_weights)
hidden_states_expanded = hidden_states.repeat(num_experts, 1) # (num_experts * num_tokens, hidden_dim)
hidden_states_expanded = hidden_states_expanded.view(
num_experts, -1, hidden_dim
) # (num_experts, num_tokens, hidden_dim)

# self.w2(F.silu(self.w1(x)) * self.w3(x))
silu_out = F.silu(torch.bmm(hidden_states_expanded, self.w1_stacked.transpose(1, 2)))
x_w3 = torch.bmm(hidden_states_expanded, self.w3_stacked.transpose(1, 2))
silu_x_w3 = silu_out * x_w3
next_states = torch.bmm(silu_x_w3, self.w2_stacked.transpose(1, 2))

next_states = next_states.view(num_experts, batch_size, -1, hidden_dim)
next_states = next_states * dense_routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
next_states = next_states.sum(dim=0)

return next_states, router_logits


class Lfm2MoeModelPatcher(Lfm2ModelPatcher):
def __enter__(self):
super().__enter__()
setattr(self._model, self.orig_forward_name, self.patched_forward)

for idx, layer in enumerate(self._model.model.layers):
if hasattr(layer, "conv"):
conv_layer = layer.conv
conv_layer._orig_forward = conv_layer.slow_forward
conv_layer.slow_forward = types.MethodType(lfm2_short_conv_forward_patched, conv_layer)
if hasattr(layer, "feed_forward") and hasattr(layer.feed_forward, "num_experts"):
sparse_moe_block = layer.feed_forward
num_experts = sparse_moe_block.num_experts
sparse_moe_block.w1_stacked = torch.concat(
tuple(sparse_moe_block.experts[i].w1.weight.unsqueeze(0) for i in range(num_experts)),
dim=0,
)
sparse_moe_block.w2_stacked = torch.concat(
tuple(sparse_moe_block.experts[i].w2.weight.unsqueeze(0) for i in range(num_experts)),
dim=0,
)
sparse_moe_block.w3_stacked = torch.concat(
tuple(sparse_moe_block.experts[i].w3.weight.unsqueeze(0) for i in range(num_experts)), dim=0
)
sparse_moe_block._orig_forward = sparse_moe_block.forward
sparse_moe_block.forward = types.MethodType(lfm2_moe_sparse_block_forward_patched, sparse_moe_block)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
setattr(self._model, self.orig_forward_name, self.model_orig_forward)

for layer in self._model.model.layers:
if hasattr(layer, "conv") and isinstance(layer.conv, Lfm2ShortConv):
if hasattr(layer, "conv"):
conv_layer = layer.conv
else:
continue
conv_layer.slow_forward = conv_layer._orig_forward
conv_layer.slow_forward = conv_layer._orig_forward
elif hasattr(layer, "feed_forward") and hasattr(layer.feed_forward, "num_experts"):
sparse_moe_block = layer.feed_forward
sparse_moe_block.forward = sparse_moe_block._orig_forward
delattr(sparse_moe_block, "w1_stacked")
delattr(sparse_moe_block, "w2_stacked")
delattr(sparse_moe_block, "w3_stacked")


class GptOssModelPatcher(OVDecoderModelPatcher):
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def get_submodels(model):
"minicpmo",
]

SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "lfm2", "granitemoehybrid"]
SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "lfm2", "lfm2_moe", "granitemoehybrid"]

# All transformers, diffusers, timm and sentence transformers models that are supported via optimum-onnx OnnxConfigs but that have currently no test
# TODO: add tests for all models that are compatible and remove support for all others
Expand Down
2 changes: 1 addition & 1 deletion optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,7 @@ def prepare_inputs_for_generation(
# decoding stage so it takes the last token
input_ids = input_ids[:, -1].unsqueeze(-1)

if self.config.model_type not in ["lfm2", "granitemoehybrid"]:
if self.config.model_type not in ["lfm2", "lfm2_moe", "granitemoehybrid"]:
# LFM2 and GraniteMoeHybrid (Granite-4.0) require the attention mask
# to be the length of the full context, so default mask from OVModelForCausalLM needs to be used.
# Other models like Mamba typically do not require an attention_mask
Expand Down
Loading