Skip to content

Add Flash Attention forward kernel with MFMA32 register pipeline#133

Draft
yanguahe wants to merge 18 commits intomainfrom
hyg_mha
Draft

Add Flash Attention forward kernel with MFMA32 register pipeline#133
yanguahe wants to merge 18 commits intomainfrom
hyg_mha

Conversation

@yanguahe
Copy link

@yanguahe yanguahe commented Feb 13, 2026

Motivation

Implement a high-performance Flash Attention forward kernel in FlyDSL, targeting AMD Instinct GPUs (MI308X/MI325X). This PR provides a pure FlyDSL implementation of causal multi-head attention (MHA) with MFMA32-based GEMM pipelines, aiming to match or approach the performance of hand-optimized CK (Composable Kernel) implementations.

Technical Details

  • Kernel architecture: MFMA32 register pipeline with mfma_f32_32x32x8f16 for both GEMM stages (S=K@Q^T and O=V^T@P).
  • Tile shape: BLOCK_M=128, BLOCK_N=32, 4 waves (256 threads per workgroup).
  • Register-based softmax: Online softmax computed entirely in registers over the KV dimension, avoiding LDS roundtrips for S/P matrices.
  • P register reuse: P (attention weights) are kept in MFMA32 register layout and fed directly to GEMM2 without LDS writeback.
  • LDS layout: K and V^T use separate single-buffered LDS regions per iteration.
  • Causal masking: Tile-level early-exit for fully masked tiles, element-wise masking for partial tiles.
  • Data layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim).
  • Grid mapping: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M.
  • Constraints: head_dim % 32 == 0, head_dim >= 64, seq_len % 128 == 0.

Test Plan

  • Correctness: Compare against PyTorch F.scaled_dot_product_attention reference with max error < 1e-2 and cosine similarity > 0.99.
  • Benchmark: Profile with run_perftest using 100 iterations after 5 warmup iterations on MI325X and MI308X.
  • Test command:
python tests/kernels/test_flash_attn_func.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100

Test Result

Configuration: B=1, S=8192, H=64, D=128, causal, fp16

MI325X

Kernel Avg Time (us) TFLOPS Speedup vs CK
ASM (aiter) 1,720.3 639.1 1.52x
CK (ck_tile) 2,619.3 419.8 1.00x (baseline)
FlyDSL 2,811.4 391.1 0.93x

MI308X

Kernel Avg Time (us) TFLOPS Speedup vs CK
ASM (aiter) 8,385.7 131.1 1.40x
CK (ck_tile) 11,700.2 94.0 1.00x (baseline)
FlyDSL 11,404.1 96.4 1.03x
  • On MI325X, FlyDSL achieves 93% of CK performance (391.1 vs 419.8 TFLOPS).
  • On MI308X, FlyDSL is slightly faster than CK (96.4 vs 94.0 TFLOPS, ~1.03x).
  • Correctness: max_err = 4.88e-04, cosine_similarity = 1.00000 on both platforms.

Submission Checklist

yanguahe and others added 14 commits February 3, 2026 23:13
- Add kernels/simple_gemm.py: Simple GEMM kernel (C = A × B^T) for AMD GPUs
  using MFMA instructions with XOR16 LDS swizzle and boundary checks for
  non-aligned M, N, K dimensions
- Add tests/kernels/test_simple_gemm.py: Test script with aligned and
  non-aligned dimension test cases
- Add tests/kernels/test_moe_stage1_simple.py: Standalone test script for
  MoE Stage1 kernel
- Add run.sh: Shell script for running tests and collecting ROCm thread traces
- Add input.yaml: ROCm profiler configuration for thread trace collection

Co-authored-by: Cursor <cursoragent@cursor.com>
…n simple GEMM

- Add waves_per_eu parameter to compiler.compile() for AMDGPU occupancy hints
- Implement _apply_waves_per_eu_on_llvm_funcs() to set amdgpu-waves-per-eu
  attribute on GPU kernel functions via LLVM passthrough
- Refactor simple_gemm to use mask-based loads/stores for M/N boundaries
  instead of host-side padding (Triton-like approach)
- Only K dimension is padded on host (required for MFMA vector loads)
- Add --waves_per_eu CLI argument to test_simple_gemm.py

Co-authored-by: Cursor <cursoragent@cursor.com>
- Add OOB_OFFSET (0x80000000) and MAX_NUM_RECORDS (0x7FFFFFFE) constants
  that match Triton's BufferOpsEmitter for reliable hardware OOB detection
- Update buffer load/store to use OOB_OFFSET for masked-out elements,
  ensuring hardware always detects OOB when mask=False
- Simplify GEMM kernel masking by removing redundant K boundary checks
  since K dimension is guaranteed to be padded to tile_k
- Enable additional test cases in run.sh

Co-authored-by: Cursor <cursoragent@cursor.com>
- Add unsafe_fp_math and fast_fp_math parameters to compiler pipeline
- Replace __ocml_exp2_f32 library calls with llvm.intr.exp2 intrinsics
- Apply unsafe-fp-math function attributes to GPU kernel llvm.func ops
- Add fastmath parameter support to arith.maximum operation
- Improve test reproducibility with seed control and MD5 hash comparison
- Add detailed array comparison utility for debugging numerical differences

Co-authored-by: Cursor <cursoragent@cursor.com>
- Switched the **active** `v4_4` kernel path to a true **MFMA32** pipeline (`mfma_f32_32x32x8f16`) with `BLOCK_M=128`, `BLOCK_N=32`, `NUM_WAVES=4`.
- Remapped compute flow to **`K @ Q^T -> online softmax -> V^T @ P`**.
- Kept intermediate **S/P in registers** (removed the previous `P -> LDS -> VGPR` roundtrip).
- Split LDS staging for K and `V^T` into separate regions and removed an inner-loop barrier to cut synchronization overhead.
- Updated test constraints and compile options in `test_flash_attention_v4_4.py` (`seq_len % 128`, `head_dim % 32`, `waves_per_eu=3`).
- Final measured result at target shape: **12350.8 us/iter**, with accuracy preserved (`diff.abs.max=4.88e-4`, `max_diff_thr=3.255208e-04`), about **2.17x faster** than the previous 26751.5 us.
Add gated CK-style N128/prefetch/reduction experiments plus ROCDL phase-fence wrappers so performance tuning can be A/B tested without regressing the stable target-shape path.

Co-authored-by: Cursor <cursoragent@cursor.com>
Align kernel, test, and run-script references so the renamed entrypoint is used consistently across build and benchmark workflows.

Co-authored-by: Cursor <cursoragent@cursor.com>
…c benchmarking.

Drop the v4.3 comparison path to keep tests focused on flash_attn_func and align run.sh defaults with the updated benchmark flow.

Co-authored-by: Cursor <cursoragent@cursor.com>
@yanguahe yanguahe changed the title Hyg mha Add Flash Attention forward kernel with MFMA32 register pipeline Feb 13, 2026
yanguahe and others added 2 commits February 13, 2026 20:26
Resolve merge conflict in flydsl/src/flydsl/compiler/compiler.py:
- Keep all PR functions: _replace_ocml_exp2_with_intrinsic,
  _apply_unsafe_fp_math_on_llvm_funcs, _apply_waves_per_eu_on_llvm_funcs,
  _apply_flat_work_group_size_on_llvm_funcs
- Keep main's _apply_waves_per_eu_hint (gpu.func level, complementary)
- Combine compile() signature: keep waves_per_eu/flat_work_group_size/
  unsafe_fp_math/fast_fp_math params from PR, adopt Optional return type
  from main

Co-authored-by: Cursor <cursoragent@cursor.com>
pass


def _apply_waves_per_eu_on_llvm_funcs(module: ir.Module, waves_per_eu: int) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there has already this function, similar.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! The old _apply_waves_per_eu_hint operated on gpu.func (pre-LLVM lowering) via rocdl.waves_per_eu, while the new _apply_waves_per_eu_on_llvm_funcs operates on llvm.func (post-lowering) via LLVM passthrough amdgpu-waves-per-eu. The passthrough approach is more reliable since it directly controls the LLVM backend. I've removed the old function in commit ac1d477.

return names


def _replace_ocml_exp2_with_intrinsic(module: ir.Module) -> ir.Module:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use exp2 directly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't use math.exp2 directly here — the convert-gpu-to-rocdl pass unconditionally lowers it to __ocml_exp2_f32 (a safe but slow 6-instruction library call with range reduction + v_ldexp_f32). There's no pass-level option to emit llvm.intr.exp2 instead.

This function is a post-lowering optimization: it replaces the OCML library call with llvm.intr.exp2 + fast math flags, giving us a single v_exp_f32 instruction. I've updated the docstring in commit ac1d477 to clarify this rationale and added a TODO to replace with a proper MLIR pass when upstream support is available.

# Descriptor uses i32 bytes; clamp to the max representable.
if nbytes > 0xFFFFFFFF:
nbytes = 0xFFFFFFFF
# Clamp to MAX_NUM_RECORDS to ensure OOB_OFFSET works correctly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this? use dynamic shapes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to dynamic shapes. This is a correctness fix for masked buffer loads/stores.

The previous code used num_records=0xFFFFFFFF with mask_offset=0x7FFFFFFF. The GPU does unsigned comparison offset < num_records for OOB detection, so 0x7FFFFFFF < 0xFFFFFFFF = true — the mask never triggers OOB, which is a bug.

Changed to match Triton's approach:

  • MAX_NUM_RECORDS = 0x7FFFFFFE
  • OOB_OFFSET = 0x80000000
  • Since 0x80000000 > 0x7FFFFFFE (unsigned), hardware OOB is always triggered when mask=False. ✅

@coderfeli coderfeli marked this pull request as draft February 19, 2026 01:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants