-
Notifications
You must be signed in to change notification settings - Fork 17
Enable Triton MOE for MXFP4 on gfx950 (MI355X) #226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c92f2b5
2c12a07
553be85
f37370e
8eeb42d
3c6872d
a3168bd
9c6a7b8
95a5517
0b027c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| .git | ||
| __pycache__ | ||
| *.pyc | ||
| *.egg-info | ||
| build/ | ||
| dist/ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| name: Sync upstream main | ||
|
|
||
| on: | ||
| schedule: | ||
| # Run nightly at 06:00 UTC (midnight CST) | ||
| - cron: '0 6 * * *' | ||
| workflow_dispatch: # Allow manual trigger | ||
|
|
||
| jobs: | ||
| sync: | ||
| runs-on: ubuntu-latest | ||
| steps: | ||
| - name: Checkout fork | ||
| uses: actions/checkout@v4 | ||
| with: | ||
| ref: main | ||
| fetch-depth: 0 | ||
| token: ${{ secrets.GITHUB_TOKEN }} | ||
|
|
||
| - name: Add upstream remote | ||
| run: git remote add upstream https://github.com/ROCm/ATOM.git | ||
|
|
||
| - name: Fetch upstream | ||
| run: git fetch upstream main | ||
|
|
||
| - name: Check for new commits | ||
| id: check | ||
| run: | | ||
| BEHIND=$(git rev-list --count HEAD..upstream/main) | ||
| echo "behind=$BEHIND" >> "$GITHUB_OUTPUT" | ||
| echo "Fork is $BEHIND commit(s) behind upstream" | ||
|
|
||
| - name: Merge upstream | ||
| if: steps.check.outputs.behind != '0' | ||
| run: | | ||
| git config user.name "github-actions[bot]" | ||
| git config user.email "github-actions[bot]@users.noreply.github.com" | ||
| git merge upstream/main --no-edit | ||
|
|
||
| - name: Push | ||
| if: steps.check.outputs.behind != '0' | ||
| run: git push origin main | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -370,7 +370,19 @@ def prefill_attention_triton( | |||||||||||||||||||||||||||||||||||||||||||||||||
| if ctx.is_prefill: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| k_cache = k.unsqueeze(1) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| v_cache = v.unsqueeze(1) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| v_cache = v.unsqueeze(1) | |
| v_cache = v.unsqueeze(-1) |
Copilot
AI
Feb 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Constructing block_tables for every prefill call allocates a potentially large [batch_size, max_seqlen_k] tensor and fills it in a Python loop, which is likely to be a significant prefill-time overhead. Consider generating this table once in metadata preparation (or caching it on attn_metadata), and/or vectorizing the fill to avoid per-sequence torch.arange in Python.
| block_tables = torch.zeros( | |
| batch_size, max_len, dtype=torch.int32, device=q.device | |
| ) | |
| for i in range(batch_size): | |
| s = attn_metadata.cu_seqlens_k[i].item() | |
| e = attn_metadata.cu_seqlens_k[i + 1].item() | |
| block_tables[i, : e - s] = torch.arange( | |
| s, e, dtype=torch.int32, device=q.device | |
| ) | |
| # Vectorized construction of block_tables to avoid Python loop | |
| cu_seqlens_k = attn_metadata.cu_seqlens_k.to(device=q.device) | |
| starts = cu_seqlens_k[:-1].to(dtype=torch.int32) # [batch_size] | |
| lengths = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).to(dtype=torch.int32) | |
| positions = torch.arange( | |
| max_len, dtype=torch.int32, device=q.device | |
| ) # [max_len] | |
| # Broadcast to [batch_size, max_len] | |
| start_grid = starts.unsqueeze(1) | |
| pos_grid = positions.unsqueeze(0) | |
| indices = start_grid + pos_grid | |
| valid_mask = pos_grid < lengths.unsqueeze(1) | |
| block_tables = torch.where( | |
| valid_mask, indices, torch.zeros_like(indices) | |
| ) |
Copilot
AI
Feb 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dispatch_backend() now always routes prefill through prefill_attention_triton, which makes prefill_attention() (the aiter.flash_attn_varlen_func path) unused. If this change is only intended as a CK-unavailable fallback, it should be conditional (e.g., only use the Triton prefill path when the required kernels are present, otherwise keep the existing varlen flash attention path).
| # Always use Triton prefill (no CK/flash_attn_varlen_func dependency) | |
| return self.prefill_attention_triton | |
| # Use Triton prefill when Triton attention is enabled; otherwise, use varlen flash attention | |
| if self.use_triton_attn: | |
| return self.prefill_attention_triton | |
| return self.prefill_attention |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -396,19 +396,22 @@ def _forward_prefill_mha( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k = torch.cat((k_nope, k_rope.expand((*k_nope.shape[:-1], -1))), dim=-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output = flash_attn_varlen_func( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| q=q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k=k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| v=v, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cu_seqlens_q=attn_metadata.cu_seqlens_q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cu_seqlens_k=attn_metadata.cu_seqlens_k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_seqlen_q=attn_metadata.max_seqlen_q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_seqlen_k=attn_metadata.max_seqlen_k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| min_seqlen_q=attn_metadata.min_seqlen_q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dropout_p=attn_metadata.dropout_p, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| softmax_scale=self.scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| causal=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Use PyTorch SDPA for MLA prefill attention (no CK dependency) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.nn.functional as F | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cu_q = attn_metadata.cu_seqlens_q | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cu_k = attn_metadata.cu_seqlens_k | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_seqs = cu_q.shape[0] - 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+399
to
+404
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in range(num_seqs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| oi = F.scaled_dot_product_attention( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| qi, ki, vi, is_causal=True, scale=self.scale | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs.append(oi.squeeze(0).transpose(0, 1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output = torch.cat(outputs, dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+399
to
415
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Use PyTorch SDPA for MLA prefill attention (no CK dependency) | |
| import torch.nn.functional as F | |
| cu_q = attn_metadata.cu_seqlens_q | |
| cu_k = attn_metadata.cu_seqlens_k | |
| num_seqs = cu_q.shape[0] - 1 | |
| outputs = [] | |
| for i in range(num_seqs): | |
| qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0) | |
| ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) | |
| vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) | |
| oi = F.scaled_dot_product_attention( | |
| qi, ki, vi, is_causal=True, scale=self.scale | |
| ) | |
| outputs.append(oi.squeeze(0).transpose(0, 1)) | |
| output = torch.cat(outputs, dim=0) | |
| # Prefer FlashAttention varlen kernel for MLA prefill; fall back to PyTorch SDPA if unavailable. | |
| cu_q = attn_metadata.cu_seqlens_q | |
| cu_k = attn_metadata.cu_seqlens_k | |
| try: | |
| # flash_attn_varlen_func expects [total_tokens, n_heads, head_dim] tensors with varlen metadata. | |
| output = flash_attn_varlen_func( | |
| q, | |
| k, | |
| v, | |
| cu_q, | |
| cu_k, | |
| attn_metadata.max_seqlen_q, | |
| attn_metadata.max_seqlen_k, | |
| 0.0, # dropout_p | |
| softmax_scale=self.scale, | |
| causal=True, | |
| ) | |
| except Exception: | |
| # Fallback: per-sequence PyTorch SDPA (slower, but no specialized kernel required). | |
| import torch.nn.functional as F | |
| num_seqs = cu_q.shape[0] - 1 | |
| outputs = [] | |
| for i in range(num_seqs): | |
| qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0) | |
| ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) | |
| vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0) | |
| oi = F.scaled_dot_product_attention( | |
| qi, ki, vi, is_causal=True, scale=self.scale | |
| ) | |
| outputs.append(oi.squeeze(0).transpose(0, 1)) | |
| output = torch.cat(outputs, dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This workflow pushes to
mainusingGITHUB_TOKEN, but no explicitpermissionsare set. On many repos the default token permissions are read-only, so the push will fail. Addpermissions: contents: write(workflow- or job-level) so the scheduled sync can push merges.