From 4bbb41308881b89dab1fc748cfd642cf77dc8ada Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Wed, 18 Feb 2026 10:49:47 +0400 Subject: [PATCH 1/3] [OpenVINO] Add support for GLM4.7 --- optimum/exporters/openvino/model_configs.py | 16 ++++ optimum/exporters/openvino/model_patcher.py | 100 ++++++++++++++++++++ tests/openvino/test_decoder.py | 4 + tests/openvino/utils_tests.py | 1 + 4 files changed, 121 insertions(+) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 9a86c06832..15e59ebce9 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -148,6 +148,7 @@ FluxTransfromerModelPatcher, Gemma2ModelPatcher, Gemma3LMModelPatcher, + Glm4MoePatcher, GptJModelPatcher, GptNeoModelPatcher, GptNeoxModelPatcher, @@ -3931,6 +3932,21 @@ class GLM4OpenVINOConfig(LlamaOpenVINOConfig): MIN_TRANSFORMERS_VERSION = "4.51.3" +@register_in_tasks_manager( + "glm4_moe", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + ], + library_name="transformers", +) +class Glm4MoeOpenVINOConfig(LlamaOpenVINOConfig): + MIN_TRANSFORMERS_VERSION = "4.57.0" + _MODEL_PATCHER = Glm4MoePatcher + + @register_in_tasks_manager( "granite", *[ diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 75f1498159..a1cf03d662 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -7634,6 +7634,106 @@ def __exit__(self, exc_type, exc_value, traceback): del afmoe_moe.down_projs, afmoe_moe.gate_projs, afmoe_moe.up_projs +def glm4_moe_forward_patched(self, hidden_states): + """ + Vectorized MoE forward for Glm4MoeMoE. + + Replaces the original for-loop over experts with a batched matmul approach, + avoiding data-dependent control flow (if token_indices.numel() > 0) that breaks + torch.jit.trace. Also produces a much smaller OpenVINO graph. + """ + num_experts = self.config.n_routed_experts + batch_size, seq_len, hidden_dim = hidden_states.shape + + # Router: returns topk_indices [B*S, top_k] and topk_weights [B*S, top_k] + topk_indices, topk_weights = self.gate(hidden_states) + + # Build full routing weight matrix [B*S, num_experts] + new_routing_weights = torch.zeros( + batch_size * seq_len, num_experts, dtype=topk_weights.dtype, device=topk_weights.device + ) + new_routing_weights.scatter_(dim=1, index=topk_indices, src=topk_weights) + + hidden_states = hidden_states.view(-1, hidden_dim) + + # Process shared experts + shared_output = self.shared_experts(hidden_states) + + # Vectorized expert computation using batched matmul + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, hidden_dim) + act_fn = self.experts[0].act_fn + + gate = torch.bmm(hidden_states, self.gate_projs) + up = torch.bmm(hidden_states, self.up_projs) + gate_up = act_fn(gate) * up + next_states = torch.bmm(gate_up, self.down_projs) + + next_states = next_states.view(num_experts, batch_size, -1, hidden_dim) + next_states = next_states * new_routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + + shared_output = shared_output.view(batch_size, -1, hidden_dim) + output = shared_output + next_states + return output.view(batch_size, seq_len, hidden_dim) + + +class Glm4MoePatcher(OVDecoderModelPatcher): + """Model patcher for Glm4Moe models (e.g., GLM-4.7-Flash). + + Patches the MoE forward to use vectorized batched matmul instead of + a for-loop over experts with data-dependent conditional branching. + """ + + def __enter__(self): + super().__enter__() + for layer in self._model.model.layers: + if isinstance(layer.mlp, type) or not hasattr(layer.mlp, "experts"): + continue + if not hasattr(layer.mlp, "experts") or not hasattr(layer.mlp, "gate"): + continue + + moe = layer.mlp + num_experts = moe.config.n_routed_experts + moe._orig_forward = moe.forward + moe.forward = types.MethodType(glm4_moe_forward_patched, moe) + + # Fuse expert weights for vectorized batched matmul + moe.gate_projs = ( + torch.concat( + tuple(moe.experts[i].gate_proj.weight.unsqueeze(0) for i in range(num_experts)), + dim=0, + ) + .transpose(1, 2) + .float() + ) + moe.up_projs = ( + torch.concat( + tuple(moe.experts[i].up_proj.weight.unsqueeze(0) for i in range(num_experts)), + dim=0, + ) + .transpose(1, 2) + .float() + ) + moe.down_projs = ( + torch.concat( + tuple(moe.experts[i].down_proj.weight.unsqueeze(0) for i in range(num_experts)), + dim=0, + ) + .transpose(1, 2) + .float() + ) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + if hasattr(layer.mlp, "_orig_forward"): + layer.mlp.forward = layer.mlp._orig_forward + del layer.mlp._orig_forward + if hasattr(layer.mlp, "gate_projs"): + del layer.mlp.gate_projs, layer.mlp.up_projs, layer.mlp.down_projs + + # adopted from https://github.com/huggingface/transformers/blob/v4.57.6/src/transformers/models/llama/modeling_llama.py#L197 class LlamaEagle3Attention(LlamaAttention): """ diff --git a/tests/openvino/test_decoder.py b/tests/openvino/test_decoder.py index 1b4df9de96..96620677de 100644 --- a/tests/openvino/test_decoder.py +++ b/tests/openvino/test_decoder.py @@ -150,6 +150,9 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): if is_transformers_version(">=", "4.55.0") and is_transformers_version("<", "4.58.0"): SUPPORTED_ARCHITECTURES += ("afmoe",) + if is_transformers_version(">=", "4.57.0"): + SUPPORTED_ARCHITECTURES += ("glm4_moe",) + if is_transformers_version("<", "4.56.0"): SUPPORTED_ARCHITECTURES += ("qwen", "chatglm", "chatglm4") @@ -224,6 +227,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "mixtral_awq": 2, "gemma3_text": 2, "glm4": 2, + "glm4_moe": 2, "qwen3": 2, "qwen3_moe": 2, "mamba": 0, diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 85f79801cd..3f1e6927e9 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -215,6 +215,7 @@ "xglm": "optimum-intel-internal-testing/tiny-random-XGLMForCausalLM", "xverse": "optimum-intel-internal-testing/tiny-random-xverse", "glm4": "optimum-intel-internal-testing/tiny-random-glm4", + "glm4_moe": "optimum-intel-internal-testing/tiny-random-glm4-moe", "glm": "optimum-intel-internal-testing/tiny-random-glm-edge", "open-clip": "optimum-intel-internal-testing/tiny-open-clip-model", "open-clip-ov": "optimum-intel-internal-testing/tiny-open-clip-model", From b8d22c37a9b6407ff453b52d2b3af6ee857181c2 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Wed, 18 Feb 2026 14:16:47 +0400 Subject: [PATCH 2/3] Document gml4.7 support --- docs/source/openvino/models.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/openvino/models.mdx b/docs/source/openvino/models.mdx index 10d9f38304..68c64171b9 100644 --- a/docs/source/openvino/models.mdx +++ b/docs/source/openvino/models.mdx @@ -61,6 +61,7 @@ Here is the list of the supported architectures : - Falcon-Mamba - Flaubert - GLM-4 +- GLM-4 MoE (GLM-4.7) - GLM-Edge - GPT-2 - GPT-BigCode From 256598dbbb5395d956d123264d5147f4f7b7d2e5 Mon Sep 17 00:00:00 2001 From: "Kazantsev, Roman" Date: Wed, 18 Feb 2026 18:45:45 +0400 Subject: [PATCH 3/3] Add missed tests for exporting --- tests/openvino/test_export.py | 3 +++ tests/openvino/test_exporters_cli.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/tests/openvino/test_export.py b/tests/openvino/test_export.py index 18811bd121..8dffd19c12 100644 --- a/tests/openvino/test_export.py +++ b/tests/openvino/test_export.py @@ -107,6 +107,9 @@ class ExportModelTest(unittest.TestCase): if is_transformers_version(">=", "4.55.0") and is_transformers_version("<", "4.58.0"): SUPPORTED_ARCHITECTURES.update({"afmoe": OVModelForCausalLM}) + if is_transformers_version(">=", "4.57.0"): + SUPPORTED_ARCHITECTURES.update({"glm4_moe": OVModelForCausalLM}) + EXPECTED_DIFFUSERS_SCALE_FACTORS = { "stable-diffusion-xl": {"vae_encoder": "128.0", "vae_decoder": "128.0"}, "stable-diffusion-3": {"text_encoder_3": "8.0"}, diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 4be27f43e5..4f5ef1a6de 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -152,6 +152,13 @@ class OVCLIExportTestCase(unittest.TestCase): ] ) + if is_transformers_version(">=", "4.57.0"): + SUPPORTED_ARCHITECTURES.extend( + [ + ("text-generation-with-past", "glm4_moe"), + ] + ) + EXPECTED_NUMBER_OF_TOKENIZER_MODELS = { "gpt2": 2, "t5": 2,