@@ -138,12 +138,11 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) {
138138 C10_CUDA_CHECK (status_);
139139 }
140140 if (max_smem_per_block >= 104 * 1024 ) { // H100 and A100
141- // 104KB
141+ // 104KB, 1 CTAs in A100, 2 CTAs in H100.
142142 run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128 , 128 , 8 , 4 , 4 , 4 , false , false , T>, Is_causal>(params, stream);
143143 } else { // sm86 and sm89
144- // 96KB
145- // We need to adjust no_double_buffer to save some smem, because is_v_in_regs=true will still allocate smem that may overflow
146- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128 , 128 , 8 , 4 , 4 , 4 , false , true , T>, Is_causal>(params, stream);
144+ // 96KB, 2 CTAs in sm86 and sm 89.
145+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128 , 128 , 8 , 4 , 4 , 4 , true , false , T>, Is_causal>(params, stream);
147146 }
148147}
149148
@@ -158,17 +157,17 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
158157 if (status_ != cudaSuccess) {
159158 C10_CUDA_CHECK (status_);
160159 }
161- // printf("max_smem_per_block = %d\n", max_smem_per_block);
162- // Changing AtomLayoutMdQ from 2 to 4 takes the same time
163- // This is slightly faster. We want to split M more so we need fewer registers to store LSE.
164160 if (max_smem_per_block >= 144 * 1024 ) { // H100 and A100
165- // 144KB
161+ // In fwd, multi-CTA configurations are faster, but in bwd, their speeds are very close.
162+ // 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100.
163+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
164+ // 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100.
165+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
166+ // 144KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
166167 run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128 , 128 , 8 , 4 , 4 , 4 , false , false , T>, Is_causal>(params, stream);
167- // This has a lot of register spilling
168- // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>>(params, stream);
169168 } else { // sm86 and sm89
170- // 88KB
171- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 128 , 8 , 2 , 4 , 4 , false , false , T>, Is_causal>(params, stream);
169+ // 72KB, 1 CTAs in sm86 and sm 89.
170+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 128 , 8 , 2 , 4 , 4 , true , false , T>, Is_causal>(params, stream);
172171 }
173172 // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
174173}
@@ -186,11 +185,11 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) {
186185 }
187186 // printf("max_smem_per_block = %d\n", max_smem_per_block);
188187 if (max_smem_per_block >= 116 * 1024 ) { // H100 and A100
189- // 116KB
188+ // 116KB, 1 CTAs in A100, 1 CTAs in H100.
190189 run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 128 , 8 , 2 , 4 , 4 , false , false , T>, Is_causal>(params, stream);
191190 } else { // sm86 and sm89
192- // 80KB
193- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 2 , 4 , 4 , false , false , T>, Is_causal>(params, stream);
191+ // 92KB, 1 CTAs in sm86 and sm 89.
192+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 128 , 8 , 2 , 4 , 4 , true , false , T>, Is_causal>(params, stream);
194193 }
195194}
196195
@@ -205,20 +204,12 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
205204 if (status_ != cudaSuccess) {
206205 C10_CUDA_CHECK (status_);
207206 }
208- // printf("max_smem_per_block = %d\n", max_smem_per_block);
209- // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 2, 2, 2, false, false, T>>(params, stream);
210- // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
211- // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
212- // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 2, 2, false, false, T>>(params, stream);
213- if (max_smem_per_block >= 224 * 1024 ) { // H100
214- // 224KB
215- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128 , 128 , 8 , 2 , 4 , 2 , false , false , T>, Is_causal>(params, stream);
216- } else if (max_smem_per_block >= 144 * 1024 ) { // A100
217- // 144KB
207+ if (max_smem_per_block >= 144 * 1024 ) { // H100 and A100
208+ // 144KB, 1 CTAs in A100, 1 CTAs in H100.
218209 run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 128 , 8 , 2 , 4 , 2 , false , false , T>, Is_causal>(params, stream);
219210 } else { // sm86 and sm89
220- // 88KB
221- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , false , true , T>, Is_causal>(params, stream);
211+ // 88KB, 1 CTAs in sm86 and sm 89.
212+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , true , false , T>, Is_causal>(params, stream);
222213 }
223214}
224215
@@ -233,15 +224,12 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
233224 if (status_ != cudaSuccess) {
234225 C10_CUDA_CHECK (status_);
235226 }
236- if (max_smem_per_block >= 208 * 1024 ) { // H100
237- // 208KB
238- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 128 , 8 , 4 , 2 , 2 , false , false , T>, Is_causal>(params, stream);
239- } else if (max_smem_per_block >= 152 * 1024 ) { // A100
240- // 152KB
227+ if (max_smem_per_block >= 136 * 1024 ) { // H100 and A100
228+ // 136KB, 1 CTAs in A100, 1 CTAs in H100.
241229 run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , false , false , T>, Is_causal>(params, stream);
242230 } else { // sm86 and sm89
243- // 88KB
244- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32 , 64 , 8 , 4 , 2 , 2 , false , true , T>, Is_causal>(params, stream);
231+ // 96KB, 1 CTAs in sm86 and sm 89.
232+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , true , true , T>, Is_causal>(params, stream);
245233 }
246234}
247235
@@ -256,15 +244,15 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) {
256244 if (status_ != cudaSuccess) {
257245 C10_CUDA_CHECK (status_);
258246 }
259- if (max_smem_per_block >= 200 * 1024 ) { // H100
260- // 200KB
247+ if (max_smem_per_block >= 176 * 1024 ) { // H100
248+ // 176KB, 1 CTAs in H100.
261249 run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , false , false , T>, Is_causal>(params, stream);
262- } else if (max_smem_per_block >= 132 * 1024 ) { // A100
263- // 132KB
264- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32 , 64 , 8 , 4 , 2 , 2 , false , false , T>, Is_causal>(params, stream);
250+ } else if (max_smem_per_block >= 144 * 1024 ) { // A100
251+ // 144KB, 1 CTAs in A100.
252+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 64 , 8 , 4 , 2 , 2 , false , true , T>, Is_causal>(params, stream);
265253 } else { // sm86 and sm89
266- // 82KB
267- run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32 , 32 , 8 , 4 , 1 , 2 , true , false , T>, Is_causal>(params, stream);
254+ // 96KB, 1 CTAs in sm86 and sm 89.
255+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64 , 32 , 8 , 4 , 1 , 2 , true , true , T>, Is_causal>(params, stream);
268256 }
269257}
270258
0 commit comments