Enable Triton MOE for MXFP4 on gfx950 (MI355X)#226
Enable Triton MOE for MXFP4 on gfx950 (MI355X)#226
Conversation
Create 8 documentation guides covering all major ATOM subsystems: - Architecture & Design (request lifecycle, component diagram) - Configuration & CLI Reference (all config classes, env vars) - Model Support (8 architectures, weight loading, adding new models) - Model Operations & AITER Integration (kernel mapping, fused ops) - Scheduling & KV Cache (prefill-first scheduler, block manager, prefix caching) - Compilation & CUDA Graphs (4 levels, 5 modes, piecewise compilation) - Distributed Inference (TP, DP, EP with MORI all-to-all) - Serving & Benchmarks (OpenAI server, profiling, MTP speculative decoding) All guides are fact-checked against the codebase. README updated with expanded features (OpenAI API, quantization, multi-GPU, speculative decoding, prefix caching), supported models table, documentation links, and improved section structure.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1. Route prefill to prefill_attention_triton when use_triton_attn=True (models with head_dim!=128 or sliding_window). Previously prefill always used CK-based flash_attn_varlen_func, which fails when AITER is built with ENABLE_CK=0. 2. Create fake_block_tables inline from cu_seqlens_k in prefill_attention_triton. The attn_metadata.fake_block_tables field was never populated, causing NoneType stride errors. 3. Dockerfile: add ENABLE_CK build arg for Triton-only AITER builds, install triton_kernels package (required for MXFP4 MoE on gfx94x), and conditionally skip CK submodule init when ENABLE_CK=0. Tested with GPT-OSS-120B (head_dim=64, MXFP4 MoE, sliding_window=128) on MI300X using ENABLE_CK=0 Docker image.
Enable Triton-only prefill attention and ENABLE_CK Docker support
* Add Dockerfile.clean for minimal ATOM build from public sources - Dockerfile.clean: clean Docker build using rocm/dev-ubuntu-22.04:7.2-complete base, PyTorch nightly ROCm 7.2 wheel, ROCm Triton 3.5.x from source (replaces incompatible Triton 3.6.0), and AITER Triton-only build (ENABLE_CK=0) - Fix scheduler.py: initialize num_rejected=0 before speculative-decode branch to prevent UnboundLocalError in non-speculative path (regression from ROCm#219) - Fix test_scheduler.py: add required num_rejected param to ScheduledBatchOutput - Add .dockerignore to exclude .git and build artifacts from Docker context * Fix shard_offset UnboundLocalError in MergedColumnParallelLinear for per_Token/per_1x32 quant types weight_loader() only handled per_1x128 and per_Tensor quant types when computing shard_offset for scale parameters. For per_Token and per_1x32 quant types (used by DeepSeek-R1 FP8), shard_offset was left undefined causing UnboundLocalError. Add else clause with same shard_offset logic as normal weights. * Add multi-stage wheel build and Triton MOE fallback - Dockerfile.wheels: builder stage that compiles/downloads all wheels (PyTorch ROCm 7.2, Triton 3.5.x, AITER with ENABLE_CK=0) - Dockerfile.clean: rewritten for zero-compilation install from pre-built wheels via bind-mount (37.9GB vs 67.9GB) - moe.py: add Triton MOE fallback for FP8 when CK sorting kernel is unavailable (CompressedTensorsFp8MoEMethod + Fp8MoEMethod), skip weight shuffle in Triton path * Add MORI and FlyDSL wheel builds to Docker multi-stage build Dockerfile.wheels: - Add LLVM/MLIR build from ROCm/llvm-project (blobless clone for speed) - Add FlyDSL wheel build from ROCm/FlyDSL source - Add MORI wheel build from ROCm/mori source - Patch Caffe2Config.cmake to work with ROCm nightly torch - Filter torch/triton from MORI requirements to preserve ROCm wheels Dockerfile.clean: - Add openmpi-bin, libopenmpi-dev, libdw1 for MORI runtime - Install mori and flydsl wheels before amd_aiter - Add FlyDSL import check; use pip show for MORI (segfaults without GPU)
- MOE: handle fused shared expert top_k mismatch (actual_top_k = topk_ids.numel() // M) - MLA prefill: replace flash_attn_varlen_func with PyTorch SDPA (no CK dependency) - MHA prefill: always dispatch to prefill_attention_triton (layer 0 uses MHA) - aiter_mla: populate paged KV metadata (kv_indptr, kv_indices) in prepare_prefill
The Triton MOE path (triton_kernels matmul_ogs + routing) was gated by gfx94* prefix check, excluding gfx950. Enable it on both gfx942 and gfx950 with graceful triton_kernels availability check. Also add CK-unavailable fallback: when CK MOE sorting is missing (e.g. ENABLE_CK=0 builds), automatically fall back to Triton if available.
There was a problem hiding this comment.
Pull request overview
This PR primarily aims to enable/extend Triton-based MoE execution (including CK-unavailable fallback behavior) for ROCm GPUs, targeting MXFP4 models on gfx950 (MI355X), while also adding container/CI automation changes and modifying attention prefill implementations.
Changes:
- Extend MoE backend selection to support gfx950 and add fallback logic to use Triton when CK MoE sorting is unavailable.
- Add wheel-only Docker build flows (builder + clean install) and update the existing Dockerfile to support ENABLE_CK=0 AITER builds.
- Modify attention prefill paths (MLA + MHA) to avoid CK/flash-attn varlen dependencies.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
atom/model_ops/moe.py |
Extends Triton MoE enablement and introduces CK-missing fallback paths (plus a Triton FP8 MoE helper). |
atom/model_ops/attention_mla.py |
Replaces varlen flash attention prefill with a per-sequence PyTorch SDPA loop. |
atom/model_ops/attention_mha.py |
Forces Triton prefill path and constructs fake block tables for unified_attention-based prefill. |
atom/model_ops/attentions/aiter_mla.py |
Prepares paged-KV metadata needed by MLA prefill paths. |
atom/model_ops/linear.py |
Fixes shard offset/size handling for scale tensors for additional quant types. |
atom/model_engine/scheduler.py |
Minor internal variable initialization related to num_rejected. |
tests/test_scheduler.py |
Updates tests to match the ScheduledBatchOutput constructor signature (adds num_rejected). |
docker/Dockerfile |
Adds ENABLE_CK build arg and installs triton_kernels. |
docker/Dockerfile.wheels |
New: builder image that produces a wheel bundle for “clean” installs. |
docker/Dockerfile.clean |
New: installs runtime from prebuilt wheels (no source compilation). |
.github/workflows/sync-upstream.yaml |
New scheduled workflow to auto-merge upstream main into fork main. |
.dockerignore |
New docker ignore rules to reduce build context noise. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self.use_triton = get_gfx() in ("gfx942", "gfx950") and has_triton_kernels() | ||
| if not self.use_triton and not _has_ck_moe_sorting(): | ||
| if has_triton_kernels(): | ||
| self.use_triton = True | ||
| _moe_logger.info( | ||
| "CK MOE sorting not available, " | ||
| "using Triton MOE kernels for MXFP4" | ||
| ) |
There was a problem hiding this comment.
In the CK-unavailable fallback, self.use_triton can be set to True solely based on has_triton_kernels() even when get_gfx() is not in the supported whitelist. This can unintentionally enable the MXFP4 Triton path on unsupported GPUs when CK is missing. Keep the same arch gating in the fallback (or explicitly verify the current gfx is supported) before switching to Triton.
| @@ -370,7 +370,19 @@ def prefill_attention_triton( | |||
| if ctx.is_prefill: | |||
| k_cache = k.unsqueeze(1) | |||
| v_cache = v.unsqueeze(1) | |||
There was a problem hiding this comment.
unified_attention expects the V cache to be in the [num_blocks, num_kv_heads, head_dim, block_size] layout (consistent with how v_cache is created in ModelRunner.allocate_kv_cache). For the prefill block_size=1 case, v_cache should therefore be v.unsqueeze(-1) rather than v.unsqueeze(1), otherwise the value layout is incorrect.
| v_cache = v.unsqueeze(1) | |
| v_cache = v.unsqueeze(-1) |
| block_tables = torch.zeros( | ||
| batch_size, max_len, dtype=torch.int32, device=q.device | ||
| ) | ||
| for i in range(batch_size): | ||
| s = attn_metadata.cu_seqlens_k[i].item() | ||
| e = attn_metadata.cu_seqlens_k[i + 1].item() | ||
| block_tables[i, : e - s] = torch.arange( | ||
| s, e, dtype=torch.int32, device=q.device | ||
| ) |
There was a problem hiding this comment.
Constructing block_tables for every prefill call allocates a potentially large [batch_size, max_seqlen_k] tensor and fills it in a Python loop, which is likely to be a significant prefill-time overhead. Consider generating this table once in metadata preparation (or caching it on attn_metadata), and/or vectorizing the fill to avoid per-sequence torch.arange in Python.
| block_tables = torch.zeros( | |
| batch_size, max_len, dtype=torch.int32, device=q.device | |
| ) | |
| for i in range(batch_size): | |
| s = attn_metadata.cu_seqlens_k[i].item() | |
| e = attn_metadata.cu_seqlens_k[i + 1].item() | |
| block_tables[i, : e - s] = torch.arange( | |
| s, e, dtype=torch.int32, device=q.device | |
| ) | |
| # Vectorized construction of block_tables to avoid Python loop | |
| cu_seqlens_k = attn_metadata.cu_seqlens_k.to(device=q.device) | |
| starts = cu_seqlens_k[:-1].to(dtype=torch.int32) # [batch_size] | |
| lengths = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).to(dtype=torch.int32) | |
| positions = torch.arange( | |
| max_len, dtype=torch.int32, device=q.device | |
| ) # [max_len] | |
| # Broadcast to [batch_size, max_len] | |
| start_grid = starts.unsqueeze(1) | |
| pos_grid = positions.unsqueeze(0) | |
| indices = start_grid + pos_grid | |
| valid_mask = pos_grid < lengths.unsqueeze(1) | |
| block_tables = torch.where( | |
| valid_mask, indices, torch.zeros_like(indices) | |
| ) |
| # Always use Triton prefill (no CK/flash_attn_varlen_func dependency) | ||
| return self.prefill_attention_triton |
There was a problem hiding this comment.
dispatch_backend() now always routes prefill through prefill_attention_triton, which makes prefill_attention() (the aiter.flash_attn_varlen_func path) unused. If this change is only intended as a CK-unavailable fallback, it should be conditional (e.g., only use the Triton prefill path when the required kernels are present, otherwise keep the existing varlen flash attention path).
| # Always use Triton prefill (no CK/flash_attn_varlen_func dependency) | |
| return self.prefill_attention_triton | |
| # Use Triton prefill when Triton attention is enabled; otherwise, use varlen flash attention | |
| if self.use_triton_attn: | |
| return self.prefill_attention_triton | |
| return self.prefill_attention |
| # Use PyTorch SDPA for MLA prefill attention (no CK dependency) | ||
| import torch.nn.functional as F | ||
|
|
||
| cu_q = attn_metadata.cu_seqlens_q | ||
| cu_k = attn_metadata.cu_seqlens_k | ||
| num_seqs = cu_q.shape[0] - 1 | ||
| outputs = [] | ||
| for i in range(num_seqs): | ||
| qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0) | ||
| ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) | ||
| vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) | ||
| oi = F.scaled_dot_product_attention( | ||
| qi, ki, vi, is_causal=True, scale=self.scale | ||
| ) | ||
| outputs.append(oi.squeeze(0).transpose(0, 1)) | ||
| output = torch.cat(outputs, dim=0) | ||
|
|
There was a problem hiding this comment.
This replaces a single flash_attn_varlen_func call with a Python loop over sequences and per-sequence scaled_dot_product_attention calls, which will scale poorly with batch size and likely be a major prefill performance regression. Consider keeping flash_attn_varlen_func as the fast path and only falling back to the SDPA loop when the varlen kernel is unavailable.
| # Use PyTorch SDPA for MLA prefill attention (no CK dependency) | |
| import torch.nn.functional as F | |
| cu_q = attn_metadata.cu_seqlens_q | |
| cu_k = attn_metadata.cu_seqlens_k | |
| num_seqs = cu_q.shape[0] - 1 | |
| outputs = [] | |
| for i in range(num_seqs): | |
| qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0) | |
| ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) | |
| vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) | |
| oi = F.scaled_dot_product_attention( | |
| qi, ki, vi, is_causal=True, scale=self.scale | |
| ) | |
| outputs.append(oi.squeeze(0).transpose(0, 1)) | |
| output = torch.cat(outputs, dim=0) | |
| # Prefer FlashAttention varlen kernel for MLA prefill; fall back to PyTorch SDPA if unavailable. | |
| cu_q = attn_metadata.cu_seqlens_q | |
| cu_k = attn_metadata.cu_seqlens_k | |
| try: | |
| # flash_attn_varlen_func expects [total_tokens, n_heads, head_dim] tensors with varlen metadata. | |
| output = flash_attn_varlen_func( | |
| q, | |
| k, | |
| v, | |
| cu_q, | |
| cu_k, | |
| attn_metadata.max_seqlen_q, | |
| attn_metadata.max_seqlen_k, | |
| 0.0, # dropout_p | |
| softmax_scale=self.scale, | |
| causal=True, | |
| ) | |
| except Exception: | |
| # Fallback: per-sequence PyTorch SDPA (slower, but no specialized kernel required). | |
| import torch.nn.functional as F | |
| num_seqs = cu_q.shape[0] - 1 | |
| outputs = [] | |
| for i in range(num_seqs): | |
| qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0) | |
| ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) | |
| vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) | |
| oi = F.scaled_dot_product_attention( | |
| qi, ki, vi, is_causal=True, scale=self.scale | |
| ) | |
| outputs.append(oi.squeeze(0).transpose(0, 1)) | |
| output = torch.cat(outputs, dim=0) |
|
|
||
| jobs: | ||
| sync: | ||
| runs-on: ubuntu-latest |
There was a problem hiding this comment.
This workflow pushes to main using GITHUB_TOKEN, but no explicit permissions are set. On many repos the default token permissions are read-only, so the push will fail. Add permissions: contents: write (workflow- or job-level) so the scheduled sync can push merges.
| runs-on: ubuntu-latest | |
| runs-on: ubuntu-latest | |
| permissions: | |
| contents: write |
| # Use PyTorch SDPA for MLA prefill attention (no CK dependency) | ||
| import torch.nn.functional as F | ||
|
|
||
| cu_q = attn_metadata.cu_seqlens_q | ||
| cu_k = attn_metadata.cu_seqlens_k | ||
| num_seqs = cu_q.shape[0] - 1 |
There was a problem hiding this comment.
The PR description focuses on enabling Triton MoE on gfx950 / CK-unavailable fallback for MoE, but this diff also changes the MLA prefill attention implementation (switching kernels/backends). Please update the PR description (and test plan) to cover this additional behavioral/performance change, or split it into a separate PR if it’s not required for the MoE enablement.
| RUN git clone --depth=1 --branch ${AITER_BRANCH} ${AITER_REPO} /build/aiter | ||
|
|
||
| # Set AITER build env for all subsequent commands in this layer | ||
| RUN cd /build/aiter \ | ||
| && pip3 install --break-system-packages -r requirements.txt \ | ||
| && export ENABLE_CK=0 PREBUILD_TRITON=${PREBUILD_TRITON} \ | ||
| PREBUILD_TRITON_ARCHS="gfx942,gfx950" \ | ||
| MAX_JOBS=${MAX_JOBS} GPU_ARCHS=${GPU_ARCH_LIST} \ | ||
| && pip3 install --break-system-packages --no-build-isolation -e . \ | ||
| && python3 -c "import aiter; print('editable install OK')" \ | ||
| && echo "install" > aiter/install_mode \ | ||
| && python3 setup.py bdist_wheel \ |
There was a problem hiding this comment.
docker/Dockerfile.wheels pulls aiter from the third-party repo https://github.com/sunway513/aiter.git at a mutable branch (AITER_BRANCH="feat/prebuild-triton"), then builds and packages it into the amd_aiter wheel used by downstream images. If this fork or branch is ever compromised, malicious code will be fetched and executed during image build and at runtime with whatever privileges and data access the ATOM container has. To reduce supply-chain risk, fetch aiter from the official ROCm repo or pin this dependency to an immutable identifier (e.g., a specific commit hash or a signed, versioned release) and apply integrity verification where possible.
- Rename GFX950MXScaleLayout to CDNA4MXScaleLayout to match upstream triton_kernels (triton-lang/triton release/3.5.x) - Add block_m=128 constraint for gfx950 to avoid LDS overflow (default CDNA4 block_m=256 needs 162KB, MI355X limit is 160KB)
Summary
triton_kernelsmatmul_ogs + routing) on gfx950 (MI355X) for MXFP4 models like gpt-oss-120bget_gfx().startswith("gfx94")to explicitget_gfx() in ("gfx942", "gfx950")with gracefulhas_triton_kernels()checkContext
gpt-oss-120b (128 experts, Swiglu, MXFP4/per_1x32) was forced through the CK-Tile 2-stage MOE path on MI355X because the Triton path was gated to
gfx94*only. The Triton path viatriton_kernelsalready supports Swiglu activation, MXFP4 withGFX950MXScaleLayout, and bypasses AITER's fused_moe entirely (no CK dependency).Test plan
triton_kernelsinstalled:python -c "import triton_kernels; print('OK')"python -m atom.entrypoints.openai_server --model /models/openai/gpt-oss-120b --kv_cache_dtype fp8"using Triton MOE kernels for MXFP4"(NOT CK-Tile)