diff --git a/contrib/models/Trinity/README.md b/contrib/models/Trinity/README.md new file mode 100644 index 0000000..d3e91fd --- /dev/null +++ b/contrib/models/Trinity/README.md @@ -0,0 +1,735 @@ +# Contrib Model: Trinity + +NeuronX Distributed Inference implementation of the Trinity model family (AfmoeForCausalLM) from Arcee AI. A single unified implementation supports all three model sizes. + +## Model Family + +| Model | HuggingFace ID | Total Params | Active Params | Instance | +|-------|----------------|-------------|---------------|----------| +| **Nano** | `arcee-ai/Trinity-Nano-Preview` | ~6B | ~1B | inf2.xlarge* / inf2.8xlarge / trn2.3xlarge | +| **Mini** | `arcee-ai/Trinity-Mini` | ~26B | ~4.5B | trn2.3xlarge (TP=4) | +| **Large** | `arcee-ai/Trinity-Large-Preview` | ~250B | ~15B | trn2.48xlarge (TP=64) | + +**License:** Apache 2.0 + +## Architecture Details + +| Feature | Nano | Mini | Large | +|---------|------|------|-------| +| Layers | 56 (2 dense + 54 MoE) | 32 (2 dense + 30 MoE) | 60 (6 dense + 54 MoE) | +| Hidden Size | 1024 | 2048 | 3072 | +| Attention Heads | 8 | 32 | 48 | +| KV Heads (GQA) | 2 | 4 | 8 | +| Head Dim | 128 | 128 | 128 | +| Experts per MoE layer | 128 | 128 | 256 | +| Active Experts (TopK) | 8 | 8 | 4 | +| Shared Experts | 1 | 1 | 1 | +| Dense Intermediate | 3072 | 6144 | 12288 | +| MoE Intermediate | 256 | 1024 | 3072 | +| Sliding Window | 2048 | 2048 | 4096 | +| Max Position Embeddings | 131,072 | 131,072 | 262,144 | +| Vocabulary | 200,192 | 200,192 | 200,192 | +| Routing | Sigmoid + normalize (scale baked into weights) | +| Activation | SiLU gated MLP (`glu_type="glu"`) | +| Position Encoding | RoPE (sliding attention layers only) | +| Normalization | RMSNorm (4 per layer) | + +### Unique Architecture Features + +- **Mixed Attention:** Alternating sliding window and full attention (every 4th layer) +- **Gated Attention:** Sigmoid gate applied to attention output before o_proj +- **QK Normalization:** Per-head RMSNorm on Q and K +- **muP Scaling:** Embedding output scaled by hidden_size^0.5 +- **Expert Bias:** Learned bias added to routing scores for expert selection +- **Conditional RoPE:** Rotary embeddings applied only to sliding attention layers + +## Validation Results + +**Validated:** 2026-03-06 +**SDK:** NxDI 0.8.0, neuronx-cc 2.23.6484, torch-neuronx 2.9.0.2.12, transformers 4.57.6 (SDK 2.28) +**Benchmarked:** 2026-03-06 (Nano + Mini, trn2.3xlarge + inf2.8xlarge, with bucketing) +**Neuron vs CPU accuracy verified:** 2026-03-06 (Nano, trn2.3xlarge TP=2) + +All results below are from the **unified `modeling_trinity.py`** (this code), SDK 2.28 with bucketing enabled. Fused TKG results from SDK 2.28. + +### Neuron vs CPU Accuracy (Trinity-Nano, TP=2, SDK 2.28) + +Full-precision CPU reference comparison using HuggingFace `AutoModelForCausalLM` (bf16) vs Neuron compiled model: + +| Prompt | Neuron Top-1 | CPU Top-1 | Match | Top-20 Cosine | Full-Vocab Cosine | +|--------|-------------|-----------|-------|---------------|-------------------| +| "Hello, how are you?" | I | I | YES | 0.9995 | 0.9398 | +| "Explain quantum computing in simple terms." | Answer | Answer | YES | 0.9995 | 0.9905 | +| "Write a Python function that calculates the Fibonacci sequence." | The | The | YES | 0.9999 | 0.9835 | + +**Summary:** 3/3 top-1 match, avg top-20 cosine similarity 0.9996, avg full-vocab cosine similarity 0.9712. Lower full-vocab cosine (esp. prompt 1 at 0.94) is expected due to MoE bf16 accumulation in NKI blockwise matmul kernel -- the tail of the logit distribution diverges while the top of the distribution is preserved. + +### Trinity-Nano on trn2.3xlarge (TP=2, LNC=2) + +| Metric | Result | +|--------|--------| +| Compilation Time | 5.1 min | +| Load Time | 2.2 min | +| Forward Pass Latency | ~0.50s | + +**First-token predictions:** + +| Prompt | Top-1 Token | Logit | Top-5 | +|--------|-------------|-------|-------| +| "Hello, how are you?" | I | 17.75 | I, Hello, How | +| "Explain quantum computing in simple terms." | Answer | 21.00 | Answer, Quantum, What | +| "Write a Python function that calculates the Fibonacci sequence." | The | 24.75 | The, Your, Additionally | + +**Generation (5 tokens):** +- "Hello, how are you?" -> "I am fine, thank" +- "Explain quantum computing in simple terms." -> "Answer: Quantum computing uses" + +### Trinity-Mini on trn2.3xlarge (TP=4, LNC=2) + +| Metric | Result | +|--------|--------| +| Compilation Time | 4.9 min | +| Load Time | 4.1 min (from pre-compiled) | +| Forward Pass Latency | ~0.37s | + +**First-token predictions:** + +| Prompt | Top-1 Token | Logit | Top-5 | +|--------|-------------|-------|-------| +| "Hello, how are you?" | I | 20.12 | I, This, My | +| "Explain quantum computing in simple terms." | What | 20.75 | What, How, Quantum | +| "Write a Python function that calculates the Fibonacci sequence." | The | 28.00 | The, Your, It | + +**Generation (5 tokens):** +- "Hello, how are you?" -> "I'm fine, thank" +- "Explain quantum computing in simple terms." -> "What are the key differences" + +### Trinity-Nano on inf2.8xlarge (TP=1, no LNC) + +| Metric | Result | +|--------|--------| +| Compilation Time | Reused from trn2.3xlarge | +| Load Time | 47.7s | +| Forward Pass Latency | ~0.73s | + +**Note:** inf2.xlarge (16GB system RAM) cannot run Nano with standard loading -- OOM killed at 15.3GB RSS during weight loading. However, **pre-sharded weights bypass this entirely** (see Pre-Sharded Deployment below). inf2.8xlarge (123GB system RAM) works with standard loading at TP=1. NxDI auto-converts GQA to MHA when `TP=1` and `num_kv_heads=2`. + +### Trinity-Large on trn2.48xlarge (TP=64, LNC=2) + +| Metric | Result | +|--------|--------| +| Compilation Time | 8.6 min | +| Load Time | 15.6 min | +| Forward Pass Latency | ~1.15s | + +**First-token predictions:** + +| Prompt | Top-1 Token | +|--------|-------------| +| "Hello, how are you?" | I | +| "Explain quantum computing in simple terms." | Quantum | +| "Write a Python function that calculates the Fibonacci sequence." | The | + +**Notes:** +- TP=32 is insufficient -- sharded weights consume ~23.5GB per logical NeuronCore, exceeding the ~24GB HBM per physical NC and leaving no room for scratchpad/KV cache. TP=64 (all 64 logical cores on trn2.48xlarge) is required. +- Model is ~516GB on disk (31 safetensors in bf16). Root EBS volume (600GB) is insufficient -- NVMe instance store is required for model storage (`/mnt/nvme/`). +- Set `TMPDIR`, `BASE_COMPILE_WORK_DIR`, and `NEURON_COMPILE_CACHE_URL` to NVMe paths to avoid filling root disk during compilation. + +## Performance Benchmarks + +**SDK 2.28**, seq_len=2048, BF16, bucketing enabled, measured with proper CTE+TKG pipeline (KV cache). TTFT averaged over 10 iterations (3 warmup). TKG averaged over 28 tokens (3 warmup discarded from 32 generated). Throughput = `batch_size * (1000 / avg_tkg_ms)` (steady-state TKG-based). + +### Trinity-Nano (~6B total, ~1B active) + +| Instance | TP | BS | TTFT (ms) | TKG (ms/tok) | Throughput (tok/s) | Per-seq (tok/s) | Compile | +|----------|-----|------|-----------|-------------|-------------------|-----------------|---------| +| inf2.xlarge* | 1 | 1 | 706 | 9.0 | 112 | 112 | N/A (pre-sharded) | +| inf2.8xlarge | 1 | 1 | 706 | 9.2 | 109 | 109 | 7.9 min | +| inf2.8xlarge | 1 | 2 | 901 | 13.3 | 150 | 75 | 8.8 min | +| inf2.8xlarge | 1 | 4 | 1347 | 20.8 | 192 | 48 | 11.7 min | +| inf2.8xlarge | 2 | 1 | 516 | 7.6 | 131 | 131 | 4.8 min | +| inf2.8xlarge | 2 | 2 | 674 | 9.4 | 212 | 106 | 6.6 min | +| inf2.8xlarge | 2 | 4 | 993 | 13.6 | 294 | 74 | 8.5 min | +| trn2.3xlarge | 2 | 1 | 516 | 10.8 | 93 | 93 | 4.9 min | +| trn2.3xlarge | 2 | 2 | 680 | 13.9 | 144 | 72 | 7.4 min | +| trn2.3xlarge | 2 | 4 | 930 | 16.3 | 245 | 61 | 9.4 min | +| trn2.3xlarge | 4 | 1 | 476 | 9.2 | 109 | 109 | 5.0 min | +| trn2.3xlarge | 4 | 2 | 600 | 12.4 | 161 | 81 | 6.5 min | +| trn2.3xlarge | 4 | 4 | 817 | 14.9 | 269 | 67 | 8.5 min | + +**Whole-instance throughput** (TP x DP = all available cores): + +| Instance | Config | BS | Throughput (tok/s) | Notes | +|----------|--------|----|--------------------|-------| +| inf2.8xlarge | TP=2 DP=1 | 4 | **294** | Best (use all 2 cores, single replica) | +| inf2.8xlarge | TP=1 DP=2 | 1 | 218* | Calculated: 2 x 109 tok/s | +| trn2.3xlarge | TP=4 DP=1 | 4 | **269** | Best (use all 4 cores, single replica) | +| trn2.3xlarge | TP=2 DP=2 | 1 | 186* | Calculated: 2 x 93 tok/s | + +*DP throughput calculated mathematically (replicas are independent on separate NeuronCores). + +**Recommended config**: inf2.8xlarge TP=2 BS=4 for max throughput (294 tok/s) with lowest TTFT on inf2 (516ms), or trn2.3xlarge TP=4 BS=4 for max throughput on trn2 (269 tok/s). *inf2.xlarge requires pre-sharded weights (see Pre-Sharded Deployment). + +### Trinity-Mini (~26B total, ~4.5B active) + +| Instance | TP | BS | TTFT (ms) | TKG (ms/tok) | Throughput (tok/s) | Per-seq (tok/s) | Compile | +|----------|-----|------|-----------|-------------|-------------------|-----------------|---------| +| trn2.3xlarge | 4 | 1 | 371 | 11.8 | 85 | 85 | 3.9 min | +| trn2.3xlarge | 4 | 2 | 598 | 11.5 | 174 | 87 | 6.8 min | +| trn2.3xlarge | 4 | 4 | 805 | 13.6 | 295 | 74 | 9.1 min | + +Mini requires TP=4 (all cores on trn2.3xlarge), so DP is not applicable. + +**Recommended config**: trn2.3xlarge TP=4 BS=4 for best throughput/latency balance (295 tok/s, 13.6ms TKG), or BS=1 for lowest TTFT (371 ms). + +### Trinity-Large (~250B total, ~15B active) + +| Instance | TP | BS | TTFT (ms) | TKG (ms/tok) | Throughput (tok/s) | Per-seq (tok/s) | Compile | Load | +|----------|-----|------|-----------|-------------|-------------------|-----------------|---------|------| +| trn2.48xlarge | 64 | 1 | 1161 | 14.7 | 68 | 68 | 9.2 min | 851s | +| trn2.48xlarge | 64 | 2 | 1657 | 19.1 | 102 | 51 | 11.1 min | 867s | +| trn2.48xlarge | 64 | 4 | 1980 | 29.0 | 137 | 34 | 14.0 min | 873s | + +**Recommended config**: trn2.48xlarge TP=64 BS=1 for lowest latency (1.16s TTFT, 14.7ms TKG), or BS=4 for max aggregate throughput (137 tok/s). NVMe instance store required for model storage (~743GB on disk). + +### GPU Comparison (g5.12xlarge, 4x NVIDIA A10G) + +Benchmarked via vLLM 0.16.0, bf16. Shows single-request latency and aggregate throughput at various concurrency levels. GPU uses continuous batching with CUDA graphs (PagedAttention v2). + +**Trinity-Nano** on 1x A10G (TP=1): + +| Concurrency | TTFT (ms) | TKG (ms/tok) | Throughput (tok/s) | +|-------------|-----------|-------------|-------------------| +| 1 | 20 | 6.9 | 137 | +| 4 | 30 | -- | 400 | +| 16 | 56 | -- | 1140 | +| 64 | 65 | -- | 2782 | + +**Trinity-Mini** on 4x A10G (TP=4, max_num_seqs=32): + +| Concurrency | TTFT (ms) | TKG (ms/tok) | Throughput (tok/s) | +|-------------|-----------|-------------|-------------------| +| 1 | 24 | 6.7 | 138 | +| 4 | 42 | -- | 337 | +| 16 | 79 | -- | 857 | + +**GPU vs Neuron** (single-request, BS=1): + +| Model | Metric | GPU (A10G) | Neuron (best) | Notes | +|-------|--------|-----------|---------------|-------| +| Nano | TTFT | 20 ms | 476 ms | GPU 24x faster (CUDA graphs vs CTE forward) | +| Nano | TKG | 6.9 ms | 7.6 ms | GPU 1.1x faster (inf2 TP=2) | +| Mini | TTFT | 24 ms | 371 ms | GPU 15x faster | +| Mini | TKG | 6.7 ms | 11.5 ms | GPU 1.7x faster | + +GPU TTFT advantage comes from vLLM's CUDA graph capture eliminating kernel launch overhead. Neuron TTFT is dominated by the CTE forward pass through compiled HLO graphs. A vLLM-Neuron serving stack would narrow this gap. + +### Key Observations + +- **Batching scales well**: BS=4 gives 2.0-3.5x aggregate throughput vs BS=1, with TKG latency increase of 30-100% +- **Mini is fastest TTFT**: 371ms at TP=4 BS=1, vs 476ms (Nano TP=4) and 1161ms (Large TP=64) +- **inf2.8xlarge TP=2 is best for Nano**: 294 tok/s (BS=4) with 516ms TTFT -- better throughput than trn2 TP=4 (269 tok/s) +- **TP=2 on inf2 outperforms TP=1**: 21-53% higher throughput across batch sizes (TKG drops from 9.2ms to 7.6ms at BS=1) +- **DP gives higher throughput than TP for small models**: trn2 TP=2 DP=2 at BS=1 yields 186 tok/s vs TP=4 DP=1 BS=1 at 109 tok/s, but at higher per-token latency +- **TP=4 vs TP=2 on trn2**: TP=4 has 15-17% lower TKG latency (better per-sequence), but TP=2 enables DP=2 for higher aggregate throughput +- **Compile time grows with batch size**: BS=4 takes 8.5-9.4 min vs 3.9-5.0 min (BS=1) +- **Large TKG is comparable to smaller models**: 14.7ms despite 250B total params -- MoE activates only 15B +- **Load time dominates Large**: 14.2 min to shard 516GB across 64 cores; compile is only 9.2 min +- **GPU has massive TTFT advantage**: 20-24ms vs 371-706ms (15-35x) due to CUDA graphs vs compiled HLO forward pass +- **GPU aggregate throughput scales with concurrency**: 2782 tok/s (Nano, 64 concurrent) vs 294 tok/s (Neuron inf2 TP=2 BS=4) -- continuous batching vs static batching +- **GPU TKG is 1.1-1.7x faster**: 6.7-6.9ms vs 7.6-11.8ms on Neuron +- **inf2.xlarge cannot run Nano**: 16GB system RAM is insufficient for 12GB bf16 model weight loading (OOM during sharding), even with pre-compiled artifacts. **Pre-sharded weights solve this** (1.39 GB RSS, 112 tok/s). + +## Usage + +### Trinity-Nano-Preview (~6B total, ~1B active) + +```python +import torch +from transformers import AutoTokenizer +from neuronx_distributed_inference.models.config import MoENeuronConfig + +from src.modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig + +model_path = "/path/to/arcee-ai/Trinity-Nano-Preview/" +compiled_path = "/path/to/compiled-nano/" + +neuron_config = MoENeuronConfig( + tp_degree=2, # Nano is small enough for TP=2 + batch_size=1, + seq_len=2048, # Max tested: 40960 (TP=2), 49152 (TP=4) + torch_dtype=torch.bfloat16, +) + +config = TrinityInferenceConfig.from_pretrained( + model_path, neuron_config=neuron_config +) + +model = NeuronTrinityForCausalLM(model_path, config) +model.compile(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained( + model_path, padding_side="right", trust_remote_code=True +) +``` + +**With bucketing** (for variable-length inputs): + +```python +neuron_config = MoENeuronConfig( + tp_degree=2, + batch_size=1, + seq_len=4096, + torch_dtype=torch.bfloat16, + enable_bucketing=True, + apply_seq_ids_mask=True, + context_encoding_buckets=[2048, 4096], # Must be >= sliding_window (2048) + token_generation_buckets=[2048, 4096], +) +``` + +**Instance:** inf2.xlarge (TP=1, pre-sharded weights required), inf2.8xlarge (TP=1), or trn2.3xlarge (TP=2/4). + +### Trinity-Mini (~26B total, ~4.5B active) + +```python +import torch +from transformers import AutoTokenizer +from neuronx_distributed_inference.models.config import MoENeuronConfig + +from src.modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig + +model_path = "/path/to/arcee-ai/Trinity-Mini/" +compiled_path = "/path/to/compiled-mini/" + +neuron_config = MoENeuronConfig( + tp_degree=4, + batch_size=1, + seq_len=2048, # Max tested: 32768 + torch_dtype=torch.bfloat16, +) + +config = TrinityInferenceConfig.from_pretrained( + model_path, neuron_config=neuron_config +) + +model = NeuronTrinityForCausalLM(model_path, config) +model.compile(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained( + model_path, padding_side="right", trust_remote_code=True +) +``` + +**With bucketing** (for variable-length inputs): + +```python +neuron_config = MoENeuronConfig( + tp_degree=4, + batch_size=1, + seq_len=4096, + torch_dtype=torch.bfloat16, + enable_bucketing=True, + apply_seq_ids_mask=True, + context_encoding_buckets=[2048, 4096], # Must be >= sliding_window (2048) + token_generation_buckets=[2048, 4096], +) +``` + +**Instance:** trn2.3xlarge (TP=4). Does NOT fit inf2.8xlarge (~48GB bf16). + +### Trinity-Large-Preview (~250B total, ~15B active) + +```python +import torch +from transformers import AutoTokenizer +from neuronx_distributed_inference.models.config import MoENeuronConfig + +from src.modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig + +model_path = "/path/to/arcee-ai/Trinity-Large-Preview/" +compiled_path = "/path/to/compiled-large/" + +neuron_config = MoENeuronConfig( + tp_degree=64, + batch_size=1, + seq_len=4096, # Max tested: 30720 + torch_dtype=torch.bfloat16, +) + +config = TrinityInferenceConfig.from_pretrained( + model_path, neuron_config=neuron_config +) + +model = NeuronTrinityForCausalLM(model_path, config) +model.compile(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained( + model_path, padding_side="right", trust_remote_code=True +) +``` + +**With bucketing** (for variable-length inputs): + +```python +neuron_config = MoENeuronConfig( + tp_degree=64, + batch_size=1, + seq_len=8192, + torch_dtype=torch.bfloat16, + enable_bucketing=True, + apply_seq_ids_mask=True, + context_encoding_buckets=[4096, 8192], # Must be >= sliding_window (4096) + token_generation_buckets=[4096, 8192], +) +``` + +**Instance:** trn2.48xlarge only (TP=64, capacity block required, NVMe instance store for model storage). + +## Pre-Sharded Deployment (inf2.xlarge) + +The standard NxDI load path downloads the full HuggingFace checkpoint into CPU RAM, converts it to Neuron format, and shards weights by TP rank. For Trinity-Nano (~12GB bf16), this requires 15+ GB system RAM — exceeding inf2.xlarge's 16GB. + +**Pre-sharded weights** bypass this entirely. During compilation on a larger instance, setting `save_sharded_checkpoint=True` saves per-rank weight files (`weights/tp{rank}_sharded_checkpoint.safetensors`). During loading, NxDI reads directly from these files without loading the full HF checkpoint. + +### Workflow + +1. **Compile on a larger instance** (inf2.8xlarge or trn2): + +```python +neuron_config = MoENeuronConfig( + tp_degree=1, + batch_size=1, + seq_len=2048, + torch_dtype=torch.bfloat16, + save_sharded_checkpoint=True, # Key flag +) +config = TrinityInferenceConfig.from_pretrained(model_path, neuron_config=neuron_config) +model = NeuronTrinityForCausalLM(model_path, config) +model.compile(compiled_path) # Saves model.pt, neuron_config.json, weights/tp0_sharded_checkpoint.safetensors +``` + +2. **Upload the compiled artifact** (model.pt + neuron_config.json + weights/) to HuggingFace or S3. + +3. **Load on inf2.xlarge** (16GB RAM): + +```python +neuron_config = MoENeuronConfig( + tp_degree=1, + batch_size=1, + seq_len=2048, + torch_dtype=torch.bfloat16, + save_sharded_checkpoint=True, # Must match compilation +) +config = TrinityInferenceConfig.from_pretrained(model_path, neuron_config=neuron_config) +model = NeuronTrinityForCausalLM(model_path, config) +model.load(compiled_artifact_path) # Reads directly from sharded files — only 1.4 GB RSS +``` + +### Pre-Sharded Results (inf2.xlarge) + +| Instance | TP | BS | TTFT (ms) | TKG (ms/tok) | Throughput (tok/s) | Load | Peak RSS | +|----------|-----|------|-----------|-------------|-------------------|------|----------| +| inf2.xlarge | 1 | 1 | 706 | 9.0 | 112 | 18.4s | 1.39 GB | + +- **Memory**: 1.9 GB system RAM used (12.6%) vs 15+ GB that causes OOM with standard loading +- **Performance**: Identical to inf2.8xlarge (706 vs 707ms TTFT, 9.0 vs 9.1ms TKG) +- **Load time**: 18.4s (comparable to inf2.8xlarge's 11s pre-sharded / 48s standard) + +### Available Pre-Compiled Artifacts + +| Model | TP | BS | seq_len | SDK | HuggingFace Repo | +|-------|-----|------|---------|-----|-----------------| +| Nano | 1 | 1 | 2048 | 2.28 | `jburtoft/Trinity-Nano-Neuron-TP1` | + +## Caveats + +1. **`padding_side="right"` required** -- NKI flash attention kernel does not support left-padding. Always set `padding_side="right"` on the tokenizer. + +2. **MoE v2 bf16 accumulation** -- The NxDI MoE v2 NKI kernel accumulates in bf16, causing ~23x more divergence per MoE layer compared to dense layers. Full-vocab cosine similarity is ~0.936, but top-1 token accuracy is preserved. A fix ticket has been filed. + +3. **`trust_remote_code=True` required** -- Trinity uses a custom `AfmoeForCausalLM` architecture not in standard transformers. The HuggingFace download requires `trust_remote_code=True`. + +4. **transformers version sensitivity** -- Use transformers 4.56.2 with SDK 2.27. Reference outputs may vary across transformers versions. + +5. **GLU type** -- Trinity uses `SiLU(gate) * up` which maps to NxDI's `glu_type="glu"`, NOT `"swiglu"`. This is handled automatically by the config class. + +6. **route_scale baked into weights** -- NxDI MoE v2 does not support `route_scale` natively. The scale is baked into routed expert `down_proj` weights during weight conversion. Shared expert weights are NOT scaled. + +7. **Gate padding at high TP** -- When `num_attention_heads` is not evenly divisible by `tp_degree` (e.g., Large at TP=64: 48/64), gate weights are padded with interleaved layout matching the Q projection. This is handled automatically during weight conversion. + +8. **Mixed attention KV cache (TrinityKVCacheManager)** -- Trinity uses mixed attention (alternating sliding window and full attention every 4th layer). The standard `KVCacheManager` applies a single `sliding_window` modulation to all layers, which causes out-of-bounds writes for full-attention layers with larger KV caches. `TrinityKVCacheManager` provides per-layer KV cache management: uniform `max_length` cache buffers (safe for CTE `fill_prefix`), per-layer scatter modulation during TKG (sliding: `pos % sliding_window`, global: no modulation), and per-layer KV read slicing (sliding: `sliding_window`, global: `max_length`). This is enabled automatically. + +## Maximum Sequence Length + +Validated with token generation (5 tokens per prompt) at each max seq_len: + +| Model | Instance | TP | Max seq_len | Compile | Load | Gen Latency | +|-------|----------|-----|------------|---------|------|-------------| +| Nano | trn2.3xlarge | 2 | **40,960** | 1.5 min | 3.2 min | 2.4s/tok | +| Nano | trn2.3xlarge | 4 | **49,152** | 1.4 min | 1.4 min | 2.4s/tok | +| Mini | trn2.3xlarge | 4 | **32,768** | 0.9 min | 7.7 min | 2.4s/tok | +| Large | trn2.48xlarge | 64 | **30,720** | 1.6 min | 16.5 min | 2.9s/tok | + +Compile times above are for cache-hit runs. First compilation at each seq_len takes 5-25 min. + +Higher TP gives more headroom for KV cache (Nano TP=4 fits 49K vs 41K at TP=2). The failure mode at the limit is compilation timeout, not OOM. + +## Bucketing + +Bucketing compiles separate NEFFs for different sequence length buckets, enabling efficient inference for variable-length inputs without padding every input to `seq_len`. + +### Configuration + +Enable bucketing by adding `enable_bucketing=True` and `apply_seq_ids_mask=True` to the neuron config: + +```python +neuron_config = MoENeuronConfig( + tp_degree=2, + batch_size=1, + seq_len=4096, + torch_dtype=torch.bfloat16, + enable_bucketing=True, + apply_seq_ids_mask=True, # Required for mixed attention bucketing + context_encoding_buckets=[2048, 4096], # Optional: custom CTE buckets + token_generation_buckets=[2048, 4096], # Optional: custom TKG buckets +) +``` + +If `context_encoding_buckets` / `token_generation_buckets` are omitted, NxDI auto-generates power-of-2 buckets from 128 to `seq_len`. + +### Restrictions + +1. **`apply_seq_ids_mask=True` is required.** Without it, TKG fails with a shape mismatch: full-attention layers have KV caches sized to `seq_len` but the TKG attention mask is only sized to the bucket value. + +2. **All context encoding buckets must be >= `sliding_window`.** The `get_last_kv_window` function in windowed attention gathers indices up to `sliding_window - 1` from the K/V tensors. A CTE bucket smaller than `sliding_window` produces K/V tensors too short for those indices, causing an out-of-bounds memory access. For Trinity-Nano and Trinity-Mini (`sliding_window=2048`), the smallest CTE bucket must be at least 2048. For Trinity-Large (`sliding_window=4096`), the smallest CTE bucket must be at least 4096. Short prompts are padded to the smallest qualifying bucket. + +3. **Token generation buckets have no minimum.** TKG buckets control the `n_positions` dimension for the attention mask. The `apply_seq_ids_mask` flag dynamically pads the mask when needed. + +### Validated Bucketing Results + +**Trinity-Nano** on trn2.3xlarge (TP=2, seq_len=4096, buckets=[2048, 4096]): + +| Prompt | Input Tokens | Top-1 | Forward | Status | +|--------|-------------|-------|---------|--------| +| "Hello!" | 3 | "I" | 0.51s | PASS | +| "What is the capital of France?" | 8 | "(" | 0.50s | PASS | +| (20-token prompt) | 20 | "Provide" | 0.50s | PASS | +| (124-token prompt) | 124 | `<\|im_end\|>` | 0.50s | PASS | + +- Compile: 7.0 min (4 NEFFs: 2 CTE + 2 TKG) +- Load: 37.6s +- Token generation: ~0.5s/tok (5 tokens generated per prompt) +- Backward compatible with non-bucketing flow + +**Trinity-Mini** on trn2.3xlarge (TP=4, seq_len=4096, buckets=[2048, 4096]): + +| Prompt | Input Tokens | Top-1 | Forward | Status | +|--------|-------------|-------|---------|--------| +| "Hello!" | 3 | "I" | 0.37s | PASS | +| "What is the capital of France?" | 8 | "Paris" | 0.37s | PASS | +| (20-token prompt) | 20 | "Also" | 0.37s | PASS | +| (124-token prompt) | 124 | "Conclude" | 0.37s | PASS | + +- Compile: 5.5 min (4 NEFFs: 2 CTE + 2 TKG) +- Load: 78.3s (1.3 min) +- Token generation: ~0.4s/tok (5 tokens generated per prompt) + +**Trinity-Large** on trn2.48xlarge (TP=64, seq_len=8192, buckets=[4096, 8192]): + +| Prompt | Input Tokens | Top-1 | Forward | Status | +|--------|-------------|-------|---------|--------| +| "Hello!" | 3 | "I" | 1.16s | PASS | +| "What is the capital of France?" | 8 | "\n" | 1.15s | PASS | +| (20-token prompt) | 20 | "\n" | 1.15s | PASS | +| (124-token prompt) | 124 | `<\|end_of_text\|>` | 1.15s | PASS | + +- Compile: 12.5 min (4 NEFFs: 2 CTE + 2 TKG) +- Load: 15.7 min +- Token generation: ~1.2s/tok (5 tokens generated per prompt) + +### How It Works + +Trinity's mixed attention (sliding window + full attention every 4th layer) requires three mechanisms working together for bucketing: + +1. **`has_mixed_attn=True`** on the model base tells the framework to generate dual attention masks: a global causal mask for full-attention layers and a local windowed mask for sliding layers. The decoder layer selects the appropriate mask per layer type (Llama4 pattern). + +2. **`apply_seq_ids_mask=True`** enables dynamic mask padding in `compute_for_token_gen`. When a full-attention layer's KV cache (sized `max_length`) exceeds the TKG bucket's attention mask (sized `n_positions`), the mask is automatically padded with zeros. + +3. **`TrinityKVCacheManager`** replaces the standard `KVCacheManager` with per-layer awareness. All layers share uniform `max_length` cache buffers (required for CTE `fill_prefix` safety), but during TKG, scatter indices are modulated per-layer (sliding: `position % sliding_window`, global: raw position) and KV reads are sliced per-layer (sliding: `sliding_window`, global: `max_length`). + +## Compatibility Matrix + +| Model | Instance | TP | LNC | Max seq_len | Status | +|-------|----------|-----|-----|------------|--------| +| Nano | inf2.xlarge | 1 | N/A | -- | PASS with pre-sharded weights (standard load OOMs at 16GB system RAM) | +| Nano | inf2.8xlarge | 1 | N/A | -- | Validated (not seq_len tested) | +| Nano | inf2.8xlarge | 2 | N/A | -- | Validated (best throughput on inf2) | +| Nano | trn2.3xlarge | 2 | 2 | 40,960 | Validated | +| Nano | trn2.3xlarge | 4 | 2 | 49,152 | Validated | +| Mini | inf2.8xlarge | -- | -- | -- | Does NOT fit | +| Mini | trn2.3xlarge | 4 | 2 | 32,768 | Validated | +| Large | trn2.48xlarge | 32 | 2 | -- | FAIL (HBM OOM per NC) | +| Large | trn2.48xlarge | 64 | 2 | 30,720 | Validated | + +### Minimum Requirements by Model Size + +| Model | Min HBM | Min TP | Min Instance | +|-------|---------|--------|-------------| +| Nano | ~12GB bf16 | 1 | inf2.xlarge (pre-sharded weights) or inf2.8xlarge | +| Mini | ~48GB bf16 | 4 | trn2.3xlarge | +| Large | ~500GB bf16 | 64 | trn2.48xlarge (capacity block, NVMe storage) | + +### SDK Configuration + +| Component | SDK 2.27 | SDK 2.28 | +|-----------|----------|----------| +| NxDI | 0.7.15063 | 0.8.0 | +| neuronx-cc | 2.22.12471 | 2.23.6484 | +| torch-neuronx | 2.9.0.2.11 | 2.9.0.2.12.22436 | +| torch | 2.9.0 | 2.9.0 | +| transformers | 4.56.2 | 4.57.6 | +| Venv | `/opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/` | same | + +Both SDK versions are validated. Fused MoE TKG requires SDK 2.28. + +## Testing + +```bash +# Set paths for your model +export TRINITY_MODEL_PATH="/path/to/model" +export TRINITY_COMPILED_PATH="/path/to/compiled" + +# Run integration tests +pytest test/integration/test_trinity.py --capture=tee-sys + +# Or run directly +python test/integration/test_trinity.py +``` + +**Prerequisites:** +- Pre-compiled model at `TRINITY_COMPILED_PATH` +- HuggingFace model weights downloaded with `trust_remote_code=True` +- Appropriate instance for model size (see Compatibility Matrix) + +## Key Porting Challenges + +This model required solving several non-trivial porting challenges: + +1. **GLU type mismatch:** Trinity uses `SiLU(gate)*up` which maps to NxDI's `"glu"` type, NOT `"swiglu"` (`gate*SiLU(gate)*up`). +2. **Gated attention:** Trinity applies `sigmoid(gate(input))` to attention output before o_proj. Solved via inline override of attention forward methods (required for Neuron tracer compatibility). +3. **Dual intermediate sizes:** Dense layers use `intermediate_size`, MoE experts use `moe_intermediate_size`. Config swaps values for MoE module compatibility. +4. **route_scale not supported by NxDI MoE v2:** Baked into expert `down_proj` weights during conversion. +5. **expert_bias not supported by NxDI:** Created custom `RouterTopKWithBias` subclass. +6. **Conditional RoPE:** Only sliding attention layers get rotary embeddings. +7. **Mixed attention masks and KV cache:** Framework provides both global and local masks via `has_mixed_attn=True`; decoder layer selects based on layer type. `TrinityKVCacheManager` provides per-layer KV cache management (uniform buffers, per-layer scatter modulation and read slicing) to handle the different cache sizes of sliding vs full-attention layers. +8. **Gate weight padding at high TP:** Interleaved padding matching Q projection layout (prevents wrong-head gating on 54/64 cores). +9. **Shared expert weight loading:** Standalone module for reliable weight mapping vs NxDI built-in shared expert handling. + +## Fused MoE TKG NKI Kernel (SDK 2.28+) + +The SDK 2.28 fused MoE TKG kernel (`moe_token_gen_selective_load_kernel`) combines RMSNorm + Router TopK + Expert MLP into a single NKI kernel for token generation, reducing HBM round-trips. + +### Configuration + +```python +neuron_config = MoENeuronConfig( + tp_degree=2, + batch_size=1, + seq_len=2048, + torch_dtype=torch.bfloat16, + moe_fused_nki_kernel_enabled=True, + router_topk_nki_kernel_enabled=False, # Must be False for sigmoid routing + expert_mlp_nki_kernel_enabled=True, +) +``` + +### Sigmoid Routing Support + +The fused kernel's NKI router asserts `router_act_fn == SOFTMAX`, but Trinity uses sigmoid. The implementation patches the kernel to use the ISA router fallback (`use_router_topk_nki_kernel=False`), which supports both sigmoid and softmax. This is done automatically via `_PatchedKernelCall` wrapper applied during `setup_attr_for_model()`. + +### Alignment Constraint + +The fused kernel requires `moe_intermediate_size / tp_degree % 128 == 0`: + +| Model | moe_intermediate | TP | Per-TP | Eligible? | +|-------|-----------------|-----|--------|-----------| +| Nano | 256 | 2 | 128 | YES | +| Nano | 256 | 4 | 64 | NO | +| Mini | 1024 | 4 | 256 | YES | +| Large | 3072 | 64 | 48 | NO | + +The config class automatically enables/disables fused TKG based on this alignment check. + +### Test Results (SDK 2.28, Trinity-Nano, trn2.3xlarge, TP=2) + +**CTE (context encoding) -- exact match with non-fused baseline:** + +| Prompt | Non-fused Top-1 | Fused Top-1 | Match | +|--------|----------------|-------------|-------| +| Hello, how are you? | I | I | YES | +| What is the capital of France? | ( | ( | YES | +| 1 + 1 = | (newline) | (newline) | YES | + +**TKG (autoregressive generation, 8 tokens):** + +| Prompt | Non-fused TKG | Fused TKG | First Token | +|--------|---------------|-----------|-------------| +| Hello, how are you? | I am fine, thank you. And | I am fine, thanks! And you | MATCH (diverges at token 4) | +| What is the capital of France? | (Answer: Paris)... | (A) Paris (B) London | MATCH (diverges at token 2) | +| 1 + 1 = | 2, 1 + 2 = | 2, 2 + 1 = | MATCH (diverges at token 2) | + +TKG tokens diverge after the first token due to **expert_bias not being used by the fused kernel** (known limitation -- the fused kernel omits per-layer expert bias). Both paths produce coherent, sensible text. CTE outputs are identical because the CTE path always uses the non-fused compute-bound pipeline. + +**Latency:** + +| Metric | Non-fused | Fused | Notes | +|--------|-----------|-------|-------| +| Compile | 357s (5.9 min) | 258s (4.3 min) | Fused compiles faster | +| CTE latency | ~0.52s | ~0.49s | Similar | +| TKG latency | ~0.011s | ~0.014s | Fused is slower on Nano | + +The fused kernel does not improve TKG latency on Trinity-Nano (intermediate_size=256 is too small for the selective loading to pay off). The kernel is designed for larger models where expert weight loading from HBM is the bottleneck. Mini (intermediate=1024) is expected to benefit. + +### Known Limitations + +1. **Expert bias omitted** -- The fused NKI kernel does not apply per-layer `expert_bias` during routing. Non-fused routing uses `RouterTopKWithBias` which adds bias. This causes TKG output divergence after the first token. +2. **Nano TP=4 and Large TP=64 ineligible** -- Alignment constraint `intermediate/TP % 128 != 0` prevents use. +3. **No latency benefit on Nano** -- Expert weights (256 intermediate) are too small for selective loading overhead to pay off. + +## NKI Kernels + +The NxDI framework uses several NKI (Neuron Kernel Interface) kernels during Trinity compilation and inference. These are hardware-accelerated kernels that execute directly on Neuron cores. + +| Kernel | Source | Purpose | +|--------|--------|---------| +| **Flash Attention (Context Encoding)** | `neuronxcc.nki._pre_prod_kernels.attn_fwd` | Full-sequence attention during context encoding (prompt processing). Fused QKV attention with causal masking and sliding window support. | +| **Flash Attention ISA** | `neuronxcc.nki.kernels.attention.attention_isa_kernel` | ISA-level flash attention implementation used as BIR (Built-in Runtime) fallback for context encoding. | +| **Token Gen Attention** | `neuronxcc.nki._private_kernels.attention.attention_tkg_fwd_isa_kernel` | Single-token attention with KV cache lookup during autoregressive token generation. | +| **Token Gen Attention Block (Fused)** | `neuronxcc.nki._pre_prod_kernels.attention_token_gen.llama3_nki_attention_block_token_gen_kernel` | Fused kernel combining attention + RMSNorm + residual connection for token generation. Used when `attn_block_tkg_nki_kernel_enabled` is true. | +| **Blockwise Matmul (MoE Experts)** | `neuronx_distributed.modules.moe.blockwise.BlockwiseMatmulNKIFunc` | Expert MLP computation in MoE layers (gate, up, down projections). Handles sparse expert dispatch with token routing. **Note:** Accumulates in bf16, causing slightly higher numerical divergence vs CPU reference. | +| **Custom RMSNorm** | `neuronx_distributed_inference.modules.custom_calls.CustomRMSNorm` | Hardware-accelerated RMSNorm via `AwsNeuronRmsNorm` custom call. Used 4 times per decoder layer (input_norm, post_attn_norm, pre_ff_norm, post_ff_norm). | +| **Cumsum** | `neuronxcc.nki.kernels.cumsum` | Attention mask computation for causal mask prefix sums. Used in both context encoding and token generation paths. | +| **Router TopK** | `neuronx_distributed.kernels.router_topk_kernel` | Expert selection in MoE routing -- selects top-k experts from sigmoid routing scores. Used once per MoE layer. | +| **Fused MoE TKG** | `neuronxcc.nki._pre_prod_kernels.moe_token_gen.moe_token_gen_selective_load_kernel` | Combines RMSNorm + Router TopK + Expert MLP for token generation. Selectively loads expert weights from HBM. SDK 2.28+. Uses ISA router fallback for sigmoid. | + +### NKI Kernel Interaction with Trinity-Specific Features + +- **Gated attention bypass:** When NKI fused attention block kernels are enabled (`attn_block_tkg_nki_kernel_enabled` or `attn_block_cte_nki_kernel_enabled`), Trinity's custom gated attention is bypassed and the base class fused kernel is used instead. The gated attention path is used when fused kernels are disabled. +- **MoE bf16 accumulation:** The blockwise matmul NKI kernel accumulates expert outputs in bf16 rather than fp32, which is the primary source of numerical divergence between Neuron and CPU reference outputs. Top-1 token accuracy is preserved. +- **Left-padding unsupported:** The NKI flash attention kernels require right-padding (`padding_side="right"`). Left-padding produces incorrect results. + +## Example Checkpoints + +- `arcee-ai/Trinity-Nano-Preview` (requires `trust_remote_code=True`) +- `arcee-ai/Trinity-Mini` (requires `trust_remote_code=True`) +- `arcee-ai/Trinity-Large-Preview` (requires `trust_remote_code=True`) + +## Maintainer + +Jim Burtoft + +**Last Updated:** 2026-03-06 (re-benchmarked Nano + Mini with bucketing, added inf2 TP=2 and whole-instance throughput) diff --git a/contrib/models/Trinity/src/__init__.py b/contrib/models/Trinity/src/__init__.py new file mode 100644 index 0000000..1082485 --- /dev/null +++ b/contrib/models/Trinity/src/__init__.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Trinity (AfmoeForCausalLM) model implementation for NeuronX Distributed Inference.""" + +from .modeling_trinity import ( + TrinityInferenceConfig, + NeuronTrinityModel, + NeuronTrinityForCausalLM, + NeuronTrinityAttention, + NeuronTrinityMLP, + NeuronTrinitySharedExpert, + NeuronTrinityDecoderLayer, +) diff --git a/contrib/models/Trinity/src/modeling_trinity.py b/contrib/models/Trinity/src/modeling_trinity.py new file mode 100644 index 0000000..1ec8463 --- /dev/null +++ b/contrib/models/Trinity/src/modeling_trinity.py @@ -0,0 +1,1871 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unified NeuronX Distributed Inference implementation for the Trinity model family +(AfmoeForCausalLM) from Arcee AI. + +Supports all three Trinity sizes from a single codebase: +- Trinity-Nano-Preview (~6B total, ~1B active) +- Trinity-Mini (~26B total, ~4.5B active) +- Trinity-Large-Preview (~250B total, ~15B active) + +Architecture (shared across all sizes): +- AfmoeForCausalLM: Arcee Foundation Mixture of Experts +- Mixed attention: sliding_attention + full_attention (every 4th layer) +- Gated attention: gate_proj + sigmoid on attention output +- QK normalization: RMSNorm on Q and K per head +- Dual layer norms: pre/post for both attention and MLP (4 per layer) +- muP scaling: hidden_size**0.5 on input embeddings +- Sigmoid routing with normalization +- SiLU gated MLP (gate_proj, up_proj, down_proj) +- Expert bias on routing scores + +Key porting decisions: +- glu_type="glu" (NOT "swiglu") -- Trinity uses SiLU(gate)*up, which is NxDI's "glu" +- route_scale baked into routed expert down_proj weights (NxDI MoE v2 doesn't support it) +- muP scaling baked into embedding weights during conversion +- expert_bias handled via custom RouterTopKWithBias subclass +- Gated attention handled via inline override of attention forward methods +- Gate weight padding uses interleaved layout matching Q projection (for high TP) +""" + +import json +import os +import math +import logging +from typing import List, Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + NeuronConfig, + MoENeuronConfig, +) +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.attention.gqa import ( + determine_sharding_strategy, + get_shardable_head_counts, +) +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.kvcache.utils import ( + dynamic_update_slice, + fill_prefix, +) +from neuronx_distributed_inference.modules.generation.sampling import Sampler + +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers import utils as nxd_utils +from neuronx_distributed.utils import cpu_mode + +logger = logging.getLogger(__name__) + +# MoE v2 module (required for MoE layers) +try: + from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module + from neuronx_distributed.modules.moe.routing import RouterTopK + + MOE_V2_AVAILABLE = True +except ImportError: + MOE_V2_AVAILABLE = False + logger.warning("moe_v2 not available, MoE layers will not work") + + +def _patch_fused_tkg_for_sigmoid(): + """Patch MoEFusedTKG kernel to use ISA router fallback for sigmoid routing. + + The SDK 2.28 fused MoE TKG NKI kernel's router_topk_kernel_nki only supports + softmax activation. Trinity uses sigmoid routing. The kernel also has an ISA + router fallback (router_topk_isa_kernel) that supports both sigmoid and softmax. + + This patch wraps the selective-load kernel call to force + use_router_topk_nki_kernel=False, which uses the ISA router fallback. + + NOTE: The fused TKG kernel does NOT support expert_selection_bias. + The compiled NKI router kernels and nkilib utilities are incompatible with + user-level @nki.jit kernel inlining (TensorView.get_view() requires .ap() + which is unavailable in the NKI trace context). For models that need + expert_bias (all Trinity models), use the non-fused path instead + (moe_fused_nki_kernel_enabled=False). + + Must be called before model.compile(). + """ + try: + import neuronx_distributed.modules.moe.moe_fused_tkg as fused_tkg_mod + + original_kernel = fused_tkg_mod._moe_token_gen_selective_load_kernel_nki_call + if original_kernel is None: + logger.warning( + "Fused TKG selective load kernel not available, skipping patch" + ) + return + + # The NKI kernel call object supports [grid](**kwargs) invocation. + # We wrap it to inject use_router_topk_nki_kernel=False. + class _PatchedKernelCall: + """Wrapper that injects use_router_topk_nki_kernel=False into kernel calls.""" + + def __init__(self, original): + self._original = original + + def __getitem__(self, grid): + original_grid_call = self._original[grid] + + def patched_call(*args, **kwargs): + kwargs["use_router_topk_nki_kernel"] = False + return original_grid_call(*args, **kwargs) + + return patched_call + + fused_tkg_mod._moe_token_gen_selective_load_kernel_nki_call = ( + _PatchedKernelCall(original_kernel) + ) + + # Also patch the forward-all-experts kernel if it has the same issue + original_all = fused_tkg_mod._moe_tkg_forward_all_experts_nki_call + if original_all is not None: + fused_tkg_mod._moe_tkg_forward_all_experts_nki_call = _PatchedKernelCall( + original_all + ) + + logger.warning( + "Patched MoEFusedTKG for sigmoid routing (ISA fallback). " + "expert_bias NOT supported in fused TKG - TKG tokens may diverge from non-fused." + ) + except ImportError: + logger.info("moe_fused_tkg module not available (SDK < 2.28), skipping patch") + except Exception as e: + logger.warning("Failed to patch MoEFusedTKG for sigmoid: %s", e) + + +# --------------------------------------------------------------------------- +# TrinityKVCacheManager: Per-layer KV cache sizing for mixed attention +# --------------------------------------------------------------------------- +# Adapted from GptOssKVCacheManager in NxDI. Trinity has mixed attention: +# most layers use sliding-window attention (KV cache = sliding_window - 1), +# while every 4th layer uses full attention (KV cache = max_length). +# +# The standard KVCacheManager applies a single sliding_window modulation to +# ALL layers in _get_index_to_update_new_position, which causes OOB when +# full-attention layers have larger KV cache buffers. This custom manager +# creates per-layer cache buffers and applies per-layer scatter modulation. +# --------------------------------------------------------------------------- + + +def _slice_kv_cacheline(padding_side, seq_len, cache, transposed): + """Slice KV cache to seq_len along the sequence dimension.""" + seqlen_dim = 3 if transposed else 2 + if padding_side == "right": + return torch.ops.aten.slice(cache, dim=seqlen_dim, start=0, end=seq_len) + max_idx = cache.shape[seqlen_dim] + return torch.ops.aten.slice( + cache, dim=seqlen_dim, start=max_idx - seq_len, end=max_idx + ) + + +class TrinityKVCacheManager(nn.Module): + """Per-layer KV cache manager for Trinity's mixed attention. + + Sliding-window layers get a smaller KV cache (sliding_window - 1 positions). + Full-attention layers get the full max_length cache. Each layer's scatter + index is modulated correctly to stay within its own buffer. + """ + + def __init__( + self, config, num_kv_head, layer_types, sliding_window, global_rank=None + ): + super().__init__() + self.config = config + self.neuron_config = config.neuron_config + self.padding_side = config.neuron_config.padding_side + self.is_continuous_batching = config.neuron_config.is_continuous_batching + self.num_kv_head = num_kv_head + self.batch_size = config.neuron_config.max_batch_size + self.k_cache_transposed = config.neuron_config.k_cache_transposed + self.global_rank = global_rank + self.num_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.tp_degree = config.neuron_config.tp_degree + self.dtype = ( + config.neuron_config.attention_dtype + if config.neuron_config.attention_dtype is not None + else config.neuron_config.torch_dtype + ) + + # Per-layer attention type + self.layer_types = layer_types + # Use sliding_window directly as the cache size. GptOss uses + # sliding_window - 1 because its attention kernel convention differs, + # but Trinity's windowed_attention_forward creates TKG masks of size + # sliding_window, so the cache must match exactly. + self.sliding_window = sliding_window + + self._init_kv_shape() + + self.past_key_values = nn.ParameterList( + [ + nn.Parameter(torch.zeros(shape, dtype=self.dtype), requires_grad=False) + for layer_idx in range(self.num_layers) + for shape in [self.k_shapes[layer_idx], self.v_shapes[layer_idx]] + ] + ) + + def _get_num_kv_heads_per_rank(self): + gqa_sharding_strategy = determine_sharding_strategy( + self.tp_degree, self.num_kv_head + ) + _, num_key_value_heads = get_shardable_head_counts( + self.tp_degree, + self.num_attention_heads, + self.num_kv_head, + gqa_sharding_strategy, + ) + if parallel_state.model_parallel_is_initialized(): + return nxd_utils.divide(num_key_value_heads, self.tp_degree) + return num_key_value_heads + + def _init_kv_shape(self): + self.k_shapes = [] + self.v_shapes = [] + num_kv_heads_per_rank = self._get_num_kv_heads_per_rank() + head_dim = getattr( + self.config, + "head_dim", + self.config.hidden_size // self.config.num_attention_heads, + ) + max_length = self.config.neuron_config.max_length + + # All layers get max_length cache. During CTE, fill_prefix writes + # the full context (up to max_length tokens) into the cache. A + # smaller sliding-window cache would cause OOB when the CTE bucket + # exceeds sliding_window. The sliding-window optimization is applied + # only at TKG time via _get_index_to_update_new_position (wrapping + # position_ids modulo sliding_window for sliding layers) and via + # get_kv_by_layer_id (slicing the cache to sliding_window during read). + for layer_idx in range(self.num_layers): + shape = (self.batch_size, num_kv_heads_per_rank, max_length, head_dim) + self.k_shapes.append(shape) + self.v_shapes.append(shape) + + def _fetch_cache(self, idx, kvcache_buffer=None): + if kvcache_buffer is not None: + if ( + len(kvcache_buffer) == len(self.past_key_values) // 2 + and len(kvcache_buffer[0]) == 2 + ): + return kvcache_buffer[idx][0], kvcache_buffer[idx][1] + elif len(kvcache_buffer) == len(self.past_key_values): + return kvcache_buffer[2 * idx], kvcache_buffer[2 * idx + 1] + else: + raise ValueError( + f"kvcache_buffer length {len(kvcache_buffer)} not recognized" + ) + return self.past_key_values[2 * idx], self.past_key_values[2 * idx + 1] + + def get_kv_by_layer_id( + self, + idx, + seq_len, + skip_slice=False, + kvcache_buffer=None, + **kwargs, + ): + k_cache, v_cache = self._fetch_cache(idx, kvcache_buffer) + + # Override seq_len with the per-layer effective size: + # - Sliding layers: min(sliding_window, max_length) — only read the + # last sliding_window positions during TKG. + # - Full-attention layers: max_length — read everything. + # During CTE, seq_len from the caller equals n_positions (the bucket); + # we still override to ensure the slice matches what attention expects. + if hasattr(self, "v_shapes"): + is_sliding = self.layer_types[idx] == "sliding_attention" + max_len = self.v_shapes[idx][2] # always max_length + if is_sliding and self.sliding_window and self.sliding_window < max_len: + seq_len = self.sliding_window + else: + seq_len = max_len + + if not skip_slice: + k_cache = _slice_kv_cacheline( + self.padding_side, seq_len, k_cache, self.k_cache_transposed + ) + v_cache = _slice_kv_cacheline(self.padding_side, seq_len, v_cache, False) + return k_cache, v_cache + + def get_cache(self, seq_len, skip_slice=False, kvcache_buffer=None, **kwargs): + past_key_values = [] + for idx in range(len(self.past_key_values) // 2): + k_cache, v_cache = self.get_kv_by_layer_id( + idx=idx, + seq_len=seq_len, + skip_slice=skip_slice, + kvcache_buffer=kvcache_buffer, + **kwargs, + ) + past_key_values.append([k_cache, v_cache]) + return past_key_values + + def update_cache( + self, + is_for_context_encoding, + seq_ids, + position_ids, + new_key_values, + seq_len, + scatter_index=None, + kv_active_mask=None, + kvcache_buffer=None, + **kwargs, + ): + updated_kv_cache = [] + for idx, kv_per_layer in enumerate(new_key_values): + k_cache, v_cache = self.update_kv_by_layer_id( + idx=idx, + is_for_context_encoding=is_for_context_encoding, + seq_ids=seq_ids, + position_ids=position_ids, + kv_per_layer=kv_per_layer, + seq_len=seq_len, + scatter_index=scatter_index, + kv_active_mask=kv_active_mask, + kvcache_buffer=kvcache_buffer, + ) + updated_kv_cache.append(k_cache) + updated_kv_cache.append(v_cache) + return updated_kv_cache + + def update_kv_by_layer_id( + self, + idx, + is_for_context_encoding, + seq_ids, + position_ids, + kv_per_layer, + seq_len, + scatter_index=None, + kv_active_mask=None, + kvcache_buffer=None, + **kwargs, + ): + latest_k, latest_v = kv_per_layer[0], kv_per_layer[1] + k_cache, v_cache = self._fetch_cache(idx, kvcache_buffer) + + if is_for_context_encoding: + if self.is_continuous_batching: + assert seq_ids.dim() == 1 and seq_ids.shape[0] == 1 + if self.k_cache_transposed: + cache_idx = seq_ids + indices = [cache_idx] + [ + torch.zeros(1, device=seq_ids.device) + for _ in range(k_cache.dim() - 1) + ] + indices = [t.squeeze().to(torch.int32) for t in indices] + k_cache = dynamic_update_slice(k_cache, latest_k, indices) + v_cache = dynamic_update_slice(v_cache, latest_v, indices) + else: + from neuronx_distributed_inference.modules.kvcache.utils import ( + update_cache_const_indices, + ) + + k_cache = update_cache_const_indices(k_cache, latest_k, seq_ids) + v_cache = update_cache_const_indices(v_cache, latest_v, seq_ids) + else: + k_cache = fill_prefix(k_cache, latest_k) + v_cache = fill_prefix(v_cache, latest_v) + else: + # Token generation: scatter new KV into the correct position. + # Per-layer modulation keeps indices within the buffer bounds. + scatter_index_k = self._get_index_to_update_new_position( + scatter_index, position_ids, latest_k, self.k_cache_transposed, idx + ) + scatter_index_v = self._get_index_to_update_new_position( + scatter_index, position_ids, latest_v, False, idx + ) + k_cache = torch.scatter( + input=k_cache, + dim=(2 if not self.k_cache_transposed else 3), + index=scatter_index_k, + src=latest_k, + ) + v_cache = torch.scatter( + input=v_cache, dim=2, index=scatter_index_v, src=latest_v + ) + return k_cache, v_cache + + def _get_index_to_update_new_position( + self, scatter_index, position_ids, full_k, transposed, layer_idx + ): + """Per-layer scatter index modulation. + + Sliding-window layers: position_ids % sliding_window (wraps within window) + Full-attention layers: position_ids as-is (no modulation needed) + """ + is_sliding = self.layer_types[layer_idx] == "sliding_attention" + if is_sliding and self.sliding_window: + position_ids = position_ids % self.sliding_window + index = position_ids + view_shape = ( + (-1, 1, index.shape[-1], 1) + if not transposed + else (-1, 1, 1, index.shape[-1]) + ) + return index.view(*view_shape).expand_as(full_k) + + +class RouterTopKWithBias(RouterTopK): + """RouterTopK with expert_bias support for Trinity. + + Trinity uses expert_bias to influence which experts are selected: + - Sigmoid scores are computed: scores = sigmoid(logits) + - For top-k selection: topk(scores + expert_bias) + - For actual routing weights: gather scores at selected indices (no bias) + + The bias only affects WHICH experts are selected, not their weights. + """ + + def __init__(self, expert_bias_size, **kwargs): + super().__init__(**kwargs) + self.register_buffer( + "expert_bias", + torch.zeros(expert_bias_size, dtype=torch.float32), + ) + + def forward(self, hidden_states): + router_logits = self.get_router_logits(hidden_states) + expert_affinities = self.apply_activation_fn(router_logits) + expert_affinities = expert_affinities.to(dtype=hidden_states.dtype) + + # Top-k selection with expert_bias added to scores. + scores_for_selection = expert_affinities.float() + self.expert_bias.float() + _, expert_index = torch.topk(scores_for_selection, self.top_k) + expert_index = expert_index.detach().to(dtype=torch.long) + + return router_logits, expert_affinities, expert_index + + +def initialize_moe_with_expert_bias(config, init_tkg_module=False, rmsnorm=None): + """Initialize MoE module with expert_bias support. + + Args: + config: TrinityInferenceConfig + init_tkg_module: If True, enable fused MoE TKG NKI kernel path. + Requires SDK 2.28+ and moe_intermediate_size/tp % 128 == 0. + rmsnorm: RMSNorm module to fuse into MoE (required when init_tkg_module=True). + The fused kernel applies this norm internally, so the caller must skip it. + """ + if init_tkg_module: + try: + moe = initialize_moe_module( + config=config, init_tkg_module=True, rmsnorm=rmsnorm + ) + except TypeError: + # SDK 2.27 or older: initialize_moe_module doesn't accept these args + logger.warning( + "Fused MoE TKG not supported by this SDK version. " + "Falling back to standard path." + ) + moe = initialize_moe_module(config=config) + else: + moe = initialize_moe_module(config=config) + + old_router = moe.router + new_router = RouterTopKWithBias( + expert_bias_size=config.num_local_experts, + num_experts=old_router.num_experts, + top_k=old_router.top_k, + hidden_size=old_router.hidden_size, + dtype=old_router.dtype, + device=old_router.device, + act_fn=old_router.act_fn, + sequence_parallel_enabled=old_router.sequence_parallel_enabled, + sequence_dimension=old_router.sequence_dimension, + bias=old_router.bias, + apply_act_fn_over_topk=old_router.apply_act_fn_over_topk, + store_transposed_weights=old_router.store_transposed_weights, + ) + new_router.linear_router = old_router.linear_router + if hasattr(old_router, "weight_T"): + new_router.weight_T = old_router.weight_T + + moe.router = new_router + moe.eval() + return moe + + +def get_rmsnorm_cls(): + """Get the appropriate RMSNorm class based on execution mode.""" + if cpu_mode(): + + class StandardRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + return (self.weight * hidden_states).to(input_dtype) + + return StandardRMSNorm + else: + return CustomRMSNorm + + +class TrinityInferenceConfig(InferenceConfig): + """Configuration for Trinity (AfmoeForCausalLM) inference. + + Handles all Trinity model sizes (Nano, Mini, Large) via config-driven values. + + IMPORTANT: initialize_moe_module reads config.intermediate_size for expert MLP + dimensions. Trinity has two different intermediate sizes: + - intermediate_size: used for dense MLP layers (first num_dense_layers) + - moe_intermediate_size: used for MoE expert MLPs + + We store the dense size as dense_intermediate_size and set intermediate_size to + moe_intermediate_size so that initialize_moe_module gets the correct value. + """ + + def __init__(self, neuron_config=None, **kwargs): + # Model architecture parameters from AfmoeConfig + self.vocab_size = kwargs.pop("vocab_size", 200192) + self.hidden_size = kwargs.pop("hidden_size", 2048) + + # CRITICAL: intermediate_size must be the MoE intermediate size for initialize_moe_module + dense_intermediate = kwargs.pop("intermediate_size", 6144) + moe_intermediate = kwargs.pop("moe_intermediate_size", 1024) + self.dense_intermediate_size = dense_intermediate + self.intermediate_size = moe_intermediate + self.moe_intermediate_size = moe_intermediate + + self.num_hidden_layers = kwargs.pop("num_hidden_layers", 32) + self.num_dense_layers = kwargs.pop("num_dense_layers", 2) + self.num_attention_heads = kwargs.pop("num_attention_heads", 32) + self.num_key_value_heads = kwargs.pop("num_key_value_heads", 4) + self.head_dim = kwargs.pop("head_dim", 128) + self.hidden_act = kwargs.pop("hidden_act", "silu") + self.max_position_embeddings = kwargs.pop("max_position_embeddings", 131072) + self.initializer_range = kwargs.pop("initializer_range", 0.02) + self.rms_norm_eps = kwargs.pop("rms_norm_eps", 1e-5) + self.use_cache = kwargs.pop("use_cache", True) + self.rope_theta = kwargs.pop("rope_theta", 10000.0) + self.rope_scaling = kwargs.pop("rope_scaling", None) + self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + self.attention_dropout = kwargs.pop("attention_dropout", 0.0) + + # MoE parameters + self.num_experts = kwargs.pop("num_experts", 128) + self.num_local_experts = kwargs.pop("num_local_experts", None) + if self.num_local_experts is None: + self.num_local_experts = self.num_experts + self.num_experts_per_tok = kwargs.pop("num_experts_per_tok", 8) + self.num_shared_experts = kwargs.pop("num_shared_experts", 1) + # IMPORTANT: Set n_shared_experts=0 for initialize_moe_module so the NxDI MoE + # module does NOT create its own SharedExperts. We handle shared experts ourselves + # in NeuronTrinityDecoderLayer to ensure proper weight loading. + self.n_shared_experts = 0 + self.num_expert_groups = kwargs.pop("num_expert_groups", 1) + self.num_limited_groups = kwargs.pop("num_limited_groups", 1) + self.score_func = kwargs.pop("score_func", "sigmoid") + self.route_norm = kwargs.pop("route_norm", True) + self.route_scale = kwargs.pop("route_scale", 1.0) + self.n_group = kwargs.pop("n_group", 1) + self.topk_group = kwargs.pop("topk_group", 1) + self.load_balance_coeff = kwargs.pop("load_balance_coeff", 0.001) + + # Attention patterns + self.global_attn_every_n_layers = kwargs.pop("global_attn_every_n_layers", 4) + self.sliding_window = kwargs.pop("sliding_window", 2048) + self.layer_types = kwargs.pop("layer_types", None) + + # Clamp sliding_window to seq_len if seq_len < sliding_window. + # The KV cache is sized by seq_len (via n_positions), and sliding window + # attention creates masks of size sliding_window. These must match. + if neuron_config is not None and hasattr(neuron_config, "seq_len"): + if neuron_config.seq_len < self.sliding_window: + logger.info( + "Clamping sliding_window from %d to %d to match seq_len", + self.sliding_window, + neuron_config.seq_len, + ) + self.sliding_window = neuron_config.seq_len + + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if bool((i + 1) % self.global_attn_every_n_layers) + else "full_attention" + for i in range(self.num_hidden_layers) + ] + + # muP + self.mup_enabled = kwargs.pop("mup_enabled", True) + + # Standard attributes + self.pad_token_id = kwargs.pop("pad_token_id", None) + self.bos_token_id = kwargs.pop("bos_token_id", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) + self.torch_dtype = kwargs.pop("torch_dtype", "bfloat16") + self.attention_bias = kwargs.pop("attention_bias", False) + self.output_attentions = kwargs.pop("output_attentions", False) + self.output_hidden_states = kwargs.pop("output_hidden_states", False) + + # Pop HF-specific keys not used by our config + kwargs.pop("auto_map", None) + kwargs.pop("architectures", None) + kwargs.pop("model_type", None) + kwargs.pop("transformers_version", None) + kwargs.pop("dtype", None) + kwargs.pop("use_grouped_mm", None) + + super().__init__(neuron_config=neuron_config, **kwargs) + + # Adjust num_local_experts for expert parallelism + if hasattr(self, "neuron_config") and self.neuron_config is not None: + ep_degree = getattr(self.neuron_config, "ep_degree", 1) + if ep_degree > 1: + self.num_local_experts = self.num_experts // ep_degree + + # Set MoE neuron config parameters + if hasattr(self, "neuron_config") and self.neuron_config is not None: + if not hasattr(self.neuron_config, "glu_mlp"): + self.neuron_config.glu_mlp = True + # Trinity uses SiLU(gate)*up which is NxDI's "glu" type, + # NOT "swiglu" which computes gate*SiLU(gate)*up + self.neuron_config.glu_type = "glu" + # Trinity uses sigmoid routing (not softmax) + if hasattr(self.neuron_config, "router_config"): + self.neuron_config.router_config.act_fn = "sigmoid" + + def add_derived_config(self): + """Add derived configuration parameters.""" + self.num_cores_per_group = 1 + + # Enable fused MoE TKG kernel if alignment constraint is met. + # The fused NKI kernel requires intermediate_size_per_tp % 128 == 0. + # This is auto-validated; if the constraint fails, we fall back to + # the standard blockwise matmul path silently. + self._enable_fused_moe_tkg() + + def _enable_fused_moe_tkg(self): + """Check and enable fused MoE TKG NKI kernel (SDK 2.28+). + + The fused kernel combines RMSNorm + Router TopK + Expert MLP into a + single NKI kernel launch, reducing HBM round-trips during token gen. + + Requires: moe_intermediate_size / moe_tp_degree % 128 == 0. + """ + MOE_TKG_MK_INTERMEDIATE_PER_TP = 128 + if not hasattr(self, "neuron_config") or self.neuron_config is None: + return + + # Check if user explicitly requested fused kernel + fused_requested = getattr( + self.neuron_config, "moe_fused_nki_kernel_enabled", None + ) + if fused_requested is None: + return # Not requested, don't enable + + moe_tp = getattr(self.neuron_config, "moe_tp_degree", None) + if moe_tp is None: + moe_tp = getattr(self.neuron_config, "tp_degree", 1) + + i_per_tp = self.moe_intermediate_size // moe_tp + if i_per_tp % MOE_TKG_MK_INTERMEDIATE_PER_TP != 0: + logger.warning( + "Cannot enable fused MoE TKG kernel: " + "moe_intermediate_size/tp (%d/%d=%d) is not divisible by %d. " + "Falling back to standard blockwise matmul path.", + self.moe_intermediate_size, + moe_tp, + i_per_tp, + MOE_TKG_MK_INTERMEDIATE_PER_TP, + ) + self.neuron_config.moe_fused_nki_kernel_enabled = None + self.moe_fused_nki_kernel_enabled = False + else: + self.moe_fused_nki_kernel_enabled = True + logger.info( + "Fused MoE TKG NKI kernel enabled (intermediate_per_tp=%d)", i_per_tp + ) + + def get_required_attributes(self) -> List[str]: + return [ + "hidden_size", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "vocab_size", + "max_position_embeddings", + "num_local_experts", + "num_experts_per_tok", + "intermediate_size", + "head_dim", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return MoENeuronConfig + + @classmethod + def from_pretrained(cls, model_path: str, **kwargs) -> "TrinityInferenceConfig": + neuron_config = kwargs.pop("neuron_config", None) + model_path = os.path.expanduser(model_path) + config_path = os.path.join(model_path, "config.json") + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found at {config_path}") + with open(config_path, "r") as f: + config_dict = json.load(f) + config_dict.update(kwargs) + config = cls(neuron_config=neuron_config, **config_dict) + return config + + +class NeuronTrinityAttention(NeuronAttentionBase): + """Trinity attention with QK norms, conditional RoPE, and gated output. + + Key differences from standard attention: + 1. QK norms: RMSNorm applied to Q and K per head before attention + 2. Conditional RoPE: Only applied for sliding_attention layers, not full_attention + 3. Gated output: output = o_proj(attn_out * sigmoid(gate_proj(input))) + + Gating strategy (inline override): + The Neuron tracer cannot follow tensor flow through mutable state, closures, + or dynamic method replacement. The ONLY working approach is to have the gate + computation INLINE in the same method that calls o_proj. + + We override standard_causal_attention_forward and windowed_attention_forward + to insert gate_values = sigmoid(attn_gate_proj(original_hidden_states)) + and apply attn_output = attn_output * gate_values before the o_proj call. + """ + + def __init__(self, config: TrinityInferenceConfig, layer_idx: int): + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + + # RoPE only for sliding attention layers + if is_sliding: + rotary_emb = RotaryEmbedding( + config.head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + else: + rotary_emb = None + + # Set sliding_window on sliding layers so the base class dispatches to + # windowed_attention_forward (which uses the local_mask from the + # framework's mixed-attention flow). Full-attention layers get None. + sliding_window = config.sliding_window if is_sliding else None + + # Per-head QK norm + rmsnorm_cls = get_rmsnorm_cls() + q_norm = rmsnorm_cls(config.head_dim, eps=config.rms_norm_eps) + k_norm = rmsnorm_cls(config.head_dim, eps=config.rms_norm_eps) + + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rotary_emb=rotary_emb, + rope_theta=config.rope_theta if is_sliding else None, + rms_norm_eps=config.rms_norm_eps, + use_qk_norm=False, + q_layernorm=q_norm, + k_layernorm=k_norm, + sliding_window=sliding_window, + ) + + self.layer_idx = layer_idx + self.is_sliding = is_sliding + + # Gated attention: gate_proj applied before o_proj. + # Must match the actual per-rank attention output size from NxDI. + # When num_attention_heads is not divisible by TP, NxDI pads to + # ceil(num_heads/tp) heads per rank. We must match that padding. + tp_degree = config.neuron_config.tp_degree + heads_per_rank = math.ceil(config.num_attention_heads / tp_degree) + padded_total_heads = heads_per_rank * tp_degree + gate_output_size = padded_total_heads * config.head_dim + + self.attn_gate_proj = ColumnParallelLinear( + config.hidden_size, + gate_output_size, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + + def _apply_gated_o_proj(self, attn_output, gate_hidden_states, adapter_ids=None): + """Apply gating then o_proj, all inline for Neuron tracing. + + This method MUST be called from within the same forward pass where + gate_hidden_states is a live tensor in the traced graph. + """ + gate_values = torch.sigmoid(self.attn_gate_proj(gate_hidden_states)) + attn_output = attn_output * gate_values + return self.get_o_proj()(attn_output, adapter_ids=adapter_ids) + + def standard_causal_attention_forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + active_mask=None, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + rotary_position_ids=None, + kv_mgr=None, + get_kv_per_layer=False, + update_kv_per_layer=False, + residual=None, + windowed_context_encoding_window_idx=-1, + **kwargs, + ): + """Override base class to insert gating before o_proj. + + Copied from NeuronAttentionBase.standard_causal_attention_forward (NxDI 0.8.0) + with one change: the o_proj call is replaced with _apply_gated_o_proj. + """ + from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBaseOutput, + ) + + use_polar_compatible_rope = kwargs.get("use_polar_compatible_rope", False) + + # Save original hidden_states for gate computation BEFORE dtype conversion + gate_hidden_states = hidden_states + + original_dtype = hidden_states.dtype + hidden_states = hidden_states.to(self.torch_dtype) + seq_ids = kwargs.get("seq_ids") + is_context_parallel = past_key_value is None and self.cp_degree > 1 + is_data_parallel = past_key_value is not None and self.dp_degree > 1 + if is_context_parallel: + attention_mask, hidden_states, position_ids, cos_cache, sin_cache = ( + self._split_inputs_for_context_parallel( + attention_mask, hidden_states, position_ids, cos_cache, sin_cache + ) + ) + + if is_data_parallel: + from neuronx_distributed_inference.modules.attention.attention_base import ( + get_dp_rank, + split_along_dim, + get_data_parallel_attention_dp_group, + gather_from_tensor_model_parallel_region_with_dim, + ) + + dp_rank = get_dp_rank( + self.rank_util.get_rank(), + self.tp_degree, + self.dp_degree, + self.neuron_config.switch_cc, + ) + hidden_states = split_along_dim( + hidden_states, dim=0, rank=dp_rank, num_partitions=self.dp_degree + ) + attention_mask = split_along_dim( + attention_mask, dim=0, rank=dp_rank, num_partitions=self.dp_degree + ) + position_ids = split_along_dim( + position_ids, dim=0, rank=dp_rank, num_partitions=self.dp_degree + ) + + bsz, q_len, _ = hidden_states.size() + if self.sequence_parallel_enabled: + q_len *= self.tensor_model_parallel_group.size() + + if rotary_position_ids is None: + rotary_position_ids = position_ids + + if get_kv_per_layer: + assert kv_mgr is not None + past_key_value = kv_mgr.get_kv_by_layer_id(**kwargs) + + is_token_gen = past_key_value is not None + + if windowed_context_encoding_window_idx >= 0: + is_token_gen = False + + if self.neuron_config.is_prefix_caching: + is_token_gen = is_token_gen and q_len < 128 + + # NKI kernel paths -- delegate to base class (no custom gating in fused kernels) + if self.attn_block_tkg_nki_kernel_enabled and is_token_gen: + return super().standard_causal_attention_forward( + gate_hidden_states.to(self.torch_dtype) + if is_context_parallel or is_data_parallel + else gate_hidden_states, + attention_mask, + position_ids, + past_key_value, + active_mask, + adapter_ids, + cos_cache, + sin_cache, + rmsnorm, + rotary_position_ids, + kv_mgr, + get_kv_per_layer, + update_kv_per_layer, + residual, + windowed_context_encoding_window_idx, + **kwargs, + ) + + if ( + self.attn_block_cte_nki_kernel_enabled + and not is_token_gen + and not self.neuron_config.is_prefix_caching + ): + return super().standard_causal_attention_forward( + gate_hidden_states.to(self.torch_dtype) + if is_context_parallel or is_data_parallel + else gate_hidden_states, + attention_mask, + position_ids, + past_key_value, + active_mask, + adapter_ids, + cos_cache, + sin_cache, + rmsnorm, + rotary_position_ids, + kv_mgr, + get_kv_per_layer, + update_kv_per_layer, + residual, + windowed_context_encoding_window_idx, + **kwargs, + ) + + tkg_attn_kernel_fused_rope = ( + is_token_gen and self.attn_tkg_builtin_kernel_enabled + ) + + Q, K, V, cos_cache, sin_cache, residual = self.prep_qkv_tensors( + rotary_position_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rmsnorm=rmsnorm, + skip_rope=tkg_attn_kernel_fused_rope, + residual=residual, + use_polar_compatible_rope=use_polar_compatible_rope, + ) + + if is_token_gen: + if tkg_attn_kernel_fused_rope: + attn_output, K = self.attention_tokengen_kernel_builtin( + Q, + K, + V, + position_ids, + past_key_value, + attention_mask, + active_mask, + rotary_position_ids, + ) + else: + attn_output = self.attention_tokengen( + Q, + K, + V, + attention_mask, + position_ids, + past_key_value, + active_mask, + **kwargs, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + else: + attn_output, K, V = self.attention_context_encode( + Q, K, V, q_len, bsz, attention_mask, past_key_value, active_mask + ) + + # merge multi head hidden + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + # *** GATED ATTENTION: apply gate BEFORE o_proj, all inline *** + attn_output = self._apply_gated_o_proj( + attn_output, gate_hidden_states, adapter_ids=adapter_ids + ) + + if self.k_cache_transposed: + K = K.permute(0, 1, 3, 2) + + kv = (K, V) + + if update_kv_per_layer: + assert kv_mgr is not None + kv = kv_mgr.update_kv_by_layer_id( + kv_per_layer=kv, + position_ids=position_ids, + **kwargs, + ) + + if is_context_parallel and not self.sequence_parallel_enabled: + from neuronx_distributed_inference.modules.attention.attention_base import ( + gather_from_tensor_model_parallel_region_with_dim, + get_context_parallel_attention_cp_group, + ) + + attn_output = gather_from_tensor_model_parallel_region_with_dim( + attn_output, + gather_dim=1, + process_group=get_context_parallel_attention_cp_group(), + ) + + if is_data_parallel: + from neuronx_distributed_inference.modules.attention.attention_base import ( + gather_from_tensor_model_parallel_region_with_dim, + get_data_parallel_attention_dp_group, + ) + + attn_output = gather_from_tensor_model_parallel_region_with_dim( + attn_output, + gather_dim=0, + process_group=get_data_parallel_attention_dp_group(), + ) + + attn_output = attn_output.to(original_dtype) + + return NeuronAttentionBaseOutput( + attn_output, kv, cos_cache, sin_cache, residual + ) + + def windowed_attention_forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + active_mask=None, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + rotary_position_ids=None, + kv_mgr=None, + get_kv_per_layer=False, + update_kv_per_layer=False, + residual=None, + windowed_context_encoding_window_idx=-1, + **kwargs, + ): + """Override base class to insert gating before o_proj. + + Copied from NeuronAttentionBase.windowed_attention_forward (NxDI 0.8.0) + with one change: the o_proj call is replaced with _apply_gated_o_proj. + """ + from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBaseOutput, + get_last_kv_window, + ) + + # Save original hidden_states for gate computation BEFORE any modifications + gate_hidden_states = hidden_states + + is_context_parallel = past_key_value is None and self.cp_degree > 1 + is_data_parallel = past_key_value is not None and self.dp_degree > 1 + + full_position_ids = position_ids.clone() + + if is_context_parallel: + attention_mask, hidden_states, position_ids, cos_cache, sin_cache = ( + self._split_inputs_for_context_parallel( + attention_mask, hidden_states, position_ids, cos_cache, sin_cache + ) + ) + + if is_data_parallel: + from neuronx_distributed_inference.modules.attention.attention_base import ( + get_dp_rank, + split_along_dim, + ) + + dp_rank = get_dp_rank( + self.rank_util.get_rank(), + self.tp_degree, + self.dp_degree, + self.neuron_config.switch_cc, + ) + hidden_states = split_along_dim( + hidden_states, dim=0, rank=dp_rank, num_partitions=self.dp_degree + ) + attention_mask = split_along_dim( + attention_mask, dim=0, rank=dp_rank, num_partitions=self.dp_degree + ) + position_ids = split_along_dim( + position_ids, dim=0, rank=dp_rank, num_partitions=self.dp_degree + ) + + bsz, q_len, _ = hidden_states.size() + if self.sequence_parallel_enabled: + q_len *= self.tensor_model_parallel_group.size() + + if rotary_position_ids is None: + rotary_position_ids = position_ids + + if get_kv_per_layer: + assert kv_mgr is not None + past_key_value = kv_mgr.get_kv_by_layer_id(**kwargs) + + is_token_gen = past_key_value is not None + + if windowed_context_encoding_window_idx >= 0: + is_token_gen = False + + # NKI kernel path -- delegate to base class (no gating) + if self.attn_block_tkg_nki_kernel_enabled and is_token_gen: + return super().windowed_attention_forward( + gate_hidden_states, + attention_mask, + position_ids, + past_key_value, + active_mask, + adapter_ids, + cos_cache, + sin_cache, + rmsnorm, + rotary_position_ids, + kv_mgr, + get_kv_per_layer, + update_kv_per_layer, + residual, + windowed_context_encoding_window_idx, + **kwargs, + ) + + tkg_attn_kernel_fused_rope = ( + is_token_gen and self.attn_tkg_builtin_kernel_enabled + ) + + Q, K, V, cos_cache, sin_cache, residual = self.prep_qkv_tensors( + rotary_position_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rmsnorm=rmsnorm, + skip_rope=tkg_attn_kernel_fused_rope, + residual=residual, + ) + + if is_token_gen: + attn_output = self.attention_tokengen( + Q, + K, + V, + attention_mask, + position_ids, + past_key_value, + active_mask, + **kwargs, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + else: + attn_output, K, V = self.attention_context_encode_windowed_attention( + Q, + K, + V, + q_len, + bsz, + attention_mask, + self.sliding_window, + past_key_value, + active_mask, + ) + K, V = get_last_kv_window( + self.sliding_window, + full_position_ids, + K, + V, + windowed_context_encoding_window_idx, + self.neuron_config.speculation_length, + ) + + # merge multi head hidden + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + # *** GATED ATTENTION: apply gate BEFORE o_proj, all inline *** + attn_output = self._apply_gated_o_proj( + attn_output, gate_hidden_states, adapter_ids=adapter_ids + ) + + if self.k_cache_transposed: + K = K.permute(0, 1, 3, 2) + + kv = (K, V) + + if update_kv_per_layer: + assert kv_mgr is not None + kv = kv_mgr.update_kv_by_layer_id( + kv_per_layer=kv, + position_ids=position_ids, + **kwargs, + ) + + if is_context_parallel and not self.sequence_parallel_enabled: + from neuronx_distributed_inference.modules.attention.attention_base import ( + gather_from_tensor_model_parallel_region_with_dim, + get_context_parallel_attention_cp_group, + ) + + attn_output = gather_from_tensor_model_parallel_region_with_dim( + attn_output, + gather_dim=1, + process_group=get_context_parallel_attention_cp_group(), + ) + + return NeuronAttentionBaseOutput( + attn_output, kv, cos_cache, sin_cache, residual + ) + + +class NeuronTrinityMLP(nn.Module): + """Dense MLP for non-MoE layers (first num_dense_layers layers). + + Uses dense_intermediate_size, NOT the MoE intermediate_size. + """ + + def __init__(self, config: TrinityInferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + intermediate = config.dense_intermediate_size + + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + intermediate, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + intermediate, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.down_proj = RowParallelLinear( + intermediate, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + ) + self.act_fn = F.silu + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class NeuronTrinitySharedExpert(nn.Module): + """Shared expert MLP for MoE layers. + + Trinity has num_shared_experts=1. Each MoE layer has a shared expert whose + output is added to the routed expert output for every token. Uses the same + SiLU-gated MLP architecture as the dense layers but with moe_intermediate_size. + + Implemented as a standalone module (separate from NxDI's MoE SharedExperts) + to ensure reliable weight loading via standard ColumnParallelLinear/RowParallelLinear. + """ + + def __init__(self, config: TrinityInferenceConfig): + super().__init__() + self.hidden_size = config.hidden_size + intermediate = config.moe_intermediate_size + + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + intermediate, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + intermediate, + bias=False, + gather_output=False, + dtype=config.neuron_config.torch_dtype, + ) + self.down_proj = RowParallelLinear( + intermediate, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=config.neuron_config.torch_dtype, + ) + self.act_fn = F.silu + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class NeuronTrinityDecoderLayer(nn.Module): + """Trinity decoder layer with dual layer norms and conditional MoE. + + Structure: + - input_layernorm -> attention -> post_attention_layernorm -> residual + - pre_mlp_layernorm -> MLP/MoE -> post_mlp_layernorm -> residual + """ + + def __init__(self, config: TrinityInferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + self.self_attn = NeuronTrinityAttention(config, layer_idx) + self.attention_type = config.layer_types[layer_idx] + + rmsnorm_cls = get_rmsnorm_cls() + self.input_layernorm = rmsnorm_cls(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = rmsnorm_cls( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_mlp_layernorm = rmsnorm_cls( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_mlp_layernorm = rmsnorm_cls( + config.hidden_size, eps=config.rms_norm_eps + ) + + # MoE for layers >= num_dense_layers, dense MLP otherwise + self.moe_enabled = layer_idx >= config.num_dense_layers + self.moe_fused_tkg = getattr(config, "moe_fused_nki_kernel_enabled", False) + if self.moe_enabled and MOE_V2_AVAILABLE: + self.mlp = initialize_moe_with_expert_bias( + config=config, + init_tkg_module=self.moe_fused_tkg, + # Pass rmsnorm=None so MoE's _forward_compute_bound does NOT + # re-normalize during CTE (we normalize in the decoder forward). + # For the TKG fused kernel, we provide a separate RMSNorm + # instance below so the kernel can access gamma/eps. + rmsnorm=None, + ) + # For fused TKG: the kernel needs gamma/eps for its internal + # RMSNorm. Since we passed rmsnorm=None above, we must provide + # a separate (non-shared) RMSNorm instance on MoEFusedTKG. + # This avoids the shared-module aliasing issue that corrupted + # weight loading in the CTE path. + if self.moe_fused_tkg and hasattr(self.mlp, "moe_fused_tkg"): + fused_tkg = self.mlp.moe_fused_tkg + if fused_tkg is not None: + moe_rmsnorm = rmsnorm_cls( + config.hidden_size, eps=config.rms_norm_eps + ) + fused_tkg.post_attention_layernorm = moe_rmsnorm + # Shared expert: handled outside NxDI MoE to ensure reliable weight loading + if config.num_shared_experts > 0: + self.shared_expert = NeuronTrinitySharedExpert(config) + else: + self.shared_expert = None + else: + self.mlp = NeuronTrinityMLP(config) + self.shared_expert = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + residual = hidden_states + normed = self.input_layernorm(hidden_states) + + # Mixed attention mask selection (matches Llama4 pattern): + # - Sliding layers use local_mask (windowed, sized to sliding_window) + # - Full-attention layers use attention_mask (global, sized to n_positions; + # padded by apply_seq_ids_mask in compute_for_token_gen when KV cache + # is larger than the mask) + local_mask = kwargs.pop("local_mask", None) + if self.attention_type == "sliding_attention" and local_mask is not None: + mask = local_mask + else: + mask = attention_mask + + attn_output, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=normed, + attention_mask=mask, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + attn_output = self.post_attention_layernorm(attn_output) + hidden_states = residual + attn_output + + # MLP with dual norms + residual = hidden_states + # Normalization strategy for fused MoE TKG: + # - CTE (seq_len > 1): Decoder applies pre_mlp_layernorm. + # MoE's _forward_compute_bound skips norm (rmsnorm=None). + # - TKG (seq_len == 1): Decoder skips pre_mlp_layernorm. + # Fused kernel applies norm internally using its own RMSNorm. + # MoEFusedTKG fallback also skips norm (post_attn_layernorm + # handles it when kernel is disabled). + # When fused TKG is not enabled, decoder always applies norm. + is_tkg = self.moe_fused_tkg and hidden_states.shape[1] == 1 + if not is_tkg: + hidden_states = self.pre_mlp_layernorm(hidden_states) + + if self.moe_enabled and MOE_V2_AVAILABLE: + mlp_output = self.mlp(hidden_states, padding_mask)[0] + # Add shared expert output (applied to every token) + if self.shared_expert is not None: + # In TKG mode, hidden_states is un-normed (fused kernel + # handles norm internally). Shared expert needs normed input. + shared_input = ( + self.pre_mlp_layernorm(hidden_states) if is_tkg else hidden_states + ) + shared_output = self.shared_expert(shared_input) + mlp_output = mlp_output + shared_output + hidden_states = mlp_output + else: + hidden_states = self.mlp(hidden_states) + + hidden_states = self.post_mlp_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value, cos_cache, sin_cache, None) + return outputs + + +class NeuronTrinityModel(NeuronBaseModel): + """NeuronX Trinity base model (all sizes).""" + + def setup_attr_for_model(self, config: TrinityInferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = getattr(config.neuron_config, "buckets", None) + + # Mixed attention: set has_mixed_attn so the framework creates both + # global and local masks. Set self.sliding_window so the framework + # generates local_attn_mask via _create_windowed_attn_mask_tkg. + # Full-attention layers use the global mask (padded by apply_seq_ids_mask + # in compute_for_token_gen when KV cache > mask size). + # Sliding layers use local_mask (sized to sliding_window). + self.sliding_window = getattr(config, "sliding_window", None) + self.has_mixed_attn = True + + # Store layer_types and raw sliding_window for the custom KV cache manager. + self._layer_types = config.layer_types + self._config_sliding_window = getattr(config, "sliding_window", None) + + # Patch fused MoE TKG kernel for sigmoid routing (must happen before compile). + # Trinity uses sigmoid routing but the fused NKI kernel's router_topk_kernel_nki + # only supports softmax. This forces the ISA router fallback which supports both. + if getattr(config, "moe_fused_nki_kernel_enabled", False): + _patch_fused_tkg_for_sigmoid() + + def init_inference_optimization(self, config: TrinityInferenceConfig): + """Override to use TrinityKVCacheManager for per-layer KV cache sizing. + + The standard KVCacheManager cannot handle mixed attention with bucketing: + its _get_index_to_update_new_position applies uniform sliding_window + modulation to ALL layers, causing OOB for full-attention layers. + + TrinityKVCacheManager creates per-layer cache buffers and applies + per-layer scatter modulation (sliding layers: position % window, + full-attention layers: no modulation). + """ + if self.on_device_sampling: + self.sampler = Sampler(config.neuron_config) + + self.kv_mgr = TrinityKVCacheManager( + config, + num_kv_head=self.num_key_value_heads, + layer_types=self._layer_types, + sliding_window=self._config_sliding_window, + global_rank=self.rank_util, + ) + + def init_model(self, config: TrinityInferenceConfig): + self.padding_idx = getattr(config, "pad_token_id", None) + self.vocab_size = config.vocab_size + self.mup_enabled = getattr(config, "mup_enabled", False) + self.mup_scale = math.sqrt(config.hidden_size) if self.mup_enabled else 1.0 + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + + self.layers = nn.ModuleList( + [ + NeuronTrinityDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + rmsnorm_cls = get_rmsnorm_cls() + self.norm = rmsnorm_cls(config.hidden_size, eps=config.rms_norm_eps) + + # Pad vocab_size to be divisible by TP degree for ColumnParallelLinear + tp_degree = config.neuron_config.tp_degree + padded_vocab = config.vocab_size + if padded_vocab % tp_degree != 0: + padded_vocab = ((padded_vocab // tp_degree) + 1) * tp_degree + self.padded_vocab_size = padded_vocab + self.actual_vocab_size = config.vocab_size + + self.lm_head = ColumnParallelLinear( + config.hidden_size, + padded_vocab, + gather_output=False if self.on_device_sampling else True, + bias=False, + dtype=config.neuron_config.torch_dtype, + pad=True, + ) + + +class NeuronTrinityForCausalLM(NeuronBaseForCausalLM): + """NeuronX wrapper for Trinity causal language models (all sizes). + + Supports: + - arcee-ai/Trinity-Nano-Preview (~6B total, ~1B active) + - arcee-ai/Trinity-Mini (~26B total, ~4.5B active) + - arcee-ai/Trinity-Large-Preview (~250B total, ~15B active) + """ + + _model_cls = NeuronTrinityModel + + @classmethod + def get_config_cls(cls): + return TrinityInferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: InferenceConfig + ) -> dict: + """Convert HuggingFace AfmoeForCausalLM state dict to NeuronX format. + + Key transformations: + 1. Remove 'model.' prefix from HF keys + 2. Rename QK norms: q_norm -> q_layernorm, k_norm -> k_layernorm + 3. Map attention gate_proj to attn_gate_proj (gated attention) + 4. Stack per-expert weights into [E, H, 2*I] gate_up_proj format + 5. Map router: router.gate.weight -> router.linear_router.weight + 6. Map shared expert weights to standalone shared_expert module + 7. Bake muP scaling into embedding weights + 8. Bake route_scale into routed expert down_proj weights + 9. Pad gate_proj weights with interleaved layout (when num_heads % TP != 0) + 10. Pad lm_head weights (when vocab_size % TP != 0) + """ + neuron_state_dict = {} + neuron_config = config.neuron_config + target_dtype = torch.bfloat16 + + has_model_prefix = any(k.startswith("model.") for k in state_dict.keys()) + + def strip_prefix(key): + if has_model_prefix and key.startswith("model."): + return key[6:] + return key + + # Direct mappings: embeddings, final norm, lm_head + for key, value in state_dict.items(): + stripped = strip_prefix(key) + + if stripped == "embed_tokens.weight": + embed_weight = value.to(target_dtype) + mup_enabled = getattr(config, "mup_enabled", False) + if mup_enabled: + mup_scale = math.sqrt(config.hidden_size) + embed_weight = embed_weight * mup_scale + neuron_state_dict["embed_tokens.weight"] = embed_weight + continue + if stripped == "norm.weight": + neuron_state_dict["norm.weight"] = value.to(target_dtype) + continue + if key == "lm_head.weight": + lm_weight = value.to(target_dtype) + # Pad lm_head to be divisible by TP degree + tp_degree = neuron_config.tp_degree + vocab_size = lm_weight.shape[0] + if vocab_size % tp_degree != 0: + padded_vocab = ((vocab_size // tp_degree) + 1) * tp_degree + pad_rows = padded_vocab - vocab_size + lm_weight = torch.cat( + [ + lm_weight, + torch.zeros( + pad_rows, lm_weight.shape[1], dtype=target_dtype + ), + ], + dim=0, + ) + neuron_state_dict["lm_head.weight"] = lm_weight + continue + + # Layer-by-layer conversion + num_layers = config.num_hidden_layers + num_experts = config.num_local_experts + moe_intermediate = config.moe_intermediate_size + hidden_size = config.hidden_size + num_dense_layers = getattr(config, "num_dense_layers", 2) + + for layer_idx in range(num_layers): + if has_model_prefix: + hf_prefix = f"model.layers.{layer_idx}" + else: + hf_prefix = f"layers.{layer_idx}" + neuron_prefix = f"layers.{layer_idx}" + + # Layer norms (4 per layer) + for norm_name in [ + "input_layernorm", + "post_attention_layernorm", + "pre_mlp_layernorm", + "post_mlp_layernorm", + ]: + hf_key = f"{hf_prefix}.{norm_name}.weight" + if hf_key in state_dict: + neuron_state_dict[f"{neuron_prefix}.{norm_name}.weight"] = ( + state_dict[hf_key].to(target_dtype) + ) + + # Attention Q, K, V projections + for proj in ["q_proj", "k_proj", "v_proj"]: + hf_key = f"{hf_prefix}.self_attn.{proj}.weight" + if hf_key in state_dict: + neuron_state_dict[ + f"{neuron_prefix}.self_attn.qkv_proj.{proj}.weight" + ] = state_dict[hf_key].to(target_dtype) + + # O projection + hf_key = f"{hf_prefix}.self_attn.o_proj.weight" + if hf_key in state_dict: + neuron_state_dict[f"{neuron_prefix}.self_attn.o_proj.weight"] = ( + state_dict[hf_key].to(target_dtype) + ) + + # QK norm weights: q_norm -> q_layernorm, k_norm -> k_layernorm + for hf_norm, neuron_norm in [ + ("q_norm", "q_layernorm"), + ("k_norm", "k_layernorm"), + ]: + hf_key = f"{hf_prefix}.self_attn.{hf_norm}.weight" + if hf_key in state_dict: + neuron_state_dict[ + f"{neuron_prefix}.self_attn.{neuron_norm}.weight" + ] = state_dict[hf_key].to(target_dtype) + + # Attention gate_proj (gated attention, Trinity-specific) + # CRITICAL: Must use INTERLEAVED padding matching Q projection layout. + # NxDI pads Q with maybe_pad_interleaved (REPLICATE_TO_TP_DEGREE), + # inserting zero heads between KV groups. The gate_proj output is + # element-wise multiplied with the attention output (which follows + # the Q head layout), so gate_proj MUST use the same interleaved + # padding pattern. Using tail padding causes cores to apply gate + # weights from the wrong head. + hf_key = f"{hf_prefix}.self_attn.gate_proj.weight" + if hf_key in state_dict: + gate_weight = state_dict[hf_key].to(target_dtype) + tp_degree = neuron_config.tp_degree + num_heads = config.num_attention_heads + num_kv_heads = config.num_key_value_heads + if num_heads % tp_degree != 0: + # Use interleaved padding matching Q layout. + # Gate weight is (num_heads, hidden_size) -- one row per head. + # Split into KV groups, pad each group with zero rows, + # then concatenate back to (padded_total_heads, hidden_size). + padded_total_heads = math.ceil(num_heads / tp_degree) * tp_degree + group_size = num_heads // num_kv_heads # Q heads per KV group + groups = gate_weight.split(group_size, dim=0) + pad_per_group = (padded_total_heads - num_heads) // num_kv_heads + interleaved = [] + for group in groups: + interleaved.append(group) + interleaved.append( + torch.zeros( + pad_per_group, + gate_weight.shape[1], + dtype=target_dtype, + ) + ) + gate_weight = torch.cat(interleaved, dim=0) + neuron_state_dict[ + f"{neuron_prefix}.self_attn.attn_gate_proj.weight" + ] = gate_weight + + # MLP weights + if layer_idx < num_dense_layers: + # Dense layers (uses dense_intermediate_size) + for proj_name in ["gate_proj", "up_proj", "down_proj"]: + hf_key = f"{hf_prefix}.mlp.{proj_name}.weight" + if hf_key in state_dict: + neuron_state_dict[f"{neuron_prefix}.mlp.{proj_name}.weight"] = ( + state_dict[hf_key].to(target_dtype) + ) + else: + # MoE layers + # Router: router.gate.weight -> router.linear_router.weight + hf_router_key = f"{hf_prefix}.mlp.router.gate.weight" + if hf_router_key in state_dict: + neuron_state_dict[ + f"{neuron_prefix}.mlp.router.linear_router.weight" + ] = state_dict[hf_router_key].to(target_dtype) + + # Expert bias (Trinity-specific routing parameter) + hf_bias_key = f"{hf_prefix}.mlp.expert_bias" + if hf_bias_key in state_dict: + neuron_state_dict[f"{neuron_prefix}.mlp.router.expert_bias"] = ( + state_dict[hf_bias_key].to(torch.float32) + ) + + # Stack expert weights for NxDI MoE v2 format + gate_up_proj = torch.empty( + num_experts, hidden_size, 2 * moe_intermediate, dtype=target_dtype + ) + down_proj = torch.empty( + num_experts, moe_intermediate, hidden_size, dtype=target_dtype + ) + + all_experts_found = True + for e in range(num_experts): + gate_key = f"{hf_prefix}.mlp.experts.{e}.gate_proj.weight" + up_key = f"{hf_prefix}.mlp.experts.{e}.up_proj.weight" + down_key = f"{hf_prefix}.mlp.experts.{e}.down_proj.weight" + + if ( + gate_key in state_dict + and up_key in state_dict + and down_key in state_dict + ): + gate_w = state_dict[gate_key].to(target_dtype) + up_w = state_dict[up_key].to(target_dtype) + down_w = state_dict[down_key].to(target_dtype) + + gate_up_concat = torch.cat([gate_w, up_w], dim=0) + gate_up_proj[e] = gate_up_concat.T + down_proj[e] = down_w.T + else: + all_experts_found = False + break + + if all_experts_found: + # Bake route_scale into routed expert down_proj weights. + # NxDI MoE v2 does NOT support route_scale natively. + # Shared experts are NOT scaled. + route_scale = getattr(config, "route_scale", 1.0) + if route_scale != 1.0: + down_proj = down_proj * route_scale + + neuron_state_dict[ + f"{neuron_prefix}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gate_up_proj + neuron_state_dict[ + f"{neuron_prefix}.mlp.expert_mlps.mlp_op.down_proj.weight" + ] = down_proj + + # Shared expert weights (mapped to standalone NeuronTrinitySharedExpert) + for proj_name in ["gate_proj", "up_proj", "down_proj"]: + hf_key = f"{hf_prefix}.mlp.shared_experts.{proj_name}.weight" + if hf_key in state_dict: + neuron_state_dict[ + f"{neuron_prefix}.shared_expert.{proj_name}.weight" + ] = state_dict[hf_key].to(target_dtype) + + # Fused MoE TKG aliased weights. + # When init_tkg_module=True, the MoE module stores the + # pre_mlp_layernorm as moe.rmsnorm and also inside + # moe.moe_fused_tkg.post_attention_layernorm. These are the + # same Python object (aliased), but appear as separate keys in + # the state dict. We must provide both so the framework loads + # them correctly. + # Similarly, the router stores a transposed weight (weight_T) + # alongside linear_router.weight. + if getattr(config, "moe_fused_nki_kernel_enabled", False): + # MoEFusedTKG has a separate (non-shared) RMSNorm that + # needs the same weights as pre_mlp_layernorm. This is + # a distinct module (not aliased) so we copy the weight. + # Note: moe.rmsnorm is None (not a module), so we do NOT + # provide mlp.rmsnorm.weight. + pre_mlp_key = f"{neuron_prefix}.pre_mlp_layernorm.weight" + if pre_mlp_key in neuron_state_dict: + pre_mlp_w = neuron_state_dict[pre_mlp_key] + neuron_state_dict[ + f"{neuron_prefix}.mlp.moe_fused_tkg.post_attention_layernorm.weight" + ] = pre_mlp_w.clone() + + # Router transposed weight (generated by preshard_hook, + # but we provide it here too for completeness) + router_key = f"{neuron_prefix}.mlp.router.linear_router.weight" + if router_key in neuron_state_dict: + neuron_state_dict[f"{neuron_prefix}.mlp.router.weight_T"] = ( + neuron_state_dict[router_key].detach().T.clone() + ) + + # Rank utilities for tensor parallel + tp_degree = neuron_config.tp_degree + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + for i in range(num_layers): + neuron_state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + + return neuron_state_dict + + def get_compiler_args(self): + """Get compiler arguments for Trinity models.""" + return "--model-type=transformer -O1" diff --git a/contrib/models/Trinity/test/__init__.py b/contrib/models/Trinity/test/__init__.py new file mode 100644 index 0000000..ce63f8f --- /dev/null +++ b/contrib/models/Trinity/test/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/models/Trinity/test/integration/__init__.py b/contrib/models/Trinity/test/integration/__init__.py new file mode 100644 index 0000000..ce63f8f --- /dev/null +++ b/contrib/models/Trinity/test/integration/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/models/Trinity/test/integration/test_model.py b/contrib/models/Trinity/test/integration/test_model.py new file mode 100644 index 0000000..9a1d40b --- /dev/null +++ b/contrib/models/Trinity/test/integration/test_model.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Integration tests for Trinity (AfmoeForCausalLM) NeuronX implementation. + +Supports all three Trinity model sizes (Nano, Mini, Large) via environment variables. + +Usage: + # Set paths for your model size + export TRINITY_MODEL_PATH="/path/to/model" + export TRINITY_COMPILED_PATH="/path/to/compiled" + + # Run tests + pytest test/integration/test_trinity.py --capture=tee-sys + +Prerequisites: + - Pre-compiled model at TRINITY_COMPILED_PATH + - HuggingFace model weights at TRINITY_MODEL_PATH (downloaded with trust_remote_code=True) + - Appropriate instance for model size (see README.md) +""" + +import json +import logging +import os +import sys +import time +from pathlib import Path + +import pytest +import torch +from transformers import AutoTokenizer + +from neuronx_distributed_inference.models.config import MoENeuronConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + +# Import from src directory +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) +from modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig + +logger = logging.getLogger(__name__) + +# Configuration via environment variables +MODEL_PATH = os.environ.get("TRINITY_MODEL_PATH") +COMPILED_MODEL_PATH = os.environ.get("TRINITY_COMPILED_PATH") + +# Performance thresholds (conservative upper bounds -- fail if exceeded) +# These are generous limits to catch regressions, not tight benchmarks. +# TTFT: max acceptable latency in ms per forward pass +MAX_TTFT_MS = float(os.environ.get("TRINITY_MAX_TTFT_MS", "5000")) +# Throughput: min acceptable tokens/second (naive loop, not CTE+TKG pipeline) +MIN_THROUGHPUT_TOK_S = float(os.environ.get("TRINITY_MIN_THROUGHPUT_TOK_S", "0.5")) + +_MISSING_ENV = [] +if not MODEL_PATH: + _MISSING_ENV.append("TRINITY_MODEL_PATH") +if not COMPILED_MODEL_PATH: + _MISSING_ENV.append("TRINITY_COMPILED_PATH") + +if _MISSING_ENV: + pytestmark = pytest.mark.skip( + reason=f"Required environment variables not set: {', '.join(_MISSING_ENV)}" + ) + + +def load_neuron_config_from_compiled(compiled_path: str): + """Load neuron configuration from compiled model's neuron_config.json.""" + config_path = Path(compiled_path) / "neuron_config.json" + + if not config_path.exists(): + raise FileNotFoundError(f"neuron_config.json not found: {config_path}") + + with open(config_path) as f: + config_data = json.load(f) + + if "neuron_config" in config_data: + return config_data["neuron_config"] + else: + return config_data + + +def create_model_for_inference(compiled_path: str, model_path: str): + """Create model for inference using compiled neuron_config.""" + neuron_config_dict = load_neuron_config_from_compiled(compiled_path) + + dtype_str = neuron_config_dict.get("torch_dtype", "torch.bfloat16") + if isinstance(dtype_str, str): + dtype = ( + getattr(torch, dtype_str.split(".")[1]) + if dtype_str.startswith("torch.") + else torch.bfloat16 + ) + else: + dtype = dtype_str + + neuron_config_kwargs = { + "tp_degree": neuron_config_dict.get("tp_degree", 4), + "batch_size": neuron_config_dict.get("batch_size", 1), + "seq_len": neuron_config_dict.get("seq_len", 2048), + "torch_dtype": dtype, + } + + neuron_config = MoENeuronConfig(**neuron_config_kwargs) + + try: + model_config = TrinityInferenceConfig.from_pretrained( + model_path, neuron_config=neuron_config + ) + except (TypeError, AttributeError): + model_config = TrinityInferenceConfig( + neuron_config, + load_config=load_pretrained_config(model_path), + ) + + model = NeuronTrinityForCausalLM(model_path, model_config) + return model, neuron_config + + +def generate_with_neuron_model(model, input_ids, max_new_tokens: int): + """Generate tokens using manual forward pass loop.""" + generated_ids = input_ids.clone() + + for _ in range(max_new_tokens): + seq_len = generated_ids.shape[1] + position_ids = ( + torch.arange(seq_len).unsqueeze(0).expand(generated_ids.shape[0], -1) + ) + + with torch.no_grad(): + outputs = model(generated_ids, position_ids=position_ids) + + if hasattr(outputs, "logits"): + logits = outputs.logits + elif isinstance(outputs, tuple): + logits = outputs[0] + else: + logits = outputs + + next_token_logits = logits[:, -1, :] + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) + generated_ids = torch.cat([generated_ids, next_token], dim=-1) + + return generated_ids + + +@pytest.fixture(scope="module") +def compiled_model(): + """Load pre-compiled model.""" + model, neuron_config = create_model_for_inference(COMPILED_MODEL_PATH, MODEL_PATH) + model.load(COMPILED_MODEL_PATH) + return model + + +@pytest.fixture(scope="module") +def tokenizer(): + """Load tokenizer.""" + tokenizer = AutoTokenizer.from_pretrained( + MODEL_PATH, padding_side="right", trust_remote_code=True + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +def test_model_loads(compiled_model): + """Test that model loads successfully (smoke test).""" + assert compiled_model is not None + assert hasattr(compiled_model, "config") + logger.info("Smoke test passed - Model loaded successfully") + + +def test_model_generates(compiled_model, tokenizer): + """Test that model can generate text.""" + prompt = "Hello, how are you?" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + + generated_ids = generate_with_neuron_model( + compiled_model, inputs.input_ids, max_new_tokens=20 + ) + output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + assert len(output_text) > len(prompt), "Output should be longer than prompt" + logger.info("Generation test passed") + logger.info(" Output: %s", output_text) + + +def test_output_coherence(compiled_model, tokenizer): + """Test that output is coherent (not gibberish or repetitive).""" + prompt = "Hello, how are you?" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + + generated_ids = generate_with_neuron_model( + compiled_model, inputs.input_ids, max_new_tokens=30 + ) + output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + assert len(output_text.split()) > 3, "Output should have multiple words" + assert not _is_repetitive(output_text), "Output should not be repetitive" + + logger.info("Coherence test passed") + logger.info(" Output: %s...", output_text[:100]) + + +def test_top_token_valid(compiled_model, tokenizer): + """Test that the top predicted token is a valid, decodable token. + + Unlike model-specific tests, this does not check for a specific expected token + since different Trinity sizes produce different outputs. + """ + prompt = "Hello, how are you?" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + + seq_len = inputs.input_ids.shape[1] + position_ids = ( + torch.arange(seq_len).unsqueeze(0).expand(inputs.input_ids.shape[0], -1) + ) + + with torch.no_grad(): + outputs = compiled_model(inputs.input_ids, position_ids=position_ids) + + if hasattr(outputs, "logits"): + logits = outputs.logits + elif isinstance(outputs, tuple): + logits = outputs[0] + else: + logits = outputs + + next_token_logits = logits[:, -1, :] + top_token_id = torch.argmax(next_token_logits, dim=-1).item() + top_token = tokenizer.decode([top_token_id]).strip() + + logger.info("Top token: '%s' (id=%d)", top_token, top_token_id) + logger.info("Top logit: %.2f", next_token_logits[0, top_token_id].item()) + + # The top token should be a non-empty, printable string + assert len(top_token) > 0, f"Top token should be non-empty, got '{top_token}'" + assert top_token_id < tokenizer.vocab_size, "Token ID should be within vocab range" + logger.info("Top token validation passed") + + +def _is_repetitive(text: str, max_repeat: int = 5) -> bool: + """Check if text has excessive repetition.""" + words = text.split() + if len(words) < 10: + return False + + for i in range(len(words) - max_repeat): + word = words[i] + if all(words[i + j] == word for j in range(max_repeat)): + return True + + new_text = text[-100:] if len(text) > 100 else text + if len(new_text) > 20: + char_counts = {} + for c in new_text: + char_counts[c] = char_counts.get(c, 0) + 1 + max_char_ratio = max(char_counts.values()) / len(new_text) + if max_char_ratio > 0.5: + return True + + return False + + +def test_performance_ttft(compiled_model, tokenizer): + """Test Time To First Token (TTFT) performance. + + Pass criteria: avg TTFT must be below MAX_TTFT_MS (default 5000ms). + Override with TRINITY_MAX_TTFT_MS environment variable. + """ + prompt = "Hello, how are you?" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids + + # Warmup + for _ in range(3): + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len).unsqueeze(0).expand(input_ids.shape[0], -1) + with torch.no_grad(): + _ = compiled_model(input_ids, position_ids=position_ids) + + # Measure TTFT + times = [] + for _ in range(10): + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len).unsqueeze(0).expand(input_ids.shape[0], -1) + + start = time.perf_counter() + with torch.no_grad(): + _ = compiled_model(input_ids, position_ids=position_ids) + end = time.perf_counter() + + times.append((end - start) * 1000) + + avg_ttft = sum(times) / len(times) + min_ttft = min(times) + max_ttft = max(times) + logger.info( + "TTFT: avg=%.2fms, min=%.2fms, max=%.2fms (threshold: %.0fms)", + avg_ttft, + min_ttft, + max_ttft, + MAX_TTFT_MS, + ) + assert avg_ttft < MAX_TTFT_MS, ( + f"TTFT regression: {avg_ttft:.1f}ms exceeds threshold {MAX_TTFT_MS:.0f}ms" + ) + + +def test_performance_throughput(compiled_model, tokenizer): + """Test token generation throughput using naive forward loop. + + Pass criteria: throughput must exceed MIN_THROUGHPUT_TOK_S (default 0.5 tok/s). + Override with TRINITY_MIN_THROUGHPUT_TOK_S environment variable. + + NOTE: This uses a naive loop (re-encodes full context each token), so throughput + is much lower than proper CTE+TKG pipeline. The threshold is intentionally low. + """ + prompt = "Hello" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids + num_tokens = 50 + + # Warmup + _ = generate_with_neuron_model(compiled_model, input_ids, max_new_tokens=5) + + # Measure throughput + start = time.perf_counter() + _ = generate_with_neuron_model(compiled_model, input_ids, max_new_tokens=num_tokens) + end = time.perf_counter() + + total_time = end - start + throughput = num_tokens / total_time + logger.info( + "Throughput: %.2f tok/s (%d tokens in %.1fs, threshold: %.1f tok/s)", + throughput, + num_tokens, + total_time, + MIN_THROUGHPUT_TOK_S, + ) + assert throughput > MIN_THROUGHPUT_TOK_S, ( + f"Throughput regression: {throughput:.2f} tok/s below threshold " + f"{MIN_THROUGHPUT_TOK_S:.1f} tok/s" + ) + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(name)s %(levelname)s: %(message)s", + ) + + logger.info("=" * 80) + logger.info("Trinity (AfmoeForCausalLM) Integration Tests") + logger.info("=" * 80) + logger.info("Model path: %s", MODEL_PATH) + logger.info("Compiled path: %s", COMPILED_MODEL_PATH) + + logger.info("Loading compiled model from %s...", COMPILED_MODEL_PATH) + model, neuron_config = create_model_for_inference(COMPILED_MODEL_PATH, MODEL_PATH) + model.load(COMPILED_MODEL_PATH) + logger.info("Model loaded") + + tokenizer = AutoTokenizer.from_pretrained( + MODEL_PATH, padding_side="right", trust_remote_code=True + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + logger.info("") + logger.info("=" * 80) + logger.info("Running Tests") + logger.info("=" * 80) + + logger.info("") + logger.info("1. Smoke Test (Model Loading)...") + test_model_loads(model) + + logger.info("") + logger.info("2. Generation Test...") + test_model_generates(model, tokenizer) + + logger.info("") + logger.info("3. Coherence Test...") + test_output_coherence(model, tokenizer) + + logger.info("") + logger.info("4. Top Token Validation...") + test_top_token_valid(model, tokenizer) + + logger.info("") + logger.info("=" * 80) + logger.info("All tests passed!") + logger.info("=" * 80) diff --git a/contrib/models/Trinity/test/integration/test_trinity.py b/contrib/models/Trinity/test/integration/test_trinity.py new file mode 100644 index 0000000..9a1d40b --- /dev/null +++ b/contrib/models/Trinity/test/integration/test_trinity.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Integration tests for Trinity (AfmoeForCausalLM) NeuronX implementation. + +Supports all three Trinity model sizes (Nano, Mini, Large) via environment variables. + +Usage: + # Set paths for your model size + export TRINITY_MODEL_PATH="/path/to/model" + export TRINITY_COMPILED_PATH="/path/to/compiled" + + # Run tests + pytest test/integration/test_trinity.py --capture=tee-sys + +Prerequisites: + - Pre-compiled model at TRINITY_COMPILED_PATH + - HuggingFace model weights at TRINITY_MODEL_PATH (downloaded with trust_remote_code=True) + - Appropriate instance for model size (see README.md) +""" + +import json +import logging +import os +import sys +import time +from pathlib import Path + +import pytest +import torch +from transformers import AutoTokenizer + +from neuronx_distributed_inference.models.config import MoENeuronConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + +# Import from src directory +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) +from modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig + +logger = logging.getLogger(__name__) + +# Configuration via environment variables +MODEL_PATH = os.environ.get("TRINITY_MODEL_PATH") +COMPILED_MODEL_PATH = os.environ.get("TRINITY_COMPILED_PATH") + +# Performance thresholds (conservative upper bounds -- fail if exceeded) +# These are generous limits to catch regressions, not tight benchmarks. +# TTFT: max acceptable latency in ms per forward pass +MAX_TTFT_MS = float(os.environ.get("TRINITY_MAX_TTFT_MS", "5000")) +# Throughput: min acceptable tokens/second (naive loop, not CTE+TKG pipeline) +MIN_THROUGHPUT_TOK_S = float(os.environ.get("TRINITY_MIN_THROUGHPUT_TOK_S", "0.5")) + +_MISSING_ENV = [] +if not MODEL_PATH: + _MISSING_ENV.append("TRINITY_MODEL_PATH") +if not COMPILED_MODEL_PATH: + _MISSING_ENV.append("TRINITY_COMPILED_PATH") + +if _MISSING_ENV: + pytestmark = pytest.mark.skip( + reason=f"Required environment variables not set: {', '.join(_MISSING_ENV)}" + ) + + +def load_neuron_config_from_compiled(compiled_path: str): + """Load neuron configuration from compiled model's neuron_config.json.""" + config_path = Path(compiled_path) / "neuron_config.json" + + if not config_path.exists(): + raise FileNotFoundError(f"neuron_config.json not found: {config_path}") + + with open(config_path) as f: + config_data = json.load(f) + + if "neuron_config" in config_data: + return config_data["neuron_config"] + else: + return config_data + + +def create_model_for_inference(compiled_path: str, model_path: str): + """Create model for inference using compiled neuron_config.""" + neuron_config_dict = load_neuron_config_from_compiled(compiled_path) + + dtype_str = neuron_config_dict.get("torch_dtype", "torch.bfloat16") + if isinstance(dtype_str, str): + dtype = ( + getattr(torch, dtype_str.split(".")[1]) + if dtype_str.startswith("torch.") + else torch.bfloat16 + ) + else: + dtype = dtype_str + + neuron_config_kwargs = { + "tp_degree": neuron_config_dict.get("tp_degree", 4), + "batch_size": neuron_config_dict.get("batch_size", 1), + "seq_len": neuron_config_dict.get("seq_len", 2048), + "torch_dtype": dtype, + } + + neuron_config = MoENeuronConfig(**neuron_config_kwargs) + + try: + model_config = TrinityInferenceConfig.from_pretrained( + model_path, neuron_config=neuron_config + ) + except (TypeError, AttributeError): + model_config = TrinityInferenceConfig( + neuron_config, + load_config=load_pretrained_config(model_path), + ) + + model = NeuronTrinityForCausalLM(model_path, model_config) + return model, neuron_config + + +def generate_with_neuron_model(model, input_ids, max_new_tokens: int): + """Generate tokens using manual forward pass loop.""" + generated_ids = input_ids.clone() + + for _ in range(max_new_tokens): + seq_len = generated_ids.shape[1] + position_ids = ( + torch.arange(seq_len).unsqueeze(0).expand(generated_ids.shape[0], -1) + ) + + with torch.no_grad(): + outputs = model(generated_ids, position_ids=position_ids) + + if hasattr(outputs, "logits"): + logits = outputs.logits + elif isinstance(outputs, tuple): + logits = outputs[0] + else: + logits = outputs + + next_token_logits = logits[:, -1, :] + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) + generated_ids = torch.cat([generated_ids, next_token], dim=-1) + + return generated_ids + + +@pytest.fixture(scope="module") +def compiled_model(): + """Load pre-compiled model.""" + model, neuron_config = create_model_for_inference(COMPILED_MODEL_PATH, MODEL_PATH) + model.load(COMPILED_MODEL_PATH) + return model + + +@pytest.fixture(scope="module") +def tokenizer(): + """Load tokenizer.""" + tokenizer = AutoTokenizer.from_pretrained( + MODEL_PATH, padding_side="right", trust_remote_code=True + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +def test_model_loads(compiled_model): + """Test that model loads successfully (smoke test).""" + assert compiled_model is not None + assert hasattr(compiled_model, "config") + logger.info("Smoke test passed - Model loaded successfully") + + +def test_model_generates(compiled_model, tokenizer): + """Test that model can generate text.""" + prompt = "Hello, how are you?" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + + generated_ids = generate_with_neuron_model( + compiled_model, inputs.input_ids, max_new_tokens=20 + ) + output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + assert len(output_text) > len(prompt), "Output should be longer than prompt" + logger.info("Generation test passed") + logger.info(" Output: %s", output_text) + + +def test_output_coherence(compiled_model, tokenizer): + """Test that output is coherent (not gibberish or repetitive).""" + prompt = "Hello, how are you?" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + + generated_ids = generate_with_neuron_model( + compiled_model, inputs.input_ids, max_new_tokens=30 + ) + output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + assert len(output_text.split()) > 3, "Output should have multiple words" + assert not _is_repetitive(output_text), "Output should not be repetitive" + + logger.info("Coherence test passed") + logger.info(" Output: %s...", output_text[:100]) + + +def test_top_token_valid(compiled_model, tokenizer): + """Test that the top predicted token is a valid, decodable token. + + Unlike model-specific tests, this does not check for a specific expected token + since different Trinity sizes produce different outputs. + """ + prompt = "Hello, how are you?" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + + seq_len = inputs.input_ids.shape[1] + position_ids = ( + torch.arange(seq_len).unsqueeze(0).expand(inputs.input_ids.shape[0], -1) + ) + + with torch.no_grad(): + outputs = compiled_model(inputs.input_ids, position_ids=position_ids) + + if hasattr(outputs, "logits"): + logits = outputs.logits + elif isinstance(outputs, tuple): + logits = outputs[0] + else: + logits = outputs + + next_token_logits = logits[:, -1, :] + top_token_id = torch.argmax(next_token_logits, dim=-1).item() + top_token = tokenizer.decode([top_token_id]).strip() + + logger.info("Top token: '%s' (id=%d)", top_token, top_token_id) + logger.info("Top logit: %.2f", next_token_logits[0, top_token_id].item()) + + # The top token should be a non-empty, printable string + assert len(top_token) > 0, f"Top token should be non-empty, got '{top_token}'" + assert top_token_id < tokenizer.vocab_size, "Token ID should be within vocab range" + logger.info("Top token validation passed") + + +def _is_repetitive(text: str, max_repeat: int = 5) -> bool: + """Check if text has excessive repetition.""" + words = text.split() + if len(words) < 10: + return False + + for i in range(len(words) - max_repeat): + word = words[i] + if all(words[i + j] == word for j in range(max_repeat)): + return True + + new_text = text[-100:] if len(text) > 100 else text + if len(new_text) > 20: + char_counts = {} + for c in new_text: + char_counts[c] = char_counts.get(c, 0) + 1 + max_char_ratio = max(char_counts.values()) / len(new_text) + if max_char_ratio > 0.5: + return True + + return False + + +def test_performance_ttft(compiled_model, tokenizer): + """Test Time To First Token (TTFT) performance. + + Pass criteria: avg TTFT must be below MAX_TTFT_MS (default 5000ms). + Override with TRINITY_MAX_TTFT_MS environment variable. + """ + prompt = "Hello, how are you?" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids + + # Warmup + for _ in range(3): + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len).unsqueeze(0).expand(input_ids.shape[0], -1) + with torch.no_grad(): + _ = compiled_model(input_ids, position_ids=position_ids) + + # Measure TTFT + times = [] + for _ in range(10): + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len).unsqueeze(0).expand(input_ids.shape[0], -1) + + start = time.perf_counter() + with torch.no_grad(): + _ = compiled_model(input_ids, position_ids=position_ids) + end = time.perf_counter() + + times.append((end - start) * 1000) + + avg_ttft = sum(times) / len(times) + min_ttft = min(times) + max_ttft = max(times) + logger.info( + "TTFT: avg=%.2fms, min=%.2fms, max=%.2fms (threshold: %.0fms)", + avg_ttft, + min_ttft, + max_ttft, + MAX_TTFT_MS, + ) + assert avg_ttft < MAX_TTFT_MS, ( + f"TTFT regression: {avg_ttft:.1f}ms exceeds threshold {MAX_TTFT_MS:.0f}ms" + ) + + +def test_performance_throughput(compiled_model, tokenizer): + """Test token generation throughput using naive forward loop. + + Pass criteria: throughput must exceed MIN_THROUGHPUT_TOK_S (default 0.5 tok/s). + Override with TRINITY_MIN_THROUGHPUT_TOK_S environment variable. + + NOTE: This uses a naive loop (re-encodes full context each token), so throughput + is much lower than proper CTE+TKG pipeline. The threshold is intentionally low. + """ + prompt = "Hello" + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids + num_tokens = 50 + + # Warmup + _ = generate_with_neuron_model(compiled_model, input_ids, max_new_tokens=5) + + # Measure throughput + start = time.perf_counter() + _ = generate_with_neuron_model(compiled_model, input_ids, max_new_tokens=num_tokens) + end = time.perf_counter() + + total_time = end - start + throughput = num_tokens / total_time + logger.info( + "Throughput: %.2f tok/s (%d tokens in %.1fs, threshold: %.1f tok/s)", + throughput, + num_tokens, + total_time, + MIN_THROUGHPUT_TOK_S, + ) + assert throughput > MIN_THROUGHPUT_TOK_S, ( + f"Throughput regression: {throughput:.2f} tok/s below threshold " + f"{MIN_THROUGHPUT_TOK_S:.1f} tok/s" + ) + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(name)s %(levelname)s: %(message)s", + ) + + logger.info("=" * 80) + logger.info("Trinity (AfmoeForCausalLM) Integration Tests") + logger.info("=" * 80) + logger.info("Model path: %s", MODEL_PATH) + logger.info("Compiled path: %s", COMPILED_MODEL_PATH) + + logger.info("Loading compiled model from %s...", COMPILED_MODEL_PATH) + model, neuron_config = create_model_for_inference(COMPILED_MODEL_PATH, MODEL_PATH) + model.load(COMPILED_MODEL_PATH) + logger.info("Model loaded") + + tokenizer = AutoTokenizer.from_pretrained( + MODEL_PATH, padding_side="right", trust_remote_code=True + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + logger.info("") + logger.info("=" * 80) + logger.info("Running Tests") + logger.info("=" * 80) + + logger.info("") + logger.info("1. Smoke Test (Model Loading)...") + test_model_loads(model) + + logger.info("") + logger.info("2. Generation Test...") + test_model_generates(model, tokenizer) + + logger.info("") + logger.info("3. Coherence Test...") + test_output_coherence(model, tokenizer) + + logger.info("") + logger.info("4. Top Token Validation...") + test_top_token_valid(model, tokenizer) + + logger.info("") + logger.info("=" * 80) + logger.info("All tests passed!") + logger.info("=" * 80) diff --git a/contrib/models/Trinity/test/unit/__init__.py b/contrib/models/Trinity/test/unit/__init__.py new file mode 100644 index 0000000..ce63f8f --- /dev/null +++ b/contrib/models/Trinity/test/unit/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/contrib/models/Trinity/test/unit/test_config.py b/contrib/models/Trinity/test/unit/test_config.py new file mode 100644 index 0000000..2b7ef46 --- /dev/null +++ b/contrib/models/Trinity/test/unit/test_config.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +CPU-only unit tests for TrinityInferenceConfig. + +These tests verify config parsing, parameter transformation, and derived +configuration without requiring Neuron hardware or a downloaded model. + +Usage: + pytest test/unit/test_config.py -v +""" + +import json +import os +import tempfile + +import pytest +import torch + +from neuronx_distributed_inference.models.config import MoENeuronConfig + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) +from modeling_trinity import TrinityInferenceConfig + + +# Minimal Nano config dict (matches arcee-ai/Trinity-Nano-Preview/config.json) +NANO_CONFIG = { + "architectures": ["AfmoeForCausalLM"], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_afmoe.AfmoeConfig", + "AutoModel": "modeling_afmoe.AfmoeModel", + "AutoModelForCausalLM": "modeling_afmoe.AfmoeForCausalLM", + }, + "dtype": "bfloat16", + "global_attn_every_n_layers": 4, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 3072, + "max_position_embeddings": 131072, + "model_type": "afmoe", + "moe_intermediate_size": 256, + "mup_enabled": True, + "n_group": 1, + "num_attention_heads": 8, + "num_dense_layers": 2, + "num_expert_groups": 1, + "num_experts": 128, + "num_experts_per_tok": 8, + "num_hidden_layers": 56, + "num_key_value_heads": 2, + "num_limited_groups": 1, + "num_shared_experts": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 10000, + "route_norm": True, + "route_scale": 2.826, + "score_func": "sigmoid", + "sliding_window": 2048, + "tie_word_embeddings": False, + "topk_group": 1, + "transformers_version": "4.57.3", + "use_cache": True, + "use_grouped_mm": True, + "vocab_size": 200192, +} + +# Mini config overrides (key differences from Nano) +MINI_OVERRIDES = { + "hidden_size": 2048, + "intermediate_size": 6144, + "moe_intermediate_size": 1024, + "num_attention_heads": 32, + "num_dense_layers": 2, + "num_hidden_layers": 32, + "num_key_value_heads": 4, +} + +# Large config overrides +LARGE_OVERRIDES = { + "hidden_size": 3072, + "intermediate_size": 12288, + "moe_intermediate_size": 3072, + "num_attention_heads": 48, + "num_dense_layers": 6, + "num_experts": 256, + "num_experts_per_tok": 4, + "num_hidden_layers": 60, + "num_key_value_heads": 8, + "sliding_window": 4096, +} + + +def make_config(overrides=None, tp_degree=2, seq_len=2048, batch_size=1): + """Create a TrinityInferenceConfig from dict with optional overrides.""" + config_dict = NANO_CONFIG.copy() + if overrides: + config_dict.update(overrides) + + neuron_config = MoENeuronConfig( + tp_degree=tp_degree, + batch_size=batch_size, + seq_len=seq_len, + torch_dtype=torch.bfloat16, + ) + return TrinityInferenceConfig(neuron_config=neuron_config, **config_dict) + + +class TestConfigParsing: + """Test that config is correctly parsed from HF config dict.""" + + def test_nano_basic_params(self): + config = make_config() + assert config.vocab_size == 200192 + assert config.hidden_size == 1024 + assert config.num_hidden_layers == 56 + assert config.num_attention_heads == 8 + assert config.num_key_value_heads == 2 + assert config.head_dim == 128 + assert config.num_experts == 128 + assert config.num_experts_per_tok == 8 + assert config.num_shared_experts == 1 + assert config.num_dense_layers == 2 + + def test_intermediate_size_swap(self): + """CRITICAL: intermediate_size must be MoE size, not dense size.""" + config = make_config() + # intermediate_size should be moe_intermediate_size (for NxDI MoE module) + assert config.intermediate_size == 256, ( + "intermediate_size must equal moe_intermediate_size for initialize_moe_module" + ) + # Dense intermediate preserved separately + assert config.dense_intermediate_size == 3072 + assert config.moe_intermediate_size == 256 + + def test_mini_intermediate_size_swap(self): + config = make_config(MINI_OVERRIDES, tp_degree=4) + assert config.intermediate_size == 1024 + assert config.dense_intermediate_size == 6144 + assert config.moe_intermediate_size == 1024 + + def test_large_intermediate_size_swap(self): + config = make_config(LARGE_OVERRIDES, tp_degree=64) + assert config.intermediate_size == 3072 + assert config.dense_intermediate_size == 12288 + assert config.moe_intermediate_size == 3072 + + def test_shared_experts_forced_zero(self): + """n_shared_experts must be 0 for NxDI (Trinity handles shared experts manually).""" + config = make_config() + assert config.n_shared_experts == 0 + # But num_shared_experts preserves the real count + assert config.num_shared_experts == 1 + + def test_glu_type_set(self): + """Trinity uses SiLU gated MLP which maps to NxDI glu_type='glu'.""" + config = make_config() + assert config.neuron_config.glu_type == "glu" + assert config.neuron_config.glu_mlp is True + + def test_score_func_sigmoid(self): + config = make_config() + assert config.score_func == "sigmoid" + + def test_route_scale(self): + config = make_config() + assert config.route_scale == 2.826 + + +class TestLayerTypes: + """Test mixed attention layer type generation.""" + + def test_nano_layer_types_count(self): + config = make_config() + assert len(config.layer_types) == 56 + + def test_layer_types_pattern(self): + """Every 4th layer (0-indexed: 3, 7, 11, ...) should be full_attention.""" + config = make_config() + for i, lt in enumerate(config.layer_types): + if (i + 1) % 4 == 0: + assert lt == "full_attention", f"Layer {i} should be full_attention" + else: + assert lt == "sliding_attention", ( + f"Layer {i} should be sliding_attention" + ) + + def test_nano_full_attention_count(self): + config = make_config() + full = sum(1 for lt in config.layer_types if lt == "full_attention") + sliding = sum(1 for lt in config.layer_types if lt == "sliding_attention") + assert full == 14 # 56 / 4 = 14 + assert sliding == 42 + + def test_mini_layer_types(self): + config = make_config(MINI_OVERRIDES, tp_degree=4) + assert len(config.layer_types) == 32 + full = sum(1 for lt in config.layer_types if lt == "full_attention") + assert full == 8 # 32 / 4 = 8 + + def test_explicit_layer_types_preserved(self): + """If layer_types is provided in config, it should be preserved as-is.""" + overrides = {"layer_types": ["sliding_attention", "full_attention"] * 28} + config = make_config(overrides) + assert config.layer_types == ["sliding_attention", "full_attention"] * 28 + + +class TestSlidingWindowClamping: + """Test that sliding_window is clamped to seq_len when seq_len < sliding_window.""" + + def test_no_clamping_when_seq_ge_window(self): + config = make_config(seq_len=2048) + assert config.sliding_window == 2048 + + def test_clamping_when_seq_lt_window(self): + config = make_config(seq_len=1024) + assert config.sliding_window == 1024 + + def test_no_clamping_large_seq(self): + config = make_config(seq_len=8192) + assert config.sliding_window == 2048 + + +class TestFromPretrained: + """Test from_pretrained loading from a config.json file.""" + + def test_from_pretrained_loads_config(self): + with tempfile.TemporaryDirectory() as tmpdir: + config_path = os.path.join(tmpdir, "config.json") + with open(config_path, "w") as f: + json.dump(NANO_CONFIG, f) + + neuron_config = MoENeuronConfig( + tp_degree=2, + batch_size=1, + seq_len=2048, + torch_dtype=torch.bfloat16, + ) + config = TrinityInferenceConfig.from_pretrained( + tmpdir, neuron_config=neuron_config + ) + + assert config.vocab_size == 200192 + assert config.hidden_size == 1024 + assert config.intermediate_size == 256 # Swapped to MoE size + assert config.num_hidden_layers == 56 + + def test_from_pretrained_missing_config(self): + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(FileNotFoundError): + TrinityInferenceConfig.from_pretrained(tmpdir) + + +class TestFusedMoeTkgEligibility: + """Test automatic fused MoE TKG kernel eligibility detection.""" + + def test_nano_tp2_eligible(self): + """Nano (moe_intermediate=256) at TP=2: 256/2=128, 128%128=0 → eligible.""" + config = make_config(tp_degree=2) + # After add_derived_config, check if fused TKG was enabled + # The check is: moe_intermediate_size / moe_tp_degree % 128 == 0 + per_tp = config.moe_intermediate_size // 2 + assert per_tp % 128 == 0 + + def test_nano_tp4_ineligible(self): + """Nano (moe_intermediate=256) at TP=4: 256/4=64, 64%128!=0 → ineligible.""" + config = make_config(tp_degree=4) + per_tp = config.moe_intermediate_size // 4 + assert per_tp % 128 != 0 + + def test_mini_tp4_eligible(self): + """Mini (moe_intermediate=1024) at TP=4: 1024/4=256, 256%128=0 → eligible.""" + config = make_config(MINI_OVERRIDES, tp_degree=4) + per_tp = config.moe_intermediate_size // 4 + assert per_tp % 128 == 0 + + def test_large_tp64_ineligible(self): + """Large (moe_intermediate=3072) at TP=64: 3072/64=48, 48%128!=0 → ineligible.""" + config = make_config(LARGE_OVERRIDES, tp_degree=64) + per_tp = config.moe_intermediate_size // 64 + assert per_tp % 128 != 0 diff --git a/contrib/models/Trinity/test/unit/test_weight_conversion.py b/contrib/models/Trinity/test/unit/test_weight_conversion.py new file mode 100644 index 0000000..7633064 --- /dev/null +++ b/contrib/models/Trinity/test/unit/test_weight_conversion.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +CPU-only unit tests for Trinity HF→Neuron weight conversion. + +Tests verify key weight name mappings, transformations (muP scaling, +route_scale baking, expert stacking, gate padding), and edge cases +without requiring Neuron hardware or model weights. + +Usage: + pytest test/unit/test_weight_conversion.py -v +""" + +import math +import pytest +import torch + +from neuronx_distributed_inference.models.config import MoENeuronConfig + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) +from modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig + + +# Minimal Nano config for weight conversion tests +NANO_CONFIG = { + "vocab_size": 200192, + "hidden_size": 1024, + "intermediate_size": 3072, + "moe_intermediate_size": 256, + "num_hidden_layers": 56, + "num_dense_layers": 2, + "num_attention_heads": 8, + "num_key_value_heads": 2, + "head_dim": 128, + "hidden_act": "silu", + "max_position_embeddings": 131072, + "rms_norm_eps": 1e-05, + "rope_theta": 10000, + "num_experts": 128, + "num_experts_per_tok": 8, + "num_shared_experts": 1, + "score_func": "sigmoid", + "route_norm": True, + "route_scale": 2.826, + "sliding_window": 2048, + "mup_enabled": True, + "global_attn_every_n_layers": 4, + "tie_word_embeddings": False, + "n_group": 1, + "topk_group": 1, +} + + +def make_config(tp_degree=2): + """Create a TrinityInferenceConfig for testing.""" + neuron_config = MoENeuronConfig( + tp_degree=tp_degree, + batch_size=1, + seq_len=2048, + torch_dtype=torch.bfloat16, + ) + return TrinityInferenceConfig(neuron_config=neuron_config, **NANO_CONFIG.copy()) + + +def make_minimal_state_dict(config, num_layers=2, num_experts=4): + """Create a minimal HF state dict for testing weight conversion. + + Only includes enough layers and experts to validate transformations. + Uses small tensor sizes for fast CPU testing. + """ + H = config.hidden_size # 1024 + I_dense = config.dense_intermediate_size # 3072 + I_moe = config.moe_intermediate_size # 256 + V = config.vocab_size # 200192 + n_heads = config.num_attention_heads # 8 + n_kv_heads = config.num_key_value_heads # 2 + head_dim = config.head_dim # 128 + + sd = {} + + # Embedding + sd["model.embed_tokens.weight"] = torch.randn(V, H) + + # LM head + sd["model.norm.weight"] = torch.ones(H) + sd["lm_head.weight"] = torch.randn(V, H) + + for layer_idx in range(num_layers): + prefix = f"model.layers.{layer_idx}" + + # Attention weights + sd[f"{prefix}.self_attn.q_proj.weight"] = torch.randn(n_heads * head_dim, H) + sd[f"{prefix}.self_attn.k_proj.weight"] = torch.randn(n_kv_heads * head_dim, H) + sd[f"{prefix}.self_attn.v_proj.weight"] = torch.randn(n_kv_heads * head_dim, H) + sd[f"{prefix}.self_attn.o_proj.weight"] = torch.randn(H, n_heads * head_dim) + + # QK norms (Trinity-specific) + sd[f"{prefix}.self_attn.q_norm.weight"] = torch.ones(head_dim) + sd[f"{prefix}.self_attn.k_norm.weight"] = torch.ones(head_dim) + + # Attention gate (gated attention) + sd[f"{prefix}.self_attn.gate_proj.weight"] = torch.randn(n_heads, H) + + # Layer norms + sd[f"{prefix}.input_layernorm.weight"] = torch.ones(H) + sd[f"{prefix}.post_attention_layernorm.weight"] = torch.ones(H) + sd[f"{prefix}.pre_feedforward_layernorm.weight"] = torch.ones(H) + sd[f"{prefix}.post_feedforward_layernorm.weight"] = torch.ones(H) + + if layer_idx < config.num_dense_layers: + # Dense MLP layers + sd[f"{prefix}.mlp.gate_proj.weight"] = torch.randn(I_dense, H) + sd[f"{prefix}.mlp.up_proj.weight"] = torch.randn(I_dense, H) + sd[f"{prefix}.mlp.down_proj.weight"] = torch.randn(H, I_dense) + else: + # MoE layers + sd[f"{prefix}.mlp.router.gate.weight"] = torch.randn(num_experts, H) + sd[f"{prefix}.mlp.expert_bias"] = torch.randn(num_experts) + + for e in range(num_experts): + sd[f"{prefix}.mlp.experts.{e}.gate_proj.weight"] = torch.randn(I_moe, H) + sd[f"{prefix}.mlp.experts.{e}.up_proj.weight"] = torch.randn(I_moe, H) + sd[f"{prefix}.mlp.experts.{e}.down_proj.weight"] = torch.randn(H, I_moe) + + # Shared expert (no index -- Trinity uses shared_experts.{proj} directly) + sd[f"{prefix}.mlp.shared_experts.gate_proj.weight"] = torch.randn(I_moe, H) + sd[f"{prefix}.mlp.shared_experts.up_proj.weight"] = torch.randn(I_moe, H) + sd[f"{prefix}.mlp.shared_experts.down_proj.weight"] = torch.randn(H, I_moe) + + return sd + + +class TestModelPrefixRemoval: + """Test that 'model.' prefix is correctly removed from HF keys.""" + + def test_model_prefix_removed(self): + config = make_config() + sd = {"model.embed_tokens.weight": torch.randn(200192, 1024)} + result = NeuronTrinityForCausalLM.convert_hf_to_neuron_state_dict(sd, config) + assert "embed_tokens.weight" in result + assert "model.embed_tokens.weight" not in result + + +class TestMuPScaling: + """Test muP embedding scaling: embed_weight *= sqrt(hidden_size).""" + + def test_mup_scaling_applied(self): + config = make_config() + original_embed = torch.randn(200192, 1024) + sd = {"model.embed_tokens.weight": original_embed.clone()} + result = NeuronTrinityForCausalLM.convert_hf_to_neuron_state_dict(sd, config) + + # Conversion outputs bf16, so cast expected to match + expected = (original_embed * math.sqrt(1024)).to(torch.bfloat16) + assert torch.allclose(result["embed_tokens.weight"], expected, atol=1e-2), ( + "muP scaling should multiply embedding by sqrt(hidden_size)" + ) + + +class TestQKNormRename: + """Test q_norm → q_layernorm, k_norm → k_layernorm rename.""" + + def test_qk_norm_renamed(self): + config = make_config() + sd = { + "model.layers.0.self_attn.q_norm.weight": torch.ones(128), + "model.layers.0.self_attn.k_norm.weight": torch.ones(128), + } + result = NeuronTrinityForCausalLM.convert_hf_to_neuron_state_dict(sd, config) + + assert "layers.0.self_attn.q_layernorm.weight" in result + assert "layers.0.self_attn.k_layernorm.weight" in result + assert "layers.0.self_attn.q_norm.weight" not in result + assert "layers.0.self_attn.k_norm.weight" not in result + + +class TestAttentionGateRename: + """Test self_attn.gate_proj → self_attn.attn_gate_proj rename.""" + + def test_gate_renamed(self): + config = make_config() + sd = { + "model.layers.0.self_attn.gate_proj.weight": torch.randn(8, 1024), + } + result = NeuronTrinityForCausalLM.convert_hf_to_neuron_state_dict(sd, config) + + assert "layers.0.self_attn.attn_gate_proj.weight" in result + assert "layers.0.self_attn.gate_proj.weight" not in result + + +class TestRouterRename: + """Test router.gate.weight → router.linear_router.weight rename.""" + + def test_router_weight_renamed(self): + config = make_config() + sd = { + "model.layers.2.mlp.router.gate.weight": torch.randn(128, 1024), + } + result = NeuronTrinityForCausalLM.convert_hf_to_neuron_state_dict(sd, config) + + assert "layers.2.mlp.router.linear_router.weight" in result + assert "layers.2.mlp.router.gate.weight" not in result + + +class TestExpertBiasMapping: + """Test expert_bias mapping: mlp.expert_bias → mlp.router.expert_bias.""" + + def test_expert_bias_mapped(self): + config = make_config() + bias = torch.randn(128) + sd = {"model.layers.2.mlp.expert_bias": bias.clone()} + result = NeuronTrinityForCausalLM.convert_hf_to_neuron_state_dict(sd, config) + + assert "layers.2.mlp.router.expert_bias" in result + assert torch.equal(result["layers.2.mlp.router.expert_bias"], bias) + + def test_expert_bias_kept_float32(self): + """Expert bias should remain float32 (not converted to bf16).""" + config = make_config() + bias = torch.randn(128, dtype=torch.float32) + sd = {"model.layers.2.mlp.expert_bias": bias} + result = NeuronTrinityForCausalLM.convert_hf_to_neuron_state_dict(sd, config) + + assert result["layers.2.mlp.router.expert_bias"].dtype == torch.float32 + + +class TestExpertWeightStacking: + """Test per-expert weights are stacked into [E, H, 2*I] format.""" + + def test_expert_stacking(self): + config = make_config() + H = 1024 + I = 256 + # Must provide all num_experts from config (128) since the + # conversion iterates over all experts and aborts if any are missing. + num_experts = config.num_local_experts + + sd = {} + for e in range(num_experts): + sd[f"model.layers.2.mlp.experts.{e}.gate_proj.weight"] = torch.randn(I, H) + sd[f"model.layers.2.mlp.experts.{e}.up_proj.weight"] = torch.randn(I, H) + sd[f"model.layers.2.mlp.experts.{e}.down_proj.weight"] = torch.randn(H, I) + + # Need router key too for the conversion to proceed + sd["model.layers.2.mlp.router.gate.weight"] = torch.randn(num_experts, H) + + result = NeuronTrinityForCausalLM.convert_hf_to_neuron_state_dict(sd, config) + + # gate_up_proj should be [E, H, 2*I] (key includes .mlp_op. and .weight) + gate_up_key = "layers.2.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + assert gate_up_key in result, ( + f"Expected key '{gate_up_key}' in result. Keys: " + f"{[k for k in result if 'expert' in k]}" + ) + assert result[gate_up_key].shape == (num_experts, H, 2 * I) + + # down_proj should be [E, I, H] + down_key = "layers.2.mlp.expert_mlps.mlp_op.down_proj.weight" + assert down_key in result + assert result[down_key].shape == (num_experts, I, H) + + +class TestRouteScaleBaking: + """Test that route_scale is baked into routed expert down_proj weights.""" + + def test_route_scale_applied_to_down_proj(self): + config = make_config() + H = 1024 + I = 256 + # Must provide all num_experts from config + num_experts = config.num_local_experts + route_scale = config.route_scale # 2.826 + + sd = {} + down_projs = [] + for e in range(num_experts): + sd[f"model.layers.2.mlp.experts.{e}.gate_proj.weight"] = torch.randn(I, H) + sd[f"model.layers.2.mlp.experts.{e}.up_proj.weight"] = torch.randn(I, H) + down = torch.randn(H, I) + sd[f"model.layers.2.mlp.experts.{e}.down_proj.weight"] = down.clone() + down_projs.append(down) + + sd["model.layers.2.mlp.router.gate.weight"] = torch.randn(num_experts, H) + + result = NeuronTrinityForCausalLM.convert_hf_to_neuron_state_dict(sd, config) + + # The down_proj weights should be scaled by route_scale + result_down = result["layers.2.mlp.expert_mlps.mlp_op.down_proj.weight"] + for e in range(min(4, num_experts)): + # Original down_proj is transposed to [I, H], cast to bf16, then scaled + expected = (down_projs[e].T.to(torch.bfloat16)) * route_scale + assert torch.allclose(result_down[e], expected, atol=1e-2), ( + f"Expert {e} down_proj should be scaled by route_scale={route_scale}" + ) + + +class TestSharedExpertMapping: + """Test shared expert weight key mapping.""" + + def test_shared_expert_keys(self): + config = make_config() + H = 1024 + I = 256 + + # HF Trinity uses shared_experts.{proj} (no index) for single shared expert + sd = { + "model.layers.2.mlp.shared_experts.gate_proj.weight": torch.randn(I, H), + "model.layers.2.mlp.shared_experts.up_proj.weight": torch.randn(I, H), + "model.layers.2.mlp.shared_experts.down_proj.weight": torch.randn(H, I), + } + + result = NeuronTrinityForCausalLM.convert_hf_to_neuron_state_dict(sd, config) + + # Should be mapped to standalone shared_expert module + assert ( + "layers.2.shared_expert.gate_proj.weight" in result + or "layers.2.mlp.shared_expert.gate_proj.weight" in result + ), ( + f"Shared expert keys not found. Keys with 'shared': " + f"{[k for k in result if 'shared' in k]}" + ) + + +class TestGatePadding: + """Test gate weight padding when num_heads % TP != 0.""" + + def test_no_padding_when_divisible(self): + """Nano: 8 heads / TP=2 = 4 per rank, no padding needed.""" + config = make_config(tp_degree=2) + n_heads = 8 + H = 1024 + gate = torch.randn(n_heads, H) + sd = {"model.layers.0.self_attn.gate_proj.weight": gate.clone()} + result = NeuronTrinityForCausalLM.convert_hf_to_neuron_state_dict(sd, config) + result_gate = result["layers.0.self_attn.attn_gate_proj.weight"] + # No padding: shape should be (8, H) + assert result_gate.shape[0] == n_heads + + def test_padding_when_not_divisible(self): + """Large: 48 heads / TP=64 requires padding to 64.""" + # Create a Large-like config + large_config_dict = NANO_CONFIG.copy() + large_config_dict.update( + { + "hidden_size": 3072, + "intermediate_size": 12288, + "moe_intermediate_size": 3072, + "num_attention_heads": 48, + "num_key_value_heads": 8, + "num_hidden_layers": 60, + "num_dense_layers": 6, + "num_experts": 256, + "num_experts_per_tok": 4, + "sliding_window": 4096, + } + ) + neuron_config = MoENeuronConfig( + tp_degree=64, batch_size=1, seq_len=4096, torch_dtype=torch.bfloat16 + ) + config = TrinityInferenceConfig( + neuron_config=neuron_config, **large_config_dict + ) + + n_heads = 48 + H = 3072 + gate = torch.randn(n_heads, H) + sd = {"model.layers.0.self_attn.gate_proj.weight": gate.clone()} + result = NeuronTrinityForCausalLM.convert_hf_to_neuron_state_dict(sd, config) + result_gate = result["layers.0.self_attn.attn_gate_proj.weight"] + # Gate weight is (num_heads, hidden_size). After interleaved padding, + # the output should be (padded_total_heads, hidden_size). + padded_heads = 64 # next multiple of TP=64 >= 48 + assert result_gate.shape == (padded_heads, H), ( + f"Expected shape ({padded_heads}, {H}), got {result_gate.shape}" + )