Skip to content

Commit 5e8bc61

Browse files
committed
Compact tl.load/tl.store calls and tidy formatting in triton kernels
- Simplify multi-line tl.load and tl.store calls into single-line forms where safe - Normalize argument formatting and masks for better readability in flash_dmattn_triton_special.py - No behavioral changes; purely cosmetic/formatting cleanup
1 parent 926bb35 commit 5e8bc61

File tree

1 file changed

+26
-65
lines changed

1 file changed

+26
-65
lines changed

flash_dmattn/flash_dmattn_triton_special.py

Lines changed: 26 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -147,16 +147,8 @@ def _fwd_preprocess(
147147
v_base_ptrs + gather_idx[:, None] * stride_vn + offs_d[None, :]
148148
)
149149
if EVEN_HEADDIM:
150-
k = tl.load(
151-
k_ptrs,
152-
mask=valid_idx[:, None],
153-
other=0.0,
154-
)
155-
v = tl.load(
156-
v_ptrs,
157-
mask=valid_idx[:, None],
158-
other=0.0,
159-
)
150+
k = tl.load(k_ptrs, mask=valid_idx[:, None], other=0.0)
151+
v = tl.load(v_ptrs, mask=valid_idx[:, None], other=0.0)
160152
else:
161153
k = tl.load(
162154
k_ptrs,
@@ -171,11 +163,7 @@ def _fwd_preprocess(
171163
b_ptrs = (
172164
b_base_ptrs + gather_idx * stride_bn
173165
)
174-
b = tl.load(
175-
b_ptrs,
176-
mask=valid_idx,
177-
other=0.0,
178-
)
166+
b = tl.load(b_ptrs, mask=valid_idx, other=0.0)
179167

180168
# Store to CuK, CuV, CuB
181169
cuk_ptrs = (
@@ -185,35 +173,21 @@ def _fwd_preprocess(
185173
cuv_base_ptrs + offs_k[:, None] * stride_cvk + offs_d[None, :]
186174
)
187175
if EVEN_HEADDIM:
188-
tl.store(
189-
cuk_ptrs,
190-
k,
191-
mask=valid_idx[:, None],
192-
)
193-
tl.store(
194-
cuv_ptrs,
195-
v,
196-
mask=valid_idx[:, None],
197-
)
176+
tl.store(cuk_ptrs, k, mask=valid_idx[:, None])
177+
tl.store(cuv_ptrs, v, mask=valid_idx[:, None])
198178
else:
199179
tl.store(
200-
cuk_ptrs,
201-
k,
180+
cuk_ptrs, k,
202181
mask=valid_idx[:, None] & (offs_d[None, :] < headdim),
203182
)
204183
tl.store(
205-
cuv_ptrs,
206-
v,
184+
cuv_ptrs, v,
207185
mask=valid_idx[:, None] & (offs_d[None, :] < headdim),
208186
)
209187
cub_ptrs = (
210188
cub_base_ptrs + offs_k * stride_cbk
211189
)
212-
tl.store(
213-
cub_ptrs,
214-
b,
215-
mask=valid_idx,
216-
)
190+
tl.store(cub_ptrs, b, mask=valid_idx)
217191

218192
# Store mask to CuM
219193
for start_m in range(0, seqlen_q, BLOCK_M):
@@ -234,11 +208,7 @@ def _fwd_preprocess(
234208

235209
cum = tl.where(row_mask & col_mask[None, :], mask, False)
236210

237-
tl.store(
238-
cum_ptrs,
239-
cum,
240-
mask=row_mask & col_mask[None, :],
241-
)
211+
tl.store(cum_ptrs, cum, mask=row_mask & col_mask[None, :])
242212

243213

244214
@triton.autotune(
@@ -371,7 +341,9 @@ def _fwd_kernel(
371341
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
372342
else:
373343
q = tl.load(
374-
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
344+
q_ptrs,
345+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
346+
other=0.0
375347
)
376348

377349
# Scale q
@@ -386,9 +358,7 @@ def _fwd_kernel(
386358
)
387359
# Load mask
388360
if EVEN_M & EVEN_N:
389-
m = tl.load(
390-
cum_ptrs,
391-
)
361+
m = tl.load(cum_ptrs)
392362
else:
393363
m = tl.load(
394364
cum_ptrs,
@@ -408,15 +378,9 @@ def _fwd_kernel(
408378
)
409379
if EVEN_N:
410380
if EVEN_HEADDIM:
411-
k = tl.load(
412-
cuk_ptrs,
413-
)
381+
k = tl.load(cuk_ptrs)
414382
else:
415-
k = tl.load(
416-
cuk_ptrs,
417-
mask=offs_d[None, :] < headdim,
418-
other=0.0
419-
)
383+
k = tl.load(cuk_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
420384
else:
421385
if EVEN_HEADDIM:
422386
k = tl.load(
@@ -436,9 +400,7 @@ def _fwd_kernel(
436400
cub_base_ptrs + (start_n + offs_n) * stride_cbk
437401
)
438402
if EVEN_M & EVEN_N:
439-
b = tl.load(
440-
cub_ptrs
441-
)
403+
b = tl.load(cub_ptrs)
442404
else:
443405
b = tl.load(
444406
cub_ptrs,
@@ -475,15 +437,9 @@ def _fwd_kernel(
475437
)
476438
if EVEN_N:
477439
if EVEN_HEADDIM:
478-
v = tl.load(
479-
cuv_ptrs,
480-
)
440+
v = tl.load(cuv_ptrs)
481441
else:
482-
v = tl.load(
483-
cuv_ptrs,
484-
mask=offs_d[None, :] < headdim,
485-
other=0.0
486-
)
442+
v = tl.load(cuv_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
487443
else:
488444
if EVEN_HEADDIM:
489445
v = tl.load(
@@ -532,7 +488,8 @@ def _fwd_kernel(
532488
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
533489
else:
534490
tl.store(
535-
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
491+
out_ptrs, acc_o,
492+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
536493
)
537494

538495

@@ -654,10 +611,14 @@ def _bwd_kernel_one_col_block(
654611
v = tl.load(cuv_ptrs, mask=offs_n[:, None] < window_size, other=0.0)
655612
else:
656613
k = tl.load(
657-
cuk_ptrs, mask=(offs_n[:, None] < window_size) & (offs_d[None, :] < headdim), other=0.0
614+
cuk_ptrs,
615+
mask=(offs_n[:, None] < window_size) & (offs_d[None, :] < headdim),
616+
other=0.0
658617
)
659618
v = tl.load(
660-
cuv_ptrs, mask=(offs_n[:, None] < window_size) & (offs_d[None, :] < headdim), other=0.0
619+
cuv_ptrs,
620+
mask=(offs_n[:, None] < window_size) & (offs_d[None, :] < headdim),
621+
other=0.0
661622
)
662623
if EVEN_N:
663624
b = tl.load(cub_ptrs)

0 commit comments

Comments
 (0)