[Contribution] GLM4MoeForCausalLM Support#58
Open
lifelongeeek wants to merge 8 commits intoaws-neuron:mainfrom
Open
[Contribution] GLM4MoeForCausalLM Support#58lifelongeeek wants to merge 8 commits intoaws-neuron:mainfrom
lifelongeeek wants to merge 8 commits intoaws-neuron:mainfrom
Conversation
Adds NXD inference support for GLM-4.5 MoE (Glm4MoeForCausalLM) models. Based on the DeepSeek architecture with group-limited routing, sigmoid activation, and optional partial RoPE. Key components: - NeuronGlm4MoeForCausalLM: top-level CausalLM model class - NeuronGlm4MoeModel: transformer body with dense + MoE layer selection - NeuronGlm4MoeAttention: multi-head GQA with partial RoPE support - NeuronGlm4MoeDecoderLayer: decoder layer dispatching dense vs. MoE MLP - Glm4MoeInferenceConfig: config loader with Glm4Moe-specific field mapping - NeuronGlm4MoeRouter: sigmoid-based group-limited top-k routing - initialize_glm4_moe_module: wires router + ExpertMLPsV2 + SharedExperts Supports tp_degree/moe_tp_degree/moe_ep_degree sharding via NXD process groups. Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
- examples/generation_glm4_moe_demo.py: compile and run inference demo with configurable tp_degree, seq_len, and model/traced-model paths - test_glm4_moe_accuracy.py: CPU (HuggingFace) vs Neuron token-matching accuracy test; passes with greedy decoding (top_k=1) - create_glm4_tiny_random.py: utility to create a small random-weight GLM-4.5 MoE checkpoint for local testing without downloading the full model Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
- docs/glm4_moe_implementation.md: architecture overview, module breakdown, weight conversion details, sharding configuration guide - docs/glm4_moe_testing.md: step-by-step testing guide with tiny random model, expected outputs, and troubleshooting notes Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
- Add tensor_capture_hook = kwargs.get('tensor_capture_hook', None) to
prepare_inputs_for_generation to fix NameError when tensor_capture_hook
was referenced before assignment
- Add import inspect
- Remove unconditional tensor_capture_hook from model_inputs dict
- Conditionally include tensor_capture_hook only when the model's forward()
signature accepts it (multimodal models only), preventing TypeError for
text-only models like GLM-4.5 MoE
…tegration - Add contrib/models/glm4_moe/ following NxDI contrib structure - Source model: src/glm4_moe/modeling_glm4_moe.py (Glm4MoeInferenceConfig, NeuronGlm4MoeForCausalLM, partial RoPE, sigmoid group routing, shared experts) - Unit tests: test/unit/ — router top-k, partial RoPE, decoder dispatch (49 tests, all PASS) - Integration tests: test/integration/ — compile + check_accuracy_logits_v2 with reduced 2-layer random-weight config on trn2.3xlarge (PASS) - Examples: examples/generation_glm4_moe_demo.py with CLI args - vLLM integration: vllm/run_offline_inference.py + start-vllm-server.sh - README.md with architecture details, compatibility matrix, validation results Tested on trn2.3xlarge (LNC=2, TP=2), NxDI 2.21+, transformers>=4.56.0
Following the pattern of examples/generation_qwen3_moe_demo.py. Targets trn2.48xlarge (tp=32, moe_tp=4, moe_ep=8, bs=4, seq=4096). Adds contrib src path via sys.path for the contrib-based model.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds NeuronX Distributed Inference (NxDI) support for GLM-4.5 MoE (
Glm4MoeForCausalLM) — a ~70B Mixture-of-Experts language model from ZhipuAI / Tsinghua University. Follows the contrib contribution guidelines and uses PR #34 as a structural reference.Model Information
Model Name: GLM-4.5 MoE (Air variant)
HuggingFace:
zai-org/GLM-4.5-AirModel Architecture: Decoder-only MoE transformer with partial RoPE, sigmoid group-limited routing, shared experts, and dense-first layers.
Parameters: ~70B total, ~9B active per token (128 routed experts, top-8 per token)
Checklist
Please ensure your PR includes the following items. Refer to the contrib/CONTRIBUTING.md for detailed guidelines.
Required Components
Accuracy Test (
contrib/models/glm4_moe/test/integration/test_model.py)check_accuracy_logits_v2with a reduced 2-layer random-weight modeldivergence_difference_tol=0.001)README.md (
contrib/models/glm4_moe/README.md)Source Code (
contrib/models/glm4_moe/src/glm4_moe/)modeling_glm4_moe.py: full NxDI implementationcontrib/models/glm4_moe/Optional Components
Unit Tests (
contrib/models/glm4_moe/test/unit/)test_router.py: sigmoid group-limited top-k routing (10 tests, CPU-only)test_attention.py: partial RoPE, QK norm, GQA (24 tests, CPU-only)test_decoder.py: dense vs. MoE layer dispatch viafirst_k_dense_replace(15 tests, CPU-only)vLLM Integration (
contrib/models/glm4_moe/vllm/)Folder Structure
Architecture Notes
GLM-4.5 MoE has several differences from standard MoE models (e.g. Qwen3MoE) that required custom implementations:
partial_rotary_factor=0.5)head_dimonlyattention_bias=True)n_group,topk_group)e_score_correction_bias(frozen buffer)norm_topk_prob+routed_scaling_factorn_shared_experts=1(always active)first_k_dense_replace=1Testing
How did you test this change?
Tests included:
check_accuracy_logits_v2Test Results (2026-03-06, trn2.3xlarge, NxDI 2.21+, transformers 4.56.2):
check_accuracy_logits_v2divergence_difference_tol=0.001)