Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.git
__pycache__
*.pyc
*.egg-info
build/
dist/
42 changes: 42 additions & 0 deletions .github/workflows/sync-upstream.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Sync upstream main

on:
schedule:
# Run nightly at 06:00 UTC (midnight CST)
- cron: '0 6 * * *'
workflow_dispatch: # Allow manual trigger

jobs:
sync:
runs-on: ubuntu-latest
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This workflow pushes to main using GITHUB_TOKEN, but no explicit permissions are set. On many repos the default token permissions are read-only, so the push will fail. Add permissions: contents: write (workflow- or job-level) so the scheduled sync can push merges.

Suggested change
runs-on: ubuntu-latest
runs-on: ubuntu-latest
permissions:
contents: write

Copilot uses AI. Check for mistakes.
steps:
- name: Checkout fork
uses: actions/checkout@v4
with:
ref: main
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}

- name: Add upstream remote
run: git remote add upstream https://github.com/ROCm/ATOM.git

- name: Fetch upstream
run: git fetch upstream main

- name: Check for new commits
id: check
run: |
BEHIND=$(git rev-list --count HEAD..upstream/main)
echo "behind=$BEHIND" >> "$GITHUB_OUTPUT"
echo "Fork is $BEHIND commit(s) behind upstream"

- name: Merge upstream
if: steps.check.outputs.behind != '0'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git merge upstream/main --no-edit

- name: Push
if: steps.check.outputs.behind != '0'
run: git push origin main
1 change: 1 addition & 0 deletions atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def postprocess(
continue
token_ids = prev_token_ids[seq.id]
num_new_token = len(token_ids)
num_rejected = 0
self.update_spec_stats(num_new_token)
idx = fwd_output.req_ids.index(seq.id)
if is_deferred_out or self.use_spec:
Expand Down
17 changes: 15 additions & 2 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,19 @@ def prefill_attention_triton(
if ctx.is_prefill:
k_cache = k.unsqueeze(1)
v_cache = v.unsqueeze(1)
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unified_attention expects the V cache to be in the [num_blocks, num_kv_heads, head_dim, block_size] layout (consistent with how v_cache is created in ModelRunner.allocate_kv_cache). For the prefill block_size=1 case, v_cache should therefore be v.unsqueeze(-1) rather than v.unsqueeze(1), otherwise the value layout is incorrect.

Suggested change
v_cache = v.unsqueeze(1)
v_cache = v.unsqueeze(-1)

Copilot uses AI. Check for mistakes.
block_tables = attn_metadata.fake_block_tables
# Create fake block_tables for prefill: each token is its own
# "block" (block_size=1). Shape [num_seqs, max_seqlen_k].
batch_size = attn_metadata.cu_seqlens_k.shape[0] - 1
max_len = attn_metadata.max_seqlen_k
block_tables = torch.zeros(
batch_size, max_len, dtype=torch.int32, device=q.device
)
for i in range(batch_size):
s = attn_metadata.cu_seqlens_k[i].item()
e = attn_metadata.cu_seqlens_k[i + 1].item()
block_tables[i, : e - s] = torch.arange(
s, e, dtype=torch.int32, device=q.device
)
Comment on lines +377 to +385
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constructing block_tables for every prefill call allocates a potentially large [batch_size, max_seqlen_k] tensor and fills it in a Python loop, which is likely to be a significant prefill-time overhead. Consider generating this table once in metadata preparation (or caching it on attn_metadata), and/or vectorizing the fill to avoid per-sequence torch.arange in Python.

Suggested change
block_tables = torch.zeros(
batch_size, max_len, dtype=torch.int32, device=q.device
)
for i in range(batch_size):
s = attn_metadata.cu_seqlens_k[i].item()
e = attn_metadata.cu_seqlens_k[i + 1].item()
block_tables[i, : e - s] = torch.arange(
s, e, dtype=torch.int32, device=q.device
)
# Vectorized construction of block_tables to avoid Python loop
cu_seqlens_k = attn_metadata.cu_seqlens_k.to(device=q.device)
starts = cu_seqlens_k[:-1].to(dtype=torch.int32) # [batch_size]
lengths = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).to(dtype=torch.int32)
positions = torch.arange(
max_len, dtype=torch.int32, device=q.device
) # [max_len]
# Broadcast to [batch_size, max_len]
start_grid = starts.unsqueeze(1)
pos_grid = positions.unsqueeze(0)
indices = start_grid + pos_grid
valid_mask = pos_grid < lengths.unsqueeze(1)
block_tables = torch.where(
valid_mask, indices, torch.zeros_like(indices)
)

Copilot uses AI. Check for mistakes.

o = torch.empty_like(q)
descale_shape = (attn_metadata.cu_seqlens_q.shape[0] - 1, k.shape[1])
Expand Down Expand Up @@ -407,7 +419,8 @@ def dispatch_backend(self, fwd_ctx: ForwardContext):
ctx = fwd_ctx.context

if ctx.is_prefill:
return self.prefill_attention
# Always use Triton prefill (no CK/flash_attn_varlen_func dependency)
return self.prefill_attention_triton
Comment on lines +422 to +423
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dispatch_backend() now always routes prefill through prefill_attention_triton, which makes prefill_attention() (the aiter.flash_attn_varlen_func path) unused. If this change is only intended as a CK-unavailable fallback, it should be conditional (e.g., only use the Triton prefill path when the required kernels are present, otherwise keep the existing varlen flash attention path).

Suggested change
# Always use Triton prefill (no CK/flash_attn_varlen_func dependency)
return self.prefill_attention_triton
# Use Triton prefill when Triton attention is enabled; otherwise, use varlen flash attention
if self.use_triton_attn:
return self.prefill_attention_triton
return self.prefill_attention

Copilot uses AI. Check for mistakes.
else:
if self.use_triton_attn:
return self.paged_attention_triton
Expand Down
29 changes: 16 additions & 13 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,19 +396,22 @@ def _forward_prefill_mha(

k = torch.cat((k_nope, k_rope.expand((*k_nope.shape[:-1], -1))), dim=-1)

output = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=attn_metadata.cu_seqlens_q,
cu_seqlens_k=attn_metadata.cu_seqlens_k,
max_seqlen_q=attn_metadata.max_seqlen_q,
max_seqlen_k=attn_metadata.max_seqlen_k,
min_seqlen_q=attn_metadata.min_seqlen_q,
dropout_p=attn_metadata.dropout_p,
softmax_scale=self.scale,
causal=True,
)
# Use PyTorch SDPA for MLA prefill attention (no CK dependency)
import torch.nn.functional as F

cu_q = attn_metadata.cu_seqlens_q
cu_k = attn_metadata.cu_seqlens_k
num_seqs = cu_q.shape[0] - 1
Comment on lines +399 to +404
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description focuses on enabling Triton MoE on gfx950 / CK-unavailable fallback for MoE, but this diff also changes the MLA prefill attention implementation (switching kernels/backends). Please update the PR description (and test plan) to cover this additional behavioral/performance change, or split it into a separate PR if it’s not required for the MoE enablement.

Copilot uses AI. Check for mistakes.
outputs = []
for i in range(num_seqs):
qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0)
ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
oi = F.scaled_dot_product_attention(
qi, ki, vi, is_causal=True, scale=self.scale
)
outputs.append(oi.squeeze(0).transpose(0, 1))
output = torch.cat(outputs, dim=0)

Comment on lines +399 to 415
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replaces a single flash_attn_varlen_func call with a Python loop over sequences and per-sequence scaled_dot_product_attention calls, which will scale poorly with batch size and likely be a major prefill performance regression. Consider keeping flash_attn_varlen_func as the fast path and only falling back to the SDPA loop when the varlen kernel is unavailable.

Suggested change
# Use PyTorch SDPA for MLA prefill attention (no CK dependency)
import torch.nn.functional as F
cu_q = attn_metadata.cu_seqlens_q
cu_k = attn_metadata.cu_seqlens_k
num_seqs = cu_q.shape[0] - 1
outputs = []
for i in range(num_seqs):
qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0)
ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
oi = F.scaled_dot_product_attention(
qi, ki, vi, is_causal=True, scale=self.scale
)
outputs.append(oi.squeeze(0).transpose(0, 1))
output = torch.cat(outputs, dim=0)
# Prefer FlashAttention varlen kernel for MLA prefill; fall back to PyTorch SDPA if unavailable.
cu_q = attn_metadata.cu_seqlens_q
cu_k = attn_metadata.cu_seqlens_k
try:
# flash_attn_varlen_func expects [total_tokens, n_heads, head_dim] tensors with varlen metadata.
output = flash_attn_varlen_func(
q,
k,
v,
cu_q,
cu_k,
attn_metadata.max_seqlen_q,
attn_metadata.max_seqlen_k,
0.0, # dropout_p
softmax_scale=self.scale,
causal=True,
)
except Exception:
# Fallback: per-sequence PyTorch SDPA (slower, but no specialized kernel required).
import torch.nn.functional as F
num_seqs = cu_q.shape[0] - 1
outputs = []
for i in range(num_seqs):
qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0)
ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
oi = F.scaled_dot_product_attention(
qi, ki, vi, is_causal=True, scale=self.scale
)
outputs.append(oi.squeeze(0).transpose(0, 1))
output = torch.cat(outputs, dim=0)

Copilot uses AI. Check for mistakes.
return self.o_proj(output.flatten(start_dim=-2))

Expand Down
50 changes: 50 additions & 0 deletions atom/model_ops/attentions/aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,56 @@ def prepare_prefill(self, batch: ScheduledBatch):
bs = batch.total_seqs_num_prefill
sum_scheduled_tokens = batch.total_tokens_num_prefill
var = self.model_runner.forward_vars

# Prepare paged KV metadata for MLA prefill paths
# (needed by mla_prefill_fwd for bf16, unified_attention for fp8)
if batch.block_tables:
context_lens = np.asarray(batch.context_lens[:bs], dtype=np.int32)
num_blocks_per_seq = cdiv(context_lens, self.block_size)
kv_indptr = np.cumsum(num_blocks_per_seq)
sum_blocks = kv_indptr[-1]

dst = var["kv_indices"].np
offset = 0
for i in range(bs):
bt = batch.block_tables[i]
n = len(bt)
dst[offset : offset + n] = bt
offset += n
sum_blocks_before_converted = offset

var["kv_indptr"].np[0] = 0
var["kv_indptr"].np[1 : bs + 1] = kv_indptr

attn_metadata.kv_indptr = var["kv_indptr"].copy_to_gpu(bs + 1)
attn_metadata.kv_indices = var["kv_indices"].copy_to_gpu(
sum_blocks_before_converted
)
attn_metadata.kv_last_page_lens = var["kv_last_page_lens"].gpu[:bs]

if self.block_ratio > 1:
kv_indices_convert_triton(
var["kv_indices"].gpu[:sum_blocks_before_converted],
var["kv_indices_converted"].gpu[:sum_blocks],
var["kv_indptr"].gpu[: bs + 1],
self.block_ratio,
self.block_size,
)
attn_metadata.kv_indices = var["kv_indices_converted"].gpu[:sum_blocks]

# Prepare block_tables for unified_attention (fp8 prefill)
if attn_metadata.block_tables is None:
self.prepare_block_tables(batch)
attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs)
if self.block_ratio > 1:
block_table_convert_triton(
var["block_tables"].gpu[:bs],
var["block_tables_converted"].gpu[:bs],
var["context_lens"].gpu[:bs],
self.block_ratio,
)
attn_metadata.block_tables = var["block_tables_converted"].gpu[:bs]

if self.is_sparse and attn_metadata.max_seqlen_k > self.index_topk:
if attn_metadata.block_tables is None:
self.prepare_block_tables(batch)
Expand Down
9 changes: 7 additions & 2 deletions atom/model_ops/fused_moe_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
from triton_kernels.routing import routing
from triton_kernels.matmul_ogs import PrecisionConfig
from triton_kernels.matmul_ogs import update_opt_flags_constraints

if get_gfx() == "gfx950":
# MI355X has 160KB LDS; default CDNA4 block_m=256 exceeds it.
update_opt_flags_constraints({"block_m": 128})
except (AttributeError, ImportError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
Expand All @@ -53,9 +58,9 @@ def _swizzle_mxfp4(quant_tensor, scale):
scale_layout_opts: dict[str, Any] = {}
value_layout = StridedLayout
if get_gfx() == "gfx950":
from triton_kernels.tensor_details.layout import GFX950MXScaleLayout
from triton_kernels.tensor_details.layout import CDNA4MXScaleLayout

scale_layout = GFX950MXScaleLayout
scale_layout = CDNA4MXScaleLayout
else:
scale_layout = StridedLayout

Expand Down
4 changes: 4 additions & 0 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,10 @@ def weight_loader(
elif self.quant_type == QuantType.per_Tensor:
shard_offset = loaded_shard_id
shard_size = 1
else:
# per_Token and per_1x32: scale dim 0 matches output_size
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
Expand Down
Loading
Loading