Skip to content

Commit 2767510

Browse files
committed
Optimizes flash attention backward kernel configurations
Adjusts kernel parameters across different head dimensions to improve memory usage and performance on various GPU architectures. Updates shared memory requirements and CTA counts for better utilization on sm86, sm89, A100, and H100 GPUs. Enables double buffering and adjusts block sizes to reduce memory footprint while maintaining or improving performance across different hardware configurations.
1 parent 1082e72 commit 2767510

File tree

1 file changed

+29
-41
lines changed

1 file changed

+29
-41
lines changed

csrc/src/flash_bwd_launch_template.h

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,11 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
138138
C10_CUDA_CHECK(status_);
139139
}
140140
if (max_smem_per_block >= 104 * 1024) { // H100 and A100
141-
// 104KB
141+
// 104KB, 1 CTAs in A100, 2 CTAs in H100.
142142
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
143143
} else { // sm86 and sm89
144-
// 96KB
145-
// We need to adjust no_double_buffer to save some smem, because is_v_in_regs=true will still allocate smem that may overflow
146-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_causal>(params, stream);
144+
// 96KB, 2 CTAs in sm86 and sm 89.
145+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
147146
}
148147
}
149148

@@ -158,17 +157,17 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
158157
if (status_ != cudaSuccess) {
159158
C10_CUDA_CHECK(status_);
160159
}
161-
// printf("max_smem_per_block = %d\n", max_smem_per_block);
162-
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
163-
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
164160
if (max_smem_per_block >= 144 * 1024) { // H100 and A100
165-
// 144KB
161+
// In fwd, multi-CTA configurations are faster, but in bwd, their speeds are very close.
162+
// 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100.
163+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
164+
// 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100.
165+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
166+
// 144KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
166167
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
167-
// This has a lot of register spilling
168-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>>(params, stream);
169168
} else { // sm86 and sm89
170-
// 88KB
171-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
169+
// 72KB, 1 CTAs in sm86 and sm 89.
170+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
172171
}
173172
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
174173
}
@@ -186,11 +185,11 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
186185
}
187186
// printf("max_smem_per_block = %d\n", max_smem_per_block);
188187
if (max_smem_per_block >= 116 * 1024) { // H100 and A100
189-
// 116KB
188+
// 116KB, 1 CTAs in A100, 1 CTAs in H100.
190189
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
191190
} else { // sm86 and sm89
192-
// 80KB
193-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
191+
// 92KB, 1 CTAs in sm86 and sm 89.
192+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
194193
}
195194
}
196195

@@ -205,20 +204,12 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
205204
if (status_ != cudaSuccess) {
206205
C10_CUDA_CHECK(status_);
207206
}
208-
// printf("max_smem_per_block = %d\n", max_smem_per_block);
209-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 2, 2, 2, false, false, T>>(params, stream);
210-
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
211-
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
212-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 2, 2, false, false, T>>(params, stream);
213-
if (max_smem_per_block >= 224 * 1024) { // H100
214-
// 224KB
215-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
216-
} else if (max_smem_per_block >= 144 * 1024) { // A100
217-
// 144KB
207+
if (max_smem_per_block >= 144 * 1024) { // H100 and A100
208+
// 144KB, 1 CTAs in A100, 1 CTAs in H100.
218209
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
219210
} else { // sm86 and sm89
220-
// 88KB
221-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
211+
// 88KB, 1 CTAs in sm86 and sm 89.
212+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_causal>(params, stream);
222213
}
223214
}
224215

@@ -233,15 +224,12 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
233224
if (status_ != cudaSuccess) {
234225
C10_CUDA_CHECK(status_);
235226
}
236-
if (max_smem_per_block >= 208 * 1024) { // H100
237-
// 208KB
238-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
239-
} else if (max_smem_per_block >= 152 * 1024) { // A100
240-
// 152KB
227+
if (max_smem_per_block >= 136 * 1024) { // H100 and A100
228+
// 136KB, 1 CTAs in A100, 1 CTAs in H100.
241229
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
242230
} else { // sm86 and sm89
243-
// 88KB
244-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
231+
// 96KB, 1 CTAs in sm86 and sm 89.
232+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_causal>(params, stream);
245233
}
246234
}
247235

@@ -256,15 +244,15 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
256244
if (status_ != cudaSuccess) {
257245
C10_CUDA_CHECK(status_);
258246
}
259-
if (max_smem_per_block >= 200 * 1024) { // H100
260-
// 200KB
247+
if (max_smem_per_block >= 176 * 1024) { // H100
248+
// 176KB, 1 CTAs in H100.
261249
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
262-
} else if (max_smem_per_block >= 132 * 1024) { // A100
263-
// 132KB
264-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
250+
} else if (max_smem_per_block >= 144 * 1024) { // A100
251+
// 144KB, 1 CTAs in A100.
252+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
265253
} else { // sm86 and sm89
266-
// 82KB
267-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 32, 8, 4, 1, 2, true, false, T>, Is_causal>(params, stream);
254+
// 96KB, 1 CTAs in sm86 and sm 89.
255+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal>(params, stream);
268256
}
269257
}
270258

0 commit comments

Comments
 (0)