Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/openvino/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Here is the list of the supported architectures :
- Falcon-Mamba
- Flaubert
- GLM-4
- GLM-4 MoE (GLM-4.7)
- GLM-Edge
- GPT-2
- GPT-BigCode
Expand Down
16 changes: 16 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
FluxTransfromerModelPatcher,
Gemma2ModelPatcher,
Gemma3LMModelPatcher,
Glm4MoePatcher,
GptJModelPatcher,
GptNeoModelPatcher,
GptNeoxModelPatcher,
Expand Down Expand Up @@ -3931,6 +3932,21 @@ class GLM4OpenVINOConfig(LlamaOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.51.3"


@register_in_tasks_manager(
"glm4_moe",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
],
library_name="transformers",
)
class Glm4MoeOpenVINOConfig(LlamaOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.57.0"
_MODEL_PATCHER = Glm4MoePatcher


@register_in_tasks_manager(
"granite",
*[
Expand Down
100 changes: 100 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7634,6 +7634,106 @@ def __exit__(self, exc_type, exc_value, traceback):
del afmoe_moe.down_projs, afmoe_moe.gate_projs, afmoe_moe.up_projs


def glm4_moe_forward_patched(self, hidden_states):
"""
Vectorized MoE forward for Glm4MoeMoE.

Replaces the original for-loop over experts with a batched matmul approach,
avoiding data-dependent control flow (if token_indices.numel() > 0) that breaks
torch.jit.trace. Also produces a much smaller OpenVINO graph.
"""
num_experts = self.config.n_routed_experts
batch_size, seq_len, hidden_dim = hidden_states.shape

# Router: returns topk_indices [B*S, top_k] and topk_weights [B*S, top_k]
topk_indices, topk_weights = self.gate(hidden_states)

# Build full routing weight matrix [B*S, num_experts]
new_routing_weights = torch.zeros(
batch_size * seq_len, num_experts, dtype=topk_weights.dtype, device=topk_weights.device
)
new_routing_weights.scatter_(dim=1, index=topk_indices, src=topk_weights)

hidden_states = hidden_states.view(-1, hidden_dim)

# Process shared experts
shared_output = self.shared_experts(hidden_states)

# Vectorized expert computation using batched matmul
hidden_states = hidden_states.repeat(num_experts, 1)
hidden_states = hidden_states.view(num_experts, -1, hidden_dim)
act_fn = self.experts[0].act_fn

gate = torch.bmm(hidden_states, self.gate_projs)
up = torch.bmm(hidden_states, self.up_projs)
gate_up = act_fn(gate) * up
next_states = torch.bmm(gate_up, self.down_projs)

next_states = next_states.view(num_experts, batch_size, -1, hidden_dim)
next_states = next_states * new_routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
next_states = next_states.sum(dim=0)

shared_output = shared_output.view(batch_size, -1, hidden_dim)
output = shared_output + next_states
return output.view(batch_size, seq_len, hidden_dim)


class Glm4MoePatcher(OVDecoderModelPatcher):
"""Model patcher for Glm4Moe models (e.g., GLM-4.7-Flash).

Patches the MoE forward to use vectorized batched matmul instead of
a for-loop over experts with data-dependent conditional branching.
"""

def __enter__(self):
super().__enter__()
for layer in self._model.model.layers:
if isinstance(layer.mlp, type) or not hasattr(layer.mlp, "experts"):
continue
if not hasattr(layer.mlp, "experts") or not hasattr(layer.mlp, "gate"):
continue

moe = layer.mlp
num_experts = moe.config.n_routed_experts
moe._orig_forward = moe.forward
moe.forward = types.MethodType(glm4_moe_forward_patched, moe)

# Fuse expert weights for vectorized batched matmul
moe.gate_projs = (
torch.concat(
tuple(moe.experts[i].gate_proj.weight.unsqueeze(0) for i in range(num_experts)),
dim=0,
)
.transpose(1, 2)
.float()
)
moe.up_projs = (
torch.concat(
tuple(moe.experts[i].up_proj.weight.unsqueeze(0) for i in range(num_experts)),
dim=0,
)
.transpose(1, 2)
.float()
)
moe.down_projs = (
torch.concat(
tuple(moe.experts[i].down_proj.weight.unsqueeze(0) for i in range(num_experts)),
dim=0,
)
.transpose(1, 2)
.float()
)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for layer in self._model.model.layers:
if hasattr(layer.mlp, "_orig_forward"):
layer.mlp.forward = layer.mlp._orig_forward
del layer.mlp._orig_forward
if hasattr(layer.mlp, "gate_projs"):
del layer.mlp.gate_projs, layer.mlp.up_projs, layer.mlp.down_projs


# adopted from https://github.com/huggingface/transformers/blob/v4.57.6/src/transformers/models/llama/modeling_llama.py#L197
class LlamaEagle3Attention(LlamaAttention):
"""
Expand Down
4 changes: 4 additions & 0 deletions tests/openvino/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
if is_transformers_version(">=", "4.55.0") and is_transformers_version("<", "4.58.0"):
SUPPORTED_ARCHITECTURES += ("afmoe",)

if is_transformers_version(">=", "4.57.0"):
SUPPORTED_ARCHITECTURES += ("glm4_moe",)

if is_transformers_version("<", "4.56.0"):
SUPPORTED_ARCHITECTURES += ("qwen", "chatglm", "chatglm4")

Expand Down Expand Up @@ -224,6 +227,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"mixtral_awq": 2,
"gemma3_text": 2,
"glm4": 2,
"glm4_moe": 2,
"qwen3": 2,
"qwen3_moe": 2,
"mamba": 0,
Expand Down
3 changes: 3 additions & 0 deletions tests/openvino/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ class ExportModelTest(unittest.TestCase):
if is_transformers_version(">=", "4.55.0") and is_transformers_version("<", "4.58.0"):
SUPPORTED_ARCHITECTURES.update({"afmoe": OVModelForCausalLM})

if is_transformers_version(">=", "4.57.0"):
SUPPORTED_ARCHITECTURES.update({"glm4_moe": OVModelForCausalLM})

EXPECTED_DIFFUSERS_SCALE_FACTORS = {
"stable-diffusion-xl": {"vae_encoder": "128.0", "vae_decoder": "128.0"},
"stable-diffusion-3": {"text_encoder_3": "8.0"},
Expand Down
7 changes: 7 additions & 0 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ class OVCLIExportTestCase(unittest.TestCase):
]
)

if is_transformers_version(">=", "4.57.0"):
SUPPORTED_ARCHITECTURES.extend(
[
("text-generation-with-past", "glm4_moe"),
]
)

EXPECTED_NUMBER_OF_TOKENIZER_MODELS = {
"gpt2": 2,
"t5": 2,
Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
"xglm": "optimum-intel-internal-testing/tiny-random-XGLMForCausalLM",
"xverse": "optimum-intel-internal-testing/tiny-random-xverse",
"glm4": "optimum-intel-internal-testing/tiny-random-glm4",
"glm4_moe": "optimum-intel-internal-testing/tiny-random-glm4-moe",
"glm": "optimum-intel-internal-testing/tiny-random-glm-edge",
"open-clip": "optimum-intel-internal-testing/tiny-open-clip-model",
"open-clip-ov": "optimum-intel-internal-testing/tiny-open-clip-model",
Expand Down
Loading