Skip to content

Commit 6bf2371

Browse files
authored
Fix modeling example
2 parents 801e816 + 2767510 commit 6bf2371

File tree

4 files changed

+102
-176
lines changed

4 files changed

+102
-176
lines changed

csrc/flash_api.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,8 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
311311
) {
312312

313313
// This needs to match with run_mha_fwd_splitkv_dispatch
314-
const int block_n = head_size <= 32 ? 128 : (head_size <= 128 ? 128 : 64);
314+
const int block_n = 64;
315+
// const int block_n = head_size <= 32 ? 128 : (head_size <= 128 ? 128 : 64);
315316
const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
316317
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
317318
// In any case we don't expect seqlen_q to be larger than 64 for inference.

csrc/src/flash_bwd_launch_template.h

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,11 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
138138
C10_CUDA_CHECK(status_);
139139
}
140140
if (max_smem_per_block >= 104 * 1024) { // H100 and A100
141-
// 104KB
141+
// 104KB, 1 CTAs in A100, 2 CTAs in H100.
142142
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
143143
} else { // sm86 and sm89
144-
// 96KB
145-
// We need to adjust no_double_buffer to save some smem, because is_v_in_regs=true will still allocate smem that may overflow
146-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_causal>(params, stream);
144+
// 96KB, 2 CTAs in sm86 and sm 89.
145+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
147146
}
148147
}
149148

@@ -158,17 +157,17 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
158157
if (status_ != cudaSuccess) {
159158
C10_CUDA_CHECK(status_);
160159
}
161-
// printf("max_smem_per_block = %d\n", max_smem_per_block);
162-
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
163-
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
164160
if (max_smem_per_block >= 144 * 1024) { // H100 and A100
165-
// 144KB
161+
// In fwd, multi-CTA configurations are faster, but in bwd, their speeds are very close.
162+
// 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100.
163+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
164+
// 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100.
165+
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
166+
// 144KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
166167
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
167-
// This has a lot of register spilling
168-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>>(params, stream);
169168
} else { // sm86 and sm89
170-
// 88KB
171-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
169+
// 72KB, 1 CTAs in sm86 and sm 89.
170+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
172171
}
173172
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
174173
}
@@ -186,11 +185,11 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
186185
}
187186
// printf("max_smem_per_block = %d\n", max_smem_per_block);
188187
if (max_smem_per_block >= 116 * 1024) { // H100 and A100
189-
// 116KB
188+
// 116KB, 1 CTAs in A100, 1 CTAs in H100.
190189
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
191190
} else { // sm86 and sm89
192-
// 80KB
193-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
191+
// 92KB, 1 CTAs in sm86 and sm 89.
192+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
194193
}
195194
}
196195

@@ -205,20 +204,12 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
205204
if (status_ != cudaSuccess) {
206205
C10_CUDA_CHECK(status_);
207206
}
208-
// printf("max_smem_per_block = %d\n", max_smem_per_block);
209-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 2, 2, 2, false, false, T>>(params, stream);
210-
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
211-
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
212-
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 2, 2, false, false, T>>(params, stream);
213-
if (max_smem_per_block >= 224 * 1024) { // H100
214-
// 224KB
215-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
216-
} else if (max_smem_per_block >= 144 * 1024) { // A100
217-
// 144KB
207+
if (max_smem_per_block >= 144 * 1024) { // H100 and A100
208+
// 144KB, 1 CTAs in A100, 1 CTAs in H100.
218209
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
219210
} else { // sm86 and sm89
220-
// 88KB
221-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
211+
// 88KB, 1 CTAs in sm86 and sm 89.
212+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_causal>(params, stream);
222213
}
223214
}
224215

@@ -233,15 +224,12 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
233224
if (status_ != cudaSuccess) {
234225
C10_CUDA_CHECK(status_);
235226
}
236-
if (max_smem_per_block >= 208 * 1024) { // H100
237-
// 208KB
238-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
239-
} else if (max_smem_per_block >= 152 * 1024) { // A100
240-
// 152KB
227+
if (max_smem_per_block >= 136 * 1024) { // H100 and A100
228+
// 136KB, 1 CTAs in A100, 1 CTAs in H100.
241229
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
242230
} else { // sm86 and sm89
243-
// 88KB
244-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
231+
// 96KB, 1 CTAs in sm86 and sm 89.
232+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_causal>(params, stream);
245233
}
246234
}
247235

@@ -256,15 +244,15 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
256244
if (status_ != cudaSuccess) {
257245
C10_CUDA_CHECK(status_);
258246
}
259-
if (max_smem_per_block >= 200 * 1024) { // H100
260-
// 200KB
247+
if (max_smem_per_block >= 176 * 1024) { // H100
248+
// 176KB, 1 CTAs in H100.
261249
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
262-
} else if (max_smem_per_block >= 132 * 1024) { // A100
263-
// 132KB
264-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
250+
} else if (max_smem_per_block >= 144 * 1024) { // A100
251+
// 144KB, 1 CTAs in A100.
252+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
265253
} else { // sm86 and sm89
266-
// 82KB
267-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 32, 8, 4, 1, 2, true, false, T>, Is_causal>(params, stream);
254+
// 96KB, 1 CTAs in sm86 and sm 89.
255+
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal>(params, stream);
268256
}
269257
}
270258

csrc/src/flash_fwd_launch_template.h

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
155155
template<typename T, int Headdim, bool Is_causal>
156156
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
157157
constexpr static int kBlockM = 64; // Fixed for all head dimensions
158-
constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 128 ? 128 : 64);
158+
constexpr static int kBlockN = 64; // Fixed for all head dimensions
159+
// constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 128 ? 128 : 64);
159160
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
160161
}
161162

@@ -171,11 +172,18 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
171172
if (status_ != cudaSuccess) {
172173
C10_CUDA_CHECK(status_);
173174
}
174-
if (max_smem_per_block >= 176 * 1024) {
175-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
175+
if (max_smem_per_block >= 164 * 1024) {
176+
// 28KB, 3 CTAs in sm86 and sm 89, 5 CTAs in A100, 8 CTAs in H100.
177+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
178+
// 48KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100.
179+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
180+
// 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100.
181+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
176182
} else {
177-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
183+
// 24KB, 4 CTAs in sm86 and sm 89.
184+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, true, true, T>, Is_causal>(params, stream);
178185
}
186+
179187
}
180188

181189
template<typename T, bool Is_causal>
@@ -190,11 +198,18 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
190198
if (status_ != cudaSuccess) {
191199
C10_CUDA_CHECK(status_);
192200
}
193-
if (max_smem_per_block >= 224 * 1024) {
194-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
195-
} else {
201+
if (max_smem_per_block >= 164 * 1024) { // H100 and A100
202+
// 40KB, 2 CTAs in sm86 and sm 89, 4 CTAs in A100, 5 CTAs in H100.
196203
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
204+
// 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100.
205+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_causal>(params, stream);
206+
// 112KB, N/A in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100.
207+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
208+
} else { // sm86 and sm89
209+
// 32KB, 3 CTAs in sm86 and sm 89.
210+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, true, true, T>, Is_causal>(params, stream);
197211
}
212+
198213
}
199214

200215
template<typename T, bool Is_causal>
@@ -209,9 +224,15 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
209224
if (status_ != cudaSuccess) {
210225
C10_CUDA_CHECK(status_);
211226
}
212-
if (max_smem_per_block >= 160 * 1024) {
213-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
214-
} else {
227+
if (max_smem_per_block >= 164 * 1024) { // H100 and A100
228+
// 52KB, 1 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100.
229+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
230+
// 80KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 2 CTAs in H100.
231+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
232+
// 136KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
233+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
234+
} else { // sm86 and sm89
235+
// 40KB, 2 CTAs in sm86 and sm 89.
215236
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, true, true, T>, Is_causal>(params, stream);
216237
}
217238
}
@@ -228,19 +249,28 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
228249
if (status_ != cudaSuccess) {
229250
C10_CUDA_CHECK(status_);
230251
}
231-
if (max_smem_per_block >= 192 * 1024) {
232-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
233-
} else {
234-
// For sm86 or sm89, 64 x 64 (48 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM.
235-
// Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment
252+
if (max_smem_per_block >= 164 * 1024) { // H100 and A100
253+
// 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100.
254+
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
255+
// 96KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100.
256+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
257+
// 160KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
258+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
259+
} else { // sm86 and sm89
260+
// 48KB, 2 CTAs in sm86 and sm 89.
236261
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, true, true, T>, Is_causal>(params, stream);
237262
}
238263
}
239264

240265
template<typename T, bool Is_causal>
241266
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
242267
constexpr static int Headdim = 192;
268+
// 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100.
243269
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
270+
// 128KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
271+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
272+
// 208KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100.
273+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
244274
}
245275

246276
template<typename T, bool Is_causal>
@@ -255,9 +285,15 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
255285
if (status_ != cudaSuccess) {
256286
C10_CUDA_CHECK(status_);
257287
}
258-
if (max_smem_per_block >= 224 * 1024) {
288+
if (max_smem_per_block >= 112 * 1024) { // H100 and A100
289+
// 112KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100.
259290
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
260-
} else {
291+
// 192KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100.
292+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
293+
// 256KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, N/A CTAs in H100.
294+
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
295+
} else { // sm86 and sm89
296+
// 80KB, 1 CTAs in sm86 and sm 89.
261297
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, true, true, T>, Is_causal>(params, stream);
262298
}
263299
}

0 commit comments

Comments
 (0)