Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
11b3967
Add simple GEMM kernel with MFMA 16x16x16 and test scripts
yanguahe Feb 3, 2026
7a3d05b
Add waves_per_eu support and switch to mask-based boundary handling i…
yanguahe Feb 4, 2026
5e11342
Fix hardware OOB handling in buffer ops to match Triton implementation
yanguahe Feb 4, 2026
516fe7d
Add unsafe_fp_math and fast_fp_math compiler options for faster GPU math
yanguahe Feb 6, 2026
ab58591
Add peng's Flash Attention PR: https://github.com/sunway513/FlyDSL/pu…
yanguahe Feb 12, 2026
9390ffa
Add flash_attention_v4_4_kernel and it's test
yanguahe Feb 12, 2026
9ee814c
Romve flash_attention_v4_4_kernel
yanguahe Feb 13, 2026
d41beae
Add flash_attention_v4_4_kernel and it's test
yanguahe Feb 13, 2026
7705b50
Optimize flash_attention_v4_4 with MFMA32 register pipeline
yanguahe Feb 13, 2026
b8fc969
[WIP] Opt flash_attention_v4_4_kernel
yanguahe Feb 13, 2026
73f96c7
Refine flash_attention_v4_4 convergence paths with safe defaults.
yanguahe Feb 13, 2026
4e0fff0
Rename flash_attention_v4_4 artifacts to flash_attn_func.
yanguahe Feb 13, 2026
5767f75
Remove legacy flash_attention_v4 variants and simplify flash_attn_fun…
yanguahe Feb 13, 2026
1c49a1c
Remove temp file
yanguahe Feb 13, 2026
5b039e2
Merge origin/main into hyg_mha, resolve conflict in compiler.py
yanguahe Feb 13, 2026
c274a87
Update test config
yanguahe Feb 14, 2026
ac1d477
Address review: remove unused _apply_waves_per_eu_hint, clarify exp2 …
yanguahe Feb 15, 2026
5401b36
Fix CI: update preshuffle_gemm to use compile(waves_per_eu=) instead …
yanguahe Feb 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
318 changes: 276 additions & 42 deletions flydsl/src/flydsl/compiler/compiler.py

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions flydsl/src/flydsl/dialects/ext/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,13 @@ def f64(value: float, *, loc: Location = None, ip: InsertionPoint = None) -> "Ar
"""Create an f64 constant."""
return constant(value, type=F64Type.get(), loc=loc, ip=ip)

def maximum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue":
def maximum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, fastmath=None, loc: Location = None) -> "ArithValue":
"""Compute maximum of two values (automatically handles float/int types).

Args:
lhs: Left operand (ArithValue, Value, or Python number)
rhs: Right operand (ArithValue, Value, or Python number)
fastmath: Optional fast-math flags (e.g. arith.FastMathFlags.fast)
loc: Optional source location

Returns:
Expand All @@ -171,7 +172,7 @@ def maximum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *,
>>> c = arith.maximum(a, b) # Function style
>>> d = a.max(b) # Method style (equivalent)
"""
return _minmax_op(lhs, rhs, op_type="max", loc=loc)
return _minmax_op(lhs, rhs, op_type="max", fastmath=fastmath, loc=loc)

def minimum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue":
"""Compute minimum of two values (automatically handles float/int types).
Expand Down Expand Up @@ -788,6 +789,7 @@ def _minmax_op(
rhs: "ArithValue",
op_type: str, # "max" or "min"
*,
fastmath=None,
loc: Location = None,
) -> "ArithValue":
"""Execute min/max operation based on operand types."""
Expand All @@ -809,7 +811,10 @@ def _minmax_op(
op_class = _arith.MaximumFOp
else:
op_class = _arith.MinimumFOp
result = op_class(lhs_val, rhs_val, loc=loc).result
if fastmath is not None:
result = op_class(lhs_val, rhs_val, fastmath=fastmath, loc=loc).result
else:
result = op_class(lhs_val, rhs_val, loc=loc).result
elif _is_integer_like_type(lhs_val.type):
# Integer min/max (signed/unsigned logic could be tricky, default to signed for now)
# TODO: Add unsigned support if needed
Expand Down
62 changes: 46 additions & 16 deletions flydsl/src/flydsl/dialects/ext/buffer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@
'i32_select',
]

# =============================================================================
# Constants for Hardware OOB (Out-of-Bounds) Handling
# =============================================================================
# These values are chosen to match Triton's implementation for reliable hardware
# OOB detection in AMD buffer load/store operations.
#
# Reference: triton/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp
# - OOB_OFFSET = static_cast<int>(std::numeric_limits<int>::max() + int64_t(1))
# - numRecordsByte = std::numeric_limits<int>::max() - 1
#
# How it works:
# - When mask=False, offset is replaced with OOB_OFFSET (0x80000000)
# - Hardware compares: if (offset >= num_records) -> return 0 (load) or ignore (store)
# - 0x80000000 (as unsigned) = 2147483648 > 0x7FFFFFFE = 2147483646
# - This guarantees hardware OOB detection triggers for masked-out elements
# =============================================================================
OOB_OFFSET = 0x80000000 # -2147483648 as signed i32, 2147483648 as unsigned
MAX_NUM_RECORDS = 0x7FFFFFFE # 2147483646 (std::numeric_limits<int>::max() - 1)


def create_llvm_ptr(value, address_space: int = 0) -> ir.Value:
"""Convert an index value to LLVM pointer.
Expand Down Expand Up @@ -195,34 +214,41 @@ def _num_records_from_memref_type() -> Optional[int]:

if num_records_bytes is not None:
# Caller-provided size in BYTES (preferred for exact hardware OOB behavior).
# NOTE: When using masks, num_records should not exceed MAX_NUM_RECORDS
# to ensure OOB_OFFSET always triggers hardware OOB detection.
if isinstance(num_records_bytes, int):
nbytes = int(num_records_bytes)
if nbytes <= 0:
nbytes = 0
# Descriptor uses i32 bytes; clamp to the max representable.
if nbytes > 0xFFFFFFFF:
nbytes = 0xFFFFFFFF
# Clamp to MAX_NUM_RECORDS to ensure OOB_OFFSET works correctly.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this? use dynamic shapes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to dynamic shapes. This is a correctness fix for masked buffer loads/stores.

The previous code used num_records=0xFFFFFFFF with mask_offset=0x7FFFFFFF. The GPU does unsigned comparison offset < num_records for OOB detection, so 0x7FFFFFFF < 0xFFFFFFFF = true — the mask never triggers OOB, which is a bug.

Changed to match Triton's approach:

  • MAX_NUM_RECORDS = 0x7FFFFFFE
  • OOB_OFFSET = 0x80000000
  • Since 0x80000000 > 0x7FFFFFFE (unsigned), hardware OOB is always triggered when mask=False. ✅

if nbytes > MAX_NUM_RECORDS:
nbytes = MAX_NUM_RECORDS
num_records = _create_i32_constant(nbytes)
else:
# Value path: cast to i32 if needed.
# Note: For dynamic values, we trust the caller to provide valid sizes.
# If the buffer is larger than MAX_NUM_RECORDS, OOB detection may not
# work correctly for masked loads/stores.
v = _unwrap_value(num_records_bytes)
if not isinstance(v.type, ir.IntegerType) or v.type.width != 32:
op = std_arith.IndexCastOp(ir.IntegerType.get_signless(32), v)
v = _unwrap_value(op.result)
num_records = v
elif max_size:
# Use max for flexibility (hardware will check actual bounds)
# Note: flir's rocdl.make.buffer.rsrc requires i32, not i64
num_records = _create_i32_constant(0xFFFFFFFF) # FALLBACK_MAX_SIZE
# Use MAX_NUM_RECORDS for flexibility with proper OOB handling.
# This value (0x7FFFFFFE) ensures that OOB_OFFSET (0x80000000) will
# always trigger hardware OOB detection.
num_records = _create_i32_constant(MAX_NUM_RECORDS)
else:
# Use the logical memref size (in bytes) for hardware OOB checking.
nbytes = _num_records_from_memref_type()
if nbytes is None:
# Fall back to max-size if we can't infer statically.
num_records = _create_i32_constant(0xFFFFFFFF)
# Fall back to MAX_NUM_RECORDS if we can't infer statically.
num_records = _create_i32_constant(MAX_NUM_RECORDS)
else:
if nbytes > 0xFFFFFFFF:
nbytes = 0xFFFFFFFF
# Clamp to MAX_NUM_RECORDS for proper OOB handling with masks.
if nbytes > MAX_NUM_RECORDS:
nbytes = MAX_NUM_RECORDS
num_records = _create_i32_constant(int(nbytes))

# Create resource descriptor (returns !llvm.ptr<8>)
Expand Down Expand Up @@ -312,11 +338,13 @@ def buffer_load(rsrc: ir.Value,
op = std_arith.MulIOp(offset, bytes_const)
offset = _unwrap_value(op.result)

# Apply mask by setting invalid offsets to max
# Apply mask by setting invalid offsets to OOB_OFFSET
# When mask=False, offset becomes OOB_OFFSET (0x80000000), which is always
# >= MAX_NUM_RECORDS (0x7FFFFFFE), triggering hardware OOB (returns 0).
if mask is not None:
mask = _unwrap_value(mask)
max_offset = _create_i32_constant(0x7FFFFFFF)
op = std_arith.SelectOp(mask, offset, max_offset)
oob_offset = _create_i32_constant(OOB_OFFSET)
op = std_arith.SelectOp(mask, offset, oob_offset)
offset = _unwrap_value(op.result)

# Create vector type
Expand Down Expand Up @@ -400,11 +428,13 @@ def buffer_store(data: ir.Value,
op = std_arith.MulIOp(offset, bytes_const)
offset = _unwrap_value(op.result)

# Apply mask by setting invalid offsets to max
# Apply mask by setting invalid offsets to OOB_OFFSET
# When mask=False, offset becomes OOB_OFFSET (0x80000000), which is always
# >= MAX_NUM_RECORDS (0x7FFFFFFE), triggering hardware OOB (store ignored).
if mask is not None:
mask = _unwrap_value(mask)
max_offset = _create_i32_constant(0x7FFFFFFF)
op = std_arith.SelectOp(mask, offset, max_offset)
oob_offset = _create_i32_constant(OOB_OFFSET)
op = std_arith.SelectOp(mask, offset, oob_offset)
offset = _unwrap_value(op.result)

# Create instruction offset (soffset) and aux flags
Expand Down
147 changes: 146 additions & 1 deletion flydsl/src/flydsl/dialects/ext/rocdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from _mlir.dialects.rocdl import * # noqa: F401,F403

# Keep references to ODS-generated builders so we can wrap them without losing access.
_ods_mfma_f32_32x32x8f16 = globals().get("mfma_f32_32x32x8f16", None)
_ods_mfma_f32_16x16x16f16 = mfma_f32_16x16x16f16
_ods_mfma_f32_16x16x16bf16_1k = globals().get("mfma_f32_16x16x16bf16_1k", None)
_ods_mfma_f32_16x16x32_fp8_fp8 = mfma_f32_16x16x32_fp8_fp8
Expand All @@ -28,6 +29,8 @@
_ods_readlane = readlane
_ods_readfirstlane = readfirstlane
_ods_ds_swizzle = ds_swizzle
_ods_permlane16_swap = permlane16_swap
_ods_permlane32_swap = permlane32_swap
_ods_raw_ptr_buffer_atomic_fadd = raw_ptr_buffer_atomic_fadd

mask_mfma = 0x008
Expand All @@ -45,6 +48,61 @@ def sched_dswr(cnt):
sched_group_barrier(mask_dswr, cnt, 0)


def _unwrap_i32_scalar(v, *, loc=None):
from _mlir.ir import IntegerType
from . import arith as _arith_ext

return _arith_ext.unwrap(v, type=IntegerType.get_signless(32), loc=loc)


def async_global_load_to_lds(global_ptr, lds_ptr, size, offset=0, aux=0, *, loc=None, ip=None):
"""Global->LDS async-style copy wrapper (closest stable ROCDL primitive)."""
from . import arith as _arith_ext

return global_load_lds(
_arith_ext.unwrap(global_ptr, loc=loc),
_arith_ext.unwrap(lds_ptr, loc=loc),
_unwrap_i32_scalar(size, loc=loc),
_unwrap_i32_scalar(offset, loc=loc),
_unwrap_i32_scalar(aux, loc=loc),
loc=loc,
ip=ip,
)


def async_load_to_lds(global_ptr, lds_ptr, size, offset=0, aux=0, *, loc=None, ip=None):
"""Alias for load_to_lds with scalar auto-unwrapping."""
from . import arith as _arith_ext

return load_to_lds(
_arith_ext.unwrap(global_ptr, loc=loc),
_arith_ext.unwrap(lds_ptr, loc=loc),
_unwrap_i32_scalar(size, loc=loc),
_unwrap_i32_scalar(offset, loc=loc),
_unwrap_i32_scalar(aux, loc=loc),
loc=loc,
ip=ip,
)


def async_load_fence(wait_vmem=0, wait_ds=0, *, loc=None, ip=None):
"""Waitcnt-style fence helper for staged async copy scheduling."""
# NOTE: wait_loadcnt/wait_dscnt lowerings are not stable on current toolchain.
# Use conservative full waitcnt fence for now.
_ = (wait_vmem, wait_ds)
return s_waitcnt(0, loc=loc, ip=ip)


def phase_barrier(mask=0, *, loc=None, ip=None):
"""Scheduling barrier wrapper used as phase fence in pipelined kernels."""
return sched_barrier(mask, loc=loc, ip=ip)


def phase_group_barrier(mask, size, group_id=0, *, loc=None, ip=None):
"""Group scheduling barrier wrapper used as phase fence in pipelined kernels."""
return sched_group_barrier(mask, size, group_id, loc=loc, ip=ip)


def _unwrap_mfma_operand(v, *, loc=None):
"""MFMA operands are MLIR Values; some trailing operands are i32 flags.

Expand All @@ -68,6 +126,20 @@ def mfma_f32_16x16x16f16(result_type, operands, *, loc=None, ip=None):
"""Return the op result directly (no `.result` needed at call sites)."""
return mfma_f32_16x16x16f16_op(result_type, operands, loc=loc, ip=ip).result


def mfma_f32_32x32x8f16_op(result_type, operands, *, loc=None, ip=None):
"""Return the op view (original behavior)."""
if _ods_mfma_f32_32x32x8f16 is None:
raise AttributeError("ROCDL op not found: mfma_f32_32x32x8f16")
ops = [_unwrap_mfma_operand(v, loc=loc) for v in operands]
return _ods_mfma_f32_32x32x8f16(result_type, ops, loc=loc, ip=ip)


def mfma_f32_32x32x8f16(result_type, operands, *, loc=None, ip=None):
"""Return the op result directly (no `.result` needed at call sites)."""
return mfma_f32_32x32x8f16_op(result_type, operands, loc=loc, ip=ip).result


# for bf16 version mfma
def mfma_f32_16x16x16bf16_1k_op(result_type, operands, *, loc=None, ip=None):
"""Return the op view (original behavior)."""
Expand Down Expand Up @@ -138,6 +210,73 @@ def ds_swizzle(result_type, src, offset, *, loc=None, ip=None):
return _ods_ds_swizzle(result_type, _arith_ext.unwrap(src), _arith_ext.unwrap(offset), loc=loc, ip=ip)


def _unwrap_i32_lane_operand(v, *, loc=None):
from _mlir.ir import IntegerType
from . import arith as _arith_ext

return _arith_ext.unwrap(v, type=IntegerType.get_signless(32), loc=loc)


def _permlane_i32x2_struct_type():
from _mlir import ir as _ir

# Some Python bindings accept optional spaces in LLVM type parser; keep both.
try:
return _ir.Type.parse("!llvm.struct<(i32, i32)>")
except Exception:
return _ir.Type.parse("!llvm.struct<(i32,i32)>")


def _extract_permlane_lane_i32(pair_val, *, loc=None, ip=None):
from _mlir.dialects import llvm as _llvm
from _mlir.ir import IntegerType

i32 = IntegerType.get_signless(32)
return _llvm.extractvalue(i32, pair_val, [0], loc=loc, ip=ip)


def permlane16_swap_pair(old, src, fi=False, bound_control=False, *, loc=None, ip=None):
"""High-level permlane16 swap wrapper returning the raw i32x2 struct."""
return _ods_permlane16_swap(
_permlane_i32x2_struct_type(),
_unwrap_i32_lane_operand(old, loc=loc),
_unwrap_i32_lane_operand(src, loc=loc),
fi,
bound_control,
loc=loc,
ip=ip,
)


def permlane16_swap_i32(old, src, fi=False, bound_control=False, *, loc=None, ip=None):
"""High-level permlane16 swap wrapper returning the swapped i32 lane value."""
pair_val = permlane16_swap_pair(
old, src, fi=fi, bound_control=bound_control, loc=loc, ip=ip
)
return _extract_permlane_lane_i32(pair_val, loc=loc, ip=ip)


def permlane32_swap_pair(old, src, fi=False, bound_control=False, *, loc=None, ip=None):
"""High-level permlane32 swap wrapper returning the raw i32x2 struct."""
return _ods_permlane32_swap(
_permlane_i32x2_struct_type(),
_unwrap_i32_lane_operand(old, loc=loc),
_unwrap_i32_lane_operand(src, loc=loc),
fi,
bound_control,
loc=loc,
ip=ip,
)


def permlane32_swap_i32(old, src, fi=False, bound_control=False, *, loc=None, ip=None):
"""High-level permlane32 swap wrapper returning the swapped i32 lane value."""
pair_val = permlane32_swap_pair(
old, src, fi=fi, bound_control=bound_control, loc=loc, ip=ip
)
return _extract_permlane_lane_i32(pair_val, loc=loc, ip=ip)


def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, ip=None):
"""Atomic fadd that accepts `ArithValue` / wrappers (no explicit `arith.unwrap(...)` needed).

Expand Down Expand Up @@ -173,6 +312,7 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None,
'barrier', 's_barrier', 's_barrier_signal', 's_barrier_wait',
's_waitcnt', 's_wait_loadcnt', 's_wait_storecnt',
's_wait_dscnt', 's_wait_expcnt',
'async_load_fence',

# Matrix operations - MFMA (Matrix Fused Multiply-Add)
'mfma_f32_32x32x8f16', 'mfma_f32_16x16x16f16',
Expand All @@ -182,7 +322,7 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None,
'mfma_i32_16x16x32_i8',
'mfma_scale_f32_16x16x128_f8f6f4',
# Raw-op constructors (return op view) for the above
'mfma_f32_16x16x16f16_op', 'mfma_f32_16x16x32_fp8_fp8_op',
'mfma_f32_32x32x8f16_op', 'mfma_f32_16x16x16f16_op', 'mfma_f32_16x16x32_fp8_fp8_op',
'mfma_f32_16x16x16bf16_1k_op',
'mfma_i32_16x16x32_i8_op',
'mfma_scale_f32_16x16x128_f8f6f4_op',
Expand All @@ -198,6 +338,8 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None,
# Shuffle and permutation
'ds_swizzle', 'ds_bpermute',
'permlanex16', 'permlane16_swap', 'permlane32_swap',
'permlane16_swap_pair', 'permlane16_swap_i32',
'permlane32_swap_pair', 'permlane32_swap_i32',
'readlane', 'readfirstlane',
'update_dpp',
'ballot',
Expand All @@ -206,6 +348,7 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None,
'raw_buffer_load', 'raw_buffer_store',
'raw_ptr_buffer_load', 'raw_ptr_buffer_store',
'load_to_lds', 'global_load_lds',
'async_load_to_lds', 'async_global_load_to_lds',
'make_buffer_rsrc',

# Atomic operations
Expand All @@ -219,6 +362,8 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None,
# Scheduling and optimization
's_setprio', 's_sleep',
'sched_barrier', 'sched_group_barrier',
'phase_barrier', 'phase_group_barrier',
'sched_mfma', 'sched_vmem', 'sched_dsrd', 'sched_dswr',
'iglp_opt',

# Type conversions
Expand Down
Loading