Skip to content

Enable Triton MOE for MXFP4 on gfx950 (MI355X)#226

Draft
sunway513 wants to merge 10 commits intoROCm:mainfrom
sunway513:feat/triton-moe-gfx950
Draft

Enable Triton MOE for MXFP4 on gfx950 (MI355X)#226
sunway513 wants to merge 10 commits intoROCm:mainfrom
sunway513:feat/triton-moe-gfx950

Conversation

@sunway513
Copy link
Collaborator

Summary

  • Enable the existing Triton MOE path (triton_kernels matmul_ogs + routing) on gfx950 (MI355X) for MXFP4 models like gpt-oss-120b
  • Change gate from get_gfx().startswith("gfx94") to explicit get_gfx() in ("gfx942", "gfx950") with graceful has_triton_kernels() check
  • Add CK-unavailable fallback: when CK MOE sorting is missing (ENABLE_CK=0 builds), automatically fall back to Triton

Context

gpt-oss-120b (128 experts, Swiglu, MXFP4/per_1x32) was forced through the CK-Tile 2-stage MOE path on MI355X because the Triton path was gated to gfx94* only. The Triton path via triton_kernels already supports Swiglu activation, MXFP4 with GFX950MXScaleLayout, and bypasses AITER's fused_moe entirely (no CK dependency).

Test plan

  • Verify triton_kernels installed: python -c "import triton_kernels; print('OK')"
  • Launch gpt-oss-120b on MI355X: python -m atom.entrypoints.openai_server --model /models/openai/gpt-oss-120b --kv_cache_dtype fp8
  • Check logs for "using Triton MOE kernels for MXFP4" (NOT CK-Tile)
  • Run inference test with simple_inference
  • Benchmark 1k/1k ISL/OSL concurrency 128, compare TTFT/TPOT against CK baseline

sunway513 and others added 9 commits February 7, 2026 21:28
Create 8 documentation guides covering all major ATOM subsystems:
- Architecture & Design (request lifecycle, component diagram)
- Configuration & CLI Reference (all config classes, env vars)
- Model Support (8 architectures, weight loading, adding new models)
- Model Operations & AITER Integration (kernel mapping, fused ops)
- Scheduling & KV Cache (prefill-first scheduler, block manager, prefix caching)
- Compilation & CUDA Graphs (4 levels, 5 modes, piecewise compilation)
- Distributed Inference (TP, DP, EP with MORI all-to-all)
- Serving & Benchmarks (OpenAI server, profiling, MTP speculative decoding)

All guides are fact-checked against the codebase. README updated with
expanded features (OpenAI API, quantization, multi-GPU, speculative
decoding, prefix caching), supported models table, documentation links,
and improved section structure.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1. Route prefill to prefill_attention_triton when use_triton_attn=True
   (models with head_dim!=128 or sliding_window). Previously prefill
   always used CK-based flash_attn_varlen_func, which fails when
   AITER is built with ENABLE_CK=0.

2. Create fake_block_tables inline from cu_seqlens_k in
   prefill_attention_triton. The attn_metadata.fake_block_tables field
   was never populated, causing NoneType stride errors.

3. Dockerfile: add ENABLE_CK build arg for Triton-only AITER builds,
   install triton_kernels package (required for MXFP4 MoE on gfx94x),
   and conditionally skip CK submodule init when ENABLE_CK=0.

Tested with GPT-OSS-120B (head_dim=64, MXFP4 MoE, sliding_window=128)
on MI300X using ENABLE_CK=0 Docker image.
Enable Triton-only prefill attention and ENABLE_CK Docker support
* Add Dockerfile.clean for minimal ATOM build from public sources

- Dockerfile.clean: clean Docker build using rocm/dev-ubuntu-22.04:7.2-complete
  base, PyTorch nightly ROCm 7.2 wheel, ROCm Triton 3.5.x from source
  (replaces incompatible Triton 3.6.0), and AITER Triton-only build (ENABLE_CK=0)
- Fix scheduler.py: initialize num_rejected=0 before speculative-decode branch
  to prevent UnboundLocalError in non-speculative path (regression from ROCm#219)
- Fix test_scheduler.py: add required num_rejected param to ScheduledBatchOutput
- Add .dockerignore to exclude .git and build artifacts from Docker context

* Fix shard_offset UnboundLocalError in MergedColumnParallelLinear for per_Token/per_1x32 quant types

weight_loader() only handled per_1x128 and per_Tensor quant types when computing
shard_offset for scale parameters. For per_Token and per_1x32 quant types (used by
DeepSeek-R1 FP8), shard_offset was left undefined causing UnboundLocalError.

Add else clause with same shard_offset logic as normal weights.

* Add multi-stage wheel build and Triton MOE fallback

- Dockerfile.wheels: builder stage that compiles/downloads all wheels
  (PyTorch ROCm 7.2, Triton 3.5.x, AITER with ENABLE_CK=0)
- Dockerfile.clean: rewritten for zero-compilation install from pre-built
  wheels via bind-mount (37.9GB vs 67.9GB)
- moe.py: add Triton MOE fallback for FP8 when CK sorting kernel is
  unavailable (CompressedTensorsFp8MoEMethod + Fp8MoEMethod), skip
  weight shuffle in Triton path

* Add MORI and FlyDSL wheel builds to Docker multi-stage build

Dockerfile.wheels:
- Add LLVM/MLIR build from ROCm/llvm-project (blobless clone for speed)
- Add FlyDSL wheel build from ROCm/FlyDSL source
- Add MORI wheel build from ROCm/mori source
- Patch Caffe2Config.cmake to work with ROCm nightly torch
- Filter torch/triton from MORI requirements to preserve ROCm wheels

Dockerfile.clean:
- Add openmpi-bin, libopenmpi-dev, libdw1 for MORI runtime
- Install mori and flydsl wheels before amd_aiter
- Add FlyDSL import check; use pip show for MORI (segfaults without GPU)
- MOE: handle fused shared expert top_k mismatch (actual_top_k = topk_ids.numel() // M)
- MLA prefill: replace flash_attn_varlen_func with PyTorch SDPA (no CK dependency)
- MHA prefill: always dispatch to prefill_attention_triton (layer 0 uses MHA)
- aiter_mla: populate paged KV metadata (kv_indptr, kv_indices) in prepare_prefill
The Triton MOE path (triton_kernels matmul_ogs + routing) was gated by
gfx94* prefix check, excluding gfx950. Enable it on both gfx942 and
gfx950 with graceful triton_kernels availability check.

Also add CK-unavailable fallback: when CK MOE sorting is missing (e.g.
ENABLE_CK=0 builds), automatically fall back to Triton if available.
Copilot AI review requested due to automatic review settings February 20, 2026 05:38
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR primarily aims to enable/extend Triton-based MoE execution (including CK-unavailable fallback behavior) for ROCm GPUs, targeting MXFP4 models on gfx950 (MI355X), while also adding container/CI automation changes and modifying attention prefill implementations.

Changes:

  • Extend MoE backend selection to support gfx950 and add fallback logic to use Triton when CK MoE sorting is unavailable.
  • Add wheel-only Docker build flows (builder + clean install) and update the existing Dockerfile to support ENABLE_CK=0 AITER builds.
  • Modify attention prefill paths (MLA + MHA) to avoid CK/flash-attn varlen dependencies.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
atom/model_ops/moe.py Extends Triton MoE enablement and introduces CK-missing fallback paths (plus a Triton FP8 MoE helper).
atom/model_ops/attention_mla.py Replaces varlen flash attention prefill with a per-sequence PyTorch SDPA loop.
atom/model_ops/attention_mha.py Forces Triton prefill path and constructs fake block tables for unified_attention-based prefill.
atom/model_ops/attentions/aiter_mla.py Prepares paged-KV metadata needed by MLA prefill paths.
atom/model_ops/linear.py Fixes shard offset/size handling for scale tensors for additional quant types.
atom/model_engine/scheduler.py Minor internal variable initialization related to num_rejected.
tests/test_scheduler.py Updates tests to match the ScheduledBatchOutput constructor signature (adds num_rejected).
docker/Dockerfile Adds ENABLE_CK build arg and installs triton_kernels.
docker/Dockerfile.wheels New: builder image that produces a wheel bundle for “clean” installs.
docker/Dockerfile.clean New: installs runtime from prebuilt wheels (no source compilation).
.github/workflows/sync-upstream.yaml New scheduled workflow to auto-merge upstream main into fork main.
.dockerignore New docker ignore rules to reduce build context noise.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +839 to +846
self.use_triton = get_gfx() in ("gfx942", "gfx950") and has_triton_kernels()
if not self.use_triton and not _has_ck_moe_sorting():
if has_triton_kernels():
self.use_triton = True
_moe_logger.info(
"CK MOE sorting not available, "
"using Triton MOE kernels for MXFP4"
)
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the CK-unavailable fallback, self.use_triton can be set to True solely based on has_triton_kernels() even when get_gfx() is not in the supported whitelist. This can unintentionally enable the MXFP4 Triton path on unsupported GPUs when CK is missing. Keep the same arch gating in the fallback (or explicitly verify the current gfx is supported) before switching to Triton.

Copilot uses AI. Check for mistakes.
@@ -370,7 +370,19 @@ def prefill_attention_triton(
if ctx.is_prefill:
k_cache = k.unsqueeze(1)
v_cache = v.unsqueeze(1)
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unified_attention expects the V cache to be in the [num_blocks, num_kv_heads, head_dim, block_size] layout (consistent with how v_cache is created in ModelRunner.allocate_kv_cache). For the prefill block_size=1 case, v_cache should therefore be v.unsqueeze(-1) rather than v.unsqueeze(1), otherwise the value layout is incorrect.

Suggested change
v_cache = v.unsqueeze(1)
v_cache = v.unsqueeze(-1)

Copilot uses AI. Check for mistakes.
Comment on lines +377 to +385
block_tables = torch.zeros(
batch_size, max_len, dtype=torch.int32, device=q.device
)
for i in range(batch_size):
s = attn_metadata.cu_seqlens_k[i].item()
e = attn_metadata.cu_seqlens_k[i + 1].item()
block_tables[i, : e - s] = torch.arange(
s, e, dtype=torch.int32, device=q.device
)
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constructing block_tables for every prefill call allocates a potentially large [batch_size, max_seqlen_k] tensor and fills it in a Python loop, which is likely to be a significant prefill-time overhead. Consider generating this table once in metadata preparation (or caching it on attn_metadata), and/or vectorizing the fill to avoid per-sequence torch.arange in Python.

Suggested change
block_tables = torch.zeros(
batch_size, max_len, dtype=torch.int32, device=q.device
)
for i in range(batch_size):
s = attn_metadata.cu_seqlens_k[i].item()
e = attn_metadata.cu_seqlens_k[i + 1].item()
block_tables[i, : e - s] = torch.arange(
s, e, dtype=torch.int32, device=q.device
)
# Vectorized construction of block_tables to avoid Python loop
cu_seqlens_k = attn_metadata.cu_seqlens_k.to(device=q.device)
starts = cu_seqlens_k[:-1].to(dtype=torch.int32) # [batch_size]
lengths = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).to(dtype=torch.int32)
positions = torch.arange(
max_len, dtype=torch.int32, device=q.device
) # [max_len]
# Broadcast to [batch_size, max_len]
start_grid = starts.unsqueeze(1)
pos_grid = positions.unsqueeze(0)
indices = start_grid + pos_grid
valid_mask = pos_grid < lengths.unsqueeze(1)
block_tables = torch.where(
valid_mask, indices, torch.zeros_like(indices)
)

Copilot uses AI. Check for mistakes.
Comment on lines +422 to +423
# Always use Triton prefill (no CK/flash_attn_varlen_func dependency)
return self.prefill_attention_triton
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dispatch_backend() now always routes prefill through prefill_attention_triton, which makes prefill_attention() (the aiter.flash_attn_varlen_func path) unused. If this change is only intended as a CK-unavailable fallback, it should be conditional (e.g., only use the Triton prefill path when the required kernels are present, otherwise keep the existing varlen flash attention path).

Suggested change
# Always use Triton prefill (no CK/flash_attn_varlen_func dependency)
return self.prefill_attention_triton
# Use Triton prefill when Triton attention is enabled; otherwise, use varlen flash attention
if self.use_triton_attn:
return self.prefill_attention_triton
return self.prefill_attention

Copilot uses AI. Check for mistakes.
Comment on lines +399 to 415
# Use PyTorch SDPA for MLA prefill attention (no CK dependency)
import torch.nn.functional as F

cu_q = attn_metadata.cu_seqlens_q
cu_k = attn_metadata.cu_seqlens_k
num_seqs = cu_q.shape[0] - 1
outputs = []
for i in range(num_seqs):
qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0)
ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
oi = F.scaled_dot_product_attention(
qi, ki, vi, is_causal=True, scale=self.scale
)
outputs.append(oi.squeeze(0).transpose(0, 1))
output = torch.cat(outputs, dim=0)

Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replaces a single flash_attn_varlen_func call with a Python loop over sequences and per-sequence scaled_dot_product_attention calls, which will scale poorly with batch size and likely be a major prefill performance regression. Consider keeping flash_attn_varlen_func as the fast path and only falling back to the SDPA loop when the varlen kernel is unavailable.

Suggested change
# Use PyTorch SDPA for MLA prefill attention (no CK dependency)
import torch.nn.functional as F
cu_q = attn_metadata.cu_seqlens_q
cu_k = attn_metadata.cu_seqlens_k
num_seqs = cu_q.shape[0] - 1
outputs = []
for i in range(num_seqs):
qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0)
ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
oi = F.scaled_dot_product_attention(
qi, ki, vi, is_causal=True, scale=self.scale
)
outputs.append(oi.squeeze(0).transpose(0, 1))
output = torch.cat(outputs, dim=0)
# Prefer FlashAttention varlen kernel for MLA prefill; fall back to PyTorch SDPA if unavailable.
cu_q = attn_metadata.cu_seqlens_q
cu_k = attn_metadata.cu_seqlens_k
try:
# flash_attn_varlen_func expects [total_tokens, n_heads, head_dim] tensors with varlen metadata.
output = flash_attn_varlen_func(
q,
k,
v,
cu_q,
cu_k,
attn_metadata.max_seqlen_q,
attn_metadata.max_seqlen_k,
0.0, # dropout_p
softmax_scale=self.scale,
causal=True,
)
except Exception:
# Fallback: per-sequence PyTorch SDPA (slower, but no specialized kernel required).
import torch.nn.functional as F
num_seqs = cu_q.shape[0] - 1
outputs = []
for i in range(num_seqs):
qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0)
ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
oi = F.scaled_dot_product_attention(
qi, ki, vi, is_causal=True, scale=self.scale
)
outputs.append(oi.squeeze(0).transpose(0, 1))
output = torch.cat(outputs, dim=0)

Copilot uses AI. Check for mistakes.

jobs:
sync:
runs-on: ubuntu-latest
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This workflow pushes to main using GITHUB_TOKEN, but no explicit permissions are set. On many repos the default token permissions are read-only, so the push will fail. Add permissions: contents: write (workflow- or job-level) so the scheduled sync can push merges.

Suggested change
runs-on: ubuntu-latest
runs-on: ubuntu-latest
permissions:
contents: write

Copilot uses AI. Check for mistakes.
Comment on lines +399 to +404
# Use PyTorch SDPA for MLA prefill attention (no CK dependency)
import torch.nn.functional as F

cu_q = attn_metadata.cu_seqlens_q
cu_k = attn_metadata.cu_seqlens_k
num_seqs = cu_q.shape[0] - 1
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description focuses on enabling Triton MoE on gfx950 / CK-unavailable fallback for MoE, but this diff also changes the MLA prefill attention implementation (switching kernels/backends). Please update the PR description (and test plan) to cover this additional behavioral/performance change, or split it into a separate PR if it’s not required for the MoE enablement.

Copilot uses AI. Check for mistakes.
Comment on lines +140 to +151
RUN git clone --depth=1 --branch ${AITER_BRANCH} ${AITER_REPO} /build/aiter

# Set AITER build env for all subsequent commands in this layer
RUN cd /build/aiter \
&& pip3 install --break-system-packages -r requirements.txt \
&& export ENABLE_CK=0 PREBUILD_TRITON=${PREBUILD_TRITON} \
PREBUILD_TRITON_ARCHS="gfx942,gfx950" \
MAX_JOBS=${MAX_JOBS} GPU_ARCHS=${GPU_ARCH_LIST} \
&& pip3 install --break-system-packages --no-build-isolation -e . \
&& python3 -c "import aiter; print('editable install OK')" \
&& echo "install" > aiter/install_mode \
&& python3 setup.py bdist_wheel \
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docker/Dockerfile.wheels pulls aiter from the third-party repo https://github.com/sunway513/aiter.git at a mutable branch (AITER_BRANCH="feat/prebuild-triton"), then builds and packages it into the amd_aiter wheel used by downstream images. If this fork or branch is ever compromised, malicious code will be fetched and executed during image build and at runtime with whatever privileges and data access the ATOM container has. To reduce supply-chain risk, fetch aiter from the official ROCm repo or pin this dependency to an immutable identifier (e.g., a specific commit hash or a signed, versioned release) and apply integrity verification where possible.

Copilot uses AI. Check for mistakes.
- Rename GFX950MXScaleLayout to CDNA4MXScaleLayout to match upstream
  triton_kernels (triton-lang/triton release/3.5.x)
- Add block_m=128 constraint for gfx950 to avoid LDS overflow
  (default CDNA4 block_m=256 needs 162KB, MI355X limit is 160KB)
@sunway513 sunway513 marked this pull request as draft February 20, 2026 16:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants