From df55ea7599463decf16f5c5fb34d5cd5d8b8f429 Mon Sep 17 00:00:00 2001 From: Nick Riasanovsky Date: Tue, 14 Oct 2025 17:55:09 -0700 Subject: [PATCH 1/2] Added tmp change --- .../kernels/blackwell_triton_fused_attention.py | 17 ++++++++++++----- .../blackwell_triton_fused_attention_dp.py | 6 +++--- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/tritonbench/kernels/blackwell_triton_fused_attention.py b/tritonbench/kernels/blackwell_triton_fused_attention.py index f5e401ede..65f74f097 100644 --- a/tritonbench/kernels/blackwell_triton_fused_attention.py +++ b/tritonbench/kernels/blackwell_triton_fused_attention.py @@ -76,11 +76,8 @@ def _attn_fwd_subtile( qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) else: qk = qk * qk_scale - m_ij[:, None] - p = tl.math.exp2(qk) # -- compute correction factor alpha = tl.math.exp2(m_i - m_ij) - if not FADD2_REDUCE: - l_ij = tl.sum(p, 1) # -- update output accumulator -- BM: tl.constexpr = acc.shape[0] @@ -98,6 +95,7 @@ def _attn_fwd_subtile( else: acc = acc * alpha[:, None] + p = tl.math.exp2(qk) PM: tl.constexpr = p.shape[0] PN: tl.constexpr = p.shape[1] if FADD2_REDUCE: @@ -105,6 +103,8 @@ def _attn_fwd_subtile( l_ij0, l_ij1 = tl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2) l_i0 = l_i0 * alpha + l_ij0 l_i1 = l_i1 * alpha + l_ij1 + else: + l_ij = tl.sum(p, 1) # prepare p and v for the dot p = p.to(dtype) @@ -259,7 +259,7 @@ def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg): if HAS_REG_AUTO_WS: extra_kwargs["minRegAutoWS"] = 24 extra_kwargs["maxRegAutoWS"] = maxreg - extra_kwargs["data_partition_factor"] = 2 + # extra_kwargs["data_partition_factor"] = 2 return triton.Config(config_kwargs, **extra_kwargs) @@ -556,6 +556,7 @@ def _attn_fwd_persist( SUBTILING: tl.constexpr, VECT_MUL: tl.constexpr, FADD2_REDUCE: tl.constexpr, + data_partition_factor: tl.constexpr, ): n_tile_num = tl.cdiv(N_CTX, BLOCK_M) prog_id = tl.program_id(0) @@ -594,7 +595,12 @@ def _attn_fwd_persist( ) # inner loop warpspec vs. outer loop warpspec - for _ in tl.range(0, tiles_per_sm, warp_specialize=warp_specialize and OUTER_LOOP): + for _ in tl.range( + 0, + tiles_per_sm, + warp_specialize=warp_specialize and OUTER_LOOP, + data_partition_factor=data_partition_factor, + ): pid = tile_idx % n_tile_num off_hz = tile_idx // n_tile_num _attn_fwd_tma_dp( @@ -707,6 +713,7 @@ def grid_debug(META): warp_specialize=warp_specialize, OUTER_LOOP=True, dtype=torch_dtype_to_triton(q.dtype), + data_partition_factor=2, **extra_kern_args, ) else: diff --git a/tritonbench/kernels/blackwell_triton_fused_attention_dp.py b/tritonbench/kernels/blackwell_triton_fused_attention_dp.py index 446440fd7..f8db8ab77 100644 --- a/tritonbench/kernels/blackwell_triton_fused_attention_dp.py +++ b/tritonbench/kernels/blackwell_triton_fused_attention_dp.py @@ -80,11 +80,8 @@ def _attn_fwd_subtile( qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None]) else: qk = qk * qk_scale - m_ij[:, None] - p = tl.math.exp2(qk) # -- compute correction factor alpha = tl.math.exp2(m_i - m_ij) - if not FADD2_REDUCE: - l_ij = tl.sum(p, 1) # -- update output accumulator -- BM: tl.constexpr = acc.shape[0] @@ -104,6 +101,7 @@ def _attn_fwd_subtile( # update m_i and l_i # place this at the end of the loop to reduce register pressure + p = tl.math.exp2(qk) PM: tl.constexpr = p.shape[0] PN: tl.constexpr = p.shape[1] if FADD2_REDUCE: @@ -111,6 +109,8 @@ def _attn_fwd_subtile( l_ij0, l_ij1 = tl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2) l_i0 = l_i0 * alpha + l_ij0 l_i1 = l_i1 * alpha + l_ij1 + else: + l_ij = tl.sum(p, 1) # We can potentially move these to be before updating l_ij, so the dot # is not blocked. From f54077caec9933f85d6d65b6537de8645576d7e9 Mon Sep 17 00:00:00 2001 From: Nick Riasanovsky Date: Wed, 15 Oct 2025 08:08:03 -0700 Subject: [PATCH 2/2] revert data_partition_factor change --- .../kernels/blackwell_triton_fused_attention.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tritonbench/kernels/blackwell_triton_fused_attention.py b/tritonbench/kernels/blackwell_triton_fused_attention.py index 65f74f097..90dc93b39 100644 --- a/tritonbench/kernels/blackwell_triton_fused_attention.py +++ b/tritonbench/kernels/blackwell_triton_fused_attention.py @@ -259,7 +259,7 @@ def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg): if HAS_REG_AUTO_WS: extra_kwargs["minRegAutoWS"] = 24 extra_kwargs["maxRegAutoWS"] = maxreg - # extra_kwargs["data_partition_factor"] = 2 + extra_kwargs["data_partition_factor"] = 2 return triton.Config(config_kwargs, **extra_kwargs) @@ -556,7 +556,6 @@ def _attn_fwd_persist( SUBTILING: tl.constexpr, VECT_MUL: tl.constexpr, FADD2_REDUCE: tl.constexpr, - data_partition_factor: tl.constexpr, ): n_tile_num = tl.cdiv(N_CTX, BLOCK_M) prog_id = tl.program_id(0) @@ -595,12 +594,7 @@ def _attn_fwd_persist( ) # inner loop warpspec vs. outer loop warpspec - for _ in tl.range( - 0, - tiles_per_sm, - warp_specialize=warp_specialize and OUTER_LOOP, - data_partition_factor=data_partition_factor, - ): + for _ in tl.range(0, tiles_per_sm, warp_specialize=warp_specialize and OUTER_LOOP): pid = tile_idx % n_tile_num off_hz = tile_idx // n_tile_num _attn_fwd_tma_dp(