@@ -147,16 +147,8 @@ def _fwd_preprocess(
147147 v_base_ptrs + gather_idx [:, None ] * stride_vn + offs_d [None , :]
148148 )
149149 if EVEN_HEADDIM :
150- k = tl .load (
151- k_ptrs ,
152- mask = valid_idx [:, None ],
153- other = 0.0 ,
154- )
155- v = tl .load (
156- v_ptrs ,
157- mask = valid_idx [:, None ],
158- other = 0.0 ,
159- )
150+ k = tl .load (k_ptrs , mask = valid_idx [:, None ], other = 0.0 )
151+ v = tl .load (v_ptrs , mask = valid_idx [:, None ], other = 0.0 )
160152 else :
161153 k = tl .load (
162154 k_ptrs ,
@@ -171,11 +163,7 @@ def _fwd_preprocess(
171163 b_ptrs = (
172164 b_base_ptrs + gather_idx * stride_bn
173165 )
174- b = tl .load (
175- b_ptrs ,
176- mask = valid_idx ,
177- other = 0.0 ,
178- )
166+ b = tl .load (b_ptrs , mask = valid_idx , other = 0.0 )
179167
180168 # Store to CuK, CuV, CuB
181169 cuk_ptrs = (
@@ -185,35 +173,21 @@ def _fwd_preprocess(
185173 cuv_base_ptrs + offs_k [:, None ] * stride_cvk + offs_d [None , :]
186174 )
187175 if EVEN_HEADDIM :
188- tl .store (
189- cuk_ptrs ,
190- k ,
191- mask = valid_idx [:, None ],
192- )
193- tl .store (
194- cuv_ptrs ,
195- v ,
196- mask = valid_idx [:, None ],
197- )
176+ tl .store (cuk_ptrs , k , mask = valid_idx [:, None ])
177+ tl .store (cuv_ptrs , v , mask = valid_idx [:, None ])
198178 else :
199179 tl .store (
200- cuk_ptrs ,
201- k ,
180+ cuk_ptrs , k ,
202181 mask = valid_idx [:, None ] & (offs_d [None , :] < headdim ),
203182 )
204183 tl .store (
205- cuv_ptrs ,
206- v ,
184+ cuv_ptrs , v ,
207185 mask = valid_idx [:, None ] & (offs_d [None , :] < headdim ),
208186 )
209187 cub_ptrs = (
210188 cub_base_ptrs + offs_k * stride_cbk
211189 )
212- tl .store (
213- cub_ptrs ,
214- b ,
215- mask = valid_idx ,
216- )
190+ tl .store (cub_ptrs , b , mask = valid_idx )
217191
218192 # Store mask to CuM
219193 for start_m in range (0 , seqlen_q , BLOCK_M ):
@@ -234,11 +208,7 @@ def _fwd_preprocess(
234208
235209 cum = tl .where (row_mask & col_mask [None , :], mask , False )
236210
237- tl .store (
238- cum_ptrs ,
239- cum ,
240- mask = row_mask & col_mask [None , :],
241- )
211+ tl .store (cum_ptrs , cum , mask = row_mask & col_mask [None , :])
242212
243213
244214@triton .autotune (
@@ -371,7 +341,9 @@ def _fwd_kernel(
371341 q = tl .load (q_ptrs , mask = offs_m [:, None ] < seqlen_q , other = 0.0 )
372342 else :
373343 q = tl .load (
374- q_ptrs , mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ), other = 0.0
344+ q_ptrs ,
345+ mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim ),
346+ other = 0.0
375347 )
376348
377349 # Scale q
@@ -386,9 +358,7 @@ def _fwd_kernel(
386358 )
387359 # Load mask
388360 if EVEN_M & EVEN_N :
389- m = tl .load (
390- cum_ptrs ,
391- )
361+ m = tl .load (cum_ptrs )
392362 else :
393363 m = tl .load (
394364 cum_ptrs ,
@@ -408,15 +378,9 @@ def _fwd_kernel(
408378 )
409379 if EVEN_N :
410380 if EVEN_HEADDIM :
411- k = tl .load (
412- cuk_ptrs ,
413- )
381+ k = tl .load (cuk_ptrs )
414382 else :
415- k = tl .load (
416- cuk_ptrs ,
417- mask = offs_d [None , :] < headdim ,
418- other = 0.0
419- )
383+ k = tl .load (cuk_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
420384 else :
421385 if EVEN_HEADDIM :
422386 k = tl .load (
@@ -436,9 +400,7 @@ def _fwd_kernel(
436400 cub_base_ptrs + (start_n + offs_n ) * stride_cbk
437401 )
438402 if EVEN_M & EVEN_N :
439- b = tl .load (
440- cub_ptrs
441- )
403+ b = tl .load (cub_ptrs )
442404 else :
443405 b = tl .load (
444406 cub_ptrs ,
@@ -475,15 +437,9 @@ def _fwd_kernel(
475437 )
476438 if EVEN_N :
477439 if EVEN_HEADDIM :
478- v = tl .load (
479- cuv_ptrs ,
480- )
440+ v = tl .load (cuv_ptrs )
481441 else :
482- v = tl .load (
483- cuv_ptrs ,
484- mask = offs_d [None , :] < headdim ,
485- other = 0.0
486- )
442+ v = tl .load (cuv_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
487443 else :
488444 if EVEN_HEADDIM :
489445 v = tl .load (
@@ -532,7 +488,8 @@ def _fwd_kernel(
532488 tl .store (out_ptrs , acc_o , mask = offs_m [:, None ] < seqlen_q )
533489 else :
534490 tl .store (
535- out_ptrs , acc_o , mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim )
491+ out_ptrs , acc_o ,
492+ mask = (offs_m [:, None ] < seqlen_q ) & (offs_d [None , :] < headdim )
536493 )
537494
538495
@@ -654,10 +611,14 @@ def _bwd_kernel_one_col_block(
654611 v = tl .load (cuv_ptrs , mask = offs_n [:, None ] < window_size , other = 0.0 )
655612 else :
656613 k = tl .load (
657- cuk_ptrs , mask = (offs_n [:, None ] < window_size ) & (offs_d [None , :] < headdim ), other = 0.0
614+ cuk_ptrs ,
615+ mask = (offs_n [:, None ] < window_size ) & (offs_d [None , :] < headdim ),
616+ other = 0.0
658617 )
659618 v = tl .load (
660- cuv_ptrs , mask = (offs_n [:, None ] < window_size ) & (offs_d [None , :] < headdim ), other = 0.0
619+ cuv_ptrs ,
620+ mask = (offs_n [:, None ] < window_size ) & (offs_d [None , :] < headdim ),
621+ other = 0.0
661622 )
662623 if EVEN_N :
663624 b = tl .load (cub_ptrs )
0 commit comments